From f2416507f020996509ccfd83d05a12c3dfc57a37 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 23 Feb 2026 14:17:47 -0500 Subject: [PATCH] Fix test --- pymongo/asynchronous/mongo_client.py | 11 +++++------ pymongo/synchronous/mongo_client.py | 11 +++++------ test/asynchronous/test_session.py | 2 ++ test/test_session.py | 2 ++ 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index afd634a76..d70186586 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -2268,15 +2268,14 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): self, session: Optional[client_session.AsyncClientSession] ) -> AsyncGenerator[Optional[client_session.AsyncClientSession], None]: """If provided session is None, lend a temporary session.""" + if session is not None and not isinstance(session, client_session.AsyncClientSession): + raise ValueError( + f"'session' argument must be an AsyncClientSession or None, not {type(session)}" + ) # Check for a bound session. If one exists, treat it as an explicitly passed session. session = session or self._get_bound_session() - - if session is not None: - if not isinstance(session, client_session.AsyncClientSession): - raise ValueError( - f"'session' argument must be an AsyncClientSession or None, not {type(session)}" - ) + if session: # Don't call end_session. yield session return diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 863bce5be..943cd1f5b 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -2264,15 +2264,14 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): self, session: Optional[client_session.ClientSession] ) -> Generator[Optional[client_session.ClientSession], None]: """If provided session is None, lend a temporary session.""" + if session is not None and not isinstance(session, client_session.ClientSession): + raise ValueError( + f"'session' argument must be a ClientSession or None, not {type(session)}" + ) # Check for a bound session. If one exists, treat it as an explicitly passed session. session = session or self._get_bound_session() - - if session is not None: - if not isinstance(session, client_session.ClientSession): - raise ValueError( - f"'session' argument must be a ClientSession or None, not {type(session)}" - ) + if session: # Don't call end_session. yield session return diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index 1f1412b58..3ef2f7337 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -877,6 +877,8 @@ class TestSession(AsyncIntegrationTest): session1 = self.client.start_session() session2 = self.client.start_session() + session1._materialize() + session2._materialize() try: self.listener.reset() # Uses implicit session diff --git a/test/test_session.py b/test/test_session.py index 61bf4ef37..4c5859693 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -877,6 +877,8 @@ class TestSession(IntegrationTest): session1 = self.client.start_session() session2 = self.client.start_session() + session1._materialize() + session2._materialize() try: self.listener.reset() # Uses implicit session