-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4545426
commit 1dd90fc
Showing
81 changed files
with
321 additions
and
1,752 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,44 @@ | ||
# GAAA | ||
# A3V | ||
data_preprocessing configs are defined in config/preprocessing | ||
train configs are defined in config/training | ||
the valid process is contained in each train epoch | ||
|
||
all the data preprocessing file in scripts/preprocessing: | ||
ESIM: | ||
python preprocess_quora.py | ||
python preprocess_snli.py | ||
python preprocess_mnli.py | ||
BERT: | ||
python preprocess_quora_bert.py | ||
python preprocess_snli_bert.py | ||
python preprocess_mnli_bert.py | ||
|
||
Stage One:pre-train model A | ||
ESIM: | ||
python esim_quora.py | ||
python esim_snli.py | ||
python esim_mnli.py | ||
BERT: | ||
python bert_quora.py | ||
python bert_snli.py | ||
python bert_mnli.py | ||
|
||
Stage Two:fine-tuning model B | ||
ESIM: | ||
python top_esim_quora.py | ||
python top_esim_snli.py | ||
python top_esim_mnli.py | ||
BERT: | ||
python top_bert_quora.py | ||
python top_bert_snli.py | ||
python top_bert_mnli.py | ||
|
||
To get Kaggle Open Evaluation submission file: | ||
python esim_mnli_test.py | ||
python top_bert_mnli_test.py | ||
|
||
|
||
|
||
|
||
|
||
|
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
""" | ||
Train the ESIM model on the preprocessed SNLI dataset. | ||
""" | ||
# Aurelien Coet, 2018. | ||
|
||
from utils.runned.utils_test_three import validate | ||
from a3v.model_transformer import ESIM | ||
# from a3v.model_bert_transformer import ESIM | ||
import os | ||
import argparse | ||
import json | ||
import numpy as np | ||
import pickle | ||
import torch | ||
import matplotlib | ||
matplotlib.use('Agg') | ||
|
||
|
||
def transform_batch_data(data, batch_size=64, shuffle=True): | ||
data_batch = dict() | ||
data_batch['premises'] = dict() | ||
data_batch['hypotheses'] = dict() | ||
data_batch['labels'] = dict() | ||
index = np.arange(len(data['labels'])) | ||
if shuffle: | ||
np.random.shuffle(index) | ||
|
||
idx = -1 | ||
for i in range(len(index)): | ||
if i % batch_size == 0: | ||
idx += 1 | ||
data_batch['premises'][idx] = [] | ||
data_batch['hypotheses'][idx] = [] | ||
data_batch['labels'][idx] = [] | ||
data_batch['premises'][idx].append(data['premises'][index[i]]) | ||
data_batch['hypotheses'][idx].append(data['hypotheses'][index[i]]) | ||
data_batch['labels'][idx].append(int(data['labels'][index[i]])) | ||
return data_batch | ||
|
||
|
||
def main(train_file, | ||
valid_file, | ||
test_file, | ||
target_dir, | ||
embedding_size=512, | ||
hidden_size=512, | ||
dropout=0.5, | ||
num_classes=3, | ||
epochs=64, | ||
batch_size=32, | ||
lr=0.0004, | ||
patience=5, | ||
max_grad_norm=10.0, | ||
checkpoint=None): | ||
""" | ||
Train the ESIM model on the Quora dataset. | ||
Args: | ||
train_file: A path to some preprocessed data that must be used | ||
to train the model. | ||
valid_file: A path to some preprocessed data that must be used | ||
to validate the model. | ||
embeddings_file: A path to some preprocessed word embeddings that | ||
must be used to initialise the model. | ||
target_dir: The path to a directory where the trained model must | ||
be saved. | ||
hidden_size: The size of the hidden layers in the model. Defaults | ||
to 300. | ||
dropout: The dropout rate to use in the model. Defaults to 0.5. | ||
num_classes: The number of classes in the output of the model. | ||
Defaults to 3. | ||
epochs: The maximum number of epochs for training. Defaults to 64. | ||
batch_size: The size of the batches for training. Defaults to 32. | ||
lr: The learning rate for the optimizer. Defaults to 0.0004. | ||
patience: The patience to use for early stopping. Defaults to 5. | ||
checkpoint: A checkpoint from which to continue training. If None, | ||
training starts from scratch. Defaults to None. | ||
""" | ||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
|
||
print(20 * "=", " Preparing for training ", 20 * "=") | ||
|
||
if not os.path.exists(target_dir): | ||
os.makedirs(target_dir) | ||
|
||
# -------------------- Data loading ------------------- # | ||
print("\t* Loading training data...") | ||
with open(train_file, "rb") as pkl: | ||
train_data = pickle.load(pkl) | ||
|
||
print("\t* Loading validation data...") | ||
with open(valid_file, "rb") as pkl: | ||
valid_data = pickle.load(pkl) | ||
valid_dataloader = transform_batch_data(valid_data, batch_size=batch_size, shuffle=False) | ||
|
||
print("\t* Loading test data...") | ||
with open(test_file, "rb") as pkl: | ||
test_data = pickle.load(pkl) | ||
test_dataloader = transform_batch_data(test_data, batch_size=batch_size, shuffle=False) | ||
|
||
# -------------------- Model definition ------------------- # | ||
print("\t* Building model...") | ||
|
||
model = ESIM(embedding_size, | ||
hidden_size, | ||
dropout=dropout, | ||
num_classes=num_classes, | ||
device=device).to(device) | ||
|
||
# -------------------- Preparation for training ------------------- # | ||
|
||
# Continuing training from a checkpoint if one was given as argument. | ||
if checkpoint: | ||
checkpoint = torch.load(checkpoint) | ||
start_epoch = checkpoint["epoch"] + 1 | ||
|
||
print("\t* Training will continue on existing model from epoch {}..." | ||
.format(start_epoch)) | ||
|
||
model.load_state_dict(checkpoint["model"]) | ||
|
||
# Compute loss and accuracy before starting (or resuming) training. | ||
_, valid_accuracy = validate(model, test_dataloader) | ||
print("\t* Validation accuracy: {:.4f}%".format(valid_accuracy*100)) | ||
|
||
# _, test_loss, test_accuracy = validate(model, | ||
# test_dataloader, | ||
# criterion) | ||
# print("\t* test loss before training: {:.4f}, accuracy: {:.4f}%" | ||
# .format(test_loss, (test_accuracy*100))) | ||
|
||
|
||
|
||
if __name__ == "__main__": | ||
default_config = "../../config/training/mnli_training_bert.json" | ||
|
||
parser = argparse.ArgumentParser( | ||
description="Train the ESIM model on quora") | ||
parser.add_argument("--config", | ||
default=default_config, | ||
help="Path to a json configuration file") | ||
|
||
script_dir = os.path.dirname(os.path.realpath(__file__)) | ||
script_dir = script_dir + '/scripts/training' | ||
|
||
parser.add_argument("--checkpoint", | ||
default=os.path.dirname(os.path.realpath(__file__)) + '/data/checkpoints/MNLI/bert/' +"esim_{}.pth.tar".format(12), | ||
help="Path to a checkpoint file to resume training") | ||
args = parser.parse_args() | ||
|
||
if args.config == default_config: | ||
config_path = os.path.join(script_dir, args.config) | ||
else: | ||
config_path = args.config | ||
|
||
with open(os.path.normpath(config_path), 'r') as config_file: | ||
config = json.load(config_file) | ||
|
||
main(os.path.normpath(os.path.join(script_dir, config["train_data"])), | ||
os.path.normpath(os.path.join(script_dir, config["valid_data_matched"])), | ||
os.path.normpath(os.path.join(script_dir, config["valid_data_mismatched"])), | ||
os.path.normpath(os.path.join(script_dir, config["target_dir"])), | ||
config["embedding_size"], | ||
config["hidden_size"], | ||
0,#config["dropout"], | ||
config["num_classes"], | ||
config["epochs"], | ||
config["batch_size"], | ||
config["lr"], | ||
config["patience"], | ||
config["max_gradient_norm"], | ||
args.checkpoint) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.