Skip to content

Commit

Permalink
more debug
Browse files Browse the repository at this point in the history
  • Loading branch information
kermitt2 committed Jan 28, 2024
1 parent 7061ff9 commit bf19c65
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions delft/applications/licenseClassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def train_binary(embeddings_name, fold_count, architecture="gru", transformer=No
print('loading multiclass copyright/license dataset...')
xtr, y_copyrights = _read_data('data/textClassification/licenses/copyrights-licenses-data-validated.csv', data_type="copyrights")

report_training_contexts(y_copyrights)
report_training_copyrights(y_copyrights)

for class_rank in range(len(list_classes_copyright)):
model_name = 'copyright_' + list_classes_copyright[class_rank] + '_'+architecture
Expand All @@ -161,9 +161,9 @@ def train_binary(embeddings_name, fold_count, architecture="gru", transformer=No
class_weights=class_weights, transformer_name=transformer)

if fold_count == 1:
model.train(x_train, y_train_class_rank)
model.train(xtr, y_train_class_rank)
else:
model.train_nfold(x_train, y_train_class_rank)
model.train_nfold(xtr, y_train_class_rank)
# saving the model
model.save()

Expand All @@ -186,9 +186,9 @@ def train_binary(embeddings_name, fold_count, architecture="gru", transformer=No
class_weights=class_weights, transformer_name=transformer)

if fold_count == 1:
model.train(x_train, y_train_class_rank)
model.train(xtr, y_train_class_rank)
else:
model.train_nfold(x_train, y_train_class_rank)
model.train_nfold(xtr, y_train_class_rank)
# saving the model
model.save()

Expand All @@ -197,7 +197,7 @@ def train_and_eval_binary(embeddings_name, fold_count, architecture="gru", trans
print('loading multiclass copyright/license dataset...')
xtr, y_copyrights = _read_data('data/textClassification/licenses/copyrights-licenses-data-validated.csv', data_type="copyrights")

report_training_contexts(y_copyrights)
report_training_copyrights(y_copyrights)

# segment train and eval sets
x_train, y_train, x_test, y_test = split_data_and_labels(xtr, y_copyrights, 0.9)
Expand Down

0 comments on commit bf19c65

Please sign in to comment.