Skip to content

Commit

Permalink
configure wandb with env variables
Browse files Browse the repository at this point in the history
  • Loading branch information
lfoppiano committed Oct 29, 2024
1 parent d6ab668 commit 8f95a13
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 48 deletions.
3 changes: 3 additions & 0 deletions delft/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import os

DELFT_PROJECT_DIR = os.path.dirname(__file__)
4 changes: 2 additions & 2 deletions delft/applications/grobidTagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def train(model, embeddings_name=None, architecture=None, transformer=None, inpu
early_stop=early_stop,
patience=patience,
learning_rate=learning_rate,
wandb_config=wandb_config
report_to_wandb=wandb_config
)

if incremental:
Expand Down Expand Up @@ -279,7 +279,7 @@ def train_eval(model, embeddings_name=None, architecture='BidLSTM_CRF', transfor
max_sequence_length=max_sequence_length, recurrent_dropout=0.50, batch_size=batch_size,
learning_rate=learning_rate, max_epoch=max_epoch, early_stop=early_stop, patience=patience,
use_ELMo=use_ELMo, fold_number=fold_count, multiprocessing=multiprocessing,
features_indices=features_indices, transformer_name=transformer, wandb_config=wandb_config)
features_indices=features_indices, transformer_name=transformer, report_to_wandb=wandb_config)

if incremental:
if input_model_path != None:
Expand Down
90 changes: 45 additions & 45 deletions delft/sequenceLabelling/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import os
from typing import Dict

from packaging import version
from tensorflow.python.keras.models import model_from_config

from delft import DELFT_PROJECT_DIR
# ask tensorflow to be quiet and not print hundred lines of logs
from delft.utilities.Transformer import TRANSFORMER_CONFIG_FILE_NAME, DEFAULT_TRANSFORMER_TOKENIZER_DIR
from delft.utilities.misc import print_parameters
Expand Down Expand Up @@ -63,33 +62,35 @@ class Sequence(object):
# number of parallel worker for the data generator
nb_workers = 6

def __init__(self,
model_name=None,
architecture=None,
embeddings_name=None,
char_emb_size=25,
max_char_length=30,
char_lstm_units=25,
word_lstm_units=100,
max_sequence_length=300,
dropout=0.5,
recurrent_dropout=0.25,
batch_size=20,
optimizer='adam',
learning_rate=None,
lr_decay=0.9,
clip_gradients=5.0,
max_epoch=50,
early_stop=True,
patience=5,
max_checkpoints_to_keep=0,
use_ELMo=False,
log_dir=None,
fold_number=1,
multiprocessing=True,
features_indices=None,
transformer_name: str = None,
wandb_config = None):
def __init__(
self,
model_name=None,
architecture=None,
embeddings_name=None,
char_emb_size=25,
max_char_length=30,
char_lstm_units=25,
word_lstm_units=100,
max_sequence_length=300,
dropout=0.5,
recurrent_dropout=0.25,
batch_size=20,
optimizer='adam',
learning_rate=None,
lr_decay=0.9,
clip_gradients=5.0,
max_epoch=50,
early_stop=True,
patience=5,
max_checkpoints_to_keep=0,
use_ELMo=False,
log_dir=None,
fold_number=1,
multiprocessing=True,
features_indices=None,
transformer_name: str = None,
report_to_wandb = False
):

if model_name is None:
# add a dummy name based on the architecture
Expand All @@ -109,14 +110,9 @@ def __init__(self,
self.embeddings = None
self.model_local_path = None

self.wandb_config = None
if wandb_config is not None:
if 'project' in wandb_config:
self.wandb_config = wandb_config
else:
raise ValueError("The wandb_config should be a dictionary with at least the string parameter 'project'. ")
self.report_to_wandb = report_to_wandb

self.registry = load_resource_registry("delft/resources-registry.json")
self.registry = load_resource_registry(os.path.join(DELFT_PROJECT_DIR, "resources-registry.json"))

if self.embeddings_name is not None:
self.embeddings = Embeddings(self.embeddings_name, resource_registry=self.registry, use_ELMo=use_ELMo)
Expand Down Expand Up @@ -153,10 +149,15 @@ def __init__(self,
early_stop, patience,
max_checkpoints_to_keep, multiprocessing)

if wandb_config:
if report_to_wandb:
import wandb
from dotenv import load_dotenv
load_dotenv(override=True)
if os.getenv("WANDB_API_KEY") is None:
print("Warning: WANDB_API_KEY not set, wandb disabled")
self.report_to_wandb = False
return
wandb.init(
project=wandb_config["project"],
name=model_name,
config={
"model_name": self.model_config.model_name,
Expand Down Expand Up @@ -184,15 +185,14 @@ def __init__(self,


def train(self, x_train, y_train, f_train=None, x_valid=None, y_valid=None, f_valid=None, incremental=False, callbacks=None, multi_gpu=False):
if self.wandb_config:
if self.report_to_wandb:
from wandb.integration.keras import WandbMetricsLogger
from wandb.integration.keras import WandbModelCheckpoint
callbacks = callbacks + [
WandbMetricsLogger(log_freq=5),
WandbModelCheckpoint("models")
WandbMetricsLogger(),
# WandbModelCheckpoint("models", monitor='f1', mode='max')
] if callbacks is not None else [
WandbMetricsLogger(log_freq=5),
WandbModelCheckpoint("models")
WandbMetricsLogger(),
# WandbModelCheckpoint("models", monitor='f1', mode='max')
]

# TBD if valid is None, segment train to get one if early_stop is True
Expand Down Expand Up @@ -259,7 +259,7 @@ def train_(self, x_train, y_train, f_train=None, x_valid=None, y_valid=None, f_v
checkpoint_path=self.log_dir,
preprocessor=self.p,
transformer_preprocessor=self.model.transformer_preprocessor,
enable_wandb=self.wandb_config
enable_wandb=self.report_to_wandb
)
trainer.train(x_train, y_train, x_valid, y_valid, features_train=f_train, features_valid=f_valid, callbacks=callbacks)
if self.embeddings and self.embeddings.use_ELMo:
Expand Down
2 changes: 1 addition & 1 deletion delft/utilities/Transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def save_tokenizer(self, output_directory):

def instantiate_layer(self, load_pretrained_weights=True) -> Union[object, TFAutoModel, TFBertModel]:
"""
Instanciate a transformer to be loaded in a Keras layer using the availability method of the pre-trained transformer.
Instantiate a transformer to be loaded in a Keras layer using the availability method of the pre-trained transformer.
"""
if self.loading_method == LOADING_METHOD_HUGGINGFACE_NAME:
if load_pretrained_weights:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ pytest
tensorflow-addons==0.19.0
blingfire==0.1.8
accelerate>=0.20.3
python-dontenv

0 comments on commit 8f95a13

Please sign in to comment.