-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathengine.py
64 lines (52 loc) · 2.38 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import torch.nn as nn
from tqdm import tqdm
# define the binary_cross_entropy loss function
def loss_fn(outputs, targets):
return nn.BCEWithLogitsLoss()(outputs, targets.view(-1, 1))
# define the trainging function
def train_fn(data_loader, model, optimizer, device, scheduler):
# Set the model to training mode
model.train()
# trange is a tqdm wrapper around the normal python range
for bi, d in tqdm(enumerate(data_loader), total=len(data_loader)):
# Unpack training batch from our dataloader.
ids = d["ids"]
mask = d["mask"]
targets = d["targets"]
# copy each tensor to the GPU
ids = ids.to(device, dtype=torch.long)
mask = mask.to(device, dtype=torch.long)
targets = targets.to(device, dtype=torch.float)
# clear any previously calculated gradients
optimizer.zero_grad()
# outputs prior to activation.
outputs = model(ids=ids, mask=mask,)
loss = loss_fn(outputs, targets) # Perform a loss funtion
loss.backward() # Perform a backward pass to calculate the gradients
optimizer.step() # Update parameters
scheduler.step() # Update the learning rate
# define the validation function
def eval_fn(data_loader, model, device):
model.eval() # Set the model to training mode
fin_targets = [] # target variable
fin_outputs = [] # ouput variable
# Tell pytorch not to bother with constructing the compute graph during
# the forward pass, since this is only needed for backprop (training).
with torch.no_grad():
# trange is a tqdm wrapper around the normal python range
for bi, d in tqdm(enumerate(data_loader), total=len(data_loader)):
# Unpack validation batch from our dataloader.
ids = d["ids"]
mask = d["mask"]
targets = d["targets"]
# copy each tensor to the GPU
ids = ids.to(device, dtype=torch.long)
mask = mask.to(device, dtype=torch.long)
targets = targets.to(device, dtype=torch.float)
# outputs prior to activation.
outputs = model(ids=ids, mask=mask,)
# Move target and output to CPU
fin_targets.extend(targets.cpu().detach().numpy().tolist())
fin_outputs.extend(torch.sigmoid(outputs).cpu().detach().numpy().tolist())
return fin_outputs, fin_targets