Python 4542 - Improved sessions API (#2712)
This commit is contained in:
parent
e028fe2a38
commit
f533157981
@ -1,6 +1,15 @@
|
||||
Changelog
|
||||
=========
|
||||
|
||||
Changes in Version 4.17.0 (2026/XX/XX)
|
||||
--------------------------------------
|
||||
|
||||
PyMongo 4.17 brings a number of changes including:
|
||||
|
||||
- Added the :meth:`~pymongo.asynchronous.client_session.AsyncClientSession.bind` and :meth:`~pymongo.client_session.ClientSession.bind` methods
|
||||
that allow users to bind a session to all database operations within the scope of a context manager instead of having to explicitly pass the session to each individual operation.
|
||||
See <PLACEHOLDER> for examples and more information.
|
||||
|
||||
Changes in Version 4.16.0 (2026/01/07)
|
||||
--------------------------------------
|
||||
|
||||
|
||||
@ -139,6 +139,7 @@ import collections
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Mapping as _Mapping
|
||||
from contextvars import ContextVar, Token
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@ -181,6 +182,28 @@ if TYPE_CHECKING:
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
_SESSION: ContextVar[Optional[AsyncClientSession]] = ContextVar("SESSION", default=None)
|
||||
|
||||
|
||||
class _AsyncBoundSessionContext:
|
||||
"""Context manager returned by AsyncClientSession.bind() that manages bound state."""
|
||||
|
||||
def __init__(self, session: AsyncClientSession, end_session: bool) -> None:
|
||||
self._session = session
|
||||
self._session_token: Optional[Token[AsyncClientSession]] = None
|
||||
self._end_session = end_session
|
||||
|
||||
async def __aenter__(self) -> AsyncClientSession:
|
||||
self._session_token = _SESSION.set(self._session) # type: ignore[assignment]
|
||||
return self._session
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
if self._session_token:
|
||||
_SESSION.reset(self._session_token) # type: ignore[arg-type]
|
||||
self._session_token = None
|
||||
if self._end_session:
|
||||
await self._session.end_session()
|
||||
|
||||
|
||||
class SessionOptions:
|
||||
"""Options for a new :class:`AsyncClientSession`.
|
||||
@ -547,6 +570,24 @@ class AsyncClientSession:
|
||||
if self._server_session is None:
|
||||
raise InvalidOperation("Cannot use ended session")
|
||||
|
||||
def bind(self, end_session: bool = True) -> _AsyncBoundSessionContext:
|
||||
"""Bind this session so it is implicitly passed to all database operations within the returned context.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
async with client.start_session() as s:
|
||||
async with s.bind():
|
||||
# session=s is passed implicitly
|
||||
await client.db.collection.insert_one({"x": 1})
|
||||
|
||||
:param end_session: Whether to end the session on exiting the returned context. Defaults to True.
|
||||
If set to False, :meth:`~pymongo.asynchronous.client_session.AsyncClientSession.end_session()` must be called
|
||||
once the session is no longer used.
|
||||
|
||||
.. versionadded:: 4.17
|
||||
"""
|
||||
return _AsyncBoundSessionContext(self, end_session)
|
||||
|
||||
async def __aenter__(self) -> AsyncClientSession:
|
||||
return self
|
||||
|
||||
|
||||
@ -65,7 +65,7 @@ from pymongo import _csot, common, helpers_shared, periodic_executor
|
||||
from pymongo.asynchronous import client_session, database, uri_parser
|
||||
from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream
|
||||
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
|
||||
from pymongo.asynchronous.client_session import _EmptyServerSession
|
||||
from pymongo.asynchronous.client_session import _SESSION, _EmptyServerSession
|
||||
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
|
||||
from pymongo.asynchronous.settings import TopologySettings
|
||||
from pymongo.asynchronous.topology import Topology, _ErrorContext
|
||||
@ -1408,7 +1408,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
def _ensure_session(
|
||||
self, session: Optional[AsyncClientSession] = None
|
||||
) -> Optional[AsyncClientSession]:
|
||||
"""If provided session is None, lend a temporary session."""
|
||||
"""If provided session and bound session are None, lend a temporary session."""
|
||||
session = session or self._get_bound_session()
|
||||
if session:
|
||||
return session
|
||||
|
||||
@ -2267,11 +2268,14 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
self, session: Optional[client_session.AsyncClientSession]
|
||||
) -> AsyncGenerator[Optional[client_session.AsyncClientSession], None]:
|
||||
"""If provided session is None, lend a temporary session."""
|
||||
if session is not None:
|
||||
if not isinstance(session, client_session.AsyncClientSession):
|
||||
raise ValueError(
|
||||
f"'session' argument must be an AsyncClientSession or None, not {type(session)}"
|
||||
)
|
||||
if session is not None and not isinstance(session, client_session.AsyncClientSession):
|
||||
raise ValueError(
|
||||
f"'session' argument must be an AsyncClientSession or None, not {type(session)}"
|
||||
)
|
||||
|
||||
# Check for a bound session. If one exists, treat it as an explicitly passed session.
|
||||
session = session or self._get_bound_session()
|
||||
if session:
|
||||
# Don't call end_session.
|
||||
yield session
|
||||
return
|
||||
@ -2301,6 +2305,18 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
if session is not None:
|
||||
session._process_response(reply)
|
||||
|
||||
def _get_bound_session(self) -> Optional[AsyncClientSession]:
|
||||
bound_session = _SESSION.get()
|
||||
if bound_session:
|
||||
if bound_session.client is self:
|
||||
return bound_session
|
||||
else:
|
||||
raise InvalidOperation(
|
||||
"Only the client that created the bound session can perform operations within its context block. See <PLACEHOLDER> for more information."
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
async def server_info(
|
||||
self, session: Optional[client_session.AsyncClientSession] = None
|
||||
) -> dict[str, Any]:
|
||||
|
||||
@ -139,6 +139,7 @@ import collections
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Mapping as _Mapping
|
||||
from contextvars import ContextVar, Token
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@ -180,6 +181,28 @@ if TYPE_CHECKING:
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
_SESSION: ContextVar[Optional[ClientSession]] = ContextVar("SESSION", default=None)
|
||||
|
||||
|
||||
class _BoundSessionContext:
|
||||
"""Context manager returned by ClientSession.bind() that manages bound state."""
|
||||
|
||||
def __init__(self, session: ClientSession, end_session: bool) -> None:
|
||||
self._session = session
|
||||
self._session_token: Optional[Token[ClientSession]] = None
|
||||
self._end_session = end_session
|
||||
|
||||
def __enter__(self) -> ClientSession:
|
||||
self._session_token = _SESSION.set(self._session) # type: ignore[assignment]
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
if self._session_token:
|
||||
_SESSION.reset(self._session_token) # type: ignore[arg-type]
|
||||
self._session_token = None
|
||||
if self._end_session:
|
||||
self._session.end_session()
|
||||
|
||||
|
||||
class SessionOptions:
|
||||
"""Options for a new :class:`ClientSession`.
|
||||
@ -546,6 +569,24 @@ class ClientSession:
|
||||
if self._server_session is None:
|
||||
raise InvalidOperation("Cannot use ended session")
|
||||
|
||||
def bind(self, end_session: bool = True) -> _BoundSessionContext:
|
||||
"""Bind this session so it is implicitly passed to all database operations within the returned context.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
with client.start_session() as s:
|
||||
with s.bind():
|
||||
# session=s is passed implicitly
|
||||
client.db.collection.insert_one({"x": 1})
|
||||
|
||||
:param end_session: Whether to end the session on exiting the returned context. Defaults to True.
|
||||
If set to False, :meth:`~pymongo.client_session.ClientSession.end_session()` must be called
|
||||
once the session is no longer used.
|
||||
|
||||
.. versionadded:: 4.17
|
||||
"""
|
||||
return _BoundSessionContext(self, end_session)
|
||||
|
||||
def __enter__(self) -> ClientSession:
|
||||
return self
|
||||
|
||||
|
||||
@ -108,7 +108,7 @@ from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.synchronous import client_session, database, uri_parser
|
||||
from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream
|
||||
from pymongo.synchronous.client_bulk import _ClientBulk
|
||||
from pymongo.synchronous.client_session import _EmptyServerSession
|
||||
from pymongo.synchronous.client_session import _SESSION, _EmptyServerSession
|
||||
from pymongo.synchronous.command_cursor import CommandCursor
|
||||
from pymongo.synchronous.settings import TopologySettings
|
||||
from pymongo.synchronous.topology import Topology, _ErrorContext
|
||||
@ -1406,7 +1406,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
)
|
||||
|
||||
def _ensure_session(self, session: Optional[ClientSession] = None) -> Optional[ClientSession]:
|
||||
"""If provided session is None, lend a temporary session."""
|
||||
"""If provided session and bound session are None, lend a temporary session."""
|
||||
session = session or self._get_bound_session()
|
||||
if session:
|
||||
return session
|
||||
|
||||
@ -2263,11 +2264,14 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
self, session: Optional[client_session.ClientSession]
|
||||
) -> Generator[Optional[client_session.ClientSession], None]:
|
||||
"""If provided session is None, lend a temporary session."""
|
||||
if session is not None:
|
||||
if not isinstance(session, client_session.ClientSession):
|
||||
raise ValueError(
|
||||
f"'session' argument must be a ClientSession or None, not {type(session)}"
|
||||
)
|
||||
if session is not None and not isinstance(session, client_session.ClientSession):
|
||||
raise ValueError(
|
||||
f"'session' argument must be a ClientSession or None, not {type(session)}"
|
||||
)
|
||||
|
||||
# Check for a bound session. If one exists, treat it as an explicitly passed session.
|
||||
session = session or self._get_bound_session()
|
||||
if session:
|
||||
# Don't call end_session.
|
||||
yield session
|
||||
return
|
||||
@ -2295,6 +2299,18 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
if session is not None:
|
||||
session._process_response(reply)
|
||||
|
||||
def _get_bound_session(self) -> Optional[ClientSession]:
|
||||
bound_session = _SESSION.get()
|
||||
if bound_session:
|
||||
if bound_session.client is self:
|
||||
return bound_session
|
||||
else:
|
||||
raise InvalidOperation(
|
||||
"Only the client that created the bound session can perform operations within its context block. See <PLACEHOLDER> for more information."
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
def server_info(self, session: Optional[client_session.ClientSession] = None) -> dict[str, Any]:
|
||||
"""Get information about the MongoDB server we're connected to.
|
||||
|
||||
|
||||
@ -189,6 +189,52 @@ class TestSession(AsyncIntegrationTest):
|
||||
f"{f.__name__} did not return implicit session to pool",
|
||||
)
|
||||
|
||||
# Explicit bound session
|
||||
for f, args, kw in ops:
|
||||
async with client.start_session() as s:
|
||||
async with s.bind():
|
||||
listener.reset()
|
||||
s._materialize()
|
||||
last_use = s._server_session.last_use
|
||||
start = time.monotonic()
|
||||
self.assertLessEqual(last_use, start)
|
||||
# In case "f" modifies its inputs.
|
||||
args = copy.copy(args)
|
||||
kw = copy.copy(kw)
|
||||
await f(*args, **kw)
|
||||
self.assertGreaterEqual(len(listener.started_events), 1)
|
||||
for event in listener.started_events:
|
||||
self.assertIn(
|
||||
"lsid",
|
||||
event.command,
|
||||
f"{f.__name__} sent no lsid with {event.command_name}",
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
s.session_id,
|
||||
event.command["lsid"],
|
||||
f"{f.__name__} sent wrong lsid with {event.command_name}",
|
||||
)
|
||||
|
||||
self.assertFalse(s.has_ended)
|
||||
|
||||
self.assertTrue(s.has_ended)
|
||||
with self.assertRaisesRegex(InvalidOperation, "ended session"):
|
||||
async with s.bind():
|
||||
await f(*args, **kw)
|
||||
|
||||
# Test a session cannot be used on another client.
|
||||
async with self.client2.start_session() as s:
|
||||
async with s.bind():
|
||||
# In case "f" modifies its inputs.
|
||||
args = copy.copy(args)
|
||||
kw = copy.copy(kw)
|
||||
with self.assertRaisesRegex(
|
||||
InvalidOperation,
|
||||
"Only the client that created the bound session can perform operations within its context block",
|
||||
):
|
||||
await f(*args, **kw)
|
||||
|
||||
async def test_implicit_sessions_checkout(self):
|
||||
# "To confirm that implicit sessions only allocate their server session after a
|
||||
# successful connection checkout" test from Driver Sessions Spec.
|
||||
@ -825,6 +871,73 @@ class TestSession(AsyncIntegrationTest):
|
||||
async with client.start_session() as s:
|
||||
self.assertRaises(TypeError, lambda: copy.copy(s))
|
||||
|
||||
async def test_nested_session_binding(self):
|
||||
coll = self.client.pymongo_test.test
|
||||
await coll.insert_one({"x": 1})
|
||||
|
||||
session1 = self.client.start_session()
|
||||
session2 = self.client.start_session()
|
||||
session1._materialize()
|
||||
session2._materialize()
|
||||
try:
|
||||
self.listener.reset()
|
||||
# Uses implicit session
|
||||
await coll.find_one()
|
||||
implicit_lsid = self.listener.started_events[0].command.get("lsid")
|
||||
self.assertIsNotNone(implicit_lsid)
|
||||
self.assertNotEqual(implicit_lsid, session1.session_id)
|
||||
self.assertNotEqual(implicit_lsid, session2.session_id)
|
||||
|
||||
async with session1.bind(end_session=False):
|
||||
self.listener.reset()
|
||||
# Uses bound session1
|
||||
await coll.find_one()
|
||||
session1_lsid = self.listener.started_events[0].command.get("lsid")
|
||||
self.assertEqual(session1_lsid, session1.session_id)
|
||||
|
||||
async with session2.bind(end_session=False):
|
||||
self.listener.reset()
|
||||
# Uses bound session2
|
||||
await coll.find_one()
|
||||
session2_lsid = self.listener.started_events[0].command.get("lsid")
|
||||
self.assertEqual(session2_lsid, session2.session_id)
|
||||
self.assertNotEqual(session2_lsid, session1.session_id)
|
||||
|
||||
self.listener.reset()
|
||||
# Use bound session1 again
|
||||
await coll.find_one()
|
||||
session1_lsid = self.listener.started_events[0].command.get("lsid")
|
||||
self.assertEqual(session1_lsid, session1.session_id)
|
||||
self.assertNotEqual(session1_lsid, session2.session_id)
|
||||
|
||||
self.listener.reset()
|
||||
# Uses implicit session
|
||||
await coll.find_one()
|
||||
implicit_lsid = self.listener.started_events[0].command.get("lsid")
|
||||
self.assertIsNotNone(implicit_lsid)
|
||||
self.assertNotEqual(implicit_lsid, session1.session_id)
|
||||
self.assertNotEqual(implicit_lsid, session2.session_id)
|
||||
|
||||
finally:
|
||||
await session1.end_session()
|
||||
await session2.end_session()
|
||||
|
||||
async def test_session_binding_end_session(self):
|
||||
coll = self.client.pymongo_test.test
|
||||
await coll.insert_one({"x": 1})
|
||||
|
||||
async with self.client.start_session().bind() as s1:
|
||||
await coll.find_one()
|
||||
|
||||
self.assertTrue(s1.has_ended)
|
||||
|
||||
async with self.client.start_session().bind(end_session=False) as s2:
|
||||
await coll.find_one()
|
||||
|
||||
self.assertFalse(s2.has_ended)
|
||||
|
||||
await s2.end_session()
|
||||
|
||||
|
||||
class TestCausalConsistency(AsyncUnitTest):
|
||||
listener: SessionTestListener
|
||||
|
||||
@ -189,6 +189,52 @@ class TestSession(IntegrationTest):
|
||||
f"{f.__name__} did not return implicit session to pool",
|
||||
)
|
||||
|
||||
# Explicit bound session
|
||||
for f, args, kw in ops:
|
||||
with client.start_session() as s:
|
||||
with s.bind():
|
||||
listener.reset()
|
||||
s._materialize()
|
||||
last_use = s._server_session.last_use
|
||||
start = time.monotonic()
|
||||
self.assertLessEqual(last_use, start)
|
||||
# In case "f" modifies its inputs.
|
||||
args = copy.copy(args)
|
||||
kw = copy.copy(kw)
|
||||
f(*args, **kw)
|
||||
self.assertGreaterEqual(len(listener.started_events), 1)
|
||||
for event in listener.started_events:
|
||||
self.assertIn(
|
||||
"lsid",
|
||||
event.command,
|
||||
f"{f.__name__} sent no lsid with {event.command_name}",
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
s.session_id,
|
||||
event.command["lsid"],
|
||||
f"{f.__name__} sent wrong lsid with {event.command_name}",
|
||||
)
|
||||
|
||||
self.assertFalse(s.has_ended)
|
||||
|
||||
self.assertTrue(s.has_ended)
|
||||
with self.assertRaisesRegex(InvalidOperation, "ended session"):
|
||||
with s.bind():
|
||||
f(*args, **kw)
|
||||
|
||||
# Test a session cannot be used on another client.
|
||||
with self.client2.start_session() as s:
|
||||
with s.bind():
|
||||
# In case "f" modifies its inputs.
|
||||
args = copy.copy(args)
|
||||
kw = copy.copy(kw)
|
||||
with self.assertRaisesRegex(
|
||||
InvalidOperation,
|
||||
"Only the client that created the bound session can perform operations within its context block",
|
||||
):
|
||||
f(*args, **kw)
|
||||
|
||||
def test_implicit_sessions_checkout(self):
|
||||
# "To confirm that implicit sessions only allocate their server session after a
|
||||
# successful connection checkout" test from Driver Sessions Spec.
|
||||
@ -825,6 +871,73 @@ class TestSession(IntegrationTest):
|
||||
with client.start_session() as s:
|
||||
self.assertRaises(TypeError, lambda: copy.copy(s))
|
||||
|
||||
def test_nested_session_binding(self):
|
||||
coll = self.client.pymongo_test.test
|
||||
coll.insert_one({"x": 1})
|
||||
|
||||
session1 = self.client.start_session()
|
||||
session2 = self.client.start_session()
|
||||
session1._materialize()
|
||||
session2._materialize()
|
||||
try:
|
||||
self.listener.reset()
|
||||
# Uses implicit session
|
||||
coll.find_one()
|
||||
implicit_lsid = self.listener.started_events[0].command.get("lsid")
|
||||
self.assertIsNotNone(implicit_lsid)
|
||||
self.assertNotEqual(implicit_lsid, session1.session_id)
|
||||
self.assertNotEqual(implicit_lsid, session2.session_id)
|
||||
|
||||
with session1.bind(end_session=False):
|
||||
self.listener.reset()
|
||||
# Uses bound session1
|
||||
coll.find_one()
|
||||
session1_lsid = self.listener.started_events[0].command.get("lsid")
|
||||
self.assertEqual(session1_lsid, session1.session_id)
|
||||
|
||||
with session2.bind(end_session=False):
|
||||
self.listener.reset()
|
||||
# Uses bound session2
|
||||
coll.find_one()
|
||||
session2_lsid = self.listener.started_events[0].command.get("lsid")
|
||||
self.assertEqual(session2_lsid, session2.session_id)
|
||||
self.assertNotEqual(session2_lsid, session1.session_id)
|
||||
|
||||
self.listener.reset()
|
||||
# Use bound session1 again
|
||||
coll.find_one()
|
||||
session1_lsid = self.listener.started_events[0].command.get("lsid")
|
||||
self.assertEqual(session1_lsid, session1.session_id)
|
||||
self.assertNotEqual(session1_lsid, session2.session_id)
|
||||
|
||||
self.listener.reset()
|
||||
# Uses implicit session
|
||||
coll.find_one()
|
||||
implicit_lsid = self.listener.started_events[0].command.get("lsid")
|
||||
self.assertIsNotNone(implicit_lsid)
|
||||
self.assertNotEqual(implicit_lsid, session1.session_id)
|
||||
self.assertNotEqual(implicit_lsid, session2.session_id)
|
||||
|
||||
finally:
|
||||
session1.end_session()
|
||||
session2.end_session()
|
||||
|
||||
def test_session_binding_end_session(self):
|
||||
coll = self.client.pymongo_test.test
|
||||
coll.insert_one({"x": 1})
|
||||
|
||||
with self.client.start_session().bind() as s1:
|
||||
coll.find_one()
|
||||
|
||||
self.assertTrue(s1.has_ended)
|
||||
|
||||
with self.client.start_session().bind(end_session=False) as s2:
|
||||
coll.find_one()
|
||||
|
||||
self.assertFalse(s2.has_ended)
|
||||
|
||||
s2.end_session()
|
||||
|
||||
|
||||
class TestCausalConsistency(UnitTest):
|
||||
listener: SessionTestListener
|
||||
|
||||
@ -37,6 +37,7 @@ replacements = {
|
||||
"AsyncRawBatchCursor": "RawBatchCursor",
|
||||
"AsyncRawBatchCommandCursor": "RawBatchCommandCursor",
|
||||
"AsyncClientSession": "ClientSession",
|
||||
"_AsyncBoundSessionContext": "_BoundSessionContext",
|
||||
"AsyncChangeStream": "ChangeStream",
|
||||
"AsyncCollectionChangeStream": "CollectionChangeStream",
|
||||
"AsyncDatabaseChangeStream": "DatabaseChangeStream",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user