Skip to content

Commit

Permalink
fix(typing): improve decorator type hinting
Browse files Browse the repository at this point in the history
The type hinting for the most commonly used decorators were incomplete,
resulting in decorated functions being obscured.

This makes use of the special type variable `ParamSpec` which allows the
type hinting a view into the parameters of a function. As ``ParamSpec`
was introduced in Python 3.10, `ParamSpec` is imported from the
`typing_extensions` module instead of the standard library.

I have also taken the opportunity to fix other instances of `Callable`
type hints missing their arguments.

Signed-off-by: JP-Ellis <josh@jpellis.me>
  • Loading branch information
JP-Ellis committed Nov 15, 2023
1 parent 5707669 commit 4508b98
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 29 deletions.
23 changes: 15 additions & 8 deletions src/pytest_bdd/plugin.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
"""Pytest plugin entry point. Used for any fixtures needed."""
from __future__ import annotations

from typing import TYPE_CHECKING, Callable, cast
from typing import TYPE_CHECKING, Any, Callable, Generator, TypeVar, cast

import pytest
from typing_extensions import ParamSpec

from . import cucumber_json, generation, gherkin_terminal_reporter, given, reporting, then, when
from .utils import CONFIG_STACK

if TYPE_CHECKING:
from typing import Any, Generator

from _pytest.config import Config, PytestPluginManager
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureRequest
Expand All @@ -21,6 +20,10 @@
from .parser import Feature, Scenario, Step


P = ParamSpec("P")
T = TypeVar("T")


def pytest_addhooks(pluginmanager: PytestPluginManager) -> None:
"""Register plugin hooks."""
from pytest_bdd import hooks
Expand Down Expand Up @@ -93,7 +96,7 @@ def pytest_bdd_step_error(
feature: Feature,
scenario: Scenario,
step: Step,
step_func: Callable,
step_func: Callable[..., Any],
step_func_args: dict,
exception: Exception,
) -> None:
Expand All @@ -102,7 +105,11 @@ def pytest_bdd_step_error(

@pytest.hookimpl(tryfirst=True)
def pytest_bdd_before_step(
request: FixtureRequest, feature: Feature, scenario: Scenario, step: Step, step_func: Callable
request: FixtureRequest,
feature: Feature,
scenario: Scenario,
step: Step,
step_func: Callable[..., Any],
) -> None:
reporting.before_step(request, feature, scenario, step, step_func)

Expand All @@ -113,7 +120,7 @@ def pytest_bdd_after_step(
feature: Feature,
scenario: Scenario,
step: Step,
step_func: Callable,
step_func: Callable[..., Any],
step_func_args: dict[str, Any],
) -> None:
reporting.after_step(request, feature, scenario, step, step_func, step_func_args)
Expand All @@ -123,7 +130,7 @@ def pytest_cmdline_main(config: Config) -> int | None:
return generation.cmdline_main(config)


def pytest_bdd_apply_tag(tag: str, function: Callable) -> Callable:
def pytest_bdd_apply_tag(tag: str, function: Callable[P, T]) -> Callable[P, T]:
mark = getattr(pytest.mark, tag)
marked = mark(function)
return cast(Callable, marked)
return cast(Callable[P, T], marked)
10 changes: 8 additions & 2 deletions src/pytest_bdd/reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,15 +155,21 @@ def step_error(
feature: Feature,
scenario: Scenario,
step: Step,
step_func: Callable,
step_func: Callable[..., Any],
step_func_args: dict,
exception: Exception,
) -> None:
"""Finalize the step report as failed."""
request.node.__scenario_report__.fail()


def before_step(request: FixtureRequest, feature: Feature, scenario: Scenario, step: Step, step_func: Callable) -> None:
def before_step(
request: FixtureRequest,
feature: Feature,
scenario: Scenario,
step: Step,
step_func: Callable[..., Any],
) -> None:
"""Store step start time."""
request.node.__scenario_report__.add_step_report(StepReport(step=step))

Expand Down
20 changes: 12 additions & 8 deletions src/pytest_bdd/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,25 @@
import logging
import os
import re
from typing import TYPE_CHECKING, Callable, Iterator, cast
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, TypeVar, cast

import pytest
from _pytest.fixtures import FixtureDef, FixtureManager, FixtureRequest, call_fixture_func
from _pytest.nodes import iterparentnodeids
from typing_extensions import ParamSpec

from . import exceptions
from .feature import get_feature, get_features
from .steps import StepFunctionContext, get_step_fixture_name, inject_fixture
from .utils import CONFIG_STACK, get_args, get_caller_module_locals, get_caller_module_path

if TYPE_CHECKING:
from typing import Any, Iterable

from _pytest.mark.structures import ParameterSet

from .parser import Feature, Scenario, ScenarioTemplate, Step

P = ParamSpec("P")
T = TypeVar("T")

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -197,14 +198,14 @@ def _execute_scenario(feature: Feature, scenario: Scenario, request: FixtureRequ

def _get_scenario_decorator(
feature: Feature, feature_name: str, templated_scenario: ScenarioTemplate, scenario_name: str
) -> Callable[[Callable], Callable]:
) -> Callable[[Callable[P, T]], Callable[P, T]]:
# HACK: Ideally we would use `def decorator(fn)`, but we want to return a custom exception
# when the decorator is misused.
# Pytest inspect the signature to determine the required fixtures, and in that case it would look
# for a fixture called "fn" that doesn't exist (if it exists then it's even worse).
# It will error with a "fixture 'fn' not found" message instead.
# We can avoid this hack by using a pytest hook and check for misuse instead.
def decorator(*args: Callable) -> Callable:
def decorator(*args: Callable[P, T]) -> Callable[P, T]:
if not args:
raise exceptions.ScenarioIsDecoratorOnly(
"scenario function can only be used as a decorator. Refer to the documentation."
Expand Down Expand Up @@ -236,7 +237,7 @@ def scenario_wrapper(request: FixtureRequest, _pytest_bdd_example: dict[str, str

scenario_wrapper.__doc__ = f"{feature_name}: {scenario_name}"
scenario_wrapper.__scenario__ = templated_scenario
return cast(Callable, scenario_wrapper)
return cast(Callable[P, T], scenario_wrapper)

return decorator

Expand All @@ -254,8 +255,11 @@ def collect_example_parametrizations(


def scenario(
feature_name: str, scenario_name: str, encoding: str = "utf-8", features_base_dir=None
) -> Callable[[Callable], Callable]:
feature_name: str,
scenario_name: str,
encoding: str = "utf-8",
features_base_dir: str | None = None,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Scenario decorator.
:param str feature_name: Feature file name. Absolute or relative to the configured feature base path.
Expand Down
22 changes: 12 additions & 10 deletions src/pytest_bdd/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,15 @@ def _(article):

import pytest
from _pytest.fixtures import FixtureDef, FixtureRequest
from typing_extensions import ParamSpec

from .parser import Step
from .parsers import StepParser, get_parser
from .types import GIVEN, THEN, WHEN
from .utils import get_caller_module_locals

TCallable = TypeVar("TCallable", bound=Callable[..., Any])
P = ParamSpec("P")
T = TypeVar("T")


@enum.unique
Expand All @@ -74,10 +76,10 @@ def get_step_fixture_name(step: Step) -> str:

def given(
name: str | StepParser,
converters: dict[str, Callable] | None = None,
converters: dict[str, Callable[[Any], Any]] | None = None,
target_fixture: str | None = None,
stacklevel: int = 1,
) -> Callable:
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Given step decorator.
:param name: Step name or a parser object.
Expand All @@ -93,10 +95,10 @@ def given(

def when(
name: str | StepParser,
converters: dict[str, Callable] | None = None,
converters: dict[str, Callable[[Any], Any]] | None = None,
target_fixture: str | None = None,
stacklevel: int = 1,
) -> Callable:
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""When step decorator.
:param name: Step name or a parser object.
Expand All @@ -112,10 +114,10 @@ def when(

def then(
name: str | StepParser,
converters: dict[str, Callable] | None = None,
converters: dict[str, Callable[[Any], Any]] | None = None,
target_fixture: str | None = None,
stacklevel: int = 1,
) -> Callable:
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Then step decorator.
:param name: Step name or a parser object.
Expand All @@ -132,10 +134,10 @@ def then(
def step(
name: str | StepParser,
type_: Literal["given", "when", "then"] | None = None,
converters: dict[str, Callable] | None = None,
converters: dict[str, Callable[[Any], Any]] | None = None,
target_fixture: str | None = None,
stacklevel: int = 1,
) -> Callable[[TCallable], TCallable]:
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Generic step decorator.
:param name: Step name as in the feature file.
Expand All @@ -155,7 +157,7 @@ def step(
if converters is None:
converters = {}

def decorator(func: TCallable) -> TCallable:
def decorator(func: Callable[P, T]) -> Callable[P, T]:
parser = get_parser(name)

context = StepFunctionContext(
Expand Down
2 changes: 1 addition & 1 deletion src/pytest_bdd/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
CONFIG_STACK: list[Config] = []


def get_args(func: Callable) -> list[str]:
def get_args(func: Callable[..., Any]) -> list[str]:
"""Get a list of argument names for a function.
:param func: The function to inspect.
Expand Down

0 comments on commit 4508b98

Please sign in to comment.