-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemo.py
94 lines (79 loc) · 2.08 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import pandas as pd
import streamlit as st
import plotly.graph_objects as go
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
)
import altair as alt
cols = [
"stereotype",
"anti-stereotype",
"unrelated",
"profession",
"race",
"gender",
"religion",
]
def inference(inputs: str):
tokenized = tokenizer(
inputs,
return_tensors="pt",
max_length=512,
padding="max_length",
truncation=True,
).to(device)
outputs = model(**tokenized)
return outputs.logits.tolist()[0]
def write_header():
st.title("Korean Stereotype Detector")
st.markdown(
"""
- Write any sentence containing stereotypes and click the Run button.
- This application is made with TUNiB-Electra and K-StereoSet.
- Using CPU, the result might be slow
"""
)
def get_json_str(lists):
string = "{\n"
for i, v in enumerate(lists):
string += f'\t"{cols[i]}": {v},\n'
return string + "}"
def write_textbox():
input_text = st.text_area(label="Write your sentence", key=1, height=40)
button = st.button(label="Run")
output = inference(input_text)
st.markdown(
"""
#
## result
***
"""
)
if button:
col1, col2 = st.columns([5, 5])
with col1:
st.code(get_json_str(output))
with col2:
st.write(
alt.Chart(
pd.DataFrame({"Class": cols, "Logits": output}),
width=490,
height=360,
)
.mark_bar()
.encode(x="Class", y="Logits")
)
if __name__ == "__main__":
st.set_page_config(
page_title="Korean stereotype detector", page_icon="☮️", layout="wide"
)
tokenizer = AutoTokenizer.from_pretrained(
"dhtocks/tunib-electra-stereotype-classifier"
)
model = AutoModelForSequenceClassification.from_pretrained(
"dhtocks/tunib-electra-stereotype-classifier"
)
device = "cpu"
write_header()
write_textbox()