Skip to content

Commit

Permalink
cherry pick pp fix (PaddlePaddle#8060)
Browse files Browse the repository at this point in the history
* cp pp fix

* fix
  • Loading branch information
lugimzzz authored Mar 6, 2024
1 parent 6ba8ae2 commit e27f3c0
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2426,20 +2426,26 @@ def save_pretrained(


class PipelinePretrainedModel(PretrainedModel):
_sequential_layers = []
_single_to_pp_mapping = None
_pp_to_single_mapping = None
def __init_hook__(self):
if not hasattr(self, "_sequential_layers"):
self._sequential_layers = []
self._single_to_pp_mapping = None
self._pp_to_single_mapping = None

def __init__(self, config, *args, **kwargs):
self.__init_hook__()
super().__init__(config, *args, **kwargs)

def add_sequential_layer(self, layer_desc, name_prefix=""):
self.__init_hook__()
self._sequential_layers.append({"layer": layer_desc, "name_prefix": name_prefix})

def get_sequential_layers(self):
self.__init_hook__()
return [x["layer"] for x in self._sequential_layers]

def get_sequential_name_prefixes(self):
self.__init_hook__()
return {str(index): x["name_prefix"] for index, x in enumerate(self._sequential_layers)}

def _set_pipeline_name_mapping(self, mappings=None):
Expand Down

0 comments on commit e27f3c0

Please sign in to comment.