Skip to content

Commit

Permalink
feat(README.md): results added
Browse files Browse the repository at this point in the history
  • Loading branch information
AdityaNG committed May 8, 2024
1 parent 310d335 commit 7835d51
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 2 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ You can prompt the model to produce text as follows
python -m kan_gpt.prompt --prompt "Bangalore is often described as the " --model_path (checkpoint)
```

## Results

We train and compare KAN-GPT with an equivalent MLP-GPT model on the Tiny Shakespeare dataset. We observe that the KAN-GPT performs slightly better than the MLP-GPT. We are looking into further experiments to dive deeper. The results are shown below:

<img src="media/results.png">

## TODOs

- [x] Integrate [minGPT](https://github.com/karpathy/minGPT) and [pykan](https://github.com/KindXiaoming/pykan)
Expand All @@ -107,6 +113,7 @@ python -m kan_gpt.prompt --prompt "Bangalore is often described as the " --model
- [ ] Define pydantic model for training and sweep args
- [ ] Pruning the package, get rid of unused code
- [ ] Training script to PyTorch Lighting
- [x] Documentation: `mkdocs gh-deploy`
- [x] Integrate with [efficient-kan](https://github.com/Blealtan/efficient-kan/blob/master/src/efficient_kan/kan.py)
- [x] Test Cases
- [x] KAN: Forward-Backward test
Expand Down
4 changes: 2 additions & 2 deletions kan_gpt/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ def main(args):
import argparse

parser = argparse.ArgumentParser("KAN-GPT Trainer")
parser.add_argument("--model_type", default="gpt-mini")
parser.add_argument("--model_type", default="gpt-micro")
parser.add_argument("--model_path", default=None)
parser.add_argument("--max_tokens", default=100)

parser.add_argument(
"--prompt", default="Bangalore is often described as the "
"--prompt", default="Out of thy sleep. What is it thou didst say?"
)
parser.add_argument(
"--architecture", choices=["MLP", "KAN"], default="KAN"
Expand Down
Binary file added media/results.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 9 additions & 0 deletions scripts/download_checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import wandb

api = wandb.Api()

mlp_model = api.artifact("adityang/KAN-GPT/model:v35")
kan_model = api.artifact("adityang/KAN-GPT/model:v34")

mlp_model.download(root="weights/")
kan_model.download(root="weights/")
46 changes: 46 additions & 0 deletions scripts/plot_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import wandb

api = wandb.Api()

mlp_run = api.run("KAN-GPT/axi1qzwv")
kan_run = api.run("KAN-GPT/eusdq4te")

keys = [
'train_loss', 'test_loss'
]

mlp_metrics = mlp_run.history(keys=keys)
kan_metrics = kan_run.history(keys=keys)

print("MLP")
print(mlp_metrics)
print("="*20)

print("KAN")
print(kan_metrics)
print("="*20)

# Plot the test and train losses for the two models

import matplotlib.pyplot as plt

plt.plot(kan_metrics['test_loss'], label='KAN Test', linestyle="--")
plt.plot(kan_metrics['train_loss'], label='KAN Train')

plt.plot(mlp_metrics['test_loss'], label='MLP Test', linestyle="--")
plt.plot(mlp_metrics['train_loss'], label='MLP Train')

# Add a legend and show the plot

plt.xlabel('Steps')
plt.ylabel('Loss')

plt.title("Training Curves of KAN-GPT and MLP-GPT")

# Grid

plt.grid(True)

plt.legend()
plt.show()

0 comments on commit 7835d51

Please sign in to comment.