-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathot_dataset.py
26 lines (21 loc) · 1.01 KB
/
ot_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from torchtext.datasets import TranslationDataset
import torchtext.data as data
class OTDataset(TranslationDataset):
"""Defines a dataset for translation between output examples and OT constraint rankings."""
def __init__(self, tupled_examples, fields, **kwargs):
"""Create a TranslationDataset given paths and fields.
Arguments:
path: Common prefix of paths to the data files for both languages.
exts: A tuple containing the extension to path for each language.
fields: A tuple containing the fields that will be used for data
in each language.
Remaining keyword arguments: Passed to the constructor of
data.Dataset.
"""
if not isinstance(fields[0], (tuple, list)):
fields = [('src', fields[0]), ('trg', fields[1])]
examples = []
for src_words, trg_ranking in tupled_examples:
examples.append(data.Example.fromlist(
[src_words, trg_ranking], fields))
super(TranslationDataset, self).__init__(examples, fields, **kwargs)