Skip to content

Commit

Permalink
Bug fix for lsq with QATv2 mode
Browse files Browse the repository at this point in the history
  • Loading branch information
oguzhanbsolak committed Jan 23, 2025
1 parent df0875b commit a36e8b9
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions ai8x.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,15 @@ def backward(ctx, grad_output):
return torch.ones(grad_output.shape).type_as(grad_output)


class ConvertToTensorwithSpecialway(Function):
@staticmethod
def forward(ctx, x):
return torch.mul(x, 1)

@staticmethod
def backward(ctx, grad_output):
return grad_output

class qSigned(Function):
@staticmethod
def forward(ctx, x, log2_t, bit_width, mode=False):
Expand All @@ -814,15 +823,21 @@ def forward(ctx, x, log2_t, bit_width, mode=False):
n = -bit_max
p = bit_max - 1
s = 2.**Ceil.apply(log2_t) / bit_max

if not mode:
q = torch.clamp(RoundToEven.apply(x / s), n, p) * s
ctx.save_for_backward(x / s, s, number_to_tensor(n, x),
number_to_tensor(p, x))
else:
s2 = 1 / bit_max

#print("S2", s2)
q = torch.clamp(RoundToEven.apply(x / s), n, p) * s2
#TODO: Check if this is correct = Seems Correct
ctx.save_for_backward(x / s, s, number_to_tensor(n, x),
number_to_tensor(p, x))
#print(n, p)
#print("Q", q.min(), q.max())
n_save = ConvertToTensorwithSpecialway.apply(n)
p_save = ConvertToTensorwithSpecialway.apply(p)
ctx.save_for_backward(x / s, s, n_save, p_save)
return q

@staticmethod
Expand Down Expand Up @@ -1374,7 +1389,7 @@ def forward(self, x): # pylint: disable=arguments-differ
params_r = torch.flatten(self.op.weight.detach())
else:
params_r = self.op.weight.detach()
if dev.lsq_weight_scale and self.adjust_output_shift.detach():
if dev.lsq_weight_scale:
pass
#alpha = grad_scale(self.alpha, self.g)
#Prevent alpha from going to 0
Expand Down Expand Up @@ -1402,6 +1417,7 @@ def forward(self, x): # pylint: disable=arguments-differ
#self.op.bias,
)
elif not dev.fakeactquant and dev.lsq_weight_scale:
#print("BURDAYIM")
x = self._conv_forward( # pylint: disable=protected-access
x,
self.FakeQuantizeWeight.apply(self.op.weight, self.alpha, self.weight_bits.detach().item(), not dev.fakeactquant),
Expand Down Expand Up @@ -2969,7 +2985,7 @@ def pre_qat(model, train_loader, args, qat_policy):
if args.fake_act_quant and args.lsq_weight_scale:
set_device(dev.device, dev.simulate, dev.round_avg, args.fake_act_quant, args.per_channel, args.lsq_weight_scale)
#print("Device configuration: ", dev)
elif args.fake_act_quant and not args.lsq_weight_scale:
elif not args.fake_act_quant and args.lsq_weight_scale:
set_device(dev.device, dev.simulate, dev.round_avg, args.fake_act_quant, args.per_channel, args.lsq_weight_scale)
apply_scales(model)
else:
Expand Down

0 comments on commit a36e8b9

Please sign in to comment.