diff --git a/code/split_dataset.py b/code/split_dataset.py index 01be65f..2854b8d 100644 --- a/code/split_dataset.py +++ b/code/split_dataset.py @@ -56,7 +56,7 @@ def hash_withhold(withheld_idxs, length=6): def train_tune_test(ds, train_size=.90, tune_size=.1, test_size=0., withhold=None, rseed=8, out_dir=None, overwrite=False): """ split data into train, tune, and test sets """ - if train_size + tune_size + test_size != 1: + if not np.isclose(train_size + tune_size + test_size, 1): raise ValueError("train_size, tune_size, and test_size must add up to 1. current values are " "tr={}, tu={}, and te={}".format(train_size, tune_size, test_size))