PYTHON-2956 Drivers should check out an implicit session only after checking out a connection (#876)
This commit is contained in:
parent
782c5517e0
commit
b737b843e9
@ -947,9 +947,16 @@ class ClientSession(Generic[_DocumentType]):
|
||||
return self._transaction.opts.read_preference
|
||||
return None
|
||||
|
||||
def _materialize(self):
|
||||
if isinstance(self._server_session, _EmptyServerSession):
|
||||
old = self._server_session
|
||||
self._server_session = self._client._topology.get_server_session()
|
||||
if old.started_retryable_write:
|
||||
self._server_session.inc_transaction_id()
|
||||
|
||||
def _apply_to(self, command, is_retryable, read_preference, sock_info):
|
||||
self._check_ended()
|
||||
|
||||
self._materialize()
|
||||
if self.options.snapshot:
|
||||
self._update_read_concern(command, sock_info)
|
||||
|
||||
@ -1000,6 +1007,20 @@ class ClientSession(Generic[_DocumentType]):
|
||||
raise TypeError("A ClientSession cannot be copied, create a new session instead")
|
||||
|
||||
|
||||
class _EmptyServerSession:
|
||||
__slots__ = "dirty", "started_retryable_write"
|
||||
|
||||
def __init__(self):
|
||||
self.dirty = False
|
||||
self.started_retryable_write = False
|
||||
|
||||
def mark_dirty(self):
|
||||
self.dirty = True
|
||||
|
||||
def inc_transaction_id(self):
|
||||
self.started_retryable_write = True
|
||||
|
||||
|
||||
class _ServerSession(object):
|
||||
def __init__(self, generation):
|
||||
# Ensure id is type 4, regardless of CodecOptions.uuid_representation.
|
||||
|
||||
@ -66,6 +66,7 @@ from pymongo import (
|
||||
)
|
||||
from pymongo.change_stream import ChangeStream, ClusterChangeStream
|
||||
from pymongo.client_options import ClientOptions
|
||||
from pymongo.client_session import _EmptyServerSession
|
||||
from pymongo.command_cursor import CommandCursor
|
||||
from pymongo.errors import (
|
||||
AutoReconnect,
|
||||
@ -1601,7 +1602,11 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
def __start_session(self, implicit, **kwargs):
|
||||
# Raises ConfigurationError if sessions are not supported.
|
||||
server_session = self._get_server_session()
|
||||
if implicit:
|
||||
self._topology._check_implicit_session_support()
|
||||
server_session = _EmptyServerSession()
|
||||
else:
|
||||
server_session = self._get_server_session()
|
||||
opts = client_session.SessionOptions(**kwargs)
|
||||
return client_session.ClientSession(self, server_session, opts, implicit)
|
||||
|
||||
@ -1641,6 +1646,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
def _return_server_session(self, server_session, lock):
|
||||
"""Internal: return a _ServerSession to the pool."""
|
||||
if isinstance(server_session, _EmptyServerSession):
|
||||
return
|
||||
return self._topology.return_server_session(server_session, lock)
|
||||
|
||||
def _ensure_session(self, session=None):
|
||||
|
||||
@ -514,8 +514,15 @@ class Topology(object):
|
||||
with self._lock:
|
||||
return self._session_pool.pop_all()
|
||||
|
||||
def _check_implicit_session_support(self):
|
||||
with self._lock:
|
||||
self._check_session_support()
|
||||
|
||||
def _check_session_support(self):
|
||||
"""Internal check for session support on non-load balanced clusters."""
|
||||
"""Internal check for session support on clusters."""
|
||||
if self._settings.load_balanced:
|
||||
# Sessions never time out in load balanced mode.
|
||||
return float("inf")
|
||||
session_timeout = self._description.logical_session_timeout_minutes
|
||||
if session_timeout is None:
|
||||
# Maybe we need an initial scan? Can raise ServerSelectionError.
|
||||
@ -537,12 +544,7 @@ class Topology(object):
|
||||
def get_server_session(self):
|
||||
"""Start or resume a server session, or raise ConfigurationError."""
|
||||
with self._lock:
|
||||
# Sessions are always supported in load balanced mode.
|
||||
if not self._settings.load_balanced:
|
||||
session_timeout = self._check_session_support()
|
||||
else:
|
||||
# Sessions never time out in load balanced mode.
|
||||
session_timeout = float("inf")
|
||||
session_timeout = self._check_session_support()
|
||||
return self._session_pool.get_server_session(session_timeout)
|
||||
|
||||
def return_server_session(self, server_session, lock):
|
||||
|
||||
@ -18,20 +18,28 @@ import copy
|
||||
import sys
|
||||
import time
|
||||
from io import BytesIO
|
||||
from typing import Set
|
||||
from typing import Any, Callable, List, Set, Tuple
|
||||
|
||||
from pymongo.mongo_client import MongoClient
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, SkipTest, client_context, unittest
|
||||
from test.utils import EventListener, rs_or_single_client, wait_until
|
||||
from test.utils import (
|
||||
EventListener,
|
||||
ExceptionCatchingThread,
|
||||
rs_or_single_client,
|
||||
wait_until,
|
||||
)
|
||||
|
||||
from bson import DBRef
|
||||
from gridfs import GridFS, GridFSBucket
|
||||
from pymongo import ASCENDING, IndexModel, InsertOne, monitoring
|
||||
from pymongo.command_cursor import CommandCursor
|
||||
from pymongo.common import _MAX_END_SESSIONS
|
||||
from pymongo.cursor import Cursor
|
||||
from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure
|
||||
from pymongo.operations import UpdateOne
|
||||
from pymongo.read_concern import ReadConcern
|
||||
|
||||
|
||||
@ -171,6 +179,63 @@ class TestSession(IntegrationTest):
|
||||
"%s did not return implicit session to pool" % (f.__name__,),
|
||||
)
|
||||
|
||||
def test_implicit_sessions_checkout(self):
|
||||
# "To confirm that implicit sessions only allocate their server session after a
|
||||
# successful connection checkout" test from Driver Sessions Spec.
|
||||
succeeded = False
|
||||
failures = 0
|
||||
for _ in range(5):
|
||||
listener = EventListener()
|
||||
client = rs_or_single_client(
|
||||
event_listeners=[listener], maxPoolSize=1, retryWrites=True
|
||||
)
|
||||
cursor = client.db.test.find({})
|
||||
ops: List[Tuple[Callable, List[Any]]] = [
|
||||
(client.db.test.find_one, [{"_id": 1}]),
|
||||
(client.db.test.delete_one, [{}]),
|
||||
(client.db.test.update_one, [{}, {"$set": {"x": 2}}]),
|
||||
(client.db.test.bulk_write, [[UpdateOne({}, {"$set": {"x": 2}})]]),
|
||||
(client.db.test.find_one_and_delete, [{}]),
|
||||
(client.db.test.find_one_and_update, [{}, {"$set": {"x": 1}}]),
|
||||
(client.db.test.find_one_and_replace, [{}, {}]),
|
||||
(client.db.test.aggregate, [[{"$limit": 1}]]),
|
||||
(client.db.test.find, []),
|
||||
(client.server_info, [{}]),
|
||||
(client.db.aggregate, [[{"$listLocalSessions": {}}, {"$limit": 1}]]),
|
||||
(cursor.distinct, ["_id"]),
|
||||
(client.db.list_collections, []),
|
||||
]
|
||||
threads = []
|
||||
listener.results.clear()
|
||||
|
||||
def thread_target(op, *args):
|
||||
res = op(*args)
|
||||
if isinstance(res, (Cursor, CommandCursor)):
|
||||
list(res)
|
||||
|
||||
for op, args in ops:
|
||||
threads.append(
|
||||
ExceptionCatchingThread(
|
||||
target=thread_target, args=[op, *args], name=op.__name__
|
||||
)
|
||||
)
|
||||
threads[-1].start()
|
||||
self.assertEqual(len(threads), len(ops))
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
self.assertIsNone(thread.exc)
|
||||
client.close()
|
||||
lsid_set = set()
|
||||
for i in listener.results["started"]:
|
||||
if i.command.get("lsid"):
|
||||
lsid_set.add(i.command.get("lsid")["id"])
|
||||
if len(lsid_set) == 1:
|
||||
succeeded = True
|
||||
else:
|
||||
failures += 1
|
||||
print(failures)
|
||||
self.assertTrue(succeeded)
|
||||
|
||||
def test_pool_lifo(self):
|
||||
# "Pool is LIFO" test from Driver Sessions Spec.
|
||||
a = self.client.start_session()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user