-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
91 lines (77 loc) · 2.83 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchvision.datasets as datasets
# Define the encoder part of the VAE
class Encoder(nn.Module):
def __init__(self, input_dim, hidden_dim, z_dim):
super(Encoder, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2_mu = nn.Linear(hidden_dim, z_dim)
self.fc2_logvar = nn.Linear(hidden_dim, z_dim)
def forward(self, x):
h = torch.relu(self.fc1(x))
mu = self.fc2_mu(h)
logvar = self.fc2_logvar(h)
return mu, logvar
# Define the decoder part of the VAE
class Decoder(nn.Module):
def __init__(self, z_dim, hidden_dim, output_dim):
super(Decoder, self).__init__()
self.fc1 = nn.Linear(z_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, z):
h = torch.relu(self.fc1(z))
x_recon = torch.sigmoid(self.fc2(h))
return x_recon
# Define the VAE model
class VAE(nn.Module):
def __init__(self, input_dim, hidden_dim, z_dim):
super(VAE, self).__init__()
self.encoder = Encoder(input_dim, hidden_dim, z_dim)
self.decoder = Decoder(z_dim, hidden_dim, input_dim)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
mu, logvar = self.encoder(x)
z = self.reparameterize(mu, logvar)
x_recon = self.decoder(z)
return x_recon, mu, logvar
# Loss function
def loss_function(x_recon, x, mu, logvar):
BCE = nn.functional.binary_cross_entropy(x_recon, x, reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
# Hyperparameters
input_dim = 28 * 28
hidden_dim = 400
z_dim = 20
batch_size = 128
learning_rate = 1e-3
num_epochs = 10
# Data loading
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
# Model, optimizer
model = VAE(input_dim, hidden_dim, z_dim)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Training loop
model.train()
for epoch in range(num_epochs):
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.view(-1, 28 * 28) # Flatten the input data
optimizer.zero_grad()
x_recon, mu, logvar = model(data)
loss = loss_function(x_recon, data, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()
print(f'Epoch {epoch + 1}, Loss: {train_loss / len(train_loader.dataset)}')
# Save the trained model
torch.save(model.state_dict(), 'vae.pth')