Python 4542 - Improved sessions API (#2712)

This commit is contained in:
Noah Stapp 2026-03-05 08:04:37 -08:00 committed by GitHub
parent e028fe2a38
commit f533157981
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 364 additions and 14 deletions

View File

@ -1,6 +1,15 @@
Changelog
=========
Changes in Version 4.17.0 (2026/XX/XX)
--------------------------------------
PyMongo 4.17 brings a number of changes including:
- Added the :meth:`~pymongo.asynchronous.client_session.AsyncClientSession.bind` and :meth:`~pymongo.client_session.ClientSession.bind` methods
that allow users to bind a session to all database operations within the scope of a context manager instead of having to explicitly pass the session to each individual operation.
See <PLACEHOLDER> for examples and more information.
Changes in Version 4.16.0 (2026/01/07)
--------------------------------------

View File

@ -139,6 +139,7 @@ import collections
import time
import uuid
from collections.abc import Mapping as _Mapping
from contextvars import ContextVar, Token
from typing import (
TYPE_CHECKING,
Any,
@ -181,6 +182,28 @@ if TYPE_CHECKING:
_IS_SYNC = False
_SESSION: ContextVar[Optional[AsyncClientSession]] = ContextVar("SESSION", default=None)
class _AsyncBoundSessionContext:
"""Context manager returned by AsyncClientSession.bind() that manages bound state."""
def __init__(self, session: AsyncClientSession, end_session: bool) -> None:
self._session = session
self._session_token: Optional[Token[AsyncClientSession]] = None
self._end_session = end_session
async def __aenter__(self) -> AsyncClientSession:
self._session_token = _SESSION.set(self._session) # type: ignore[assignment]
return self._session
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self._session_token:
_SESSION.reset(self._session_token) # type: ignore[arg-type]
self._session_token = None
if self._end_session:
await self._session.end_session()
class SessionOptions:
"""Options for a new :class:`AsyncClientSession`.
@ -547,6 +570,24 @@ class AsyncClientSession:
if self._server_session is None:
raise InvalidOperation("Cannot use ended session")
def bind(self, end_session: bool = True) -> _AsyncBoundSessionContext:
"""Bind this session so it is implicitly passed to all database operations within the returned context.
.. code-block:: python
async with client.start_session() as s:
async with s.bind():
# session=s is passed implicitly
await client.db.collection.insert_one({"x": 1})
:param end_session: Whether to end the session on exiting the returned context. Defaults to True.
If set to False, :meth:`~pymongo.asynchronous.client_session.AsyncClientSession.end_session()` must be called
once the session is no longer used.
.. versionadded:: 4.17
"""
return _AsyncBoundSessionContext(self, end_session)
async def __aenter__(self) -> AsyncClientSession:
return self

View File

@ -65,7 +65,7 @@ from pymongo import _csot, common, helpers_shared, periodic_executor
from pymongo.asynchronous import client_session, database, uri_parser
from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
from pymongo.asynchronous.client_session import _EmptyServerSession
from pymongo.asynchronous.client_session import _SESSION, _EmptyServerSession
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.settings import TopologySettings
from pymongo.asynchronous.topology import Topology, _ErrorContext
@ -1408,7 +1408,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
def _ensure_session(
self, session: Optional[AsyncClientSession] = None
) -> Optional[AsyncClientSession]:
"""If provided session is None, lend a temporary session."""
"""If provided session and bound session are None, lend a temporary session."""
session = session or self._get_bound_session()
if session:
return session
@ -2267,11 +2268,14 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
self, session: Optional[client_session.AsyncClientSession]
) -> AsyncGenerator[Optional[client_session.AsyncClientSession], None]:
"""If provided session is None, lend a temporary session."""
if session is not None:
if not isinstance(session, client_session.AsyncClientSession):
raise ValueError(
f"'session' argument must be an AsyncClientSession or None, not {type(session)}"
)
if session is not None and not isinstance(session, client_session.AsyncClientSession):
raise ValueError(
f"'session' argument must be an AsyncClientSession or None, not {type(session)}"
)
# Check for a bound session. If one exists, treat it as an explicitly passed session.
session = session or self._get_bound_session()
if session:
# Don't call end_session.
yield session
return
@ -2301,6 +2305,18 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
if session is not None:
session._process_response(reply)
def _get_bound_session(self) -> Optional[AsyncClientSession]:
bound_session = _SESSION.get()
if bound_session:
if bound_session.client is self:
return bound_session
else:
raise InvalidOperation(
"Only the client that created the bound session can perform operations within its context block. See <PLACEHOLDER> for more information."
)
else:
return None
async def server_info(
self, session: Optional[client_session.AsyncClientSession] = None
) -> dict[str, Any]:

View File

@ -139,6 +139,7 @@ import collections
import time
import uuid
from collections.abc import Mapping as _Mapping
from contextvars import ContextVar, Token
from typing import (
TYPE_CHECKING,
Any,
@ -180,6 +181,28 @@ if TYPE_CHECKING:
_IS_SYNC = True
_SESSION: ContextVar[Optional[ClientSession]] = ContextVar("SESSION", default=None)
class _BoundSessionContext:
"""Context manager returned by ClientSession.bind() that manages bound state."""
def __init__(self, session: ClientSession, end_session: bool) -> None:
self._session = session
self._session_token: Optional[Token[ClientSession]] = None
self._end_session = end_session
def __enter__(self) -> ClientSession:
self._session_token = _SESSION.set(self._session) # type: ignore[assignment]
return self._session
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self._session_token:
_SESSION.reset(self._session_token) # type: ignore[arg-type]
self._session_token = None
if self._end_session:
self._session.end_session()
class SessionOptions:
"""Options for a new :class:`ClientSession`.
@ -546,6 +569,24 @@ class ClientSession:
if self._server_session is None:
raise InvalidOperation("Cannot use ended session")
def bind(self, end_session: bool = True) -> _BoundSessionContext:
"""Bind this session so it is implicitly passed to all database operations within the returned context.
.. code-block:: python
with client.start_session() as s:
with s.bind():
# session=s is passed implicitly
client.db.collection.insert_one({"x": 1})
:param end_session: Whether to end the session on exiting the returned context. Defaults to True.
If set to False, :meth:`~pymongo.client_session.ClientSession.end_session()` must be called
once the session is no longer used.
.. versionadded:: 4.17
"""
return _BoundSessionContext(self, end_session)
def __enter__(self) -> ClientSession:
return self

View File

@ -108,7 +108,7 @@ from pymongo.server_type import SERVER_TYPE
from pymongo.synchronous import client_session, database, uri_parser
from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream
from pymongo.synchronous.client_bulk import _ClientBulk
from pymongo.synchronous.client_session import _EmptyServerSession
from pymongo.synchronous.client_session import _SESSION, _EmptyServerSession
from pymongo.synchronous.command_cursor import CommandCursor
from pymongo.synchronous.settings import TopologySettings
from pymongo.synchronous.topology import Topology, _ErrorContext
@ -1406,7 +1406,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
)
def _ensure_session(self, session: Optional[ClientSession] = None) -> Optional[ClientSession]:
"""If provided session is None, lend a temporary session."""
"""If provided session and bound session are None, lend a temporary session."""
session = session or self._get_bound_session()
if session:
return session
@ -2263,11 +2264,14 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
self, session: Optional[client_session.ClientSession]
) -> Generator[Optional[client_session.ClientSession], None]:
"""If provided session is None, lend a temporary session."""
if session is not None:
if not isinstance(session, client_session.ClientSession):
raise ValueError(
f"'session' argument must be a ClientSession or None, not {type(session)}"
)
if session is not None and not isinstance(session, client_session.ClientSession):
raise ValueError(
f"'session' argument must be a ClientSession or None, not {type(session)}"
)
# Check for a bound session. If one exists, treat it as an explicitly passed session.
session = session or self._get_bound_session()
if session:
# Don't call end_session.
yield session
return
@ -2295,6 +2299,18 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
if session is not None:
session._process_response(reply)
def _get_bound_session(self) -> Optional[ClientSession]:
bound_session = _SESSION.get()
if bound_session:
if bound_session.client is self:
return bound_session
else:
raise InvalidOperation(
"Only the client that created the bound session can perform operations within its context block. See <PLACEHOLDER> for more information."
)
else:
return None
def server_info(self, session: Optional[client_session.ClientSession] = None) -> dict[str, Any]:
"""Get information about the MongoDB server we're connected to.

View File

@ -189,6 +189,52 @@ class TestSession(AsyncIntegrationTest):
f"{f.__name__} did not return implicit session to pool",
)
# Explicit bound session
for f, args, kw in ops:
async with client.start_session() as s:
async with s.bind():
listener.reset()
s._materialize()
last_use = s._server_session.last_use
start = time.monotonic()
self.assertLessEqual(last_use, start)
# In case "f" modifies its inputs.
args = copy.copy(args)
kw = copy.copy(kw)
await f(*args, **kw)
self.assertGreaterEqual(len(listener.started_events), 1)
for event in listener.started_events:
self.assertIn(
"lsid",
event.command,
f"{f.__name__} sent no lsid with {event.command_name}",
)
self.assertEqual(
s.session_id,
event.command["lsid"],
f"{f.__name__} sent wrong lsid with {event.command_name}",
)
self.assertFalse(s.has_ended)
self.assertTrue(s.has_ended)
with self.assertRaisesRegex(InvalidOperation, "ended session"):
async with s.bind():
await f(*args, **kw)
# Test a session cannot be used on another client.
async with self.client2.start_session() as s:
async with s.bind():
# In case "f" modifies its inputs.
args = copy.copy(args)
kw = copy.copy(kw)
with self.assertRaisesRegex(
InvalidOperation,
"Only the client that created the bound session can perform operations within its context block",
):
await f(*args, **kw)
async def test_implicit_sessions_checkout(self):
# "To confirm that implicit sessions only allocate their server session after a
# successful connection checkout" test from Driver Sessions Spec.
@ -825,6 +871,73 @@ class TestSession(AsyncIntegrationTest):
async with client.start_session() as s:
self.assertRaises(TypeError, lambda: copy.copy(s))
async def test_nested_session_binding(self):
coll = self.client.pymongo_test.test
await coll.insert_one({"x": 1})
session1 = self.client.start_session()
session2 = self.client.start_session()
session1._materialize()
session2._materialize()
try:
self.listener.reset()
# Uses implicit session
await coll.find_one()
implicit_lsid = self.listener.started_events[0].command.get("lsid")
self.assertIsNotNone(implicit_lsid)
self.assertNotEqual(implicit_lsid, session1.session_id)
self.assertNotEqual(implicit_lsid, session2.session_id)
async with session1.bind(end_session=False):
self.listener.reset()
# Uses bound session1
await coll.find_one()
session1_lsid = self.listener.started_events[0].command.get("lsid")
self.assertEqual(session1_lsid, session1.session_id)
async with session2.bind(end_session=False):
self.listener.reset()
# Uses bound session2
await coll.find_one()
session2_lsid = self.listener.started_events[0].command.get("lsid")
self.assertEqual(session2_lsid, session2.session_id)
self.assertNotEqual(session2_lsid, session1.session_id)
self.listener.reset()
# Use bound session1 again
await coll.find_one()
session1_lsid = self.listener.started_events[0].command.get("lsid")
self.assertEqual(session1_lsid, session1.session_id)
self.assertNotEqual(session1_lsid, session2.session_id)
self.listener.reset()
# Uses implicit session
await coll.find_one()
implicit_lsid = self.listener.started_events[0].command.get("lsid")
self.assertIsNotNone(implicit_lsid)
self.assertNotEqual(implicit_lsid, session1.session_id)
self.assertNotEqual(implicit_lsid, session2.session_id)
finally:
await session1.end_session()
await session2.end_session()
async def test_session_binding_end_session(self):
coll = self.client.pymongo_test.test
await coll.insert_one({"x": 1})
async with self.client.start_session().bind() as s1:
await coll.find_one()
self.assertTrue(s1.has_ended)
async with self.client.start_session().bind(end_session=False) as s2:
await coll.find_one()
self.assertFalse(s2.has_ended)
await s2.end_session()
class TestCausalConsistency(AsyncUnitTest):
listener: SessionTestListener

View File

@ -189,6 +189,52 @@ class TestSession(IntegrationTest):
f"{f.__name__} did not return implicit session to pool",
)
# Explicit bound session
for f, args, kw in ops:
with client.start_session() as s:
with s.bind():
listener.reset()
s._materialize()
last_use = s._server_session.last_use
start = time.monotonic()
self.assertLessEqual(last_use, start)
# In case "f" modifies its inputs.
args = copy.copy(args)
kw = copy.copy(kw)
f(*args, **kw)
self.assertGreaterEqual(len(listener.started_events), 1)
for event in listener.started_events:
self.assertIn(
"lsid",
event.command,
f"{f.__name__} sent no lsid with {event.command_name}",
)
self.assertEqual(
s.session_id,
event.command["lsid"],
f"{f.__name__} sent wrong lsid with {event.command_name}",
)
self.assertFalse(s.has_ended)
self.assertTrue(s.has_ended)
with self.assertRaisesRegex(InvalidOperation, "ended session"):
with s.bind():
f(*args, **kw)
# Test a session cannot be used on another client.
with self.client2.start_session() as s:
with s.bind():
# In case "f" modifies its inputs.
args = copy.copy(args)
kw = copy.copy(kw)
with self.assertRaisesRegex(
InvalidOperation,
"Only the client that created the bound session can perform operations within its context block",
):
f(*args, **kw)
def test_implicit_sessions_checkout(self):
# "To confirm that implicit sessions only allocate their server session after a
# successful connection checkout" test from Driver Sessions Spec.
@ -825,6 +871,73 @@ class TestSession(IntegrationTest):
with client.start_session() as s:
self.assertRaises(TypeError, lambda: copy.copy(s))
def test_nested_session_binding(self):
coll = self.client.pymongo_test.test
coll.insert_one({"x": 1})
session1 = self.client.start_session()
session2 = self.client.start_session()
session1._materialize()
session2._materialize()
try:
self.listener.reset()
# Uses implicit session
coll.find_one()
implicit_lsid = self.listener.started_events[0].command.get("lsid")
self.assertIsNotNone(implicit_lsid)
self.assertNotEqual(implicit_lsid, session1.session_id)
self.assertNotEqual(implicit_lsid, session2.session_id)
with session1.bind(end_session=False):
self.listener.reset()
# Uses bound session1
coll.find_one()
session1_lsid = self.listener.started_events[0].command.get("lsid")
self.assertEqual(session1_lsid, session1.session_id)
with session2.bind(end_session=False):
self.listener.reset()
# Uses bound session2
coll.find_one()
session2_lsid = self.listener.started_events[0].command.get("lsid")
self.assertEqual(session2_lsid, session2.session_id)
self.assertNotEqual(session2_lsid, session1.session_id)
self.listener.reset()
# Use bound session1 again
coll.find_one()
session1_lsid = self.listener.started_events[0].command.get("lsid")
self.assertEqual(session1_lsid, session1.session_id)
self.assertNotEqual(session1_lsid, session2.session_id)
self.listener.reset()
# Uses implicit session
coll.find_one()
implicit_lsid = self.listener.started_events[0].command.get("lsid")
self.assertIsNotNone(implicit_lsid)
self.assertNotEqual(implicit_lsid, session1.session_id)
self.assertNotEqual(implicit_lsid, session2.session_id)
finally:
session1.end_session()
session2.end_session()
def test_session_binding_end_session(self):
coll = self.client.pymongo_test.test
coll.insert_one({"x": 1})
with self.client.start_session().bind() as s1:
coll.find_one()
self.assertTrue(s1.has_ended)
with self.client.start_session().bind(end_session=False) as s2:
coll.find_one()
self.assertFalse(s2.has_ended)
s2.end_session()
class TestCausalConsistency(UnitTest):
listener: SessionTestListener

View File

@ -37,6 +37,7 @@ replacements = {
"AsyncRawBatchCursor": "RawBatchCursor",
"AsyncRawBatchCommandCursor": "RawBatchCommandCursor",
"AsyncClientSession": "ClientSession",
"_AsyncBoundSessionContext": "_BoundSessionContext",
"AsyncChangeStream": "ChangeStream",
"AsyncCollectionChangeStream": "CollectionChangeStream",
"AsyncDatabaseChangeStream": "DatabaseChangeStream",