diff --git a/seq2seq/metrics/bleu.py b/seq2seq/metrics/bleu.py index 331f37c1..57fd85db 100644 --- a/seq2seq/metrics/bleu.py +++ b/seq2seq/metrics/bleu.py @@ -30,7 +30,8 @@ import tensorflow as tf -def moses_multi_bleu(hypotheses, references, lowercase=False): +def moses_multi_bleu(hypotheses, references, lowercase=False, + multi_bleu_path=None): """Calculate the bleu score for hypotheses and references using the MOSES ulti-bleu.perl script. @@ -38,6 +39,7 @@ def moses_multi_bleu(hypotheses, references, lowercase=False): hypotheses: A numpy array of strings where each string is a single example. references: A numpy array of strings where each string is a single example. lowercase: If true, pass the "-lc" flag to the multi-bleu script + multi_bleu_path: The path to store bleu script. Default to /multi-bleu.perl Returns: The BLEU score as a float32 value. @@ -48,9 +50,12 @@ def moses_multi_bleu(hypotheses, references, lowercase=False): # Get MOSES multi-bleu script try: - multi_bleu_path, _ = urllib.request.urlretrieve( + if multi_bleu_path == None: + multi_bleu_path = os.path.join(tempfile.gettempdir(), "multi-bleu.perl") + urllib.request.urlretrieve( "https://raw.githubusercontent.com/moses-smt/mosesdecoder/" - "master/scripts/generic/multi-bleu.perl") + "master/scripts/generic/multi-bleu.perl", + multi_bleu_path) os.chmod(multi_bleu_path, 0o755) except: #pylint: disable=W0702 tf.logging.info("Unable to fetch multi-bleu.perl script, using local.")