-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path__main__.py
31 lines (25 loc) · 1.09 KB
/
__main__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from __params__ import MODEL, CRAWL_PLATFORM, SKIP_TRAINING, SKIP_PREDICTION
if CRAWL_PLATFORM:
from asyncio import run
from src.crawl import crawl_twitter, crawl_youtube, crawl_bluesky
if CRAWL_PLATFORM == "twitter":
run(crawl_twitter())
elif CRAWL_PLATFORM == "youtube":
crawl_youtube()
elif CRAWL_PLATFORM == "bluesky":
crawl_bluesky()
else:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from src import preprocess, train, evaluate, predict, visualize
tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForSequenceClassification.from_pretrained(MODEL,
num_labels=3)
training, validation, testing = preprocess(tokenizer)
if not SKIP_TRAINING:
trainer = train(model, tokenizer, training, validation)
results = evaluate(trainer, testing)
print(results)
for platform in ["twitter", "youtube", "bluesky"]:
if not SKIP_PREDICTION:
predict(platform)
visualize(f"{platform}-predictions")