Skip to content

Commit

Permalink
Improve Binders, fix ClientSession ssl bug
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertoPrevato authored Feb 17, 2023
1 parent 5913265 commit 3bdad1c
Show file tree
Hide file tree
Showing 27 changed files with 104 additions and 60 deletions.
11 changes: 10 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,18 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [2.0a1] - 2023-02-17 :heart:

- Improves how custom binders can be defined, reducing code verbosity for
custom types. This is an important feature to implement common validation of
common parameters across multiple endpoints.
- Adds support for binder types defining OpenAPI Specification for their
parameters.
- Fixes bug #305 (`ClientSession ssl=False` not working as intended).

## [2.0a0] - 2023-01-08 :hourglass_flowing_sand:

- Renames the `plugins` namespace to `settings`
- Renames the `plugins` namespace to `settings`.
- Upgrades `rodi` to v2, which includes improvements.
- Adds support for alternative implementation of containers for dependency
injection, using the new `ContainerProtocol` in `rodi`.
Expand Down
6 changes: 2 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ import asyncio
from blacksheep.client import ClientSession


async def client_example(loop):
async def client_example():
async with ClientSession() as client:
response = await client.get("https://docs.python.org/3/")

Expand All @@ -249,9 +249,7 @@ async def client_example(loop):
print(text)


loop = asyncio.get_event_loop()
loop.run_until_complete(client_example(loop))

asyncio.run(client_example())
```

## Supported platforms and runtimes
Expand Down
2 changes: 1 addition & 1 deletion blacksheep/client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

INSECURE_SSLCONTEXT = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT)
INSECURE_SSLCONTEXT.check_hostname = False
INSECURE_SSLCONTEXT.verify_mode = ssl.CERT_NONE


class IncomingContent(Content):
Expand Down Expand Up @@ -80,7 +81,6 @@ def __init__(self, response, transport):


class ClientConnection(asyncio.Protocol):

__slots__ = (
"loop",
"pool",
Expand Down
2 changes: 0 additions & 2 deletions blacksheep/client/cookies.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def not_ip_address(value: str):


class StoredCookie:

__slots__ = ("cookie", "persistent", "creation_time", "expiry_time")

def __init__(self, cookie: Cookie):
Expand Down Expand Up @@ -203,7 +202,6 @@ def _get_cookies_checking_exp(
schema: str, cookies: Dict[str, StoredCookie]
) -> Iterable[Cookie]:
for cookie_name, stored_cookie in cookies.copy().items():

if stored_cookie.is_expired():
del cookies[cookie_name]
continue
Expand Down
2 changes: 2 additions & 0 deletions blacksheep/client/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ def get_ssl_context(
"Invalid ssl argument, expected one of: "
"None, False, True, instance of ssl.SSLContext."
)

if ssl:
raise InvalidArgument("SSL argument specified for non-https scheme.")

return None


Expand Down
7 changes: 3 additions & 4 deletions blacksheep/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def __contains__(self, item: Any) -> bool:


class ClientRequestContext:

__slots__ = ("path", "cookies")

def __init__(self, request, cookies: Optional[CookieJar] = None):
Expand Down Expand Up @@ -345,17 +344,17 @@ async def _send_core(self, request: Request) -> Response:

return response

async def _send_using_connection(self, request) -> Response:
async def _send_using_connection(self, request, attempt: int = 1) -> Response:
connection = await self.get_connection(request.url)

try:
return await asyncio.wait_for(
connection.send(request), self.request_timeout
)
except ConnectionClosedError as connection_closed_error:
if connection_closed_error.can_retry:
if connection_closed_error.can_retry and attempt < 4:
await asyncio.sleep(self.delay_before_retry)
return await self._send_using_connection(request)
return await self._send_using_connection(request, attempt + 1)
raise
except TimeoutError:
raise RequestTimeout(request.url, self.request_timeout)
Expand Down
1 change: 0 additions & 1 deletion blacksheep/common/files/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@


class FileInfo:

__slots__ = ("etag", "size", "mime", "modified_time")

def __init__(self, size: int, etag: str, mime: str, modified_time: str):
Expand Down
1 change: 0 additions & 1 deletion blacksheep/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def parse_multipart(value: bytes) -> Generator[FormPart, None, None]:
default_charset = None

for part_bytes in split_multipart(value):

try:
yield parse_part(part_bytes, default_charset)
except CharsetPart as charset:
Expand Down
1 change: 0 additions & 1 deletion blacksheep/ranges.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def _parse_range_value(range_value: str):


class Range:

__slots__ = ("_unit", "_parts")

def __init__(self, unit: str, parts: Sequence[RangePart]):
Expand Down
21 changes: 16 additions & 5 deletions blacksheep/server/bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,15 @@ class RequestMethod(BoundValue[str]):
"""


def _implicit_default(obj: "Binder"):
try:
return issubclass(obj.handle, BoundValue)
except (AttributeError, TypeError):
return False


class Binder(metaclass=BinderMeta): # type: ignore
handle: ClassVar[Type[BoundValue]]
handle: ClassVar[Type[Any]]
name_alias: ClassVar[str] = ""
type_alias: ClassVar[Any] = None

Expand All @@ -252,7 +259,7 @@ def __init__(
required: bool = True,
converter: Optional[Callable] = None,
):
self._implicit = implicit
self._implicit = implicit or not _implicit_default(self)
self.parameter_name = name
self.expected_type = expected_type
self.required = required
Expand Down Expand Up @@ -316,7 +323,10 @@ def example(id: str):
# applied implicitly
...
"""
value = await self.get_value(request)
try:
value = await self.get_value(request)
except ValueError as value_error:
raise BadRequest("Invalid parameter.") from value_error

if value is None and self.default is not empty:
return self.default
Expand All @@ -334,6 +344,7 @@ def example(id: str):
@abstractmethod
async def get_value(self, request: Request) -> Any:
"""Gets a value from the given request object."""
raise NotImplementedError()


def get_binder_by_type(bound_value_type: Type[BoundValue]) -> Type[Binder]:
Expand Down Expand Up @@ -405,7 +416,7 @@ class BodyBinder(Binder):

def __init__(
self,
expected_type: T,
expected_type,
name: str = "body",
implicit: bool = False,
required: bool = False,
Expand Down Expand Up @@ -592,7 +603,7 @@ class SyncBinder(Binder):

def __init__(
self,
expected_type: T = List[str],
expected_type: Any = List[str],
name: str = "",
implicit: bool = False,
required: bool = False,
Expand Down
13 changes: 10 additions & 3 deletions blacksheep/server/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def __init__(self, parameter_name, route):


def _check_union(
parameter: inspect.Parameter, annotation: Any, method: Callable[..., Any]
parameter: ParamInfo, annotation: Any, method: Callable[..., Any]
) -> Tuple[bool, Any]:
"""
Checks if the given annotation is Optional[] - in such case unwraps it
Expand Down Expand Up @@ -292,7 +292,7 @@ def _get_bound_value_type(bound_type: Type[BoundValue]) -> Type[Any]:


def _get_parameter_binder(
parameter: inspect.Parameter,
parameter: ParamInfo,
services: ContainerProtocol,
route: Optional[Route],
method: Callable[..., Any],
Expand All @@ -316,6 +316,13 @@ def _get_parameter_binder(
if annotation in Binder.aliases:
return Binder.aliases[annotation](services)

if (
annotation in Binder.handlers
and annotation not in services
and not issubclass(annotation, BoundValue)
):
return Binder.handlers[annotation](annotation, parameter.name)

# 1. is the type annotation of BoundValue[T] type?
if _is_bound_value_annotation(annotation):
binder_type = get_binder_by_type(annotation)
Expand Down Expand Up @@ -377,7 +384,7 @@ def _get_parameter_binder(


def get_parameter_binder(
parameter: inspect.Parameter,
parameter: ParamInfo,
services: ContainerProtocol,
route: Optional[Route],
method: Callable[..., Any],
Expand Down
1 change: 0 additions & 1 deletion blacksheep/server/openapi/docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def parse_docstring(self, docstring: str) -> DocstringInfo:


def type_repr_to_type(type_repr: str) -> Optional[Type]:

array_match = _array_rx.match(type_repr)

if array_match:
Expand Down
60 changes: 58 additions & 2 deletions blacksheep/server/openapi/v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass, fields, is_dataclass
from datetime import date, datetime
from enum import Enum, IntEnum
from typing import Any, Dict, List, Mapping, Optional, Tuple, Type, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
from typing import _GenericAlias as GenericAlias
from typing import get_type_hints
from uuid import UUID
Expand Down Expand Up @@ -295,6 +295,9 @@ def __init__(
DataClassTypeHandler(),
PydanticModelTypeHandler(),
]
self._binder_docs: Dict[
Type[Binder], Iterable[Union[Parameter, Reference]]
] = {}

@property
def object_types_handlers(self) -> List[ObjectTypeHandler]:
Expand All @@ -304,6 +307,7 @@ def get_ui_page_title(self) -> str:
return self.info.title

def generate_documentation(self, app: Application) -> OpenAPI:
self._optimize_binders_docs()
return OpenAPI(
info=self.info, paths=self.get_paths(app), components=self.components
)
Expand Down Expand Up @@ -697,12 +701,16 @@ def get_parameters(
if not hasattr(handler, "binders"):
return None
binders: List[Binder] = handler.binders
parameters: Mapping[str, Union[Parameter, Reference]] = {}
parameters: Dict[str, Union[Parameter, Reference]] = {}

docs = self.get_handler_docs(handler)
parameters_info = (docs.parameters if docs else None) or dict()

for binder in binders:
if binder.__class__ in self._binder_docs:
self._handle_binder_docs(binder, parameters)
continue

location = self.get_parameter_location_for_binder(binder)

if not location:
Expand Down Expand Up @@ -971,3 +979,51 @@ def get_routes_docs(

self.events.on_paths_created.fire_sync(paths_doc)
return paths_doc

def set_binder_docs(
self,
binder_type: Type[Binder],
params_docs: Iterable[Union[Parameter, Reference]],
):
"""
Configures parameters documentation for a given binder type. A binder can
read values from one or more input parameters, this is why this method supports
an iterable of Parameter or Reference objects. In most use cases, it is
desirable to use a Parameter here. Reference objects are configured
automatically when the documentation is built.
"""
self._binder_docs[binder_type] = params_docs

def _handle_binder_docs(
self, binder: Binder, parameters: Dict[str, Union[Parameter, Reference]]
):
params_docs = self._binder_docs[binder.__class__]

for i, param_doc in enumerate(params_docs):
parameters[f"{binder.__class__.__qualname__}_{i}"] = param_doc

def _optimize_binders_docs(self):
"""
Optimizes the documentation for custom binders to use references and
components.parameters, instead of duplicating parameters documentation in each
operation where they are used.
"""
new_dict = {}
params_docs: Iterable[Union[Parameter, Reference]]

for key, params_docs in self._binder_docs.items():
new_docs: List[Reference] = []

for param in params_docs:
if isinstance(param, Reference):
new_docs.append(param)
else:
if self.components.parameters is None:
self.components.parameters = {}

self.components.parameters[param.name] = param
new_docs.append(Reference(f"#/components/parameters/{param.name}"))

new_dict[key] = new_docs

self._binder_docs = new_dict
4 changes: 0 additions & 4 deletions blacksheep/server/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def __init__(self, parameter_pattern_name: str, matched_parameter: str) -> None:


class RouteMatch:

__slots__ = ("values", "pattern", "handler")

def __init__(self, route: "Route", values: Optional[Dict[str, bytes]]):
Expand All @@ -98,7 +97,6 @@ def _get_parameter_pattern_fragment(


class Route:

__slots__ = (
"handler",
"pattern",
Expand Down Expand Up @@ -378,7 +376,6 @@ def ws(self, pattern) -> Callable[..., Any]:


class Router(RouterBase):

__slots__ = ("routes", "_map", "_fallback")

def __init__(self):
Expand Down Expand Up @@ -478,7 +475,6 @@ def get_matching_route(self, method: AnyStr, value: AnyStr) -> Optional[Route]:


class RegisteredRoute:

__slots__ = ("method", "pattern", "handler")

def __init__(self, method: str, pattern: str, handler: Callable):
Expand Down
2 changes: 0 additions & 2 deletions itests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ async def test_post_form(session, data):

@pytest.mark.asyncio
async def test_post_multipart_form_with_files(session):

if os.path.exists("out"):
shutil.rmtree("out")

Expand Down Expand Up @@ -192,7 +191,6 @@ async def test_post_multipart_form_with_files(session):

@pytest.mark.asyncio
async def test_post_multipart_form_with_images(session):

if os.path.exists("out"):
shutil.rmtree("out")

Expand Down
Loading

0 comments on commit 3bdad1c

Please sign in to comment.