Revert empty context for ASGI runs (#2911)
This commit is contained in:
parent
850d92656d
commit
45c0b568d3
@ -131,7 +131,7 @@ class MockLoop:
|
||||
self._tasks: list[asyncio.Task[Any]] = []
|
||||
self._later: list[MockTimerHandle] = []
|
||||
|
||||
def create_task(self, coroutine: Any, **kwargs: Any) -> Any:
|
||||
def create_task(self, coroutine: Any) -> Any:
|
||||
self._tasks.insert(0, coroutine)
|
||||
return MockTask()
|
||||
|
||||
|
||||
@ -226,7 +226,7 @@ class MockLoop:
|
||||
self._tasks: list[asyncio.Task[Any]] = []
|
||||
self._later: list[MockTimerHandle] = []
|
||||
|
||||
def create_task(self, coroutine: Any, **kwargs: Any) -> Any:
|
||||
def create_task(self, coroutine: Any) -> Any:
|
||||
self._tasks.insert(0, coroutine)
|
||||
return MockTask()
|
||||
|
||||
|
||||
@ -2,8 +2,6 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
@ -13,14 +11,12 @@ from contextlib import AbstractContextManager
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from tests.protocols.test_http import SIMPLE_GET_REQUEST
|
||||
from tests.utils import run_server
|
||||
from uvicorn import Server
|
||||
from uvicorn._types import ASGIApplication, ASGIReceiveCallable, ASGISendCallable, Scope
|
||||
from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.protocols.http.flow_control import HIGH_WATER_LIMIT
|
||||
from uvicorn.protocols.http.h11_impl import H11Protocol
|
||||
from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol
|
||||
from uvicorn.server import Server
|
||||
|
||||
pytestmark = pytest.mark.anyio
|
||||
|
||||
@ -154,81 +150,3 @@ async def test_limit_max_requests_jitter(
|
||||
await client.get(f"http://127.0.0.1:{unused_tcp_port}")
|
||||
await task
|
||||
assert f"Maximum request limit of {limit} exceeded. Terminating process." in caplog.text
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def server(*, app: ASGIApplication, port: int, http_protocol_cls: type[H11Protocol | HttpToolsProtocol]):
|
||||
config = Config(app=app, port=port, loop="asyncio", http=http_protocol_cls)
|
||||
server = Server(config=config)
|
||||
task = asyncio.create_task(server.serve())
|
||||
|
||||
while not server.started:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
reader, writer = await asyncio.open_connection("127.0.0.1", port)
|
||||
|
||||
async def extract_json_body(request: bytes):
|
||||
writer.write(request)
|
||||
await writer.drain()
|
||||
|
||||
status, *headers = (await reader.readuntil(b"\r\n\r\n")).split(b"\r\n")[:-2]
|
||||
assert status == b"HTTP/1.1 200 OK"
|
||||
|
||||
content_length = next(int(h.split(b":", 1)[1]) for h in headers if h.lower().startswith(b"content-length:"))
|
||||
return json.loads(await reader.readexactly(content_length))
|
||||
|
||||
try:
|
||||
yield extract_json_body
|
||||
finally:
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
server.should_exit = True
|
||||
await task
|
||||
|
||||
|
||||
async def test_no_contextvars_pollution_asyncio(
|
||||
http_protocol_cls: type[H11Protocol | HttpToolsProtocol], unused_tcp_port: int
|
||||
):
|
||||
"""Non-regression test for https://github.com/encode/uvicorn/issues/2167."""
|
||||
default_contextvars = {c.name for c in contextvars.copy_context().keys()}
|
||||
ctx: contextvars.ContextVar[str] = contextvars.ContextVar("ctx")
|
||||
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
assert scope["type"] == "http"
|
||||
|
||||
# initial context should be empty
|
||||
initial_context = {
|
||||
n: v for c, v in contextvars.copy_context().items() if (n := c.name) not in default_contextvars
|
||||
}
|
||||
# set any contextvar before the body is read
|
||||
ctx.set(scope["path"])
|
||||
|
||||
while True:
|
||||
message = await receive()
|
||||
assert message["type"] == "http.request"
|
||||
if not message["more_body"]:
|
||||
break
|
||||
|
||||
# return the initial context for empty assertion
|
||||
body = json.dumps(initial_context).encode("utf-8")
|
||||
headers = [(b"content-type", b"application/json"), (b"content-length", str(len(body)).encode("utf-8"))]
|
||||
await send({"type": "http.response.start", "status": 200, "headers": headers})
|
||||
await send({"type": "http.response.body", "body": body})
|
||||
|
||||
# body has to be larger than HIGH_WATER_LIMIT to trigger a reading pause on the main thread
|
||||
# and a resumption inside the ASGI task
|
||||
large_body = b"a" * (HIGH_WATER_LIMIT + 1)
|
||||
large_request = b"\r\n".join(
|
||||
[
|
||||
b"POST /large-body HTTP/1.1",
|
||||
b"Host: example.org",
|
||||
b"Content-Type: application/octet-stream",
|
||||
f"Content-Length: {len(large_body)}".encode(),
|
||||
b"",
|
||||
large_body,
|
||||
]
|
||||
)
|
||||
|
||||
async with server(app=app, http_protocol_cls=http_protocol_cls, port=unused_tcp_port) as extract_json_body:
|
||||
assert await extract_json_body(large_request) == {}
|
||||
assert await extract_json_body(SIMPLE_GET_REQUEST) == {}
|
||||
|
||||
@ -1,10 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import http
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal
|
||||
from urllib.parse import unquote
|
||||
@ -250,13 +248,7 @@ class H11Protocol(asyncio.Protocol):
|
||||
message_event=asyncio.Event(),
|
||||
on_response=self.on_response_complete,
|
||||
)
|
||||
# For the asyncio loop, we need to explicitly start with an empty context
|
||||
# as it can be polluted from previous ASGI runs.
|
||||
# See https://github.com/python/cpython/issues/140947 for details.
|
||||
if sys.version_info >= (3, 11): # pragma: py-lt-311
|
||||
task = self.loop.create_task(self.cycle.run_asgi(app), context=contextvars.Context())
|
||||
else: # pragma: py-gte-311
|
||||
task = contextvars.Context().run(self.loop.create_task, self.cycle.run_asgi(app))
|
||||
task = self.loop.create_task(self.cycle.run_asgi(app))
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
self.tasks.add(task)
|
||||
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import http
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
import urllib
|
||||
from asyncio.events import TimerHandle
|
||||
from collections import deque
|
||||
@ -289,13 +287,7 @@ class HttpToolsProtocol(asyncio.Protocol):
|
||||
)
|
||||
if existing_cycle is None or existing_cycle.response_complete:
|
||||
# Standard case - start processing the request.
|
||||
# For the asyncio loop, we need to explicitly start with an empty context
|
||||
# as it can be polluted from previous ASGI runs.
|
||||
# See https://github.com/python/cpython/issues/140947 for details.
|
||||
if sys.version_info >= (3, 11): # pragma: py-lt-311
|
||||
task = self.loop.create_task(self.cycle.run_asgi(app), context=contextvars.Context())
|
||||
else: # pragma: py-gte-311
|
||||
task = contextvars.Context().run(self.loop.create_task, self.cycle.run_asgi(app))
|
||||
task = self.loop.create_task(self.cycle.run_asgi(app))
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
self.tasks.add(task)
|
||||
else:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user