diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 8091d6c5c..8bd7fb9d6 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -415,7 +415,7 @@ functions: if [ -n "${MONGODB_STARTED}" ]; then export PYMONGO_MUST_CONNECT=1 fi - PYTHON_BINARY=${PYTHON_BINARY} GREEN_FRAMEWORK=${GREEN_FRAMEWORK} C_EXTENSIONS=${C_EXTENSIONS} COVERAGE=${COVERAGE} AUTH=${AUTH} SSL=${SSL} sh ${PROJECT_DIRECTORY}/.evergreen/run-tests.sh + PYTHON_BINARY=${PYTHON_BINARY} GREEN_FRAMEWORK=${GREEN_FRAMEWORK} C_EXTENSIONS=${C_EXTENSIONS} COVERAGE=${COVERAGE} COMPRESSORS=${COMPRESSORS} AUTH=${AUTH} SSL=${SSL} sh ${PROJECT_DIRECTORY}/.evergreen/run-tests.sh "run enterprise auth tests": - command: shell.exec @@ -968,6 +968,17 @@ axes: variables: AUTH: "noauth" SSL: "nossl" + - id: compression + display_name: Compression + values: + - id: snappy + display_name: snappy compression + variables: + COMPRESSORS: "snappy" + - id: zlib + display_name: zlib compression + variables: + COMPRESSORS: "zlib" - id: python-version display_name: "Python" values: @@ -1049,6 +1060,10 @@ axes: display_name: "Without C Extensions" variables: C_EXTENSIONS: "--no_ext" + - id: "with-c-extensions" + display_name: "With C Extensions" + variables: + C_EXTENSIONS: "" - id: storage-engine display_name: Storage values: @@ -1186,7 +1201,7 @@ buildvariants: exclude_spec: # These interpreters are always tested without extensions. - python-version: ["pypy", "pypy3", "pypy3.5", "jython2.7"] - c-extensions: "*" + c-extensions: "without-c-extensions" auth: "*" ssl: "*" coverage: "*" @@ -1200,6 +1215,35 @@ buildvariants: - ".3.0" - ".2.6" +- matrix_name: "tests-python-version-ubuntu16-compression" + matrix_spec: {"python-version": "*", "c-extensions": "*", "compression": "*"} + exclude_spec: + # These interpreters are always tested without extensions. + - python-version: ["pypy", "pypy3", "pypy3.5", "jython2.7"] + c-extensions: "with-c-extensions" + compression: "*" + - python-version: ["jython2.7"] + c-extensions: "*" + compression: "snappy" + batchtime: 10080 # 7 days + display_name: "${compression} ${c-extensions} ${python-version} (x86_64)" + # Ubuntu 16.04 images have libsnappy-dev installed + run_on: ubuntu1604-test + tasks: + - "test-latest-standalone" + - "test-3.6-standalone" + +- matrix_name: "tests-python-version-py37-plus-compression" + matrix_spec: {"python-version-requires-openssl-102-plus": "*", "c-extensions": "*", "compression": "*"} + exclude_spec: + display_name: "${compression} ${c-extensions} ${python-version-requires-openssl-102-plus} (x86_64)" + batchtime: 10080 # 7 days + # Ubuntu 16.04 images have libsnappy-dev installed, and provides OpenSSL 1.0.2 + run_on: ubuntu1604-test + tasks: + - "test-latest-standalone" + - "test-3.6-standalone" + - matrix_name: "tests-python-version-green-framework-rhel62" matrix_spec: {"python-version": "*", "green-framework": "*", auth-ssl: "*"} exclude_spec: diff --git a/.evergreen/run-tests.sh b/.evergreen/run-tests.sh index 3b70adf5b..c69586acc 100755 --- a/.evergreen/run-tests.sh +++ b/.evergreen/run-tests.sh @@ -17,6 +17,11 @@ PYTHON_BINARY=${PYTHON_BINARY:-} GREEN_FRAMEWORK=${GREEN_FRAMEWORK:-} C_EXTENSIONS=${C_EXTENSIONS:-} COVERAGE=${COVERAGE:-} +COMPRESSORS=${COMPRESSORS:-} + +if [ -n $COMPRESSORS ]; then + export COMPRESSORS=$COMPRESSORS +fi export JAVA_HOME=/opt/java/jdk8 @@ -47,6 +52,13 @@ if [ -z "$PYTHON_BINARY" ]; then PYTHON=python trap "deactivate; rm -rf pymongotestvenv" EXIT HUP fi +elif [ $COMPRESSORS = "snappy" ]; then + $PYTHON_BINARY -m virtualenv --system-site-packages --never-download snappytest + . snappytest/bin/activate + trap "deactivate; rm -rf snappytest" EXIT HUP + # 0.5.2 has issues in pypy3(.5) + pip install python-snappy==0.5.1 + PYTHON=python else PYTHON="$PYTHON_BINARY" fi diff --git a/README.rst b/README.rst index 43ba301e6..72a2d2ae8 100644 --- a/README.rst +++ b/README.rst @@ -113,10 +113,15 @@ PyMongo:: $ python -m pip install pymongo[tls] +Wire protocol compression with snappy requires `python-snappy +`_:: + + $ python -m pip install pymongo[snappy] + You can install all dependencies automatically with the following command:: - $ python -m pip install pymongo[gssapi,srv,tls] + $ python -m pip install pymongo[snappy,gssapi,srv,tls] Other optional packages: diff --git a/doc/changelog.rst b/doc/changelog.rst index dd7d8417a..9af491ec8 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -9,6 +9,8 @@ Version 3.7 adds support for MongoDB 4.0. Highlights include: - Support for multi-document transactions, see :ref:`transactions-ref`. - Support for the SCRAM-SHA-256 authentication mechanism. - Support for Python 3.7. +- Support for wire protocol compression. See + :meth:`~pymongo.mongo_client.MongoClient` for details. - MD5 is now optional in GridFS. - If not specified, the authSource for the PLAIN authentication mechanism defaults to $external. diff --git a/doc/installation.rst b/doc/installation.rst index b57db4de8..00c10d9a6 100644 --- a/doc/installation.rst +++ b/doc/installation.rst @@ -70,10 +70,15 @@ PyMongo:: $ python -m pip install pymongo[tls] +Wire protocol compression with snappy requires `python-snappy +`_:: + + $ python -m pip install pymongo[snappy] + You can install all dependencies automatically with the following command:: - $ python -m pip install pymongo[gssapi,srv,tls] + $ python -m pip install pymongo[snappy,gssapi,srv,tls] Other optional packages: diff --git a/pymongo/_cmessagemodule.c b/pymongo/_cmessagemodule.c index 17de8b266..27d2daf30 100644 --- a/pymongo/_cmessagemodule.c +++ b/pymongo/_cmessagemodule.c @@ -142,19 +142,25 @@ static int add_last_error(PyObject* self, buffer_t buffer, } static int init_insert_buffer(buffer_t buffer, int request_id, int options, - const char* coll_name, int coll_name_len) { - /* Save space for message length */ - int length_location = buffer_save_space(buffer, 4); - if (length_location == -1) { - PyErr_NoMemory(); - return length_location; + const char* coll_name, int coll_name_len, + int compress) { + int length_location = 0; + if (!compress) { + /* Save space for message length */ + int length_location = buffer_save_space(buffer, 4); + if (length_location == -1) { + PyErr_NoMemory(); + return length_location; + } + if (!buffer_write_int32(buffer, (int32_t)request_id) || + !buffer_write_bytes(buffer, + "\x00\x00\x00\x00" + "\xd2\x07\x00\x00", + 8)) { + return -1; + } } - if (!buffer_write_int32(buffer, (int32_t)request_id) || - !buffer_write_bytes(buffer, - "\x00\x00\x00\x00" - "\xd2\x07\x00\x00", - 8) || - !buffer_write_int32(buffer, (int32_t)options) || + if (!buffer_write_int32(buffer, (int32_t)options) || !buffer_write_bytes(buffer, coll_name, coll_name_len + 1)) { @@ -211,7 +217,8 @@ static PyObject* _cbson_insert_message(PyObject* self, PyObject* args) { request_id, flags, collection_name, - collection_name_length); + collection_name_length, + 0); if (length_location == -1) { destroy_codec_options(&options); PyMem_Free(collection_name); @@ -660,7 +667,7 @@ static PyObject* _send_insert(PyObject* self, PyObject* ctx, PyObject* gle_args, buffer_t buffer, char* coll_name, int coll_len, int request_id, int safe, - codec_options_t* options, PyObject* to_publish) { + codec_options_t* options, PyObject* to_publish, int compress) { if (safe) { if (!add_last_error(self, buffer, request_id, @@ -669,16 +676,16 @@ _send_insert(PyObject* self, PyObject* ctx, } } - /* The max_doc_size parameter for legacy_write is the max size of any - * document in buffer. We enforced max size already, pass 0 here. */ - return PyObject_CallMethod(ctx, "legacy_write", - "i" BYTES_FORMAT_STRING "iNO", + /* The max_doc_size parameter for legacy_bulk_insert is the max size of + * any document in buffer. We enforced max size already, pass 0 here. */ + return PyObject_CallMethod(ctx, "legacy_bulk_insert", + "i" BYTES_FORMAT_STRING "iNOi", request_id, buffer_get_buffer(buffer), buffer_get_position(buffer), 0, PyBool_FromLong((long)safe), - to_publish); + to_publish, compress); } static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) { @@ -689,6 +696,7 @@ static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) { int send_safe, flags = 0; int length_location, message_length; int collection_name_length; + int compress; char* collection_name = NULL; PyObject* docs; PyObject* doc; @@ -698,6 +706,7 @@ static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) { PyObject* result; PyObject* max_bson_size_obj; PyObject* max_message_size_obj; + PyObject* compress_obj; PyObject* to_publish = NULL; unsigned char check_keys; unsigned char safe; @@ -754,6 +763,17 @@ static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) { return NULL; } + compress_obj = PyObject_GetAttrString(ctx, "compress"); + compress = PyObject_IsTrue(compress_obj); + Py_XDECREF(compress_obj); + if (compress == -1) { + destroy_codec_options(&options); + PyMem_Free(collection_name); + return NULL; + } + + compress = compress && !(safe || send_safe); + buffer = buffer_new(); if (!buffer) { destroy_codec_options(&options); @@ -766,7 +786,8 @@ static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) { request_id, flags, collection_name, - collection_name_length); + collection_name_length, + compress); if (length_location == -1) { goto insertfail; } @@ -797,13 +818,15 @@ static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) { /* If we've encoded anything send it before raising. */ if (!empty) { buffer_update_position(buffer, before); - message_length = buffer_get_position(buffer) - length_location; - buffer_write_int32_at_position( - buffer, length_location, (int32_t)message_length); + if (!compress) { + message_length = buffer_get_position(buffer) - length_location; + buffer_write_int32_at_position( + buffer, length_location, (int32_t)message_length); + } result = _send_insert(self, ctx, last_error_args, buffer, collection_name, collection_name_length, request_id, send_safe, &options, - to_publish); + to_publish, compress); if (!result) goto iterfail; Py_DECREF(result); @@ -826,7 +849,8 @@ static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) { new_request_id, flags, collection_name, - collection_name_length); + collection_name_length, + compress); if (message_start == -1) { buffer_free(new_buffer); goto iterfail; @@ -841,13 +865,16 @@ static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) { /* Roll back to the beginning of this document. */ buffer_update_position(buffer, before); - message_length = buffer_get_position(buffer) - length_location; - buffer_write_int32_at_position( - buffer, length_location, (int32_t)message_length); + if (!compress) { + message_length = buffer_get_position(buffer) - length_location; + buffer_write_int32_at_position( + buffer, length_location, (int32_t)message_length); + } result = _send_insert(self, ctx, last_error_args, buffer, collection_name, collection_name_length, - request_id, send_safe, &options, to_publish); + request_id, send_safe, &options, to_publish, + compress); buffer_free(buffer); buffer = new_buffer; @@ -927,14 +954,16 @@ static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) { goto insertfail; } - message_length = buffer_get_position(buffer) - length_location; - buffer_write_int32_at_position( - buffer, length_location, (int32_t)message_length); + if (!compress) { + message_length = buffer_get_position(buffer) - length_location; + buffer_write_int32_at_position( + buffer, length_location, (int32_t)message_length); + } /* Send the last (or only) batch */ result = _send_insert(self, ctx, last_error_args, buffer, collection_name, collection_name_length, - request_id, safe, &options, to_publish); + request_id, safe, &options, to_publish, compress); Py_DECREF(to_publish); PyMem_Free(collection_name); @@ -971,43 +1000,16 @@ insertfail: return NULL; } -static buffer_t -_command_buffer_new(char* ns, int ns_len) { - buffer_t buffer; - if (!(buffer = buffer_new())) { - PyErr_NoMemory(); - return NULL; - } - /* Save space for message length and request id */ - if ((buffer_save_space(buffer, 8)) == -1) { - PyErr_NoMemory(); - buffer_free(buffer); - return NULL; - } - if (!buffer_write_bytes(buffer, - "\x00\x00\x00\x00" /* responseTo */ - "\xd4\x07\x00\x00" /* opcode */ - "\x00\x00\x00\x00", /* flags */ - 12) || - !buffer_write_bytes(buffer, - ns, ns_len + 1) || /* namespace */ - !buffer_write_bytes(buffer, - "\x00\x00\x00\x00" /* skip */ - "\xFF\xFF\xFF\xFF", /* limit (-1) */ - 8)) { - buffer_free(buffer); - return NULL; - } - return buffer; -} - #define _INSERT 0 #define _UPDATE 1 #define _DELETE 2 -static PyObject* -_cbson_do_batched_write_command(PyObject* self, PyObject* args) { - struct module_state *state = GETSTATE(self); +static int +_batched_write_command( + char* ns, int ns_len, unsigned char op, int check_keys, + PyObject* command, PyObject* docs, PyObject* ctx, + PyObject* to_publish, codec_options_t options, + buffer_t buffer, struct module_state *state) { long max_bson_size; long max_cmd_size; @@ -1015,31 +1017,12 @@ _cbson_do_batched_write_command(PyObject* self, PyObject* args) { int idx = 0; int cmd_len_loc; int lst_len_loc; - int ns_len; - int request_id; int position; int length; - char *ns = NULL; PyObject* max_bson_size_obj; PyObject* max_write_batch_size_obj; - PyObject* command; PyObject* doc; - PyObject* docs; - PyObject* ctx; PyObject* iterator; - PyObject* result; - PyObject* to_publish = NULL; - unsigned char op; - unsigned char check_keys; - codec_options_t options; - buffer_t buffer; - - if (!PyArg_ParseTuple(args, "et#bOObO&O", "utf-8", - &ns, &ns_len, &op, &command, &docs, &check_keys, - convert_codec_options, &options, - &ctx)) { - return NULL; - } max_bson_size_obj = PyObject_GetAttrString(ctx, "max_bson_size"); #if PY_MAJOR_VERSION >= 3 @@ -1049,9 +1032,7 @@ _cbson_do_batched_write_command(PyObject* self, PyObject* args) { #endif Py_XDECREF(max_bson_size_obj); if (max_bson_size == -1) { - destroy_codec_options(&options); - PyMem_Free(ns); - return NULL; + return 0; } /* * Max BSON object size + 16k - 2 bytes for ending NUL bytes @@ -1067,31 +1048,26 @@ _cbson_do_batched_write_command(PyObject* self, PyObject* args) { #endif Py_XDECREF(max_write_batch_size_obj); if (max_write_batch_size == -1) { - destroy_codec_options(&options); - PyMem_Free(ns); - return NULL; + return 0; } - if (!(to_publish = PyList_New(0))) { - destroy_codec_options(&options); - PyMem_Free(ns); - return NULL; + if (!buffer_write_bytes(buffer, + "\x00\x00\x00\x00", /* flags */ + 4) || + !buffer_write_bytes(buffer, + ns, ns_len + 1) || /* namespace */ + !buffer_write_bytes(buffer, + "\x00\x00\x00\x00" /* skip */ + "\xFF\xFF\xFF\xFF", /* limit (-1) */ + 8)) { + return 0; } - if (!(buffer = _command_buffer_new(ns, ns_len))) { - destroy_codec_options(&options); - PyMem_Free(ns); - Py_DECREF(to_publish); - return NULL; - } - - PyMem_Free(ns); - /* Position of command document length */ cmd_len_loc = buffer_get_position(buffer); if (!write_dict(state->_cbson, buffer, command, 0, &options, 0)) { - goto cmdfail; + return 0; } /* Write type byte for array */ @@ -1127,7 +1103,7 @@ _cbson_do_batched_write_command(PyObject* self, PyObject* args) { PyErr_SetString(InvalidOperation, "Unknown command"); Py_DECREF(InvalidOperation); } - goto cmdfail; + return 0; } } @@ -1135,7 +1111,7 @@ _cbson_do_batched_write_command(PyObject* self, PyObject* args) { lst_len_loc = buffer_save_space(buffer, 4); if (lst_len_loc == -1) { PyErr_NoMemory(); - goto cmdfail; + return 0; } iterator = PyObject_GetIter(docs); @@ -1145,7 +1121,7 @@ _cbson_do_batched_write_command(PyObject* self, PyObject* args) { PyErr_SetString(InvalidOperation, "input is not iterable"); Py_DECREF(InvalidOperation); } - goto cmdfail; + return 0; } while ((doc = PyIter_Next(iterator)) != NULL) { int sub_doc_begin = buffer_get_position(buffer); @@ -1214,32 +1190,151 @@ _cbson_do_batched_write_command(PyObject* self, PyObject* args) { goto cmdfail; - request_id = rand(); position = buffer_get_position(buffer); length = position - lst_len_loc - 1; buffer_write_int32_at_position(buffer, lst_len_loc, (int32_t)length); length = position - cmd_len_loc; buffer_write_int32_at_position(buffer, cmd_len_loc, (int32_t)length); + return 1; + +cmditerfail: + Py_XDECREF(doc); + Py_DECREF(iterator); +cmdfail: + return 0; +} + +static PyObject* +_cbson_encode_batched_write_command(PyObject* self, PyObject* args) { + char *ns = NULL; + unsigned char op; + unsigned char check_keys; + int ns_len; + PyObject* command; + PyObject* docs; + PyObject* ctx = NULL; + PyObject* to_publish = NULL; + PyObject* result = NULL; + codec_options_t options; + struct module_state *state = GETSTATE(self); + + if (!PyArg_ParseTuple(args, "et#bOObO&O", "utf-8", + &ns, &ns_len, &op, &command, &docs, &check_keys, + convert_codec_options, &options, + &ctx)) { + return NULL; + } + buffer_t buffer; + if (!(buffer = buffer_new())) { + PyErr_NoMemory(); + PyMem_Free(ns); + destroy_codec_options(&options); + return NULL; + } + if (!(to_publish = PyList_New(0))) { + goto fail; + } + + if (!_batched_write_command( + ns, + ns_len, + op, + check_keys, + command, + docs, + ctx, + to_publish, + options, + buffer, + state)) { + goto fail; + } + + result = Py_BuildValue(BYTES_FORMAT_STRING "O", + buffer_get_buffer(buffer), + buffer_get_position(buffer), + to_publish); +fail: + PyMem_Free(ns); + destroy_codec_options(&options); + buffer_free(buffer); + Py_XDECREF(to_publish); + return result; +} + +static PyObject* +_cbson_do_batched_write_command(PyObject* self, PyObject* args) { + char *ns = NULL; + unsigned char op; + unsigned char check_keys; + int ns_len; + int request_id; + int position; + PyObject* command; + PyObject* docs; + PyObject* ctx = NULL; + PyObject* to_publish = NULL; + PyObject* result = NULL; + codec_options_t options; + struct module_state *state = GETSTATE(self); + + if (!PyArg_ParseTuple(args, "et#bOObO&O", "utf-8", + &ns, &ns_len, &op, &command, &docs, &check_keys, + convert_codec_options, &options, + &ctx)) { + return NULL; + } + buffer_t buffer; + if (!(buffer = buffer_new())) { + PyErr_NoMemory(); + PyMem_Free(ns); + destroy_codec_options(&options); + return NULL; + } + /* Save space for message length and request id */ + if ((buffer_save_space(buffer, 8)) == -1) { + PyErr_NoMemory(); + goto fail; + } + if (!buffer_write_bytes(buffer, + "\x00\x00\x00\x00" /* responseTo */ + "\xd4\x07\x00\x00", /* opcode */ + 8)) { + goto fail; + } + if (!(to_publish = PyList_New(0))) { + goto fail; + } + + if (!_batched_write_command( + ns, + ns_len, + op, + check_keys, + command, + docs, + ctx, + to_publish, + options, + buffer, + state)) { + goto fail; + } + + request_id = rand(); + position = buffer_get_position(buffer); buffer_write_int32_at_position(buffer, 0, (int32_t)position); buffer_write_int32_at_position(buffer, 4, (int32_t)request_id); result = Py_BuildValue("i" BYTES_FORMAT_STRING "O", request_id, buffer_get_buffer(buffer), buffer_get_position(buffer), to_publish); - - Py_DECREF(to_publish); +fail: + PyMem_Free(ns); + destroy_codec_options(&options); buffer_free(buffer); - destroy_codec_options(&options); - return result; - -cmditerfail: - Py_XDECREF(doc); - Py_DECREF(iterator); -cmdfail: - destroy_codec_options(&options); Py_XDECREF(to_publish); - buffer_free(buffer); - return NULL; + return result; } static PyMethodDef _CMessageMethods[] = { @@ -1255,6 +1350,8 @@ static PyMethodDef _CMessageMethods[] = { "insert a batch of documents, splitting the batch as needed"}, {"_do_batched_write_command", _cbson_do_batched_write_command, METH_VARARGS, "Create the next batched insert, update, or delete command"}, + {"_encode_batched_write_command", _cbson_encode_batched_write_command, METH_VARARGS, + "Encode the next batched insert, update, or delete command"}, {NULL, NULL, 0, NULL} }; diff --git a/pymongo/bulk.py b/pymongo/bulk.py index 7df1c16fb..fdcef627d 100644 --- a/pymongo/bulk.py +++ b/pymongo/bulk.py @@ -35,6 +35,7 @@ from pymongo.errors import (BulkWriteError, from pymongo.message import (_INSERT, _UPDATE, _DELETE, _do_batched_insert, _do_batched_write_command, + _do_batched_write_command_compressed, _randint, _BulkWriteContext) from pymongo.read_preferences import ReadPreference @@ -251,6 +252,11 @@ class _Bulk(object): self.current_run = next(generator) run = self.current_run + if sock_info.compression_context: + do_writes = _do_batched_write_command_compressed + else: + do_writes = _do_batched_write_command + # sock_info.command validates the session, but we use # sock_info.write_command. sock_info.validate_session(client, session) @@ -272,7 +278,7 @@ class _Bulk(object): check_keys = run.op_type == _INSERT ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible. - request_id, msg, to_send = _do_batched_write_command( + request_id, msg, to_send = do_writes( self.namespace, run.op_type, cmd, ops, check_keys, self.collection.codec_options, bwc) if not to_send: diff --git a/pymongo/client_options.py b/pymongo/client_options.py index 8c5c275ba..b4232b8c1 100644 --- a/pymongo/client_options.py +++ b/pymongo/client_options.py @@ -18,6 +18,7 @@ from bson.codec_options import _parse_codec_options from pymongo.auth import _build_credentials_tuple from pymongo.common import validate_boolean from pymongo import common +from pymongo.compression_support import CompressionSettings from pymongo.errors import ConfigurationError from pymongo.monitoring import _EventListeners from pymongo.pool import PoolOptions @@ -120,6 +121,9 @@ def _parse_pool_options(options): wait_queue_multiple = options.get('waitqueuemultiple') event_listeners = options.get('event_listeners') appname = options.get('appname') + compression_settings = CompressionSettings( + options.get('compressors', []), + options.get('zlibcompressionlevel', -1)) ssl_context, ssl_match_hostname = _parse_ssl_options(options) return PoolOptions(max_pool_size, min_pool_size, @@ -128,7 +132,8 @@ def _parse_pool_options(options): wait_queue_timeout, wait_queue_multiple, ssl_context, ssl_match_hostname, socket_keepalive, _EventListeners(event_listeners), - appname) + appname, + compression_settings) class ClientOptions(object): diff --git a/pymongo/collection.py b/pymongo/collection.py index f83a172ff..f8c3f4cf7 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -504,6 +504,7 @@ class Collection(common.BaseObject): if publish: start = datetime.datetime.now() + args = args + (sock_info.compression_context,) rqst_id, msg, max_size = func(*args) if publish: duration = datetime.datetime.now() - start diff --git a/pymongo/common.py b/pymongo/common.py index 374cb7213..9e05f02ef 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -26,6 +26,8 @@ from bson.codec_options import CodecOptions from bson.py3compat import abc, integer_types, iteritems, string_type from bson.raw_bson import RawBSONDocument from pymongo.auth import MECHANISMS +from pymongo.compression_support import (validate_compressors, + validate_zlib_compression_level) from pymongo.errors import ConfigurationError from pymongo.monitoring import _validate_event_listeners from pymongo.read_concern import ReadConcern @@ -539,6 +541,8 @@ URI_VALIDATORS = { 'appname': validate_appname_or_none, 'unicode_decode_error_handler': validate_unicode_decode_error_handler, 'retrywrites': validate_boolean_or_string, + 'compressors': validate_compressors, + 'zlibcompressionlevel': validate_zlib_compression_level } TIMEOUT_VALIDATORS = { diff --git a/pymongo/compression_support.py b/pymongo/compression_support.py new file mode 100644 index 000000000..a4987df08 --- /dev/null +++ b/pymongo/compression_support.py @@ -0,0 +1,124 @@ +# Copyright 2018 MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +try: + import snappy + _HAVE_SNAPPY = True +except ImportError: + # python-snappy isn't available. + _HAVE_SNAPPY = False + +try: + import zlib + _HAVE_ZLIB = True +except ImportError: + # Python built without zlib support. + _HAVE_ZLIB = False + +from pymongo.monitoring import _SENSITIVE_COMMANDS + +_SUPPORTED_COMPRESSORS = set(["snappy", "zlib"]) +_NO_COMPRESSION = set(['ismaster']) +_NO_COMPRESSION.update(_SENSITIVE_COMMANDS) + + +def validate_compressors(dummy, value): + compressors = value.split(",") + for compressor in compressors[:]: + if compressor not in _SUPPORTED_COMPRESSORS: + compressors.remove(compressor) + warnings.warn("Unsupported compressor: %s" % (compressor,)) + elif compressor == "snappy" and not _HAVE_SNAPPY: + compressors.remove(compressor) + warnings.warn( + "Wire protocol compression with snappy is not available. " + "You must install the python-snappy module for snappy support.") + elif compressor == "zlib" and not _HAVE_ZLIB: + compressors.remove(compressor) + warnings.warn( + "Wire protocol compression with zlib is not available. " + "The zlib module is not available.") + return compressors + + +def validate_zlib_compression_level(option, value): + try: + level = int(value) + except: + raise TypeError("%s must be an integer, not %r." % (option, value)) + if level < -1 or level > 9: + raise ValueError( + "%s must be between -1 and 9, not %d." % (option, level)) + return level + + +class CompressionSettings(object): + def __init__(self, compressors, zlib_compression_level): + self.compressors = compressors + self.zlib_compression_level = zlib_compression_level + + def get_compression_context(self, compressors): + if compressors: + chosen = compressors[0] + if chosen == "snappy": + return SnappyContext() + elif chosen == "zlib": + return ZlibContext(self.zlib_compression_level) + + +def _zlib_no_compress(data): + """Compress data with zlib level 0.""" + cobj = zlib.compressobj(0) + return b"".join([cobj.compress(data), cobj.flush()]) + + +class SnappyContext(object): + compressor_id = 1 + + @staticmethod + def compress(data): + return snappy.compress(data) + + +class ZlibContext(object): + compressor_id = 2 + + def __init__(self, level): + # Jython zlib.compress doesn't support -1 + if level == -1: + self.compress = zlib.compress + # Jython zlib.compress also doesn't support 0 + elif level == 0: + self.compress = _zlib_no_compress + else: + self.compress = lambda data: zlib.compress(data, level) + + +def decompress(data, compressor_id): + if compressor_id == SnappyContext.compressor_id: + # python-snappy doesn't support the buffer interface. + # https://github.com/andrix/python-snappy/issues/65 + # This only matters when data is a memoryview since + # id(bytes(data)) == id(data) when data is a bytes. + # NOTE: bytes(memoryview) returns the memoryview repr + # in Python 2.7. The right thing to do in 2.7 is call + # memoryview.tobytes(), but we currently only use + # memoryview in Python 3.x. + return snappy.uncompress(bytes(data)) + elif compressor_id == ZlibContext.compressor_id: + return zlib.decompress(data) + else: + raise ValueError("Unknown compressorId %d" % (compressor_id,)) diff --git a/pymongo/ismaster.py b/pymongo/ismaster.py index a20ef0d12..e723ff0a9 100644 --- a/pymongo/ismaster.py +++ b/pymongo/ismaster.py @@ -152,3 +152,7 @@ class IsMaster(object): @property def last_write_date(self): return self._doc.get('lastWrite', {}).get('lastWriteDate') + + @property + def compressors(self): + return self._doc.get('compression') diff --git a/pymongo/message.py b/pymongo/message.py index bb2a0f63e..1717bb627 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -25,7 +25,7 @@ import random import struct import bson -from bson import CodecOptions +from bson import CodecOptions, _make_c_string, _dict_to_bson from bson.codec_options import DEFAULT_CODEC_OPTIONS from bson.py3compat import b, StringIO from bson.son import SON @@ -326,7 +326,8 @@ class _Query(object): self.read_preference) return query(flags, ns, self.ntoskip, ntoreturn, - spec, None if use_cmd else self.fields, self.codec_options) + spec, None if use_cmd else self.fields, + self.codec_options, ctx=sock_info.compression_context) class _GetMore(object): @@ -375,14 +376,15 @@ class _GetMore(object): """Get a getmore message.""" ns = _UJOIN % (self.db, self.coll) + ctx = sock_info.compression_context if use_cmd: ns = _UJOIN % (self.db, "$cmd") spec = self.as_command(sock_info)[0] - return query(0, ns, 0, -1, spec, None, self.codec_options) + return query(0, ns, 0, -1, spec, None, self.codec_options, ctx=ctx) - return get_more(ns, self.ntoreturn, self.cursor_id) + return get_more(ns, self.ntoreturn, self.cursor_id, ctx) class _RawBatchQuery(_Query): @@ -436,6 +438,25 @@ class _CursorAddress(tuple): return not self == other +_pack_compression_header = struct.Struct(" ctx.max_bson_size) @@ -758,8 +926,12 @@ def _do_batched_insert(collection_name, docs, check_keys, if has_docs: # We have enough data, send this message. try: - request_id, msg = _insert_message(data.getvalue(), send_safe) - ctx.legacy_write(request_id, msg, 0, send_safe, to_send) + if compress: + rid, msg = None, data.getvalue() + else: + rid, msg = _insert_message(data.getvalue(), send_safe) + ctx.legacy_bulk_insert( + rid, msg, 0, send_safe, to_send, compress) # Exception type could be OperationFailure or a subtype # (e.g. DuplicateKeyError) except OperationFailure as exc: @@ -787,8 +959,11 @@ def _do_batched_insert(collection_name, docs, check_keys, if not has_docs: raise InvalidOperation("cannot do an empty bulk insert") - request_id, msg = _insert_message(data.getvalue(), safe) - ctx.legacy_write(request_id, msg, 0, safe, to_send) + if compress: + request_id, msg = None, data.getvalue() + else: + request_id, msg = _insert_message(data.getvalue(), safe) + ctx.legacy_bulk_insert(request_id, msg, 0, safe, to_send, compress) # Re-raise any exception stored due to continue_on_error if last_error is not None: @@ -797,21 +972,69 @@ if _use_c: _do_batched_insert = _cmessage._do_batched_insert -def _do_batched_write_command(namespace, operation, command, - docs, check_keys, opts, ctx): +def _do_batched_write_command_compressed( + namespace, operation, command, docs, check_keys, opts, ctx): + """Create the next batched insert, update, or delete command, compressed. + """ + data, to_send = _encode_batched_write_command( + namespace, operation, command, docs, check_keys, opts, ctx) + + request_id, msg = _compress( + 2004, + data, + ctx.sock_info.compression_context) + return request_id, msg, to_send + + +def _encode_batched_write_command( + namespace, operation, command, docs, check_keys, opts, ctx): + """Encode the next batched insert, update, or delete command. + """ + buf = StringIO() + + to_send, _ = _batched_write_command( + namespace, operation, command, docs, check_keys, opts, ctx, buf) + return buf.getvalue(), to_send +if _use_c: + _encode_batched_write_command = _cmessage._encode_batched_write_command + + +def _do_batched_write_command( + namespace, operation, command, docs, check_keys, opts, ctx): """Create the next batched insert, update, or delete command. """ + buf = StringIO() + + # Save space for message length and request id + buf.write(_ZERO_64) + # responseTo, opCode + buf.write(b"\x00\x00\x00\x00\xd4\x07\x00\x00") + + # Write OP_QUERY write command + to_send, length = _batched_write_command( + namespace, operation, command, docs, check_keys, opts, ctx, buf) + + # Header - request id and message length + buf.seek(4) + request_id = _randint() + buf.write(struct.pack('`_ package. + zlib support requires the Python standard library zlib module. + By default no compression is used. Compression support must also be + enabled on the server. MongoDB 3.4+ supports snappy compression. + MongoDB 3.6+ supports snappy and zlib. + - `zlibCompressionLevel`: (int) The zlib compression level to use + when zlib is used as the wire protocol compressor. Supported values + are -1 through 9. -1 tells the zlib library to use its default + compression level (usually 6). 0 means no compression. 1 is best + speed. 9 is best compression. Defaults to -1. | **Write Concern options:** | (Only set if passed. No default values.) diff --git a/pymongo/network.py b/pymongo/network.py index e50caea25..444e4a71c 100644 --- a/pymongo/network.py +++ b/pymongo/network.py @@ -38,6 +38,7 @@ from bson.py3compat import PY3 from pymongo import helpers, message from pymongo.common import MAX_MESSAGE_SIZE +from pymongo.compression_support import decompress, _NO_COMPRESSION from pymongo.errors import (AutoReconnect, NotMasterError, OperationFailure, @@ -54,7 +55,8 @@ def command(sock, dbname, spec, slave_ok, is_mongos, check_keys=False, listeners=None, max_bson_size=None, read_concern=None, parse_write_concern_error=False, - collation=None): + collation=None, + compression_ctx=None): """Execute a command over the socket, or raise socket.error. :Parameters: @@ -101,8 +103,12 @@ def command(sock, dbname, spec, slave_ok, is_mongos, if publish: start = datetime.datetime.now() - request_id, msg, size = message.query(flags, ns, 0, -1, spec, - None, codec_options, check_keys) + if name.lower() not in _NO_COMPRESSION and compression_ctx: + request_id, msg, size = message.query( + flags, ns, 0, -1, spec, None, codec_options, check_keys, compression_ctx) + else: + request_id, msg, size = message.query( + flags, ns, 0, -1, spec, None, codec_options, check_keys) if (max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD): @@ -142,15 +148,13 @@ def command(sock, dbname, spec, slave_ok, is_mongos, duration, response_doc, name, request_id, address) return response_doc +_UNPACK_COMPRESSION_HEADER = struct.Struct(" max_message_size: raise ProtocolError("Message length (%r) is larger than server max " "message size (%r)" % (length, max_message_size)) + if op_code == 2012: + op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER( + _receive_data_on_socket(sock, 9)) + data = decompress( + _receive_data_on_socket(sock, length - 25), compressor_id) + else: + data = _receive_data_on_socket(sock, length - 16) + if op_code != _OpReply.OP_CODE: + raise ProtocolError("Got opcode %r but expected " + "%r" % (op_code, _OpReply.OP_CODE)) - return _OpReply.unpack(_receive_data_on_socket(sock, length - 16)) + return _OpReply.unpack(data) # memoryview was introduced in Python 2.7 but we only use it on Python 3 diff --git a/pymongo/pool.py b/pymongo/pool.py index 738a3f80f..c6a111248 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -280,14 +280,16 @@ class PoolOptions(object): '__connect_timeout', '__socket_timeout', '__wait_queue_timeout', '__wait_queue_multiple', '__ssl_context', '__ssl_match_hostname', '__socket_keepalive', - '__event_listeners', '__appname', '__metadata') + '__event_listeners', '__appname', '__metadata', + '__compression_settings') def __init__(self, max_pool_size=100, min_pool_size=0, max_idle_time_seconds=None, connect_timeout=None, socket_timeout=None, wait_queue_timeout=None, wait_queue_multiple=None, ssl_context=None, ssl_match_hostname=True, socket_keepalive=True, - event_listeners=None, appname=None): + event_listeners=None, appname=None, + compression_settings=None): self.__max_pool_size = max_pool_size self.__min_pool_size = min_pool_size @@ -301,6 +303,7 @@ class PoolOptions(object): self.__socket_keepalive = socket_keepalive self.__event_listeners = event_listeners self.__appname = appname + self.__compression_settings = compression_settings self.__metadata = _METADATA.copy() if appname: self.__metadata['application'] = {'name': appname} @@ -392,6 +395,10 @@ class PoolOptions(object): """ return self.__appname + @property + def compression_settings(self): + return self.__compression_settings + @property def metadata(self): """A dict of metadata about the application, driver, os, and platform. @@ -423,6 +430,8 @@ class SocketInfo(object): self.is_mongos = False self.listeners = pool.opts.event_listeners + self.compression_settings = pool.opts.compression_settings + self.compression_context = None # The pool's pool_id changes with each reset() so we can close sockets # created before the last reset. @@ -432,7 +441,8 @@ class SocketInfo(object): cmd = SON([('ismaster', 1)]) if not self.performed_handshake: cmd['client'] = metadata - self.performed_handshake = True + if self.compression_settings: + cmd['compression'] = self.compression_settings.compressors if self.max_wire_version >= 6 and cluster_time is not None: cmd['$clusterTime'] = cluster_time @@ -446,6 +456,12 @@ class SocketInfo(object): self.supports_sessions = ( ismaster.logical_session_timeout_minutes is not None) self.is_mongos = ismaster.server_type == SERVER_TYPE.Mongos + if not self.performed_handshake and self.compression_settings: + ctx = self.compression_settings.get_compression_context( + ismaster.compressors) + self.compression_context = ctx + + self.performed_handshake = True return ismaster def command(self, dbname, spec, slave_ok=False, @@ -512,7 +528,8 @@ class SocketInfo(object): self.address, check_keys, listeners, self.max_bson_size, read_concern, parse_write_concern_error=parse_write_concern_error, - collation=collation) + collation=collation, + compression_ctx=self.compression_context) except OperationFailure: raise # Catch socket.error, KeyboardInterrupt, etc. and close ourselves. diff --git a/setup.py b/setup.py index 492224c5d..692a14fa8 100755 --- a/setup.py +++ b/setup.py @@ -323,11 +323,14 @@ ext_modules = [Extension('bson._cbson', sources=['pymongo/_cmessagemodule.c', 'bson/buffer.c'])] +extras_require = {'snappy': ["python-snappy"]} vi = sys.version_info if vi[0] == 2: - extras_require = {'tls': ["ipaddress"], 'srv': ["dnspython>=1.8.0,<2.0.0"]} + extras_require.update( + {'tls': ["ipaddress"], 'srv': ["dnspython>=1.8.0,<2.0.0"]}) else: - extras_require = {'tls': [], 'srv': ["dnspython>=1.13.0,<2.0.0"]} + extras_require.update( + {'tls': [], 'srv': ["dnspython>=1.13.0,<2.0.0"]}) if sys.platform == 'win32': extras_require['gssapi'] = ["winkerberos>=0.5.0"] if vi[0] == 2 and vi < (2, 7, 9) or vi[0] == 3 and vi < (3, 4): diff --git a/test/__init__.py b/test/__init__.py index 349f30a96..3e54523ec 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -72,6 +72,7 @@ if CA_PEM: if CERT_REQS is not None: _SSL_OPTIONS['ssl_cert_reqs'] = CERT_REQS +COMPRESSORS = os.environ.get("COMPRESSORS") def is_server_resolvable(): """Returns True if 'server' is resolvable.""" @@ -186,16 +187,19 @@ class ClientContext(object): self.ssl_cert_none = False self.ssl_certfile = False self.server_is_resolvable = is_server_resolvable() - self.ssl_client_options = {} + self.default_client_options = {} self.sessions_enabled = False self.client = self._connect(host, port) + if COMPRESSORS: + self.default_client_options["compressors"] = COMPRESSORS + if HAVE_SSL and not self.client: # Is MongoDB configured for SSL? self.client = self._connect(host, port, **_SSL_OPTIONS) if self.client: self.ssl = True - self.ssl_client_options = _SSL_OPTIONS + self.default_client_options.update(_SSL_OPTIONS) self.ssl_certfile = True if _SSL_OPTIONS.get('ssl_cert_reqs') == ssl.CERT_NONE: self.ssl_cert_none = True @@ -223,7 +227,7 @@ class ClientContext(object): self.client = self._connect( host, port, username=db_user, password=db_pwd, replicaSet=self.replica_set_name, - **self.ssl_client_options) + **self.default_client_options) # May not have this if OperationFailure was raised earlier. self.cmd_line = self.client.admin.command('getCmdLineOpts') @@ -242,13 +246,13 @@ class ClientContext(object): username=db_user, password=db_pwd, replicaSet=self.replica_set_name, - **self.ssl_client_options) + **self.default_client_options) else: self.client = pymongo.MongoClient( host, port, replicaSet=self.replica_set_name, - **self.ssl_client_options) + **self.default_client_options) # Get the authoritative ismaster result from the primary. self.ismaster = self.client.admin.command('ismaster') @@ -285,6 +289,8 @@ class ClientContext(object): timeout_ms = 10000 else: timeout_ms = 500 + if COMPRESSORS: + kwargs["compressors"] = COMPRESSORS client = pymongo.MongoClient( host, port, serverSelectionTimeoutMS=timeout_ms, **kwargs) try: @@ -339,7 +345,7 @@ class ClientContext(object): username=db_user, password=db_pwd, serverSelectionTimeoutMS=100, - **self.ssl_client_options) + **self.default_client_options) try: return db_user in _all_users(client.admin) diff --git a/test/pymongo_mocks.py b/test/pymongo_mocks.py index 542e78f64..63ee695c3 100644 --- a/test/pymongo_mocks.py +++ b/test/pymongo_mocks.py @@ -123,7 +123,7 @@ class MockClient(MongoClient): kwargs['_pool_class'] = partial(MockPool, self) kwargs['_monitor_class'] = partial(MockMonitor, self) - client_options = client_context.ssl_client_options.copy() + client_options = client_context.default_client_options.copy() client_options.update(kwargs) super(MockClient, self).__init__(*args, **client_options) diff --git a/test/test_client.py b/test/test_client.py index e647a4788..df24737ac 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -18,6 +18,7 @@ import contextlib import datetime import gc import os +import platform import signal import socket import struct @@ -35,6 +36,7 @@ from bson.tz_util import utc from pymongo import auth, message from pymongo.common import _UUID_REPRESENTATIONS from pymongo.command_cursor import CommandCursor +from pymongo.compression_support import _HAVE_SNAPPY from pymongo.cursor import CursorType from pymongo.database import Database from pymongo.errors import (AutoReconnect, @@ -382,7 +384,7 @@ class TestClient(IntegrationTest): port are not overloaded. """ host, port = client_context.host, client_context.port - kwargs = client_context.ssl_client_options.copy() + kwargs = client_context.default_client_options.copy() if client_context.auth_enabled: kwargs['username'] = db_user kwargs['password'] = db_pwd @@ -1180,6 +1182,76 @@ class TestClient(IntegrationTest): self.assertIn('heartbeatFrequencyMS', str(context.exception)) + def test_compression(self): + def compression_settings(client): + pool_options = client._MongoClient__options.pool_options + return pool_options.compression_settings + + uri = "mongodb://localhost:27017/?compressors=zlib" + client = MongoClient(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ['zlib']) + uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=4" + client = MongoClient(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ['zlib']) + self.assertEqual(opts.zlib_compression_level, 4) + uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-1" + client = MongoClient(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ['zlib']) + self.assertEqual(opts.zlib_compression_level, -1) + uri = "mongodb://localhost:27017" + client = MongoClient(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, []) + self.assertEqual(opts.zlib_compression_level, -1) + uri = "mongodb://localhost:27017/?compressors=foobar" + client = MongoClient(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, []) + self.assertEqual(opts.zlib_compression_level, -1) + uri = "mongodb://localhost:27017/?compressors=foobar,zlib" + client = MongoClient(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ['zlib']) + self.assertEqual(opts.zlib_compression_level, -1) + + # According to the connection string spec, unsupported values + # just raise a warning and are ignored. + uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=10" + client = MongoClient(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ['zlib']) + self.assertEqual(opts.zlib_compression_level, -1) + uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-2" + client = MongoClient(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ['zlib']) + self.assertEqual(opts.zlib_compression_level, -1) + + if not _HAVE_SNAPPY: + uri = "mongodb://localhost:27017/?compressors=snappy" + client = MongoClient(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, []) + else: + uri = "mongodb://localhost:27017/?compressors=snappy" + client = MongoClient(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ['snappy']) + uri = "mongodb://localhost:27017/?compressors=snappy,zlib" + client = MongoClient(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ['snappy', 'zlib']) + + options = client_context.default_client_options + if "compressors" in options and "zlib" in options["compressors"]: + for level in range(-1, 10): + client = single_client(zlibcompressionlevel=level) + # No error + client.pymongo_test.test.find_one() + class TestExhaustCursor(IntegrationTest): """Test that clients properly handle errors from exhaust cursors.""" diff --git a/test/test_dns.py b/test/test_dns.py index fa15fcbc8..2fec78d0f 100644 --- a/test/test_dns.py +++ b/test/test_dns.py @@ -32,7 +32,7 @@ from test.utils import wait_until _TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), 'dns') -_SSL_OPTS = client_context.ssl_client_options.copy() +_SSL_OPTS = client_context.default_client_options.copy() if client_context.ssl is True: # Our test certs don't support the SRV hosts used in these tests. _SSL_OPTS['ssl_match_hostname'] = False diff --git a/test/test_read_preferences.py b/test/test_read_preferences.py index f2e0e5c83..cd37db9ed 100644 --- a/test/test_read_preferences.py +++ b/test/test_read_preferences.py @@ -310,7 +310,7 @@ class TestReadPreferences(TestReadPreferencesBase): class ReadPrefTester(MongoClient): def __init__(self, *args, **kwargs): self.has_read_from = set() - client_options = client_context.ssl_client_options.copy() + client_options = client_context.default_client_options.copy() client_options.update(kwargs) super(ReadPrefTester, self).__init__(*args, **client_options) diff --git a/test/utils.py b/test/utils.py index aa13f2556..8571e0d4d 100644 --- a/test/utils.py +++ b/test/utils.py @@ -33,9 +33,7 @@ from pymongo.monitoring import _SENSITIVE_COMMANDS from pymongo.server_selectors import (any_server_selector, writable_server_selector) from pymongo.write_concern import WriteConcern -from test import (client_context, - db_user, - db_pwd) +from test import (client_context, db_user, db_pwd) IMPOSSIBLE_WRITE_CONCERN = WriteConcern(w=1000) @@ -134,8 +132,7 @@ def _connection_string(h, p, authenticate): def _mongo_client(host, port, authenticate=True, direct=False, **kwargs): """Create a new client over SSL/TLS if necessary.""" - client_options = client_context.ssl_client_options.copy() - + client_options = client_context.default_client_options.copy() if client_context.replica_set_name and not direct: client_options['replicaSet'] = client_context.replica_set_name client_options.update(kwargs)