PYTHON-2956 Drivers should check out an implicit session only after checking out a connection (#876)

This commit is contained in:
Julius Park 2022-03-01 15:44:05 -08:00 committed by GitHub
parent 782c5517e0
commit b737b843e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 106 additions and 11 deletions

View File

@ -947,9 +947,16 @@ class ClientSession(Generic[_DocumentType]):
return self._transaction.opts.read_preference
return None
def _materialize(self):
if isinstance(self._server_session, _EmptyServerSession):
old = self._server_session
self._server_session = self._client._topology.get_server_session()
if old.started_retryable_write:
self._server_session.inc_transaction_id()
def _apply_to(self, command, is_retryable, read_preference, sock_info):
self._check_ended()
self._materialize()
if self.options.snapshot:
self._update_read_concern(command, sock_info)
@ -1000,6 +1007,20 @@ class ClientSession(Generic[_DocumentType]):
raise TypeError("A ClientSession cannot be copied, create a new session instead")
class _EmptyServerSession:
__slots__ = "dirty", "started_retryable_write"
def __init__(self):
self.dirty = False
self.started_retryable_write = False
def mark_dirty(self):
self.dirty = True
def inc_transaction_id(self):
self.started_retryable_write = True
class _ServerSession(object):
def __init__(self, generation):
# Ensure id is type 4, regardless of CodecOptions.uuid_representation.

View File

@ -66,6 +66,7 @@ from pymongo import (
)
from pymongo.change_stream import ChangeStream, ClusterChangeStream
from pymongo.client_options import ClientOptions
from pymongo.client_session import _EmptyServerSession
from pymongo.command_cursor import CommandCursor
from pymongo.errors import (
AutoReconnect,
@ -1601,7 +1602,11 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
def __start_session(self, implicit, **kwargs):
# Raises ConfigurationError if sessions are not supported.
server_session = self._get_server_session()
if implicit:
self._topology._check_implicit_session_support()
server_session = _EmptyServerSession()
else:
server_session = self._get_server_session()
opts = client_session.SessionOptions(**kwargs)
return client_session.ClientSession(self, server_session, opts, implicit)
@ -1641,6 +1646,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
def _return_server_session(self, server_session, lock):
"""Internal: return a _ServerSession to the pool."""
if isinstance(server_session, _EmptyServerSession):
return
return self._topology.return_server_session(server_session, lock)
def _ensure_session(self, session=None):

View File

@ -514,8 +514,15 @@ class Topology(object):
with self._lock:
return self._session_pool.pop_all()
def _check_implicit_session_support(self):
with self._lock:
self._check_session_support()
def _check_session_support(self):
"""Internal check for session support on non-load balanced clusters."""
"""Internal check for session support on clusters."""
if self._settings.load_balanced:
# Sessions never time out in load balanced mode.
return float("inf")
session_timeout = self._description.logical_session_timeout_minutes
if session_timeout is None:
# Maybe we need an initial scan? Can raise ServerSelectionError.
@ -537,12 +544,7 @@ class Topology(object):
def get_server_session(self):
"""Start or resume a server session, or raise ConfigurationError."""
with self._lock:
# Sessions are always supported in load balanced mode.
if not self._settings.load_balanced:
session_timeout = self._check_session_support()
else:
# Sessions never time out in load balanced mode.
session_timeout = float("inf")
session_timeout = self._check_session_support()
return self._session_pool.get_server_session(session_timeout)
def return_server_session(self, server_session, lock):

View File

@ -18,20 +18,28 @@ import copy
import sys
import time
from io import BytesIO
from typing import Set
from typing import Any, Callable, List, Set, Tuple
from pymongo.mongo_client import MongoClient
sys.path[0:0] = [""]
from test import IntegrationTest, SkipTest, client_context, unittest
from test.utils import EventListener, rs_or_single_client, wait_until
from test.utils import (
EventListener,
ExceptionCatchingThread,
rs_or_single_client,
wait_until,
)
from bson import DBRef
from gridfs import GridFS, GridFSBucket
from pymongo import ASCENDING, IndexModel, InsertOne, monitoring
from pymongo.command_cursor import CommandCursor
from pymongo.common import _MAX_END_SESSIONS
from pymongo.cursor import Cursor
from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure
from pymongo.operations import UpdateOne
from pymongo.read_concern import ReadConcern
@ -171,6 +179,63 @@ class TestSession(IntegrationTest):
"%s did not return implicit session to pool" % (f.__name__,),
)
def test_implicit_sessions_checkout(self):
# "To confirm that implicit sessions only allocate their server session after a
# successful connection checkout" test from Driver Sessions Spec.
succeeded = False
failures = 0
for _ in range(5):
listener = EventListener()
client = rs_or_single_client(
event_listeners=[listener], maxPoolSize=1, retryWrites=True
)
cursor = client.db.test.find({})
ops: List[Tuple[Callable, List[Any]]] = [
(client.db.test.find_one, [{"_id": 1}]),
(client.db.test.delete_one, [{}]),
(client.db.test.update_one, [{}, {"$set": {"x": 2}}]),
(client.db.test.bulk_write, [[UpdateOne({}, {"$set": {"x": 2}})]]),
(client.db.test.find_one_and_delete, [{}]),
(client.db.test.find_one_and_update, [{}, {"$set": {"x": 1}}]),
(client.db.test.find_one_and_replace, [{}, {}]),
(client.db.test.aggregate, [[{"$limit": 1}]]),
(client.db.test.find, []),
(client.server_info, [{}]),
(client.db.aggregate, [[{"$listLocalSessions": {}}, {"$limit": 1}]]),
(cursor.distinct, ["_id"]),
(client.db.list_collections, []),
]
threads = []
listener.results.clear()
def thread_target(op, *args):
res = op(*args)
if isinstance(res, (Cursor, CommandCursor)):
list(res)
for op, args in ops:
threads.append(
ExceptionCatchingThread(
target=thread_target, args=[op, *args], name=op.__name__
)
)
threads[-1].start()
self.assertEqual(len(threads), len(ops))
for thread in threads:
thread.join()
self.assertIsNone(thread.exc)
client.close()
lsid_set = set()
for i in listener.results["started"]:
if i.command.get("lsid"):
lsid_set.add(i.command.get("lsid")["id"])
if len(lsid_set) == 1:
succeeded = True
else:
failures += 1
print(failures)
self.assertTrue(succeeded)
def test_pool_lifo(self):
# "Pool is LIFO" test from Driver Sessions Spec.
a = self.client.start_session()