-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainTextExtractorNet.py
90 lines (72 loc) · 2.62 KB
/
trainTextExtractorNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Trains the text extraction network
import sys
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.optim as optim
# fancy output
# from tqdm import tqdm
# ssim
from pytorch_ssim import SSIM
# local
from model import TextExtractor
from data_gen import PatchifyDB
if __name__ == '__main__':
in_path = '../data/ourdata/X/s'
target_path = '../data/ourdata/Y/s'
model_path = './model/m1-{}-{}.pt'
optimizer_state_path = './model/m1-opt-{}.pt'
patch_size = 256
patch_per_image = 1
model_save_freq = 5
# checkpointing
checkpoint_load = True
model_state = './model/m1-4-0.17196649312973022.pt'
optimizer_state = './model/m1-opt-4.pt'
device = torch.device("cuda:1" if torch.cuda.is_available()
else "cpu")
# param
batch_size = 8
num_epoch = 100
ssim_window_size = 23
num_workers = 4
# get data
transform = transforms.Compose([transforms.Grayscale(),
transforms.ToTensor()])
db = PatchifyDB(in_path, target_path, patch_size,
patch_per_image, transform=transform)
data_loader = DataLoader(db, batch_size=batch_size, shuffle=True,
num_workers=num_workers)
# get the model
model = TextExtractor().to(device)
ssim = SSIM(window_size=ssim_window_size, size_average=True)
optimizer = optim.Adam(model.parameters(), lr=0.001, eps=1e-07)
if checkpoint_load:
model.load_state_dict(torch.load(model_state))
optimizer.load_state_dict(torch.load(optimizer_state))
model.train()
print('Restoring checkpoint: {} - {}'.format(
model_state, optimizer_state))
# Train
# ADD: make a tqdm based progress bar
for epoch in range(num_epoch):
for in_batch, target_batch in data_loader:
# to GPU
in_batch = in_batch.to(device)
target_batch = target_batch.to(device)
optimizer.zero_grad()
outputs = model(in_batch)
loss = 0.5*(1 - ssim(outputs, target_batch))
loss.backward()
optimizer.step()
sys.stdout.write('[%d/%d] - loss: %.5f\r' %
(epoch + 1, num_epoch,
loss.item()))
sys.stdout.flush()
sys.stdout.write('\n')
# save model
if epoch % model_save_freq == 0:
torch.save(model.state_dict(),
model_path.format(epoch, loss))
torch.save(optimizer.state_dict(),
optimizer_state_path.format(epoch))