Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Mrzhouqifei committed Nov 21, 2019
1 parent 4545426 commit 1dd90fc
Show file tree
Hide file tree
Showing 81 changed files with 321 additions and 1,752 deletions.
45 changes: 44 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,44 @@
# GAAA
# A3V
data_preprocessing configs are defined in config/preprocessing
train configs are defined in config/training
the valid process is contained in each train epoch

all the data preprocessing file in scripts/preprocessing:
ESIM:
python preprocess_quora.py
python preprocess_snli.py
python preprocess_mnli.py
BERT:
python preprocess_quora_bert.py
python preprocess_snli_bert.py
python preprocess_mnli_bert.py

Stage One:pre-train model A
ESIM:
python esim_quora.py
python esim_snli.py
python esim_mnli.py
BERT:
python bert_quora.py
python bert_snli.py
python bert_mnli.py

Stage Two:fine-tuning model B
ESIM:
python top_esim_quora.py
python top_esim_snli.py
python top_esim_mnli.py
BERT:
python top_bert_quora.py
python top_bert_snli.py
python top_bert_mnli.py

To get Kaggle Open Evaluation submission file:
python esim_mnli_test.py
python top_bert_mnli_test.py






File renamed without changes.
File renamed without changes.
File renamed without changes.
4 changes: 2 additions & 2 deletions mfae/droped/droped.py → a3v/droped/droped.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import torch
import torch.nn as nn
from mfae.layers import RNNDropout, Seq2SeqEncoder, SoftmaxAttention
from mfae.utils import replace_masked
from a3v.layers import RNNDropout, Seq2SeqEncoder, SoftmaxAttention
from a3v.utils import replace_masked
import math
from torch.nn.modules.transformer import *

Expand Down
2 changes: 1 addition & 1 deletion mfae/droped/layers.py → a3v/droped/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.activation import MultiheadAttention
from mfae.utils import sort_by_seq_lens, masked_softmax, weighted_sum, normal_softmax
from a3v.utils import sort_by_seq_lens, masked_softmax, weighted_sum, normal_softmax

# Class widely inspired from:
# https://github.com/allenai/allennlp/blob/master/allennlp/modules/input_variational_dropout.py
Expand Down
4 changes: 2 additions & 2 deletions mfae/droped/model_new.py → a3v/droped/model_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import torch
import torch.nn as nn
from mfae.layers import RNNDropout, Seq2SeqEncoder, SoftmaxAttention, LinerEncoder
from mfae.utils import get_mask, replace_masked
from a3v.layers import RNNDropout, Seq2SeqEncoder, SoftmaxAttention, LinerEncoder
from a3v.utils import get_mask, replace_masked
# from allennlp.modules.elmo import Elmo, batch_to_ids

class ESIM(nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions mfae/droped/model_top.py → a3v/droped/model_top.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import torch
import torch.nn as nn
from mfae.layers import RNNDropout, Seq2SeqEncoder, SoftmaxAttention, LinerEncoder
from mfae.utils import get_mask, replace_masked
from a3v.layers import RNNDropout, Seq2SeqEncoder, SoftmaxAttention, LinerEncoder
from a3v.utils import get_mask, replace_masked
# from allennlp.modules.elmo import Elmo, batch_to_ids

class TOP(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import torch
import torch.nn as nn
from mfae.layers import RNNDropout, Seq2SeqEncoderLast, SoftmaxAttention, LengthEncoder
from mfae.utils import replace_masked
from a3v.layers import RNNDropout, Seq2SeqEncoderLast, SoftmaxAttention, LengthEncoder
from a3v.utils import replace_masked
from torch.nn.modules import TransformerEncoder, TransformerEncoderLayer, LayerNorm
import math

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
172 changes: 172 additions & 0 deletions alreadrun/test_bert_mnli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
"""
Train the ESIM model on the preprocessed SNLI dataset.
"""
# Aurelien Coet, 2018.

from utils.runned.utils_test_three import validate
from a3v.model_transformer import ESIM
# from a3v.model_bert_transformer import ESIM
import os
import argparse
import json
import numpy as np
import pickle
import torch
import matplotlib
matplotlib.use('Agg')


def transform_batch_data(data, batch_size=64, shuffle=True):
data_batch = dict()
data_batch['premises'] = dict()
data_batch['hypotheses'] = dict()
data_batch['labels'] = dict()
index = np.arange(len(data['labels']))
if shuffle:
np.random.shuffle(index)

idx = -1
for i in range(len(index)):
if i % batch_size == 0:
idx += 1
data_batch['premises'][idx] = []
data_batch['hypotheses'][idx] = []
data_batch['labels'][idx] = []
data_batch['premises'][idx].append(data['premises'][index[i]])
data_batch['hypotheses'][idx].append(data['hypotheses'][index[i]])
data_batch['labels'][idx].append(int(data['labels'][index[i]]))
return data_batch


def main(train_file,
valid_file,
test_file,
target_dir,
embedding_size=512,
hidden_size=512,
dropout=0.5,
num_classes=3,
epochs=64,
batch_size=32,
lr=0.0004,
patience=5,
max_grad_norm=10.0,
checkpoint=None):
"""
Train the ESIM model on the Quora dataset.
Args:
train_file: A path to some preprocessed data that must be used
to train the model.
valid_file: A path to some preprocessed data that must be used
to validate the model.
embeddings_file: A path to some preprocessed word embeddings that
must be used to initialise the model.
target_dir: The path to a directory where the trained model must
be saved.
hidden_size: The size of the hidden layers in the model. Defaults
to 300.
dropout: The dropout rate to use in the model. Defaults to 0.5.
num_classes: The number of classes in the output of the model.
Defaults to 3.
epochs: The maximum number of epochs for training. Defaults to 64.
batch_size: The size of the batches for training. Defaults to 32.
lr: The learning rate for the optimizer. Defaults to 0.0004.
patience: The patience to use for early stopping. Defaults to 5.
checkpoint: A checkpoint from which to continue training. If None,
training starts from scratch. Defaults to None.
"""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(20 * "=", " Preparing for training ", 20 * "=")

if not os.path.exists(target_dir):
os.makedirs(target_dir)

# -------------------- Data loading ------------------- #
print("\t* Loading training data...")
with open(train_file, "rb") as pkl:
train_data = pickle.load(pkl)

print("\t* Loading validation data...")
with open(valid_file, "rb") as pkl:
valid_data = pickle.load(pkl)
valid_dataloader = transform_batch_data(valid_data, batch_size=batch_size, shuffle=False)

print("\t* Loading test data...")
with open(test_file, "rb") as pkl:
test_data = pickle.load(pkl)
test_dataloader = transform_batch_data(test_data, batch_size=batch_size, shuffle=False)

# -------------------- Model definition ------------------- #
print("\t* Building model...")

model = ESIM(embedding_size,
hidden_size,
dropout=dropout,
num_classes=num_classes,
device=device).to(device)

# -------------------- Preparation for training ------------------- #

# Continuing training from a checkpoint if one was given as argument.
if checkpoint:
checkpoint = torch.load(checkpoint)
start_epoch = checkpoint["epoch"] + 1

print("\t* Training will continue on existing model from epoch {}..."
.format(start_epoch))

model.load_state_dict(checkpoint["model"])

# Compute loss and accuracy before starting (or resuming) training.
_, valid_accuracy = validate(model, test_dataloader)
print("\t* Validation accuracy: {:.4f}%".format(valid_accuracy*100))

# _, test_loss, test_accuracy = validate(model,
# test_dataloader,
# criterion)
# print("\t* test loss before training: {:.4f}, accuracy: {:.4f}%"
# .format(test_loss, (test_accuracy*100)))



if __name__ == "__main__":
default_config = "../../config/training/mnli_training_bert.json"

parser = argparse.ArgumentParser(
description="Train the ESIM model on quora")
parser.add_argument("--config",
default=default_config,
help="Path to a json configuration file")

script_dir = os.path.dirname(os.path.realpath(__file__))
script_dir = script_dir + '/scripts/training'

parser.add_argument("--checkpoint",
default=os.path.dirname(os.path.realpath(__file__)) + '/data/checkpoints/MNLI/bert/' +"esim_{}.pth.tar".format(12),
help="Path to a checkpoint file to resume training")
args = parser.parse_args()

if args.config == default_config:
config_path = os.path.join(script_dir, args.config)
else:
config_path = args.config

with open(os.path.normpath(config_path), 'r') as config_file:
config = json.load(config_file)

main(os.path.normpath(os.path.join(script_dir, config["train_data"])),
os.path.normpath(os.path.join(script_dir, config["valid_data_matched"])),
os.path.normpath(os.path.join(script_dir, config["valid_data_mismatched"])),
os.path.normpath(os.path.join(script_dir, config["target_dir"])),
config["embedding_size"],
config["hidden_size"],
0,#config["dropout"],
config["num_classes"],
config["epochs"],
config["batch_size"],
config["lr"],
config["patience"],
config["max_gradient_norm"],
args.checkpoint)
2 changes: 1 addition & 1 deletion alreadrun/test_bert_quora.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Aurelien Coet, 2018.

from utils.runned.utils_test_two import validate
from mfae.model_transformer import ESIM
from a3v.model_transformer import ESIM
import os
import argparse
import json
Expand Down
4 changes: 2 additions & 2 deletions alreadrun/test_bert_snli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# Aurelien Coet, 2018.

from utils.runned.utils_test_three import validate
from mfae.model_transformer import ESIM
# from mfae.model_bert_transformer import ESIM
from a3v.model_transformer import ESIM
# from a3v.model_bert_transformer import ESIM
import os
import argparse
import json
Expand Down
4 changes: 2 additions & 2 deletions alreadrun/test_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import os
import torch
import torch.nn.functional as F
from mfae.droped.resnet import PreActResNet18
from mfae.droped.resnet_top import PreActResNet18Top
from a3v.droped.resnet import PreActResNet18
from a3v.droped.resnet_top import PreActResNet18Top
from torch.autograd import Variable
import sys
from utils.utils_base import creterion_cifar
Expand Down
4 changes: 2 additions & 2 deletions alreadrun/test_esim_quora.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# Aurelien Coet, 2018.

from utils.runned.utils_test_esim_quora import validate
from mfae.model import ESIM
from mfae.data import NLIDataset
from a3v.model import ESIM
from a3v.data import NLIDataset
from torch.utils.data import DataLoader
import os
import argparse
Expand Down
4 changes: 2 additions & 2 deletions alreadrun/test_esim_snli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# Aurelien Coet, 2018.

from utils.runned.utils_test_esim_snli import validate
from mfae.model import ESIM
from mfae.data import NLIDataset
from a3v.model import ESIM
from a3v.data import NLIDataset
from torch.utils.data import DataLoader
import os
import argparse
Expand Down
4 changes: 2 additions & 2 deletions alreadrun/top_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import torchvision.transforms as transforms
import os
import torch
from mfae.droped.resnet import PreActResNet18
from mfae.droped.resnet_top import PreActResNet18Top
from a3v.droped.resnet import PreActResNet18
from a3v.droped.resnet_top import PreActResNet18Top
from torch.autograd import Variable
import sys

Expand Down
2 changes: 1 addition & 1 deletion bert_mnli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Aurelien Coet, 2018.

from utils.utils_transformer import train, validate
from mfae.model_transformer import ESIM
from a3v.model_transformer import ESIM
import torch.nn as nn
import matplotlib.pyplot as plt
import os
Expand Down
2 changes: 1 addition & 1 deletion bert_mnli_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Aurelien Coet, 2018.

from utils.utils_transformer import test
from mfae.model_transformer import ESIM
from a3v.model_transformer import ESIM
import torch.nn as nn
import matplotlib.pyplot as plt
import os
Expand Down
2 changes: 1 addition & 1 deletion bert_quora.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Aurelien Coet, 2018.

from utils.utils_transformer import train, validate
from mfae.model_transformer import ESIM
from a3v.model_transformer import ESIM
import torch.nn as nn
import matplotlib.pyplot as plt
import os
Expand Down
2 changes: 1 addition & 1 deletion bert_quora_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Aurelien Coet, 2018.

from utils.utils_transformer import train_loss
from mfae.model_transformer import ESIM
from a3v.model_transformer import ESIM
import torch.nn as nn
import matplotlib.pyplot as plt
import os
Expand Down
2 changes: 1 addition & 1 deletion bert_snli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Aurelien Coet, 2018.

from utils.utils_transformer import train, validate
from mfae.model_transformer import ESIM
from a3v.model_transformer import ESIM
import torch.nn as nn
import matplotlib.pyplot as plt
import os
Expand Down
Loading

0 comments on commit 1dd90fc

Please sign in to comment.