-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
42 lines (34 loc) · 1.32 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import my_datasets
from model import mymodel
max_epoch = 30
if __name__ == '__main__':
train_datas=my_datasets.mydatasets("./dataset/train")
test_data=my_datasets.mydatasets("./dataset/test")
train_dataloader=DataLoader(train_datas,batch_size=64,shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
m=mymodel().cuda()
loss_fn=nn.MultiLabelSoftMarginLoss().cuda()
optimizer = torch.optim.Adam(m.parameters(), lr=0.001)
w=SummaryWriter("logs")
total_step=0
for epoch in range(max_epoch):
print("外层训练次数{}".format(epoch))
for i,(imgs,targets) in enumerate(train_dataloader):
imgs=imgs.cuda()
targets=targets.cuda()
outputs=m(imgs)
loss = loss_fn(outputs, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i%100==0:
total_step+=1
print("训练{}次,loss:{}".format(total_step*10, loss.item()))
w.add_scalar("loss",loss,total_step)
# writer.add_images("imgs", imgs, i)
# tensorboard --logdir=logs
torch.save(m, "./checkpoints/model.pth")