diff --git a/delft/sequenceLabelling/trainer.py b/delft/sequenceLabelling/trainer.py index cdc20a8..7dcc3c9 100644 --- a/delft/sequenceLabelling/trainer.py +++ b/delft/sequenceLabelling/trainer.py @@ -79,14 +79,12 @@ def compile_model(self, local_model, train_size): if local_model.config.use_chain_crf: local_model.compile( optimizer=optimizer, - loss=local_model.crf.sparse_crf_loss_bert_masked, - metrics = ["accuracy"] if self.enable_wandb else [] + loss=local_model.crf.sparse_crf_loss_bert_masked ) elif local_model.config.use_crf: # loss is calculated by the custom CRF wrapper local_model.compile( optimizer=optimizer, - metrics = ["accuracy"] if self.enable_wandb else [] ) else: # we apply a mask on the predicted labels so that the weights @@ -94,7 +92,6 @@ def compile_model(self, local_model, train_size): local_model.compile( optimizer=optimizer, loss=sparse_crossentropy_masked, - metrics=["accuracy"] if self.enable_wandb else [] ) else: @@ -109,14 +106,12 @@ def compile_model(self, local_model, train_size): local_model.compile( optimizer=optimizer, loss=local_model.crf.loss, - metrics = ["accuracy"] if self.enable_wandb else [] ) elif local_model.config.use_crf: if tf.executing_eagerly(): # loss is calculated by the custom CRF wrapper, no need to specify a loss function here local_model.compile( optimizer=optimizer, - metrics = ["accuracy"] if self.enable_wandb else [] ) else: print("compile model, graph mode") @@ -128,7 +123,6 @@ def compile_model(self, local_model, train_size): local_model.compile( optimizer=optimizer, loss='sparse_categorical_crossentropy', - metrics = ["accuracy"] if self.enable_wandb else [] ) #local_model.compile(optimizer=optimizer, loss=InnerLossPusher(local_model)) else: @@ -136,7 +130,6 @@ def compile_model(self, local_model, train_size): local_model.compile( optimizer=optimizer, loss='sparse_categorical_crossentropy', - metrics = ["accuracy"] if self.enable_wandb else [] ) return local_model diff --git a/delft/sequenceLabelling/wrapper.py b/delft/sequenceLabelling/wrapper.py index e52e047..d6bbb04 100644 --- a/delft/sequenceLabelling/wrapper.py +++ b/delft/sequenceLabelling/wrapper.py @@ -108,10 +108,13 @@ def __init__(self, word_emb_size = 0 self.embeddings = None self.model_local_path = None - if wandb_config is not None and 'project' not in wandb_config: - self.wandb_config = wandb_config - else: - raise ValueError("The wandb_config should be a dictionary with at least the string parameter 'project'. ") + + self.wandb_config = None + if wandb_config is not None: + if 'project' in wandb_config: + self.wandb_config = wandb_config + else: + raise ValueError("The wandb_config should be a dictionary with at least the string parameter 'project'. ") self.registry = load_resource_registry("delft/resources-registry.json")