-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathflan_t5_augment.py
113 lines (95 loc) · 5.67 KB
/
flan_t5_augment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import argparse
import csv
from tqdm import tqdm
from transformers import pipeline
from datasets import load_dataset
def fetcher(dataset):
for i in dataset:
yield i["text"]
def get_feedback(inp):
text = inp["text"]
inp["text"] = prompts["feedback"][0].format(text=text)
inp["argument"] = text
return inp
def get_similar_quality_arg(inp):
text = inp["text"]
inp["text"] = prompts["similar"][0].format(text=text)
inp["argument"] = text
return inp
def get_counter(inp):
text = inp["text"]
inp["text"] = prompts["counter"][0].format(text=text)
inp["argument"] = text
return inp
def get_assumptions(inp):
text = inp["text"]
inp["text"] = prompts["assumption"][0].format(text=text)
inp["argument"] = text
return inp
def sample_feedback(dataset):
dataset = dataset.map(get_feedback, batched=False)
csv_write_file_feedback = prompts["feedback"][1]
csv_write_file_feedback.writerow(["cogeny", "effectiveness", "reasonableness", "text", "title", "feedback"])
for i, output in enumerate(tqdm(pipe(fetcher(dataset), max_new_tokens=512, batch_size=16))):
output = output[0]["generated_text"]
output = f"""{output[:output.find("</s>")]}"""
output = output.replace("\n", ".").strip()
csv_write_file_feedback.writerow([dataset[i]["cogency_mean"], dataset[i]["effectiveness_mean"], dataset[i]["reasonableness_mean"], dataset[i]["argument"], dataset[i]["title"], output])
def sample_similar_instance(dataset):
dataset = dataset.map(get_feedback, batched=False)
csv_write_file = prompts["similar"][1]
csv_write_file.writerow(["cogeny", "effectiveness", "reasonableness", "text", "title", "feedback"])
for i, output in enumerate(tqdm(pipe(fetcher(dataset), max_new_tokens=512, batch_size=16))):
output = output[0]["generated_text"]
output = f"""{output[:output.find("</s>")]}"""
output = output.replace("\n", ".").strip()
csv_write_file.writerow([dataset[i]["cogency_mean"], dataset[i]["effectiveness_mean"], dataset[i]["reasonableness_mean"], dataset[i]["argument"], dataset[i]["title"], output])
def sample_assumptions(dataset):
dataset = dataset.map(get_feedback, batched=False)
csv_write_file = prompts["assumption"][1]
csv_write_file.writerow(["cogeny", "effectiveness", "reasonableness", "text", "title", "feedback"])
for i, output in enumerate(tqdm(pipe(fetcher(dataset), max_new_tokens=512, batch_size=16))):
output = output[0]["generated_text"]
output = f"""{output[:output.find("</s>")]}"""
output = output.replace("\n", ".").strip()
csv_write_file.writerow([dataset[i]["cogency_mean"], dataset[i]["effectiveness_mean"], dataset[i]["reasonableness_mean"], dataset[i]["argument"], dataset[i]["title"], output])
def sample_counter_text(dataset):
dataset = dataset.map(get_feedback, batched=False)
csv_write_file = prompts["counter"][1]
csv_write_file.writerow(["cogeny", "effectiveness", "reasonableness", "text", "title", "feedback"])
for i, output in enumerate(tqdm(pipe(fetcher(dataset), max_new_tokens=512, batch_size=16))):
output = output[0]["generated_text"]
output = f"""{output[:output.find("</s>")]}"""
output = output.replace("\n", ".").strip()
csv_write_file.writerow([dataset[i]["cogency_mean"], dataset[i]["effectiveness_mean"], dataset[i]["reasonableness_mean"], dataset[i]["argument"], dataset[i]["title"], output])
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_file",
default="train_dataset.csv",
nargs="*",
required=True,
)
parser.add_argument(
"--output_dir",
default="augmented",
type=str,
required=True,
)
parser.add_argument("--add_similar", action='store_true')
args = parser.parse_args()
dataset = load_dataset("csv", data_files=args.input_file)["train"]
pipe = pipeline("text2text-generation", model="google/flan-t5-xl", device_map="auto")
csv_write_file_counter = f'{args.output_dir}/counter.csv'
csv_write_file_feedback = f'{args.output_dir}/feedback.csv'
csv_write_file_assumptions = f'{args.output_dir}/assumptions.csv'
csv_write_file_similar = f'{args.output_dir}/similar.csv'
prompts = {"feedback": ['''Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nGive concise writing feedback for the following argument in context with the topic, preferably in bullet points\n{text}\n\n### Response:''', csv.writer(open(csv_write_file_feedback, "w+"))],
"similar": ['''Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nGenerate a similar quality argument as the following argument:\n{text}\n\n### Response:''', csv.writer(open(csv_write_file_similar, "w+"))],
"counter": ['''Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nGive a counter-argument for the following argument\n{text}\n\n### Response:''', csv.writer(open(csv_write_file_counter, "w+"))],
"assumption": ['''Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nSummarize the assumptions, if any, in the following argument in a bullet format"\n{text}\n\n### Response:''', csv.writer(open(csv_write_file_assumptions, "w+"))]}
sample_feedback(dataset)
sample_assumptions(dataset)
sample_counter_text(dataset)
if args.add_similar:
sample_similar_instance(dataset)