diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index 95c085585..fe2dd0725 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -139,7 +139,7 @@ import collections import time import uuid from collections.abc import Mapping as _Mapping -from contextvars import ContextVar +from contextvars import ContextVar, Token from typing import ( TYPE_CHECKING, Any, @@ -154,8 +154,6 @@ from typing import ( TypeVar, ) -from _contextvars import Token - from bson.binary import Binary from bson.int64 import Int64 from bson.timestamp import Timestamp @@ -193,6 +191,24 @@ class _AsyncBoundClientSession: self.client_id = client_id +class AsyncBoundSessionContext: + """Context manager returned by AsyncClientSession.bind() that manages bound state.""" + + def __init__(self, session: AsyncClientSession) -> None: + self._session = session + self._session_token: Optional[Token[_AsyncBoundClientSession]] = None + + async def __aenter__(self) -> AsyncClientSession: + bound_session = _AsyncBoundClientSession(self._session, id(self._session._client)) + self._session_token = _SESSION.set(bound_session) # type: ignore[assignment] + return self._session + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self._session_token: + _SESSION.reset(self._session_token) # type: ignore[arg-type] + self._session_token = None + + class SessionOptions: """Options for a new :class:`AsyncClientSession`. @@ -528,9 +544,6 @@ class AsyncClientSession: self._attached_to_cursor = False # Should we leave the session alive when the cursor is closed? self._leave_alive = False - # Is this session bound to a context manager scope? - self._bound = False - self._session_token: Optional[Token[_AsyncBoundClientSession]] = None async def end_session(self) -> None: """Finish this session. If a transaction has started, abort it. @@ -561,23 +574,18 @@ class AsyncClientSession: if self._server_session is None: raise InvalidOperation("Cannot use ended session") - def bind(self) -> AsyncClientSession: - self._bound = True - return self + def bind(self) -> AsyncBoundSessionContext: + """Bind this session so it is implicitly passed to all database operations within the returned context. + + .. versionadded:: 4.17 + """ + return AsyncBoundSessionContext(self) async def __aenter__(self) -> AsyncClientSession: - if self._bound: - bound_session = _AsyncBoundClientSession(self, id(self._client)) - self._session_token = _SESSION.set(bound_session) # type: ignore[assignment] return self async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - if self._session_token: - _SESSION.reset(self._session_token) # type: ignore[arg-type] - self._session_token = None - self._bound = False - else: - await self._end_session(lock=True) + await self._end_session(lock=True) @property def client(self) -> AsyncMongoClient[Any]: diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index 85ff79f99..7dbfb7fe9 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -139,7 +139,7 @@ import collections import time import uuid from collections.abc import Mapping as _Mapping -from contextvars import ContextVar +from contextvars import ContextVar, Token from typing import ( TYPE_CHECKING, Any, @@ -153,8 +153,6 @@ from typing import ( TypeVar, ) -from _contextvars import Token - from bson.binary import Binary from bson.int64 import Int64 from bson.timestamp import Timestamp @@ -192,6 +190,24 @@ class _BoundClientSession: self.client_id = client_id +class BoundSessionContext: + """Context manager returned by ClientSession.bind() that manages bound state.""" + + def __init__(self, session: ClientSession) -> None: + self._session = session + self._session_token: Optional[Token[_BoundClientSession]] = None + + def __enter__(self) -> ClientSession: + bound_session = _BoundClientSession(self._session, id(self._session._client)) + self._session_token = _SESSION.set(bound_session) # type: ignore[assignment] + return self._session + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self._session_token: + _SESSION.reset(self._session_token) # type: ignore[arg-type] + self._session_token = None + + class SessionOptions: """Options for a new :class:`ClientSession`. @@ -527,9 +543,6 @@ class ClientSession: self._attached_to_cursor = False # Should we leave the session alive when the cursor is closed? self._leave_alive = False - # Is this session bound to a context manager scope? - self._bound = False - self._session_token: Optional[Token[_BoundClientSession]] = None def end_session(self) -> None: """Finish this session. If a transaction has started, abort it. @@ -560,23 +573,18 @@ class ClientSession: if self._server_session is None: raise InvalidOperation("Cannot use ended session") - def bind(self) -> ClientSession: - self._bound = True - return self + def bind(self) -> BoundSessionContext: + """Bind this session so it is implicitly passed to all database operations within the returned context. + + .. versionadded:: 4.17 + """ + return BoundSessionContext(self) def __enter__(self) -> ClientSession: - if self._bound: - bound_session = _BoundClientSession(self, id(self._client)) - self._session_token = _SESSION.set(bound_session) # type: ignore[assignment] return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - if self._session_token: - _SESSION.reset(self._session_token) # type: ignore[arg-type] - self._session_token = None - self._bound = False - else: - self._end_session(lock=True) + self._end_session(lock=True) @property def client(self) -> MongoClient[Any]: diff --git a/tools/synchro.py b/tools/synchro.py index d1056a41e..18fb852cc 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -38,6 +38,7 @@ replacements = { "AsyncRawBatchCommandCursor": "RawBatchCommandCursor", "AsyncClientSession": "ClientSession", "_AsyncBoundClientSession": "_BoundClientSession", + "AsyncBoundSessionContext": "BoundSessionContext", "AsyncChangeStream": "ChangeStream", "AsyncCollectionChangeStream": "CollectionChangeStream", "AsyncDatabaseChangeStream": "DatabaseChangeStream",