-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathtest_slot.py
64 lines (52 loc) · 1.59 KB
/
test_slot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import json
import pickle
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import Dict
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from dataset import SeqTaggingClsDataset
from model import SeqTagger
from utils import Vocab
def main(args):
# TODO: implement main function
raise NotImplementedError
def parse_args() -> Namespace:
parser = ArgumentParser()
parser.add_argument(
"--data_dir",
type=Path,
help="Directory to the dataset.",
default="./data/slot/",
)
parser.add_argument(
"--cache_dir",
type=Path,
help="Directory to the preprocessed caches.",
default="./cache/slot/",
)
parser.add_argument(
"--ckpt_dir",
type=Path,
help="Directory to save the model file.",
default="./ckpt/slot/",
)
parser.add_argument("--pred_file", type=Path, default="pred.slot.csv")
# data
parser.add_argument("--max_len", type=int, default=128)
# model
parser.add_argument("--hidden_size", type=int, default=512)
parser.add_argument("--num_layers", type=int, default=2)
parser.add_argument("--dropout", type=float, default=0.1)
parser.add_argument("--bidirectional", type=bool, default=True)
# data loader
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument(
"--device", type=torch.device, help="cpu, cuda, cuda:0, cuda:1", default="cpu"
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
main(args)