diff --git a/demo_helpers/demo_helpers/compute_performance.py b/demo_helpers/demo_helpers/compute_performance.py index 8e3bfef..1e030cc 100644 --- a/demo_helpers/demo_helpers/compute_performance.py +++ b/demo_helpers/demo_helpers/compute_performance.py @@ -182,16 +182,21 @@ def pytorch_model_inference(dataset, model): out = model(**inputs) if not isinstance(out, torch.Tensor): - if "logits" in out: - out = out.logits - elif "start_logits" in out and "end_logits" in out: - out = torch.vstack((out["start_logits"], out["end_logits"])) - elif "last_hidden_state" in out: - out = out.last_hidden_state + if isinstance(out, tuple) and len(out) == 1: + out = out[0] + elif isinstance(out, dict): + if "logits" in out: + out = out.logits + elif "start_logits" in out and "end_logits" in out: + out = torch.vstack((out["start_logits"], out["end_logits"])) + elif "last_hidden_state" in out: + out = out.last_hidden_state + else: + raise ValueError( + "Unknown output key. List of keys:", list(out.keys()) + ) else: - raise ValueError( - "Unknown output key. List of keys:", list(out.keys()) - ) + raise ValueError("Unknown output type", type(out)) pred.append(out) return dataset.postprocess(pred)