Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion connexion/middleware/request_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send):
self._operation.parameters, self._operation.body_definition()
),
)
receive = await validator.wrap_receive(receive, scope=scope)
receive, scope = await validator.wrap_receive(receive, scope=scope)

await self.next_app(scope, receive, send)

Expand Down
18 changes: 11 additions & 7 deletions connexion/validators/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,16 @@ def _validate(self, body: t.Any) -> t.Optional[dict]:
:raises: :class:`connexion.exceptions.BadRequestProblem`
"""

def _insert_body(self, receive: Receive, *, body: t.Any, scope: Scope) -> Receive:
def _insert_body(
self, receive: Receive, *, body: t.Any, scope: Scope
) -> t.Tuple[Receive, Scope]:
"""
Insert messages transmitting the body at the start of the `receive` channel.

This method updates the provided `scope` in place with the right `Content-Length` header.
"""
if body is None:
return receive
return receive, scope

bytes_body = json.dumps(body).encode(self._encoding)

Expand All @@ -93,7 +95,7 @@ def _insert_body(self, receive: Receive, *, body: t.Any, scope: Scope) -> Receiv

receive = self._insert_messages(receive, messages=messages)

return receive
return receive, new_scope

@staticmethod
def _insert_messages(
Expand All @@ -111,7 +113,9 @@ async def receive_() -> t.MutableMapping[str, t.Any]:

return receive_

async def wrap_receive(self, receive: Receive, *, scope: Scope) -> Receive:
async def wrap_receive(
self, receive: Receive, *, scope: Scope
) -> t.Tuple[Receive, Scope]:
"""
Wrap the provided `receive` channel with request body validation.

Expand All @@ -124,7 +128,7 @@ async def wrap_receive(self, receive: Receive, *, scope: Scope) -> Receive:
if body is None and self._required:
raise BadRequestProblem("RequestBody is required")
# The default body is encoded as a `receive` channel to mimic an incoming body
receive = self._insert_body(receive, body=body, scope=scope)
receive, scope = self._insert_body(receive, body=body, scope=scope)

# The receive channel is converted to a stream for convenient access
messages = []
Expand All @@ -146,12 +150,12 @@ async def stream() -> t.AsyncGenerator[bytes, None]:
# If MUTABLE_VALIDATION is enabled, include any changes made during validation in the messages to send
if self.MUTABLE_VALIDATION:
# Include changes made during validation
receive = self._insert_body(receive, body=body, scope=scope)
receive, scope = self._insert_body(receive, body=body, scope=scope)
else:
# Serialize original messages
receive = self._insert_messages(receive, messages=messages)

return receive
return receive, scope


class AbstractResponseBodyValidator:
Expand Down
5 changes: 4 additions & 1 deletion tests/fixtures/json_validation/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ components:
password:
type: string
writeOnly: true
human:
type: boolean
default: true
X:
type: object
properties:
Expand Down Expand Up @@ -144,4 +147,4 @@ paths:
schema:
type: array
items:
$ref: "#/components/schemas/X"
$ref: "#/components/schemas/X"
4 changes: 3 additions & 1 deletion tests/fixtures/json_validation/swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ definitions:
password:
type: string
x-writeOnly: true

human:
type: boolean
default: true
paths:
/minlength:
post:
Expand Down
25 changes: 24 additions & 1 deletion tests/test_json_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from connexion import App
from connexion.json_schema import Draft4RequestValidator
from connexion.spec import Specification
from connexion.validators import JSONRequestBodyValidator
from connexion.validators import (
DefaultsJSONRequestBodyValidator,
JSONRequestBodyValidator,
)
from jsonschema.validators import _utils, extend

from conftest import build_app_from_fixture
Expand Down Expand Up @@ -152,3 +155,23 @@ def test_multipart_form_json_array(json_validation_spec_dir, spec, app_class):
assert res.json()[0]["age"] == 30
assert res.json()[1]["name"] == "alena-reply"
assert res.json()[1]["age"] == 38


def test_defaults_body(json_validation_spec_dir, spec):
"""ensure that defaults applied that modify the body"""

class MyDefaultsJSONBodyValidator(DefaultsJSONRequestBodyValidator):
pass

validator_map = {"body": {"application/json": MyDefaultsJSONBodyValidator}}

app = App(__name__, specification_dir=json_validation_spec_dir)
app.add_api(spec, validate_responses=True, validator_map=validator_map)
app_client = app.test_client()

res = app_client.post(
"/v1.0/user",
json={"name": "foo"},
)
assert res.status_code == 200
assert res.json().get("human")
Loading