diff --git a/delft/applications/textClassifier.py b/delft/applications/textClassifier.py index 79a3e6f..979b542 100644 --- a/delft/applications/textClassifier.py +++ b/delft/applications/textClassifier.py @@ -93,7 +93,7 @@ def train(model_name, if fold_count == 1: model.train(xtr, y_, incremental=incremental, multi_gpu=multi_gpu) else: - model.train_nfold(xtr, y_) + model.train_nfold(xtr, y_, multi_gpu=multi_gpu) # saving the model model.save() diff --git a/delft/sequenceLabelling/wrapper.py b/delft/sequenceLabelling/wrapper.py index f2b2d80..538e483 100644 --- a/delft/sequenceLabelling/wrapper.py +++ b/delft/sequenceLabelling/wrapper.py @@ -151,8 +151,9 @@ def train(self, x_train, y_train, f_train=None, x_valid=None, y_valid=None, f_va # This trick avoid an exception being through when the --multi-gpu approach is used on a single GPU system. # It might be removed with TF 2.10 https://github.com/tensorflow/tensorflow/issues/50487 - import atexit - atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore + if version.parse(tf.__version__) < version.parse('2.10.0'): + import atexit + atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore with strategy.scope(): self.train_(x_train, y_train, f_train, x_valid, y_valid, f_valid, incremental, callbacks) @@ -219,8 +220,9 @@ def train_nfold(self, x_train, y_train, x_valid=None, y_valid=None, f_train=None # This trick avoid an exception being through when the --multi-gpu approach is used on a single GPU system. # It might be removed with TF 2.10 https://github.com/tensorflow/tensorflow/issues/50487 - import atexit - atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore + if version.parse(tf.__version__) < version.parse('2.10.0'): + import atexit + atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore with strategy.scope(): self.train_nfold_(x_train, y_train, x_valid, y_valid, f_train, f_valid, incremental, callbacks) diff --git a/delft/textClassification/wrapper.py b/delft/textClassification/wrapper.py index 4a5eeca..b44b633 100644 --- a/delft/textClassification/wrapper.py +++ b/delft/textClassification/wrapper.py @@ -145,8 +145,9 @@ def train(self, x_train, y_train, vocab_init=None, incremental=False, callbacks= # This trick avoid an exception being through when the --multi-gpu approach is used on a single GPU system. # It might be removed with TF 2.10 https://github.com/tensorflow/tensorflow/issues/50487 - import atexit - atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore + if version.parse(tf.__version__) < version.parse('2.10.0'): + import atexit + atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore with strategy.scope(): self.train_(x_train, y_train, vocab_init, incremental, callbacks) @@ -217,8 +218,9 @@ def train_nfold(self, x_train, y_train, vocab_init=None, incremental=False, call # This trick avoid an exception being through when the --multi-gpu approach is used on a single GPU system. # It might be removed with TF 2.10 https://github.com/tensorflow/tensorflow/issues/50487 - import atexit - atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore + if version.parse(tf.__version__) < version.parse('2.10.0'): + import atexit + atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore with strategy.scope(): self.train_nfold_(x_train, y_train,vocab_init, incremental, callbacks)