Skip to content

Commit

Permalink
[gemini] fix param op hook when output is tuple (#5355)
Browse files Browse the repository at this point in the history
* [gemini] fix param op hook when output is tuple

* [gemini] fix param op hook
  • Loading branch information
ver217 authored Feb 4, 2024
1 parent 1c790c0 commit 2dd01e3
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
5 changes: 3 additions & 2 deletions colossalai/tensor/colo_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@

from .colo_tensor import _convert_output

WHITE_LIST_FUNCS = {torch.Tensor.__getitem__, torch.Tensor.is_floating_point}
WHITE_LIST_FUNCS = {torch.Tensor.__getitem__}
NO_HOOK_FUNCS = {torch.Tensor.is_floating_point}


def is_no_hook_op(func) -> bool:
return func.__name__.startswith("__") and func not in WHITE_LIST_FUNCS
return (func.__name__.startswith("__") and func not in WHITE_LIST_FUNCS) or func in NO_HOOK_FUNCS


def filter_colo_parameters(*args, **kwargs):
Expand Down
8 changes: 5 additions & 3 deletions colossalai/tensor/param_op_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ def pre_op(params: List[torch.Tensor], *args: Any) -> list:
@staticmethod
def post_op(params: List[torch.Tensor], arg: Any) -> Any:
ColoParamOpHookManager._trigger_post_forward(params)
return PostFwdPreBwd.apply(params, arg)
# incase the output is a tuple, we have to flatten it
grad_args, other_args, grad_flags, spec = _flatten_grad_args(arg)
new_grad_args = PostFwdPreBwd.apply(params, *grad_args)
return _merge_args(new_grad_args, other_args, grad_flags, spec)

@staticmethod
def has_hook() -> bool:
Expand All @@ -113,7 +116,7 @@ def backward(ctx, *grads):

class PostFwdPreBwd(torch.autograd.Function):
@staticmethod
def forward(ctx, params, args):
def forward(ctx, params, *args):
ctx.params = params
return args

Expand Down Expand Up @@ -142,7 +145,6 @@ def _flatten_grad_args(args) -> Tuple[list, list, List[bool], TreeSpec]:
grad_args.append(arg)
else:
other_args.append(arg)
assert len(grad_args) > 0
return grad_args, other_args, grad_flags, spec


Expand Down

0 comments on commit 2dd01e3

Please sign in to comment.