diff --git a/delft/utilities/Transformer.py b/delft/utilities/Transformer.py index 23c7292..24e3a74 100644 --- a/delft/utilities/Transformer.py +++ b/delft/utilities/Transformer.py @@ -199,7 +199,12 @@ def instantiate_layer(self, load_pretrained_weights=True) -> Union[object, TFAut elif self.loading_method == LOADING_METHOD_LOCAL_MODEL_DIR: if load_pretrained_weights: - transformer_model = TFAutoModel.from_pretrained(self.local_dir_path, from_pt=True) + try: + transformer_model = TFAutoModel.from_pretrained(self.local_dir_path, from_pt=True) + except: + # failure might be due to safetensors format for the weights, we can try an alternative loading + # for this case + transformer_model = TFAutoModel.from_pretrained(self.local_dir_path) self.transformer_config = transformer_model.config return transformer_model else: