diff --git a/keras_cv_attention_models/beit/beit.py b/keras_cv_attention_models/beit/beit.py index ca5b3afc..3de0182d 100644 --- a/keras_cv_attention_models/beit/beit.py +++ b/keras_cv_attention_models/beit/beit.py @@ -84,11 +84,11 @@ def build(self, input_shape): pos_sin, pos_cos = np.reshape(pos_sin, [height * width, pos_fileters * 4]), np.reshape(pos_cos, [height * width, pos_fileters * 4]) if hasattr(self, "register_buffer"): # PyTorch - self.register_buffer("pos_sin", functional.convert_to_tensor(pos_sin, dtype="float32"), persistent=False) - self.register_buffer("pos_cos", functional.convert_to_tensor(pos_cos, dtype="float32"), persistent=False) + self.register_buffer("pos_sin", functional.convert_to_tensor(pos_sin, dtype=self.compute_dtype), persistent=False) + self.register_buffer("pos_cos", functional.convert_to_tensor(pos_cos, dtype=self.compute_dtype), persistent=False) else: - self.pos_sin = functional.convert_to_tensor(pos_sin, dtype="float32") - self.pos_cos = functional.convert_to_tensor(pos_cos, dtype="float32") + self.pos_sin = functional.convert_to_tensor(pos_sin, dtype=self.compute_dtype) + self.pos_cos = functional.convert_to_tensor(pos_cos, dtype=self.compute_dtype) super().build(input_shape) def call(self, inputs, **kwargs): diff --git a/keras_cv_attention_models/edgenext/edgenext.py b/keras_cv_attention_models/edgenext/edgenext.py index 1794b010..d3b8f226 100644 --- a/keras_cv_attention_models/edgenext/edgenext.py +++ b/keras_cv_attention_models/edgenext/edgenext.py @@ -49,12 +49,12 @@ def build(self, input_shape): positional_embedding = np.concatenate([pos_hh, pos_ww], axis=-1) # [12, 27, 64] if hasattr(self, "register_buffer"): # PyTorch - self.register_buffer("positional_embedding", functional.convert_to_tensor(positional_embedding, dtype="float32"), persistent=False) + self.register_buffer("positional_embedding", functional.convert_to_tensor(positional_embedding, dtype=self.compute_dtype), persistent=False) else: - self.positional_embedding = functional.convert_to_tensor(positional_embedding, dtype="float32") + self.positional_embedding = functional.convert_to_tensor(positional_embedding, dtype=self.compute_dtype) - self.token_projection_ww = self.add_weight(name="ww", shape=(self.filters * 2, channels), trainable=True, dtype="float32") - self.token_projection_bb = self.add_weight(name="bb", shape=(channels,), trainable=True, dtype="float32") + self.token_projection_ww = self.add_weight(name="ww", shape=(self.filters * 2, channels), trainable=True) + self.token_projection_bb = self.add_weight(name="bb", shape=(channels,), trainable=True) super().build(input_shape) def call(self, inputs, **kwargs): diff --git a/keras_cv_attention_models/efficientdet/efficientdet.py b/keras_cv_attention_models/efficientdet/efficientdet.py index 919cce27..4d5829d2 100644 --- a/keras_cv_attention_models/efficientdet/efficientdet.py +++ b/keras_cv_attention_models/efficientdet/efficientdet.py @@ -36,7 +36,7 @@ def __init__(self, initializer="ones", epsilon=1e-4, **kwargs): def build(self, input_shape): self.total = len(input_shape) - self.gain = self.add_weight(name="gain", shape=(self.total,), initializer=self.initializer, dtype="float32", trainable=True) + self.gain = self.add_weight(name="gain", shape=(self.total,), initializer=self.initializer, trainable=True) self.__epsilon__ = float(self.epsilon) super().build(input_shape) diff --git a/keras_cv_attention_models/gpt2/gpt2.py b/keras_cv_attention_models/gpt2/gpt2.py index e854b7d7..1fac024a 100644 --- a/keras_cv_attention_models/gpt2/gpt2.py +++ b/keras_cv_attention_models/gpt2/gpt2.py @@ -48,13 +48,13 @@ def __init__(self, block_size, **kwargs): def build(self, input_shape): causal_mask = (1 - np.tri(self.block_size).astype("float32")[None, None]) * -1e10 if hasattr(self, "register_buffer"): # PyTorch - self.register_buffer("causal_mask", functional.convert_to_tensor(causal_mask, dtype="float32"), persistent=False) + self.register_buffer("causal_mask", functional.convert_to_tensor(causal_mask, dtype=self.compute_dtype), persistent=False) else: - self.causal_mask = functional.convert_to_tensor(causal_mask, dtype="float32") + self.causal_mask = functional.convert_to_tensor(causal_mask, dtype=self.compute_dtype) super().build(input_shape) def call(self, inputs): - return inputs + functional.cast(self.causal_mask[:, :, : inputs.shape[2], : inputs.shape[3]], inputs.dtype) + return inputs + self.causal_mask[:, :, : inputs.shape[2], : inputs.shape[3]] def get_config(self): base_config = super().get_config() diff --git a/keras_cv_attention_models/gpvit/gpvit.py b/keras_cv_attention_models/gpvit/gpvit.py index 5b01105a..9d356f18 100644 --- a/keras_cv_attention_models/gpvit/gpvit.py +++ b/keras_cv_attention_models/gpvit/gpvit.py @@ -37,7 +37,7 @@ def __init__(self, shape, **kwargs): self.shape = shape def build(self, input_shape): - self.gain = self.add_weight(name="gain", shape=self.shape, dtype="float32", trainable=True) + self.gain = self.add_weight(name="gain", shape=self.shape, trainable=True) super().build(input_shape) def call(self, inputs, **kwargs): diff --git a/keras_cv_attention_models/nfnets/nfnets.py b/keras_cv_attention_models/nfnets/nfnets.py index 84213e48..a9493ba0 100644 --- a/keras_cv_attention_models/nfnets/nfnets.py +++ b/keras_cv_attention_models/nfnets/nfnets.py @@ -57,7 +57,7 @@ def build(self, input_shape): default_conv_op = self._convolution_op # TF < 2.7.0 else: default_conv_op = self.convolution_op # TF 2.7.0 - self.gain = self.add_weight(name="gain", shape=(self.filters,), initializer="ones", trainable=True, dtype="float32") + self.gain = self.add_weight(name="gain", shape=(self.filters,), initializer="ones", trainable=True) self.fan_in = float(np.prod(self.kernel.shape[:-1])) self.__eps__ = float(self.eps) self.__gamma__ = float(self.gamma) @@ -91,9 +91,9 @@ def __init__(self, use_bias=False, weight_init_value=0, bias_init_value=0, **kwa self.bb_init = initializers.Constant(bias_init_value) if bias_init_value != 0 else "zeros" def build(self, input_shape): - self.gain = self.add_weight(name="gain", shape=(), initializer=self.ww_init, dtype="float32", trainable=True) + self.gain = self.add_weight(name="gain", shape=(), initializer=self.ww_init, trainable=True) if self.use_bias: - self.bias = self.add_weight(name="bias", shape=(), initializer=self.bb_init, dtype="float32", trainable=True) + self.bias = self.add_weight(name="bias", shape=(), initializer=self.bb_init, trainable=True) super().build(input_shape) def call(self, inputs): diff --git a/keras_cv_attention_models/pytorch_backend/functional.py b/keras_cv_attention_models/pytorch_backend/functional.py index c9107868..3b1d207c 100644 --- a/keras_cv_attention_models/pytorch_backend/functional.py +++ b/keras_cv_attention_models/pytorch_backend/functional.py @@ -51,7 +51,7 @@ def concat(inputs, axis, name=None): def convert_to_tensor(inputs, dtype="float32"): - return torch.tensor(inputs, dtype=getattr(torch, dtype)) + return torch.tensor(inputs, dtype=getattr(torch, dtype) if isinstance(dtype, str) else dtype) def cos(inputs, name=None): diff --git a/keras_cv_attention_models/pytorch_backend/layers.py b/keras_cv_attention_models/pytorch_backend/layers.py index 46c941ae..cc12c023 100644 --- a/keras_cv_attention_models/pytorch_backend/layers.py +++ b/keras_cv_attention_models/pytorch_backend/layers.py @@ -53,6 +53,10 @@ def __init__(self, name, value): def __repr__(self): return "{}, shape={}".format(self.name, self.shape) + @property + def dtype(self): + return self.__value__.dtype + def value(self): return self.__value__ @@ -280,6 +284,13 @@ def forward(self, inputs, **kwargs): else: return self.call(inputs, **kwargs) + @property + def compute_dtype(self): + try: + return next(self.parameters()).dtype + except StopIteration: + return torch.get_default_dtype() + @property def weights(self): return [Weight(name=self.name + "/" + kk.split(".")[-1], value=vv) for kk, vv in self.state_dict().items() if not kk.endswith(".num_batches_tracked")] diff --git a/keras_cv_attention_models/swin_transformer_v2/swin_transformer_v2.py b/keras_cv_attention_models/swin_transformer_v2/swin_transformer_v2.py index e4b41591..f2f01031 100644 --- a/keras_cv_attention_models/swin_transformer_v2/swin_transformer_v2.py +++ b/keras_cv_attention_models/swin_transformer_v2/swin_transformer_v2.py @@ -45,7 +45,7 @@ def build(self, input_shape): weight_shape[ii] = input_shape[ii] initializer = initializers.constant(math.log(self.init_value)) - self.scale = self.add_weight(name="weight", shape=weight_shape, initializer=initializer, trainable=True, dtype="float32") + self.scale = self.add_weight(name="weight", shape=weight_shape, initializer=initializer, trainable=True) # self.__max_value__ = functional.convert_to_tensor(float(math.log(self.max_value))) self.__max_value__ = float(math.log(self.max_value)) super().build(input_shape) @@ -82,9 +82,9 @@ def build(self, input_shape): relative_log_coords = np.sign(coords) * np.log(1.0 + np.abs(coords)) / (np.log(2.0) * 3.0) relative_log_coords = np.reshape(relative_log_coords, [-1, 2]) # [23 * 29, 2] if hasattr(self, "register_buffer"): # PyTorch - self.register_buffer("relative_log_coords", functional.convert_to_tensor(relative_log_coords, dtype="float32"), persistent=False) + self.register_buffer("relative_log_coords", functional.convert_to_tensor(relative_log_coords, dtype=self.compute_dtype), persistent=False) else: - self.relative_log_coords = functional.convert_to_tensor(relative_log_coords, dtype="float32") + self.relative_log_coords = functional.convert_to_tensor(relative_log_coords, dtype=self.compute_dtype) self.height, self.width = height, width # For reload with shape mismatched super().build(input_shape) @@ -168,15 +168,13 @@ def build(self, input_shape): mask = np.transpose(mask, [0, 2, 1, 3]) mask = np.reshape(mask, [-1, self.window_height * self.window_width]) attn_mask = np.expand_dims(mask, 1) - np.expand_dims(mask, 2) - # attn_mask = tf.cast(np.where(attn_mask != 0, -100, 0), self._compute_dtype) attn_mask = np.where(attn_mask != 0, -100, 0) attn_mask = np.expand_dims(np.expand_dims(attn_mask, 1), 0) # expand dims on batch and num_heads - # attn_mask = functional.convert_to_tensor(attn_mask, dtype="float32") if hasattr(self, "register_buffer"): # PyTorch - self.register_buffer("attn_mask", functional.convert_to_tensor(attn_mask, dtype="float32"), persistent=False) + self.register_buffer("attn_mask", functional.convert_to_tensor(attn_mask, dtype=self.compute_dtype), persistent=False) else: - self.attn_mask = functional.convert_to_tensor(attn_mask, dtype="float32") + self.attn_mask = functional.convert_to_tensor(attn_mask, dtype=self.compute_dtype) self.num_heads, self.query_blocks = input_shape[1], input_shape[2] super().build(input_shape) diff --git a/keras_cv_attention_models/swin_transformer_v2/swin_transformer_v2_timm.py b/keras_cv_attention_models/swin_transformer_v2/swin_transformer_v2_timm.py index 07cad1df..f4cfd0ce 100644 --- a/keras_cv_attention_models/swin_transformer_v2/swin_transformer_v2_timm.py +++ b/keras_cv_attention_models/swin_transformer_v2/swin_transformer_v2_timm.py @@ -33,7 +33,7 @@ def build(self, input_shape): axis = self.axis if isinstance(self.axis, (list, tuple)) else [self.axis] for ii in axis: weight_shape[ii] = input_shape[ii] - self.scale = self.add_weight(name="weight", shape=weight_shape, initializer=self.initializer, trainable=True, dtype="float32") + self.scale = self.add_weight(name="weight", shape=weight_shape, initializer=self.initializer, trainable=True) super().build(input_shape) def call(self, inputs, **kwargs): @@ -59,14 +59,12 @@ def build(self, input_shape): coords = np.stack([yy, xx], axis=-1).astype("float32") # [14, 14, 2] coords_flatten = np.reshape(coords, [-1, 2]) # [196, 2] relative_coords = coords_flatten[:, None, :] - coords_flatten[None, :, :] # [196, 196, 2] - # relative_coords = tf.reshape(relative_coords, [-1, 2]) # [196 * 196, 2] - # relative_coords = tf.cast(relative_coords, self.dtype) relative_coords_log = np.sign(relative_coords) * np.log(1.0 + np.abs(relative_coords)) if hasattr(self, "register_buffer"): # PyTorch - self.register_buffer("relative_coords_log", functional.convert_to_tensor(relative_coords_log, dtype="float32"), persistent=False) + self.register_buffer("relative_coords_log", functional.convert_to_tensor(relative_coords_log, dtype=self.compute_dtype), persistent=False) else: - self.relative_coords_log = functional.convert_to_tensor(relative_coords_log, dtype="float32") + self.relative_coords_log = functional.convert_to_tensor(relative_coords_log, dtype=self.compute_dtype) self.height, self.width = height, width super().build(input_shape) @@ -102,15 +100,13 @@ def build(self, input_shape): mask = np.transpose(mask, [0, 2, 1, 3]) mask = np.reshape(mask, [-1, self.window_height * self.window_width]) attn_mask = np.expand_dims(mask, 1) - np.expand_dims(mask, 2) - # attn_mask = tf.cast(np.where(attn_mask != 0, -100, 0), self._compute_dtype) attn_mask = np.where(attn_mask != 0, -100, 0) attn_mask = np.expand_dims(np.expand_dims(attn_mask, 1), 0) # expand dims on batch and num_heads - # attn_mask = functional.convert_to_tensor(attn_mask, dtype="float32") if hasattr(self, "register_buffer"): # PyTorch - self.register_buffer("attn_mask", functional.convert_to_tensor(attn_mask, dtype="float32"), persistent=False) + self.register_buffer("attn_mask", functional.convert_to_tensor(attn_mask, dtype=self.compute_dtype), persistent=False) else: - self.attn_mask = functional.convert_to_tensor(attn_mask, dtype="float32") + self.attn_mask = functional.convert_to_tensor(attn_mask, dtype=self.compute_dtype) self.num_heads, self.query_blocks = input_shape[1], input_shape[2] super().build(input_shape)