Skip to content

Commit

Permalink
fix dtypes for float16
Browse files Browse the repository at this point in the history
  • Loading branch information
leondgarse committed Jun 17, 2023
1 parent e9c1c41 commit 1c35306
Show file tree
Hide file tree
Showing 10 changed files with 38 additions and 33 deletions.
8 changes: 4 additions & 4 deletions keras_cv_attention_models/beit/beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@ def build(self, input_shape):
pos_sin, pos_cos = np.reshape(pos_sin, [height * width, pos_fileters * 4]), np.reshape(pos_cos, [height * width, pos_fileters * 4])

if hasattr(self, "register_buffer"): # PyTorch
self.register_buffer("pos_sin", functional.convert_to_tensor(pos_sin, dtype="float32"), persistent=False)
self.register_buffer("pos_cos", functional.convert_to_tensor(pos_cos, dtype="float32"), persistent=False)
self.register_buffer("pos_sin", functional.convert_to_tensor(pos_sin, dtype=self.compute_dtype), persistent=False)
self.register_buffer("pos_cos", functional.convert_to_tensor(pos_cos, dtype=self.compute_dtype), persistent=False)
else:
self.pos_sin = functional.convert_to_tensor(pos_sin, dtype="float32")
self.pos_cos = functional.convert_to_tensor(pos_cos, dtype="float32")
self.pos_sin = functional.convert_to_tensor(pos_sin, dtype=self.compute_dtype)
self.pos_cos = functional.convert_to_tensor(pos_cos, dtype=self.compute_dtype)
super().build(input_shape)

def call(self, inputs, **kwargs):
Expand Down
8 changes: 4 additions & 4 deletions keras_cv_attention_models/edgenext/edgenext.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ def build(self, input_shape):
positional_embedding = np.concatenate([pos_hh, pos_ww], axis=-1) # [12, 27, 64]

if hasattr(self, "register_buffer"): # PyTorch
self.register_buffer("positional_embedding", functional.convert_to_tensor(positional_embedding, dtype="float32"), persistent=False)
self.register_buffer("positional_embedding", functional.convert_to_tensor(positional_embedding, dtype=self.compute_dtype), persistent=False)
else:
self.positional_embedding = functional.convert_to_tensor(positional_embedding, dtype="float32")
self.positional_embedding = functional.convert_to_tensor(positional_embedding, dtype=self.compute_dtype)

self.token_projection_ww = self.add_weight(name="ww", shape=(self.filters * 2, channels), trainable=True, dtype="float32")
self.token_projection_bb = self.add_weight(name="bb", shape=(channels,), trainable=True, dtype="float32")
self.token_projection_ww = self.add_weight(name="ww", shape=(self.filters * 2, channels), trainable=True)
self.token_projection_bb = self.add_weight(name="bb", shape=(channels,), trainable=True)
super().build(input_shape)

def call(self, inputs, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion keras_cv_attention_models/efficientdet/efficientdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, initializer="ones", epsilon=1e-4, **kwargs):

def build(self, input_shape):
self.total = len(input_shape)
self.gain = self.add_weight(name="gain", shape=(self.total,), initializer=self.initializer, dtype="float32", trainable=True)
self.gain = self.add_weight(name="gain", shape=(self.total,), initializer=self.initializer, trainable=True)
self.__epsilon__ = float(self.epsilon)
super().build(input_shape)

Expand Down
6 changes: 3 additions & 3 deletions keras_cv_attention_models/gpt2/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ def __init__(self, block_size, **kwargs):
def build(self, input_shape):
causal_mask = (1 - np.tri(self.block_size).astype("float32")[None, None]) * -1e10
if hasattr(self, "register_buffer"): # PyTorch
self.register_buffer("causal_mask", functional.convert_to_tensor(causal_mask, dtype="float32"), persistent=False)
self.register_buffer("causal_mask", functional.convert_to_tensor(causal_mask, dtype=self.compute_dtype), persistent=False)
else:
self.causal_mask = functional.convert_to_tensor(causal_mask, dtype="float32")
self.causal_mask = functional.convert_to_tensor(causal_mask, dtype=self.compute_dtype)
super().build(input_shape)

def call(self, inputs):
return inputs + functional.cast(self.causal_mask[:, :, : inputs.shape[2], : inputs.shape[3]], inputs.dtype)
return inputs + self.causal_mask[:, :, : inputs.shape[2], : inputs.shape[3]]

def get_config(self):
base_config = super().get_config()
Expand Down
2 changes: 1 addition & 1 deletion keras_cv_attention_models/gpvit/gpvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, shape, **kwargs):
self.shape = shape

def build(self, input_shape):
self.gain = self.add_weight(name="gain", shape=self.shape, dtype="float32", trainable=True)
self.gain = self.add_weight(name="gain", shape=self.shape, trainable=True)
super().build(input_shape)

def call(self, inputs, **kwargs):
Expand Down
6 changes: 3 additions & 3 deletions keras_cv_attention_models/nfnets/nfnets.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def build(self, input_shape):
default_conv_op = self._convolution_op # TF < 2.7.0
else:
default_conv_op = self.convolution_op # TF 2.7.0
self.gain = self.add_weight(name="gain", shape=(self.filters,), initializer="ones", trainable=True, dtype="float32")
self.gain = self.add_weight(name="gain", shape=(self.filters,), initializer="ones", trainable=True)
self.fan_in = float(np.prod(self.kernel.shape[:-1]))
self.__eps__ = float(self.eps)
self.__gamma__ = float(self.gamma)
Expand Down Expand Up @@ -91,9 +91,9 @@ def __init__(self, use_bias=False, weight_init_value=0, bias_init_value=0, **kwa
self.bb_init = initializers.Constant(bias_init_value) if bias_init_value != 0 else "zeros"

def build(self, input_shape):
self.gain = self.add_weight(name="gain", shape=(), initializer=self.ww_init, dtype="float32", trainable=True)
self.gain = self.add_weight(name="gain", shape=(), initializer=self.ww_init, trainable=True)
if self.use_bias:
self.bias = self.add_weight(name="bias", shape=(), initializer=self.bb_init, dtype="float32", trainable=True)
self.bias = self.add_weight(name="bias", shape=(), initializer=self.bb_init, trainable=True)
super().build(input_shape)

def call(self, inputs):
Expand Down
2 changes: 1 addition & 1 deletion keras_cv_attention_models/pytorch_backend/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def concat(inputs, axis, name=None):


def convert_to_tensor(inputs, dtype="float32"):
return torch.tensor(inputs, dtype=getattr(torch, dtype))
return torch.tensor(inputs, dtype=getattr(torch, dtype) if isinstance(dtype, str) else dtype)


def cos(inputs, name=None):
Expand Down
11 changes: 11 additions & 0 deletions keras_cv_attention_models/pytorch_backend/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def __init__(self, name, value):
def __repr__(self):
return "{}, shape={}".format(self.name, self.shape)

@property
def dtype(self):
return self.__value__.dtype

def value(self):
return self.__value__

Expand Down Expand Up @@ -280,6 +284,13 @@ def forward(self, inputs, **kwargs):
else:
return self.call(inputs, **kwargs)

@property
def compute_dtype(self):
try:
return next(self.parameters()).dtype
except StopIteration:
return torch.get_default_dtype()

@property
def weights(self):
return [Weight(name=self.name + "/" + kk.split(".")[-1], value=vv) for kk, vv in self.state_dict().items() if not kk.endswith(".num_batches_tracked")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def build(self, input_shape):
weight_shape[ii] = input_shape[ii]

initializer = initializers.constant(math.log(self.init_value))
self.scale = self.add_weight(name="weight", shape=weight_shape, initializer=initializer, trainable=True, dtype="float32")
self.scale = self.add_weight(name="weight", shape=weight_shape, initializer=initializer, trainable=True)
# self.__max_value__ = functional.convert_to_tensor(float(math.log(self.max_value)))
self.__max_value__ = float(math.log(self.max_value))
super().build(input_shape)
Expand Down Expand Up @@ -82,9 +82,9 @@ def build(self, input_shape):
relative_log_coords = np.sign(coords) * np.log(1.0 + np.abs(coords)) / (np.log(2.0) * 3.0)
relative_log_coords = np.reshape(relative_log_coords, [-1, 2]) # [23 * 29, 2]
if hasattr(self, "register_buffer"): # PyTorch
self.register_buffer("relative_log_coords", functional.convert_to_tensor(relative_log_coords, dtype="float32"), persistent=False)
self.register_buffer("relative_log_coords", functional.convert_to_tensor(relative_log_coords, dtype=self.compute_dtype), persistent=False)
else:
self.relative_log_coords = functional.convert_to_tensor(relative_log_coords, dtype="float32")
self.relative_log_coords = functional.convert_to_tensor(relative_log_coords, dtype=self.compute_dtype)
self.height, self.width = height, width # For reload with shape mismatched
super().build(input_shape)

Expand Down Expand Up @@ -168,15 +168,13 @@ def build(self, input_shape):
mask = np.transpose(mask, [0, 2, 1, 3])
mask = np.reshape(mask, [-1, self.window_height * self.window_width])
attn_mask = np.expand_dims(mask, 1) - np.expand_dims(mask, 2)
# attn_mask = tf.cast(np.where(attn_mask != 0, -100, 0), self._compute_dtype)
attn_mask = np.where(attn_mask != 0, -100, 0)
attn_mask = np.expand_dims(np.expand_dims(attn_mask, 1), 0) # expand dims on batch and num_heads
# attn_mask = functional.convert_to_tensor(attn_mask, dtype="float32")

if hasattr(self, "register_buffer"): # PyTorch
self.register_buffer("attn_mask", functional.convert_to_tensor(attn_mask, dtype="float32"), persistent=False)
self.register_buffer("attn_mask", functional.convert_to_tensor(attn_mask, dtype=self.compute_dtype), persistent=False)
else:
self.attn_mask = functional.convert_to_tensor(attn_mask, dtype="float32")
self.attn_mask = functional.convert_to_tensor(attn_mask, dtype=self.compute_dtype)

self.num_heads, self.query_blocks = input_shape[1], input_shape[2]
super().build(input_shape)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def build(self, input_shape):
axis = self.axis if isinstance(self.axis, (list, tuple)) else [self.axis]
for ii in axis:
weight_shape[ii] = input_shape[ii]
self.scale = self.add_weight(name="weight", shape=weight_shape, initializer=self.initializer, trainable=True, dtype="float32")
self.scale = self.add_weight(name="weight", shape=weight_shape, initializer=self.initializer, trainable=True)
super().build(input_shape)

def call(self, inputs, **kwargs):
Expand All @@ -59,14 +59,12 @@ def build(self, input_shape):
coords = np.stack([yy, xx], axis=-1).astype("float32") # [14, 14, 2]
coords_flatten = np.reshape(coords, [-1, 2]) # [196, 2]
relative_coords = coords_flatten[:, None, :] - coords_flatten[None, :, :] # [196, 196, 2]
# relative_coords = tf.reshape(relative_coords, [-1, 2]) # [196 * 196, 2]
# relative_coords = tf.cast(relative_coords, self.dtype)
relative_coords_log = np.sign(relative_coords) * np.log(1.0 + np.abs(relative_coords))

if hasattr(self, "register_buffer"): # PyTorch
self.register_buffer("relative_coords_log", functional.convert_to_tensor(relative_coords_log, dtype="float32"), persistent=False)
self.register_buffer("relative_coords_log", functional.convert_to_tensor(relative_coords_log, dtype=self.compute_dtype), persistent=False)
else:
self.relative_coords_log = functional.convert_to_tensor(relative_coords_log, dtype="float32")
self.relative_coords_log = functional.convert_to_tensor(relative_coords_log, dtype=self.compute_dtype)
self.height, self.width = height, width
super().build(input_shape)

Expand Down Expand Up @@ -102,15 +100,13 @@ def build(self, input_shape):
mask = np.transpose(mask, [0, 2, 1, 3])
mask = np.reshape(mask, [-1, self.window_height * self.window_width])
attn_mask = np.expand_dims(mask, 1) - np.expand_dims(mask, 2)
# attn_mask = tf.cast(np.where(attn_mask != 0, -100, 0), self._compute_dtype)
attn_mask = np.where(attn_mask != 0, -100, 0)
attn_mask = np.expand_dims(np.expand_dims(attn_mask, 1), 0) # expand dims on batch and num_heads
# attn_mask = functional.convert_to_tensor(attn_mask, dtype="float32")

if hasattr(self, "register_buffer"): # PyTorch
self.register_buffer("attn_mask", functional.convert_to_tensor(attn_mask, dtype="float32"), persistent=False)
self.register_buffer("attn_mask", functional.convert_to_tensor(attn_mask, dtype=self.compute_dtype), persistent=False)
else:
self.attn_mask = functional.convert_to_tensor(attn_mask, dtype="float32")
self.attn_mask = functional.convert_to_tensor(attn_mask, dtype=self.compute_dtype)

self.num_heads, self.query_blocks = input_shape[1], input_shape[2]
super().build(input_shape)
Expand Down

1 comment on commit 1c35306

@leondgarse
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For #122

Please sign in to comment.