From f737ec0db916e75b114f8582231463be8ab14dec Mon Sep 17 00:00:00 2001 From: "A. Jesse Jiryu Davis" Date: Wed, 11 Mar 2015 10:57:01 -0400 Subject: [PATCH] PYTHON-728 - Translate socket.error to ConnectionFailure in pool.py. SocketInfo and Pool are now responsible for catching all socket.errors and gaierrors and translating them to ConnectionFailure. Server and MongoClient need no longer worry about anything but ConnectionFailure. Functions in pool.py and network.py still throw socket.errors into SocketInfo and Pool. --- pymongo/cursor.py | 7 ++-- pymongo/monitor.py | 2 +- pymongo/pool.py | 96 ++++++++++++++++++++++++++----------------- pymongo/server.py | 48 ++++++++-------------- test/pymongo_mocks.py | 8 ++-- test/test_pooling.py | 15 ++++--- test/test_topology.py | 10 ++--- 7 files changed, 98 insertions(+), 88 deletions(-) diff --git a/pymongo/cursor.py b/pymongo/cursor.py index d6b0c39bf..ca75527e5 100644 --- a/pymongo/cursor.py +++ b/pymongo/cursor.py @@ -15,8 +15,6 @@ """Cursor class to iterate over Mongo query results.""" import copy -import socket - from collections import deque from bson import RE_TYPE @@ -28,6 +26,7 @@ from bson.son import SON from pymongo import helpers from pymongo.common import validate_boolean, validate_is_mapping from pymongo.errors import (AutoReconnect, + ConnectionFailure, InvalidOperation, NotMasterError, OperationFailure) @@ -839,9 +838,9 @@ class Cursor(object): # Exhaust cursor - no getMore message. try: data = self.__exhaust_mgr.sock.receive_message(1, None) - except socket.error as exc: + except ConnectionFailure: self.__die() - raise AutoReconnect(str(exc)) + raise try: doc = helpers._unpack_response(response=data, diff --git a/pymongo/monitor.py b/pymongo/monitor.py index 2a174208e..50175e8c3 100644 --- a/pymongo/monitor.py +++ b/pymongo/monitor.py @@ -151,7 +151,7 @@ class Monitor(object): def _check_with_socket(self, sock_info): """Return (IsMaster, round_trip_time). - Can raise socket.error or PyMongoError. + Can raise ConnectionFailure or OperationFailure. """ start = _time() request_id, msg, _ = message.query( diff --git a/pymongo/pool.py b/pymongo/pool.py index 34e490a00..e8ccc121b 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -19,7 +19,7 @@ import threading from bson.py3compat import u, itervalues from pymongo import auth, thread_util -from pymongo.errors import ConnectionFailure +from pymongo.errors import AutoReconnect, ConnectionFailure, NetworkTimeout from pymongo.ismaster import IsMaster from pymongo.monotonic import time as _time from pymongo.network import (command, @@ -40,6 +40,16 @@ except ImportError: from pymongo.ssl_match_hostname import match_hostname, CertificateError +def _raise_connection_failure(address, error): + """Convert a socket.error to ConnectionFailure and raise it.""" + host, port = address + msg = '%s:%d: %s' % (host, port, error) + if isinstance(error, socket.timeout): + raise NetworkTimeout(msg) + else: + raise AutoReconnect(msg) + + class PoolOptions(object): __slots__ = ('__max_pool_size', '__connect_timeout', '__socket_timeout', @@ -121,11 +131,11 @@ class SocketInfo(object): - `sock`: a raw socket object - `pool`: a Pool instance - `ismaster`: an IsMaster instance, response to ismaster call on `sock` - - `host`: a string, the server's hostname (without port) + - `address`: the server's (host, port) """ - def __init__(self, sock, pool, ismaster, host): + def __init__(self, sock, pool, ismaster, address): self.sock = sock - self.host = host + self.address = address self.authset = set() self.closed = False self.last_checkout = _time() @@ -136,7 +146,7 @@ class SocketInfo(object): self.pool_id = pool.pool_id def command(self, dbname, spec): - """Execute a command or raise socket.error or OperationFailure. + """Execute a command or raise ConnectionFailure or OperationFailure. :Parameters: - `dbname`: name of the database on which to run the command @@ -144,37 +154,34 @@ class SocketInfo(object): """ try: return command(self.sock, dbname, spec) - except socket.error: - self.close() - raise + except socket.error as error: + self._raise_connection_failure(error) def send_message(self, message): - """Send a raw BSON message or raise socket.error. + """Send a raw BSON message or raise ConnectionFailure. If a network exception is raised, the socket is closed. """ try: self.sock.sendall(message) - except: - self.close() - raise + except socket.error as error: + self._raise_connection_failure(error) def receive_message(self, operation, request_id): - """Receive a raw BSON message or raise socket.error. + """Receive a raw BSON message or raise ConnectionFailure. If any exception is raised, the socket is closed. """ try: return receive_message(self.sock, operation, request_id) - except: - self.close() - raise + except socket.error as error: + self._raise_connection_failure(error) def check_auth(self, all_credentials): """Update this socket's authentication. Log in or out to bring this socket's credentials up to date with - those provided. Can raise socket.error or OperationFailure. + those provided. Can raise ConnectionFailure or OperationFailure. :Parameters: - `all_credentials`: dict, maps auth source to MongoCredential. @@ -195,6 +202,8 @@ class SocketInfo(object): def authenticate(self, credentials): """Log in to the server and store these credentials in `authset`. + Can raise ConnectionFailure or OperationFailure. + :Parameters: - `credentials`: A MongoCredential. """ @@ -215,6 +224,10 @@ class SocketInfo(object): 'max_wire_version on a SocketInfo created without handshake') return self.ismaster.max_wire_version + def _raise_connection_failure(self, error): + self.close() + _raise_connection_failure(self.address, error) + def __eq__(self, other): return self.sock == other.sock @@ -235,6 +248,8 @@ class SocketInfo(object): def _create_connection(address, options): """Given (host, port) and PoolOptions, connect and return a socket object. + Can raise socket.error. + This is a modified version of create_connection from CPython >= 2.6. """ host, port = address @@ -248,9 +263,9 @@ def _create_connection(address, options): try: sock.connect(host) return sock - except socket.error as e: + except socket.error: sock.close() - raise e + raise # Don't try IPv6 if we don't support it. Also skip it if host # is 'localhost' (::1 is fine). Avoids slow connect issues @@ -287,6 +302,8 @@ def _create_connection(address, options): def _configured_socket(address, options): """Given (host, port) and PoolOptions, return a configured socket. + Can raise socket.error, ConnectionFailure, or CertificateError. + Sets socket's SSL and timeout options. """ sock = _create_connection(address, options) @@ -355,20 +372,25 @@ class Pool: sock_info.close() def connect(self): - """Connect to Mongo and return a new (connected) socket. Note that the - pool does not keep a reference to the socket -- you must call - return_socket() when you're done with it. + """Connect to Mongo and return a new SocketInfo. + + Can raise ConnectionFailure or CertificateError. + + Note that the pool does not keep a reference to the socket -- you + must call return_socket() when you're done with it. """ - sock = _configured_socket(self.address, self.opts) - if self.handshake: - try: + sock = None + try: + sock = _configured_socket(self.address, self.opts) + if self.handshake: ismaster = IsMaster(command(sock, 'admin', {'ismaster': 1})) - except: + else: + ismaster = None + return SocketInfo(sock, self, ismaster, self.address) + except socket.error as error: + if sock is not None: sock.close() - raise - else: - ismaster = None - return SocketInfo(sock, self, ismaster, host=self.address[0]) + _raise_connection_failure(self.address, error) @contextlib.contextmanager def get_socket(self, all_credentials, checkout=False): @@ -387,6 +409,8 @@ class Pool: using the correct authentication mechanism for the server's wire protocol version. + Can raise ConnectionFailure or OperationFailure. + :Parameters: - `all_credentials`: dict, maps auth source to MongoCredential. - `checkout` (optional): keep socket checked out. @@ -397,13 +421,8 @@ class Pool: try: sock_info.check_auth(all_credentials) yield sock_info - except (socket.error, ConnectionFailure): - sock_info.close() - - # Decrement semaphore. - self.return_socket(sock_info) - raise except: + # Exception in caller. Decrement semaphore. self.return_socket(sock_info) raise else: @@ -411,6 +430,7 @@ class Pool: self.return_socket(sock_info) def _get_socket_no_auth(self): + """Get or create a SocketInfo. Can raise ConnectionFailure.""" # We use the pid here to avoid issues with fork / multiprocessing. # See test.test_client:TestClient.test_fork for an example of # what could go wrong otherwise @@ -430,9 +450,11 @@ class Pool: with self.lock: sock_info, from_pool = self.sockets.pop(), True except KeyError: + # Can raise ConnectionFailure or CertificateError. sock_info, from_pool = self.connect(), False if from_pool: + # Can raise ConnectionFailure. sock_info = self._check(sock_info) except: @@ -460,7 +482,7 @@ class Pool: the last time this socket was used, or if the socket has been closed by some external network error, and if so, attempts to create a new socket. If this connection attempt fails we reset the pool and reraise the - error. + ConnectionFailure. Checking sockets lets us avoid seeing *some* :class:`~pymongo.errors.AutoReconnect` exceptions on server diff --git a/pymongo/server.py b/pymongo/server.py index 009016307..6a144f26e 100644 --- a/pymongo/server.py +++ b/pymongo/server.py @@ -15,9 +15,8 @@ """Communicate with one MongoDB server in a topology.""" import contextlib -import socket -from pymongo.errors import AutoReconnect, DocumentTooLarge, NetworkTimeout +from pymongo.errors import DocumentTooLarge from pymongo.response import Response, ExhaustResponse from pymongo.server_type import SERVER_TYPE @@ -62,11 +61,8 @@ class Server(object): - `all_credentials`: dict, maps auth source to MongoCredential. """ request_id, data = self._check_bson_size(message) - try: - with self.get_socket(all_credentials) as sock_info: - sock_info.send_message(data) - except socket.error as exc: - self._raise_connection_failure(exc) + with self.get_socket(all_credentials) as sock_info: + sock_info.send_message(data) def send_message_with_response( self, @@ -84,22 +80,19 @@ class Server(object): It is returned along with its Pool in the Response. """ request_id, data = self._check_bson_size(message) - try: - with self.get_socket(all_credentials, exhaust) as sock_info: - sock_info.send_message(data) - response_data = sock_info.receive_message(1, request_id) - if exhaust: - return ExhaustResponse( - data=response_data, - address=self._description.address, - socket_info=sock_info, - pool=self._pool) - else: - return Response( - data=response_data, - address=self._description.address) - except socket.error as exc: - self._raise_connection_failure(exc) + with self.get_socket(all_credentials, exhaust) as sock_info: + sock_info.send_message(data) + response_data = sock_info.receive_message(1, request_id) + if exhaust: + return ExhaustResponse( + data=response_data, + address=self._description.address, + socket_info=sock_info, + pool=self._pool) + else: + return Response( + data=response_data, + address=self._description.address) @contextlib.contextmanager def get_socket(self, all_credentials, checkout=False): @@ -140,15 +133,6 @@ class Server(object): # get_more and kill_cursors messages don't include BSON documents. return message - def _raise_connection_failure(self, exc): - host, port = self._description.address - msg = '%s:%d: %s' % (host, port, exc) - - if isinstance(exc, socket.timeout): - raise NetworkTimeout(msg) - else: - raise AutoReconnect(msg) - def __str__(self): d = self._description return '' % ( diff --git a/test/pymongo_mocks.py b/test/pymongo_mocks.py index 7785d8b6f..e60cf883b 100644 --- a/test/pymongo_mocks.py +++ b/test/pymongo_mocks.py @@ -15,12 +15,12 @@ """Tools for mocking parts of PyMongo to test other parts.""" import contextlib -import socket from functools import partial import weakref from pymongo import common from pymongo import MongoClient +from pymongo.errors import AutoReconnect, NetworkTimeout from pymongo.ismaster import IsMaster from pymongo.monitor import Monitor from pymongo.pool import Pool, PoolOptions @@ -46,7 +46,7 @@ class MockPool(Pool): client = self.client host_and_port = '%s:%s' % (self.mock_host, self.mock_port) if host_and_port in client.mock_down_hosts: - raise socket.error('mock error') + raise AutoReconnect('mock error') assert host_and_port in ( client.mock_standalones @@ -152,7 +152,7 @@ class MockClient(MongoClient): # host is like 'a:1'. if host in self.mock_down_hosts: - raise socket.timeout('mock timeout') + raise NetworkTimeout('mock timeout') elif host in self.mock_standalones: response = { @@ -188,7 +188,7 @@ class MockClient(MongoClient): else: # In test_internal_ips(), we try to connect to a host listed # in ismaster['hosts'] but not publicly accessible. - raise socket.error('Unknown host: %s' % host) + raise AutoReconnect('Unknown host: %s' % host) return response, rtt diff --git a/test/test_pooling.py b/test/test_pooling.py index f2b2fc386..cd7c88234 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -16,13 +16,12 @@ import gc import random -import socket import sys import threading import time from pymongo import MongoClient -from pymongo.errors import (ConfigurationError, +from pymongo.errors import (AutoReconnect, ConnectionFailure, DuplicateKeyError, ExceededMaxWaiters) @@ -265,7 +264,7 @@ class TestPooling(_TestPoolingBase): # Swap pool's address with a bad one. address, cx_pool.address = cx_pool.address, ('foo.com', 1234) - with self.assertRaises(socket.error): + with self.assertRaises(AutoReconnect): with cx_pool.get_socket({}): pass @@ -495,12 +494,18 @@ class TestPoolMaxSize(_TestPoolingBase): # First call to get_socket fails; if pool doesn't release its semaphore # then the second call raises "ConnectionFailure: Timed out waiting for - # socket from pool" instead of the socket.error. + # socket from pool" instead of AutoReconnect. for i in range(2): - with self.assertRaises(socket.error): + with self.assertRaises(AutoReconnect) as context: with test_pool.get_socket({}, checkout=True): pass + # Testing for AutoReconnect instead of ConnectionFailure, above, + # is sufficient right *now* to catch a semaphore leak. But that + # seems error-prone, so check the message too. + self.assertNotIn('waiting for socket from pool', + str(context.exception)) + if __name__ == "__main__": unittest.main() diff --git a/test/test_topology.py b/test/test_topology.py index 1fe17599b..b3832bd18 100644 --- a/test/test_topology.py +++ b/test/test_topology.py @@ -18,7 +18,6 @@ import sys sys.path[0:0] = [""] -import socket import threading from bson.py3compat import imap @@ -27,7 +26,8 @@ from pymongo.read_preferences import ReadPreference, Secondary from pymongo.server_type import SERVER_TYPE from pymongo.topology import Topology from pymongo.topology_description import TOPOLOGY_TYPE -from pymongo.errors import (ConfigurationError, +from pymongo.errors import (AutoReconnect, + ConfigurationError, ConnectionFailure) from pymongo.ismaster import IsMaster from pymongo.monitor import Monitor @@ -239,7 +239,7 @@ class TestSingleServerTopology(TopologyTest): if available: return IsMaster({'ok': 1}), round_trip_time else: - raise socket.error() + raise AutoReconnect('mock monitor error') t = create_mock_topology(monitor_class=TestMonitor) s = t.select_server(writable_server_selector) @@ -543,7 +543,7 @@ class TestTopologyErrors(TopologyTest): if ismaster_count[0] == 1: return IsMaster({'ok': 1}), 0 else: - raise socket.error() + raise AutoReconnect('mock monitor error') t = create_mock_topology(monitor_class=TestMonitor) server = wait_for_master(t) @@ -564,7 +564,7 @@ class TestTopologyErrors(TopologyTest): if ismaster_count[0] in (1, 3): return IsMaster({'ok': 1}), 0 else: - raise socket.error() + raise AutoReconnect('mock monitor error') t = create_mock_topology(monitor_class=TestMonitor) server = wait_for_master(t)