Skip to content

Commit

Permalink
fix multigpu tricks for tf < 2.10
Browse files Browse the repository at this point in the history
  • Loading branch information
lfoppiano committed Jan 18, 2024
1 parent 4e6ed0d commit 92d9573
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 9 deletions.
2 changes: 1 addition & 1 deletion delft/applications/textClassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def train(model_name,
if fold_count == 1:
model.train(xtr, y_, incremental=incremental, multi_gpu=multi_gpu)
else:
model.train_nfold(xtr, y_)
model.train_nfold(xtr, y_, multi_gpu=multi_gpu)
# saving the model
model.save()

Expand Down
10 changes: 6 additions & 4 deletions delft/sequenceLabelling/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,9 @@ def train(self, x_train, y_train, f_train=None, x_valid=None, y_valid=None, f_va

# This trick avoid an exception being through when the --multi-gpu approach is used on a single GPU system.
# It might be removed with TF 2.10 https://github.com/tensorflow/tensorflow/issues/50487
import atexit
atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore
if version.parse(tf.__version__) < version.parse('2.10.0'):
import atexit
atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore

with strategy.scope():
self.train_(x_train, y_train, f_train, x_valid, y_valid, f_valid, incremental, callbacks)
Expand Down Expand Up @@ -219,8 +220,9 @@ def train_nfold(self, x_train, y_train, x_valid=None, y_valid=None, f_train=None

# This trick avoid an exception being through when the --multi-gpu approach is used on a single GPU system.
# It might be removed with TF 2.10 https://github.com/tensorflow/tensorflow/issues/50487
import atexit
atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore
if version.parse(tf.__version__) < version.parse('2.10.0'):
import atexit
atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore

with strategy.scope():
self.train_nfold_(x_train, y_train, x_valid, y_valid, f_train, f_valid, incremental, callbacks)
Expand Down
10 changes: 6 additions & 4 deletions delft/textClassification/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,9 @@ def train(self, x_train, y_train, vocab_init=None, incremental=False, callbacks=

# This trick avoid an exception being through when the --multi-gpu approach is used on a single GPU system.
# It might be removed with TF 2.10 https://github.com/tensorflow/tensorflow/issues/50487
import atexit
atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore
if version.parse(tf.__version__) < version.parse('2.10.0'):
import atexit
atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore

with strategy.scope():
self.train_(x_train, y_train, vocab_init, incremental, callbacks)
Expand Down Expand Up @@ -217,8 +218,9 @@ def train_nfold(self, x_train, y_train, vocab_init=None, incremental=False, call

# This trick avoid an exception being through when the --multi-gpu approach is used on a single GPU system.
# It might be removed with TF 2.10 https://github.com/tensorflow/tensorflow/issues/50487
import atexit
atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore
if version.parse(tf.__version__) < version.parse('2.10.0'):
import atexit
atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore

with strategy.scope():
self.train_nfold_(x_train, y_train,vocab_init, incremental, callbacks)
Expand Down

0 comments on commit 92d9573

Please sign in to comment.