Skip to content

Commit

Permalink
improve the handling of parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
lfoppiano committed Oct 26, 2024
1 parent a6ae760 commit 9517d9f
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 70 deletions.
94 changes: 48 additions & 46 deletions delft/applications/grobidTagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import argparse
import json
import time
from typing import Dict

from sklearn.model_selection import train_test_split

Expand Down Expand Up @@ -166,9 +167,9 @@ def configure(model, architecture, output_path=None, max_sequence_length=-1, bat
# train a GROBID model with all available data

def train(model, embeddings_name=None, architecture=None, transformer=None, input_path=None,
output_path=None, features_indices=None, max_sequence_length=-1, batch_size=-1, max_epoch=-1,
use_ELMo=False, incremental=False, input_model_path=None, patience=-1, learning_rate=None, early_stop=None, multi_gpu=False,
enable_wandb=False):
output_path=None, features_indices=None, max_sequence_length=-1, batch_size=-1, max_epoch=-1,
use_ELMo=False, incremental=False, input_model_path=None, patience=-1, learning_rate=None, early_stop=None, multi_gpu=False,
wandb_config=None):

print('Loading data...')
if input_path == None:
Expand All @@ -186,31 +187,38 @@ def train(model, embeddings_name=None, architecture=None, transformer=None, inpu
print("\nmax train sequence length:", str(longest_row(x_train)))
print("max validation sequence length:", str(longest_row(x_valid)))

batch_size, max_sequence_length, model_name, embeddings_name, max_epoch, multiprocessing, early_stop, patience = configure(model,
architecture,
output_path,
max_sequence_length,
batch_size,
embeddings_name,
max_epoch,
use_ELMo,
patience, early_stop)

model = Sequence(model_name,
recurrent_dropout=0.50,
embeddings_name=embeddings_name,
architecture=architecture,
transformer_name=transformer,
batch_size=batch_size,
max_sequence_length=max_sequence_length,
features_indices=features_indices,
max_epoch=max_epoch,
use_ELMo=use_ELMo,
multiprocessing=multiprocessing,
early_stop=early_stop,
patience=patience,
learning_rate=learning_rate,
enable_wandb=enable_wandb)
(batch_size, max_sequence_length, model_name,
embeddings_name, max_epoch, multiprocessing,
early_stop, patience) = configure(
model,
architecture,
output_path,
max_sequence_length,
batch_size,
embeddings_name,
max_epoch,
use_ELMo,
patience,
early_stop
)

model = Sequence(
model_name,
recurrent_dropout=0.50,
embeddings_name=embeddings_name,
architecture=architecture,
transformer_name=transformer,
batch_size=batch_size,
max_sequence_length=max_sequence_length,
features_indices=features_indices,
max_epoch=max_epoch,
use_ELMo=use_ELMo,
multiprocessing=multiprocessing,
early_stop=early_stop,
patience=patience,
learning_rate=learning_rate,
wandb_config=wandb_config
)

if incremental:
if input_model_path != None:
Expand Down Expand Up @@ -238,7 +246,7 @@ def train_eval(model, embeddings_name=None, architecture='BidLSTM_CRF', transfor
input_path=None, output_path=None, fold_count=1,
features_indices=None, max_sequence_length=-1, batch_size=-1, max_epoch=-1,
use_ELMo=False, incremental=False, input_model_path=None, patience=-1,
learning_rate=None, early_stop=None, multi_gpu=False, enable_wandb: bool=False):
learning_rate=None, early_stop=None, multi_gpu=False, wandb_config=None):

print('Loading data...')
if input_path is None:
Expand Down Expand Up @@ -267,22 +275,11 @@ def train_eval(model, embeddings_name=None, architecture='BidLSTM_CRF', transfor
use_ELMo,
patience,
early_stop)
model = Sequence(model_name,
recurrent_dropout=0.50,
embeddings_name=embeddings_name,
architecture=architecture,
transformer_name=transformer,
max_sequence_length=max_sequence_length,
batch_size=batch_size,
fold_number=fold_count,
features_indices=features_indices,
max_epoch=max_epoch,
use_ELMo=use_ELMo,
multiprocessing=multiprocessing,
early_stop=early_stop,
patience=patience,
learning_rate=learning_rate,
enable_wandb=enable_wandb)
model = Sequence(model_name, architecture=architecture, embeddings_name=embeddings_name,
max_sequence_length=max_sequence_length, recurrent_dropout=0.50, batch_size=batch_size,
learning_rate=learning_rate, max_epoch=max_epoch, early_stop=early_stop, patience=patience,
use_ELMo=use_ELMo, fold_number=fold_count, multiprocessing=multiprocessing,
features_indices=features_indices, transformer_name=transformer, wandb_config=wandb_config)

if incremental:
if input_model_path != None:
Expand Down Expand Up @@ -482,6 +479,11 @@ class Tasks:
# default word embeddings
embeddings_name = "glove-840B"

if wandb:
wandb_config = {
"project": "delft-grobidTagger"
}

if action == Tasks.TRAIN:
train(
model,
Expand All @@ -500,7 +502,7 @@ class Tasks:
max_epoch=max_epoch,
early_stop=early_stop,
multi_gpu=multi_gpu,
enable_wandb=wandb
wandb_config=wandb_config
)

if action == Tasks.EVAL:
Expand Down
33 changes: 26 additions & 7 deletions delft/sequenceLabelling/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,18 +80,22 @@ def compile_model(self, local_model, train_size):
local_model.compile(
optimizer=optimizer,
loss=local_model.crf.sparse_crf_loss_bert_masked,
metrics = [wandb.config.metric] if self.enable_wandb else []
metrics = ["accuracy"] if self.enable_wandb else []
)
elif local_model.config.use_crf:
# loss is calculated by the custom CRF wrapper
local_model.compile(
optimizer=optimizer,
metrics = [wandb.config.metric] if self.enable_wandb else []
metrics = ["accuracy"] if self.enable_wandb else []
)
else:
# we apply a mask on the predicted labels so that the weights
# corresponding to special symbols are neutralized
local_model.compile(optimizer=optimizer, loss=sparse_crossentropy_masked)
local_model.compile(
optimizer=optimizer,
loss=sparse_crossentropy_masked,
metrics=["accuracy"] if self.enable_wandb else []
)
else:

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
Expand All @@ -102,23 +106,38 @@ def compile_model(self, local_model, train_size):

#optimizer = tf.keras.optimizers.Adam(self.training_config.learning_rate)
if local_model.config.use_chain_crf:
local_model.compile(optimizer=optimizer, loss=local_model.crf.loss)
local_model.compile(
optimizer=optimizer,
loss=local_model.crf.loss,
metrics = ["accuracy"] if self.enable_wandb else []
)
elif local_model.config.use_crf:
if tf.executing_eagerly():
# loss is calculated by the custom CRF wrapper, no need to specify a loss function here
local_model.compile(optimizer=optimizer)
local_model.compile(
optimizer=optimizer,
metrics = ["accuracy"] if self.enable_wandb else []
)
else:
print("compile model, graph mode")
# always expecting a loss function here, but it is calculated internally by the CRF wapper
# the following will fail in graph mode because
# '<tf.Variable 'chain_kernel:0' shape=(10, 10) dtype=float32> has `None` for gradient.'
# however this variable cannot be accessed, so no soluton for the moment
# (probably need not using keras fit and creating a custom training loop to get the gradient)
local_model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy')
local_model.compile(
optimizer=optimizer,
loss='sparse_categorical_crossentropy',
metrics = ["accuracy"] if self.enable_wandb else []
)
#local_model.compile(optimizer=optimizer, loss=InnerLossPusher(local_model))
else:
# only sparse label encoding is used (no one-hot encoded labels as it was the case in DeLFT < 0.3)
local_model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy')
local_model.compile(
optimizer=optimizer,
loss='sparse_categorical_crossentropy',
metrics = ["accuracy"] if self.enable_wandb else []
)

return local_model

Expand Down
37 changes: 20 additions & 17 deletions delft/sequenceLabelling/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import Dict

from packaging import version
from tensorflow.python.keras.models import model_from_config
Expand Down Expand Up @@ -88,7 +89,7 @@ def __init__(self,
multiprocessing=True,
features_indices=None,
transformer_name: str = None,
enable_wandb: bool = False):
wandb_config = None):

if model_name is None:
# add a dummy name based on the architecture
Expand All @@ -107,7 +108,10 @@ def __init__(self,
word_emb_size = 0
self.embeddings = None
self.model_local_path = None
self.enable_wandb = enable_wandb
if wandb_config is not None and 'project' not in wandb_config:
self.wandb_config = wandb_config
else:
raise ValueError("The wandb_config should be a dictionary with at least the string parameter 'project'. ")

self.registry = load_resource_registry("delft/resources-registry.json")

Expand Down Expand Up @@ -146,37 +150,36 @@ def __init__(self,
early_stop, patience,
max_checkpoints_to_keep, multiprocessing)

if enable_wandb:
if wandb_config:
import wandb
wandb.init(
project="delft",
project=wandb_config["project"],
name=model_name,
# settings=wandb.Settings(start_method="fork"),
config={
"model_name": self.model_config.model_name,
"architecture": self.model_config.architecture,
"char_emb_size": self.model_config.char_embedding_size,
"max_char_length": self.model_config.max_char_length,
"max_sequence_length": self.model_config.max_sequence_length,
"dropout": self.model_config.dropout,
"recurrent_dropout": self.model_config.recurrent_dropout,
"fold_number": self.model_config.fold_number,
"transformer_name": self.model_config.transformer_name,
"embeddings_name": self.model_config.embeddings_name,
"embedding_size": self.model_config.word_embedding_size,
"batch_size": self.training_config.batch_size,
"optimizer": self.training_config.optimizer,
"learning_rate": self.training_config.learning_rate,
"lr_decay": self.training_config.lr_decay,
"clip_gradients": self.training_config.clip_gradients,
"max_epoch": self.training_config.max_epoch,
"early_stop": self.training_config.early_stop,
"patience": self.training_config.patience,
"max_checkpoints_to_keep": self.training_config.max_checkpoints_to_keep,
"multiprocessing": self.training_config.multiprocessing
"char_emb_size": self.model_config.char_embedding_size,
"max_char_length": self.model_config.max_char_length,
"max_sequence_length": self.model_config.max_sequence_length,
"dropout": self.model_config.dropout,
"recurrent_dropout": self.model_config.recurrent_dropout,
"optimizer": self.training_config.optimizer,
"clip_gradients": self.training_config.clip_gradients
}
)


def train(self, x_train, y_train, f_train=None, x_valid=None, y_valid=None, f_valid=None, incremental=False, callbacks=None, multi_gpu=False):
if self.enable_wandb:
if self.wandb_config:
from wandb.integration.keras import WandbMetricsLogger
from wandb.integration.keras import WandbModelCheckpoint
callbacks = callbacks + [
Expand Down Expand Up @@ -251,7 +254,7 @@ def train_(self, x_train, y_train, f_train=None, x_valid=None, y_valid=None, f_v
checkpoint_path=self.log_dir,
preprocessor=self.p,
transformer_preprocessor=self.model.transformer_preprocessor,
enable_wandb=self.enable_wandb
enable_wandb=self.wandb_config
)
trainer.train(x_train, y_train, x_valid, y_valid, features_train=f_train, features_valid=f_valid, callbacks=callbacks)
if self.embeddings and self.embeddings.use_ELMo:
Expand Down

0 comments on commit 9517d9f

Please sign in to comment.