Skip to content

Commit

Permalink
add back rs to train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kellymarchisio committed Feb 27, 2023
1 parent dcbf6de commit 498c4d1
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def make_data(args, infile, total_lines, outfile, vocab, pos_sample_rates_dict,

if shuffle:
ts_print('Shuffling training file...', flush=True)
subprocess.run(['bash', '-c', 'shuf -o ' + outfile + '<' + outfile])
# On Mac, use gshuf instead of shuf
subprocess.run(['bash', '-c', 'gshuf -o ' + outfile + '<' + outfile])
ts_print('Done shuffling training file.', flush=True)

return outfile_len
Expand Down Expand Up @@ -197,6 +198,8 @@ def forward(self, context_idxs, target_idxs):
# Maybe does more computation, though?
#
# https://github.com/LakheyM/word2vec/blob/master/word2vec_SGNS_git.ipynb
# inner_products = torch.sum(torch.mul(
# context_vectors, target_vectors), dim=1).reshape(-1, 1)
inner_products = torch.diagonal(
context_vectors @ target_vectors.T).reshape(-1, 1)
return inner_products
Expand Down Expand Up @@ -511,6 +514,9 @@ def train(args, device, vocab, train_file, total_train_n, total_lines,
# Source: https://www.stat.cmu.edu/~larry/=sml/Opt.pdf p. 5, 12
iso_loss_unscaled = torch.linalg.norm(model_vecs_tmp - loaded_vecs)
iso_loss = iso_loss_unscaled / len(loaded_vecs)
elif args.loss == 'rs':
iso_loss = gh.diffble_rs_distance(model_vecs_tmp,
loaded_vecs, device, args.mode == 'unsupervised')
elif args.loss == 'evs':
iso_loss = iso.diffble_evs_distance(model_vecs_tmp,
loaded_vecs, device)
Expand Down

0 comments on commit 498c4d1

Please sign in to comment.