Skip to content

Commit

Permalink
code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
felixgwu committed Nov 29, 2019
1 parent ef88ac4 commit 55345eb
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
4 changes: 1 addition & 3 deletions bert_score/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@


model2layers = {
'bert-base-multilingual-cased' : 9,
'bert-base-uncased': 9,
'bert-large-uncased': 18,
'bert-base-cased-finetuned-mrpc': 9,
Expand Down Expand Up @@ -111,7 +110,6 @@ def padding(arr, pad_token, dtype=torch.long):

def bert_encode(model, x, attention_mask, all_layers=False):
model.eval()
x_seg = torch.zeros_like(x, dtype=torch.long)
with torch.no_grad():
out = model(x, attention_mask=attention_mask)
if all_layers:
Expand Down Expand Up @@ -170,7 +168,7 @@ def collate_idf(arr, tokenizer, idf_dict, device='cuda:0'):

idf_weights = [[idf_dict[i] for i in a] for a in arr]

pad_token = tokenizer._convert_token_to_id(tokenizer.pad_token)
pad_token = tokenizer.pad_token_id

padded, lens, mask = padding(arr, pad_token, dtype=torch.long)
padded_idf, _, _ = padding(idf_weights, 0, dtype=torch.float)
Expand Down
14 changes: 7 additions & 7 deletions tests/test_bert_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import bert_score

eps = 1e-6
EPS = 1e-6

cands = [
"28-year-old chef found dead in San Francisco mall",
Expand All @@ -25,9 +25,9 @@ def test_score(self):
self.assertTrue(torch.is_tensor(R))
self.assertTrue(torch.is_tensor(F))
self.assertEqual(hash_code, f'roberta-large_L17_no-idf_version={bert_score.__version__}')
self.assertTrue((P - torch.tensor([0.9843302369117737, 0.9832239747047424, 0.9120386242866516])).abs_().max() < eps)
self.assertTrue((R - torch.tensor([0.9823839068412781, 0.9732863903045654, 0.920428991317749])).abs_().max() < eps)
self.assertTrue((F - torch.tensor([0.9833561182022095, 0.9782299995422363, 0.916214644908905])).abs_().max() < eps)
self.assertTrue((P - torch.tensor([0.9843302369117737, 0.9832239747047424, 0.9120386242866516])).abs_().max() < EPS)
self.assertTrue((R - torch.tensor([0.9823839068412781, 0.9732863903045654, 0.920428991317749])).abs_().max() < EPS)
self.assertTrue((F - torch.tensor([0.9833561182022095, 0.9782299995422363, 0.916214644908905])).abs_().max() < EPS)

def test_idf_score(self):
(P, R, F), hash_code = bert_score.score(cands, refs, model_type='roberta-large', num_layers=17,
Expand All @@ -38,9 +38,9 @@ def test_idf_score(self):
self.assertTrue(torch.is_tensor(R))
self.assertTrue(torch.is_tensor(F))
self.assertEqual(hash_code, f'roberta-large_L17_idf_version={bert_score.__version__}')
self.assertTrue((P - torch.tensor([0.9837872385978699, 0.9754738807678223, 0.8947395086288452])).abs_().max() < eps)
self.assertTrue((R - torch.tensor([0.9827190637588501, 0.9697767496109009, 0.9172918796539307])).abs_().max() < eps)
self.assertTrue((F - torch.tensor([0.9832529425621033, 0.972616970539093, 0.9058753848075867])).abs_().max() < eps)
self.assertTrue((P - torch.tensor([0.9837872385978699, 0.9754738807678223, 0.8947395086288452])).abs_().max() < EPS)
self.assertTrue((R - torch.tensor([0.9827190637588501, 0.9697767496109009, 0.9172918796539307])).abs_().max() < EPS)
self.assertTrue((F - torch.tensor([0.9832529425621033, 0.972616970539093, 0.9058753848075867])).abs_().max() < EPS)

if __name__ == '__main__':
unittest.main()

0 comments on commit 55345eb

Please sign in to comment.