Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/develop' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
mgraffg committed Aug 2, 2024
2 parents 37d09d7 + 3eb1dad commit 9ec0892
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 9 deletions.
10 changes: 8 additions & 2 deletions dialectid/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class DialectId:
"""DialectId"""
lang: str='es'
voc_size_exponent: int=15
subwords: bool=True

@property
def bow(self):
Expand All @@ -43,8 +44,12 @@ def bow(self):
path = BOW[self.lang].split('.')
module = '.'.join(path[:-1])
text_repr = importlib.import_module(module)
kwargs = {}
if module != 'EvoMSA.text_repr':
kwargs = dict(subwords=self.subwords)
_ = getattr(text_repr, path[-1])(lang=self.lang,
voc_size_exponent=self.voc_size_exponent)
voc_size_exponent=self.voc_size_exponent,
**kwargs)
self._bow = _
return self._bow

Expand All @@ -55,7 +60,8 @@ def weights(self):
return self._weights
except AttributeError:
self._weights = load_dialectid(self.lang,
self.voc_size_exponent)
self.voc_size_exponent,
self.subwords)
return self._weights

@property
Expand Down
16 changes: 12 additions & 4 deletions dialectid/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_DialectId():
from dialectid.model import DialectId
from dialectid import BoW

dialectid = DialectId(voc_size_exponent=15)
dialectid = DialectId(voc_size_exponent=15, subwords=False)
assert dialectid.lang == 'es' and dialectid.voc_size_exponent == 15
assert isinstance(dialectid.bow, BoW)

Expand All @@ -38,7 +38,7 @@ def test_DialectId_df():

from dialectid.model import DialectId

dialectid = DialectId(voc_size_exponent=15)
dialectid = DialectId(voc_size_exponent=15, subwords=False)
hy = dialectid.decision_function('comiendo tacos')
assert hy.shape == (1, 20)
assert hy.argmax(axis=1)[0] == 0
Expand All @@ -49,7 +49,7 @@ def test_countries():

from dialectid.model import DialectId

dialectid = DialectId(voc_size_exponent=15)
dialectid = DialectId(voc_size_exponent=15, subwords=False)
assert len(dialectid.countries) == 20
assert dialectid.countries[0] == 'mx'

Expand All @@ -59,10 +59,18 @@ def test_predict():

from dialectid.model import DialectId

dialectid = DialectId(voc_size_exponent=15)
dialectid = DialectId(voc_size_exponent=15, subwords=False)
countries = dialectid.predict('comiendo tacos')
assert countries[0] == 'mx'
countries = dialectid.predict(['comiendo tacos',
'tomando vino'])
assert countries.shape == (2, )


def test_DialectId_subwords():
"""Test DialectId subwords"""

from dialectid.model import DialectId
dialectid = DialectId(voc_size_exponent=15)
countries = dialectid.predict('comiendo tacos')
assert countries[0] == 'mx'
7 changes: 5 additions & 2 deletions dialectid/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,16 @@ def load(filename):
return data


def load_dialectid(lang, dim):
def load_dialectid(lang, dim, subwords):
"""Load url"""

diroutput = join(dirname(__file__), 'models')
if not isdir(diroutput):
os.mkdir(diroutput)
filename = f'dialectid_{lang}_{dim}.json.gz'
if subwords:
filename = f'dialectid_subwords_{lang}_{dim}.json.gz'
else:
filename = f'dialectid_{lang}_{dim}.json.gz'
output = join(diroutput, filename)
if not isfile(output):
Download(f'{BASEURL}/{filename}', output)
Expand Down
22 changes: 21 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
[project]
name = 'dialectid'
description = "Set of algorithms to detect the dialect of a given text"
readme = "README.rst"
dependencies = [
'numpy',
'scikit-learn>=1.3.0',
Expand All @@ -8,9 +10,27 @@ dependencies = [
'EvoMSA'
]
dynamic = ['version']
classifiers = [
"Development Status :: 3 - Alpha",
"Environment :: Console",
"Intended Audience :: Developers",
"Intended Audience :: Information Technology",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Programming Language :: Python",
"Topic :: Scientific/Engineering :: Artificial Intelligence"
]



[tool.setuptools.dynamic]
version = {attr = 'dialectid.__version__'}

[tool.setuptools]
packages = ['dialectid', 'dialectid.tests']
packages = ['dialectid', 'dialectid.tests']

[project.urls]
Homepage = "https://ingeotec.github.io/dialectid"
Repository = "https://github.com/INGEOTEC/dialectid"
Issues = "https://github.com/INGEOTEC/dialectid/issues"

0 comments on commit 9ec0892

Please sign in to comment.