Skip to content

Commit

Permalink
[#27] support multiple GPUs training
Browse files Browse the repository at this point in the history
  • Loading branch information
linchuming committed Oct 26, 2021
1 parent 657411e commit ed86a0d
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 31 deletions.
2 changes: 1 addition & 1 deletion AFSD/anet/BDNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def forward(self, feat_dict, ssl=False):
prop_loc = torch.cat([o.view(batch_num, -1, 2) for o in prop_locs], 1)
prop_conf = torch.cat([o.view(batch_num, -1, num_classes) for o in prop_confs], 1)
center = torch.cat([o.view(batch_num, -1, 1) for o in centers], 1)
priors = torch.cat(self.priors, 0).to(loc.device)
priors = torch.cat(self.priors, 0).to(loc.device).unsqueeze(0)
return loc, conf, prop_loc, prop_conf, center, priors, start, end, \
start_loc_prop, end_loc_prop, start_conf_prop, end_conf_prop

Expand Down
2 changes: 1 addition & 1 deletion AFSD/anet/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def sub_processor(lock, pid, video_list):
with torch.no_grad():
output_dict = net(clip)

loc, conf, priors = output_dict['loc'], output_dict['conf'], output_dict['priors']
loc, conf, priors = output_dict['loc'], output_dict['conf'], output_dict['priors'][0]
prop_loc, prop_conf = output_dict['prop_loc'], output_dict['prop_conf']
center = output_dict['center']
loc = loc[0]
Expand Down
4 changes: 2 additions & 2 deletions AFSD/anet/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def sub_processor(lock, pid, video_list):
flow_output_dict = flow_net(flow_clip)

loc, conf, priors = rgb_output_dict['loc'], rgb_output_dict['conf'], \
rgb_output_dict['priors']
rgb_output_dict['priors'][0]
prop_loc, prop_conf = rgb_output_dict['prop_loc'], rgb_output_dict['prop_conf']
center = rgb_output_dict['center']

Expand All @@ -154,7 +154,7 @@ def sub_processor(lock, pid, video_list):
rgb_center = center

loc, conf, priors = flow_output_dict['loc'], flow_output_dict['conf'], \
flow_output_dict['priors']
flow_output_dict['priors'][0]
prop_loc, prop_conf = flow_output_dict['prop_loc'], flow_output_dict['prop_conf']
center = flow_output_dict['center']

Expand Down
27 changes: 17 additions & 10 deletions AFSD/anet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
os.makedirs(train_state_path)

resume = config['training']['resume']
config['training']['ssl'] = 0.1


def print_training_info():
Expand All @@ -37,8 +36,10 @@ def print_training_info():
print('checkpoint path: ', checkpoint_path)
print('loc weight: ', config['training']['lw'])
print('cls weight: ', config['training']['cw'])
print('piou: ', config['training']['piou'])
print('ssl weight: ', config['training']['ssl'])
print('piou:', config['training']['piou'])
print('resume: ', resume)
print('gpu num: ', ngpu)


def set_seed(seed):
Expand Down Expand Up @@ -115,7 +116,7 @@ def forward_one_epoch(net, clips, targets, scores=None, training=True, ssl=True)

if training:
if ssl:
output_dict = net(clips, proposals=targets, ssl=ssl)
output_dict = net.module(clips, proposals=targets, ssl=ssl)
else:
output_dict = net(clips, ssl=False)
else:
Expand All @@ -134,7 +135,7 @@ def forward_one_epoch(net, clips, targets, scores=None, training=True, ssl=True)
loss_l, loss_c, loss_prop_l, loss_prop_c, loss_ct = CPD_Loss(
[output_dict['loc'], output_dict['conf'],
output_dict['prop_loc'], output_dict['prop_conf'],
output_dict['center'], output_dict['priors']],
output_dict['center'], output_dict['priors'][0]],
targets)
loss_start, loss_end = calc_bce_loss(output_dict['start'], output_dict['end'], scores)
scores_ = F.interpolate(scores, scale_factor=1.0 / 8)
Expand Down Expand Up @@ -177,12 +178,18 @@ def run_one_epoch(epoch, net, optimizer, data_loader, epoch_step_num, training=T
loss_ct = loss_ct * config['training']['cw']
cost = loss_l + loss_c + loss_prop_l + loss_prop_c + loss_ct + loss_start + loss_end

if flags[0]:
loss_trip = forward_one_epoch(net, ssl_clips, ssl_targets, training=training,
ssl=True)
loss_trip *= config['training']['ssl']
cost = cost + loss_trip
loss_trip_val += loss_trip.cpu().detach().numpy()
ssl_count = 0
loss_trip = 0
for i in range(len(flags)):
if flags[i] and config['training']['ssl'] > 0:
loss_trip += forward_one_epoch(net, ssl_clips[i].unsqueeze(0), [ssl_targets[i]],
training=training, ssl=True) * config['training']['ssl']
loss_trip_val += loss_trip.cpu().detach().numpy()
ssl_count += 1
if ssl_count:
loss_trip_val /= ssl_count
loss_trip /= ssl_count
cost = cost + loss_trip

if training:
optimizer.zero_grad()
Expand Down
2 changes: 2 additions & 0 deletions AFSD/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def get_config():

parser.add_argument('--lw', type=float, default=10.0)
parser.add_argument('--cw', type=float, default=1)
parser.add_argument('--ssl', type=float, default=0.1)
parser.add_argument('--piou', type=float, default=0)
parser.add_argument('--resume', type=int, default=0)
parser.add_argument('--ngpu', type=int, default=1)
Expand Down Expand Up @@ -55,6 +56,7 @@ def get_config():
data['training']['lw'] = args.lw
data['training']['cw'] = args.cw
data['training']['piou'] = args.piou
data['training']['ssl'] = args.ssl
data['training']['resume'] = args.resume
data['ngpu'] = args.ngpu
data['testing']['fusion'] = args.fusion
Expand Down
2 changes: 1 addition & 1 deletion AFSD/thumos14/BDNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def forward(self, feat_dict, ssl=False):
prop_loc = torch.cat([o.view(batch_num, -1, 2) for o in prop_locs], 1)
prop_conf = torch.cat([o.view(batch_num, -1, num_classes) for o in prop_confs], 1)
center = torch.cat([o.view(batch_num, -1, 1) for o in centers], 1)
priors = torch.cat(self.priors, 0).to(loc.device)
priors = torch.cat(self.priors, 0).to(loc.device).unsqueeze(0)
return loc, conf, prop_loc, prop_conf, center, priors, start, end, \
start_loc_prop, end_loc_prop, start_conf_prop, end_conf_prop

Expand Down
4 changes: 2 additions & 2 deletions AFSD/thumos14/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
if fusion:
flow_output_dict = flow_net(flow_clip)

loc, conf, priors = output_dict['loc'], output_dict['conf'], output_dict['priors']
loc, conf, priors = output_dict['loc'], output_dict['conf'], output_dict['priors'][0]
prop_loc, prop_conf = output_dict['prop_loc'], output_dict['prop_conf']
center = output_dict['center']
if fusion:
Expand All @@ -128,7 +128,7 @@
rgb_center = center[0]

loc, conf, priors = flow_output_dict['loc'], flow_output_dict['conf'], \
flow_output_dict['priors']
flow_output_dict['priors'][0]
prop_loc, prop_conf = flow_output_dict['prop_loc'], flow_output_dict['prop_conf']
center = flow_output_dict['center']

Expand Down
32 changes: 18 additions & 14 deletions AFSD/thumos14/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,13 @@
checkpoint_path = config['training']['checkpoint_path']
focal_loss = config['training']['focal_loss']
random_seed = config['training']['random_seed']
ngpu = config['ngpu']

train_state_path = os.path.join(checkpoint_path, 'training')
if not os.path.exists(train_state_path):
os.makedirs(train_state_path)

resume = config['training']['resume']
config['training']['ssl'] = 0.1


def print_training_info():
print('batch size: ', batch_size)
Expand All @@ -37,11 +36,11 @@ def print_training_info():
print('checkpoint path: ', checkpoint_path)
print('loc weight: ', config['training']['lw'])
print('cls weight: ', config['training']['cw'])
print('iou weight: ', config['training']['piou'])
print('ssl weight: ', config['training']['ssl'])
print('piou:', config['training']['piou'])
print('resume: ', resume)

print('gpu num: ', ngpu)


def set_seed(seed):
torch.manual_seed(seed)
Expand Down Expand Up @@ -117,7 +116,7 @@ def forward_one_epoch(net, clips, targets, scores=None, training=True, ssl=True)

if training:
if ssl:
output_dict = net(clips, proposals=targets, ssl=ssl)
output_dict = net.module(clips, proposals=targets, ssl=ssl)
else:
output_dict = net(clips, ssl=False)
else:
Expand All @@ -136,7 +135,7 @@ def forward_one_epoch(net, clips, targets, scores=None, training=True, ssl=True)
loss_l, loss_c, loss_prop_l, loss_prop_c, loss_ct = CPD_Loss(
[output_dict['loc'], output_dict['conf'],
output_dict['prop_loc'], output_dict['prop_conf'],
output_dict['center'], output_dict['priors']],
output_dict['center'], output_dict['priors'][0]],
targets)
loss_start, loss_end = calc_bce_loss(output_dict['start'], output_dict['end'], scores)
scores_ = F.interpolate(scores, scale_factor=1.0 / 4)
Expand Down Expand Up @@ -180,13 +179,18 @@ def run_one_epoch(epoch, net, optimizer, data_loader, epoch_step_num, training=T
loss_ct = loss_ct * config['training']['cw']
cost = loss_l + loss_c + loss_prop_l + loss_prop_c + loss_ct + loss_start + loss_end

if flags[0]:
loss_trip = forward_one_epoch(net, ssl_clips, ssl_targets, training=training,
ssl=True)
loss_trip *= config['training']['ssl']
cost = cost + loss_trip
loss_trip_val += loss_trip.cpu().detach().numpy()

ssl_count = 0
loss_trip = 0
for i in range(len(flags)):
if flags[i] and config['training']['ssl'] > 0:
loss_trip += forward_one_epoch(net, ssl_clips[i].unsqueeze(0), [ssl_targets[i]],
training=training, ssl=True) * config['training']['ssl']
loss_trip_val += loss_trip.cpu().detach().numpy()
ssl_count += 1
if ssl_count:
loss_trip_val /= ssl_count
loss_trip /= ssl_count
cost = cost + loss_trip
if training:
optimizer.zero_grad()
cost.backward()
Expand Down Expand Up @@ -236,7 +240,7 @@ def run_one_epoch(epoch, net, optimizer, data_loader, epoch_step_num, training=T
"""
net = BDNet(in_channels=config['model']['in_channels'],
backbone_model=config['model']['backbone_model'])
net = nn.DataParallel(net, device_ids=[0]).cuda()
net = nn.DataParallel(net, device_ids=list(range(ngpu))).cuda()

"""
Setup optimizer
Expand Down

0 comments on commit ed86a0d

Please sign in to comment.