Skip to content

Commit

Permalink
fix logic for wandb
Browse files Browse the repository at this point in the history
  • Loading branch information
lfoppiano committed Oct 28, 2024
1 parent 63cc126 commit eef22fb
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 12 deletions.
9 changes: 1 addition & 8 deletions delft/sequenceLabelling/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,22 +79,19 @@ def compile_model(self, local_model, train_size):
if local_model.config.use_chain_crf:
local_model.compile(
optimizer=optimizer,
loss=local_model.crf.sparse_crf_loss_bert_masked,
metrics = ["accuracy"] if self.enable_wandb else []
loss=local_model.crf.sparse_crf_loss_bert_masked
)
elif local_model.config.use_crf:
# loss is calculated by the custom CRF wrapper
local_model.compile(
optimizer=optimizer,
metrics = ["accuracy"] if self.enable_wandb else []
)
else:
# we apply a mask on the predicted labels so that the weights
# corresponding to special symbols are neutralized
local_model.compile(
optimizer=optimizer,
loss=sparse_crossentropy_masked,
metrics=["accuracy"] if self.enable_wandb else []
)
else:

Expand All @@ -109,14 +106,12 @@ def compile_model(self, local_model, train_size):
local_model.compile(
optimizer=optimizer,
loss=local_model.crf.loss,
metrics = ["accuracy"] if self.enable_wandb else []
)
elif local_model.config.use_crf:
if tf.executing_eagerly():
# loss is calculated by the custom CRF wrapper, no need to specify a loss function here
local_model.compile(
optimizer=optimizer,
metrics = ["accuracy"] if self.enable_wandb else []
)
else:
print("compile model, graph mode")
Expand All @@ -128,15 +123,13 @@ def compile_model(self, local_model, train_size):
local_model.compile(
optimizer=optimizer,
loss='sparse_categorical_crossentropy',
metrics = ["accuracy"] if self.enable_wandb else []
)
#local_model.compile(optimizer=optimizer, loss=InnerLossPusher(local_model))
else:
# only sparse label encoding is used (no one-hot encoded labels as it was the case in DeLFT < 0.3)
local_model.compile(
optimizer=optimizer,
loss='sparse_categorical_crossentropy',
metrics = ["accuracy"] if self.enable_wandb else []
)

return local_model
Expand Down
11 changes: 7 additions & 4 deletions delft/sequenceLabelling/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,13 @@ def __init__(self,
word_emb_size = 0
self.embeddings = None
self.model_local_path = None
if wandb_config is not None and 'project' not in wandb_config:
self.wandb_config = wandb_config
else:
raise ValueError("The wandb_config should be a dictionary with at least the string parameter 'project'. ")

self.wandb_config = None
if wandb_config is not None:
if 'project' in wandb_config:
self.wandb_config = wandb_config
else:
raise ValueError("The wandb_config should be a dictionary with at least the string parameter 'project'. ")

self.registry = load_resource_registry("delft/resources-registry.json")

Expand Down

0 comments on commit eef22fb

Please sign in to comment.