Skip to content

Workaround to duplicate decorators #71

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion refactor/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,17 @@ def _replace_input(self, node: ast.AST) -> _LazyActionMixin[K, T]:

class _ReplaceCodeSegmentAction(BaseAction):
def apply(self, context: Context, source: str) -> str:
# The decorators are removed in the 'lines' but present in the 'context`
# This lead to the 'replacement' containing the decorators and the returned
# 'lines' to duplicate them. Proposed workaround is to add the decorators in
# the 'view', in case the '_resynthesize()' adds/modifies them
lines = split_lines(source, encoding=context.file_info.get_encoding())
(
lineno,
col_offset,
end_lineno,
end_col_offset,
) = self._get_segment_span(context)
) = self._get_decorated_segment_span(context)

view = slice(lineno - 1, end_lineno)
source_lines = lines[view]
Expand All @@ -102,6 +106,9 @@ def apply(self, context: Context, source: str) -> str:
def _get_segment_span(self, context: Context) -> PositionType:
raise NotImplementedError

def _get_decorated_segment_span(self, context: Context) -> PositionType:
raise NotImplementedError

def _resynthesize(self, context: Context) -> str:
raise NotImplementedError

Expand All @@ -121,6 +128,13 @@ class LazyReplace(_ReplaceCodeSegmentAction, _LazyActionMixin[ast.AST, ast.AST])
def _get_segment_span(self, context: Context) -> PositionType:
return position_for(self.node)

def _get_decorated_segment_span(self, context: Context) -> PositionType:
lineno, col_offset, end_lineno, end_col_offset = position_for(self.node)
# Add the decorators to the segment span to resolve an issue with def -> async def
if hasattr(self.node, "decorator_list") and len(getattr(self.node, "decorator_list")) > 0:
lineno, _, _, _ = position_for(getattr(self.node, "decorator_list")[0])
return lineno, col_offset, end_lineno, end_col_offset

def _resynthesize(self, context: Context) -> str:
return context.unparse(self.build())

Expand Down Expand Up @@ -228,6 +242,9 @@ class _Rename(Replace):
def _get_segment_span(self, context: Context) -> PositionType:
return self.identifier_span

def _get_decorated_segment_span(self, context: Context) -> PositionType:
return self.identifier_span

def _resynthesize(self, context: Context) -> str:
return self.target.name

Expand Down Expand Up @@ -260,6 +277,13 @@ def is_critical_node(self, context: Context) -> bool:
def _get_segment_span(self, context: Context) -> PositionType:
return position_for(self.node)

def _get_decorated_segment_span(self, context: Context) -> PositionType:
lineno, col_offset, end_lineno, end_col_offset = position_for(self.node)
# Add the decorators to the segment span to resolve an issue with def -> async def
if hasattr(self.node, "decorator_list") and len(getattr(self.node, "decorator_list")) > 0:
lineno, _, _, _ = position_for(getattr(self.node, "decorator_list")[0])
return lineno, col_offset, end_lineno, end_col_offset

def _resynthesize(self, context: Context) -> str:
if self.is_critical_node(context):
raise InvalidActionError(
Expand Down
18 changes: 18 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,24 @@ def func():
assert position_for(right_node) == (3, 23, 3, 25)


def test_get_positions_with_decorator():
source = textwrap.dedent(
"""\
@deco0
@deco1(arg0,
arg1)
def func():
if a > 5:
return 5 + 3 + 25
elif b > 10:
return 1 + 3 + 5 + 7
"""
)
tree = ast.parse(source)
right_node = tree.body[0].body[0].body[0].value.right
assert position_for(right_node) == (6, 23, 6, 25)


def test_singleton():
from dataclasses import dataclass

Expand Down
102 changes: 102 additions & 0 deletions tests/test_complete_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,107 @@ def match(self, node):
return AsyncifierAction(node)


class MakeFunctionAsyncWithDecorators(Rule):
INPUT_SOURCE = """
@deco0
@deco1(arg0,
arg1)
def something():
a += .1
'''you know
this is custom
literal
'''
print(we,
preserve,
everything
)
return (
right + "?")
"""

EXPECTED_SOURCE = """
@deco0
@deco1(arg0,
arg1)
async def something():
a += .1
'''you know
this is custom
literal
'''
print(we,
preserve,
everything
)
return (
right + "?")
"""

def match(self, node):
assert isinstance(node, ast.FunctionDef)
return AsyncifierAction(node)


class OnlyKeywordArgumentDefaultNotSetCheckRule(Rule):
context_providers = (context.Scope,)

INPUT_SOURCE = """
class Klass:
def method(self, *, a):
print()

lambda self, *, a: print

"""

EXPECTED_SOURCE = """
class Klass:
def method(self, *, a=None):
print()

lambda self, *, a=None: print

"""

def match(self, node: ast.AST) -> BaseAction | None:
assert isinstance(node, (ast.FunctionDef, ast.Lambda))
assert any(kw_default is None for kw_default in node.args.kw_defaults)

if isinstance(node, ast.Lambda) and not (
isinstance(node.body, ast.Name) and isinstance(node.body.ctx, ast.Load)
):
scope = self.context["scope"].resolve(node.body)
scope.definitions.get(node.body.id, [])

elif isinstance(node, ast.FunctionDef):
for stmt in node.body:
for identifier in ast.walk(stmt):
if not (
isinstance(identifier, ast.Name)
and isinstance(identifier.ctx, ast.Load)
):
continue

scope = self.context["scope"].resolve(identifier)
while not scope.definitions.get(identifier.id, []):
scope = scope.parent
if scope is None:
break

kw_defaults = []
for kw_default in node.args.kw_defaults:
if kw_default is None:
kw_defaults.append(ast.Constant(value=None))
else:
kw_defaults.append(kw_default)

target = deepcopy(node)
target.args.kw_defaults = kw_defaults

return Replace(node, target)


class OnlyKeywordArgumentDefaultNotSetCheckRule(Rule):
context_providers = (context.Scope,)

Expand Down Expand Up @@ -944,6 +1045,7 @@ def match(self, node: ast.AST) -> Iterator[Replace | InsertAfter]:
@pytest.mark.parametrize(
"rule",
[
MakeFunctionAsyncWithDecorators,
ReplaceNexts,
ReplacePlaceholders,
PropagateConstants,
Expand Down