PYTHON-2390 - Retryable reads use the same implicit session (#2544)

This commit is contained in:
Noah Stapp 2025-09-24 13:23:28 -04:00 committed by GitHub
parent 51f7b408f3
commit 0049dc8896
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 198 additions and 130 deletions

View File

@ -50,7 +50,6 @@ class _AggregationCommand:
cursor_class: type[AsyncCommandCursor[Any]],
pipeline: _Pipeline,
options: MutableMapping[str, Any],
explicit_session: bool,
let: Optional[Mapping[str, Any]] = None,
user_fields: Optional[MutableMapping[str, Any]] = None,
result_processor: Optional[Callable[[Mapping[str, Any], AsyncConnection], None]] = None,
@ -92,7 +91,6 @@ class _AggregationCommand:
self._options["cursor"]["batchSize"] = self._batch_size
self._cursor_class = cursor_class
self._explicit_session = explicit_session
self._user_fields = user_fields
self._result_processor = result_processor
@ -197,7 +195,6 @@ class _AggregationCommand:
batch_size=self._batch_size or 0,
max_await_time_ms=self._max_await_time_ms,
session=session,
explicit_session=self._explicit_session,
comment=self._options.get("comment"),
)
await cmd_cursor._maybe_pin_connection(conn)

View File

@ -236,7 +236,7 @@ class AsyncChangeStream(Generic[_DocumentType]):
)
async def _run_aggregation_cmd(
self, session: Optional[AsyncClientSession], explicit_session: bool
self, session: Optional[AsyncClientSession]
) -> AsyncCommandCursor: # type: ignore[type-arg]
"""Run the full aggregation pipeline for this AsyncChangeStream and return
the corresponding AsyncCommandCursor.
@ -246,7 +246,6 @@ class AsyncChangeStream(Generic[_DocumentType]):
AsyncCommandCursor,
self._aggregation_pipeline(),
self._command_options(),
explicit_session,
result_processor=self._process_result,
comment=self._comment,
)
@ -258,10 +257,8 @@ class AsyncChangeStream(Generic[_DocumentType]):
)
async def _create_cursor(self) -> AsyncCommandCursor: # type: ignore[type-arg]
async with self._client._tmp_session(self._session, close=False) as s:
return await self._run_aggregation_cmd(
session=s, explicit_session=self._session is not None
)
async with self._client._tmp_session(self._session) as s:
return await self._run_aggregation_cmd(session=s)
async def _resume(self) -> None:
"""Reestablish this change stream after a resumable error."""

View File

@ -440,6 +440,8 @@ class _AsyncClientBulk:
) -> None:
"""Internal helper for processing the server reply command cursor."""
if result.get("cursor"):
if session:
session._leave_alive = True
coll = AsyncCollection(
database=AsyncDatabase(self.client, "admin"),
name="$cmd.bulkWrite",
@ -449,7 +451,6 @@ class _AsyncClientBulk:
result["cursor"],
conn.address,
session=session,
explicit_session=session is not None,
comment=self.comment,
)
await cmd_cursor._maybe_pin_connection(conn)

View File

@ -513,6 +513,10 @@ class AsyncClientSession:
# Is this an implicitly created session?
self._implicit = implicit
self._transaction = _Transaction(None, client)
# Is this session attached to a cursor?
self._attached_to_cursor = False
# Should we leave the session alive when the cursor is closed?
self._leave_alive = False
async def end_session(self) -> None:
"""Finish this session. If a transaction has started, abort it.
@ -535,7 +539,7 @@ class AsyncClientSession:
def _end_implicit_session(self) -> None:
# Implicit sessions can't be part of transactions or pinned connections
if self._server_session is not None:
if not self._leave_alive and self._server_session is not None:
self._client._return_server_session(self._server_session)
self._server_session = None

View File

@ -2549,7 +2549,6 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
self.with_options(codec_options=codec_options, read_preference=ReadPreference.PRIMARY),
)
read_pref = (session and session._txn_read_preference()) or ReadPreference.PRIMARY
explicit_session = session is not None
async def _cmd(
session: Optional[AsyncClientSession],
@ -2576,13 +2575,12 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
cursor,
conn.address,
session=session,
explicit_session=explicit_session,
comment=cmd.get("comment"),
)
await cmd_cursor._maybe_pin_connection(conn)
return cmd_cursor
async with self._database.client._tmp_session(session, False) as s:
async with self._database.client._tmp_session(session) as s:
return await self._database.client._retryable_read(
_cmd, read_pref, s, operation=_Op.LIST_INDEXES
)
@ -2678,7 +2676,6 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
AsyncCommandCursor,
pipeline,
kwargs,
explicit_session=session is not None,
comment=comment,
user_fields={"cursor": {"firstBatch": 1}},
)
@ -2900,7 +2897,6 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
pipeline: _Pipeline,
cursor_class: Type[AsyncCommandCursor], # type: ignore[type-arg]
session: Optional[AsyncClientSession],
explicit_session: bool,
let: Optional[Mapping[str, Any]] = None,
comment: Optional[Any] = None,
**kwargs: Any,
@ -2912,7 +2908,6 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
cursor_class,
pipeline,
kwargs,
explicit_session,
let,
user_fields={"cursor": {"firstBatch": 1}},
)
@ -3018,13 +3013,12 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
.. _aggregate command:
https://mongodb.com/docs/manual/reference/command/aggregate
"""
async with self._database.client._tmp_session(session, close=False) as s:
async with self._database.client._tmp_session(session) as s:
return await self._aggregate(
_CollectionAggregationCommand,
pipeline,
AsyncCommandCursor,
session=s,
explicit_session=session is not None,
let=let,
comment=comment,
**kwargs,
@ -3065,7 +3059,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
raise InvalidOperation("aggregate_raw_batches does not support auto encryption")
if comment is not None:
kwargs["comment"] = comment
async with self._database.client._tmp_session(session, close=False) as s:
async with self._database.client._tmp_session(session) as s:
return cast(
AsyncRawBatchCursor[_DocumentType],
await self._aggregate(
@ -3073,7 +3067,6 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
pipeline,
AsyncRawBatchCommandCursor,
session=s,
explicit_session=session is not None,
**kwargs,
),
)

View File

@ -64,7 +64,6 @@ class AsyncCommandCursor(Generic[_DocumentType]):
batch_size: int = 0,
max_await_time_ms: Optional[int] = None,
session: Optional[AsyncClientSession] = None,
explicit_session: bool = False,
comment: Any = None,
) -> None:
"""Create a new command cursor."""
@ -80,7 +79,8 @@ class AsyncCommandCursor(Generic[_DocumentType]):
self._max_await_time_ms = max_await_time_ms
self._timeout = self._collection.database.client.options.timeout
self._session = session
self._explicit_session = explicit_session
if self._session is not None:
self._session._attached_to_cursor = True
self._killed = self._id == 0
self._comment = comment
if self._killed:
@ -197,7 +197,7 @@ class AsyncCommandCursor(Generic[_DocumentType]):
.. versionadded:: 3.6
"""
if self._explicit_session:
if self._session and not self._session._implicit:
return self._session
return None
@ -218,9 +218,10 @@ class AsyncCommandCursor(Generic[_DocumentType]):
"""Closes this cursor without acquiring a lock."""
cursor_id, address = self._prepare_to_die()
self._collection.database.client._cleanup_cursor_no_lock(
cursor_id, address, self._sock_mgr, self._session, self._explicit_session
cursor_id, address, self._sock_mgr, self._session
)
if not self._explicit_session:
if self._session and self._session._implicit:
self._session._attached_to_cursor = False
self._session = None
self._sock_mgr = None
@ -232,14 +233,15 @@ class AsyncCommandCursor(Generic[_DocumentType]):
address,
self._sock_mgr,
self._session,
self._explicit_session,
)
if not self._explicit_session:
if self._session and self._session._implicit:
self._session._attached_to_cursor = False
self._session = None
self._sock_mgr = None
def _end_session(self) -> None:
if self._session and not self._explicit_session:
if self._session and self._session._implicit:
self._session._attached_to_cursor = False
self._session._end_implicit_session()
self._session = None
@ -430,7 +432,6 @@ class AsyncRawBatchCommandCursor(AsyncCommandCursor[_DocumentType]):
batch_size: int = 0,
max_await_time_ms: Optional[int] = None,
session: Optional[AsyncClientSession] = None,
explicit_session: bool = False,
comment: Any = None,
) -> None:
"""Create a new cursor / iterator over raw batches of BSON data.
@ -449,7 +450,6 @@ class AsyncRawBatchCommandCursor(AsyncCommandCursor[_DocumentType]):
batch_size,
max_await_time_ms,
session,
explicit_session,
comment,
)

View File

@ -138,10 +138,9 @@ class AsyncCursor(Generic[_DocumentType]):
if session:
self._session = session
self._explicit_session = True
self._session._attached_to_cursor = True
else:
self._session = None
self._explicit_session = False
spec: Mapping[str, Any] = filter or {}
validate_is_mapping("filter", spec)
@ -150,7 +149,7 @@ class AsyncCursor(Generic[_DocumentType]):
if not isinstance(limit, int):
raise TypeError(f"limit must be an instance of int, not {type(limit)}")
validate_boolean("no_cursor_timeout", no_cursor_timeout)
if no_cursor_timeout and not self._explicit_session:
if no_cursor_timeout and self._session and self._session._implicit:
warnings.warn(
"use an explicit session with no_cursor_timeout=True "
"otherwise the cursor may still timeout after "
@ -283,7 +282,7 @@ class AsyncCursor(Generic[_DocumentType]):
def _clone(self, deepcopy: bool = True, base: Optional[AsyncCursor] = None) -> AsyncCursor: # type: ignore[type-arg]
"""Internal clone helper."""
if not base:
if self._explicit_session:
if self._session and not self._session._implicit:
base = self._clone_base(self._session)
else:
base = self._clone_base(None)
@ -945,7 +944,7 @@ class AsyncCursor(Generic[_DocumentType]):
.. versionadded:: 3.6
"""
if self._explicit_session:
if self._session and not self._session._implicit:
return self._session
return None
@ -1034,9 +1033,10 @@ class AsyncCursor(Generic[_DocumentType]):
cursor_id, address = self._prepare_to_die(already_killed)
self._collection.database.client._cleanup_cursor_no_lock(
cursor_id, address, self._sock_mgr, self._session, self._explicit_session
cursor_id, address, self._sock_mgr, self._session
)
if not self._explicit_session:
if self._session and self._session._implicit:
self._session._attached_to_cursor = False
self._session = None
self._sock_mgr = None
@ -1054,9 +1054,9 @@ class AsyncCursor(Generic[_DocumentType]):
address,
self._sock_mgr,
self._session,
self._explicit_session,
)
if not self._explicit_session:
if self._session and self._session._implicit:
self._session._attached_to_cursor = False
self._session = None
self._sock_mgr = None

View File

@ -611,6 +611,8 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
common.validate_is_mapping("clusteredIndex", clustered_index)
async with self._client._tmp_session(session) as s:
if s and not s.in_transaction:
s._leave_alive = True
# Skip this check in a transaction where listCollections is not
# supported.
if (
@ -619,6 +621,8 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
and name in await self._list_collection_names(filter={"name": name}, session=s)
):
raise CollectionInvalid("collection %s already exists" % name)
if s:
s._leave_alive = False
coll = AsyncCollection(
self,
name,
@ -699,13 +703,12 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
.. _aggregate command:
https://mongodb.com/docs/manual/reference/command/aggregate
"""
async with self.client._tmp_session(session, close=False) as s:
async with self.client._tmp_session(session) as s:
cmd = _DatabaseAggregationCommand(
self,
AsyncCommandCursor,
pipeline,
kwargs,
session is not None,
user_fields={"cursor": {"firstBatch": 1}},
)
return await self.client._retryable_read(
@ -1011,7 +1014,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
else:
command_name = next(iter(command))
async with self._client._tmp_session(session, close=False) as tmp_session:
async with self._client._tmp_session(session) as tmp_session:
opts = codec_options or DEFAULT_CODEC_OPTIONS
if read_preference is None:
@ -1043,7 +1046,6 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
conn.address,
max_await_time_ms=max_await_time_ms,
session=tmp_session,
explicit_session=session is not None,
comment=comment,
)
await cmd_cursor._maybe_pin_connection(conn)
@ -1089,7 +1091,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
)
cmd = {"listCollections": 1, "cursor": {}}
cmd.update(kwargs)
async with self._client._tmp_session(session, close=False) as tmp_session:
async with self._client._tmp_session(session) as tmp_session:
cursor = (
await self._command(conn, cmd, read_preference=read_preference, session=tmp_session)
)["cursor"]
@ -1098,7 +1100,6 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
cursor,
conn.address,
session=tmp_session,
explicit_session=session is not None,
comment=cmd.get("comment"),
)
await cmd_cursor._maybe_pin_connection(conn)

View File

@ -2048,17 +2048,18 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
retryable = bool(
retryable and self.options.retry_reads and not (session and session.in_transaction)
)
return await self._retry_internal(
func,
session,
None,
operation,
is_read=True,
address=address,
read_pref=read_pref,
retryable=retryable,
operation_id=operation_id,
)
async with self._tmp_session(session) as s:
return await self._retry_internal(
func,
s,
None,
operation,
is_read=True,
address=address,
read_pref=read_pref,
retryable=retryable,
operation_id=operation_id,
)
async def _retryable_write(
self,
@ -2091,7 +2092,6 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
address: Optional[_CursorAddress],
conn_mgr: _ConnectionManager,
session: Optional[AsyncClientSession],
explicit_session: bool,
) -> None:
"""Cleanup a cursor from __del__ without locking.
@ -2106,7 +2106,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
# The cursor will be closed later in a different session.
if cursor_id or conn_mgr:
self._close_cursor_soon(cursor_id, address, conn_mgr)
if session and not explicit_session:
if session and session._implicit and not session._leave_alive:
session._end_implicit_session()
async def _cleanup_cursor_lock(
@ -2115,7 +2115,6 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
address: Optional[_CursorAddress],
conn_mgr: _ConnectionManager,
session: Optional[AsyncClientSession],
explicit_session: bool,
) -> None:
"""Cleanup a cursor from cursor.close() using a lock.
@ -2127,7 +2126,6 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
:param address: The _CursorAddress.
:param conn_mgr: The _ConnectionManager for the pinned connection or None.
:param session: The cursor's session.
:param explicit_session: True if the session was passed explicitly.
"""
if cursor_id:
if conn_mgr and conn_mgr.more_to_come:
@ -2140,7 +2138,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
await self._close_cursor_now(cursor_id, address, session=session, conn_mgr=conn_mgr)
if conn_mgr:
await conn_mgr.close()
if session and not explicit_session:
if session and session._implicit and not session._leave_alive:
session._end_implicit_session()
async def _close_cursor_now(
@ -2221,7 +2219,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
for address, cursor_id, conn_mgr in pinned_cursors:
try:
await self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False)
await self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None)
except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed:
# Raise the exception when client is closed so that it
@ -2266,7 +2264,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
@contextlib.asynccontextmanager
async def _tmp_session(
self, session: Optional[client_session.AsyncClientSession], close: bool = True
self, session: Optional[client_session.AsyncClientSession]
) -> AsyncGenerator[Optional[client_session.AsyncClientSession], None]:
"""If provided session is None, lend a temporary session."""
if session is not None:
@ -2291,7 +2289,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
raise
finally:
# Call end_session when we exit this scope.
if close:
if not s._attached_to_cursor:
await s.end_session()
else:
yield None

View File

@ -50,7 +50,6 @@ class _AggregationCommand:
cursor_class: type[CommandCursor[Any]],
pipeline: _Pipeline,
options: MutableMapping[str, Any],
explicit_session: bool,
let: Optional[Mapping[str, Any]] = None,
user_fields: Optional[MutableMapping[str, Any]] = None,
result_processor: Optional[Callable[[Mapping[str, Any], Connection], None]] = None,
@ -92,7 +91,6 @@ class _AggregationCommand:
self._options["cursor"]["batchSize"] = self._batch_size
self._cursor_class = cursor_class
self._explicit_session = explicit_session
self._user_fields = user_fields
self._result_processor = result_processor
@ -197,7 +195,6 @@ class _AggregationCommand:
batch_size=self._batch_size or 0,
max_await_time_ms=self._max_await_time_ms,
session=session,
explicit_session=self._explicit_session,
comment=self._options.get("comment"),
)
cmd_cursor._maybe_pin_connection(conn)

View File

@ -235,9 +235,7 @@ class ChangeStream(Generic[_DocumentType]):
f"response : {result!r}"
)
def _run_aggregation_cmd(
self, session: Optional[ClientSession], explicit_session: bool
) -> CommandCursor: # type: ignore[type-arg]
def _run_aggregation_cmd(self, session: Optional[ClientSession]) -> CommandCursor: # type: ignore[type-arg]
"""Run the full aggregation pipeline for this ChangeStream and return
the corresponding CommandCursor.
"""
@ -246,7 +244,6 @@ class ChangeStream(Generic[_DocumentType]):
CommandCursor,
self._aggregation_pipeline(),
self._command_options(),
explicit_session,
result_processor=self._process_result,
comment=self._comment,
)
@ -258,8 +255,8 @@ class ChangeStream(Generic[_DocumentType]):
)
def _create_cursor(self) -> CommandCursor: # type: ignore[type-arg]
with self._client._tmp_session(self._session, close=False) as s:
return self._run_aggregation_cmd(session=s, explicit_session=self._session is not None)
with self._client._tmp_session(self._session) as s:
return self._run_aggregation_cmd(session=s)
def _resume(self) -> None:
"""Reestablish this change stream after a resumable error."""

View File

@ -438,6 +438,8 @@ class _ClientBulk:
) -> None:
"""Internal helper for processing the server reply command cursor."""
if result.get("cursor"):
if session:
session._leave_alive = True
coll = Collection(
database=Database(self.client, "admin"),
name="$cmd.bulkWrite",
@ -447,7 +449,6 @@ class _ClientBulk:
result["cursor"],
conn.address,
session=session,
explicit_session=session is not None,
comment=self.comment,
)
cmd_cursor._maybe_pin_connection(conn)

View File

@ -512,6 +512,10 @@ class ClientSession:
# Is this an implicitly created session?
self._implicit = implicit
self._transaction = _Transaction(None, client)
# Is this session attached to a cursor?
self._attached_to_cursor = False
# Should we leave the session alive when the cursor is closed?
self._leave_alive = False
def end_session(self) -> None:
"""Finish this session. If a transaction has started, abort it.
@ -534,7 +538,7 @@ class ClientSession:
def _end_implicit_session(self) -> None:
# Implicit sessions can't be part of transactions or pinned connections
if self._server_session is not None:
if not self._leave_alive and self._server_session is not None:
self._client._return_server_session(self._server_session)
self._server_session = None

View File

@ -2546,7 +2546,6 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
self.with_options(codec_options=codec_options, read_preference=ReadPreference.PRIMARY),
)
read_pref = (session and session._txn_read_preference()) or ReadPreference.PRIMARY
explicit_session = session is not None
def _cmd(
session: Optional[ClientSession],
@ -2573,13 +2572,12 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
cursor,
conn.address,
session=session,
explicit_session=explicit_session,
comment=cmd.get("comment"),
)
cmd_cursor._maybe_pin_connection(conn)
return cmd_cursor
with self._database.client._tmp_session(session, False) as s:
with self._database.client._tmp_session(session) as s:
return self._database.client._retryable_read(
_cmd, read_pref, s, operation=_Op.LIST_INDEXES
)
@ -2675,7 +2673,6 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
CommandCursor,
pipeline,
kwargs,
explicit_session=session is not None,
comment=comment,
user_fields={"cursor": {"firstBatch": 1}},
)
@ -2893,7 +2890,6 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
pipeline: _Pipeline,
cursor_class: Type[CommandCursor], # type: ignore[type-arg]
session: Optional[ClientSession],
explicit_session: bool,
let: Optional[Mapping[str, Any]] = None,
comment: Optional[Any] = None,
**kwargs: Any,
@ -2905,7 +2901,6 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
cursor_class,
pipeline,
kwargs,
explicit_session,
let,
user_fields={"cursor": {"firstBatch": 1}},
)
@ -3011,13 +3006,12 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
.. _aggregate command:
https://mongodb.com/docs/manual/reference/command/aggregate
"""
with self._database.client._tmp_session(session, close=False) as s:
with self._database.client._tmp_session(session) as s:
return self._aggregate(
_CollectionAggregationCommand,
pipeline,
CommandCursor,
session=s,
explicit_session=session is not None,
let=let,
comment=comment,
**kwargs,
@ -3058,7 +3052,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
raise InvalidOperation("aggregate_raw_batches does not support auto encryption")
if comment is not None:
kwargs["comment"] = comment
with self._database.client._tmp_session(session, close=False) as s:
with self._database.client._tmp_session(session) as s:
return cast(
RawBatchCursor[_DocumentType],
self._aggregate(
@ -3066,7 +3060,6 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
pipeline,
RawBatchCommandCursor,
session=s,
explicit_session=session is not None,
**kwargs,
),
)

View File

@ -64,7 +64,6 @@ class CommandCursor(Generic[_DocumentType]):
batch_size: int = 0,
max_await_time_ms: Optional[int] = None,
session: Optional[ClientSession] = None,
explicit_session: bool = False,
comment: Any = None,
) -> None:
"""Create a new command cursor."""
@ -80,7 +79,8 @@ class CommandCursor(Generic[_DocumentType]):
self._max_await_time_ms = max_await_time_ms
self._timeout = self._collection.database.client.options.timeout
self._session = session
self._explicit_session = explicit_session
if self._session is not None:
self._session._attached_to_cursor = True
self._killed = self._id == 0
self._comment = comment
if self._killed:
@ -197,7 +197,7 @@ class CommandCursor(Generic[_DocumentType]):
.. versionadded:: 3.6
"""
if self._explicit_session:
if self._session and not self._session._implicit:
return self._session
return None
@ -218,9 +218,10 @@ class CommandCursor(Generic[_DocumentType]):
"""Closes this cursor without acquiring a lock."""
cursor_id, address = self._prepare_to_die()
self._collection.database.client._cleanup_cursor_no_lock(
cursor_id, address, self._sock_mgr, self._session, self._explicit_session
cursor_id, address, self._sock_mgr, self._session
)
if not self._explicit_session:
if self._session and self._session._implicit:
self._session._attached_to_cursor = False
self._session = None
self._sock_mgr = None
@ -232,14 +233,15 @@ class CommandCursor(Generic[_DocumentType]):
address,
self._sock_mgr,
self._session,
self._explicit_session,
)
if not self._explicit_session:
if self._session and self._session._implicit:
self._session._attached_to_cursor = False
self._session = None
self._sock_mgr = None
def _end_session(self) -> None:
if self._session and not self._explicit_session:
if self._session and self._session._implicit:
self._session._attached_to_cursor = False
self._session._end_implicit_session()
self._session = None
@ -430,7 +432,6 @@ class RawBatchCommandCursor(CommandCursor[_DocumentType]):
batch_size: int = 0,
max_await_time_ms: Optional[int] = None,
session: Optional[ClientSession] = None,
explicit_session: bool = False,
comment: Any = None,
) -> None:
"""Create a new cursor / iterator over raw batches of BSON data.
@ -449,7 +450,6 @@ class RawBatchCommandCursor(CommandCursor[_DocumentType]):
batch_size,
max_await_time_ms,
session,
explicit_session,
comment,
)

View File

@ -138,10 +138,9 @@ class Cursor(Generic[_DocumentType]):
if session:
self._session = session
self._explicit_session = True
self._session._attached_to_cursor = True
else:
self._session = None
self._explicit_session = False
spec: Mapping[str, Any] = filter or {}
validate_is_mapping("filter", spec)
@ -150,7 +149,7 @@ class Cursor(Generic[_DocumentType]):
if not isinstance(limit, int):
raise TypeError(f"limit must be an instance of int, not {type(limit)}")
validate_boolean("no_cursor_timeout", no_cursor_timeout)
if no_cursor_timeout and not self._explicit_session:
if no_cursor_timeout and self._session and self._session._implicit:
warnings.warn(
"use an explicit session with no_cursor_timeout=True "
"otherwise the cursor may still timeout after "
@ -283,7 +282,7 @@ class Cursor(Generic[_DocumentType]):
def _clone(self, deepcopy: bool = True, base: Optional[Cursor] = None) -> Cursor: # type: ignore[type-arg]
"""Internal clone helper."""
if not base:
if self._explicit_session:
if self._session and not self._session._implicit:
base = self._clone_base(self._session)
else:
base = self._clone_base(None)
@ -943,7 +942,7 @@ class Cursor(Generic[_DocumentType]):
.. versionadded:: 3.6
"""
if self._explicit_session:
if self._session and not self._session._implicit:
return self._session
return None
@ -1032,9 +1031,10 @@ class Cursor(Generic[_DocumentType]):
cursor_id, address = self._prepare_to_die(already_killed)
self._collection.database.client._cleanup_cursor_no_lock(
cursor_id, address, self._sock_mgr, self._session, self._explicit_session
cursor_id, address, self._sock_mgr, self._session
)
if not self._explicit_session:
if self._session and self._session._implicit:
self._session._attached_to_cursor = False
self._session = None
self._sock_mgr = None
@ -1052,9 +1052,9 @@ class Cursor(Generic[_DocumentType]):
address,
self._sock_mgr,
self._session,
self._explicit_session,
)
if not self._explicit_session:
if self._session and self._session._implicit:
self._session._attached_to_cursor = False
self._session = None
self._sock_mgr = None

View File

@ -611,6 +611,8 @@ class Database(common.BaseObject, Generic[_DocumentType]):
common.validate_is_mapping("clusteredIndex", clustered_index)
with self._client._tmp_session(session) as s:
if s and not s.in_transaction:
s._leave_alive = True
# Skip this check in a transaction where listCollections is not
# supported.
if (
@ -619,6 +621,8 @@ class Database(common.BaseObject, Generic[_DocumentType]):
and name in self._list_collection_names(filter={"name": name}, session=s)
):
raise CollectionInvalid("collection %s already exists" % name)
if s:
s._leave_alive = False
coll = Collection(
self,
name,
@ -699,13 +703,12 @@ class Database(common.BaseObject, Generic[_DocumentType]):
.. _aggregate command:
https://mongodb.com/docs/manual/reference/command/aggregate
"""
with self.client._tmp_session(session, close=False) as s:
with self.client._tmp_session(session) as s:
cmd = _DatabaseAggregationCommand(
self,
CommandCursor,
pipeline,
kwargs,
session is not None,
user_fields={"cursor": {"firstBatch": 1}},
)
return self.client._retryable_read(
@ -1009,7 +1012,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
else:
command_name = next(iter(command))
with self._client._tmp_session(session, close=False) as tmp_session:
with self._client._tmp_session(session) as tmp_session:
opts = codec_options or DEFAULT_CODEC_OPTIONS
if read_preference is None:
@ -1039,7 +1042,6 @@ class Database(common.BaseObject, Generic[_DocumentType]):
conn.address,
max_await_time_ms=max_await_time_ms,
session=tmp_session,
explicit_session=session is not None,
comment=comment,
)
cmd_cursor._maybe_pin_connection(conn)
@ -1085,7 +1087,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
)
cmd = {"listCollections": 1, "cursor": {}}
cmd.update(kwargs)
with self._client._tmp_session(session, close=False) as tmp_session:
with self._client._tmp_session(session) as tmp_session:
cursor = (
self._command(conn, cmd, read_preference=read_preference, session=tmp_session)
)["cursor"]
@ -1094,7 +1096,6 @@ class Database(common.BaseObject, Generic[_DocumentType]):
cursor,
conn.address,
session=tmp_session,
explicit_session=session is not None,
comment=cmd.get("comment"),
)
cmd_cursor._maybe_pin_connection(conn)

View File

@ -2044,17 +2044,18 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
retryable = bool(
retryable and self.options.retry_reads and not (session and session.in_transaction)
)
return self._retry_internal(
func,
session,
None,
operation,
is_read=True,
address=address,
read_pref=read_pref,
retryable=retryable,
operation_id=operation_id,
)
with self._tmp_session(session) as s:
return self._retry_internal(
func,
s,
None,
operation,
is_read=True,
address=address,
read_pref=read_pref,
retryable=retryable,
operation_id=operation_id,
)
def _retryable_write(
self,
@ -2087,7 +2088,6 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
address: Optional[_CursorAddress],
conn_mgr: _ConnectionManager,
session: Optional[ClientSession],
explicit_session: bool,
) -> None:
"""Cleanup a cursor from __del__ without locking.
@ -2102,7 +2102,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
# The cursor will be closed later in a different session.
if cursor_id or conn_mgr:
self._close_cursor_soon(cursor_id, address, conn_mgr)
if session and not explicit_session:
if session and session._implicit and not session._leave_alive:
session._end_implicit_session()
def _cleanup_cursor_lock(
@ -2111,7 +2111,6 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
address: Optional[_CursorAddress],
conn_mgr: _ConnectionManager,
session: Optional[ClientSession],
explicit_session: bool,
) -> None:
"""Cleanup a cursor from cursor.close() using a lock.
@ -2123,7 +2122,6 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
:param address: The _CursorAddress.
:param conn_mgr: The _ConnectionManager for the pinned connection or None.
:param session: The cursor's session.
:param explicit_session: True if the session was passed explicitly.
"""
if cursor_id:
if conn_mgr and conn_mgr.more_to_come:
@ -2136,7 +2134,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
self._close_cursor_now(cursor_id, address, session=session, conn_mgr=conn_mgr)
if conn_mgr:
conn_mgr.close()
if session and not explicit_session:
if session and session._implicit and not session._leave_alive:
session._end_implicit_session()
def _close_cursor_now(
@ -2217,7 +2215,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
for address, cursor_id, conn_mgr in pinned_cursors:
try:
self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False)
self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None)
except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed:
# Raise the exception when client is closed so that it
@ -2262,7 +2260,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
@contextlib.contextmanager
def _tmp_session(
self, session: Optional[client_session.ClientSession], close: bool = True
self, session: Optional[client_session.ClientSession]
) -> Generator[Optional[client_session.ClientSession], None]:
"""If provided session is None, lend a temporary session."""
if session is not None:
@ -2287,7 +2285,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
raise
finally:
# Call end_session when we exit this scope.
if close:
if not s._attached_to_cursor:
s.end_session()
else:
yield None

View File

@ -218,6 +218,49 @@ class TestRetryableReads(AsyncIntegrationTest):
# Assert that both events occurred on the same mongos.
assert listener.succeeded_events[0].connection_id == listener.failed_events[0].connection_id
@async_client_context.require_failCommand_fail_point
async def test_retryable_reads_are_retried_on_the_same_implicit_session(self):
listener = OvertCommandListener()
client = await self.async_rs_or_single_client(
directConnection=False,
event_listeners=[listener],
retryReads=True,
)
await client.t.t.insert_one({"x": 1})
commands = [
("aggregate", lambda: client.t.t.count_documents({})),
("aggregate", lambda: client.t.t.aggregate([{"$match": {}}])),
("count", lambda: client.t.t.estimated_document_count()),
("distinct", lambda: client.t.t.distinct("x")),
("find", lambda: client.t.t.find_one({})),
("listDatabases", lambda: client.list_databases()),
("listCollections", lambda: client.t.list_collections()),
("listIndexes", lambda: client.t.t.list_indexes()),
]
for command_name, operation in commands:
listener.reset()
fail_command = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {"failCommands": [command_name], "errorCode": 6},
}
async with self.fail_point(fail_command):
await operation()
# Assert that both events occurred on the same session.
command_docs = [
event.command
for event in listener.started_events
if event.command_name == command_name
]
self.assertEqual(len(command_docs), 2)
self.assertEqual(command_docs[0]["lsid"], command_docs[1]["lsid"])
self.assertIsNot(command_docs[0], command_docs[1])
if __name__ == "__main__":
unittest.main()

View File

@ -216,6 +216,49 @@ class TestRetryableReads(IntegrationTest):
# Assert that both events occurred on the same mongos.
assert listener.succeeded_events[0].connection_id == listener.failed_events[0].connection_id
@client_context.require_failCommand_fail_point
def test_retryable_reads_are_retried_on_the_same_implicit_session(self):
listener = OvertCommandListener()
client = self.rs_or_single_client(
directConnection=False,
event_listeners=[listener],
retryReads=True,
)
client.t.t.insert_one({"x": 1})
commands = [
("aggregate", lambda: client.t.t.count_documents({})),
("aggregate", lambda: client.t.t.aggregate([{"$match": {}}])),
("count", lambda: client.t.t.estimated_document_count()),
("distinct", lambda: client.t.t.distinct("x")),
("find", lambda: client.t.t.find_one({})),
("listDatabases", lambda: client.list_databases()),
("listCollections", lambda: client.t.list_collections()),
("listIndexes", lambda: client.t.t.list_indexes()),
]
for command_name, operation in commands:
listener.reset()
fail_command = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {"failCommands": [command_name], "errorCode": 6},
}
with self.fail_point(fail_command):
operation()
# Assert that both events occurred on the same session.
command_docs = [
event.command
for event in listener.started_events
if event.command_name == command_name
]
self.assertEqual(len(command_docs), 2)
self.assertEqual(command_docs[0]["lsid"], command_docs[1]["lsid"])
self.assertIsNot(command_docs[0], command_docs[1])
if __name__ == "__main__":
unittest.main()