Skip to content

Commit

Permalink
fix the closing pool only for tf < 2.10.0
Browse files Browse the repository at this point in the history
  • Loading branch information
lfoppiano committed Sep 1, 2023
1 parent 8a03944 commit f8df013
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
7 changes: 4 additions & 3 deletions delft/sequenceLabelling/tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import tensorflow as tf
from packaging import version

from delft.sequenceLabelling.data_generator import DataGeneratorTransformers
from delft.sequenceLabelling.preprocess import Preprocessor
Expand All @@ -23,16 +24,16 @@ def __init__(self,
self.model_config = model_config
self.embeddings = embeddings


def tag(self, texts, output_format, features=None, multi_gpu=False):
if multi_gpu:
strategy = tf.distribute.MirroredStrategy()
print('Running with multi-gpu. Number of devices: {}'.format(strategy.num_replicas_in_sync))

# This trick avoid an exception being through when the --multi-gpu approach is used on a single GPU system.
# It might be removed with TF 2.10 https://github.com/tensorflow/tensorflow/issues/50487
import atexit
atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore
if version.parse(tf.__version__) < version.parse('2.10.0'):
import atexit
atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore

with strategy.scope():
return self.tag_(texts, output_format, features)
Expand Down
7 changes: 5 additions & 2 deletions delft/sequenceLabelling/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

from packaging import version

# ask tensorflow to be quiet and not print hundred lines of logs
from delft.utilities.Transformer import TRANSFORMER_CONFIG_FILE_NAME, DEFAULT_TRANSFORMER_TOKENIZER_DIR
from delft.utilities.misc import print_parameters
Expand Down Expand Up @@ -473,8 +475,9 @@ def tag(self, texts, output_format, features=None, batch_size=None, multi_gpu=Fa

# This trick avoid an exception being through when the --multi-gpu approach is used on a single GPU system.
# It might be removed with TF 2.10 https://github.com/tensorflow/tensorflow/issues/50487
import atexit
atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore
if version.parse(tf.__version__) < version.parse('2.10.0'):
import atexit
atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore

with strategy.scope():
return self.tag_(texts, output_format, features, batch_size)
Expand Down

0 comments on commit f8df013

Please sign in to comment.