-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemo.py
32 lines (24 loc) · 788 Bytes
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import logging
import hydra
import omegaconf
import wandb
from utilpy import log_init
from src.configs import Config
from src.fastdtm import DTM
@hydra.main(version_base=None, config_path="src/configs/", config_name="demo")
def main(cfg: Config):
log_init()
logger = logging.getLogger("main")
docs = [[[1, 2, 5], [3, 4], [2]], [[1, 3]], [[1], [5, 2]]]
vocabulary = ["func", "yellow", "prefix", "func1", "yellow1", "prefix1"]
try:
dtm = DTM(docs, vocabulary, cfg.model)
logger.info("start initialize")
dtm.initialize(True)
logger.info("start estimate")
dtm.estimate(cfg.data.epochs)
dtm.save_data(cfg.data.output_dir)
except Exception as ex:
logger.exception(ex)
if __name__ == "__main__":
main()