PYTHON-1278, PYTHON-1553 - Support OP_COMPRESSED

This commit is contained in:
Bernie Hackett 2018-05-21 11:56:44 -07:00
parent 505b6ebc37
commit 335cb97a34
23 changed files with 934 additions and 286 deletions

View File

@ -415,7 +415,7 @@ functions:
if [ -n "${MONGODB_STARTED}" ]; then
export PYMONGO_MUST_CONNECT=1
fi
PYTHON_BINARY=${PYTHON_BINARY} GREEN_FRAMEWORK=${GREEN_FRAMEWORK} C_EXTENSIONS=${C_EXTENSIONS} COVERAGE=${COVERAGE} AUTH=${AUTH} SSL=${SSL} sh ${PROJECT_DIRECTORY}/.evergreen/run-tests.sh
PYTHON_BINARY=${PYTHON_BINARY} GREEN_FRAMEWORK=${GREEN_FRAMEWORK} C_EXTENSIONS=${C_EXTENSIONS} COVERAGE=${COVERAGE} COMPRESSORS=${COMPRESSORS} AUTH=${AUTH} SSL=${SSL} sh ${PROJECT_DIRECTORY}/.evergreen/run-tests.sh
"run enterprise auth tests":
- command: shell.exec
@ -968,6 +968,17 @@ axes:
variables:
AUTH: "noauth"
SSL: "nossl"
- id: compression
display_name: Compression
values:
- id: snappy
display_name: snappy compression
variables:
COMPRESSORS: "snappy"
- id: zlib
display_name: zlib compression
variables:
COMPRESSORS: "zlib"
- id: python-version
display_name: "Python"
values:
@ -1049,6 +1060,10 @@ axes:
display_name: "Without C Extensions"
variables:
C_EXTENSIONS: "--no_ext"
- id: "with-c-extensions"
display_name: "With C Extensions"
variables:
C_EXTENSIONS: ""
- id: storage-engine
display_name: Storage
values:
@ -1186,7 +1201,7 @@ buildvariants:
exclude_spec:
# These interpreters are always tested without extensions.
- python-version: ["pypy", "pypy3", "pypy3.5", "jython2.7"]
c-extensions: "*"
c-extensions: "without-c-extensions"
auth: "*"
ssl: "*"
coverage: "*"
@ -1200,6 +1215,35 @@ buildvariants:
- ".3.0"
- ".2.6"
- matrix_name: "tests-python-version-ubuntu16-compression"
matrix_spec: {"python-version": "*", "c-extensions": "*", "compression": "*"}
exclude_spec:
# These interpreters are always tested without extensions.
- python-version: ["pypy", "pypy3", "pypy3.5", "jython2.7"]
c-extensions: "with-c-extensions"
compression: "*"
- python-version: ["jython2.7"]
c-extensions: "*"
compression: "snappy"
batchtime: 10080 # 7 days
display_name: "${compression} ${c-extensions} ${python-version} (x86_64)"
# Ubuntu 16.04 images have libsnappy-dev installed
run_on: ubuntu1604-test
tasks:
- "test-latest-standalone"
- "test-3.6-standalone"
- matrix_name: "tests-python-version-py37-plus-compression"
matrix_spec: {"python-version-requires-openssl-102-plus": "*", "c-extensions": "*", "compression": "*"}
exclude_spec:
display_name: "${compression} ${c-extensions} ${python-version-requires-openssl-102-plus} (x86_64)"
batchtime: 10080 # 7 days
# Ubuntu 16.04 images have libsnappy-dev installed, and provides OpenSSL 1.0.2
run_on: ubuntu1604-test
tasks:
- "test-latest-standalone"
- "test-3.6-standalone"
- matrix_name: "tests-python-version-green-framework-rhel62"
matrix_spec: {"python-version": "*", "green-framework": "*", auth-ssl: "*"}
exclude_spec:

View File

@ -17,6 +17,11 @@ PYTHON_BINARY=${PYTHON_BINARY:-}
GREEN_FRAMEWORK=${GREEN_FRAMEWORK:-}
C_EXTENSIONS=${C_EXTENSIONS:-}
COVERAGE=${COVERAGE:-}
COMPRESSORS=${COMPRESSORS:-}
if [ -n $COMPRESSORS ]; then
export COMPRESSORS=$COMPRESSORS
fi
export JAVA_HOME=/opt/java/jdk8
@ -47,6 +52,13 @@ if [ -z "$PYTHON_BINARY" ]; then
PYTHON=python
trap "deactivate; rm -rf pymongotestvenv" EXIT HUP
fi
elif [ $COMPRESSORS = "snappy" ]; then
$PYTHON_BINARY -m virtualenv --system-site-packages --never-download snappytest
. snappytest/bin/activate
trap "deactivate; rm -rf snappytest" EXIT HUP
# 0.5.2 has issues in pypy3(.5)
pip install python-snappy==0.5.1
PYTHON=python
else
PYTHON="$PYTHON_BINARY"
fi

View File

@ -113,10 +113,15 @@ PyMongo::
$ python -m pip install pymongo[tls]
Wire protocol compression with snappy requires `python-snappy
<https://pypi.org/project/python-snappy>`_::
$ python -m pip install pymongo[snappy]
You can install all dependencies automatically with the following
command::
$ python -m pip install pymongo[gssapi,srv,tls]
$ python -m pip install pymongo[snappy,gssapi,srv,tls]
Other optional packages:

View File

@ -9,6 +9,8 @@ Version 3.7 adds support for MongoDB 4.0. Highlights include:
- Support for multi-document transactions, see :ref:`transactions-ref`.
- Support for the SCRAM-SHA-256 authentication mechanism.
- Support for Python 3.7.
- Support for wire protocol compression. See
:meth:`~pymongo.mongo_client.MongoClient` for details.
- MD5 is now optional in GridFS.
- If not specified, the authSource for the PLAIN authentication mechanism
defaults to $external.

View File

@ -70,10 +70,15 @@ PyMongo::
$ python -m pip install pymongo[tls]
Wire protocol compression with snappy requires `python-snappy
<https://pypi.org/project/python-snappy>`_::
$ python -m pip install pymongo[snappy]
You can install all dependencies automatically with the following
command::
$ python -m pip install pymongo[gssapi,srv,tls]
$ python -m pip install pymongo[snappy,gssapi,srv,tls]
Other optional packages:

View File

@ -142,19 +142,25 @@ static int add_last_error(PyObject* self, buffer_t buffer,
}
static int init_insert_buffer(buffer_t buffer, int request_id, int options,
const char* coll_name, int coll_name_len) {
/* Save space for message length */
int length_location = buffer_save_space(buffer, 4);
if (length_location == -1) {
PyErr_NoMemory();
return length_location;
const char* coll_name, int coll_name_len,
int compress) {
int length_location = 0;
if (!compress) {
/* Save space for message length */
int length_location = buffer_save_space(buffer, 4);
if (length_location == -1) {
PyErr_NoMemory();
return length_location;
}
if (!buffer_write_int32(buffer, (int32_t)request_id) ||
!buffer_write_bytes(buffer,
"\x00\x00\x00\x00"
"\xd2\x07\x00\x00",
8)) {
return -1;
}
}
if (!buffer_write_int32(buffer, (int32_t)request_id) ||
!buffer_write_bytes(buffer,
"\x00\x00\x00\x00"
"\xd2\x07\x00\x00",
8) ||
!buffer_write_int32(buffer, (int32_t)options) ||
if (!buffer_write_int32(buffer, (int32_t)options) ||
!buffer_write_bytes(buffer,
coll_name,
coll_name_len + 1)) {
@ -211,7 +217,8 @@ static PyObject* _cbson_insert_message(PyObject* self, PyObject* args) {
request_id,
flags,
collection_name,
collection_name_length);
collection_name_length,
0);
if (length_location == -1) {
destroy_codec_options(&options);
PyMem_Free(collection_name);
@ -660,7 +667,7 @@ static PyObject*
_send_insert(PyObject* self, PyObject* ctx,
PyObject* gle_args, buffer_t buffer,
char* coll_name, int coll_len, int request_id, int safe,
codec_options_t* options, PyObject* to_publish) {
codec_options_t* options, PyObject* to_publish, int compress) {
if (safe) {
if (!add_last_error(self, buffer, request_id,
@ -669,16 +676,16 @@ _send_insert(PyObject* self, PyObject* ctx,
}
}
/* The max_doc_size parameter for legacy_write is the max size of any
* document in buffer. We enforced max size already, pass 0 here. */
return PyObject_CallMethod(ctx, "legacy_write",
"i" BYTES_FORMAT_STRING "iNO",
/* The max_doc_size parameter for legacy_bulk_insert is the max size of
* any document in buffer. We enforced max size already, pass 0 here. */
return PyObject_CallMethod(ctx, "legacy_bulk_insert",
"i" BYTES_FORMAT_STRING "iNOi",
request_id,
buffer_get_buffer(buffer),
buffer_get_position(buffer),
0,
PyBool_FromLong((long)safe),
to_publish);
to_publish, compress);
}
static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) {
@ -689,6 +696,7 @@ static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) {
int send_safe, flags = 0;
int length_location, message_length;
int collection_name_length;
int compress;
char* collection_name = NULL;
PyObject* docs;
PyObject* doc;
@ -698,6 +706,7 @@ static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) {
PyObject* result;
PyObject* max_bson_size_obj;
PyObject* max_message_size_obj;
PyObject* compress_obj;
PyObject* to_publish = NULL;
unsigned char check_keys;
unsigned char safe;
@ -754,6 +763,17 @@ static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) {
return NULL;
}
compress_obj = PyObject_GetAttrString(ctx, "compress");
compress = PyObject_IsTrue(compress_obj);
Py_XDECREF(compress_obj);
if (compress == -1) {
destroy_codec_options(&options);
PyMem_Free(collection_name);
return NULL;
}
compress = compress && !(safe || send_safe);
buffer = buffer_new();
if (!buffer) {
destroy_codec_options(&options);
@ -766,7 +786,8 @@ static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) {
request_id,
flags,
collection_name,
collection_name_length);
collection_name_length,
compress);
if (length_location == -1) {
goto insertfail;
}
@ -797,13 +818,15 @@ static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) {
/* If we've encoded anything send it before raising. */
if (!empty) {
buffer_update_position(buffer, before);
message_length = buffer_get_position(buffer) - length_location;
buffer_write_int32_at_position(
buffer, length_location, (int32_t)message_length);
if (!compress) {
message_length = buffer_get_position(buffer) - length_location;
buffer_write_int32_at_position(
buffer, length_location, (int32_t)message_length);
}
result = _send_insert(self, ctx, last_error_args, buffer,
collection_name, collection_name_length,
request_id, send_safe, &options,
to_publish);
to_publish, compress);
if (!result)
goto iterfail;
Py_DECREF(result);
@ -826,7 +849,8 @@ static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) {
new_request_id,
flags,
collection_name,
collection_name_length);
collection_name_length,
compress);
if (message_start == -1) {
buffer_free(new_buffer);
goto iterfail;
@ -841,13 +865,16 @@ static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) {
/* Roll back to the beginning of this document. */
buffer_update_position(buffer, before);
message_length = buffer_get_position(buffer) - length_location;
buffer_write_int32_at_position(
buffer, length_location, (int32_t)message_length);
if (!compress) {
message_length = buffer_get_position(buffer) - length_location;
buffer_write_int32_at_position(
buffer, length_location, (int32_t)message_length);
}
result = _send_insert(self, ctx, last_error_args, buffer,
collection_name, collection_name_length,
request_id, send_safe, &options, to_publish);
request_id, send_safe, &options, to_publish,
compress);
buffer_free(buffer);
buffer = new_buffer;
@ -927,14 +954,16 @@ static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) {
goto insertfail;
}
message_length = buffer_get_position(buffer) - length_location;
buffer_write_int32_at_position(
buffer, length_location, (int32_t)message_length);
if (!compress) {
message_length = buffer_get_position(buffer) - length_location;
buffer_write_int32_at_position(
buffer, length_location, (int32_t)message_length);
}
/* Send the last (or only) batch */
result = _send_insert(self, ctx, last_error_args, buffer,
collection_name, collection_name_length,
request_id, safe, &options, to_publish);
request_id, safe, &options, to_publish, compress);
Py_DECREF(to_publish);
PyMem_Free(collection_name);
@ -971,43 +1000,16 @@ insertfail:
return NULL;
}
static buffer_t
_command_buffer_new(char* ns, int ns_len) {
buffer_t buffer;
if (!(buffer = buffer_new())) {
PyErr_NoMemory();
return NULL;
}
/* Save space for message length and request id */
if ((buffer_save_space(buffer, 8)) == -1) {
PyErr_NoMemory();
buffer_free(buffer);
return NULL;
}
if (!buffer_write_bytes(buffer,
"\x00\x00\x00\x00" /* responseTo */
"\xd4\x07\x00\x00" /* opcode */
"\x00\x00\x00\x00", /* flags */
12) ||
!buffer_write_bytes(buffer,
ns, ns_len + 1) || /* namespace */
!buffer_write_bytes(buffer,
"\x00\x00\x00\x00" /* skip */
"\xFF\xFF\xFF\xFF", /* limit (-1) */
8)) {
buffer_free(buffer);
return NULL;
}
return buffer;
}
#define _INSERT 0
#define _UPDATE 1
#define _DELETE 2
static PyObject*
_cbson_do_batched_write_command(PyObject* self, PyObject* args) {
struct module_state *state = GETSTATE(self);
static int
_batched_write_command(
char* ns, int ns_len, unsigned char op, int check_keys,
PyObject* command, PyObject* docs, PyObject* ctx,
PyObject* to_publish, codec_options_t options,
buffer_t buffer, struct module_state *state) {
long max_bson_size;
long max_cmd_size;
@ -1015,31 +1017,12 @@ _cbson_do_batched_write_command(PyObject* self, PyObject* args) {
int idx = 0;
int cmd_len_loc;
int lst_len_loc;
int ns_len;
int request_id;
int position;
int length;
char *ns = NULL;
PyObject* max_bson_size_obj;
PyObject* max_write_batch_size_obj;
PyObject* command;
PyObject* doc;
PyObject* docs;
PyObject* ctx;
PyObject* iterator;
PyObject* result;
PyObject* to_publish = NULL;
unsigned char op;
unsigned char check_keys;
codec_options_t options;
buffer_t buffer;
if (!PyArg_ParseTuple(args, "et#bOObO&O", "utf-8",
&ns, &ns_len, &op, &command, &docs, &check_keys,
convert_codec_options, &options,
&ctx)) {
return NULL;
}
max_bson_size_obj = PyObject_GetAttrString(ctx, "max_bson_size");
#if PY_MAJOR_VERSION >= 3
@ -1049,9 +1032,7 @@ _cbson_do_batched_write_command(PyObject* self, PyObject* args) {
#endif
Py_XDECREF(max_bson_size_obj);
if (max_bson_size == -1) {
destroy_codec_options(&options);
PyMem_Free(ns);
return NULL;
return 0;
}
/*
* Max BSON object size + 16k - 2 bytes for ending NUL bytes
@ -1067,31 +1048,26 @@ _cbson_do_batched_write_command(PyObject* self, PyObject* args) {
#endif
Py_XDECREF(max_write_batch_size_obj);
if (max_write_batch_size == -1) {
destroy_codec_options(&options);
PyMem_Free(ns);
return NULL;
return 0;
}
if (!(to_publish = PyList_New(0))) {
destroy_codec_options(&options);
PyMem_Free(ns);
return NULL;
if (!buffer_write_bytes(buffer,
"\x00\x00\x00\x00", /* flags */
4) ||
!buffer_write_bytes(buffer,
ns, ns_len + 1) || /* namespace */
!buffer_write_bytes(buffer,
"\x00\x00\x00\x00" /* skip */
"\xFF\xFF\xFF\xFF", /* limit (-1) */
8)) {
return 0;
}
if (!(buffer = _command_buffer_new(ns, ns_len))) {
destroy_codec_options(&options);
PyMem_Free(ns);
Py_DECREF(to_publish);
return NULL;
}
PyMem_Free(ns);
/* Position of command document length */
cmd_len_loc = buffer_get_position(buffer);
if (!write_dict(state->_cbson, buffer, command, 0,
&options, 0)) {
goto cmdfail;
return 0;
}
/* Write type byte for array */
@ -1127,7 +1103,7 @@ _cbson_do_batched_write_command(PyObject* self, PyObject* args) {
PyErr_SetString(InvalidOperation, "Unknown command");
Py_DECREF(InvalidOperation);
}
goto cmdfail;
return 0;
}
}
@ -1135,7 +1111,7 @@ _cbson_do_batched_write_command(PyObject* self, PyObject* args) {
lst_len_loc = buffer_save_space(buffer, 4);
if (lst_len_loc == -1) {
PyErr_NoMemory();
goto cmdfail;
return 0;
}
iterator = PyObject_GetIter(docs);
@ -1145,7 +1121,7 @@ _cbson_do_batched_write_command(PyObject* self, PyObject* args) {
PyErr_SetString(InvalidOperation, "input is not iterable");
Py_DECREF(InvalidOperation);
}
goto cmdfail;
return 0;
}
while ((doc = PyIter_Next(iterator)) != NULL) {
int sub_doc_begin = buffer_get_position(buffer);
@ -1214,32 +1190,151 @@ _cbson_do_batched_write_command(PyObject* self, PyObject* args) {
goto cmdfail;
request_id = rand();
position = buffer_get_position(buffer);
length = position - lst_len_loc - 1;
buffer_write_int32_at_position(buffer, lst_len_loc, (int32_t)length);
length = position - cmd_len_loc;
buffer_write_int32_at_position(buffer, cmd_len_loc, (int32_t)length);
return 1;
cmditerfail:
Py_XDECREF(doc);
Py_DECREF(iterator);
cmdfail:
return 0;
}
static PyObject*
_cbson_encode_batched_write_command(PyObject* self, PyObject* args) {
char *ns = NULL;
unsigned char op;
unsigned char check_keys;
int ns_len;
PyObject* command;
PyObject* docs;
PyObject* ctx = NULL;
PyObject* to_publish = NULL;
PyObject* result = NULL;
codec_options_t options;
struct module_state *state = GETSTATE(self);
if (!PyArg_ParseTuple(args, "et#bOObO&O", "utf-8",
&ns, &ns_len, &op, &command, &docs, &check_keys,
convert_codec_options, &options,
&ctx)) {
return NULL;
}
buffer_t buffer;
if (!(buffer = buffer_new())) {
PyErr_NoMemory();
PyMem_Free(ns);
destroy_codec_options(&options);
return NULL;
}
if (!(to_publish = PyList_New(0))) {
goto fail;
}
if (!_batched_write_command(
ns,
ns_len,
op,
check_keys,
command,
docs,
ctx,
to_publish,
options,
buffer,
state)) {
goto fail;
}
result = Py_BuildValue(BYTES_FORMAT_STRING "O",
buffer_get_buffer(buffer),
buffer_get_position(buffer),
to_publish);
fail:
PyMem_Free(ns);
destroy_codec_options(&options);
buffer_free(buffer);
Py_XDECREF(to_publish);
return result;
}
static PyObject*
_cbson_do_batched_write_command(PyObject* self, PyObject* args) {
char *ns = NULL;
unsigned char op;
unsigned char check_keys;
int ns_len;
int request_id;
int position;
PyObject* command;
PyObject* docs;
PyObject* ctx = NULL;
PyObject* to_publish = NULL;
PyObject* result = NULL;
codec_options_t options;
struct module_state *state = GETSTATE(self);
if (!PyArg_ParseTuple(args, "et#bOObO&O", "utf-8",
&ns, &ns_len, &op, &command, &docs, &check_keys,
convert_codec_options, &options,
&ctx)) {
return NULL;
}
buffer_t buffer;
if (!(buffer = buffer_new())) {
PyErr_NoMemory();
PyMem_Free(ns);
destroy_codec_options(&options);
return NULL;
}
/* Save space for message length and request id */
if ((buffer_save_space(buffer, 8)) == -1) {
PyErr_NoMemory();
goto fail;
}
if (!buffer_write_bytes(buffer,
"\x00\x00\x00\x00" /* responseTo */
"\xd4\x07\x00\x00", /* opcode */
8)) {
goto fail;
}
if (!(to_publish = PyList_New(0))) {
goto fail;
}
if (!_batched_write_command(
ns,
ns_len,
op,
check_keys,
command,
docs,
ctx,
to_publish,
options,
buffer,
state)) {
goto fail;
}
request_id = rand();
position = buffer_get_position(buffer);
buffer_write_int32_at_position(buffer, 0, (int32_t)position);
buffer_write_int32_at_position(buffer, 4, (int32_t)request_id);
result = Py_BuildValue("i" BYTES_FORMAT_STRING "O", request_id,
buffer_get_buffer(buffer),
buffer_get_position(buffer),
to_publish);
Py_DECREF(to_publish);
fail:
PyMem_Free(ns);
destroy_codec_options(&options);
buffer_free(buffer);
destroy_codec_options(&options);
return result;
cmditerfail:
Py_XDECREF(doc);
Py_DECREF(iterator);
cmdfail:
destroy_codec_options(&options);
Py_XDECREF(to_publish);
buffer_free(buffer);
return NULL;
return result;
}
static PyMethodDef _CMessageMethods[] = {
@ -1255,6 +1350,8 @@ static PyMethodDef _CMessageMethods[] = {
"insert a batch of documents, splitting the batch as needed"},
{"_do_batched_write_command", _cbson_do_batched_write_command, METH_VARARGS,
"Create the next batched insert, update, or delete command"},
{"_encode_batched_write_command", _cbson_encode_batched_write_command, METH_VARARGS,
"Encode the next batched insert, update, or delete command"},
{NULL, NULL, 0, NULL}
};

View File

@ -35,6 +35,7 @@ from pymongo.errors import (BulkWriteError,
from pymongo.message import (_INSERT, _UPDATE, _DELETE,
_do_batched_insert,
_do_batched_write_command,
_do_batched_write_command_compressed,
_randint,
_BulkWriteContext)
from pymongo.read_preferences import ReadPreference
@ -251,6 +252,11 @@ class _Bulk(object):
self.current_run = next(generator)
run = self.current_run
if sock_info.compression_context:
do_writes = _do_batched_write_command_compressed
else:
do_writes = _do_batched_write_command
# sock_info.command validates the session, but we use
# sock_info.write_command.
sock_info.validate_session(client, session)
@ -272,7 +278,7 @@ class _Bulk(object):
check_keys = run.op_type == _INSERT
ops = islice(run.ops, run.idx_offset, None)
# Run as many ops as possible.
request_id, msg, to_send = _do_batched_write_command(
request_id, msg, to_send = do_writes(
self.namespace, run.op_type, cmd, ops, check_keys,
self.collection.codec_options, bwc)
if not to_send:

View File

@ -18,6 +18,7 @@ from bson.codec_options import _parse_codec_options
from pymongo.auth import _build_credentials_tuple
from pymongo.common import validate_boolean
from pymongo import common
from pymongo.compression_support import CompressionSettings
from pymongo.errors import ConfigurationError
from pymongo.monitoring import _EventListeners
from pymongo.pool import PoolOptions
@ -120,6 +121,9 @@ def _parse_pool_options(options):
wait_queue_multiple = options.get('waitqueuemultiple')
event_listeners = options.get('event_listeners')
appname = options.get('appname')
compression_settings = CompressionSettings(
options.get('compressors', []),
options.get('zlibcompressionlevel', -1))
ssl_context, ssl_match_hostname = _parse_ssl_options(options)
return PoolOptions(max_pool_size,
min_pool_size,
@ -128,7 +132,8 @@ def _parse_pool_options(options):
wait_queue_timeout, wait_queue_multiple,
ssl_context, ssl_match_hostname, socket_keepalive,
_EventListeners(event_listeners),
appname)
appname,
compression_settings)
class ClientOptions(object):

View File

@ -504,6 +504,7 @@ class Collection(common.BaseObject):
if publish:
start = datetime.datetime.now()
args = args + (sock_info.compression_context,)
rqst_id, msg, max_size = func(*args)
if publish:
duration = datetime.datetime.now() - start

View File

@ -26,6 +26,8 @@ from bson.codec_options import CodecOptions
from bson.py3compat import abc, integer_types, iteritems, string_type
from bson.raw_bson import RawBSONDocument
from pymongo.auth import MECHANISMS
from pymongo.compression_support import (validate_compressors,
validate_zlib_compression_level)
from pymongo.errors import ConfigurationError
from pymongo.monitoring import _validate_event_listeners
from pymongo.read_concern import ReadConcern
@ -539,6 +541,8 @@ URI_VALIDATORS = {
'appname': validate_appname_or_none,
'unicode_decode_error_handler': validate_unicode_decode_error_handler,
'retrywrites': validate_boolean_or_string,
'compressors': validate_compressors,
'zlibcompressionlevel': validate_zlib_compression_level
}
TIMEOUT_VALIDATORS = {

View File

@ -0,0 +1,124 @@
# Copyright 2018 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
try:
import snappy
_HAVE_SNAPPY = True
except ImportError:
# python-snappy isn't available.
_HAVE_SNAPPY = False
try:
import zlib
_HAVE_ZLIB = True
except ImportError:
# Python built without zlib support.
_HAVE_ZLIB = False
from pymongo.monitoring import _SENSITIVE_COMMANDS
_SUPPORTED_COMPRESSORS = set(["snappy", "zlib"])
_NO_COMPRESSION = set(['ismaster'])
_NO_COMPRESSION.update(_SENSITIVE_COMMANDS)
def validate_compressors(dummy, value):
compressors = value.split(",")
for compressor in compressors[:]:
if compressor not in _SUPPORTED_COMPRESSORS:
compressors.remove(compressor)
warnings.warn("Unsupported compressor: %s" % (compressor,))
elif compressor == "snappy" and not _HAVE_SNAPPY:
compressors.remove(compressor)
warnings.warn(
"Wire protocol compression with snappy is not available. "
"You must install the python-snappy module for snappy support.")
elif compressor == "zlib" and not _HAVE_ZLIB:
compressors.remove(compressor)
warnings.warn(
"Wire protocol compression with zlib is not available. "
"The zlib module is not available.")
return compressors
def validate_zlib_compression_level(option, value):
try:
level = int(value)
except:
raise TypeError("%s must be an integer, not %r." % (option, value))
if level < -1 or level > 9:
raise ValueError(
"%s must be between -1 and 9, not %d." % (option, level))
return level
class CompressionSettings(object):
def __init__(self, compressors, zlib_compression_level):
self.compressors = compressors
self.zlib_compression_level = zlib_compression_level
def get_compression_context(self, compressors):
if compressors:
chosen = compressors[0]
if chosen == "snappy":
return SnappyContext()
elif chosen == "zlib":
return ZlibContext(self.zlib_compression_level)
def _zlib_no_compress(data):
"""Compress data with zlib level 0."""
cobj = zlib.compressobj(0)
return b"".join([cobj.compress(data), cobj.flush()])
class SnappyContext(object):
compressor_id = 1
@staticmethod
def compress(data):
return snappy.compress(data)
class ZlibContext(object):
compressor_id = 2
def __init__(self, level):
# Jython zlib.compress doesn't support -1
if level == -1:
self.compress = zlib.compress
# Jython zlib.compress also doesn't support 0
elif level == 0:
self.compress = _zlib_no_compress
else:
self.compress = lambda data: zlib.compress(data, level)
def decompress(data, compressor_id):
if compressor_id == SnappyContext.compressor_id:
# python-snappy doesn't support the buffer interface.
# https://github.com/andrix/python-snappy/issues/65
# This only matters when data is a memoryview since
# id(bytes(data)) == id(data) when data is a bytes.
# NOTE: bytes(memoryview) returns the memoryview repr
# in Python 2.7. The right thing to do in 2.7 is call
# memoryview.tobytes(), but we currently only use
# memoryview in Python 3.x.
return snappy.uncompress(bytes(data))
elif compressor_id == ZlibContext.compressor_id:
return zlib.decompress(data)
else:
raise ValueError("Unknown compressorId %d" % (compressor_id,))

View File

@ -152,3 +152,7 @@ class IsMaster(object):
@property
def last_write_date(self):
return self._doc.get('lastWrite', {}).get('lastWriteDate')
@property
def compressors(self):
return self._doc.get('compression')

View File

@ -25,7 +25,7 @@ import random
import struct
import bson
from bson import CodecOptions
from bson import CodecOptions, _make_c_string, _dict_to_bson
from bson.codec_options import DEFAULT_CODEC_OPTIONS
from bson.py3compat import b, StringIO
from bson.son import SON
@ -326,7 +326,8 @@ class _Query(object):
self.read_preference)
return query(flags, ns, self.ntoskip, ntoreturn,
spec, None if use_cmd else self.fields, self.codec_options)
spec, None if use_cmd else self.fields,
self.codec_options, ctx=sock_info.compression_context)
class _GetMore(object):
@ -375,14 +376,15 @@ class _GetMore(object):
"""Get a getmore message."""
ns = _UJOIN % (self.db, self.coll)
ctx = sock_info.compression_context
if use_cmd:
ns = _UJOIN % (self.db, "$cmd")
spec = self.as_command(sock_info)[0]
return query(0, ns, 0, -1, spec, None, self.codec_options)
return query(0, ns, 0, -1, spec, None, self.codec_options, ctx=ctx)
return get_more(ns, self.ntoreturn, self.cursor_id)
return get_more(ns, self.ntoreturn, self.cursor_id, ctx)
class _RawBatchQuery(_Query):
@ -436,6 +438,25 @@ class _CursorAddress(tuple):
return not self == other
_pack_compression_header = struct.Struct("<iiiiiiB").pack
_COMPRESSION_HEADER_SIZE = 25
def _compress(operation, data, ctx):
"""Takes message data, compresses it, and adds an OP_COMPRESSED header."""
compressed = ctx.compress(data)
request_id = _randint()
header = _pack_compression_header(
_COMPRESSION_HEADER_SIZE + len(compressed), # Total message length
request_id, # Request id
0, # responseTo
2012, # operation id
operation, # original operation id
len(data), # uncompressed message length
ctx.compressor_id) # compressor id
return request_id, header + compressed
def __last_error(namespace, args):
"""Data to send to do a lastError.
"""
@ -446,119 +467,265 @@ def __last_error(namespace, args):
None, DEFAULT_CODEC_OPTIONS)
_pack_header = struct.Struct("<iiii").pack
def __pack_message(operation, data):
"""Takes message data and adds a message header based on the operation.
Returns the resultant message string.
"""
request_id = _randint()
message = struct.pack("<i", 16 + len(data))
message += struct.pack("<i", request_id)
message += _ZERO_32 # responseTo
message += struct.pack("<i", operation)
return (request_id, message + data)
rid = _randint()
message = _pack_header(16 + len(data), rid, 0, operation)
return rid, message + data
_pack_int = struct.Struct("<i").pack
def _insert(collection_name, docs, check_keys, flags, opts):
"""Get an OP_INSERT message"""
encode = _dict_to_bson # Make local. Uses extensions.
if len(docs) == 1:
encoded = encode(docs[0], check_keys, opts)
return b"".join([
b"\x00\x00\x00\x00", # Flags don't matter for one doc.
_make_c_string(collection_name),
encoded]), len(encoded)
encoded = [encode(doc, check_keys, opts) for doc in docs]
if not encoded:
raise InvalidOperation("cannot do an empty bulk insert")
return b"".join([
_pack_int(flags),
_make_c_string(collection_name),
b"".join(encoded)]), max(map(len, encoded))
def _insert_compressed(
collection_name, docs, check_keys, continue_on_error, opts, ctx):
"""Internal compressed unacknowledged insert message helper."""
op_insert, max_bson_size = _insert(
collection_name, docs, check_keys, continue_on_error, opts)
rid, msg = _compress(2002, op_insert, ctx)
return rid, msg, max_bson_size
def _insert_uncompressed(collection_name, docs, check_keys,
safe, last_error_args, continue_on_error, opts):
"""Internal insert message helper."""
op_insert, max_bson_size = _insert(
collection_name, docs, check_keys, continue_on_error, opts)
rid, msg = __pack_message(2002, op_insert)
if safe:
rid, gle, _ = __last_error(collection_name, last_error_args)
return rid, msg + gle, max_bson_size
return rid, msg, max_bson_size
if _use_c:
_insert_uncompressed = _cmessage._insert_message
def insert(collection_name, docs, check_keys,
safe, last_error_args, continue_on_error, opts):
safe, last_error_args, continue_on_error, opts, ctx=None):
"""Get an **insert** message."""
options = 0
if continue_on_error:
options += 1
data = struct.pack("<i", options)
data += bson._make_c_string(collection_name)
encoded = [bson.BSON.encode(doc, check_keys, opts) for doc in docs]
if not encoded:
raise InvalidOperation("cannot do an empty bulk insert")
max_bson_size = max(map(len, encoded))
data += _EMPTY.join(encoded)
if safe:
(_, insert_message) = __pack_message(2002, data)
(request_id, error_message, _) = __last_error(collection_name,
last_error_args)
return (request_id, insert_message + error_message, max_bson_size)
else:
(request_id, insert_message) = __pack_message(2002, data)
return (request_id, insert_message, max_bson_size)
if _use_c:
insert = _cmessage._insert_message
if ctx:
return _insert_compressed(
collection_name, docs, check_keys, continue_on_error, opts, ctx)
return _insert_uncompressed(collection_name, docs, check_keys, safe,
last_error_args, continue_on_error, opts)
def update(collection_name, upsert, multi,
spec, doc, safe, last_error_args, check_keys, opts):
"""Get an **update** message.
"""
options = 0
def _update(collection_name, upsert, multi, spec, doc, check_keys, opts):
"""Get an OP_UPDATE message."""
flags = 0
if upsert:
options += 1
flags += 1
if multi:
options += 2
flags += 2
encode = _dict_to_bson # Make local. Uses extensions.
encoded_update = encode(doc, check_keys, opts)
return b"".join([
_ZERO_32,
_make_c_string(collection_name),
_pack_int(flags),
encode(spec, False, opts),
encoded_update]), len(encoded_update)
data = _ZERO_32
data += bson._make_c_string(collection_name)
data += struct.pack("<i", options)
data += bson.BSON.encode(spec, False, opts)
encoded = bson.BSON.encode(doc, check_keys, opts)
data += encoded
def _update_compressed(
collection_name, upsert, multi, spec, doc, check_keys, opts, ctx):
"""Internal compressed unacknowledged update message helper."""
op_update, max_bson_size = _update(
collection_name, upsert, multi, spec, doc, check_keys, opts)
rid, msg = _compress(2001, op_update, ctx)
return rid, msg, max_bson_size
def _update_uncompressed(collection_name, upsert, multi, spec,
doc, safe, last_error_args, check_keys, opts):
"""Internal update message helper."""
op_update, max_bson_size = _update(
collection_name, upsert, multi, spec, doc, check_keys, opts)
rid, msg = __pack_message(2001, op_update)
if safe:
(_, update_message) = __pack_message(2001, data)
(request_id, error_message, _) = __last_error(collection_name,
last_error_args)
return (request_id, update_message + error_message, len(encoded))
else:
(request_id, update_message) = __pack_message(2001, data)
return (request_id, update_message, len(encoded))
rid, gle, _ = __last_error(collection_name, last_error_args)
return rid, msg + gle, max_bson_size
return rid, msg, max_bson_size
if _use_c:
update = _cmessage._update_message
_update_uncompressed = _cmessage._update_message
def query(options, collection_name, num_to_skip,
num_to_return, query, field_selector, opts, check_keys=False):
"""Get a **query** message.
"""
data = struct.pack("<I", options)
data += bson._make_c_string(collection_name)
data += struct.pack("<i", num_to_skip)
data += struct.pack("<i", num_to_return)
if check_keys:
def update(collection_name, upsert, multi, spec,
doc, safe, last_error_args, check_keys, opts, ctx=None):
"""Get an **update** message."""
if ctx:
return _update_compressed(
collection_name, upsert, multi, spec, doc, check_keys, opts, ctx)
return _update_uncompressed(collection_name, upsert, multi, spec,
doc, safe, last_error_args, check_keys, opts)
def _query(options, collection_name, num_to_skip,
num_to_return, query, field_selector, opts, check_keys):
"""Get an OP_QUERY message."""
encode = _dict_to_bson # Make local. Uses extensions.
if check_keys and "$clusterTime" in query:
# Temporarily remove $clusterTime to avoid an error from the $-prefix.
cluster_time = query.pop('$clusterTime', None)
encoded = bson.BSON.encode(query, True, opts)
if cluster_time is not None:
extra = bson._name_value_to_bson(
b"$clusterTime\x00", cluster_time, False, opts)
encoded = (
bson._PACK_INT(len(encoded) + len(extra))
+ encoded[4:-1] + extra + b'\x00')
query['$clusterTime'] = cluster_time
cluster_time = query.pop('$clusterTime')
encoded = encode(query, True, opts)
extra = bson._name_value_to_bson(
b"$clusterTime\x00", cluster_time, False, opts)
encoded = (
_pack_int(len(encoded) + len(extra))
+ encoded[4:-1] + extra + b'\x00')
query['$clusterTime'] = cluster_time
else:
encoded = bson.BSON.encode(query, False, opts)
data += encoded
max_bson_size = len(encoded)
if field_selector is not None:
encoded = bson.BSON.encode(field_selector, False, opts)
data += encoded
max_bson_size = max(len(encoded), max_bson_size)
(request_id, query_message) = __pack_message(2004, data)
return (request_id, query_message, max_bson_size)
encoded = encode(query, check_keys, opts)
if field_selector:
efs = encode(field_selector, False, opts)
else:
efs = b""
max_bson_size = max(len(encoded), len(efs))
return b"".join([
_pack_int(options),
_make_c_string(collection_name),
_pack_int(num_to_skip),
_pack_int(num_to_return),
encoded,
efs]), max_bson_size
def _query_compressed(options, collection_name, num_to_skip,
num_to_return, query, field_selector,
opts, check_keys=False, ctx=None):
"""Internal compressed query message helper."""
op_query, max_bson_size = _query(
options,
collection_name,
num_to_skip,
num_to_return,
query,
field_selector,
opts,
check_keys)
rid, msg = _compress(2004, op_query, ctx)
return rid, msg, max_bson_size
def _query_uncompressed(options, collection_name, num_to_skip,
num_to_return, query, field_selector, opts, check_keys=False):
"""Internal query message helper."""
op_query, max_bson_size = _query(
options,
collection_name,
num_to_skip,
num_to_return,
query,
field_selector,
opts,
check_keys)
rid, msg = __pack_message(2004, op_query)
return rid, msg, max_bson_size
if _use_c:
query = _cmessage._query_message
_query_uncompressed = _cmessage._query_message
def get_more(collection_name, num_to_return, cursor_id):
"""Get a **getMore** message.
"""
data = _ZERO_32
data += bson._make_c_string(collection_name)
data += struct.pack("<i", num_to_return)
data += struct.pack("<q", cursor_id)
return __pack_message(2005, data)
def query(options, collection_name, num_to_skip, num_to_return,
query, field_selector, opts, check_keys=False, ctx=None):
"""Get a **query** message."""
if ctx:
return _query_compressed(options, collection_name, num_to_skip,
num_to_return, query, field_selector,
opts, check_keys, ctx)
return _query_uncompressed(options, collection_name, num_to_skip,
num_to_return, query, field_selector, opts,
check_keys)
_pack_long_long = struct.Struct("<q").pack
def _get_more(collection_name, num_to_return, cursor_id):
"""Get an OP_GET_MORE message."""
return b"".join([
_ZERO_32,
_make_c_string(collection_name),
_pack_int(num_to_return),
_pack_long_long(cursor_id)])
def _get_more_compressed(collection_name, num_to_return, cursor_id, ctx):
"""Internal compressed getMore message helper."""
return _compress(
2005, _get_more(collection_name, num_to_return, cursor_id), ctx)
def _get_more_uncompressed(collection_name, num_to_return, cursor_id):
"""Internal getMore message helper."""
return __pack_message(
2005, _get_more(collection_name, num_to_return, cursor_id))
if _use_c:
get_more = _cmessage._get_more_message
_get_more_uncompressed = _cmessage._get_more_message
def delete(collection_name, spec, safe,
last_error_args, opts, flags=0):
def get_more(collection_name, num_to_return, cursor_id, ctx=None):
"""Get a **getMore** message."""
if ctx:
return _get_more_compressed(
collection_name, num_to_return, cursor_id, ctx)
return _get_more_uncompressed(collection_name, num_to_return, cursor_id)
def _delete(collection_name, spec, opts, flags):
"""Get an OP_DELETE message."""
encoded = _dict_to_bson(spec, False, opts) # Uses extensions.
return b"".join([
_ZERO_32,
_make_c_string(collection_name),
_pack_int(flags),
encoded]), len(encoded)
def _delete_compressed(collection_name, spec, opts, flags, ctx):
"""Internal compressed unacknowledged delete message helper."""
op_delete, max_bson_size = _delete(collection_name, spec, opts, flags)
rid, msg = _compress(2006, op_delete, ctx)
return rid, msg, max_bson_size
def _delete_uncompressed(
collection_name, spec, safe, last_error_args, opts, flags=0):
"""Internal delete message helper."""
op_delete, max_bson_size = _delete(collection_name, spec, opts, flags)
rid, msg = __pack_message(2006, op_delete)
if safe:
rid, gle, _ = __last_error(collection_name, last_error_args)
return rid, msg + gle, max_bson_size
return rid, msg, max_bson_size
def delete(
collection_name, spec, safe, last_error_args, opts, flags=0, ctx=None):
"""Get a **delete** message.
`opts` is a CodecOptions. `flags` is a bit vector that may contain
@ -566,29 +733,19 @@ def delete(collection_name, spec, safe,
http://docs.mongodb.org/meta-driver/latest/legacy/mongodb-wire-protocol/#op-delete
"""
data = _ZERO_32
data += bson._make_c_string(collection_name)
data += struct.pack("<I", flags)
encoded = bson.BSON.encode(spec, False, opts)
data += encoded
if safe:
(_, remove_message) = __pack_message(2006, data)
(request_id, error_message, _) = __last_error(collection_name,
last_error_args)
return (request_id, remove_message + error_message, len(encoded))
else:
(request_id, remove_message) = __pack_message(2006, data)
return (request_id, remove_message, len(encoded))
if ctx:
return _delete_compressed(collection_name, spec, opts, flags, ctx)
return _delete_uncompressed(
collection_name, spec, safe, last_error_args, opts, flags)
def kill_cursors(cursor_ids):
"""Get a **killCursors** message.
"""
data = _ZERO_32
data += struct.pack("<i", len(cursor_ids))
for cursor_id in cursor_ids:
data += struct.pack("<q", cursor_id)
return __pack_message(2007, data)
num_cursors = len(cursor_ids)
pack = struct.Struct("<ii" + ("q" * num_cursors)).pack
op_kill_cursors = pack(0, num_cursors, *cursor_ids)
return __pack_message(2007, op_kill_cursors)
_FIELD_MAP = {
@ -603,7 +760,7 @@ class _BulkWriteContext(object):
__slots__ = ('db_name', 'command', 'sock_info', 'op_id',
'name', 'field', 'publish', 'start_time', 'listeners',
'session')
'session', 'compress')
def __init__(self, database_name, command, sock_info, operation_id,
listeners, session):
@ -617,6 +774,7 @@ class _BulkWriteContext(object):
self.field = _FIELD_MAP[self.name]
self.start_time = datetime.datetime.now() if self.publish else None
self.session = session
self.compress = True if sock_info.compression_context else False
@property
def max_bson_size(self):
@ -633,6 +791,14 @@ class _BulkWriteContext(object):
"""A proxy for SockInfo.max_write_batch_size."""
return self.sock_info.max_write_batch_size
def legacy_bulk_insert(
self, request_id, msg, max_doc_size, acknowledged, docs, compress):
if compress:
request_id, msg = _compress(
2002, msg, self.sock_info.compression_context)
return self.legacy_write(
request_id, msg, max_doc_size, acknowledged, docs)
def legacy_write(self, request_id, msg, max_doc_size, acknowledged, docs):
"""A proxy for SocketInfo.legacy_write that handles event publishing.
"""
@ -739,12 +905,14 @@ def _do_batched_insert(collection_name, docs, check_keys,
last_error = None
data = StringIO()
data.write(struct.pack("<i", int(continue_on_error)))
data.write(bson._make_c_string(collection_name))
data.write(_make_c_string(collection_name))
message_length = begin_loc = data.tell()
has_docs = False
to_send = []
encode = _dict_to_bson # Make local
compress = ctx.compress and not (safe or send_safe)
for doc in docs:
encoded = bson.BSON.encode(doc, check_keys, opts)
encoded = encode(doc, check_keys, opts)
encoded_length = len(encoded)
too_large = (encoded_length > ctx.max_bson_size)
@ -758,8 +926,12 @@ def _do_batched_insert(collection_name, docs, check_keys,
if has_docs:
# We have enough data, send this message.
try:
request_id, msg = _insert_message(data.getvalue(), send_safe)
ctx.legacy_write(request_id, msg, 0, send_safe, to_send)
if compress:
rid, msg = None, data.getvalue()
else:
rid, msg = _insert_message(data.getvalue(), send_safe)
ctx.legacy_bulk_insert(
rid, msg, 0, send_safe, to_send, compress)
# Exception type could be OperationFailure or a subtype
# (e.g. DuplicateKeyError)
except OperationFailure as exc:
@ -787,8 +959,11 @@ def _do_batched_insert(collection_name, docs, check_keys,
if not has_docs:
raise InvalidOperation("cannot do an empty bulk insert")
request_id, msg = _insert_message(data.getvalue(), safe)
ctx.legacy_write(request_id, msg, 0, safe, to_send)
if compress:
request_id, msg = None, data.getvalue()
else:
request_id, msg = _insert_message(data.getvalue(), safe)
ctx.legacy_bulk_insert(request_id, msg, 0, safe, to_send, compress)
# Re-raise any exception stored due to continue_on_error
if last_error is not None:
@ -797,21 +972,69 @@ if _use_c:
_do_batched_insert = _cmessage._do_batched_insert
def _do_batched_write_command(namespace, operation, command,
docs, check_keys, opts, ctx):
def _do_batched_write_command_compressed(
namespace, operation, command, docs, check_keys, opts, ctx):
"""Create the next batched insert, update, or delete command, compressed.
"""
data, to_send = _encode_batched_write_command(
namespace, operation, command, docs, check_keys, opts, ctx)
request_id, msg = _compress(
2004,
data,
ctx.sock_info.compression_context)
return request_id, msg, to_send
def _encode_batched_write_command(
namespace, operation, command, docs, check_keys, opts, ctx):
"""Encode the next batched insert, update, or delete command.
"""
buf = StringIO()
to_send, _ = _batched_write_command(
namespace, operation, command, docs, check_keys, opts, ctx, buf)
return buf.getvalue(), to_send
if _use_c:
_encode_batched_write_command = _cmessage._encode_batched_write_command
def _do_batched_write_command(
namespace, operation, command, docs, check_keys, opts, ctx):
"""Create the next batched insert, update, or delete command.
"""
buf = StringIO()
# Save space for message length and request id
buf.write(_ZERO_64)
# responseTo, opCode
buf.write(b"\x00\x00\x00\x00\xd4\x07\x00\x00")
# Write OP_QUERY write command
to_send, length = _batched_write_command(
namespace, operation, command, docs, check_keys, opts, ctx, buf)
# Header - request id and message length
buf.seek(4)
request_id = _randint()
buf.write(struct.pack('<i', request_id))
buf.seek(0)
buf.write(struct.pack('<i', length))
return request_id, buf.getvalue(), to_send
if _use_c:
_do_batched_write_command = _cmessage._do_batched_write_command
def _batched_write_command(
namespace, operation, command, docs, check_keys, opts, ctx, buf):
"""Create a batched OP_QUERY write command."""
max_bson_size = ctx.max_bson_size
max_write_batch_size = ctx.max_write_batch_size
# Max BSON object size + 16k - 2 bytes for ending NUL bytes.
# Server guarantees there is enough room: SERVER-10643.
max_cmd_size = max_bson_size + _COMMAND_OVERHEAD
buf = StringIO()
# Save space for message length and request id
buf.write(_ZERO_64)
# responseTo, opCode
buf.write(b"\x00\x00\x00\x00\xd4\x07\x00\x00")
# No options
buf.write(_ZERO_32)
# Namespace as C string
@ -871,15 +1094,8 @@ def _do_batched_write_command(namespace, operation, command,
buf.write(struct.pack('<i', length - list_start - 1))
buf.seek(command_start)
buf.write(struct.pack('<i', length - command_start))
buf.seek(4)
request_id = _randint()
buf.write(struct.pack('<i', request_id))
buf.seek(0)
buf.write(struct.pack('<i', length))
return request_id, buf.getvalue(), to_send
if _use_c:
_do_batched_write_command = _cmessage._do_batched_write_command
return to_send, length
class _OpReply(object):

View File

@ -252,6 +252,20 @@ class MongoClient(common.BaseObject):
periodic keep-alive packets on connected sockets. Defaults to
``True``. Disabling it is not recommended, see
https://docs.mongodb.com/manual/faq/diagnostics/#does-tcp-keepalive-time-affect-mongodb-deployments",
- `compressors`: Comma separated list of compressors for wire
protocol compression. The list is used to negotiate a compressor
with the server. Currently supported options are "snappy" and
"zlib". Support for snappy requires the
`python-snappy <https://pypi.org/project/python-snappy/>`_ package.
zlib support requires the Python standard library zlib module.
By default no compression is used. Compression support must also be
enabled on the server. MongoDB 3.4+ supports snappy compression.
MongoDB 3.6+ supports snappy and zlib.
- `zlibCompressionLevel`: (int) The zlib compression level to use
when zlib is used as the wire protocol compressor. Supported values
are -1 through 9. -1 tells the zlib library to use its default
compression level (usually 6). 0 means no compression. 1 is best
speed. 9 is best compression. Defaults to -1.
| **Write Concern options:**
| (Only set if passed. No default values.)

View File

@ -38,6 +38,7 @@ from bson.py3compat import PY3
from pymongo import helpers, message
from pymongo.common import MAX_MESSAGE_SIZE
from pymongo.compression_support import decompress, _NO_COMPRESSION
from pymongo.errors import (AutoReconnect,
NotMasterError,
OperationFailure,
@ -54,7 +55,8 @@ def command(sock, dbname, spec, slave_ok, is_mongos,
check_keys=False, listeners=None, max_bson_size=None,
read_concern=None,
parse_write_concern_error=False,
collation=None):
collation=None,
compression_ctx=None):
"""Execute a command over the socket, or raise socket.error.
:Parameters:
@ -101,8 +103,12 @@ def command(sock, dbname, spec, slave_ok, is_mongos,
if publish:
start = datetime.datetime.now()
request_id, msg, size = message.query(flags, ns, 0, -1, spec,
None, codec_options, check_keys)
if name.lower() not in _NO_COMPRESSION and compression_ctx:
request_id, msg, size = message.query(
flags, ns, 0, -1, spec, None, codec_options, check_keys, compression_ctx)
else:
request_id, msg, size = message.query(
flags, ns, 0, -1, spec, None, codec_options, check_keys)
if (max_bson_size is not None
and size > max_bson_size + message._COMMAND_OVERHEAD):
@ -142,15 +148,13 @@ def command(sock, dbname, spec, slave_ok, is_mongos,
duration, response_doc, name, request_id, address)
return response_doc
_UNPACK_COMPRESSION_HEADER = struct.Struct("<iiB").unpack
def receive_message(sock, request_id, max_message_size=MAX_MESSAGE_SIZE):
"""Receive a raw BSON message or raise socket.error."""
# Ignore the response's request id.
length, _, response_to, op_code = _UNPACK_HEADER(
_receive_data_on_socket(sock, 16))
if op_code != _OpReply.OP_CODE:
raise ProtocolError("Got opcode %r but expected "
"%r" % (op_code, _OpReply.OP_CODE))
# No request_id for exhaust cursor "getMore".
if request_id is not None:
if request_id != response_to:
@ -162,8 +166,18 @@ def receive_message(sock, request_id, max_message_size=MAX_MESSAGE_SIZE):
if length > max_message_size:
raise ProtocolError("Message length (%r) is larger than server max "
"message size (%r)" % (length, max_message_size))
if op_code == 2012:
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
_receive_data_on_socket(sock, 9))
data = decompress(
_receive_data_on_socket(sock, length - 25), compressor_id)
else:
data = _receive_data_on_socket(sock, length - 16)
if op_code != _OpReply.OP_CODE:
raise ProtocolError("Got opcode %r but expected "
"%r" % (op_code, _OpReply.OP_CODE))
return _OpReply.unpack(_receive_data_on_socket(sock, length - 16))
return _OpReply.unpack(data)
# memoryview was introduced in Python 2.7 but we only use it on Python 3

View File

@ -280,14 +280,16 @@ class PoolOptions(object):
'__connect_timeout', '__socket_timeout',
'__wait_queue_timeout', '__wait_queue_multiple',
'__ssl_context', '__ssl_match_hostname', '__socket_keepalive',
'__event_listeners', '__appname', '__metadata')
'__event_listeners', '__appname', '__metadata',
'__compression_settings')
def __init__(self, max_pool_size=100, min_pool_size=0,
max_idle_time_seconds=None, connect_timeout=None,
socket_timeout=None, wait_queue_timeout=None,
wait_queue_multiple=None, ssl_context=None,
ssl_match_hostname=True, socket_keepalive=True,
event_listeners=None, appname=None):
event_listeners=None, appname=None,
compression_settings=None):
self.__max_pool_size = max_pool_size
self.__min_pool_size = min_pool_size
@ -301,6 +303,7 @@ class PoolOptions(object):
self.__socket_keepalive = socket_keepalive
self.__event_listeners = event_listeners
self.__appname = appname
self.__compression_settings = compression_settings
self.__metadata = _METADATA.copy()
if appname:
self.__metadata['application'] = {'name': appname}
@ -392,6 +395,10 @@ class PoolOptions(object):
"""
return self.__appname
@property
def compression_settings(self):
return self.__compression_settings
@property
def metadata(self):
"""A dict of metadata about the application, driver, os, and platform.
@ -423,6 +430,8 @@ class SocketInfo(object):
self.is_mongos = False
self.listeners = pool.opts.event_listeners
self.compression_settings = pool.opts.compression_settings
self.compression_context = None
# The pool's pool_id changes with each reset() so we can close sockets
# created before the last reset.
@ -432,7 +441,8 @@ class SocketInfo(object):
cmd = SON([('ismaster', 1)])
if not self.performed_handshake:
cmd['client'] = metadata
self.performed_handshake = True
if self.compression_settings:
cmd['compression'] = self.compression_settings.compressors
if self.max_wire_version >= 6 and cluster_time is not None:
cmd['$clusterTime'] = cluster_time
@ -446,6 +456,12 @@ class SocketInfo(object):
self.supports_sessions = (
ismaster.logical_session_timeout_minutes is not None)
self.is_mongos = ismaster.server_type == SERVER_TYPE.Mongos
if not self.performed_handshake and self.compression_settings:
ctx = self.compression_settings.get_compression_context(
ismaster.compressors)
self.compression_context = ctx
self.performed_handshake = True
return ismaster
def command(self, dbname, spec, slave_ok=False,
@ -512,7 +528,8 @@ class SocketInfo(object):
self.address, check_keys, listeners,
self.max_bson_size, read_concern,
parse_write_concern_error=parse_write_concern_error,
collation=collation)
collation=collation,
compression_ctx=self.compression_context)
except OperationFailure:
raise
# Catch socket.error, KeyboardInterrupt, etc. and close ourselves.

View File

@ -323,11 +323,14 @@ ext_modules = [Extension('bson._cbson',
sources=['pymongo/_cmessagemodule.c',
'bson/buffer.c'])]
extras_require = {'snappy': ["python-snappy"]}
vi = sys.version_info
if vi[0] == 2:
extras_require = {'tls': ["ipaddress"], 'srv': ["dnspython>=1.8.0,<2.0.0"]}
extras_require.update(
{'tls': ["ipaddress"], 'srv': ["dnspython>=1.8.0,<2.0.0"]})
else:
extras_require = {'tls': [], 'srv': ["dnspython>=1.13.0,<2.0.0"]}
extras_require.update(
{'tls': [], 'srv': ["dnspython>=1.13.0,<2.0.0"]})
if sys.platform == 'win32':
extras_require['gssapi'] = ["winkerberos>=0.5.0"]
if vi[0] == 2 and vi < (2, 7, 9) or vi[0] == 3 and vi < (3, 4):

View File

@ -72,6 +72,7 @@ if CA_PEM:
if CERT_REQS is not None:
_SSL_OPTIONS['ssl_cert_reqs'] = CERT_REQS
COMPRESSORS = os.environ.get("COMPRESSORS")
def is_server_resolvable():
"""Returns True if 'server' is resolvable."""
@ -186,16 +187,19 @@ class ClientContext(object):
self.ssl_cert_none = False
self.ssl_certfile = False
self.server_is_resolvable = is_server_resolvable()
self.ssl_client_options = {}
self.default_client_options = {}
self.sessions_enabled = False
self.client = self._connect(host, port)
if COMPRESSORS:
self.default_client_options["compressors"] = COMPRESSORS
if HAVE_SSL and not self.client:
# Is MongoDB configured for SSL?
self.client = self._connect(host, port, **_SSL_OPTIONS)
if self.client:
self.ssl = True
self.ssl_client_options = _SSL_OPTIONS
self.default_client_options.update(_SSL_OPTIONS)
self.ssl_certfile = True
if _SSL_OPTIONS.get('ssl_cert_reqs') == ssl.CERT_NONE:
self.ssl_cert_none = True
@ -223,7 +227,7 @@ class ClientContext(object):
self.client = self._connect(
host, port, username=db_user, password=db_pwd,
replicaSet=self.replica_set_name,
**self.ssl_client_options)
**self.default_client_options)
# May not have this if OperationFailure was raised earlier.
self.cmd_line = self.client.admin.command('getCmdLineOpts')
@ -242,13 +246,13 @@ class ClientContext(object):
username=db_user,
password=db_pwd,
replicaSet=self.replica_set_name,
**self.ssl_client_options)
**self.default_client_options)
else:
self.client = pymongo.MongoClient(
host,
port,
replicaSet=self.replica_set_name,
**self.ssl_client_options)
**self.default_client_options)
# Get the authoritative ismaster result from the primary.
self.ismaster = self.client.admin.command('ismaster')
@ -285,6 +289,8 @@ class ClientContext(object):
timeout_ms = 10000
else:
timeout_ms = 500
if COMPRESSORS:
kwargs["compressors"] = COMPRESSORS
client = pymongo.MongoClient(
host, port, serverSelectionTimeoutMS=timeout_ms, **kwargs)
try:
@ -339,7 +345,7 @@ class ClientContext(object):
username=db_user,
password=db_pwd,
serverSelectionTimeoutMS=100,
**self.ssl_client_options)
**self.default_client_options)
try:
return db_user in _all_users(client.admin)

View File

@ -123,7 +123,7 @@ class MockClient(MongoClient):
kwargs['_pool_class'] = partial(MockPool, self)
kwargs['_monitor_class'] = partial(MockMonitor, self)
client_options = client_context.ssl_client_options.copy()
client_options = client_context.default_client_options.copy()
client_options.update(kwargs)
super(MockClient, self).__init__(*args, **client_options)

View File

@ -18,6 +18,7 @@ import contextlib
import datetime
import gc
import os
import platform
import signal
import socket
import struct
@ -35,6 +36,7 @@ from bson.tz_util import utc
from pymongo import auth, message
from pymongo.common import _UUID_REPRESENTATIONS
from pymongo.command_cursor import CommandCursor
from pymongo.compression_support import _HAVE_SNAPPY
from pymongo.cursor import CursorType
from pymongo.database import Database
from pymongo.errors import (AutoReconnect,
@ -382,7 +384,7 @@ class TestClient(IntegrationTest):
port are not overloaded.
"""
host, port = client_context.host, client_context.port
kwargs = client_context.ssl_client_options.copy()
kwargs = client_context.default_client_options.copy()
if client_context.auth_enabled:
kwargs['username'] = db_user
kwargs['password'] = db_pwd
@ -1180,6 +1182,76 @@ class TestClient(IntegrationTest):
self.assertIn('heartbeatFrequencyMS', str(context.exception))
def test_compression(self):
def compression_settings(client):
pool_options = client._MongoClient__options.pool_options
return pool_options.compression_settings
uri = "mongodb://localhost:27017/?compressors=zlib"
client = MongoClient(uri, connect=False)
opts = compression_settings(client)
self.assertEqual(opts.compressors, ['zlib'])
uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=4"
client = MongoClient(uri, connect=False)
opts = compression_settings(client)
self.assertEqual(opts.compressors, ['zlib'])
self.assertEqual(opts.zlib_compression_level, 4)
uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-1"
client = MongoClient(uri, connect=False)
opts = compression_settings(client)
self.assertEqual(opts.compressors, ['zlib'])
self.assertEqual(opts.zlib_compression_level, -1)
uri = "mongodb://localhost:27017"
client = MongoClient(uri, connect=False)
opts = compression_settings(client)
self.assertEqual(opts.compressors, [])
self.assertEqual(opts.zlib_compression_level, -1)
uri = "mongodb://localhost:27017/?compressors=foobar"
client = MongoClient(uri, connect=False)
opts = compression_settings(client)
self.assertEqual(opts.compressors, [])
self.assertEqual(opts.zlib_compression_level, -1)
uri = "mongodb://localhost:27017/?compressors=foobar,zlib"
client = MongoClient(uri, connect=False)
opts = compression_settings(client)
self.assertEqual(opts.compressors, ['zlib'])
self.assertEqual(opts.zlib_compression_level, -1)
# According to the connection string spec, unsupported values
# just raise a warning and are ignored.
uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=10"
client = MongoClient(uri, connect=False)
opts = compression_settings(client)
self.assertEqual(opts.compressors, ['zlib'])
self.assertEqual(opts.zlib_compression_level, -1)
uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-2"
client = MongoClient(uri, connect=False)
opts = compression_settings(client)
self.assertEqual(opts.compressors, ['zlib'])
self.assertEqual(opts.zlib_compression_level, -1)
if not _HAVE_SNAPPY:
uri = "mongodb://localhost:27017/?compressors=snappy"
client = MongoClient(uri, connect=False)
opts = compression_settings(client)
self.assertEqual(opts.compressors, [])
else:
uri = "mongodb://localhost:27017/?compressors=snappy"
client = MongoClient(uri, connect=False)
opts = compression_settings(client)
self.assertEqual(opts.compressors, ['snappy'])
uri = "mongodb://localhost:27017/?compressors=snappy,zlib"
client = MongoClient(uri, connect=False)
opts = compression_settings(client)
self.assertEqual(opts.compressors, ['snappy', 'zlib'])
options = client_context.default_client_options
if "compressors" in options and "zlib" in options["compressors"]:
for level in range(-1, 10):
client = single_client(zlibcompressionlevel=level)
# No error
client.pymongo_test.test.find_one()
class TestExhaustCursor(IntegrationTest):
"""Test that clients properly handle errors from exhaust cursors."""

View File

@ -32,7 +32,7 @@ from test.utils import wait_until
_TEST_PATH = os.path.join(
os.path.dirname(os.path.realpath(__file__)), 'dns')
_SSL_OPTS = client_context.ssl_client_options.copy()
_SSL_OPTS = client_context.default_client_options.copy()
if client_context.ssl is True:
# Our test certs don't support the SRV hosts used in these tests.
_SSL_OPTS['ssl_match_hostname'] = False

View File

@ -310,7 +310,7 @@ class TestReadPreferences(TestReadPreferencesBase):
class ReadPrefTester(MongoClient):
def __init__(self, *args, **kwargs):
self.has_read_from = set()
client_options = client_context.ssl_client_options.copy()
client_options = client_context.default_client_options.copy()
client_options.update(kwargs)
super(ReadPrefTester, self).__init__(*args, **client_options)

View File

@ -33,9 +33,7 @@ from pymongo.monitoring import _SENSITIVE_COMMANDS
from pymongo.server_selectors import (any_server_selector,
writable_server_selector)
from pymongo.write_concern import WriteConcern
from test import (client_context,
db_user,
db_pwd)
from test import (client_context, db_user, db_pwd)
IMPOSSIBLE_WRITE_CONCERN = WriteConcern(w=1000)
@ -134,8 +132,7 @@ def _connection_string(h, p, authenticate):
def _mongo_client(host, port, authenticate=True, direct=False, **kwargs):
"""Create a new client over SSL/TLS if necessary."""
client_options = client_context.ssl_client_options.copy()
client_options = client_context.default_client_options.copy()
if client_context.replica_set_name and not direct:
client_options['replicaSet'] = client_context.replica_set_name
client_options.update(kwargs)