Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validation Curves for SEA and Performer on OpenWebText Dataset (Figure A.5 in Appendix A.9) #9

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ plots/
submission/
submission*.zip

hello.png
hello.png
68 changes: 68 additions & 0 deletions src/main/plot/figure_opt_curve_web.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import math
import os
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
plt.style.use('seaborn-v0_8-bright')
matplotlib.rcParams['font.family'] = 'Noto Sans, DejaVu Sans'

data = pd.read_csv("./plots/main/wandb_opt125_openwebtext.csv")
colnames = [
'opt125 k64 w64 openwebtext - eval/score',
'opt125 performer openwebtext - eval/score',
]
names = [
'Ours',
'Performer',
]
TOMETHOD = {
'Ours':'perlin',
'Performer':'performer',
}
COLORS = {
'none': 'green',
'perlin': 'magenta',
'performer': 'blue',
'reformer': 'purple',
'scatterbrain': 'gray',
'sinkhorn': 'orange',
'synthesizer': 'yellow',
}

xss = []
yss = []

for cn in colnames:
dxs = data['Step'].to_numpy()
dys = data[cn].to_numpy()
xs = []
ys = []
for i in range(len(dys)):
x = dxs[i]
y = dys[i]
if not math.isnan(y):
xs.append(x)
ys.append(y)
xss.append(xs)
yss.append(ys)

plt.figure(figsize=(3.5,2.7))

for i in range(len(xss)):
name = names[i]
xs = xss[i]
ys = yss[i]
plt.plot(xs, ys, label=name, color=COLORS[TOMETHOD[name]])

plt.grid()
plt.legend()
plt.ylim(0, 150)
plt.xlabel('Optimizer Steps', fontweight=500)
plt.ylabel('PPL. ↓', fontweight=500)
plt.title('Validation Curve', fontweight=500)

root = './plots/main/figure_opt_curve_openwebtext'
os.makedirs(root, exist_ok=True)
plt.savefig(os.path.join(root, 'plot_opt_curve_openwebtext.pdf'), bbox_inches='tight')
plt.savefig(os.path.join(root, 'plot_opt_curve_openwebtext.png'), bbox_inches='tight')
print(os.path.join(root, 'plot_opt_curve_openwebtext.png'))
2 changes: 1 addition & 1 deletion src/main/visualize/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,4 @@ def process_batch_index(attentions: List[torch.Tensor], i: int, T: int, gs = [0.
# for img in tqdm.tqdm(iterator, dynamic_ncols=True, desc='render.layer', total=N):
# imgs.append(img)

return np.concatenate(imgs, axis=0)
return np.concatenate(imgs, axis=0)
2 changes: 1 addition & 1 deletion src/main/visualize/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,4 @@ def main(
'evaluate': args.evaluate
})

main(**kwargs)
main(**kwargs)
2 changes: 1 addition & 1 deletion src/models/perlin_bert/perlin_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2527,4 +2527,4 @@ def forward(
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
)
2 changes: 1 addition & 1 deletion src/trainer/glue_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,4 +431,4 @@ def main(self):
trainer = Trainer(
subset='mnli'
)
trainer.main()
trainer.main()