PYTHON-1332 - Send lsid with all commands

This commit is contained in:
A. Jesse Jiryu Davis 2017-09-20 22:56:26 -04:00
parent c1ec855060
commit 6fa2e4047f
6 changed files with 155 additions and 107 deletions

View File

@ -149,6 +149,13 @@ class _ServerSessionPool(collections.deque):
This class is not thread-safe, access it while holding the Topology lock.
"""
def get_server_session(self, session_timeout_minutes):
# Although the Driver Sessions Spec says we only clear stale sessions
# in return_server_session, PyMongo can't take a lock when returning
# sessions from a __del__ method (like in Cursor.__die), so it can't
# clear stale sessions there. In case many sessions were returned via
# __del__, check for stale sessions here too.
self._clear_stale(session_timeout_minutes)
# The most recently used sessions are on the left.
while self:
s = self.popleft()
@ -158,6 +165,11 @@ class _ServerSessionPool(collections.deque):
return _ServerSession()
def return_server_session(self, server_session, session_timeout_minutes):
self._clear_stale(session_timeout_minutes)
if not server_session.timed_out(session_timeout_minutes):
self.appendleft(server_session)
def _clear_stale(self, session_timeout_minutes):
# Clear stale sessions. The least recently used are on the right.
while self:
if self[-1].timed_out(session_timeout_minutes):
@ -165,6 +177,3 @@ class _ServerSessionPool(collections.deque):
else:
# The remaining sessions also haven't timed out.
break
if not server_session.timed_out(session_timeout_minutes):
self.appendleft(server_session)

View File

@ -1398,6 +1398,9 @@ class Collection(common.BaseObject):
>>> for batch in cursor:
... print(bson.decode_all(batch))
Unlike most PyMongo methods, this method sends no session id to the
server.
.. versionadded:: 3.6
"""
return RawBatchCursor(self, *args, **kwargs)
@ -1443,6 +1446,9 @@ class Collection(common.BaseObject):
- `**kwargs`: additional options for the parallelCollectionScan
command can be passed as keyword arguments.
Unlike most PyMongo methods, this method sends no session id to the
server unless an explicit ``session`` parameter is passed.
.. note:: Requires server version **>= 2.5.5**.
.. versionchanged:: 3.6
@ -1461,14 +1467,19 @@ class Collection(common.BaseObject):
('numCursors', num_cursors)])
cmd.update(kwargs)
s = self.__database.client._ensure_session(session)
with self._socket_for_reads() as (sock_info, slave_ok):
result = self._command(sock_info, cmd, slave_ok,
read_concern=self.read_concern,
session=s)
# Avoid auto-injecting a session.
result = sock_info.command(
self.__database.name,
cmd,
slave_ok,
self.read_preference,
self.codec_options,
read_concern=self.read_concern,
session=session)
return [CommandCursor(self, cursor['cursor'], sock_info.address,
session=s, session_owned=session is None)
session=session, explicit_session=True)
for cursor in result['cursors']]
def _count(self, cmd, collation=None, session=None):
@ -1890,20 +1901,20 @@ class Collection(common.BaseObject):
with self._socket_for_primary_reads() as (sock_info, slave_ok):
cmd = SON([("listIndexes", self.__name), ("cursor", {})])
if sock_info.max_wire_version > 2:
s = self.__database.client._ensure_session(session)
try:
cursor = self._command(sock_info, cmd, slave_ok,
ReadPreference.PRIMARY,
codec_options,
session=s)["cursor"]
except OperationFailure as exc:
# Ignore NamespaceNotFound errors to match the behavior
# of reading from *.system.indexes.
if exc.code != 26:
raise
cursor = {'id': 0, 'firstBatch': []}
return CommandCursor(coll, cursor, sock_info.address,
session=s, session_owned=session is None)
with self.__database.client._tmp_session(session, False) as s:
try:
cursor = self._command(sock_info, cmd, slave_ok,
ReadPreference.PRIMARY,
codec_options,
session=s)["cursor"]
except OperationFailure as exc:
# Ignore NamespaceNotFound errors to match the behavior
# of reading from *.system.indexes.
if exc.code != 26:
raise
cursor = {'id': 0, 'firstBatch': []}
return CommandCursor(coll, cursor, sock_info.address, session=s,
explicit_session=session is not None)
else:
namespace = _UJOIN % (self.__database.name, "system.indexes")
res = helpers._first_batch(
@ -1995,7 +2006,7 @@ class Collection(common.BaseObject):
return options
def _aggregate(self, pipeline, cursor_class, first_batch_size, session,
**kwargs):
explicit_session, **kwargs):
common.validate_list('pipeline', pipeline)
if "explain" in kwargs:
@ -2028,47 +2039,56 @@ class Collection(common.BaseObject):
cmd['writeConcern'] = self.write_concern.document
cmd.update(kwargs)
session_owned = session is None
s = self.__database.client._ensure_session(session)
try:
# Apply this Collection's read concern if $out is not in the
# pipeline.
if sock_info.max_wire_version >= 4 and 'readConcern' not in cmd:
if dollar_out:
result = self._command(sock_info, cmd, slave_ok,
parse_write_concern_error=True,
collation=collation,
session=s)
else:
result = self._command(sock_info, cmd, slave_ok,
read_concern=self.read_concern,
collation=collation,
session=s)
# Apply this Collection's read concern if $out is not in the
# pipeline.
if sock_info.max_wire_version >= 4 and 'readConcern' not in cmd:
if dollar_out:
# Avoid auto-injecting a session.
result = sock_info.command(
self.__database.name,
cmd,
slave_ok,
self.read_preference,
self.codec_options,
parse_write_concern_error=True,
collation=collation,
session=session)
else:
result = self._command(sock_info, cmd, slave_ok,
parse_write_concern_error=dollar_out,
collation=collation,
session=s)
result = sock_info.command(
self.__database.name,
cmd,
slave_ok,
ReadPreference.PRIMARY,
self.codec_options,
read_concern=self.read_concern,
collation=collation,
session=session)
else:
result = sock_info.command(
self.__database.name,
cmd,
slave_ok,
self.read_preference,
self.codec_options,
parse_write_concern_error=dollar_out,
collation=collation,
session=session)
if "cursor" in result:
cursor = result["cursor"]
else:
# Pre-MongoDB 2.6. Fake a cursor.
cursor = {
"id": 0,
"firstBatch": result["result"],
"ns": self.full_name,
}
if "cursor" in result:
cursor = result["cursor"]
else:
# Pre-MongoDB 2.6. Fake a cursor.
cursor = {
"id": 0,
"firstBatch": result["result"],
"ns": self.full_name,
}
return cursor_class(
self, cursor, sock_info.address,
batch_size=batch_size or 0,
max_await_time_ms=max_await_time_ms,
session=s, session_owned=session_owned)
except Exception:
if session_owned:
s.end_session()
raise
return cursor_class(
self, cursor, sock_info.address,
batch_size=batch_size or 0,
max_await_time_ms=max_await_time_ms,
session=session, explicit_session=explicit_session)
def aggregate(self, pipeline, session=None, **kwargs):
"""Perform an aggregation using the aggregation framework on this
@ -2148,13 +2168,15 @@ class Collection(common.BaseObject):
.. _aggregate command:
https://docs.mongodb.com/manual/reference/command/aggregate
"""
return self._aggregate(pipeline,
CommandCursor,
kwargs.get('batchSize'),
session=session,
**kwargs)
with self.__database.client._tmp_session(session, close=False) as s:
return self._aggregate(pipeline,
CommandCursor,
kwargs.get('batchSize'),
session=s,
explicit_session=session is not None,
**kwargs)
def aggregate_raw_batches(self, pipeline, session=None, **kwargs):
def aggregate_raw_batches(self, pipeline, **kwargs):
"""Perform an aggregation and retrieve batches of raw BSON.
Takes the same parameters as :meth:`aggregate` but returns a
@ -2171,13 +2193,13 @@ class Collection(common.BaseObject):
>>> for batch in cursor:
... print(bson.decode_all(batch))
Unlike most PyMongo methods, this method sends no session id to the
server.
.. versionadded:: 3.6
"""
return self._aggregate(pipeline,
RawBatchCommandCursor,
0,
session=session,
**kwargs)
return self._aggregate(pipeline, RawBatchCommandCursor, 0,
None, False, **kwargs)
def watch(self, pipeline=None, full_document='default', resume_after=None,
max_await_time_ms=None, batch_size=None, collation=None):

View File

@ -36,7 +36,7 @@ class CommandCursor(object):
def __init__(self, collection, cursor_info, address, retrieved=0,
batch_size=0, max_await_time_ms=None, session=None,
session_owned=False):
explicit_session=False):
"""Create a new command cursor.
The parameter 'retrieved' is unused.
@ -47,21 +47,24 @@ class CommandCursor(object):
self.__address = address
self.__data = deque(cursor_info['firstBatch'])
self.__batch_size = batch_size
if (not isinstance(max_await_time_ms, integer_types)
and max_await_time_ms is not None):
raise TypeError("max_await_time_ms must be an integer or None")
self.__max_await_time_ms = max_await_time_ms
self.__session = session
self.__explicit_session = explicit_session
self.__killed = (self.__id == 0)
if self.__killed:
self.__end_session(True)
if "ns" in cursor_info:
self.__ns = cursor_info["ns"]
else:
self.__ns = collection.full_name
self.__session = session
self.__session_owned = session_owned
self.batch_size(batch_size)
if (not isinstance(max_await_time_ms, integer_types)
and max_await_time_ms is not None):
raise TypeError("max_await_time_ms must be an integer or None")
def __del__(self):
if self.__id and not self.__killed:
self.__die()
@ -80,7 +83,10 @@ class CommandCursor(object):
self.__collection.database.client.close_cursor(
self.__id, address)
self.__killed = True
if self.__session and self.__session_owned:
self.__end_session(synchronous)
def __end_session(self, synchronous):
if self.__session and not self.__explicit_session:
self.__session._end_session(lock=synchronous)
self.__session = None
@ -116,6 +122,10 @@ class CommandCursor(object):
def __send_message(self, operation):
"""Send a getmore message and handle the response.
"""
def kill():
self.__killed = True
self.__end_session(True)
client = self.__collection.database.client
listeners = client._event_listeners
publish = listeners.enabled_for_commands
@ -127,7 +137,7 @@ class CommandCursor(object):
# or to another server. It can cause a _pinValue
# assertion on some server releases if we get here
# due to a socket timeout.
self.__killed = True
kill()
raise
cmd_duration = response.duration
@ -144,7 +154,7 @@ class CommandCursor(object):
helpers._check_command_response(doc['data'][0])
except OperationFailure as exc:
self.__killed = True
kill()
if publish:
duration = (datetime.datetime.now() - start) + cmd_duration
@ -155,7 +165,7 @@ class CommandCursor(object):
except NotMasterError as exc:
# Don't send kill cursors to another server after a "not master"
# error. It's completely pointless.
self.__killed = True
kill()
if publish:
duration = (datetime.datetime.now() - start) + cmd_duration
@ -191,7 +201,7 @@ class CommandCursor(object):
duration, res, "getMore", rqst_id, self.__address)
if self.__id == 0:
self.__killed = True
kill()
self.__data = deque(documents)
def _unpack_response(self, response, cursor_id, codec_options):
@ -219,6 +229,7 @@ class CommandCursor(object):
self.__max_await_time_ms))
else: # Cursor id is zero nothing else to return
self.__killed = True
self.__end_session(True)
return len(self.__data)
@ -258,7 +269,7 @@ class CommandCursor(object):
.. versionadded:: 3.6
"""
if not self.__session_owned:
if self.__explicit_session:
return self.__session
def __iter__(self):
@ -288,7 +299,8 @@ class RawBatchCommandCursor(CommandCursor):
_getmore_class = _RawBatchGetMore
def __init__(self, collection, cursor_info, address, retrieved=0,
batch_size=0, max_await_time_ms=None, session=None):
batch_size=0, max_await_time_ms=None, session=None,
explicit_session=False):
"""Create a new cursor / iterator over raw batches of BSON data.
Should not be called directly by application developers -
@ -300,7 +312,7 @@ class RawBatchCommandCursor(CommandCursor):
assert not cursor_info.get('firstBatch')
super(RawBatchCommandCursor, self).__init__(
collection, cursor_info, address, retrieved, batch_size,
max_await_time_ms, session)
max_await_time_ms, session, explicit_session)
def _unpack_response(self, response, cursor_id, codec_options):
return helpers._raw_response(response, cursor_id)

View File

@ -127,7 +127,13 @@ class Cursor(object):
self.__id = None
self.__exhaust = False
self.__exhaust_mgr = None
self.__session = None
if session:
self.__session = session
self.__explicit_session = True
else:
self.__session = None
self.__explicit_session = False
spec = filter
if spec is None:
@ -178,12 +184,6 @@ class Cursor(object):
self.__return_key = return_key
self.__show_record_id = show_record_id
self.__snapshot = snapshot
if session:
self.__session = session
self.__session_owned = False
else:
self.__session = collection.database.client._ensure_session(session)
self.__session_owned = True
self.__set_hint(hint)
# Exhaust cursor support
@ -268,10 +268,10 @@ class Cursor(object):
def _clone(self, deepcopy=True, base=None):
"""Internal clone helper."""
if not base:
if self.__session_owned:
base = self._clone_base(None)
else:
if self.__explicit_session:
base = self._clone_base(self.__session)
else:
base = self._clone_base(None)
values_to_clone = ("spec", "projection", "skip", "limit",
"max_time_ms", "max_await_time_ms", "comment",
@ -312,7 +312,7 @@ class Cursor(object):
if self.__exhaust and self.__exhaust_mgr:
self.__exhaust_mgr.close()
self.__killed = True
if self.__session and self.__session_owned:
if self.__session and not self.__explicit_session:
self.__session._end_session(lock=synchronous)
self.__session = None
@ -1069,10 +1069,8 @@ class Cursor(object):
if len(self.__data) or self.__killed:
return len(self.__data)
# If a previous call to __die() has cleared the session.
if not self.__session:
self.__session = self.__collection.database.client._ensure_session()
self.__session_owned = True
if self.__id is None: # Query
q = self._query_class(self.__query_flags,
@ -1166,7 +1164,7 @@ class Cursor(object):
.. versionadded:: 3.6
"""
if not self.__session_owned:
if self.__explicit_session:
return self.__session
def __iter__(self):

View File

@ -551,11 +551,11 @@ class Database(common.BaseObject):
if sock_info.max_wire_version > 2:
coll = self["$cmd"]
s = self.__client._ensure_session(session)
cursor = self._command(
sock_info, cmd, slave_okay, session=s,
session_owned=session is None)["cursor"]
return CommandCursor(coll, cursor, sock_info.address)
with self.__client._tmp_session(session, close=False) as s:
cursor = self._command(
sock_info, cmd, slave_okay, session=s)["cursor"]
return CommandCursor(coll, cursor, sock_info.address, session=s,
explicit_session=session is not None)
else:
coll = self["system.namespaces"]
res = _first_batch(sock_info, coll.database.name, coll.name,

View File

@ -1246,7 +1246,7 @@ class MongoClient(common.BaseObject):
return None
@contextlib.contextmanager
def _tmp_session(self, session):
def _tmp_session(self, session, close=True):
"""If provided session is None, lend a temporary session."""
if session:
# Don't call end_session.
@ -1254,10 +1254,17 @@ class MongoClient(common.BaseObject):
return
s = self._ensure_session(session)
if s:
if s and close:
with s:
# Call end_session when we exit this scope.
yield s
elif s:
try:
# Only call end_session on error.
yield s
except Exception:
s.end_session()
raise
else:
yield None