PYTHON-1332 Session may only be used by the client that started it.

Centralize $clusterTime receiving.
This commit is contained in:
Shane Harvey 2017-11-16 16:10:43 -08:00
parent 27d94755df
commit b669cd86dc
9 changed files with 59 additions and 44 deletions

View File

@ -240,8 +240,9 @@ class _Bulk(object):
listeners = client._event_listeners
with self.collection.database.client._tmp_session(session) as s:
# sock_info.command checks auth, but we use sock_info.write_command.
sock_info.check_session_auth_matches(s)
# sock_info.command validates the session, but we use
# sock_info.write_command.
sock_info.validate_session(client, s)
for run in generator:
cmd = SON([(_COMMANDS[run.op_type], self.collection.name),
('ordered', self.ordered)])
@ -267,10 +268,7 @@ class _Bulk(object):
if not to_send:
raise InvalidOperation("cannot do an empty bulk write")
result = bwc.write_command(request_id, msg, to_send)
client._receive_cluster_time(result)
if s is not None:
s._advance_cluster_time(result.get("$clusterTime"))
s._advance_operation_time(result.get("operationTime"))
client._receive_cluster_time(result, s)
results.append((idx_offset, result))
if self.ordered and "writeErrors" in result:
break

View File

@ -91,7 +91,7 @@ class ClientSession(object):
def _end_session(self, lock):
if self._server_session is not None:
self.client._return_server_session(self._server_session, lock)
self._client._return_server_session(self._server_session, lock)
self._server_session = None
def __enter__(self):

View File

@ -154,12 +154,7 @@ class CommandCursor(object):
self.__collection.codec_options)
if from_command:
first = docs[0]
client._receive_cluster_time(first)
if self.__session is not None:
self.__session._advance_cluster_time(
first.get('$clusterTime'))
self.__session._advance_operation_time(
first.get('operationTime'))
client._receive_cluster_time(first, self.__session)
helpers._check_command_response(first)
except OperationFailure as exc:

View File

@ -968,12 +968,7 @@ class Cursor(object):
codec_options=self.__codec_options)
if from_command:
first = docs[0]
client._receive_cluster_time(first)
if self.__session is not None:
self.__session._advance_cluster_time(
first.get("$clusterTime"))
self.__session._advance_operation_time(
first.get("operationTime"))
client._receive_cluster_time(first, self.__session)
helpers._check_command_response(first)
except OperationFailure as exc:
self.__killed = True

View File

@ -295,7 +295,7 @@ class _Query(object):
'Specifying a collation is unsupported with a max wire '
'version of %d.' % (sock_info.max_wire_version,))
sock_info.check_session_auth_matches(self.session)
sock_info.validate_session(self.client, self.session)
return use_find_cmd
@ -369,7 +369,7 @@ class _GetMore(object):
self.max_await_time_ms = max_await_time_ms
def use_command(self, sock_info, exhaust):
sock_info.check_session_auth_matches(self.session)
sock_info.validate_session(self.client, self.session)
return sock_info.max_wire_version >= 4 and not exhaust
def as_command(self):

View File

@ -1356,8 +1356,12 @@ class MongoClient(common.BaseObject):
if cluster_time:
command['$clusterTime'] = cluster_time
def _receive_cluster_time(self, reply):
self._topology.receive_cluster_time(reply.get('$clusterTime'))
def _receive_cluster_time(self, reply, session):
cluster_time = reply.get('$clusterTime')
self._topology.receive_cluster_time(cluster_time)
if session is not None:
session._advance_cluster_time(cluster_time)
session._advance_operation_time(reply.get("operationTime"))
def server_info(self, session=None):
"""Get information about the MongoDB server we're connected to.

View File

@ -129,12 +129,7 @@ def command(sock, dbname, spec, slave_ok, is_mongos,
response_doc = unpacked_docs[0]
if client:
client._receive_cluster_time(response_doc)
if session:
session._advance_cluster_time(
response_doc.get('$clusterTime'))
session._advance_operation_time(
response_doc.get('operationTime'))
client._receive_cluster_time(response_doc, session)
if check:
helpers._check_command_response(
response_doc, None, allowable_errors,

View File

@ -461,7 +461,7 @@ class SocketInfo(object):
- `client`: optional MongoClient for gossipping $clusterTime.
- `retryable_write`: True if this command is a retryable write.
"""
self.check_session_auth_matches(session)
self.validate_session(client, session)
if (read_concern and self.max_wire_version < 4
and not read_concern.ok_for_legacy):
raise ConfigurationError(
@ -591,11 +591,21 @@ class SocketInfo(object):
auth.authenticate(credentials, self)
self.authset.add(credentials)
def check_session_auth_matches(self, session):
"""Raise error if a ClientSession is logged in as a different user."""
if session and session._authset != self.authset:
raise InvalidOperation('session was used after authenticating'
' with different credentials')
def validate_session(self, client, session):
"""Validate this session before use with client.
Raises error if this session is logged in as a different user or
the client is not the one that created the session.
"""
if session:
if session._client is not client:
raise InvalidOperation(
'Can only use session with the MongoClient that'
' started it')
if session._authset != self.authset:
raise InvalidOperation(
'Cannot use session after authenticating with different'
' credentials')
def close(self):
self.closed = True

View File

@ -57,19 +57,25 @@ def session_ids(client):
class TestSession(IntegrationTest):
@classmethod
@client_context.require_sessions
def setUp(self):
super(TestSession, self).setUp()
def setUpClass(cls):
super(TestSession, cls).setUpClass()
# Create a second client so we can make sure clients cannot share
# sessions.
cls.client2 = rs_or_single_client()
# Redact no commands, so we can test user-admin commands have "lsid".
self.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy()
cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy()
monitoring._SENSITIVE_COMMANDS.clear()
def tearDown(self):
monitoring._SENSITIVE_COMMANDS.update(self.sensitive_commands)
super(TestSession, self).tearDown()
@classmethod
def tearDownClass(cls):
monitoring._SENSITIVE_COMMANDS.update(cls.sensitive_commands)
super(TestSession, cls).tearDownClass()
def _test_ops(self, client, *ops, **kwargs):
def _test_ops(self, client, *ops):
listener = client.event_listeners()[0][0]
for f, args, kw in ops:
@ -103,6 +109,18 @@ class TestSession(IntegrationTest):
with self.assertRaisesRegex(InvalidOperation, "ended session"):
f(*args, **kw)
# Test a session cannot be used on another client.
with self.client2.start_session() as s:
# In case "f" modifies its inputs.
args = copy.copy(args)
kw = copy.copy(kw)
kw['session'] = s
with self.assertRaisesRegex(
InvalidOperation,
'Can only use session with the MongoClient'
' that started it'):
f(*args, **kw)
# No explicit session.
for f, args, kw in ops:
listener.results.clear()
@ -980,8 +998,8 @@ class TestSessionsMultiAuth(IntegrationTest):
client.admin.logout()
db.authenticate('second-user', 'pass')
err = 'session was used after authenticating with different' \
' credentials'
err = ('Cannot use session after authenticating with different'
' credentials')
with self.assertRaisesRegex(InvalidOperation, err):
# Auth has changed between find and getMore.