Skip to content

Commit

Permalink
Handle tuple in pytorch model output
Browse files Browse the repository at this point in the history
  • Loading branch information
ataheridezfouli-groq committed Jan 29, 2024
1 parent 04f1b90 commit 2582343
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions demo_helpers/demo_helpers/compute_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,16 +182,21 @@ def pytorch_model_inference(dataset, model):
out = model(**inputs)

if not isinstance(out, torch.Tensor):
if "logits" in out:
out = out.logits
elif "start_logits" in out and "end_logits" in out:
out = torch.vstack((out["start_logits"], out["end_logits"]))
elif "last_hidden_state" in out:
out = out.last_hidden_state
if isinstance(out, tuple) and len(out) == 1:
out = out[0]
elif isinstance(out, dict):
if "logits" in out:
out = out.logits
elif "start_logits" in out and "end_logits" in out:
out = torch.vstack((out["start_logits"], out["end_logits"]))
elif "last_hidden_state" in out:
out = out.last_hidden_state
else:
raise ValueError(
"Unknown output key. List of keys:", list(out.keys())
)
else:
raise ValueError(
"Unknown output key. List of keys:", list(out.keys())
)
raise ValueError("Unknown output type", type(out))
pred.append(out)

return dataset.postprocess(pred)
Expand Down

0 comments on commit 2582343

Please sign in to comment.