mongo-python-driver/test/test_pooling.py

611 lines
20 KiB
Python

# Copyright 2009-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test built in connection-pooling with threads."""
from __future__ import annotations
import asyncio
import gc
import random
import socket
import sys
import time
from bson.codec_options import DEFAULT_CODEC_OPTIONS
from bson.son import SON
from pymongo import MongoClient, message, timeout
from pymongo.errors import AutoReconnect, ConnectionFailure, DuplicateKeyError
from pymongo.hello import HelloCompat
from pymongo.lock import _create_lock
sys.path[0:0] = [""]
from test import IntegrationTest, client_context, unittest
from test.helpers import ConcurrentRunner
from test.utils import delay, get_pool, joinall
from pymongo.socket_checker import SocketChecker
from pymongo.synchronous.pool import Pool, PoolOptions
_IS_SYNC = True
N = 10
DB = "pymongo-pooling-tests"
def gc_collect_until_done(tasks, timeout=60):
start = time.time()
running = list(tasks)
while running:
assert (time.time() - start) < timeout, "Tasks timed out"
for t in running:
t.join(0.1)
if not t.is_alive():
running.remove(t)
gc.collect()
class MongoTask(ConcurrentRunner):
"""A thread/Task that uses a MongoClient."""
def __init__(self, client):
super().__init__()
self.daemon = True # Don't hang whole test if task hangs.
self.client = client
self.db = self.client[DB]
self.passed = False
def run(self):
self.run_mongo_thread()
self.passed = True
def run_mongo_thread(self):
raise NotImplementedError
class InsertOneAndFind(MongoTask):
def run_mongo_thread(self):
for _ in range(N):
rand = random.randint(0, N)
_id = (self.db.sf.insert_one({"x": rand})).inserted_id
assert rand == (self.db.sf.find_one(_id))["x"]
class Unique(MongoTask):
def run_mongo_thread(self):
for _ in range(N):
self.db.unique.insert_one({}) # no error
class NonUnique(MongoTask):
def run_mongo_thread(self):
for _ in range(N):
try:
self.db.unique.insert_one({"_id": "jesse"})
except DuplicateKeyError:
pass
else:
raise AssertionError("Should have raised DuplicateKeyError")
class SocketGetter(MongoTask):
"""Utility for TestPooling.
Checks out a socket and holds it forever. Used in
test_no_wait_queue_timeout.
"""
def __init__(self, client, pool):
super().__init__(client)
self.state = "init"
self.pool = pool
self.sock = None
def run_mongo_thread(self):
self.state = "get_socket"
# Call 'pin_cursor' so we can hold the socket.
with self.pool.checkout() as sock:
sock.pin_cursor()
self.sock = sock
self.state = "connection"
def release_conn(self):
if self.sock:
self.sock.unpin()
self.sock = None
return True
return False
def run_cases(client, cases):
tasks = []
n_runs = 5
for case in cases:
for _i in range(n_runs):
t = case(client)
t.start()
tasks.append(t)
for t in tasks:
t.join()
for t in tasks:
assert t.passed, "%s.run() threw an exception" % repr(t)
class _TestPoolingBase(IntegrationTest):
"""Base class for all connection-pool tests."""
@client_context.require_connection
def setUp(self):
super().setUp()
self.c = self.rs_or_single_client()
db = self.c[DB]
db.unique.drop()
db.test.drop()
db.unique.insert_one({"_id": "jesse"})
db.test.insert_many([{} for _ in range(10)])
def create_pool(self, pair=None, *args, **kwargs):
if pair is None:
pair = (client_context.host, client_context.port)
# Start the pool with the correct ssl options.
pool_options = client_context.client._topology_settings.pool_options
kwargs["ssl_context"] = pool_options._ssl_context
kwargs["tls_allow_invalid_hostnames"] = pool_options.tls_allow_invalid_hostnames
kwargs["server_api"] = pool_options.server_api
pool = Pool(pair, PoolOptions(*args, **kwargs))
pool.ready()
return pool
class TestPooling(_TestPoolingBase):
def test_max_pool_size_validation(self):
host, port = client_context.host, client_context.port
self.assertRaises(ValueError, MongoClient, host=host, port=port, maxPoolSize=-1)
self.assertRaises(ValueError, MongoClient, host=host, port=port, maxPoolSize="foo")
c = MongoClient(host=host, port=port, maxPoolSize=100, connect=False)
self.assertEqual(c.options.pool_options.max_pool_size, 100)
def test_no_disconnect(self):
run_cases(self.c, [NonUnique, Unique, InsertOneAndFind])
def test_pool_reuses_open_socket(self):
# Test Pool's _check_closed() method doesn't close a healthy socket.
cx_pool = self.create_pool(max_pool_size=10)
cx_pool._check_interval_seconds = 0 # Always check.
with cx_pool.checkout() as conn:
pass
with cx_pool.checkout() as new_connection:
self.assertEqual(conn, new_connection)
self.assertEqual(1, len(cx_pool.conns))
def test_get_socket_and_exception(self):
# get_socket() returns socket after a non-network error.
cx_pool = self.create_pool(max_pool_size=1, wait_queue_timeout=1)
with self.assertRaises(ZeroDivisionError):
with cx_pool.checkout() as conn:
1 / 0
# Socket was returned, not closed.
with cx_pool.checkout() as new_connection:
self.assertEqual(conn, new_connection)
self.assertEqual(1, len(cx_pool.conns))
def test_pool_removes_closed_socket(self):
# Test that Pool removes explicitly closed socket.
cx_pool = self.create_pool()
with cx_pool.checkout() as conn:
# Use Connection's API to close the socket.
conn.close_conn(None)
self.assertEqual(0, len(cx_pool.conns))
def test_pool_removes_dead_socket(self):
# Test that Pool removes dead socket and the socket doesn't return
# itself PYTHON-344
cx_pool = self.create_pool(max_pool_size=1, wait_queue_timeout=1)
cx_pool._check_interval_seconds = 0 # Always check.
with cx_pool.checkout() as conn:
# Simulate a closed socket without telling the Connection it's
# closed.
conn.conn.close()
self.assertTrue(conn.conn_closed())
with cx_pool.checkout() as new_connection:
self.assertEqual(0, len(cx_pool.conns))
self.assertNotEqual(conn, new_connection)
self.assertEqual(1, len(cx_pool.conns))
# Semaphore was released.
with cx_pool.checkout():
pass
def test_socket_closed(self):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect((client_context.host, client_context.port))
socket_checker = SocketChecker()
self.assertFalse(socket_checker.socket_closed(s))
s.close()
self.assertTrue(socket_checker.socket_closed(s))
def test_socket_checker(self):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect((client_context.host, client_context.port))
socket_checker = SocketChecker()
# Socket has nothing to read.
self.assertFalse(socket_checker.select(s, read=True))
self.assertFalse(socket_checker.select(s, read=True, timeout=0))
self.assertFalse(socket_checker.select(s, read=True, timeout=0.05))
# Socket is writable.
self.assertTrue(socket_checker.select(s, write=True, timeout=None))
self.assertTrue(socket_checker.select(s, write=True))
self.assertTrue(socket_checker.select(s, write=True, timeout=0))
self.assertTrue(socket_checker.select(s, write=True, timeout=0.05))
# Make the socket readable
_, msg, _ = message._query(
0, "admin.$cmd", 0, -1, SON([("ping", 1)]), None, DEFAULT_CODEC_OPTIONS
)
s.sendall(msg)
# Block until the socket is readable.
self.assertTrue(socket_checker.select(s, read=True, timeout=None))
self.assertTrue(socket_checker.select(s, read=True))
self.assertTrue(socket_checker.select(s, read=True, timeout=0))
self.assertTrue(socket_checker.select(s, read=True, timeout=0.05))
# Socket is still writable.
self.assertTrue(socket_checker.select(s, write=True, timeout=None))
self.assertTrue(socket_checker.select(s, write=True))
self.assertTrue(socket_checker.select(s, write=True, timeout=0))
self.assertTrue(socket_checker.select(s, write=True, timeout=0.05))
s.close()
self.assertTrue(socket_checker.socket_closed(s))
def test_return_socket_after_reset(self):
pool = self.create_pool()
with pool.checkout() as sock:
self.assertEqual(pool.active_sockets, 1)
self.assertEqual(pool.operation_count, 1)
pool.reset()
self.assertTrue(sock.closed)
self.assertEqual(0, len(pool.conns))
self.assertEqual(pool.active_sockets, 0)
self.assertEqual(pool.operation_count, 0)
def test_pool_check(self):
# Test that Pool recovers from two connection failures in a row.
# This exercises code at the end of Pool._check().
cx_pool = self.create_pool(max_pool_size=1, connect_timeout=1, wait_queue_timeout=1)
cx_pool._check_interval_seconds = 0 # Always check.
self.addCleanup(cx_pool.close)
with cx_pool.checkout() as conn:
# Simulate a closed socket without telling the Connection it's
# closed.
conn.conn.close()
# Swap pool's address with a bad one.
address, cx_pool.address = cx_pool.address, ("foo.com", 1234)
with self.assertRaises(AutoReconnect):
with cx_pool.checkout():
pass
# Back to normal, semaphore was correctly released.
cx_pool.address = address
with cx_pool.checkout():
pass
def test_wait_queue_timeout(self):
wait_queue_timeout = 2 # Seconds
pool = self.create_pool(max_pool_size=1, wait_queue_timeout=wait_queue_timeout)
self.addCleanup(pool.close)
with pool.checkout():
start = time.time()
with self.assertRaises(ConnectionFailure):
with pool.checkout():
pass
duration = time.time() - start
self.assertTrue(
abs(wait_queue_timeout - duration) < 1,
f"Waited {duration:.2f} seconds for a socket, expected {wait_queue_timeout:f}",
)
def test_no_wait_queue_timeout(self):
# Verify get_socket() with no wait_queue_timeout blocks forever.
pool = self.create_pool(max_pool_size=1)
self.addCleanup(pool.close)
# Reach max_size.
with pool.checkout() as s1:
t = SocketGetter(self.c, pool)
t.start()
while t.state != "get_socket":
time.sleep(0.1)
time.sleep(1)
self.assertEqual(t.state, "get_socket")
while t.state != "connection":
time.sleep(0.1)
self.assertEqual(t.state, "connection")
self.assertEqual(t.sock, s1)
# Cleanup
t.release_conn()
t.join()
pool.close()
def test_checkout_more_than_max_pool_size(self):
pool = self.create_pool(max_pool_size=2)
socks = []
for _ in range(2):
# Call 'pin_cursor' so we can hold the socket.
with pool.checkout() as sock:
sock.pin_cursor()
socks.append(sock)
tasks = []
for _ in range(10):
t = SocketGetter(self.c, pool)
t.start()
tasks.append(t)
time.sleep(1)
for t in tasks:
self.assertEqual(t.state, "get_socket")
# Cleanup
for socket_info in socks:
socket_info.unpin()
while tasks:
to_remove = []
for t in tasks:
if t.release_conn():
to_remove.append(t)
t.join()
for t in to_remove:
tasks.remove(t)
time.sleep(0.05)
pool.close()
def test_maxConnecting(self):
client = self.rs_or_single_client()
self.client.test.test.insert_one({})
self.addCleanup(self.client.test.test.delete_many, {})
pool = get_pool(client)
docs = []
# Run 50 short running operations
def find_one():
docs.append(client.test.test.find_one({}))
tasks = [ConcurrentRunner(target=find_one) for _ in range(50)]
for task in tasks:
task.start()
for task in tasks:
task.join(10)
self.assertEqual(len(docs), 50)
self.assertLessEqual(len(pool.conns), 50)
# TLS and auth make connection establishment more expensive than
# the query which leads to more threads hitting maxConnecting.
# The end result is fewer total connections and better latency.
if client_context.tls and client_context.auth_enabled:
self.assertLessEqual(len(pool.conns), 30)
else:
self.assertLessEqual(len(pool.conns), 50)
# MongoDB 4.4.1 with auth + ssl:
# maxConnecting = 2: 6 connections in ~0.231+ seconds
# maxConnecting = unbounded: 50 connections in ~0.642+ seconds
#
# MongoDB 4.4.1 with no-auth no-ssl Python 3.8:
# maxConnecting = 2: 15-22 connections in ~0.108+ seconds
# maxConnecting = unbounded: 30+ connections in ~0.140+ seconds
print(len(pool.conns))
@client_context.require_failCommand_appName
def test_csot_timeout_message(self):
client = self.rs_or_single_client(appName="connectionTimeoutApp")
# Mock an operation failing due to pymongo.timeout().
mock_connection_timeout = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"blockConnection": True,
"blockTimeMS": 1000,
"failCommands": ["find"],
"appName": "connectionTimeoutApp",
},
}
client.db.t.insert_one({"x": 1})
with self.fail_point(mock_connection_timeout):
with self.assertRaises(Exception) as error:
with timeout(0.5):
client.db.t.find_one({"$where": delay(2)})
self.assertTrue("(configured timeouts: timeoutMS: 500.0ms" in str(error.exception))
@client_context.require_failCommand_appName
def test_socket_timeout_message(self):
client = self.rs_or_single_client(socketTimeoutMS=500, appName="connectionTimeoutApp")
# Mock an operation failing due to socketTimeoutMS.
mock_connection_timeout = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"blockConnection": True,
"blockTimeMS": 1000,
"failCommands": ["find"],
"appName": "connectionTimeoutApp",
},
}
client.db.t.insert_one({"x": 1})
with self.fail_point(mock_connection_timeout):
with self.assertRaises(Exception) as error:
client.db.t.find_one({"$where": delay(2)})
self.assertTrue(
"(configured timeouts: socketTimeoutMS: 500.0ms, connectTimeoutMS: 20000.0ms)"
in str(error.exception)
)
@client_context.require_failCommand_appName
def test_connection_timeout_message(self):
# Mock a connection creation failing due to timeout.
mock_connection_timeout = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"blockConnection": True,
"blockTimeMS": 1000,
"failCommands": [HelloCompat.LEGACY_CMD, "hello"],
"appName": "connectionTimeoutApp",
},
}
client = self.rs_or_single_client(
connectTimeoutMS=500,
socketTimeoutMS=500,
appName="connectionTimeoutApp",
heartbeatFrequencyMS=1000000,
)
client.admin.command("ping")
pool = get_pool(client)
pool.reset_without_pause()
with self.fail_point(mock_connection_timeout):
with self.assertRaises(Exception) as error:
client.admin.command("ping")
self.assertTrue(
"(configured timeouts: socketTimeoutMS: 500.0ms, connectTimeoutMS: 500.0ms)"
in str(error.exception)
)
class TestPoolMaxSize(_TestPoolingBase):
def test_max_pool_size(self):
max_pool_size = 4
c = self.rs_or_single_client(maxPoolSize=max_pool_size)
collection = c[DB].test
# Need one document.
collection.drop()
collection.insert_one({})
# ntasks had better be much larger than max_pool_size to ensure that
# max_pool_size connections are actually required at some point in this
# test's execution.
cx_pool = get_pool(c)
ntasks = 10
tasks = []
lock = _create_lock()
self.n_passed = 0
def f():
for _ in range(5):
collection.find_one({"$where": delay(0.1)})
assert len(cx_pool.conns) <= max_pool_size
with lock:
self.n_passed += 1
for _i in range(ntasks):
t = ConcurrentRunner(target=f)
tasks.append(t)
t.start()
joinall(tasks)
self.assertEqual(ntasks, self.n_passed)
self.assertTrue(len(cx_pool.conns) > 1)
self.assertEqual(0, cx_pool.requests)
def test_max_pool_size_none(self):
c = self.rs_or_single_client(maxPoolSize=None)
collection = c[DB].test
# Need one document.
collection.drop()
collection.insert_one({})
cx_pool = get_pool(c)
ntasks = 10
tasks = []
lock = _create_lock()
self.n_passed = 0
def f():
for _ in range(5):
collection.find_one({"$where": delay(0.1)})
with lock:
self.n_passed += 1
for _i in range(ntasks):
t = ConcurrentRunner(target=f)
tasks.append(t)
t.start()
joinall(tasks)
self.assertEqual(ntasks, self.n_passed)
self.assertTrue(len(cx_pool.conns) > 1)
self.assertEqual(cx_pool.max_pool_size, float("inf"))
def test_max_pool_size_zero(self):
c = self.rs_or_single_client(maxPoolSize=0)
pool = get_pool(c)
self.assertEqual(pool.max_pool_size, float("inf"))
def test_max_pool_size_with_connection_failure(self):
# The pool acquires its semaphore before attempting to connect; ensure
# it releases the semaphore on connection failure.
test_pool = Pool(
("somedomainthatdoesntexist.org", 27017),
PoolOptions(max_pool_size=1, connect_timeout=1, socket_timeout=1, wait_queue_timeout=1),
)
test_pool.ready()
# First call to get_socket fails; if pool doesn't release its semaphore
# then the second call raises "ConnectionFailure: Timed out waiting for
# socket from pool" instead of AutoReconnect.
for _i in range(2):
with self.assertRaises(AutoReconnect) as context:
with test_pool.checkout():
pass
# Testing for AutoReconnect instead of ConnectionFailure, above,
# is sufficient right *now* to catch a semaphore leak. But that
# seems error-prone, so check the message too.
self.assertNotIn("waiting for socket from pool", str(context.exception))
if __name__ == "__main__":
unittest.main()