From 8f95a13f1b37909c58082cc8ec95851af44698e2 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Tue, 29 Oct 2024 07:40:47 +0100 Subject: [PATCH] configure wandb with env variables --- delft/__init__.py | 3 + delft/applications/grobidTagger.py | 4 +- delft/sequenceLabelling/wrapper.py | 90 +++++++++++++++--------------- delft/utilities/Transformer.py | 2 +- requirements.txt | 1 + 5 files changed, 52 insertions(+), 48 deletions(-) diff --git a/delft/__init__.py b/delft/__init__.py index e69de29b..d5621680 100644 --- a/delft/__init__.py +++ b/delft/__init__.py @@ -0,0 +1,3 @@ +import os + +DELFT_PROJECT_DIR = os.path.dirname(__file__) \ No newline at end of file diff --git a/delft/applications/grobidTagger.py b/delft/applications/grobidTagger.py index 474d9260..d3f0a1dd 100644 --- a/delft/applications/grobidTagger.py +++ b/delft/applications/grobidTagger.py @@ -217,7 +217,7 @@ def train(model, embeddings_name=None, architecture=None, transformer=None, inpu early_stop=early_stop, patience=patience, learning_rate=learning_rate, - wandb_config=wandb_config + report_to_wandb=wandb_config ) if incremental: @@ -279,7 +279,7 @@ def train_eval(model, embeddings_name=None, architecture='BidLSTM_CRF', transfor max_sequence_length=max_sequence_length, recurrent_dropout=0.50, batch_size=batch_size, learning_rate=learning_rate, max_epoch=max_epoch, early_stop=early_stop, patience=patience, use_ELMo=use_ELMo, fold_number=fold_count, multiprocessing=multiprocessing, - features_indices=features_indices, transformer_name=transformer, wandb_config=wandb_config) + features_indices=features_indices, transformer_name=transformer, report_to_wandb=wandb_config) if incremental: if input_model_path != None: diff --git a/delft/sequenceLabelling/wrapper.py b/delft/sequenceLabelling/wrapper.py index 7313d2d2..4ca86b60 100644 --- a/delft/sequenceLabelling/wrapper.py +++ b/delft/sequenceLabelling/wrapper.py @@ -1,9 +1,8 @@ import os -from typing import Dict from packaging import version -from tensorflow.python.keras.models import model_from_config +from delft import DELFT_PROJECT_DIR # ask tensorflow to be quiet and not print hundred lines of logs from delft.utilities.Transformer import TRANSFORMER_CONFIG_FILE_NAME, DEFAULT_TRANSFORMER_TOKENIZER_DIR from delft.utilities.misc import print_parameters @@ -63,33 +62,35 @@ class Sequence(object): # number of parallel worker for the data generator nb_workers = 6 - def __init__(self, - model_name=None, - architecture=None, - embeddings_name=None, - char_emb_size=25, - max_char_length=30, - char_lstm_units=25, - word_lstm_units=100, - max_sequence_length=300, - dropout=0.5, - recurrent_dropout=0.25, - batch_size=20, - optimizer='adam', - learning_rate=None, - lr_decay=0.9, - clip_gradients=5.0, - max_epoch=50, - early_stop=True, - patience=5, - max_checkpoints_to_keep=0, - use_ELMo=False, - log_dir=None, - fold_number=1, - multiprocessing=True, - features_indices=None, - transformer_name: str = None, - wandb_config = None): + def __init__( + self, + model_name=None, + architecture=None, + embeddings_name=None, + char_emb_size=25, + max_char_length=30, + char_lstm_units=25, + word_lstm_units=100, + max_sequence_length=300, + dropout=0.5, + recurrent_dropout=0.25, + batch_size=20, + optimizer='adam', + learning_rate=None, + lr_decay=0.9, + clip_gradients=5.0, + max_epoch=50, + early_stop=True, + patience=5, + max_checkpoints_to_keep=0, + use_ELMo=False, + log_dir=None, + fold_number=1, + multiprocessing=True, + features_indices=None, + transformer_name: str = None, + report_to_wandb = False + ): if model_name is None: # add a dummy name based on the architecture @@ -109,14 +110,9 @@ def __init__(self, self.embeddings = None self.model_local_path = None - self.wandb_config = None - if wandb_config is not None: - if 'project' in wandb_config: - self.wandb_config = wandb_config - else: - raise ValueError("The wandb_config should be a dictionary with at least the string parameter 'project'. ") + self.report_to_wandb = report_to_wandb - self.registry = load_resource_registry("delft/resources-registry.json") + self.registry = load_resource_registry(os.path.join(DELFT_PROJECT_DIR, "resources-registry.json")) if self.embeddings_name is not None: self.embeddings = Embeddings(self.embeddings_name, resource_registry=self.registry, use_ELMo=use_ELMo) @@ -153,10 +149,15 @@ def __init__(self, early_stop, patience, max_checkpoints_to_keep, multiprocessing) - if wandb_config: + if report_to_wandb: import wandb + from dotenv import load_dotenv + load_dotenv(override=True) + if os.getenv("WANDB_API_KEY") is None: + print("Warning: WANDB_API_KEY not set, wandb disabled") + self.report_to_wandb = False + return wandb.init( - project=wandb_config["project"], name=model_name, config={ "model_name": self.model_config.model_name, @@ -184,15 +185,14 @@ def __init__(self, def train(self, x_train, y_train, f_train=None, x_valid=None, y_valid=None, f_valid=None, incremental=False, callbacks=None, multi_gpu=False): - if self.wandb_config: + if self.report_to_wandb: from wandb.integration.keras import WandbMetricsLogger - from wandb.integration.keras import WandbModelCheckpoint callbacks = callbacks + [ - WandbMetricsLogger(log_freq=5), - WandbModelCheckpoint("models") + WandbMetricsLogger(), + # WandbModelCheckpoint("models", monitor='f1', mode='max') ] if callbacks is not None else [ - WandbMetricsLogger(log_freq=5), - WandbModelCheckpoint("models") + WandbMetricsLogger(), + # WandbModelCheckpoint("models", monitor='f1', mode='max') ] # TBD if valid is None, segment train to get one if early_stop is True @@ -259,7 +259,7 @@ def train_(self, x_train, y_train, f_train=None, x_valid=None, y_valid=None, f_v checkpoint_path=self.log_dir, preprocessor=self.p, transformer_preprocessor=self.model.transformer_preprocessor, - enable_wandb=self.wandb_config + enable_wandb=self.report_to_wandb ) trainer.train(x_train, y_train, x_valid, y_valid, features_train=f_train, features_valid=f_valid, callbacks=callbacks) if self.embeddings and self.embeddings.use_ELMo: diff --git a/delft/utilities/Transformer.py b/delft/utilities/Transformer.py index 24e3a746..ad474d89 100644 --- a/delft/utilities/Transformer.py +++ b/delft/utilities/Transformer.py @@ -170,7 +170,7 @@ def save_tokenizer(self, output_directory): def instantiate_layer(self, load_pretrained_weights=True) -> Union[object, TFAutoModel, TFBertModel]: """ - Instanciate a transformer to be loaded in a Keras layer using the availability method of the pre-trained transformer. + Instantiate a transformer to be loaded in a Keras layer using the availability method of the pre-trained transformer. """ if self.loading_method == LOADING_METHOD_HUGGINGFACE_NAME: if load_pretrained_weights: diff --git a/requirements.txt b/requirements.txt index c45c1297..74cf833b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,3 +16,4 @@ pytest tensorflow-addons==0.19.0 blingfire==0.1.8 accelerate>=0.20.3 +python-dontenv