diff --git a/pymongo/server.py b/pymongo/server.py index 985d45acb..f431fd014 100644 --- a/pymongo/server.py +++ b/pymongo/server.py @@ -132,6 +132,8 @@ class Server: if publish: cmd, dbn = operation.as_command(conn) + if "$db" not in cmd: + cmd["$db"] = dbn assert listeners is not None listeners.publish_command_start( cmd, dbn, request_id, conn.address, service_id=conn.service_id diff --git a/test/test_cursor.py b/test/test_cursor.py index f8820f8aa..37c0335fa 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -22,6 +22,8 @@ import sys import threading import time +import pymongo + sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest @@ -1586,6 +1588,28 @@ class TestRawBatchCommandCursor(IntegrationTest): n += 4 listener.reset() + @client_context.require_version_min(5, 0, -1) + @client_context.require_no_mongos + def test_exhaust_cursor_db_set(self): + listener = OvertCommandListener() + client = rs_or_single_client(event_listeners=[listener]) + self.addCleanup(client.close) + c = client.pymongo_test.test + c.delete_many({}) + c.insert_many([{"_id": i} for i in range(3)]) + + listener.reset() + + result = list(c.find({}, cursor_type=pymongo.CursorType.EXHAUST, batch_size=1)) + + self.assertEqual(len(result), 3) + + self.assertEqual( + listener.started_command_names(), ["find", "getMore", "getMore", "getMore"] + ) + for cmd in listener.started_events: + self.assertEqual(cmd.command["$db"], "pymongo_test") + if __name__ == "__main__": unittest.main()