Skip to content

Commit

Permalink
SeqTM
Browse files Browse the repository at this point in the history
  • Loading branch information
mgraffg committed Aug 23, 2024
1 parent e3947c2 commit 7df00fb
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 6 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include dialectid/data/emojis.json.gz
2 changes: 1 addition & 1 deletion dialectid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

__version__ = '0.0.4'
__version__ = '0.0.5'

from dialectid.text_repr import BoW
from dialectid.model import DialectId
Binary file added dialectid/data/emojis.json.gz
Binary file not shown.
16 changes: 14 additions & 2 deletions dialectid/tests/test_text_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# https://www.cia.gov/the-world-factbook/about/archives/2021/field/languages/


from dialectid.text_repr import BoW
from dialectid.text_repr import BoW, SeqTM
import numpy as np


Expand All @@ -44,4 +44,16 @@ def test_subwords():
bow = BoW(lang='es', voc_size_exponent=13,
subwords=True)
bow.transform(['Hola'])



def test_SeqTM():
"""Test SeqTM class"""

seq = SeqTM(lang='es', subwords=True, voc_size_exponent=13)
assert seq.language == 'es'
assert seq.voc_size_exponent == 13
_ = [['q:~dia', 'q:s~', 'duro', 'q:s~']]
assert seq.compute_tokens('~dias~duros~') == _
assert seq.compute_tokens('~🤷~') == [['🤷']]
assert seq.compute_tokens('~🙇🏿~') == [['🙇']]
assert seq.tokenize('buenos dias 🙇🏿')[-1] == '🙇'
193 changes: 190 additions & 3 deletions dialectid/text_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from collections import OrderedDict
from os.path import join, dirname
from EvoMSA import BoW as EvoMSABoW
from EvoMSA.utils import b4msa_params
from b4msa.textmodel import TextModel
from microtc.weighting import TFIDF
from microtc import emoticons
from microtc.utils import tweet_iterator
from dialectid.utils import load_bow


Expand All @@ -48,6 +51,7 @@ def __init__(self, pretrain: bool=True,
assert loc is None
loc = 'qgrams'
self.loc = loc
self.subwords = subwords
if estimator_kwargs is None:
estimator_kwargs = {'dual': True, 'class_weight': 'balanced'}
super().__init__(pretrain=pretrain,
Expand All @@ -63,6 +67,15 @@ def loc(self):
def loc(self, value):
self._loc = value

@property
def subwords(self):
"""Whether to use subwords"""
return self._subwords

@subwords.setter
def subwords(self, value):
self._subwords = value

@property
def bow(self):
"""BoW"""
Expand All @@ -72,7 +85,7 @@ def bow(self):
data = load_bow(lang=self.lang,
d=self.voc_size_exponent,
func=self.voc_selection,
loc=self._loc)
loc=self.loc)
params = data['params']
counter = data['counter']
params.update(self.b4msa_kwargs)
Expand All @@ -82,4 +95,178 @@ def bow(self):
tfidf.word2id, tfidf.wordWeight = tfidf.counter2weight(counter)
bow.model = tfidf
self._bow = bow
return bow
return bow

class SeqTM(TextModel):
"""TextModel where the utterance is segmented in a sequence."""

def __init__(self, language='es',
voc_size_exponent: int=17,
voc_selection: str='most_common_by_type',
loc: str=None,
subwords: bool=True,
**kwargs):
if subwords:
assert loc is None
loc = 'qgrams'
self._map = {}
data = load_bow(lang=language,
d=voc_size_exponent,
func=voc_selection,
loc=loc)
params = data['params']
counter = data['counter']
params.update(kwargs)
super().__init__(**params)
self.language = language
self.voc_size_exponent = voc_size_exponent
self.voc_selection = voc_selection
self.loc = loc
self.subwords = subwords
self.__vocabulary(counter)

def __vocabulary(self, counter):
"""Vocabulary"""

tfidf = TFIDF()
tfidf.N = counter.update_calls
tfidf.word2id, tfidf.wordWeight = tfidf.counter2weight(counter)
self.model = tfidf
tokens = self.tokens
for value in tfidf.word2id:
key = value
if value[:2] == 'q:':
key = value[2:]
self._map[key] = value
tokens[key] = value
_ = join(dirname(__file__), 'data', 'emojis.json.gz')
emojis = next(tweet_iterator(_))
for k, v in emojis.items():
self._map[k] = v
tokens[k] = v

@property
def language(self):
"""Language of the pre-trained text representations"""

return self._language

@language.setter
def language(self, value):
self._language = value

@property
def voc_selection(self):
"""Method used to select the vocabulary"""

return self._voc_selection

@voc_selection.setter
def voc_selection(self, value):
self._voc_selection = value

@property
def voc_size_exponent(self):
"""Vocabulary size :math:`2^v`; where :math:`v` is :py:attr:`voc_size_exponent` """
return self._voc_size_exponent

@voc_size_exponent.setter
def voc_size_exponent(self, value):
self._voc_size_exponent = value

@property
def loc(self):
"""Location/Country"""

return self._loc

@loc.setter
def loc(self, value):
self._loc = value

@property
def subwords(self):
"""Whether to use subwords"""

return self._subwords

@subwords.setter
def subwords(self, value):
self._subwords = value

@property
def tokens(self):
"""Tokens"""

try:
return self._tokens
except AttributeError:
self._tokens = OrderedDict()
return self._tokens

@property
def data_structure(self):
"""Datastructure"""

try:
return self._data_structure
except AttributeError:
_ = emoticons.create_data_structure
self._data_structure = _(self.tokens)
return self._data_structure

def compute_tokens(self, text):
"""
Labels in a text
:param text:
:type text: str
:returns: The labels in the text
:rtype: set
"""

get = self._map.get
lst = self.find_token(text)
_ = [text[a:b] for a, b in lst]
return [[get(x, x) for x in _]]

def find_token(self, text):
"""Obtain the position of each label in the text
:param text: text
:type text: str
:return: list of pairs, init and end of the word
:rtype: list
"""

blocks = list()
init = i = end = 0
head = self.data_structure
current = head
text_length = len(text)
while i < text_length:
char = text[i]
try:
current = current[char]
i += 1
if "__end__" in current:
end = i
except KeyError:
current = head
if end > init:
blocks.append([init, end])
if (end - init) > 2 and text[end - 1] == '~':
init = i = end = end - 1
else:
init = i = end
elif i > init:
if (i - init) > 2 and text[i - 1] == '~':
init = end = i = i - 1
else:
init = end = i
else:
init += 1
i = end = init
if end > init:
blocks.append([init, end])
return blocks

0 comments on commit 7df00fb

Please sign in to comment.