Skip to content

Commit

Permalink
Merge pull request #14 from INGEOTEC/develop
Browse files Browse the repository at this point in the history
Version - 0.0.5
  • Loading branch information
mgraffg authored Sep 9, 2024
2 parents 268c598 + 7d44bc3 commit c3a1f82
Show file tree
Hide file tree
Showing 10 changed files with 408 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ jobs:
git config --global user.email "mgraffg@ieee.org"
git config --global user.name "mgraffg"
cd quarto
quarto publish gh-pages . --no-browser
quarto publish gh-pages dialectid.qmd --no-browser
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include dialectid/data/emojis.json.gz
4 changes: 2 additions & 2 deletions dialectid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

__version__ = '0.0.4'
__version__ = '0.0.5'

from dialectid.text_repr import BoW
from dialectid.text_repr import BoW, SeqTM
from dialectid.model import DialectId
Binary file added dialectid/data/emojis.json.gz
Binary file not shown.
68 changes: 67 additions & 1 deletion dialectid/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from dataclasses import dataclass
import importlib
import numpy as np
from dialectid.utils import BOW, load_dialectid
from dialectid.utils import BOW, load_dialectid, load_seqtm

@dataclass
class DialectId:
Expand Down Expand Up @@ -87,3 +87,69 @@ def predict(self, D: List[Union[dict, list, str]]) -> np.ndarray:

hy = self.decision_function(D)
return self.countries[hy.argmax(axis=1)]


@dataclass
class DenseBoW:
"""DenseBoW"""

lang: str='es'
voc_size_exponent: int=13
precision: int=32

def estimator(self, **kwargs):
"""Estimator"""

from sklearn.svm import LinearSVC
return LinearSVC(class_weight='balanced')

@property
def bow(self):
"""BoW"""

try:
return self._bow
except AttributeError:
from dialectid.text_repr import SeqTM
self._bow = SeqTM(language=self.lang,
voc_size_exponent=self.voc_size_exponent)
return self._bow

@property
def weights(self):
"""Weights"""
try:
return self._weights
except AttributeError:
iterator = load_seqtm(self.lang,
self.voc_size_exponent,
self.precision)
precision = getattr(np, f'float{self.precision}')
weights = []
names = []
for data in iterator:
_ = np.frombuffer(bytes.fromhex(data['coef']), dtype=precision)
weights.append(_)
names.append(data['labels'][-1])
self._weights = np.vstack(weights)
self._names = np.array(names)
return self._weights

@property
def names(self):
"""Vector space components"""

return self._names

def encode(self, text):
"""Encode utterace into a matrix"""

token2id = self.bow.token2id
seq = []
for token in self.bow.tokenize(text):
try:
seq.append(token2id[token])
except KeyError:
continue
W = self.weights
return np.vstack([W[:, x] for x in seq]).T
18 changes: 18 additions & 0 deletions dialectid/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,21 @@ def test_DialectId_subwords():
dialectid = DialectId(voc_size_exponent=15)
countries = dialectid.predict('comiendo tacos')
assert countries[0] == 'mx'


def test_DenseBoW():
"""Test DenseBoW based on SeqTM"""

from dialectid.model import DenseBoW

dense = DenseBoW(precision=16)
assert dense.weights.shape[0] == dense.names.shape[0]
dense.weights[0, 0] > 25


def test_DenseBoW_encode():
"""Test DenseBoW sentence repr"""

from dialectid.model import DenseBoW
dense = DenseBoW(precision=16)
assert dense.encode('buenos dΓ­as').shape[1] == 2
65 changes: 63 additions & 2 deletions dialectid/tests/test_text_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# https://www.cia.gov/the-world-factbook/about/archives/2021/field/languages/


from dialectid.text_repr import BoW
from dialectid.text_repr import BoW, SeqTM
import numpy as np


Expand All @@ -44,4 +44,65 @@ def test_subwords():
bow = BoW(lang='es', voc_size_exponent=13,
subwords=True)
bow.transform(['Hola'])



def test_SeqTM():
"""Test SeqTM class"""

seq = SeqTM(language='es', subwords=True,
sequence=False,
voc_selection='most_common_by_type',
voc_size_exponent=13)
assert seq.language == 'es'
assert seq.voc_size_exponent == 13
_ = [['dias', 'q:~dur', 'q:os~']]
assert seq.compute_tokens('~dias~duros~') == _
assert seq.compute_tokens('~🀷~') == [['🀷']]
assert seq.compute_tokens('~πŸ™‡πŸΏ~') == [['πŸ™‡']]
assert seq.tokenize('buenos dias πŸ™‡πŸΏ')[-1] == 'πŸ™‡'


def test_SeqTM_bug():
"""Test SeqTM class"""

seq = SeqTM(language='es', subwords=True,
sequence=False,
voc_selection='most_common_by_type',
voc_size_exponent=13)
res1 = seq.tokenize('mira pinche a')
res2 = seq.tokenize('a pinche a')
assert res1[1:] == res2[1:]


def test_SeqTM_seq():
"""Test SeqTM seq option"""

seq = SeqTM(language='es', sequence=True,
voc_selection='most_common',
voc_size_exponent=13)
res1 = seq.tokenize('mira pinche a')
res2 = seq.tokenize('a pinche a')
assert res1[1:] == res2[1:]


def test_SeqTM_seq_bug():
"""Test SeqTM seq option"""

seq = SeqTM(language='es', sequence=True,
voc_selection='most_common',
voc_size_exponent=13)
assert seq.del_dup == False


def test_SeqTM_names():
seq = SeqTM(language='es', sequence=True,
voc_selection='most_common',
voc_size_exponent=13)
assert len(seq.names) == len(seq.model.word2id)


def test_SeqTM_weights():
seq = SeqTM(language='es', sequence=True,
voc_selection='most_common',
voc_size_exponent=13)
assert len(seq.weights) == len(seq.names)
Loading

0 comments on commit c3a1f82

Please sign in to comment.