diff --git a/pymongo/cursor.py b/pymongo/cursor.py index d6b0c39bf..ca75527e5 100644 --- a/pymongo/cursor.py +++ b/pymongo/cursor.py @@ -15,8 +15,6 @@ """Cursor class to iterate over Mongo query results.""" import copy -import socket - from collections import deque from bson import RE_TYPE @@ -28,6 +26,7 @@ from bson.son import SON from pymongo import helpers from pymongo.common import validate_boolean, validate_is_mapping from pymongo.errors import (AutoReconnect, + ConnectionFailure, InvalidOperation, NotMasterError, OperationFailure) @@ -839,9 +838,9 @@ class Cursor(object): # Exhaust cursor - no getMore message. try: data = self.__exhaust_mgr.sock.receive_message(1, None) - except socket.error as exc: + except ConnectionFailure: self.__die() - raise AutoReconnect(str(exc)) + raise try: doc = helpers._unpack_response(response=data, diff --git a/pymongo/monitor.py b/pymongo/monitor.py index 2a174208e..50175e8c3 100644 --- a/pymongo/monitor.py +++ b/pymongo/monitor.py @@ -151,7 +151,7 @@ class Monitor(object): def _check_with_socket(self, sock_info): """Return (IsMaster, round_trip_time). - Can raise socket.error or PyMongoError. + Can raise ConnectionFailure or OperationFailure. """ start = _time() request_id, msg, _ = message.query( diff --git a/pymongo/pool.py b/pymongo/pool.py index 34e490a00..e8ccc121b 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -19,7 +19,7 @@ import threading from bson.py3compat import u, itervalues from pymongo import auth, thread_util -from pymongo.errors import ConnectionFailure +from pymongo.errors import AutoReconnect, ConnectionFailure, NetworkTimeout from pymongo.ismaster import IsMaster from pymongo.monotonic import time as _time from pymongo.network import (command, @@ -40,6 +40,16 @@ except ImportError: from pymongo.ssl_match_hostname import match_hostname, CertificateError +def _raise_connection_failure(address, error): + """Convert a socket.error to ConnectionFailure and raise it.""" + host, port = address + msg = '%s:%d: %s' % (host, port, error) + if isinstance(error, socket.timeout): + raise NetworkTimeout(msg) + else: + raise AutoReconnect(msg) + + class PoolOptions(object): __slots__ = ('__max_pool_size', '__connect_timeout', '__socket_timeout', @@ -121,11 +131,11 @@ class SocketInfo(object): - `sock`: a raw socket object - `pool`: a Pool instance - `ismaster`: an IsMaster instance, response to ismaster call on `sock` - - `host`: a string, the server's hostname (without port) + - `address`: the server's (host, port) """ - def __init__(self, sock, pool, ismaster, host): + def __init__(self, sock, pool, ismaster, address): self.sock = sock - self.host = host + self.address = address self.authset = set() self.closed = False self.last_checkout = _time() @@ -136,7 +146,7 @@ class SocketInfo(object): self.pool_id = pool.pool_id def command(self, dbname, spec): - """Execute a command or raise socket.error or OperationFailure. + """Execute a command or raise ConnectionFailure or OperationFailure. :Parameters: - `dbname`: name of the database on which to run the command @@ -144,37 +154,34 @@ class SocketInfo(object): """ try: return command(self.sock, dbname, spec) - except socket.error: - self.close() - raise + except socket.error as error: + self._raise_connection_failure(error) def send_message(self, message): - """Send a raw BSON message or raise socket.error. + """Send a raw BSON message or raise ConnectionFailure. If a network exception is raised, the socket is closed. """ try: self.sock.sendall(message) - except: - self.close() - raise + except socket.error as error: + self._raise_connection_failure(error) def receive_message(self, operation, request_id): - """Receive a raw BSON message or raise socket.error. + """Receive a raw BSON message or raise ConnectionFailure. If any exception is raised, the socket is closed. """ try: return receive_message(self.sock, operation, request_id) - except: - self.close() - raise + except socket.error as error: + self._raise_connection_failure(error) def check_auth(self, all_credentials): """Update this socket's authentication. Log in or out to bring this socket's credentials up to date with - those provided. Can raise socket.error or OperationFailure. + those provided. Can raise ConnectionFailure or OperationFailure. :Parameters: - `all_credentials`: dict, maps auth source to MongoCredential. @@ -195,6 +202,8 @@ class SocketInfo(object): def authenticate(self, credentials): """Log in to the server and store these credentials in `authset`. + Can raise ConnectionFailure or OperationFailure. + :Parameters: - `credentials`: A MongoCredential. """ @@ -215,6 +224,10 @@ class SocketInfo(object): 'max_wire_version on a SocketInfo created without handshake') return self.ismaster.max_wire_version + def _raise_connection_failure(self, error): + self.close() + _raise_connection_failure(self.address, error) + def __eq__(self, other): return self.sock == other.sock @@ -235,6 +248,8 @@ class SocketInfo(object): def _create_connection(address, options): """Given (host, port) and PoolOptions, connect and return a socket object. + Can raise socket.error. + This is a modified version of create_connection from CPython >= 2.6. """ host, port = address @@ -248,9 +263,9 @@ def _create_connection(address, options): try: sock.connect(host) return sock - except socket.error as e: + except socket.error: sock.close() - raise e + raise # Don't try IPv6 if we don't support it. Also skip it if host # is 'localhost' (::1 is fine). Avoids slow connect issues @@ -287,6 +302,8 @@ def _create_connection(address, options): def _configured_socket(address, options): """Given (host, port) and PoolOptions, return a configured socket. + Can raise socket.error, ConnectionFailure, or CertificateError. + Sets socket's SSL and timeout options. """ sock = _create_connection(address, options) @@ -355,20 +372,25 @@ class Pool: sock_info.close() def connect(self): - """Connect to Mongo and return a new (connected) socket. Note that the - pool does not keep a reference to the socket -- you must call - return_socket() when you're done with it. + """Connect to Mongo and return a new SocketInfo. + + Can raise ConnectionFailure or CertificateError. + + Note that the pool does not keep a reference to the socket -- you + must call return_socket() when you're done with it. """ - sock = _configured_socket(self.address, self.opts) - if self.handshake: - try: + sock = None + try: + sock = _configured_socket(self.address, self.opts) + if self.handshake: ismaster = IsMaster(command(sock, 'admin', {'ismaster': 1})) - except: + else: + ismaster = None + return SocketInfo(sock, self, ismaster, self.address) + except socket.error as error: + if sock is not None: sock.close() - raise - else: - ismaster = None - return SocketInfo(sock, self, ismaster, host=self.address[0]) + _raise_connection_failure(self.address, error) @contextlib.contextmanager def get_socket(self, all_credentials, checkout=False): @@ -387,6 +409,8 @@ class Pool: using the correct authentication mechanism for the server's wire protocol version. + Can raise ConnectionFailure or OperationFailure. + :Parameters: - `all_credentials`: dict, maps auth source to MongoCredential. - `checkout` (optional): keep socket checked out. @@ -397,13 +421,8 @@ class Pool: try: sock_info.check_auth(all_credentials) yield sock_info - except (socket.error, ConnectionFailure): - sock_info.close() - - # Decrement semaphore. - self.return_socket(sock_info) - raise except: + # Exception in caller. Decrement semaphore. self.return_socket(sock_info) raise else: @@ -411,6 +430,7 @@ class Pool: self.return_socket(sock_info) def _get_socket_no_auth(self): + """Get or create a SocketInfo. Can raise ConnectionFailure.""" # We use the pid here to avoid issues with fork / multiprocessing. # See test.test_client:TestClient.test_fork for an example of # what could go wrong otherwise @@ -430,9 +450,11 @@ class Pool: with self.lock: sock_info, from_pool = self.sockets.pop(), True except KeyError: + # Can raise ConnectionFailure or CertificateError. sock_info, from_pool = self.connect(), False if from_pool: + # Can raise ConnectionFailure. sock_info = self._check(sock_info) except: @@ -460,7 +482,7 @@ class Pool: the last time this socket was used, or if the socket has been closed by some external network error, and if so, attempts to create a new socket. If this connection attempt fails we reset the pool and reraise the - error. + ConnectionFailure. Checking sockets lets us avoid seeing *some* :class:`~pymongo.errors.AutoReconnect` exceptions on server diff --git a/pymongo/server.py b/pymongo/server.py index 009016307..6a144f26e 100644 --- a/pymongo/server.py +++ b/pymongo/server.py @@ -15,9 +15,8 @@ """Communicate with one MongoDB server in a topology.""" import contextlib -import socket -from pymongo.errors import AutoReconnect, DocumentTooLarge, NetworkTimeout +from pymongo.errors import DocumentTooLarge from pymongo.response import Response, ExhaustResponse from pymongo.server_type import SERVER_TYPE @@ -62,11 +61,8 @@ class Server(object): - `all_credentials`: dict, maps auth source to MongoCredential. """ request_id, data = self._check_bson_size(message) - try: - with self.get_socket(all_credentials) as sock_info: - sock_info.send_message(data) - except socket.error as exc: - self._raise_connection_failure(exc) + with self.get_socket(all_credentials) as sock_info: + sock_info.send_message(data) def send_message_with_response( self, @@ -84,22 +80,19 @@ class Server(object): It is returned along with its Pool in the Response. """ request_id, data = self._check_bson_size(message) - try: - with self.get_socket(all_credentials, exhaust) as sock_info: - sock_info.send_message(data) - response_data = sock_info.receive_message(1, request_id) - if exhaust: - return ExhaustResponse( - data=response_data, - address=self._description.address, - socket_info=sock_info, - pool=self._pool) - else: - return Response( - data=response_data, - address=self._description.address) - except socket.error as exc: - self._raise_connection_failure(exc) + with self.get_socket(all_credentials, exhaust) as sock_info: + sock_info.send_message(data) + response_data = sock_info.receive_message(1, request_id) + if exhaust: + return ExhaustResponse( + data=response_data, + address=self._description.address, + socket_info=sock_info, + pool=self._pool) + else: + return Response( + data=response_data, + address=self._description.address) @contextlib.contextmanager def get_socket(self, all_credentials, checkout=False): @@ -140,15 +133,6 @@ class Server(object): # get_more and kill_cursors messages don't include BSON documents. return message - def _raise_connection_failure(self, exc): - host, port = self._description.address - msg = '%s:%d: %s' % (host, port, exc) - - if isinstance(exc, socket.timeout): - raise NetworkTimeout(msg) - else: - raise AutoReconnect(msg) - def __str__(self): d = self._description return '' % ( diff --git a/test/pymongo_mocks.py b/test/pymongo_mocks.py index 7785d8b6f..e60cf883b 100644 --- a/test/pymongo_mocks.py +++ b/test/pymongo_mocks.py @@ -15,12 +15,12 @@ """Tools for mocking parts of PyMongo to test other parts.""" import contextlib -import socket from functools import partial import weakref from pymongo import common from pymongo import MongoClient +from pymongo.errors import AutoReconnect, NetworkTimeout from pymongo.ismaster import IsMaster from pymongo.monitor import Monitor from pymongo.pool import Pool, PoolOptions @@ -46,7 +46,7 @@ class MockPool(Pool): client = self.client host_and_port = '%s:%s' % (self.mock_host, self.mock_port) if host_and_port in client.mock_down_hosts: - raise socket.error('mock error') + raise AutoReconnect('mock error') assert host_and_port in ( client.mock_standalones @@ -152,7 +152,7 @@ class MockClient(MongoClient): # host is like 'a:1'. if host in self.mock_down_hosts: - raise socket.timeout('mock timeout') + raise NetworkTimeout('mock timeout') elif host in self.mock_standalones: response = { @@ -188,7 +188,7 @@ class MockClient(MongoClient): else: # In test_internal_ips(), we try to connect to a host listed # in ismaster['hosts'] but not publicly accessible. - raise socket.error('Unknown host: %s' % host) + raise AutoReconnect('Unknown host: %s' % host) return response, rtt diff --git a/test/test_pooling.py b/test/test_pooling.py index f2b2fc386..cd7c88234 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -16,13 +16,12 @@ import gc import random -import socket import sys import threading import time from pymongo import MongoClient -from pymongo.errors import (ConfigurationError, +from pymongo.errors import (AutoReconnect, ConnectionFailure, DuplicateKeyError, ExceededMaxWaiters) @@ -265,7 +264,7 @@ class TestPooling(_TestPoolingBase): # Swap pool's address with a bad one. address, cx_pool.address = cx_pool.address, ('foo.com', 1234) - with self.assertRaises(socket.error): + with self.assertRaises(AutoReconnect): with cx_pool.get_socket({}): pass @@ -495,12 +494,18 @@ class TestPoolMaxSize(_TestPoolingBase): # First call to get_socket fails; if pool doesn't release its semaphore # then the second call raises "ConnectionFailure: Timed out waiting for - # socket from pool" instead of the socket.error. + # socket from pool" instead of AutoReconnect. for i in range(2): - with self.assertRaises(socket.error): + with self.assertRaises(AutoReconnect) as context: with test_pool.get_socket({}, checkout=True): pass + # Testing for AutoReconnect instead of ConnectionFailure, above, + # is sufficient right *now* to catch a semaphore leak. But that + # seems error-prone, so check the message too. + self.assertNotIn('waiting for socket from pool', + str(context.exception)) + if __name__ == "__main__": unittest.main() diff --git a/test/test_topology.py b/test/test_topology.py index 1fe17599b..b3832bd18 100644 --- a/test/test_topology.py +++ b/test/test_topology.py @@ -18,7 +18,6 @@ import sys sys.path[0:0] = [""] -import socket import threading from bson.py3compat import imap @@ -27,7 +26,8 @@ from pymongo.read_preferences import ReadPreference, Secondary from pymongo.server_type import SERVER_TYPE from pymongo.topology import Topology from pymongo.topology_description import TOPOLOGY_TYPE -from pymongo.errors import (ConfigurationError, +from pymongo.errors import (AutoReconnect, + ConfigurationError, ConnectionFailure) from pymongo.ismaster import IsMaster from pymongo.monitor import Monitor @@ -239,7 +239,7 @@ class TestSingleServerTopology(TopologyTest): if available: return IsMaster({'ok': 1}), round_trip_time else: - raise socket.error() + raise AutoReconnect('mock monitor error') t = create_mock_topology(monitor_class=TestMonitor) s = t.select_server(writable_server_selector) @@ -543,7 +543,7 @@ class TestTopologyErrors(TopologyTest): if ismaster_count[0] == 1: return IsMaster({'ok': 1}), 0 else: - raise socket.error() + raise AutoReconnect('mock monitor error') t = create_mock_topology(monitor_class=TestMonitor) server = wait_for_master(t) @@ -564,7 +564,7 @@ class TestTopologyErrors(TopologyTest): if ismaster_count[0] in (1, 3): return IsMaster({'ok': 1}), 0 else: - raise socket.error() + raise AutoReconnect('mock monitor error') t = create_mock_topology(monitor_class=TestMonitor) server = wait_for_master(t)