Skip to content

Commit

Permalink
Bugfix; added test run file
Browse files Browse the repository at this point in the history
  • Loading branch information
RileyLazarou committed Oct 30, 2021
1 parent 65be12d commit 1c4bfc1
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 3 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# PokeGAN
GAN for generating pokemon sprites

## Quickstart:

In a virtual environment with Python 3.7+, install everything in `requirements.txt`. Then, run `test.py`

Samples:

![Alt text](samples/1.png) ![Alt text](samples/2.png) ![Alt text](samples/3.png) ![Alt text](samples/4.png) ![Alt text](samples/5.png) ![Alt text](samples/6.png) ![Alt text](samples/7.png) ![Alt text](samples/8.png) ![Alt text](samples/9.png) ![Alt text](samples/a.png) ![Alt text](samples/b.png) ![Alt text](samples/c.png) ![Alt text](samples/d.png) ![Alt text](samples/e.png) ![Alt text](samples/f.png) ![Alt text](samples/g.png) ![Alt text](samples/h.png) ![Alt text](samples/i.png) ![Alt text](samples/j.png)
2 changes: 1 addition & 1 deletion aegan.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class Generator(nn.Module):
Output shape: (?, 3, 96, 96)
"""

def __init__(self, latent_dim: int = 8):
def __init__(self, latent_dim: int = 16):
"""Initialize generator.
Args:
Expand Down
6 changes: 4 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@ Pillow==8.3.2
pyparsing==2.4.7
python-dateutil==2.8.1
six==1.15.0
torch==1.6.0
torchvision==0.7.0
torch==1.10.0+cu113
torchaudio==0.10.0+cu113
torchvision==0.11.1+cu113
typing-extensions==3.10.0.2
13 changes: 13 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch
from aegan import Generator as G
import torchvision.utils as vutils

device = torch.device('cpu')
netG = G()
netG.load_state_dict(torch.load('trained_generator_weights.pt', map_location=device))
vec = torch.randn((32, 16))
with torch.no_grad():
fake = netG(vec)

for i in range(32):
vutils.save_image(fake.data[i], f'testfake.{i:02d}.png', normalize=True)

0 comments on commit 1c4bfc1

Please sign in to comment.