diff --git a/pymongo/pool.py b/pymongo/pool.py index 8bd6602cc..972a8c1db 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -487,6 +487,7 @@ class SocketInfo(object): self.enabled_for_cmap = pool.enabled_for_cmap self.compression_settings = pool.opts.compression_settings self.compression_context = None + self.socket_checker = SocketChecker() # The pool's generation changes with each reset() so we can close # sockets created before the last reset. @@ -752,6 +753,10 @@ class SocketInfo(object): self.listeners.publish_connection_closed( self.address, self.id, reason) + def socket_closed(self): + """Return True if we know socket has been closed, False otherwise.""" + return self.socket_checker.socket_closed(self.sock) + def send_cluster_time(self, command, session, client): """Add cluster time for MongoDB >= 3.6.""" if self.max_wire_version >= 6 and client: @@ -976,7 +981,6 @@ class Pool: self._socket_semaphore = thread_util.create_semaphore( self.opts.max_pool_size, max_waiters) - self.socket_checker = SocketChecker() if self.enabled_for_cmap: self.opts.event_listeners.publish_pool_created( self.address, self.opts.non_default_options) @@ -1244,7 +1248,7 @@ class Pool: if (self._check_interval_seconds is not None and ( 0 == self._check_interval_seconds or idle_time_seconds > self._check_interval_seconds)): - if self.socket_checker.socket_closed(sock_info.sock): + if sock_info.socket_closed(): sock_info.close_socket(ConnectionClosedReason.ERROR) return True diff --git a/pymongo/socket_checker.py b/pymongo/socket_checker.py index 9b21e69d2..93d6cb5a4 100644 --- a/pymongo/socket_checker.py +++ b/pymongo/socket_checker.py @@ -16,7 +16,6 @@ import errno import select -import threading _HAVE_POLL = hasattr(select, "poll") _SelectError = getattr(select, "error", OSError) @@ -34,10 +33,8 @@ class SocketChecker(object): def __init__(self): if _HAVE_POLL: - self._lock = threading.Lock() self._poller = select.poll() else: - self._lock = None self._poller = None def select(self, sock, read=False, write=False, timeout=0): @@ -50,14 +47,13 @@ class SocketChecker(object): mask = mask | select.POLLIN | select.POLLPRI if write: mask = mask | select.POLLOUT - with self._lock: - self._poller.register(sock, mask) - try: - # poll() timeout is in milliseconds. select() - # timeout is in seconds. - res = self._poller.poll(timeout * 1000) - finally: - self._poller.unregister(sock) + self._poller.register(sock, mask) + try: + # poll() timeout is in milliseconds. select() + # timeout is in seconds. + res = self._poller.poll(timeout * 1000) + finally: + self._poller.unregister(sock) else: rlist = [sock] if read else [] wlist = [sock] if write else [] diff --git a/test/test_pooling.py b/test/test_pooling.py index bfe5abc11..8ed4068a6 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -236,8 +236,7 @@ class TestPooling(_TestPoolingBase): # Simulate a closed socket without telling the SocketInfo it's # closed. sock_info.sock.close() - self.assertTrue( - cx_pool.socket_checker.socket_closed(sock_info.sock)) + self.assertTrue(sock_info.socket_closed()) with cx_pool.get_socket({}) as new_sock_info: self.assertEqual(0, len(cx_pool.sockets)) @@ -257,25 +256,6 @@ class TestPooling(_TestPoolingBase): s.close() self.assertTrue(socket_checker.socket_closed(s)) - def test_socket_closed_thread_safe(self): - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.connect((client_context.host, client_context.port)) - self.addCleanup(s.close) - socket_checker = SocketChecker() - - def check_socket(): - for _ in range(1000): - self.assertFalse(socket_checker.socket_closed(s)) - - threads = [] - for i in range(3): - thread = threading.Thread(target=check_socket) - thread.start() - threads.append(thread) - - for thread in threads: - thread.join() - def test_return_socket_after_reset(self): pool = self.create_pool() with pool.get_socket({}) as sock: