Skip to content

Commit

Permalink
docs(README,media): results and metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
AdityaNG committed May 9, 2024
1 parent af51c39 commit 5ed621d
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 18 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ python -m kan_gpt.prompt --prompt "Bangalore is often described as the " --model

We train and compare KAN-GPT with an equivalent MLP-GPT model on the Tiny Shakespeare dataset. We observe that the KAN-GPT performs slightly better than the MLP-GPT. We are looking into further experiments to dive deeper. The results are shown below:

<img src="media/results.png">

| Metrics | | |
|---------|---------|---------|
| <img src="media/results_loss.png"> | <img src="media/results_cross_entropy.png"> | <img src="media/results_perplexity.png"> |

## TODOs

Expand Down
Binary file removed media/results.png
Binary file not shown.
Binary file added media/results_cross_entropy.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added media/results_loss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added media/results_perplexity.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
75 changes: 58 additions & 17 deletions scripts/plot_results.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
import wandb
import numpy as np
import matplotlib.pyplot as plt

api = wandb.Api()

mlp_run = api.run("KAN-GPT/axi1qzwv")
kan_run = api.run("KAN-GPT/eusdq4te")
mlp_run = api.run("KAN-GPT/rk3dmrwh") # axi1qzwv
kan_run = api.run("KAN-GPT/m6msyzlz") # eusdq4te

keys = [
'train_loss', 'test_loss'
"train_loss",
"train_perplexity",
"train_f1",
"train_precision",
"train_recall",
"train_cross_entropy",
"test_loss",
"test_perplexity",
"test_f1",
"test_precision",
"test_recall",
"test_cross_entropy",
]

mlp_metrics = mlp_run.history(keys=keys)
Expand All @@ -20,27 +33,55 @@
print(kan_metrics)
print("="*20)

# Plot the test and train losses for the two models
metrics = [
"Loss",
"Perplexity",
"F1",
"Precision",
"Recall",
"Cross Entropy",
]

for metric in metrics:
id = metric.lower().replace(" ", "_").replace("-", "_")
# Plot the test and train losses for the two models

import matplotlib.pyplot as plt
kan_data = kan_metrics[[f"test_{id}", f"train_{id}"]]
mlp_data = mlp_metrics[[f"test_{id}", f"train_{id}"]]

kan_data = kan_data.dropna()
mlp_data = mlp_data.dropna()

print("MLP")
print(mlp_data)
print("="*20)

print("KAN")
print(kan_data)
print("="*20)

plt.plot(kan_data[f'test_{id}'].astype(np.float16), label='KAN Test', linestyle="--")
plt.plot(kan_data[f'train_{id}'].astype(np.float16), label='KAN Train')
plt.plot(mlp_data[f'test_{id}'].astype(np.float16), label='MLP Test', linestyle="--")
plt.plot(mlp_data[f'train_{id}'].astype(np.float16), label='MLP Train')

plt.plot(kan_metrics['test_loss'], label='KAN Test', linestyle="--")
plt.plot(kan_metrics['train_loss'], label='KAN Train')
# Add a legend and show the plot

plt.plot(mlp_metrics['test_loss'], label='MLP Test', linestyle="--")
plt.plot(mlp_metrics['train_loss'], label='MLP Train')
plt.xlabel('Steps')
plt.ylabel(metric)

# Add a legend and show the plot
plt.title(f"{metric} curves: KAN-GPT and MLP-GPT")

plt.xlabel('Steps')
plt.ylabel('Loss')
# Grid

plt.title("Training Curves of KAN-GPT and MLP-GPT")
plt.grid(True)

# Grid
plt.legend()
plt.draw()

plt.grid(True)
# Save to media/results_loss.png

plt.legend()
plt.show()
plt.savefig(f'media/results_{id}.png')

plt.show()
plt.cla()

0 comments on commit 5ed621d

Please sign in to comment.