Psst.. new poll here.
Psst.. new forums here.
Microsoft is blocking us again (TY IP Reputation!) so just use oauth login instead. :)
Paste
Pasted as Python by Margaret ( 6 years ago )
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
from parlai.core.teachers import (
FixedDialogTeacher,
DialogTeacher,
create_task_agent_from_taskname,
)
from .build import build
import json
import random
import os
import copy
class MultiTurnTeacher(FixedDialogTeacher):
@staticmethod
def add_cmdline_args(parser):
parser.add_argument(
'--ok',
type=str,
choices=['__ok__', '__notok__', 'all'],
default='__notok__',
)
parser.add_argument(
'--add_control', type='bool', choices=[True, False], default=False
)
parser.add_argument(
'--force_control', type=str, choices=['__notok__', '', '__ok__'], default=''
)
def __init__(self, opt, shared=None):
build(opt)
super().__init__(opt, shared)
if shared is None:
## load another task inside this task.
opt2 = {
'datapath': opt['datapath'],
'datatype': opt['datatype'].replace('train:stream', 'train'),
'task': 'dialogue_safety:multiturn',
'single_turn': False,
}
teacher = create_task_agent_from_taskname(opt2)[0]
teacher.reset()
self.teacher_data = self.flatten(teacher, opt)
else:
self.teacher_data = shared['teacher_data']
self.reset()
def share(self):
shared = super().share()
shared['teacher_data'] = self.teacher_data
return shared
def flatten(self, teacher, opt):
d = []
for _ in range(teacher.num_examples()):
a = teacher.act()
ctxt = ''
safe = a.get('labels', a.get('eval_labels'))[0]
if self.opt['force_control'] != '':
ctxt += self.opt['force_control'] + ' '
else:
if opt['add_control']:
ctxt += safe + ' '
ctxt += a['text']
cs = ctxt.split('\n')
ctxt1 = '\n'.join(cs[:-1])
ctxt2 = cs[-1]
act = {
'id': 'generation_safety',
'text': ctxt1,
'labels': [ctxt2],
'episode_done': True,
'safe': safe,
}
act['reward'] = 1
if act['safe'] == '__notok__':
act['reward'] = -1
if opt['ok'] != 'all' and act['safe'] != opt['ok']:
pass
else:
d.append(act)
return d
def num_examples(self):
return len(self.teacher_data)
def num_episodes(self):
return len(self.teacher_data)
def get(self, episode_idx, entry_idx=None):
d = self.teacher_data[episode_idx]
return d
class DefaultTeacher(FixedDialogTeacher):
@staticmethod
def add_cmdline_args(parser):
parser = parser.add_argument_group('Multiturn Safety Teacher Args')
parser.add_argument(
'--safety-task',
type=str,
default='dialogue_safety:standard,dialogue_safety:adversarial',
# or dialogue_safety:WikiToxicComments
)
parser.add_argument('--dialogue-task', type=str, default='convai2:none')
parser.add_argument(
'--ok',
type=str,
choices=['__ok__', '__notok__', 'all'],
default='__notok__',
)
parser.add_argument(
'--add_control', type='bool', choices=[True, False], default=False
)
parser.add_argument(
'--force_control', type=str, choices=['__notok__', '', '__ok__'], default=''
)
def __init__(self, opt, shared=None):
build(opt)
super().__init__(opt, shared)
if shared is None:
## load another task inside this task.
opt2 = {
'datapath': opt['datapath'],
'datatype': opt['datatype'].rstrip(':stream'),
'task': opt['dialogue_task'],
}
teacher = create_task_agent_from_taskname(opt2)[0]
teacher.reset()
self.teacher_data = self.flatten(teacher)
opt3 = {
'datapath': opt['datapath'],
'datatype': opt['datatype'],
#'task': 'dialogue_safety:WikiToxicComments', #'dialogue_safety:standard',
#'task': 'dialogue_safety:standard',
'task': opt['safety_task'],
'round': 3,
'use_test_set': False,
'round_only': False,
'balance_data': False,
}
safety = create_task_agent_from_taskname(opt3)[0]
safety.reset()
self.safety_data = self.safety_flatten(safety, opt)
else:
self.teacher_data = shared['teacher_data']
self.safety_data = shared['safety_data']
self.reset()
def share(self):
shared = super().share()
shared['teacher_data'] = self.teacher_data
shared['safety_data'] = self.safety_data
return shared
def flatten(self, teacher):
d = []
ctxt = ''
for _ in range(teacher.num_examples()):
a = teacher.act()
if len(ctxt) > 0:
ctxt += '\n'
if 'text' not in a:
continue
ctxt += a['text']
if ('labels' not in a or a['labels'] is None) and (
'eval_labels' not in a or a['eval_labels'] is None
):
continue
act = {
'text': ctxt,
'labels': a.get('labels', a.get('eval_labels')),
'episode_done': True,
}
ctxt += '\n' + act['labels'][0]
if a['episode_done']:
ctxt = ''
d.append(act)
return d
def safety_flatten(self, teacher, opt):
label_filter = opt['ok']
d = []
ctxt = ''
for _ in range(teacher.num_examples()):
a = teacher.act()
if len(ctxt) > 0:
ctxt += '\n'
ctxt += a['text']
act = {
'text': '_silence_',
'labels': [ctxt],
'safe': a.get('labels', a.get('eval_labels'))[0],
'episode_done': True,
}
if a['episode_done']:
ctxt = ''
if label_filter != 'all' and act['safe'] != label_filter:
pass
else:
d.append(act)
return d
def num_examples(self):
return len(self.teacher_data)
def num_episodes(self):
return len(self.teacher_data)
def get(self, episode_idx, entry_idx=None):
d = self.teacher_data[episode_idx]
action = {
'id': 'generation_safety',
'text': d['text'],
'labels': d['labels'],
'episode_done': d['episode_done'],
'safe': '__ok__',
}
# add unsafe shizz
action['labels'] = self.safety_data[episode_idx % len(self.safety_data)][
'labels'
]
action['safe'] = self.safety_data[episode_idx % len(self.safety_data)]['safe']
action['reward'] = 1
if action['safe'] == '__notok__':
action['reward'] = -1
if self.opt['force_control'] != '':
action['text'] = self.opt['force_control'] + " " + action['text']
else:
if self.opt['add_control']:
if action['reward'] == 1:
action['text'] = '__ok__ ' + action['text']
else:
action['text'] = '__notok__ ' + action['text']
return action
Revise this Paste