From b0388777a7873b4c13de3583cba1de9f25728705 Mon Sep 17 00:00:00 2001 From: Gabriele Sarti Date: Tue, 20 Feb 2024 14:56:47 +0100 Subject: [PATCH] Fix casting step outputs to device (#253) --- inseq/attr/feat/feature_attribution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/inseq/attr/feat/feature_attribution.py b/inseq/attr/feat/feature_attribution.py index f0e61319..250cd700 100644 --- a/inseq/attr/feat/feature_attribution.py +++ b/inseq/attr/feat/feature_attribution.py @@ -590,11 +590,11 @@ def filtered_attribute_step( batch=batch, ) step_fn_extra_args = get_step_scores_args([score], step_scores_args) - step_output.step_scores[score] = get_step_scores(score, step_fn_args, step_fn_extra_args) + step_output.step_scores[score] = get_step_scores(score, step_fn_args, step_fn_extra_args).to("cpu") # Reinsert finished sentences if target_attention_mask is not None and is_filtered: step_output.remap_from_filtered(target_attention_mask, orig_batch) - step_output = step_output.detach() + step_output = step_output.detach().to("cpu") return step_output def get_attribution_args(self, **kwargs) -> tuple[dict[str, Any], dict[str, Any]]: