diff --git a/Augmentation.py b/Augmentation.py new file mode 100644 index 0000000..99ea548 --- /dev/null +++ b/Augmentation.py @@ -0,0 +1,617 @@ +import random +from nltk import Tree +from tqdm import tqdm +import pandas as pd +import argparse +import numpy as np +import os +from datasets import load_dataset, Dataset, concatenate_datasets +from process_data import settings +from multiprocessing import Pool,cpu_count + + +def modify(commands, debug=0): + commands = commands.split(' ') + verb = ['look', 'jump', 'walk', 'turn', 'run'] + end_sign = ['left', 'right', 'twice', 'thrice'] + add_pos = [] + if debug: + print(commands) + for i in range(len(commands)-1): + if commands[i] in end_sign and commands[i+1] in verb: + add_pos.append(i+1) + # commands.insert(i+1,'and') + if debug: + print(commands) + if commands[i] in verb and commands[i+1] in verb: + add_pos.append(i+1) + # commands.insert(i+1,'and') + if debug: + print(commands) + for i, pos in enumerate(add_pos): + commands.insert(pos+i, 'and') + if debug: + print(commands) + return ' '.join(commands) + + +def c2a(commands, debug=0): + verb = {'look': 'I_LOOK', 'walk': 'I_WALK', + 'run': 'I_RUN', 'jump': 'I_JUMP'} + direction = {'left': 'I_TURN_LEFT', 'right': 'I_TURN_RIGHT'} + times = {'twice': 2, 'thrice': 3} + conjunction = ['and', 'after'] + + commands = commands.split(' ') + actions = [] + previous_command = [] + pre_actions = [] + flag = 0 + i = 0 + if debug: + print('raw:', commands) + while len(commands) > 0: + current = commands.pop(0) + if debug: + print('-'*50) + print('step ', i) + i += 1 + print('current command:', current, len(commands)) + print('curret waiting commands list:', previous_command) + print('already actions:', actions) + print('previous waiting actions:', pre_actions) + if current in verb.keys() or current == 'turn' or current in conjunction: # add new actions + if current == 'and': + continue + if not previous_command: # initialization + previous_command.append(current) + + else: # one conmands over + if debug: + print('##### one commands over#####') + current_action = translate(previous_command) + previous_command = [] + if debug: + print( + '****got new action from previous commandsa list:{}****'.format(current_action[0])) + + if current == 'after': + pre_actions.extend(current_action) + if debug: + print('****this action into pre_actions****') + elif pre_actions: + if debug: + print( + '****pre_actions and current_actions into action list****') + actions.extend(current_action) + actions.extend(pre_actions) + pre_actions = [] + previous_command.append(current) + else: + # current is a verb + previous_command.append(current) + actions.extend(current_action) + else: + previous_command.append(current) + if previous_command: + current_action = translate(previous_command) + actions.extend(current_action) + if pre_actions: + actions.extend(pre_actions) + if debug: + print('-'*50) + print('over') + print('previous_command', previous_command) + print('pre_actions', pre_actions) + print('current action', current_action) + return actions + + +def translate(previous_command): + verb = {'look': 'I_LOOK', 'walk': 'I_WALK', + 'run': 'I_RUN', 'jump': 'I_JUMP'} + direction = {'left': 'I_TURN_LEFT', 'right': 'I_TURN_RIGHT'} + times = {'twice': 2, 'thrice': 3} + conjunction = ['and', 'after'] + if previous_command[-1] in times.keys(): + return translate(previous_command[:-1])*times[previous_command[-1]] + if len(previous_command) == 1: + return [verb[previous_command[0]]] + elif len(previous_command) == 2: + if previous_command[0] == 'turn': + return [direction[previous_command[1]]] + elif previous_command[1] in direction: + return [direction[previous_command[1]], verb[previous_command[0]]] + elif len(previous_command) == 3: + if previous_command[0] == 'turn': + if previous_command[1] == 'opposite': + return [direction[previous_command[2]]]*2 + else: + return [direction[previous_command[2]]]*4 + elif previous_command[0] in verb.keys(): + if previous_command[1] == 'opposite': + return [direction[previous_command[2]], direction[previous_command[2]], verb[previous_command[0]]] + else: + return [direction[previous_command[2]], verb[previous_command[0]]]*4 +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) +def subtree_exchange_scan(args,parsing1,parsing2): + new_sentence=None + try: + if args.debug: + print('check5') + t1 = Tree.fromstring(parsing1) + t2 = Tree.fromstring(parsing2) + t1_len=len(t1.leaves()) + t2_len=len(t2.leaves()) + # ----- restrict label-------------- + # candidate_subtree1=list(t1.subtrees(lambda t: t.label() in ['VP','VB'])) + # candidate_subtree2 = list(t2.subtrees(lambda t: t.label() in ['VP', 'VB'])) + # tree_labels1 = [tree.label() for tree in candidate_subtree1] + # tree_labels2 = [tree.label() for tree in candidate_subtree2] + # same_labels = list(set(tree_labels1) & set(tree_labels2)) + # if not same_labels: + # if args.debug: + # print('no same label') + # return None + # select_label=random.choice(same_labels) + # candidate1 = random.choice( + # [t for t in candidate_subtree1 if t.label() == select_label]) + # candidate2 = random.choice( + # [t for t in candidate_subtree2 if t.label() == select_label]) + candidate_subtree1 = list(t1.subtrees()) + candidate_subtree2 = list(t2.subtrees()) + candidate1 = random.choice( + [t for t in candidate_subtree1]) + candidate2 = random.choice( + [t for t in candidate_subtree2]) + exchanged_span = ' '.join(candidate1.leaves()) + exchanging_span = ' '.join(candidate2.leaves()) + original_sentence = ' '.join(t1.leaves()) + new_sentence = original_sentence.replace(exchanged_span, exchanging_span) + debug=0 + if args.debug: + print('check6') + print(new_sentence) + debug=1 + modified_sentence=modify(new_sentence,debug) + new_label=c2a(modified_sentence,debug) + if args.showinfo: + print('cand1:', ' '.join(candidate1.leaves()), + 'cand2:', ' '.join(candidate2.leaves())) + # print([' '.join(c.leaves()) for c in cand1]) + # print([' '.join(c.leaves()) for c in cand2]) + print('src1:', parsing1) + print('src2:', parsing2) + print('new:',new_sentence) + return modified_sentence,new_label + except Exception as e: + if args.debug: + print('Error!!') + print(e) + return None +def subtree_exchange_single(args,parsing1,label1,parsing2,label2,lam1,lam2): + """ + For a pair sentence, exchange subtree and return a label based on subtree length + + Find the candidate subtree, and extract correspoding span, and exchange span + + """ + if args.debug: + print('check4') + assert lam1>lam2 + t1=Tree.fromstring(parsing1) + original_sentence=' '.join(t1.leaves()) + t1_len=len(t1.leaves()) + candidate_subtree1=list(t1.subtrees(lambda t: lam1>len(t.leaves())/t1_len>lam2)) + t2=Tree.fromstring(parsing2) + candidate_subtree2=list(t2.subtrees(lambda t: lam1>len(t.leaves())/t1_len>lam2)) + if args.debug: + print('check5') + # print('subtree1:',len(candidate_subtree1),'\nsubtree2:',len(candidate_subtree2)) + if len(candidate_subtree1)==0 or len(candidate_subtree2)==0: + + return None + if args.debug: + print('check6') + if args.phrase_label: + if args.debug: + print('phrase_label') + tree_labels1=[tree.label() for tree in candidate_subtree1] + tree_labels2=[tree.label() for tree in candidate_subtree2] + same_labels=list(set(tree_labels1)&set(tree_labels2)) + if not same_labels: + # print('无相同类型的子树') + return None + if args.phrase_length: + if args.debug: + print('phrase_lable_length') + candidate=[(t1,t2) for t1 in candidate_subtree1 for t2 in candidate_subtree2 if len(t1.leaves())==len(t2.leaves()) and t1.label()==t2.label()] + candidate1,candidate2= random.choice(candidate) + else: + if args.debug: + print('phrase_lable') + select_label=random.choice(same_labels) + candidate1=random.choice([t for t in candidate_subtree1 if t.label()==select_label]) + candidate2=random.choice([t for t in candidate_subtree2 if t.label()==select_label]) + else: + if args.debug: + print('no phrase_label') + if args.phrase_length: + if args.debug: + print('phrase_length') + candidate=[(t1,t2) for t1 in candidate_subtree1 for t2 in candidate_subtree2 if len(t1.leaves())==len(t2.leaves())] + candidate1,candidate2= random.choice(candidate) + else: + if args.debug: + print('normal TreeMix') + candidate1=random.choice(candidate_subtree1) + candidate2=random.choice(candidate_subtree2) + + exchanged_span=' '.join(candidate1.leaves()) + exchanged_len=len(candidate1.leaves()) + exchanging_span=' '.join(candidate2.leaves()) + new_sentence=original_sentence.replace(exchanged_span,exchanging_span) + # if args.mixup_cross: + new_label=np.zeros(len(args.label_list)) + + exchanging_len=len(candidate2.leaves()) + new_len=t1_len-exchanged_len+exchanging_len + + new_label[int(label2)]+=exchanging_len/new_len + new_label[int(label1)]+=(new_len-exchanging_len)/new_len + + # else: + # new_label=label1 + if args.showinfo: + # print('树1 {}'.format(t1)) + # print('树2 {}'.format(t2)) + print('-'*50) + print('candidate1:{}'.format([' '.join(x.leaves()) for x in candidate_subtree1])) + print('candidate2:{}'.format([' '.join(x.leaves()) for x in candidate_subtree2])) + print('sentence1 ## {} [{}]\nsentence2 ## {} [{}]'.format(original_sentence,label1,' '.join(t2.leaves()),label2)) + print('{} <=========== {}'.format(exchanged_span,exchanging_span)) + print('new sentence: ## {}'.format(new_sentence)) + print('new label:[{}]'.format(new_label)) + return new_sentence,new_label +def subtree_exchange_pair(args,parsing11,parsing12,label1,parsing21,parsing22,label2,lam1,lam2): + """ + For a pair sentence, exchange subtree and return a label based on subtree length + + Find the candidate subtree, and extract correspoding span, and exchange span + + """ + assert lam1>lam2 + lam2=lam1-0.2 + t11=Tree.fromstring(parsing11) + t12=Tree.fromstring(parsing12) + original_sentence1=' '.join(t11.leaves()) + t11_len=len(t11.leaves()) + original_sentence2=' '.join(t12.leaves()) + t12_len=len(t12.leaves()) + candidate_subtree11=list(t11.subtrees(lambda t: lam1>len(t.leaves())/t11_len>lam2)) + candidate_subtree12=list(t12.subtrees(lambda t: lam1>len(t.leaves())/t12_len>lam2)) + t21=Tree.fromstring(parsing21) + t22=Tree.fromstring(parsing22) + t21_len=len(t21.leaves()) + t22_len=len(t22.leaves()) + candidate_subtree21=list(t21.subtrees(lambda t: lam1>len(t.leaves())/t11_len>lam2)) + candidate_subtree22=list(t22.subtrees(lambda t: lam1>len(t.leaves())/t12_len>lam2)) + if args.showinfo: + print('\n') + print('*'*50) + print('t11_len:{}\tt12_len:{}\tt21_len:{}\tt22_len:{}\ncandidate_subtree11:{}\ncandidate_subtree12:{}\ncandidate_subtree21:{}\ncandidate_subtree21:{}' + .format(t11_len,t12_len,t21_len,t22_len,candidate_subtree11,candidate_subtree12,candidate_subtree21,candidate_subtree22)) + + # print('subtree1:',len(candidate_subtree1),'\nsubtree2:',len(candidate_subtree2)) + if len(candidate_subtree11)==0 or len(candidate_subtree12)==0 or len(candidate_subtree21)==0 or len(candidate_subtree22)==0: + # print("this pair fail",len(candidate_subtree1),len(candidate_subtree2)) + return None + + if args.phrase_label: + tree_labels11=[tree.label() for tree in candidate_subtree11] + tree_labels12=[tree.label() for tree in candidate_subtree12] + tree_labels21=[tree.label() for tree in candidate_subtree21] + tree_labels22=[tree.label() for tree in candidate_subtree22] + same_labels1=list(set(tree_labels11)&set(tree_labels21)) + same_labels2=list(set(tree_labels12)&set(tree_labels22)) + if not (same_labels1 and same_labels2) : + # print('无相同类型的子树') + return None + select_label1=random.choice(same_labels1) + select_label2=random.choice(same_labels2) + displaced1=random.choice([t for t in candidate_subtree11 if t.label()==select_label1]) + displacing1=random.choice([t for t in candidate_subtree21 if t.label()==select_label1]) + displaced2=random.choice([t for t in candidate_subtree12 if t.label()==select_label2]) + displacing2=random.choice([t for t in candidate_subtree22 if t.label()==select_label2]) + else: + displaced1=random.choice(candidate_subtree11) + displacing1=random.choice(candidate_subtree21) + displaced2=random.choice(candidate_subtree12) + displacing2=random.choice(candidate_subtree22) + + + displaced_span1=' '.join(displaced1.leaves()) + displaced_len1=len(displaced1.leaves()) + displacing_span1=' '.join(displacing1.leaves()) + new_sentence1=original_sentence1.replace(displaced_span1,displacing_span1) + + displaced_span2=' '.join(displaced2.leaves()) + displaced_len2=len(displaced2.leaves()) + displacing_span2=' '.join(displacing2.leaves()) + new_sentence2=original_sentence2.replace(displaced_span2,displacing_span2) + + # if args.mixup_cross: + new_label=np.zeros(len(args.label_list)) + displacing_len1=len(displacing1.leaves()) + displacing_len2=len(displacing2.leaves()) + new_len=t11_len+t12_len-displaced_len1-displaced_len2+displacing_len1+displacing_len2 + displacing_len=displacing_len1+displacing_len2 + new_label[int(label2)]+=displacing_len/new_len + new_label[int(label1)]+=(new_len-displacing_len)/new_len + + + if args.showinfo: + print('Before\nsentence1:{}\nsentence2:{}\nlabel1:{}\tlabel2:{}'.format(original_sentence1,original_sentence2,label1,label2)) + print('replaced1:{} replacing1:{}\nreplaced2:{} replacing2:{}'.format(displaced_span1,displacing_span1,displaced_span2,displacing2)) + print('After\nsentence1:{}\nsentence2:{}\nnew_label:{}'.format(new_sentence1,new_sentence2,new_label)) + print('*'*50) + + # print('被替换的span:{}\n替换的span:{}'.format(exchanged_span,exchanging_span)) + return new_sentence1,new_sentence2,new_label +def augmentation(args,data,seed,dataset,aug_times,lam1=0.1,lam2=0.3): + """ + generate aug_num augmentation dataset + input: + dataset --- pd.dataframe + output: + aug_dataset --- pd.dataframe + """ + generated_list=[] + # print('check2') + if args.debug: + print('check3') + shuffled_dataset=dataset.shuffle() + success=0 + total=0 + with tqdm(total=int(aug_times)*len(dataset)) as bar: + while success < int(aug_times)*len(dataset): + # for idx in range(len(dataset)): + idx = total % len(dataset) + if args.fraction: + bar.set_description('| Dataset:{:<5} | seed:{} | times:{} | fraction:{} |'.format(data,seed,aug_times,args.fraction)) + else: + bar.set_description('| Dataset:{:<5} | seed:{} | times:{} | '.format(data,seed,aug_times)) + + if args.data_type=='single_cls': + if args.debug: + print('check4') + if 'None' not in [dataset[idx]['parsing1'], shuffled_dataset[idx]['parsing1']]: + aug_sample=subtree_exchange_single( + args,dataset[idx]['parsing1'],dataset[idx][args.label_name], + shuffled_dataset[idx]['parsing1'],shuffled_dataset[idx][args.label_name], + lam1,lam2) + else: + continue + + elif args.data_type=='pair_cls': + # print('check4:pair') + if args.debug: + print('check4') + if 'None' not in [dataset[idx]['parsing1'], dataset[idx]['parsing2'], dataset[idx][args.label_name], + shuffled_dataset[idx]['parsing1'], shuffled_dataset[idx]['parsing2']]: + aug_sample=subtree_exchange_pair( + args,dataset[idx]['parsing1'],dataset[idx]['parsing2'],dataset[idx][args.label_name], + shuffled_dataset[idx]['parsing1'],shuffled_dataset[idx]['parsing2'],shuffled_dataset[idx][args.label_name], + lam1,lam2) + else: + continue + + elif args.data_type=='semantic_parsing': + if args.debug: + print('check4') + aug_sample=subtree_exchange_scan( + args,dataset[idx]['parsing1'], + shuffled_dataset[idx]['parsing1']) + + if args.debug: + print('ok') + print('got one aug_sample : {}'.format(aug_sample)) + if aug_sample: + bar.update(1) + success+=1 + generated_list.append(aug_sample) + else: + if args.debug: + print('fail this time') + total+=1 + #De-duplication + # generated_list=list(set(generated_list)) + return generated_list +def parse_argument(): + parser=argparse.ArgumentParser() + parser.add_argument('--lam1',type=float,default=0.3) + parser.add_argument('--lam2',type=float,default=0.1) + parser.add_argument('--times',default=[2,5],nargs='+',help='augmentation times list') + parser.add_argument('--min_token',type=int,default=0,help='minimum token numbers of augmentation samples') + parser.add_argument('--label_name',type=str,default='label') + parser.add_argument('--phrase_label',action='store_true',help='subtree lable must be same if set') + parser.add_argument('--phrase_length',action='store_true',help='subtree phrase must be same length if set') + # parser.add_argument('--data_type',type=str,required=True,help='This is a single classification task or pair sentences classification task') + parser.add_argument('--seeds',default=[0,1,2,3,4],nargs='+',help='seed list') + parser.add_argument('--showinfo',action='store_true') + parser.add_argument('--mixup_cross',action='store_false',help="NO mix across different classes if set") + parser.add_argument('--low_resource',action='store_true',help="create low source raw and aug datasets if set") + parser.add_argument('--debug',action='store_true',help="display debug information") + parser.add_argument('--data',nargs='+',required=True,help='data list') + parser.add_argument('--proc',type=int,help='processing number for multiprocessing') + args=parser.parse_args() + if not args.proc: + args.proc=cpu_count() + return args +def create_aug_data(args,dataset,data,seed,times,test_dataset=None): + + if args.phrase_label and not args.phrase_length: + prefix_save_path=os.path.join(args.output_dir,'samephraselabel_times{}_min{}_seed{}_{}_{}'.format(times,args.min_token,seed,args.lam1,args.lam2)) + elif args.phrase_length and not args.phrase_label: + prefix_save_path=os.path.join(args.output_dir,'samephraselength_times{}_min{}_seed{}_{}_{}'.format(times,args.min_token,seed,args.lam1,args.lam2)) + elif args.phrase_length and args.phrase_label: + prefix_save_path=os.path.join(args.output_dir,'samephraselabel_length_times{}_min{}_seed{}_{}_{}'.format(times,args.min_token,seed,args.lam1,args.lam2)) + elif not args.mixup_cross: + prefix_save_path=os.path.join(args.output_dir,'sameclass_times{}_min{}_seed{}_{}_{}'.format(times,args.min_token,seed,args.lam1,args.lam2)) + elif args.data_type == 'semantic_parsing': + prefix_save_path = os.path.join(args.output_dir, 'scan_times{}_seed{}'.format( + times, seed)) + else: + prefix_save_path=os.path.join(args.output_dir,'times{}_min{}_seed{}_{}_{}'.format(times,args.min_token,seed,args.lam1,args.lam2)) + if args.debug: + print('check1') + if not [file_name for file_name in os.listdir(args.output_dir) if file_name.startswith(prefix_save_path)]: + if args.min_token: + dataset=dataset.filter(lambda sample: len(sample[tasksettings.task_to_keys[data][0]].split(' '))>args.min_token) + if tasksettings.task_to_keys[data][1]: + dataset=dataset.filter(lambda sample: len(sample[tasksettings.task_to_keys[data][1]].split(' '))>args.min_token) + if args.data_type=='single_cls': + if args.debug: + print('check2') + if args.mixup_cross: + new_pd=pd.DataFrame(augmentation(args,data,seed,dataset,times,args.lam1,args.lam2),columns=[tasksettings.task_to_keys[data][0],args.label_name]) + else: + if args.debug: + print('label_list',args.label_list) + new_pd=None + for i in args.label_list: + samples=dataset.filter(lambda sample:sample[args.label_name]==i) + dataframe=pd.DataFrame(augmentation(args,data,seed,samples,times,args.lam1,args.lam2),columns=[tasksettings.task_to_keys[data][0],args.label_name]) + new_pd=pd.concat([new_pd,dataframe],axis=0) + elif args.data_type=='pair_cls': + if args.debug: + print('check2') + if args.mixup_cross: + # print('check1') + # print(args, seed, dataset, times,tasksettings.task_to_keys[data][0], tasksettings.task_to_keys[data][1], args.label_name) + new_pd=pd.DataFrame(augmentation(args,data,seed,dataset,times,args.lam1,args.lam2),columns=[tasksettings.task_to_keys[data][0],tasksettings.task_to_keys[data][1],args.label_name]) + else: + new_pd=None + if args.debug: + print('label_list',args.label_list) + for i in args.label_list: + samples=dataset.filter(lambda sample:sample[args.label_name]==i) + dataframe=pd.DataFrame(augmentation(args,data,seed,samples,times,args.lam1,args.lam2),columns=[tasksettings.task_to_keys[data][0],tasksettings.task_to_keys[data][1],args.label_name]) + new_pd=pd.concat([new_pd,dataframe],axis=0) + elif args.data_type=='semantic_parsing': + if args.debug: + print('check2') + new_pd=pd.DataFrame(augmentation(args,data,seed,dataset,times),columns=[tasksettings.task_to_keys[data][0],args.label_name]) + + + new_pd=new_pd.sample(frac=1) + + + + if args.data_type == 'semantic_parsing': + + train_pd=pd.read_csv('DATA/ADDPRIM_JUMP/data/train.csv') + frames = [train_pd,new_pd] + aug_dataset=pd.concat(frames,ignore_index=True) + else: + aug_dataset = Dataset.from_pandas(new_pd) + aug_dataset = aug_dataset.remove_columns("__index_level_0__") + + if args.phrase_label: + save_path = os.path.join(args.output_dir, 'samephraselabel_times{}_min{}_seed{}_{}_{}_{}k'.format( + times, args.min_token, seed, args.lam1, args.lam2, round(len(new_pd)//1000,-1))) + elif args.phrase_length and not args.phrase_label: + save_path=os.path.join(args.output_dir,'samephraselength_times{}_min{}_seed{}_{}_{}_{}k'.format(times,args.min_token,seed,args.lam1,args.lam2,round(len(new_pd)//1000,-1))) + elif args.phrase_length and args.phrase_label: + save_path=os.path.join(args.output_dir,'samephraselabel_length_times{}_min{}_seed{}_{}_{}_{}k'.format(times,args.min_token,seed,args.lam1,args.lam2,round(len(new_pd)//1000,-1))) + elif not args.mixup_cross: + save_path=os.path.join(args.output_dir,'sameclass_times{}_min{}_seed{}_{}_{}_{}k'.format(times,args.min_token,seed,args.lam1,args.lam2,round(len(new_pd)//1000,-1))) + elif args.data_type=='semantic_parsing': + save_path_train = os.path.join(prefix_save_path, 'train.csv') + save_path_test = os.path.join(prefix_save_path, 'test.csv') + else: + save_path=os.path.join(args.output_dir,'times{}_min{}_seed{}_{}_{}_{}k'.format(times,args.min_token,seed,args.lam1,args.lam2,round(len(new_pd)//1000,-1))) + if args.data_type == 'semantic_parsing': + + + if not os.path.exists(prefix_save_path): + os.makedirs(prefix_save_path) + + aug_dataset.to_csv(save_path_train,index=0) + test_dataset.to_csv(save_path_test,index=0) + else: + aug_dataset.save_to_disk(save_path) + else: + print('file {} already exsits!'.format(prefix_save_path)) + + +def main(): + p=Pool(args.proc) + for data in args.data: + path_dir=os.path.join('DATA',data.upper()) + if data in tasksettings.pair_datasets: + args.data_type='pair_cls' + elif data in tasksettings.SCAN: + args.label_name='actions' + args.data_type='semantic_parsing' + testset_path=os.path.join(path_dir,'data','test.csv') + else: + args.data_type='single_cls' + if data=='trec': + try: + assert args.label_name in ['label-fine', 'label-coarse'] + except AssertionError: + raise(AssertionError( + "If you want to train on trec dataset with augmentation, you have to name the label of split in ['label-fine', 'label-coarse']")) + + print(args.label_name,data) + args.output_dir=os.path.join(path_dir,'generated/{}'.format(args.label_name)) + else: + args.output_dir=os.path.join(path_dir,'generated') + args.data_path=os.path.join(path_dir,'data','train_parsing.csv') + + dataset=load_dataset('csv',data_files=[args.data_path],split='train') + if args.data_type=='semantic_parsing': + testset=load_dataset('csv',data_files=[testset_path],split='train') + if args.data_type in ['single_cls','pair_cls']: + args.label_list=list(set(dataset[args.label_name])) #根据data做一个表查找所有的label + for seed in args.seeds: + seed=int(seed) + set_seed(seed) + dataset=dataset.shuffle() + if args.low_resource: + for fraction in tasksettings.low_resource[data]: + args.fraction=fraction + train_dataset=dataset.select(random.sample(range(len(dataset)),int(fraction*len(dataset)))) + low_resource_dir=os.path.join(path_dir,'low_resource','low_resource_{}'.format(fraction),'seed_{}'.format(seed)) + if not os.path.exists(low_resource_dir): + os.makedirs(low_resource_dir) + args.output_dir=low_resource_dir + train_path=os.path.join(args.output_dir,'partial_train') + if not os.path.exists(train_path): + train_dataset.save_to_disk(train_path) + for times in args.times: + times=int(times) + p.apply_async(create_aug_data, args=( + args, train_dataset, data, seed, times)) + else: + args.fraction=None + for times in args.times: + times=int(times) + p.apply_async(create_aug_data, args=( + args, dataset, data, seed, times,testset)) + print('='*20,'Start generating augmentation datsets !',"="*20) + # p.close() + # p.join() + + p.close() + p.join() + print('='*20, 'Augmenatation done !', "="*20) +if __name__=='__main__': + + tasksettings=settings.TaskSettings() + args=parse_argument() + print(args) + main() diff --git a/batch_train.py b/batch_train.py new file mode 100644 index 0000000..2c5e72f --- /dev/null +++ b/batch_train.py @@ -0,0 +1,102 @@ +import argparse +import os +from process_data.settings import TaskSettings +def parse_argument(): + parser = argparse.ArgumentParser(description='download and parsing datasets') + parser.add_argument('--data',type=str,required=True,help='data list') + parser.add_argument('--aug_dir',help='Augmentation file directory') + parser.add_argument('--seeds',default=[0,1,2,3,4],nargs='+',help='seed list') + parser.add_argument('--modes',nargs='+',required=True,help='seed list') + parser.add_argument('--label_name',type=str,default='label') + # parser.add_argument('--batch_size',default=128,type=int,help='train examples in each batch') + # parser.add_argument('--aug_batch_size',default=128,type=int,help='train examples in each batch') + parser.add_argument('--random_mix',type=str,choices=['zero_one','zero','one','all'],help="random mixup ") + parser.add_argument('--prefix',type=str,help="only choosing the datasets with the prefix,for ablation study") + parser.add_argument('--GPU',type=int,default=0,help="available GPU number") + parser.add_argument('--low_resource', action='store_true', + help='whther to train low resource dataset') + + args=parser.parse_args() + if args.data=='trec': + try: + assert args.label_name in ['label-fine','label-coarse'] + except AssertionError: + raise( AssertionError("If you want to train on TREC dataset with augmentation, you have to name the label of split either 'label-fine' or 'label-coarse'")) + args.aug_dir = os.path.join('DATA', args.data.upper(), 'generated',args.label_name) + if args.aug_dir is None : + args.aug_dir=os.path.join('DATA',args.data.upper(),'generated') + + if 'aug' in args.modes: + try: + assert [file for file in os.listdir(args.aug_dir) if 'times' in file] + except AssertionError: + raise( AssertionError( "{}".format('This directory has no augmentation file, please input correct aug_dir!') ) ) + if args.low_resource: + try: + args.low_resource = os.path.join('DATA', args.data.upper(),'low_resource') + assert os.path.exists(args.low_resource) + except AssertionError: + raise( AssertionError("There is no any low resource datasets in this data")) + + return args +def batch_train(args): + for seed in args.seeds: + # for aug_file in os.listdir(args.aug_dir): + for mode in args.modes: + if mode=='raw': + # data_path=os.path.join(args.aug_dir,aug_file) + if args.random_mix: + os.system('CUDA_VISIBLE_DEVICES={} python run.py --label_name {} --mode {} --seed {} --data {} --random_mix {} --epoch {epoch} --batch_size {batch_size} --aug_batch_size {aug_batch_size} --val_steps {val_steps} --max_length {max_length} --augweight {augweight} '.format(args.GPU,args.label_name,mode,int(seed),args.data,args.random_mix,**settings[args.data])) + else: + os.system('CUDA_VISIBLE_DEVICES={} python run.py --label_name {} --mode {} --seed {} --data {} --epoch {epoch} --batch_size {batch_size} --aug_batch_size {aug_batch_size} --val_steps {val_steps} --max_length {max_length} --augweight {augweight} '.format(args.GPU,args.label_name,mode,int(seed),args.data,**settings[args.data])) + else: + for aug_file in os.listdir(args.aug_dir): + if args.prefix: + # only train on file with prefix + if aug_file.startswith(args.prefix): + aug_file_path = os.path.join( + args.aug_dir, aug_file) + assert os.path.exists(aug_file_path) + os.system('CUDA_VISIBLE_DEVICES={} python run.py --label_name {} --mode {} --seed {} --data {} --data_path {} --epoch {epoch} --batch_size {batch_size} --aug_batch_size {aug_batch_size} --val_steps {val_steps} --max_length {max_length} --augweight {augweight} '.format( + args.GPU, args.label_name, mode, int(seed), args.data, aug_file_path, **settings[args.data])) + else: + # train on every file in dir + aug_file_path = os.path.join( + args.aug_dir, aug_file) + assert os.path.exists(aug_file_path) + os.system('CUDA_VISIBLE_DEVICES={} python run.py --label_name {} --mode {} --seed {} --data {} --data_path {} --epoch {epoch} --batch_size {batch_size} --aug_batch_size {aug_batch_size} --val_steps {val_steps} --max_length {max_length} --augweight {augweight} '.format( + args.GPU, args.label_name, mode, int(seed), args.data, aug_file_path, **settings[args.data])) +def low_resource_train(args): + for partial_split in os.listdir(args.low_resource): + partial_split_path=os.path.join(args.low_resource,partial_split) + args.output_dir = os.path.join( + args.low_resource_dir, partial_split) + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + for seed_num in os.listdir(partial_split_path): + partial_split_seed_path=os.path.join(partial_split_path,seed_num) + for mode in args.modes: + if mode=='raw': + if args.random_mix: + os.system('CUDA_VISIBLE_DEVICES={} python run.py --low_resource_dir {} --seed {} --output_dir {} --label_name {} --mode {} --data {} --random_mix {} --epoch {epoch} --batch_size {batch_size} --aug_batch_size {aug_batch_size} --val_steps {val_steps} --max_length {max_length} --augweight {augweight} ' + .format(args.GPU, partial_split_seed_path, int(seed_num.split('_')[1]), args.output_dir, args.label_name, mode, args.data, args.random_mix, **settings[args.data])) + else: + os.system('CUDA_VISIBLE_DEVICES={} python run.py --low_resource_dir {} --seed {} --output_dir {} --label_name {} --mode {} --data {} --epoch {epoch} --batch_size {batch_size} --aug_batch_size {aug_batch_size} --val_steps {val_steps} --max_length {max_length} --augweight {augweight} ' + .format(args.GPU, partial_split_seed_path, int(seed_num.split('_')[1]),args.output_dir, args.label_name, mode, args.data, **settings[args.data])) + elif mode=='raw_aug': + for aug_file in [file for file in os.listdir(partial_split_seed_path) if file.startswith('times')]: + aug_file_path=os.path.join(partial_split_seed_path,aug_file) + assert os.path.exists(aug_file_path) + os.system('CUDA_VISIBLE_DEVICES={} python run.py --low_resource_dir {} --seed {} --output_dir {} --label_name {} --mode {} --data {} --data_path {} --epoch {epoch} --batch_size {batch_size} --aug_batch_size {aug_batch_size} --val_steps {val_steps} --max_length {max_length} --augweight {augweight} '.format( + args.GPU, partial_split_seed_path, int(seed_num.split('_')[1]) , args.output_dir, args.label_name, mode, args.data, aug_file_path, **settings[args.data])) +if __name__=='__main__': + args=parse_argument() + tasksettings=TaskSettings() + settings=tasksettings.train_settings + if args.low_resource: + args.low_resource_dir=os.path.join('DATA',args.data.upper(),'runs','low_resource') + if not os.path.exists(args.low_resource_dir): + os.makedirs(args.low_resource_dir) + low_resource_train(args) + else: + batch_train(args) diff --git a/online_augmentation/__init__.py b/online_augmentation/__init__.py new file mode 100644 index 0000000..df94742 --- /dev/null +++ b/online_augmentation/__init__.py @@ -0,0 +1,131 @@ +import torch +import random +import numpy as np +def random_mixup_process(args,ids1,lam): + + rand_index=torch.randperm(ids1.shape[0]) + lenlist=[] + # rand_index=torch.randperm(len(ids1)) + for x in ids1: + mask=((x!=101)&(x!=0)&(x!=102)) + lenlist.append(int(mask.sum())) + lenlist2=torch.tensor(lenlist)[rand_index] + spanlen=torch.tensor([int(x*lam) for x in lenlist]) + + beginlist=[1+random.randint(0,x-int(x*lam)) for x in lenlist] + beginlist2=[1+random.randint(0,x-y) for x,y in zip(lenlist2,spanlen)] + if args.difflen: + + spanlen2=torch.tensor([int(x*lam) for x in lenlist2]) + # beginlist2=[1+random.randint(0,x-y) for x,y in zip(lenlist2,spanlen2)] + spanlist2=[(x,int(y)) for x,y in zip(beginlist2,spanlen2)] + else: + # beginlist2=[1+random.randint(0,x-y) for x,y in zip(lenlist2,spanlen)] + spanlist2=[(x,int(y)) for x,y in zip(beginlist2,spanlen)] + spanlist=[(x,int(y)) for x,y in zip(beginlist,spanlen)] + + ids2=ids1.clone() + if args.difflen: + for idx in range(len(ids1)): + tmp=torch.cat((ids1[idx][:spanlist[idx][0]],ids2[rand_index[idx]][spanlist2[idx][0]:spanlist2[idx][0]+spanlist2[idx][1]],ids1[idx][spanlist[idx][0]+spanlist[idx][1]:]),dim=0)[:ids1.shape[1]] + ids1[idx]=torch.cat((tmp,torch.zeros(ids1.shape[1]-len(tmp)))) + else: + for idx in range(len(ids1)): + ids1[idx][spanlist[idx][0]:spanlist[idx][0]+spanlist[idx][1]]=ids2[rand_index[idx]][spanlist2[idx][0]:spanlist2[idx][0]+spanlist2[idx][1]] + assert ids1.shape==ids2.shape + return ids1,rand_index +def mixup_01(args,input_ids,lam,idx1,idx2): + ''' + 01交换 + ''' + difflen=False + random_index=torch.zeros(len(idx1)+len(idx2)).long() + random_index[idx1]=torch.tensor(np.random.choice(idx2,size=len(idx1))) + random_index[idx2]=torch.tensor(np.random.choice(idx1,size=len(idx2))) + + len_list1=[] + len_list2=[] + for input_id1 in input_ids: + #计算各个句子的具体token数 + mask=((input_id1!=101)&(input_id1!=0)&(input_id1!=102)) + len_list1.append(int(mask.sum())) + # print(len_list1) + len_list2=torch.tensor(len_list1)[random_index] + + spanlen=torch.tensor([int(x*lam) for x in len_list1]) + beginlist=[1+random.randint(0,x-int(x*lam)) for x in len_list1] + beginlist2=[1+random.randint(0,x-y) for x,y in zip(len_list2,spanlen)] + if difflen: + spanlen2=torch.tensor([int(x*lam) for x in len_list2]) + # beginlist2=[1+random.randint(0,x-y) for x,y in zip(lenlist2,spanlen2)] + spanlist2=[(x,int(y)) for x,y in zip(beginlist2,spanlen2)] + else: + # beginlist2=[1+random.randint(0,x-y) for x,y in zip(lenlist2,spanlen)] + spanlist2=[(x,int(y)) for x,y in zip(beginlist2,spanlen)] + spanlist=[(x,int(y)) for x,y in zip(beginlist,spanlen)] + new_ids=input_ids.clone() + + # print(random_index) + if difflen: + for idx in range(len(ids1)): + tmp=torch.cat((ids1[idx][:spanlist[idx][0]],ids2[rand_index[idx]][spanlist2[idx][0]:spanlist2[idx][0]+spanlist2[idx][1]],ids1[idx][spanlist[idx][0]+spanlist[idx][1]:]),dim=0)[:ids1.shape[1]] + ids1[idx]=torch.cat((tmp,torch.zeros(ids1.shape[1]-len(tmp)))) + else: + for idx in range(len(input_ids)): + new_ids[idx][spanlist[idx][0]:spanlist[idx][0]+spanlist[idx][1]]=input_ids[random_index[idx]][spanlist2[idx][0]:spanlist2[idx][0]+spanlist2[idx][1]] + # for i in range(len(input_ids)): + # print('{}:交换的是{}与{},其中第1句选取的是从{}开始到{}的句子,第2句选取的是从{}开始到{}结束的句子'.format( + # i,i,random_index[i],spanlist[i][0],spanlist[i][0]+spanlist[i][1],spanlist2[i][0],spanlist2[i][0]+spanlist2[i][1])) + return new_ids,random_index +def mixup(args,input_ids,lam,idx1,idx2=None): + ''' + 只针对idx1索引对应的sample内部进行交换,如果idx2也给的话就是idx1 idx2进行交换 + ''' + select_input_ids=torch.index_select(input_ids,0,idx1) + rand_index=torch.randperm(select_input_ids.shape[0]) + new_idx=torch.tensor(list(range(input_ids.shape[0]))) + len_list1=[] + len_list2=[] + for input_id1 in select_input_ids: + #calculte length of tokens in each sentence + mask=((input_id1!=101)&(input_id1!=0)&(input_id1!=102)) + len_list1.append(int(mask.sum())) + len_list2=torch.tensor(len_list1)[rand_index] + + spanlen=torch.tensor([int(x*lam) for x in len_list1]) + beginlist=[1+random.randint(0,x-y) for x,y in zip(len_list1,spanlen)] + beginlist2=[1+random.randint(0,max(0,x-y)) for x,y in zip(len_list2,spanlen)] + + spanlist=[(x,int(y)) for x,y in zip(beginlist,spanlen)] + spanlist2 = [(x, min(int(y),z)) for x, y, z in zip(beginlist2, spanlen, len_list2)] + new_ids=input_ids.clone() + new_idx[idx1]=idx1[rand_index] + for i,idx in enumerate(idx1): + new_ids[idx][spanlist[i][0]:spanlist[i][0]+spanlist[i][1]]=input_ids[idx1[rand_index[i]]][spanlist2[i][0]:spanlist2[i][0]+spanlist2[i][1]] + + return new_ids,new_idx +def random_mixup(args,ids1,lab1,lam): + """ + function: random select span to exchange based on lam to decide span length and rand_index decide selected candidate exchange sentece + input: + ids1 -- tensors of tensors input_ids + lab1 -- tensors of tensors labels + lam -- span length rate + output: + ids1 -- tensors of tensors , exchanged span + rand_index -- tensors , permutation index + + """ + if args.random_mix=='all': + return mixup(args,ids1,lam,torch.tensor(range(ids1.shape[0]))) + else: + pos_idx=(lab1==1).nonzero().squeeze() + neg_idx=(lab1==0).nonzero().squeeze() + pos_samples=torch.index_select(ids1,0,pos_idx) + neg_samples=torch.index_select(ids1,0,neg_idx) + if args.random_mix=='zero': + return mixup(args,ids1,lam,neg_idx) + if args.random_mix=='one': + return mixup(args,ids1,lam,pos_idx) + if args.random_mix=='zero_one': + return mixup_01(args,ids1,lam,pos_idx,neg_idx) diff --git a/online_augmentation/__pycache__/__init__.cpython-38.pyc b/online_augmentation/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..6b59384 Binary files /dev/null and b/online_augmentation/__pycache__/__init__.cpython-38.pyc differ diff --git a/online_augmentation/__pycache__/__init__.cpython-39.pyc b/online_augmentation/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..2c905d6 Binary files /dev/null and b/online_augmentation/__pycache__/__init__.cpython-39.pyc differ diff --git a/process_data/Load_data.py b/process_data/Load_data.py new file mode 100644 index 0000000..472f6de --- /dev/null +++ b/process_data/Load_data.py @@ -0,0 +1,157 @@ +from datasets import load_dataset, load_from_disk +from transformers import BertTokenizer +import torch +import numpy as np +import os +from . import settings +class DATA_process(object): + def __init__(self, args=None): + if args: + print('Initializing with args') + self.data = args.data if args.data else None + self.task = args.task if args.task else None + self.tokenizer = BertTokenizer.from_pretrained( + args.model, do_lower_case=True) if args.model else None + self.tasksettings = settings.TaskSettings() + self.max_length = args.max_length if args.max_length else None + self.label_name = args.label_name if args.label_name else None + self.batch_size = args.batch_size if args.batch_size else None + self.aug_batch_size=args.aug_batch_size if args.aug_batch_size else None + self.min_train_token = args.min_train_token if args.min_train_token else None + self.max_train_token = args.max_train_token if args.max_train_token else None + self.num_proc = args.num_proc if args.num_proc else None + self.low_resource_dir = args.low_resource_dir if args.low_resource_dir else None + self.data_path = args.data_path if args.data_path else None + self.random_mix = args.random_mix if args.random_mix else None + + def validation_data(self): + validation_set = self.validationset( + data=self.data) + print('='*20,'multiprocess processing test dataset','='*20) + # Process dataset to make dataloader + if self.task == 'single': + validation_set = validation_set.map( + self.encode, batched=True, num_proc=self.num_proc) + else: + validation_set = validation_set.map( + self.encode_pair, batched=True, num_proc=self.num_proc) + # validation_set = validation_set.map(lambda examples: {'labels': examples[args.label_name]}, batched=True) + validation_set = validation_set.rename_column( + self.label_name, "labels") + validation_set.set_format(type='torch', columns=[ + 'input_ids', 'token_type_ids', 'attention_mask', 'labels']) + + val_dataloader = torch.utils.data.DataLoader( + validation_set, batch_size=self.batch_size, shuffle=True) + return val_dataloader + def encode(self, examples): + return self.tokenizer(examples[self.tasksettings.task_to_keys[self.data][0]], max_length=self.max_length, truncation=True, padding='max_length') + def encode_pair(self, examples): + return self.tokenizer(examples[self.tasksettings.task_to_keys[self.data][0]], examples[self.tasksettings.task_to_keys[self.data][1]], max_length=self.max_length, truncation=True, padding='max_length') + + def train_data(self, count_label=False): + train_set, label_num = self.traindataset( + data=self.data, low_resource_dir=self.low_resource_dir, label_num=count_label) + print('='*20,'multiprocess processing train dataset','='*20) + if self.task == 'single': + train_set = train_set.map( + self.encode, batched=True, num_proc=self.num_proc) + else: + train_set = train_set.map( + self.encode_pair, batched=True, num_proc=self.num_proc) + if self.random_mix: + # sort the train dataset + print('-'*20, 'random_mixup', '-'*20) + train_set = train_set.map( + lambda examples: {'token_num': np.sum(np.array(examples['attention_mask']))}) + train_set = train_set.sort('token_num', reverse=True) + # train_set = train_set.map(lambda examples: {'labels': examples[args.label_name]}, batched=True) + train_set = train_set.rename_column(self.label_name, "labels") + if self.min_train_token: + print( + '-'*20, 'filter sample whose sentence shorter than {}'.format(self.min_train_token), '-'*20) + train_set = train_set.filter(lambda example: sum( + example['attention_mask']) > self.min_train_token+2) + if self.max_train_token: + print( + '-'*20, 'filter sample whose sentence longer than {}'.format(self.max_train_token), '-'*20) + train_set = train_set.filter(lambda example: sum( + example['attention_mask']) < self.max_train_token+2) + train_set.set_format(type='torch', columns=[ + 'input_ids', 'token_type_ids', 'attention_mask', 'labels']) + + train_dataloader = torch.utils.data.DataLoader( + train_set, batch_size=self.batch_size, shuffle=True) + if count_label: + return train_dataloader, label_num + else: + return train_dataloader + def augmentation_data(self): + try: + aug_dataset = load_dataset( + 'csv', data_files=[self.data_path])['train'] + except Exception as e: + aug_dataset = load_from_disk(self.data_path) + print('='*20, 'multiprocess processing aug dataset', '='*20) + if self.task == 'single': + aug_dataset = aug_dataset.map( + self.encode, batched=True, num_proc=self.num_proc) + else: + aug_dataset = aug_dataset.map( + self.encode_pair, batched=True, num_proc=self.num_proc) + # if self.mix: + # # label has more than one dimension + # # aug_dataset = aug_dataset.map(lambda examples: {'labels':examples[self.label_name]},batched=True) + # else: + # # aug_dataset = aug_dataset.map(lambda examples: {'labels':int(examples[self.label_name])}) + aug_dataset = aug_dataset.rename_column(self.label_name, 'labels') + + aug_dataset.set_format(type='torch', columns=[ + 'input_ids', 'token_type_ids', 'attention_mask', 'labels']) + aug_dataloader = torch.utils.data.DataLoader( + aug_dataset, batch_size=self.aug_batch_size, shuffle=True) + return aug_dataloader + + def validationset(self,data): + if data in ['sst2', 'rte', 'mrpc', 'qqp', 'mnli', 'qnli']: + if data == 'mnli': + validation_set = load_dataset( + 'glue', data, split='validation_mismatched') + else: + validation_set = load_dataset('glue', data, split='validation') + print('-'*20, 'Test on glue@{}'.format(data), '-'*20) + elif data in ['imdb', 'ag_news', 'trec']: + validation_set = load_dataset(data, split='test') + print('-'*20, 'Test on {}'.format(data), '-'*20) + elif data == 'sst': + validation_set = load_dataset(data, 'default', split='test') + validation_set = validation_set.map(lambda example: {'label': int( + example['label']*10//2)}, remove_columns=['tokens', 'tree'], num_proc=4) + print('-'*20, 'Test on {}'.format(data), '-'*20) + else: + validation_set = load_dataset(data, split='validation') + print('-'*20, 'Test on {}'.format(data), '-'*20) + + return validation_set + + def traindataset(self, data, low_resource_dir=None, split='train', label_num=False): + if low_resource_dir: + train_set = load_from_disk(os.path.join( + low_resource_dir, 'partial_train')) + else: + if data in ['sst2', 'rte', 'mrpc', 'qqp', 'mnli', 'qnli']: + train_set = load_dataset('glue', data, split=split) + elif data == 'sst': + train_set = load_dataset(data, 'default', split=split) + train_set = train_set.map(lambda example: {'label': int( + example['label']*10//2)}, remove_columns=['tokens', 'tree'], num_proc=4) + else: + train_set = load_dataset(data, split=split) + if label_num: + return train_set, len(set(train_set[self.label_name])) + else: + return train_set +if __name__=="__main__": + data_processor=DATA_process() + valset=data_processor.validationset(data='ag_news') + print(valset) diff --git a/process_data/__init__.py b/process_data/__init__.py new file mode 100644 index 0000000..3c63951 --- /dev/null +++ b/process_data/__init__.py @@ -0,0 +1,2 @@ +if __name__=='__main__': + print('Using process_data package') \ No newline at end of file diff --git a/process_data/__pycache__/Augmentation.cpython-39.pyc b/process_data/__pycache__/Augmentation.cpython-39.pyc new file mode 100644 index 0000000..98e2761 Binary files /dev/null and b/process_data/__pycache__/Augmentation.cpython-39.pyc differ diff --git a/process_data/__pycache__/Load_data.cpython-38.pyc b/process_data/__pycache__/Load_data.cpython-38.pyc new file mode 100644 index 0000000..91bbb4c Binary files /dev/null and b/process_data/__pycache__/Load_data.cpython-38.pyc differ diff --git a/process_data/__pycache__/Load_data.cpython-39.pyc b/process_data/__pycache__/Load_data.cpython-39.pyc new file mode 100644 index 0000000..3fe5d32 Binary files /dev/null and b/process_data/__pycache__/Load_data.cpython-39.pyc differ diff --git a/process_data/__pycache__/__init__.cpython-38.pyc b/process_data/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..857a025 Binary files /dev/null and b/process_data/__pycache__/__init__.cpython-38.pyc differ diff --git a/process_data/__pycache__/__init__.cpython-39.pyc b/process_data/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..232c9cd Binary files /dev/null and b/process_data/__pycache__/__init__.cpython-39.pyc differ diff --git a/process_data/__pycache__/ceshi.cpython-39.pyc b/process_data/__pycache__/ceshi.cpython-39.pyc new file mode 100644 index 0000000..70c6e8b Binary files /dev/null and b/process_data/__pycache__/ceshi.cpython-39.pyc differ diff --git a/process_data/__pycache__/settings.cpython-38.pyc b/process_data/__pycache__/settings.cpython-38.pyc new file mode 100644 index 0000000..10dc713 Binary files /dev/null and b/process_data/__pycache__/settings.cpython-38.pyc differ diff --git a/process_data/__pycache__/settings.cpython-39.pyc b/process_data/__pycache__/settings.cpython-39.pyc new file mode 100644 index 0000000..5daf79b Binary files /dev/null and b/process_data/__pycache__/settings.cpython-39.pyc differ diff --git a/process_data/ceshi.py b/process_data/ceshi.py new file mode 100644 index 0000000..a584b53 --- /dev/null +++ b/process_data/ceshi.py @@ -0,0 +1,63 @@ +from datasets import load_dataset +import pandas as np +import argparse +import os +import pandas as pd +from stanfordcorenlp import StanfordCoreNLP +from tqdm import tqdm +import time +import settings +from multiprocessing import cpu_count +from pandarallel import pandarallel + + +def parse_argument(): + parser = argparse.ArgumentParser( + description='download and parsing datasets') + parser.add_argument('--data', nargs='+', required=True, help='data list') + parser.add_argument('--corenlp_dir', type=str, + default='/remote-home/lzhang/stanford-corenlp-full-2018-10-05/') + parser.add_argument('--proc', type=int, help='multiprocessing num') + args = parser.parse_args() + return args + +def parsing_using_stanfordnlp(raw_text): + try: + parsing= snlp.parse(raw_text) + return parsing + except Exception as e: + return 'None' + + +def constituency_parsing(args): + if not args.proc: + args.proc = cpu_count() + pandarallel.initialize(nb_workers=args.proc, progress_bar=True) + for dataset in args.data: + DATA_dir = os.path.join(os.path.abspath( + os.path.join(os.getcwd(), "..")), 'DATA') + path_dir = os.path.join(DATA_dir, dataset.upper()) + output_path = os.path.join(path_dir, 'data', 'ceshi_parsing.csv') + if os.path.exists(output_path): + print('The data {} has already parsed!'.format(dataset.upper())) + continue + train = pd.read_csv(os.path.join( + path_dir, 'data', 'test.csv'), encoding="utf-8") + del train['Unnamed: 0'] + for i,text_name in enumerate(task_to_keys[dataset]): + parsing_name = 'parsing{}'.format(i+1) + train[parsing_name] = train[text_name].parallel_apply( + parsing_using_stanfordnlp) + + for i,text_name in enumerate(task_to_keys[dataset]): + parsing_name='parsing{}'.format(i+1) + train=train.drop(train[train[parsing_name]=='none'].index) + train.to_csv(output_path, index=0) + + +if __name__ == '__main__': + args = parse_argument() + tasksettings = settings.TaskSettings() + task_to_keys = tasksettings.task_to_keys + snlp = StanfordCoreNLP(args.corenlp_dir) + constituency_parsing(args) diff --git a/process_data/get_data.py b/process_data/get_data.py new file mode 100644 index 0000000..980a052 --- /dev/null +++ b/process_data/get_data.py @@ -0,0 +1,114 @@ +from datasets import load_dataset +import numpy as np +import argparse +import os +import pandas as pd +from stanfordcorenlp import StanfordCoreNLP +from tqdm import tqdm +import time +import settings +from multiprocessing import cpu_count +from pandarallel import pandarallel +def parse_argument(): + parser = argparse.ArgumentParser(description='download and parsing datasets') + parser.add_argument('--data',nargs='+',required=True,help='data list') + parser.add_argument('--corenlp_dir',type=str,default='/remote-home/lzhang/stanford-corenlp-full-2018-10-05/') + parser.add_argument('--proc',type=int,help='multiprocessing num') + args=parser.parse_args() + return args + + +def parsing_stanfordnlp(raw_text): + try: + parsing = snlp.parse(raw_text) + return parsing + except Exception as e: + return 'None' + +def constituency_parsing(args): + if not args.proc: + args.proc = cpu_count() + pandarallel.initialize(nb_workers=args.proc, progress_bar=True) + for dataset in args.data: + DATA_dir=os.path.join(os.path.abspath(os.path.join(os.getcwd(), "..")),'DATA') + path_dir=os.path.join(DATA_dir,dataset.upper()) + output_path=os.path.join(path_dir,'data','train_parsing.csv') + if os.path.exists(output_path): + print('The data {} has already parsed!'.format(dataset.upper())) + continue + train=pd.read_csv(os.path.join(path_dir,'data','train.csv'),encoding="utf-8") + for dataset in args.data: + DATA_dir = os.path.join(os.path.abspath( + os.path.join(os.getcwd(), "..")), 'DATA') + path_dir = os.path.join(DATA_dir, dataset.upper()) + output_path = os.path.join(path_dir, 'data', 'train_parsing.csv') + if os.path.exists(output_path): + print('The data {} has already parsed!'.format(dataset.upper())) + continue + train = pd.read_csv(os.path.join( + path_dir, 'data', 'train.csv'), encoding="utf-8") + for i,text_name in enumerate(task_to_keys[dataset]): + parsing_name = 'parsing{}'.format(i+1) + train[parsing_name] = train[text_name].parallel_apply( + parsing_stanfordnlp) + + for i,text_name in enumerate(task_to_keys[dataset]): + parsing_name='parsing{}'.format(i+1) + train=train.drop(train[train[parsing_name]=='None'].index) + train.to_csv(output_path, index=0) +def download_data(args): + + for dataset in args.data: + DATA_dir=os.path.join(os.path.abspath(os.path.join(os.getcwd(), "..")),'DATA') + path_dir=os.path.join(DATA_dir,dataset.upper()) + if dataset.upper() in os.listdir(DATA_dir): + print('{} directory already exists !'.format(dataset.upper())) + continue + try: + if dataset ==['addprim_jump','addprim_turn_left','simple']: + downloaded_data_list = [load_dataset('scan', dataset)] + if dataset in ['sst2', 'rte', 'mrpc', 'qqp', 'mnli', 'qnli']: + downloaded_data_list=[load_dataset('glue',dataset)] + elif dataset =='sst': + downloaded_data_list = [load_dataset("sst", "default")] + else: + downloaded_data_list=[load_dataset(dataset)] + + if not os.path.exists(path_dir): + if dataset=='trec': + os.makedirs(os.path.join(path_dir,'generated/fine')) + os.makedirs(os.path.join(path_dir,'generated/coarse')) + os.makedirs(os.path.join( + path_dir, 'runs/label-coarse/raw')) + os.makedirs(os.path.join( + path_dir, 'runs/label-coarse/aug')) + os.makedirs(os.path.join( + path_dir, 'runs/label-coarse/raw_aug')) + os.makedirs(os.path.join( + path_dir, 'runs/label-fine/raw')) + os.makedirs(os.path.join( + path_dir, 'runs/label-fine/aug')) + os.makedirs(os.path.join( + path_dir, 'runs/label-fine/raw_aug')) + else: + os.makedirs(os.path.join(path_dir,'generated')) + os.makedirs(os.path.join(path_dir,'runs/raw')) + os.makedirs(os.path.join(path_dir,'runs/aug')) + os.makedirs(os.path.join(path_dir,'runs/raw_aug')) + os.makedirs(os.path.join(path_dir,'logs')) + os.makedirs(os.path.join(path_dir,'data')) + for downloaded_data in downloaded_data_list: + for data_split in downloaded_data: + dataset_split=downloaded_data[data_split] + dataset_split.to_csv(os.path.join(path_dir,'data',data_split+'.csv'),index=0) + except Exception as e: + print('Downloading failed on {}, due to error {}'.format(dataset,e)) +if __name__=='__main__': + args = parse_argument() + tasksettings=settings.TaskSettings() + task_to_keys=tasksettings.task_to_keys + print('='*20,'Start Downloading Datasets','='*20) + download_data(args) + print('='*20,'Start Parsing Datasets','='*20) + snlp = StanfordCoreNLP(args.corenlp_dir) + constituency_parsing(args) diff --git a/process_data/settings.py b/process_data/settings.py new file mode 100644 index 0000000..def6afb --- /dev/null +++ b/process_data/settings.py @@ -0,0 +1,39 @@ +class TaskSettings(object): + def __init__(self): + self.train_settings={ + "mnli":{'epoch':5,'batch_size':96,'aug_batch_size':96,'val_steps':100,'max_length':128,'augweight':0.2}, + "mrpc":{'epoch':10,'batch_size':32,'aug_batch_size':32,'val_steps':50,'max_length':128,'augweight':0.2}, + "qnli":{'epoch':5,'batch_size':96,'aug_batch_size':96,'val_steps':100,'max_length':128,'augweight':0.2}, + "qqp": {'epoch':5,'batch_size':96,'aug_batch_size':96,'val_steps':300,'max_length':128,'augweight':0.2}, + "rte": {'epoch':10,'batch_size':32,'aug_batch_size':32,'val_steps':50,'max_length':128,'augweight':-0.2}, + "sst2":{'epoch':5,'batch_size':96,'aug_batch_size':96,'val_steps':100,'max_length':128,'augweight':0.5}, + "trec":{'epoch':20,'batch_size':96,'aug_batch_size':96,'val_steps':100,'max_length':128,'augweight':0.5}, + "imdb":{'epoch':5,'batch_size':8,'aug_batch_size':8,'val_steps':500,'max_length':512,'augweight':0.5}, + "ag_news": {'epoch': 5, 'batch_size': 96, 'aug_batch_size': 96, 'val_steps': 500, 'max_length': 128, 'augweight': 0.5}, + + } + self.task_to_keys = { + "mnli": ["premise", "hypothesis"], + "mrpc": ["sentence1", "sentence2"], + "qnli": ["question", "sentence"], + "qqp": ["question1", "question2"], + "rte": ["sentence1", "sentence2"], + "sst2": ["sentence"], + "trec": ["text"], + "anli": ["premise", "hypothesis"], + "imdb": ["text"], + "ag_news":["text"], + "sst":["sentence"], + "addprim_jump":["commands"], + "addprim_turn_left":["commands"] + } + self.pair_datasets=['qqp','rte','qnli','mrpc','mnli'] + self.SCAN = ['addprim_turn_left', 'addprim_jump','simple'] + self.low_resource={ + "ag_news":[0.01,0.02,0.05,0.1,0.2], + "sst":[0.01,0.02,0.05,0.1,0.2], + "sst2":[0.01,0.02,0.05,0.1,0.2] + } + + + diff --git a/process_data/subtree_substitution.py b/process_data/subtree_substitution.py new file mode 100644 index 0000000..ee5951d --- /dev/null +++ b/process_data/subtree_substitution.py @@ -0,0 +1,112 @@ +import random +from nltk import Tree +from tqdm import tqdm +import pandas as pd +import argparse +import numpy as np +import os +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) +def subtree_exchange(args,parsing1,label1,parsing2,label2,lam1,lam2): + """ + For a pair sentence, exchange subtree and return a label based on subtree length + + Find the candidate subtree, and extract correspoding span, and exchange span + + """ + assert lam1>lam2 + t1=Tree.fromstring(parsing1) + original_sentence=' '.join(t1.leaves()) + t1_len=len(t1.leaves()) + candicate_subtree1=list(t1.subtrees(lambda t: lam1>len(t.leaves())/t1_len>lam2)) + t2=Tree.fromstring(parsing2) + candicate_subtree2=list(t2.subtrees(lambda t: lam1>len(t.leaves())/t1_len>lam2)) + + # print('subtree1:',len(candicate_subtree1),'\nsubtree2:',len(candicate_subtree2)) + if len(candicate_subtree1)==0 or len(candicate_subtree2)==0: + # print("this pair fail",len(candicate_subtree1),len(candicate_subtree2)) + return None + + if args.same_type: + tree_labels1=[tree.label() for tree in candicate_subtree1] + tree_labels2=[tree.label() for tree in candicate_subtree2] + same_labels=list(set(tree_labels1)&set(tree_labels2)) + if not same_labels: + # print('无相同类型的子树') + return None + select_label=random.choice(same_labels) + candicate1=random.choice([t for t in candicate_subtree1 if t.label()==select_label]) + candicate2=random.choice([t for t in candicate_subtree2 if t.label()==select_label]) + else: + candicate1=random.choice(candicate_subtree1) + candicate2=random.choice(candicate_subtree2) + + exchanged_span=' '.join(candicate1.leaves()) + exchanged_len=len(candicate1.leaves()) + exchanging_span=' '.join(candicate2.leaves()) + new_sentence=original_sentence.replace(exchanged_span,exchanging_span) + if label1!=label2: + exchanging_len=len(candicate2.leaves()) + new_len=t1_len-exchanged_len+exchanging_len + new_label=(exchanging_len/new_len)*label2+(new_len-exchanging_len)/new_len*label1 + else: + new_label=label1 + # print('被替换的span:{}\n替换的span:{}'.format(exchanged_span,exchanging_span)) + return new_sentence,new_label +def augmentation(args,dataset,aug_times,lam1,lam2): + """ + generate aug_num augmentation dataset + input: + dataset --- pd.dataframe + output: + aug_dataset --- pd.dataframe + """ + generated_list=[] + data_list=dataset.values.tolist() + shuffled_list=data_list.copy() + with tqdm(total=aug_times*len(data_list)) as bar: + for i in range(aug_times): + np.random.shuffle(shuffled_list) + for idx in range(len(data_list)): + bar.update(1) + aug_sample=subtree_exchange(args,data_list[idx][2],data_list[idx][1],shuffled_list[idx][2],shuffled_list[idx][1],lam1,lam2) + if aug_sample: + generated_list.append(aug_sample) + #De-duplication + generated_list=list(set(generated_list)) + return generated_list +def main(): + parser=argparse.ArgumentParser() + parser.add_argument('--attention',action='store_true',help='labels weight on attention score') + parser.add_argument('--lam1',type=float,default=0.6) + parser.add_argument('--lam2',type=float,default=0.3) + parser.add_argument('--times',type=int,default=5) + parser.add_argument('--min_token',type=int,default=10,help='minimum token numbers of augmentation samples') + parser.add_argument('--same_type',action='store_true') + parser.add_argument('--seed',default=7,type=int) + # parser.add_argument('--data_path',type=str,required=True) + parser.add_argument('--output_dir',type=str,required=True) + + + # parser.add_argument('--load_path',metavar='dir',required=True,help='directory of created augmentation dataset') + args=parser.parse_args() + set_seed(args.seed) + dataset=pd.read_csv("SST-2/train_parsing.csv") + if args.min_token: + dataset=dataset.loc[dataset['sentence'].str.split().apply(lambda x:len(x)>args.min_token)] + pos_samples=dataset.loc[dataset["label"]==1] + neg_samples=dataset.loc[dataset["label"]==0] + pos_pd=pd.DataFrame(augmentation(args,pos_samples,args.times,args.lam1,args.lam2),columns=["sentence","label"]) + neg_pd=pd.DataFrame(augmentation(args,neg_samples,args.times,args.lam1,args.lam2),columns=["sentence","label"]) + new_pd=pd.concat([pos_pd,neg_pd],axis=0) + new_pd=new_pd.sample(frac=1) + if args.same_type: + new_pd.to_csv(os.path.join(args.output_dir,'sametype_generated_times{}_seed{}_{}_{}_{}k.csv'.format(args.times,args.seed,args.lam1,args.lam2,round(len(new_pd))//1000,-1)),index=0) + else: + new_pd.to_csv(os.path.join(args.output_dir,'generated_times{}_seed{}_{}_{}_{}k.csv'.format(args.times,args.seed,args.lam1,args.lam2,round(len(new_pd))//1000,-1)),index=0) + + + +if __name__=='__main__': + main() \ No newline at end of file diff --git a/run.py b/run.py new file mode 100644 index 0000000..9ac4134 --- /dev/null +++ b/run.py @@ -0,0 +1,466 @@ +import random +import torch +import torch.utils.data.distributed +from torch.utils.data.distributed import DistributedSampler +import torch.nn.parallel +from transformers import BertForSequenceClassification, AdamW +from transformers import get_linear_schedule_with_warmup +import numpy as np +import torch.nn as nn +from sklearn.metrics import f1_score, accuracy_score +from tqdm import tqdm +import os +import re +import torch.nn.functional as F +from torch.utils.tensorboard import SummaryWriter +from itertools import cycle +import argparse +import torch.distributed as dist +import time +import online_augmentation +import logging +from process_data.Load_data import DATA_process + + + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + + +def cross_entropy(logits, target): + p = F.softmax(logits, dim=1) + log_p = -torch.log(p) + loss = target*log_p + # print(target,p,log_p,loss) + batch_num = logits.shape[0] + return loss.sum()/batch_num + + + +def flat_accuracy(preds, labels): + pred_flat = np.argmax(preds, axis=1).flatten() + labels_flat = labels.flatten() + return accuracy_score(labels_flat, pred_flat) + + +def reduce_tensor(tensor, args): + rt = tensor.clone() + dist.all_reduce(rt, op=dist.ReduceOp.SUM) + rt /= args.world_size + return rt + + +def tensorboard_settings(args): + if 'raw' in args.mode: + if args.data_path: + # raw_aug + log_dir = os.path.join(args.output_dir, 'Raw_Aug_{}_{}_{}_{}_{}'.format(args.data_path.split( + '/')[-1], args.seed, args.augweight, args.batch_size, args.aug_batch_size)) + if os.path.exists(log_dir): + raise IOError( + 'This tensorboard file {} already exists! Please do not train the same data repeatedly, if you want to train this dataset, delete corresponding tensorboard file first! '.format(log_dir)) + writer = SummaryWriter(log_dir=log_dir) + else: + # raw + if args.random_mix: + log_dir = os.path.join(args.output_dir, 'Raw_random_mixup_{}_{}_{}'.format( + args.random_mix, args.alpha, args.seed)) + if os.path.exists(log_dir): + raise IOError( + 'This tensorboard file {} already exists! Please do not train the same data repeatedly, if you want to train this dataset, delete corresponding tensorboard file first! '.format(log_dir)) + writer = SummaryWriter(log_dir=log_dir) + else: + log_dir = os.path.join( + args.output_dir, 'Raw_{}'.format(args.seed)) + if os.path.exists(log_dir): + raise IOError( + 'This tensorboard file {} already exists! Please do not train the same data repeatedly, if you want to train this dataset, delete corresponding tensorboard file first! '.format(log_dir)) + writer = SummaryWriter(log_dir=log_dir) + elif args.mode == 'aug': + # aug + log_dir = os.path.join(args.output_dir, 'Aug_{}_{}_{}_{}_{}'.format(args.data_path.split( + '/')[-1], args.seed, args.augweight, args.batch_size, args.aug_batch_size)) + if os.path.exists(log_dir): + raise IOError( + 'This tensorboard file {} already exists! Please do not train the same data repeatedly, if you want to train this dataset, delete corresponding tensorboard file first! '.format(log_dir)) + writer = SummaryWriter(log_dir=log_dir) + return writer + + +def logging_settings(args): + logger = logging.getLogger('result') + logger.setLevel(logging.INFO) + fmt = logging.Formatter( + fmt='%(asctime)s - %(filename)s - %(levelname)s: %(message)s') + if not os.path.exists(os.path.join('DATA', args.data.upper(), 'logs')): + os.makedirs(os.path.join( + 'DATA', args.data.upper(), 'logs')) + if args.low_resource_dir: + log_path = os.path.join('DATA', args.data.upper(),'logs', 'lowresourcebest_result.log') + else: + log_path = os.path.join('DATA', args.data.upper(),'logs', 'best_result.log') + + fh = logging.FileHandler(log_path, mode='a+', encoding='utf-8') + ft=logging.Filter(name='result.a') + fh.setFormatter(fmt) + fh.setLevel(logging.INFO) + fh.addFilter(ft) + logger.addHandler(fh) + result_logger=logging.getLogger('result.a') + return result_logger +def loading_model(args,label_num): + t1 = time.time() + if args.local_rank == -1: + device = torch.device( + "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") + args.n_gpu = torch.cuda.device_count() + else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs + torch.cuda.set_device(args.local_rank) + device = torch.device("cuda", args.local_rank) + torch.distributed.init_process_group(backend='nccl') + args.n_gpu = 1 # the number of gpu on each proc + args.device = device + if args.local_rank != -1: + args.world_size = torch.cuda.device_count() + else: + args.world_size = 1 + print('*'*40, '\nSettings:{}'.format(args)) + print('*'*40) + print('='*20, 'Loading models', '='*20) + model = BertForSequenceClassification.from_pretrained( + args.model, num_labels=label_num) + model.to(device) + t2 = time.time() + print( + '='*20, 'Loading models complete!, cost {:.2f}s'.format(t2-t1), '='*20) + # model parrallel + if args.local_rank != -1: + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.local_rank]) + elif args.n_gpu > 1: + model = nn.DataParallel(model) + if args.load_model_path is not None: + print("="*20, "Load model from %s", args.load_model_path,) + model.load_state_dict(torch.load(args.load_model_path)) + return model + +def parse_argument(): + parser = argparse.ArgumentParser() + parser.add_argument('--local_rank', default=-1, type=int, + help='node rank for distributed training') + parser.add_argument("--no_cuda", action='store_true', + help="Avoid using CUDA when available") + + parser.add_argument( + '--mode', type=str, choices=['raw', 'aug', 'raw_aug', 'visualize'], required=True) + parser.add_argument('--save_model', action='store_true') + parser.add_argument('--load_model_path', type=str) + parser.add_argument('--data', type=str, required=True) + parser.add_argument('--num_proc', type=int, default=8, + help='multi process number used in dataloader process') + + # training settings + parser.add_argument('--output_dir', type=str, help="tensorboard fileoutput directory") + parser.add_argument('--epoch', type=int, default=5, help='train epochs') + parser.add_argument('--lr', type=float, default=2e-5, help='learning rate') + parser.add_argument('--seed', default=42, type=int, help='seed ') + parser.add_argument('--batch_size', default=128, type=int, + help='train examples in each batch') + parser.add_argument('--val_steps', default=100, type=int, + help='evaluate on dev datasets every steps') + parser.add_argument('--max_length', default=128, + type=int, help='encode max length') + parser.add_argument('--label_name', type=str, default='label') + parser.add_argument('--model', type=str, default='bert-base-uncased') + parser.add_argument('--low_resource_dir', type=str, + help='Low resource data dir') + + # train on augmentation dataset parameters + parser.add_argument('--aug_batch_size', default=128, + type=int, help='train examples in each batch') + parser.add_argument('--augweight', default=0.2, type=float) + parser.add_argument('--data_path', type=str, help="augmentation file path") + parser.add_argument('--min_train_token', type=int, default=0, + help="minimum token num restriction for train dataset") + parser.add_argument('--max_train_token', type=int, default=0, + help="maximum token num restriction for train dataset") + parser.add_argument('--mix', action='store_false', help='train on 01mixup') + + # random mixup + parser.add_argument('--alpha', type=float, default=0.1, + help="online augmentation alpha") + parser.add_argument('--onlyaug', action='store_true', + help="train only on online aug batch") + parser.add_argument('--difflen', action='store_true', + help="train only on online aug batch") + parser.add_argument('--random_mix', type=str, help="random mixup ") + + # visualize dataset + + args = parser.parse_args() + if args.data == 'trec': + try: + assert args.label_name in ['label-fine', 'label-coarse'] + except AssertionError: + raise(AssertionError( + "If you want to train on trec dataset with augmentation, you have to name the label of split")) + if not args.output_dir: + args.output_dir = os.path.join( + 'DATA', args.data.upper(), 'runs', args.label_name, args.mode) + if args.mode == 'raw': + args.batch_size = 128 + if 'aug' in args.mode: + assert args.data_path + if args.mode == 'aug': + args.seed = 42 + if not args.output_dir: + args.output_dir = os.path.join( + 'DATA', args.data.upper(), 'runs', args.mode) + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + if args.data in ['rte', 'mrpc', 'qqp', 'mnli', 'qnli']: + args.task = 'pair' + else: + args.task = 'single' + + return args + + +def train(args): + # ======================================== + # Tensorboard &Logging + # ======================================== + writer = tensorboard_settings(args) + result_logger = logging_settings(args) + data_process = DATA_process(args) + # ======================================== + # Loading datasets + # ======================================== + print('='*20, 'Start processing dataset', '='*20) + t1 = time.time() + + val_dataloader = data_process.validation_data() + + if args.mode != 'aug': + train_dataloader, label_num = data_process.train_data(count_label=True) + # print('Label_num',label_num) + if args.data_path: + print('='*20, 'Train Augmentation dataset path: {}'.format(args.data_path), '='*20) + aug_dataloader = data_process.augmentation_data() + if args.mode == 'aug': + train_dataloader = aug_dataloader + else: + aug_dataloader = cycle(aug_dataloader) + + t2 = time.time() + print('='*20, 'Dataset process done! cost {:.2f}s'.format(t2-t1), '='*20) + + # ======================================== + # Model + # ======================================== + model=loading_model(args,label_num) + # ======================================== + # Optimizer Settings + # ======================================== + optimizer = AdamW(model.parameters(), lr=args.lr) + all_steps = args.epoch*len(train_dataloader) + scheduler = get_linear_schedule_with_warmup( + optimizer, num_warmup_steps=20, num_training_steps=all_steps) + criterion = nn.CrossEntropyLoss() + model.train() + + # ======================================== + # Train + # ======================================== + print('='*20, 'Start training', '='*20) + best_acc = 0 + args.val_steps = min(len(train_dataloader), args.val_steps) + + for epoch in range(args.epoch): + bar = tqdm(enumerate(train_dataloader), total=len( + train_dataloader)//args.world_size) + fail = 0 + loss = 0 + for step, batch in bar: + model.zero_grad() + + # ---------------------------------------------- + # Train_dataloader + # ---------------------------------------------- + if args.random_mix: + try: + + input_ids, target_a = batch['input_ids'], batch['labels'] + lam = np.random.choice([0, 0.1, 0.2, 0.3]) + exchanged_ids, new_index = online_augmentation.random_mixup( + args, input_ids, target_a, lam) + target_b = target_a[new_index] + outputs = model(exchanged_ids.to(args.device), token_type_ids=None, attention_mask=( + exchanged_ids > 0).to(args.device)) + logits = outputs.logits + loss = criterion(logits.to(args.device), target_a.to( + args.device))*(1-lam)+criterion(logits.to(args.device), target_b.to(args.device))*lam + + + except Exception as e: + fail += 1 + batch = {k: v.to(args.device) for k, v in batch.items()} + outputs = model(**batch) + loss = outputs.loss + elif args.model == 'aug': + # train only on augmentation dataset + batch = {k: v.to(args.device) for k, v in batch.items()} + if args.mix: + # train on 01 tree mixup augmentation dataset + mix_label = batch['labels'] + del batch['labels'] + + outputs = model(**batch) + logits = outputs.logits + + loss = cross_entropy(logits, mix_label) + else: + # train on 00&11 tree mixup augmentation dataset + outputs = model(**batch) + loss = outputs.loss + else: + # normal train + + batch = {k: v.to(args.device) for k, v in batch.items()} + + outputs = model(**batch) + loss = outputs.loss + # ---------------------------------------------- + # Aug_dataloader + # ---------------------------------------------- + if args.mode == 'raw_aug': + aug_batch = next(aug_dataloader) + aug_batch = {k: v.to(args.device) for k, v in aug_batch.items()} + + if args.mix: + mix_label = aug_batch['labels'] + del aug_batch['labels'] + aug_outputs = model(**aug_batch) + aug_logits = aug_outputs.logits + + aug_loss = cross_entropy(aug_logits, mix_label) + else: + aug_outputs = model(**aug_batch) + aug_loss = aug_outputs.loss + loss += aug_loss*args.augweight # for sst2,rte reaches best performance + + # Backward propagation + if args.n_gpu > 1: + loss = loss.mean() + loss.backward() + optimizer.step() + scheduler.step() + optimizer.zero_grad() + if args.local_rank == 0 or args.local_rank == -1: + writer.add_scalar("Loss/loss", loss, step + + epoch*len(train_dataloader)) + writer.flush() + if args.random_mix: + bar.set_description( + '| Epoch: {:<2}/{:<2}| Best acc:{:.2f}| Fail:{}|'.format(epoch, args.epoch, best_acc*100, fail)) + else: + bar.set_description( + '| Epoch: {:<2}/{:<2}| Best acc:{:.2f}|'.format(epoch, args.epoch, best_acc*100)) + + # ================================================= + # Validation + # ================================================= + if (epoch*len(train_dataloader)+step+1) % args.val_steps == 0: + total_eval_accuracy = 0 + total_val_loss = 0 + model.eval() # evaluation after each epoch + for i, batch in enumerate(val_dataloader): + with torch.no_grad(): + batch = {k: v.to(args.device) + for k, v in batch.items()} + outputs = model(**batch) + logits = outputs.logits + loss = outputs.loss + + if args.n_gpu > 1: + loss = loss.mean() + logits = logits.detach().cpu().numpy() + label_ids = batch['labels'].to('cpu').numpy() + + accuracy = flat_accuracy(logits, label_ids) + if args.local_rank != -1: + torch.distributed.barrier() + reduced_loss = reduce_tensor(loss, args) + accuracy = torch.tensor(accuracy).to(args.device) + reduced_acc = reduce_tensor(accuracy, args) + total_val_loss += reduced_loss + total_eval_accuracy += reduced_acc + else: + total_eval_accuracy += accuracy.item() + total_val_loss += loss.item() + avg_val_loss = total_val_loss/len(val_dataloader) + avg_val_accuracy = total_eval_accuracy/len(val_dataloader) + if avg_val_accuracy > best_acc: + best_acc = avg_val_accuracy + bset_steps = (epoch*len(train_dataloader) + + step)*args.batch_size + if args.save_model: + torch.save(model.state_dict(), 'best_model.pt') + if args.local_rank == 0 or args.local_rank == -1: + writer.add_scalar("Test/Loss", avg_val_loss, + epoch*len(train_dataloader)+step) + writer.add_scalar( + "Test/Accuracy", avg_val_accuracy, epoch*len(train_dataloader)+step) + writer.flush() + # print(f'Validation loss: {avg_val_loss}') + # print(f'Accuracy: {avg_val_accuracy:.5f}') + # print('Best Accuracy:{:.5f} Steps:{}\n'.format(best_acc, bset_steps)) + + if args.data_path: + aug_num=args.data_path.split('_')[-1] + + if args.low_resource_dir: + # low resource raw_aug + partial = re.findall(r'low_resource_(0.\d+)', + args.low_resource_dir)[0] + aug_num_seed = aug_num+'_'+str(args.seed) + result_logger.info('-'*160) + result_logger.info('| Data : {} | Mode: {:<8} | #Aug {:<6} | Best acc:{} | Steps:{} | Weight {} |Aug data: {}'.format( + args.data+'_'+partial, args.mode, aug_num_seed, round(best_acc*100, 3), bset_steps, args.augweight, args.data_path)) + else: + # raw_aug + aug_data_seed=re.findall(r'seed(\d)',args.data_path)[0] + aug_num_seed = aug_num+'_'+aug_data_seed + result_logger.info('-'*160) + result_logger.info('| Data : {} | Mode: {:<8} | #Aug {:<6} | Best acc:{} | Steps:{} | Weight {} |Aug data: {}'.format( + args.data, args.mode, aug_num_seed ,round(best_acc*100,3), bset_steps, args.augweight,args.data_path)) + else: + if args.low_resource_dir: + # low resource raw + partial=re.findall(r'low_resource_(0.\d+)',args.low_resource_dir)[0] + result_logger.info('-'*160) + result_logger.info('| Data : {} | Mode: {:.8} | Seed: {} | Best acc:{} | Steps:{} | Randommix: {} | Aug data: {}'.format( + args.data+'-'+partial, args.mode, args.seed, round(best_acc*100,3), bset_steps,bool(args.random_mix) ,args.data_path)) + else: + # raw + result_logger.info('-'*160) + result_logger.info('| Data : {} | Mode: {:.8} | Seed: {} | Best acc:{} | Steps:{} | Randommix: {} | Aug data: {}'.format( + args.data, args.mode, args.seed, round(best_acc*100,3), bset_steps, bool(args.random_mix),args.data_path)) + + + + + +def main(args): + set_seed(args.seed) + if args.mode in ['raw', 'raw_aug', 'aug']: + if args.low_resource_dir: + print("="*20, ' Lowresource ', '='*20) + train(args) +if __name__ == '__main__': + args = parse_argument() + main(args)