From 860725cfde71c6a115b4b9c5a245fe3448aeff97 Mon Sep 17 00:00:00 2001 From: Keisuke Fukuda Date: Tue, 13 Jun 2017 16:40:04 +0900 Subject: [PATCH 1/8] WIP --- chainermn/dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/chainermn/dataset.py b/chainermn/dataset.py index 7b16e373..b754744a 100644 --- a/chainermn/dataset.py +++ b/chainermn/dataset.py @@ -2,7 +2,7 @@ import warnings -def scatter_dataset(dataset, comm): +def scatter_dataset(dataset, comm, shuffle=False): """Scatter the given dataset to the workers in the communicator. The dataset of worker 0 (i.e., the worker whose ``comm.rank`` is 0) is @@ -27,6 +27,7 @@ def scatter_dataset(dataset, comm): # TODO(akiba): write why we do not use mpi_comm.scatter if comm.rank == 0: + # TODO(keisukefukuda) mine = None n_total_samples = len(dataset) n_sub_samples = (n_total_samples + comm.size - 1) // comm.size From 609fa2d8baf64e7a3ce6cbc05a280b3edd932ad3 Mon Sep 17 00:00:00 2001 From: Keisuke Fukuda Date: Fri, 18 Aug 2017 14:46:31 +0900 Subject: [PATCH 2/8] implemented shuffle=True/False option in scatter_dataset --- chainermn/dataset.py | 16 +++++++++++----- tests/test_dataset.py | 28 ++++++++++++++++------------ tests/test_mnist.py | 4 ++-- 3 files changed, 29 insertions(+), 19 deletions(-) diff --git a/chainermn/dataset.py b/chainermn/dataset.py index d753bf82..bcbaa42b 100644 --- a/chainermn/dataset.py +++ b/chainermn/dataset.py @@ -1,8 +1,9 @@ import chainer.datasets +import numpy import warnings -def scatter_dataset(dataset, comm, shuffle=False): +def scatter_dataset(dataset, comm, root=0, shuffle=False, seed=None): """Scatter the given dataset to the workers in the communicator. The dataset of worker 0 (i.e., the worker whose ``comm.rank`` is 0) is @@ -27,16 +28,21 @@ def scatter_dataset(dataset, comm, shuffle=False): # We cannot use `mpi_comm.scatter`. This is due to MPI4py's bug. # For large datasets, when using `mpi_comm.scatter`, it causes MemoryError. - if comm.rank == 0: - # TODO(keisukefukuda) + if comm.rank == root: mine = None n_total_samples = len(dataset) n_sub_samples = (n_total_samples + comm.size - 1) // comm.size + + if shuffle: + order = numpy.random.RandomState(seed).permutation(n_total_samples) + else: + order = numpy.arange(n_total_samples) + for i in range(comm.size): b = n_total_samples * i // comm.size e = b + n_sub_samples - subds = chainer.datasets.SubDataset(dataset, b, e) - if i == 0: + subds = chainer.datasets.SubDataset(dataset, b, e, order) + if i == root: mine = subds else: comm.send(subds, dest=i) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index bf5487d9..4ac05a6b 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -13,12 +13,13 @@ def setUp(self): self.mpi_comm = mpi4py.MPI.COMM_WORLD self.communicator = NaiveCommunicator(self.mpi_comm) - def check_scatter_dataset(self, original_dataset): + def check_scatter_dataset(self, original_dataset, shuffle=False, root=0): my_dataset = chainermn.scatter_dataset( - original_dataset, self.communicator) + original_dataset, self.communicator, + shuffle=shuffle, root=root) sub_datasets = self.mpi_comm.gather(my_dataset) - if self.mpi_comm.rank == 0: + if self.mpi_comm.rank == root: # Test the sizes sub_sizes = [len(sub_dataset) for sub_dataset in sub_datasets] self.assertEqual(len(set(sub_sizes)), 1) @@ -36,12 +37,15 @@ def check_scatter_dataset(self, original_dataset): def test_scatter_dataset(self): n = self.communicator.size - self.check_scatter_dataset([]) - self.check_scatter_dataset([0]) - self.check_scatter_dataset(list(range(n))) - self.check_scatter_dataset(list(range(n * 5 - 1))) - - self.check_scatter_dataset(np.array([])) - self.check_scatter_dataset(np.array([0])) - self.check_scatter_dataset(np.arange(n)) - self.check_scatter_dataset(np.arange(n * 5 - 1)) + for shuffle in [True, False]: + for root in range(self.communicator.size): + self.check_scatter_dataset([], root, shuffle) + self.check_scatter_dataset([0], root, shuffle) + self.check_scatter_dataset(list(range(n)), root, shuffle) + self.check_scatter_dataset(list(range(n * 5 - 1)), + root, shuffle) + + self.check_scatter_dataset(np.array([]), root, shuffle) + self.check_scatter_dataset(np.array([0]), root, shuffle) + self.check_scatter_dataset(np.arange(n), root, shuffle) + self.check_scatter_dataset(np.arange(n * 5 - 1), root, shuffle) diff --git a/tests/test_mnist.py b/tests/test_mnist.py index 17dfd78d..589e742d 100644 --- a/tests/test_mnist.py +++ b/tests/test_mnist.py @@ -45,8 +45,8 @@ def test_mnist(self, display_log=True): else: train, test = None, None - train = chainermn.scatter_dataset(train, comm) - test = chainermn.scatter_dataset(test, comm) + train = chainermn.scatter_dataset(train, comm, shuffle=True) + test = chainermn.scatter_dataset(test, comm, shuffle=True) train_iter = chainer.iterators.SerialIterator(train, batchsize) test_iter = chainer.iterators.SerialIterator(test, batchsize, From 96608990b5ddb1598e22ff91070cfc9512cb8359 Mon Sep 17 00:00:00 2001 From: Keisuke Fukuda Date: Fri, 18 Aug 2017 17:01:03 +0900 Subject: [PATCH 3/8] misc. fixes --- chainermn/dataset.py | 7 ++++++- examples/mnist/train_mnist.py | 4 ++-- tests/test_dataset.py | 16 ++++++++-------- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/chainermn/dataset.py b/chainermn/dataset.py index bcbaa42b..94ad0d5d 100644 --- a/chainermn/dataset.py +++ b/chainermn/dataset.py @@ -16,6 +16,8 @@ def scatter_dataset(dataset, comm, root=0, shuffle=False, seed=None): dataset: A dataset (e.g., ``list``, ``numpy.ndarray``, ``chainer.datasets.TupleDataset``, ...). comm: ChainerMN communicator or MPI4py communicator. + shuffle: Shuffle the dataset before being scattered. + root: The root process of the scatter operation. Returns: Scattered dataset. @@ -25,9 +27,12 @@ def scatter_dataset(dataset, comm, root=0, shuffle=False, seed=None): comm = comm.mpi_comm assert hasattr(comm, 'send') assert hasattr(comm, 'recv') + assert 0 <= root and root < comm.size, "root={},rank={}".format(root,comm.rank) # We cannot use `mpi_comm.scatter`. This is due to MPI4py's bug. # For large datasets, when using `mpi_comm.scatter`, it causes MemoryError. + # import sys + # sys.stderr.write("scatter_dataset(): root={}".format(root)) if comm.rank == root: mine = None n_total_samples = len(dataset) @@ -48,7 +53,7 @@ def scatter_dataset(dataset, comm, root=0, shuffle=False, seed=None): comm.send(subds, dest=i) return mine else: - return comm.recv(source=0) + return comm.recv(source=root) def get_n_iterations_for_one_epoch(dataset, local_batch_size, comm): diff --git a/examples/mnist/train_mnist.py b/examples/mnist/train_mnist.py index eb0de74d..b0bcaecb 100644 --- a/examples/mnist/train_mnist.py +++ b/examples/mnist/train_mnist.py @@ -89,8 +89,8 @@ def main(): train, test = chainer.datasets.get_mnist() else: train, test = None, None - train = chainermn.scatter_dataset(train, comm) - test = chainermn.scatter_dataset(test, comm) + train = chainermn.scatter_dataset(train, comm, shuffle=True) + test = chainermn.scatter_dataset(test, comm, shuffle=True) train_iter = chainer.iterators.SerialIterator(train, args.batchsize) test_iter = chainer.iterators.SerialIterator(test, args.batchsize, diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 4ac05a6b..d2db2125 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -39,13 +39,13 @@ def test_scatter_dataset(self): for shuffle in [True, False]: for root in range(self.communicator.size): - self.check_scatter_dataset([], root, shuffle) - self.check_scatter_dataset([0], root, shuffle) - self.check_scatter_dataset(list(range(n)), root, shuffle) + self.check_scatter_dataset([], shuffle, root) + self.check_scatter_dataset([0], shuffle, root) + self.check_scatter_dataset(list(range(n)), shuffle, root) self.check_scatter_dataset(list(range(n * 5 - 1)), - root, shuffle) + shuffle, root) - self.check_scatter_dataset(np.array([]), root, shuffle) - self.check_scatter_dataset(np.array([0]), root, shuffle) - self.check_scatter_dataset(np.arange(n), root, shuffle) - self.check_scatter_dataset(np.arange(n * 5 - 1), root, shuffle) + self.check_scatter_dataset(np.array([]), shuffle, root) + self.check_scatter_dataset(np.array([0]), shuffle, root) + self.check_scatter_dataset(np.arange(n), shuffle, root) + self.check_scatter_dataset(np.arange(n * 5 - 1), shuffle, root) From a13db87d40be9b71d27c4ca4cc8e6e5d2fcc27aa Mon Sep 17 00:00:00 2001 From: Keisuke Fukuda Date: Mon, 21 Aug 2017 10:24:58 +0900 Subject: [PATCH 4/8] removed unnecessary assertion message --- chainermn/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chainermn/dataset.py b/chainermn/dataset.py index 94ad0d5d..c1aa2c52 100644 --- a/chainermn/dataset.py +++ b/chainermn/dataset.py @@ -27,7 +27,7 @@ def scatter_dataset(dataset, comm, root=0, shuffle=False, seed=None): comm = comm.mpi_comm assert hasattr(comm, 'send') assert hasattr(comm, 'recv') - assert 0 <= root and root < comm.size, "root={},rank={}".format(root,comm.rank) + assert 0 <= root and root < comm.size # We cannot use `mpi_comm.scatter`. This is due to MPI4py's bug. # For large datasets, when using `mpi_comm.scatter`, it causes MemoryError. From 74e6b09dcf0b27d1bbc4e1f1db701f6ea06088cf Mon Sep 17 00:00:00 2001 From: Keisuke Fukuda Date: Mon, 21 Aug 2017 14:34:27 +0900 Subject: [PATCH 5/8] minor bugfix --- tests/test_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index d2db2125..23a0dbdd 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -17,7 +17,7 @@ def check_scatter_dataset(self, original_dataset, shuffle=False, root=0): my_dataset = chainermn.scatter_dataset( original_dataset, self.communicator, shuffle=shuffle, root=root) - sub_datasets = self.mpi_comm.gather(my_dataset) + sub_datasets = self.mpi_comm.gather(my_dataset, root=root) if self.mpi_comm.rank == root: # Test the sizes From d3233f8f60bb4c483da3103e16cf13a6bcf87c3c Mon Sep 17 00:00:00 2001 From: Keisuke Fukuda Date: Wed, 23 Aug 2017 11:19:47 +0900 Subject: [PATCH 6/8] Pass None as ``order`` argument to ``SubDataset`` if ``shuffle`` is None --- chainermn/dataset.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/chainermn/dataset.py b/chainermn/dataset.py index c1aa2c52..a46fb2cf 100644 --- a/chainermn/dataset.py +++ b/chainermn/dataset.py @@ -16,8 +16,14 @@ def scatter_dataset(dataset, comm, root=0, shuffle=False, seed=None): dataset: A dataset (e.g., ``list``, ``numpy.ndarray``, ``chainer.datasets.TupleDataset``, ...). comm: ChainerMN communicator or MPI4py communicator. - shuffle: Shuffle the dataset before being scattered. - root: The root process of the scatter operation. + shuffle (bool): If ``True``, the order of examples is shuffled + before being scattered. + root (int): The root process of the scatter operation. + seed (int): Seed the generator used for the permutation of indexes. + If an integer being convertible to 32 bit unsigned integers is + specified, it is guaranteed that each sample + in the given dataset always belongs to a specific subset. + If ``None``, the permutation is changed randomly. Returns: Scattered dataset. @@ -37,11 +43,10 @@ def scatter_dataset(dataset, comm, root=0, shuffle=False, seed=None): mine = None n_total_samples = len(dataset) n_sub_samples = (n_total_samples + comm.size - 1) // comm.size + order = None if shuffle: order = numpy.random.RandomState(seed).permutation(n_total_samples) - else: - order = numpy.arange(n_total_samples) for i in range(comm.size): b = n_total_samples * i // comm.size From 246482a7b9ecb8dfdc49385500a6209fb39361a6 Mon Sep 17 00:00:00 2001 From: Keisuke Fukuda Date: Wed, 23 Aug 2017 11:21:53 +0900 Subject: [PATCH 7/8] Removed unused code --- chainermn/dataset.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/chainermn/dataset.py b/chainermn/dataset.py index a46fb2cf..c8647e42 100644 --- a/chainermn/dataset.py +++ b/chainermn/dataset.py @@ -37,8 +37,6 @@ def scatter_dataset(dataset, comm, root=0, shuffle=False, seed=None): # We cannot use `mpi_comm.scatter`. This is due to MPI4py's bug. # For large datasets, when using `mpi_comm.scatter`, it causes MemoryError. - # import sys - # sys.stderr.write("scatter_dataset(): root={}".format(root)) if comm.rank == root: mine = None n_total_samples = len(dataset) From 98308970e1277299604cce2883ae2d6c4b8e2948 Mon Sep 17 00:00:00 2001 From: Keisuke Fukuda Date: Wed, 23 Aug 2017 16:30:00 +0900 Subject: [PATCH 8/8] added shuffle=True to imagenet example --- examples/imagenet/train_imagenet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/imagenet/train_imagenet.py b/examples/imagenet/train_imagenet.py index 706a6d0e..667bef7c 100644 --- a/examples/imagenet/train_imagenet.py +++ b/examples/imagenet/train_imagenet.py @@ -168,7 +168,7 @@ def main(): else: train = None val = None - train = chainermn.scatter_dataset(train, comm) + train = chainermn.scatter_dataset(train, comm, shuffle=True) val = chainermn.scatter_dataset(val, comm) # We need to change the start method of multiprocessing module if we are