diff --git a/dialectid/model.py b/dialectid/model.py index 4c31800..774b4be 100644 --- a/dialectid/model.py +++ b/dialectid/model.py @@ -32,6 +32,7 @@ class DialectId: """DialectId""" lang: str='es' voc_size_exponent: int=15 + subwords: bool=True @property def bow(self): @@ -43,8 +44,12 @@ def bow(self): path = BOW[self.lang].split('.') module = '.'.join(path[:-1]) text_repr = importlib.import_module(module) + kwargs = {} + if module != 'EvoMSA.text_repr': + kwargs = dict(subwords=self.subwords) _ = getattr(text_repr, path[-1])(lang=self.lang, - voc_size_exponent=self.voc_size_exponent) + voc_size_exponent=self.voc_size_exponent, + **kwargs) self._bow = _ return self._bow @@ -55,7 +60,8 @@ def weights(self): return self._weights except AttributeError: self._weights = load_dialectid(self.lang, - self.voc_size_exponent) + self.voc_size_exponent, + self.subwords) return self._weights @property diff --git a/dialectid/tests/test_model.py b/dialectid/tests/test_model.py index baf316b..e2acb29 100644 --- a/dialectid/tests/test_model.py +++ b/dialectid/tests/test_model.py @@ -28,7 +28,7 @@ def test_DialectId(): from dialectid.model import DialectId from dialectid import BoW - dialectid = DialectId(voc_size_exponent=15) + dialectid = DialectId(voc_size_exponent=15, subwords=False) assert dialectid.lang == 'es' and dialectid.voc_size_exponent == 15 assert isinstance(dialectid.bow, BoW) @@ -38,7 +38,7 @@ def test_DialectId_df(): from dialectid.model import DialectId - dialectid = DialectId(voc_size_exponent=15) + dialectid = DialectId(voc_size_exponent=15, subwords=False) hy = dialectid.decision_function('comiendo tacos') assert hy.shape == (1, 20) assert hy.argmax(axis=1)[0] == 0 @@ -49,7 +49,7 @@ def test_countries(): from dialectid.model import DialectId - dialectid = DialectId(voc_size_exponent=15) + dialectid = DialectId(voc_size_exponent=15, subwords=False) assert len(dialectid.countries) == 20 assert dialectid.countries[0] == 'mx' @@ -59,10 +59,18 @@ def test_predict(): from dialectid.model import DialectId - dialectid = DialectId(voc_size_exponent=15) + dialectid = DialectId(voc_size_exponent=15, subwords=False) countries = dialectid.predict('comiendo tacos') assert countries[0] == 'mx' countries = dialectid.predict(['comiendo tacos', 'tomando vino']) assert countries.shape == (2, ) + +def test_DialectId_subwords(): + """Test DialectId subwords""" + + from dialectid.model import DialectId + dialectid = DialectId(voc_size_exponent=15) + countries = dialectid.predict('comiendo tacos') + assert countries[0] == 'mx' diff --git a/dialectid/utils.py b/dialectid/utils.py index 2932c3b..5e1c5cf 100644 --- a/dialectid/utils.py +++ b/dialectid/utils.py @@ -146,13 +146,16 @@ def load(filename): return data -def load_dialectid(lang, dim): +def load_dialectid(lang, dim, subwords): """Load url""" diroutput = join(dirname(__file__), 'models') if not isdir(diroutput): os.mkdir(diroutput) - filename = f'dialectid_{lang}_{dim}.json.gz' + if subwords: + filename = f'dialectid_subwords_{lang}_{dim}.json.gz' + else: + filename = f'dialectid_{lang}_{dim}.json.gz' output = join(diroutput, filename) if not isfile(output): Download(f'{BASEURL}/{filename}', output) diff --git a/pyproject.toml b/pyproject.toml index 1a14217..0d035e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,7 @@ [project] name = 'dialectid' +description = "Set of algorithms to detect the dialect of a given text" +readme = "README.rst" dependencies = [ 'numpy', 'scikit-learn>=1.3.0', @@ -8,9 +10,27 @@ dependencies = [ 'EvoMSA' ] dynamic = ['version'] +classifiers = [ + "Development Status :: 3 - Alpha", + "Environment :: Console", + "Intended Audience :: Developers", + "Intended Audience :: Information Technology", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Topic :: Scientific/Engineering :: Artificial Intelligence" +] + + [tool.setuptools.dynamic] version = {attr = 'dialectid.__version__'} [tool.setuptools] -packages = ['dialectid', 'dialectid.tests'] \ No newline at end of file +packages = ['dialectid', 'dialectid.tests'] + +[project.urls] +Homepage = "https://ingeotec.github.io/dialectid" +Repository = "https://github.com/INGEOTEC/dialectid" +Issues = "https://github.com/INGEOTEC/dialectid/issues" \ No newline at end of file