Skip to content

Commit

Permalink
[Distributed] adapt sequence parallel on LoRA (PaddlePaddle#8235)
Browse files Browse the repository at this point in the history
  • Loading branch information
SylarTiaNII authored Apr 10, 2024
1 parent 3cc32da commit 06c2fdc
Show file tree
Hide file tree
Showing 2 changed files with 309 additions and 1 deletion.
245 changes: 245 additions & 0 deletions paddlenlp/peft/lora/lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@
RowParallelLinear,
)

from paddlenlp.transformers.sequence_parallel_utils import (
AllGatherOp,
ColumnSequenceParallelLinear,
MC2ColumnSeqParallelLinear,
MC2RowSeqParallelLinear,
ReduceScatterOp,
RowSequenceParallelLinear,
mark_as_sequence_parallel_parameter,
)

from .lora_quick_layers import quick_lora

if "npu" in paddle.device.get_all_custom_device_type():
Expand All @@ -34,6 +44,10 @@
MC2LoRaColumnParallelLinear = None


def is_mc2_valid():
return "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0"))


class LoRALinear(nn.Linear):
# LoRA implemented in a dense layer
def __init__(
Expand Down Expand Up @@ -265,6 +279,120 @@ def extra_repr(self):
return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}"


class RowSequenceParallelLoRALinear(RowSequenceParallelLinear):
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
rslora: bool = False,
lora_plus_scale: float = 1.0,
merge_weights: bool = True,
use_quick_lora: bool = False,
**kwargs
):
RowSequenceParallelLinear.__init__(self, in_features, out_features, **kwargs)
if not isinstance(r, int) or r <= 0:
raise ValueError("Lora rank r should be a positive integer")
self.r = r
self.lora_alpha = lora_alpha
# Optional dropout
if lora_dropout > 0.0:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x
# Mark the weight as unmerged
self.merged = False
self.merge_weights = merge_weights

# compatible
self.name = self._name

# Actual trainable parameters
self.lora_A = self.create_parameter(
shape=[self.input_size_per_partition, r],
dtype=self._dtype,
is_bias=False,
attr=paddle.ParamAttr(
initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu")
),
)
self.lora_B = self.create_parameter(
shape=[r, self.out_features],
dtype=self._dtype,
is_bias=False,
attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=0.0),
learning_rate=lora_plus_scale,
),
)

self.lora_A.is_distributed = True
self.lora_A.split_axis = 0
self.lora_B.is_distributed = False
mark_as_sequence_parallel_parameter(self.lora_B)
if not rslora:
self.scaling = self.lora_alpha / self.r
else:
self.scaling = self.lora_alpha / math.sqrt(self.r)

# Freezing the pre-trained weight matrix
self.weight.stop_gradient = True
self._use_quick_lora = use_quick_lora and lora_dropout == 0.0

@property
def use_quick_lora(self):
# TODO(@gexiao): support qlora
return False # self._use_quick_lora and self.training and not self.merged

def train(self):
super().train()
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling
self.weight.set_value(new_weight)
self.merged = False

def eval(self):
super().eval()
if self.merge_weights and not self.merged:
# Merge the weights and mark it
new_weight = self.weight + self.lora_A @ self.lora_B * self.scaling
self.weight.set_value(new_weight)
self.merged = True

def forward(self, x: paddle.Tensor):
if not self.input_is_parallel:
input_mp = mp_ops._c_split(x, group=self.model_parallel_group)
else:
input_mp = x

if not is_mc2_valid():
output_parallel = self.linear(input_mp, self.weight, name=self._name)
output_ = ReduceScatterOp.apply(output_parallel)
result_mp = output_ + self.bias if self.bias is not None else output_
else:
output_ = MC2RowSeqParallelLinear.apply(input_mp, self.weight, self.model_parallel_group)
result_mp = output_ + self.bias if self.bias is not None else output_

if not self.merged:
input_mp = self.lora_dropout(input_mp)
if not is_mc2_valid():
input_mp = input_mp @ self.lora_A
input_mp = ReduceScatterOp.apply(input_mp)
else:
input_mp = MC2RowSeqParallelLinear.apply(input_mp, self.lora_A, self.model_parallel_group)
delta_mp = (input_mp @ self.lora_B) * self.scaling
result_mp += delta_mp
return result_mp

def extra_repr(self):
name = f", name={self.name}" if self.name else ""
return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}"


class ColumnParallelLoRALinear(ColumnParallelLinear):
def __init__(
self,
Expand Down Expand Up @@ -390,6 +518,123 @@ def extra_repr(self):
return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}"


class ColumnSequenceParallelLoRALinear(ColumnSequenceParallelLinear):
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
rslora: bool = False,
lora_plus_scale: float = 1.0,
merge_weights: bool = True,
lora_A_weight_attr: Optional[paddle.ParamAttr] = None,
use_quick_lora: bool = False,
**kwargs
):
ColumnSequenceParallelLinear.__init__(self, in_features, out_features, **kwargs)
if not isinstance(r, int) or r <= 0:
raise ValueError("Lora rank r should be a positive integer")
self.r = r
self.lora_alpha = lora_alpha
# Optional dropout
if lora_dropout > 0.0:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x
# Mark the weight as unmerged
self.merged = False
self.merge_weights = merge_weights

# compatible
self.name = self._name

# Actual trainable parameters
self.lora_A = self.create_parameter(
shape=[in_features, r],
dtype=self._dtype,
is_bias=False,
attr=lora_A_weight_attr,
)
self.lora_A.is_distributed = False
mark_as_sequence_parallel_parameter(self.lora_A)

self.lora_B = self.create_parameter(
shape=[r, self.output_size_per_partition],
dtype=self._dtype,
is_bias=False,
attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=0.0),
learning_rate=lora_plus_scale,
),
)

self.lora_B.is_distributed = True
self.lora_B.split_axis = 1
if not rslora:
self.scaling = self.lora_alpha / self.r
else:
self.scaling = self.lora_alpha / math.sqrt(self.r)

# Freezing the pre-trained weight matrix
self.weight.stop_gradient = True
self._use_quick_lora = use_quick_lora and lora_dropout == 0.0

@property
def use_quick_lora(self):
# TODO(@gexiao): support qlora
return False # self._use_quick_lora and self.training and not self.merged

def train(self):
super().train()
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling
self.weight.set_value(new_weight)
self.merged = False

def eval(self):
super().eval()
if self.merge_weights and not self.merged:
# Merge the weights and mark it
new_weight = self.weight + self.lora_A @ self.lora_B * self.scaling
self.weight.set_value(new_weight)
self.merged = True

def forward(self, x: paddle.Tensor):
if not is_mc2_valid():
if self.is_mp:
input_parallel = AllGatherOp.apply(x)
else:
input_parallel = x
result_mp = self.linear(input_parallel, self.weight, self.bias, name=self._name)
else:
result_mp = MC2ColumnSeqParallelLinear.apply(x, self.weight, self.model_parallel_group)
if self.bias is not None:
result_mp += self.bias

if not self.merged:
input_a = self.lora_dropout(x) @ self.lora_A
if not is_mc2_valid():
input_a = AllGatherOp.apply(input_a)
delta_mp = (input_a @ self.lora_B) * self.scaling
else:
input_a = MC2ColumnSeqParallelLinear.apply(input_a, self.lora_B, self.model_parallel_group)
delta_mp = input_a * self.scaling
result_mp += delta_mp

if self.gather_output and self.is_mp:
result = mp_ops._c_concat(result_mp, group=self.model_parallel_group)
else:
result = result_mp
return result

def extra_repr(self):
name = f", name={self.name}" if self.name else ""
return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}"


class LoRAMergedLinear(nn.Linear):
# LoRA implemented in a dense layer with merged linear weights for q, k, v
def __init__(
Expand Down
65 changes: 64 additions & 1 deletion paddlenlp/peft/lora/lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@
RowParallelLinear,
)

from paddlenlp.transformers.sequence_parallel_utils import (
ColumnSequenceParallelLinear,
RowSequenceParallelLinear,
)

from ...transformers.conversion_utils import ConversionMixin
from ...transformers.model_utils import PretrainedModel, _add_variant, dtype_guard
from ...transformers.utils import weight_name_suffix
Expand All @@ -41,10 +46,12 @@
from .lora_layers import (
ColumnParallelLoRALinear,
ColumnParallelLoRAMergedLinear,
ColumnSequenceParallelLoRALinear,
LoRAConv2D,
LoRALinear,
LoRAMergedLinear,
RowParallelLoRALinear,
RowSequenceParallelLoRALinear,
)

try:
Expand Down Expand Up @@ -366,6 +373,58 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
# Lora column parallel will spilt lora A matrix
self.add_lora_split_mapping(module_name + ".lora_A", is_column=False)

# for lora qat
if self.lora_config.do_qat:
self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=False)
self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False)
self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False)
elif isinstance(module, ColumnSequenceParallelLinear):
# recover the original output_features
output_features = module.weight.shape[1] * module.world_size
lora_module = ColumnSequenceParallelLoRALinear(
in_features=module.weight.shape[0],
out_features=output_features,
gather_output=module.gather_output,
has_bias=module.bias is not None,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout,
rslora=lora_config.rslora,
lora_plus_scale=lora_config.lora_plus_scale,
merge_weights=lora_config.merge_weights,
lora_A_weight_attr=paddle.ParamAttr(
initializer=nn.initializer.KaimingUniform(
negative_slope=math.sqrt(5), nonlinearity="leaky_relu"
)
),
use_quick_lora=lora_config.use_quick_lora,
)
# Lora column parallel will spilt lora B matrix
self.add_lora_split_mapping(module_name + ".lora_B", is_column=True)

# for lora qat
if self.lora_config.do_qat:
self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=True)
self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False)
self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False)
elif isinstance(module, RowSequenceParallelLinear):
# recover the original output_features
lora_module = RowSequenceParallelLoRALinear(
in_features=module.weight.shape[0] * module.world_size,
out_features=module.weight.shape[1],
has_bias=module.bias is not None,
input_is_parallel=module.input_is_parallel,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout,
rslora=lora_config.rslora,
lora_plus_scale=lora_config.lora_plus_scale,
merge_weights=lora_config.merge_weights,
use_quick_lora=lora_config.use_quick_lora,
)
# Lora column parallel will spilt lora A matrix
self.add_lora_split_mapping(module_name + ".lora_A", is_column=False)

# for lora qat
if self.lora_config.do_qat:
self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=False)
Expand Down Expand Up @@ -451,7 +510,7 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
)
if lora_module is None:
raise ValueError(
f"LoRA strategy only supports paddle.nn.Linear or paddle.distributed.fleet.meta_parallel.ColumnParallelLinear. {module}({module_name}) is not supported。"
f"LoRA strategy only supports paddle.nn.Linear or paddle.distributed.fleet.meta_parallel.ColumnParallelLinear or paddlenlp.transformers.sequence_utils. {module}({module_name} {type(module).__name__}) is not supported。"
)
if getattr(lora_module, "quant_weight", None) is not None:
lora_module.quant_weight = module.quant_weight
Expand Down Expand Up @@ -509,6 +568,8 @@ def mark_only_lora_as_trainable(self) -> None:
or isinstance(layer, LoRAConv2D)
or isinstance(layer, ColumnParallelLoRALinear)
or isinstance(layer, RowParallelLoRALinear)
or isinstance(layer, ColumnSequenceParallelLoRALinear)
or isinstance(layer, RowSequenceParallelLoRALinear)
or isinstance(layer, LoRAMergedLinear)
or isinstance(layer, ColumnParallelLoRAMergedLinear)
or (QuantizationLoRALinear is not None and isinstance(layer, QuantizationLoRALinear))
Expand Down Expand Up @@ -596,9 +657,11 @@ def restore_original_model(self):
self._find_and_restore_module(layer_name)
elif (
isinstance(layer, ColumnParallelLoRALinear)
or isinstance(layer, ColumnSequenceParallelLoRALinear)
or isinstance(layer, LoRAConv2D)
or isinstance(layer, ColumnParallelLoRAMergedLinear)
or isinstance(layer, RowParallelLoRALinear)
or isinstance(layer, RowSequenceParallelLoRALinear)
or (QuantizationLoRALinear is not None and isinstance(layer, QuantizationLoRALinear))
or (
ColumnParallelQuantizationLoRALinear is not None
Expand Down

0 comments on commit 06c2fdc

Please sign in to comment.