PYTHON-1332 Session may only be used by the client that started it.
Centralize $clusterTime receiving.
This commit is contained in:
parent
27d94755df
commit
b669cd86dc
@ -240,8 +240,9 @@ class _Bulk(object):
|
||||
listeners = client._event_listeners
|
||||
|
||||
with self.collection.database.client._tmp_session(session) as s:
|
||||
# sock_info.command checks auth, but we use sock_info.write_command.
|
||||
sock_info.check_session_auth_matches(s)
|
||||
# sock_info.command validates the session, but we use
|
||||
# sock_info.write_command.
|
||||
sock_info.validate_session(client, s)
|
||||
for run in generator:
|
||||
cmd = SON([(_COMMANDS[run.op_type], self.collection.name),
|
||||
('ordered', self.ordered)])
|
||||
@ -267,10 +268,7 @@ class _Bulk(object):
|
||||
if not to_send:
|
||||
raise InvalidOperation("cannot do an empty bulk write")
|
||||
result = bwc.write_command(request_id, msg, to_send)
|
||||
client._receive_cluster_time(result)
|
||||
if s is not None:
|
||||
s._advance_cluster_time(result.get("$clusterTime"))
|
||||
s._advance_operation_time(result.get("operationTime"))
|
||||
client._receive_cluster_time(result, s)
|
||||
results.append((idx_offset, result))
|
||||
if self.ordered and "writeErrors" in result:
|
||||
break
|
||||
|
||||
@ -91,7 +91,7 @@ class ClientSession(object):
|
||||
|
||||
def _end_session(self, lock):
|
||||
if self._server_session is not None:
|
||||
self.client._return_server_session(self._server_session, lock)
|
||||
self._client._return_server_session(self._server_session, lock)
|
||||
self._server_session = None
|
||||
|
||||
def __enter__(self):
|
||||
|
||||
@ -154,12 +154,7 @@ class CommandCursor(object):
|
||||
self.__collection.codec_options)
|
||||
if from_command:
|
||||
first = docs[0]
|
||||
client._receive_cluster_time(first)
|
||||
if self.__session is not None:
|
||||
self.__session._advance_cluster_time(
|
||||
first.get('$clusterTime'))
|
||||
self.__session._advance_operation_time(
|
||||
first.get('operationTime'))
|
||||
client._receive_cluster_time(first, self.__session)
|
||||
helpers._check_command_response(first)
|
||||
|
||||
except OperationFailure as exc:
|
||||
|
||||
@ -968,12 +968,7 @@ class Cursor(object):
|
||||
codec_options=self.__codec_options)
|
||||
if from_command:
|
||||
first = docs[0]
|
||||
client._receive_cluster_time(first)
|
||||
if self.__session is not None:
|
||||
self.__session._advance_cluster_time(
|
||||
first.get("$clusterTime"))
|
||||
self.__session._advance_operation_time(
|
||||
first.get("operationTime"))
|
||||
client._receive_cluster_time(first, self.__session)
|
||||
helpers._check_command_response(first)
|
||||
except OperationFailure as exc:
|
||||
self.__killed = True
|
||||
|
||||
@ -295,7 +295,7 @@ class _Query(object):
|
||||
'Specifying a collation is unsupported with a max wire '
|
||||
'version of %d.' % (sock_info.max_wire_version,))
|
||||
|
||||
sock_info.check_session_auth_matches(self.session)
|
||||
sock_info.validate_session(self.client, self.session)
|
||||
|
||||
return use_find_cmd
|
||||
|
||||
@ -369,7 +369,7 @@ class _GetMore(object):
|
||||
self.max_await_time_ms = max_await_time_ms
|
||||
|
||||
def use_command(self, sock_info, exhaust):
|
||||
sock_info.check_session_auth_matches(self.session)
|
||||
sock_info.validate_session(self.client, self.session)
|
||||
return sock_info.max_wire_version >= 4 and not exhaust
|
||||
|
||||
def as_command(self):
|
||||
|
||||
@ -1356,8 +1356,12 @@ class MongoClient(common.BaseObject):
|
||||
if cluster_time:
|
||||
command['$clusterTime'] = cluster_time
|
||||
|
||||
def _receive_cluster_time(self, reply):
|
||||
self._topology.receive_cluster_time(reply.get('$clusterTime'))
|
||||
def _receive_cluster_time(self, reply, session):
|
||||
cluster_time = reply.get('$clusterTime')
|
||||
self._topology.receive_cluster_time(cluster_time)
|
||||
if session is not None:
|
||||
session._advance_cluster_time(cluster_time)
|
||||
session._advance_operation_time(reply.get("operationTime"))
|
||||
|
||||
def server_info(self, session=None):
|
||||
"""Get information about the MongoDB server we're connected to.
|
||||
|
||||
@ -129,12 +129,7 @@ def command(sock, dbname, spec, slave_ok, is_mongos,
|
||||
|
||||
response_doc = unpacked_docs[0]
|
||||
if client:
|
||||
client._receive_cluster_time(response_doc)
|
||||
if session:
|
||||
session._advance_cluster_time(
|
||||
response_doc.get('$clusterTime'))
|
||||
session._advance_operation_time(
|
||||
response_doc.get('operationTime'))
|
||||
client._receive_cluster_time(response_doc, session)
|
||||
if check:
|
||||
helpers._check_command_response(
|
||||
response_doc, None, allowable_errors,
|
||||
|
||||
@ -461,7 +461,7 @@ class SocketInfo(object):
|
||||
- `client`: optional MongoClient for gossipping $clusterTime.
|
||||
- `retryable_write`: True if this command is a retryable write.
|
||||
"""
|
||||
self.check_session_auth_matches(session)
|
||||
self.validate_session(client, session)
|
||||
if (read_concern and self.max_wire_version < 4
|
||||
and not read_concern.ok_for_legacy):
|
||||
raise ConfigurationError(
|
||||
@ -591,11 +591,21 @@ class SocketInfo(object):
|
||||
auth.authenticate(credentials, self)
|
||||
self.authset.add(credentials)
|
||||
|
||||
def check_session_auth_matches(self, session):
|
||||
"""Raise error if a ClientSession is logged in as a different user."""
|
||||
if session and session._authset != self.authset:
|
||||
raise InvalidOperation('session was used after authenticating'
|
||||
' with different credentials')
|
||||
def validate_session(self, client, session):
|
||||
"""Validate this session before use with client.
|
||||
|
||||
Raises error if this session is logged in as a different user or
|
||||
the client is not the one that created the session.
|
||||
"""
|
||||
if session:
|
||||
if session._client is not client:
|
||||
raise InvalidOperation(
|
||||
'Can only use session with the MongoClient that'
|
||||
' started it')
|
||||
if session._authset != self.authset:
|
||||
raise InvalidOperation(
|
||||
'Cannot use session after authenticating with different'
|
||||
' credentials')
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
@ -57,19 +57,25 @@ def session_ids(client):
|
||||
|
||||
|
||||
class TestSession(IntegrationTest):
|
||||
|
||||
@classmethod
|
||||
@client_context.require_sessions
|
||||
def setUp(self):
|
||||
super(TestSession, self).setUp()
|
||||
def setUpClass(cls):
|
||||
super(TestSession, cls).setUpClass()
|
||||
# Create a second client so we can make sure clients cannot share
|
||||
# sessions.
|
||||
cls.client2 = rs_or_single_client()
|
||||
|
||||
# Redact no commands, so we can test user-admin commands have "lsid".
|
||||
self.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy()
|
||||
cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy()
|
||||
monitoring._SENSITIVE_COMMANDS.clear()
|
||||
|
||||
def tearDown(self):
|
||||
monitoring._SENSITIVE_COMMANDS.update(self.sensitive_commands)
|
||||
super(TestSession, self).tearDown()
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
monitoring._SENSITIVE_COMMANDS.update(cls.sensitive_commands)
|
||||
super(TestSession, cls).tearDownClass()
|
||||
|
||||
def _test_ops(self, client, *ops, **kwargs):
|
||||
def _test_ops(self, client, *ops):
|
||||
listener = client.event_listeners()[0][0]
|
||||
|
||||
for f, args, kw in ops:
|
||||
@ -103,6 +109,18 @@ class TestSession(IntegrationTest):
|
||||
with self.assertRaisesRegex(InvalidOperation, "ended session"):
|
||||
f(*args, **kw)
|
||||
|
||||
# Test a session cannot be used on another client.
|
||||
with self.client2.start_session() as s:
|
||||
# In case "f" modifies its inputs.
|
||||
args = copy.copy(args)
|
||||
kw = copy.copy(kw)
|
||||
kw['session'] = s
|
||||
with self.assertRaisesRegex(
|
||||
InvalidOperation,
|
||||
'Can only use session with the MongoClient'
|
||||
' that started it'):
|
||||
f(*args, **kw)
|
||||
|
||||
# No explicit session.
|
||||
for f, args, kw in ops:
|
||||
listener.results.clear()
|
||||
@ -980,8 +998,8 @@ class TestSessionsMultiAuth(IntegrationTest):
|
||||
client.admin.logout()
|
||||
db.authenticate('second-user', 'pass')
|
||||
|
||||
err = 'session was used after authenticating with different' \
|
||||
' credentials'
|
||||
err = ('Cannot use session after authenticating with different'
|
||||
' credentials')
|
||||
|
||||
with self.assertRaisesRegex(InvalidOperation, err):
|
||||
# Auth has changed between find and getMore.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user