From 498c4d1133e9c8261535abd7f1dec7baeb0d5629 Mon Sep 17 00:00:00 2001 From: Kelly Marchisio Date: Mon, 27 Feb 2023 10:33:27 -0500 Subject: [PATCH] add back rs to train.py --- src/train.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/train.py b/src/train.py index 181ec6e..45a0dbd 100644 --- a/src/train.py +++ b/src/train.py @@ -98,7 +98,8 @@ def make_data(args, infile, total_lines, outfile, vocab, pos_sample_rates_dict, if shuffle: ts_print('Shuffling training file...', flush=True) - subprocess.run(['bash', '-c', 'shuf -o ' + outfile + '<' + outfile]) + # On Mac, use gshuf instead of shuf + subprocess.run(['bash', '-c', 'gshuf -o ' + outfile + '<' + outfile]) ts_print('Done shuffling training file.', flush=True) return outfile_len @@ -197,6 +198,8 @@ def forward(self, context_idxs, target_idxs): # Maybe does more computation, though? # # https://github.com/LakheyM/word2vec/blob/master/word2vec_SGNS_git.ipynb + # inner_products = torch.sum(torch.mul( + # context_vectors, target_vectors), dim=1).reshape(-1, 1) inner_products = torch.diagonal( context_vectors @ target_vectors.T).reshape(-1, 1) return inner_products @@ -511,6 +514,9 @@ def train(args, device, vocab, train_file, total_train_n, total_lines, # Source: https://www.stat.cmu.edu/~larry/=sml/Opt.pdf p. 5, 12 iso_loss_unscaled = torch.linalg.norm(model_vecs_tmp - loaded_vecs) iso_loss = iso_loss_unscaled / len(loaded_vecs) + elif args.loss == 'rs': + iso_loss = gh.diffble_rs_distance(model_vecs_tmp, + loaded_vecs, device, args.mode == 'unsupervised') elif args.loss == 'evs': iso_loss = iso.diffble_evs_distance(model_vecs_tmp, loaded_vecs, device)