Skip to content

Commit

Permalink
Add support for always detecting specific languages in language ident…
Browse files Browse the repository at this point in the history
…ification
  • Loading branch information
PhilipMay committed Jan 29, 2024
1 parent 9e83c95 commit dae8c67
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
10 changes: 9 additions & 1 deletion mltb2/fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import os
from dataclasses import dataclass, field
from typing import List, Optional

import fasttext
from fasttext.FastText import _FastText
Expand Down Expand Up @@ -51,12 +52,15 @@ def get_model_path_and_download() -> str:

return model_full_path

def __call__(self, text: str, num_lang: int = 10):
def __call__(self, text: str, num_lang: int = 10, always_detect_lang: Optional[List[str]] = None):
"""Identify languages of a given text.
Args:
text: the text for which the language is to be recognized
num_lang: number of returned languages
always_detect_lang: A list of languages that should always be returned
even if not detected. If the language is not detected, the probability
is set to 0.0.
Returns:
A dict from language to probability.
This dict contains no more than ``num_lang`` elements.
Expand All @@ -76,4 +80,8 @@ def __call__(self, text: str, num_lang: int = 10):
languages = predictions[0]
probabilities = predictions[1]
lang_to_prob = {lang[9:]: prob for lang, prob in zip(languages, probabilities)}
if always_detect_lang is not None:
for lang in always_detect_lang:
if lang not in lang_to_prob:
lang_to_prob[lang] = 0.0
return lang_to_prob
11 changes: 11 additions & 0 deletions tests/test_fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,14 @@ def test_fasttext_language_identification_call():
languages = language_identification("This is an English sentence.")
assert languages is not None
assert len(languages) == 10


def test_fasttext_language_identification_call_with_always_detect_lang():
language_identification = FastTextLanguageIdentification()
languages = language_identification("This is an English sentence.")
assert languages is not None
assert len(languages) == 10
languages_with_de = language_identification("This is an English sentence.", always_detect_lang=["de"])
assert languages_with_de is not None
assert len(languages_with_de) == 11
assert "de" in languages_with_de

0 comments on commit dae8c67

Please sign in to comment.