From 89b6d5b108a2df5b030640af9992cabd1338829f Mon Sep 17 00:00:00 2001 From: yangheng95 Date: Sat, 18 Mar 2023 18:19:15 +0000 Subject: [PATCH] 2.1.12 --- .../inference.py | 4 +- .../multitask_train.py | 21 ++++--- pyabsa/__init__.py | 2 +- pyabsa/tasks/ABSAInstruction/data_utils.py | 61 ++++++++++-------- pyabsa/tasks/ABSAInstruction/instruction.py | 16 ++--- pyabsa/tasks/ABSAInstruction/model.py | 63 ++++++++++++------- 6 files changed, 100 insertions(+), 67 deletions(-) diff --git a/examples-v2/aspect_opinion_sentiment_category_extraction/inference.py b/examples-v2/aspect_opinion_sentiment_category_extraction/inference.py index baee4442..82a589bd 100644 --- a/examples-v2/aspect_opinion_sentiment_category_extraction/inference.py +++ b/examples-v2/aspect_opinion_sentiment_category_extraction/inference.py @@ -10,7 +10,9 @@ from pyabsa import ABSAInstruction if __name__ == "__main__": - generator = ABSAInstruction.ABSAGenerator("multilingual") + generator = ABSAInstruction.ABSAGenerator( + "checkpoints/multitask/googleflan-t5-base-instruction/checkpoint-2745" + ) example = [ "The food is good, but the service is bad.", "The laptop is good, but the battery life is bad.", diff --git a/examples-v2/aspect_opinion_sentiment_category_extraction/multitask_train.py b/examples-v2/aspect_opinion_sentiment_category_extraction/multitask_train.py index 6863a55c..6874a196 100644 --- a/examples-v2/aspect_opinion_sentiment_category_extraction/multitask_train.py +++ b/examples-v2/aspect_opinion_sentiment_category_extraction/multitask_train.py @@ -11,6 +11,7 @@ import findfile from pyabsa import ABSAInstruction as absa_instruction + warnings.filterwarnings("ignore") import pandas as pd @@ -18,10 +19,10 @@ task_name = "multitask" experiment_name = "instruction" # model_checkpoint = 'allenai/tk-instruct-base-def-pos' -model_checkpoint = "kevinscaria/ate_tk-instruct-base-def-pos-neg-neut-combined" +# model_checkpoint = "kevinscaria/ate_tk-instruct-base-def-pos-neg-neut-combined" # model_checkpoint = 'allenai/tk-instruct-large-def-pos' # model_checkpoint = 'allenai/tk-instruct-3b-def-pos' -# model_checkpoint = 'google/mt5-base' +model_checkpoint = "google/flan-t5-base" print("Experiment Name: ", experiment_name) model_out_path = "checkpoints" @@ -33,12 +34,12 @@ # Load the data # id_train_file_path = './integrated_datasets' # id_test_file_path = './integrated_datasets' -# id_train_file_path = "./integrated_datasets/acos_datasets/" -# id_test_file_path = "./integrated_datasets/acos_datasets" -id_train_file_path = './integrated_datasets/acos_datasets/501.Laptop14' -id_test_file_path = './integrated_datasets/acos_datasets/501.Laptop14' -# id_train_file_path = './integrated_datasets/acos_datasets/504.Restaurant16' -# id_test_file_path = './integrated_datasets/acos_datasets/504.Restaurant16' +id_train_file_path = "./integrated_datasets/acos_datasets/" +id_test_file_path = "./integrated_datasets/acos_datasets" +# id_train_file_path = './integrated_datasets/acos_datasets/501.Laptop14' +# id_test_file_path = './integrated_datasets/acos_datasets/501.Laptop14' +# id_train_file_path = './integrated_datasets/acos_datasets/502.Restaurant14' +# id_test_file_path = './integrated_datasets/acos_datasets/502.Restaurant14' id_tr_df = absa_instruction.data_utils.read_json(id_train_file_path, "train") @@ -72,9 +73,9 @@ "evaluation_strategy": "epoch", "save_strategy": "epoch", "learning_rate": 5e-5, - "per_device_train_batch_size": 4, + "per_device_train_batch_size": 16, "per_device_eval_batch_size": 16, - "num_train_epochs": 6, + "num_train_epochs": 3, "weight_decay": 0.01, "warmup_ratio": 0.1, "load_best_model_at_end": True, diff --git a/pyabsa/__init__.py b/pyabsa/__init__.py index 266f111f..222595d0 100644 --- a/pyabsa/__init__.py +++ b/pyabsa/__init__.py @@ -7,7 +7,7 @@ # Copyright (C) 2021. All Rights Reserved. __name__ = "pyabsa" -__version__ = "2.1.11" +__version__ = "2.1.12" from pyabsa.framework.flag_class import * diff --git a/pyabsa/tasks/ABSAInstruction/data_utils.py b/pyabsa/tasks/ABSAInstruction/data_utils.py index 429a4024..da4040ac 100644 --- a/pyabsa/tasks/ABSAInstruction/data_utils.py +++ b/pyabsa/tasks/ABSAInstruction/data_utils.py @@ -48,35 +48,43 @@ def prepare_instruction_dataloader(self, df): cat_instructor = CategoryInstruction() alldata = [] for i, data in df.iterrows(): - _aspects = [label["aspect"] for label in data["labels"]] + _aspects = ["aspect:" + label["aspect"] for label in data["labels"]] aspects = [] for asp in _aspects: if asp.strip() not in aspects: aspects.append(asp.strip()) - aspects = ", ".join(aspects) - alldata.append( - {"text": ate_instructor.prepare_input(data["text"]), "labels": aspects} - ) + aspects = "|".join(aspects) - opinions = ", ".join( + polarities = [] + _polarities = [ + "{}:{}".format(label["aspect"], label["polarity"]) + for label in data["labels"] + ] + for pol in _polarities: + if pol not in polarities: + polarities.append(pol) + polarities = "|".join(polarities) + + opinions = "|".join( [ "{}:{}".format(label["aspect"], label["opinion"]) for label in data["labels"] ] ) - alldata.append( - { - "text": op_instructor.prepare_input(data["text"], aspects), - "labels": opinions, - } - ) - polarities = ", ".join( + categories = "|".join( [ - "{}:{}".format(label["aspect"], label["polarity"]) + "{}:{}".format(label["aspect"], label["category"]) for label in data["labels"] ] ) + + # ATE task + alldata.append( + {"text": ate_instructor.prepare_input(data["text"]), "labels": aspects} + ) + + # APC task alldata.append( { "text": apc_instructor.prepare_input(data["text"], aspects), @@ -84,21 +92,23 @@ def prepare_instruction_dataloader(self, df): } ) - categories = ", ".join( - [ - "{}:{}".format( - label["aspect"], label["category"].replace("NULL", "") - ) - for label in data["labels"] - ] - ) + # Opinion task alldata.append( { - "text": cat_instructor.prepare_input(data["text"], aspects), - "labels": categories, + "text": op_instructor.prepare_input(data["text"], aspects), + "labels": opinions, } ) - # print(alldata[-1]['labels']) + + # Category task + if "NULL" not in categories: + alldata.append( + { + "text": cat_instructor.prepare_input(data["text"], aspects), + "labels": categories, + } + ) + alldata = pd.DataFrame(alldata) return alldata @@ -163,6 +173,7 @@ def read_json(data_path, data_type="train"): files = findfile.find_files(data_path, [data_type, ".jsonl"], exclude_key=[".txt"]) for f in files: + print(f) with open(f, "r", encoding="utf8") as fin: for line in fin: data.append(json.loads(line)) diff --git a/pyabsa/tasks/ABSAInstruction/instruction.py b/pyabsa/tasks/ABSAInstruction/instruction.py index 00374589..c30c4305 100644 --- a/pyabsa/tasks/ABSAInstruction/instruction.py +++ b/pyabsa/tasks/ABSAInstruction/instruction.py @@ -31,12 +31,12 @@ def __init__(self, bos_instruction=None, eos_instruction=None): example 1- input: I charge it at night and skip taking the cord with me because of the good battery life. {self.eos_instruction} -battery life, cord +aspect:battery life|aspect:cord example 2- input: Great food, good size menu, great service and an unpretensious setting. {self.eos_instruction} -food, menu, service, setting +aspect:food|aspect:menu|aspect:service|aspect:setting Now extract aspects from the following example: input: """ @@ -64,13 +64,13 @@ def __init__(self, bos_instruction=None, eos_instruction=None): input: I charge it at night and skip taking the cord with me because of the good battery life. The aspects are: battery life, cord {self.eos_instruction} -battery life:positive, cord:positive +battery life:positive|cord:positive example 2- input: Great food, good size menu, great service and an unpretensious setting. The aspects are: food, menu, service, setting {self.eos_instruction} -food:positive, menu:positive, service:positive, setting:positive +food:positive|menu:positive|service:positive|setting:positive Now predict aspect sentiments from the following example: @@ -103,13 +103,13 @@ def __init__(self, bos_instruction=None, eos_instruction=None): input: I charge it at night and skip taking the cord with me because of the good battery life. The aspects are: battery life, cord {self.eos_instruction} -battery life:good, cord:NULL +battery life:good|cord:NULL example 2- input: Great food, good size menu, great service and an unpretensious setting. The aspects are: food, menu, service, setting {self.eos_instruction} -food:great, menu:good, service:great, setting:unpretensious +food:great|menu:good|service:great|setting:unpretensious Now extract opinions for the following example: input:""" @@ -141,11 +141,11 @@ def __init__(self, bos_instruction=None, eos_instruction=None): input: I charge it at night and skip taking the cord with me because of the good battery life. The aspects are: battery life, cord {self.eos_instruction} -battery life:POWER_SUPPLY#GENERAL, cord:NULL +battery life:POWER_SUPPLY#GENERAL|cord:NULL example 2- input: Great food, good size menu, great service and an unpretensious setting. -The aspects are: food, menu, service, setting +The aspects are: food:FOOD#QUALITY| menu:RESTAURANT#GENERAL|service:SERVICE#GENERAL|setting:SERVICE#GENERAL {self.eos_instruction} food:FOOD#QUALITY, menu:RESTAURANT#GENERAL, service:SERVICE#GENERAL, setting:SERVICE#GENERAL diff --git a/pyabsa/tasks/ABSAInstruction/model.py b/pyabsa/tasks/ABSAInstruction/model.py index 8989e95a..1650caeb 100644 --- a/pyabsa/tasks/ABSAInstruction/model.py +++ b/pyabsa/tasks/ABSAInstruction/model.py @@ -1,4 +1,5 @@ import autocuda +import sklearn import torch from pyabsa.framework.checkpoint_class.checkpoint_template import CheckpointManager from torch.utils.data import DataLoader @@ -32,6 +33,7 @@ def __init__(self, checkpoint): self.tokenizer = AutoTokenizer.from_pretrained(checkpoint) self.model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint) + self.model.config.max_length = 128 self.data_collator = DataCollatorForSeq2Seq(self.tokenizer) self.device = autocuda.auto_cuda() self.model.to(self.device) @@ -94,7 +96,7 @@ def predict(self, text, **kwargs): ate_outputs = self.tokenizer.batch_decode( ate_outputs, skip_special_tokens=True )[0] - result["aspect"] = [asp.strip() for asp in ate_outputs.split(",")] + result["aspect"] = [asp.strip() for asp in ate_outputs.split("|")] # APC inference inputs = self.tokenizer( @@ -106,7 +108,7 @@ def predict(self, text, **kwargs): apc_outputs = self.tokenizer.batch_decode( apc_outputs, skip_special_tokens=True )[0] - result["sentiment"] = [sent.strip() for sent in apc_outputs.split(",")] + result["sentiment"] = [sent.strip() for sent in apc_outputs.split("|")] # Opinion inference inputs = self.tokenizer( @@ -118,7 +120,7 @@ def predict(self, text, **kwargs): op_outputs = self.tokenizer.batch_decode(op_outputs, skip_special_tokens=True)[ 0 ] - result["opinion"] = [op.strip() for op in op_outputs.split(",")] + result["opinion"] = [op.strip() for op in op_outputs.split("|")] # Category inference inputs = self.tokenizer( @@ -130,7 +132,7 @@ def predict(self, text, **kwargs): cat_outputs = self.tokenizer.batch_decode( cat_outputs, skip_special_tokens=True )[0] - result["category"] = [cat.strip() for cat in cat_outputs.split(",")] + result["category"] = [cat.strip() for cat in cat_outputs.split("|")] ensemble_result = { "text": text, "Quadruples": [ @@ -207,26 +209,43 @@ def get_aspect_metrics(self, true_aspects, pred_aspects): return aspect_p, aspect_r, aspect_f1 def get_classic_metrics(self, y_true, y_pred): - total_pred = 0 - total_gt = 0 - tp = 1e-6 + valid_gts = [] + valid_preds = [] for gt, pred in zip(y_true, y_pred): - print(gt) - print(pred) - - gt_list = gt.split(", ") - pred_list = pred.split(", ") - total_pred += len(pred_list) - total_gt += len(gt_list) - for gt_val in gt_list: + gt_list = gt.split("|") + pred_list = pred.split("|") + while gt_list: + gt_val = gt_list[-1].strip().lower() for pred_val in pred_list: - gt_val = gt_val.replace(" ", "") - pred_val = pred_val.replace(" ", "") - if pred_val.strip().lower() == gt_val.strip().lower(): - tp += 1 - p = tp / total_pred - r = tp / total_gt - return {"precision": p, "recall": r, "f1": 2 * p * r / (p + r)} + pred_val = pred_val.strip().lower() + gt_key, _, gt_label = gt_val.partition(":") + pred_key, _, pred_label = pred_val.partition(":") + if gt_key.startswith(pred_key): + if gt_label: + valid_gts.append(gt_label) + else: + break + if pred_label: + valid_preds.append(pred_label) + else: + valid_preds.append("") + break + + gt_list.pop() + + report = sklearn.metrics.classification_report(valid_gts, valid_preds) + print(report) + accuracy = sklearn.metrics.accuracy_score(valid_gts, valid_preds) + precision = precision_score(valid_gts, valid_preds, average="macro") + recall = recall_score(valid_gts, valid_preds, average="macro") + f1 = f1_score(valid_gts, valid_preds, average="macro") + + return { + "accuracy": accuracy, + "precision": precision, + "recall": recall, + "f1": f1, + } # def get_classic_metrics(self, y_true, y_pred): #