Skip to content

Commit

Permalink
Fix casting step outputs to device (#253)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti authored Feb 20, 2024
1 parent ff4ac86 commit b038877
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions inseq/attr/feat/feature_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,11 +590,11 @@ def filtered_attribute_step(
batch=batch,
)
step_fn_extra_args = get_step_scores_args([score], step_scores_args)
step_output.step_scores[score] = get_step_scores(score, step_fn_args, step_fn_extra_args)
step_output.step_scores[score] = get_step_scores(score, step_fn_args, step_fn_extra_args).to("cpu")
# Reinsert finished sentences
if target_attention_mask is not None and is_filtered:
step_output.remap_from_filtered(target_attention_mask, orig_batch)
step_output = step_output.detach()
step_output = step_output.detach().to("cpu")
return step_output

def get_attribution_args(self, **kwargs) -> tuple[dict[str, Any], dict[str, Any]]:
Expand Down

0 comments on commit b038877

Please sign in to comment.