Skip to content

Commit

Permalink
test(kan): coverage improved
Browse files Browse the repository at this point in the history
  • Loading branch information
AdityaNG committed May 9, 2024
1 parent f9827dc commit 87425fd
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 20 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ We train and compare KAN-GPT with an equivalent MLP-GPT model on the Tiny Shakes
- [x] Mini training POC for MLP-GPT
- [x] Train MLP-GPT on the webtext dataset as a baseline
- [x] Train KAN-GPT on the webtext dataset as a baseline
- [ ] Metrics comparing KAN-GPT and MLP-GPT
- [x] Metrics comparing KAN-GPT and MLP-GPT
- [x] Auto Save checkpoints
- [x] Auto Save checkpoints to W&B
- [ ] Auto Download model weights from git / huggingface
Expand Down
2 changes: 1 addition & 1 deletion kan_gpt/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.0.0
1.0.1
14 changes: 12 additions & 2 deletions kan_gpt/kan/KAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,14 @@ def forward(self, x):
torch.Size([100, 3])
"""

B, C, T = x.shape
shape_size = len(x.shape)

if shape_size == 3:
B, C, T = x.shape
elif shape_size == 2:
B, T = x.shape
else:
raise NotImplementedError()

x = x.view(-1, T)

Expand Down Expand Up @@ -403,7 +410,10 @@ def forward(self, x):

U = x.shape[1]

x = x.view(B, C, U)
if shape_size == 3:
x = x.view(B, C, U)
elif shape_size == 2:
assert x.shape == (B, U)

return x

Expand Down
54 changes: 38 additions & 16 deletions kan_gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,6 @@
from kan_gpt.mingpt.utils import CfgNode as CN
from kan_gpt.settings import settings


def get_KAN():
if settings.kan.KAN_IMPLEMENTATION == "EFFICIENT_KAN":
KAN = EFFICIENT_KAN # type: ignore
elif settings.kan.KAN_IMPLEMENTATION == "ORIGINAL_KAN":
KAN = ORIGINAL_KAN # type: ignore

return KAN


# -----------------------------------------------------------------------------


Expand Down Expand Up @@ -65,7 +55,8 @@ class CausalSelfAttention(nn.Module):

def __init__(self, config):
super().__init__()
KAN = get_KAN()
self.kan_implementation = config.kan_implementation
KAN = self.get_KAN()

assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads, but in a batch
Expand All @@ -86,6 +77,16 @@ def __init__(self, config):
self.n_head = config.n_head
self.n_embd = config.n_embd

def get_KAN(self):
if self.kan_implementation == "EFFICIENT_KAN":
KAN = EFFICIENT_KAN # type: ignore
elif self.kan_implementation == "ORIGINAL_KAN":
KAN = ORIGINAL_KAN # type: ignore
else:
raise NotImplementedError()

return KAN

def forward(self, x):
B, T, C = (
x.size()
Expand Down Expand Up @@ -125,7 +126,8 @@ class Block(nn.Module):

def __init__(self, config):
super().__init__()
KAN = get_KAN()
self.kan_implementation = config.kan_implementation
KAN = self.get_KAN()
self.ln_1 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.ln_2 = nn.LayerNorm(config.n_embd)
Expand All @@ -142,6 +144,14 @@ def __init__(self, config):
m.c_proj(m.act(m.c_fc(x)))
) # MLP forward

def get_KAN(self):
if self.kan_implementation == "EFFICIENT_KAN":
KAN = EFFICIENT_KAN # type: ignore
elif self.kan_implementation == "ORIGINAL_KAN":
KAN = ORIGINAL_KAN # type: ignore

return KAN

def forward(self, x):
x = x + self.attn(self.ln_1(x))
x = x + self.mlpf(self.ln_2(x))
Expand All @@ -167,13 +177,17 @@ def get_default_config():
C.embd_pdrop = 0.1
C.resid_pdrop = 0.1
C.attn_pdrop = 0.1
C.attn_pdrop = 0.1
# KAN Implementation
C.kan_implementation = settings.kan.KAN_IMPLEMENTATION
return C

def __init__(self, config):
super().__init__()
assert config.vocab_size is not None
assert config.block_size is not None
self.block_size = config.block_size
self.kan_implementation = config.kan_implementation

type_given = config.model_type is not None
params_given = all(
Expand Down Expand Up @@ -228,7 +242,7 @@ def __init__(self, config):
ln_f=nn.LayerNorm(config.n_embd),
)
)
KAN = get_KAN()
KAN = self.get_KAN()
self.lm_head = KAN(
width=[config.n_embd, config.vocab_size], bias_trainable=False
)
Expand All @@ -247,6 +261,14 @@ def __init__(self, config):
n_params = sum(p.numel() for p in self.transformer.parameters())
print("number of parameters: %.2fM" % (n_params / 1e6,))

def get_KAN(self):
if self.kan_implementation == "EFFICIENT_KAN":
KAN = EFFICIENT_KAN # type: ignore
elif self.kan_implementation == "ORIGINAL_KAN":
KAN = ORIGINAL_KAN # type: ignore

return KAN

def kan_loss(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -294,7 +316,7 @@ def nonlinear(x, th=small_mag_threshold, factor=small_reg_factor):

total_reg = torch.tensor(0.0).to(device=x.device, dtype=torch.float32)
size = 0
KAN = get_KAN()
KAN = self.get_KAN()
for mod in self.modules():
if isinstance(mod, KAN):
total_reg += reg(mod)
Expand All @@ -304,7 +326,7 @@ def nonlinear(x, th=small_mag_threshold, factor=small_reg_factor):
return mean_reg

def _init_weights(self, module):
KAN = get_KAN()
KAN = self.get_KAN()
if isinstance(module, KAN):
# torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
# if module.bias is not None:
Expand Down Expand Up @@ -380,7 +402,7 @@ def configure_optimizers(self, train_config):
# regularizing weight decay
decay = set()
no_decay = set()
KAN = get_KAN()
KAN = self.get_KAN()
whitelist_weight_modules = (KAN,)
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules():
Expand Down
2 changes: 2 additions & 0 deletions tests/test_gpt_kan.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def get_gpt_model_efficient():
model_config.model_type = MODEL_TYPE
model_config.vocab_size = VOCAB_SIZE
model_config.block_size = BLOCK_SIZE
model_config.kan_implementation = os.getenv("KAN_IMPLEMENTATION")
model = KAN_GPT(model_config)

del KAN_GPT
Expand All @@ -35,6 +36,7 @@ def get_gpt_model_original():
model_config.model_type = MODEL_TYPE
model_config.vocab_size = VOCAB_SIZE
model_config.block_size = BLOCK_SIZE
model_config.kan_implementation = os.getenv("KAN_IMPLEMENTATION")
model = KAN_GPT(model_config)

del KAN_GPT
Expand Down
33 changes: 33 additions & 0 deletions tests/test_gpt_mlp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import random
from tempfile import TemporaryDirectory

import torch
from kan_gpt.mingpt.model import GPT as MLP_GPT
from kan_gpt.mingpt.utils import set_seed, setup_logging, CfgNode as CN

VOCAB_SIZE = 8
BLOCK_SIZE = 16
Expand Down Expand Up @@ -93,3 +97,32 @@ def test_backward_batched():
if isinstance(param.grad, torch.Tensor):
grad_set.add(param)
assert len(grad_set) > 0, f"Tensor.grad missing"


def test_CN():
C = CN()
C.device = "auto"
assert C.device == "auto", "Unable to set param"


def test_seed_set():
seed = 0
set_seed(seed)

rr1 = random.random()
rr2 = random.random()

set_seed(seed)

rr3 = random.random()

assert rr1 == rr3, "seed not set correctly"
assert rr1 != rr2, "seed not set correctly"


def test_setup_logging():
C = CN()
with TemporaryDirectory() as folder:
C.system = CN()
C.system.work_dir = folder
setup_logging(C)
24 changes: 24 additions & 0 deletions tests/test_kan.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from tempfile import TemporaryDirectory
import torch
from kan_gpt.kan.KAN import KAN
from kan_gpt.kan.utils import create_dataset


def test_forward():
Expand Down Expand Up @@ -72,3 +74,25 @@ def test_backward_batched():
if isinstance(param.grad, torch.Tensor):
grad_set.add(param)
assert len(grad_set) > 0, f"Tensor.grad missing"


def test_plot():
model = KAN(width=[2, 3, 2, 1])
x = torch.normal(0, 1, size=(100, 1, 2))
model(x)
beta = 100
with TemporaryDirectory() as folder:
model.plot(beta=beta, folder=folder)


def test_train():
f = lambda x: torch.exp(torch.sin(torch.pi * x[:, [0]]) + x[:, [1]] ** 2)
dataset = create_dataset(f, n_var=2)
dataset["train_input"].shape, dataset["train_label"].shape

model = KAN(width=[2, 1], grid=5, k=3, seed=0)
model.train_kan(dataset, opt="LBFGS", steps=1, lamb=0.1)
model.plot()
model.prune()
with TemporaryDirectory() as folder:
model.plot(mask=True, folder=folder)

0 comments on commit 87425fd

Please sign in to comment.