diff --git a/pymongo/bulk.py b/pymongo/bulk.py index c056e80fd..5bc468096 100644 --- a/pymongo/bulk.py +++ b/pymongo/bulk.py @@ -240,8 +240,9 @@ class _Bulk(object): listeners = client._event_listeners with self.collection.database.client._tmp_session(session) as s: - # sock_info.command checks auth, but we use sock_info.write_command. - sock_info.check_session_auth_matches(s) + # sock_info.command validates the session, but we use + # sock_info.write_command. + sock_info.validate_session(client, s) for run in generator: cmd = SON([(_COMMANDS[run.op_type], self.collection.name), ('ordered', self.ordered)]) @@ -267,10 +268,7 @@ class _Bulk(object): if not to_send: raise InvalidOperation("cannot do an empty bulk write") result = bwc.write_command(request_id, msg, to_send) - client._receive_cluster_time(result) - if s is not None: - s._advance_cluster_time(result.get("$clusterTime")) - s._advance_operation_time(result.get("operationTime")) + client._receive_cluster_time(result, s) results.append((idx_offset, result)) if self.ordered and "writeErrors" in result: break diff --git a/pymongo/client_session.py b/pymongo/client_session.py index 151b701e3..841a79f60 100644 --- a/pymongo/client_session.py +++ b/pymongo/client_session.py @@ -91,7 +91,7 @@ class ClientSession(object): def _end_session(self, lock): if self._server_session is not None: - self.client._return_server_session(self._server_session, lock) + self._client._return_server_session(self._server_session, lock) self._server_session = None def __enter__(self): diff --git a/pymongo/command_cursor.py b/pymongo/command_cursor.py index 72ec2be6a..ba5ba9478 100644 --- a/pymongo/command_cursor.py +++ b/pymongo/command_cursor.py @@ -154,12 +154,7 @@ class CommandCursor(object): self.__collection.codec_options) if from_command: first = docs[0] - client._receive_cluster_time(first) - if self.__session is not None: - self.__session._advance_cluster_time( - first.get('$clusterTime')) - self.__session._advance_operation_time( - first.get('operationTime')) + client._receive_cluster_time(first, self.__session) helpers._check_command_response(first) except OperationFailure as exc: diff --git a/pymongo/cursor.py b/pymongo/cursor.py index bddd3497c..2c50e3d8b 100644 --- a/pymongo/cursor.py +++ b/pymongo/cursor.py @@ -968,12 +968,7 @@ class Cursor(object): codec_options=self.__codec_options) if from_command: first = docs[0] - client._receive_cluster_time(first) - if self.__session is not None: - self.__session._advance_cluster_time( - first.get("$clusterTime")) - self.__session._advance_operation_time( - first.get("operationTime")) + client._receive_cluster_time(first, self.__session) helpers._check_command_response(first) except OperationFailure as exc: self.__killed = True diff --git a/pymongo/message.py b/pymongo/message.py index cc18472c7..8773b8da5 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -295,7 +295,7 @@ class _Query(object): 'Specifying a collation is unsupported with a max wire ' 'version of %d.' % (sock_info.max_wire_version,)) - sock_info.check_session_auth_matches(self.session) + sock_info.validate_session(self.client, self.session) return use_find_cmd @@ -369,7 +369,7 @@ class _GetMore(object): self.max_await_time_ms = max_await_time_ms def use_command(self, sock_info, exhaust): - sock_info.check_session_auth_matches(self.session) + sock_info.validate_session(self.client, self.session) return sock_info.max_wire_version >= 4 and not exhaust def as_command(self): diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index f8764026f..f2a24e2f6 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -1356,8 +1356,12 @@ class MongoClient(common.BaseObject): if cluster_time: command['$clusterTime'] = cluster_time - def _receive_cluster_time(self, reply): - self._topology.receive_cluster_time(reply.get('$clusterTime')) + def _receive_cluster_time(self, reply, session): + cluster_time = reply.get('$clusterTime') + self._topology.receive_cluster_time(cluster_time) + if session is not None: + session._advance_cluster_time(cluster_time) + session._advance_operation_time(reply.get("operationTime")) def server_info(self, session=None): """Get information about the MongoDB server we're connected to. diff --git a/pymongo/network.py b/pymongo/network.py index 09a77b020..8d76412dc 100644 --- a/pymongo/network.py +++ b/pymongo/network.py @@ -129,12 +129,7 @@ def command(sock, dbname, spec, slave_ok, is_mongos, response_doc = unpacked_docs[0] if client: - client._receive_cluster_time(response_doc) - if session: - session._advance_cluster_time( - response_doc.get('$clusterTime')) - session._advance_operation_time( - response_doc.get('operationTime')) + client._receive_cluster_time(response_doc, session) if check: helpers._check_command_response( response_doc, None, allowable_errors, diff --git a/pymongo/pool.py b/pymongo/pool.py index 2a340a2ed..289d72e5c 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -461,7 +461,7 @@ class SocketInfo(object): - `client`: optional MongoClient for gossipping $clusterTime. - `retryable_write`: True if this command is a retryable write. """ - self.check_session_auth_matches(session) + self.validate_session(client, session) if (read_concern and self.max_wire_version < 4 and not read_concern.ok_for_legacy): raise ConfigurationError( @@ -591,11 +591,21 @@ class SocketInfo(object): auth.authenticate(credentials, self) self.authset.add(credentials) - def check_session_auth_matches(self, session): - """Raise error if a ClientSession is logged in as a different user.""" - if session and session._authset != self.authset: - raise InvalidOperation('session was used after authenticating' - ' with different credentials') + def validate_session(self, client, session): + """Validate this session before use with client. + + Raises error if this session is logged in as a different user or + the client is not the one that created the session. + """ + if session: + if session._client is not client: + raise InvalidOperation( + 'Can only use session with the MongoClient that' + ' started it') + if session._authset != self.authset: + raise InvalidOperation( + 'Cannot use session after authenticating with different' + ' credentials') def close(self): self.closed = True diff --git a/test/test_session.py b/test/test_session.py index 597ac6b4f..15816b3f2 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -57,19 +57,25 @@ def session_ids(client): class TestSession(IntegrationTest): + + @classmethod @client_context.require_sessions - def setUp(self): - super(TestSession, self).setUp() + def setUpClass(cls): + super(TestSession, cls).setUpClass() + # Create a second client so we can make sure clients cannot share + # sessions. + cls.client2 = rs_or_single_client() # Redact no commands, so we can test user-admin commands have "lsid". - self.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy() + cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy() monitoring._SENSITIVE_COMMANDS.clear() - def tearDown(self): - monitoring._SENSITIVE_COMMANDS.update(self.sensitive_commands) - super(TestSession, self).tearDown() + @classmethod + def tearDownClass(cls): + monitoring._SENSITIVE_COMMANDS.update(cls.sensitive_commands) + super(TestSession, cls).tearDownClass() - def _test_ops(self, client, *ops, **kwargs): + def _test_ops(self, client, *ops): listener = client.event_listeners()[0][0] for f, args, kw in ops: @@ -103,6 +109,18 @@ class TestSession(IntegrationTest): with self.assertRaisesRegex(InvalidOperation, "ended session"): f(*args, **kw) + # Test a session cannot be used on another client. + with self.client2.start_session() as s: + # In case "f" modifies its inputs. + args = copy.copy(args) + kw = copy.copy(kw) + kw['session'] = s + with self.assertRaisesRegex( + InvalidOperation, + 'Can only use session with the MongoClient' + ' that started it'): + f(*args, **kw) + # No explicit session. for f, args, kw in ops: listener.results.clear() @@ -980,8 +998,8 @@ class TestSessionsMultiAuth(IntegrationTest): client.admin.logout() db.authenticate('second-user', 'pass') - err = 'session was used after authenticating with different' \ - ' credentials' + err = ('Cannot use session after authenticating with different' + ' credentials') with self.assertRaisesRegex(InvalidOperation, err): # Auth has changed between find and getMore.