Skip to content

Commit

Permalink
Workaround for minilmv2 pytorch output issue
Browse files Browse the repository at this point in the history
  • Loading branch information
ataheridezfouli-groq committed Feb 1, 2024
1 parent 2582343 commit a4c8e47
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
7 changes: 5 additions & 2 deletions demo_helpers/demo_helpers/compute_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,11 @@ def pytorch_model_inference(dataset, model):
out = model(**inputs)

if not isinstance(out, torch.Tensor):
if isinstance(out, tuple) and len(out) == 1:
out = out[0]
if isinstance(out, tuple):
if len(out) == 1:
out = out[0]
else:
raise ValueError("Cannot handle tuple with len", len(out))
elif isinstance(out, dict):
if "logits" in out:
out = out.logits
Expand Down
4 changes: 1 addition & 3 deletions proof_points/natural_language_processing/minilm/minilmv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ def evaluate_minilm(rebuild_policy=None, should_execute=True):

# load pre-trained torch model
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AutoModel.from_pretrained(
"sentence-transformers/all-MiniLM-L6-v2", torchscript=True
)
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

# dummy inputs to generate the groq model
max_seq_length = 128
Expand Down

0 comments on commit a4c8e47

Please sign in to comment.