From 5c6f3670df90f460d6a183cb9495dff150656412 Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Sat, 21 Dec 2024 14:27:33 +0100 Subject: [PATCH] feat: drop support for Python 3.8 (#3730) --- .github/workflows/test.yml | 1 - .pre-commit-config.yaml | 2 +- RELEASE.md | 7 ++ federation-compatibility/schema.py | 6 +- noxfile.py | 6 +- poetry.lock | 88 +-------------- pyproject.toml | 13 +-- strawberry/__init__.py | 18 +-- strawberry/aiohttp/test/client.py | 18 +-- strawberry/aiohttp/views.py | 12 +- strawberry/annotation.py | 27 ++--- strawberry/asgi/__init__.py | 9 +- strawberry/asgi/test/client.py | 17 +-- strawberry/chalice/views.py | 6 +- strawberry/channels/__init__.py | 2 +- strawberry/channels/handlers/base.py | 10 +- strawberry/channels/handlers/http_handler.py | 11 +- strawberry/channels/handlers/ws_handler.py | 7 +- strawberry/channels/testing.py | 14 +-- strawberry/cli/commands/codegen.py | 18 +-- strawberry/cli/commands/upgrade/__init__.py | 5 +- .../cli/commands/upgrade/_run_codemod.py | 12 +- strawberry/codegen/exceptions.py | 4 +- strawberry/codegen/plugins/print_operation.py | 12 +- strawberry/codegen/plugins/python.py | 12 +- strawberry/codegen/plugins/typescript.py | 6 +- strawberry/codegen/query_codegen.py | 63 +++++------ strawberry/codegen/types.py | 69 ++++++------ strawberry/codemods/annotated_unions.py | 7 +- strawberry/dataloader.py | 33 +++--- strawberry/directive.py | 17 ++- strawberry/django/test/client.py | 8 +- strawberry/django/views.py | 9 +- strawberry/exceptions/__init__.py | 48 ++++---- .../exceptions/conflicting_arguments.py | 4 +- strawberry/exceptions/duplicated_type_name.py | 6 +- strawberry/exceptions/handler.py | 14 +-- strawberry/exceptions/invalid_union_type.py | 4 +- .../missing_arguments_annotations.py | 4 +- .../exceptions/missing_field_annotation.py | 4 +- .../exceptions/object_is_not_an_enum.py | 4 +- .../exceptions/private_strawberry_field.py | 4 +- strawberry/exceptions/syntax.py | 8 +- strawberry/exceptions/utils/source_finder.py | 13 ++- strawberry/experimental/pydantic/__init__.py | 6 +- strawberry/experimental/pydantic/_compat.py | 28 ++--- .../experimental/pydantic/conversion.py | 4 +- .../experimental/pydantic/conversion_types.py | 6 +- .../experimental/pydantic/error_type.py | 34 +++--- .../experimental/pydantic/exceptions.py | 10 +- strawberry/experimental/pydantic/fields.py | 15 +-- .../experimental/pydantic/object_type.py | 42 ++++--- strawberry/experimental/pydantic/utils.py | 16 +-- strawberry/ext/dataclasses/dataclasses.py | 6 +- strawberry/ext/mypy_plugin.py | 15 +-- strawberry/extensions/__init__.py | 15 ++- strawberry/extensions/add_validation_rules.py | 8 +- strawberry/extensions/base_extension.py | 8 +- strawberry/extensions/context.py | 29 ++--- strawberry/extensions/directives.py | 4 +- strawberry/extensions/disable_validation.py | 2 +- strawberry/extensions/field_extension.py | 3 +- strawberry/extensions/mask_errors.py | 5 +- strawberry/extensions/max_aliases.py | 4 +- strawberry/extensions/max_tokens.py | 2 +- strawberry/extensions/parser_cache.py | 3 +- strawberry/extensions/pyinstrument.py | 5 +- strawberry/extensions/query_depth_limiter.py | 26 ++--- strawberry/extensions/runner.py | 14 +-- strawberry/extensions/tracing/apollo.py | 20 ++-- strawberry/extensions/tracing/datadog.py | 4 +- .../extensions/tracing/opentelemetry.py | 17 ++- strawberry/extensions/utils.py | 6 +- strawberry/extensions/validation_cache.py | 3 +- strawberry/fastapi/context.py | 6 +- strawberry/fastapi/router.py | 23 ++-- strawberry/federation/__init__.py | 8 +- strawberry/federation/argument.py | 3 +- strawberry/federation/enum.py | 16 +-- strawberry/federation/field.py | 53 +++++---- strawberry/federation/object_type.py | 50 ++++----- strawberry/federation/scalar.py | 15 ++- strawberry/federation/schema.py | 66 +++++------ strawberry/federation/schema_directive.py | 10 +- strawberry/federation/schema_directives.py | 28 ++--- strawberry/federation/union.py | 5 +- strawberry/field_extensions/input_mutation.py | 3 +- strawberry/file_uploads/utils.py | 7 +- strawberry/flask/views.py | 5 +- strawberry/http/__init__.py | 12 +- strawberry/http/async_base_view.py | 23 ++-- strawberry/http/base.py | 9 +- strawberry/http/ides.py | 2 +- strawberry/http/parse_content_type.py | 3 +- strawberry/http/sync_base_view.py | 8 +- strawberry/http/temporal_response.py | 3 +- strawberry/http/types.py | 5 +- strawberry/litestar/controller.py | 22 ++-- strawberry/parent.py | 3 +- strawberry/permission.py | 14 +-- strawberry/printer/ast_from_value.py | 3 +- strawberry/printer/printer.py | 80 ++++++++----- strawberry/quart/views.py | 6 +- strawberry/relay/exceptions.py | 8 +- strawberry/relay/fields.py | 46 ++++---- strawberry/relay/types.py | 56 +++++----- strawberry/relay/utils.py | 8 +- strawberry/sanic/utils.py | 8 +- strawberry/sanic/views.py | 12 +- strawberry/scalars.py | 4 +- strawberry/schema/base.py | 27 +++-- strawberry/schema/compat.py | 8 +- strawberry/schema/execute.py | 16 +-- strawberry/schema/name_converter.py | 6 +- strawberry/schema/schema.py | 62 ++++++----- strawberry/schema/schema_converter.py | 46 ++++---- strawberry/schema/subscribe.py | 7 +- strawberry/schema/types/base_scalars.py | 2 +- strawberry/schema/types/concrete_type.py | 4 +- strawberry/schema/types/scalar.py | 7 +- strawberry/schema_codegen/__init__.py | 8 +- strawberry/schema_directive.py | 16 +-- .../graphql_transport_ws/handlers.py | 17 ++- .../protocols/graphql_transport_ws/types.py | 32 +++--- .../protocols/graphql_ws/handlers.py | 11 +- .../protocols/graphql_ws/types.py | 26 ++--- strawberry/test/__init__.py | 2 +- strawberry/test/client.py | 40 +++---- strawberry/tools/create_type.py | 7 +- strawberry/tools/merge_types.py | 3 +- strawberry/types/__init__.py | 2 +- strawberry/types/arguments.py | 22 ++-- strawberry/types/auto.py | 4 +- strawberry/types/base.py | 38 +++---- strawberry/types/enum.py | 8 +- strawberry/types/execution.py | 20 ++-- strawberry/types/field.py | 57 +++++----- strawberry/types/fields/resolver.py | 32 +++--- strawberry/types/graphql.py | 4 +- strawberry/types/info.py | 14 +-- strawberry/types/lazy_type.py | 8 +- strawberry/types/mutation.py | 53 +++++---- strawberry/types/nodes.py | 20 ++-- strawberry/types/object_type.py | 30 +++-- strawberry/types/private.py | 3 +- strawberry/types/scalar.py | 4 +- strawberry/types/type_resolver.py | 10 +- strawberry/types/union.py | 19 ++-- strawberry/types/unset.py | 6 +- strawberry/utils/aio.py | 11 +- strawberry/utils/await_maybe.py | 5 +- strawberry/utils/debug.py | 4 +- strawberry/utils/deprecations.py | 4 +- strawberry/utils/inspect.py | 8 +- strawberry/utils/str_converters.py | 2 +- strawberry/utils/typing.py | 105 +++++++----------- tests/a.py | 3 +- tests/b.py | 5 +- tests/benchmarks/api.py | 6 +- tests/benchmarks/schema.py | 11 +- tests/benchmarks/test_execute.py | 12 +- .../test_execute_with_extensions.py | 6 +- tests/benchmarks/test_generic_input.py | 10 +- tests/benchmarks/test_subscriptions.py | 2 +- tests/chalice/app.py | 4 +- tests/channels/test_layers.py | 3 +- tests/channels/test_testing.py | 3 +- tests/cli/test_codegen.py | 9 +- tests/codegen/conftest.py | 33 +++--- .../codegen/snapshots/python/generic_types.py | 4 +- .../snapshots/python/multiple_types.py | 2 +- .../snapshots/python/mutation_with_object.py | 4 +- .../python/nullable_list_of_non_scalars.py | 2 +- .../snapshots/python/optional_and_lists.py | 6 +- tests/codegen/snapshots/python/variables.py | 10 +- tests/codegen/test_query_codegen.py | 3 +- tests/conftest.py | 8 +- tests/d.py | 5 +- tests/django/test_dataloaders.py | 7 +- .../pydantic/schema/test_basic.py | 12 +- .../pydantic/schema/test_federation.py | 2 +- .../pydantic/schema/test_mutation.py | 4 +- tests/experimental/pydantic/test_basic.py | 19 ++-- .../experimental/pydantic/test_conversion.py | 49 ++++---- .../experimental/pydantic/test_error_type.py | 16 +-- tests/experimental/pydantic/test_fields.py | 5 +- tests/fastapi/app.py | 4 +- tests/fastapi/test_context.py | 4 +- .../federation/printer/test_authenticated.py | 5 +- tests/federation/printer/test_entities.py | 9 +- tests/federation/printer/test_inaccessible.py | 7 +- tests/federation/printer/test_interface.py | 3 +- tests/federation/printer/test_keys.py | 9 +- tests/federation/printer/test_override.py | 5 +- tests/federation/printer/test_policy.py | 5 +- tests/federation/printer/test_provides.py | 9 +- tests/federation/printer/test_requires.py | 5 +- .../printer/test_requires_scopes.py | 5 +- tests/federation/printer/test_shareable.py | 3 +- tests/federation/printer/test_tag.py | 5 +- tests/federation/test_entities.py | 16 +-- tests/federation/test_schema.py | 20 ++-- tests/fields/test_arguments.py | 50 ++++----- tests/fields/test_field_defaults.py | 4 +- tests/fields/test_resolvers.py | 4 +- tests/http/clients/aiohttp.py | 23 ++-- tests/http/clients/asgi.py | 23 ++-- tests/http/clients/async_flask.py | 4 +- tests/http/clients/base.py | 64 +++++------ tests/http/clients/chalice.py | 16 +-- tests/http/clients/channels.py | 25 +++-- tests/http/clients/django.py | 22 ++-- tests/http/clients/fastapi.py | 19 ++-- tests/http/clients/flask.py | 20 ++-- tests/http/clients/litestar.py | 19 ++-- tests/http/clients/quart.py | 16 +-- tests/http/clients/sanic.py | 14 +-- tests/http/conftest.py | 7 +- tests/http/context.py | 5 +- tests/http/test_graphql_ide.py | 12 +- tests/http/test_multipart_subscription.py | 3 +- tests/http/test_parse_content_type.py | 4 +- tests/http/test_upload.py | 5 +- tests/litestar/test_context.py | 4 +- .../objects/generics/test_generic_objects.py | 23 ++-- tests/objects/generics/test_names.py | 7 +- tests/plugins/strawberry_exceptions.py | 4 +- tests/python_312/test_generic_objects.py | 23 ++-- tests/python_312/test_generics_schema.py | 26 ++--- tests/relay/schema.py | 24 ++-- tests/relay/schema_future_annotations.py | 24 ++-- tests/relay/test_connection.py | 13 ++- tests/relay/test_exceptions.py | 14 +-- tests/relay/test_fields.py | 7 +- tests/relay/test_schema.py | 7 +- tests/relay/test_types.py | 3 +- .../extensions/schema_extensions/conftest.py | 9 +- .../schema_extensions/test_extensions.py | 8 +- .../schema_extensions/test_subscription.py | 8 +- tests/schema/extensions/test_datadog.py | 7 +- .../extensions/test_field_extensions.py | 3 +- .../schema/extensions/test_input_mutation.py | 2 +- .../test_input_mutation_federation.py | 2 +- .../extensions/test_query_depth_limiter.py | 8 +- tests/schema/test_annotated/type_a.py | 3 +- tests/schema/test_annotated/type_b.py | 3 +- tests/schema/test_arguments.py | 3 +- tests/schema/test_dataloaders.py | 3 +- tests/schema/test_directives.py | 10 +- tests/schema/test_enum.py | 9 +- tests/schema/test_extensions.py | 5 +- tests/schema/test_generics.py | 30 ++--- tests/schema/test_generics_nested.py | 80 ++++++------- tests/schema/test_info.py | 7 +- tests/schema/test_interface.py | 6 +- tests/schema/test_lazy/test_lazy_generic.py | 8 +- tests/schema/test_lazy/type_a.py | 5 +- tests/schema/test_lazy/type_b.py | 11 +- tests/schema/test_lazy/type_c.py | 3 +- tests/schema/test_lazy/type_d.py | 3 +- tests/schema/test_lazy_types/type_a.py | 4 +- tests/schema/test_list.py | 10 +- tests/schema/test_name_converter.py | 4 +- tests/schema/test_permission.py | 6 +- tests/schema/test_resolvers.py | 16 +-- tests/schema/test_schema_generation.py | 6 +- tests/schema/test_schema_hooks.py | 3 +- tests/schema/test_subscription.py | 21 +--- tests/schema/test_union.py | 15 +-- tests/test/conftest.py | 3 +- tests/test_auto.py | 4 +- tests/test_dataloaders.py | 29 ++--- tests/test_forward_references.py | 44 ++------ tests/test_printer/test_basic.py | 10 +- tests/test_printer/test_schema_directives.py | 11 +- tests/typecheckers/test_relay.py | 91 ++++++++------- tests/typecheckers/utils/mypy.py | 8 +- tests/typecheckers/utils/pyright.py | 6 +- tests/types/cross_module_resolvers/a_mod.py | 4 +- tests/types/cross_module_resolvers/b_mod.py | 4 +- tests/types/cross_module_resolvers/c_mod.py | 28 +++-- .../test_cross_module_resolvers.py | 58 +++++----- tests/types/cross_module_resolvers/x_mod.py | 5 +- tests/types/resolving/test_generics.py | 6 +- tests/types/resolving/test_lists.py | 69 ++++-------- tests/types/resolving/test_optionals.py | 10 +- .../resolving/test_string_annotations.py | 22 ++-- tests/types/resolving/test_unions.py | 3 +- tests/types/test_argument_types.py | 6 +- tests/types/test_convert_to_dictionary.py | 6 +- tests/types/test_field_types.py | 6 +- tests/types/test_lazy_types.py | 4 +- .../test_lazy_types_future_annotations.py | 2 +- tests/types/test_object_types.py | 11 +- tests/types/test_resolver_types.py | 6 +- tests/utils/test_arguments_converter.py | 15 ++- tests/utils/test_typing.py | 41 +------ tests/utils/test_typing_forward_refs.py | 17 +-- tests/views/schema.py | 13 ++- tests/websockets/conftest.py | 7 +- tests/websockets/test_graphql_transport_ws.py | 11 +- tests/websockets/test_graphql_ws.py | 3 +- tests/websockets/test_websockets.py | 10 +- tests/websockets/views.py | 6 +- 304 files changed, 1982 insertions(+), 2330 deletions(-) create mode 100644 RELEASE.md diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index bf9c4fe53b..6f544df95a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -57,7 +57,6 @@ jobs: - uses: actions/setup-python@v5 with: python-version: | - 3.8 3.9 3.10 3.11 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6c4163edbd..e08ba53530 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.9 + rev: v0.8.3 hooks: - id: ruff-format exclude: ^tests/\w+/snapshots/ diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..6388bb6e73 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,7 @@ +Release type: minor + +This release drops support for Python 3.8, which reached its end-of-life (EOL) +in October 2024. The minimum supported Python version is now 3.9. + +We strongly recommend upgrading to Python 3.9 or a newer version, as older +versions are no longer maintained and may contain security vulnerabilities. diff --git a/federation-compatibility/schema.py b/federation-compatibility/schema.py index b3d5c5bafa..4952977ee3 100644 --- a/federation-compatibility/schema.py +++ b/federation-compatibility/schema.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional +from typing import Any, Optional import strawberry from strawberry.schema_directive import Location @@ -252,7 +252,7 @@ def created_by(self) -> Optional[User]: return User(**user) notes: Optional[str] = strawberry.federation.field(tags=["internal"]) - research: List[ProductResearch] + research: list[ProductResearch] @classmethod def from_data(cls, data: dict) -> "Product": @@ -290,7 +290,7 @@ def resolve_reference(cls, **data: Any) -> Optional["Product"]: @strawberry.federation.interface_object(keys=["id"]) class Inventory: id: strawberry.ID - deprecated_products: List[DeprecatedProduct] + deprecated_products: list[DeprecatedProduct] @classmethod def resolve_reference(cls, id: strawberry.ID) -> "Inventory": diff --git a/noxfile.py b/noxfile.py index cf9d3d450f..699077dd63 100644 --- a/noxfile.py +++ b/noxfile.py @@ -1,5 +1,5 @@ import itertools -from typing import Any, Callable, List +from typing import Any, Callable import nox from nox_poetry import Session, session @@ -8,7 +8,7 @@ nox.options.error_on_external_run = True nox.options.default_venv_backend = "uv" -PYTHON_VERSIONS = ["3.13", "3.12", "3.11", "3.10", "3.9", "3.8"] +PYTHON_VERSIONS = ["3.13", "3.12", "3.11", "3.10", "3.9"] GQL_CORE_VERSIONS = [ "3.2.3", @@ -54,7 +54,7 @@ def _install_gql_core(session: Session, version: str) -> None: ) -def with_gql_core_parametrize(name: str, params: List[str]) -> Callable[[Any], Any]: +def with_gql_core_parametrize(name: str, params: list[str]) -> Callable[[Any], Any]: # github cache doesn't support comma in the name, this is a workaround. arg_names = f"{name}, gql_core" combinations = list(itertools.product(params, GQL_CORE_VERSIONS)) diff --git a/poetry.lock b/poetry.lock index 9265195a9b..5428088851 100644 --- a/poetry.lock +++ b/poetry.lock @@ -159,9 +159,6 @@ files = [ {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, ] -[package.dependencies] -typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""} - [[package]] name = "ansicon" version = "1.89.0" @@ -241,21 +238,6 @@ files = [ astroid = ["astroid (>=2,<4)"] test = ["astroid (>=2,<4)", "pytest", "pytest-cov", "pytest-xdist"] -[[package]] -name = "astunparse" -version = "1.6.3" -description = "An AST unparser for Python" -optional = false -python-versions = "*" -files = [ - {file = "astunparse-1.6.3-py2.py3-none-any.whl", hash = "sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8"}, - {file = "astunparse-1.6.3.tar.gz", hash = "sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872"}, -] - -[package.dependencies] -six = ">=1.6.1,<2.0" -wheel = ">=0.23.0,<1.0" - [[package]] name = "async-timeout" version = "5.0.1" @@ -331,34 +313,6 @@ typing-extensions = {version = "*", markers = "python_version < \"3.10\""} [package.extras] visualize = ["Twisted (>=16.1.1)", "graphviz (>0.5.1)"] -[[package]] -name = "backports-zoneinfo" -version = "0.2.1" -description = "Backport of the standard library zoneinfo module" -optional = false -python-versions = ">=3.6" -files = [ - {file = "backports.zoneinfo-0.2.1-cp36-cp36m-macosx_10_14_x86_64.whl", hash = "sha256:da6013fd84a690242c310d77ddb8441a559e9cb3d3d59ebac9aca1a57b2e18bc"}, - {file = "backports.zoneinfo-0.2.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:89a48c0d158a3cc3f654da4c2de1ceba85263fafb861b98b59040a5086259722"}, - {file = "backports.zoneinfo-0.2.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:1c5742112073a563c81f786e77514969acb58649bcdf6cdf0b4ed31a348d4546"}, - {file = "backports.zoneinfo-0.2.1-cp36-cp36m-win32.whl", hash = "sha256:e8236383a20872c0cdf5a62b554b27538db7fa1bbec52429d8d106effbaeca08"}, - {file = "backports.zoneinfo-0.2.1-cp36-cp36m-win_amd64.whl", hash = "sha256:8439c030a11780786a2002261569bdf362264f605dfa4d65090b64b05c9f79a7"}, - {file = "backports.zoneinfo-0.2.1-cp37-cp37m-macosx_10_14_x86_64.whl", hash = "sha256:f04e857b59d9d1ccc39ce2da1021d196e47234873820cbeaad210724b1ee28ac"}, - {file = "backports.zoneinfo-0.2.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:17746bd546106fa389c51dbea67c8b7c8f0d14b5526a579ca6ccf5ed72c526cf"}, - {file = "backports.zoneinfo-0.2.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:5c144945a7752ca544b4b78c8c41544cdfaf9786f25fe5ffb10e838e19a27570"}, - {file = "backports.zoneinfo-0.2.1-cp37-cp37m-win32.whl", hash = "sha256:e55b384612d93be96506932a786bbcde5a2db7a9e6a4bb4bffe8b733f5b9036b"}, - {file = "backports.zoneinfo-0.2.1-cp37-cp37m-win_amd64.whl", hash = "sha256:a76b38c52400b762e48131494ba26be363491ac4f9a04c1b7e92483d169f6582"}, - {file = "backports.zoneinfo-0.2.1-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:8961c0f32cd0336fb8e8ead11a1f8cd99ec07145ec2931122faaac1c8f7fd987"}, - {file = "backports.zoneinfo-0.2.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:e81b76cace8eda1fca50e345242ba977f9be6ae3945af8d46326d776b4cf78d1"}, - {file = "backports.zoneinfo-0.2.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:7b0a64cda4145548fed9efc10322770f929b944ce5cee6c0dfe0c87bf4c0c8c9"}, - {file = "backports.zoneinfo-0.2.1-cp38-cp38-win32.whl", hash = "sha256:1b13e654a55cd45672cb54ed12148cd33628f672548f373963b0bff67b217328"}, - {file = "backports.zoneinfo-0.2.1-cp38-cp38-win_amd64.whl", hash = "sha256:4a0f800587060bf8880f954dbef70de6c11bbe59c673c3d818921f042f9954a6"}, - {file = "backports.zoneinfo-0.2.1.tar.gz", hash = "sha256:fadbfe37f74051d024037f223b8e001611eac868b5c5b06144ef4d8b799862f2"}, -] - -[package.extras] -tzdata = ["tzdata"] - [[package]] name = "black" version = "24.8.0" @@ -1025,7 +979,6 @@ files = [ [package.dependencies] asgiref = ">=3.6.0,<4" -"backports.zoneinfo" = {version = "*", markers = "python_version < \"3.9\""} sqlparse = ">=0.3.1" tzdata = {version = "*", markers = "sys_platform == \"win32\""} @@ -1400,17 +1353,6 @@ files = [ {file = "frozenlist-1.5.0.tar.gz", hash = "sha256:81d5af29e61b9c8348e876d442253723928dce6433e0e76cd925cd83f1b4b817"}, ] -[[package]] -name = "graphlib-backport" -version = "1.1.0" -description = "Backport of the Python 3.9 graphlib module for Python 3.6+" -optional = false -python-versions = ">=3.6,<4.0" -files = [ - {file = "graphlib_backport-1.1.0-py3-none-any.whl", hash = "sha256:eccacf9f2126cdf89ce32a6018c88e1ecd3e4898a07568add6e1907a439055ba"}, - {file = "graphlib_backport-1.1.0.tar.gz", hash = "sha256:00a7888b21e5393064a133209cb5d3b3ef0a2096cf023914c9d778dff5644125"}, -] - [[package]] name = "graphql-core" version = "3.2.5" @@ -1662,28 +1604,6 @@ perf = ["ipython"] test = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-perf (>=0.9.2)"] type = ["pytest-mypy"] -[[package]] -name = "importlib-resources" -version = "6.4.5" -description = "Read resources from Python packages" -optional = false -python-versions = ">=3.8" -files = [ - {file = "importlib_resources-6.4.5-py3-none-any.whl", hash = "sha256:ac29d5f956f01d5e4bb63102a5a19957f1b9175e45649977264a1416783bb717"}, - {file = "importlib_resources-6.4.5.tar.gz", hash = "sha256:980862a1d16c9e147a59603677fa2aa5fd82b87f223b6cb870695bcfce830065"}, -] - -[package.dependencies] -zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""} - -[package.extras] -check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] -cover = ["pytest-cov"] -doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -enabler = ["pytest-enabler (>=2.2)"] -test = ["jaraco.test (>=5.4)", "pytest (>=6,!=8.1.*)", "zipp (>=3.17)"] -type = ["pytest-mypy"] - [[package]] name = "incremental" version = "24.7.2" @@ -1859,7 +1779,6 @@ files = [ [package.dependencies] importlib-metadata = {version = ">=4.11.4", markers = "python_version < \"3.12\""} -importlib-resources = {version = "*", markers = "python_version < \"3.9\""} "jaraco.classes" = "*" jeepney = {version = ">=0.4.2", markers = "sys_platform == \"linux\""} pywin32-ctypes = {version = ">=0.2.0", markers = "sys_platform == \"win32\""} @@ -1935,7 +1854,6 @@ click = "*" exceptiongroup = {version = "*", markers = "python_version < \"3.11\""} httpx = ">=0.22" importlib-metadata = {version = "*", markers = "python_version < \"3.10\""} -importlib-resources = {version = ">=5.12.0", markers = "python_version < \"3.9\""} litestar-htmx = ">=0.3.0" msgspec = ">=0.18.2" multidict = ">=6.0.2" @@ -4840,7 +4758,7 @@ aiohttp = ["aiohttp"] asgi = ["python-multipart", "starlette"] chalice = ["chalice"] channels = ["asgiref", "channels"] -cli = ["graphlib_backport", "libcst", "pygments", "rich", "typer"] +cli = ["libcst", "pygments", "rich", "typer"] debug = ["libcst", "rich"] debug-server = ["libcst", "pygments", "python-multipart", "rich", "starlette", "typer", "uvicorn"] django = ["Django", "asgiref"] @@ -4855,5 +4773,5 @@ sanic = ["sanic"] [metadata] lock-version = "2.0" -python-versions = "^3.8" -content-hash = "b79e8f6ac8156a509e46956feec5bb4e8bd369938b7fba0d7c6f0aa58a9cb14a" +python-versions = "^3.9" +content-hash = "7ff91376b60e5b89d4a0e2016507c73c73d1063ba9a3dff6f1400473a820bab9" diff --git a/pyproject.toml b/pyproject.toml index 0fa895578c..af146e54b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ requires = ["poetry-core>=1.6"] build-backend = "poetry.core.masonry.api" [tool.poetry.dependencies] -python = "^3.8" +python = "^3.9" graphql-core = ">=3.2.0,<3.4.0" typing-extensions = ">=4.5.0" python-dateutil = "^2.7.0" @@ -54,13 +54,11 @@ python-multipart = {version = ">=0.0.7", optional = true} sanic = {version = ">=20.12.2", optional = true} aiohttp = {version = "^3.7.4.post0", optional = true} fastapi = {version = ">=0.65.2", optional = true} -litestar = {version = ">=2", optional = true, python = ">=3.8"} +litestar = {version = ">=2", optional = true} channels = {version = ">=3.0.5", optional = true} -astunparse = {version = "^1.6.3", python = "<3.9"} libcst = {version = ">=0.4.7", optional = true} rich = {version = ">=12.0.0", optional = true} pyinstrument = {version = ">=4.0.0", optional = true} -graphlib_backport = {version = "*", python = "<3.9", optional = true} [tool.poetry.group.dev.dependencies] asgiref = "^3.2" @@ -195,12 +193,12 @@ append-github-contributor = true exclude = ["**/__pycache__",] reportMissingImports = true reportMissingTypeStubs = false -pythonVersion = "3.8" +pythonVersion = "3.9" stubPath = "" [tool.ruff] line-length = 88 -target-version = "py38" +target-version = "py39" fix = true exclude = [ ".bzr", @@ -240,9 +238,7 @@ ignore = [ "S102", "S104", "S324", - "ANN101", # missing annotation for self? # definitely enable these, maybe not in tests - "ANN102", "ANN401", "PGH003", "PGH004", @@ -287,7 +283,6 @@ ignore = [ # enable these, we have some in tests "B006", - "PT004", "PT007", "PT011", "PT012", diff --git a/strawberry/__init__.py b/strawberry/__init__.py index 2cbda32b2f..d03f15d998 100644 --- a/strawberry/__init__.py +++ b/strawberry/__init__.py @@ -18,39 +18,39 @@ from .types.info import Info from .types.lazy_type import LazyType, lazy from .types.mutation import mutation, subscription -from .types.object_type import asdict, input, interface, type +from .types.object_type import asdict, input, interface, type # noqa: A004 from .types.private import Private from .types.scalar import scalar from .types.union import union from .types.unset import UNSET __all__ = [ - "BasePermission", - "experimental", "ID", - "Info", "UNSET", - "lazy", + "BasePermission", + "Info", "LazyType", "Parent", "Private", "Schema", "argument", + "asdict", + "auto", "directive", "directive_field", - "schema_directive", "enum", "enum_value", + "experimental", "federation", "field", "input", "interface", + "lazy", "mutation", + "relay", "scalar", + "schema_directive", "subscription", "type", "union", - "auto", - "asdict", - "relay", ] diff --git a/strawberry/aiohttp/test/client.py b/strawberry/aiohttp/test/client.py index 0d25f4043a..86b08bf341 100644 --- a/strawberry/aiohttp/test/client.py +++ b/strawberry/aiohttp/test/client.py @@ -2,23 +2,25 @@ import warnings from typing import ( + TYPE_CHECKING, Any, - Dict, - Mapping, Optional, ) from strawberry.test.client import BaseGraphQLTestClient, Response +if TYPE_CHECKING: + from collections.abc import Mapping + class GraphQLTestClient(BaseGraphQLTestClient): async def query( self, query: str, - variables: Optional[Dict[str, Mapping]] = None, - headers: Optional[Dict[str, object]] = None, + variables: Optional[dict[str, Mapping]] = None, + headers: Optional[dict[str, object]] = None, asserts_errors: Optional[bool] = None, - files: Optional[Dict[str, object]] = None, + files: Optional[dict[str, object]] = None, assert_no_errors: Optional[bool] = True, ) -> Response: body = self._build_body(query, variables, files) @@ -51,9 +53,9 @@ async def query( async def request( self, - body: Dict[str, object], - headers: Optional[Dict[str, object]] = None, - files: Optional[Dict[str, object]] = None, + body: dict[str, object], + headers: Optional[dict[str, object]] = None, + files: Optional[dict[str, object]] = None, ) -> Any: response = await self._client.post( self.url, diff --git a/strawberry/aiohttp/views.py b/strawberry/aiohttp/views.py index aa07e34d8b..443264ecd4 100644 --- a/strawberry/aiohttp/views.py +++ b/strawberry/aiohttp/views.py @@ -8,11 +8,7 @@ from typing import ( TYPE_CHECKING, Any, - AsyncGenerator, Callable, - Dict, - Iterable, - Mapping, Optional, Union, cast, @@ -40,6 +36,8 @@ from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL if TYPE_CHECKING: + from collections.abc import AsyncGenerator, Iterable, Mapping + from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.schema import BaseSchema @@ -67,8 +65,8 @@ def headers(self) -> Mapping[str, str]: async def get_form_data(self) -> FormData: reader = await self.request.multipart() - data: Dict[str, Any] = {} - files: Dict[str, Any] = {} + data: dict[str, Any] = {} + files: dict[str, Any] = {} async for field in reader: assert isinstance(field, BodyPartReader) @@ -224,7 +222,7 @@ async def create_streaming_response( request: web.Request, stream: Callable[[], AsyncGenerator[str, None]], sub_response: web.Response, - headers: Dict[str, str], + headers: dict[str, str], ) -> web.StreamResponse: response = web.StreamResponse( status=sub_response.status, diff --git a/strawberry/annotation.py b/strawberry/annotation.py index 8934f5bad8..d3fb65deba 100644 --- a/strawberry/annotation.py +++ b/strawberry/annotation.py @@ -7,18 +7,15 @@ from enum import Enum from typing import ( TYPE_CHECKING, + Annotated, Any, - Dict, ForwardRef, - List, Optional, - Tuple, - Type, TypeVar, Union, cast, ) -from typing_extensions import Annotated, Self, get_args, get_origin +from typing_extensions import Self, get_args, get_origin from strawberry.types.base import ( StrawberryList, @@ -54,13 +51,13 @@ class StrawberryAnnotation: - __slots__ = "raw_annotation", "namespace", "__resolve_cache__" + __slots__ = "__resolve_cache__", "namespace", "raw_annotation" def __init__( self, annotation: Union[object, str], *, - namespace: Optional[Dict[str, Any]] = None, + namespace: Optional[dict[str, Any]] = None, ) -> None: self.raw_annotation = annotation self.namespace = namespace @@ -78,7 +75,7 @@ def __hash__(self) -> int: @staticmethod def from_annotation( - annotation: object, namespace: Optional[Dict[str, Any]] = None + annotation: object, namespace: Optional[dict[str, Any]] = None ) -> Optional[StrawberryAnnotation]: if annotation is None: return None @@ -115,8 +112,8 @@ def evaluate(self) -> type: return evaled_type def _get_type_with_args( - self, evaled_type: Type[Any] - ) -> Tuple[Type[Any], List[Any]]: + self, evaled_type: type[Any] + ) -> tuple[type[Any], list[Any]]: if self._is_async_type(evaled_type): return self._get_type_with_args(self._strip_async_type(evaled_type)) @@ -140,7 +137,7 @@ def _resolve(self) -> Union[StrawberryType, type]: if is_private(evaled_type): return evaled_type - args: List[Any] = [] + args: list[Any] = [] evaled_type, args = self._get_type_with_args(evaled_type) @@ -224,7 +221,7 @@ def create_optional(self, evaled_type: Any) -> StrawberryOptional: def create_type_var(self, evaled_type: TypeVar) -> StrawberryTypeVar: return StrawberryTypeVar(evaled_type) - def create_union(self, evaled_type: Type[Any], args: list[Any]) -> StrawberryUnion: + def create_union(self, evaled_type: type[Any], args: list[Any]) -> StrawberryUnion: # Prevent import cycles from strawberry.types.union import StrawberryUnion @@ -289,7 +286,7 @@ def _is_lazy_type(cls, annotation: Any) -> bool: return isinstance(annotation, LazyType) @classmethod - def _is_optional(cls, annotation: Any, args: List[Any]) -> bool: + def _is_optional(cls, annotation: Any, args: list[Any]) -> bool: """Returns True if the annotation is Optional[SomeType].""" # Optionals are represented as unions if not cls._is_union(annotation, args): @@ -341,7 +338,7 @@ def _is_strawberry_type(cls, evaled_type: Any) -> bool: return False @classmethod - def _is_union(cls, annotation: Any, args: List[Any]) -> bool: + def _is_union(cls, annotation: Any, args: list[Any]) -> bool: """Returns True if annotation is a Union.""" # this check is needed because unions declared with the new syntax `A | B` # don't have a `__origin__` property on them, but they are instances of @@ -365,7 +362,7 @@ def _is_union(cls, annotation: Any, args: List[Any]) -> bool: return any(isinstance(arg, StrawberryUnion) for arg in args) @classmethod - def _strip_async_type(cls, annotation: Type[Any]) -> type: + def _strip_async_type(cls, annotation: type[Any]) -> type: return annotation.__args__[0] @classmethod diff --git a/strawberry/asgi/__init__.py b/strawberry/asgi/__init__.py index 9075b6eabe..2fd33a4210 100644 --- a/strawberry/asgi/__init__.py +++ b/strawberry/asgi/__init__.py @@ -5,13 +5,8 @@ from json import JSONDecodeError from typing import ( TYPE_CHECKING, - AsyncGenerator, - AsyncIterator, Callable, - Dict, - Mapping, Optional, - Sequence, Union, cast, ) @@ -46,6 +41,8 @@ from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL if TYPE_CHECKING: + from collections.abc import AsyncGenerator, AsyncIterator, Mapping, Sequence + from starlette.types import Receive, Scope, Send from strawberry.http import GraphQLHTTPResponse @@ -231,7 +228,7 @@ async def create_streaming_response( request: Request | WebSocket, stream: Callable[[], AsyncIterator[str]], sub_response: Response, - headers: Dict[str, str], + headers: dict[str, str], ) -> Response: return StreamingResponse( stream(), diff --git a/strawberry/asgi/test/client.py b/strawberry/asgi/test/client.py index 74afd78229..b6a814be30 100644 --- a/strawberry/asgi/test/client.py +++ b/strawberry/asgi/test/client.py @@ -1,11 +1,12 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional +from typing import TYPE_CHECKING, Any, Optional from strawberry.test import BaseGraphQLTestClient if TYPE_CHECKING: + from collections.abc import Mapping from typing_extensions import Literal @@ -13,10 +14,10 @@ class GraphQLTestClient(BaseGraphQLTestClient): def _build_body( self, query: str, - variables: Optional[Dict[str, Mapping]] = None, - files: Optional[Dict[str, object]] = None, - ) -> Dict[str, object]: - body: Dict[str, object] = {"query": query} + variables: Optional[dict[str, Mapping]] = None, + files: Optional[dict[str, object]] = None, + ) -> dict[str, object]: + body: dict[str, object] = {"query": query} if variables: body["variables"] = variables @@ -34,9 +35,9 @@ def _build_body( def request( self, - body: Dict[str, object], - headers: Optional[Dict[str, object]] = None, - files: Optional[Dict[str, object]] = None, + body: dict[str, object], + headers: Optional[dict[str, object]] = None, + files: Optional[dict[str, object]] = None, ) -> Any: return self._client.post( self.url, diff --git a/strawberry/chalice/views.py b/strawberry/chalice/views.py index 9c131f202d..9d5c424402 100644 --- a/strawberry/chalice/views.py +++ b/strawberry/chalice/views.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union, cast from chalice.app import Request, Response from strawberry.http.exceptions import HTTPException @@ -11,6 +11,8 @@ from strawberry.http.typevars import Context, RootValue if TYPE_CHECKING: + from collections.abc import Mapping + from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.schema import BaseSchema @@ -91,7 +93,7 @@ def error_response( message: str, error_code: str, http_status_code: int, - headers: Optional[Dict[str, str | List[str]]] = None, + headers: Optional[dict[str, str | list[str]]] = None, ) -> Response: """A wrapper for error responses. diff --git a/strawberry/channels/__init__.py b/strawberry/channels/__init__.py index f67fb25a82..f680cbfe50 100644 --- a/strawberry/channels/__init__.py +++ b/strawberry/channels/__init__.py @@ -10,8 +10,8 @@ __all__ = [ "ChannelsConsumer", "ChannelsRequest", - "GraphQLProtocolTypeRouter", "GraphQLHTTPConsumer", + "GraphQLProtocolTypeRouter", "GraphQLWSConsumer", "SyncGraphQLHTTPConsumer", ] diff --git a/strawberry/channels/handlers/base.py b/strawberry/channels/handlers/base.py index 769ec569e5..d4d8774fe8 100644 --- a/strawberry/channels/handlers/base.py +++ b/strawberry/channels/handlers/base.py @@ -2,15 +2,11 @@ import contextlib import warnings from collections import defaultdict +from collections.abc import AsyncGenerator, Awaitable, Sequence from typing import ( Any, - AsyncGenerator, - Awaitable, Callable, - DefaultDict, - List, Optional, - Sequence, ) from typing_extensions import Literal, Protocol, TypedDict from weakref import WeakSet @@ -31,7 +27,7 @@ class ChannelsLayer(Protocol): # pragma: no cover # Default channels API - extensions: List[Literal["groups", "flush"]] + extensions: list[Literal["groups", "flush"]] async def send(self, channel: str, message: dict) -> None: ... @@ -62,7 +58,7 @@ class ChannelsConsumer(AsyncConsumer): channel_receive: Callable[[], Awaitable[dict]] def __init__(self, *args: str, **kwargs: Any) -> None: - self.listen_queues: DefaultDict[str, WeakSet[asyncio.Queue]] = defaultdict( + self.listen_queues: defaultdict[str, WeakSet[asyncio.Queue]] = defaultdict( WeakSet ) super().__init__(*args, **kwargs) diff --git a/strawberry/channels/handlers/http_handler.py b/strawberry/channels/handlers/http_handler.py index a60bf2789e..7281f53cdf 100644 --- a/strawberry/channels/handlers/http_handler.py +++ b/strawberry/channels/handlers/http_handler.py @@ -8,10 +8,7 @@ from typing import ( TYPE_CHECKING, Any, - AsyncGenerator, Callable, - Dict, - Mapping, Optional, Union, ) @@ -35,6 +32,8 @@ from .base import ChannelsConsumer if TYPE_CHECKING: + from collections.abc import AsyncGenerator, Mapping + from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.http.types import HTTPMethod, QueryParams @@ -46,7 +45,7 @@ class ChannelsResponse: content: bytes status: int = 200 content_type: str = "application/json" - headers: Dict[bytes, bytes] = dataclasses.field(default_factory=dict) + headers: dict[bytes, bytes] = dataclasses.field(default_factory=dict) @dataclasses.dataclass @@ -54,7 +53,7 @@ class MultipartChannelsResponse: stream: Callable[[], AsyncGenerator[str, None]] status: int = 200 content_type: str = "multipart/mixed;boundary=graphql;subscriptionSpec=1.0" - headers: Dict[bytes, bytes] = dataclasses.field(default_factory=dict) + headers: dict[bytes, bytes] = dataclasses.field(default_factory=dict) @dataclasses.dataclass @@ -279,7 +278,7 @@ async def create_streaming_response( request: ChannelsRequest, stream: Callable[[], AsyncGenerator[str, None]], sub_response: TemporalResponse, - headers: Dict[str, str], + headers: dict[str, str], ) -> MultipartChannelsResponse: status = sub_response.status_code or 200 diff --git a/strawberry/channels/handlers/ws_handler.py b/strawberry/channels/handlers/ws_handler.py index 34ba50f8bb..54992b1d44 100644 --- a/strawberry/channels/handlers/ws_handler.py +++ b/strawberry/channels/handlers/ws_handler.py @@ -5,10 +5,7 @@ import json from typing import ( TYPE_CHECKING, - AsyncGenerator, - Mapping, Optional, - Tuple, TypedDict, Union, ) @@ -22,6 +19,8 @@ from .base import ChannelsWSConsumer if TYPE_CHECKING: + from collections.abc import AsyncGenerator, Mapping + from strawberry.http import GraphQLHTTPResponse from strawberry.schema import BaseSchema @@ -111,7 +110,7 @@ def __init__( keep_alive: bool = False, keep_alive_interval: float = 1, debug: bool = False, - subscription_protocols: Tuple[str, str] = ( + subscription_protocols: tuple[str, str] = ( GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL, ), diff --git a/strawberry/channels/testing.py b/strawberry/channels/testing.py index 892bb4bda0..be3c276fdf 100644 --- a/strawberry/channels/testing.py +++ b/strawberry/channels/testing.py @@ -4,12 +4,7 @@ from typing import ( TYPE_CHECKING, Any, - AsyncIterator, - Dict, - List, Optional, - Tuple, - Type, Union, ) @@ -24,6 +19,7 @@ from strawberry.types import ExecutionResult if TYPE_CHECKING: + from collections.abc import AsyncIterator from types import TracebackType from typing_extensions import Self @@ -57,7 +53,7 @@ def __init__( self, application: ASGIApplication, path: str, - headers: Optional[List[Tuple[bytes, bytes]]] = None, + headers: Optional[list[tuple[bytes, bytes]]] = None, protocol: str = GRAPHQL_TRANSPORT_WS_PROTOCOL, connection_params: dict = {}, **kwargs: Any, @@ -85,7 +81,7 @@ async def __aenter__(self) -> Self: async def __aexit__( self, - exc_type: Optional[Type[BaseException]], + exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: @@ -116,7 +112,7 @@ async def gql_init(self) -> None: # get transformed into `FormattedExecutionResult` on the wire, but we attempt # to do a limited representation of them here, to make testing simpler. async def subscribe( - self, query: str, variables: Optional[Dict] = None + self, query: str, variables: Optional[dict] = None ) -> Union[ExecutionResult, AsyncIterator[ExecutionResult]]: id_ = uuid.uuid4().hex @@ -164,7 +160,7 @@ async def subscribe( else: return - def process_errors(self, errors: List[GraphQLFormattedError]) -> List[GraphQLError]: + def process_errors(self, errors: list[GraphQLFormattedError]) -> list[GraphQLError]: """Reconstructs a GraphQLError from a FormattedGraphQLError.""" result = [] for f_error in errors: diff --git a/strawberry/cli/commands/codegen.py b/strawberry/cli/commands/codegen.py index 6fe784da30..50c3c3327b 100644 --- a/strawberry/cli/commands/codegen.py +++ b/strawberry/cli/commands/codegen.py @@ -3,8 +3,8 @@ import functools import importlib import inspect -from pathlib import Path # noqa: TCH003 -from typing import List, Optional, Type, Union, cast +from pathlib import Path # noqa: TC003 +from typing import Optional, Union, cast import rich import typer @@ -22,7 +22,7 @@ def _is_codegen_plugin(obj: object) -> bool: ) -def _import_plugin(plugin: str) -> Optional[Type[QueryCodegenPlugin]]: +def _import_plugin(plugin: str) -> Optional[type[QueryCodegenPlugin]]: module_name = plugin symbol_name: Optional[str] = None @@ -63,7 +63,7 @@ def _import_plugin(plugin: str) -> Optional[Type[QueryCodegenPlugin]]: @functools.lru_cache def _load_plugin( plugin_path: str, -) -> Union[Type[QueryCodegenPlugin], Type[ConsolePlugin]]: +) -> type[Union[QueryCodegenPlugin, ConsolePlugin]]: # try to import plugin_name from current folder # then try to import from strawberry.codegen.plugins @@ -80,8 +80,8 @@ def _load_plugin( def _load_plugins( - plugin_ids: List[str], query: Path -) -> List[Union[QueryCodegenPlugin, ConsolePlugin]]: + plugin_ids: list[str], query: Path +) -> list[Union[QueryCodegenPlugin, ConsolePlugin]]: plugins = [] for ptype_id in plugin_ids: ptype = _load_plugin(ptype_id) @@ -93,7 +93,7 @@ def _load_plugins( @app.command(help="Generate code from a query") def codegen( - query: Optional[List[Path]] = typer.Argument( + query: Optional[list[Path]] = typer.Argument( default=None, exists=True, dir_okay=False ), schema: str = typer.Option(..., help="Python path to the schema file"), @@ -117,7 +117,7 @@ def codegen( writable=True, resolve_path=True, ), - selected_plugins: List[str] = typer.Option( + selected_plugins: list[str] = typer.Option( ..., "-p", "--plugins", @@ -135,7 +135,7 @@ def codegen( console_plugin.before_any_start() for q in query: - plugins = cast(List[QueryCodegenPlugin], _load_plugins(selected_plugins, q)) + plugins = cast(list[QueryCodegenPlugin], _load_plugins(selected_plugins, q)) code_generator = QueryCodegen( schema_symbol, plugins=plugins, console_plugin=console_plugin diff --git a/strawberry/cli/commands/upgrade/__init__.py b/strawberry/cli/commands/upgrade/__init__.py index 2b8f387ccb..4c18016f6a 100644 --- a/strawberry/cli/commands/upgrade/__init__.py +++ b/strawberry/cli/commands/upgrade/__init__.py @@ -1,9 +1,8 @@ from __future__ import annotations import glob -import pathlib # noqa: TCH003 +import pathlib # noqa: TC003 import sys -from typing import List import rich import typer @@ -29,7 +28,7 @@ def upgrade( autocompletion=lambda: list(codemods.keys()), help="Name of the upgrade to run", ), - paths: List[pathlib.Path] = typer.Argument(..., file_okay=True, dir_okay=True), + paths: list[pathlib.Path] = typer.Argument(..., file_okay=True, dir_okay=True), python_target: str = typer.Option( ".".join(str(x) for x in sys.version_info[:2]), "--python-target", diff --git a/strawberry/cli/commands/upgrade/_run_codemod.py b/strawberry/cli/commands/upgrade/_run_codemod.py index 04e168e240..abd6e6e8b1 100644 --- a/strawberry/cli/commands/upgrade/_run_codemod.py +++ b/strawberry/cli/commands/upgrade/_run_codemod.py @@ -4,7 +4,7 @@ import os from importlib.metadata import version from multiprocessing import Pool, cpu_count -from typing import TYPE_CHECKING, Any, Dict, Generator, Sequence, Type, Union +from typing import TYPE_CHECKING, Any, Union from libcst.codemod._cli import ExecutionConfig, ExecutionResult, _execute_transform from libcst.codemod._dummy_pool import DummyPool @@ -13,10 +13,12 @@ from ._fake_progress import FakeProgress if TYPE_CHECKING: + from collections.abc import Generator, Sequence + from libcst.codemod import Codemod -ProgressType = Union[Type[Progress], Type[FakeProgress]] -PoolType = Union[Type[Pool], Type[DummyPool]] # type: ignore +ProgressType = Union[type[Progress], type[FakeProgress]] +PoolType = Union[type[Pool], type[DummyPool]] # type: ignore def _get_libcst_version() -> tuple[int, int, int]: @@ -31,9 +33,9 @@ def _get_libcst_version() -> tuple[int, int, int]: def _execute_transform_wrap( - job: Dict[str, Any], + job: dict[str, Any], ) -> ExecutionResult: - additional_kwargs: Dict[str, Any] = {} + additional_kwargs: dict[str, Any] = {} if _get_libcst_version() >= (1, 4, 0): additional_kwargs["scratch"] = {} diff --git a/strawberry/codegen/exceptions.py b/strawberry/codegen/exceptions.py index ca7bd6ad10..3db33c1537 100644 --- a/strawberry/codegen/exceptions.py +++ b/strawberry/codegen/exceptions.py @@ -16,7 +16,7 @@ class MultipleOperationsProvidedError(CodegenError): __all__ = [ "CodegenError", - "NoOperationProvidedError", - "NoOperationNameProvidedError", "MultipleOperationsProvidedError", + "NoOperationNameProvidedError", + "NoOperationProvidedError", ] diff --git a/strawberry/codegen/plugins/print_operation.py b/strawberry/codegen/plugins/print_operation.py index 9d71c0e704..5a37c87086 100644 --- a/strawberry/codegen/plugins/print_operation.py +++ b/strawberry/codegen/plugins/print_operation.py @@ -1,7 +1,7 @@ from __future__ import annotations import textwrap -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional from strawberry.codegen import CodegenFile, QueryCodegenPlugin from strawberry.codegen.types import ( @@ -35,8 +35,8 @@ class PrintOperationPlugin(QueryCodegenPlugin): def generate_code( - self, types: List[GraphQLType], operation: GraphQLOperation - ) -> List[CodegenFile]: + self, types: list[GraphQLType], operation: GraphQLOperation + ) -> list[CodegenFile]: code_lines = [] for t in types: if not isinstance(t, GraphQLFragmentType): @@ -139,7 +139,7 @@ def _print_argument_value(self, value: GraphQLArgumentValue) -> str: raise ValueError(f"not supported: {type(value)}") # pragma: no cover - def _print_arguments(self, arguments: List[GraphQLArgument]) -> str: + def _print_arguments(self, arguments: list[GraphQLArgument]) -> str: if not arguments: return "" @@ -154,7 +154,7 @@ def _print_arguments(self, arguments: List[GraphQLArgument]) -> str: + ")" ) - def _print_directives(self, directives: List[GraphQLDirective]) -> str: + def _print_directives(self, directives: list[GraphQLDirective]) -> str: if not directives: return "" @@ -204,7 +204,7 @@ def _print_selection(self, selection: GraphQLSelection) -> str: raise ValueError(f"Unsupported selection: {selection}") # pragma: no cover - def _print_selections(self, selections: List[GraphQLSelection]) -> str: + def _print_selections(self, selections: list[GraphQLSelection]) -> str: selections_text = "\n".join( [self._print_selection(selection) for selection in selections] ) diff --git a/strawberry/codegen/plugins/python.py b/strawberry/codegen/plugins/python.py index 03f97aceb8..ab7ba5e00c 100644 --- a/strawberry/codegen/plugins/python.py +++ b/strawberry/codegen/plugins/python.py @@ -3,7 +3,7 @@ import textwrap from collections import defaultdict from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Set +from typing import TYPE_CHECKING, Optional from strawberry.codegen import CodegenFile, QueryCodegenPlugin from strawberry.codegen.types import ( @@ -35,7 +35,7 @@ class PythonType: class PythonPlugin(QueryCodegenPlugin): - SCALARS_TO_PYTHON_TYPES: Dict[str, PythonType] = { + SCALARS_TO_PYTHON_TYPES: dict[str, PythonType] = { "ID": PythonType("str"), "Int": PythonType("int"), "String": PythonType("str"), @@ -49,13 +49,13 @@ class PythonPlugin(QueryCodegenPlugin): } def __init__(self, query: Path) -> None: - self.imports: Dict[str, Set[str]] = defaultdict(set) + self.imports: dict[str, set[str]] = defaultdict(set) self.outfile_name: str = query.with_suffix(".py").name self.query = query def generate_code( - self, types: List[GraphQLType], operation: GraphQLOperation - ) -> List[CodegenFile]: + self, types: list[GraphQLType], operation: GraphQLOperation + ) -> list[CodegenFile]: printed_types = list(filter(None, (self._print_type(type) for type in types))) imports = self._print_imports() @@ -80,7 +80,7 @@ def _get_type_name(self, type_: GraphQLType) -> str: if isinstance(type_, GraphQLList): self.imports["typing"].add("List") - return f"List[{self._get_type_name(type_.of_type)}]" + return f"list[{self._get_type_name(type_.of_type)}]" if isinstance(type_, GraphQLUnion): # TODO: wrong place for this diff --git a/strawberry/codegen/plugins/typescript.py b/strawberry/codegen/plugins/typescript.py index ede9bd859d..057afd4870 100644 --- a/strawberry/codegen/plugins/typescript.py +++ b/strawberry/codegen/plugins/typescript.py @@ -1,7 +1,7 @@ from __future__ import annotations import textwrap -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING from strawberry.codegen import CodegenFile, QueryCodegenPlugin from strawberry.codegen.types import ( @@ -40,8 +40,8 @@ def __init__(self, query: Path) -> None: self.query = query def generate_code( - self, types: List[GraphQLType], operation: GraphQLOperation - ) -> List[CodegenFile]: + self, types: list[GraphQLType], operation: GraphQLOperation + ) -> list[CodegenFile]: printed_types = list(filter(None, (self._print_type(type) for type in types))) return [CodegenFile(self.outfile_name, "\n\n".join(printed_types))] diff --git a/strawberry/codegen/query_codegen.py b/strawberry/codegen/query_codegen.py index 583f7fbe39..675820e72d 100644 --- a/strawberry/codegen/query_codegen.py +++ b/strawberry/codegen/query_codegen.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Iterable, Mapping, Sequence from dataclasses import MISSING, dataclass from enum import Enum from functools import cmp_to_key, partial @@ -8,13 +9,7 @@ TYPE_CHECKING, Any, Callable, - Iterable, - List, - Mapping, Optional, - Sequence, - Tuple, - Type, Union, cast, ) @@ -114,7 +109,7 @@ class CodegenFile: @dataclass class CodegenResult: - files: List[CodegenFile] + files: list[CodegenFile] def to_string(self) -> str: return "\n".join(f.content for f in self.files) + "\n" @@ -144,15 +139,15 @@ def on_start(self) -> None: ... def on_end(self, result: CodegenResult) -> None: ... def generate_code( - self, types: List[GraphQLType], operation: GraphQLOperation - ) -> List[CodegenFile]: + self, types: list[GraphQLType], operation: GraphQLOperation + ) -> list[CodegenFile]: return [] class ConsolePlugin: def __init__(self, output_dir: Path) -> None: self.output_dir = output_dir - self.files_generated: List[Path] = [] + self.files_generated: list[Path] = [] def before_any_start(self) -> None: rich.print( @@ -238,13 +233,13 @@ def _py_to_graphql_value(obj: Any) -> GraphQLArgumentValue: class QueryCodegenPluginManager: def __init__( self, - plugins: List[QueryCodegenPlugin], + plugins: list[QueryCodegenPlugin], console_plugin: Optional[ConsolePlugin] = None, ) -> None: self.plugins = plugins self.console_plugin = console_plugin - def _sort_types(self, types: List[GraphQLType]) -> List[GraphQLType]: + def _sort_types(self, types: list[GraphQLType]) -> list[GraphQLType]: """Sort the types. t1 < t2 iff t2 has a dependency on t1. @@ -266,7 +261,7 @@ def type_cmp(t1: GraphQLType, t2: GraphQLType) -> int: return sorted(types, key=cmp_to_key(type_cmp)) def generate_code( - self, types: List[GraphQLType], operation: GraphQLOperation + self, types: list[GraphQLType], operation: GraphQLOperation ) -> CodegenResult: result = CodegenResult(files=[]) @@ -301,12 +296,12 @@ class QueryCodegen: def __init__( self, schema: Schema, - plugins: List[QueryCodegenPlugin], + plugins: list[QueryCodegenPlugin], console_plugin: Optional[ConsolePlugin] = None, ) -> None: self.schema = schema self.plugin_manager = QueryCodegenPluginManager(plugins, console_plugin) - self.types: List[GraphQLType] = [] + self.types: list[GraphQLType] = [] def run(self, query: str) -> CodegenResult: self.plugin_manager.on_start() @@ -395,7 +390,7 @@ def _convert_selection(self, selection: SelectionNode) -> GraphQLSelection: def _convert_selection_set( self, selection_set: Optional[SelectionSetNode] - ) -> List[GraphQLSelection]: + ) -> list[GraphQLSelection]: if selection_set is None: return [] @@ -442,7 +437,7 @@ def _convert_value(self, value: ValueNode) -> GraphQLArgumentValue: def _convert_arguments( self, arguments: Iterable[ArgumentNode] - ) -> List[GraphQLArgument]: + ) -> list[GraphQLArgument]: return [ GraphQLArgument(argument.name.value, self._convert_value(argument.value)) for argument in arguments @@ -450,7 +445,7 @@ def _convert_arguments( def _convert_directives( self, directives: Iterable[DirectiveNode] - ) -> List[GraphQLDirective]: + ) -> list[GraphQLDirective]: return [ GraphQLDirective( directive.name.value, @@ -500,7 +495,7 @@ def _convert_variable_definitions( self, variable_definitions: Optional[Iterable[VariableDefinitionNode]], operation_name: str, - ) -> Tuple[List[GraphQLVariable], Optional[GraphQLObjectType]]: + ) -> tuple[list[GraphQLVariable], Optional[GraphQLObjectType]]: if not variable_definitions: return [], None @@ -508,7 +503,7 @@ def _convert_variable_definitions( self._collect_type(type_) - variables: List[GraphQLVariable] = [] + variables: list[GraphQLVariable] = [] for variable_definition in variable_definitions: variable_type = self._collect_type_from_variable(variable_definition.type) @@ -522,7 +517,7 @@ def _convert_variable_definitions( return variables, type_ - def _get_operations(self, ast: DocumentNode) -> List[OperationDefinitionNode]: + def _get_operations(self, ast: DocumentNode) -> list[OperationDefinitionNode]: return [ definition for definition in ast.definitions @@ -642,7 +637,7 @@ def _field_from_selection( def _unwrap_type( self, type_: Union[type, StrawberryType] - ) -> Tuple[ + ) -> tuple[ Union[type, StrawberryType], Optional[Callable[[GraphQLType], GraphQLType]] ]: wrapper: Optional[Callable[[GraphQLType], GraphQLType]] = None @@ -777,7 +772,7 @@ def _collect_types( ) current_type = graph_ql_object_type_factory(class_name) - fields: List[Union[GraphQLFragmentSpread, GraphQLField]] = [] + fields: list[Union[GraphQLFragmentSpread, GraphQLField]] = [] for sub_selection in selection_set.selections: if isinstance(sub_selection, FragmentSpreadNode): @@ -805,7 +800,7 @@ def _collect_types( # `GraphQLField` or `GraphQLFragmentSpread` # and the suite above will cause this statement to be # skipped if there are any `GraphQLFragmentSpread`. - current_type.fields = cast(List[GraphQLField], fields) + current_type.fields = cast(list[GraphQLField], fields) self._collect_type(current_type) @@ -821,12 +816,12 @@ def _collect_types_using_fragments( selection: HasSelectionSet, parent_type: StrawberryObjectDefinition, class_name: str, - ) -> List[GraphQLObjectType]: + ) -> list[GraphQLObjectType]: assert selection.selection_set - common_fields: List[GraphQLField] = [] - fragments: List[InlineFragmentNode] = [] - sub_types: List[GraphQLObjectType] = [] + common_fields: list[GraphQLField] = [] + fragments: list[InlineFragmentNode] = [] + sub_types: list[GraphQLObjectType] = [] for sub_selection in selection.selection_set.selections: if isinstance(sub_selection, FieldNode): @@ -848,7 +843,7 @@ def _collect_types_using_fragments( list(common_fields), graphql_typename=type_condition_name, ) - fields: List[Union[GraphQLFragmentSpread, GraphQLField]] = [] + fields: list[Union[GraphQLFragmentSpread, GraphQLField]] = [] for sub_selection in fragment.selection_set.selections: if isinstance(sub_selection, FragmentSpreadNode): @@ -893,7 +888,7 @@ def _collect_types_using_fragments( # `GraphQLField` or `GraphQLFragmentSpread` # and the suite above will cause this statement to be # skipped if there are any `GraphQLFragmentSpread`. - current_type.fields.extend(cast(List[GraphQLField], fields)) + current_type.fields.extend(cast(list[GraphQLField], fields)) sub_types.append(current_type) @@ -903,7 +898,7 @@ def _collect_types_using_fragments( return sub_types def _collect_scalar( - self, scalar_definition: ScalarDefinition, python_type: Optional[Type] + self, scalar_definition: ScalarDefinition, python_type: Optional[type] ) -> GraphQLScalar: graphql_scalar = GraphQLScalar(scalar_definition.name, python_type=python_type) @@ -922,9 +917,9 @@ def _collect_enum(self, enum: EnumDefinition) -> GraphQLEnum: __all__ = [ - "QueryCodegen", - "QueryCodegenPlugin", - "ConsolePlugin", "CodegenFile", "CodegenResult", + "ConsolePlugin", + "QueryCodegen", + "QueryCodegenPlugin", ] diff --git a/strawberry/codegen/types.py b/strawberry/codegen/types.py index 1518cb5f1f..a0e824cf1e 100644 --- a/strawberry/codegen/types.py +++ b/strawberry/codegen/types.py @@ -1,9 +1,10 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import TYPE_CHECKING, List, Mapping, Optional, Type, Union +from typing import TYPE_CHECKING, Optional, Union if TYPE_CHECKING: + from collections.abc import Mapping from enum import EnumMeta from typing_extensions import Literal @@ -23,7 +24,7 @@ class GraphQLList: @dataclass class GraphQLUnion: name: str - types: List[GraphQLObjectType] + types: list[GraphQLObjectType] @dataclass @@ -42,7 +43,7 @@ class GraphQLFragmentSpread: @dataclass class GraphQLObjectType: name: str - fields: List[GraphQLField] = field(default_factory=list) + fields: list[GraphQLField] = field(default_factory=list) graphql_typename: Optional[str] = None @@ -52,7 +53,7 @@ class GraphQLObjectType: @dataclass class GraphQLFragmentType(GraphQLObjectType): name: str - fields: List[GraphQLField] = field(default_factory=list) + fields: list[GraphQLField] = field(default_factory=list) graphql_typename: Optional[str] = None on: str = "" @@ -66,14 +67,14 @@ def __post_init__(self) -> None: @dataclass class GraphQLEnum: name: str - values: List[str] + values: list[str] python_type: EnumMeta @dataclass class GraphQLScalar: name: str - python_type: Optional[Type] + python_type: Optional[type] GraphQLType = Union[ @@ -90,15 +91,15 @@ class GraphQLScalar: class GraphQLFieldSelection: field: str alias: Optional[str] - selections: List[GraphQLSelection] - directives: List[GraphQLDirective] - arguments: List[GraphQLArgument] + selections: list[GraphQLSelection] + directives: list[GraphQLDirective] + arguments: list[GraphQLArgument] @dataclass class GraphQLInlineFragment: type_condition: str - selections: List[GraphQLSelection] + selections: list[GraphQLSelection] GraphQLSelection = Union[ @@ -141,7 +142,7 @@ class GraphQLNullValue: @dataclass class GraphQLListValue: - values: List[GraphQLArgumentValue] + values: list[GraphQLArgumentValue] @dataclass @@ -176,7 +177,7 @@ class GraphQLArgument: @dataclass class GraphQLDirective: name: str - arguments: List[GraphQLArgument] + arguments: list[GraphQLArgument] @dataclass @@ -189,39 +190,39 @@ class GraphQLVariable: class GraphQLOperation: name: str kind: Literal["query", "mutation", "subscription"] - selections: List[GraphQLSelection] - directives: List[GraphQLDirective] - variables: List[GraphQLVariable] + selections: list[GraphQLSelection] + directives: list[GraphQLDirective] + variables: list[GraphQLVariable] type: GraphQLObjectType variables_type: Optional[GraphQLObjectType] __all__ = [ - "GraphQLOptional", - "GraphQLList", - "GraphQLUnion", + "GraphQLArgument", + "GraphQLArgumentValue", + "GraphQLBoolValue", + "GraphQLDirective", + "GraphQLEnum", + "GraphQLEnumValue", "GraphQLField", + "GraphQLFieldSelection", + "GraphQLFloatValue", "GraphQLFragmentSpread", - "GraphQLObjectType", "GraphQLFragmentType", - "GraphQLEnum", - "GraphQLScalar", - "GraphQLType", - "GraphQLFieldSelection", "GraphQLInlineFragment", - "GraphQLSelection", - "GraphQLStringValue", "GraphQLIntValue", - "GraphQLFloatValue", - "GraphQLEnumValue", - "GraphQLBoolValue", - "GraphQLNullValue", + "GraphQLList", "GraphQLListValue", + "GraphQLNullValue", + "GraphQLObjectType", "GraphQLObjectValue", - "GraphQLVariableReference", - "GraphQLArgumentValue", - "GraphQLArgument", - "GraphQLDirective", - "GraphQLVariable", "GraphQLOperation", + "GraphQLOptional", + "GraphQLScalar", + "GraphQLSelection", + "GraphQLStringValue", + "GraphQLType", + "GraphQLUnion", + "GraphQLVariable", + "GraphQLVariableReference", ] diff --git a/strawberry/codemods/annotated_unions.py b/strawberry/codemods/annotated_unions.py index 0da878b75c..096c78ac0a 100644 --- a/strawberry/codemods/annotated_unions.py +++ b/strawberry/codemods/annotated_unions.py @@ -1,13 +1,16 @@ from __future__ import annotations -from typing import Optional, Sequence +from typing import TYPE_CHECKING, Optional import libcst as cst import libcst.matchers as m -from libcst._nodes.expression import BaseExpression, Call # noqa: TCH002 +from libcst._nodes.expression import BaseExpression, Call # noqa: TC002 from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor +if TYPE_CHECKING: + from collections.abc import Sequence + def _find_named_argument(args: Sequence[cst.Arg], name: str) -> cst.Arg | None: return next( diff --git a/strawberry/dataloader.py b/strawberry/dataloader.py index ce2d247b27..c40d274b37 100644 --- a/strawberry/dataloader.py +++ b/strawberry/dataloader.py @@ -8,16 +8,9 @@ from typing import ( TYPE_CHECKING, Any, - Awaitable, Callable, - Dict, Generic, - Hashable, - Iterable, - List, - Mapping, Optional, - Sequence, TypeVar, Union, overload, @@ -27,6 +20,7 @@ if TYPE_CHECKING: from asyncio.events import AbstractEventLoop + from collections.abc import Awaitable, Hashable, Iterable, Mapping, Sequence T = TypeVar("T") @@ -41,7 +35,7 @@ class LoaderTask(Generic[K, T]): @dataclass class Batch(Generic[K, T]): - tasks: List[LoaderTask] = dataclasses.field(default_factory=list) + tasks: list[LoaderTask] = dataclasses.field(default_factory=list) dispatched: bool = False def add_task(self, key: Any, future: Future) -> None: @@ -75,7 +69,7 @@ def __init__(self, cache_key_fn: Optional[Callable[[K], Hashable]] = None) -> No self.cache_key_fn: Callable[[K], Hashable] = ( cache_key_fn if cache_key_fn is not None else lambda x: x ) - self.cache_map: Dict[Hashable, Future[T]] = {} + self.cache_map: dict[Hashable, Future[T]] = {} def get(self, key: K) -> Union[Future[T], None]: return self.cache_map.get(self.cache_key_fn(key)) @@ -99,7 +93,7 @@ class DataLoader(Generic[K, T]): def __init__( self, # any BaseException is rethrown in 'load', so should be excluded from the T type - load_fn: Callable[[List[K]], Awaitable[Sequence[Union[T, BaseException]]]], + load_fn: Callable[[list[K]], Awaitable[Sequence[Union[T, BaseException]]]], max_batch_size: Optional[int] = None, cache: bool = True, loop: Optional[AbstractEventLoop] = None, @@ -111,7 +105,7 @@ def __init__( @overload def __init__( self: DataLoader[K, Any], - load_fn: Callable[[List[K]], Awaitable[List[Any]]], + load_fn: Callable[[list[K]], Awaitable[list[Any]]], max_batch_size: Optional[int] = None, cache: bool = True, loop: Optional[AbstractEventLoop] = None, @@ -121,7 +115,7 @@ def __init__( def __init__( self, - load_fn: Callable[[List[K]], Awaitable[Sequence[Union[T, BaseException]]]], + load_fn: Callable[[list[K]], Awaitable[Sequence[Union[T, BaseException]]]], max_batch_size: Optional[int] = None, cache: bool = True, loop: Optional[AbstractEventLoop] = None, @@ -164,7 +158,7 @@ def load(self, key: K) -> Awaitable[T]: return future - def load_many(self, keys: Iterable[K]) -> Awaitable[List[T]]: + def load_many(self, keys: Iterable[K]) -> Awaitable[list[T]]: return gather(*map(self.load, keys)) def clear(self, key: K) -> None: @@ -210,8 +204,7 @@ def prime_many(self, data: Mapping[K, T], force: bool = False) -> None: def should_create_new_batch(loader: DataLoader, batch: Batch) -> bool: return bool( batch.dispatched - or loader.max_batch_size - and len(batch) >= loader.max_batch_size + or (loader.max_batch_size and len(batch) >= loader.max_batch_size) ) @@ -267,13 +260,13 @@ async def dispatch_batch(loader: DataLoader, batch: Batch) -> None: __all__ = [ - "DataLoader", - "Batch", - "LoaderTask", "AbstractCache", + "Batch", + "DataLoader", "DefaultCache", - "should_create_new_batch", - "get_current_batch", + "LoaderTask", "dispatch", "dispatch_batch", + "get_current_batch", + "should_create_new_batch", ] diff --git a/strawberry/directive.py b/strawberry/directive.py index 17eaafd0c8..4ec5dd6b70 100644 --- a/strawberry/directive.py +++ b/strawberry/directive.py @@ -2,8 +2,15 @@ import dataclasses from functools import cached_property -from typing import TYPE_CHECKING, Any, Callable, Generic, List, Optional, TypeVar -from typing_extensions import Annotated +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + Callable, + Generic, + Optional, + TypeVar, +) from graphql import DirectiveLocation @@ -84,17 +91,17 @@ class StrawberryDirective(Generic[T]): python_name: str graphql_name: Optional[str] resolver: StrawberryDirectiveResolver[T] - locations: List[DirectiveLocation] + locations: list[DirectiveLocation] description: Optional[str] = None @cached_property - def arguments(self) -> List[StrawberryArgument]: + def arguments(self) -> list[StrawberryArgument]: return self.resolver.arguments def directive( *, - locations: List[DirectiveLocation], + locations: list[DirectiveLocation], description: Optional[str] = None, name: Optional[str] = None, ) -> Callable[[Callable[..., T]], StrawberryDirective[T]]: diff --git a/strawberry/django/test/client.py b/strawberry/django/test/client.py index 1ce5ab4df4..b9d5c994fb 100644 --- a/strawberry/django/test/client.py +++ b/strawberry/django/test/client.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Optional from strawberry.test import BaseGraphQLTestClient @@ -6,9 +6,9 @@ class GraphQLTestClient(BaseGraphQLTestClient): def request( self, - body: Dict[str, object], - headers: Optional[Dict[str, object]] = None, - files: Optional[Dict[str, object]] = None, + body: dict[str, object], + headers: Optional[dict[str, object]] = None, + files: Optional[dict[str, object]] = None, ) -> Any: if files: return self._client.post( diff --git a/strawberry/django/views.py b/strawberry/django/views.py index 47cd2b1d98..97a0955b75 100644 --- a/strawberry/django/views.py +++ b/strawberry/django/views.py @@ -5,10 +5,7 @@ from typing import ( TYPE_CHECKING, Any, - AsyncIterator, Callable, - Dict, - Mapping, Optional, Union, cast, @@ -42,6 +39,8 @@ from .context import StrawberryDjangoContext if TYPE_CHECKING: + from collections.abc import AsyncIterator, Mapping + from django.template.response import TemplateResponse from strawberry.http import GraphQLHTTPResponse @@ -190,7 +189,7 @@ async def create_streaming_response( request: HttpRequest, stream: Callable[[], AsyncIterator[Any]], sub_response: TemporalHttpResponse, - headers: Dict[str, str], + headers: dict[str, str], ) -> HttpResponseBase: return StreamingHttpResponse( streaming_content=stream(), @@ -318,4 +317,4 @@ async def create_websocket_response( raise NotImplementedError -__all__ = ["GraphQLView", "AsyncGraphQLView"] +__all__ = ["AsyncGraphQLView", "GraphQLView"] diff --git a/strawberry/exceptions/__init__.py b/strawberry/exceptions/__init__.py index d492c7e9bb..acd3cc8f89 100644 --- a/strawberry/exceptions/__init__.py +++ b/strawberry/exceptions/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, Dict, Optional, Set, Union +from typing import TYPE_CHECKING, Optional, Union from graphql import GraphQLError @@ -48,7 +48,7 @@ class UnallowedReturnTypeForUnion(Exception): """The return type is not in the list of Union types.""" def __init__( - self, field_name: str, result_type: str, allowed_types: Set[GraphQLObjectType] + self, field_name: str, result_type: str, allowed_types: set[GraphQLObjectType] ) -> None: formatted_allowed_types = list(sorted(type_.name for type_ in allowed_types)) @@ -160,37 +160,37 @@ class StrawberryGraphQLError(GraphQLError): class ConnectionRejectionError(Exception): """Use it when you want to reject a WebSocket connection.""" - def __init__(self, payload: Dict[str, object] = {}) -> None: + def __init__(self, payload: dict[str, object] = {}) -> None: self.payload = payload __all__ = [ - "StrawberryException", - "UnableToFindExceptionSource", + "ConflictingArgumentsError", + "DuplicatedTypeName", + "FieldWithResolverAndDefaultFactoryError", + "FieldWithResolverAndDefaultValueError", + "InvalidArgumentTypeError", + "InvalidCustomContext", + "InvalidDefaultFactoryError", + "InvalidTypeForUnionMergeError", + "InvalidUnionTypeError", "MissingArgumentsAnnotationsError", + "MissingFieldAnnotationError", + "MissingOptionalDependenciesError", + "MissingQueryError", "MissingReturnAnnotationError", - "WrongReturnTypeForUnion", - "UnallowedReturnTypeForUnion", + "MissingTypesForGenericError", + "MultipleStrawberryArgumentsError", "ObjectIsNotAnEnumError", "ObjectIsNotClassError", - "InvalidUnionTypeError", - "InvalidTypeForUnionMergeError", - "MissingTypesForGenericError", - "UnsupportedTypeError", - "UnresolvedFieldTypeError", "PrivateStrawberryFieldError", - "MultipleStrawberryArgumentsError", "ScalarAlreadyRegisteredError", - "WrongNumberOfResultsReturned", - "FieldWithResolverAndDefaultValueError", - "FieldWithResolverAndDefaultFactoryError", - "ConflictingArgumentsError", - "MissingQueryError", - "InvalidArgumentTypeError", - "InvalidDefaultFactoryError", - "InvalidCustomContext", - "MissingFieldAnnotationError", - "DuplicatedTypeName", + "StrawberryException", "StrawberryGraphQLError", - "MissingOptionalDependenciesError", + "UnableToFindExceptionSource", + "UnallowedReturnTypeForUnion", + "UnresolvedFieldTypeError", + "UnsupportedTypeError", + "WrongNumberOfResultsReturned", + "WrongReturnTypeForUnion", ] diff --git a/strawberry/exceptions/conflicting_arguments.py b/strawberry/exceptions/conflicting_arguments.py index d83fb32572..8315363606 100644 --- a/strawberry/exceptions/conflicting_arguments.py +++ b/strawberry/exceptions/conflicting_arguments.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional from .exception import StrawberryException from .utils.source_finder import SourceFinder @@ -16,7 +16,7 @@ class ConflictingArgumentsError(StrawberryException): def __init__( self, resolver: StrawberryResolver, - arguments: List[str], + arguments: list[str], ) -> None: self.function = resolver.wrapped_func self.argument_names = arguments diff --git a/strawberry/exceptions/duplicated_type_name.py b/strawberry/exceptions/duplicated_type_name.py index 9f3eb691f3..be98c79f9c 100644 --- a/strawberry/exceptions/duplicated_type_name.py +++ b/strawberry/exceptions/duplicated_type_name.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING, Optional from .exception import StrawberryException from .utils.source_finder import SourceFinder @@ -17,8 +17,8 @@ class DuplicatedTypeName(StrawberryException): def __init__( self, - first_cls: Optional[Type], - second_cls: Optional[Type], + first_cls: Optional[type], + second_cls: Optional[type], duplicated_type_name: str, ) -> None: self.first_cls = first_cls diff --git a/strawberry/exceptions/handler.py b/strawberry/exceptions/handler.py index 2a49d497b6..414db397ed 100644 --- a/strawberry/exceptions/handler.py +++ b/strawberry/exceptions/handler.py @@ -2,7 +2,7 @@ import sys import threading from types import TracebackType -from typing import Any, Callable, Optional, Tuple, Type, cast +from typing import Any, Callable, Optional, cast from .exception import StrawberryException, UnableToFindExceptionSource @@ -10,7 +10,7 @@ ExceptionHandler = Callable[ - [Type[BaseException], BaseException, Optional[TracebackType]], None + [type[BaseException], BaseException, Optional[TracebackType]], None ] @@ -20,7 +20,7 @@ def should_use_rich_exceptions() -> bool: return errors_disabled.lower() not in ["true", "1", "yes"] -def _get_handler(exception_type: Type[BaseException]) -> ExceptionHandler: +def _get_handler(exception_type: type[BaseException]) -> ExceptionHandler: if issubclass(exception_type, StrawberryException): try: import rich @@ -29,7 +29,7 @@ def _get_handler(exception_type: Type[BaseException]) -> ExceptionHandler: else: def _handler( - exception_type: Type[BaseException], + exception_type: type[BaseException], exception: BaseException, traceback: Optional[TracebackType], ) -> None: @@ -47,7 +47,7 @@ def _handler( def strawberry_exception_handler( - exception_type: Type[BaseException], + exception_type: type[BaseException], exception: BaseException, traceback: Optional[TracebackType], ) -> None: @@ -55,8 +55,8 @@ def strawberry_exception_handler( def strawberry_threading_exception_handler( - args: Tuple[ - Type[BaseException], + args: tuple[ + type[BaseException], Optional[BaseException], Optional[TracebackType], Optional[threading.Thread], diff --git a/strawberry/exceptions/invalid_union_type.py b/strawberry/exceptions/invalid_union_type.py index 28e65de6c5..25139cfb33 100644 --- a/strawberry/exceptions/invalid_union_type.py +++ b/strawberry/exceptions/invalid_union_type.py @@ -3,7 +3,7 @@ from functools import cached_property from inspect import getframeinfo, stack from pathlib import Path -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING, Optional from strawberry.exceptions.utils.source_finder import SourceFinder @@ -80,7 +80,7 @@ def exception_source(self) -> Optional[ExceptionSource]: class InvalidTypeForUnionMergeError(StrawberryException): """A specialized version of InvalidUnionTypeError for when trying to merge unions using the pipe operator.""" - invalid_type: Type + invalid_type: type def __init__(self, union: StrawberryUnion, other: object) -> None: self.union = union diff --git a/strawberry/exceptions/missing_arguments_annotations.py b/strawberry/exceptions/missing_arguments_annotations.py index 97610584fa..a5fd5cfd48 100644 --- a/strawberry/exceptions/missing_arguments_annotations.py +++ b/strawberry/exceptions/missing_arguments_annotations.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional from .exception import StrawberryException from .utils.source_finder import SourceFinder @@ -18,7 +18,7 @@ class MissingArgumentsAnnotationsError(StrawberryException): def __init__( self, resolver: StrawberryResolver, - arguments: List[str], + arguments: list[str], ) -> None: self.missing_arguments = arguments self.function = resolver.wrapped_func diff --git a/strawberry/exceptions/missing_field_annotation.py b/strawberry/exceptions/missing_field_annotation.py index a417c15dd8..594c575393 100644 --- a/strawberry/exceptions/missing_field_annotation.py +++ b/strawberry/exceptions/missing_field_annotation.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING, Optional from .exception import StrawberryException from .utils.source_finder import SourceFinder @@ -11,7 +11,7 @@ class MissingFieldAnnotationError(StrawberryException): - def __init__(self, field_name: str, cls: Type) -> None: + def __init__(self, field_name: str, cls: type) -> None: self.cls = cls self.field_name = field_name diff --git a/strawberry/exceptions/object_is_not_an_enum.py b/strawberry/exceptions/object_is_not_an_enum.py index fe79fe2550..e3817dcb94 100644 --- a/strawberry/exceptions/object_is_not_an_enum.py +++ b/strawberry/exceptions/object_is_not_an_enum.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING, Optional from .exception import StrawberryException from .utils.source_finder import SourceFinder @@ -13,7 +13,7 @@ class ObjectIsNotAnEnumError(StrawberryException): - def __init__(self, cls: Type[Enum]) -> None: + def __init__(self, cls: type[Enum]) -> None: self.cls = cls self.message = ( "strawberry.enum can only be used with subclasses of Enum. " diff --git a/strawberry/exceptions/private_strawberry_field.py b/strawberry/exceptions/private_strawberry_field.py index c36a1f8993..918cb64223 100644 --- a/strawberry/exceptions/private_strawberry_field.py +++ b/strawberry/exceptions/private_strawberry_field.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING, Optional from .exception import StrawberryException from .utils.source_finder import SourceFinder @@ -11,7 +11,7 @@ class PrivateStrawberryFieldError(StrawberryException): - def __init__(self, field_name: str, cls: Type) -> None: + def __init__(self, field_name: str, cls: type) -> None: self.cls = cls self.field_name = field_name diff --git a/strawberry/exceptions/syntax.py b/strawberry/exceptions/syntax.py index ce460fdbbd..0403e07a2c 100644 --- a/strawberry/exceptions/syntax.py +++ b/strawberry/exceptions/syntax.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple +from typing import TYPE_CHECKING, Optional from pygments.lexers import PythonLexer from rich.segment import Segment @@ -14,10 +14,10 @@ class Syntax(RichSyntax): def __init__( self, code: str, - line_range: Tuple[int, int], - highlight_lines: Optional[Set[int]] = None, + line_range: tuple[int, int], + highlight_lines: Optional[set[int]] = None, line_offset: int = 0, - line_annotations: Optional[Dict[int, str]] = None, + line_annotations: Optional[dict[int, str]] = None, ) -> None: self.line_offset = line_offset self.line_annotations = line_annotations or {} diff --git a/strawberry/exceptions/utils/source_finder.py b/strawberry/exceptions/utils/source_finder.py index 9b2c8010d0..a24d4cc776 100644 --- a/strawberry/exceptions/utils/source_finder.py +++ b/strawberry/exceptions/utils/source_finder.py @@ -6,11 +6,12 @@ from dataclasses import dataclass from functools import cached_property from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, Type, cast +from typing import TYPE_CHECKING, Any, Callable, Optional, cast from ..exception_source import ExceptionSource if TYPE_CHECKING: + from collections.abc import Sequence from inspect import Traceback from libcst import BinaryOperation, Call, CSTNode, FunctionDef @@ -113,7 +114,7 @@ def _find_function_definition( ) def _find_class_definition( - self, source: SourcePath, cls: Type[Any] + self, source: SourcePath, cls: type[Any] ) -> Optional[CSTNode]: import libcst.matchers as m @@ -122,7 +123,7 @@ def _find_class_definition( class_defs = self._find(source.code, matcher) return self._find_definition_by_qualname(cls.__qualname__, class_defs) - def find_class(self, cls: Type[Any]) -> Optional[ExceptionSource]: + def find_class(self, cls: type[Any]) -> Optional[ExceptionSource]: source = self.find_source(cls.__module__) if source is None: @@ -147,7 +148,7 @@ def find_class(self, cls: Type[Any]) -> Optional[ExceptionSource]: ) def find_class_attribute( - self, cls: Type[Any], attribute_name: str + self, cls: type[Any], attribute_name: str ) -> Optional[ExceptionSource]: source = self.find_source(cls.__module__) @@ -560,11 +561,11 @@ def cst(self) -> Optional[LibCSTSourceFinder]: except ImportError: return None # pragma: no cover - def find_class_from_object(self, cls: Type[Any]) -> Optional[ExceptionSource]: + def find_class_from_object(self, cls: type[Any]) -> Optional[ExceptionSource]: return self.cst.find_class(cls) if self.cst else None def find_class_attribute_from_object( - self, cls: Type[Any], attribute_name: str + self, cls: type[Any], attribute_name: str ) -> Optional[ExceptionSource]: return self.cst.find_class_attribute(cls, attribute_name) if self.cst else None diff --git a/strawberry/experimental/pydantic/__init__.py b/strawberry/experimental/pydantic/__init__.py index 10f382650d..37b2d6878c 100644 --- a/strawberry/experimental/pydantic/__init__.py +++ b/strawberry/experimental/pydantic/__init__.py @@ -1,11 +1,11 @@ from .error_type import error_type from .exceptions import UnregisteredTypeException -from .object_type import input, interface, type +from .object_type import input, interface, type # noqa: A004 __all__ = [ - "error_type", "UnregisteredTypeException", + "error_type", "input", - "type", "interface", + "type", ] diff --git a/strawberry/experimental/pydantic/_compat.py b/strawberry/experimental/pydantic/_compat.py index 7a5f776a4d..b2166c2bf9 100644 --- a/strawberry/experimental/pydantic/_compat.py +++ b/strawberry/experimental/pydantic/_compat.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from decimal import Decimal from functools import cached_property -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type +from typing import TYPE_CHECKING, Any, Callable, Optional from uuid import UUID import pydantic @@ -98,7 +98,7 @@ def has_default(self) -> bool: } -def get_fields_map_for_v2() -> Dict[Any, Any]: +def get_fields_map_for_v2() -> dict[Any, Any]: import pydantic_core fields_map = { @@ -124,7 +124,7 @@ def PYDANTIC_MISSING_TYPE(self) -> Any: return PydanticUndefined - def get_model_fields(self, model: Type[BaseModel]) -> Dict[str, CompatModelField]: + def get_model_fields(self, model: type[BaseModel]) -> dict[str, CompatModelField]: field_info: dict[str, FieldInfo] = model.model_fields new_fields = {} # Convert it into CompatModelField @@ -147,10 +147,10 @@ def get_model_fields(self, model: Type[BaseModel]) -> Dict[str, CompatModelField return new_fields @cached_property - def fields_map(self) -> Dict[Any, Any]: + def fields_map(self) -> dict[Any, Any]: return get_fields_map_for_v2() - def get_basic_type(self, type_: Any) -> Type[Any]: + def get_basic_type(self, type_: Any) -> type[Any]: if type_ in self.fields_map: type_ = self.fields_map[type_] @@ -162,7 +162,7 @@ def get_basic_type(self, type_: Any) -> Type[Any]: return type_ - def model_dump(self, model_instance: BaseModel) -> Dict[Any, Any]: + def model_dump(self, model_instance: BaseModel) -> dict[Any, Any]: return model_instance.model_dump() @@ -171,7 +171,7 @@ class PydanticV1Compat: def PYDANTIC_MISSING_TYPE(self) -> Any: return dataclasses.MISSING - def get_model_fields(self, model: Type[BaseModel]) -> Dict[str, CompatModelField]: + def get_model_fields(self, model: type[BaseModel]) -> dict[str, CompatModelField]: new_fields = {} # Convert it into CompatModelField for name, field in model.__fields__.items(): # type: ignore[attr-defined] @@ -192,7 +192,7 @@ def get_model_fields(self, model: Type[BaseModel]) -> Dict[str, CompatModelField return new_fields @cached_property - def fields_map(self) -> Dict[Any, Any]: + def fields_map(self) -> dict[Any, Any]: if IS_PYDANTIC_V2: return { getattr(pydantic.v1, field_name): type @@ -206,7 +206,7 @@ def fields_map(self) -> Dict[Any, Any]: if hasattr(pydantic, field_name) } - def get_basic_type(self, type_: Any) -> Type[Any]: + def get_basic_type(self, type_: Any) -> type[Any]: if IS_PYDANTIC_V1: ConstrainedInt = pydantic.ConstrainedInt ConstrainedFloat = pydantic.ConstrainedFloat @@ -225,7 +225,7 @@ def get_basic_type(self, type_: Any) -> Type[Any]: if lenient_issubclass(type_, ConstrainedStr): return str if lenient_issubclass(type_, ConstrainedList): - return List[self.get_basic_type(type_.item_type)] # type: ignore + return list[self.get_basic_type(type_.item_type)] # type: ignore if type_ in self.fields_map: type_ = self.fields_map[type_] @@ -238,7 +238,7 @@ def get_basic_type(self, type_: Any) -> Type[Any]: return type_ - def model_dump(self, model_instance: BaseModel) -> Dict[Any, Any]: + def model_dump(self, model_instance: BaseModel) -> dict[Any, Any]: return model_instance.dict() @@ -250,7 +250,7 @@ def __init__(self, is_v2: bool) -> None: self._compat = PydanticV1Compat() # type: ignore[assignment] @classmethod - def from_model(cls, model: Type[BaseModel]) -> "PydanticCompat": + def from_model(cls, model: type[BaseModel]) -> "PydanticCompat": if hasattr(model, "model_fields"): return cls(is_v2=True) @@ -283,10 +283,10 @@ def new_type_supertype(type_: Any) -> Any: __all__ = [ "PydanticCompat", + "get_args", + "get_origin", "is_new_type", "lenient_issubclass", - "get_origin", - "get_args", "new_type_supertype", "smart_deepcopy", ] diff --git a/strawberry/experimental/pydantic/conversion.py b/strawberry/experimental/pydantic/conversion.py index ae3d2aaddb..5296f60acc 100644 --- a/strawberry/experimental/pydantic/conversion.py +++ b/strawberry/experimental/pydantic/conversion.py @@ -2,7 +2,7 @@ import copy import dataclasses -from typing import TYPE_CHECKING, Any, Type, Union, cast +from typing import TYPE_CHECKING, Any, Union, cast from strawberry.types.base import ( StrawberryList, @@ -98,7 +98,7 @@ def convert_pydantic_model_to_strawberry_class( return cls(**kwargs) -def convert_strawberry_class_to_pydantic_model(obj: Type) -> Any: +def convert_strawberry_class_to_pydantic_model(obj: type) -> Any: if hasattr(obj, "to_pydantic"): return obj.to_pydantic() elif dataclasses.is_dataclass(obj): diff --git a/strawberry/experimental/pydantic/conversion_types.py b/strawberry/experimental/pydantic/conversion_types.py index 0e24ba2f00..747c67f351 100644 --- a/strawberry/experimental/pydantic/conversion_types.py +++ b/strawberry/experimental/pydantic/conversion_types.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, Optional, Type, TypeVar +from typing import TYPE_CHECKING, Any, Optional, TypeVar from typing_extensions import Protocol from pydantic import BaseModel @@ -22,7 +22,7 @@ def __init__(self, **kwargs: Any) -> None: ... @staticmethod def from_pydantic( - instance: PydanticModel, extra: Optional[Dict[str, Any]] = None + instance: PydanticModel, extra: Optional[dict[str, Any]] = None ) -> StrawberryTypeFromPydantic[PydanticModel]: ... def to_pydantic(self, **kwargs: Any) -> PydanticModel: ... @@ -31,4 +31,4 @@ def to_pydantic(self, **kwargs: Any) -> PydanticModel: ... def __strawberry_definition__(self) -> StrawberryObjectDefinition: ... @property - def _pydantic_type(self) -> Type[PydanticModel]: ... + def _pydantic_type(self) -> type[PydanticModel]: ... diff --git a/strawberry/experimental/pydantic/error_type.py b/strawberry/experimental/pydantic/error_type.py index bbedfe610b..1f87778697 100644 --- a/strawberry/experimental/pydantic/error_type.py +++ b/strawberry/experimental/pydantic/error_type.py @@ -3,13 +3,10 @@ import dataclasses import warnings from typing import ( + TYPE_CHECKING, Any, Callable, - List, Optional, - Sequence, - Tuple, - Type, Union, cast, ) @@ -33,14 +30,19 @@ from .exceptions import MissingFieldsListError +if TYPE_CHECKING: + from collections.abc import Sequence -def get_type_for_field(field: CompatModelField) -> Union[Any, Type[None], Type[List]]: + from strawberry.types.base import WithStrawberryObjectDefinition + + +def get_type_for_field(field: CompatModelField) -> Union[type[Union[None, list]], Any]: type_ = field.outer_type_ type_ = normalize_type(type_) return field_type_to_type(type_) -def field_type_to_type(type_: Type) -> Union[Any, List[Any], None]: +def field_type_to_type(type_: type) -> Union[Any, list[Any], None]: error_class: Any = str strawberry_type: Any = error_class @@ -52,26 +54,26 @@ def field_type_to_type(type_: Type) -> Union[Any, List[Any], None]: elif lenient_issubclass(child_type, BaseModel): strawberry_type = get_strawberry_type_from_model(child_type) else: - strawberry_type = List[error_class] + strawberry_type = list[error_class] strawberry_type = Optional[strawberry_type] elif lenient_issubclass(type_, BaseModel): strawberry_type = get_strawberry_type_from_model(type_) return Optional[strawberry_type] - return Optional[List[strawberry_type]] + return Optional[list[strawberry_type]] def error_type( - model: Type[BaseModel], + model: type[BaseModel], *, - fields: Optional[List[str]] = None, + fields: Optional[list[str]] = None, name: Optional[str] = None, description: Optional[str] = None, directives: Optional[Sequence[object]] = (), all_fields: bool = False, -) -> Callable[..., Type]: - def wrap(cls: Type) -> Type: +) -> Callable[..., type]: + def wrap(cls: type) -> type: compat = PydanticCompat.from_model(model) model_fields = compat.get_model_fields(model) fields_set = set(fields) if fields else set() @@ -103,7 +105,7 @@ def wrap(cls: Type) -> Type: if not fields_set: raise MissingFieldsListError(cls) - all_model_fields: List[Tuple[str, Any, dataclasses.Field]] = [ + all_model_fields: list[tuple[str, Any, dataclasses.Field]] = [ ( name, get_type_for_field(field), @@ -113,8 +115,8 @@ def wrap(cls: Type) -> Type: if name in fields_set ] - wrapped = _wrap_dataclass(cls) - extra_fields = cast(List[dataclasses.Field], _get_fields(wrapped, {})) + wrapped: type[WithStrawberryObjectDefinition] = _wrap_dataclass(cls) + extra_fields = cast(list[dataclasses.Field], _get_fields(wrapped, {})) private_fields = get_private_fields(wrapped) all_model_fields.extend( @@ -146,7 +148,7 @@ def wrap(cls: Type) -> Type: ) model._strawberry_type = cls # type: ignore[attr-defined] - cls._pydantic_type = model + cls._pydantic_type = model # type: ignore[attr-defined] return cls return wrap diff --git a/strawberry/experimental/pydantic/exceptions.py b/strawberry/experimental/pydantic/exceptions.py index 10b7999fa9..9c54cffc87 100644 --- a/strawberry/experimental/pydantic/exceptions.py +++ b/strawberry/experimental/pydantic/exceptions.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, List, Type +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from pydantic import BaseModel @@ -8,7 +8,7 @@ class MissingFieldsListError(Exception): - def __init__(self, type: Type[BaseModel]) -> None: + def __init__(self, type: type[BaseModel]) -> None: message = ( f"List of fields to copy from {type} is empty. Add fields with the " f"`auto` type annotation" @@ -22,7 +22,7 @@ class UnsupportedTypeError(Exception): class UnregisteredTypeException(Exception): - def __init__(self, type: Type[BaseModel]) -> None: + def __init__(self, type: type[BaseModel]) -> None: message = ( f"Cannot find a Strawberry Type for {type} did you forget to register it?" ) @@ -43,9 +43,9 @@ def __init__(self, default: Any, default_factory: NoArgAnyCallable) -> None: class AutoFieldsNotInBaseModelError(Exception): def __init__( self, - fields: List[str], + fields: list[str], cls_name: str, - model: Type[BaseModel], + model: type[BaseModel], ) -> None: message = ( f"{cls_name} defines {fields} with strawberry.auto. " diff --git a/strawberry/experimental/pydantic/fields.py b/strawberry/experimental/pydantic/fields.py index 9cac486290..fe6b863431 100644 --- a/strawberry/experimental/pydantic/fields.py +++ b/strawberry/experimental/pydantic/fields.py @@ -1,6 +1,5 @@ import builtins -from typing import Any, Union -from typing_extensions import Annotated +from typing import Annotated, Any, Union from pydantic import BaseModel @@ -25,17 +24,7 @@ else: raise -try: - from typing import GenericAlias as TypingGenericAlias # type: ignore -except ImportError: - import sys - - # python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on) - # we do this under a conditional to avoid a mypy :) - if sys.version_info < (3, 9): - TypingGenericAlias = () - else: - raise +from typing import GenericAlias as TypingGenericAlias # type: ignore def replace_pydantic_types(type_: Any, is_input: bool) -> Any: diff --git a/strawberry/experimental/pydantic/object_type.py b/strawberry/experimental/pydantic/object_type.py index c9e1bb1161..caf8571b87 100644 --- a/strawberry/experimental/pydantic/object_type.py +++ b/strawberry/experimental/pydantic/object_type.py @@ -7,12 +7,7 @@ TYPE_CHECKING, Any, Callable, - Dict, - List, Optional, - Sequence, - Set, - Type, cast, ) @@ -39,6 +34,9 @@ from strawberry.types.type_resolver import _get_fields if TYPE_CHECKING: + import builtins + from collections.abc import Sequence + from graphql import GraphQLResolveInfo @@ -59,8 +57,8 @@ def get_type_for_field(field: CompatModelField, is_input: bool, compat: Pydantic def _build_dataclass_creation_fields( field: CompatModelField, is_input: bool, - existing_fields: Dict[str, StrawberryField], - auto_fields_set: Set[str], + existing_fields: dict[str, StrawberryField], + auto_fields_set: set[str], use_pydantic_alias: bool, compat: PydanticCompat, ) -> DataclassCreationFields: @@ -118,9 +116,9 @@ def _build_dataclass_creation_fields( def type( - model: Type[PydanticModel], + model: builtins.type[PydanticModel], *, - fields: Optional[List[str]] = None, + fields: Optional[list[str]] = None, name: Optional[str] = None, is_input: bool = False, is_interface: bool = False, @@ -128,8 +126,8 @@ def type( directives: Optional[Sequence[object]] = (), all_fields: bool = False, use_pydantic_alias: bool = True, -) -> Callable[..., Type[StrawberryTypeFromPydantic[PydanticModel]]]: - def wrap(cls: Any) -> Type[StrawberryTypeFromPydantic[PydanticModel]]: +) -> Callable[..., builtins.type[StrawberryTypeFromPydantic[PydanticModel]]]: + def wrap(cls: Any) -> builtins.type[StrawberryTypeFromPydantic[PydanticModel]]: compat = PydanticCompat.from_model(model) model_fields = compat.get_model_fields(model) original_fields_set = set(fields) if fields else set() @@ -177,12 +175,12 @@ def wrap(cls: Any) -> Type[StrawberryTypeFromPydantic[PydanticModel]]: wrapped = _wrap_dataclass(cls) extra_strawberry_fields = _get_fields(wrapped, {}) - extra_fields = cast(List[dataclasses.Field], extra_strawberry_fields) + extra_fields = cast(list[dataclasses.Field], extra_strawberry_fields) private_fields = get_private_fields(wrapped) extra_fields_dict = {field.name: field for field in extra_strawberry_fields} - all_model_fields: List[DataclassCreationFields] = [ + all_model_fields: list[DataclassCreationFields] = [ _build_dataclass_creation_fields( field, is_input, @@ -208,7 +206,7 @@ def wrap(cls: Any) -> Type[StrawberryTypeFromPydantic[PydanticModel]]: # Implicitly define `is_type_of` to support interfaces/unions that use # pydantic objects (not the corresponding strawberry type) @classmethod # type: ignore - def is_type_of(cls: Type, obj: Any, _info: GraphQLResolveInfo) -> bool: + def is_type_of(cls: builtins.type, obj: Any, _info: GraphQLResolveInfo) -> bool: return isinstance(obj, (cls, model)) namespace = {"is_type_of": is_type_of} @@ -232,7 +230,7 @@ def is_type_of(cls: Type, obj: Any, _info: GraphQLResolveInfo) -> bool: if hasattr(cls, "resolve_reference"): namespace["resolve_reference"] = cls.resolve_reference - kwargs: Dict[str, object] = {} + kwargs: dict[str, object] = {} # Python 3.10.1 introduces the kw_only param to `make_dataclass`. # If we're on an older version then generate our own custom init function @@ -273,7 +271,7 @@ def is_type_of(cls: Type, obj: Any, _info: GraphQLResolveInfo) -> bool: cls._pydantic_type = model def from_pydantic_default( - instance: PydanticModel, extra: Optional[Dict[str, Any]] = None + instance: PydanticModel, extra: Optional[dict[str, Any]] = None ) -> StrawberryTypeFromPydantic[PydanticModel]: ret = convert_pydantic_model_to_strawberry_class( cls=cls, model_instance=instance, extra=extra @@ -302,16 +300,16 @@ def to_pydantic_default(self: Any, **kwargs: Any) -> PydanticModel: def input( - model: Type[PydanticModel], + model: builtins.type[PydanticModel], *, - fields: Optional[List[str]] = None, + fields: Optional[list[str]] = None, name: Optional[str] = None, is_interface: bool = False, description: Optional[str] = None, directives: Optional[Sequence[object]] = (), all_fields: bool = False, use_pydantic_alias: bool = True, -) -> Callable[..., Type[StrawberryTypeFromPydantic[PydanticModel]]]: +) -> Callable[..., builtins.type[StrawberryTypeFromPydantic[PydanticModel]]]: """Convenience decorator for creating an input type from a Pydantic model. Equal to `partial(type, is_input=True)` @@ -332,16 +330,16 @@ def input( def interface( - model: Type[PydanticModel], + model: builtins.type[PydanticModel], *, - fields: Optional[List[str]] = None, + fields: Optional[list[str]] = None, name: Optional[str] = None, is_input: bool = False, description: Optional[str] = None, directives: Optional[Sequence[object]] = (), all_fields: bool = False, use_pydantic_alias: bool = True, -) -> Callable[..., Type[StrawberryTypeFromPydantic[PydanticModel]]]: +) -> Callable[..., builtins.type[StrawberryTypeFromPydantic[PydanticModel]]]: """Convenience decorator for creating an interface type from a Pydantic model. Equal to `partial(type, is_interface=True)` diff --git a/strawberry/experimental/pydantic/utils.py b/strawberry/experimental/pydantic/utils.py index 912553fb98..acc9eba635 100644 --- a/strawberry/experimental/pydantic/utils.py +++ b/strawberry/experimental/pydantic/utils.py @@ -4,11 +4,7 @@ from typing import ( TYPE_CHECKING, Any, - List, NamedTuple, - Set, - Tuple, - Type, Union, cast, ) @@ -38,9 +34,9 @@ from pydantic.typing import NoArgAnyCallable -def normalize_type(type_: Type) -> Any: +def normalize_type(type_: type) -> Any: if is_list(type_): - return List[normalize_type(get_list_annotation(type_))] # type: ignore + return list[normalize_type(get_list_annotation(type_))] # type: ignore if is_optional(type_): return get_optional_annotation(type_) @@ -55,7 +51,7 @@ def get_strawberry_type_from_model(type_: Any) -> Any: raise UnregisteredTypeException(type_) -def get_private_fields(cls: Type) -> List[dataclasses.Field]: +def get_private_fields(cls: type) -> list[dataclasses.Field]: return [field for field in dataclasses.fields(cls) if is_private(field.type)] @@ -63,10 +59,10 @@ class DataclassCreationFields(NamedTuple): """Fields required for the fields parameter of make_dataclass.""" name: str - field_type: Type + field_type: type field: dataclasses.Field - def to_tuple(self) -> Tuple[str, Type, dataclasses.Field]: + def to_tuple(self) -> tuple[str, type, dataclasses.Field]: # fields parameter wants (name, type, Field) return self.name, self.field_type, self.field @@ -125,7 +121,7 @@ def get_default_factory_for_field( def ensure_all_auto_fields_in_pydantic( - model: Type[BaseModel], auto_fields: Set[str], cls_name: str + model: type[BaseModel], auto_fields: set[str], cls_name: str ) -> None: compat = PydanticCompat.from_model(model) # Raise error if user defined a strawberry.auto field not present in the model diff --git a/strawberry/ext/dataclasses/dataclasses.py b/strawberry/ext/dataclasses/dataclasses.py index 820649c18d..7c4bb89b80 100644 --- a/strawberry/ext/dataclasses/dataclasses.py +++ b/strawberry/ext/dataclasses/dataclasses.py @@ -11,15 +11,15 @@ _field_init, _init_param, ) -from typing import Any, Dict, List +from typing import Any def dataclass_init_fn( - fields: List[Any], + fields: list[Any], frozen: bool, has_post_init: bool, self_name: str, - globals_: Dict[str, Any], + globals_: dict[str, Any], ) -> Any: """Create an __init__ function for a dataclass. diff --git a/strawberry/ext/mypy_plugin.py b/strawberry/ext/mypy_plugin.py index 77a8a406f0..797428019f 100644 --- a/strawberry/ext/mypy_plugin.py +++ b/strawberry/ext/mypy_plugin.py @@ -8,10 +8,7 @@ TYPE_CHECKING, Any, Callable, - List, Optional, - Set, - Tuple, Union, cast, ) @@ -60,7 +57,7 @@ except ImportError: TypeVarDef = TypeVarType -PYDANTIC_VERSION: Optional[Tuple[int, ...]] = None +PYDANTIC_VERSION: Optional[tuple[int, ...]] = None # To be compatible with user who don't use pydantic try: @@ -326,7 +323,7 @@ def add_static_method_to_class( api: Union[SemanticAnalyzerPluginInterface, CheckerPluginInterface], cls: ClassDef, name: str, - args: List[Argument], + args: list[Argument], return_type: Type, tvar_def: Optional[TypeVarType] = None, ) -> None: @@ -410,7 +407,7 @@ def strawberry_pydantic_class_callback(ctx: ClassDefContext) -> None: model_type = cast(Instance, _get_type_for_expr(model_expression, ctx.api)) # these are the fields that the user added to the strawberry type - new_strawberry_fields: Set[str] = set() + new_strawberry_fields: set[str] = set() # TODO: think about inheritance for strawberry? for stmt in ctx.cls.defs.body: @@ -418,7 +415,7 @@ def strawberry_pydantic_class_callback(ctx: ClassDefContext) -> None: lhs = cast(NameExpr, stmt.lvalues[0]) new_strawberry_fields.add(lhs.name) - pydantic_fields: Set[PydanticModelField] = set() + pydantic_fields: set[PydanticModelField] = set() try: fields = model_type.type.metadata[PYDANTIC_METADATA_KEY]["fields"] for data in fields.items(): @@ -438,7 +435,7 @@ def strawberry_pydantic_class_callback(ctx: ClassDefContext) -> None: ctx.reason, ) - potentially_missing_fields: Set[PydanticModelField] = { + potentially_missing_fields: set[PydanticModelField] = { f for f in pydantic_fields if f.name not in new_strawberry_fields } @@ -449,7 +446,7 @@ def strawberry_pydantic_class_callback(ctx: ClassDefContext) -> None: This means that the user is using all_fields=True """ is_all_fields: bool = len(potentially_missing_fields) == len(pydantic_fields) - missing_pydantic_fields: Set[PydanticModelField] = ( + missing_pydantic_fields: set[PydanticModelField] = ( potentially_missing_fields if not is_all_fields else set() ) diff --git a/strawberry/extensions/__init__.py b/strawberry/extensions/__init__.py index a94f69651b..e77c8e1590 100644 --- a/strawberry/extensions/__init__.py +++ b/strawberry/extensions/__init__.py @@ -1,5 +1,4 @@ import warnings -from typing import Type from .add_validation_rules import AddValidationRules from .base_extension import LifecycleStep, SchemaExtension @@ -13,7 +12,7 @@ from .validation_cache import ValidationCache -def __getattr__(name: str) -> Type[SchemaExtension]: +def __getattr__(name: str) -> type[SchemaExtension]: if name == "Extension": warnings.warn( ( @@ -29,16 +28,16 @@ def __getattr__(name: str) -> Type[SchemaExtension]: __all__ = [ - "FieldExtension", - "SchemaExtension", - "LifecycleStep", "AddValidationRules", "DisableValidation", - "ParserCache", - "QueryDepthLimiter", + "FieldExtension", "IgnoreContext", - "ValidationCache", + "LifecycleStep", "MaskErrors", "MaxAliasesLimiter", "MaxTokensLimiter", + "ParserCache", + "QueryDepthLimiter", + "SchemaExtension", + "ValidationCache", ] diff --git a/strawberry/extensions/add_validation_rules.py b/strawberry/extensions/add_validation_rules.py index 763ef70b05..51ea96bf78 100644 --- a/strawberry/extensions/add_validation_rules.py +++ b/strawberry/extensions/add_validation_rules.py @@ -1,10 +1,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Iterator, List, Type +from typing import TYPE_CHECKING from strawberry.extensions.base_extension import SchemaExtension if TYPE_CHECKING: + from collections.abc import Iterator + from graphql import ASTValidationRule @@ -37,9 +39,9 @@ def enter_field(self, node, *args) -> None: ``` """ - validation_rules: List[Type[ASTValidationRule]] + validation_rules: list[type[ASTValidationRule]] - def __init__(self, validation_rules: List[Type[ASTValidationRule]]) -> None: + def __init__(self, validation_rules: list[type[ASTValidationRule]]) -> None: self.validation_rules = validation_rules def on_operation(self) -> Iterator[None]: diff --git a/strawberry/extensions/base_extension.py b/strawberry/extensions/base_extension.py index ff8d75d7ce..d3279d53e9 100644 --- a/strawberry/extensions/base_extension.py +++ b/strawberry/extensions/base_extension.py @@ -1,7 +1,7 @@ from __future__ import annotations from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Dict, Set +from typing import TYPE_CHECKING, Any, Callable from strawberry.utils.await_maybe import AsyncIteratorOrIterator, AwaitableOrValue @@ -60,7 +60,7 @@ def resolve( ) -> AwaitableOrValue[object]: return _next(root, info, *args, **kwargs) - def get_results(self) -> AwaitableOrValue[Dict[str, Any]]: + def get_results(self) -> AwaitableOrValue[dict[str, Any]]: return {} @classmethod @@ -71,11 +71,11 @@ def _implements_resolve(cls) -> bool: Hook = Callable[[SchemaExtension], AsyncIteratorOrIterator[None]] -HOOK_METHODS: Set[str] = { +HOOK_METHODS: set[str] = { SchemaExtension.on_operation.__name__, SchemaExtension.on_validate.__name__, SchemaExtension.on_parse.__name__, SchemaExtension.on_execute.__name__, } -__all__ = ["SchemaExtension", "Hook", "HOOK_METHODS", "LifecycleStep"] +__all__ = ["HOOK_METHODS", "Hook", "LifecycleStep", "SchemaExtension"] diff --git a/strawberry/extensions/context.py b/strawberry/extensions/context.py index f9a231fed4..040ce83143 100644 --- a/strawberry/extensions/context.py +++ b/strawberry/extensions/context.py @@ -8,15 +8,9 @@ from typing import ( TYPE_CHECKING, Any, - AsyncContextManager, - AsyncIterator, Callable, - ContextManager, - Iterator, - List, NamedTuple, Optional, - Type, Union, ) @@ -24,6 +18,7 @@ from strawberry.utils.await_maybe import AwaitableOrValue, await_maybe if TYPE_CHECKING: + from collections.abc import AsyncIterator, Iterator from types import TracebackType from strawberry.extensions.base_extension import Hook @@ -31,17 +26,23 @@ class WrappedHook(NamedTuple): extension: SchemaExtension - hook: Callable[..., Union[AsyncContextManager[None], ContextManager[None]]] + hook: Callable[ + ..., + Union[ + contextlib.AbstractAsyncContextManager[None], + contextlib.AbstractContextManager[None], + ], + ] is_async: bool class ExtensionContextManagerBase: __slots__ = ( - "hooks", - "deprecation_message", - "default_hook", "async_exit_stack", + "default_hook", + "deprecation_message", "exit_stack", + "hooks", ) def __init_subclass__(cls) -> None: @@ -56,8 +57,8 @@ def __init_subclass__(cls) -> None: LEGACY_ENTER: str LEGACY_EXIT: str - def __init__(self, extensions: List[SchemaExtension]) -> None: - self.hooks: List[WrappedHook] = [] + def __init__(self, extensions: list[SchemaExtension]) -> None: + self.hooks: list[WrappedHook] = [] self.default_hook: Hook = getattr(SchemaExtension, self.HOOK_NAME) for extension in extensions: hook = self.get_hook(extension) @@ -179,7 +180,7 @@ def __enter__(self) -> None: def __exit__( self, - exc_type: Optional[Type[BaseException]], + exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: @@ -198,7 +199,7 @@ async def __aenter__(self) -> None: async def __aexit__( self, - exc_type: Optional[Type[BaseException]], + exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: diff --git a/strawberry/extensions/directives.py b/strawberry/extensions/directives.py index 82f9efe146..cee1889189 100644 --- a/strawberry/extensions/directives.py +++ b/strawberry/extensions/directives.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Dict, Tuple +from typing import TYPE_CHECKING, Any, Callable from strawberry.extensions import SchemaExtension from strawberry.types.nodes import convert_arguments @@ -62,7 +62,7 @@ def process_directive( directive: DirectiveNode, value: Any, info: GraphQLResolveInfo, -) -> Tuple[StrawberryDirective, Dict[str, Any]]: +) -> tuple[StrawberryDirective, dict[str, Any]]: """Get a `StrawberryDirective` from ``directive` and prepare its arguments.""" directive_name = directive.name.value schema: Schema = info.schema._strawberry_schema # type: ignore diff --git a/strawberry/extensions/disable_validation.py b/strawberry/extensions/disable_validation.py index cd9aeafaed..d4731c045c 100644 --- a/strawberry/extensions/disable_validation.py +++ b/strawberry/extensions/disable_validation.py @@ -1,4 +1,4 @@ -from typing import Iterator +from collections.abc import Iterator from strawberry.extensions.base_extension import SchemaExtension diff --git a/strawberry/extensions/field_extension.py b/strawberry/extensions/field_extension.py index afb70e53ab..6683247a98 100644 --- a/strawberry/extensions/field_extension.py +++ b/strawberry/extensions/field_extension.py @@ -1,8 +1,9 @@ from __future__ import annotations import itertools +from collections.abc import Awaitable from functools import cached_property -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Union +from typing import TYPE_CHECKING, Any, Callable, Union if TYPE_CHECKING: from typing_extensions import TypeAlias diff --git a/strawberry/extensions/mask_errors.py b/strawberry/extensions/mask_errors.py index 5cb0ffbaa7..8a6ec571e2 100644 --- a/strawberry/extensions/mask_errors.py +++ b/strawberry/extensions/mask_errors.py @@ -1,4 +1,5 @@ -from typing import Callable, Iterator, List +from collections.abc import Iterator +from typing import Callable from graphql.error import GraphQLError @@ -36,7 +37,7 @@ def on_operation(self) -> Iterator[None]: yield result = self.execution_context.result if result and result.errors: - processed_errors: List[GraphQLError] = [] + processed_errors: list[GraphQLError] = [] for error in result.errors: if self.should_mask_error(error): processed_errors.append(self.anonymise_error(error)) diff --git a/strawberry/extensions/max_aliases.py b/strawberry/extensions/max_aliases.py index b1d0aba268..c77d524222 100644 --- a/strawberry/extensions/max_aliases.py +++ b/strawberry/extensions/max_aliases.py @@ -1,4 +1,4 @@ -from typing import Type, Union +from typing import Union from graphql import ( ExecutableDefinitionNode, @@ -35,7 +35,7 @@ def __init__(self, max_alias_count: int) -> None: super().__init__([validator]) -def create_validator(max_alias_count: int) -> Type[ValidationRule]: +def create_validator(max_alias_count: int) -> type[ValidationRule]: """Create a validator that checks the number of aliases in a document. Args: diff --git a/strawberry/extensions/max_tokens.py b/strawberry/extensions/max_tokens.py index 60accd8a8a..7c54cbf264 100644 --- a/strawberry/extensions/max_tokens.py +++ b/strawberry/extensions/max_tokens.py @@ -1,4 +1,4 @@ -from typing import Iterator +from collections.abc import Iterator from strawberry.extensions.base_extension import SchemaExtension diff --git a/strawberry/extensions/parser_cache.py b/strawberry/extensions/parser_cache.py index 39b28c039b..fdfdd9b23b 100644 --- a/strawberry/extensions/parser_cache.py +++ b/strawberry/extensions/parser_cache.py @@ -1,5 +1,6 @@ +from collections.abc import Iterator from functools import lru_cache -from typing import Iterator, Optional +from typing import Optional from strawberry.extensions.base_extension import SchemaExtension from strawberry.schema.execute import parse_document diff --git a/strawberry/extensions/pyinstrument.py b/strawberry/extensions/pyinstrument.py index 53dd9fe66a..89b2486d7f 100644 --- a/strawberry/extensions/pyinstrument.py +++ b/strawberry/extensions/pyinstrument.py @@ -1,12 +1,15 @@ from __future__ import annotations from pathlib import Path -from typing import Iterator +from typing import TYPE_CHECKING from pyinstrument import Profiler from strawberry.extensions.base_extension import SchemaExtension +if TYPE_CHECKING: + from collections.abc import Iterator + class PyInstrument(SchemaExtension): """Extension to profile the execution time of resolvers using PyInstrument.""" diff --git a/strawberry/extensions/query_depth_limiter.py b/strawberry/extensions/query_depth_limiter.py index f23831a00b..ef801120ae 100644 --- a/strawberry/extensions/query_depth_limiter.py +++ b/strawberry/extensions/query_depth_limiter.py @@ -30,12 +30,9 @@ import re from dataclasses import dataclass from typing import ( + TYPE_CHECKING, Callable, - Dict, - Iterable, - List, Optional, - Type, Union, ) @@ -61,12 +58,15 @@ from strawberry.extensions import AddValidationRules from strawberry.extensions.utils import is_introspection_key +if TYPE_CHECKING: + from collections.abc import Iterable + IgnoreType = Union[Callable[[str], bool], re.Pattern, str] FieldArgumentType = Union[ - bool, int, float, str, List["FieldArgumentType"], Dict[str, "FieldArgumentType"] + bool, int, float, str, list["FieldArgumentType"], dict[str, "FieldArgumentType"] ] -FieldArgumentsType = Dict[str, FieldArgumentType] +FieldArgumentsType = dict[str, FieldArgumentType] @dataclass @@ -99,7 +99,7 @@ class QueryDepthLimiter(AddValidationRules): def __init__( self, max_depth: int, - callback: Optional[Callable[[Dict[str, int]], None]] = None, + callback: Optional[Callable[[dict[str, int]], None]] = None, should_ignore: Optional[ShouldIgnoreType] = None, ) -> None: """Initialize the QueryDepthLimiter. @@ -123,8 +123,8 @@ def __init__( def create_validator( max_depth: int, should_ignore: Optional[ShouldIgnoreType], - callback: Optional[Callable[[Dict[str, int]], None]] = None, -) -> Type[ValidationRule]: + callback: Optional[Callable[[dict[str, int]], None]] = None, +) -> type[ValidationRule]: class DepthLimitValidator(ValidationRule): def __init__(self, validation_context: ValidationContext) -> None: document = validation_context.document @@ -154,7 +154,7 @@ def __init__(self, validation_context: ValidationContext) -> None: def get_fragments( definitions: Iterable[DefinitionNode], -) -> Dict[str, FragmentDefinitionNode]: +) -> dict[str, FragmentDefinitionNode]: fragments = {} for definition in definitions: if isinstance(definition, FragmentDefinitionNode): @@ -167,7 +167,7 @@ def get_fragments( # We can basically treat those the same def get_queries_and_mutations( definitions: Iterable[DefinitionNode], -) -> Dict[str, OperationDefinitionNode]: +) -> dict[str, OperationDefinitionNode]: operations = {} for definition in definitions: @@ -214,7 +214,7 @@ def get_field_arguments( def determine_depth( node: Node, - fragments: Dict[str, FragmentDefinitionNode], + fragments: dict[str, FragmentDefinitionNode], depth_so_far: int, max_depth: int, context: ValidationContext, @@ -294,7 +294,7 @@ def determine_depth( raise TypeError(f"Depth crawler cannot handle: {node.kind}") # pragma: no cover -def is_ignored(node: FieldNode, ignore: Optional[List[IgnoreType]] = None) -> bool: +def is_ignored(node: FieldNode, ignore: Optional[list[IgnoreType]] = None) -> bool: if ignore is None: return False diff --git a/strawberry/extensions/runner.py b/strawberry/extensions/runner.py index 1e249fc1e8..3f307807f8 100644 --- a/strawberry/extensions/runner.py +++ b/strawberry/extensions/runner.py @@ -1,7 +1,7 @@ from __future__ import annotations import inspect -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Optional from strawberry.extensions.context import ( ExecutingContextManager, @@ -18,12 +18,12 @@ class SchemaExtensionsRunner: - extensions: List[SchemaExtension] + extensions: list[SchemaExtension] def __init__( self, execution_context: ExecutionContext, - extensions: Optional[List[SchemaExtension]] = None, + extensions: Optional[list[SchemaExtension]] = None, ) -> None: self.execution_context = execution_context self.extensions = extensions or [] @@ -40,8 +40,8 @@ def parsing(self) -> ParsingContextManager: def executing(self) -> ExecutingContextManager: return ExecutingContextManager(self.extensions) - def get_extensions_results_sync(self) -> Dict[str, Any]: - data: Dict[str, Any] = {} + def get_extensions_results_sync(self) -> dict[str, Any]: + data: dict[str, Any] = {} for extension in self.extensions: if inspect.iscoroutinefunction(extension.get_results): msg = "Cannot use async extension hook during sync execution" @@ -50,8 +50,8 @@ def get_extensions_results_sync(self) -> Dict[str, Any]: return data - async def get_extensions_results(self, ctx: ExecutionContext) -> Dict[str, Any]: - data: Dict[str, Any] = {} + async def get_extensions_results(self, ctx: ExecutionContext) -> dict[str, Any]: + data: dict[str, Any] = {} for extension in self.extensions: data.update(await await_maybe(extension.get_results())) diff --git a/strawberry/extensions/tracing/apollo.py b/strawberry/extensions/tracing/apollo.py index 2245f54643..2cc9583e07 100644 --- a/strawberry/extensions/tracing/apollo.py +++ b/strawberry/extensions/tracing/apollo.py @@ -4,7 +4,7 @@ import time from datetime import datetime, timezone from inspect import isawaitable -from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional from strawberry.extensions import SchemaExtension from strawberry.extensions.utils import get_path_from_info @@ -12,6 +12,8 @@ from .utils import should_skip_tracing if TYPE_CHECKING: + from collections.abc import Generator + from graphql import GraphQLResolveInfo DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" @@ -25,20 +27,20 @@ class ApolloStepStats: start_offset: int duration: int - def to_json(self) -> Dict[str, Any]: + def to_json(self) -> dict[str, Any]: return {"startOffset": self.start_offset, "duration": self.duration} @dataclasses.dataclass class ApolloResolverStats: - path: List[str] + path: list[str] parent_type: Any field_name: str return_type: Any start_offset: int duration: Optional[int] = None - def to_json(self) -> Dict[str, Any]: + def to_json(self) -> dict[str, Any]: return { "path": self.path, "field_name": self.field_name, @@ -51,9 +53,9 @@ def to_json(self) -> Dict[str, Any]: @dataclasses.dataclass class ApolloExecutionStats: - resolvers: List[ApolloResolverStats] + resolvers: list[ApolloResolverStats] - def to_json(self) -> Dict[str, Any]: + def to_json(self) -> dict[str, Any]: return {"resolvers": [resolver.to_json() for resolver in self.resolvers]} @@ -67,7 +69,7 @@ class ApolloTracingStats: parsing: ApolloStepStats version: int = 1 - def to_json(self) -> Dict[str, Any]: + def to_json(self) -> dict[str, Any]: return { "version": self.version, "startTime": self.start_time.strftime(DATETIME_FORMAT), @@ -81,7 +83,7 @@ def to_json(self) -> Dict[str, Any]: class ApolloTracingExtension(SchemaExtension): def __init__(self, execution_context: ExecutionContext) -> None: - self._resolver_stats: List[ApolloResolverStats] = [] + self._resolver_stats: list[ApolloResolverStats] = [] self.execution_context = execution_context def on_operation(self) -> Generator[None, None, None]: @@ -121,7 +123,7 @@ def stats(self) -> ApolloTracingStats: ), ) - def get_results(self) -> Dict[str, Dict[str, Any]]: + def get_results(self) -> dict[str, dict[str, Any]]: return {"tracing": self.stats.to_json()} async def resolve( diff --git a/strawberry/extensions/tracing/datadog.py b/strawberry/extensions/tracing/datadog.py index 2b8c676ca2..02a722f29f 100644 --- a/strawberry/extensions/tracing/datadog.py +++ b/strawberry/extensions/tracing/datadog.py @@ -3,7 +3,7 @@ import hashlib from functools import cached_property from inspect import isawaitable -from typing import TYPE_CHECKING, Any, Callable, Generator, Iterator, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional from ddtrace import Span, tracer @@ -11,6 +11,8 @@ from strawberry.extensions.tracing.utils import should_skip_tracing if TYPE_CHECKING: + from collections.abc import Generator, Iterator + from graphql import GraphQLResolveInfo from strawberry.types.execution import ExecutionContext diff --git a/strawberry/extensions/tracing/opentelemetry.py b/strawberry/extensions/tracing/opentelemetry.py index 133f3a373b..686d311c96 100644 --- a/strawberry/extensions/tracing/opentelemetry.py +++ b/strawberry/extensions/tracing/opentelemetry.py @@ -6,12 +6,7 @@ TYPE_CHECKING, Any, Callable, - Dict, - FrozenSet, - Generator, - Iterable, Optional, - Set, Union, ) @@ -24,6 +19,8 @@ from .utils import should_skip_tracing if TYPE_CHECKING: + from collections.abc import Generator, Iterable + from graphql import GraphQLResolveInfo from opentelemetry.trace import Span, Tracer @@ -32,12 +29,12 @@ DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" -ArgFilter = Callable[[Dict[str, Any], "GraphQLResolveInfo"], Dict[str, Any]] +ArgFilter = Callable[[dict[str, Any], "GraphQLResolveInfo"], dict[str, Any]] class OpenTelemetryExtension(SchemaExtension): _arg_filter: Optional[ArgFilter] - _span_holder: Dict[LifecycleStep, Span] + _span_holder: dict[LifecycleStep, Span] _tracer: Tracer def __init__( @@ -100,8 +97,8 @@ def on_parse(self) -> Generator[None, None, None]: self._span_holder[LifecycleStep.PARSE].end() def filter_resolver_args( - self, args: Dict[str, Any], info: GraphQLResolveInfo - ) -> Dict[str, Any]: + self, args: dict[str, Any], info: GraphQLResolveInfo + ) -> dict[str, Any]: if not self._arg_filter: return args return self._arg_filter(deepcopy(args), info) @@ -132,7 +129,7 @@ def convert_to_allowed_types(self, value: Any) -> Any: else: return str(value) - def convert_set_to_allowed_types(self, value: Union[Set, FrozenSet]) -> str: + def convert_set_to_allowed_types(self, value: Union[set, frozenset]) -> str: return ( "{" + ", ".join(str(self.convert_to_allowed_types(x)) for x in value) + "}" ) diff --git a/strawberry/extensions/utils.py b/strawberry/extensions/utils.py index f857db5eaf..e47ec681e2 100644 --- a/strawberry/extensions/utils.py +++ b/strawberry/extensions/utils.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List, Union +from typing import TYPE_CHECKING, Union if TYPE_CHECKING: from graphql import GraphQLResolveInfo @@ -26,7 +26,7 @@ def is_introspection_field(info: GraphQLResolveInfo) -> bool: return False -def get_path_from_info(info: GraphQLResolveInfo) -> List[str]: +def get_path_from_info(info: GraphQLResolveInfo) -> list[str]: path = info.path elements = [] @@ -37,4 +37,4 @@ def get_path_from_info(info: GraphQLResolveInfo) -> List[str]: return elements[::-1] -__all__ = ["is_introspection_key", "is_introspection_field", "get_path_from_info"] +__all__ = ["get_path_from_info", "is_introspection_field", "is_introspection_key"] diff --git a/strawberry/extensions/validation_cache.py b/strawberry/extensions/validation_cache.py index 6c9cc153c4..35f4baa02c 100644 --- a/strawberry/extensions/validation_cache.py +++ b/strawberry/extensions/validation_cache.py @@ -1,5 +1,6 @@ +from collections.abc import Iterator from functools import lru_cache -from typing import Iterator, Optional +from typing import Optional from strawberry.extensions.base_extension import SchemaExtension from strawberry.schema.execute import validate_document diff --git a/strawberry/fastapi/context.py b/strawberry/fastapi/context.py index d6af1642cd..1c79711698 100644 --- a/strawberry/fastapi/context.py +++ b/strawberry/fastapi/context.py @@ -1,13 +1,13 @@ -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from starlette.background import BackgroundTasks from starlette.requests import Request from starlette.responses import Response from starlette.websockets import WebSocket -CustomContext = Union["BaseContext", Dict[str, Any]] +CustomContext = Union["BaseContext", dict[str, Any]] MergedContext = Union[ - "BaseContext", Dict[str, Union[Any, BackgroundTasks, Request, Response, WebSocket]] + "BaseContext", dict[str, Union[Any, BackgroundTasks, Request, Response, WebSocket]] ] diff --git a/strawberry/fastapi/router.py b/strawberry/fastapi/router.py index 51ccd9731f..28d158bc4d 100644 --- a/strawberry/fastapi/router.py +++ b/strawberry/fastapi/router.py @@ -6,21 +6,15 @@ from typing import ( TYPE_CHECKING, Any, - AsyncIterator, - Awaitable, Callable, - Dict, - List, Optional, - Sequence, - Type, Union, cast, ) from typing_extensions import TypeGuard from starlette import status -from starlette.background import BackgroundTasks # noqa: TCH002 +from starlette.background import BackgroundTasks # noqa: TC002 from starlette.requests import HTTPConnection, Request from starlette.responses import ( HTMLResponse, @@ -44,6 +38,7 @@ from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL if TYPE_CHECKING: + from collections.abc import AsyncIterator, Awaitable, Sequence from enum import Enum from starlette.routing import BaseRoute @@ -138,16 +133,16 @@ def __init__( ), connection_init_wait_timeout: timedelta = timedelta(minutes=1), prefix: str = "", - tags: Optional[List[Union[str, Enum]]] = None, + tags: Optional[list[Union[str, Enum]]] = None, dependencies: Optional[Sequence[params.Depends]] = None, - default_response_class: Type[Response] = Default(JSONResponse), - responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None, - callbacks: Optional[List[BaseRoute]] = None, - routes: Optional[List[BaseRoute]] = None, + default_response_class: type[Response] = Default(JSONResponse), + responses: Optional[dict[Union[int, str], dict[str, Any]]] = None, + callbacks: Optional[list[BaseRoute]] = None, + routes: Optional[list[BaseRoute]] = None, redirect_slashes: bool = True, default: Optional[ASGIApp] = None, dependency_overrides_provider: Optional[Any] = None, - route_class: Type[APIRoute] = APIRoute, + route_class: type[APIRoute] = APIRoute, on_startup: Optional[Sequence[Callable[[], Any]]] = None, on_shutdown: Optional[Sequence[Callable[[], Any]]] = None, lifespan: Optional[Lifespan[Any]] = None, @@ -297,7 +292,7 @@ async def create_streaming_response( request: Request, stream: Callable[[], AsyncIterator[str]], sub_response: Response, - headers: Dict[str, str], + headers: dict[str, str], ) -> Response: return StreamingResponse( stream(), diff --git a/strawberry/federation/__init__.py b/strawberry/federation/__init__.py index 1ba201045f..e398148ef0 100644 --- a/strawberry/federation/__init__.py +++ b/strawberry/federation/__init__.py @@ -2,24 +2,24 @@ from .enum import enum, enum_value from .field import field from .mutation import mutation -from .object_type import input, interface, interface_object, type +from .object_type import input, interface, interface_object, type # noqa: A004 from .scalar import scalar from .schema import Schema from .schema_directive import schema_directive from .union import union __all__ = [ + "Schema", "argument", "enum", "enum_value", "field", - "mutation", "input", "interface", "interface_object", - "type", + "mutation", "scalar", - "Schema", "schema_directive", + "type", "union", ] diff --git a/strawberry/federation/argument.py b/strawberry/federation/argument.py index 2268211a76..9c42fad6cc 100644 --- a/strawberry/federation/argument.py +++ b/strawberry/federation/argument.py @@ -1,4 +1,5 @@ -from typing import Iterable, Optional +from collections.abc import Iterable +from typing import Optional from strawberry.types.arguments import StrawberryArgumentAnnotation diff --git a/strawberry/federation/enum.py b/strawberry/federation/enum.py index 0fdebcfedb..cdf85575a7 100644 --- a/strawberry/federation/enum.py +++ b/strawberry/federation/enum.py @@ -4,8 +4,6 @@ TYPE_CHECKING, Any, Callable, - Iterable, - List, Optional, Union, overload, @@ -15,6 +13,8 @@ from strawberry.types.enum import enum_value as base_enum_value if TYPE_CHECKING: + from collections.abc import Iterable + from strawberry.enum import EnumType, EnumValueDefinition @@ -47,8 +47,8 @@ def enum( directives: Iterable[object] = (), authenticated: bool = False, inaccessible: bool = False, - policy: Optional[List[List[str]]] = None, - requires_scopes: Optional[List[List[str]]] = None, + policy: Optional[list[list[str]]] = None, + requires_scopes: Optional[list[list[str]]] = None, tags: Optional[Iterable[str]] = (), ) -> EnumType: ... @@ -62,8 +62,8 @@ def enum( directives: Iterable[object] = (), authenticated: bool = False, inaccessible: bool = False, - policy: Optional[List[List[str]]] = None, - requires_scopes: Optional[List[List[str]]] = None, + policy: Optional[list[list[str]]] = None, + requires_scopes: Optional[list[list[str]]] = None, tags: Optional[Iterable[str]] = (), ) -> Callable[[EnumType], EnumType]: ... @@ -76,8 +76,8 @@ def enum( directives=(), authenticated: bool = False, inaccessible: bool = False, - policy: Optional[List[List[str]]] = None, - requires_scopes: Optional[List[List[str]]] = None, + policy: Optional[list[list[str]]] = None, + requires_scopes: Optional[list[list[str]]] = None, tags: Optional[Iterable[str]] = (), ) -> Union[EnumType, Callable[[EnumType], EnumType]]: """Registers the enum in the GraphQL type system. diff --git a/strawberry/federation/field.py b/strawberry/federation/field.py index 35d4f34b47..c9789e650c 100644 --- a/strawberry/federation/field.py +++ b/strawberry/federation/field.py @@ -5,11 +5,7 @@ TYPE_CHECKING, Any, Callable, - Iterable, - List, Optional, - Sequence, - Type, TypeVar, Union, overload, @@ -19,6 +15,7 @@ from strawberry.types.unset import UNSET if TYPE_CHECKING: + from collections.abc import Iterable, Sequence from typing_extensions import Literal from strawberry.extensions.field_extension import FieldExtension @@ -40,20 +37,20 @@ def field( authenticated: bool = False, external: bool = False, inaccessible: bool = False, - policy: Optional[List[List[str]]] = None, - provides: Optional[List[str]] = None, + policy: Optional[list[list[str]]] = None, + provides: Optional[list[str]] = None, override: Optional[Union[Override, str]] = None, - requires: Optional[List[str]] = None, - requires_scopes: Optional[List[List[str]]] = None, + requires: Optional[list[str]] = None, + requires_scopes: Optional[list[list[str]]] = None, tags: Optional[Iterable[str]] = (), shareable: bool = False, init: Literal[False] = False, - permission_classes: Optional[List[Type[BasePermission]]] = None, + permission_classes: Optional[list[type[BasePermission]]] = None, deprecation_reason: Optional[str] = None, default: Any = UNSET, default_factory: Union[Callable[..., object], object] = UNSET, directives: Sequence[object] = (), - extensions: Optional[List[FieldExtension]] = None, + extensions: Optional[list[FieldExtension]] = None, graphql_type: Optional[Any] = None, ) -> T: ... @@ -67,20 +64,20 @@ def field( authenticated: bool = False, external: bool = False, inaccessible: bool = False, - policy: Optional[List[List[str]]] = None, - provides: Optional[List[str]] = None, + policy: Optional[list[list[str]]] = None, + provides: Optional[list[str]] = None, override: Optional[Union[Override, str]] = None, - requires: Optional[List[str]] = None, - requires_scopes: Optional[List[List[str]]] = None, + requires: Optional[list[str]] = None, + requires_scopes: Optional[list[list[str]]] = None, tags: Optional[Iterable[str]] = (), shareable: bool = False, init: Literal[True] = True, - permission_classes: Optional[List[Type[BasePermission]]] = None, + permission_classes: Optional[list[type[BasePermission]]] = None, deprecation_reason: Optional[str] = None, default: Any = UNSET, default_factory: Union[Callable[..., object], object] = UNSET, directives: Sequence[object] = (), - extensions: Optional[List[FieldExtension]] = None, + extensions: Optional[list[FieldExtension]] = None, graphql_type: Optional[Any] = None, ) -> Any: ... @@ -95,19 +92,19 @@ def field( authenticated: bool = False, external: bool = False, inaccessible: bool = False, - policy: Optional[List[List[str]]] = None, - provides: Optional[List[str]] = None, + policy: Optional[list[list[str]]] = None, + provides: Optional[list[str]] = None, override: Optional[Union[Override, str]] = None, - requires: Optional[List[str]] = None, - requires_scopes: Optional[List[List[str]]] = None, + requires: Optional[list[str]] = None, + requires_scopes: Optional[list[list[str]]] = None, tags: Optional[Iterable[str]] = (), shareable: bool = False, - permission_classes: Optional[List[Type[BasePermission]]] = None, + permission_classes: Optional[list[type[BasePermission]]] = None, deprecation_reason: Optional[str] = None, default: Any = UNSET, default_factory: Union[Callable[..., object], object] = UNSET, directives: Sequence[object] = (), - extensions: Optional[List[FieldExtension]] = None, + extensions: Optional[list[FieldExtension]] = None, graphql_type: Optional[Any] = None, ) -> StrawberryField: ... @@ -121,19 +118,19 @@ def field( authenticated: bool = False, external: bool = False, inaccessible: bool = False, - policy: Optional[List[List[str]]] = None, - provides: Optional[List[str]] = None, + policy: Optional[list[list[str]]] = None, + provides: Optional[list[str]] = None, override: Optional[Union[Override, str]] = None, - requires: Optional[List[str]] = None, - requires_scopes: Optional[List[List[str]]] = None, + requires: Optional[list[str]] = None, + requires_scopes: Optional[list[list[str]]] = None, tags: Optional[Iterable[str]] = (), shareable: bool = False, - permission_classes: Optional[List[Type[BasePermission]]] = None, + permission_classes: Optional[list[type[BasePermission]]] = None, deprecation_reason: Optional[str] = None, default: Any = dataclasses.MISSING, default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, directives: Sequence[object] = (), - extensions: Optional[List[FieldExtension]] = None, + extensions: Optional[list[FieldExtension]] = None, graphql_type: Optional[Any] = None, # This init parameter is used by PyRight to determine whether this field # is added in the constructor or not. It is not used to change diff --git a/strawberry/federation/object_type.py b/strawberry/federation/object_type.py index 7005302370..e31f3d2cae 100644 --- a/strawberry/federation/object_type.py +++ b/strawberry/federation/object_type.py @@ -1,11 +1,9 @@ +import builtins +from collections.abc import Iterable, Sequence from typing import ( TYPE_CHECKING, Callable, - Iterable, - List, Optional, - Sequence, - Type, TypeVar, Union, overload, @@ -23,7 +21,7 @@ from .schema_directives import Key -T = TypeVar("T", bound=Type) +T = TypeVar("T", bound=builtins.type) def _impl_type( @@ -38,8 +36,8 @@ def _impl_type( extend: bool = False, shareable: bool = False, inaccessible: bool = UNSET, - policy: Optional[List[List[str]]] = None, - requires_scopes: Optional[List[List[str]]] = None, + policy: Optional[list[list[str]]] = None, + requires_scopes: Optional[list[list[str]]] = None, tags: Iterable[str] = (), is_input: bool = False, is_interface: bool = False, @@ -115,8 +113,8 @@ def type( extend: bool = False, inaccessible: bool = UNSET, keys: Iterable[Union["Key", str]] = (), - policy: Optional[List[List[str]]] = None, - requires_scopes: Optional[List[List[str]]] = None, + policy: Optional[list[list[str]]] = None, + requires_scopes: Optional[list[list[str]]] = None, shareable: bool = False, tags: Iterable[str] = (), ) -> T: ... @@ -137,8 +135,8 @@ def type( extend: bool = False, inaccessible: bool = UNSET, keys: Iterable[Union["Key", str]] = (), - policy: Optional[List[List[str]]] = None, - requires_scopes: Optional[List[List[str]]] = None, + policy: Optional[list[list[str]]] = None, + requires_scopes: Optional[list[list[str]]] = None, shareable: bool = False, tags: Iterable[str] = (), ) -> Callable[[T], T]: ... @@ -154,8 +152,8 @@ def type( extend: bool = False, inaccessible: bool = UNSET, keys: Iterable[Union["Key", str]] = (), - policy: Optional[List[List[str]]] = None, - requires_scopes: Optional[List[List[str]]] = None, + policy: Optional[list[list[str]]] = None, + requires_scopes: Optional[list[list[str]]] = None, shareable: bool = False, tags: Iterable[str] = (), ): @@ -247,8 +245,8 @@ def interface( authenticated: bool = False, inaccessible: bool = UNSET, keys: Iterable[Union["Key", str]] = (), - policy: Optional[List[List[str]]] = None, - requires_scopes: Optional[List[List[str]]] = None, + policy: Optional[list[list[str]]] = None, + requires_scopes: Optional[list[list[str]]] = None, tags: Iterable[str] = (), ) -> T: ... @@ -267,8 +265,8 @@ def interface( authenticated: bool = False, inaccessible: bool = UNSET, keys: Iterable[Union["Key", str]] = (), - policy: Optional[List[List[str]]] = None, - requires_scopes: Optional[List[List[str]]] = None, + policy: Optional[list[list[str]]] = None, + requires_scopes: Optional[list[list[str]]] = None, tags: Iterable[str] = (), ) -> Callable[[T], T]: ... @@ -282,8 +280,8 @@ def interface( authenticated: bool = False, inaccessible: bool = UNSET, keys: Iterable[Union["Key", str]] = (), - policy: Optional[List[List[str]]] = None, - requires_scopes: Optional[List[List[str]]] = None, + policy: Optional[list[list[str]]] = None, + requires_scopes: Optional[list[list[str]]] = None, tags: Iterable[str] = (), ): return _impl_type( @@ -316,8 +314,8 @@ def interface_object( authenticated: bool = False, inaccessible: bool = UNSET, keys: Iterable[Union["Key", str]] = (), - policy: Optional[List[List[str]]] = None, - requires_scopes: Optional[List[List[str]]] = None, + policy: Optional[list[list[str]]] = None, + requires_scopes: Optional[list[list[str]]] = None, tags: Iterable[str] = (), ) -> T: ... @@ -336,8 +334,8 @@ def interface_object( authenticated: bool = False, inaccessible: bool = UNSET, keys: Iterable[Union["Key", str]] = (), - policy: Optional[List[List[str]]] = None, - requires_scopes: Optional[List[List[str]]] = None, + policy: Optional[list[list[str]]] = None, + requires_scopes: Optional[list[list[str]]] = None, tags: Iterable[str] = (), ) -> Callable[[T], T]: ... @@ -351,8 +349,8 @@ def interface_object( authenticated: bool = False, inaccessible: bool = UNSET, keys: Iterable[Union["Key", str]] = (), - policy: Optional[List[List[str]]] = None, - requires_scopes: Optional[List[List[str]]] = None, + policy: Optional[list[list[str]]] = None, + requires_scopes: Optional[list[list[str]]] = None, tags: Iterable[str] = (), ): return _impl_type( @@ -371,4 +369,4 @@ def interface_object( ) -__all__ = ["type", "input", "interface", "interface_object"] +__all__ = ["input", "interface", "interface_object", "type"] diff --git a/strawberry/federation/scalar.py b/strawberry/federation/scalar.py index 9cb1183318..0bb5310ea3 100644 --- a/strawberry/federation/scalar.py +++ b/strawberry/federation/scalar.py @@ -1,9 +1,8 @@ import sys +from collections.abc import Iterable from typing import ( Any, Callable, - Iterable, - List, NewType, Optional, TypeVar, @@ -36,8 +35,8 @@ def scalar( directives: Iterable[object] = (), authenticated: bool = False, inaccessible: bool = False, - policy: Optional[List[List[str]]] = None, - requires_scopes: Optional[List[List[str]]] = None, + policy: Optional[list[list[str]]] = None, + requires_scopes: Optional[list[list[str]]] = None, tags: Optional[Iterable[str]] = (), ) -> Callable[[_T], _T]: ... @@ -55,8 +54,8 @@ def scalar( directives: Iterable[object] = (), authenticated: bool = False, inaccessible: bool = False, - policy: Optional[List[List[str]]] = None, - requires_scopes: Optional[List[List[str]]] = None, + policy: Optional[list[list[str]]] = None, + requires_scopes: Optional[list[list[str]]] = None, tags: Optional[Iterable[str]] = (), ) -> _T: ... @@ -73,8 +72,8 @@ def scalar( directives: Iterable[object] = (), authenticated: bool = False, inaccessible: bool = False, - policy: Optional[List[List[str]]] = None, - requires_scopes: Optional[List[List[str]]] = None, + policy: Optional[list[list[str]]] = None, + requires_scopes: Optional[list[list[str]]] = None, tags: Optional[Iterable[str]] = (), ) -> Any: """Annotates a class or type as a GraphQL custom scalar. diff --git a/strawberry/federation/schema.py b/strawberry/federation/schema.py index e9cbf1625e..a8acc0e289 100644 --- a/strawberry/federation/schema.py +++ b/strawberry/federation/schema.py @@ -1,18 +1,12 @@ from collections import defaultdict +from collections.abc import Iterable, Mapping from functools import cached_property from itertools import chain from typing import ( TYPE_CHECKING, Any, - DefaultDict, - Dict, - Iterable, - List, - Mapping, NewType, Optional, - Set, - Type, Union, cast, ) @@ -50,17 +44,17 @@ class Schema(BaseSchema): def __init__( self, - query: Optional[Type] = None, - mutation: Optional[Type] = None, - subscription: Optional[Type] = None, + query: Optional[type] = None, + mutation: Optional[type] = None, + subscription: Optional[type] = None, # TODO: we should update directives' type in the main schema - directives: Iterable[Type] = (), - types: Iterable[Type] = (), - extensions: Iterable[Union[Type["SchemaExtension"], "SchemaExtension"]] = (), - execution_context_class: Optional[Type["GraphQLExecutionContext"]] = None, + directives: Iterable[type] = (), + types: Iterable[type] = (), + extensions: Iterable[Union[type["SchemaExtension"], "SchemaExtension"]] = (), + execution_context_class: Optional[type["GraphQLExecutionContext"]] = None, config: Optional["StrawberryConfig"] = None, scalar_overrides: Optional[ - Dict[object, Union[Type, "ScalarWrapper", "ScalarDefinition"]] + dict[object, Union[type, "ScalarWrapper", "ScalarDefinition"]] ] = None, schema_directives: Iterable[object] = (), enable_federation_2: bool = False, @@ -91,11 +85,11 @@ def __init__( def _get_federation_query_type( self, - query: Optional[Type[WithStrawberryObjectDefinition]], - mutation: Optional[Type[WithStrawberryObjectDefinition]], - subscription: Optional[Type[WithStrawberryObjectDefinition]], - additional_types: Iterable[Type[WithStrawberryObjectDefinition]], - ) -> Type: + query: Optional[type[WithStrawberryObjectDefinition]], + mutation: Optional[type[WithStrawberryObjectDefinition]], + subscription: Optional[type[WithStrawberryObjectDefinition]], + additional_types: Iterable[type[WithStrawberryObjectDefinition]], + ) -> type: """Returns a new query type that includes the _service field. If the query type is provided, it will be used as the base for the new @@ -129,7 +123,7 @@ def service() -> Service: entity_type = _get_entity_type(query, mutation, subscription, additional_types) if entity_type: - self.entities_resolver.__annotations__["return"] = List[ + self.entities_resolver.__annotations__["return"] = list[ Optional[entity_type] # type: ignore ] @@ -156,8 +150,8 @@ def service() -> Service: return query_type def entities_resolver( - self, info: Info, representations: List[FederationAny] - ) -> List[FederationAny]: + self, info: Info, representations: list[FederationAny] + ) -> list[FederationAny]: results = [] for representation in representations: @@ -212,10 +206,10 @@ def _remove_resolvable_field(self) -> None: directive.resolvable = UNSET @cached_property - def schema_directives_in_use(self) -> List[object]: + def schema_directives_in_use(self) -> list[object]: all_graphql_types = self._schema.type_map.values() - directives: List[object] = [] + directives: list[object] = [] for type_ in all_graphql_types: strawberry_definition = type_.extensions.get("strawberry-definition") @@ -236,7 +230,7 @@ def schema_directives_in_use(self) -> List[object]: def _add_link_for_composed_directive( self, directive: "StrawberrySchemaDirective", - directive_by_url: Mapping[str, Set[str]], + directive_by_url: Mapping[str, set[str]], ) -> None: if not isinstance(directive, StrawberryFederationSchemaDirective): return @@ -256,11 +250,11 @@ def _add_link_for_composed_directive( directive_by_url[import_url].add(f"@{name}") def _add_link_directives( - self, additional_directives: Optional[List[object]] = None + self, additional_directives: Optional[list[object]] = None ) -> None: from .schema_directives import FederationDirective, Link - directive_by_url: DefaultDict[str, Set[str]] = defaultdict(set) + directive_by_url: defaultdict[str, set[str]] = defaultdict(set) additional_directives = additional_directives or [] @@ -274,7 +268,7 @@ def _add_link_directives( f"@{directive.imported_from.name}" ) - link_directives: List[object] = [ + link_directives: list[object] = [ Link( url=url, import_=list(sorted(directives)), @@ -284,10 +278,10 @@ def _add_link_directives( self.schema_directives = self.schema_directives + link_directives - def _add_compose_directives(self) -> List["ComposeDirective"]: + def _add_compose_directives(self) -> list["ComposeDirective"]: from .schema_directives import ComposeDirective - compose_directives: List[ComposeDirective] = [] + compose_directives: list[ComposeDirective] = [] for directive in self.schema_directives_in_use: definition = directive.__strawberry_directive__ # type: ignore @@ -318,10 +312,10 @@ def _warn_for_federation_directives(self) -> None: def _get_entity_type( - query: Optional[Type[WithStrawberryObjectDefinition]], - mutation: Optional[Type[WithStrawberryObjectDefinition]], - subscription: Optional[Type[WithStrawberryObjectDefinition]], - additional_types: Iterable[Type[WithStrawberryObjectDefinition]], + query: Optional[type[WithStrawberryObjectDefinition]], + mutation: Optional[type[WithStrawberryObjectDefinition]], + subscription: Optional[type[WithStrawberryObjectDefinition]], + additional_types: Iterable[type[WithStrawberryObjectDefinition]], ) -> Optional[StrawberryUnion]: # recursively iterate over the schema to find all types annotated with @key # if no types are annotated with @key, then the _Entity union and Query._entities @@ -330,7 +324,7 @@ def _get_entity_type( entity_types = set() # need a stack to keep track of the types we need to visit - stack: List[Any] = [query, mutation, subscription, *additional_types] + stack: list[Any] = [query, mutation, subscription, *additional_types] seen = set() diff --git a/strawberry/federation/schema_directive.py b/strawberry/federation/schema_directive.py index 06ffe76336..6ff9138e83 100644 --- a/strawberry/federation/schema_directive.py +++ b/strawberry/federation/schema_directive.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Callable, List, Optional, Type, TypeVar +from typing import Callable, Optional, TypeVar from typing_extensions import dataclass_transform from strawberry.directive import directive_field @@ -19,7 +19,7 @@ class StrawberryFederationSchemaDirective(StrawberrySchemaDirective): compose_options: Optional[ComposeOptions] = None -T = TypeVar("T", bound=Type) +T = TypeVar("T", bound=type) @dataclass_transform( @@ -29,19 +29,19 @@ class StrawberryFederationSchemaDirective(StrawberrySchemaDirective): ) def schema_directive( *, - locations: List[Location], + locations: list[Location], description: Optional[str] = None, name: Optional[str] = None, repeatable: bool = False, print_definition: bool = True, compose: bool = False, import_url: Optional[str] = None, -) -> Callable[..., T]: +) -> Callable[[T], T]: def _wrap(cls: T) -> T: cls = _wrap_dataclass(cls) # type: ignore fields = _get_fields(cls, {}) - cls.__strawberry_directive__ = StrawberryFederationSchemaDirective( + cls.__strawberry_directive__ = StrawberryFederationSchemaDirective( # type: ignore[attr-defined] python_name=cls.__name__, graphql_name=name, locations=locations, diff --git a/strawberry/federation/schema_directives.py b/strawberry/federation/schema_directives.py index 0f9232d7f3..c249f7d212 100644 --- a/strawberry/federation/schema_directives.py +++ b/strawberry/federation/schema_directives.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import ClassVar, List, Optional +from typing import ClassVar, Optional from strawberry import directive_field from strawberry.schema_directive import Location, schema_directive @@ -84,14 +84,14 @@ class Link: url: Optional[str] as_: Optional[str] = directive_field(name="as") for_: Optional[LinkPurpose] = directive_field(name="for") - import_: Optional[List[Optional[LinkImport]]] = directive_field(name="import") + import_: Optional[list[Optional[LinkImport]]] = directive_field(name="import") def __init__( self, url: Optional[str] = UNSET, as_: Optional[str] = UNSET, for_: Optional[LinkPurpose] = UNSET, - import_: Optional[List[Optional[LinkImport]]] = UNSET, + import_: Optional[list[Optional[LinkImport]]] = UNSET, ) -> None: self.url = url self.as_ = as_ @@ -204,7 +204,7 @@ class Authenticated(FederationDirective): print_definition=False, ) class RequiresScopes(FederationDirective): - scopes: "List[List[str]]" + scopes: "list[list[str]]" imported_from: ClassVar[ImportedFrom] = ImportedFrom( name="requiresScopes", url="https://specs.apollo.dev/federation/v2.7" ) @@ -222,25 +222,25 @@ class RequiresScopes(FederationDirective): print_definition=False, ) class Policy(FederationDirective): - policies: "List[List[str]]" + policies: "list[list[str]]" imported_from: ClassVar[ImportedFrom] = ImportedFrom( name="policy", url="https://specs.apollo.dev/federation/v2.7" ) __all__ = [ + "Authenticated", + "ComposeDirective", "External", - "Requires", - "Provides", + "Inaccessible", + "InterfaceObject", "Key", - "Shareable", "Link", - "Tag", "Override", - "Inaccessible", - "ComposeDirective", - "InterfaceObject", - "Authenticated", - "RequiresScopes", "Policy", + "Provides", + "Requires", + "RequiresScopes", + "Shareable", + "Tag", ] diff --git a/strawberry/federation/union.py b/strawberry/federation/union.py index c61e4f5578..b3d1c8a2ce 100644 --- a/strawberry/federation/union.py +++ b/strawberry/federation/union.py @@ -1,4 +1,5 @@ -from typing import Any, Collection, Iterable, Optional, Type +from collections.abc import Collection, Iterable +from typing import Any, Optional from strawberry.types.union import StrawberryUnion from strawberry.types.union import union as base_union @@ -6,7 +7,7 @@ def union( name: str, - types: Optional[Collection[Type[Any]]] = None, + types: Optional[Collection[type[Any]]] = None, *, description: Optional[str] = None, directives: Iterable[object] = (), diff --git a/strawberry/field_extensions/input_mutation.py b/strawberry/field_extensions/input_mutation.py index f251d016da..6625ecf005 100644 --- a/strawberry/field_extensions/input_mutation.py +++ b/strawberry/field_extensions/input_mutation.py @@ -3,7 +3,6 @@ from typing import ( TYPE_CHECKING, Any, - Dict, ) import strawberry @@ -27,7 +26,7 @@ def apply(self, field: StrawberryField) -> None: assert resolver name = field.graphql_name or to_camel_case(resolver.name) - type_dict: Dict[str, Any] = { + type_dict: dict[str, Any] = { "__doc__": f"Input data for `{name}` mutation", "__annotations__": {}, } diff --git a/strawberry/file_uploads/utils.py b/strawberry/file_uploads/utils.py index 0fec3a2c82..ea08c2200b 100644 --- a/strawberry/file_uploads/utils.py +++ b/strawberry/file_uploads/utils.py @@ -1,12 +1,13 @@ import copy -from typing import Any, Dict, Mapping +from collections.abc import Mapping +from typing import Any def replace_placeholders_with_files( - operations_with_placeholders: Dict[str, Any], + operations_with_placeholders: dict[str, Any], files_map: Mapping[str, Any], files: Mapping[str, Any], -) -> Dict[str, Any]: +) -> dict[str, Any]: # TODO: test this with missing variables in operations_with_placeholders operations = copy.deepcopy(operations_with_placeholders) diff --git a/strawberry/flask/views.py b/strawberry/flask/views.py index d93f480b1e..f730e5408e 100644 --- a/strawberry/flask/views.py +++ b/strawberry/flask/views.py @@ -4,7 +4,6 @@ from typing import ( TYPE_CHECKING, Any, - Mapping, Optional, Union, cast, @@ -23,6 +22,8 @@ from strawberry.http.typevars import Context, RootValue if TYPE_CHECKING: + from collections.abc import Mapping + from flask.typing import ResponseReturnValue from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE @@ -203,6 +204,6 @@ async def create_websocket_response( __all__ = [ - "GraphQLView", "AsyncGraphQLView", + "GraphQLView", ] diff --git a/strawberry/http/__init__.py b/strawberry/http/__init__.py index 86f295b9c3..326c14a4fd 100644 --- a/strawberry/http/__init__.py +++ b/strawberry/http/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Optional from typing_extensions import Literal, TypedDict if TYPE_CHECKING: @@ -9,9 +9,9 @@ class GraphQLHTTPResponse(TypedDict, total=False): - data: Optional[Dict[str, object]] - errors: Optional[List[object]] - extensions: Optional[Dict[str, object]] + data: Optional[dict[str, object]] + errors: Optional[list[object]] + extensions: Optional[dict[str, object]] def process_result(result: ExecutionResult) -> GraphQLHTTPResponse: @@ -30,13 +30,13 @@ class GraphQLRequestData: # query is optional here as it can be added by an extensions # (for example an extension for persisted queries) query: Optional[str] - variables: Optional[Dict[str, Any]] + variables: Optional[dict[str, Any]] operation_name: Optional[str] protocol: Literal["http", "multipart-subscription"] = "http" __all__ = [ "GraphQLHTTPResponse", - "process_result", "GraphQLRequestData", + "process_result", ] diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index d10b688003..799fec8ba3 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -2,18 +2,13 @@ import asyncio import contextlib import json +from collections.abc import AsyncGenerator, Mapping from datetime import timedelta from typing import ( Any, - AsyncGenerator, Callable, - Dict, Generic, - List, - Mapping, Optional, - Tuple, - Type, Union, cast, overload, @@ -124,10 +119,10 @@ class AsyncBaseHTTPView( ], AsyncWebSocketAdapter, ] - graphql_transport_ws_handler_class: Type[ + graphql_transport_ws_handler_class: type[ BaseGraphQLTransportWSHandler[Context, RootValue] ] = BaseGraphQLTransportWSHandler[Context, RootValue] - graphql_ws_handler_class: Type[BaseGraphQLWSHandler[Context, RootValue]] = ( + graphql_ws_handler_class: type[BaseGraphQLWSHandler[Context, RootValue]] = ( BaseGraphQLWSHandler[Context, RootValue] ) @@ -163,7 +158,7 @@ async def create_streaming_response( request: Request, stream: Callable[[], AsyncGenerator[str, None]], sub_response: SubResponse, - headers: Dict[str, str], + headers: dict[str, str], ) -> Response: raise ValueError("Multipart responses are not supported") @@ -220,7 +215,7 @@ async def execute_operation( allowed_operation_types=allowed_operation_types, ) - async def parse_multipart(self, request: AsyncHTTPRequestAdapter) -> Dict[str, str]: + async def parse_multipart(self, request: AsyncHTTPRequestAdapter) -> dict[str, str]: try: form_data = await request.get_form_data() except ValueError as e: @@ -243,7 +238,7 @@ async def parse_multipart(self, request: AsyncHTTPRequestAdapter) -> Dict[str, s raise HTTPException(400, "File(s) missing in form data") from e def _handle_errors( - self, errors: List[GraphQLError], response_data: GraphQLHTTPResponse + self, errors: list[GraphQLError], response_data: GraphQLHTTPResponse ) -> None: """Hook to allow custom handling of errors, used by the Sentry Integration.""" @@ -380,7 +375,7 @@ def _stream_with_heartbeat( self, stream: Callable[[], AsyncGenerator[str, None]] ) -> Callable[[], AsyncGenerator[str, None]]: """Adds a heartbeat to the stream, to prevent the connection from closing when there are no messages being sent.""" - queue: asyncio.Queue[Tuple[bool, Any]] = asyncio.Queue(1) + queue: asyncio.Queue[tuple[bool, Any]] = asyncio.Queue(1) cancelling = False @@ -448,7 +443,7 @@ async def stream() -> AsyncGenerator[str, None]: async def parse_multipart_subscriptions( self, request: AsyncHTTPRequestAdapter - ) -> Dict[str, str]: + ) -> dict[str, str]: if request.method == "GET": return self.parse_query_params(request.query_params) @@ -489,7 +484,7 @@ async def process_result( async def on_ws_connect( self, context: Context - ) -> Union[UnsetType, None, Dict[str, object]]: + ) -> Union[UnsetType, None, dict[str, object]]: return UNSET diff --git a/strawberry/http/base.py b/strawberry/http/base.py index ffb41bf751..1cb1904888 100644 --- a/strawberry/http/base.py +++ b/strawberry/http/base.py @@ -1,5 +1,6 @@ import json -from typing import Any, Dict, Generic, List, Mapping, Optional, Union +from collections.abc import Mapping +from typing import Any, Generic, Optional, Union from typing_extensions import Protocol from strawberry.http.ides import GraphQL_IDE, get_graphql_ide_html @@ -11,7 +12,7 @@ class BaseRequestProtocol(Protocol): @property - def query_params(self) -> Mapping[str, Optional[Union[str, List[str]]]]: ... + def query_params(self) -> Mapping[str, Optional[Union[str, list[str]]]]: ... @property def method(self) -> HTTPMethod: ... @@ -49,7 +50,7 @@ def decode_json(self, data: Union[str, bytes]) -> object: def encode_json(self, data: object) -> str: return json.dumps(data) - def parse_query_params(self, params: QueryParams) -> Dict[str, Any]: + def parse_query_params(self, params: QueryParams) -> dict[str, Any]: params = dict(params) if "variables" in params: @@ -65,7 +66,7 @@ def graphql_ide_html(self) -> str: return get_graphql_ide_html(graphql_ide=self.graphql_ide) def _is_multipart_subscriptions( - self, content_type: str, params: Dict[str, str] + self, content_type: str, params: dict[str, str] ) -> bool: if content_type != "multipart/mixed": return False diff --git a/strawberry/http/ides.py b/strawberry/http/ides.py index 9680a0277a..63d7d4af10 100644 --- a/strawberry/http/ides.py +++ b/strawberry/http/ides.py @@ -22,4 +22,4 @@ def get_graphql_ide_html( return template -__all__ = ["get_graphql_ide_html", "GraphQL_IDE"] +__all__ = ["GraphQL_IDE", "get_graphql_ide_html"] diff --git a/strawberry/http/parse_content_type.py b/strawberry/http/parse_content_type.py index d28be1a337..da54798b64 100644 --- a/strawberry/http/parse_content_type.py +++ b/strawberry/http/parse_content_type.py @@ -1,8 +1,7 @@ from email.message import Message -from typing import Dict, Tuple -def parse_content_type(content_type: str) -> Tuple[str, Dict[str, str]]: +def parse_content_type(content_type: str) -> tuple[str, dict[str, str]]: """Parse a content type header into a mime-type and a dictionary of parameters.""" email = Message() email["content-type"] = content_type diff --git a/strawberry/http/sync_base_view.py b/strawberry/http/sync_base_view.py index df770e0541..555d7708d0 100644 --- a/strawberry/http/sync_base_view.py +++ b/strawberry/http/sync_base_view.py @@ -1,12 +1,10 @@ import abc import json +from collections.abc import Mapping from typing import ( Any, Callable, - Dict, Generic, - List, - Mapping, Optional, Union, ) @@ -126,7 +124,7 @@ def execute_operation( allowed_operation_types=allowed_operation_types, ) - def parse_multipart(self, request: SyncHTTPRequestAdapter) -> Dict[str, str]: + def parse_multipart(self, request: SyncHTTPRequestAdapter) -> dict[str, str]: operations = self.parse_json(request.post_data.get("operations", "{}")) files_map = self.parse_json(request.post_data.get("map", "{}")) @@ -159,7 +157,7 @@ def parse_http_body(self, request: SyncHTTPRequestAdapter) -> GraphQLRequestData ) def _handle_errors( - self, errors: List[GraphQLError], response_data: GraphQLHTTPResponse + self, errors: list[GraphQLError], response_data: GraphQLHTTPResponse ) -> None: """Hook to allow custom handling of errors, used by the Sentry Integration.""" diff --git a/strawberry/http/temporal_response.py b/strawberry/http/temporal_response.py index 47bf52fa9d..8c93a54161 100644 --- a/strawberry/http/temporal_response.py +++ b/strawberry/http/temporal_response.py @@ -1,11 +1,10 @@ from dataclasses import dataclass, field -from typing import Dict @dataclass class TemporalResponse: status_code: int = 200 - headers: Dict[str, str] = field(default_factory=dict) + headers: dict[str, str] = field(default_factory=dict) __all__ = ["TemporalResponse"] diff --git a/strawberry/http/types.py b/strawberry/http/types.py index 28794c9afa..3aeaa7ec4b 100644 --- a/strawberry/http/types.py +++ b/strawberry/http/types.py @@ -1,4 +1,5 @@ -from typing import Any, Mapping, Optional +from collections.abc import Mapping +from typing import Any, Optional from typing_extensions import Literal, TypedDict HTTPMethod = Literal[ @@ -13,4 +14,4 @@ class FormData(TypedDict): form: Mapping[str, Any] -__all__ = ["HTTPMethod", "QueryParams", "FormData"] +__all__ = ["FormData", "HTTPMethod", "QueryParams"] diff --git a/strawberry/litestar/controller.py b/strawberry/litestar/controller.py index 2ffc456df8..0e91aacc97 100644 --- a/strawberry/litestar/controller.py +++ b/strawberry/litestar/controller.py @@ -8,14 +8,8 @@ from typing import ( TYPE_CHECKING, Any, - AsyncGenerator, - AsyncIterator, Callable, - Dict, - FrozenSet, Optional, - Tuple, - Type, TypedDict, Union, cast, @@ -60,7 +54,7 @@ from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL if TYPE_CHECKING: - from collections.abc import Mapping + from collections.abc import AsyncGenerator, AsyncIterator, Mapping from litestar.types import AnyCallable, Dependencies from strawberry.http import GraphQLHTTPResponse @@ -97,7 +91,7 @@ class WebSocketContextDict(TypedDict): MergedContext = Union[ - BaseContext, WebSocketContextDict, HTTPContextDict, Dict[str, Any] + BaseContext, WebSocketContextDict, HTTPContextDict, dict[str, Any] ] @@ -254,11 +248,11 @@ class GraphQLController( websocket_adapter_class = LitestarWebSocketAdapter allow_queries_via_get: bool = True - graphiql_allowed_accept: FrozenSet[str] = frozenset({"text/html", "*/*"}) + graphiql_allowed_accept: frozenset[str] = frozenset({"text/html", "*/*"}) graphql_ide: Optional[GraphQL_IDE] = "graphiql" debug: bool = False connection_init_wait_timeout: timedelta = timedelta(minutes=1) - protocols: Tuple[str, ...] = ( + protocols: tuple[str, ...] = ( GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL, ) @@ -329,7 +323,7 @@ async def create_streaming_response( request: Request, stream: Callable[[], AsyncIterator[str]], sub_response: Response, - headers: Dict[str, str], + headers: dict[str, str], ) -> Response: return Stream( stream(), @@ -416,13 +410,13 @@ def make_graphql_controller( root_value_getter: Optional[AnyCallable] = None, # TODO: context typevar context_getter: Optional[AnyCallable] = None, - subscription_protocols: Tuple[str, ...] = ( + subscription_protocols: tuple[str, ...] = ( GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL, ), connection_init_wait_timeout: timedelta = timedelta(minutes=1), multipart_uploads_enabled: bool = False, -) -> Type[GraphQLController]: # sourcery skip: move-assign +) -> type[GraphQLController]: # sourcery skip: move-assign if context_getter is None: custom_context_getter_ = _none_custom_context_getter else: @@ -474,6 +468,6 @@ class _GraphQLController(GraphQLController): __all__ = [ - "make_graphql_controller", "GraphQLController", + "make_graphql_controller", ] diff --git a/strawberry/parent.py b/strawberry/parent.py index f95f90b8e7..99223028ba 100644 --- a/strawberry/parent.py +++ b/strawberry/parent.py @@ -1,5 +1,4 @@ -from typing import TypeVar -from typing_extensions import Annotated +from typing import Annotated, TypeVar class StrawberryParent: ... diff --git a/strawberry/permission.py b/strawberry/permission.py index 91f796069f..9624cce535 100644 --- a/strawberry/permission.py +++ b/strawberry/permission.py @@ -7,11 +7,7 @@ from typing import ( TYPE_CHECKING, Any, - Awaitable, - Dict, - List, Optional, - Type, Union, ) @@ -25,6 +21,8 @@ from strawberry.utils.await_maybe import await_maybe if TYPE_CHECKING: + from collections.abc import Awaitable + from graphql import GraphQLError, GraphQLErrorExtensions from strawberry.extensions.field_extension import ( @@ -56,7 +54,7 @@ def has_permission(self, source, info, **kwargs): error_extensions: Optional[GraphQLErrorExtensions] = None - error_class: Type[GraphQLError] = StrawberryGraphQLError + error_class: type[GraphQLError] = StrawberryGraphQLError _schema_directive: Optional[object] = None @@ -138,7 +136,7 @@ class PermissionExtension(FieldExtension): def __init__( self, - permissions: List[BasePermission], + permissions: list[BasePermission], use_directives: bool = True, fail_silently: bool = False, ) -> None: @@ -189,7 +187,7 @@ def resolve( next_: SyncExtensionResolver, source: Any, info: Info, - **kwargs: Dict[str, Any], + **kwargs: dict[str, Any], ) -> Any: """Checks if the permission should be accepted and raises an exception if not.""" for permission in self.permissions: @@ -202,7 +200,7 @@ async def resolve_async( next_: AsyncExtensionResolver, source: Any, info: Info, - **kwargs: Dict[str, Any], + **kwargs: dict[str, Any], ) -> Any: for permission in self.permissions: has_permission = await await_maybe( diff --git a/strawberry/printer/ast_from_value.py b/strawberry/printer/ast_from_value.py index 7242d7b9df..eb2da54853 100644 --- a/strawberry/printer/ast_from_value.py +++ b/strawberry/printer/ast_from_value.py @@ -1,8 +1,9 @@ from __future__ import annotations import re +from collections.abc import Mapping from math import isfinite -from typing import TYPE_CHECKING, Any, Mapping, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast from graphql.language import ( BooleanValueNode, diff --git a/strawberry/printer/printer.py b/strawberry/printer/printer.py index d5ac653251..d51ae32a2c 100644 --- a/strawberry/printer/printer.py +++ b/strawberry/printer/printer.py @@ -5,19 +5,14 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - List, Optional, - Set, - Tuple, - Type, TypeVar, Union, cast, overload, ) -from graphql import is_union_type +from graphql import GraphQLObjectType, GraphQLSchema, is_union_type from graphql.language.printer import print_ast from graphql.type import ( is_enum_type, @@ -38,7 +33,11 @@ from graphql.utilities.print_schema import print_type as original_print_type from strawberry.schema_directive import Location, StrawberrySchemaDirective -from strawberry.types.base import StrawberryContainer, has_object_definition +from strawberry.types.base import ( + StrawberryContainer, + StrawberryObjectDefinition, + has_object_definition, +) from strawberry.types.enum import EnumDefinition from strawberry.types.scalar import ScalarWrapper from strawberry.types.unset import UNSET @@ -64,18 +63,18 @@ @dataclasses.dataclass class PrintExtras: - directives: Set[str] = dataclasses.field(default_factory=set) - types: Set[type] = dataclasses.field(default_factory=set) + directives: set[str] = dataclasses.field(default_factory=set) + types: set[type] = dataclasses.field(default_factory=set) @overload -def _serialize_dataclasses(value: Dict[_T, object]) -> Dict[_T, object]: ... +def _serialize_dataclasses(value: dict[_T, object]) -> dict[_T, object]: ... @overload def _serialize_dataclasses( - value: Union[List[object], Tuple[object]], -) -> List[object]: ... + value: Union[list[object], tuple[object]], +) -> list[object]: ... @overload @@ -94,7 +93,7 @@ def _serialize_dataclasses(value): def print_schema_directive_params( - directive: GraphQLDirective, values: Dict[str, Any] + directive: GraphQLDirective, values: dict[str, Any] ) -> str: params = [] for name, arg in directive.args.items(): @@ -189,7 +188,7 @@ def print_argument_directives( def print_args( - args: Dict[str, GraphQLArgument], + args: dict[str, GraphQLArgument], indentation: str = "", *, schema: BaseSchema, @@ -225,7 +224,12 @@ def print_args( ) -def print_fields(type_: Type, schema: BaseSchema, *, extras: PrintExtras) -> str: +def print_fields( + type_: GraphQLObjectType, + schema: BaseSchema, + *, + extras: PrintExtras, +) -> str: from strawberry.schema.schema_converter import GraphQLCoreConverter fields = [] @@ -320,11 +324,13 @@ def print_enum( ) -def print_extends(type_: Type, schema: BaseSchema) -> str: +def print_extends(type_: GraphQLObjectType, schema: BaseSchema) -> str: from strawberry.schema.schema_converter import GraphQLCoreConverter - strawberry_type = type_.extensions and type_.extensions.get( - GraphQLCoreConverter.DEFINITION_BACKREF + strawberry_type = cast( + Optional[StrawberryObjectDefinition], + type_.extensions + and type_.extensions.get(GraphQLCoreConverter.DEFINITION_BACKREF), ) if strawberry_type and strawberry_type.extend: @@ -334,12 +340,14 @@ def print_extends(type_: Type, schema: BaseSchema) -> str: def print_type_directives( - type_: Type, schema: BaseSchema, *, extras: PrintExtras + type_: GraphQLObjectType, schema: BaseSchema, *, extras: PrintExtras ) -> str: from strawberry.schema.schema_converter import GraphQLCoreConverter - strawberry_type = type_.extensions and type_.extensions.get( - GraphQLCoreConverter.DEFINITION_BACKREF + strawberry_type = cast( + Optional[StrawberryObjectDefinition], + type_.extensions + and type_.extensions.get(GraphQLCoreConverter.DEFINITION_BACKREF), ) if not strawberry_type: @@ -354,7 +362,7 @@ def print_type_directives( for directive in strawberry_type.directives or [] if any( location in allowed_locations - for location in directive.__strawberry_directive__.locations + for location in directive.__strawberry_directive__.locations # type: ignore[attr-defined] ) ) @@ -550,21 +558,33 @@ def is_builtin_directive(directive: GraphQLDirective) -> bool: def print_schema(schema: BaseSchema) -> str: - graphql_core_schema = schema._schema # type: ignore + graphql_core_schema = cast( + GraphQLSchema, + schema._schema, # type: ignore + ) extras = PrintExtras() - directives = filter( - lambda n: not is_builtin_directive(n), graphql_core_schema.directives - ) + filtered_directives = [ + directive + for directive in graphql_core_schema.directives + if not is_builtin_directive(directive) + ] + type_map = graphql_core_schema.type_map - types = filter(is_defined_type, map(type_map.get, sorted(type_map))) + types = [ + type_ + for type_name in sorted(type_map) + if is_defined_type(type_ := type_map[type_name]) + ] types_printed = [_print_type(type_, schema, extras=extras) for type_ in types] schema_definition = print_schema_definition(schema, extras=extras) - directives = filter( - None, [print_directive(directive, schema=schema) for directive in directives] - ) + directives = [ + printed_directive + for directive in filtered_directives + if (printed_directive := print_directive(directive, schema=schema)) is not None + ] def _name_getter(type_: Any) -> str: if hasattr(type_, "name"): diff --git a/strawberry/quart/views.py b/strawberry/quart/views.py index 528a987abc..93afe401aa 100644 --- a/strawberry/quart/views.py +++ b/strawberry/quart/views.py @@ -1,6 +1,6 @@ import warnings -from collections.abc import Mapping -from typing import TYPE_CHECKING, AsyncGenerator, Callable, Dict, Optional, cast +from collections.abc import AsyncGenerator, Mapping +from typing import TYPE_CHECKING, Callable, Optional, cast from typing_extensions import TypeGuard from quart import Request, Response, request @@ -111,7 +111,7 @@ async def create_streaming_response( request: Request, stream: Callable[[], AsyncGenerator[str, None]], sub_response: Response, - headers: Dict[str, str], + headers: dict[str, str], ) -> Response: return ( stream(), diff --git a/strawberry/relay/exceptions.py b/strawberry/relay/exceptions.py index 4150633d8c..78536f1c3a 100644 --- a/strawberry/relay/exceptions.py +++ b/strawberry/relay/exceptions.py @@ -2,7 +2,7 @@ from collections.abc import Callable from functools import cached_property -from typing import TYPE_CHECKING, Optional, Type, cast +from typing import TYPE_CHECKING, Optional, cast from strawberry.exceptions.exception import StrawberryException from strawberry.exceptions.utils.source_finder import SourceFinder @@ -13,7 +13,7 @@ class NodeIDAnnotationError(StrawberryException): - def __init__(self, message: str, cls: Type) -> None: + def __init__(self, message: str, cls: type) -> None: self.cls = cls self.message = message @@ -41,7 +41,7 @@ def exception_source(self) -> Optional[ExceptionSource]: class RelayWrongAnnotationError(StrawberryException): - def __init__(self, field_name: str, cls: Type) -> None: + def __init__(self, field_name: str, cls: type) -> None: self.cls = cls self.field_name = field_name @@ -85,7 +85,7 @@ def __init__(self, field_name: str, resolver: StrawberryResolver) -> None: ) self.suggestion = ( "To fix this error you can annootate your resolver to return " - "one of the following options: `List[]`, " + "one of the following options: `list[]`, " "`Iterator[]`, `Iterable[]`, " "`AsyncIterator[]`, `AsyncIterable[]`, " "`Generator[, Any, Any]` and " diff --git a/strawberry/relay/fields.py b/strawberry/relay/fields.py index 32673cfc10..347fd22169 100644 --- a/strawberry/relay/fields.py +++ b/strawberry/relay/fields.py @@ -4,28 +4,26 @@ import dataclasses import inspect from collections import defaultdict -from collections.abc import AsyncIterable -from typing import ( - TYPE_CHECKING, - Any, +from collections.abc import ( + AsyncIterable, AsyncIterator, Awaitable, - Callable, - DefaultDict, - Dict, - ForwardRef, Iterable, Iterator, - List, Mapping, - Optional, Sequence, - Tuple, - Type, +) +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + Callable, + ForwardRef, + Optional, Union, cast, ) -from typing_extensions import Annotated, get_args, get_origin +from typing_extensions import get_args, get_origin from strawberry.annotation import StrawberryAnnotation from strawberry.extensions.field_extension import ( @@ -100,7 +98,7 @@ def resolver( def get_node_list_resolver( self, field: StrawberryField - ) -> Callable[[Info, List[GlobalID]], Union[List[Node], Awaitable[List[Node]]]]: + ) -> Callable[[Info, list[GlobalID]], Union[list[Node], Awaitable[list[Node]]]]: type_ = field.type assert isinstance(type_, StrawberryList) is_optional = isinstance(type_.of_type, StrawberryOptional) @@ -108,14 +106,14 @@ def get_node_list_resolver( def resolver( info: Info, ids: Annotated[ - List[GlobalID], argument(description="The IDs of the objects.") + list[GlobalID], argument(description="The IDs of the objects.") ], - ) -> Union[List[Node], Awaitable[List[Node]]]: - nodes_map: DefaultDict[Type[Node], List[str]] = defaultdict(list) + ) -> Union[list[Node], Awaitable[list[Node]]]: + nodes_map: defaultdict[type[Node], list[str]] = defaultdict(list) # Store the index of the node in the list of nodes of the same type # so that we can return them in the same order while also supporting # different types - index_map: Dict[GlobalID, Tuple[Type[Node], int]] = {} + index_map: dict[GlobalID, tuple[type[Node], int]] = {} for gid in ids: node_t = gid.resolve_type(info) nodes_map[node_t].append(gid.node_id) @@ -143,7 +141,7 @@ def resolver( if awaitable_nodes or asyncgen_nodes: - async def resolve(resolved: Any = resolved_nodes) -> List[Node]: + async def resolve(resolved: Any = resolved_nodes) -> list[Node]: resolved.update( zip( [ @@ -182,7 +180,7 @@ async def resolve(resolved: Any = resolved_nodes) -> List[Node]: class ConnectionExtension(FieldExtension): - connection_type: Type[Connection[Node]] + connection_type: type[Connection[Node]] def apply(self, field: StrawberryField) -> None: field.arguments = [ @@ -268,7 +266,7 @@ def apply(self, field: StrawberryField) -> None: ): raise RelayWrongResolverAnnotationError(field.name, field.base_resolver) - self.connection_type = cast(Type[Connection[Node]], f_type) + self.connection_type = cast(type[Connection[Node]], f_type) def resolve( self, @@ -352,13 +350,13 @@ def connection( name: Optional[str] = None, is_subscription: bool = False, description: Optional[str] = None, - permission_classes: Optional[List[Type[BasePermission]]] = None, + permission_classes: Optional[list[type[BasePermission]]] = None, deprecation_reason: Optional[str] = None, default: Any = dataclasses.MISSING, default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, metadata: Optional[Mapping[Any, Any]] = None, directives: Optional[Sequence[object]] = (), - extensions: List[FieldExtension] = (), # type: ignore + extensions: list[FieldExtension] = (), # type: ignore # This init parameter is used by pyright to determine whether this field # is added in the constructor or not. It is not used to change # any behaviour at the moment. @@ -460,4 +458,4 @@ def get_some_nodes(self, age: int) -> Iterable[SomeType]: ... return f -__all__ = ["node", "connection"] +__all__ = ["connection", "node"] diff --git a/strawberry/relay/types.py b/strawberry/relay/types.py index 869e017be9..5342d6e29a 100644 --- a/strawberry/relay/types.py +++ b/strawberry/relay/types.py @@ -4,27 +4,28 @@ import inspect import itertools import sys -from typing import ( - TYPE_CHECKING, - Any, +from collections.abc import ( AsyncIterable, AsyncIterator, Awaitable, + Iterable, + Iterator, + Sequence, +) +from typing import ( + TYPE_CHECKING, + Annotated, + Any, ClassVar, ForwardRef, Generic, - Iterable, - Iterator, - List, Optional, - Sequence, - Type, TypeVar, Union, cast, overload, ) -from typing_extensions import Annotated, Literal, Self, TypeAlias, get_args, get_origin +from typing_extensions import Literal, Self, TypeAlias, get_args, get_origin from strawberry.relay.exceptions import NodeIDAnnotationError from strawberry.types.base import ( @@ -33,9 +34,10 @@ get_object_definition, ) from strawberry.types.field import field -from strawberry.types.info import Info # noqa: TCH001 +from strawberry.types.info import Info # noqa: TC001 from strawberry.types.lazy_type import LazyType -from strawberry.types.object_type import interface, type +from strawberry.types.object_type import interface +from strawberry.types.object_type import type as strawberry_type from strawberry.types.private import StrawberryPrivate from strawberry.utils.aio import aenumerate, aislice, resolve_awaitable from strawberry.utils.inspect import in_async_context @@ -137,7 +139,7 @@ async def resolve_node( info: Info, *, required: Literal[True] = ..., - ensure_type: Type[_T], + ensure_type: type[_T], ) -> _T: ... @overload @@ -211,7 +213,7 @@ async def resolve_node(self, info, *, required=False, ensure_type=None) -> Any: return node - def resolve_type(self, info: Info) -> Type[Node]: + def resolve_type(self, info: Info) -> type[Node]: """Resolve the internal type name to its type itself. Args: @@ -247,7 +249,7 @@ def resolve_node_sync( info: Info, *, required: Literal[True] = ..., - ensure_type: Type[_T], + ensure_type: type[_T], ) -> _T: ... @overload @@ -391,7 +393,7 @@ def _id(cls, root: Node, info: Info) -> GlobalID: parent_type = info._raw_info.parent_type type_def = info.schema.get_type_by_name(parent_type.name) assert isinstance(type_def, StrawberryObjectDefinition) - origin = cast(Type[Node], type_def.origin) + origin = cast(type[Node], type_def.origin) resolve_id = origin.resolve_id resolve_typename = origin.resolve_typename @@ -618,7 +620,7 @@ def resolve_node( return next(iter(cast(Iterable[Self], retval))) -@type(description="Information to aid in pagination.") +@strawberry_type(description="Information to aid in pagination.") class PageInfo: """Information to aid in pagination. @@ -647,7 +649,7 @@ class PageInfo: ) -@type(description="An edge in a connection.") +@strawberry_type(description="An edge in a connection.") class Edge(Generic[NodeType]): """An edge in a connection. @@ -666,7 +668,7 @@ def resolve_edge(cls, node: NodeType, *, cursor: Any = None) -> Self: return cls(cursor=to_base64(PREFIX, cursor), node=node) -@type(description="A connection to a list of items.") +@strawberry_type(description="A connection to a list of items.") class Connection(Generic[NodeType]): """A connection to a list of items. @@ -679,7 +681,7 @@ class Connection(Generic[NodeType]): """ page_info: PageInfo = field(description="Pagination data for this connection") - edges: List[Edge[NodeType]] = field( + edges: list[Edge[NodeType]] = field( description="Contains the nodes in this connection" ) @@ -738,7 +740,7 @@ def resolve_connection( raise NotImplementedError -@type(name="Connection", description="A connection to a list of items.") +@strawberry_type(name="Connection", description="A connection to a list of items.") class ListConnection(Connection[NodeType]): """A connection to a list of items. @@ -751,7 +753,7 @@ class ListConnection(Connection[NodeType]): """ page_info: PageInfo = field(description="Pagination data for this connection") - edges: List[Edge[NodeType]] = field( + edges: list[Edge[NodeType]] = field( description="Contains the nodes in this connection" ) @@ -827,7 +829,7 @@ async def resolver() -> Self: # The slice above might return an object that now is not async # iterable anymore (e.g. an already cached django queryset) if isinstance(iterator, (AsyncIterator, AsyncIterable)): - edges: List[Edge] = [ + edges: list[Edge] = [ edge_class.resolve_edge( cls.resolve_node(v, info=info, **kwargs), cursor=slice_metadata.start + i, @@ -835,7 +837,7 @@ async def resolver() -> Self: async for i, v in aenumerate(iterator) ] else: - edges: List[Edge] = [ # type: ignore[no-redef] + edges: list[Edge] = [ # type: ignore[no-redef] edge_class.resolve_edge( cls.resolve_node(v, info=info, **kwargs), cursor=slice_metadata.start + i, @@ -935,17 +937,17 @@ async def resolver() -> Self: __all__ = [ + "PREFIX", + "Connection", + "Edge", "GlobalID", "GlobalIDValueError", + "ListConnection", "Node", "NodeID", "NodeIDAnnotationError", "NodeIDPrivate", "NodeIterableType", "NodeType", - "PREFIX", - "Connection", - "Edge", "PageInfo", - "ListConnection", ] diff --git a/strawberry/relay/utils.py b/strawberry/relay/utils.py index c25bd537b0..d25eacb447 100644 --- a/strawberry/relay/utils.py +++ b/strawberry/relay/utils.py @@ -3,7 +3,7 @@ import base64 import dataclasses import sys -from typing import TYPE_CHECKING, Any, Tuple, Union +from typing import TYPE_CHECKING, Any, Union from typing_extensions import Self, assert_never from strawberry.types.base import StrawberryObjectDefinition @@ -13,7 +13,7 @@ from strawberry.types.info import Info -def from_base64(value: str) -> Tuple[str, str]: +def from_base64(value: str) -> tuple[str, str]: """Parse the base64 encoded relay value. Args: @@ -191,8 +191,8 @@ def from_arguments( __all__ = [ + "SliceMetadata", "from_base64", - "to_base64", "should_resolve_list_connection_edges", - "SliceMetadata", + "to_base64", ] diff --git a/strawberry/sanic/utils.py b/strawberry/sanic/utils.py index 7c57b02664..1d78c09118 100644 --- a/strawberry/sanic/utils.py +++ b/strawberry/sanic/utils.py @@ -1,7 +1,7 @@ from __future__ import annotations from io import BytesIO -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union, cast from sanic.request import File @@ -9,7 +9,7 @@ from sanic.request import Request -def convert_request_to_files_dict(request: Request) -> Dict[str, Any]: +def convert_request_to_files_dict(request: Request) -> dict[str, Any]: """Converts the request.files dictionary to a dictionary of BytesIO objects. `request.files` has the following format, even if only a single file is uploaded: @@ -24,12 +24,12 @@ def convert_request_to_files_dict(request: Request) -> Dict[str, Any]: Note that the dictionary entries are lists. """ - request_files = cast(Optional[Dict[str, List[File]]], request.files) + request_files = cast(Optional[dict[str, list[File]]], request.files) if not request_files: return {} - files_dict: Dict[str, Union[BytesIO, List[BytesIO]]] = {} + files_dict: dict[str, Union[BytesIO, list[BytesIO]]] = {} for field_name, file_list in request_files.items(): assert len(file_list) == 1 diff --git a/strawberry/sanic/views.py b/strawberry/sanic/views.py index ee76d2e946..9044df2b0a 100644 --- a/strawberry/sanic/views.py +++ b/strawberry/sanic/views.py @@ -5,12 +5,8 @@ from typing import ( TYPE_CHECKING, Any, - AsyncGenerator, Callable, - Dict, - Mapping, Optional, - Type, cast, ) from typing_extensions import TypeGuard @@ -29,6 +25,8 @@ from strawberry.sanic.utils import convert_request_to_files_dict if TYPE_CHECKING: + from collections.abc import AsyncGenerator, Mapping + from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.schema import BaseSchema @@ -109,8 +107,8 @@ def __init__( graphiql: Optional[bool] = None, graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, - json_encoder: Optional[Type[json.JSONEncoder]] = None, - json_dumps_params: Optional[Dict[str, Any]] = None, + json_encoder: Optional[type[json.JSONEncoder]] = None, + json_dumps_params: Optional[dict[str, Any]] = None, multipart_uploads_enabled: bool = False, ) -> None: self.schema = schema @@ -194,7 +192,7 @@ async def create_streaming_response( request: Request, stream: Callable[[], AsyncGenerator[str, None]], sub_response: TemporalResponse, - headers: Dict[str, str], + headers: dict[str, str], ) -> HTTPResponse: response = await self.request.respond( status=sub_response.status_code, diff --git a/strawberry/scalars.py b/strawberry/scalars.py index f87f81889a..fa1e1ea902 100644 --- a/strawberry/scalars.py +++ b/strawberry/scalars.py @@ -1,7 +1,7 @@ from __future__ import annotations import base64 -from typing import TYPE_CHECKING, Any, Dict, NewType, Union +from typing import TYPE_CHECKING, Any, NewType, Union from strawberry.types.scalar import scalar @@ -57,7 +57,7 @@ def is_scalar( annotation: Any, - scalar_registry: Dict[object, Union[ScalarWrapper, ScalarDefinition]], + scalar_registry: dict[object, Union[ScalarWrapper, ScalarDefinition]], ) -> bool: if annotation in scalar_registry: return True diff --git a/strawberry/schema/base.py b/strawberry/schema/base.py index f34fc98632..cc580f12c5 100644 --- a/strawberry/schema/base.py +++ b/strawberry/schema/base.py @@ -2,12 +2,14 @@ from abc import abstractmethod from functools import lru_cache -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Type, Union +from typing import TYPE_CHECKING, Any, Optional, Union from typing_extensions import Protocol from strawberry.utils.logging import StrawberryLogger if TYPE_CHECKING: + from collections.abc import Iterable + from graphql import GraphQLError from strawberry.directive import StrawberryDirective @@ -16,7 +18,10 @@ ExecutionContext, ExecutionResult, ) - from strawberry.types.base import StrawberryObjectDefinition + from strawberry.types.base import ( + StrawberryObjectDefinition, + WithStrawberryObjectDefinition, + ) from strawberry.types.enum import EnumDefinition from strawberry.types.graphql import OperationType from strawberry.types.scalar import ScalarDefinition @@ -29,16 +34,16 @@ class BaseSchema(Protocol): config: StrawberryConfig schema_converter: GraphQLCoreConverter - query: Type - mutation: Optional[Type] - subscription: Optional[Type] - schema_directives: List[object] + query: type[WithStrawberryObjectDefinition] + mutation: Optional[type[WithStrawberryObjectDefinition]] + subscription: Optional[type[WithStrawberryObjectDefinition]] + schema_directives: list[object] @abstractmethod async def execute( self, query: Optional[str], - variable_values: Optional[Dict[str, Any]] = None, + variable_values: Optional[dict[str, Any]] = None, context_value: Optional[Any] = None, root_value: Optional[Any] = None, operation_name: Optional[str] = None, @@ -50,7 +55,7 @@ async def execute( def execute_sync( self, query: Optional[str], - variable_values: Optional[Dict[str, Any]] = None, + variable_values: Optional[dict[str, Any]] = None, context_value: Optional[Any] = None, root_value: Optional[Any] = None, operation_name: Optional[str] = None, @@ -62,7 +67,7 @@ def execute_sync( async def subscribe( self, query: str, - variable_values: Optional[Dict[str, Any]] = None, + variable_values: Optional[dict[str, Any]] = None, context_value: Optional[Any] = None, root_value: Optional[Any] = None, operation_name: Optional[str] = None, @@ -101,7 +106,7 @@ def remove_field_suggestion(error: GraphQLError) -> None: def _process_errors( self, - errors: List[GraphQLError], + errors: list[GraphQLError], execution_context: Optional[ExecutionContext] = None, ) -> None: if self.config.disable_field_suggestions: @@ -112,7 +117,7 @@ def _process_errors( def process_errors( self, - errors: List[GraphQLError], + errors: list[GraphQLError], execution_context: Optional[ExecutionContext] = None, ) -> None: for error in errors: diff --git a/strawberry/schema/compat.py b/strawberry/schema/compat.py index 2831eab78a..b5e00e7158 100644 --- a/strawberry/schema/compat.py +++ b/strawberry/schema/compat.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Dict, Union +from typing import TYPE_CHECKING, Union from strawberry.scalars import is_scalar as is_strawberry_scalar from strawberry.types.base import StrawberryType, has_object_definition @@ -29,7 +29,7 @@ def is_interface_type(type_: Union[StrawberryType, type]) -> TypeGuard[type]: def is_scalar( type_: Union[StrawberryType, type], - scalar_registry: Dict[object, Union[ScalarWrapper, ScalarDefinition]], + scalar_registry: dict[object, Union[ScalarWrapper, ScalarDefinition]], ) -> TypeGuard[type]: return is_strawberry_scalar(type_, scalar_registry) @@ -54,10 +54,10 @@ def is_graphql_generic(type_: Union[StrawberryType, type]) -> bool: __all__ = [ + "is_enum", + "is_graphql_generic", "is_input_type", "is_interface_type", "is_scalar", - "is_enum", "is_schema_directive", - "is_graphql_generic", ] diff --git a/strawberry/schema/execute.py b/strawberry/schema/execute.py index 5262535062..d47baa8a6a 100644 --- a/strawberry/schema/execute.py +++ b/strawberry/schema/execute.py @@ -1,16 +1,12 @@ from __future__ import annotations from asyncio import ensure_future +from collections.abc import Awaitable, Iterable from inspect import isawaitable from typing import ( TYPE_CHECKING, - Awaitable, Callable, - Iterable, - List, Optional, - Tuple, - Type, TypedDict, Union, cast, @@ -49,7 +45,7 @@ class ParseOptions(TypedDict): ProcessErrors: TypeAlias = ( - "Callable[[List[GraphQLError], Optional[ExecutionContext]], None]" + "Callable[[list[GraphQLError], Optional[ExecutionContext]], None]" ) @@ -60,8 +56,8 @@ def parse_document(query: str, **kwargs: Unpack[ParseOptions]) -> DocumentNode: def validate_document( schema: GraphQLSchema, document: DocumentNode, - validation_rules: Tuple[Type[ASTValidationRule], ...], -) -> List[GraphQLError]: + validation_rules: tuple[type[ASTValidationRule], ...], +) -> list[GraphQLError]: validation_rules = ( *validation_rules, OneOfInputValidationRule, @@ -150,7 +146,7 @@ async def execute( extensions_runner: SchemaExtensionsRunner, process_errors: ProcessErrors, middleware_manager: MiddlewareManager, - execution_context_class: Optional[Type[GraphQLExecutionContext]] = None, + execution_context_class: Optional[type[GraphQLExecutionContext]] = None, ) -> ExecutionResult | PreExecutionError: try: async with extensions_runner.operation(): @@ -214,7 +210,7 @@ def execute_sync( allowed_operation_types: Iterable[OperationType], extensions_runner: SchemaExtensionsRunner, execution_context: ExecutionContext, - execution_context_class: Optional[Type[GraphQLExecutionContext]] = None, + execution_context_class: Optional[type[GraphQLExecutionContext]] = None, process_errors: ProcessErrors, middleware_manager: MiddlewareManager, ) -> ExecutionResult: diff --git a/strawberry/schema/name_converter.py b/strawberry/schema/name_converter.py index ec6d0edf2a..1dc30edd55 100644 --- a/strawberry/schema/name_converter.py +++ b/strawberry/schema/name_converter.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List, Optional, Union, cast +from typing import TYPE_CHECKING, Optional, Union, cast from typing_extensions import Protocol from strawberry.directive import StrawberryDirective @@ -126,11 +126,11 @@ def from_union(self, union: StrawberryUnion) -> str: def from_generic( self, generic_type: StrawberryObjectDefinition, - types: List[Union[StrawberryType, type]], + types: list[Union[StrawberryType, type]], ) -> str: generic_type_name = generic_type.name - names: List[str] = [] + names: list[str] = [] for type_ in types: name = self.get_from_type(type_) diff --git a/strawberry/schema/schema.py b/strawberry/schema/schema.py index a8fce095b5..2351c8d51b 100644 --- a/strawberry/schema/schema.py +++ b/strawberry/schema/schema.py @@ -5,11 +5,7 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - Iterable, - List, Optional, - Type, Union, cast, ) @@ -37,7 +33,11 @@ from strawberry.schema.schema_converter import GraphQLCoreConverter from strawberry.schema.types.scalar import DEFAULT_SCALAR_REGISTRY from strawberry.types import ExecutionContext -from strawberry.types.base import StrawberryObjectDefinition, has_object_definition +from strawberry.types.base import ( + StrawberryObjectDefinition, + WithStrawberryObjectDefinition, + has_object_definition, +) from strawberry.types.graphql import OperationType from ..printer import print_schema @@ -48,6 +48,8 @@ from .subscribe import SubscriptionResult, subscribe if TYPE_CHECKING: + from collections.abc import Iterable + from graphql import ExecutionContext as GraphQLExecutionContext from strawberry.directive import StrawberryDirective @@ -70,16 +72,16 @@ def __init__( self, # TODO: can we make sure we only allow to pass # something that has been decorated? - query: Type, - mutation: Optional[Type] = None, - subscription: Optional[Type] = None, + query: type, + mutation: Optional[type] = None, + subscription: Optional[type] = None, directives: Iterable[StrawberryDirective] = (), - types: Iterable[Union[Type, StrawberryType]] = (), - extensions: Iterable[Union[Type[SchemaExtension], SchemaExtension]] = (), - execution_context_class: Optional[Type[GraphQLExecutionContext]] = None, + types: Iterable[Union[type, StrawberryType]] = (), + extensions: Iterable[Union[type[SchemaExtension], SchemaExtension]] = (), + execution_context_class: Optional[type[GraphQLExecutionContext]] = None, config: Optional[StrawberryConfig] = None, scalar_overrides: Optional[ - Dict[object, Union[Type, ScalarWrapper, ScalarDefinition]], + dict[object, Union[type, ScalarWrapper, ScalarDefinition]], ] = None, schema_directives: Iterable[object] = (), ) -> None: @@ -127,7 +129,7 @@ class Query: self.execution_context_class = execution_context_class self.config = config or StrawberryConfig() - SCALAR_OVERRIDES_DICT_TYPE = Dict[ + SCALAR_OVERRIDES_DICT_TYPE = dict[ object, Union["ScalarWrapper", "ScalarDefinition"] ] @@ -142,14 +144,24 @@ class Query: self.directives = directives self.schema_directives = list(schema_directives) - query_type = self.schema_converter.from_object(query.__strawberry_definition__) + query_type = self.schema_converter.from_object( + cast(type[WithStrawberryObjectDefinition], query).__strawberry_definition__ + ) mutation_type = ( - self.schema_converter.from_object(mutation.__strawberry_definition__) + self.schema_converter.from_object( + cast( + type[WithStrawberryObjectDefinition], mutation + ).__strawberry_definition__ + ) if mutation else None ) subscription_type = ( - self.schema_converter.from_object(subscription.__strawberry_definition__) + self.schema_converter.from_object( + cast( + type[WithStrawberryObjectDefinition], subscription + ).__strawberry_definition__ + ) if subscription else None ) @@ -213,7 +225,7 @@ class Query: formatted_errors = "\n\n".join(f"❌ {error.message}" for error in errors) raise ValueError(f"Invalid Schema. Errors:\n\n{formatted_errors}") - def get_extensions(self, sync: bool = False) -> List[SchemaExtension]: + def get_extensions(self, sync: bool = False) -> list[SchemaExtension]: extensions = [] if self.directives: extensions = [ @@ -227,11 +239,11 @@ def get_extensions(self, sync: bool = False) -> List[SchemaExtension]: ] @cached_property - def _sync_extensions(self) -> List[SchemaExtension]: + def _sync_extensions(self) -> list[SchemaExtension]: return self.get_extensions(sync=True) @cached_property - def _async_extensions(self) -> List[SchemaExtension]: + def _async_extensions(self) -> list[SchemaExtension]: return self.get_extensions(sync=False) def create_extensions_runner( @@ -256,7 +268,7 @@ def _create_execution_context( self, query: Optional[str], allowed_operation_types: Iterable[OperationType], - variable_values: Optional[Dict[str, Any]] = None, + variable_values: Optional[dict[str, Any]] = None, context_value: Optional[Any] = None, root_value: Optional[Any] = None, operation_name: Optional[str] = None, @@ -320,13 +332,13 @@ def get_directive_by_name(self, graphql_name: str) -> Optional[StrawberryDirecti def get_fields( self, type_definition: StrawberryObjectDefinition - ) -> List[StrawberryField]: + ) -> list[StrawberryField]: return type_definition.fields async def execute( self, query: Optional[str], - variable_values: Optional[Dict[str, Any]] = None, + variable_values: Optional[dict[str, Any]] = None, context_value: Optional[Any] = None, root_value: Optional[Any] = None, operation_name: Optional[str] = None, @@ -361,7 +373,7 @@ async def execute( def execute_sync( self, query: Optional[str], - variable_values: Optional[Dict[str, Any]] = None, + variable_values: Optional[dict[str, Any]] = None, context_value: Optional[Any] = None, root_value: Optional[Any] = None, operation_name: Optional[str] = None, @@ -397,7 +409,7 @@ def execute_sync( async def subscribe( self, query: Optional[str], - variable_values: Optional[Dict[str, Any]] = None, + variable_values: Optional[dict[str, Any]] = None, context_value: Optional[Any] = None, root_value: Optional[Any] = None, operation_name: Optional[str] = None, @@ -495,7 +507,7 @@ def as_str(self) -> str: __str__ = as_str - def introspect(self) -> Dict[str, Any]: + def introspect(self) -> dict[str, Any]: """Return the introspection query result for the current schema. Raises: diff --git a/strawberry/schema/schema_converter.py b/strawberry/schema/schema_converter.py index 1083b46f9b..bc364ef219 100644 --- a/strawberry/schema/schema_converter.py +++ b/strawberry/schema/schema_converter.py @@ -6,14 +6,9 @@ from typing import ( TYPE_CHECKING, Any, - Awaitable, Callable, - Dict, Generic, - List, Optional, - Tuple, - Type, TypeVar, Union, cast, @@ -76,6 +71,8 @@ from .types.concrete_type import ConcreteType if TYPE_CHECKING: + from collections.abc import Awaitable + from graphql import ( GraphQLInputType, GraphQLNullableType, @@ -111,8 +108,8 @@ def _get_thunk_mapping( type_definition: StrawberryObjectDefinition, name_converter: Callable[[StrawberryField], str], field_converter: FieldConverterProtocol[FieldType], - get_fields: Callable[[StrawberryObjectDefinition], List[StrawberryField]], -) -> Dict[str, FieldType]: + get_fields: Callable[[StrawberryObjectDefinition], list[StrawberryField]], +) -> dict[str, FieldType]: """Create a GraphQL core `ThunkMapping` mapping of field names to field types. This method filters out remaining `strawberry.Private` annotated fields that @@ -124,7 +121,7 @@ def _get_thunk_mapping( Raises: TypeError: If the type of a field in ``fields`` is `UNRESOLVED` """ - thunk_mapping: Dict[str, FieldType] = {} + thunk_mapping: dict[str, FieldType] = {} fields = get_fields(type_definition) @@ -173,7 +170,7 @@ def parse_value(self, input_value: str) -> Any: return self.wrapped_cls(super().parse_value(input_value)) def parse_literal( - self, value_node: ValueNode, _variables: Optional[Dict[str, Any]] = None + self, value_node: ValueNode, _variables: Optional[dict[str, Any]] = None ) -> Any: return self.wrapped_cls(super().parse_literal(value_node, _variables)) @@ -185,8 +182,8 @@ def get_arguments( info: Info, kwargs: Any, config: StrawberryConfig, - scalar_registry: Dict[object, Union[ScalarWrapper, ScalarDefinition]], -) -> Tuple[List[Any], Dict[str, Any]]: + scalar_registry: dict[object, Union[ScalarWrapper, ScalarDefinition]], +) -> tuple[list[Any], dict[str, Any]]: # TODO: An extension might have changed the resolver arguments, # but we need them here since we are calling it. # This is a bit of a hack, but it's the easiest way to get the arguments @@ -242,10 +239,10 @@ class GraphQLCoreConverter: def __init__( self, config: StrawberryConfig, - scalar_registry: Dict[object, Union[ScalarWrapper, ScalarDefinition]], - get_fields: Callable[[StrawberryObjectDefinition], List[StrawberryField]], + scalar_registry: dict[object, Union[ScalarWrapper, ScalarDefinition]], + get_fields: Callable[[StrawberryObjectDefinition], list[StrawberryField]], ) -> None: - self.type_map: Dict[str, ConcreteType] = {} + self.type_map: dict[str, ConcreteType] = {} self.config = config self.scalar_registry = scalar_registry self.get_fields = get_fields @@ -329,13 +326,14 @@ def from_directive(self, directive: StrawberryDirective) -> GraphQLDirective: }, ) - def from_schema_directive(self, cls: Type) -> GraphQLDirective: + def from_schema_directive(self, cls: type) -> GraphQLDirective: strawberry_directive = cast( - "StrawberrySchemaDirective", cls.__strawberry_directive__ + "StrawberrySchemaDirective", + cls.__strawberry_directive__, # type: ignore[attr-defined] ) module = sys.modules[cls.__module__] - args: Dict[str, GraphQLArgument] = {} + args: dict[str, GraphQLArgument] = {} for field in strawberry_directive.fields: default = field.default if default == dataclasses.MISSING: @@ -436,7 +434,7 @@ def from_input_field( def get_graphql_fields( self, type_definition: StrawberryObjectDefinition - ) -> Dict[str, GraphQLField]: + ) -> dict[str, GraphQLField]: return _get_thunk_mapping( type_definition=type_definition, name_converter=self.config.name_converter.from_field, @@ -446,7 +444,7 @@ def get_graphql_fields( def get_graphql_input_fields( self, type_definition: StrawberryObjectDefinition - ) -> Dict[str, GraphQLInputField]: + ) -> dict[str, GraphQLInputField]: return _get_thunk_mapping( type_definition=type_definition, name_converter=self.config.name_converter.from_field, @@ -672,8 +670,8 @@ def _strawberry_info_from_graphql(info: GraphQLResolveInfo) -> Info: def _get_result( _source: Any, info: Info, - field_args: List[Any], - field_kwargs: Dict[str, Any], + field_args: list[Any], + field_kwargs: dict[str, Any], ) -> Any: return field.get_result( _source, info=info, args=field_args, kwargs=field_kwargs @@ -762,7 +760,7 @@ async def _async_resolver( _resolver._is_default = not field.base_resolver # type: ignore return _resolver - def from_scalar(self, scalar: Type) -> GraphQLScalarType: + def from_scalar(self, scalar: type) -> GraphQLScalarType: scalar_definition: ScalarDefinition if scalar in self.scalar_registry: @@ -773,7 +771,7 @@ def from_scalar(self, scalar: Type) -> GraphQLScalarType: else: scalar_definition = _scalar_definition else: - scalar_definition = scalar._scalar_definition + scalar_definition = scalar._scalar_definition # type: ignore[attr-defined] scalar_name = self.config.name_converter.from_type(scalar_definition) @@ -864,7 +862,7 @@ def from_union(self, union: StrawberryUnion) -> GraphQLUnionType: assert isinstance(graphql_union, GraphQLUnionType) # For mypy return graphql_union - graphql_types: List[GraphQLObjectType] = [] + graphql_types: list[GraphQLObjectType] = [] for type_ in union.types: graphql_type = self.from_type(type_) diff --git a/strawberry/schema/subscribe.py b/strawberry/schema/subscribe.py index 22052e36cf..5417958b62 100644 --- a/strawberry/schema/subscribe.py +++ b/strawberry/schema/subscribe.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Optional, Type, Union +from collections.abc import AsyncGenerator, AsyncIterator +from typing import TYPE_CHECKING, Optional, Union from graphql import ( ExecutionResult as OriginalExecutionResult, @@ -44,7 +45,7 @@ async def _subscribe( extensions_runner: SchemaExtensionsRunner, process_errors: ProcessErrors, middleware_manager: MiddlewareManager, - execution_context_class: Optional[Type[GraphQLExecutionContext]] = None, + execution_context_class: Optional[type[GraphQLExecutionContext]] = None, ) -> AsyncGenerator[Union[PreExecutionError, ExecutionResult], None]: async with extensions_runner.operation(): if initial_error := await _parse_and_validate_async( @@ -128,7 +129,7 @@ async def subscribe( extensions_runner: SchemaExtensionsRunner, process_errors: ProcessErrors, middleware_manager: MiddlewareManager, - execution_context_class: Optional[Type[GraphQLExecutionContext]] = None, + execution_context_class: Optional[type[GraphQLExecutionContext]] = None, ) -> SubscriptionResult: asyncgen = _subscribe( schema, diff --git a/strawberry/schema/types/base_scalars.py b/strawberry/schema/types/base_scalars.py index 8e73a1383a..4d8a66df23 100644 --- a/strawberry/schema/types/base_scalars.py +++ b/strawberry/schema/types/base_scalars.py @@ -81,4 +81,4 @@ def _verify_void(x: None) -> None: description="Represents NULL values", ) -__all__ = ["Date", "DateTime", "Time", "Decimal", "UUID", "Void"] +__all__ = ["UUID", "Date", "DateTime", "Decimal", "Time", "Void"] diff --git a/strawberry/schema/types/concrete_type.py b/strawberry/schema/types/concrete_type.py index 5507209e59..6a421c26a3 100644 --- a/strawberry/schema/types/concrete_type.py +++ b/strawberry/schema/types/concrete_type.py @@ -1,7 +1,7 @@ from __future__ import annotations import dataclasses -from typing import TYPE_CHECKING, Dict, Union +from typing import TYPE_CHECKING, Union from graphql import GraphQLField, GraphQLInputField, GraphQLType @@ -22,7 +22,7 @@ class ConcreteType: implementation: GraphQLType -TypeMap = Dict[str, ConcreteType] +TypeMap = dict[str, ConcreteType] __all__ = ["ConcreteType", "Field", "GraphQLType", "TypeMap"] diff --git a/strawberry/schema/types/scalar.py b/strawberry/schema/types/scalar.py index 0e89ad4476..d07c55b886 100644 --- a/strawberry/schema/types/scalar.py +++ b/strawberry/schema/types/scalar.py @@ -1,6 +1,5 @@ import datetime import decimal -from typing import Dict, Type from uuid import UUID from graphql import ( @@ -45,11 +44,11 @@ def _make_scalar_definition(scalar_type: GraphQLScalarType) -> ScalarDefinition: ) -def _get_scalar_definition(scalar: Type) -> ScalarDefinition: - return scalar._scalar_definition +def _get_scalar_definition(scalar: type) -> ScalarDefinition: + return scalar._scalar_definition # type: ignore[attr-defined] -DEFAULT_SCALAR_REGISTRY: Dict[object, ScalarDefinition] = { +DEFAULT_SCALAR_REGISTRY: dict[object, ScalarDefinition] = { type(None): _get_scalar_definition(base_scalars.Void), None: _get_scalar_definition(base_scalars.Void), str: _make_scalar_definition(GraphQLString), diff --git a/strawberry/schema_codegen/__init__.py b/strawberry/schema_codegen/__init__.py index 5ce5c1de11..92c018d948 100644 --- a/strawberry/schema_codegen/__init__.py +++ b/strawberry/schema_codegen/__init__.py @@ -3,11 +3,11 @@ import dataclasses import keyword from collections import defaultdict -from typing import TYPE_CHECKING, List, Tuple, Union +from graphlib import TopologicalSorter +from typing import TYPE_CHECKING, Union from typing_extensions import Protocol, TypeAlias import libcst as cst -from graphlib import TopologicalSorter from graphql import ( EnumTypeDefinitionNode, EnumValueDefinitionNode, @@ -42,7 +42,7 @@ class HasDirectives(Protocol): - directives: Tuple[ConstDirectiveNode, ...] + directives: tuple[ConstDirectiveNode, ...] _SCALAR_MAP = { @@ -256,7 +256,7 @@ def _get_field( ) -ArgumentValue: TypeAlias = Union[str, bool, List["ArgumentValue"]] +ArgumentValue: TypeAlias = Union[str, bool, list["ArgumentValue"]] def _get_argument_value(argument_value: ConstValueNode) -> ArgumentValue: diff --git a/strawberry/schema_directive.py b/strawberry/schema_directive.py index f519a7a766..6cf5069a11 100644 --- a/strawberry/schema_directive.py +++ b/strawberry/schema_directive.py @@ -1,6 +1,6 @@ import dataclasses from enum import Enum -from typing import Callable, List, Optional, Type, TypeVar +from typing import Callable, Optional, TypeVar from typing_extensions import dataclass_transform from strawberry.types.field import StrawberryField, field @@ -28,15 +28,15 @@ class Location(Enum): class StrawberrySchemaDirective: python_name: str graphql_name: Optional[str] - locations: List[Location] - fields: List["StrawberryField"] + locations: list[Location] + fields: list["StrawberryField"] description: Optional[str] = None repeatable: bool = False print_definition: bool = True - origin: Optional[Type] = None + origin: Optional[type] = None -T = TypeVar("T", bound=Type) +T = TypeVar("T", bound=type) @dataclass_transform( @@ -46,17 +46,17 @@ class StrawberrySchemaDirective: ) def schema_directive( *, - locations: List[Location], + locations: list[Location], description: Optional[str] = None, name: Optional[str] = None, repeatable: bool = False, print_definition: bool = True, -) -> Callable[..., T]: +) -> Callable[[T], T]: def _wrap(cls: T) -> T: cls = _wrap_dataclass(cls) # type: ignore fields = _get_fields(cls, {}) - cls.__strawberry_directive__ = StrawberrySchemaDirective( + cls.__strawberry_directive__ = StrawberrySchemaDirective( # type: ignore[attr-defined] python_name=cls.__name__, graphql_name=name, locations=locations, diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index 6c29f9faeb..b4cdc9d4e8 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -2,14 +2,12 @@ import asyncio import logging +from collections.abc import Awaitable from contextlib import suppress from typing import ( TYPE_CHECKING, Any, - Awaitable, - Dict, Generic, - List, Optional, cast, ) @@ -40,6 +38,7 @@ from strawberry.utils.operation import get_operation_type if TYPE_CHECKING: + from collections.abc import Awaitable from datetime import timedelta from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncWebSocketAdapter @@ -71,8 +70,8 @@ def __init__( self.connection_init_received = False self.connection_acknowledged = False self.connection_timed_out = False - self.operations: Dict[str, Operation[Context, RootValue]] = {} - self.completed_tasks: List[asyncio.Task] = [] + self.operations: dict[str, Operation[Context, RootValue]] = {} + self.completed_tasks: list[asyncio.Task] = [] async def handle(self) -> None: self.on_request_accepted() @@ -343,14 +342,14 @@ class Operation(Generic[Context, RootValue]): """A class encapsulating a single operation with its id. Helps enforce protocol state transition.""" __slots__ = [ + "completed", "handler", "id", + "operation_name", "operation_type", "query", - "variables", - "operation_name", - "completed", "task", + "variables", ] def __init__( @@ -359,7 +358,7 @@ def __init__( id: str, operation_type: OperationType, query: str, - variables: Optional[Dict[str, object]], + variables: Optional[dict[str, object]], operation_name: Optional[str], ) -> None: self.handler = handler diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/types.py b/strawberry/subscriptions/protocols/graphql_transport_ws/types.py index 7e5a804f29..d08f38e971 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/types.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/types.py @@ -1,4 +1,4 @@ -from typing import Dict, List, TypedDict, Union +from typing import TypedDict, Union from typing_extensions import Literal, NotRequired from graphql import GraphQLFormattedError @@ -8,35 +8,35 @@ class ConnectionInitMessage(TypedDict): """Direction: Client -> Server.""" type: Literal["connection_init"] - payload: NotRequired[Union[Dict[str, object], None]] + payload: NotRequired[Union[dict[str, object], None]] class ConnectionAckMessage(TypedDict): """Direction: Server -> Client.""" type: Literal["connection_ack"] - payload: NotRequired[Union[Dict[str, object], None]] + payload: NotRequired[Union[dict[str, object], None]] class PingMessage(TypedDict): """Direction: bidirectional.""" type: Literal["ping"] - payload: NotRequired[Union[Dict[str, object], None]] + payload: NotRequired[Union[dict[str, object], None]] class PongMessage(TypedDict): """Direction: bidirectional.""" type: Literal["pong"] - payload: NotRequired[Union[Dict[str, object], None]] + payload: NotRequired[Union[dict[str, object], None]] class SubscribeMessagePayload(TypedDict): operationName: NotRequired[Union[str, None]] query: str - variables: NotRequired[Union[Dict[str, object], None]] - extensions: NotRequired[Union[Dict[str, object], None]] + variables: NotRequired[Union[dict[str, object], None]] + extensions: NotRequired[Union[dict[str, object], None]] class SubscribeMessage(TypedDict): @@ -48,9 +48,9 @@ class SubscribeMessage(TypedDict): class NextMessagePayload(TypedDict): - errors: NotRequired[List[GraphQLFormattedError]] - data: NotRequired[Union[Dict[str, object], None]] - extensions: NotRequired[Dict[str, object]] + errors: NotRequired[list[GraphQLFormattedError]] + data: NotRequired[Union[dict[str, object], None]] + extensions: NotRequired[dict[str, object]] class NextMessage(TypedDict): @@ -66,7 +66,7 @@ class ErrorMessage(TypedDict): id: str type: Literal["error"] - payload: List[GraphQLFormattedError] + payload: list[GraphQLFormattedError] class CompleteMessage(TypedDict): @@ -89,13 +89,13 @@ class CompleteMessage(TypedDict): __all__ = [ - "ConnectionInitMessage", + "CompleteMessage", "ConnectionAckMessage", + "ConnectionInitMessage", + "ErrorMessage", + "Message", + "NextMessage", "PingMessage", "PongMessage", "SubscribeMessage", - "NextMessage", - "ErrorMessage", - "CompleteMessage", - "Message", ] diff --git a/strawberry/subscriptions/protocols/graphql_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_ws/handlers.py index 6f2dcb929d..352d5c5f08 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_ws/handlers.py @@ -1,12 +1,11 @@ from __future__ import annotations import asyncio +from collections.abc import AsyncGenerator from contextlib import suppress from typing import ( TYPE_CHECKING, Any, - AsyncGenerator, - Dict, Generic, Optional, cast, @@ -28,6 +27,8 @@ from strawberry.utils.debug import pretty_print_graphql_operation if TYPE_CHECKING: + from collections.abc import AsyncGenerator + from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncWebSocketAdapter from strawberry.schema import BaseSchema @@ -56,8 +57,8 @@ def __init__( self.keep_alive = keep_alive self.keep_alive_interval = keep_alive_interval self.keep_alive_task: Optional[asyncio.Task] = None - self.subscriptions: Dict[str, AsyncGenerator] = {} - self.tasks: Dict[str, asyncio.Task] = {} + self.subscriptions: dict[str, AsyncGenerator] = {} + self.tasks: dict[str, asyncio.Task] = {} async def handle(self) -> None: try: @@ -164,7 +165,7 @@ async def handle_async_results( operation_id: str, query: str, operation_name: Optional[str], - variables: Optional[Dict[str, object]], + variables: Optional[dict[str, object]], ) -> None: try: agen_or_err = await self.schema.subscribe( diff --git a/strawberry/subscriptions/protocols/graphql_ws/types.py b/strawberry/subscriptions/protocols/graphql_ws/types.py index d29a6209fb..891802866e 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/types.py +++ b/strawberry/subscriptions/protocols/graphql_ws/types.py @@ -1,4 +1,4 @@ -from typing import Dict, List, TypedDict, Union +from typing import TypedDict, Union from typing_extensions import Literal, NotRequired from graphql import GraphQLFormattedError @@ -6,12 +6,12 @@ class ConnectionInitMessage(TypedDict): type: Literal["connection_init"] - payload: NotRequired[Dict[str, object]] + payload: NotRequired[dict[str, object]] class StartMessagePayload(TypedDict): query: str - variables: NotRequired[Dict[str, object]] + variables: NotRequired[dict[str, object]] operationName: NotRequired[str] @@ -32,20 +32,20 @@ class ConnectionTerminateMessage(TypedDict): class ConnectionErrorMessage(TypedDict): type: Literal["connection_error"] - payload: NotRequired[Dict[str, object]] + payload: NotRequired[dict[str, object]] class ConnectionAckMessage(TypedDict): type: Literal["connection_ack"] - payload: NotRequired[Dict[str, object]] + payload: NotRequired[dict[str, object]] class DataMessagePayload(TypedDict): data: object - errors: NotRequired[List[GraphQLFormattedError]] + errors: NotRequired[list[GraphQLFormattedError]] # Non-standard field: - extensions: NotRequired[Dict[str, object]] + extensions: NotRequired[dict[str, object]] class DataMessage(TypedDict): @@ -84,15 +84,15 @@ class ConnectionKeepAliveMessage(TypedDict): __all__ = [ + "CompleteMessage", + "ConnectionAckMessage", + "ConnectionErrorMessage", "ConnectionInitMessage", - "StartMessage", - "StopMessage", + "ConnectionKeepAliveMessage", "ConnectionTerminateMessage", - "ConnectionErrorMessage", - "ConnectionAckMessage", "DataMessage", "ErrorMessage", - "CompleteMessage", - "ConnectionKeepAliveMessage", "OperationMessage", + "StartMessage", + "StopMessage", ] diff --git a/strawberry/test/__init__.py b/strawberry/test/__init__.py index 5f6b0b76e5..c81b9637fb 100644 --- a/strawberry/test/__init__.py +++ b/strawberry/test/__init__.py @@ -1,3 +1,3 @@ from .client import BaseGraphQLTestClient, Body, Response -__all__ = ["Body", "Response", "BaseGraphQLTestClient"] +__all__ = ["BaseGraphQLTestClient", "Body", "Response"] diff --git a/strawberry/test/client.py b/strawberry/test/client.py index 243ac3aadb..7ec027f0fc 100644 --- a/strawberry/test/client.py +++ b/strawberry/test/client.py @@ -4,23 +4,25 @@ import warnings from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Coroutine, Dict, List, Mapping, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union from typing_extensions import Literal, TypedDict if TYPE_CHECKING: + from collections.abc import Coroutine, Mapping + from graphql import GraphQLFormattedError @dataclass class Response: - errors: Optional[List[GraphQLFormattedError]] - data: Optional[Dict[str, object]] - extensions: Optional[Dict[str, object]] + errors: Optional[list[GraphQLFormattedError]] + data: Optional[dict[str, object]] + extensions: Optional[dict[str, object]] class Body(TypedDict, total=False): query: str - variables: Optional[Dict[str, object]] + variables: Optional[dict[str, object]] class BaseGraphQLTestClient(ABC): @@ -35,10 +37,10 @@ def __init__( def query( self, query: str, - variables: Optional[Dict[str, Mapping]] = None, - headers: Optional[Dict[str, object]] = None, + variables: Optional[dict[str, Mapping]] = None, + headers: Optional[dict[str, object]] = None, asserts_errors: Optional[bool] = None, - files: Optional[Dict[str, object]] = None, + files: Optional[dict[str, object]] = None, assert_no_errors: Optional[bool] = True, ) -> Union[Coroutine[Any, Any, Response], Response]: body = self._build_body(query, variables, files) @@ -71,19 +73,19 @@ def query( @abstractmethod def request( self, - body: Dict[str, object], - headers: Optional[Dict[str, object]] = None, - files: Optional[Dict[str, object]] = None, + body: dict[str, object], + headers: Optional[dict[str, object]] = None, + files: Optional[dict[str, object]] = None, ) -> Any: raise NotImplementedError def _build_body( self, query: str, - variables: Optional[Dict[str, Mapping]] = None, - files: Optional[Dict[str, object]] = None, - ) -> Dict[str, object]: - body: Dict[str, object] = {"query": query} + variables: Optional[dict[str, Mapping]] = None, + files: Optional[dict[str, object]] = None, + ) -> dict[str, object]: + body: dict[str, object] = {"query": query} if variables: body["variables"] = variables @@ -103,8 +105,8 @@ def _build_body( @staticmethod def _build_multipart_file_map( - variables: Dict[str, Mapping], files: Dict[str, object] - ) -> Dict[str, List[str]]: + variables: dict[str, Mapping], files: dict[str, object] + ) -> dict[str, list[str]]: """Creates the file mapping between the variables and the files objects passed as key arguments. Args: @@ -158,7 +160,7 @@ def _build_multipart_file_map( # } ``` """ - map: Dict[str, List[str]] = {} + map: dict[str, list[str]] = {} for key, values in variables.items(): reference = key variable_values = values @@ -195,4 +197,4 @@ def _decode(self, response: Any, type: Literal["multipart", "json"]) -> Any: return response.json() -__all__ = ["BaseGraphQLTestClient", "Response", "Body"] +__all__ = ["BaseGraphQLTestClient", "Body", "Response"] diff --git a/strawberry/tools/create_type.py b/strawberry/tools/create_type.py index 8436a5969c..e6cc24edf2 100644 --- a/strawberry/tools/create_type.py +++ b/strawberry/tools/create_type.py @@ -1,5 +1,6 @@ import types -from typing import List, Optional, Sequence, Type +from collections.abc import Sequence +from typing import Optional import strawberry from strawberry.types.field import StrawberryField @@ -7,13 +8,13 @@ def create_type( name: str, - fields: List[StrawberryField], + fields: list[StrawberryField], is_input: bool = False, is_interface: bool = False, description: Optional[str] = None, directives: Optional[Sequence[object]] = (), extend: bool = False, -) -> Type: +) -> type: """Create a Strawberry type from a list of StrawberryFields. Args: diff --git a/strawberry/tools/merge_types.py b/strawberry/tools/merge_types.py index 84095d7086..22524a985b 100644 --- a/strawberry/tools/merge_types.py +++ b/strawberry/tools/merge_types.py @@ -1,13 +1,12 @@ import warnings from collections import Counter from itertools import chain -from typing import Tuple import strawberry from strawberry.types.base import has_object_definition -def merge_types(name: str, types: Tuple[type, ...]) -> type: +def merge_types(name: str, types: tuple[type, ...]) -> type: """Merge multiple Strawberry types into one. For example, given two queries `A` and `B`, one can merge them into a diff --git a/strawberry/types/__init__.py b/strawberry/types/__init__.py index 65f055865c..ba3b5158b9 100644 --- a/strawberry/types/__init__.py +++ b/strawberry/types/__init__.py @@ -5,9 +5,9 @@ __all__ = [ "ExecutionContext", "ExecutionResult", - "SubscriptionExecutionResult", "Info", "Info", + "SubscriptionExecutionResult", "get_object_definition", "has_object_definition", ] diff --git a/strawberry/types/arguments.py b/strawberry/types/arguments.py index a48f0e417d..ff5930f431 100644 --- a/strawberry/types/arguments.py +++ b/strawberry/types/arguments.py @@ -2,18 +2,16 @@ import inspect import warnings +from collections.abc import Iterable, Mapping from typing import ( TYPE_CHECKING, + Annotated, Any, - Dict, - Iterable, - List, - Mapping, Optional, Union, cast, ) -from typing_extensions import Annotated, get_args, get_origin +from typing_extensions import get_args, get_origin from strawberry.annotation import StrawberryAnnotation from strawberry.exceptions import MultipleStrawberryArgumentsError, UnsupportedTypeError @@ -34,7 +32,7 @@ from strawberry.types.scalar import ScalarDefinition, ScalarWrapper -DEPRECATED_NAMES: Dict[str, str] = { +DEPRECATED_NAMES: dict[str, str] = { "UNSET": ( "importing `UNSET` from `strawberry.arguments` is deprecated, " "import instead from `strawberry` or from `strawberry.types.unset`" @@ -140,7 +138,7 @@ def is_graphql_generic(self) -> bool: def convert_argument( value: object, type_: Union[StrawberryType, type], - scalar_registry: Dict[object, Union[ScalarWrapper, ScalarDefinition]], + scalar_registry: dict[object, Union[ScalarWrapper, ScalarDefinition]], config: StrawberryConfig, ) -> object: # TODO: move this somewhere else and make it first class @@ -197,11 +195,11 @@ def convert_argument( def convert_arguments( - value: Dict[str, Any], - arguments: List[StrawberryArgument], - scalar_registry: Dict[object, Union[ScalarWrapper, ScalarDefinition]], + value: dict[str, Any], + arguments: list[StrawberryArgument], + scalar_registry: dict[object, Union[ScalarWrapper, ScalarDefinition]], config: StrawberryConfig, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Converts a nested dictionary to a dictionary of actual types. It deals with conversion of input types to proper dataclasses and @@ -283,9 +281,9 @@ def __getattr__(name: str) -> Any: # TODO: check exports __all__ = [ # noqa: F822 + "UNSET", # for backwards compatibility # type: ignore "StrawberryArgument", "StrawberryArgumentAnnotation", - "UNSET", # for backwards compatibility # type: ignore "argument", "is_unset", # for backwards compatibility # type: ignore ] diff --git a/strawberry/types/auto.py b/strawberry/types/auto.py index 9de85b1e34..7ee49b3d4d 100644 --- a/strawberry/types/auto.py +++ b/strawberry/types/auto.py @@ -1,7 +1,7 @@ from __future__ import annotations -from typing import Any, Optional, Union, cast -from typing_extensions import Annotated, get_args, get_origin +from typing import Annotated, Any, Optional, Union, cast +from typing_extensions import get_args, get_origin from strawberry.annotation import StrawberryAnnotation from strawberry.types.base import StrawberryType diff --git a/strawberry/types/base.py b/strawberry/types/base.py index 9c235093d7..636ae3cab9 100644 --- a/strawberry/types/base.py +++ b/strawberry/types/base.py @@ -7,12 +7,7 @@ Any, Callable, ClassVar, - Dict, - List, - Mapping, Optional, - Sequence, - Type, TypeVar, Union, overload, @@ -25,6 +20,7 @@ from strawberry.utils.typing import is_generic as is_type_generic if TYPE_CHECKING: + from collections.abc import Mapping, Sequence from typing_extensions import TypeGuard from graphql import GraphQLAbstractType, GraphQLResolveInfo @@ -44,7 +40,7 @@ class StrawberryType(ABC): """ @property - def type_params(self) -> List[TypeVar]: + def type_params(self) -> list[TypeVar]: return [] @property @@ -55,9 +51,9 @@ def is_one_of(self) -> bool: def copy_with( self, type_var_map: Mapping[ - str, Union[StrawberryType, Type[WithStrawberryObjectDefinition]] + str, Union[StrawberryType, type[WithStrawberryObjectDefinition]] ], - ) -> Union[StrawberryType, Type[WithStrawberryObjectDefinition]]: + ) -> Union[StrawberryType, type[WithStrawberryObjectDefinition]]: raise NotImplementedError() @property @@ -93,7 +89,7 @@ def __hash__(self) -> int: class StrawberryContainer(StrawberryType): def __init__( - self, of_type: Union[StrawberryType, Type[WithStrawberryObjectDefinition], type] + self, of_type: Union[StrawberryType, type[WithStrawberryObjectDefinition], type] ) -> None: self.of_type = of_type @@ -110,7 +106,7 @@ def __eq__(self, other: object) -> bool: return super().__eq__(other) @property - def type_params(self) -> List[TypeVar]: + def type_params(self) -> list[TypeVar]: if has_object_definition(self.of_type): parameters = getattr(self.of_type, "__parameters__", None) @@ -125,7 +121,7 @@ def type_params(self) -> List[TypeVar]: def copy_with( self, type_var_map: Mapping[ - str, Union[StrawberryType, Type[WithStrawberryObjectDefinition]] + str, Union[StrawberryType, type[WithStrawberryObjectDefinition]] ], ) -> Self: of_type_copy = self.of_type @@ -180,7 +176,7 @@ def has_generic(self, type_var: TypeVar) -> bool: return self.type_var == type_var @property - def type_params(self) -> List[TypeVar]: + def type_params(self) -> list[TypeVar]: return [self.type_var] def __eq__(self, other: object) -> bool: @@ -201,7 +197,7 @@ class WithStrawberryObjectDefinition(Protocol): def has_object_definition( obj: Any, -) -> TypeGuard[Type[WithStrawberryObjectDefinition]]: +) -> TypeGuard[type[WithStrawberryObjectDefinition]]: if hasattr(obj, "__strawberry_definition__"): return True # TODO: Generics remove dunder members here, so we inject it here. @@ -254,9 +250,9 @@ class StrawberryObjectDefinition(StrawberryType): name: str is_input: bool is_interface: bool - origin: Type[Any] + origin: type[Any] description: Optional[str] - interfaces: List[StrawberryObjectDefinition] + interfaces: list[StrawberryObjectDefinition] extend: bool directives: Optional[Sequence[object]] is_type_of: Optional[Callable[[Any, GraphQLResolveInfo], bool]] @@ -264,7 +260,7 @@ class StrawberryObjectDefinition(StrawberryType): Callable[[Any, GraphQLResolveInfo, GraphQLAbstractType], str] ] - fields: List[StrawberryField] + fields: list[StrawberryField] concrete_of: Optional[StrawberryObjectDefinition] = None """Concrete implementations of Generic TypeDefinitions fill this in""" @@ -296,7 +292,7 @@ def resolve_generic(self, wrapped_cls: type) -> type: def copy_with( self, type_var_map: Mapping[str, Union[StrawberryType, type]] - ) -> Type[WithStrawberryObjectDefinition]: + ) -> type[WithStrawberryObjectDefinition]: fields = [field.copy_with(type_var_map) for field in self.fields] new_type_definition = StrawberryObjectDefinition( @@ -353,7 +349,7 @@ def is_specialized_generic(self) -> bool: ) @property - def specialized_type_var_map(self) -> Optional[Dict[str, type]]: + def specialized_type_var_map(self) -> Optional[dict[str, type]]: return get_specialized_type_var_map(self.origin) @property @@ -361,14 +357,14 @@ def is_object_type(self) -> bool: return not self.is_input and not self.is_interface @property - def type_params(self) -> List[TypeVar]: - type_params: List[TypeVar] = [] + def type_params(self) -> list[TypeVar]: + type_params: list[TypeVar] = [] for field in self.fields: type_params.extend(field.type_params) return type_params - def is_implemented_by(self, root: Type[WithStrawberryObjectDefinition]) -> bool: + def is_implemented_by(self, root: type[WithStrawberryObjectDefinition]) -> bool: # TODO: Support dicts if isinstance(root, dict): raise NotImplementedError diff --git a/strawberry/types/enum.py b/strawberry/types/enum.py index 4ec01be188..2aaff852e4 100644 --- a/strawberry/types/enum.py +++ b/strawberry/types/enum.py @@ -1,11 +1,9 @@ import dataclasses +from collections.abc import Iterable, Mapping from enum import EnumMeta from typing import ( Any, Callable, - Iterable, - List, - Mapping, Optional, TypeVar, Union, @@ -29,7 +27,7 @@ class EnumValue: class EnumDefinition(StrawberryType): wrapped_cls: EnumMeta name: str - values: List[EnumValue] + values: list[EnumValue] description: Optional[str] directives: Iterable[object] = () @@ -228,4 +226,4 @@ def wrap(cls: EnumType) -> EnumType: return wrap(cls) -__all__ = ["EnumValue", "EnumDefinition", "EnumValueDefinition", "enum", "enum_value"] +__all__ = ["EnumDefinition", "EnumValue", "EnumValueDefinition", "enum", "enum_value"] diff --git a/strawberry/types/execution.py b/strawberry/types/execution.py index 94f3982a6f..d28440246a 100644 --- a/strawberry/types/execution.py +++ b/strawberry/types/execution.py @@ -4,12 +4,7 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - Iterable, - List, Optional, - Tuple, - Type, runtime_checkable, ) from typing_extensions import Protocol, TypedDict @@ -19,6 +14,7 @@ from strawberry.utils.operation import get_first_operation, get_operation_type if TYPE_CHECKING: + from collections.abc import Iterable from typing_extensions import NotRequired from graphql import ASTValidationRule @@ -37,12 +33,12 @@ class ExecutionContext: schema: Schema allowed_operations: Iterable[OperationType] context: Any = None - variables: Optional[Dict[str, Any]] = None + variables: Optional[dict[str, Any]] = None parse_options: ParseOptions = dataclasses.field( default_factory=lambda: ParseOptions() ) root_value: Optional[Any] = None - validation_rules: Tuple[Type[ASTValidationRule], ...] = dataclasses.field( + validation_rules: tuple[type[ASTValidationRule], ...] = dataclasses.field( default_factory=lambda: tuple(specified_rules) ) @@ -52,9 +48,9 @@ class ExecutionContext: # Values that get populated during the GraphQL execution so that they can be # accessed by extensions graphql_document: Optional[DocumentNode] = None - errors: Optional[List[GraphQLError]] = None + errors: Optional[list[GraphQLError]] = None result: Optional[GraphQLExecutionResult] = None - extensions_results: Dict[str, Any] = dataclasses.field(default_factory=dict) + extensions_results: dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self, provided_operation_name: str | None) -> None: self._provided_operation_name = provided_operation_name @@ -91,9 +87,9 @@ def _get_first_operation(self) -> Optional[OperationDefinitionNode]: @dataclasses.dataclass class ExecutionResult: - data: Optional[Dict[str, Any]] - errors: Optional[List[GraphQLError]] - extensions: Optional[Dict[str, Any]] = None + data: Optional[dict[str, Any]] + errors: Optional[list[GraphQLError]] + extensions: Optional[dict[str, Any]] = None @dataclasses.dataclass diff --git a/strawberry/types/field.py b/strawberry/types/field.py index 4c1fb8c812..2c5c95440f 100644 --- a/strawberry/types/field.py +++ b/strawberry/types/field.py @@ -4,18 +4,13 @@ import copy import dataclasses import sys +from collections.abc import Awaitable, Coroutine, Mapping, Sequence from functools import cached_property from typing import ( TYPE_CHECKING, Any, - Awaitable, Callable, - Coroutine, - List, - Mapping, Optional, - Sequence, - Type, TypeVar, Union, cast, @@ -83,17 +78,17 @@ def __init__( python_name: Optional[str] = None, graphql_name: Optional[str] = None, type_annotation: Optional[StrawberryAnnotation] = None, - origin: Optional[Union[Type, Callable, staticmethod, classmethod]] = None, + origin: Optional[Union[type, Callable, staticmethod, classmethod]] = None, is_subscription: bool = False, description: Optional[str] = None, base_resolver: Optional[StrawberryResolver] = None, - permission_classes: List[Type[BasePermission]] = (), # type: ignore + permission_classes: list[type[BasePermission]] = (), # type: ignore default: object = dataclasses.MISSING, default_factory: Union[Callable[[], Any], object] = dataclasses.MISSING, metadata: Optional[Mapping[Any, Any]] = None, deprecation_reason: Optional[str] = None, directives: Sequence[object] = (), - extensions: List[FieldExtension] = (), # type: ignore + extensions: list[FieldExtension] = (), # type: ignore ) -> None: # basic fields are fields with no provided resolver is_basic_field = not base_resolver @@ -124,7 +119,7 @@ def __init__( self.description: Optional[str] = description self.origin = origin - self._arguments: Optional[List[StrawberryArgument]] = None + self._arguments: Optional[list[StrawberryArgument]] = None self._base_resolver: Optional[StrawberryResolver] = None if base_resolver is not None: self.base_resolver = base_resolver @@ -142,9 +137,9 @@ def __init__( self.is_subscription = is_subscription - self.permission_classes: List[Type[BasePermission]] = list(permission_classes) + self.permission_classes: list[type[BasePermission]] = list(permission_classes) self.directives = list(directives) - self.extensions: List[FieldExtension] = list(extensions) + self.extensions: list[FieldExtension] = list(extensions) # Automatically add the permissions extension if len(self.permission_classes): @@ -213,7 +208,7 @@ def __call__(self, resolver: _RESOLVER_TYPE) -> Self: return self def get_result( - self, source: Any, info: Optional[Info], args: List[Any], kwargs: Any + self, source: Any, info: Optional[Info], args: list[Any], kwargs: Any ) -> Union[Awaitable[Any], Any]: """Calls the resolver defined for the StrawberryField. @@ -238,14 +233,14 @@ def is_basic_field(self) -> bool: return not self.base_resolver and not self.extensions @property - def arguments(self) -> List[StrawberryArgument]: + def arguments(self) -> list[StrawberryArgument]: if self._arguments is None: self._arguments = self.base_resolver.arguments if self.base_resolver else [] return self._arguments @arguments.setter - def arguments(self, value: List[StrawberryArgument]) -> None: + def arguments(self, value: list[StrawberryArgument]) -> None: self._arguments = value @property @@ -299,7 +294,7 @@ def type( self, ) -> Union[ # type: ignore [valid-type] StrawberryType, - Type[WithStrawberryObjectDefinition], + type[WithStrawberryObjectDefinition], Literal[UNRESOLVED], ]: return self.resolve_type() @@ -316,7 +311,7 @@ def type(self, type_: Any) -> None: # TODO: add this to arguments (and/or move it to StrawberryType) @property - def type_params(self) -> List[TypeVar]: + def type_params(self) -> list[TypeVar]: if has_object_definition(self.type): parameters = getattr(self.type, "__parameters__", None) @@ -334,7 +329,7 @@ def resolve_type( type_definition: Optional[StrawberryObjectDefinition] = None, ) -> Union[ # type: ignore [valid-type] StrawberryType, - Type[WithStrawberryObjectDefinition], + type[WithStrawberryObjectDefinition], Literal[UNRESOLVED], ]: # We return UNRESOLVED by default, which means this case will raise a @@ -385,7 +380,7 @@ def copy_with( new_field = copy.copy(self) override_type: Optional[ - Union[StrawberryType, Type[WithStrawberryObjectDefinition]] + Union[StrawberryType, type[WithStrawberryObjectDefinition]] ] = None type_ = self.resolve_type() if has_object_definition(type_): @@ -431,13 +426,13 @@ def field( is_subscription: bool = False, description: Optional[str] = None, init: Literal[False] = False, - permission_classes: Optional[List[Type[BasePermission]]] = None, + permission_classes: Optional[list[type[BasePermission]]] = None, deprecation_reason: Optional[str] = None, default: Any = dataclasses.MISSING, default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, metadata: Optional[Mapping[Any, Any]] = None, directives: Optional[Sequence[object]] = (), - extensions: Optional[List[FieldExtension]] = None, + extensions: Optional[list[FieldExtension]] = None, graphql_type: Optional[Any] = None, ) -> T: ... @@ -450,13 +445,13 @@ def field( is_subscription: bool = False, description: Optional[str] = None, init: Literal[False] = False, - permission_classes: Optional[List[Type[BasePermission]]] = None, + permission_classes: Optional[list[type[BasePermission]]] = None, deprecation_reason: Optional[str] = None, default: Any = dataclasses.MISSING, default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, metadata: Optional[Mapping[Any, Any]] = None, directives: Optional[Sequence[object]] = (), - extensions: Optional[List[FieldExtension]] = None, + extensions: Optional[list[FieldExtension]] = None, graphql_type: Optional[Any] = None, ) -> T: ... @@ -468,13 +463,13 @@ def field( is_subscription: bool = False, description: Optional[str] = None, init: Literal[True] = True, - permission_classes: Optional[List[Type[BasePermission]]] = None, + permission_classes: Optional[list[type[BasePermission]]] = None, deprecation_reason: Optional[str] = None, default: Any = dataclasses.MISSING, default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, metadata: Optional[Mapping[Any, Any]] = None, directives: Optional[Sequence[object]] = (), - extensions: Optional[List[FieldExtension]] = None, + extensions: Optional[list[FieldExtension]] = None, graphql_type: Optional[Any] = None, ) -> Any: ... @@ -486,13 +481,13 @@ def field( name: Optional[str] = None, is_subscription: bool = False, description: Optional[str] = None, - permission_classes: Optional[List[Type[BasePermission]]] = None, + permission_classes: Optional[list[type[BasePermission]]] = None, deprecation_reason: Optional[str] = None, default: Any = dataclasses.MISSING, default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, metadata: Optional[Mapping[Any, Any]] = None, directives: Optional[Sequence[object]] = (), - extensions: Optional[List[FieldExtension]] = None, + extensions: Optional[list[FieldExtension]] = None, graphql_type: Optional[Any] = None, ) -> StrawberryField: ... @@ -504,13 +499,13 @@ def field( name: Optional[str] = None, is_subscription: bool = False, description: Optional[str] = None, - permission_classes: Optional[List[Type[BasePermission]]] = None, + permission_classes: Optional[list[type[BasePermission]]] = None, deprecation_reason: Optional[str] = None, default: Any = dataclasses.MISSING, default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, metadata: Optional[Mapping[Any, Any]] = None, directives: Optional[Sequence[object]] = (), - extensions: Optional[List[FieldExtension]] = None, + extensions: Optional[list[FieldExtension]] = None, graphql_type: Optional[Any] = None, ) -> StrawberryField: ... @@ -521,13 +516,13 @@ def field( name: Optional[str] = None, is_subscription: bool = False, description: Optional[str] = None, - permission_classes: Optional[List[Type[BasePermission]]] = None, + permission_classes: Optional[list[type[BasePermission]]] = None, deprecation_reason: Optional[str] = None, default: Any = dataclasses.MISSING, default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, metadata: Optional[Mapping[Any, Any]] = None, directives: Optional[Sequence[object]] = (), - extensions: Optional[List[FieldExtension]] = None, + extensions: Optional[list[FieldExtension]] = None, graphql_type: Optional[Any] = None, # This init parameter is used by PyRight to determine whether this field # is added in the constructor or not. It is not used to change diff --git a/strawberry/types/fields/resolver.py b/strawberry/types/fields/resolver.py index fbe69a3844..dfa31e8a52 100644 --- a/strawberry/types/fields/resolver.py +++ b/strawberry/types/fields/resolver.py @@ -8,20 +8,17 @@ from inspect import isasyncgenfunction from typing import ( TYPE_CHECKING, + Annotated, Any, Callable, - Dict, Generic, - List, - Mapping, NamedTuple, Optional, - Tuple, TypeVar, Union, cast, ) -from typing_extensions import Annotated, Protocol, get_origin +from typing_extensions import Protocol, get_origin from strawberry.annotation import StrawberryAnnotation from strawberry.exceptions import ( @@ -36,6 +33,7 @@ if TYPE_CHECKING: import builtins + from collections.abc import Mapping class Parameter(inspect.Parameter): @@ -63,7 +61,7 @@ class Signature(inspect.Signature): class ReservedParameterSpecification(Protocol): def find( self, - parameters: Tuple[inspect.Parameter, ...], + parameters: tuple[inspect.Parameter, ...], resolver: StrawberryResolver[Any], ) -> Optional[inspect.Parameter]: """Finds the reserved parameter from ``parameters``.""" @@ -74,7 +72,7 @@ class ReservedName(NamedTuple): def find( self, - parameters: Tuple[inspect.Parameter, ...], + parameters: tuple[inspect.Parameter, ...], resolver: StrawberryResolver[Any], ) -> Optional[inspect.Parameter]: del resolver @@ -86,7 +84,7 @@ class ReservedNameBoundParameter(NamedTuple): def find( self, - parameters: Tuple[inspect.Parameter, ...], + parameters: tuple[inspect.Parameter, ...], resolver: StrawberryResolver[Any], ) -> Optional[inspect.Parameter]: del resolver @@ -109,7 +107,7 @@ class ReservedType(NamedTuple): def find( self, - parameters: Tuple[inspect.Parameter, ...], + parameters: tuple[inspect.Parameter, ...], resolver: StrawberryResolver[Any], ) -> Optional[inspect.Parameter]: # Go through all the types even after we've found one so we can @@ -181,7 +179,7 @@ def is_reserved_type(self, other: builtins.type) -> bool: class StrawberryResolver(Generic[T]): - RESERVED_PARAMSPEC: Tuple[ReservedParameterSpecification, ...] = ( + RESERVED_PARAMSPEC: tuple[ReservedParameterSpecification, ...] = ( SELF_PARAMSPEC, CLS_PARAMSPEC, ROOT_PARAMSPEC, @@ -218,7 +216,7 @@ def signature(self) -> inspect.Signature: @cached_property def strawberry_annotations( self, - ) -> Dict[inspect.Parameter, Union[StrawberryAnnotation, None]]: + ) -> dict[inspect.Parameter, Union[StrawberryAnnotation, None]]: return { p: ( StrawberryAnnotation(p.annotation, namespace=self._namespace) @@ -231,13 +229,13 @@ def strawberry_annotations( @cached_property def reserved_parameters( self, - ) -> Dict[ReservedParameterSpecification, Optional[inspect.Parameter]]: + ) -> dict[ReservedParameterSpecification, Optional[inspect.Parameter]]: """Mapping of reserved parameter specification to parameter.""" parameters = tuple(self.signature.parameters.values()) return {spec: spec.find(parameters, self) for spec in self.RESERVED_PARAMSPEC} @cached_property - def arguments(self) -> List[StrawberryArgument]: + def arguments(self) -> list[StrawberryArgument]: """Resolver arguments exposed in the GraphQL Schema.""" root_parameter = self.reserved_parameters.get(ROOT_PARAMSPEC) parent_parameter = self.reserved_parameters.get(PARENT_PARAMSPEC) @@ -258,8 +256,8 @@ def arguments(self) -> List[StrawberryArgument]: parameters = self.signature.parameters.values() reserved_parameters = set(self.reserved_parameters.values()) - missing_annotations: List[str] = [] - arguments: List[StrawberryArgument] = [] + missing_annotations: list[str] = [] + arguments: list[StrawberryArgument] = [] user_parameters = (p for p in parameters if p not in reserved_parameters) for param in user_parameters: @@ -301,7 +299,7 @@ def name(self) -> str: # TODO: consider deprecating @cached_property - def annotations(self) -> Dict[str, object]: + def annotations(self) -> dict[str, object]: """Annotations for the resolver. Does not include special args defined in `RESERVED_PARAMSPEC` (e.g. self, root, @@ -387,7 +385,7 @@ def copy_with( return other @cached_property - def _namespace(self) -> Dict[str, Any]: + def _namespace(self) -> dict[str, Any]: return sys.modules[self._unbound_wrapped_func.__module__].__dict__ @cached_property diff --git a/strawberry/types/graphql.py b/strawberry/types/graphql.py index d0bbd0237d..586a0b22c6 100644 --- a/strawberry/types/graphql.py +++ b/strawberry/types/graphql.py @@ -1,7 +1,7 @@ from __future__ import annotations import enum -from typing import TYPE_CHECKING, Set +from typing import TYPE_CHECKING if TYPE_CHECKING: from strawberry.http.types import HTTPMethod @@ -13,7 +13,7 @@ class OperationType(enum.Enum): SUBSCRIPTION = "subscription" @staticmethod - def from_http(method: HTTPMethod) -> Set[OperationType]: + def from_http(method: HTTPMethod) -> set[OperationType]: if method == "GET": return { OperationType.QUERY, diff --git a/strawberry/types/info.py b/strawberry/types/info.py index b5a95c924a..47f46e10d8 100644 --- a/strawberry/types/info.py +++ b/strawberry/types/info.py @@ -6,12 +6,8 @@ from typing import ( TYPE_CHECKING, Any, - Dict, Generic, - List, Optional, - Tuple, - Type, Union, ) from typing_extensions import TypeVar @@ -72,7 +68,7 @@ def hello(self, info: strawberry.Info[str, str]) -> str: _raw_info: GraphQLResolveInfo _field: StrawberryField - def __class_getitem__(cls, types: Union[type, Tuple[type, ...]]) -> Type[Info]: + def __class_getitem__(cls, types: Union[type, tuple[type, ...]]) -> type[Info]: """Workaround for when passing only one type. Python doesn't yet support directly passing only one type to a generic class @@ -97,7 +93,7 @@ def schema(self) -> Schema: return self._raw_info.schema._strawberry_schema # type: ignore @property - def field_nodes(self) -> List[FieldNode]: # deprecated + def field_nodes(self) -> list[FieldNode]: # deprecated warnings.warn( "`info.field_nodes` is deprecated, use `selected_fields` instead", DeprecationWarning, @@ -107,7 +103,7 @@ def field_nodes(self) -> List[FieldNode]: # deprecated return self._raw_info.field_nodes @cached_property - def selected_fields(self) -> List[Selection]: + def selected_fields(self) -> list[Selection]: """The fields that were selected on the current field's type.""" info = self._raw_info return convert_selections(info, info.field_nodes) @@ -123,14 +119,14 @@ def root_value(self) -> RootValueType: return self._raw_info.root_value @property - def variable_values(self) -> Dict[str, Any]: + def variable_values(self) -> dict[str, Any]: """The variable values passed to the query execution.""" return self._raw_info.variable_values @property def return_type( self, - ) -> Optional[Union[Type[WithStrawberryObjectDefinition], StrawberryType]]: + ) -> Optional[Union[type[WithStrawberryObjectDefinition], StrawberryType]]: """The return type of the current field being resolved.""" return self._field.type diff --git a/strawberry/types/lazy_type.py b/strawberry/types/lazy_type.py index 776ad52d4a..e08268f949 100644 --- a/strawberry/types/lazy_type.py +++ b/strawberry/types/lazy_type.py @@ -9,8 +9,6 @@ ForwardRef, Generic, Optional, - Tuple, - Type, TypeVar, Union, cast, @@ -36,7 +34,7 @@ class LazyType(Generic[TypeName, Module]): module: str package: Optional[str] = None - def __class_getitem__(cls, params: Tuple[str, str]) -> "Self": + def __class_getitem__(cls, params: tuple[str, str]) -> "Self": warnings.warn( ( "LazyType is deprecated, use " @@ -61,7 +59,7 @@ def __class_getitem__(cls, params: Tuple[str, str]) -> "Self": def __or__(self, other: Other) -> object: return Union[self, other] - def resolve_type(self) -> Type[Any]: + def resolve_type(self) -> type[Any]: module = importlib.import_module(self.module, self.package) main_module = sys.modules.get("__main__", None) if main_module: @@ -84,7 +82,7 @@ def resolve_type(self) -> Type[Any]: return module.__dict__[self.type_name] # this empty call method allows LazyTypes to be used in generic types - # for example: List[LazyType["A", "module"]] + # for example: list[LazyType["A", "module"]] def __call__(self) -> None: # pragma: no cover return None diff --git a/strawberry/types/mutation.py b/strawberry/types/mutation.py index bf0bb2b792..183240c213 100644 --- a/strawberry/types/mutation.py +++ b/strawberry/types/mutation.py @@ -5,11 +5,7 @@ TYPE_CHECKING, Any, Callable, - List, - Mapping, Optional, - Sequence, - Type, Union, overload, ) @@ -25,6 +21,7 @@ ) if TYPE_CHECKING: + from collections.abc import Mapping, Sequence from typing_extensions import Literal from strawberry.extensions.field_extension import FieldExtension @@ -41,13 +38,13 @@ def mutation( name: Optional[str] = None, description: Optional[str] = None, init: Literal[False] = False, - permission_classes: Optional[List[Type[BasePermission]]] = None, + permission_classes: Optional[list[type[BasePermission]]] = None, deprecation_reason: Optional[str] = None, default: Any = dataclasses.MISSING, default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, metadata: Optional[Mapping[Any, Any]] = None, directives: Optional[Sequence[object]] = (), - extensions: Optional[List[FieldExtension]] = None, + extensions: Optional[list[FieldExtension]] = None, graphql_type: Optional[Any] = None, ) -> T: ... @@ -59,13 +56,13 @@ def mutation( name: Optional[str] = None, description: Optional[str] = None, init: Literal[False] = False, - permission_classes: Optional[List[Type[BasePermission]]] = None, + permission_classes: Optional[list[type[BasePermission]]] = None, deprecation_reason: Optional[str] = None, default: Any = dataclasses.MISSING, default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, metadata: Optional[Mapping[Any, Any]] = None, directives: Optional[Sequence[object]] = (), - extensions: Optional[List[FieldExtension]] = None, + extensions: Optional[list[FieldExtension]] = None, graphql_type: Optional[Any] = None, ) -> T: ... @@ -76,13 +73,13 @@ def mutation( name: Optional[str] = None, description: Optional[str] = None, init: Literal[True] = True, - permission_classes: Optional[List[Type[BasePermission]]] = None, + permission_classes: Optional[list[type[BasePermission]]] = None, deprecation_reason: Optional[str] = None, default: Any = dataclasses.MISSING, default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, metadata: Optional[Mapping[Any, Any]] = None, directives: Optional[Sequence[object]] = (), - extensions: Optional[List[FieldExtension]] = None, + extensions: Optional[list[FieldExtension]] = None, graphql_type: Optional[Any] = None, ) -> Any: ... @@ -93,13 +90,13 @@ def mutation( *, name: Optional[str] = None, description: Optional[str] = None, - permission_classes: Optional[List[Type[BasePermission]]] = None, + permission_classes: Optional[list[type[BasePermission]]] = None, deprecation_reason: Optional[str] = None, default: Any = dataclasses.MISSING, default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, metadata: Optional[Mapping[Any, Any]] = None, directives: Optional[Sequence[object]] = (), - extensions: Optional[List[FieldExtension]] = None, + extensions: Optional[list[FieldExtension]] = None, graphql_type: Optional[Any] = None, ) -> StrawberryField: ... @@ -110,13 +107,13 @@ def mutation( *, name: Optional[str] = None, description: Optional[str] = None, - permission_classes: Optional[List[Type[BasePermission]]] = None, + permission_classes: Optional[list[type[BasePermission]]] = None, deprecation_reason: Optional[str] = None, default: Any = dataclasses.MISSING, default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, metadata: Optional[Mapping[Any, Any]] = None, directives: Optional[Sequence[object]] = (), - extensions: Optional[List[FieldExtension]] = None, + extensions: Optional[list[FieldExtension]] = None, graphql_type: Optional[Any] = None, ) -> StrawberryField: ... @@ -126,13 +123,13 @@ def mutation( *, name: Optional[str] = None, description: Optional[str] = None, - permission_classes: Optional[List[Type[BasePermission]]] = None, + permission_classes: Optional[list[type[BasePermission]]] = None, deprecation_reason: Optional[str] = None, default: Any = dataclasses.MISSING, default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, metadata: Optional[Mapping[Any, Any]] = None, directives: Optional[Sequence[object]] = (), - extensions: Optional[List[FieldExtension]] = None, + extensions: Optional[list[FieldExtension]] = None, graphql_type: Optional[Any] = None, # This init parameter is used by PyRight to determine whether this field # is added in the constructor or not. It is not used to change @@ -201,13 +198,13 @@ def subscription( name: Optional[str] = None, description: Optional[str] = None, init: Literal[False] = False, - permission_classes: Optional[List[Type[BasePermission]]] = None, + permission_classes: Optional[list[type[BasePermission]]] = None, deprecation_reason: Optional[str] = None, default: Any = dataclasses.MISSING, default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, metadata: Optional[Mapping[Any, Any]] = None, directives: Optional[Sequence[object]] = (), - extensions: Optional[List[FieldExtension]] = None, + extensions: Optional[list[FieldExtension]] = None, graphql_type: Optional[Any] = None, ) -> T: ... @@ -219,13 +216,13 @@ def subscription( name: Optional[str] = None, description: Optional[str] = None, init: Literal[False] = False, - permission_classes: Optional[List[Type[BasePermission]]] = None, + permission_classes: Optional[list[type[BasePermission]]] = None, deprecation_reason: Optional[str] = None, default: Any = dataclasses.MISSING, default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, metadata: Optional[Mapping[Any, Any]] = None, directives: Optional[Sequence[object]] = (), - extensions: Optional[List[FieldExtension]] = None, + extensions: Optional[list[FieldExtension]] = None, graphql_type: Optional[Any] = None, ) -> T: ... @@ -236,13 +233,13 @@ def subscription( name: Optional[str] = None, description: Optional[str] = None, init: Literal[True] = True, - permission_classes: Optional[List[Type[BasePermission]]] = None, + permission_classes: Optional[list[type[BasePermission]]] = None, deprecation_reason: Optional[str] = None, default: Any = dataclasses.MISSING, default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, metadata: Optional[Mapping[Any, Any]] = None, directives: Optional[Sequence[object]] = (), - extensions: Optional[List[FieldExtension]] = None, + extensions: Optional[list[FieldExtension]] = None, graphql_type: Optional[Any] = None, ) -> Any: ... @@ -253,13 +250,13 @@ def subscription( *, name: Optional[str] = None, description: Optional[str] = None, - permission_classes: Optional[List[Type[BasePermission]]] = None, + permission_classes: Optional[list[type[BasePermission]]] = None, deprecation_reason: Optional[str] = None, default: Any = dataclasses.MISSING, default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, metadata: Optional[Mapping[Any, Any]] = None, directives: Optional[Sequence[object]] = (), - extensions: Optional[List[FieldExtension]] = None, + extensions: Optional[list[FieldExtension]] = None, graphql_type: Optional[Any] = None, ) -> StrawberryField: ... @@ -270,13 +267,13 @@ def subscription( *, name: Optional[str] = None, description: Optional[str] = None, - permission_classes: Optional[List[Type[BasePermission]]] = None, + permission_classes: Optional[list[type[BasePermission]]] = None, deprecation_reason: Optional[str] = None, default: Any = dataclasses.MISSING, default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, metadata: Optional[Mapping[Any, Any]] = None, directives: Optional[Sequence[object]] = (), - extensions: Optional[List[FieldExtension]] = None, + extensions: Optional[list[FieldExtension]] = None, graphql_type: Optional[Any] = None, ) -> StrawberryField: ... @@ -286,13 +283,13 @@ def subscription( *, name: Optional[str] = None, description: Optional[str] = None, - permission_classes: Optional[List[Type[BasePermission]]] = None, + permission_classes: Optional[list[type[BasePermission]]] = None, deprecation_reason: Optional[str] = None, default: Any = dataclasses.MISSING, default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, metadata: Optional[Mapping[Any, Any]] = None, directives: Optional[Sequence[object]] = (), - extensions: Optional[List[FieldExtension]] = None, + extensions: Optional[list[FieldExtension]] = None, graphql_type: Optional[Any] = None, init: Literal[True, False, None] = None, ) -> Any: diff --git a/strawberry/types/nodes.py b/strawberry/types/nodes.py index 20a8407a8e..80da9092fd 100644 --- a/strawberry/types/nodes.py +++ b/strawberry/types/nodes.py @@ -12,7 +12,7 @@ from __future__ import annotations import dataclasses -from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, List, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union from graphql.language import FieldNode as GQLFieldNode from graphql.language import FragmentSpreadNode as GQLFragmentSpreadNode @@ -22,12 +22,14 @@ from graphql.language import VariableNode as GQLVariableNode if TYPE_CHECKING: + from collections.abc import Collection, Iterable + from graphql import GraphQLResolveInfo from graphql.language import ArgumentNode as GQLArgumentNode from graphql.language import DirectiveNode as GQLDirectiveNode from graphql.language import ValueNode as GQLValueNode -Arguments = Dict[str, Any] -Directives = Dict[str, Arguments] +Arguments = dict[str, Any] +Directives = dict[str, Arguments] Selection = Union["SelectedField", "FragmentSpread", "InlineFragment"] @@ -62,9 +64,9 @@ def convert_directives( def convert_selections( info: GraphQLResolveInfo, field_nodes: Collection[GQLFieldNode] -) -> List[Selection]: +) -> list[Selection]: """Return typed `Selection` based on node type.""" - selections: List[Selection] = [] + selections: list[Selection] = [] for node in field_nodes: if isinstance(node, GQLFieldNode): selections.append(SelectedField.from_node(info, node)) @@ -85,7 +87,7 @@ class FragmentSpread: name: str type_condition: str directives: Directives - selections: List[Selection] + selections: list[Selection] @classmethod def from_node( @@ -111,7 +113,7 @@ class InlineFragment: """Wrapper for a InlineFragmentNode.""" type_condition: str - selections: List[Selection] + selections: list[Selection] directives: Directives @classmethod @@ -136,7 +138,7 @@ class SelectedField: name: str directives: Directives arguments: Arguments - selections: List[Selection] + selections: list[Selection] alias: Optional[str] = None @classmethod @@ -152,4 +154,4 @@ def from_node(cls, info: GraphQLResolveInfo, node: GQLFieldNode) -> SelectedFiel ) -__all__ = ["convert_selections", "FragmentSpread", "InlineFragment", "SelectedField"] +__all__ = ["FragmentSpread", "InlineFragment", "SelectedField", "convert_selections"] diff --git a/strawberry/types/object_type.py b/strawberry/types/object_type.py index c6120763aa..8160c79f6f 100644 --- a/strawberry/types/object_type.py +++ b/strawberry/types/object_type.py @@ -1,15 +1,13 @@ +import builtins import dataclasses import inspect import sys import types +from collections.abc import Sequence from typing import ( Any, Callable, - Dict, - List, Optional, - Sequence, - Type, TypeVar, Union, overload, @@ -29,11 +27,11 @@ from .field import StrawberryField, field from .type_resolver import _get_fields -T = TypeVar("T", bound=Type) +T = TypeVar("T", bound=builtins.type) -def _get_interfaces(cls: Type[Any]) -> List[StrawberryObjectDefinition]: - interfaces: List[StrawberryObjectDefinition] = [] +def _get_interfaces(cls: builtins.type[Any]) -> list[StrawberryObjectDefinition]: + interfaces: list[StrawberryObjectDefinition] = [] for base in cls.__mro__[1:]: # Exclude current class type_definition = get_object_definition(base) if type_definition and type_definition.is_interface: @@ -42,7 +40,7 @@ def _get_interfaces(cls: Type[Any]) -> List[StrawberryObjectDefinition]: return interfaces -def _check_field_annotations(cls: Type[Any]) -> None: +def _check_field_annotations(cls: builtins.type[Any]) -> None: """Are any of the dataclass Fields missing type annotations? This is similar to the check that dataclasses do during creation, but allows us to @@ -100,12 +98,12 @@ def _check_field_annotations(cls: Type[Any]) -> None: raise MissingFieldAnnotationError(field_name, cls) -def _wrap_dataclass(cls: Type[T]) -> Type[T]: +def _wrap_dataclass(cls: builtins.type[T]) -> builtins.type[T]: """Wrap a strawberry.type class with a dataclass and check for any issues before doing so.""" # Ensure all Fields have been properly type-annotated _check_field_annotations(cls) - dclass_kwargs: Dict[str, bool] = {} + dclass_kwargs: dict[str, bool] = {} # Python 3.10 introduces the kw_only param. If we're on an older version # then generate our own custom init function @@ -133,7 +131,7 @@ def _process_type( description: Optional[str] = None, directives: Optional[Sequence[object]] = (), extend: bool = False, - original_type_annotations: Optional[Dict[str, Any]] = None, + original_type_annotations: Optional[dict[str, Any]] = None, ) -> T: name = name or to_camel_case(cls.__name__) original_type_annotations = original_type_annotations or {} @@ -143,7 +141,7 @@ def _process_type( is_type_of = getattr(cls, "is_type_of", None) resolve_type = getattr(cls, "resolve_type", None) - cls.__strawberry_definition__ = StrawberryObjectDefinition( + cls.__strawberry_definition__ = StrawberryObjectDefinition( # type: ignore[attr-defined] name=name, is_input=is_input, is_interface=is_interface, @@ -159,7 +157,7 @@ def _process_type( # TODO: remove when deprecating _type_definition DeprecatedDescriptor( DEPRECATION_MESSAGES._TYPE_DEFINITION, - cls.__strawberry_definition__, + cls.__strawberry_definition__, # type: ignore[attr-defined] "_type_definition", ).inject(cls) @@ -279,7 +277,7 @@ def wrap(cls: T) -> T: # >>> class Query: # >>> a: int = strawberry.field(graphql_type=str) # so we need to extract the information before running `_wrap_dataclass` - original_type_annotations: Dict[str, Any] = {} + original_type_annotations: dict[str, Any] = {} annotations = getattr(cls, "__annotations__", {}) @@ -460,7 +458,7 @@ class MyNode: ) -def asdict(obj: Any) -> Dict[str, object]: +def asdict(obj: Any) -> dict[str, object]: """Convert a strawberry object into a dictionary. This wraps the dataclasses.asdict function to strawberry. @@ -489,8 +487,8 @@ class User: __all__ = [ "StrawberryObjectDefinition", + "asdict", "input", "interface", "type", - "asdict", ] diff --git a/strawberry/types/private.py b/strawberry/types/private.py index cd35d6ac0d..238bbde2a5 100644 --- a/strawberry/types/private.py +++ b/strawberry/types/private.py @@ -1,5 +1,4 @@ -from typing import TypeVar -from typing_extensions import Annotated +from typing import Annotated, TypeVar from strawberry.utils.typing import type_has_annotation diff --git a/strawberry/types/scalar.py b/strawberry/types/scalar.py index c5adf22ae8..1bb1fefa1e 100644 --- a/strawberry/types/scalar.py +++ b/strawberry/types/scalar.py @@ -6,8 +6,6 @@ TYPE_CHECKING, Any, Callable, - Iterable, - Mapping, NewType, Optional, TypeVar, @@ -20,6 +18,8 @@ from strawberry.utils.str_converters import to_camel_case if TYPE_CHECKING: + from collections.abc import Iterable, Mapping + from graphql import GraphQLScalarType diff --git a/strawberry/types/type_resolver.py b/strawberry/types/type_resolver.py index 975980880e..cabedaa6d9 100644 --- a/strawberry/types/type_resolver.py +++ b/strawberry/types/type_resolver.py @@ -2,7 +2,7 @@ import dataclasses import sys -from typing import Any, Dict, List, Type +from typing import Any from strawberry.annotation import StrawberryAnnotation from strawberry.exceptions import ( @@ -17,8 +17,8 @@ def _get_fields( - cls: Type[Any], original_type_annotations: Dict[str, Type[Any]] -) -> List[StrawberryField]: + cls: type[Any], original_type_annotations: dict[str, type[Any]] +) -> list[StrawberryField]: """Get all the strawberry fields off a strawberry.type cls. This function returns a list of StrawberryFields (one for each field item), while @@ -54,7 +54,7 @@ class if one is not set by either using an explicit strawberry.field(name=...) o passing a named function (i.e. not an anonymous lambda) to strawberry.field (typically as a decorator). """ - fields: Dict[str, StrawberryField] = {} + fields: dict[str, StrawberryField] = {} # before trying to find any fields, let's first add the fields defined in # parent classes, we do this by checking if parents have a type definition @@ -71,7 +71,7 @@ class if one is not set by either using an explicit strawberry.field(name=...) o # Find the class the each field was originally defined on so we can use # that scope later when resolving the type, as it may have different names # available to it. - origins: Dict[str, type] = {field_name: cls for field_name in cls.__annotations__} + origins: dict[str, type] = {field_name: cls for field_name in cls.__annotations__} for base in cls.__mro__: if has_object_definition(base): diff --git a/strawberry/types/union.py b/strawberry/types/union.py index f5d8c6210c..003d532c90 100644 --- a/strawberry/types/union.py +++ b/strawberry/types/union.py @@ -6,20 +6,15 @@ from itertools import chain from typing import ( TYPE_CHECKING, + Annotated, Any, - Collection, - Iterable, - List, - Mapping, NoReturn, Optional, - Tuple, - Type, TypeVar, Union, cast, ) -from typing_extensions import Annotated, get_origin +from typing_extensions import get_origin from graphql import GraphQLNamedType, GraphQLUnionType @@ -39,6 +34,8 @@ from strawberry.types.lazy_type import LazyType if TYPE_CHECKING: + from collections.abc import Collection, Iterable, Mapping + from graphql import ( GraphQLAbstractType, GraphQLResolveInfo, @@ -57,7 +54,7 @@ class StrawberryUnion(StrawberryType): def __init__( self, name: Optional[str] = None, - type_annotations: Tuple[StrawberryAnnotation, ...] = tuple(), + type_annotations: tuple[StrawberryAnnotation, ...] = tuple(), description: Optional[str] = None, directives: Iterable[object] = (), ) -> None: @@ -94,14 +91,14 @@ def __or__(self, other: Union[StrawberryType, type]) -> StrawberryType: raise InvalidTypeForUnionMergeError(self, other) @property - def types(self) -> Tuple[StrawberryType, ...]: + def types(self) -> tuple[StrawberryType, ...]: return tuple( cast(StrawberryType, annotation.resolve()) for annotation in self.type_annotations ) @property - def type_params(self) -> List[TypeVar]: + def type_params(self) -> list[TypeVar]: def _get_type_params(type_: StrawberryType) -> list[TypeVar]: if isinstance(type_, LazyType): type_ = cast("StrawberryType", type_.resolve_type()) @@ -240,7 +237,7 @@ def is_valid_union_type(type_: object) -> bool: def union( name: str, - types: Optional[Collection[Type[Any]]] = None, + types: Optional[Collection[type[Any]]] = None, *, description: Optional[str] = None, directives: Iterable[object] = (), diff --git a/strawberry/types/unset.py b/strawberry/types/unset.py index 31c8012b73..e1d2acea0f 100644 --- a/strawberry/types/unset.py +++ b/strawberry/types/unset.py @@ -1,7 +1,7 @@ import warnings -from typing import Any, Dict, Optional, Type +from typing import Any, Optional -DEPRECATED_NAMES: Dict[str, str] = { +DEPRECATED_NAMES: dict[str, str] = { "is_unset": "`is_unset` is deprecated use `value is UNSET` instead", } @@ -9,7 +9,7 @@ class UnsetType: __instance: Optional["UnsetType"] = None - def __new__(cls: Type["UnsetType"]) -> "UnsetType": + def __new__(cls: type["UnsetType"]) -> "UnsetType": if cls.__instance is None: ret = super().__new__(cls) cls.__instance = ret diff --git a/strawberry/utils/aio.py b/strawberry/utils/aio.py index 7ffd8f9e67..ddbff937f7 100644 --- a/strawberry/utils/aio.py +++ b/strawberry/utils/aio.py @@ -1,14 +1,9 @@ import sys +from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Awaitable from typing import ( Any, - AsyncGenerator, - AsyncIterable, - AsyncIterator, - Awaitable, Callable, - List, Optional, - Tuple, TypeVar, Union, ) @@ -19,7 +14,7 @@ async def aenumerate( iterable: Union[AsyncIterator[_T], AsyncIterable[_T]], -) -> AsyncIterator[Tuple[int, _T]]: +) -> AsyncIterator[tuple[int, _T]]: """Async version of enumerate.""" i = 0 async for element in iterable: @@ -56,7 +51,7 @@ async def aislice( return -async def asyncgen_to_list(generator: AsyncGenerator[_T, Any]) -> List[_T]: +async def asyncgen_to_list(generator: AsyncGenerator[_T, Any]) -> list[_T]: """Convert an async generator to a list.""" return [element async for element in generator] diff --git a/strawberry/utils/await_maybe.py b/strawberry/utils/await_maybe.py index 65b1a23bd6..6833d26d07 100644 --- a/strawberry/utils/await_maybe.py +++ b/strawberry/utils/await_maybe.py @@ -1,5 +1,6 @@ import inspect -from typing import AsyncIterator, Awaitable, Iterator, TypeVar, Union +from collections.abc import AsyncIterator, Awaitable, Iterator +from typing import TypeVar, Union T = TypeVar("T") @@ -14,4 +15,4 @@ async def await_maybe(value: AwaitableOrValue[T]) -> T: return value -__all__ = ["await_maybe", "AwaitableOrValue", "AsyncIteratorOrIterator"] +__all__ = ["AsyncIteratorOrIterator", "AwaitableOrValue", "await_maybe"] diff --git a/strawberry/utils/debug.py b/strawberry/utils/debug.py index 825a2ccc24..25fa0e5f7f 100644 --- a/strawberry/utils/debug.py +++ b/strawberry/utils/debug.py @@ -1,7 +1,7 @@ import datetime import json from json import JSONEncoder -from typing import Any, Dict, Optional +from typing import Any, Optional class StrawberryJSONEncoder(JSONEncoder): @@ -10,7 +10,7 @@ def default(self, o: Any) -> Any: def pretty_print_graphql_operation( - operation_name: Optional[str], query: str, variables: Optional[Dict["str", Any]] + operation_name: Optional[str], query: str, variables: Optional[dict["str", Any]] ) -> None: """Pretty print a GraphQL operation using pygments. diff --git a/strawberry/utils/deprecations.py b/strawberry/utils/deprecations.py index 4e802c5b43..646c31225a 100644 --- a/strawberry/utils/deprecations.py +++ b/strawberry/utils/deprecations.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import Any, Optional, Type +from typing import Any, Optional class DEPRECATION_MESSAGES: @@ -19,7 +19,7 @@ def __init__(self, msg: str, alias: object, attr_name: str) -> None: def warn(self) -> None: warnings.warn(self.msg, stacklevel=2) - def __get__(self, obj: Optional[object], type: Optional[Type] = None) -> Any: + def __get__(self, obj: Optional[object], type: Optional[type] = None) -> Any: self.warn() return self.alias diff --git a/strawberry/utils/inspect.py b/strawberry/utils/inspect.py index 4ed67c6be4..650c53b498 100644 --- a/strawberry/utils/inspect.py +++ b/strawberry/utils/inspect.py @@ -4,8 +4,6 @@ from typing import ( Any, Callable, - Dict, - List, Optional, TypeVar, get_origin, @@ -27,7 +25,7 @@ def in_async_context() -> bool: @lru_cache(maxsize=250) -def get_func_args(func: Callable[[Any], Any]) -> List[str]: +def get_func_args(func: Callable[[Any], Any]) -> list[str]: """Returns a list of arguments for the function.""" sig = inspect.signature(func) @@ -38,7 +36,7 @@ def get_func_args(func: Callable[[Any], Any]) -> List[str]: ] -def get_specialized_type_var_map(cls: type) -> Optional[Dict[str, type]]: +def get_specialized_type_var_map(cls: type) -> Optional[dict[str, type]]: """Get a type var map for specialized types. Consider the following: @@ -122,4 +120,4 @@ class IntBarFoo(IntBar, Foo[str]): ... return type_var_map -__all__ = ["in_async_context", "get_func_args", "get_specialized_type_var_map"] +__all__ = ["get_func_args", "get_specialized_type_var_map", "in_async_context"] diff --git a/strawberry/utils/str_converters.py b/strawberry/utils/str_converters.py index 9c52a7ed7d..2fac7f6df3 100644 --- a/strawberry/utils/str_converters.py +++ b/strawberry/utils/str_converters.py @@ -26,4 +26,4 @@ def to_snake_case(name: str) -> str: return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower() -__all__ = ["to_camel_case", "to_kebab_case", "capitalize_first", "to_snake_case"] +__all__ = ["capitalize_first", "to_camel_case", "to_kebab_case", "to_snake_case"] diff --git a/strawberry/utils/typing.py b/strawberry/utils/typing.py index 245f6542ab..e06a7f72a2 100644 --- a/strawberry/utils/typing.py +++ b/strawberry/utils/typing.py @@ -2,19 +2,15 @@ import dataclasses import sys import typing +from collections.abc import AsyncGenerator from functools import lru_cache from typing import ( # type: ignore - TYPE_CHECKING, + Annotated, Any, - AsyncGenerator, ClassVar, - Dict, ForwardRef, Generic, - List, Optional, - Tuple, - Type, TypeVar, Union, _eval_type, @@ -23,22 +19,11 @@ cast, overload, ) -from typing_extensions import Annotated, TypeGuard, get_args, get_origin - -ast_unparse = getattr(ast, "unparse", None) -# ast.unparse is only available on python 3.9+. For older versions we will -# use `astunparse.unparse`. -# We are also using "not TYPE_CHECKING" here because mypy gives an erorr -# on tests because "astunparse" is missing stubs, but the mypy action says -# that the comment is unused. -if not TYPE_CHECKING and ast_unparse is None: - import astunparse - - ast_unparse = astunparse.unparse +from typing_extensions import TypeGuard, get_args, get_origin @lru_cache -def get_generic_alias(type_: Type) -> Type: +def get_generic_alias(type_: type) -> type: """Get the generic alias for a type. Given a type, its generic alias from `typing` module will be returned @@ -105,21 +90,21 @@ def is_union(annotation: object) -> bool: return annotation_origin == Union -def is_optional(annotation: Type) -> bool: +def is_optional(annotation: type) -> bool: """Returns True if the annotation is Optional[SomeType].""" # Optionals are represented as unions if not is_union(annotation): return False - types = annotation.__args__ + types = annotation.__args__ # type: ignore[attr-defined] # A Union to be optional needs to have at least one None type return any(x == None.__class__ for x in types) -def get_optional_annotation(annotation: Type) -> Type: - types = annotation.__args__ +def get_optional_annotation(annotation: type) -> type: + types = annotation.__args__ # type: ignore[attr-defined] non_none_types = tuple(x for x in types if x != None.__class__) @@ -127,13 +112,13 @@ def get_optional_annotation(annotation: Type) -> Type: # type (normally a Union type). if len(non_none_types) > 1: - return annotation.copy_with(non_none_types) + return annotation.copy_with(non_none_types) # type: ignore[attr-defined] return non_none_types[0] -def get_list_annotation(annotation: Type) -> Type: - return annotation.__args__[0] +def get_list_annotation(annotation: type) -> type: + return annotation.__args__[0] # type: ignore[attr-defined] def is_concrete_generic(annotation: type) -> bool: @@ -161,7 +146,7 @@ def is_generic(annotation: type) -> bool: ) -def is_type_var(annotation: Type) -> bool: +def is_type_var(annotation: type) -> bool: """Returns True if the annotation is a TypeVar.""" return isinstance(annotation, TypeVar) @@ -186,7 +171,7 @@ def is_classvar(cls: type, annotation: Union[ForwardRef, str]) -> bool: ) -def type_has_annotation(type_: object, annotation: Type) -> bool: +def type_has_annotation(type_: object, annotation: type) -> bool: """Returns True if the type_ has been annotated with annotation.""" if get_origin(type_) is Annotated: return any(isinstance(argument, annotation) for argument in get_args(type_)) @@ -194,14 +179,13 @@ def type_has_annotation(type_: object, annotation: Type) -> bool: return False -def get_parameters(annotation: Type) -> Union[Tuple[object], Tuple[()]]: - if ( - isinstance(annotation, _GenericAlias) - or isinstance(annotation, type) +def get_parameters(annotation: type) -> Union[tuple[object], tuple[()]]: + if isinstance(annotation, _GenericAlias) or ( + isinstance(annotation, type) and issubclass(annotation, Generic) # type:ignore and annotation is not Generic ): - return annotation.__parameters__ + return annotation.__parameters__ # type: ignore[union-attr] else: return () # pragma: no cover @@ -238,8 +222,7 @@ def _ast_replace_union_operation( if hasattr(ast, "Index") and isinstance(expr.slice, ast.Index): expr = ast.Subscript( expr.value, - # The cast is required for mypy on python 3.7 and 3.8 - ast.Index(_ast_replace_union_operation(cast(Any, expr.slice).value)), # type: ignore + ast.Index(_ast_replace_union_operation(expr.slice.value)), # type: ignore ast.Load(), ) elif isinstance(expr.slice, (ast.BinOp, ast.Tuple)): @@ -254,9 +237,9 @@ def _ast_replace_union_operation( def _get_namespace_from_ast( expr: Union[ast.Expr, ast.expr], - globalns: Optional[Dict] = None, - localns: Optional[Dict] = None, -) -> Dict[str, Type]: + globalns: Optional[dict] = None, + localns: Optional[dict] = None, +) -> dict[str, type]: from strawberry.types.lazy_type import StrawberryLazyReference extra = {} @@ -274,7 +257,6 @@ def _get_namespace_from_ast( and expr.value.id == "Union" ): if hasattr(ast, "Index") and isinstance(expr.slice, ast.Index): - # The cast is required for mypy on python 3.7 and 3.8 expr_slice = cast(Any, expr.slice).value else: expr_slice = expr.slice @@ -292,18 +274,15 @@ def _get_namespace_from_ast( and isinstance(expr.value, ast.Name) and expr.value.id == "Annotated" ): - assert ast_unparse - if hasattr(ast, "Index") and isinstance(expr.slice, ast.Index): - # The cast is required for mypy on python 3.7 and 3.8 expr_slice = cast(Any, expr.slice).value else: expr_slice = expr.slice - args: List[str] = [] + args: list[str] = [] for elt in cast(ast.Tuple, expr_slice).elts: extra.update(_get_namespace_from_ast(elt, globalns, localns)) - args.append(ast_unparse(elt)) + args.append(ast.unparse(elt)) # When using forward refs, the whole # Annotated[SomeType, strawberry.lazy("type.module")] is a forward ref, @@ -322,16 +301,16 @@ def _get_namespace_from_ast( def eval_type( type_: Any, - globalns: Optional[Dict] = None, - localns: Optional[Dict] = None, -) -> Type: + globalns: Optional[dict] = None, + localns: Optional[dict] = None, +) -> type: """Evaluates a type, resolving forward references.""" from strawberry.types.auto import StrawberryAuto from strawberry.types.lazy_type import StrawberryLazyReference from strawberry.types.private import StrawberryPrivate globalns = globalns or {} - # If this is not a string, maybe its args are (e.g. List["Foo"]) + # If this is not a string, maybe its args are (e.g. list["Foo"]) if isinstance(type_, ForwardRef): ast_obj = cast(ast.Expr, ast.parse(type_.__forward_arg__).body[0]) @@ -347,10 +326,9 @@ def eval_type( globalns.update(_get_namespace_from_ast(ast_obj, globalns, localns)) - assert ast_unparse - type_ = ForwardRef(ast_unparse(ast_obj)) + type_ = ForwardRef(ast.unparse(ast_obj)) - extra: Dict[str, Any] = {} + extra: dict[str, Any] = {} if sys.version_info >= (3, 13): extra = {"type_params": None} @@ -400,13 +378,6 @@ def eval_type( if origin is UnionType: origin = Union - # Future annotations in older versions will eval generic aliases to their - # real types (i.e. List[foo] will have its origin set to list instead - # of List). If that type is not subscriptable, retrieve its generic - # alias version instead. - if sys.version_info < (3, 9) and not hasattr(origin, "__class_getitem__"): - origin = get_generic_alias(origin) - type_ = ( origin[tuple(eval_type(a, globalns, localns) for a in args)] if args @@ -417,19 +388,19 @@ def eval_type( __all__ = [ + "eval_type", "get_generic_alias", - "is_generic_alias", - "is_list", - "is_union", - "is_optional", - "get_optional_annotation", "get_list_annotation", + "get_optional_annotation", + "get_parameters", + "is_classvar", "is_concrete_generic", - "is_generic_subclass", "is_generic", + "is_generic_alias", + "is_generic_subclass", + "is_list", + "is_optional", "is_type_var", - "is_classvar", + "is_union", "type_has_annotation", - "get_parameters", - "eval_type", ] diff --git a/tests/a.py b/tests/a.py index 73d46f3d71..fc55f1361c 100644 --- a/tests/a.py +++ b/tests/a.py @@ -1,7 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional -from typing_extensions import Annotated +from typing import TYPE_CHECKING, Annotated, Optional import strawberry diff --git a/tests/b.py b/tests/b.py index b291c107ba..2e9b83e1bf 100644 --- a/tests/b.py +++ b/tests/b.py @@ -1,7 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List, Optional -from typing_extensions import Annotated +from typing import TYPE_CHECKING, Annotated, Optional import strawberry @@ -22,7 +21,7 @@ async def a(self) -> Annotated[A, strawberry.lazy("tests.a"), object()]: @strawberry.field async def a_list( self, - ) -> List[Annotated[A, strawberry.lazy("tests.a")]]: # pragma: no cover + ) -> list[Annotated[A, strawberry.lazy("tests.a")]]: # pragma: no cover from tests.a import A return [A(id=self.id)] diff --git a/tests/benchmarks/api.py b/tests/benchmarks/api.py index 4584d944e2..24b3d3c888 100644 --- a/tests/benchmarks/api.py +++ b/tests/benchmarks/api.py @@ -1,4 +1,4 @@ -from typing import AsyncIterator, List +from collections.abc import AsyncIterator import strawberry from strawberry.directive import DirectiveLocation @@ -59,11 +59,11 @@ def hello(self) -> str: return "Hello World!" @strawberry.field - def people(self, limit: int = 100) -> List[Person]: + def people(self, limit: int = 100) -> list[Person]: return people[:limit] if limit else people @strawberry.field - def items(self, count: int) -> List[Item]: + def items(self, count: int) -> list[Item]: return [Item(name="Item", index=i) for i in range(count)] diff --git a/tests/benchmarks/schema.py b/tests/benchmarks/schema.py index fd7b306895..decf6b3927 100644 --- a/tests/benchmarks/schema.py +++ b/tests/benchmarks/schema.py @@ -1,8 +1,7 @@ from __future__ import annotations from enum import Enum -from typing import List, Union -from typing_extensions import Annotated +from typing import Annotated, Union import strawberry @@ -118,7 +117,7 @@ def random(cls, seed: int) -> Comment: @strawberry.type class UserConnection: - edges: List[UserEdge | None] | None + edges: list[UserEdge | None] | None page_info: PageInfo @@ -130,7 +129,7 @@ class UserEdge: @strawberry.type class PostConnection: - edges: List[PostEdge | None] | None + edges: list[PostEdge | None] | None page_info: PageInfo @@ -142,7 +141,7 @@ class PostEdge: @strawberry.type class CommentConnection: - edges: List[CommentEdge | None] | None + edges: list[CommentEdge | None] | None page_info: PageInfo @@ -166,7 +165,7 @@ class Query: @strawberry.field async def search( self, query: str, first: int = 10, after: str | None = None - ) -> List[SearchResult | None] | None: + ) -> list[SearchResult | None] | None: div = 3 chunks = [first // div + (1 if x < first % div else 0) for x in range(div)] diff --git a/tests/benchmarks/test_execute.py b/tests/benchmarks/test_execute.py index 865d49c700..643677596a 100644 --- a/tests/benchmarks/test_execute.py +++ b/tests/benchmarks/test_execute.py @@ -2,7 +2,7 @@ import datetime import random from datetime import date -from typing import List, Type, cast +from typing import cast import pytest from pytest_codspeed.plugin import BenchmarkFixture @@ -27,10 +27,10 @@ class Patron: name: str age: int birthday: date - tags: List[str] + tags: list[str] @strawberry.field - def pets(self) -> List[Pet]: + def pets(self) -> list[Pet]: return [ Pet( id=i, @@ -42,7 +42,7 @@ def pets(self) -> List[Pet]: @strawberry.type class Query: @strawberry.field - def patrons(self) -> List[Patron]: + def patrons(self) -> list[Patron]: return [ Patron( id=i, @@ -84,13 +84,13 @@ def test_interface_performance(benchmark: BenchmarkFixture, ntypes: int): class Item: id: ID - CONCRETE_TYPES: List[Type[Item]] = [ + CONCRETE_TYPES: list[type[Item]] = [ strawberry.type(type(f"Item{i}", (Item,), {})) for i in range(ntypes) ] @strawberry.type class Query: - items: List[Item] + items: list[Item] schema = strawberry.Schema(query=Query, types=CONCRETE_TYPES) query = "query { items { id } }" diff --git a/tests/benchmarks/test_execute_with_extensions.py b/tests/benchmarks/test_execute_with_extensions.py index 059ae959ef..21495b45c8 100644 --- a/tests/benchmarks/test_execute_with_extensions.py +++ b/tests/benchmarks/test_execute_with_extensions.py @@ -1,7 +1,7 @@ import asyncio from inspect import isawaitable from pathlib import Path -from typing import Any, Dict, List +from typing import Any import pytest from pytest_codspeed.plugin import BenchmarkFixture @@ -14,7 +14,7 @@ class SimpleExtension(SchemaExtension): - def get_results(self) -> AwaitableOrValue[Dict[str, Any]]: + def get_results(self) -> AwaitableOrValue[dict[str, Any]]: return super().get_results() @@ -39,7 +39,7 @@ async def resolve(self, _next, root, info, *args: Any, **kwargs: Any) -> Any: ids=lambda x: f"with_{'_'.join(type(ext).__name__.lower() for ext in x) or 'no_extensions'}", ) def test_execute( - benchmark: BenchmarkFixture, items: int, extensions: List[SchemaExtension] + benchmark: BenchmarkFixture, items: int, extensions: list[SchemaExtension] ): schema = strawberry.Schema(query=Query, extensions=extensions) diff --git a/tests/benchmarks/test_generic_input.py b/tests/benchmarks/test_generic_input.py index bc937c4605..7be84c567a 100644 --- a/tests/benchmarks/test_generic_input.py +++ b/tests/benchmarks/test_generic_input.py @@ -1,5 +1,5 @@ import asyncio -from typing import Generic, List, Optional, TypeVar +from typing import Generic, Optional, TypeVar from pytest_codspeed.plugin import BenchmarkFixture @@ -13,8 +13,8 @@ class GraphQLFilter(Generic[T]): """EXTERNAL Filter for GraphQL queries""" eq: Optional[T] = None - in_: Optional[List[T]] = None - nin: Optional[List[T]] = None + in_: Optional[list[T]] = None + nin: Optional[list[T]] = None gt: Optional[T] = None gte: Optional[T] = None lt: Optional[T] = None @@ -36,7 +36,7 @@ class Book: async def authors( self, name: Optional[GraphQLFilter[str]] = None, - ) -> List[Author]: + ) -> list[Author]: return [Author(name="F. Scott Fitzgerald")] @@ -48,7 +48,7 @@ def get_books(): @strawberry.type class Query: - books: List[Book] = strawberry.field(resolver=get_books) + books: list[Book] = strawberry.field(resolver=get_books) schema = strawberry.Schema(query=Query) diff --git a/tests/benchmarks/test_subscriptions.py b/tests/benchmarks/test_subscriptions.py index 0bbd512835..79e24a1301 100644 --- a/tests/benchmarks/test_subscriptions.py +++ b/tests/benchmarks/test_subscriptions.py @@ -1,5 +1,5 @@ import asyncio -from typing import AsyncIterator +from collections.abc import AsyncIterator import pytest from graphql import ExecutionResult diff --git a/tests/chalice/app.py b/tests/chalice/app.py index c391ffa3c6..95c60adedb 100644 --- a/tests/chalice/app.py +++ b/tests/chalice/app.py @@ -1,5 +1,3 @@ -from typing import Dict - from chalice import Chalice # type: ignore from chalice.app import Response from strawberry.chalice.views import GraphQLView @@ -14,7 +12,7 @@ @app.route("/") -def index() -> Dict[str, str]: +def index() -> dict[str, str]: return {"strawberry": "cake"} diff --git a/tests/channels/test_layers.py b/tests/channels/test_layers.py index ffc7108703..1610cd178d 100644 --- a/tests/channels/test_layers.py +++ b/tests/channels/test_layers.py @@ -1,7 +1,8 @@ from __future__ import annotations import asyncio -from typing import TYPE_CHECKING, AsyncGenerator +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING import pytest diff --git a/tests/channels/test_testing.py b/tests/channels/test_testing.py index f45f535214..4f790240e0 100644 --- a/tests/channels/test_testing.py +++ b/tests/channels/test_testing.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, AsyncGenerator +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, Any import pytest diff --git a/tests/cli/test_codegen.py b/tests/cli/test_codegen.py index 991eaaf488..0a05430522 100644 --- a/tests/cli/test_codegen.py +++ b/tests/cli/test_codegen.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import List import pytest from typer import Typer @@ -19,8 +18,8 @@ def on_end(self, result: CodegenResult): class QueryCodegenTestPlugin(QueryCodegenPlugin): def generate_code( - self, types: List[GraphQLType], operation: GraphQLOperation - ) -> List[CodegenFile]: + self, types: list[GraphQLType], operation: GraphQLOperation + ) -> list[CodegenFile]: return [ CodegenFile( path="test.py", @@ -31,8 +30,8 @@ def generate_code( class EmptyPlugin(QueryCodegenPlugin): def generate_code( - self, types: List[GraphQLType], operation: GraphQLOperation - ) -> List[CodegenFile]: + self, types: list[GraphQLType], operation: GraphQLOperation + ) -> list[CodegenFile]: return [ CodegenFile( path="test.py", diff --git a/tests/codegen/conftest.py b/tests/codegen/conftest.py index d776a474fd..f1cb598958 100644 --- a/tests/codegen/conftest.py +++ b/tests/codegen/conftest.py @@ -2,8 +2,15 @@ import decimal import enum import random -from typing import TYPE_CHECKING, Generic, List, NewType, Optional, TypeVar, Union -from typing_extensions import Annotated +from typing import ( + TYPE_CHECKING, + Annotated, + Generic, + NewType, + Optional, + TypeVar, + Union, +) from uuid import UUID import pytest @@ -41,8 +48,8 @@ class Animal: @strawberry.type class LifeContainer(Generic[LivingThing1, LivingThing2]): - items1: List[LivingThing1] - items2: List[LivingThing2] + items1: list[LivingThing1] + items2: list[LivingThing2] PersonOrAnimal = Annotated[Union[Person, Animal], strawberry.union("PersonOrAnimal")] @@ -79,8 +86,8 @@ class ExampleInput: name: str age: int person: Optional[PersonInput] - people: List[PersonInput] - optional_people: Optional[List[PersonInput]] + people: list[PersonInput] + optional_people: Optional[list[PersonInput]] @strawberry.type @@ -95,13 +102,13 @@ class Query: time: datetime.time decimal: decimal.Decimal optional_int: Optional[int] - list_of_int: List[int] - list_of_optional_int: List[Optional[int]] - optional_list_of_optional_int: Optional[List[Optional[int]]] + list_of_int: list[int] + list_of_optional_int: list[Optional[int]] + optional_list_of_optional_int: Optional[list[Optional[int]]] person: Person optional_person: Optional[Person] - list_of_people: List[Person] - optional_list_of_people: Optional[List[Person]] + list_of_people: list[Person] + optional_list_of_people: Optional[list[Person]] enum: Color json: JSON union: PersonOrAnimal @@ -141,12 +148,12 @@ class BlogPostInput: @strawberry.input class AddBlogPostsInput: - posts: List[BlogPostInput] + posts: list[BlogPostInput] @strawberry.type class AddBlogPostsOutput: - posts: List[BlogPost] + posts: list[BlogPost] @strawberry.type diff --git a/tests/codegen/snapshots/python/generic_types.py b/tests/codegen/snapshots/python/generic_types.py index ccffbb8fb5..8a77c03c74 100644 --- a/tests/codegen/snapshots/python/generic_types.py +++ b/tests/codegen/snapshots/python/generic_types.py @@ -9,8 +9,8 @@ class ListLifeGenericResultListLifeItems2: age: int class ListLifeGenericResultListLife: - items1: List[ListLifeGenericResultListLifeItems1] - items2: List[ListLifeGenericResultListLifeItems2] + items1: list[ListLifeGenericResultListLifeItems1] + items2: list[ListLifeGenericResultListLifeItems2] class ListLifeGenericResult: list_life: ListLifeGenericResultListLife diff --git a/tests/codegen/snapshots/python/multiple_types.py b/tests/codegen/snapshots/python/multiple_types.py index 39f115fe17..e81f613653 100644 --- a/tests/codegen/snapshots/python/multiple_types.py +++ b/tests/codegen/snapshots/python/multiple_types.py @@ -8,4 +8,4 @@ class OperationNameResultListOfPeople: class OperationNameResult: person: OperationNameResultPerson - list_of_people: List[OperationNameResultListOfPeople] + list_of_people: list[OperationNameResultListOfPeople] diff --git a/tests/codegen/snapshots/python/mutation_with_object.py b/tests/codegen/snapshots/python/mutation_with_object.py index dcd0164cd9..d24687bd19 100644 --- a/tests/codegen/snapshots/python/mutation_with_object.py +++ b/tests/codegen/snapshots/python/mutation_with_object.py @@ -5,7 +5,7 @@ class AddBlogPostsResultAddBlogPostsPosts: title: str class AddBlogPostsResultAddBlogPosts: - posts: List[AddBlogPostsResultAddBlogPostsPosts] + posts: list[AddBlogPostsResultAddBlogPostsPosts] class AddBlogPostsResult: add_blog_posts: AddBlogPostsResultAddBlogPosts @@ -24,4 +24,4 @@ class BlogPostInput: an_optional_int: Optional[int] = None class AddBlogPostsVariables: - input: List[BlogPostInput] + input: list[BlogPostInput] diff --git a/tests/codegen/snapshots/python/nullable_list_of_non_scalars.py b/tests/codegen/snapshots/python/nullable_list_of_non_scalars.py index f5efcaa7f8..cf7b196659 100644 --- a/tests/codegen/snapshots/python/nullable_list_of_non_scalars.py +++ b/tests/codegen/snapshots/python/nullable_list_of_non_scalars.py @@ -5,4 +5,4 @@ class OperationNameResultOptionalListOfPeople: age: int class OperationNameResult: - optional_list_of_people: Optional[List[OperationNameResultOptionalListOfPeople]] + optional_list_of_people: Optional[list[OperationNameResultOptionalListOfPeople]] diff --git a/tests/codegen/snapshots/python/optional_and_lists.py b/tests/codegen/snapshots/python/optional_and_lists.py index 57d57b58f1..bc2802386e 100644 --- a/tests/codegen/snapshots/python/optional_and_lists.py +++ b/tests/codegen/snapshots/python/optional_and_lists.py @@ -2,6 +2,6 @@ class OperationNameResult: optional_int: Optional[int] - list_of_int: List[int] - list_of_optional_int: List[Optional[int]] - optional_list_of_optional_int: Optional[List[Optional[int]]] + list_of_int: list[int] + list_of_optional_int: list[Optional[int]] + optional_list_of_optional_int: Optional[list[Optional[int]]] diff --git a/tests/codegen/snapshots/python/variables.py b/tests/codegen/snapshots/python/variables.py index a19d010717..98f913e1a8 100644 --- a/tests/codegen/snapshots/python/variables.py +++ b/tests/codegen/snapshots/python/variables.py @@ -12,12 +12,12 @@ class ExampleInput: name: str age: int person: Optional[PersonInput] - people: List[PersonInput] - optional_people: Optional[List[PersonInput]] + people: list[PersonInput] + optional_people: Optional[list[PersonInput]] class OperationNameVariables: id: Optional[str] input: ExampleInput - ids: List[str] - ids2: Optional[List[Optional[str]]] - ids3: Optional[List[Optional[List[Optional[str]]]]] + ids: list[str] + ids2: Optional[list[Optional[str]]] + ids3: Optional[list[Optional[list[Optional[str]]]]] diff --git a/tests/codegen/test_query_codegen.py b/tests/codegen/test_query_codegen.py index 78a3b71259..d4819db2a9 100644 --- a/tests/codegen/test_query_codegen.py +++ b/tests/codegen/test_query_codegen.py @@ -5,7 +5,6 @@ # - 5. test subscriptions (raise) from pathlib import Path -from typing import Type import pytest from pytest_snapshot.plugin import Snapshot @@ -34,7 +33,7 @@ @pytest.mark.parametrize("query", QUERIES, ids=[x.name for x in QUERIES]) def test_codegen( query: Path, - plugin_class: Type[QueryCodegenPlugin], + plugin_class: type[QueryCodegenPlugin], plugin_name: str, extension: str, snapshot: Snapshot, diff --git a/tests/conftest.py b/tests/conftest.py index 6e1187b042..99049677c0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,17 +1,17 @@ import pathlib import sys -from typing import Any, List, Tuple +from typing import Any import pytest from strawberry.utils import IS_GQL_32 -def pytest_emoji_xfailed(config: pytest.Config) -> Tuple[str, str]: +def pytest_emoji_xfailed(config: pytest.Config) -> tuple[str, str]: return "🤷‍♂️ ", "XFAIL 🤷‍♂️ " -def pytest_emoji_skipped(config: pytest.Config) -> Tuple[str, str]: +def pytest_emoji_skipped(config: pytest.Config) -> tuple[str, str]: return "🦘 ", "SKIPPED 🦘" @@ -19,7 +19,7 @@ def pytest_emoji_skipped(config: pytest.Config) -> Tuple[str, str]: @pytest.hookimpl # type: ignore -def pytest_collection_modifyitems(config: pytest.Config, items: List[pytest.Item]): +def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]): rootdir = pathlib.Path(config.rootdir) # type: ignore for item in items: diff --git a/tests/d.py b/tests/d.py index 44678fdc98..a027e0c8f2 100644 --- a/tests/d.py +++ b/tests/d.py @@ -1,7 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List -from typing_extensions import Annotated +from typing import TYPE_CHECKING, Annotated import strawberry @@ -16,7 +15,7 @@ class D: @strawberry.field async def c_list( self, - ) -> List[Annotated[C, strawberry.lazy("tests.c")]]: # pragma: no cover + ) -> list[Annotated[C, strawberry.lazy("tests.c")]]: # pragma: no cover from tests.c import C return [C(id=self.id)] diff --git a/tests/django/test_dataloaders.py b/tests/django/test_dataloaders.py index 835e6a5c1e..bd043ad925 100644 --- a/tests/django/test_dataloaders.py +++ b/tests/django/test_dataloaders.py @@ -1,5 +1,4 @@ import json -from typing import List, Tuple import pytest from asgiref.sync import sync_to_async @@ -11,7 +10,7 @@ try: import django - DJANGO_VERSION: Tuple[int, int, int] = django.VERSION + DJANGO_VERSION: tuple[int, int, int] = django.VERSION except ImportError: DJANGO_VERSION = (0, 0, 0) @@ -43,7 +42,7 @@ async def test_fetch_data_from_db(mocker: MockerFixture): from .app.models import Example - def _sync_batch_load(keys: List[str]): + def _sync_batch_load(keys: list[str]): data = Example.objects.filter(id__in=keys) return list(data) @@ -53,7 +52,7 @@ def _sync_batch_load(keys: List[str]): ids = await prepare_db() - async def idx(keys: List[str]) -> List[Example]: + async def idx(keys: list[str]) -> list[Example]: return await batch_load(keys) mock_loader = mocker.Mock(side_effect=idx) diff --git a/tests/experimental/pydantic/schema/test_basic.py b/tests/experimental/pydantic/schema/test_basic.py index f816208978..91e6d7317e 100644 --- a/tests/experimental/pydantic/schema/test_basic.py +++ b/tests/experimental/pydantic/schema/test_basic.py @@ -1,10 +1,8 @@ -import sys import textwrap from enum import Enum -from typing import List, Optional, Union +from typing import Optional, Union import pydantic -import pytest import strawberry from tests.experimental.pydantic.utils import needs_pydantic_v1 @@ -162,7 +160,7 @@ def user(self) -> User: def test_basic_type_with_list(): class UserModel(pydantic.BaseModel): age: int - friend_names: List[str] + friend_names: list[str] @strawberry.experimental.pydantic.type(UserModel) class User: @@ -225,7 +223,7 @@ class HobbyType: name: strawberry.auto class User(pydantic.BaseModel): - hobbies: List[Hobby] + hobbies: list[Hobby] @strawberry.experimental.pydantic.type(User) class UserType: @@ -532,10 +530,6 @@ def user(self) -> User: @needs_pydantic_v1 -@pytest.mark.skipif( - sys.version_info < (3, 9), - reason="ConstrainedList with another model does not work with 3.8", -) def test_basic_type_with_constrained_list(): class FriendList(pydantic.ConstrainedList): min_items = 1 diff --git a/tests/experimental/pydantic/schema/test_federation.py b/tests/experimental/pydantic/schema/test_federation.py index 47bd56c2f9..db94a8e336 100644 --- a/tests/experimental/pydantic/schema/test_federation.py +++ b/tests/experimental/pydantic/schema/test_federation.py @@ -26,7 +26,7 @@ def resolve_reference(cls, upc) -> "Product": @strawberry.federation.type(extend=True) class Query: @strawberry.field - def top_products(self, first: int) -> typing.List[Product]: + def top_products(self, first: int) -> typing.List[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) diff --git a/tests/experimental/pydantic/schema/test_mutation.py b/tests/experimental/pydantic/schema/test_mutation.py index 03e545eece..d225d75523 100644 --- a/tests/experimental/pydantic/schema/test_mutation.py +++ b/tests/experimental/pydantic/schema/test_mutation.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Union +from typing import Union import pydantic @@ -183,7 +183,7 @@ def create_user(self, input: CreateUserInput) -> Union[UserType, UserError]: try: data = input.to_pydantic() except pydantic.ValidationError as e: - args: Dict[str, List[str]] = {} + args: dict[str, list[str]] = {} for error in e.errors(): field = error["loc"][0] # currently doesn't support nested errors field_errors = args.get(field, []) diff --git a/tests/experimental/pydantic/test_basic.py b/tests/experimental/pydantic/test_basic.py index 9556d3f617..8624a13b7f 100644 --- a/tests/experimental/pydantic/test_basic.py +++ b/tests/experimental/pydantic/test_basic.py @@ -1,7 +1,6 @@ import dataclasses from enum import Enum -from typing import Any, List, Optional, Union -from typing_extensions import Annotated +from typing import Annotated, Any, Optional, Union import pydantic import pytest @@ -221,7 +220,7 @@ class UserType: def test_list(): class User(pydantic.BaseModel): - friend_names: List[str] + friend_names: list[str] @strawberry.experimental.pydantic.type(User) class UserType: @@ -242,7 +241,7 @@ class Friend(pydantic.BaseModel): name: str class User(pydantic.BaseModel): - friends: Optional[List[Optional[Friend]]] + friends: Optional[list[Optional[Friend]]] @strawberry.experimental.pydantic.type(Friend) class FriendType: @@ -353,7 +352,7 @@ class UserModel(pydantic.BaseModel): name: str = pydantic.Field("Michael", description="The user name") password: Optional[str] = pydantic.Field(default="ABC") passwordtwo: Optional[str] = None - some_list: Optional[List[str]] = pydantic.Field(default_factory=list) + some_list: Optional[list[str]] = pydantic.Field(default_factory=list) check: Optional[bool] = False @strawberry.experimental.pydantic.type(UserModel, all_fields=True) @@ -400,8 +399,8 @@ def test_type_with_fields_mutable_default(): empty_list = [] class User(pydantic.BaseModel): - groups: List[str] - friends: List[str] = empty_list + groups: list[str] + friends: list[str] = empty_list @strawberry.experimental.pydantic.type(User) class UserType: @@ -633,7 +632,7 @@ class User(pydantic.BaseModel): work: Optional[Work] = None class Group(pydantic.BaseModel): - users: List[User] + users: list[User] # Test both definition orders @strawberry.experimental.pydantic.input(Work) @@ -664,7 +663,7 @@ class GroupOutput: @strawberry.type class Query: - groups: List[GroupOutput] + groups: list[GroupOutput] @strawberry.type class Mutation: @@ -939,7 +938,7 @@ class UserType: def test_nested_annotated(): class User(pydantic.BaseModel): a: Optional[Annotated[int, "metadata"]] - b: Optional[List[Annotated[int, "metadata"]]] + b: Optional[list[Annotated[int, "metadata"]]] @strawberry.experimental.pydantic.input(User, all_fields=True) class UserType: diff --git a/tests/experimental/pydantic/test_conversion.py b/tests/experimental/pydantic/test_conversion.py index 745e5e90c4..7841f8de2e 100644 --- a/tests/experimental/pydantic/test_conversion.py +++ b/tests/experimental/pydantic/test_conversion.py @@ -3,7 +3,7 @@ import re import sys from enum import Enum -from typing import Any, Dict, List, NewType, Optional, TypeVar, Union +from typing import Any, NewType, Optional, TypeVar, Union import pytest from pydantic import BaseModel, Field, ValidationError @@ -261,7 +261,7 @@ class Work: name: strawberry.auto class UserModel(BaseModel): - work: List[WorkModel] + work: list[WorkModel] @strawberry.experimental.pydantic.type(UserModel) class User: @@ -280,7 +280,7 @@ class User: def test_can_convert_pydantic_type_with_list_of_nested_int_to_strawberry(): class UserModel(BaseModel): - hours: List[int] + hours: list[int] @strawberry.experimental.pydantic.type(UserModel) class User: @@ -300,7 +300,7 @@ class User: def test_can_convert_pydantic_type_with_matrix_list_of_nested_int_to_strawberry(): class UserModel(BaseModel): - hours: List[List[int]] + hours: list[list[int]] @strawberry.experimental.pydantic.type(UserModel) class User: @@ -331,7 +331,7 @@ class Hour: hour: strawberry.auto class UserModel(BaseModel): - hours: List[List[HourModel]] + hours: list[list[HourModel]] @strawberry.experimental.pydantic.type(UserModel) class User: @@ -602,7 +602,7 @@ class UserModel(BaseModel): @strawberry.experimental.pydantic.type(UserModel) class User: - work: List[Work] + work: list[Work] password: strawberry.auto origin_user = UserModel(password="abc") @@ -633,7 +633,7 @@ class Work: name: strawberry.auto class UserModel(BaseModel): - work: List[WorkModel] + work: list[WorkModel] @strawberry.experimental.pydantic.type(UserModel) class User: @@ -665,7 +665,7 @@ class Work: name: strawberry.auto class UserModel(BaseModel): - work: List[Optional[WorkModel]] + work: list[Optional[WorkModel]] @strawberry.experimental.pydantic.type(UserModel) class User: @@ -722,7 +722,7 @@ class User: def test_can_convert_pydantic_type_to_strawberry_with_optional_nested_value(): class UserModel(BaseModel): - names: Optional[List[str]] + names: Optional[list[str]] @strawberry.experimental.pydantic.type(UserModel) class User: @@ -828,7 +828,7 @@ def test_can_convert_pydantic_type_to_strawberry_newtype_list(): class User(BaseModel): age: int - passwords: List[Password] + passwords: list[Password] @strawberry.experimental.pydantic.type(User) class UserType: @@ -960,7 +960,7 @@ class User(BaseModel): work: Optional[Work] class Group(BaseModel): - users: List[User] + users: list[User] # Test both definition orders @strawberry.experimental.pydantic.input(Work) @@ -1016,7 +1016,7 @@ class UserType: @staticmethod def from_pydantic( - instance: User, extra: Optional[Dict[str, Any]] = None + instance: User, extra: Optional[dict[str, Any]] = None ) -> "UserType": return UserType( age=str(instance.age), @@ -1058,7 +1058,7 @@ class UserType: @staticmethod def from_pydantic( - instance: User, extra: Optional[Dict[str, Any]] = None + instance: User, extra: Optional[dict[str, Any]] = None ) -> "UserType": return UserType( age=str(instance.age), @@ -1121,7 +1121,7 @@ class Work(BaseModel): class User(BaseModel): age: int password: Optional[str] - work: Dict[str, Work] + work: dict[str, Work] @strawberry.experimental.pydantic.input(Work) class WorkInput: @@ -1175,10 +1175,6 @@ class UserInput: data.to_pydantic() -@pytest.mark.skipif( - sys.version_info < (3, 9), - reason="generic aliases where added in python 3.9", -) def test_can_convert_generic_alias_fields_to_strawberry(): class TestModel(BaseModel): list_1d: list[int] @@ -1232,21 +1228,17 @@ class Test: @needs_pydantic_v1 -@pytest.mark.skipif( - sys.version_info < (3, 9), - reason="ConstrainedList with another model does not work with 3.8", -) def test_can_convert_pydantic_type_to_strawberry_with_constrained_list(): from pydantic import ConstrainedList class WorkModel(BaseModel): name: str - class workList(ConstrainedList): + class WorkList(ConstrainedList): min_items = 1 class UserModel(BaseModel): - work: workList[WorkModel] + work: WorkList[WorkModel] @strawberry.experimental.pydantic.type(WorkModel) class Work: @@ -1268,23 +1260,20 @@ class User: SI = TypeVar("SI", covariant=True) # pragma: no mutate -class SpecialList(List[SI]): +class SpecialList(list[SI]): pass @needs_pydantic_v1 -@pytest.mark.skipif( - sys.version_info < (3, 9), reason="SpecialList does not work with 3.8" -) def test_can_convert_pydantic_type_to_strawberry_with_specialized_list(): class WorkModel(BaseModel): name: str - class workList(SpecialList[SI]): + class WorkList(SpecialList[SI]): min_items = 1 class UserModel(BaseModel): - work: workList[WorkModel] + work: WorkList[WorkModel] @strawberry.experimental.pydantic.type(WorkModel) class Work: diff --git a/tests/experimental/pydantic/test_error_type.py b/tests/experimental/pydantic/test_error_type.py index 8e37c6402c..ce5893ca02 100644 --- a/tests/experimental/pydantic/test_error_type.py +++ b/tests/experimental/pydantic/test_error_type.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional import pydantic import pytest @@ -175,7 +175,7 @@ class FriendModel(pydantic.BaseModel): food: str class UserModel(pydantic.BaseModel): - friends: List[FriendModel] + friends: list[FriendModel] @strawberry.experimental.pydantic.error_type(FriendModel) class FriendError: @@ -199,7 +199,7 @@ class UserError: def test_error_type_with_list_of_scalar(): class UserModel(pydantic.BaseModel): - friends: List[int] + friends: list[int] @strawberry.experimental.pydantic.error_type(UserModel) class UserError: @@ -239,7 +239,7 @@ class UserError: def test_error_type_with_list_of_optional_scalar(): class UserModel(pydantic.BaseModel): - age: List[Optional[int]] + age: list[Optional[int]] @strawberry.experimental.pydantic.error_type(UserModel) class UserError: @@ -260,7 +260,7 @@ class UserError: def test_error_type_with_optional_list_scalar(): class UserModel(pydantic.BaseModel): - age: Optional[List[int]] + age: Optional[list[int]] @strawberry.experimental.pydantic.error_type(UserModel) class UserError: @@ -281,7 +281,7 @@ class UserError: def test_error_type_with_optional_list_of_optional_scalar(): class UserModel(pydantic.BaseModel): - age: Optional[List[Optional[int]]] + age: Optional[list[Optional[int]]] @strawberry.experimental.pydantic.error_type(UserModel) class UserError: @@ -309,7 +309,7 @@ class FriendError: name: strawberry.auto class UserModel(pydantic.BaseModel): - friends: Optional[List[FriendModel]] + friends: Optional[list[FriendModel]] @strawberry.experimental.pydantic.error_type(UserModel) class UserError: @@ -329,7 +329,7 @@ class UserError: def test_error_type_with_matrix_list_of_scalar(): class UserModel(pydantic.BaseModel): - age: List[List[int]] + age: list[list[int]] @strawberry.experimental.pydantic.error_type(UserModel) class UserError: diff --git a/tests/experimental/pydantic/test_fields.py b/tests/experimental/pydantic/test_fields.py index 878969b9af..0184463a37 100644 --- a/tests/experimental/pydantic/test_fields.py +++ b/tests/experimental/pydantic/test_fields.py @@ -1,5 +1,4 @@ import re -from typing import List from typing_extensions import Literal import pydantic @@ -154,7 +153,7 @@ class UserType: ... assert UserType.__strawberry_definition__.fields[0].name == "friends" assert ( UserType.__strawberry_definition__.fields[0].type_annotation.raw_annotation - == List[str] + == list[str] ) data = UserType(friends=[]) @@ -181,7 +180,7 @@ class UserType: ... assert UserType.__strawberry_definition__.fields[0].name == "friends" assert ( UserType.__strawberry_definition__.fields[0].type_annotation.raw_annotation - == List[List[int]] + == list[list[int]] ) diff --git a/tests/fastapi/app.py b/tests/fastapi/app.py index 0ae009eebc..26f0d7eb61 100644 --- a/tests/fastapi/app.py +++ b/tests/fastapi/app.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Union +from typing import Any, Union from fastapi import BackgroundTasks, Depends, FastAPI, Request, WebSocket from strawberry.fastapi import GraphQLRouter @@ -14,7 +14,7 @@ async def get_context( request: Request = None, ws: WebSocket = None, custom_value=Depends(custom_context_dependency), -) -> Dict[str, Any]: +) -> dict[str, Any]: return { "custom_value": custom_value, "request": request or ws, diff --git a/tests/fastapi/test_context.py b/tests/fastapi/test_context.py index 16530b4e09..c5a9147289 100644 --- a/tests/fastapi/test_context.py +++ b/tests/fastapi/test_context.py @@ -1,5 +1,5 @@ import asyncio -from typing import AsyncGenerator, Dict +from collections.abc import AsyncGenerator import pytest @@ -108,7 +108,7 @@ def abc(self, info: strawberry.Info) -> str: def custom_context_dependency() -> str: return "rocks" - def get_context(value: str = Depends(custom_context_dependency)) -> Dict[str, str]: + def get_context(value: str = Depends(custom_context_dependency)) -> dict[str, str]: return {"strawberry": value} app = FastAPI() diff --git a/tests/federation/printer/test_authenticated.py b/tests/federation/printer/test_authenticated.py index db9ef10a22..d362dad337 100644 --- a/tests/federation/printer/test_authenticated.py +++ b/tests/federation/printer/test_authenticated.py @@ -1,7 +1,6 @@ import textwrap from enum import Enum -from typing import List -from typing_extensions import Annotated +from typing import Annotated import strawberry @@ -20,7 +19,7 @@ class Query: @strawberry.federation.field(authenticated=True) def top_products( self, first: Annotated[int, strawberry.federation.argument()] - ) -> List[Product]: + ) -> list[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) diff --git a/tests/federation/printer/test_entities.py b/tests/federation/printer/test_entities.py index eb5866a849..7f7ada547b 100644 --- a/tests/federation/printer/test_entities.py +++ b/tests/federation/printer/test_entities.py @@ -1,7 +1,6 @@ # type: ignore import textwrap -from typing import List import strawberry @@ -16,7 +15,7 @@ class User: @strawberry.federation.type(extend=True) class Product: upc: str = strawberry.federation.field(external=True) - reviews: List["Review"] + reviews: list["Review"] @strawberry.federation.type class Review: @@ -27,7 +26,7 @@ class Review: @strawberry.federation.type class Query: @strawberry.field - def top_products(self, first: int) -> List[Product]: + def top_products(self, first: int) -> list[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) @@ -79,7 +78,7 @@ class User: @strawberry.federation.type(keys=["upc"], extend=True) class Product: upc: str = strawberry.federation.field(external=True) - reviews: List["Review"] + reviews: list["Review"] @strawberry.federation.type class Review: @@ -90,7 +89,7 @@ class Review: @strawberry.federation.type class Query: @strawberry.field - def top_products(self, first: int) -> List[Product]: + def top_products(self, first: int) -> list[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) diff --git a/tests/federation/printer/test_inaccessible.py b/tests/federation/printer/test_inaccessible.py index e8801d3e64..110ec1075d 100644 --- a/tests/federation/printer/test_inaccessible.py +++ b/tests/federation/printer/test_inaccessible.py @@ -1,7 +1,6 @@ import textwrap from enum import Enum -from typing import List -from typing_extensions import Annotated +from typing import Annotated import strawberry @@ -34,7 +33,7 @@ class Query: def top_products( self, first: Annotated[int, strawberry.federation.argument(inaccessible=True)], - ) -> List[Product]: + ) -> list[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema( @@ -97,7 +96,7 @@ class Query: @strawberry.type class Mutation: @strawberry.federation.mutation(inaccessible=True) - def hello(self) -> str: + def hello(self) -> str: # pragma: no cover return "Hello" schema = strawberry.federation.Schema( diff --git a/tests/federation/printer/test_interface.py b/tests/federation/printer/test_interface.py index 00d2464052..3d96fa0906 100644 --- a/tests/federation/printer/test_interface.py +++ b/tests/federation/printer/test_interface.py @@ -1,5 +1,4 @@ import textwrap -from typing import List import strawberry @@ -16,7 +15,7 @@ class Product(SomeInterface): @strawberry.federation.type class Query: @strawberry.field - def top_products(self, first: int) -> List[Product]: + def top_products(self, first: int) -> list[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) diff --git a/tests/federation/printer/test_keys.py b/tests/federation/printer/test_keys.py index 5cf411aa4c..e6368b24b4 100644 --- a/tests/federation/printer/test_keys.py +++ b/tests/federation/printer/test_keys.py @@ -1,7 +1,6 @@ # type: ignore import textwrap -from typing import List import strawberry from strawberry.federation.schema_directives import Key @@ -17,7 +16,7 @@ class User: @strawberry.federation.type(keys=[Key(fields="upc", resolvable=True)], extend=True) class Product: upc: str = strawberry.federation.field(external=True) - reviews: List["Review"] + reviews: list["Review"] @strawberry.federation.type(keys=["body"]) class Review: @@ -28,7 +27,7 @@ class Review: @strawberry.federation.type class Query: @strawberry.field - def top_products(self, first: int) -> List[Product]: + def top_products(self, first: int) -> list[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=False) @@ -79,7 +78,7 @@ class User: @strawberry.federation.type(keys=[Key(fields="upc", resolvable=True)], extend=True) class Product: upc: str = strawberry.federation.field(external=True) - reviews: List["Review"] + reviews: list["Review"] @strawberry.federation.type(keys=["body"]) class Review: @@ -90,7 +89,7 @@ class Review: @strawberry.federation.type class Query: @strawberry.field - def top_products(self, first: int) -> List[Product]: + def top_products(self, first: int) -> list[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) diff --git a/tests/federation/printer/test_override.py b/tests/federation/printer/test_override.py index bec1f6f95a..a5682b88fd 100644 --- a/tests/federation/printer/test_override.py +++ b/tests/federation/printer/test_override.py @@ -1,7 +1,6 @@ # type: ignore import textwrap -from typing import List import strawberry from strawberry.federation.schema_directives import Override @@ -19,7 +18,7 @@ class Product(SomeInterface): @strawberry.federation.type class Query: @strawberry.field - def top_products(self, first: int) -> List[Product]: + def top_products(self, first: int) -> list[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) @@ -71,7 +70,7 @@ class Product(SomeInterface): @strawberry.federation.type class Query: @strawberry.field - def top_products(self, first: int) -> List[Product]: + def top_products(self, first: int) -> list[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) diff --git a/tests/federation/printer/test_policy.py b/tests/federation/printer/test_policy.py index 05dae2eb11..043eafc5a5 100644 --- a/tests/federation/printer/test_policy.py +++ b/tests/federation/printer/test_policy.py @@ -1,7 +1,6 @@ import textwrap from enum import Enum -from typing import List -from typing_extensions import Annotated +from typing import Annotated import strawberry @@ -26,7 +25,7 @@ class Query: ) def top_products( self, first: Annotated[int, strawberry.federation.argument()] - ) -> List[Product]: + ) -> list[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) diff --git a/tests/federation/printer/test_provides.py b/tests/federation/printer/test_provides.py index a13130bc82..87c8c6970e 100644 --- a/tests/federation/printer/test_provides.py +++ b/tests/federation/printer/test_provides.py @@ -1,7 +1,6 @@ # type: ignore import textwrap -from typing import List import strawberry from strawberry.schema.config import StrawberryConfig @@ -18,7 +17,7 @@ class User: class Product: upc: str = strawberry.federation.field(external=True) the_name: str = strawberry.federation.field(external=True) - reviews: List["Review"] + reviews: list["Review"] @strawberry.federation.type class Review: @@ -29,7 +28,7 @@ class Review: @strawberry.federation.type class Query: @strawberry.field - def top_products(self, first: int) -> List[Product]: + def top_products(self, first: int) -> list[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema( @@ -90,7 +89,7 @@ class User: class Product: upc: str = strawberry.federation.field(external=True) the_name: str = strawberry.federation.field(external=True) - reviews: List["Review"] + reviews: list["Review"] @strawberry.federation.type class Review: @@ -101,7 +100,7 @@ class Review: @strawberry.federation.type class Query: @strawberry.field - def top_products(self, first: int) -> List[Product]: + def top_products(self, first: int) -> list[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema( diff --git a/tests/federation/printer/test_requires.py b/tests/federation/printer/test_requires.py index fef6011949..8abdfb2ddf 100644 --- a/tests/federation/printer/test_requires.py +++ b/tests/federation/printer/test_requires.py @@ -1,7 +1,6 @@ # type: ignore import textwrap -from typing import List import strawberry @@ -21,7 +20,7 @@ class Product: field3: str = strawberry.federation.field(external=True) @strawberry.federation.field(requires=["field1", "field2", "field3"]) - def reviews(self) -> List["Review"]: + def reviews(self) -> list["Review"]: return [] @strawberry.federation.type @@ -33,7 +32,7 @@ class Review: @strawberry.federation.type class Query: @strawberry.field - def top_products(self, first: int) -> List[Product]: + def top_products(self, first: int) -> list[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) diff --git a/tests/federation/printer/test_requires_scopes.py b/tests/federation/printer/test_requires_scopes.py index 2c42eec0bf..3e87487a12 100644 --- a/tests/federation/printer/test_requires_scopes.py +++ b/tests/federation/printer/test_requires_scopes.py @@ -1,7 +1,6 @@ import textwrap from enum import Enum -from typing import List -from typing_extensions import Annotated +from typing import Annotated import strawberry @@ -26,7 +25,7 @@ class Query: ) def top_products( self, first: Annotated[int, strawberry.federation.argument()] - ) -> List[Product]: + ) -> list[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) diff --git a/tests/federation/printer/test_shareable.py b/tests/federation/printer/test_shareable.py index 6212b8425b..820d21819d 100644 --- a/tests/federation/printer/test_shareable.py +++ b/tests/federation/printer/test_shareable.py @@ -1,7 +1,6 @@ # type: ignore import textwrap -from typing import List import strawberry @@ -18,7 +17,7 @@ class Product(SomeInterface): @strawberry.federation.type class Query: @strawberry.field - def top_products(self, first: int) -> List[Product]: + def top_products(self, first: int) -> list[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) diff --git a/tests/federation/printer/test_tag.py b/tests/federation/printer/test_tag.py index 388c0adf93..f485f16524 100644 --- a/tests/federation/printer/test_tag.py +++ b/tests/federation/printer/test_tag.py @@ -1,7 +1,6 @@ import textwrap from enum import Enum -from typing import List -from typing_extensions import Annotated +from typing import Annotated import strawberry @@ -22,7 +21,7 @@ class Query: @strawberry.field def top_products( self, first: Annotated[int, strawberry.federation.argument(tags=["myTag"])] - ) -> List[Product]: + ) -> list[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) diff --git a/tests/federation/test_entities.py b/tests/federation/test_entities.py index ad79aeade8..73d9d1d102 100644 --- a/tests/federation/test_entities.py +++ b/tests/federation/test_entities.py @@ -18,7 +18,7 @@ def resolve_reference(cls, upc) -> "Product": @strawberry.federation.type(extend=True) class Query: @strawberry.field - def top_products(self, first: int) -> typing.List[Product]: + def top_products(self, first: int) -> typing.List[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) @@ -58,7 +58,7 @@ def resolve_reference(cls, info: strawberry.Info, upc: str) -> "Product": @strawberry.federation.type(extend=True) class Query: @strawberry.field - def top_products(self, first: int) -> typing.List[Product]: + def top_products(self, first: int) -> typing.List[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) @@ -102,7 +102,7 @@ class Product: @strawberry.federation.type(extend=True) class Query: @strawberry.field - def top_products(self, first: int) -> typing.List[Product]: + def top_products(self, first: int) -> typing.List[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) @@ -142,7 +142,7 @@ class Product: @strawberry.federation.type(extend=True) class Query: @strawberry.field - def top_products(self, first: int) -> typing.List[Product]: + def top_products(self, first: int) -> typing.List[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) @@ -189,7 +189,7 @@ class Product: @strawberry.federation.type(extend=True) class Query: @strawberry.field - def top_products(self, first: int) -> typing.List[Product]: + def top_products(self, first: int) -> typing.List[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) @@ -232,7 +232,7 @@ class Product: @strawberry.federation.type(extend=True) class Query: @strawberry.field - def top_products(self, first: int) -> typing.List[Product]: + def top_products(self, first: int) -> typing.List[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) @@ -430,7 +430,7 @@ async def resolve_reference(cls, upc: str) -> "Product": @strawberry.federation.type(extend=True) class Query: @strawberry.field - def top_products(self, first: int) -> typing.List[Product]: + def top_products(self, first: int) -> typing.List[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) @@ -469,7 +469,7 @@ async def resolve_reference(cls, upc: str) -> "Product": @strawberry.federation.type(extend=True) class Query: @strawberry.field - def top_products(self, first: int) -> typing.List[Product]: + def top_products(self, first: int) -> typing.List[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) diff --git a/tests/federation/test_schema.py b/tests/federation/test_schema.py index ade9a9e235..278d691a52 100644 --- a/tests/federation/test_schema.py +++ b/tests/federation/test_schema.py @@ -1,6 +1,6 @@ import textwrap import warnings -from typing import Generic, List, Optional, TypeVar +from typing import Generic, Optional, TypeVar import pytest @@ -18,7 +18,7 @@ class Product: @strawberry.federation.type(extend=True) class Query: @strawberry.field - def top_products(self, first: int) -> List[Product]: + def top_products(self, first: int) -> list[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) @@ -74,7 +74,7 @@ class Product: @strawberry.federation.type(extend=True) class Query: @strawberry.field - def top_products(self, first: int) -> List[Product]: + def top_products(self, first: int) -> list[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) @@ -136,7 +136,7 @@ class Example: @strawberry.federation.type(extend=True) class Query: @strawberry.field - def top_products(self, first: int) -> List[Example]: + def top_products(self, first: int) -> list[Example]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) @@ -164,7 +164,7 @@ class Product: @strawberry.federation.type(extend=True) class Query: @strawberry.field - def top_products(self, first: int) -> List[Product]: + def top_products(self, first: int) -> list[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) @@ -210,12 +210,14 @@ class Product: @strawberry.type class ListOfProducts(Generic[T]): - products: List[T] + products: list[T] @strawberry.federation.type(extend=True) class Query: @strawberry.field - def top_products(self, first: int) -> ListOfProducts[Product]: + def top_products( + self, first: int + ) -> ListOfProducts[Product]: # pragma: no cover return ListOfProducts(products=[]) schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) @@ -264,7 +266,7 @@ class ExampleInput: @strawberry.federation.type(extend=True) class Query: @strawberry.field - def top_products(self, example: ExampleInput) -> List[str]: + def top_products(self, example: ExampleInput) -> list[str]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) @@ -350,7 +352,7 @@ class ProductFed: @strawberry.type class Query: @strawberry.field - def top_products(self, first: int) -> List[ProductFed]: + def top_products(self, first: int) -> list[ProductFed]: # pragma: no cover return [] with warnings.catch_warnings(record=True) as w: diff --git a/tests/fields/test_arguments.py b/tests/fields/test_arguments.py index f0d8a4991d..5e7959b789 100644 --- a/tests/fields/test_arguments.py +++ b/tests/fields/test_arguments.py @@ -1,6 +1,4 @@ -import sys -from typing import List, Optional, Union -from typing_extensions import Annotated +from typing import Annotated, Optional, Union import pytest @@ -17,7 +15,9 @@ def test_basic_arguments(): @strawberry.type class Query: @strawberry.field - def name(self, argument: str, optional_argument: Optional[str]) -> str: + def name( + self, argument: str, optional_argument: Optional[str] + ) -> str: # pragma: no cover return "Name" definition = Query.__strawberry_definition__ @@ -44,7 +44,9 @@ class Input: @strawberry.type class Query: @strawberry.field - def name(self, input: Input, optional_input: Optional[Input]) -> str: + def name( + self, input: Input, optional_input: Optional[Input] + ) -> str: # pragma: no cover return input.name definition = Query.__strawberry_definition__ @@ -71,7 +73,7 @@ class Input: @strawberry.type class Query: @strawberry.field - def names(self, inputs: List[Input]) -> List[str]: + def names(self, inputs: list[Input]) -> list[str]: # pragma: no cover return [input.name for input in inputs] definition = Query.__strawberry_definition__ @@ -94,7 +96,7 @@ class Input: @strawberry.type class Query: @strawberry.field - def names(self, inputs: List[Optional[Input]]) -> List[str]: + def names(self, inputs: list[Optional[Input]]) -> list[str]: # pragma: no cover return [input_.name for input_ in inputs if input_ is not None] definition = Query.__strawberry_definition__ @@ -111,7 +113,7 @@ def names(self, inputs: List[Optional[Input]]) -> List[str]: def test_basic_arguments_on_resolver(): - def name_resolver( + def name_resolver( # pragma: no cover id: strawberry.ID, argument: str, optional_argument: Optional[str] ) -> str: return "Name" @@ -140,7 +142,7 @@ class Query: def test_arguments_when_extending_a_type(): def name_resolver( id: strawberry.ID, argument: str, optional_argument: Optional[str] - ) -> str: + ) -> str: # pragma: no cover return "Name" @strawberry.type @@ -171,10 +173,10 @@ class Query(NameQuery): def test_arguments_when_extending_multiple_types(): - def name_resolver(id: strawberry.ID) -> str: + def name_resolver(id: strawberry.ID) -> str: # pragma: no cover return "Name" - def name_2_resolver(id: strawberry.ID) -> str: + def name_2_resolver(id: strawberry.ID) -> str: # pragma: no cover return "Name 2" @strawberry.type @@ -212,7 +214,7 @@ def test_argument_with_default_value_none(): @strawberry.type class Query: @strawberry.field - def name(self, argument: Optional[str] = None) -> str: + def name(self, argument: Optional[str] = None) -> str: # pragma: no cover return "Name" definition = Query.__strawberry_definition__ @@ -233,7 +235,7 @@ def test_argument_with_default_value_undefined(): @strawberry.type class Query: @strawberry.field - def name(self, argument: Optional[str]) -> str: + def name(self, argument: Optional[str]) -> str: # pragma: no cover return "Name" definition = Query.__strawberry_definition__ @@ -259,7 +261,7 @@ def name( # type: ignore str, strawberry.argument(description="This is a description"), ], - ) -> str: + ) -> str: # pragma: no cover return "Name" definition = Query.__strawberry_definition__ @@ -283,7 +285,7 @@ def name( # type: ignore Optional[str], strawberry.argument(description="This is a description"), ], - ) -> str: + ) -> str: # pragma: no cover return "Name" definition = Query.__strawberry_definition__ @@ -309,7 +311,7 @@ def name( str, strawberry.argument(description="This is a description"), ] = "Patrick", - ) -> str: + ) -> str: # pragma: no cover return "Name" definition = Query.__strawberry_definition__ @@ -335,7 +337,7 @@ def name( str, strawberry.argument(name="argument"), ] = "Patrick", - ) -> str: + ) -> str: # pragma: no cover return "Name" definition = Query.__strawberry_definition__ @@ -364,7 +366,7 @@ def name( strawberry.argument(description="This is a description"), strawberry.argument(description="Another description"), ], - ) -> str: + ) -> str: # pragma: no cover return "Name" assert str(error.value) == ( @@ -378,7 +380,9 @@ def test_annotated_with_other_information(): @strawberry.type class Query: @strawberry.field - def name(self, argument: Annotated[str, "Some other info"]) -> str: + def name( + self, argument: Annotated[str, "Some other info"] + ) -> str: # pragma: no cover return "Name" definition = Query.__strawberry_definition__ @@ -393,10 +397,6 @@ def name(self, argument: Annotated[str, "Some other info"]) -> str: assert argument.type is str -@pytest.mark.skipif( - sys.version_info < (3, 9), - reason="Annotated type was added in python 3.9", -) def test_annotated_python_39(): from typing import Annotated @@ -409,7 +409,7 @@ def name( str, strawberry.argument(description="This is a description"), ], - ) -> str: + ) -> str: # pragma: no cover return "Name" definition = Query.__strawberry_definition__ @@ -471,7 +471,7 @@ def test_resolver_with_invalid_field_argument_type(): class Adjective: text: str - def add_adjective_resolver(adjective: Adjective) -> bool: + def add_adjective_resolver(adjective: Adjective) -> bool: # pragma: no cover return True @strawberry.type diff --git a/tests/fields/test_field_defaults.py b/tests/fields/test_field_defaults.py index 2ca92069c5..cfbc0a6d8d 100644 --- a/tests/fields/test_field_defaults.py +++ b/tests/fields/test_field_defaults.py @@ -1,5 +1,3 @@ -from typing import List - import pytest import strawberry @@ -50,7 +48,7 @@ def test_field_default_extensions_value_set(): def test_field_default_factory_executed_each_time(): @strawberry.type class Query: - the_list: List[str] = strawberry.field(default_factory=list) + the_list: list[str] = strawberry.field(default_factory=list) assert Query().the_list == Query().the_list assert Query().the_list is not Query().the_list diff --git a/tests/fields/test_resolvers.py b/tests/fields/test_resolvers.py index c5611e369c..dbf9c53536 100644 --- a/tests/fields/test_resolvers.py +++ b/tests/fields/test_resolvers.py @@ -1,7 +1,7 @@ import dataclasses import textwrap import types -from typing import Any, ClassVar, List, no_type_check +from typing import Any, ClassVar, no_type_check import pytest @@ -417,7 +417,7 @@ def test_resolver_with_unhashable_default(): @strawberry.type class Query: @strawberry.field - def field(self, x: List[str] = ["foo"], y: JSON = {"foo": 42}) -> str: + def field(self, x: list[str] = ["foo"], y: JSON = {"foo": 42}) -> str: return f"{x} {y}" schema = strawberry.Schema(Query) diff --git a/tests/http/clients/aiohttp.py b/tests/http/clients/aiohttp.py index c6ce2b8fe7..915559a5f0 100644 --- a/tests/http/clients/aiohttp.py +++ b/tests/http/clients/aiohttp.py @@ -2,8 +2,9 @@ import contextlib import json +from collections.abc import AsyncGenerator from io import BytesIO -from typing import Any, AsyncGenerator, Dict, List, Optional +from typing import Any, Optional from typing_extensions import Literal from aiohttp import web @@ -30,14 +31,14 @@ ) -class GraphQLView(OnWSConnectMixin, BaseGraphQLView[Dict[str, object], object]): +class GraphQLView(OnWSConnectMixin, BaseGraphQLView[dict[str, object], object]): result_override: ResultOverrideFunction = None graphql_transport_ws_handler_class = DebuggableGraphQLTransportWSHandler graphql_ws_handler_class = DebuggableGraphQLWSHandler async def get_context( self, request: web.Request, response: web.StreamResponse - ) -> Dict[str, object]: + ) -> dict[str, object]: context = await super().get_context(request, response) return get_context(context) @@ -95,9 +96,9 @@ async def _graphql_request( self, method: Literal["get", "post"], query: Optional[str] = None, - variables: Optional[Dict[str, object]] = None, - files: Optional[Dict[str, BytesIO]] = None, - headers: Optional[Dict[str, str]] = None, + variables: Optional[dict[str, object]] = None, + files: Optional[dict[str, BytesIO]] = None, + headers: Optional[dict[str, str]] = None, **kwargs: Any, ) -> Response: async with TestClient(TestServer(self.app)) as client: @@ -129,7 +130,7 @@ async def request( self, url: str, method: Literal["get", "post", "patch", "put", "delete"], - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: async with TestClient(TestServer(self.app)) as client: response = await getattr(client, method)(url, headers=headers) @@ -143,7 +144,7 @@ async def request( async def get( self, url: str, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: return await self.request(url, "get", headers=headers) @@ -152,7 +153,7 @@ async def post( url: str, data: Optional[bytes] = None, json: Optional[JSON] = None, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: async with TestClient(TestServer(self.app)) as client: response = await client.post( @@ -170,7 +171,7 @@ async def ws_connect( self, url: str, *, - protocols: List[str], + protocols: list[str], ) -> AsyncGenerator[WebSocketClient, None]: async with TestClient(TestServer(self.app)) as client: async with client.ws_connect(url, protocols=protocols) as ws: @@ -185,7 +186,7 @@ def __init__(self, ws: ClientWebSocketResponse): async def send_text(self, payload: str) -> None: await self.ws.send_str(payload) - async def send_json(self, payload: Dict[str, Any]) -> None: + async def send_json(self, payload: dict[str, Any]) -> None: await self.ws.send_json(payload) async def send_bytes(self, payload: bytes) -> None: diff --git a/tests/http/clients/asgi.py b/tests/http/clients/asgi.py index 7d9b86ea8e..b836dcce66 100644 --- a/tests/http/clients/asgi.py +++ b/tests/http/clients/asgi.py @@ -2,8 +2,9 @@ import contextlib import json +from collections.abc import AsyncGenerator from io import BytesIO -from typing import Any, AsyncGenerator, Dict, List, Optional, Union +from typing import Any, Optional, Union from typing_extensions import Literal from starlette.requests import Request @@ -31,7 +32,7 @@ ) -class GraphQLView(OnWSConnectMixin, BaseGraphQLView[Dict[str, object], object]): +class GraphQLView(OnWSConnectMixin, BaseGraphQLView[dict[str, object], object]): result_override: ResultOverrideFunction = None graphql_transport_ws_handler_class = DebuggableGraphQLTransportWSHandler graphql_ws_handler_class = DebuggableGraphQLWSHandler @@ -43,7 +44,7 @@ async def get_context( self, request: Union[Request, WebSocket], response: Union[StarletteResponse, WebSocket], - ) -> Dict[str, object]: + ) -> dict[str, object]: context = await super().get_context(request, response) return get_context(context) @@ -86,9 +87,9 @@ async def _graphql_request( self, method: Literal["get", "post"], query: Optional[str] = None, - variables: Optional[Dict[str, object]] = None, - files: Optional[Dict[str, BytesIO]] = None, - headers: Optional[Dict[str, str]] = None, + variables: Optional[dict[str, object]] = None, + files: Optional[dict[str, BytesIO]] = None, + headers: Optional[dict[str, str]] = None, **kwargs: Any, ) -> Response: body = self._build_body( @@ -122,7 +123,7 @@ async def request( self, url: str, method: Literal["get", "post", "patch", "put", "delete"], - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: response = getattr(self.client, method)(url, headers=headers) @@ -135,7 +136,7 @@ async def request( async def get( self, url: str, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: return await self.request(url, "get", headers=headers) @@ -144,7 +145,7 @@ async def post( url: str, data: Optional[bytes] = None, json: Optional[JSON] = None, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: response = self.client.post(url, headers=headers, content=data, json=json) @@ -159,7 +160,7 @@ async def ws_connect( self, url: str, *, - protocols: List[str], + protocols: list[str], ) -> AsyncGenerator[WebSocketClient, None]: try: with self.client.websocket_connect(url, protocols) as ws: @@ -185,7 +186,7 @@ def handle_disconnect(self, exc: WebSocketDisconnect) -> None: async def send_text(self, payload: str) -> None: self.ws.send_text(payload) - async def send_json(self, payload: Dict[str, Any]) -> None: + async def send_json(self, payload: dict[str, Any]) -> None: self.ws.send_json(payload) async def send_bytes(self, payload: bytes) -> None: diff --git a/tests/http/clients/async_flask.py b/tests/http/clients/async_flask.py index 4db37c8d1b..3b8e45b755 100644 --- a/tests/http/clients/async_flask.py +++ b/tests/http/clients/async_flask.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Dict, Optional +from typing import Any, Optional from flask import Flask from flask import Request as FlaskRequest @@ -31,7 +31,7 @@ async def get_root_value(self, request: FlaskRequest) -> Query: async def get_context( self, request: FlaskRequest, response: FlaskResponse - ) -> Dict[str, object]: + ) -> dict[str, object]: context = await super().get_context(request, response) return get_context(context) diff --git a/tests/http/clients/base.py b/tests/http/clients/base.py index c7e7f12b44..426aa0e19b 100644 --- a/tests/http/clients/base.py +++ b/tests/http/clients/base.py @@ -2,21 +2,11 @@ import contextlib import json import logging +from collections.abc import AsyncGenerator, AsyncIterable, Mapping from dataclasses import dataclass from functools import cached_property from io import BytesIO -from typing import ( - Any, - AsyncContextManager, - AsyncGenerator, - AsyncIterable, - Callable, - Dict, - List, - Mapping, - Optional, - Union, -) +from typing import Any, Callable, Optional, Union from typing_extensions import Literal from strawberry.http import GraphQLHTTPResponse @@ -33,7 +23,7 @@ logger = logging.getLogger("strawberry.test.http_client") -JSON = Dict[str, object] +JSON = dict[str, object] ResultOverrideFunction = Optional[Callable[[ExecutionResult], GraphQLHTTPResponse]] @@ -47,7 +37,7 @@ def __init__( status_code: int, data: Union[bytes, AsyncIterable[bytes]], *, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> None: self.status_code = status_code self.data = data @@ -119,9 +109,9 @@ async def _graphql_request( self, method: Literal["get", "post"], query: Optional[str] = None, - variables: Optional[Dict[str, object]] = None, - files: Optional[Dict[str, BytesIO]] = None, - headers: Optional[Dict[str, str]] = None, + variables: Optional[dict[str, object]] = None, + files: Optional[dict[str, BytesIO]] = None, + headers: Optional[dict[str, str]] = None, **kwargs: Any, ) -> Response: ... @@ -130,14 +120,14 @@ async def request( self, url: str, method: Literal["get", "post", "patch", "put", "delete"], - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: ... @abc.abstractmethod async def get( self, url: str, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: ... @abc.abstractmethod @@ -146,16 +136,16 @@ async def post( url: str, data: Optional[bytes] = None, json: Optional[JSON] = None, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: ... async def query( self, query: Optional[str] = None, method: Literal["get", "post"] = "post", - variables: Optional[Dict[str, object]] = None, - files: Optional[Dict[str, BytesIO]] = None, - headers: Optional[Dict[str, str]] = None, + variables: Optional[dict[str, object]] = None, + files: Optional[dict[str, BytesIO]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: return await self._graphql_request( method, query=query, headers=headers, variables=variables, files=files @@ -164,9 +154,9 @@ async def query( def _get_headers( self, method: Literal["get", "post"], - headers: Optional[Dict[str, str]], - files: Optional[Dict[str, BytesIO]], - ) -> Dict[str, str]: + headers: Optional[dict[str, str]], + files: Optional[dict[str, BytesIO]], + ) -> dict[str, str]: additional_headers = {} headers = headers or {} @@ -183,17 +173,17 @@ def _get_headers( def _build_body( self, query: Optional[str] = None, - variables: Optional[Dict[str, object]] = None, - files: Optional[Dict[str, BytesIO]] = None, + variables: Optional[dict[str, object]] = None, + files: Optional[dict[str, BytesIO]] = None, method: Literal["get", "post"] = "post", - ) -> Optional[Dict[str, object]]: + ) -> Optional[dict[str, object]]: if query is None: assert files is None assert variables is None return None - body: Dict[str, object] = {"query": query} + body: dict[str, object] = {"query": query} if variables: body["variables"] = variables @@ -215,11 +205,11 @@ def _build_body( @staticmethod def _build_multipart_file_map( - variables: Dict[str, object], files: Dict[str, BytesIO] - ) -> Dict[str, List[str]]: + variables: dict[str, object], files: dict[str, BytesIO] + ) -> dict[str, list[str]]: # TODO: remove code duplication - files_map: Dict[str, List[str]] = {} + files_map: dict[str, list[str]] = {} for key, values in variables.items(): if isinstance(values, dict): folder_key = next(iter(values.keys())) @@ -249,8 +239,8 @@ def ws_connect( self, url: str, *, - protocols: List[str], - ) -> AsyncContextManager["WebSocketClient"]: + protocols: list[str], + ) -> contextlib.AbstractAsyncContextManager["WebSocketClient"]: raise NotImplementedError @@ -323,7 +313,7 @@ def __init__(self, *args: Any, **kwargs: Any): self.original_context = kwargs.get("context", {}) DebuggableGraphQLTransportWSHandler.on_init(self) - def get_tasks(self) -> List: + def get_tasks(self) -> list: return [op.task for op in self.operations.values()] @property @@ -345,7 +335,7 @@ def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) self.original_context = self.context - def get_tasks(self) -> List: + def get_tasks(self) -> list: return list(self.tasks.values()) @property diff --git a/tests/http/clients/chalice.py b/tests/http/clients/chalice.py index 2e01c1f400..8fbb31ff46 100644 --- a/tests/http/clients/chalice.py +++ b/tests/http/clients/chalice.py @@ -3,7 +3,7 @@ import urllib.parse from io import BytesIO from json import dumps -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from typing_extensions import Literal from chalice.app import Chalice @@ -72,16 +72,16 @@ async def _graphql_request( self, method: Literal["get", "post"], query: Optional[str] = None, - variables: Optional[Dict[str, object]] = None, - files: Optional[Dict[str, BytesIO]] = None, - headers: Optional[Dict[str, str]] = None, + variables: Optional[dict[str, object]] = None, + files: Optional[dict[str, BytesIO]] = None, + headers: Optional[dict[str, str]] = None, **kwargs: Any, ) -> Response: body = self._build_body( query=query, variables=variables, files=files, method=method ) - data: Union[Dict[str, object], str, None] = None + data: Union[dict[str, object], str, None] = None if body and files: body.update({name: (file, name) for name, file in files.items()}) @@ -113,7 +113,7 @@ async def request( self, url: str, method: Literal["get", "post", "patch", "put", "delete"], - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: with Client(self.app) as client: response = getattr(client.http, method)(url, headers=headers) @@ -127,7 +127,7 @@ async def request( async def get( self, url: str, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: return await self.request(url, "get", headers=headers) @@ -136,7 +136,7 @@ async def post( url: str, data: Optional[bytes] = None, json: Optional[JSON] = None, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: body = data or dumps(json) diff --git a/tests/http/clients/channels.py b/tests/http/clients/channels.py index 802fc263cc..bfec59a1f4 100644 --- a/tests/http/clients/channels.py +++ b/tests/http/clients/channels.py @@ -2,8 +2,9 @@ import contextlib import json as json_module +from collections.abc import AsyncGenerator from io import BytesIO -from typing import Any, AsyncGenerator, Dict, List, Optional +from typing import Any, Optional from typing_extensions import Literal from urllib3 import encode_multipart_formdata @@ -35,9 +36,9 @@ def generate_get_path( - path, query: str, variables: Optional[Dict[str, Any]] = None + path, query: str, variables: Optional[dict[str, Any]] = None ) -> str: - body: Dict[str, Any] = {"query": query} + body: dict[str, Any] = {"query": query} if variables is not None: body["variables"] = json_module.dumps(variables) @@ -46,7 +47,7 @@ def generate_get_path( def create_multipart_request_body( - body: Dict[str, object], files: Dict[str, BytesIO] + body: dict[str, object], files: dict[str, BytesIO] ) -> tuple[list[tuple[str, str]], bytes]: fields = { "operations": body["operations"], @@ -156,9 +157,9 @@ async def _graphql_request( self, method: Literal["get", "post"], query: Optional[str] = None, - variables: Optional[Dict[str, object]] = None, - files: Optional[Dict[str, BytesIO]] = None, - headers: Optional[Dict[str, str]] = None, + variables: Optional[dict[str, object]] = None, + files: Optional[dict[str, BytesIO]] = None, + headers: Optional[dict[str, str]] = None, **kwargs: Any, ) -> Response: body = self._build_body( @@ -188,7 +189,7 @@ async def request( url: str, method: Literal["get", "post", "patch", "put", "delete"], body: bytes = b"", - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: # HttpCommunicator expects tuples of bytestrings if headers: @@ -212,7 +213,7 @@ async def request( async def get( self, url: str, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: return await self.request(url, "get", headers=headers) @@ -221,7 +222,7 @@ async def post( url: str, data: Optional[bytes] = None, json: Optional[JSON] = None, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: body = b"" if data is not None: @@ -235,7 +236,7 @@ async def ws_connect( self, url: str, *, - protocols: List[str], + protocols: list[str], ) -> AsyncGenerator[WebSocketClient, None]: client = WebsocketCommunicator(self.ws_app, url, subprotocols=protocols) @@ -285,7 +286,7 @@ def name(self) -> str: async def send_text(self, payload: str) -> None: await self.ws.send_to(text_data=payload) - async def send_json(self, payload: Dict[str, Any]) -> None: + async def send_json(self, payload: dict[str, Any]) -> None: await self.ws.send_json_to(payload) async def send_bytes(self, payload: bytes) -> None: diff --git a/tests/http/clients/django.py b/tests/http/clients/django.py index a871d7b63b..10b19893c3 100644 --- a/tests/http/clients/django.py +++ b/tests/http/clients/django.py @@ -2,7 +2,7 @@ from io import BytesIO from json import dumps -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from typing_extensions import Literal from django.core.exceptions import BadRequest, SuspiciousOperation @@ -62,9 +62,9 @@ def _get_header_name(self, key: str) -> str: def _get_headers( self, method: Literal["get", "post"], - headers: Optional[Dict[str, str]], - files: Optional[Dict[str, BytesIO]], - ) -> Dict[str, str]: + headers: Optional[dict[str, str]], + files: Optional[dict[str, BytesIO]], + ) -> dict[str, str]: headers = headers or {} headers = {self._get_header_name(key): value for key, value in headers.items()} @@ -101,9 +101,9 @@ async def _graphql_request( self, method: Literal["get", "post"], query: Optional[str] = None, - variables: Optional[Dict[str, object]] = None, - files: Optional[Dict[str, BytesIO]] = None, - headers: Optional[Dict[str, str]] = None, + variables: Optional[dict[str, object]] = None, + files: Optional[dict[str, BytesIO]] = None, + headers: Optional[dict[str, str]] = None, **kwargs: Any, ) -> Response: headers = self._get_headers(method=method, headers=headers, files=files) @@ -113,7 +113,7 @@ async def _graphql_request( query=query, variables=variables, files=files, method=method ) - data: Union[Dict[str, object], str, None] = None + data: Union[dict[str, object], str, None] = None if body and files: files = { @@ -140,7 +140,7 @@ async def request( self, url: str, method: Literal["get", "post", "patch", "put", "delete"], - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: headers = self._get_headers( method=method, # type: ignore @@ -156,7 +156,7 @@ async def request( async def get( self, url: str, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: return await self.request(url, "get", headers=headers) @@ -165,7 +165,7 @@ async def post( url: str, data: Optional[bytes] = None, json: Optional[JSON] = None, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: headers = self._get_headers(method="post", headers=headers, files=None) diff --git a/tests/http/clients/fastapi.py b/tests/http/clients/fastapi.py index 70eded4049..0f0ed88398 100644 --- a/tests/http/clients/fastapi.py +++ b/tests/http/clients/fastapi.py @@ -2,8 +2,9 @@ import contextlib import json +from collections.abc import AsyncGenerator from io import BytesIO -from typing import Any, AsyncGenerator, Dict, List, Optional +from typing import Any, Optional from typing_extensions import Literal from starlette.websockets import WebSocketDisconnect @@ -39,7 +40,7 @@ async def fastapi_get_context( request: Request = None, # type: ignore ws: WebSocket = None, # type: ignore custom_value: str = Depends(custom_context_dependency), -) -> Dict[str, object]: +) -> dict[str, object]: return get_context( { "request": request or ws, @@ -114,9 +115,9 @@ async def _graphql_request( self, method: Literal["get", "post"], query: Optional[str] = None, - variables: Optional[Dict[str, object]] = None, - files: Optional[Dict[str, BytesIO]] = None, - headers: Optional[Dict[str, str]] = None, + variables: Optional[dict[str, object]] = None, + files: Optional[dict[str, BytesIO]] = None, + headers: Optional[dict[str, str]] = None, **kwargs: Any, ) -> Response: body = self._build_body( @@ -147,7 +148,7 @@ async def request( self, url: str, method: Literal["get", "post", "patch", "put", "delete"], - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: response = getattr(self.client, method)(url, headers=headers) @@ -156,7 +157,7 @@ async def request( async def get( self, url: str, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: return await self.request(url, "get", headers=headers) @@ -165,7 +166,7 @@ async def post( url: str, data: Optional[bytes] = None, json: Optional[JSON] = None, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: response = self.client.post(url, headers=headers, content=data, json=json) @@ -176,7 +177,7 @@ async def ws_connect( self, url: str, *, - protocols: List[str], + protocols: list[str], ) -> AsyncGenerator[WebSocketClient, None]: try: with self.client.websocket_connect(url, protocols) as ws: diff --git a/tests/http/clients/flask.py b/tests/http/clients/flask.py index 7da42bbaff..7f5ec5cb8f 100644 --- a/tests/http/clients/flask.py +++ b/tests/http/clients/flask.py @@ -6,7 +6,7 @@ import json import urllib.parse from io import BytesIO -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from typing_extensions import Literal from flask import Flask @@ -40,7 +40,7 @@ def get_root_value(self, request: FlaskRequest) -> object: def get_context( self, request: FlaskRequest, response: FlaskResponse - ) -> Dict[str, object]: + ) -> dict[str, object]: context = super().get_context(request, response) return get_context(context) @@ -85,16 +85,16 @@ async def _graphql_request( self, method: Literal["get", "post"], query: Optional[str] = None, - variables: Optional[Dict[str, object]] = None, - files: Optional[Dict[str, BytesIO]] = None, - headers: Optional[Dict[str, str]] = None, + variables: Optional[dict[str, object]] = None, + files: Optional[dict[str, BytesIO]] = None, + headers: Optional[dict[str, str]] = None, **kwargs: Any, ) -> Response: body = self._build_body( query=query, variables=variables, files=files, method=method ) - data: Union[Dict[str, object], str, None] = None + data: Union[dict[str, object], str, None] = None if body and files: body.update({name: (file, name) for name, file in files.items()}) @@ -117,7 +117,7 @@ def _do_request( self, url: str, method: Literal["get", "post", "patch", "put", "delete"], - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, **kwargs: Any, ): with self.app.test_client() as client: @@ -133,7 +133,7 @@ async def request( self, url: str, method: Literal["get", "post", "patch", "put", "delete"], - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, **kwargs: Any, ) -> Response: loop = asyncio.get_running_loop() @@ -146,7 +146,7 @@ async def request( async def get( self, url: str, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: return await self.request(url, "get", headers=headers) @@ -155,6 +155,6 @@ async def post( url: str, data: Optional[bytes] = None, json: Optional[JSON] = None, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: return await self.request(url, "post", headers=headers, data=data, json=json) diff --git a/tests/http/clients/litestar.py b/tests/http/clients/litestar.py index dc948e9868..f357bf75f8 100644 --- a/tests/http/clients/litestar.py +++ b/tests/http/clients/litestar.py @@ -2,8 +2,9 @@ import contextlib import json +from collections.abc import AsyncGenerator from io import BytesIO -from typing import Any, AsyncGenerator, Dict, List, Optional +from typing import Any, Optional from typing_extensions import Literal from litestar import Litestar, Request @@ -87,9 +88,9 @@ async def _graphql_request( self, method: Literal["get", "post"], query: Optional[str] = None, - variables: Optional[Dict[str, object]] = None, - files: Optional[Dict[str, BytesIO]] = None, - headers: Optional[Dict[str, str]] = None, + variables: Optional[dict[str, object]] = None, + files: Optional[dict[str, BytesIO]] = None, + headers: Optional[dict[str, str]] = None, **kwargs: Any, ) -> Response: if body := self._build_body( @@ -121,7 +122,7 @@ async def request( self, url: str, method: Literal["get", "post", "patch", "put", "delete"], - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: response = getattr(self.client, method)(url, headers=headers) @@ -134,7 +135,7 @@ async def request( async def get( self, url: str, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: return await self.request(url, "get", headers=headers) @@ -143,7 +144,7 @@ async def post( url: str, data: Optional[bytes] = None, json: Optional[JSON] = None, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: response = self.client.post(url, headers=headers, content=data, json=json) @@ -158,7 +159,7 @@ async def ws_connect( self, url: str, *, - protocols: List[str], + protocols: list[str], ) -> AsyncGenerator[WebSocketClient, None]: try: with self.client.websocket_connect(url, protocols) as ws: @@ -183,7 +184,7 @@ def handle_disconnect(self, exc: WebSocketDisconnect) -> None: async def send_text(self, payload: str) -> None: self.ws.send_text(payload) - async def send_json(self, payload: Dict[str, Any]) -> None: + async def send_json(self, payload: dict[str, Any]) -> None: self.ws.send_json(payload) async def send_bytes(self, payload: bytes) -> None: diff --git a/tests/http/clients/quart.py b/tests/http/clients/quart.py index d9a184dfd4..43d21ee28f 100644 --- a/tests/http/clients/quart.py +++ b/tests/http/clients/quart.py @@ -1,7 +1,7 @@ import json import urllib.parse from io import BytesIO -from typing import Any, Dict, Optional +from typing import Any, Optional from typing_extensions import Literal from quart import Quart @@ -33,7 +33,7 @@ async def get_root_value(self, request: QuartRequest) -> Query: async def get_context( self, request: QuartRequest, response: QuartResponse - ) -> Dict[str, object]: + ) -> dict[str, object]: context = await super().get_context(request, response) return get_context(context) @@ -78,9 +78,9 @@ async def _graphql_request( self, method: Literal["get", "post"], query: Optional[str] = None, - variables: Optional[Dict[str, object]] = None, - files: Optional[Dict[str, BytesIO]] = None, - headers: Optional[Dict[str, str]] = None, + variables: Optional[dict[str, object]] = None, + files: Optional[dict[str, BytesIO]] = None, + headers: Optional[dict[str, str]] = None, **kwargs: Any, ) -> Response: body = self._build_body( @@ -109,7 +109,7 @@ async def request( self, url: str, method: Literal["get", "post", "patch", "put", "delete"], - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, **kwargs: Any, ) -> Response: async with self.app.test_app() as test_app: @@ -126,7 +126,7 @@ async def request( async def get( self, url: str, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: return await self.request(url, "get", headers=headers) @@ -135,7 +135,7 @@ async def post( url: str, data: Optional[bytes] = None, json: Optional[JSON] = None, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: kwargs = {"headers": headers, "data": data, "json": json} return await self.request( diff --git a/tests/http/clients/sanic.py b/tests/http/clients/sanic.py index 8edc351db9..86a3346f73 100644 --- a/tests/http/clients/sanic.py +++ b/tests/http/clients/sanic.py @@ -3,7 +3,7 @@ from io import BytesIO from json import dumps from random import randint -from typing import Any, Dict, Optional +from typing import Any, Optional from typing_extensions import Literal from sanic import Sanic @@ -75,9 +75,9 @@ async def _graphql_request( self, method: Literal["get", "post"], query: Optional[str] = None, - variables: Optional[Dict[str, object]] = None, - files: Optional[Dict[str, BytesIO]] = None, - headers: Optional[Dict[str, str]] = None, + variables: Optional[dict[str, object]] = None, + files: Optional[dict[str, BytesIO]] = None, + headers: Optional[dict[str, str]] = None, **kwargs: Any, ) -> Response: body = self._build_body( @@ -111,7 +111,7 @@ async def request( self, url: str, method: Literal["get", "post", "patch", "put", "delete"], - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: request, response = await self.app.asgi_client.request( method, @@ -128,7 +128,7 @@ async def request( async def get( self, url: str, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: return await self.request(url, "get", headers=headers) @@ -137,7 +137,7 @@ async def post( url: str, data: Optional[bytes] = None, json: Optional[JSON] = None, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, ) -> Response: body = data or dumps(json) request, response = await self.app.asgi_client.request( diff --git a/tests/http/conftest.py b/tests/http/conftest.py index d1cecb0c22..3f61148830 100644 --- a/tests/http/conftest.py +++ b/tests/http/conftest.py @@ -1,5 +1,6 @@ import importlib -from typing import Any, Generator, Type +from collections.abc import Generator +from typing import Any import pytest @@ -46,10 +47,10 @@ def _get_http_client_classes() -> Generator[Any, None, None]: @pytest.fixture(params=_get_http_client_classes()) -def http_client_class(request: Any) -> Type[HttpClient]: +def http_client_class(request: Any) -> type[HttpClient]: return request.param @pytest.fixture() -def http_client(http_client_class: Type[HttpClient]) -> HttpClient: +def http_client(http_client_class: type[HttpClient]) -> HttpClient: return http_client_class() diff --git a/tests/http/context.py b/tests/http/context.py index 99985b2434..fbd92e9e0f 100644 --- a/tests/http/context.py +++ b/tests/http/context.py @@ -1,7 +1,4 @@ -from typing import Dict - - -def get_context(context: object) -> Dict[str, object]: +def get_context(context: object) -> dict[str, object]: assert isinstance(context, dict) return {**context, "custom_value": "a value from context"} diff --git a/tests/http/test_graphql_ide.py b/tests/http/test_graphql_ide.py index d9a85ad736..f58f117228 100644 --- a/tests/http/test_graphql_ide.py +++ b/tests/http/test_graphql_ide.py @@ -1,4 +1,4 @@ -from typing import Type, Union +from typing import Union from typing_extensions import Literal import pytest @@ -10,7 +10,7 @@ @pytest.mark.parametrize("graphql_ide", ["graphiql", "apollo-sandbox", "pathfinder"]) async def test_renders_graphql_ide( header_value: str, - http_client_class: Type[HttpClient], + http_client_class: type[HttpClient], graphql_ide: Literal["graphiql", "apollo-sandbox", "pathfinder"], ): http_client = http_client_class(graphql_ide=graphql_ide) @@ -36,7 +36,7 @@ async def test_renders_graphql_ide( @pytest.mark.parametrize("header_value", ["text/html", "*/*"]) async def test_renders_graphql_ide_deprecated( - header_value: str, http_client_class: Type[HttpClient] + header_value: str, http_client_class: type[HttpClient] ): with pytest.deprecated_call( match=r"The `graphiql` argument is deprecated in favor of `graphql_ide`" @@ -57,7 +57,7 @@ async def test_renders_graphql_ide_deprecated( async def test_does_not_render_graphiql_if_wrong_accept( - http_client_class: Type[HttpClient], + http_client_class: type[HttpClient], ): http_client = http_client_class() response = await http_client.get("/graphql", headers={"Accept": "text/xml"}) @@ -69,7 +69,7 @@ async def test_does_not_render_graphiql_if_wrong_accept( @pytest.mark.parametrize("graphql_ide", [False, None]) async def test_renders_graphiql_disabled( - http_client_class: Type[HttpClient], + http_client_class: type[HttpClient], graphql_ide: Union[bool, None], ): http_client = http_client_class(graphql_ide=graphql_ide) @@ -79,7 +79,7 @@ async def test_renders_graphiql_disabled( async def test_renders_graphiql_disabled_deprecated( - http_client_class: Type[HttpClient], + http_client_class: type[HttpClient], ): with pytest.deprecated_call( match=r"The `graphiql` argument is deprecated in favor of `graphql_ide`" diff --git a/tests/http/test_multipart_subscription.py b/tests/http/test_multipart_subscription.py index b680e9267a..fa8a08248f 100644 --- a/tests/http/test_multipart_subscription.py +++ b/tests/http/test_multipart_subscription.py @@ -1,5 +1,4 @@ import contextlib -from typing import Type from typing_extensions import Literal import pytest @@ -10,7 +9,7 @@ @pytest.fixture() -def http_client(http_client_class: Type[HttpClient]) -> HttpClient: +def http_client(http_client_class: type[HttpClient]) -> HttpClient: with contextlib.suppress(ImportError): import django diff --git a/tests/http/test_parse_content_type.py b/tests/http/test_parse_content_type.py index 4ae0017e40..cef6fa0a0b 100644 --- a/tests/http/test_parse_content_type.py +++ b/tests/http/test_parse_content_type.py @@ -1,5 +1,3 @@ -from typing import Dict, Tuple - import pytest from strawberry.http.parse_content_type import parse_content_type @@ -44,6 +42,6 @@ ) async def test_parse_content_type( content_type: str, - expected: Tuple[str, Dict[str, str]], + expected: tuple[str, dict[str, str]], ): assert parse_content_type(content_type) == expected diff --git a/tests/http/test_upload.py b/tests/http/test_upload.py index 7a991db846..e82f7e30b5 100644 --- a/tests/http/test_upload.py +++ b/tests/http/test_upload.py @@ -1,7 +1,6 @@ import contextlib import json from io import BytesIO -from typing import Type import pytest from urllib3 import encode_multipart_formdata @@ -10,7 +9,7 @@ @pytest.fixture() -def http_client(http_client_class: Type[HttpClient]) -> HttpClient: +def http_client(http_client_class: type[HttpClient]) -> HttpClient: with contextlib.suppress(ImportError): from .clients.chalice import ChaliceHttpClient @@ -21,7 +20,7 @@ def http_client(http_client_class: Type[HttpClient]) -> HttpClient: @pytest.fixture() -def enabled_http_client(http_client_class: Type[HttpClient]) -> HttpClient: +def enabled_http_client(http_client_class: type[HttpClient]) -> HttpClient: with contextlib.suppress(ImportError): from .clients.chalice import ChaliceHttpClient diff --git a/tests/litestar/test_context.py b/tests/litestar/test_context.py index ce009ec346..2243eb61ee 100644 --- a/tests/litestar/test_context.py +++ b/tests/litestar/test_context.py @@ -1,5 +1,3 @@ -from typing import Dict - import strawberry @@ -72,7 +70,7 @@ def abc(self, info: strawberry.Info) -> str: def custom_context_dependency() -> str: return "rocks" - async def get_context(custom_context_dependency: str) -> Dict[str, str]: + async def get_context(custom_context_dependency: str) -> dict[str, str]: return {"strawberry": custom_context_dependency} schema = strawberry.Schema(query=Query) diff --git a/tests/objects/generics/test_generic_objects.py b/tests/objects/generics/test_generic_objects.py index 2dcf9d5fe3..b1c0a0e32d 100644 --- a/tests/objects/generics/test_generic_objects.py +++ b/tests/objects/generics/test_generic_objects.py @@ -1,6 +1,5 @@ import datetime -from typing import Generic, List, Optional, TypeVar, Union -from typing_extensions import Annotated +from typing import Annotated, Generic, Optional, TypeVar, Union import pytest @@ -104,7 +103,7 @@ class Edge(Generic[T]): @strawberry.type class Connection(Generic[T]): - edges: List[Edge[T]] + edges: list[Edge[T]] definition = get_object_definition(Connection, strict=True) assert definition.is_graphql_generic @@ -140,9 +139,9 @@ class Value(Generic[T]): @strawberry.type class Foo: string: Value[str] - strings: Value[List[str]] + strings: Value[list[str]] optional_string: Value[Optional[str]] - optional_strings: Value[Optional[List[str]]] + optional_strings: Value[Optional[list[str]]] definition = get_object_definition(Foo, strict=True) assert not definition.is_graphql_generic @@ -191,7 +190,7 @@ class Edge(Generic[T]): def test_generic_with_list(): @strawberry.type class Connection(Generic[T]): - edges: List[T] + edges: list[T] definition = get_object_definition(Connection, strict=True) assert definition.is_graphql_generic @@ -220,7 +219,7 @@ class Connection(Generic[T]): def test_generic_with_list_of_optionals(): @strawberry.type class Connection(Generic[T]): - edges: List[Optional[T]] + edges: list[Optional[T]] definition = get_object_definition(Connection, strict=True) assert definition.is_graphql_generic @@ -371,7 +370,7 @@ class Edge(Generic[T]): @strawberry.type class Connection(Generic[T]): - edges: List[Edge[T]] + edges: list[Edge[T]] @strawberry.type class User: @@ -421,7 +420,7 @@ class Edge(Generic[T]): @strawberry.type class Query: - user: List[Edge[str]] + user: list[Edge[str]] query_definition = get_object_definition(Query, strict=True) assert query_definition.type_params == [] @@ -498,7 +497,7 @@ class Cat: @strawberry.type class Connection(Generic[T]): - nodes: List[T] + nodes: list[T] DogCat = Annotated[Union[Dog, Cat], strawberry.union("DogCat")] @@ -535,7 +534,7 @@ class Cat: @strawberry.type class Connection(Generic[T]): - nodes: List[T] + nodes: list[T] @strawberry.type class Query: @@ -590,7 +589,7 @@ def test_generic_with_arguments(): @strawberry.type class Collection(Generic[T]): @strawberry.field - def by_id(self, ids: List[int]) -> List[T]: + def by_id(self, ids: list[int]) -> list[T]: return [] @strawberry.type diff --git a/tests/objects/generics/test_names.py b/tests/objects/generics/test_names.py index fe5c747122..27ab8f61fd 100644 --- a/tests/objects/generics/test_names.py +++ b/tests/objects/generics/test_names.py @@ -1,6 +1,5 @@ import textwrap -from typing import Generic, List, NewType, TypeVar -from typing_extensions import Annotated +from typing import Annotated, Generic, NewType, TypeVar import pytest @@ -70,7 +69,7 @@ class Edge(Generic[T]): @strawberry.type class Connection(Generic[T]): - edges: List[T] + edges: list[T] type_definition = Connection.__strawberry_definition__ # type: ignore @@ -114,7 +113,7 @@ class DictItem(Generic[K, V]): @strawberry.type class Query: - d: Value[List[DictItem[int, str]]] + d: Value[list[DictItem[int, str]]] schema = strawberry.Schema(query=Query) diff --git a/tests/plugins/strawberry_exceptions.py b/tests/plugins/strawberry_exceptions.py index f937b3834a..e1623f66fc 100644 --- a/tests/plugins/strawberry_exceptions.py +++ b/tests/plugins/strawberry_exceptions.py @@ -2,9 +2,9 @@ import os import re from collections import defaultdict +from collections.abc import Generator from dataclasses import dataclass from pathlib import Path -from typing import DefaultDict, Generator, List, Type import pytest import rich @@ -39,7 +39,7 @@ def suppress_output(verbosity_level: int = 0) -> Generator[None, None, None]: class StrawberryExceptionsPlugin: def __init__(self, verbosity_level: int) -> None: - self._info: DefaultDict[Type[StrawberryException], List[Result]] = defaultdict( + self._info: defaultdict[type[StrawberryException], list[Result]] = defaultdict( list ) self.verbosity_level = verbosity_level diff --git a/tests/python_312/test_generic_objects.py b/tests/python_312/test_generic_objects.py index 107218e1fd..379a1b4608 100644 --- a/tests/python_312/test_generic_objects.py +++ b/tests/python_312/test_generic_objects.py @@ -2,8 +2,7 @@ import datetime import sys -from typing import List, Optional, Union -from typing_extensions import Annotated +from typing import Annotated, Optional, Union import pytest @@ -109,7 +108,7 @@ class Edge[T]: @strawberry.type class Connection[T]: - edges: List[Edge[T]] + edges: list[Edge[T]] definition = get_object_definition(Connection, strict=True) assert definition.is_graphql_generic @@ -146,9 +145,9 @@ class Value[T]: @strawberry.type class Foo: string: Value[str] - strings: Value[List[str]] + strings: Value[list[str]] optional_string: Value[Optional[str]] - optional_strings: Value[Optional[List[str]]] + optional_strings: Value[Optional[list[str]]] definition = get_object_definition(Foo, strict=True) assert not definition.is_graphql_generic @@ -196,7 +195,7 @@ class Edge[T]: def test_generic_with_list(): @strawberry.type class Connection[T]: - edges: List[T] + edges: list[T] definition = get_object_definition(Connection, strict=True) assert definition.is_graphql_generic @@ -225,7 +224,7 @@ class Connection[T]: def test_generic_with_list_of_optionals(): @strawberry.type class Connection[T]: - edges: List[Optional[T]] + edges: list[Optional[T]] definition = get_object_definition(Connection, strict=True) assert definition.is_graphql_generic @@ -374,7 +373,7 @@ class Edge[T]: @strawberry.type class Connection[T]: - edges: List[Edge[T]] + edges: list[Edge[T]] @strawberry.type class User: @@ -424,7 +423,7 @@ class Edge[T]: @strawberry.type class Query: - user: List[Edge[str]] + user: list[Edge[str]] query_definition = get_object_definition(Query, strict=True) assert query_definition.type_params == [] @@ -501,7 +500,7 @@ class Cat: @strawberry.type class Connection[T]: - nodes: List[T] + nodes: list[T] DogCat = Annotated[Union[Dog, Cat], strawberry.union("DogCat")] @@ -538,7 +537,7 @@ class Cat: @strawberry.type class Connection[T]: - nodes: List[T] + nodes: list[T] @strawberry.type class Query: @@ -591,7 +590,7 @@ def test_generic_with_arguments(): @strawberry.type class Collection[T]: @strawberry.field - def by_id(self, ids: List[int]) -> List[T]: + def by_id(self, ids: list[int]) -> list[T]: return [] @strawberry.type diff --git a/tests/python_312/test_generics_schema.py b/tests/python_312/test_generics_schema.py index fb8edb6131..d3611ac368 100644 --- a/tests/python_312/test_generics_schema.py +++ b/tests/python_312/test_generics_schema.py @@ -3,7 +3,7 @@ import sys import textwrap from enum import Enum -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union from typing_extensions import Self import pytest @@ -165,7 +165,7 @@ class Fruit: @strawberry.type class Edge[T]: cursor: strawberry.ID - nodes: List[T] + nodes: list[T] @strawberry.type class FruitEdge(Edge[Fruit]): ... @@ -359,7 +359,7 @@ class User: @strawberry.type class Edge[T]: - nodes: List[T] + nodes: list[T] @strawberry.type class Query: @@ -391,7 +391,7 @@ class User: @strawberry.type class Edge[T]: - nodes: List[Optional[T]] + nodes: list[Optional[T]] @strawberry.type class Query: @@ -427,7 +427,7 @@ class Edge[T]: @strawberry.type class Connection[T]: - edges: List[Edge[T]] + edges: list[Edge[T]] @strawberry.type class ConnectionWithMeta[T](Connection[T]): @@ -677,7 +677,7 @@ class Edge[T]: @strawberry.type class Query: @strawberry.field - def example(self) -> List[Union[Edge[int], Edge[str]]]: + def example(self) -> list[Union[Edge[int], Edge[str]]]: return [ Edge(cursor=strawberry.ID("1"), node=1), Edge(cursor=strawberry.ID("2"), node="string"), @@ -781,7 +781,7 @@ class User: @strawberry.type class Edge[T]: - nodes: List[T] + nodes: list[T] @strawberry.type class Query: @@ -816,7 +816,7 @@ class User: @strawberry.type class Edge[T]: - nodes: List[T] + nodes: list[T] @strawberry.type class Query: @@ -852,7 +852,7 @@ class User: @strawberry.type class Edge[T]: - nodes: List[T] + nodes: list[T] @strawberry.type class Query: @@ -887,7 +887,7 @@ def test_generic_with_arguments(): @strawberry.type class Collection[T]: @strawberry.field - def by_id(self, ids: List[int]) -> List[T]: + def by_id(self, ids: list[int]) -> list[T]: return [] @strawberry.type @@ -925,7 +925,7 @@ def edge(self, arg: T) -> bool: return bool(arg) @strawberry.field - def edges(self, args: List[T]) -> int: + def edges(self, args: list[T]) -> int: return len(args) @strawberry.type @@ -994,7 +994,7 @@ class Book(Node[str]): @strawberry.type class Query: @strawberry.field - def books(self) -> List[Book]: + def books(self) -> list[Book]: return list() schema = strawberry.Schema(query=Query) @@ -1021,7 +1021,7 @@ def test_self(): @strawberry.interface class INode: field: Optional[Self] - fields: List[Self] + fields: list[Self] @strawberry.type class Node(INode): ... diff --git a/tests/relay/schema.py b/tests/relay/schema.py index 5bd8cea2be..fb8934b830 100644 --- a/tests/relay/schema.py +++ b/tests/relay/schema.py @@ -1,18 +1,20 @@ import dataclasses -from typing import ( - Any, +from collections.abc import ( AsyncGenerator, AsyncIterable, AsyncIterator, Generator, Iterable, Iterator, - List, +) +from typing import ( + Annotated, + Any, NamedTuple, Optional, cast, ) -from typing_extensions import Annotated, Self, TypeAlias +from typing_extensions import Self, TypeAlias import strawberry from strawberry import relay @@ -200,9 +202,9 @@ class Query: node_with_async_permissions: relay.Node = relay.node( permission_classes=[DummyPermission] ) - nodes: List[relay.Node] = relay.node() + nodes: list[relay.Node] = relay.node() node_optional: Optional[relay.Node] = relay.node() - nodes_optional: List[Optional[relay.Node]] = relay.node() + nodes_optional: list[Optional[relay.Node]] = relay.node() fruits: relay.ListConnection[Fruit] = relay.connection(resolver=fruits_resolver) fruits_lazy: relay.ListConnection[ Annotated["Fruit", strawberry.lazy("tests.relay.schema")] @@ -224,7 +226,7 @@ def fruits_concrete_resolver( self, info: strawberry.Info, name_endswith: Optional[str] = None, - ) -> List[Fruit]: + ) -> list[Fruit]: # This is mimicing integrations, like Django return [ cast( @@ -244,7 +246,7 @@ def fruits_custom_resolver( self, info: strawberry.Info, name_endswith: Optional[str] = None, - ) -> List[Fruit]: + ) -> list[Fruit]: return [ f for f in fruits.values() @@ -256,7 +258,7 @@ def fruits_custom_resolver_lazy( self, info: strawberry.Info, name_endswith: Optional[str] = None, - ) -> List[Annotated["Fruit", strawberry.lazy("tests.relay.schema")]]: + ) -> list[Annotated["Fruit", strawberry.lazy("tests.relay.schema")]]: return [ f for f in fruits.values() @@ -328,7 +330,7 @@ def fruit_alike_connection_custom_resolver( self, info: strawberry.Info, name_endswith: Optional[str] = None, - ) -> List[FruitAlike]: + ) -> list[FruitAlike]: return [ FruitAlike(f.id, f.name, f.color) for f in fruits.values() @@ -336,7 +338,7 @@ def fruit_alike_connection_custom_resolver( ] @strawberry.relay.connection(strawberry.relay.ListConnection[Fruit]) - def some_fruits(self) -> List[Fruit]: + def some_fruits(self) -> list[Fruit]: return [Fruit(id=x, name="apple", color="green") for x in range(200)] diff --git a/tests/relay/schema_future_annotations.py b/tests/relay/schema_future_annotations.py index 464c197f0c..965737328c 100644 --- a/tests/relay/schema_future_annotations.py +++ b/tests/relay/schema_future_annotations.py @@ -1,20 +1,22 @@ from __future__ import annotations import dataclasses -from typing import ( - Any, +from collections.abc import ( AsyncGenerator, AsyncIterable, AsyncIterator, Generator, Iterable, Iterator, - List, +) +from typing import ( + Annotated, + Any, NamedTuple, Optional, cast, ) -from typing_extensions import Annotated, Self +from typing_extensions import Self import strawberry from strawberry import relay @@ -199,9 +201,9 @@ class Query: node_with_async_permissions: relay.Node = relay.node( permission_classes=[DummyPermission] ) - nodes: List[relay.Node] = relay.node() + nodes: list[relay.Node] = relay.node() node_optional: Optional[relay.Node] = relay.node() - nodes_optional: List[Optional[relay.Node]] = relay.node() + nodes_optional: list[Optional[relay.Node]] = relay.node() fruits: relay.ListConnection[Fruit] = relay.connection(resolver=fruits_resolver) fruits_lazy: relay.ListConnection[ Annotated[Fruit, strawberry.lazy("tests.relay.schema")] @@ -218,7 +220,7 @@ def fruits_concrete_resolver( self, info: strawberry.Info, name_endswith: Optional[str] = None, - ) -> List[Fruit]: + ) -> list[Fruit]: # This is mimicing integrations, like Django return [ cast( @@ -238,7 +240,7 @@ def fruits_custom_resolver( self, info: strawberry.Info, name_endswith: Optional[str] = None, - ) -> List[Fruit]: + ) -> list[Fruit]: return [ f for f in fruits.values() @@ -250,7 +252,7 @@ def fruits_custom_resolver_lazy( self, info: strawberry.Info, name_endswith: Optional[str] = None, - ) -> List[Annotated[Fruit, strawberry.lazy("tests.relay.schema")]]: + ) -> list[Annotated[Fruit, strawberry.lazy("tests.relay.schema")]]: return [ f for f in fruits.values() @@ -322,7 +324,7 @@ def fruit_alike_connection_custom_resolver( self, info: strawberry.Info, name_endswith: Optional[str] = None, - ) -> List[FruitAlike]: + ) -> list[FruitAlike]: return [ FruitAlike(f.id, f.name, f.color) for f in fruits.values() @@ -330,7 +332,7 @@ def fruit_alike_connection_custom_resolver( ] @strawberry.relay.connection(strawberry.relay.ListConnection[Fruit]) - def some_fruits(self) -> List[Fruit]: + def some_fruits(self) -> list[Fruit]: return [Fruit(id=x, name="apple", color="green") for x in range(200)] diff --git a/tests/relay/test_connection.py b/tests/relay/test_connection.py index f84d398895..f97bb0d669 100644 --- a/tests/relay/test_connection.py +++ b/tests/relay/test_connection.py @@ -1,5 +1,6 @@ import sys -from typing import Any, Iterable, List, Optional +from collections.abc import Iterable +from typing import Any, Optional from typing_extensions import Self import pytest @@ -16,8 +17,8 @@ class User(Node): @classmethod def resolve_nodes( - cls, *, info: strawberry.Info, node_ids: List[Any], required: bool - ) -> List[Self]: + cls, *, info: strawberry.Info, node_ids: list[Any], required: bool + ) -> list[Self]: return [cls() for _ in node_ids] @@ -48,7 +49,7 @@ def test_nullable_connection_with_optional(): @strawberry.type class Query: @strawberry.relay.connection(Optional[UserConnection]) - def users(self) -> Optional[List[User]]: + def users(self) -> Optional[list[User]]: return None schema = strawberry.Schema(query=Query) @@ -77,7 +78,7 @@ def test_nullable_connection_with_pipe(): @strawberry.type class Query: @strawberry.relay.connection(UserConnection | None) - def users(self) -> List[User] | None: + def users(self) -> list[User] | None: return None schema = strawberry.Schema(query=Query) @@ -104,7 +105,7 @@ class Query: @strawberry.relay.connection( Optional[UserConnection], permission_classes=[TestPermission] ) - def users(self) -> Optional[List[User]]: # pragma: no cover + def users(self) -> Optional[list[User]]: # pragma: no cover pytest.fail("Should not have been called...") schema = strawberry.Schema(query=Query) diff --git a/tests/relay/test_exceptions.py b/tests/relay/test_exceptions.py index 174c3748f7..345ae401bb 100644 --- a/tests/relay/test_exceptions.py +++ b/tests/relay/test_exceptions.py @@ -1,5 +1,3 @@ -from typing import List - import pytest import strawberry @@ -75,7 +73,7 @@ class Fruit(relay.Node): @strawberry.type class Query: @relay.connection(relay.ListConnection[Fruit]) - def fruits(self) -> List[Fruit]: ... + def fruits(self) -> list[Fruit]: ... # pragma: no cover strawberry.Schema(query=Query) @@ -93,7 +91,7 @@ class Fruit(relay.Node): @strawberry.type class Query: @relay.connection(relay.ListConnection[Fruit]) - def fruits(self) -> List[Fruit]: ... + def fruits(self) -> list[Fruit]: ... # pragma: no cover strawberry.Schema(query=Query) @@ -112,7 +110,7 @@ class Fruit(relay.Node): @strawberry.type class Query: - fruits_conn: List[Fruit] = relay.connection() + fruits_conn: list[Fruit] = relay.connection() strawberry.Schema(query=Query) @@ -131,8 +129,8 @@ class Fruit(relay.Node): @strawberry.type class Query: - @relay.connection(List[Fruit]) # type: ignore - def custom_resolver(self) -> List[Fruit]: ... + @relay.connection(list[Fruit]) # type: ignore + def custom_resolver(self) -> list[Fruit]: ... # pragma: no cover strawberry.Schema(query=Query) @@ -152,6 +150,6 @@ class Fruit(relay.Node): @strawberry.type class Query: @relay.connection(relay.Connection[Fruit]) # type: ignore - def custom_resolver(self): ... + def custom_resolver(self): ... # pragma: no cover strawberry.Schema(query=Query) diff --git a/tests/relay/test_fields.py b/tests/relay/test_fields.py index 9993697d17..2ef08c3713 100644 --- a/tests/relay/test_fields.py +++ b/tests/relay/test_fields.py @@ -1,5 +1,4 @@ import textwrap -from typing import List import pytest from pytest_mock import MockerFixture @@ -1460,7 +1459,7 @@ def test_parameters(mocker: MockerFixture): class CustomField(StrawberryField): @property - def arguments(self) -> List[StrawberryArgument]: + def arguments(self) -> list[StrawberryArgument]: return [ *super().arguments, StrawberryArgument( @@ -1472,7 +1471,7 @@ def arguments(self) -> List[StrawberryArgument]: ] @arguments.setter - def arguments(self, value: List[StrawberryArgument]): + def arguments(self, value: list[StrawberryArgument]): cls = self.__class__ return super(cls, cls).arguments.fset(self, value) @@ -1480,7 +1479,7 @@ def arguments(self, value: List[StrawberryArgument]): class Fruit(relay.Node): code: relay.NodeID[str] - def resolver(info: strawberry.Info) -> List[Fruit]: ... + def resolver(info: strawberry.Info) -> list[Fruit]: ... @strawberry.type class Query: diff --git a/tests/relay/test_schema.py b/tests/relay/test_schema.py index a1bca4cb8a..22c8897e08 100644 --- a/tests/relay/test_schema.py +++ b/tests/relay/test_schema.py @@ -1,6 +1,5 @@ import pathlib import textwrap -from typing import List from pytest_mock import MockerFixture from pytest_snapshot.plugin import Snapshot @@ -44,7 +43,7 @@ class Fruit(relay.Node): @strawberry.type class Query: @relay.connection(relay.ListConnection[Fruit]) - def fruits(self) -> List[Fruit]: + def fruits(self) -> list[Fruit]: return [Fruit(code=i) for i in range(10)] schema = strawberry.Schema(query=Query) @@ -154,7 +153,7 @@ class Fruit(BaseFruit): ... @strawberry.type class Query: @relay.connection(relay.ListConnection[Fruit]) - def fruits(self) -> List[Fruit]: + def fruits(self) -> list[Fruit]: return [Fruit(code=i) for i in range(10)] schema = strawberry.Schema(query=Query) @@ -265,7 +264,7 @@ class Fruit(BaseFruit): @strawberry.type class Query: @relay.connection(relay.ListConnection[Fruit]) - def fruits(self) -> List[Fruit]: + def fruits(self) -> list[Fruit]: return [Fruit(code=i, other_code=i) for i in range(10)] schema = strawberry.Schema(query=Query) diff --git a/tests/relay/test_types.py b/tests/relay/test_types.py index 1171d7896a..756f97822f 100644 --- a/tests/relay/test_types.py +++ b/tests/relay/test_types.py @@ -1,4 +1,5 @@ -from typing import Any, AsyncGenerator, AsyncIterable, Optional, Union, cast +from collections.abc import AsyncGenerator, AsyncIterable +from typing import Any, Optional, Union, cast from typing_extensions import assert_type from unittest.mock import MagicMock diff --git a/tests/schema/extensions/schema_extensions/conftest.py b/tests/schema/extensions/schema_extensions/conftest.py index 15023cea80..c859269d21 100644 --- a/tests/schema/extensions/schema_extensions/conftest.py +++ b/tests/schema/extensions/schema_extensions/conftest.py @@ -1,7 +1,8 @@ import contextlib import dataclasses import enum -from typing import Any, AsyncGenerator, List, Type +from collections.abc import AsyncGenerator +from typing import Any import pytest @@ -35,7 +36,7 @@ def __init_subclass__(cls, **kwargs: Any): "on_operation Exited", "get_results", ] - called_hooks: List[str] + called_hooks: list[str] @classmethod def assert_expected(cls) -> None: @@ -85,7 +86,7 @@ def exec_type(request: pytest.FixtureRequest) -> ExecType: @contextlib.contextmanager -def hook_wrap(list_: List[str], hook_name: str): +def hook_wrap(list_: list[str], hook_name: str): list_.append(f"{hook_name} Entered") try: yield @@ -94,7 +95,7 @@ def hook_wrap(list_: List[str], hook_name: str): @pytest.fixture() -def async_extension() -> Type[ExampleExtension]: +def async_extension() -> type[ExampleExtension]: class MyExtension(ExampleExtension): async def on_operation(self): with hook_wrap(self.called_hooks, SchemaExtension.on_operation.__name__): diff --git a/tests/schema/extensions/schema_extensions/test_extensions.py b/tests/schema/extensions/schema_extensions/test_extensions.py index 648fa016fc..f569e13fbd 100644 --- a/tests/schema/extensions/schema_extensions/test_extensions.py +++ b/tests/schema/extensions/schema_extensions/test_extensions.py @@ -1,7 +1,7 @@ import contextlib import json import warnings -from typing import Any, List, Optional, Type +from typing import Any, Optional from unittest.mock import patch import pytest @@ -206,7 +206,7 @@ def on_operation(self): @pytest.fixture() -def sync_extension() -> Type[ExampleExtension]: +def sync_extension() -> type[ExampleExtension]: class MyExtension(ExampleExtension): def on_operation(self): with hook_wrap(self.called_hooks, SchemaExtension.on_operation.__name__): @@ -237,7 +237,7 @@ def resolve(self, _next, root, info, *args: str, **kwargs: Any): @pytest.mark.asyncio async def test_async_extension_hooks( - default_query_types_and_query: SchemaHelper, async_extension: Type[ExampleExtension] + default_query_types_and_query: SchemaHelper, async_extension: type[ExampleExtension] ): schema = strawberry.Schema( query=default_query_types_and_query.query_type, extensions=[async_extension] @@ -817,7 +817,7 @@ def ping(self) -> str: def test_extension_execution_order_sync(): """Ensure mixed hooks (async & sync) are called correctly.""" - execution_order: List[Type[SchemaExtension]] = [] + execution_order: list[type[SchemaExtension]] = [] class ExtensionB(SchemaExtension): def on_execute(self): diff --git a/tests/schema/extensions/schema_extensions/test_subscription.py b/tests/schema/extensions/schema_extensions/test_subscription.py index bd20dc6c83..8ce3b6ba84 100644 --- a/tests/schema/extensions/schema_extensions/test_subscription.py +++ b/tests/schema/extensions/schema_extensions/test_subscription.py @@ -1,4 +1,4 @@ -from typing import AsyncGenerator, Type +from collections.abc import AsyncGenerator import pytest @@ -20,7 +20,7 @@ def assert_agen(obj) -> AsyncGenerator[ExecutionResult, None]: async def test_subscription_success_many_fields( - default_query_types_and_query: SchemaHelper, async_extension: Type[ExampleExtension] + default_query_types_and_query: SchemaHelper, async_extension: type[ExampleExtension] ) -> None: schema = strawberry.Schema( query=default_query_types_and_query.query_type, @@ -54,7 +54,7 @@ async def test_subscription_success_many_fields( async def test_subscription_extension_handles_immediate_errors( - default_query_types_and_query: SchemaHelper, async_extension: Type[ExampleExtension] + default_query_types_and_query: SchemaHelper, async_extension: type[ExampleExtension] ) -> None: @strawberry.type() class Subscription: @@ -86,7 +86,7 @@ async def count(self) -> AsyncGenerator[int, None]: async def test_error_after_first_yield_in_subscription( - default_query_types_and_query: SchemaHelper, async_extension: Type[ExampleExtension] + default_query_types_and_query: SchemaHelper, async_extension: type[ExampleExtension] ) -> None: @strawberry.type() class Subscription: diff --git a/tests/schema/extensions/test_datadog.py b/tests/schema/extensions/test_datadog.py index 0af13877a7..7f32751eaf 100644 --- a/tests/schema/extensions/test_datadog.py +++ b/tests/schema/extensions/test_datadog.py @@ -1,5 +1,6 @@ import typing -from typing import Any, AsyncGenerator, Tuple, Type +from collections.abc import AsyncGenerator +from typing import Any import pytest @@ -10,7 +11,7 @@ @pytest.fixture -def datadog_extension(mocker) -> Tuple[Type["DatadogTracingExtension"], Any]: +def datadog_extension(mocker) -> tuple[type["DatadogTracingExtension"], Any]: datadog_mock = mocker.MagicMock() mocker.patch.dict("sys.modules", ddtrace=datadog_mock) @@ -21,7 +22,7 @@ def datadog_extension(mocker) -> Tuple[Type["DatadogTracingExtension"], Any]: @pytest.fixture -def datadog_extension_sync(mocker) -> Tuple[Type["DatadogTracingExtension"], Any]: +def datadog_extension_sync(mocker) -> tuple[type["DatadogTracingExtension"], Any]: datadog_mock = mocker.MagicMock() mocker.patch.dict("sys.modules", ddtrace=datadog_mock) diff --git a/tests/schema/extensions/test_field_extensions.py b/tests/schema/extensions/test_field_extensions.py index a59eaa9938..c0e8bdb613 100644 --- a/tests/schema/extensions/test_field_extensions.py +++ b/tests/schema/extensions/test_field_extensions.py @@ -1,6 +1,5 @@ import re -from typing import Any, Callable, Optional -from typing_extensions import Annotated +from typing import Annotated, Any, Callable, Optional import pytest diff --git a/tests/schema/extensions/test_input_mutation.py b/tests/schema/extensions/test_input_mutation.py index e63d070029..a86d8a109f 100644 --- a/tests/schema/extensions/test_input_mutation.py +++ b/tests/schema/extensions/test_input_mutation.py @@ -1,5 +1,5 @@ import textwrap -from typing_extensions import Annotated +from typing import Annotated import strawberry from strawberry.field_extensions import InputMutationExtension diff --git a/tests/schema/extensions/test_input_mutation_federation.py b/tests/schema/extensions/test_input_mutation_federation.py index 3aaa84fb98..3c6b8518d9 100644 --- a/tests/schema/extensions/test_input_mutation_federation.py +++ b/tests/schema/extensions/test_input_mutation_federation.py @@ -1,5 +1,5 @@ import textwrap -from typing_extensions import Annotated +from typing import Annotated import strawberry from strawberry.field_extensions import InputMutationExtension diff --git a/tests/schema/extensions/test_query_depth_limiter.py b/tests/schema/extensions/test_query_depth_limiter.py index 22ed5be6c0..89c1cec115 100644 --- a/tests/schema/extensions/test_query_depth_limiter.py +++ b/tests/schema/extensions/test_query_depth_limiter.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union import pytest from graphql import ( @@ -47,7 +47,7 @@ class Human: name: str email: str address: Address - pets: List[Pet] + pets: list[Pet] @strawberry.input @@ -69,7 +69,7 @@ def user( pass @strawberry.field - def users(self, names: Optional[List[str]]) -> List[Human]: + def users(self, names: Optional[list[str]]) -> list[Human]: pass @strawberry.field @@ -87,7 +87,7 @@ def cat(bio: Biography) -> Cat: def run_query( query: str, max_depth: int, should_ignore: ShouldIgnoreType = None -) -> Tuple[List[GraphQLError], Union[Dict[str, int], None]]: +) -> tuple[list[GraphQLError], Union[dict[str, int], None]]: document = parse(query) result = None diff --git a/tests/schema/test_annotated/type_a.py b/tests/schema/test_annotated/type_a.py index a6bb8146c0..a57c145fe7 100644 --- a/tests/schema/test_annotated/type_a.py +++ b/tests/schema/test_annotated/type_a.py @@ -1,7 +1,6 @@ from __future__ import annotations -from typing import Optional -from typing_extensions import Annotated +from typing import Annotated, Optional from uuid import UUID import strawberry diff --git a/tests/schema/test_annotated/type_b.py b/tests/schema/test_annotated/type_b.py index dfd9186876..4ecbb6efa7 100644 --- a/tests/schema/test_annotated/type_b.py +++ b/tests/schema/test_annotated/type_b.py @@ -1,7 +1,6 @@ from __future__ import annotations -from typing import Optional -from typing_extensions import Annotated +from typing import Annotated, Optional from uuid import UUID import strawberry diff --git a/tests/schema/test_arguments.py b/tests/schema/test_arguments.py index 6020fd5a68..478c3723de 100644 --- a/tests/schema/test_arguments.py +++ b/tests/schema/test_arguments.py @@ -1,7 +1,6 @@ import textwrap from textwrap import dedent -from typing import Optional -from typing_extensions import Annotated +from typing import Annotated, Optional import strawberry from strawberry.types.unset import UNSET diff --git a/tests/schema/test_dataloaders.py b/tests/schema/test_dataloaders.py index 755460df69..099217d98b 100644 --- a/tests/schema/test_dataloaders.py +++ b/tests/schema/test_dataloaders.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import List import pytest @@ -13,7 +12,7 @@ async def test_can_use_dataloaders(mocker): class User: id: str - async def idx(keys) -> List[User]: + async def idx(keys) -> list[User]: return [User(key) for key in keys] mock_loader = mocker.Mock(side_effect=idx) diff --git a/tests/schema/test_directives.py b/tests/schema/test_directives.py index adfcd29f6b..b67391b28b 100644 --- a/tests/schema/test_directives.py +++ b/tests/schema/test_directives.py @@ -1,6 +1,6 @@ import textwrap from enum import Enum -from typing import Any, Dict, List, NoReturn, Optional +from typing import Any, NoReturn, Optional import pytest @@ -305,7 +305,7 @@ def person(self) -> Person: return Person() @strawberry.directive(locations=[DirectiveLocation.FIELD]) - def replace(value: DirectiveValue[str], old_list: List[str], new: str): + def replace(value: DirectiveValue[str], old_list: list[str], new: str): for old in old_list: value = value.replace(old, new) @@ -412,7 +412,7 @@ class Locale(Enum): EN: str = "EN" NL: str = "NL" - greetings: Dict[Locale, str] = { + greetings: dict[Locale, str] = { Locale.EN: "Hello {username}", Locale.NL: "Hallo {username}", } @@ -594,7 +594,7 @@ def greeting(self) -> str: return "Hi" @strawberry.directive(locations=[DirectiveLocation.FIELD]) - def append_names(value: DirectiveValue[str], names: List[str]): + def append_names(value: DirectiveValue[str], names: list[str]): assert isinstance(names, list) return f"{value} {', '.join(names)}" @@ -695,7 +695,7 @@ class CustomInfo(Info): test: str = "foo" @strawberry.directive(locations=[DirectiveLocation.FIELD]) - def append_names(value: DirectiveValue[str], names: List[str], info: CustomInfo): + def append_names(value: DirectiveValue[str], names: list[str], info: CustomInfo): assert isinstance(names, list) assert isinstance(info, CustomInfo) assert Info in type(info).__bases__ # Explicitly check it's not Info. diff --git a/tests/schema/test_enum.py b/tests/schema/test_enum.py index 4c6d566e3c..7adda83ea2 100644 --- a/tests/schema/test_enum.py +++ b/tests/schema/test_enum.py @@ -1,8 +1,7 @@ import typing from enum import Enum from textwrap import dedent -from typing import List, Optional -from typing_extensions import Annotated +from typing import Annotated, Optional import pytest @@ -164,7 +163,7 @@ class IceCreamFlavour(Enum): @strawberry.type class Query: @strawberry.field - def best_flavours(self) -> List[IceCreamFlavour]: + def best_flavours(self) -> list[IceCreamFlavour]: return [IceCreamFlavour.STRAWBERRY, IceCreamFlavour.PISTACHIO] schema = strawberry.Schema(query=Query) @@ -188,7 +187,7 @@ class IceCreamFlavour(Enum): @strawberry.type class Query: @strawberry.field - def best_flavours(self) -> Optional[List[IceCreamFlavour]]: + def best_flavours(self) -> Optional[list[IceCreamFlavour]]: return None schema = strawberry.Schema(query=Query) @@ -237,7 +236,7 @@ class IceCreamFlavour(Enum): @strawberry.type class Query: @strawberry.field - async def best_flavours(self) -> List[IceCreamFlavour]: + async def best_flavours(self) -> list[IceCreamFlavour]: return [IceCreamFlavour.STRAWBERRY, IceCreamFlavour.PISTACHIO] schema = strawberry.Schema(query=Query) diff --git a/tests/schema/test_extensions.py b/tests/schema/test_extensions.py index 6f54719669..542f8c1ef9 100644 --- a/tests/schema/test_extensions.py +++ b/tests/schema/test_extensions.py @@ -1,6 +1,5 @@ from enum import Enum, auto -from typing import Union, cast -from typing_extensions import Annotated +from typing import Annotated, Union, cast from graphql import ( DirectiveLocation, @@ -53,7 +52,7 @@ class Query: def test_directive(): @strawberry.directive(locations=[DirectiveLocation.FIELD]) - def uppercase(value: DirectiveValue[str], foo: str): + def uppercase(value: DirectiveValue[str], foo: str): # pragma: no cover return value.upper() @strawberry.type() diff --git a/tests/schema/test_generics.py b/tests/schema/test_generics.py index 17e582afd7..b442d2e5cd 100644 --- a/tests/schema/test_generics.py +++ b/tests/schema/test_generics.py @@ -1,6 +1,6 @@ import textwrap from enum import Enum -from typing import Any, Generic, List, Optional, TypeVar, Union +from typing import Any, Generic, Optional, TypeVar, Union from typing_extensions import Self import strawberry @@ -166,7 +166,7 @@ class Fruit: @strawberry.type class Edge(Generic[T]): cursor: strawberry.ID - nodes: List[T] + nodes: list[T] @strawberry.type class FruitEdge(Edge[Fruit]): ... @@ -325,7 +325,7 @@ class User: @strawberry.type class Edge(Generic[T]): - nodes: List[T] + nodes: list[T] @strawberry.type class Query: @@ -359,7 +359,7 @@ class User: @strawberry.type class Edge(Generic[T]): - nodes: List[Optional[T]] + nodes: list[Optional[T]] @strawberry.type class Query: @@ -397,7 +397,7 @@ class Edge(Generic[T]): @strawberry.type class Connection(Generic[T]): - edges: List[Edge[T]] + edges: list[Edge[T]] @strawberry.type class ConnectionWithMeta(Connection[T]): @@ -608,7 +608,7 @@ class Edge(Generic[T]): @strawberry.type class Query: @strawberry.field - def example(self) -> List[Union[Edge[int], Edge[str]]]: + def example(self) -> list[Union[Edge[int], Edge[str]]]: return [ Edge(cursor=strawberry.ID("1"), node=1), Edge(cursor=strawberry.ID("2"), node="string"), @@ -673,7 +673,7 @@ class Edge(Generic[T]): @strawberry.type class Connection(Generic[T]): - edges: List[Edge[T]] + edges: list[Edge[T]] @strawberry.type class Entity1: @@ -772,7 +772,7 @@ class User: @strawberry.type class Edge(Generic[T]): - nodes: List[T] + nodes: list[T] @strawberry.type class Query: @@ -809,7 +809,7 @@ class User: @strawberry.type class Edge(Generic[T]): - nodes: List[T] + nodes: list[T] @strawberry.type class Query: @@ -846,7 +846,7 @@ class User: @strawberry.type class Edge(Generic[T]): - nodes: List[T] + nodes: list[T] @strawberry.type class Query: @@ -882,7 +882,7 @@ def test_generic_with_arguments(): @strawberry.type class Collection(Generic[T]): @strawberry.field - def by_id(self, ids: List[int]) -> List[T]: + def by_id(self, ids: list[int]) -> list[T]: return [] @strawberry.type @@ -922,7 +922,7 @@ def edge(self, arg: T) -> bool: return bool(arg) @strawberry.field - def edges(self, args: List[T]) -> int: + def edges(self, args: list[T]) -> int: return len(args) @strawberry.type @@ -995,7 +995,7 @@ class Book(Node[str]): @strawberry.type class Query: @strawberry.field - def books(self) -> List[Book]: + def books(self) -> list[Book]: return list() schema = strawberry.Schema(query=Query) @@ -1022,7 +1022,7 @@ def test_self(): @strawberry.interface class INode: field: Optional[Self] - fields: List[Self] + fields: list[Self] @strawberry.type class Node(INode): ... @@ -1180,7 +1180,7 @@ def test_generic_with_interface(): @strawberry.type class Pagination(Generic[T]): - items: List[T] + items: list[T] @strawberry.interface class TestInterface: diff --git a/tests/schema/test_generics_nested.py b/tests/schema/test_generics_nested.py index a4de77b29f..e541c136d8 100644 --- a/tests/schema/test_generics_nested.py +++ b/tests/schema/test_generics_nested.py @@ -1,5 +1,5 @@ import textwrap -from typing import Generic, List, Optional, TypeVar, Union +from typing import Generic, Optional, TypeVar, Union import strawberry from strawberry.scalars import JSON @@ -59,19 +59,19 @@ class JsonBlock: data: JSON @strawberry.type - class BlockRowType(Generic[T]): + class BlockRowtype(Generic[T]): total: int - items: List[T] + items: list[T] @strawberry.type class Query: @strawberry.field def blocks( self, - ) -> List[Union[BlockRowType[int], BlockRowType[str], JsonBlock]]: + ) -> list[Union[BlockRowtype[int], BlockRowtype[str], JsonBlock]]: return [ - BlockRowType(total=3, items=["a", "b", "c"]), - BlockRowType(total=1, items=[1, 2, 3, 4]), + BlockRowtype(total=3, items=["a", "b", "c"]), + BlockRowtype(total=1, items=[1, 2, 3, 4]), JsonBlock(data=JSON({"a": 1})), ] @@ -81,10 +81,10 @@ def blocks( """query { blocks { __typename - ... on IntBlockRowType { + ... on IntBlockRowtype { a: items } - ... on StrBlockRowType { + ... on StrBlockRowtype { b: items } ... on JsonBlock { @@ -98,8 +98,8 @@ def blocks( assert result.data == { "blocks": [ - {"__typename": "StrBlockRowType", "b": ["a", "b", "c"]}, - {"__typename": "IntBlockRowType", "a": [1, 2, 3, 4]}, + {"__typename": "StrBlockRowtype", "b": ["a", "b", "c"]}, + {"__typename": "IntBlockRowtype", "a": [1, 2, 3, 4]}, {"__typename": "JsonBlock", "data": {"a": 1}}, ] } @@ -113,19 +113,19 @@ class JsonBlock: data: JSON @strawberry.type - class BlockRowType(Generic[T]): + class BlockRowtype(Generic[T]): total: int - items: List[T] + items: list[T] @strawberry.type class Query: @strawberry.field def blocks( self, - ) -> List[Union[BlockRowType[int], BlockRowType[str], JsonBlock]]: + ) -> list[Union[BlockRowtype[int], BlockRowtype[str], JsonBlock]]: return [ - BlockRowType(total=3, items=[]), - BlockRowType(total=1, items=[]), + BlockRowtype(total=3, items=[]), + BlockRowtype(total=1, items=[]), JsonBlock(data=JSON({"a": 1})), ] @@ -135,10 +135,10 @@ def blocks( """query { blocks { __typename - ... on IntBlockRowType { + ... on IntBlockRowtype { a: items } - ... on StrBlockRowType { + ... on StrBlockRowtype { b: items } ... on JsonBlock { @@ -152,8 +152,8 @@ def blocks( assert result.data == { "blocks": [ - {"__typename": "IntBlockRowType", "a": []}, - {"__typename": "IntBlockRowType", "a": []}, + {"__typename": "IntBlockRowtype", "a": []}, + {"__typename": "IntBlockRowtype", "a": []}, {"__typename": "JsonBlock", "data": {"a": 1}}, ] } @@ -167,19 +167,19 @@ class JsonBlock: data: JSON @strawberry.type - class BlockRowType(Generic[T]): + class BlockRowtype(Generic[T]): total: int - items: List[List[T]] + items: list[list[T]] @strawberry.type class Query: @strawberry.field def blocks( self, - ) -> List[Union[BlockRowType[int], BlockRowType[str], JsonBlock]]: + ) -> list[Union[BlockRowtype[int], BlockRowtype[str], JsonBlock]]: return [ - BlockRowType(total=3, items=[["a", "b", "c"]]), - BlockRowType(total=1, items=[[1, 2, 3, 4]]), + BlockRowtype(total=3, items=[["a", "b", "c"]]), + BlockRowtype(total=1, items=[[1, 2, 3, 4]]), JsonBlock(data=JSON({"a": 1})), ] @@ -189,10 +189,10 @@ def blocks( """query { blocks { __typename - ... on IntBlockRowType { + ... on IntBlockRowtype { a: items } - ... on StrBlockRowType { + ... on StrBlockRowtype { b: items } ... on JsonBlock { @@ -206,8 +206,8 @@ def blocks( assert result.data == { "blocks": [ - {"__typename": "StrBlockRowType", "b": [["a", "b", "c"]]}, - {"__typename": "IntBlockRowType", "a": [[1, 2, 3, 4]]}, + {"__typename": "StrBlockRowtype", "b": [["a", "b", "c"]]}, + {"__typename": "IntBlockRowtype", "a": [[1, 2, 3, 4]]}, {"__typename": "JsonBlock", "data": {"a": 1}}, ] } @@ -226,22 +226,22 @@ class JsonBlock(BlockInterface): data: JSON @strawberry.type - class BlockRowType(BlockInterface, Generic[T]): + class BlockRowtype(BlockInterface, Generic[T]): total: int - items: List[T] + items: list[T] @strawberry.type class Query: @strawberry.field - def blocks(self) -> List[BlockInterface]: + def blocks(self) -> list[BlockInterface]: return [ - BlockRowType(id=strawberry.ID("3"), total=3, items=["a", "b", "c"]), - BlockRowType(id=strawberry.ID("1"), total=1, items=[1, 2, 3, 4]), + BlockRowtype(id=strawberry.ID("3"), total=3, items=["a", "b", "c"]), + BlockRowtype(id=strawberry.ID("1"), total=1, items=[1, 2, 3, 4]), JsonBlock(id=strawberry.ID("2"), data=JSON({"a": 1})), ] schema = strawberry.Schema( - query=Query, types=[BlockRowType[int], JsonBlock, BlockRowType[str]] + query=Query, types=[BlockRowtype[int], JsonBlock, BlockRowtype[str]] ) expected_schema = textwrap.dedent( @@ -251,7 +251,7 @@ def blocks(self) -> List[BlockInterface]: disclaimer: String } - type IntBlockRowType implements BlockInterface { + type IntBlockRowtype implements BlockInterface { id: ID! disclaimer: String total: Int! @@ -273,7 +273,7 @@ def blocks(self) -> List[BlockInterface]: blocks: [BlockInterface!]! } - type StrBlockRowType implements BlockInterface { + type StrBlockRowtype implements BlockInterface { id: ID! disclaimer: String total: Int! @@ -289,10 +289,10 @@ def blocks(self) -> List[BlockInterface]: blocks { id __typename - ... on IntBlockRowType { + ... on IntBlockRowtype { a: items } - ... on StrBlockRowType { + ... on StrBlockRowtype { b: items } ... on JsonBlock { @@ -306,8 +306,8 @@ def blocks(self) -> List[BlockInterface]: assert result.data == { "blocks": [ - {"id": "3", "__typename": "StrBlockRowType", "b": ["a", "b", "c"]}, - {"id": "1", "__typename": "IntBlockRowType", "a": [1, 2, 3, 4]}, + {"id": "3", "__typename": "StrBlockRowtype", "b": ["a", "b", "c"]}, + {"id": "1", "__typename": "IntBlockRowtype", "a": [1, 2, 3, 4]}, {"id": "2", "__typename": "JsonBlock", "data": {"a": 1}}, ] } diff --git a/tests/schema/test_info.py b/tests/schema/test_info.py index 70fddea56d..d4cdd18955 100644 --- a/tests/schema/test_info.py +++ b/tests/schema/test_info.py @@ -1,7 +1,6 @@ import dataclasses import json -from typing import List, Optional -from typing_extensions import Annotated +from typing import Annotated, Optional import pytest @@ -299,8 +298,8 @@ def hello( ("return_type", "return_value"), [ (str, "text"), - (List[str], ["text"]), - (Optional[List[int]], None), + (list[str], ["text"]), + (Optional[list[int]], None), ], ) def test_return_type_from_resolver(return_type, return_value): diff --git a/tests/schema/test_interface.py b/tests/schema/test_interface.py index 31f106b09e..bfa7b0edd4 100644 --- a/tests/schema/test_interface.py +++ b/tests/schema/test_interface.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, List +from typing import Any import pytest from pytest_mock import MockerFixture @@ -24,7 +24,7 @@ class Italian(Cheese): @strawberry.type class Root: @strawberry.field - def assortment(self) -> List[Cheese]: + def assortment(self) -> list[Cheese]: return [ Italian(name="Asiago", province="Friuli"), Swiss(name="Tomme", canton="Vaud"), @@ -366,7 +366,7 @@ class Person(NamedNode): @strawberry.type class Query: @strawberry.field - def friends(self) -> List[NamedNode]: + def friends(self) -> list[NamedNode]: return [Person(id=1, name="foo"), Person(id=2, name="bar")] schema = strawberry.Schema(Query, types=[Person]) diff --git a/tests/schema/test_lazy/test_lazy_generic.py b/tests/schema/test_lazy/test_lazy_generic.py index e765e1ce63..1f1f62091e 100644 --- a/tests/schema/test_lazy/test_lazy_generic.py +++ b/tests/schema/test_lazy/test_lazy_generic.py @@ -3,9 +3,9 @@ import sys import sysconfig import textwrap +from collections.abc import Sequence from pathlib import Path -from typing import TYPE_CHECKING, Generic, List, Optional, Sequence, TypeVar -from typing_extensions import Annotated +from typing import TYPE_CHECKING, Annotated, Generic, Optional, TypeVar import pytest @@ -142,8 +142,8 @@ def test_lazy_types_declared_within_optional(): @strawberry.type class Query: - normal_edges: List[Edge[Optional[TypeC]]] - lazy_edges: List[ + normal_edges: list[Edge[Optional[TypeC]]] + lazy_edges: list[ Edge[ Optional[ Annotated["TypeC", strawberry.lazy("tests.schema.test_lazy.type_c")] diff --git a/tests/schema/test_lazy/type_a.py b/tests/schema/test_lazy/type_a.py index 1e5727b2f8..4042fca1dd 100644 --- a/tests/schema/test_lazy/type_a.py +++ b/tests/schema/test_lazy/type_a.py @@ -1,5 +1,4 @@ -from typing import TYPE_CHECKING, List, Optional -from typing_extensions import Annotated +from typing import TYPE_CHECKING, Annotated, Optional import strawberry @@ -10,7 +9,7 @@ @strawberry.type class TypeA: list_of_b: Optional[ - List[Annotated["TypeB", strawberry.lazy("tests.schema.test_lazy.type_b")]] + list[Annotated["TypeB", strawberry.lazy("tests.schema.test_lazy.type_b")]] ] = None @strawberry.field diff --git a/tests/schema/test_lazy/type_b.py b/tests/schema/test_lazy/type_b.py index 12c4394d2c..0760514b9e 100644 --- a/tests/schema/test_lazy/type_b.py +++ b/tests/schema/test_lazy/type_b.py @@ -1,5 +1,4 @@ -from typing import TYPE_CHECKING, List -from typing_extensions import Annotated +from typing import TYPE_CHECKING, Annotated import strawberry @@ -7,14 +6,14 @@ from .type_a import TypeA from .type_c import TypeC - ListTypeA = List[TypeA] - ListTypeC = List[TypeC] + ListTypeA = list[TypeA] + ListTypeC = list[TypeC] else: TypeA = Annotated["TypeA", strawberry.lazy("tests.schema.test_lazy.type_a")] - ListTypeA = List[ + ListTypeA = list[ Annotated["TypeA", strawberry.lazy("tests.schema.test_lazy.type_a")] ] - ListTypeC = List[ + ListTypeC = list[ Annotated["TypeC", strawberry.lazy("tests.schema.test_lazy.type_c")] ] diff --git a/tests/schema/test_lazy/type_c.py b/tests/schema/test_lazy/type_c.py index 4bc9b63df4..4858e179f1 100644 --- a/tests/schema/test_lazy/type_c.py +++ b/tests/schema/test_lazy/type_c.py @@ -1,6 +1,5 @@ import sys -from typing import Generic, TypeVar -from typing_extensions import Annotated +from typing import Annotated, Generic, TypeVar import strawberry diff --git a/tests/schema/test_lazy/type_d.py b/tests/schema/test_lazy/type_d.py index d1f1237706..967bc0bf27 100644 --- a/tests/schema/test_lazy/type_d.py +++ b/tests/schema/test_lazy/type_d.py @@ -1,6 +1,5 @@ import sys -from typing import Generic, TypeVar -from typing_extensions import Annotated +from typing import Annotated, Generic, TypeVar import strawberry diff --git a/tests/schema/test_lazy_types/type_a.py b/tests/schema/test_lazy_types/type_a.py index 45dbf1a9e2..dac9ba9c1e 100644 --- a/tests/schema/test_lazy_types/type_a.py +++ b/tests/schema/test_lazy_types/type_a.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional import strawberry @@ -11,7 +11,7 @@ @strawberry.type class TypeA: list_of_b: Optional[ - List[strawberry.LazyType["TypeB", "tests.schema.test_lazy_types.type_b"]] + list[strawberry.LazyType["TypeB", "tests.schema.test_lazy_types.type_b"]] ] = None @strawberry.field diff --git a/tests/schema/test_list.py b/tests/schema/test_list.py index 879518cd37..4119841d55 100644 --- a/tests/schema/test_list.py +++ b/tests/schema/test_list.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional import strawberry @@ -7,7 +7,7 @@ def test_basic_list(): @strawberry.type class Query: @strawberry.field - def example(self) -> List[str]: + def example(self) -> list[str]: return ["Example"] schema = strawberry.Schema(query=Query) @@ -24,7 +24,7 @@ def test_of_optional(): @strawberry.type class Query: @strawberry.field - def example(self) -> List[Optional[str]]: + def example(self) -> list[Optional[str]]: return ["Example", None] schema = strawberry.Schema(query=Query) @@ -38,12 +38,12 @@ def example(self) -> List[Optional[str]]: def test_lists_of_lists(): - def get_polygons() -> List[List[float]]: + def get_polygons() -> list[list[float]]: return [[2.0, 6.0]] @strawberry.type class Query: - polygons: List[List[float]] = strawberry.field(resolver=get_polygons) + polygons: list[list[float]] = strawberry.field(resolver=get_polygons) schema = strawberry.Schema(query=Query) diff --git a/tests/schema/test_name_converter.py b/tests/schema/test_name_converter.py index a406fbd532..8927c5a655 100644 --- a/tests/schema/test_name_converter.py +++ b/tests/schema/test_name_converter.py @@ -1,6 +1,6 @@ import textwrap from enum import Enum -from typing import Generic, List, Optional, TypeVar, Union +from typing import Generic, Optional, TypeVar, Union import strawberry from strawberry.directive import StrawberryDirective @@ -35,7 +35,7 @@ def from_union(self, union: StrawberryUnion) -> str: def from_generic( self, generic_type: StrawberryObjectDefinition, - types: List[Union[StrawberryType, type]], + types: list[Union[StrawberryType, type]], ) -> str: return super().from_generic(generic_type, types) + self.suffix diff --git a/tests/schema/test_permission.py b/tests/schema/test_permission.py index f041dac6b6..bd687d4f21 100644 --- a/tests/schema/test_permission.py +++ b/tests/schema/test_permission.py @@ -1,7 +1,7 @@ import re import textwrap import typing -from typing import List, Optional +from typing import Optional import pytest @@ -441,7 +441,7 @@ class Query: @strawberry.field( extensions=[PermissionExtension([IsAuthorized()], fail_silently=True)] ) - def names(self) -> Optional[List[str]]: # pragma: no cover + def names(self) -> Optional[list[str]]: # pragma: no cover return ["ABC"] schema = strawberry.Schema(query=Query) @@ -464,7 +464,7 @@ class Query: @strawberry.field( extensions=[PermissionExtension([IsAuthorized()], fail_silently=True)] ) - def names(self) -> List[str]: # pragma: no cover + def names(self) -> list[str]: # pragma: no cover return ["ABC"] schema = strawberry.Schema(query=Query) diff --git a/tests/schema/test_resolvers.py b/tests/schema/test_resolvers.py index 8a81626320..4f0758f72f 100644 --- a/tests/schema/test_resolvers.py +++ b/tests/schema/test_resolvers.py @@ -1,7 +1,7 @@ # type: ignore import typing from contextlib import nullcontext -from typing import Any, Generic, List, NamedTuple, Optional, Type, TypeVar, Union +from typing import Any, Generic, NamedTuple, Optional, TypeVar, Union import pytest @@ -215,7 +215,7 @@ class User: age: int @classmethod - def get_users(cls) -> "List[User]": + def get_users(cls) -> "list[User]": return [cls(name="Bob", age=10), cls(name="Nancy", age=30)] @strawberry.type @@ -237,12 +237,12 @@ class Query: def test_staticmethod_resolvers(): class Alphabet: @staticmethod - def get_letters() -> List[str]: + def get_letters() -> list[str]: return ["a", "b", "c"] @strawberry.type class Query: - letters: List[str] = strawberry.field(resolver=Alphabet.get_letters) + letters: list[str] = strawberry.field(resolver=Alphabet.get_letters) schema = strawberry.Schema(query=Query) @@ -325,7 +325,7 @@ async def test_async_list_resolver(): @strawberry.type class Query: @strawberry.field - async def best_flavours(self) -> List[str]: + async def best_flavours(self) -> list[str]: return ["strawberry", "pistachio"] schema = strawberry.Schema(query=Query) @@ -362,7 +362,7 @@ class AType: T = TypeVar("T") - def resolver_factory(strawberry_type: Type[T]): + def resolver_factory(strawberry_type: type[T]): def resolver() -> T: return strawberry_type(some=1) @@ -477,12 +477,12 @@ def test_generic_resolver_list(): class AType: some: int - def resolver() -> List[T]: + def resolver() -> list[T]: return [AType(some=1)] @strawberry.type class Query: - list_type: List[AType] = strawberry.field(resolver) + list_type: list[AType] = strawberry.field(resolver) strawberry.Schema(query=Query) diff --git a/tests/schema/test_schema_generation.py b/tests/schema/test_schema_generation.py index 6826159b3d..fd9623a8b2 100644 --- a/tests/schema/test_schema_generation.py +++ b/tests/schema/test_schema_generation.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Optional import pytest from graphql import ExecutionContext as GraphQLExecutionContext @@ -66,8 +66,8 @@ def test_custom_execution_context(): class CustomExecutionContext(GraphQLExecutionContext): @staticmethod def build_response( - data: Optional[Dict[str, Any]], - errors: List[GraphQLError], + data: Optional[dict[str, Any]], + errors: list[GraphQLError], ) -> ExecutionResult: result = super( CustomExecutionContext, CustomExecutionContext diff --git a/tests/schema/test_schema_hooks.py b/tests/schema/test_schema_hooks.py index c24b352e6e..723d2543f3 100644 --- a/tests/schema/test_schema_hooks.py +++ b/tests/schema/test_schema_hooks.py @@ -1,5 +1,4 @@ import textwrap -from typing import List import strawberry from strawberry.types.base import StrawberryObjectDefinition @@ -22,7 +21,7 @@ def public_field_filter(field: StrawberryField) -> bool: class PublicSchema(strawberry.Schema): def get_fields( self, type_definition: StrawberryObjectDefinition - ) -> List[StrawberryField]: + ) -> list[StrawberryField]: fields = super().get_fields(type_definition) return list(filter(public_field_filter, fields)) diff --git a/tests/schema/test_subscription.py b/tests/schema/test_subscription.py index 14a2781042..4ca0f5d87d 100644 --- a/tests/schema/test_subscription.py +++ b/tests/schema/test_subscription.py @@ -2,16 +2,13 @@ from __future__ import annotations import inspect -import sys from collections import abc # noqa: F401 -from typing import ( # noqa: F401 +from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator # noqa: F401 +from typing import ( + Annotated, Any, - AsyncGenerator, - AsyncIterable, - AsyncIterator, Union, ) -from typing_extensions import Annotated import pytest @@ -98,21 +95,15 @@ async def example(self, name: str) -> AsyncGenerator[str, None]: assert result.data["example"] == "Hi Nina" -requires_builtin_generics = pytest.mark.skipif( - sys.version_info < (3, 9), - reason="built-in generic annotations were added in python 3.9", -) - - @pytest.mark.parametrize( "return_annotation", ( "AsyncGenerator[str, None]", "AsyncIterable[str]", "AsyncIterator[str]", - pytest.param("abc.AsyncIterator[str]", marks=requires_builtin_generics), - pytest.param("abc.AsyncGenerator[str, None]", marks=requires_builtin_generics), - pytest.param("abc.AsyncIterable[str]", marks=requires_builtin_generics), + "abc.AsyncIterator[str]", + "abc.AsyncGenerator[str, None]", + "abc.AsyncIterable[str]", ), ) @pytest.mark.asyncio diff --git a/tests/schema/test_union.py b/tests/schema/test_union.py index 42392b9584..edd219c25e 100644 --- a/tests/schema/test_union.py +++ b/tests/schema/test_union.py @@ -2,8 +2,7 @@ import textwrap from dataclasses import dataclass from textwrap import dedent -from typing import Generic, List, Optional, TypeVar, Union -from typing_extensions import Annotated +from typing import Annotated, Generic, Optional, TypeVar, Union import pytest @@ -543,7 +542,7 @@ def test_union_with_similar_nested_generic_types(): @strawberry.type class Container(Generic[T]): - items: List[T] + items: list[T] @strawberry.type class A: @@ -684,10 +683,6 @@ class Query: InvalidUnionTypeError, match=r"Type `list\[...\]` cannot be used in a GraphQL Union", ) -@pytest.mark.skipif( - sys.version_info < (3, 9, 0), - reason="list[str] is only available on python 3.9+", -) def test_raises_on_union_with_list_str(): global ICanBeInUnion @@ -708,10 +703,6 @@ class Query: InvalidUnionTypeError, match=r"Type `list\[...\]` cannot be used in a GraphQL Union", ) -@pytest.mark.skipif( - sys.version_info < (3, 9, 0), - reason="list[str] is only available on python 3.9+", -) def test_raises_on_union_with_list_str_38(): global ICanBeInUnion @@ -721,7 +712,7 @@ class ICanBeInUnion: @strawberry.type class Query: - union: Union[ICanBeInUnion, List[str]] + union: Union[ICanBeInUnion, list[str]] strawberry.Schema(query=Query) diff --git a/tests/test/conftest.py b/tests/test/conftest.py index 30ecd326bd..0fa1001dd3 100644 --- a/tests/test/conftest.py +++ b/tests/test/conftest.py @@ -1,7 +1,8 @@ from __future__ import annotations +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -from typing import TYPE_CHECKING, AsyncGenerator +from typing import TYPE_CHECKING import pytest diff --git a/tests/test_auto.py b/tests/test_auto.py index ddeac975e2..02e0bcb423 100644 --- a/tests/test_auto.py +++ b/tests/test_auto.py @@ -1,5 +1,5 @@ -from typing import Any, cast -from typing_extensions import Annotated, get_args +from typing import Annotated, Any, cast +from typing_extensions import get_args import strawberry from strawberry.annotation import StrawberryAnnotation diff --git a/tests/test_dataloaders.py b/tests/test_dataloaders.py index f00bc430f7..d56763a24c 100644 --- a/tests/test_dataloaders.py +++ b/tests/test_dataloaders.py @@ -1,6 +1,7 @@ import asyncio from asyncio.futures import Future -from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union, cast +from collections.abc import Awaitable +from typing import Any, Callable, Optional, Union, cast import pytest from pytest_mock import MockerFixture @@ -8,10 +9,10 @@ from strawberry.dataloader import AbstractCache, DataLoader from strawberry.exceptions import WrongNumberOfResultsReturned -IDXType = Callable[[List[int]], Awaitable[List[int]]] +IDXType = Callable[[list[int]], Awaitable[list[int]]] -async def idx(keys: List[int]) -> List[int]: +async def idx(keys: list[int]) -> list[int]: return keys @@ -72,7 +73,7 @@ async def test_max_batch_size(mocker: MockerFixture): @pytest.mark.asyncio async def test_error(): - async def idx(keys: List[int]) -> List[Union[int, ValueError]]: + async def idx(keys: list[int]) -> list[Union[int, ValueError]]: return [ValueError()] loader = DataLoader(load_fn=idx) @@ -83,7 +84,7 @@ async def idx(keys: List[int]) -> List[Union[int, ValueError]]: @pytest.mark.asyncio async def test_error_and_values(): - async def idx(keys: List[int]) -> List[Union[int, ValueError]]: + async def idx(keys: list[int]) -> list[Union[int, ValueError]]: return [2] if keys == [2] else [ValueError()] loader = DataLoader(load_fn=idx) @@ -96,7 +97,7 @@ async def idx(keys: List[int]) -> List[Union[int, ValueError]]: @pytest.mark.asyncio async def test_when_raising_error_in_loader(): - async def idx(keys: List[int]) -> List[Union[int, ValueError]]: + async def idx(keys: list[int]) -> list[Union[int, ValueError]]: raise ValueError loader = DataLoader(load_fn=idx) @@ -114,7 +115,7 @@ async def idx(keys: List[int]) -> List[Union[int, ValueError]]: @pytest.mark.asyncio async def test_returning_wrong_number_of_results(): - async def idx(keys: List[int]) -> List[int]: + async def idx(keys: list[int]) -> list[int]: return [1, 2] loader = DataLoader(load_fn=idx) @@ -195,7 +196,7 @@ async def test_cache_disabled_immediate_await(mocker: MockerFixture): @pytest.mark.asyncio async def test_prime(): - async def idx(keys: List[Union[int, float]]) -> List[Union[int, float]]: + async def idx(keys: list[Union[int, float]]) -> list[Union[int, float]]: assert keys, "At least one key must be specified" return keys @@ -241,7 +242,7 @@ async def idx(keys: List[Union[int, float]]) -> List[Union[int, float]]: @pytest.mark.asyncio async def test_prime_nocache(): - async def idx(keys: List[Union[int, float]]) -> List[Union[int, float]]: + async def idx(keys: list[Union[int, float]]) -> list[Union[int, float]]: assert keys, "At least one key must be specified" return keys @@ -266,7 +267,7 @@ async def idx(keys: List[Union[int, float]]) -> List[Union[int, float]]: async def test_clear(): batch_num = 0 - async def idx(keys: List[int]) -> List[Tuple[int, int]]: + async def idx(keys: list[int]) -> list[tuple[int, int]]: """Maps key => (key, batch_num)""" nonlocal batch_num batch_num += 1 @@ -293,7 +294,7 @@ async def idx(keys: List[int]) -> List[Tuple[int, int]]: async def test_clear_nocache(): batch_num = 0 - async def idx(keys: List[int]) -> List[Tuple[int, int]]: + async def idx(keys: list[int]) -> list[tuple[int, int]]: """Maps key => (key, batch_num)""" nonlocal batch_num batch_num += 1 @@ -318,7 +319,7 @@ async def idx(keys: List[int]) -> List[Tuple[int, int]]: @pytest.mark.asyncio async def test_dont_dispatch_cancelled(): - async def idx(keys: List[int]) -> List[int]: + async def idx(keys: list[int]) -> list[int]: await asyncio.sleep(0.2) return keys @@ -363,7 +364,7 @@ async def idx(keys: List[int]) -> List[int]: async def test_cache_override(): class TestCache(AbstractCache[int, int]): def __init__(self): - self.cache: Dict[int, Future[int]] = {} + self.cache: dict[int, Future[int]] = {} def get(self, key: int) -> Optional["Future[int]"]: return self.cache.get(key) @@ -423,7 +424,7 @@ def clear(self) -> None: @pytest.mark.asyncio async def test_custom_cache_key_fn(): - def custom_cache_key(key: List[int]) -> str: + def custom_cache_key(key: list[int]) -> str: return ",".join(str(k) for k in key) loader = DataLoader(load_fn=idx, cache_key_fn=custom_cache_key) diff --git a/tests/test_forward_references.py b/tests/test_forward_references.py index ea771ec9dd..0e1831795c 100644 --- a/tests/test_forward_references.py +++ b/tests/test_forward_references.py @@ -1,12 +1,8 @@ # type: ignore from __future__ import annotations -import sys import textwrap -from typing import List -from typing_extensions import Annotated - -import pytest +from typing import Annotated import strawberry from strawberry.printer import print_schema @@ -53,15 +49,11 @@ class MyType: del MyType -@pytest.mark.skipif( - sys.version_info < (3, 9), - reason="Python 3.8 and previous can't properly resolve this.", -) def test_lazy_forward_reference(): @strawberry.type class Query: @strawberry.field - async def a(self) -> A: + async def a(self) -> A: # pragma: no cover return A(id=strawberry.ID("1")) expected_representation = """ @@ -89,10 +81,6 @@ async def a(self) -> A: assert print_schema(schema) == textwrap.dedent(expected_representation).strip() -@pytest.mark.skipif( - sys.version_info < (3, 9), - reason="Python 3.8 and previous can't properly resolve this.", -) def test_lazy_forward_reference_schema_with_a_list_only(): @strawberry.type class Query: @@ -126,12 +114,12 @@ def test_with_resolver(): class User: name: str - def get_users() -> List[User]: + def get_users() -> list[User]: # pragma: no cover return [] @strawberry.type class Query: - users: List[User] = strawberry.field(resolver=get_users) + users: list[User] = strawberry.field(resolver=get_users) definition = Query.__strawberry_definition__ assert definition.name == "Query" @@ -152,12 +140,12 @@ def test_union_or_notation(): class User: name: str - def get_users() -> List[User] | None: + def get_users() -> list[User] | None: # pragma: no cover return [] @strawberry.type class Query: - users: List[User] | None = strawberry.field(resolver=get_users) + users: list[User] | None = strawberry.field(resolver=get_users) definition = Query.__strawberry_definition__ assert definition.name == "Query" @@ -172,10 +160,6 @@ class Query: del User -@pytest.mark.skipif( - sys.version_info < (3, 9), - reason="generic type alias only available on python 3.9+", -) def test_union_or_notation_generic_type_alias(): global User @@ -183,7 +167,7 @@ def test_union_or_notation_generic_type_alias(): class User: name: str - def get_users() -> list[User] | None: + def get_users() -> list[User] | None: # pragma: no cover return [] @strawberry.type @@ -210,12 +194,12 @@ def test_annotated(): class User: name: str - def get_users() -> List[User]: + def get_users() -> list[User]: # pragma: no cover return [] @strawberry.type class Query: - users: Annotated[List[User], object()] = strawberry.field(resolver=get_users) + users: Annotated[list[User], object()] = strawberry.field(resolver=get_users) definition = Query.__strawberry_definition__ assert definition.name == "Query" @@ -236,12 +220,12 @@ def test_annotated_or_notation(): class User: name: str - def get_users() -> List[User] | None: + def get_users() -> list[User] | None: # pragma: no cover return [] @strawberry.type class Query: - users: Annotated[List[User] | None, object()] = strawberry.field( + users: Annotated[list[User] | None, object()] = strawberry.field( resolver=get_users ) @@ -258,10 +242,6 @@ class Query: del User -@pytest.mark.skipif( - sys.version_info < (3, 9), - reason="generic type alias only available on python 3.9+", -) def test_annotated_or_notation_generic_type_alias(): global User @@ -269,7 +249,7 @@ def test_annotated_or_notation_generic_type_alias(): class User: name: str - def get_users() -> list[User]: + def get_users() -> list[User]: # pragma: no cover return [] @strawberry.type diff --git a/tests/test_printer/test_basic.py b/tests/test_printer/test_basic.py index e97d269268..f4b14a5f6b 100644 --- a/tests/test_printer/test_basic.py +++ b/tests/test_printer/test_basic.py @@ -1,5 +1,5 @@ import textwrap -from typing import List, Optional +from typing import Optional from uuid import UUID import strawberry @@ -142,14 +142,14 @@ class MyInput: id_number: strawberry.ID = strawberry.ID(123) # type: ignore id_number_string: strawberry.ID = strawberry.ID("123") x: Optional[int] = UNSET - l: List[str] = strawberry.field(default_factory=list) # noqa: E741 - list_with_values: List[str] = strawberry.field( + l: list[str] = strawberry.field(default_factory=list) # noqa: E741 + list_with_values: list[str] = strawberry.field( default_factory=lambda: ["a", "b"] ) - list_from_generator: List[str] = strawberry.field( + list_from_generator: list[str] = strawberry.field( default_factory=lambda: (x for x in ["a", "b"]) ) - list_from_string: List[str] = "ab" # type: ignore - we do this for testing purposes + list_from_string: list[str] = "ab" # type: ignore - we do this for testing purposes @strawberry.type class Query: diff --git a/tests/test_printer/test_schema_directives.py b/tests/test_printer/test_schema_directives.py index ac0fb2c6a3..e7f82fea09 100644 --- a/tests/test_printer/test_schema_directives.py +++ b/tests/test_printer/test_schema_directives.py @@ -1,7 +1,6 @@ import textwrap from enum import Enum -from typing import Any, List, Optional, Union -from typing_extensions import Annotated +from typing import Annotated, Any, Optional, Union import strawberry from strawberry import BasePermission, Info @@ -71,12 +70,12 @@ class SensitiveValue: @strawberry.schema_directive(locations=[Location.OBJECT, Location.FIELD_DEFINITION]) class SensitiveData: reason: str - meta: Optional[List[SensitiveValue]] = UNSET + meta: Optional[list[SensitiveValue]] = UNSET @strawberry.schema_directive(locations=[Location.INPUT_OBJECT]) class SensitiveInput: reason: str - meta: Optional[List[SensitiveValue]] = UNSET + meta: Optional[list[SensitiveValue]] = UNSET @strawberry.schema_directive(locations=[Location.INPUT_FIELD_DEFINITION]) class RangeInput: @@ -228,7 +227,7 @@ class Query: def test_respects_schema_parameter_types_for_arguments_list_of_ints(): @strawberry.schema_directive(locations=[Location.FIELD_DEFINITION]) class Sensitive: - real_age: List[int] + real_age: list[int] @strawberry.type class Query: @@ -252,7 +251,7 @@ class Query: def test_respects_schema_parameter_types_for_arguments_list_of_strings(): @strawberry.schema_directive(locations=[Location.FIELD_DEFINITION]) class Sensitive: - real_age: List[str] + real_age: list[str] @strawberry.type class Query: diff --git a/tests/typecheckers/test_relay.py b/tests/typecheckers/test_relay.py index 90545def93..e958c95e99 100644 --- a/tests/typecheckers/test_relay.py +++ b/tests/typecheckers/test_relay.py @@ -15,7 +15,6 @@ Generator, Iterable, Iterator, - List, Optional, Union, ) @@ -62,16 +61,16 @@ class FruitAlike: ... -def fruits_resolver() -> List[Fruit]: +def fruits_resolver() -> list[Fruit]: ... @strawberry.type class Query: node: relay.Node - nodes: List[relay.Node] + nodes: list[relay.Node] node_optional: Optional[relay.Node] - nodes_optional: List[Optional[relay.Node]] + nodes_optional: list[Optional[relay.Node]] fruits: relay.Connection[Fruit] = strawberry.relay.connection( resolver=fruits_resolver, ) @@ -85,7 +84,7 @@ def fruits_custom_resolver( self, info: strawberry.Info, name_endswith: Optional[str] = None, - ) -> List[Fruit]: + ) -> list[Fruit]: ... @relay.connection(relay.Connection[Fruit]) @@ -161,198 +160,196 @@ def test(): Result( type="information", message='Type of "Query.node" is "Node"', - line=131, + line=130, column=13, ), Result( type="information", - message='Type of "Query.nodes" is "List[Node]"', - line=132, + message='Type of "Query.nodes" is "list[Node]"', + line=131, column=13, ), Result( type="information", message='Type of "Query.node_optional" is "Node | None"', - line=133, + line=132, column=13, ), Result( type="information", - message='Type of "Query.nodes_optional" is "List[Node | None]"', - line=134, + message='Type of "Query.nodes_optional" is "list[Node | None]"', + line=133, column=13, ), Result( type="information", message='Type of "Query.fruits" is "Connection[Fruit]"', - line=135, + line=134, column=13, ), Result( type="information", message='Type of "Query.fruits_conn" is "Connection[Fruit]"', - line=136, + line=135, column=13, ), Result( type="information", message='Type of "Query.fruits_custom_pagination" is "FruitCustomPaginationConnection"', - line=137, + line=136, column=13, ), Result( type="information", message='Type of "Query.fruits_custom_resolver" is "Any"', - line=138, + line=137, column=13, ), Result( type="information", message='Type of "Query.fruits_custom_resolver_iterator" is "Any"', - line=139, + line=138, column=13, ), Result( type="information", message='Type of "Query.fruits_custom_resolver_iterable" is "Any"', - line=140, + line=139, column=13, ), Result( type="information", message='Type of "Query.fruits_custom_resolver_generator" is "Any"', - line=141, + line=140, column=13, ), Result( type="information", message='Type of "Query.fruits_custom_resolver_async_iterator" is "Any"', - line=142, + line=141, column=13, ), Result( type="information", message='Type of "Query.fruits_custom_resolver_async_iterable" is "Any"', - line=143, + line=142, column=13, ), Result( type="information", message='Type of "Query.fruits_custom_resolver_async_generator" is "Any"', - line=144, + line=143, column=13, ), ] ) assert results.mypy == snapshot( [ - Result(type="error", message="Missing return statement", line=34, column=5), - Result(type="error", message="Missing return statement", line=57, column=1), + Result(type="error", message="Missing return statement", line=33, column=5), + Result(type="error", message="Missing return statement", line=56, column=1), Result( type="error", message='Untyped decorator makes function "fruits_custom_resolver" untyped', - line=75, + line=74, column=6, ), - Result(type="error", message="Missing return statement", line=76, column=5), + Result(type="error", message="Missing return statement", line=75, column=5), Result( type="error", message='Untyped decorator makes function "fruits_custom_resolver_iterator" untyped', - line=83, + line=82, column=6, ), - Result(type="error", message="Missing return statement", line=84, column=5), + Result(type="error", message="Missing return statement", line=83, column=5), Result( type="error", message='Untyped decorator makes function "fruits_custom_resolver_iterable" untyped', - line=91, + line=90, column=6, ), - Result(type="error", message="Missing return statement", line=92, column=5), + Result(type="error", message="Missing return statement", line=91, column=5), Result( type="error", message='Untyped decorator makes function "fruits_custom_resolver_generator" untyped', - line=99, + line=98, column=6, ), - Result( - type="error", message="Missing return statement", line=100, column=5 - ), + Result(type="error", message="Missing return statement", line=99, column=5), Result( type="error", message='Untyped decorator makes function "fruits_custom_resolver_async_iterator" untyped', - line=107, + line=106, column=6, ), Result( - type="error", message="Missing return statement", line=108, column=5 + type="error", message="Missing return statement", line=107, column=5 ), Result( type="error", message='Untyped decorator makes function "fruits_custom_resolver_async_iterable" untyped', - line=115, + line=114, column=6, ), Result( - type="error", message="Missing return statement", line=116, column=5 + type="error", message="Missing return statement", line=115, column=5 ), Result( type="error", message='Untyped decorator makes function "fruits_custom_resolver_async_generator" untyped', - line=123, + line=122, column=6, ), Result( - type="error", message="Missing return statement", line=124, column=5 + type="error", message="Missing return statement", line=123, column=5 ), Result( type="note", message='Revealed type is "strawberry.relay.types.Node"', - line=131, + line=130, column=13, ), Result( type="note", message='Revealed type is "builtins.list[strawberry.relay.types.Node]"', - line=132, + line=131, column=13, ), Result( type="note", message='Revealed type is "Union[strawberry.relay.types.Node, None]"', - line=133, + line=132, column=13, ), Result( type="note", message='Revealed type is "builtins.list[Union[strawberry.relay.types.Node, None]]"', - line=134, + line=133, column=13, ), Result( type="note", message='Revealed type is "strawberry.relay.types.Connection[mypy_test.Fruit]"', - line=135, + line=134, column=13, ), Result( type="note", message='Revealed type is "strawberry.relay.types.Connection[mypy_test.Fruit]"', - line=136, + line=135, column=13, ), Result( type="note", message='Revealed type is "mypy_test.FruitCustomPaginationConnection"', - line=137, + line=136, column=13, ), + Result(type="note", message='Revealed type is "Any"', line=137, column=13), Result(type="note", message='Revealed type is "Any"', line=138, column=13), Result(type="note", message='Revealed type is "Any"', line=139, column=13), Result(type="note", message='Revealed type is "Any"', line=140, column=13), Result(type="note", message='Revealed type is "Any"', line=141, column=13), Result(type="note", message='Revealed type is "Any"', line=142, column=13), Result(type="note", message='Revealed type is "Any"', line=143, column=13), - Result(type="note", message='Revealed type is "Any"', line=144, column=13), ] ) diff --git a/tests/typecheckers/utils/mypy.py b/tests/typecheckers/utils/mypy.py index d78caa3f07..85e858c627 100644 --- a/tests/typecheckers/utils/mypy.py +++ b/tests/typecheckers/utils/mypy.py @@ -4,7 +4,7 @@ import pathlib import subprocess import tempfile -from typing import List, TypedDict +from typing import TypedDict from .result import Result @@ -12,7 +12,7 @@ class PyrightCLIResult(TypedDict): version: str time: str - generalDiagnostics: List[GeneralDiagnostic] + generalDiagnostics: list[GeneralDiagnostic] summary: Summary @@ -41,7 +41,7 @@ class Summary(TypedDict): timeInSec: float -def run_mypy(code: str, strict: bool = True) -> List[Result]: +def run_mypy(code: str, strict: bool = True) -> list[Result]: args = ["mypy", "--output=json"] if strict: @@ -62,7 +62,7 @@ def run_mypy(code: str, strict: bool = True) -> List[Result]: ) full_output = full_output.strip() - results: List[Result] = [] + results: list[Result] = [] try: for line in full_output.split("\n"): diff --git a/tests/typecheckers/utils/pyright.py b/tests/typecheckers/utils/pyright.py index 6d03397fce..8cf42454bf 100644 --- a/tests/typecheckers/utils/pyright.py +++ b/tests/typecheckers/utils/pyright.py @@ -4,7 +4,7 @@ import os import subprocess import tempfile -from typing import List, TypedDict, cast +from typing import TypedDict, cast from .result import Result, ResultType @@ -12,7 +12,7 @@ class PyrightCLIResult(TypedDict): version: str time: str - generalDiagnostics: List[GeneralDiagnostic] + generalDiagnostics: list[GeneralDiagnostic] summary: Summary @@ -41,7 +41,7 @@ class Summary(TypedDict): timeInSec: float -def run_pyright(code: str, strict: bool = True) -> List[Result]: +def run_pyright(code: str, strict: bool = True) -> list[Result]: if strict: code = "# pyright: strict\n" + code diff --git a/tests/types/cross_module_resolvers/a_mod.py b/tests/types/cross_module_resolvers/a_mod.py index c2db7aae62..7b50501e0e 100644 --- a/tests/types/cross_module_resolvers/a_mod.py +++ b/tests/types/cross_module_resolvers/a_mod.py @@ -1,9 +1,7 @@ -from typing import List - import strawberry -def a_resolver() -> List["AObject"]: +def a_resolver() -> list["AObject"]: return [] diff --git a/tests/types/cross_module_resolvers/b_mod.py b/tests/types/cross_module_resolvers/b_mod.py index 92ff2adfff..a3c6ff00e1 100644 --- a/tests/types/cross_module_resolvers/b_mod.py +++ b/tests/types/cross_module_resolvers/b_mod.py @@ -1,9 +1,7 @@ -from typing import List - import strawberry -def b_resolver() -> List["BObject"]: +def b_resolver() -> list["BObject"]: return [] diff --git a/tests/types/cross_module_resolvers/c_mod.py b/tests/types/cross_module_resolvers/c_mod.py index e0392827f6..9182310e8c 100644 --- a/tests/types/cross_module_resolvers/c_mod.py +++ b/tests/types/cross_module_resolvers/c_mod.py @@ -1,5 +1,3 @@ -from typing import List - import a_mod import b_mod import x_mod @@ -9,15 +7,15 @@ import strawberry -def c_inheritance_resolver() -> List["CInheritance"]: +def c_inheritance_resolver() -> list["CInheritance"]: pass -def c_composition_resolver() -> List["CComposition"]: +def c_composition_resolver() -> list["CComposition"]: pass -def c_composition_by_name_resolver() -> List["CCompositionByName"]: +def c_composition_by_name_resolver() -> list["CCompositionByName"]: pass @@ -28,34 +26,34 @@ class CInheritance(a_mod.AObject, b_mod.BObject): @strawberry.type class CComposition: - a_list: List[a_mod.AObject] - b_list: List[b_mod.BObject] + a_list: list[a_mod.AObject] + b_list: list[b_mod.BObject] @strawberry.type class CCompositionByName: - a_list: List["C_AObject"] - b_list: List["C_BObject"] + a_list: list["C_AObject"] + b_list: list["C_BObject"] @strawberry.field - def a_method(self) -> List["C_AObject"]: + def a_method(self) -> list["C_AObject"]: return self.a_list @strawberry.field - def b_method(self) -> List["C_BObject"]: + def b_method(self) -> list["C_BObject"]: return self.b_list @strawberry.type class CCompositionByNameWithResolvers: - a_list: List["C_AObject"] = strawberry.field(resolver=a_mod.a_resolver) - b_list: List["C_BObject"] = strawberry.field(resolver=b_mod.b_resolver) + a_list: list["C_AObject"] = strawberry.field(resolver=a_mod.a_resolver) + b_list: list["C_BObject"] = strawberry.field(resolver=b_mod.b_resolver) @strawberry.type class CCompositionByNameWithTypelessResolvers: - a_list: List["C_AObject"] = strawberry.field(resolver=x_mod.typeless_resolver) - b_list: List["C_BObject"] = strawberry.field(resolver=x_mod.typeless_resolver) + a_list: list["C_AObject"] = strawberry.field(resolver=x_mod.typeless_resolver) + b_list: list["C_BObject"] = strawberry.field(resolver=x_mod.typeless_resolver) @strawberry.type diff --git a/tests/types/cross_module_resolvers/test_cross_module_resolvers.py b/tests/types/cross_module_resolvers/test_cross_module_resolvers.py index 4c218319f8..bc9241c4ca 100644 --- a/tests/types/cross_module_resolvers/test_cross_module_resolvers.py +++ b/tests/types/cross_module_resolvers/test_cross_module_resolvers.py @@ -4,8 +4,6 @@ (forward reference) and can only be resolved at schema construction. """ -from typing import List - import a_mod import b_mod import c_mod @@ -17,19 +15,19 @@ def test_a(): @strawberry.type class Query: - a_list: List[a_mod.AObject] + a_list: list[a_mod.AObject] [field] = Query.__strawberry_definition__.fields - assert field.type == List[a_mod.AObject] + assert field.type == list[a_mod.AObject] def test_a_resolver(): @strawberry.type class Query: - a_list: List[a_mod.AObject] = strawberry.field(resolver=a_mod.a_resolver) + a_list: list[a_mod.AObject] = strawberry.field(resolver=a_mod.a_resolver) [field] = Query.__strawberry_definition__.fields - assert field.type == List[a_mod.AObject] + assert field.type == list[a_mod.AObject] def test_a_only_resolver(): @@ -38,16 +36,16 @@ class Query: a_list = strawberry.field(resolver=a_mod.a_resolver) [field] = Query.__strawberry_definition__.fields - assert field.type == List[a_mod.AObject] + assert field.type == list[a_mod.AObject] def test_a_typeless_resolver(): @strawberry.type class Query: - a_list: List[a_mod.AObject] = strawberry.field(resolver=x_mod.typeless_resolver) + a_list: list[a_mod.AObject] = strawberry.field(resolver=x_mod.typeless_resolver) [field] = Query.__strawberry_definition__.fields - assert field.type == List[a_mod.AObject] + assert field.type == list[a_mod.AObject] def test_c_composition_by_name(): @@ -57,10 +55,10 @@ def test_c_composition_by_name(): a_method, b_method, ] = c_mod.CCompositionByName.__strawberry_definition__.fields - assert a_field.type == List[a_mod.AObject] - assert b_field.type == List[b_mod.BObject] - assert a_method.type == List[a_mod.AObject] - assert b_method.type == List[b_mod.BObject] + assert a_field.type == list[a_mod.AObject] + assert b_field.type == list[b_mod.BObject] + assert a_method.type == list[a_mod.AObject] + assert b_method.type == list[b_mod.BObject] def test_c_inheritance(): @@ -83,21 +81,21 @@ def test_c_inheritance(): def test_c_inheritance_resolver(): @strawberry.type class Query: - c: List[c_mod.CInheritance] = strawberry.field( + c: list[c_mod.CInheritance] = strawberry.field( resolver=c_mod.c_inheritance_resolver ) [field] = Query.__strawberry_definition__.fields - assert field.type == List[c_mod.CInheritance] + assert field.type == list[c_mod.CInheritance] def test_c_inheritance_typeless_resolver(): @strawberry.type class Query: - c: List[c_mod.CInheritance] = strawberry.field(resolver=x_mod.typeless_resolver) + c: list[c_mod.CInheritance] = strawberry.field(resolver=x_mod.typeless_resolver) [field] = Query.__strawberry_definition__.fields - assert field.type == List[c_mod.CInheritance] + assert field.type == list[c_mod.CInheritance] def test_c_inheritance_resolver_only(): @@ -106,21 +104,21 @@ class Query: c = strawberry.field(resolver=c_mod.c_inheritance_resolver) [field] = Query.__strawberry_definition__.fields - assert field.type == List[c_mod.CInheritance] + assert field.type == list[c_mod.CInheritance] def test_c_composition_resolver(): @strawberry.type class Query: - c: List[c_mod.CComposition] = strawberry.field( + c: list[c_mod.CComposition] = strawberry.field( resolver=c_mod.c_composition_resolver ) [field] = Query.__strawberry_definition__.fields - assert field.type == List[c_mod.CComposition] + assert field.type == list[c_mod.CComposition] [a_field, b_field] = field.type.of_type.__strawberry_definition__.fields - assert a_field.type == List[a_mod.AObject] - assert b_field.type == List[b_mod.BObject] + assert a_field.type == list[a_mod.AObject] + assert b_field.type == list[b_mod.BObject] def test_c_composition_by_name_with_resolvers(): @@ -128,8 +126,8 @@ def test_c_composition_by_name_with_resolvers(): a_field, b_field, ] = c_mod.CCompositionByNameWithResolvers.__strawberry_definition__.fields - assert a_field.type == List[a_mod.AObject] - assert b_field.type == List[b_mod.BObject] + assert a_field.type == list[a_mod.AObject] + assert b_field.type == list[b_mod.BObject] def test_c_composition_by_name_with_typeless_resolvers(): @@ -137,8 +135,8 @@ def test_c_composition_by_name_with_typeless_resolvers(): a_field, b_field, ] = c_mod.CCompositionByNameWithTypelessResolvers.__strawberry_definition__.fields - assert a_field.type == List[a_mod.AObject] - assert b_field.type == List[b_mod.BObject] + assert a_field.type == list[a_mod.AObject] + assert b_field.type == list[b_mod.BObject] def test_c_composition_only_resolvers(): @@ -146,14 +144,14 @@ def test_c_composition_only_resolvers(): a_field, b_field, ] = c_mod.CCompositionOnlyResolvers.__strawberry_definition__.fields - assert a_field.type == List[a_mod.AObject] - assert b_field.type == List[b_mod.BObject] + assert a_field.type == list[a_mod.AObject] + assert b_field.type == list[b_mod.BObject] def test_x_resolver(): @strawberry.type class Query: - c: List[a_mod.AObject] = strawberry.field(resolver=x_mod.typeless_resolver) + c: list[a_mod.AObject] = strawberry.field(resolver=x_mod.typeless_resolver) [c_field] = Query.__strawberry_definition__.fields - assert c_field.type == List[a_mod.AObject] + assert c_field.type == list[a_mod.AObject] diff --git a/tests/types/cross_module_resolvers/x_mod.py b/tests/types/cross_module_resolvers/x_mod.py index 5f816458fc..80be36ee10 100644 --- a/tests/types/cross_module_resolvers/x_mod.py +++ b/tests/types/cross_module_resolvers/x_mod.py @@ -1,5 +1,2 @@ -from typing import List - - -def typeless_resolver() -> List: +def typeless_resolver() -> list: # pragma: no cover return [] diff --git a/tests/types/resolving/test_generics.py b/tests/types/resolving/test_generics.py index 9c10f8cc73..bf5529c081 100644 --- a/tests/types/resolving/test_generics.py +++ b/tests/types/resolving/test_generics.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Generic, List, Optional, TypeVar, Union +from typing import Generic, Optional, TypeVar, Union import pytest @@ -34,14 +34,14 @@ def test_basic_generic(): def test_generic_lists(): T = TypeVar("T") - annotation = StrawberryAnnotation(List[T]) + annotation = StrawberryAnnotation(list[T]) resolved = annotation.resolve() assert isinstance(resolved, StrawberryList) assert isinstance(resolved.of_type, StrawberryTypeVar) assert resolved.is_graphql_generic - assert resolved == List[T] + assert resolved == list[T] def test_generic_objects(): diff --git a/tests/types/resolving/test_lists.py b/tests/types/resolving/test_lists.py index c3e50cba1d..a09a634602 100644 --- a/tests/types/resolving/test_lists.py +++ b/tests/types/resolving/test_lists.py @@ -1,8 +1,5 @@ -import sys from collections.abc import Sequence -from typing import List, Optional, Tuple, Union - -import pytest +from typing import Optional, Union import strawberry from strawberry.annotation import StrawberryAnnotation @@ -10,31 +7,27 @@ def test_basic_list(): - annotation = StrawberryAnnotation(List[str]) + annotation = StrawberryAnnotation(list[str]) resolved = annotation.resolve() assert isinstance(resolved, StrawberryList) assert resolved.of_type is str assert resolved == StrawberryList(of_type=str) - assert resolved == List[str] + assert resolved == list[str] def test_basic_tuple(): - annotation = StrawberryAnnotation(Tuple[str]) + annotation = StrawberryAnnotation(tuple[str]) resolved = annotation.resolve() assert isinstance(resolved, StrawberryList) assert resolved.of_type is str assert resolved == StrawberryList(of_type=str) - assert resolved == Tuple[str] + assert resolved == tuple[str] -@pytest.mark.skipif( - sys.version_info < (3, 9), - reason="collections.abc.Sequence supporting [] was added in python 3.9", -) def test_basic_sequence(): annotation = StrawberryAnnotation(Sequence[str]) resolved = annotation.resolve() @@ -47,20 +40,16 @@ def test_basic_sequence(): def test_list_of_optional(): - annotation = StrawberryAnnotation(List[Optional[int]]) + annotation = StrawberryAnnotation(list[Optional[int]]) resolved = annotation.resolve() assert isinstance(resolved, StrawberryList) assert resolved.of_type == Optional[int] assert resolved == StrawberryList(of_type=Optional[int]) - assert resolved == List[Optional[int]] + assert resolved == list[Optional[int]] -@pytest.mark.skipif( - sys.version_info < (3, 9), - reason="collections.abc.Sequence supporting [] was added in python 3.9", -) def test_sequence_of_optional(): annotation = StrawberryAnnotation(Sequence[Optional[int]]) resolved = annotation.resolve() @@ -73,31 +62,27 @@ def test_sequence_of_optional(): def test_tuple_of_optional(): - annotation = StrawberryAnnotation(Tuple[Optional[int]]) + annotation = StrawberryAnnotation(tuple[Optional[int]]) resolved = annotation.resolve() assert isinstance(resolved, StrawberryList) assert resolved.of_type == Optional[int] assert resolved == StrawberryList(of_type=Optional[int]) - assert resolved == Tuple[Optional[int]] + assert resolved == tuple[Optional[int]] def test_list_of_lists(): - annotation = StrawberryAnnotation(List[List[float]]) + annotation = StrawberryAnnotation(list[list[float]]) resolved = annotation.resolve() assert isinstance(resolved, StrawberryList) - assert resolved.of_type == List[float] + assert resolved.of_type == list[float] - assert resolved == StrawberryList(of_type=List[float]) - assert resolved == List[List[float]] + assert resolved == StrawberryList(of_type=list[float]) + assert resolved == list[list[float]] -@pytest.mark.skipif( - sys.version_info < (3, 9), - reason="collections.abc.Sequence supporting [] was added in python 3.9", -) def test_sequence_of_sequence(): annotation = StrawberryAnnotation(Sequence[Sequence[float]]) resolved = annotation.resolve() @@ -110,14 +95,14 @@ def test_sequence_of_sequence(): def test_tuple_of_tuple(): - annotation = StrawberryAnnotation(Tuple[Tuple[float]]) + annotation = StrawberryAnnotation(tuple[tuple[float]]) resolved = annotation.resolve() assert isinstance(resolved, StrawberryList) - assert resolved.of_type == Tuple[float] + assert resolved.of_type == tuple[float] - assert resolved == StrawberryList(of_type=Tuple[float]) - assert resolved == Tuple[Tuple[float]] + assert resolved == StrawberryList(of_type=tuple[float]) + assert resolved == tuple[tuple[float]] def test_list_of_union(): @@ -129,20 +114,16 @@ class Animal: class Fungus: spore: bool - annotation = StrawberryAnnotation(List[Union[Animal, Fungus]]) + annotation = StrawberryAnnotation(list[Union[Animal, Fungus]]) resolved = annotation.resolve() assert isinstance(resolved, StrawberryList) assert resolved.of_type == Union[Animal, Fungus] assert resolved == StrawberryList(of_type=Union[Animal, Fungus]) - assert resolved == List[Union[Animal, Fungus]] + assert resolved == list[Union[Animal, Fungus]] -@pytest.mark.skipif( - sys.version_info < (3, 9), - reason="collections.abc.Sequence supporting [] was added in python 3.9", -) def test_sequence_of_union(): @strawberry.type class Animal: @@ -162,10 +143,6 @@ class Fungus: assert resolved == Sequence[Union[Animal, Fungus]] -@pytest.mark.skipif( - sys.version_info < (3, 9), - reason="built-in generic annotations where added in python 3.9", -) def test_list_builtin(): annotation = StrawberryAnnotation(list[str]) resolved = annotation.resolve() @@ -174,14 +151,10 @@ def test_list_builtin(): assert resolved.of_type is str assert resolved == StrawberryList(of_type=str) - assert resolved == List[str] + assert resolved == list[str] assert resolved == list[str] -@pytest.mark.skipif( - sys.version_info < (3, 9), - reason="built-in generic annotations where added in python 3.9", -) def test_tuple_builtin(): annotation = StrawberryAnnotation(tuple[str]) resolved = annotation.resolve() @@ -190,5 +163,5 @@ def test_tuple_builtin(): assert resolved.of_type is str assert resolved == StrawberryList(of_type=str) - assert resolved == Tuple[str] + assert resolved == tuple[str] assert resolved == tuple[str] diff --git a/tests/types/resolving/test_optionals.py b/tests/types/resolving/test_optionals.py index 3127d0caaf..f01f8b74d4 100644 --- a/tests/types/resolving/test_optionals.py +++ b/tests/types/resolving/test_optionals.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union +from typing import Optional, Union import strawberry from strawberry.annotation import StrawberryAnnotation @@ -40,14 +40,14 @@ def test_optional_with_unset_as_union(): def test_optional_list(): - annotation = StrawberryAnnotation(Optional[List[bool]]) + annotation = StrawberryAnnotation(Optional[list[bool]]) resolved = annotation.resolve() assert isinstance(resolved, StrawberryOptional) - assert resolved.of_type == List[bool] + assert resolved.of_type == list[bool] - assert resolved == StrawberryOptional(of_type=List[bool]) - assert resolved == Optional[List[bool]] + assert resolved == StrawberryOptional(of_type=list[bool]) + assert resolved == Optional[list[bool]] def test_optional_optional(): diff --git a/tests/types/resolving/test_string_annotations.py b/tests/types/resolving/test_string_annotations.py index 321384b7e4..ce187896c7 100644 --- a/tests/types/resolving/test_string_annotations.py +++ b/tests/types/resolving/test_string_annotations.py @@ -1,4 +1,4 @@ -from typing import List, Optional, TypeVar +from typing import Optional, TypeVar import strawberry from strawberry.annotation import StrawberryAnnotation @@ -17,14 +17,14 @@ def test_basic_string(): def test_list_of_string(): - annotation = StrawberryAnnotation(List["int"]) + annotation = StrawberryAnnotation(list["int"]) resolved = annotation.resolve() assert isinstance(resolved, StrawberryList) assert resolved.of_type is int assert resolved == StrawberryList(of_type=int) - assert resolved == List[int] + assert resolved == list[int] def test_list_of_string_of_type(): @@ -32,14 +32,14 @@ def test_list_of_string_of_type(): class NameGoesHere: foo: bool - annotation = StrawberryAnnotation(List["NameGoesHere"], namespace=locals()) + annotation = StrawberryAnnotation(list["NameGoesHere"], namespace=locals()) resolved = annotation.resolve() assert isinstance(resolved, StrawberryList) assert resolved.of_type is NameGoesHere assert resolved == StrawberryList(of_type=NameGoesHere) - assert resolved == List[NameGoesHere] + assert resolved == list[NameGoesHere] def test_optional_of_string(): @@ -79,14 +79,14 @@ def test_string_of_type_var(): def test_string_of_list(): namespace = {**locals(), **globals()} - annotation = StrawberryAnnotation("List[float]", namespace=namespace) + annotation = StrawberryAnnotation("list[float]", namespace=namespace) resolved = annotation.resolve() assert isinstance(resolved, StrawberryList) assert resolved.of_type is float assert resolved == StrawberryList(of_type=float) - assert resolved == List[float] + assert resolved == list[float] def test_string_of_list_of_type(): @@ -96,14 +96,14 @@ class BlahBlah: namespace = {**locals(), **globals()} - annotation = StrawberryAnnotation("List[BlahBlah]", namespace=namespace) + annotation = StrawberryAnnotation("list[BlahBlah]", namespace=namespace) resolved = annotation.resolve() assert isinstance(resolved, StrawberryList) assert resolved.of_type is BlahBlah assert resolved == StrawberryList(of_type=BlahBlah) - assert resolved == List[BlahBlah] + assert resolved == list[BlahBlah] def test_string_of_optional(): @@ -163,7 +163,7 @@ class Query: def test_basic_list(): @strawberry.type class Query: - names: "List[str]" + names: "list[str]" definition = Query.__strawberry_definition__ assert definition.name == "Query" @@ -185,7 +185,7 @@ class User: @strawberry.type class Query: - users: "List[User]" + users: "list[User]" definition = Query.__strawberry_definition__ assert definition.name == "Query" diff --git a/tests/types/resolving/test_unions.py b/tests/types/resolving/test_unions.py index dee6ecb39d..afb7efc4f9 100644 --- a/tests/types/resolving/test_unions.py +++ b/tests/types/resolving/test_unions.py @@ -1,6 +1,5 @@ import sys -from typing import Generic, TypeVar, Union -from typing_extensions import Annotated +from typing import Annotated, Generic, TypeVar, Union import pytest diff --git a/tests/types/test_argument_types.py b/tests/types/test_argument_types.py index 6ee0aeb4b5..b9d3c9ce2c 100644 --- a/tests/types/test_argument_types.py +++ b/tests/types/test_argument_types.py @@ -1,6 +1,6 @@ import warnings from enum import Enum -from typing import List, Optional, TypeVar +from typing import Optional, TypeVar import pytest @@ -45,12 +45,12 @@ class SearchInput: def test_list(): @strawberry.field - def get_longest_word(words: List[str]) -> str: + def get_longest_word(words: list[str]) -> str: _ = words return "I cheated" argument = get_longest_word.arguments[0] - assert argument.type == List[str] + assert argument.type == list[str] def test_literal(): diff --git a/tests/types/test_convert_to_dictionary.py b/tests/types/test_convert_to_dictionary.py index 4281e30712..496cc69641 100644 --- a/tests/types/test_convert_to_dictionary.py +++ b/tests/types/test_convert_to_dictionary.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import List, Optional +from typing import Optional import strawberry from strawberry import asdict @@ -32,7 +32,7 @@ class Animal: @strawberry.type class People: name: str - animals: List[Animal] + animals: list[Animal] lorem = People( name="Kevin", animals=[Animal(legs=Count.TWO), Animal(legs=Count.FOUR)] @@ -52,7 +52,7 @@ def test_convert_input_to_dictionary(): class QnaInput: title: str description: str - tags: Optional[List[str]] = strawberry.field(default=None) + tags: Optional[list[str]] = strawberry.field(default=None) title = "Where is the capital of United Kingdom?" description = "London is the capital of United Kingdom." diff --git a/tests/types/test_field_types.py b/tests/types/test_field_types.py index fab4df79ef..f50e9d16db 100644 --- a/tests/types/test_field_types.py +++ b/tests/types/test_field_types.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import List, Optional, TypeVar +from typing import Optional, TypeVar import strawberry from strawberry.annotation import StrawberryAnnotation @@ -36,10 +36,10 @@ class RefForward: def test_list(): - annotation = StrawberryAnnotation(List[int]) + annotation = StrawberryAnnotation(list[int]) field = StrawberryField(type_annotation=annotation) - assert field.type == List[int] + assert field.type == list[int] def test_literal(): diff --git a/tests/types/test_lazy_types.py b/tests/types/test_lazy_types.py index 3dfc55bcc5..b03fb38c3c 100644 --- a/tests/types/test_lazy_types.py +++ b/tests/types/test_lazy_types.py @@ -2,8 +2,8 @@ import enum import sys import textwrap -from typing import Generic, TypeVar, Union -from typing_extensions import Annotated, TypeAlias +from typing import Annotated, Generic, TypeVar, Union +from typing_extensions import TypeAlias import pytest diff --git a/tests/types/test_lazy_types_future_annotations.py b/tests/types/test_lazy_types_future_annotations.py index 09f7f6a6d0..fb708bfcd0 100644 --- a/tests/types/test_lazy_types_future_annotations.py +++ b/tests/types/test_lazy_types_future_annotations.py @@ -1,7 +1,7 @@ from __future__ import annotations import textwrap -from typing_extensions import Annotated +from typing import Annotated import strawberry diff --git a/tests/types/test_object_types.py b/tests/types/test_object_types.py index ed31aa1423..eb9e304354 100644 --- a/tests/types/test_object_types.py +++ b/tests/types/test_object_types.py @@ -2,8 +2,7 @@ import dataclasses import re from enum import Enum -from typing import List, Optional, TypeVar, Union -from typing_extensions import Annotated +from typing import Annotated, Optional, TypeVar, Union import pytest @@ -49,11 +48,11 @@ class FromTheFuture: def test_list(): @strawberry.type class Santa: - making_a: List[str] + making_a: list[str] field: StrawberryField = get_object_definition(Santa).fields[0] - assert field.type == List[str] + assert field.type == list[str] def test_literal(): @@ -145,11 +144,11 @@ class A: @strawberry.type class B(A): - attachments: Optional[List[A]] = None + attachments: Optional[list[A]] = None @strawberry.type class C(A): - fields: List[B] + fields: list[B] c_inst = C( text="some text", diff --git a/tests/types/test_resolver_types.py b/tests/types/test_resolver_types.py index c5e49727de..3e036a393f 100644 --- a/tests/types/test_resolver_types.py +++ b/tests/types/test_resolver_types.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import List, Optional, TypeVar, Union +from typing import Optional, TypeVar, Union from asgiref.sync import sync_to_async @@ -39,11 +39,11 @@ class FutureUmpire: def test_list(): - def get_collection_types() -> List[str]: + def get_collection_types() -> list[str]: return ["list", "tuple", "dict", "set"] resolver = StrawberryResolver(get_collection_types) - assert resolver.type == List[str] + assert resolver.type == list[str] def test_literal(): diff --git a/tests/utils/test_arguments_converter.py b/tests/utils/test_arguments_converter.py index cb8f9e26ed..36a744df9f 100644 --- a/tests/utils/test_arguments_converter.py +++ b/tests/utils/test_arguments_converter.py @@ -1,6 +1,5 @@ from enum import Enum -from typing import List, Optional -from typing_extensions import Annotated +from typing import Annotated, Optional import pytest @@ -63,12 +62,12 @@ def test_list(): StrawberryArgument( graphql_name="integerList", python_name="integer_list", - type_annotation=StrawberryAnnotation(List[int]), + type_annotation=StrawberryAnnotation(list[int]), ), StrawberryArgument( graphql_name="stringList", python_name="string_list", - type_annotation=StrawberryAnnotation(List[str]), + type_annotation=StrawberryAnnotation(list[str]), ), ] @@ -196,7 +195,7 @@ class MyInput: StrawberryArgument( graphql_name="inputList", python_name="input_list", - type_annotation=StrawberryAnnotation(List[MyInput]), + type_annotation=StrawberryAnnotation(list[MyInput]), ), ] @@ -219,7 +218,7 @@ class MyInput: StrawberryArgument( graphql_name="inputList", python_name="input_list", - type_annotation=StrawberryAnnotation(Optional[List[MyInput]]), + type_annotation=StrawberryAnnotation(Optional[list[MyInput]]), ), ] assert convert_arguments( @@ -321,7 +320,7 @@ class Number: @strawberry.input class Input: - numbers: List[Number] + numbers: list[Number] args = {"input": {"numbers": [{"value": 1}, {"value": 2}]}} @@ -425,7 +424,7 @@ class Input: ) def test_fails_when_passing_non_strawberry_classes(): class Input: - numbers: List[int] + numbers: list[int] args = { "input": { diff --git a/tests/utils/test_typing.py b/tests/utils/test_typing.py index bf8809bc3d..0a6fe06122 100644 --- a/tests/utils/test_typing.py +++ b/tests/utils/test_typing.py @@ -1,9 +1,5 @@ -import sys import typing -from typing import ClassVar, ForwardRef, Optional, Union -from typing_extensions import Annotated - -import pytest +from typing import Annotated, ClassVar, ForwardRef, Optional, Union import strawberry from strawberry.types.lazy_type import LazyType @@ -78,10 +74,6 @@ class Foo: ... ) -@pytest.mark.skipif( - sys.version_info < (3, 9), - reason="python 3.8 resolves Annotated differently", -) def test_eval_type_with_deferred_annotations(): assert ( eval_type( @@ -109,37 +101,6 @@ def test_eval_type_with_deferred_annotations(): ) -@pytest.mark.skipif( - sys.version_info >= (3, 9), - reason="python 3.8 resolves Annotated differently", -) -def test_eval_type_with_deferred_annotations_3_8(): - assert ( - eval_type( - ForwardRef( - "Annotated['Fruit', strawberry.lazy('tests.utils.test_typing')]" - ), - {"strawberry": strawberry, "Annotated": Annotated}, - None, - ) - == Annotated[ - ForwardRef("Fruit"), - strawberry.lazy("tests.utils.test_typing"), - ] - ) - assert ( - eval_type( - ForwardRef("Annotated['datetime', strawberry.lazy('datetime')]"), - {"strawberry": strawberry, "Annotated": Annotated}, - None, - ) - == Annotated[ - ForwardRef("datetime"), - strawberry.lazy("datetime"), - ] - ) - - def test_is_classvar(): class Foo: attr1: str diff --git a/tests/utils/test_typing_forward_refs.py b/tests/utils/test_typing_forward_refs.py index 5ae315cdcc..a14b1174c5 100644 --- a/tests/utils/test_typing_forward_refs.py +++ b/tests/utils/test_typing_forward_refs.py @@ -1,10 +1,7 @@ from __future__ import annotations -import sys import typing -from typing import ClassVar, ForwardRef, List, Optional, Union - -import pytest +from typing import ClassVar, ForwardRef, Optional, Union from strawberry.scalars import JSON from strawberry.utils.typing import eval_type, is_classvar @@ -20,20 +17,16 @@ class Foo: ... == Union[Foo, str, None] ) assert ( - eval_type(ForwardRef("List[Foo | str] | None"), globals(), locals()) - == Union[List[Union[Foo, str]], None] + eval_type(ForwardRef("list[Foo | str] | None"), globals(), locals()) + == Union[list[Union[Foo, str]], None] ) assert ( - eval_type(ForwardRef("List[Foo | str] | None | int"), globals(), locals()) - == Union[List[Union[Foo, str]], int, None] + eval_type(ForwardRef("list[Foo | str] | None | int"), globals(), locals()) + == Union[list[Union[Foo, str]], int, None] ) assert eval_type(ForwardRef("JSON | None"), globals(), locals()) == Optional[JSON] -@pytest.mark.skipif( - sys.version_info < (3, 9), - reason="generic type alias only available on python 3.9+", -) def test_eval_type_generic_type_alias(): class Foo: ... diff --git a/tests/views/schema.py b/tests/views/schema.py index 311fccbb5f..7cd69e581f 100644 --- a/tests/views/schema.py +++ b/tests/views/schema.py @@ -1,7 +1,8 @@ import asyncio import contextlib +from collections.abc import AsyncGenerator from enum import Enum -from typing import Any, AsyncGenerator, Dict, List, Optional, Union +from typing import Any, Optional, Union from graphql import GraphQLError @@ -21,7 +22,7 @@ def has_permission(self, source: Any, info: strawberry.Info, **kwargs: Any) -> b class MyExtension(SchemaExtension): - def get_results(self) -> Dict[str, str]: + def get_results(self) -> dict[str, str]: return {"example": "example"} @@ -52,7 +53,7 @@ class Flavor(Enum): @strawberry.input class FolderInput: - files: List[Upload] + files: list[Upload] @strawberry.type @@ -135,11 +136,11 @@ def read_text(self, text_file: Upload) -> str: return _read_file(text_file) @strawberry.mutation - def read_files(self, files: List[Upload]) -> List[str]: + def read_files(self, files: list[Upload]) -> list[str]: return list(map(_read_file, files)) @strawberry.mutation - def read_folder(self, folder: FolderInput) -> List[str]: + def read_folder(self, folder: FolderInput) -> list[str]: return list(map(_read_file, folder.files)) @strawberry.mutation @@ -266,7 +267,7 @@ async def long_finalizer( class Schema(strawberry.Schema): def process_errors( - self, errors: List, execution_context: Optional[ExecutionContext] = None + self, errors: list, execution_context: Optional[ExecutionContext] = None ) -> None: import traceback diff --git a/tests/websockets/conftest.py b/tests/websockets/conftest.py index 98c257502a..e12dfa8632 100644 --- a/tests/websockets/conftest.py +++ b/tests/websockets/conftest.py @@ -1,5 +1,6 @@ import importlib -from typing import Any, Generator, Type +from collections.abc import Generator +from typing import Any import pytest @@ -33,10 +34,10 @@ def _get_http_client_classes() -> Generator[Any, None, None]: @pytest.fixture(params=_get_http_client_classes()) -def http_client_class(request: Any) -> Type[HttpClient]: +def http_client_class(request: Any) -> type[HttpClient]: return request.param @pytest.fixture() -def http_client(http_client_class: Type[HttpClient]) -> HttpClient: +def http_client(http_client_class: type[HttpClient]) -> HttpClient: return http_client_class() diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index 3a7b5849a3..787d215aac 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -4,8 +4,9 @@ import contextlib import json import time +from collections.abc import AsyncGenerator from datetime import timedelta -from typing import TYPE_CHECKING, AsyncGenerator, Dict, Optional, Type, Union +from typing import TYPE_CHECKING, Optional, Union from unittest.mock import AsyncMock, Mock, patch import pytest @@ -51,8 +52,8 @@ async def ws(ws_raw: WebSocketClient) -> WebSocketClient: def assert_next( next_message: NextMessage, id: str, - data: Dict[str, object], - extensions: Optional[Dict[str, object]] = None, + data: dict[str, object], + extensions: Optional[dict[str, object]] = None, ): """ Assert that the NextMessage payload contains the provided data. @@ -153,7 +154,7 @@ async def test_ws_message_frame_types_cannot_be_mixed(ws_raw: WebSocketClient): async def test_connection_init_timeout( - request: object, http_client_class: Type[HttpClient] + request: object, http_client_class: type[HttpClient] ): with contextlib.suppress(ImportError): from tests.http.clients.aiohttp import AioHttpClient @@ -206,7 +207,7 @@ async def test_connection_init_timeout_cancellation( @pytest.mark.xfail(reason="This test is flaky") async def test_close_twice( - mocker: MockerFixture, request: object, http_client_class: Type[HttpClient] + mocker: MockerFixture, request: object, http_client_class: type[HttpClient] ): test_client = http_client_class() test_client.create_app(connection_init_wait_timeout=timedelta(seconds=0.25)) diff --git a/tests/websockets/test_graphql_ws.py b/tests/websockets/test_graphql_ws.py index 26488c9911..246caf76b5 100644 --- a/tests/websockets/test_graphql_ws.py +++ b/tests/websockets/test_graphql_ws.py @@ -2,7 +2,8 @@ import asyncio import json -from typing import TYPE_CHECKING, AsyncGenerator, Union +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, Union from unittest import mock import pytest diff --git a/tests/websockets/test_websockets.py b/tests/websockets/test_websockets.py index d85eda42d5..2ca98e9138 100644 --- a/tests/websockets/test_websockets.py +++ b/tests/websockets/test_websockets.py @@ -1,5 +1,3 @@ -from typing import Type - from strawberry.http.async_base_view import AsyncBaseHTTPView from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( @@ -8,7 +6,7 @@ from tests.http.clients.base import HttpClient -async def test_turning_off_graphql_ws(http_client_class: Type[HttpClient]): +async def test_turning_off_graphql_ws(http_client_class: type[HttpClient]): http_client = http_client_class() http_client.create_app(subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL]) @@ -21,7 +19,7 @@ async def test_turning_off_graphql_ws(http_client_class: Type[HttpClient]): assert ws.close_reason == "Subprotocol not acceptable" -async def test_turning_off_graphql_transport_ws(http_client_class: Type[HttpClient]): +async def test_turning_off_graphql_transport_ws(http_client_class: type[HttpClient]): http_client = http_client_class() http_client.create_app(subscription_protocols=[GRAPHQL_WS_PROTOCOL]) @@ -34,7 +32,7 @@ async def test_turning_off_graphql_transport_ws(http_client_class: Type[HttpClie assert ws.close_reason == "Subprotocol not acceptable" -async def test_turning_off_all_subprotocols(http_client_class: Type[HttpClient]): +async def test_turning_off_all_subprotocols(http_client_class: type[HttpClient]): http_client = http_client_class() http_client.create_app(subscription_protocols=[]) @@ -65,7 +63,7 @@ async def test_generally_unsupported_subprotocols_are_rejected(http_client: Http assert ws.close_reason == "Subprotocol not acceptable" -async def test_clients_can_prefer_subprotocols(http_client_class: Type[HttpClient]): +async def test_clients_can_prefer_subprotocols(http_client_class: type[HttpClient]): http_client = http_client_class() http_client.create_app( subscription_protocols=[GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL] diff --git a/tests/websockets/views.py b/tests/websockets/views.py index eec511131e..1f7b2eaee2 100644 --- a/tests/websockets/views.py +++ b/tests/websockets/views.py @@ -1,4 +1,4 @@ -from typing import Dict, Union +from typing import Union from strawberry import UNSET from strawberry.exceptions import ConnectionRejectionError @@ -8,8 +8,8 @@ class OnWSConnectMixin(AsyncBaseHTTPView): async def on_ws_connect( - self, context: Dict[str, object] - ) -> Union[UnsetType, None, Dict[str, object]]: + self, context: dict[str, object] + ) -> Union[UnsetType, None, dict[str, object]]: connection_params = context["connection_params"] if isinstance(connection_params, dict):