Skip to content

Commit 37dcb95

Browse files
authored
Merge pull request #20 from zenml-io/bugfix/trust_remote_code_in_load_metrics
add trust_remote_code=True
2 parents cf674b5 + 4fc1ae5 commit 37dcb95

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

template/steps/dataset_loader/data_loader.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,23 @@ def data_loader(
3232

3333
# Load dataset based on the dataset value
3434
{%- if dataset == 'financial_news' %}
35-
dataset = load_dataset("zeroshot/twitter-financial-news-sentiment")
35+
dataset = load_dataset(
36+
"zeroshot/twitter-financial-news-sentiment",
37+
trust_remote_code=True,
38+
)
3639
{%- endif %}
3740
{%- if dataset == 'imdb_reviews' %}
38-
dataset = load_dataset("imdb")["train"]
41+
dataset = load_dataset(
42+
"imdb",
43+
trust_remote_code=True,
44+
)["train"]
3945
dataset = dataset.train_test_split(test_size=0.25, shuffle=True)
4046
{%- endif %}
4147
{%- if dataset == 'airline_reviews' %}
42-
dataset = load_dataset("Shayanvsf/US_Airline_Sentiment")
48+
dataset = load_dataset(
49+
"Shayanvsf/US_Airline_Sentiment",
50+
trust_remote_code=True,
51+
)
4352
dataset = dataset.rename_column("airline_sentiment", "label")
4453
dataset = dataset.remove_columns(["airline_sentiment_confidence","negativereason_confidence"])
4554
{%- endif %}

template/utils/misc.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ def compute_metrics(eval_pred: Tuple[np.ndarray, np.ndarray]) -> Dict[str, float
1818
logits, labels = eval_pred
1919
predictions = np.argmax(logits, axis=-1)
2020
# calculate the mertic using the predicted and true value
21-
accuracy = load_metric("accuracy").compute(
21+
accuracy = load_metric("accuracy", trust_remote_code=True).compute(
2222
predictions=predictions, references=labels
2323
)
24-
f1 = load_metric("f1").compute(
24+
f1 = load_metric("f1", trust_remote_code=True).compute(
2525
predictions=predictions, references=labels, average="weighted"
2626
)
27-
precision = load_metric("precision").compute(
27+
precision = load_metric("precision", trust_remote_code=True).compute(
2828
predictions=predictions, references=labels, average="weighted"
2929
)
3030
return {

0 commit comments

Comments
 (0)