PYTHON-2115 Remove threading.Lock() from SocketChecker

This commit is contained in:
Shane Harvey 2020-04-27 13:29:42 -07:00
parent 9cc3652ec3
commit 71d1227932
3 changed files with 14 additions and 34 deletions

View File

@ -487,6 +487,7 @@ class SocketInfo(object):
self.enabled_for_cmap = pool.enabled_for_cmap
self.compression_settings = pool.opts.compression_settings
self.compression_context = None
self.socket_checker = SocketChecker()
# The pool's generation changes with each reset() so we can close
# sockets created before the last reset.
@ -752,6 +753,10 @@ class SocketInfo(object):
self.listeners.publish_connection_closed(
self.address, self.id, reason)
def socket_closed(self):
"""Return True if we know socket has been closed, False otherwise."""
return self.socket_checker.socket_closed(self.sock)
def send_cluster_time(self, command, session, client):
"""Add cluster time for MongoDB >= 3.6."""
if self.max_wire_version >= 6 and client:
@ -976,7 +981,6 @@ class Pool:
self._socket_semaphore = thread_util.create_semaphore(
self.opts.max_pool_size, max_waiters)
self.socket_checker = SocketChecker()
if self.enabled_for_cmap:
self.opts.event_listeners.publish_pool_created(
self.address, self.opts.non_default_options)
@ -1244,7 +1248,7 @@ class Pool:
if (self._check_interval_seconds is not None and (
0 == self._check_interval_seconds or
idle_time_seconds > self._check_interval_seconds)):
if self.socket_checker.socket_closed(sock_info.sock):
if sock_info.socket_closed():
sock_info.close_socket(ConnectionClosedReason.ERROR)
return True

View File

@ -16,7 +16,6 @@
import errno
import select
import threading
_HAVE_POLL = hasattr(select, "poll")
_SelectError = getattr(select, "error", OSError)
@ -34,10 +33,8 @@ class SocketChecker(object):
def __init__(self):
if _HAVE_POLL:
self._lock = threading.Lock()
self._poller = select.poll()
else:
self._lock = None
self._poller = None
def select(self, sock, read=False, write=False, timeout=0):
@ -50,14 +47,13 @@ class SocketChecker(object):
mask = mask | select.POLLIN | select.POLLPRI
if write:
mask = mask | select.POLLOUT
with self._lock:
self._poller.register(sock, mask)
try:
# poll() timeout is in milliseconds. select()
# timeout is in seconds.
res = self._poller.poll(timeout * 1000)
finally:
self._poller.unregister(sock)
self._poller.register(sock, mask)
try:
# poll() timeout is in milliseconds. select()
# timeout is in seconds.
res = self._poller.poll(timeout * 1000)
finally:
self._poller.unregister(sock)
else:
rlist = [sock] if read else []
wlist = [sock] if write else []

View File

@ -236,8 +236,7 @@ class TestPooling(_TestPoolingBase):
# Simulate a closed socket without telling the SocketInfo it's
# closed.
sock_info.sock.close()
self.assertTrue(
cx_pool.socket_checker.socket_closed(sock_info.sock))
self.assertTrue(sock_info.socket_closed())
with cx_pool.get_socket({}) as new_sock_info:
self.assertEqual(0, len(cx_pool.sockets))
@ -257,25 +256,6 @@ class TestPooling(_TestPoolingBase):
s.close()
self.assertTrue(socket_checker.socket_closed(s))
def test_socket_closed_thread_safe(self):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect((client_context.host, client_context.port))
self.addCleanup(s.close)
socket_checker = SocketChecker()
def check_socket():
for _ in range(1000):
self.assertFalse(socket_checker.socket_closed(s))
threads = []
for i in range(3):
thread = threading.Thread(target=check_socket)
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
def test_return_socket_after_reset(self):
pool = self.create_pool()
with pool.get_socket({}) as sock: