diff --git a/pymongo/client_session.py b/pymongo/client_session.py index 52e886823..3323c5116 100644 --- a/pymongo/client_session.py +++ b/pymongo/client_session.py @@ -149,6 +149,13 @@ class _ServerSessionPool(collections.deque): This class is not thread-safe, access it while holding the Topology lock. """ def get_server_session(self, session_timeout_minutes): + # Although the Driver Sessions Spec says we only clear stale sessions + # in return_server_session, PyMongo can't take a lock when returning + # sessions from a __del__ method (like in Cursor.__die), so it can't + # clear stale sessions there. In case many sessions were returned via + # __del__, check for stale sessions here too. + self._clear_stale(session_timeout_minutes) + # The most recently used sessions are on the left. while self: s = self.popleft() @@ -158,6 +165,11 @@ class _ServerSessionPool(collections.deque): return _ServerSession() def return_server_session(self, server_session, session_timeout_minutes): + self._clear_stale(session_timeout_minutes) + if not server_session.timed_out(session_timeout_minutes): + self.appendleft(server_session) + + def _clear_stale(self, session_timeout_minutes): # Clear stale sessions. The least recently used are on the right. while self: if self[-1].timed_out(session_timeout_minutes): @@ -165,6 +177,3 @@ class _ServerSessionPool(collections.deque): else: # The remaining sessions also haven't timed out. break - - if not server_session.timed_out(session_timeout_minutes): - self.appendleft(server_session) diff --git a/pymongo/collection.py b/pymongo/collection.py index 01519a8de..03ddc3fca 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -1398,6 +1398,9 @@ class Collection(common.BaseObject): >>> for batch in cursor: ... print(bson.decode_all(batch)) + Unlike most PyMongo methods, this method sends no session id to the + server. + .. versionadded:: 3.6 """ return RawBatchCursor(self, *args, **kwargs) @@ -1443,6 +1446,9 @@ class Collection(common.BaseObject): - `**kwargs`: additional options for the parallelCollectionScan command can be passed as keyword arguments. + Unlike most PyMongo methods, this method sends no session id to the + server unless an explicit ``session`` parameter is passed. + .. note:: Requires server version **>= 2.5.5**. .. versionchanged:: 3.6 @@ -1461,14 +1467,19 @@ class Collection(common.BaseObject): ('numCursors', num_cursors)]) cmd.update(kwargs) - s = self.__database.client._ensure_session(session) with self._socket_for_reads() as (sock_info, slave_ok): - result = self._command(sock_info, cmd, slave_ok, - read_concern=self.read_concern, - session=s) + # Avoid auto-injecting a session. + result = sock_info.command( + self.__database.name, + cmd, + slave_ok, + self.read_preference, + self.codec_options, + read_concern=self.read_concern, + session=session) return [CommandCursor(self, cursor['cursor'], sock_info.address, - session=s, session_owned=session is None) + session=session, explicit_session=True) for cursor in result['cursors']] def _count(self, cmd, collation=None, session=None): @@ -1890,20 +1901,20 @@ class Collection(common.BaseObject): with self._socket_for_primary_reads() as (sock_info, slave_ok): cmd = SON([("listIndexes", self.__name), ("cursor", {})]) if sock_info.max_wire_version > 2: - s = self.__database.client._ensure_session(session) - try: - cursor = self._command(sock_info, cmd, slave_ok, - ReadPreference.PRIMARY, - codec_options, - session=s)["cursor"] - except OperationFailure as exc: - # Ignore NamespaceNotFound errors to match the behavior - # of reading from *.system.indexes. - if exc.code != 26: - raise - cursor = {'id': 0, 'firstBatch': []} - return CommandCursor(coll, cursor, sock_info.address, - session=s, session_owned=session is None) + with self.__database.client._tmp_session(session, False) as s: + try: + cursor = self._command(sock_info, cmd, slave_ok, + ReadPreference.PRIMARY, + codec_options, + session=s)["cursor"] + except OperationFailure as exc: + # Ignore NamespaceNotFound errors to match the behavior + # of reading from *.system.indexes. + if exc.code != 26: + raise + cursor = {'id': 0, 'firstBatch': []} + return CommandCursor(coll, cursor, sock_info.address, session=s, + explicit_session=session is not None) else: namespace = _UJOIN % (self.__database.name, "system.indexes") res = helpers._first_batch( @@ -1995,7 +2006,7 @@ class Collection(common.BaseObject): return options def _aggregate(self, pipeline, cursor_class, first_batch_size, session, - **kwargs): + explicit_session, **kwargs): common.validate_list('pipeline', pipeline) if "explain" in kwargs: @@ -2028,47 +2039,56 @@ class Collection(common.BaseObject): cmd['writeConcern'] = self.write_concern.document cmd.update(kwargs) - session_owned = session is None - s = self.__database.client._ensure_session(session) - try: - # Apply this Collection's read concern if $out is not in the - # pipeline. - if sock_info.max_wire_version >= 4 and 'readConcern' not in cmd: - if dollar_out: - result = self._command(sock_info, cmd, slave_ok, - parse_write_concern_error=True, - collation=collation, - session=s) - else: - result = self._command(sock_info, cmd, slave_ok, - read_concern=self.read_concern, - collation=collation, - session=s) + # Apply this Collection's read concern if $out is not in the + # pipeline. + if sock_info.max_wire_version >= 4 and 'readConcern' not in cmd: + if dollar_out: + # Avoid auto-injecting a session. + result = sock_info.command( + self.__database.name, + cmd, + slave_ok, + self.read_preference, + self.codec_options, + parse_write_concern_error=True, + collation=collation, + session=session) else: - result = self._command(sock_info, cmd, slave_ok, - parse_write_concern_error=dollar_out, - collation=collation, - session=s) + result = sock_info.command( + self.__database.name, + cmd, + slave_ok, + ReadPreference.PRIMARY, + self.codec_options, + read_concern=self.read_concern, + collation=collation, + session=session) + else: + result = sock_info.command( + self.__database.name, + cmd, + slave_ok, + self.read_preference, + self.codec_options, + parse_write_concern_error=dollar_out, + collation=collation, + session=session) - if "cursor" in result: - cursor = result["cursor"] - else: - # Pre-MongoDB 2.6. Fake a cursor. - cursor = { - "id": 0, - "firstBatch": result["result"], - "ns": self.full_name, - } + if "cursor" in result: + cursor = result["cursor"] + else: + # Pre-MongoDB 2.6. Fake a cursor. + cursor = { + "id": 0, + "firstBatch": result["result"], + "ns": self.full_name, + } - return cursor_class( - self, cursor, sock_info.address, - batch_size=batch_size or 0, - max_await_time_ms=max_await_time_ms, - session=s, session_owned=session_owned) - except Exception: - if session_owned: - s.end_session() - raise + return cursor_class( + self, cursor, sock_info.address, + batch_size=batch_size or 0, + max_await_time_ms=max_await_time_ms, + session=session, explicit_session=explicit_session) def aggregate(self, pipeline, session=None, **kwargs): """Perform an aggregation using the aggregation framework on this @@ -2148,13 +2168,15 @@ class Collection(common.BaseObject): .. _aggregate command: https://docs.mongodb.com/manual/reference/command/aggregate """ - return self._aggregate(pipeline, - CommandCursor, - kwargs.get('batchSize'), - session=session, - **kwargs) + with self.__database.client._tmp_session(session, close=False) as s: + return self._aggregate(pipeline, + CommandCursor, + kwargs.get('batchSize'), + session=s, + explicit_session=session is not None, + **kwargs) - def aggregate_raw_batches(self, pipeline, session=None, **kwargs): + def aggregate_raw_batches(self, pipeline, **kwargs): """Perform an aggregation and retrieve batches of raw BSON. Takes the same parameters as :meth:`aggregate` but returns a @@ -2171,13 +2193,13 @@ class Collection(common.BaseObject): >>> for batch in cursor: ... print(bson.decode_all(batch)) + Unlike most PyMongo methods, this method sends no session id to the + server. + .. versionadded:: 3.6 """ - return self._aggregate(pipeline, - RawBatchCommandCursor, - 0, - session=session, - **kwargs) + return self._aggregate(pipeline, RawBatchCommandCursor, 0, + None, False, **kwargs) def watch(self, pipeline=None, full_document='default', resume_after=None, max_await_time_ms=None, batch_size=None, collation=None): diff --git a/pymongo/command_cursor.py b/pymongo/command_cursor.py index e16d1920a..3cf1e9206 100644 --- a/pymongo/command_cursor.py +++ b/pymongo/command_cursor.py @@ -36,7 +36,7 @@ class CommandCursor(object): def __init__(self, collection, cursor_info, address, retrieved=0, batch_size=0, max_await_time_ms=None, session=None, - session_owned=False): + explicit_session=False): """Create a new command cursor. The parameter 'retrieved' is unused. @@ -47,21 +47,24 @@ class CommandCursor(object): self.__address = address self.__data = deque(cursor_info['firstBatch']) self.__batch_size = batch_size - if (not isinstance(max_await_time_ms, integer_types) - and max_await_time_ms is not None): - raise TypeError("max_await_time_ms must be an integer or None") self.__max_await_time_ms = max_await_time_ms + self.__session = session + self.__explicit_session = explicit_session self.__killed = (self.__id == 0) + if self.__killed: + self.__end_session(True) if "ns" in cursor_info: self.__ns = cursor_info["ns"] else: self.__ns = collection.full_name - self.__session = session - self.__session_owned = session_owned self.batch_size(batch_size) + if (not isinstance(max_await_time_ms, integer_types) + and max_await_time_ms is not None): + raise TypeError("max_await_time_ms must be an integer or None") + def __del__(self): if self.__id and not self.__killed: self.__die() @@ -80,7 +83,10 @@ class CommandCursor(object): self.__collection.database.client.close_cursor( self.__id, address) self.__killed = True - if self.__session and self.__session_owned: + self.__end_session(synchronous) + + def __end_session(self, synchronous): + if self.__session and not self.__explicit_session: self.__session._end_session(lock=synchronous) self.__session = None @@ -116,6 +122,10 @@ class CommandCursor(object): def __send_message(self, operation): """Send a getmore message and handle the response. """ + def kill(): + self.__killed = True + self.__end_session(True) + client = self.__collection.database.client listeners = client._event_listeners publish = listeners.enabled_for_commands @@ -127,7 +137,7 @@ class CommandCursor(object): # or to another server. It can cause a _pinValue # assertion on some server releases if we get here # due to a socket timeout. - self.__killed = True + kill() raise cmd_duration = response.duration @@ -144,7 +154,7 @@ class CommandCursor(object): helpers._check_command_response(doc['data'][0]) except OperationFailure as exc: - self.__killed = True + kill() if publish: duration = (datetime.datetime.now() - start) + cmd_duration @@ -155,7 +165,7 @@ class CommandCursor(object): except NotMasterError as exc: # Don't send kill cursors to another server after a "not master" # error. It's completely pointless. - self.__killed = True + kill() if publish: duration = (datetime.datetime.now() - start) + cmd_duration @@ -191,7 +201,7 @@ class CommandCursor(object): duration, res, "getMore", rqst_id, self.__address) if self.__id == 0: - self.__killed = True + kill() self.__data = deque(documents) def _unpack_response(self, response, cursor_id, codec_options): @@ -219,6 +229,7 @@ class CommandCursor(object): self.__max_await_time_ms)) else: # Cursor id is zero nothing else to return self.__killed = True + self.__end_session(True) return len(self.__data) @@ -258,7 +269,7 @@ class CommandCursor(object): .. versionadded:: 3.6 """ - if not self.__session_owned: + if self.__explicit_session: return self.__session def __iter__(self): @@ -288,7 +299,8 @@ class RawBatchCommandCursor(CommandCursor): _getmore_class = _RawBatchGetMore def __init__(self, collection, cursor_info, address, retrieved=0, - batch_size=0, max_await_time_ms=None, session=None): + batch_size=0, max_await_time_ms=None, session=None, + explicit_session=False): """Create a new cursor / iterator over raw batches of BSON data. Should not be called directly by application developers - @@ -300,7 +312,7 @@ class RawBatchCommandCursor(CommandCursor): assert not cursor_info.get('firstBatch') super(RawBatchCommandCursor, self).__init__( collection, cursor_info, address, retrieved, batch_size, - max_await_time_ms, session) + max_await_time_ms, session, explicit_session) def _unpack_response(self, response, cursor_id, codec_options): return helpers._raw_response(response, cursor_id) diff --git a/pymongo/cursor.py b/pymongo/cursor.py index 86f3f3dc4..737ea7bb5 100644 --- a/pymongo/cursor.py +++ b/pymongo/cursor.py @@ -127,7 +127,13 @@ class Cursor(object): self.__id = None self.__exhaust = False self.__exhaust_mgr = None - self.__session = None + + if session: + self.__session = session + self.__explicit_session = True + else: + self.__session = None + self.__explicit_session = False spec = filter if spec is None: @@ -178,12 +184,6 @@ class Cursor(object): self.__return_key = return_key self.__show_record_id = show_record_id self.__snapshot = snapshot - if session: - self.__session = session - self.__session_owned = False - else: - self.__session = collection.database.client._ensure_session(session) - self.__session_owned = True self.__set_hint(hint) # Exhaust cursor support @@ -268,10 +268,10 @@ class Cursor(object): def _clone(self, deepcopy=True, base=None): """Internal clone helper.""" if not base: - if self.__session_owned: - base = self._clone_base(None) - else: + if self.__explicit_session: base = self._clone_base(self.__session) + else: + base = self._clone_base(None) values_to_clone = ("spec", "projection", "skip", "limit", "max_time_ms", "max_await_time_ms", "comment", @@ -312,7 +312,7 @@ class Cursor(object): if self.__exhaust and self.__exhaust_mgr: self.__exhaust_mgr.close() self.__killed = True - if self.__session and self.__session_owned: + if self.__session and not self.__explicit_session: self.__session._end_session(lock=synchronous) self.__session = None @@ -1069,10 +1069,8 @@ class Cursor(object): if len(self.__data) or self.__killed: return len(self.__data) - # If a previous call to __die() has cleared the session. if not self.__session: self.__session = self.__collection.database.client._ensure_session() - self.__session_owned = True if self.__id is None: # Query q = self._query_class(self.__query_flags, @@ -1166,7 +1164,7 @@ class Cursor(object): .. versionadded:: 3.6 """ - if not self.__session_owned: + if self.__explicit_session: return self.__session def __iter__(self): diff --git a/pymongo/database.py b/pymongo/database.py index f4e02d1e6..cec0cd06f 100644 --- a/pymongo/database.py +++ b/pymongo/database.py @@ -551,11 +551,11 @@ class Database(common.BaseObject): if sock_info.max_wire_version > 2: coll = self["$cmd"] - s = self.__client._ensure_session(session) - cursor = self._command( - sock_info, cmd, slave_okay, session=s, - session_owned=session is None)["cursor"] - return CommandCursor(coll, cursor, sock_info.address) + with self.__client._tmp_session(session, close=False) as s: + cursor = self._command( + sock_info, cmd, slave_okay, session=s)["cursor"] + return CommandCursor(coll, cursor, sock_info.address, session=s, + explicit_session=session is not None) else: coll = self["system.namespaces"] res = _first_batch(sock_info, coll.database.name, coll.name, diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index 0a7922ae0..0fab06e17 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -1246,7 +1246,7 @@ class MongoClient(common.BaseObject): return None @contextlib.contextmanager - def _tmp_session(self, session): + def _tmp_session(self, session, close=True): """If provided session is None, lend a temporary session.""" if session: # Don't call end_session. @@ -1254,10 +1254,17 @@ class MongoClient(common.BaseObject): return s = self._ensure_session(session) - if s: + if s and close: with s: # Call end_session when we exit this scope. yield s + elif s: + try: + # Only call end_session on error. + yield s + except Exception: + s.end_session() + raise else: yield None