diff --git a/httpx/_transports/wsgi.py b/httpx/_transports/wsgi.py index 58e8309d..e71f1604 100644 --- a/httpx/_transports/wsgi.py +++ b/httpx/_transports/wsgi.py @@ -16,12 +16,17 @@ def _skip_leading_empty_chunks(body: typing.Iterable) -> typing.Iterable: class WSGIByteStream(SyncByteStream): def __init__(self, result: typing.Iterable[bytes]) -> None: + self._close = getattr(result, "close", None) self._result = _skip_leading_empty_chunks(result) def __iter__(self) -> typing.Iterator[bytes]: for part in self._result: yield part + def close(self) -> None: + if self._close is not None: + self._close() + class WSGITransport(BaseTransport): """ diff --git a/tests/test_wsgi.py b/tests/test_wsgi.py index b130e53c..164899b5 100644 --- a/tests/test_wsgi.py +++ b/tests/test_wsgi.py @@ -1,4 +1,5 @@ import sys +import wsgiref.validate from functools import partial import pytest @@ -19,7 +20,7 @@ def application_factory(output): for item in output: yield item - return application + return wsgiref.validate.validator(application) def echo_body(environ, start_response):