-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e83f80f
commit 01577d8
Showing
2 changed files
with
175 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
""" | ||
Custom data collator for multiple choice questions | ||
""" | ||
from typing import Optional, Union | ||
from dataclasses import dataclass | ||
import torch | ||
from transformers.tokenization_utils_base import ( | ||
PreTrainedTokenizerBase, | ||
PaddingStrategy | ||
) | ||
|
||
|
||
@dataclass | ||
class DataCollatorForMultipleChoice: | ||
""" | ||
Data collator that will dynamically pad the inputs for multiple choice received. | ||
""" | ||
|
||
tokenizer: PreTrainedTokenizerBase | ||
padding: Union[bool, str, PaddingStrategy] = True | ||
max_length: Optional[int] = None | ||
pad_to_multiple_of: Optional[int] = None | ||
|
||
def __call__(self, features): | ||
label_name = "label" if "label" in features[0].keys() else "labels" | ||
labels = [feature.pop(label_name) for feature in features] | ||
batch_size = len(features) | ||
num_choices = len(features[0]["input_ids"]) | ||
flattened_features = [ | ||
[{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features | ||
] | ||
flattened_features = sum(flattened_features, []) | ||
|
||
batch = self.tokenizer.pad( | ||
flattened_features, | ||
padding=self.padding, | ||
max_length=self.max_length, | ||
pad_to_multiple_of=self.pad_to_multiple_of, | ||
return_tensors="pt", | ||
) | ||
|
||
batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()} | ||
batch["labels"] = torch.tensor(labels, dtype=torch.int64) | ||
return batch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
""" | ||
Fine-tune BERT-based models | ||
""" | ||
from math import floor | ||
import argparse | ||
import numpy as np | ||
import evaluate | ||
from transformers import ( | ||
AutoTokenizer, | ||
AutoModelForMultipleChoice, | ||
TrainingArguments, | ||
Trainer, | ||
EarlyStoppingCallback | ||
) | ||
from custom_data_collator import DataCollatorForMultipleChoice | ||
from load_datasets import DatasetManager | ||
from custom_trainer_callback import CustomTrainerCallback | ||
|
||
|
||
# Args | ||
parser = argparse.ArgumentParser(description="Fine-tune T5-based models on BrainTeaser") | ||
parser.add_argument("--dataset", required=True) | ||
parser.add_argument("--checkpoint", required=True) | ||
parser.add_argument("--tokenizer") | ||
parser.add_argument("--name", required=True) | ||
parser.add_argument( | ||
"--log_steps", | ||
type=float, | ||
required=True, | ||
help="A float number in range [0,1] specifying a ratio of epochs" | ||
) | ||
parser.add_argument("--epochs", type=int, default=4) | ||
parser.add_argument("--batch_size", type=int, default=2) | ||
parser.add_argument("--accumulation_steps", type=int, default=2) | ||
parser.add_argument("--learning_rate", type=float, default=5e-5) | ||
parser.add_argument("--early_stopping_patience", type=int, default=10) | ||
args = parser.parse_args() | ||
args.tokenizer = args.checkpoint if args.tokenizer is None else args.tokenizer | ||
assert 0 < args.log_steps <= 1, "Invalid value for log_steps" | ||
|
||
|
||
# Process examples | ||
def preprocess(examples): | ||
"""Tokenize and group the given examples""" | ||
n_choices = 4 | ||
n_examples = len(examples['label']) | ||
first_sentences = [[context] * n_choices for context in examples["text"]] | ||
second_sentences = [[examples[f'choice{c}'][i] for c in range(n_choices)] for i in range(n_examples)] | ||
|
||
first_sentences = sum(first_sentences, []) | ||
second_sentences = sum(second_sentences, []) | ||
|
||
tokenized_examples = tokenizer(first_sentences, second_sentences, truncation=True, max_length=512) | ||
return {k: [v[i : i + n_choices] for i in range(0, len(v), n_choices)] for k, v in tokenized_examples.items()} | ||
|
||
|
||
# Load dataset | ||
dataset_manager = DatasetManager(ignore_case=False, force_4_choices=True, ds_format='bert') | ||
if '|' in args.dataset: | ||
assert args.dataset.count('|') == 1, "Invalid number of datasets" | ||
primary_ds, secondary_ds = args.dataset.split('|') | ||
dataset = dataset_manager.load_combined_datasets(primary_ds, secondary_ds) | ||
else: | ||
dataset = dataset_manager.load_ds(args.dataset) | ||
|
||
# Calculate the log steps based on the number of steps in each epoch | ||
effective_batch_size = args.batch_size * args.accumulation_steps | ||
args.log_steps = floor(args.log_steps * len(dataset["train"]) / effective_batch_size) | ||
args.log_steps = max(args.log_steps, 1) | ||
|
||
# Load tokenizer and process dataset | ||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) | ||
dataset = dataset.map(preprocess, batched=True) | ||
|
||
# Evaluation metrics | ||
accuracy = evaluate.load("accuracy") | ||
def compute_metrics(eval_pred): | ||
"""Calculate accuracy metric""" | ||
predictions, labels = eval_pred | ||
predictions = np.argmax(predictions, axis=1) | ||
result = accuracy.compute(predictions=predictions, references=labels) | ||
return result | ||
|
||
|
||
# Load model & start training | ||
callback = CustomTrainerCallback(vars(args)) | ||
early_stopping = EarlyStoppingCallback(early_stopping_patience=args.early_stopping_patience) | ||
callbacks = [callback, early_stopping] | ||
model = AutoModelForMultipleChoice.from_pretrained(args.checkpoint) | ||
|
||
# Training args | ||
training_args = TrainingArguments( | ||
# Saving | ||
output_dir=args.name, | ||
logging_dir=f"{args.name}/logs", | ||
save_strategy="steps", | ||
save_steps=args.log_steps, | ||
save_total_limit=2, | ||
load_best_model_at_end=True, | ||
metric_for_best_model="eval_accuracy", | ||
|
||
# Loggging | ||
logging_strategy="steps", | ||
logging_steps=args.log_steps, | ||
|
||
# Training | ||
learning_rate=args.learning_rate, | ||
num_train_epochs=args.epochs, | ||
per_device_train_batch_size=args.batch_size, | ||
per_device_eval_batch_size=args.batch_size, | ||
gradient_accumulation_steps=args.accumulation_steps, | ||
|
||
# Evaluation | ||
evaluation_strategy="steps", | ||
eval_steps=args.log_steps, | ||
) | ||
|
||
# Create Trainer instance | ||
trainer = Trainer( | ||
model=model, | ||
args=training_args, | ||
train_dataset=dataset["train"], | ||
eval_dataset=dataset["test"], | ||
tokenizer=tokenizer, | ||
data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer), | ||
compute_metrics=compute_metrics, | ||
callbacks=callbacks | ||
) | ||
|
||
# Train | ||
trainer.train() |