Skip to content

Commit

Permalink
first
Browse files Browse the repository at this point in the history
  • Loading branch information
lezhang7 committed Nov 4, 2021
1 parent 9a63746 commit cbb77b3
Show file tree
Hide file tree
Showing 20 changed files with 1,803 additions and 0 deletions.
617 changes: 617 additions & 0 deletions Augmentation.py

Large diffs are not rendered by default.

102 changes: 102 additions & 0 deletions batch_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import argparse
import os
from process_data.settings import TaskSettings
def parse_argument():
parser = argparse.ArgumentParser(description='download and parsing datasets')
parser.add_argument('--data',type=str,required=True,help='data list')
parser.add_argument('--aug_dir',help='Augmentation file directory')
parser.add_argument('--seeds',default=[0,1,2,3,4],nargs='+',help='seed list')
parser.add_argument('--modes',nargs='+',required=True,help='seed list')
parser.add_argument('--label_name',type=str,default='label')
# parser.add_argument('--batch_size',default=128,type=int,help='train examples in each batch')
# parser.add_argument('--aug_batch_size',default=128,type=int,help='train examples in each batch')
parser.add_argument('--random_mix',type=str,choices=['zero_one','zero','one','all'],help="random mixup ")
parser.add_argument('--prefix',type=str,help="only choosing the datasets with the prefix,for ablation study")
parser.add_argument('--GPU',type=int,default=0,help="available GPU number")
parser.add_argument('--low_resource', action='store_true',
help='whther to train low resource dataset')

args=parser.parse_args()
if args.data=='trec':
try:
assert args.label_name in ['label-fine','label-coarse']
except AssertionError:
raise( AssertionError("If you want to train on TREC dataset with augmentation, you have to name the label of split either 'label-fine' or 'label-coarse'"))
args.aug_dir = os.path.join('DATA', args.data.upper(), 'generated',args.label_name)
if args.aug_dir is None :
args.aug_dir=os.path.join('DATA',args.data.upper(),'generated')

if 'aug' in args.modes:
try:
assert [file for file in os.listdir(args.aug_dir) if 'times' in file]
except AssertionError:
raise( AssertionError( "{}".format('This directory has no augmentation file, please input correct aug_dir!') ) )
if args.low_resource:
try:
args.low_resource = os.path.join('DATA', args.data.upper(),'low_resource')
assert os.path.exists(args.low_resource)
except AssertionError:
raise( AssertionError("There is no any low resource datasets in this data"))

return args
def batch_train(args):
for seed in args.seeds:
# for aug_file in os.listdir(args.aug_dir):
for mode in args.modes:
if mode=='raw':
# data_path=os.path.join(args.aug_dir,aug_file)
if args.random_mix:
os.system('CUDA_VISIBLE_DEVICES={} python run.py --label_name {} --mode {} --seed {} --data {} --random_mix {} --epoch {epoch} --batch_size {batch_size} --aug_batch_size {aug_batch_size} --val_steps {val_steps} --max_length {max_length} --augweight {augweight} '.format(args.GPU,args.label_name,mode,int(seed),args.data,args.random_mix,**settings[args.data]))
else:
os.system('CUDA_VISIBLE_DEVICES={} python run.py --label_name {} --mode {} --seed {} --data {} --epoch {epoch} --batch_size {batch_size} --aug_batch_size {aug_batch_size} --val_steps {val_steps} --max_length {max_length} --augweight {augweight} '.format(args.GPU,args.label_name,mode,int(seed),args.data,**settings[args.data]))
else:
for aug_file in os.listdir(args.aug_dir):
if args.prefix:
# only train on file with prefix
if aug_file.startswith(args.prefix):
aug_file_path = os.path.join(
args.aug_dir, aug_file)
assert os.path.exists(aug_file_path)
os.system('CUDA_VISIBLE_DEVICES={} python run.py --label_name {} --mode {} --seed {} --data {} --data_path {} --epoch {epoch} --batch_size {batch_size} --aug_batch_size {aug_batch_size} --val_steps {val_steps} --max_length {max_length} --augweight {augweight} '.format(
args.GPU, args.label_name, mode, int(seed), args.data, aug_file_path, **settings[args.data]))
else:
# train on every file in dir
aug_file_path = os.path.join(
args.aug_dir, aug_file)
assert os.path.exists(aug_file_path)
os.system('CUDA_VISIBLE_DEVICES={} python run.py --label_name {} --mode {} --seed {} --data {} --data_path {} --epoch {epoch} --batch_size {batch_size} --aug_batch_size {aug_batch_size} --val_steps {val_steps} --max_length {max_length} --augweight {augweight} '.format(
args.GPU, args.label_name, mode, int(seed), args.data, aug_file_path, **settings[args.data]))
def low_resource_train(args):
for partial_split in os.listdir(args.low_resource):
partial_split_path=os.path.join(args.low_resource,partial_split)
args.output_dir = os.path.join(
args.low_resource_dir, partial_split)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
for seed_num in os.listdir(partial_split_path):
partial_split_seed_path=os.path.join(partial_split_path,seed_num)
for mode in args.modes:
if mode=='raw':
if args.random_mix:
os.system('CUDA_VISIBLE_DEVICES={} python run.py --low_resource_dir {} --seed {} --output_dir {} --label_name {} --mode {} --data {} --random_mix {} --epoch {epoch} --batch_size {batch_size} --aug_batch_size {aug_batch_size} --val_steps {val_steps} --max_length {max_length} --augweight {augweight} '
.format(args.GPU, partial_split_seed_path, int(seed_num.split('_')[1]), args.output_dir, args.label_name, mode, args.data, args.random_mix, **settings[args.data]))
else:
os.system('CUDA_VISIBLE_DEVICES={} python run.py --low_resource_dir {} --seed {} --output_dir {} --label_name {} --mode {} --data {} --epoch {epoch} --batch_size {batch_size} --aug_batch_size {aug_batch_size} --val_steps {val_steps} --max_length {max_length} --augweight {augweight} '
.format(args.GPU, partial_split_seed_path, int(seed_num.split('_')[1]),args.output_dir, args.label_name, mode, args.data, **settings[args.data]))
elif mode=='raw_aug':
for aug_file in [file for file in os.listdir(partial_split_seed_path) if file.startswith('times')]:
aug_file_path=os.path.join(partial_split_seed_path,aug_file)
assert os.path.exists(aug_file_path)
os.system('CUDA_VISIBLE_DEVICES={} python run.py --low_resource_dir {} --seed {} --output_dir {} --label_name {} --mode {} --data {} --data_path {} --epoch {epoch} --batch_size {batch_size} --aug_batch_size {aug_batch_size} --val_steps {val_steps} --max_length {max_length} --augweight {augweight} '.format(
args.GPU, partial_split_seed_path, int(seed_num.split('_')[1]) , args.output_dir, args.label_name, mode, args.data, aug_file_path, **settings[args.data]))
if __name__=='__main__':
args=parse_argument()
tasksettings=TaskSettings()
settings=tasksettings.train_settings
if args.low_resource:
args.low_resource_dir=os.path.join('DATA',args.data.upper(),'runs','low_resource')
if not os.path.exists(args.low_resource_dir):
os.makedirs(args.low_resource_dir)
low_resource_train(args)
else:
batch_train(args)
131 changes: 131 additions & 0 deletions online_augmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import torch
import random
import numpy as np
def random_mixup_process(args,ids1,lam):

rand_index=torch.randperm(ids1.shape[0])
lenlist=[]
# rand_index=torch.randperm(len(ids1))
for x in ids1:
mask=((x!=101)&(x!=0)&(x!=102))
lenlist.append(int(mask.sum()))
lenlist2=torch.tensor(lenlist)[rand_index]
spanlen=torch.tensor([int(x*lam) for x in lenlist])

beginlist=[1+random.randint(0,x-int(x*lam)) for x in lenlist]
beginlist2=[1+random.randint(0,x-y) for x,y in zip(lenlist2,spanlen)]
if args.difflen:

spanlen2=torch.tensor([int(x*lam) for x in lenlist2])
# beginlist2=[1+random.randint(0,x-y) for x,y in zip(lenlist2,spanlen2)]
spanlist2=[(x,int(y)) for x,y in zip(beginlist2,spanlen2)]
else:
# beginlist2=[1+random.randint(0,x-y) for x,y in zip(lenlist2,spanlen)]
spanlist2=[(x,int(y)) for x,y in zip(beginlist2,spanlen)]
spanlist=[(x,int(y)) for x,y in zip(beginlist,spanlen)]

ids2=ids1.clone()
if args.difflen:
for idx in range(len(ids1)):
tmp=torch.cat((ids1[idx][:spanlist[idx][0]],ids2[rand_index[idx]][spanlist2[idx][0]:spanlist2[idx][0]+spanlist2[idx][1]],ids1[idx][spanlist[idx][0]+spanlist[idx][1]:]),dim=0)[:ids1.shape[1]]
ids1[idx]=torch.cat((tmp,torch.zeros(ids1.shape[1]-len(tmp))))
else:
for idx in range(len(ids1)):
ids1[idx][spanlist[idx][0]:spanlist[idx][0]+spanlist[idx][1]]=ids2[rand_index[idx]][spanlist2[idx][0]:spanlist2[idx][0]+spanlist2[idx][1]]
assert ids1.shape==ids2.shape
return ids1,rand_index
def mixup_01(args,input_ids,lam,idx1,idx2):
'''
01交换
'''
difflen=False
random_index=torch.zeros(len(idx1)+len(idx2)).long()
random_index[idx1]=torch.tensor(np.random.choice(idx2,size=len(idx1)))
random_index[idx2]=torch.tensor(np.random.choice(idx1,size=len(idx2)))

len_list1=[]
len_list2=[]
for input_id1 in input_ids:
#计算各个句子的具体token数
mask=((input_id1!=101)&(input_id1!=0)&(input_id1!=102))
len_list1.append(int(mask.sum()))
# print(len_list1)
len_list2=torch.tensor(len_list1)[random_index]

spanlen=torch.tensor([int(x*lam) for x in len_list1])
beginlist=[1+random.randint(0,x-int(x*lam)) for x in len_list1]
beginlist2=[1+random.randint(0,x-y) for x,y in zip(len_list2,spanlen)]
if difflen:
spanlen2=torch.tensor([int(x*lam) for x in len_list2])
# beginlist2=[1+random.randint(0,x-y) for x,y in zip(lenlist2,spanlen2)]
spanlist2=[(x,int(y)) for x,y in zip(beginlist2,spanlen2)]
else:
# beginlist2=[1+random.randint(0,x-y) for x,y in zip(lenlist2,spanlen)]
spanlist2=[(x,int(y)) for x,y in zip(beginlist2,spanlen)]
spanlist=[(x,int(y)) for x,y in zip(beginlist,spanlen)]
new_ids=input_ids.clone()

# print(random_index)
if difflen:
for idx in range(len(ids1)):
tmp=torch.cat((ids1[idx][:spanlist[idx][0]],ids2[rand_index[idx]][spanlist2[idx][0]:spanlist2[idx][0]+spanlist2[idx][1]],ids1[idx][spanlist[idx][0]+spanlist[idx][1]:]),dim=0)[:ids1.shape[1]]
ids1[idx]=torch.cat((tmp,torch.zeros(ids1.shape[1]-len(tmp))))
else:
for idx in range(len(input_ids)):
new_ids[idx][spanlist[idx][0]:spanlist[idx][0]+spanlist[idx][1]]=input_ids[random_index[idx]][spanlist2[idx][0]:spanlist2[idx][0]+spanlist2[idx][1]]
# for i in range(len(input_ids)):
# print('{}:交换的是{}与{},其中第1句选取的是从{}开始到{}的句子,第2句选取的是从{}开始到{}结束的句子'.format(
# i,i,random_index[i],spanlist[i][0],spanlist[i][0]+spanlist[i][1],spanlist2[i][0],spanlist2[i][0]+spanlist2[i][1]))
return new_ids,random_index
def mixup(args,input_ids,lam,idx1,idx2=None):
'''
只针对idx1索引对应的sample内部进行交换,如果idx2也给的话就是idx1 idx2进行交换
'''
select_input_ids=torch.index_select(input_ids,0,idx1)
rand_index=torch.randperm(select_input_ids.shape[0])
new_idx=torch.tensor(list(range(input_ids.shape[0])))
len_list1=[]
len_list2=[]
for input_id1 in select_input_ids:
#calculte length of tokens in each sentence
mask=((input_id1!=101)&(input_id1!=0)&(input_id1!=102))
len_list1.append(int(mask.sum()))
len_list2=torch.tensor(len_list1)[rand_index]

spanlen=torch.tensor([int(x*lam) for x in len_list1])
beginlist=[1+random.randint(0,x-y) for x,y in zip(len_list1,spanlen)]
beginlist2=[1+random.randint(0,max(0,x-y)) for x,y in zip(len_list2,spanlen)]

spanlist=[(x,int(y)) for x,y in zip(beginlist,spanlen)]
spanlist2 = [(x, min(int(y),z)) for x, y, z in zip(beginlist2, spanlen, len_list2)]
new_ids=input_ids.clone()
new_idx[idx1]=idx1[rand_index]
for i,idx in enumerate(idx1):
new_ids[idx][spanlist[i][0]:spanlist[i][0]+spanlist[i][1]]=input_ids[idx1[rand_index[i]]][spanlist2[i][0]:spanlist2[i][0]+spanlist2[i][1]]

return new_ids,new_idx
def random_mixup(args,ids1,lab1,lam):
"""
function: random select span to exchange based on lam to decide span length and rand_index decide selected candidate exchange sentece
input:
ids1 -- tensors of tensors input_ids
lab1 -- tensors of tensors labels
lam -- span length rate
output:
ids1 -- tensors of tensors , exchanged span
rand_index -- tensors , permutation index
"""
if args.random_mix=='all':
return mixup(args,ids1,lam,torch.tensor(range(ids1.shape[0])))
else:
pos_idx=(lab1==1).nonzero().squeeze()
neg_idx=(lab1==0).nonzero().squeeze()
pos_samples=torch.index_select(ids1,0,pos_idx)
neg_samples=torch.index_select(ids1,0,neg_idx)
if args.random_mix=='zero':
return mixup(args,ids1,lam,neg_idx)
if args.random_mix=='one':
return mixup(args,ids1,lam,pos_idx)
if args.random_mix=='zero_one':
return mixup_01(args,ids1,lam,pos_idx,neg_idx)
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit cbb77b3

Please sign in to comment.