Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unclosed generator on trio #657

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions httpcore/_async/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,8 @@ def __init__(
self._pool = pool
self._status = status

async def __aiter__(self) -> AsyncIterator[bytes]:
async for part in self._stream:
yield part
def __aiter__(self) -> AsyncIterator[bytes]:
return self._stream.__aiter__()

async def aclose(self) -> None:
try:
Expand Down
32 changes: 24 additions & 8 deletions httpcore/_async/http11.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import enum
import time
from contextlib import AsyncExitStack
from types import TracebackType
from typing import (
AsyncGenerator,
AsyncIterable,
AsyncIterator,
List,
Expand Down Expand Up @@ -173,7 +175,9 @@ async def _receive_response_headers(

return http_version, event.status_code, event.reason, headers

async def _receive_response_body(self, request: Request) -> AsyncIterator[bytes]:
async def _receive_response_body(
self, request: Request
) -> AsyncGenerator[bytes, None]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)

Expand Down Expand Up @@ -304,22 +308,34 @@ def __init__(self, connection: AsyncHTTP11Connection, request: Request) -> None:
self._connection = connection
self._request = request
self._closed = False

async def __aiter__(self) -> AsyncIterator[bytes]:
kwargs = {"request": self._request}
self._stream = self._connection._receive_response_body(**kwargs)
self._trace = Trace("http11.receive_response_body", request, kwargs)

def __aiter__(self) -> AsyncIterator[bytes]:
return self

async def __anext__(self) -> bytes:
if not hasattr(self, "_trace_exit_stack"):
self._trace_exit_stack = AsyncExitStack()
await self._trace_exit_stack.enter_async_context(self._trace)

try:
async with Trace("http11.receive_response_body", self._request, kwargs):
async for chunk in self._connection._receive_response_body(**kwargs):
yield chunk
except BaseException as exc:
return await self._stream.__anext__()
except BaseException:
# If we get an exception while streaming the response,
# we want to close the response (and possibly the connection)
# before raising that exception.
await self._stream.aclose()
await self.aclose()
raise exc
raise

async def aclose(self) -> None:
if hasattr(self, "_trace_exit_stack"):
await self._trace_exit_stack.aclose()

if not self._closed:
await self._stream.aclose()
self._closed = True
async with Trace("http11.response_closed", self._request):
await self._connection._response_closed()
3 changes: 1 addition & 2 deletions httpcore/_sync/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,7 @@ def __init__(
self._status = status

def __iter__(self) -> Iterator[bytes]:
for part in self._stream:
yield part
return self._stream.__iter__()

def close(self) -> None:
try:
Expand Down
30 changes: 23 additions & 7 deletions httpcore/_sync/http11.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import enum
import time
from contextlib import ExitStack
from types import TracebackType
from typing import (
Generator,
Iterable,
Iterator,
List,
Expand Down Expand Up @@ -173,7 +175,9 @@ def _receive_response_headers(

return http_version, event.status_code, event.reason, headers

def _receive_response_body(self, request: Request) -> Iterator[bytes]:
def _receive_response_body(
self, request: Request
) -> Generator[bytes, None]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)

Expand Down Expand Up @@ -304,22 +308,34 @@ def __init__(self, connection: HTTP11Connection, request: Request) -> None:
self._connection = connection
self._request = request
self._closed = False
kwargs = {"request": self._request}
self._stream = self._connection._receive_response_body(**kwargs)
self._trace = Trace("http11.receive_response_body", request, kwargs)

def __iter__(self) -> Iterator[bytes]:
kwargs = {"request": self._request}
return self

def __anext__(self) -> bytes:
if not hasattr(self, "_trace_exit_stack"):
self._trace_exit_stack = ExitStack()
self._trace_exit_stack.enter_async_context(self._trace)

try:
with Trace("http11.receive_response_body", self._request, kwargs):
for chunk in self._connection._receive_response_body(**kwargs):
yield chunk
except BaseException as exc:
return self._stream.__anext__()
except BaseException:
# If we get an exception while streaming the response,
# we want to close the response (and possibly the connection)
# before raising that exception.
self._stream.close()
self.close()
raise exc
raise

def close(self) -> None:
if hasattr(self, "_trace_exit_stack"):
self._trace_exit_stack.close()

if not self._closed:
self._stream.close()
self._closed = True
with Trace("http11.response_closed", self._request):
self._connection._response_closed()
1 change: 1 addition & 0 deletions unasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
('import trio as concurrency', 'from tests import concurrency'),
('AsyncByteStream', 'SyncByteStream'),
('AsyncIterator', 'Iterator'),
(r'AsyncGenerator\[bytes, None\]', r'Generator\[bytes, None, None\]'),
('AutoBackend', 'SyncBackend'),
('Async([A-Z][A-Za-z0-9_]*)', r'\2'),
('async def', 'def'),
Expand Down