Compare commits

...

51 Commits
master ... v3.8

Author SHA1 Message Date
Prashant Mital
1c69dd0860
BUMP 3.8.0 2019-04-22 16:37:29 -05:00
Shane Harvey
3383316fd5 PYTHON-1826 Restore proper not master handling in cursors 2019-04-19 17:53:20 -07:00
Shane Harvey
f359284ab2 PYTHON-1664 Include type in InvalidDocument error
(cherry picked from commit 6b6efd9b59)
2019-04-19 16:51:56 -07:00
Shane Harvey
353be17179 PYTHON-1818 Support custom type encoding in watch pipelines
(cherry picked from commit 9cca2a7d2c)
2019-04-19 15:54:54 -07:00
Prashant Mital
824b58ac60
PYTHON-1819 Documentation & examples for custom type encoding/decoding
functionality

(cherry picked from commit 0ea5a1542e)
2019-04-19 13:36:20 -07:00
Prashant Mital
286ab13a8a
PYTHON-1821 Preserve field ordering when iterating over RawBSONDocument
instances

(cherry picked from commit 3b29458015)
2019-04-19 12:38:36 -07:00
Prashant Mital
f54562a9a5
PYTHON-1818 TypeCodec support for ChangeStreams
(cherry picked from commit 2f2fe9db0d)
2019-04-18 18:08:57 -07:00
Shane Harvey
6476cb8816 PYTHON-1820 Validate bson size in RawBSONDocument init
Also fixes a bug where an empty bson document could not be represented
by RawBSONDocument.

(cherry picked from commit fbb56a2311)
2019-04-18 15:16:58 -07:00
Shane Harvey
ec8768c3ea PYTHON-1814 Support custom type decoder with distinct
Fix pure python custom type decoding of bson arrays.

(cherry picked from commit 2cb34e4efc)

Conflicts:
	pymongo/command_cursor.py
2019-04-17 15:48:54 -07:00
Bernie Hackett
bc8e8b2830 PYTHON-1808 - Document uuidRepresentation
(cherry picked from commit 2f06e8a441)
2019-04-16 16:27:16 -07:00
Prashant Mital
e1c319b041
PYTHON-1783 Fix issues resulting from 3-way merge 2019-04-16 16:17:33 -07:00
Prashant Mital
b14dede0fb
PYTHON-1783 Decode user-facing documents but not internal driver-server
communications.

(cherry picked from commit 749116287a)
2019-04-16 15:35:29 -07:00
Prashant Mital
3f7481fe3b
PYTHON-1783: disallow custom-encoding built-in types
(cherry picked from commit 4049b1493a)
2019-04-16 15:25:03 -07:00
Shane Harvey
a3fe8bc106 PYTHON-1491 Skip OP_KILL_CURSORS on old 3.6 versions
(cherry picked from commit 007aa6ba50)
2019-04-16 14:34:18 -07:00
Shane Harvey
69353270f7 PYTHON-1811 Deprecate running min/max queries without hint
Starting in MongoDB 4.2 a hint will be required when using min/max.

(cherry picked from commit ea8941ef5d)
2019-04-16 14:10:28 -07:00
Shane Harvey
50b905ea2f PYTHON-1801 Update transactions count test for MongoDB >= 4.0.7
(cherry picked from commit f09d6fa052)

Changes:
    Replace runOn with minServerVersion.
2019-04-08 12:58:41 -07:00
Shane Harvey
35a408c425 PYTHON-1799 Don't iterate _ENCODERS dict when encoding bson
(cherry picked from commit eb4a047278)
2019-04-08 10:59:42 -07:00
Jesse M. Holmes
170d217a78 Update docs to use list_collection_names method (#401)
db.collection_names() is deprecated.

(cherry picked from commit 66eb6da601)
2019-03-28 16:09:19 -07:00
Terence D. Honles
8cecd8eb86 PYTHON-1695 GridOut/GridIn more closely implement io.IOBase (#387)
Allows GridOut to be wrapped with zipfile.ZipFile from the stdlib.

(cherry picked from commit 481600b7fe)
2019-03-28 15:56:07 -07:00
jakirkham
b69f6d47ed Clarify that Binary expects data to be bytes (#399)
(cherry picked from commit 5950abf323)
2019-03-28 13:56:44 -07:00
Shane Harvey
7feb8ceee7 PYTHON-1792 More reliable tests for ChangeStream.try_next
(cherry picked from commit 5ebd2938bd)
2019-03-27 13:37:12 -07:00
Shane Harvey
d2cf5fca45 PYTHON-1791 Fix reference counting leaks
Fix batched op_msg/op_query reference leak of overflow doc.
Fix theoretically possible (but practically impossible) reference leak of
$clusterTime in op_query.
Optimization: Don't encode document past the batch size in batched op_query.

(cherry picked from commit cd787dbb2c)
2019-03-27 13:07:04 -07:00
Shane Harvey
c339077ec9 PYTHON-1662 Add ChangeStream.try_next API
(cherry picked from commit 92ddc09b7e)
2019-03-26 11:15:54 -07:00
Shane Harvey
86463345e6 PYTHON-1784 Add filter support to list_collection_names
Adhere to enumerate collection spec for setting nameOnly when filter
is provided to allow filtering based on collection options.

(cherry picked from commit 11967eb160)
2019-03-22 17:07:43 -07:00
Shane Harvey
0314ab18b5 PYTHON-1786 Send comment with Cursor.count and Cursor.distinct
(cherry picked from commit 4169a04821)
2019-03-22 17:03:20 -07:00
Shane Harvey
e7ddc291b1 PYTHON-1781 Raise a client side error when attempting a sharded transaction 2019-03-22 16:56:20 -07:00
Bernie Hackett
ea62ce57d7 PYTHON-1685 - Renovate get_default_database
(cherry picked from commit c55a66235d)
2019-03-22 16:50:32 -07:00
Prashant Mital
8f4b0c598a
PYTHON-1476 Add changelog entry for flexible BSON work
(cherry picked from commit 6bea39d7ca)
2019-03-22 13:05:10 -07:00
Prashant Mital
214f4ea5ea
PYTHON-1782 Restore backwards compatibility of MongoClient initialization when passing a type_registry
(cherry picked from commit 76ef2b473f)
2019-03-22 10:32:15 -07:00
Prashant Mital
f88770798b
PYTHON-1786 Skip test_cursor.TestCursor.test_comment until SERVER-40270 is resolved
(cherry picked from commit 1b0b17450a)
2019-03-22 10:27:39 -07:00
Prashant Mital
c1453562f6
PYTHON-1782 Allow MongoClient to be initialized with type_registry
(cherry picked from commit 599e2d7117)
2019-03-21 15:34:59 -07:00
Prashant Mital
2e8d897026
PYTHON-1696 Stop encouraging use of BSON.decode as a classmethod
(cherry picked from commit cda0b71b78)
2019-03-21 11:46:48 -07:00
Shane Harvey
ffd5bae06a PYTHON-1780 Test against Python 2.7 on linux-64-amzn
(cherry picked from commit 704905844d)
2019-03-21 10:11:45 -07:00
Prashant Mital
2867fe544c
PYTHON-1769 Re-define TypeCodecBase as an AbstractBaseClass
(cherry picked from commit 65f85f648c)
2019-03-19 18:03:25 -05:00
Bernie Hackett
713419ef4e PYTHON-1656 explain() uses server default verbosity
(cherry picked from commit fe307058c8)
2019-03-19 14:37:59 -07:00
Bernie Hackett
6a917904f4 PYTHON-1762 Document that limit=0 means no limit
(cherry picked from commit c0c7c05652)
2019-03-19 11:27:06 -07:00
Bernie Hackett
3c86686c0a PYTHON-1726 Disable TLS renegotiation when possible
(cherry picked from commit bb83a26082)
2019-03-19 11:22:12 -07:00
Prashant Mital
9093ddf365
PYTHON-1731 Implement callback for unencodable types
(cherry picked from commit e01efc7073)
2019-03-18 14:09:29 -05:00
Shane Harvey
3ef4aa982c PYTHON-1600 Avoid race condition in test_last_write_date
(cherry picked from commit a1b04628b9)
2019-03-15 17:29:51 -07:00
Shane Harvey
bdfa2919e0 PYTHON-1757 Properly skip tests that require enableTestCommands
(cherry picked from commit 923229de12)
2019-03-14 16:28:54 -07:00
Shane Harvey
2ab5d181b6 PYTHON-1773 Test against clusters with enableTestCommands=0
(cherry picked from commit 9f9b888111)
2019-03-14 16:26:16 -07:00
Shane Harvey
8069e13232 PYTHON-1721 Improve GridFS file download performance (#413)
This change uses a cursor to download all the chunks in a GridFS file
instead of using individual find_one operations to read each chunk.
Detect truncated/missing/extra chunks in _GridOutChunkIterator.
Only detect extra chunks after reading the final chunk, not on every
call to read().
Retry once after CursorNotFound for backward compatibility.

(cherry picked from commit 956fd92e82)
2019-03-13 15:48:38 -07:00
Shane Harvey
905c578fe6 PYTHON-1721 Add EventListener.started_command_names 2019-03-13 15:47:46 -07:00
Shane Harvey
61b38da1d2 PYTHON-1709 Always use codec_options in Database.current_op
(cherry picked from commit da2ba8d7ed)
2019-03-13 13:36:48 -07:00
Prashant Mital
7b6db40c7d
PYTHON-1750 Support callbacks for simple types (#405)
(cherry picked from commit 83755b8739)
2019-03-13 10:41:33 -07:00
Shane Harvey
a583eec290 PYTHON-1491 Enable OP_KILL_CURSORS test
(cherry picked from commit 59c3a22115)
2019-03-12 17:59:56 -07:00
Shane Harvey
c7f5b1b1fe PYTHON-1725 Fix TestThreadsAuth.test_auto_auth_login
Create the database upfront to avoid test failures on sharded clusters.

(cherry picked from commit fd34c1da2a)
2019-03-12 16:29:27 -07:00
Shane Harvey
4110828b08 PYTHON-1644 Only run doctests against standalone servers
(cherry picked from commit ddfa412064)
2019-03-12 16:29:21 -07:00
Shane Harvey
225bec46d4 PYTHON-1767 Ignore keyPattern/keyValue fields in doctest duplicate key error (#412)
(cherry picked from commit 57a9b62e9d)
2019-03-12 16:28:56 -07:00
Shane Harvey
3082649e89 PYTHON-1766 Use insert_many to reduce test runtime (#410)
(cherry picked from commit a84f50b998)
2019-03-12 16:28:53 -07:00
Shane Harvey
f24f3f0b6e
Add changelog for 3.7.2 (#407) 2019-03-01 10:52:48 -08:00
57 changed files with 3316 additions and 567 deletions

View File

@ -351,7 +351,13 @@ functions:
params: params:
script: | script: |
${PREPARE_SHELL} ${PREPARE_SHELL}
MONGODB_VERSION=${VERSION} TOPOLOGY=${TOPOLOGY} AUTH=${AUTH} SSL=${SSL} STORAGE_ENGINE=${STORAGE_ENGINE} sh ${DRIVERS_TOOLS}/.evergreen/run-orchestration.sh MONGODB_VERSION=${VERSION} \
TOPOLOGY=${TOPOLOGY} \
AUTH=${AUTH} \
SSL=${SSL} \
STORAGE_ENGINE=${STORAGE_ENGINE} \
DISABLE_TEST_COMMANDS=${DISABLE_TEST_COMMANDS} \
sh ${DRIVERS_TOOLS}/.evergreen/run-orchestration.sh
# run-orchestration generates expansion file with the MONGODB_URI for the cluster # run-orchestration generates expansion file with the MONGODB_URI for the cluster
- command: expansions.update - command: expansions.update
params: params:
@ -415,7 +421,18 @@ functions:
if [ -n "${MONGODB_STARTED}" ]; then if [ -n "${MONGODB_STARTED}" ]; then
export PYMONGO_MUST_CONNECT=1 export PYMONGO_MUST_CONNECT=1
fi fi
PYTHON_BINARY=${PYTHON_BINARY} GREEN_FRAMEWORK=${GREEN_FRAMEWORK} C_EXTENSIONS=${C_EXTENSIONS} COVERAGE=${COVERAGE} COMPRESSORS=${COMPRESSORS} AUTH=${AUTH} SSL=${SSL} sh ${PROJECT_DIRECTORY}/.evergreen/run-tests.sh if [ -n "${DISABLE_TEST_COMMANDS}" ]; then
export PYMONGO_DISABLE_TEST_COMMANDS=1
fi
PYTHON_BINARY=${PYTHON_BINARY} \
GREEN_FRAMEWORK=${GREEN_FRAMEWORK} \
C_EXTENSIONS=${C_EXTENSIONS} \
COVERAGE=${COVERAGE} \
COMPRESSORS=${COMPRESSORS} \
AUTH=${AUTH} \
SSL=${SSL} \
sh ${PROJECT_DIRECTORY}/.evergreen/run-tests.sh
"run enterprise auth tests": "run enterprise auth tests":
- command: shell.exec - command: shell.exec
@ -554,7 +571,7 @@ tasks:
- func: "bootstrap mongo-orchestration" - func: "bootstrap mongo-orchestration"
vars: vars:
VERSION: "latest" VERSION: "latest"
TOPOLOGY: "replica_set" TOPOLOGY: "server"
- func: "run doctests" - func: "run doctests"
- name: "test-2.6-standalone" - name: "test-2.6-standalone"
@ -855,11 +872,12 @@ axes:
- id: os-fully-featured - id: os-fully-featured
display_name: OS display_name: OS
values: values:
# https://jira.mongodb.org/browse/BUILD-5453 - id: linux-64-amzn-test
#- id: linux-64-amzn-test display_name: "Amazon Linux (Enterprise)"
# display_name: "Amazon Linux (Enterprise)" run_on: linux-64-amzn-test
# run_on: linux-64-amzn-test batchtime: 10080 # 7 days
# batchtime: 10080 # 7 days variables:
PYTHON_BINARY: "python2.7"
- id: ubuntu-14.04 - id: ubuntu-14.04
display_name: "Ubuntu 14.04" display_name: "Ubuntu 14.04"
@ -1098,6 +1116,13 @@ axes:
display_name: InMemory display_name: InMemory
variables: variables:
STORAGE_ENGINE: "inmemory" STORAGE_ENGINE: "inmemory"
- id: disableTestCommands
display_name: Disable test commands
values:
- id: disabled
display_name: disabled
variables:
DISABLE_TEST_COMMANDS: "1"
- id: windows-vs2010-python-version - id: windows-vs2010-python-version
display_name: "Windows Visual Studio 2010 Python" display_name: "Windows Visual Studio 2010 Python"
values: values:
@ -1381,6 +1406,14 @@ buildvariants:
add_tasks: add_tasks:
- "test-3.0-standalone" - "test-3.0-standalone"
# enableTestCommands=0 tests on RHEL 6.2 (x86_64) with Python 2.7.
- matrix_name: "test-disableTestCommands"
matrix_spec: {disableTestCommands: "*", python-version: "2.7"}
display_name: "Disable test commands ${python-version} RHEL 6.2 (x86_64)"
run_on: rhel62-small
tasks:
- ".latest"
- matrix_name: "test-linux-enterprise-auth" - matrix_name: "test-linux-enterprise-auth"
matrix_spec: {"python-version": "*", auth: "auth"} matrix_spec: {"python-version": "*", auth: "auth"}
display_name: "Enterprise Auth Linux ${python-version}" display_name: "Enterprise Auth Linux ${python-version}"

View File

@ -183,14 +183,26 @@ def _get_string(data, position, obj_end, opts, dummy):
opts.unicode_decode_error_handler, True)[0], end + 1 opts.unicode_decode_error_handler, True)[0], end + 1
def _get_object(data, position, obj_end, opts, dummy): def _get_object_size(data, position, obj_end):
"""Decode a BSON subdocument to opts.document_class or bson.dbref.DBRef.""" """Validate and return a BSON document's size."""
obj_size = _UNPACK_INT(data[position:position + 4])[0] try:
obj_size = _UNPACK_INT(data[position:position + 4])[0]
except struct.error as exc:
raise InvalidBSON(str(exc))
end = position + obj_size - 1 end = position + obj_size - 1
if data[end:position + obj_size] != b"\x00": if data[end:end + 1] != b"\x00":
raise InvalidBSON("bad eoo") raise InvalidBSON("bad eoo")
if end >= obj_end: if end >= obj_end:
raise InvalidBSON("invalid object length") raise InvalidBSON("invalid object length")
# If this is the top-level document, validate the total size too.
if position == 0 and obj_size != obj_end:
raise InvalidBSON("invalid object length")
return obj_size, end
def _get_object(data, position, obj_end, opts, dummy):
"""Decode a BSON subdocument to opts.document_class or bson.dbref.DBRef."""
obj_size, end = _get_object_size(data, position, obj_end)
if _raw_document_class(opts.document_class): if _raw_document_class(opts.document_class):
return (opts.document_class(data[position:end + 1], opts), return (opts.document_class(data[position:end + 1], opts),
position + obj_size) position + obj_size)
@ -215,10 +227,11 @@ def _get_array(data, position, obj_end, opts, element_name):
end -= 1 end -= 1
result = [] result = []
# Avoid doing global and attibute lookups in the loop. # Avoid doing global and attribute lookups in the loop.
append = result.append append = result.append
index = data.index index = data.index
getter = _ELEMENT_GETTER getter = _ELEMENT_GETTER
decoder_map = opts.type_registry._decoder_map
while position < end: while position < end:
element_type = data[position:position + 1] element_type = data[position:position + 1]
@ -229,6 +242,12 @@ def _get_array(data, position, obj_end, opts, element_name):
data, position, obj_end, opts, element_name) data, position, obj_end, opts, element_name)
except KeyError: except KeyError:
_raise_unknown_type(element_type, element_name) _raise_unknown_type(element_type, element_name)
if decoder_map:
custom_decoder = decoder_map.get(type(value))
if custom_decoder is not None:
value = custom_decoder(value)
append(value) append(value)
if position != end + 1: if position != end + 1:
@ -388,43 +407,37 @@ def _element_to_dict(data, position, obj_end, opts):
element_name) element_name)
except KeyError: except KeyError:
_raise_unknown_type(element_type, element_name) _raise_unknown_type(element_type, element_name)
if opts.type_registry._decoder_map:
custom_decoder = opts.type_registry._decoder_map.get(type(value))
if custom_decoder is not None:
value = custom_decoder(value)
return element_name, value, position return element_name, value, position
if _USE_C: if _USE_C:
_element_to_dict = _cbson._element_to_dict _element_to_dict = _cbson._element_to_dict
def _iterate_elements(data, position, obj_end, opts): def _elements_to_dict(data, position, obj_end, opts, result=None):
"""Decode a BSON document into result."""
if result is None:
result = opts.document_class()
end = obj_end - 1 end = obj_end - 1
while position < end: while position < end:
(key, value, position) = _element_to_dict(data, position, obj_end, opts) key, value, position = _element_to_dict(data, position, obj_end, opts)
yield key, value, position
def _elements_to_dict(data, position, obj_end, opts):
"""Decode a BSON document."""
result = opts.document_class()
pos = position
for key, value, pos in _iterate_elements(data, position, obj_end, opts):
result[key] = value result[key] = value
if pos != obj_end: if position != obj_end:
raise InvalidBSON('bad object or element length') raise InvalidBSON('bad object or element length')
return result return result
def _bson_to_dict(data, opts): def _bson_to_dict(data, opts):
"""Decode a BSON string to document_class.""" """Decode a BSON string to document_class."""
try:
obj_size = _UNPACK_INT(data[:4])[0]
except struct.error as exc:
raise InvalidBSON(str(exc))
if obj_size != len(data):
raise InvalidBSON("invalid object size")
if data[obj_size - 1:obj_size] != b"\x00":
raise InvalidBSON("bad eoo")
try: try:
if _raw_document_class(opts.document_class): if _raw_document_class(opts.document_class):
return opts.document_class(data, opts) return opts.document_class(data, opts)
return _elements_to_dict(data, 4, obj_size - 1, opts) _, end = _get_object_size(data, 0, len(data))
return _elements_to_dict(data, 4, end, opts)
except InvalidBSON: except InvalidBSON:
raise raise
except Exception: except Exception:
@ -746,9 +759,13 @@ if not PY3:
_ENCODERS[long] = _encode_long _ENCODERS[long] = _encode_long
def _name_value_to_bson(name, value, check_keys, opts): _BUILT_IN_TYPES = tuple(t for t in _ENCODERS)
"""Encode a single name, value pair."""
def _name_value_to_bson(name, value, check_keys, opts,
in_custom_call=False,
in_fallback_call=False):
"""Encode a single name, value pair."""
# First see if the type is already cached. KeyError will only ever # First see if the type is already cached. KeyError will only ever
# happen once per subtype. # happen once per subtype.
try: try:
@ -766,17 +783,36 @@ def _name_value_to_bson(name, value, check_keys, opts):
_ENCODERS[type(value)] = func _ENCODERS[type(value)] = func
return func(name, value, check_keys, opts) return func(name, value, check_keys, opts)
# If all else fails test each base type. This will only happen once for # Third, check if a type encoder is registered for this type.
# a subtype of a supported base type. # Note that subtypes of registered custom types are not auto-encoded.
for base in _ENCODERS: if not in_custom_call and opts.type_registry._encoder_map:
custom_encoder = opts.type_registry._encoder_map.get(type(value))
if custom_encoder is not None:
return _name_value_to_bson(
name, custom_encoder(value), check_keys, opts,
in_custom_call=True)
# Fourth, test each base type. This will only happen once for
# a subtype of a supported base type. Unlike in the C-extensions, this
# is done after trying the custom type encoder because checking for each
# subtype is expensive.
for base in _BUILT_IN_TYPES:
if isinstance(value, base): if isinstance(value, base):
func = _ENCODERS[base] func = _ENCODERS[base]
# Cache this type for faster subsequent lookup. # Cache this type for faster subsequent lookup.
_ENCODERS[type(value)] = func _ENCODERS[type(value)] = func
return func(name, value, check_keys, opts) return func(name, value, check_keys, opts)
raise InvalidDocument("cannot convert value of type %s to bson" % # As a last resort, try using the fallback encoder, if the user has
type(value)) # provided one.
fallback_encoder = opts.type_registry._fallback_encoder
if not in_fallback_call and fallback_encoder is not None:
return _name_value_to_bson(
name, fallback_encoder(value), check_keys, opts,
in_fallback_call=True)
raise InvalidDocument(
"cannot encode object: %r, of type: %r" % (value, type(value)))
def _element_to_bson(key, value, check_keys, opts): def _element_to_bson(key, value, check_keys, opts):
@ -911,6 +947,62 @@ if _USE_C:
decode_all = _cbson.decode_all decode_all = _cbson.decode_all
def _decode_selective(rawdoc, fields, codec_options):
if _raw_document_class(codec_options.document_class):
# If document_class is RawBSONDocument, use vanilla dictionary for
# decoding command response.
doc = {}
else:
# Else, use the specified document_class.
doc = codec_options.document_class()
for key, value in iteritems(rawdoc):
if key in fields:
if fields[key] == 1:
doc[key] = _bson_to_dict(rawdoc.raw, codec_options)[key]
else:
doc[key] = _decode_selective(value, fields[key], codec_options)
else:
doc[key] = value
return doc
def _decode_all_selective(data, codec_options, fields):
"""Decode BSON data to a single document while using user-provided
custom decoding logic.
`data` must be a string representing a valid, BSON-encoded document.
:Parameters:
- `data`: BSON data
- `codec_options`: An instance of
:class:`~bson.codec_options.CodecOptions` with user-specified type
decoders. If no decoders are found, this method is the same as
``decode_all``.
- `fields`: Map of document namespaces where data that needs
to be custom decoded lives or None. For example, to custom decode a
list of objects in 'field1.subfield1', the specified value should be
``{'field1': {'subfield1': 1}}``. If ``fields`` is an empty map or
None, this method is the same as ``decode_all``.
:Returns:
- `document_list`: Single-member list containing the decoded document.
.. versionadded:: 3.8
"""
if not codec_options.type_registry._decoder_map:
return decode_all(data, codec_options)
if not fields:
return decode_all(data, codec_options.with_options(type_registry=None))
# Decode documents for internal use.
from bson.raw_bson import RawBSONDocument
internal_codec_options = codec_options.with_options(
document_class=RawBSONDocument, type_registry=None)
_doc = _bson_to_dict(data, internal_codec_options)
return [_decode_selective(_doc, fields, codec_options,)]
def decode_iter(data, codec_options=DEFAULT_CODEC_OPTIONS): def decode_iter(data, codec_options=DEFAULT_CODEC_OPTIONS):
"""Decode BSON data to multiple documents as a generator. """Decode BSON data to multiple documents as a generator.
@ -1037,10 +1129,10 @@ class BSON(bytes):
>>> import bson >>> import bson
>>> from bson.codec_options import CodecOptions >>> from bson.codec_options import CodecOptions
>>> data = bson.BSON.encode({'a': 1}) >>> data = bson.BSON.encode({'a': 1})
>>> decoded_doc = bson.BSON.decode(data) >>> decoded_doc = bson.BSON(data).decode()
<type 'dict'> <type 'dict'>
>>> options = CodecOptions(document_class=collections.OrderedDict) >>> options = CodecOptions(document_class=collections.OrderedDict)
>>> decoded_doc = bson.BSON.decode(data, codec_options=options) >>> decoded_doc = bson.BSON(data).decode(codec_options=options)
>>> type(decoded_doc) >>> type(decoded_doc)
<class 'collections.OrderedDict'> <class 'collections.OrderedDict'>

View File

@ -119,7 +119,9 @@ static PyObject* elements_to_dict(PyObject* self, const char* string,
static int _write_element_to_buffer(PyObject* self, buffer_t buffer, static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
int type_byte, PyObject* value, int type_byte, PyObject* value,
unsigned char check_keys, unsigned char check_keys,
const codec_options_t* options); const codec_options_t* options,
unsigned char in_custom_call,
unsigned char in_fallback_call);
/* Date stuff */ /* Date stuff */
static PyObject* datetime_from_millis(long long millis) { static PyObject* datetime_from_millis(long long millis) {
@ -345,9 +347,9 @@ static int _load_object(PyObject** object, char* module_name, char* object_name)
* *
* Returns non-zero on failure. */ * Returns non-zero on failure. */
static int _load_python_objects(PyObject* module) { static int _load_python_objects(PyObject* module) {
PyObject* empty_string; PyObject* empty_string = NULL;
PyObject* re_compile; PyObject* re_compile = NULL;
PyObject* compiled; PyObject* compiled = NULL;
struct module_state *state = GETSTATE(module); struct module_state *state = GETSTATE(module);
if (_load_object(&state->Binary, "bson.binary", "Binary") || if (_load_object(&state->Binary, "bson.binary", "Binary") ||
@ -447,6 +449,46 @@ static long _type_marker(PyObject* object) {
return type; return type;
} }
/* Fill out a type_registry_t* from a TypeRegistry object.
*
* Return 1 on success. options->document_class is a new reference.
* Return 0 on failure.
*/
int convert_type_registry(PyObject* registry_obj, type_registry_t* registry) {
registry->encoder_map = NULL;
registry->decoder_map = NULL;
registry->fallback_encoder = NULL;
registry->registry_obj = NULL;
registry->encoder_map = PyObject_GetAttrString(registry_obj, "_encoder_map");
if (registry->encoder_map == NULL) {
goto fail;
}
registry->is_encoder_empty = (PyDict_Size(registry->encoder_map) == 0);
registry->decoder_map = PyObject_GetAttrString(registry_obj, "_decoder_map");
if (registry->decoder_map == NULL) {
goto fail;
}
registry->is_decoder_empty = (PyDict_Size(registry->decoder_map) == 0);
registry->fallback_encoder = PyObject_GetAttrString(registry_obj, "_fallback_encoder");
if (registry->fallback_encoder == NULL) {
goto fail;
}
registry->has_fallback_encoder = (registry->fallback_encoder != Py_None);
registry->registry_obj = registry_obj;
Py_INCREF(registry->registry_obj);
return 1;
fail:
Py_XDECREF(registry->encoder_map);
Py_XDECREF(registry->decoder_map);
Py_XDECREF(registry->fallback_encoder);
return 0;
}
/* Fill out a codec_options_t* from a CodecOptions object. Use with the "O&" /* Fill out a codec_options_t* from a CodecOptions object. Use with the "O&"
* format spec in PyArg_ParseTuple. * format spec in PyArg_ParseTuple.
* *
@ -455,25 +497,37 @@ static long _type_marker(PyObject* object) {
*/ */
int convert_codec_options(PyObject* options_obj, void* p) { int convert_codec_options(PyObject* options_obj, void* p) {
codec_options_t* options = (codec_options_t*)p; codec_options_t* options = (codec_options_t*)p;
PyObject* type_registry_obj = NULL;
long type_marker; long type_marker;
options->unicode_decode_error_handler = NULL; options->unicode_decode_error_handler = NULL;
if (!PyArg_ParseTuple(options_obj, "ObbzO",
if (!PyArg_ParseTuple(options_obj, "ObbzOO",
&options->document_class, &options->document_class,
&options->tz_aware, &options->tz_aware,
&options->uuid_rep, &options->uuid_rep,
&options->unicode_decode_error_handler, &options->unicode_decode_error_handler,
&options->tzinfo)) { &options->tzinfo,
&type_registry_obj))
return 0;
type_marker = _type_marker(options->document_class);
if (type_marker < 0) {
return 0; return 0;
} }
type_marker = _type_marker(options->document_class); if (!convert_type_registry(type_registry_obj,
if (type_marker < 0) return 0; &options->type_registry)) {
return 0;
}
options->is_raw_bson = (101 == type_marker);
options->options_obj = options_obj;
Py_INCREF(options->options_obj);
Py_INCREF(options->document_class); Py_INCREF(options->document_class);
Py_INCREF(options->tzinfo); Py_INCREF(options->tzinfo);
options->options_obj = options_obj;
Py_INCREF(options->options_obj);
options->is_raw_bson = (101 == type_marker);
return 1; return 1;
} }
@ -501,17 +555,25 @@ void destroy_codec_options(codec_options_t* options) {
Py_CLEAR(options->document_class); Py_CLEAR(options->document_class);
Py_CLEAR(options->tzinfo); Py_CLEAR(options->tzinfo);
Py_CLEAR(options->options_obj); Py_CLEAR(options->options_obj);
Py_CLEAR(options->type_registry.registry_obj);
Py_CLEAR(options->type_registry.encoder_map);
Py_CLEAR(options->type_registry.decoder_map);
Py_CLEAR(options->type_registry.fallback_encoder);
} }
static int write_element_to_buffer(PyObject* self, buffer_t buffer, static int write_element_to_buffer(PyObject* self, buffer_t buffer,
int type_byte, PyObject* value, int type_byte, PyObject* value,
unsigned char check_keys, unsigned char check_keys,
const codec_options_t* options) { const codec_options_t* options,
int result; unsigned char in_custom_call,
if(Py_EnterRecursiveCall(" while encoding an object to BSON ")) unsigned char in_fallback_call) {
int result = 0;
if(Py_EnterRecursiveCall(" while encoding an object to BSON ")) {
return 0; return 0;
}
result = _write_element_to_buffer(self, buffer, type_byte, result = _write_element_to_buffer(self, buffer, type_byte,
value, check_keys, options); value, check_keys, options,
in_custom_call, in_fallback_call);
Py_LeaveRecursiveCall(); Py_LeaveRecursiveCall();
return result; return result;
} }
@ -531,37 +593,53 @@ _fix_java(const char* in, char* out) {
static void static void
_set_cannot_encode(PyObject* value) { _set_cannot_encode(PyObject* value) {
PyObject* type = NULL;
PyObject* InvalidDocument = _error("InvalidDocument"); PyObject* InvalidDocument = _error("InvalidDocument");
if (InvalidDocument) { if (InvalidDocument == NULL) {
PyObject* repr = PyObject_Repr(value); goto error;
if (repr) {
#if PY_MAJOR_VERSION >= 3
PyObject* errmsg = PyUnicode_FromString("Cannot encode object: ");
#else
PyObject* errmsg = PyString_FromString("Cannot encode object: ");
#endif
if (errmsg) {
#if PY_MAJOR_VERSION >= 3
PyObject* error = PyUnicode_Concat(errmsg, repr);
if (error) {
PyErr_SetObject(InvalidDocument, error);
Py_DECREF(error);
}
Py_DECREF(errmsg);
Py_DECREF(repr);
#else
PyString_ConcatAndDel(&errmsg, repr);
if (errmsg) {
PyErr_SetObject(InvalidDocument, errmsg);
Py_DECREF(errmsg);
}
#endif
} else {
Py_DECREF(repr);
}
}
Py_DECREF(InvalidDocument);
} }
type = PyObject_Type(value);
if (type == NULL) {
goto error;
}
#if PY_MAJOR_VERSION >= 3
PyErr_Format(InvalidDocument, "cannot encode object: %R, of type: %R",
value, type);
#else
else {
PyObject* value_repr = NULL;
PyObject* type_repr = NULL;
char* value_str = NULL;
char* type_str = NULL;
value_repr = PyObject_Repr(value);
if (value_repr == NULL) {
goto py2error;
}
value_str = PyString_AsString(value_repr);
if (value_str == NULL) {
goto py2error;
}
type_repr = PyObject_Repr(type);
if (type_repr == NULL) {
goto py2error;
}
type_str = PyString_AsString(type_repr);
if (type_str == NULL) {
goto py2error;
}
PyErr_Format(InvalidDocument, "cannot encode object: %s, of type: %s",
value_str, type_str);
py2error:
Py_XDECREF(type_repr);
Py_XDECREF(value_repr);
}
#endif
error:
Py_XDECREF(type);
Py_XDECREF(InvalidDocument);
} }
/* /*
@ -697,9 +775,13 @@ static int _write_regex_to_buffer(
static int _write_element_to_buffer(PyObject* self, buffer_t buffer, static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
int type_byte, PyObject* value, int type_byte, PyObject* value,
unsigned char check_keys, unsigned char check_keys,
const codec_options_t* options) { const codec_options_t* options,
unsigned char in_custom_call,
unsigned char in_fallback_call) {
struct module_state *state = GETSTATE(self); struct module_state *state = GETSTATE(self);
PyObject* mapping_type; PyObject* mapping_type;
PyObject* new_value = NULL;
int retval;
PyObject* uuid_type; PyObject* uuid_type;
/* /*
* Don't use PyObject_IsInstance for our custom types. It causes * Don't use PyObject_IsInstance for our custom types. It causes
@ -1092,7 +1174,8 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
if (!(item_value = PySequence_GetItem(value, i))) if (!(item_value = PySequence_GetItem(value, i)))
return 0; return 0;
if (!write_element_to_buffer(self, buffer, list_type_byte, if (!write_element_to_buffer(self, buffer, list_type_byte,
item_value, check_keys, options)) { item_value, check_keys, options,
0, 0)) {
Py_DECREF(item_value); Py_DECREF(item_value);
return 0; return 0;
} }
@ -1290,6 +1373,47 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
} }
Py_XDECREF(mapping_type); Py_XDECREF(mapping_type);
Py_XDECREF(uuid_type); Py_XDECREF(uuid_type);
/* Try a custom encoder if one is provided and we have not already
* attempted to use a type encoder. */
if (!in_custom_call && !options->type_registry.is_encoder_empty) {
PyObject* value_type = NULL;
PyObject* converter = NULL;
value_type = PyObject_Type(value);
if (value_type == NULL) {
return 0;
}
converter = PyDict_GetItem(options->type_registry.encoder_map, value_type);
Py_XDECREF(value_type);
if (converter != NULL) {
/* Transform types that have a registered converter.
* A new reference is created upon transformation. */
new_value = PyObject_CallFunctionObjArgs(converter, value, NULL);
if (new_value == NULL) {
return 0;
}
retval = write_element_to_buffer(self, buffer, type_byte, new_value,
check_keys, options, 1, 0);
Py_XDECREF(new_value);
return retval;
}
}
/* Try the fallback encoder if one is provided and we have not already
* attempted to use the fallback encoder. */
if (!in_fallback_call && options->type_registry.has_fallback_encoder) {
new_value = PyObject_CallFunctionObjArgs(
options->type_registry.fallback_encoder, value, NULL);
if (new_value == NULL) {
// propagate any exception raised by the callback
return 0;
}
retval = write_element_to_buffer(self, buffer, type_byte, new_value,
check_keys, options, 0, 1);
Py_XDECREF(new_value);
return retval;
}
/* We can't determine value's type. Fail. */ /* We can't determine value's type. Fail. */
_set_cannot_encode(value); _set_cannot_encode(value);
return 0; return 0;
@ -1363,7 +1487,7 @@ int write_pair(PyObject* self, buffer_t buffer, const char* name, int name_lengt
return 0; return 0;
} }
if (!write_element_to_buffer(self, buffer, type_byte, if (!write_element_to_buffer(self, buffer, type_byte,
value, check_keys, options)) { value, check_keys, options, 0, 0)) {
return 0; return 0;
} }
return 1; return 1;
@ -2483,6 +2607,24 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
} }
if (value) { if (value) {
if (!options->type_registry.is_decoder_empty) {
PyObject* value_type = NULL;
PyObject* converter = NULL;
value_type = PyObject_Type(value);
if (value_type == NULL) {
goto invalid;
}
converter = PyDict_GetItem(options->type_registry.decoder_map, value_type);
if (converter != NULL) {
PyObject* new_value = PyObject_CallFunctionObjArgs(converter, value, NULL);
Py_DECREF(value_type);
Py_DECREF(value);
return new_value;
} else {
Py_DECREF(value_type);
return value;
}
}
return value; return value;
} }

View File

@ -51,12 +51,23 @@
#define BYTES_FORMAT_STRING "s#" #define BYTES_FORMAT_STRING "s#"
#endif #endif
typedef struct type_registry_t {
PyObject* encoder_map;
PyObject* decoder_map;
PyObject* fallback_encoder;
PyObject* registry_obj;
unsigned char is_encoder_empty;
unsigned char is_decoder_empty;
unsigned char has_fallback_encoder;
} type_registry_t;
typedef struct codec_options_t { typedef struct codec_options_t {
PyObject* document_class; PyObject* document_class;
unsigned char tz_aware; unsigned char tz_aware;
unsigned char uuid_rep; unsigned char uuid_rep;
char* unicode_decode_error_handler; char* unicode_decode_error_handler;
PyObject* tzinfo; PyObject* tzinfo;
type_registry_t type_registry;
PyObject* options_obj; PyObject* options_obj;
unsigned char is_raw_bson; unsigned char is_raw_bson;
} codec_options_t; } codec_options_t;

View File

@ -49,8 +49,7 @@ UUID_SUBTYPE = 4
"""BSON binary subtype for a UUID. """BSON binary subtype for a UUID.
This is the new BSON binary subtype for UUIDs. The This is the new BSON binary subtype for UUIDs. The
current default is :data:`OLD_UUID_SUBTYPE` but will current default is :data:`OLD_UUID_SUBTYPE`.
change to this in a future release.
.. versionchanged:: 2.1 .. versionchanged:: 2.1
Changed to subtype 4. Changed to subtype 4.
@ -125,8 +124,8 @@ class Binary(bytes):
the difference between what should be considered binary data and the difference between what should be considered binary data and
what should be considered a string when we encode to BSON. what should be considered a string when we encode to BSON.
Raises TypeError if `data` is not an instance of :class:`str` Raises TypeError if `data` is not an instance of :class:`bytes`
(:class:`bytes` in python 3) or `subtype` is not an instance of (:class:`str` in python 2) or `subtype` is not an instance of
:class:`int`. Raises ValueError if `subtype` is not in [0, 256). :class:`int`. Raises ValueError if `subtype` is not in [0, 256).
.. note:: .. note::

View File

@ -16,13 +16,16 @@
import datetime import datetime
from abc import abstractmethod
from collections import namedtuple from collections import namedtuple
from bson.py3compat import abc, string_type from bson.py3compat import ABC, abc, abstractproperty, string_type
from bson.binary import (ALL_UUID_REPRESENTATIONS, from bson.binary import (ALL_UUID_REPRESENTATIONS,
PYTHON_LEGACY, PYTHON_LEGACY,
UUID_REPRESENTATION_NAMES) UUID_REPRESENTATION_NAMES)
_RAW_BSON_DOCUMENT_MARKER = 101 _RAW_BSON_DOCUMENT_MARKER = 101
@ -32,10 +35,139 @@ def _raw_document_class(document_class):
return marker == _RAW_BSON_DOCUMENT_MARKER return marker == _RAW_BSON_DOCUMENT_MARKER
class TypeEncoder(ABC):
"""Base class for defining type codec classes which describe how a
custom type can be transformed to one of the types BSON understands.
Codec classes must implement the ``python_type`` attribute, and the
``transform_python`` method to support encoding.
See :ref:`custom-type-type-codec` documentation for an example.
"""
@abstractproperty
def python_type(self):
"""The Python type to be converted into something serializable."""
pass
@abstractmethod
def transform_python(self, value):
"""Convert the given Python object into something serializable."""
pass
class TypeDecoder(ABC):
"""Base class for defining type codec classes which describe how a
BSON type can be transformed to a custom type.
Codec classes must implement the ``bson_type`` attribute, and the
``transform_bson`` method to support decoding.
See :ref:`custom-type-type-codec` documentation for an example.
"""
@abstractproperty
def bson_type(self):
"""The BSON type to be converted into our own type."""
pass
@abstractmethod
def transform_bson(self, value):
"""Convert the given BSON value into our own type."""
pass
class TypeCodec(TypeEncoder, TypeDecoder):
"""Base class for defining type codec classes which describe how a
custom type can be transformed to/from one of the types :mod:`bson`
can already encode/decode.
Codec classes must implement the ``python_type`` attribute, and the
``transform_python`` method to support encoding, as well as the
``bson_type`` attribute, and the ``transform_bson`` method to support
decoding.
See :ref:`custom-type-type-codec` documentation for an example.
"""
pass
class TypeRegistry(object):
"""Encapsulates type codecs used in encoding and / or decoding BSON, as
well as the fallback encoder. Type registries cannot be modified after
instantiation.
``TypeRegistry`` can be initialized with an iterable of type codecs, and
a callable for the fallback encoder::
>>> from bson.codec_options import TypeRegistry
>>> type_registry = TypeRegistry([Codec1, Codec2, Codec3, ...],
... fallback_encoder)
See :ref:`custom-type-type-registry` documentation for an example.
:Parameters:
- `type_codecs` (optional): iterable of type codec instances. If
``type_codecs`` contains multiple codecs that transform a single
python or BSON type, the transformation specified by the type codec
occurring last prevails. A TypeError will be raised if one or more
type codecs modify the encoding behavior of a built-in :mod:`bson`
type.
- `fallback_encoder` (optional): callable that accepts a single,
unencodable python value and transforms it into a type that
:mod:`bson` can encode. See :ref:`fallback-encoder-callable`
documentation for an example.
"""
def __init__(self, type_codecs=None, fallback_encoder=None):
self.__type_codecs = list(type_codecs or [])
self._fallback_encoder = fallback_encoder
self._encoder_map = {}
self._decoder_map = {}
if self._fallback_encoder is not None:
if not callable(fallback_encoder):
raise TypeError("fallback_encoder %r is not a callable" % (
fallback_encoder))
for codec in self.__type_codecs:
is_valid_codec = False
if isinstance(codec, TypeEncoder):
self._validate_type_encoder(codec)
is_valid_codec = True
self._encoder_map[codec.python_type] = codec.transform_python
if isinstance(codec, TypeDecoder):
is_valid_codec = True
self._decoder_map[codec.bson_type] = codec.transform_bson
if not is_valid_codec:
raise TypeError(
"Expected an instance of %s, %s, or %s, got %r instead" % (
TypeEncoder.__name__, TypeDecoder.__name__,
TypeCodec.__name__, codec))
def _validate_type_encoder(self, codec):
from bson import _BUILT_IN_TYPES
for pytype in _BUILT_IN_TYPES:
if issubclass(codec.python_type, pytype):
err_msg = ("TypeEncoders cannot change how built-in types are "
"encoded (encoder %s transforms type %s)" %
(codec, pytype))
raise TypeError(err_msg)
def __repr__(self):
return ('%s(type_codecs=%r, fallback_encoder=%r)' % (
self.__class__.__name__, self.__type_codecs,
self._fallback_encoder))
def __eq__(self, other):
if not isinstance(other, type(self)):
return NotImplemented
return ((self._decoder_map == other._decoder_map) and
(self._encoder_map == other._encoder_map) and
(self._fallback_encoder == other._fallback_encoder))
_options_base = namedtuple( _options_base = namedtuple(
'CodecOptions', 'CodecOptions',
('document_class', 'tz_aware', 'uuid_representation', ('document_class', 'tz_aware', 'uuid_representation',
'unicode_decode_error_handler', 'tzinfo')) 'unicode_decode_error_handler', 'tzinfo', 'type_registry'))
class CodecOptions(_options_base): class CodecOptions(_options_base):
@ -93,6 +225,11 @@ class CodecOptions(_options_base):
- `tzinfo`: A :class:`~datetime.tzinfo` subclass that specifies the - `tzinfo`: A :class:`~datetime.tzinfo` subclass that specifies the
timezone to/from which :class:`~datetime.datetime` objects should be timezone to/from which :class:`~datetime.datetime` objects should be
encoded/decoded. encoded/decoded.
- `type_registry`: Instance of :class:`TypeRegistry` used to customize
encoding and decoding behavior.
.. versionadded:: 3.8
`type_registry` attribute.
.. warning:: Care must be taken when changing .. warning:: Care must be taken when changing
`unicode_decode_error_handler` from its default value ('strict'). `unicode_decode_error_handler` from its default value ('strict').
@ -104,7 +241,7 @@ class CodecOptions(_options_base):
def __new__(cls, document_class=dict, def __new__(cls, document_class=dict,
tz_aware=False, uuid_representation=PYTHON_LEGACY, tz_aware=False, uuid_representation=PYTHON_LEGACY,
unicode_decode_error_handler="strict", unicode_decode_error_handler="strict",
tzinfo=None): tzinfo=None, type_registry=None):
if not (issubclass(document_class, abc.MutableMapping) or if not (issubclass(document_class, abc.MutableMapping) or
_raw_document_class(document_class)): _raw_document_class(document_class)):
raise TypeError("document_class must be dict, bson.son.SON, " raise TypeError("document_class must be dict, bson.son.SON, "
@ -126,9 +263,14 @@ class CodecOptions(_options_base):
raise ValueError( raise ValueError(
"cannot specify tzinfo without also setting tz_aware=True") "cannot specify tzinfo without also setting tz_aware=True")
type_registry = type_registry or TypeRegistry()
if not isinstance(type_registry, TypeRegistry):
raise TypeError("type_registry must be an instance of TypeRegistry")
return tuple.__new__( return tuple.__new__(
cls, (document_class, tz_aware, uuid_representation, cls, (document_class, tz_aware, uuid_representation,
unicode_decode_error_handler, tzinfo)) unicode_decode_error_handler, tzinfo, type_registry))
def _arguments_repr(self): def _arguments_repr(self):
"""Representation of the arguments used to create this object.""" """Representation of the arguments used to create this object."""
@ -139,10 +281,12 @@ class CodecOptions(_options_base):
uuid_rep_repr = UUID_REPRESENTATION_NAMES.get(self.uuid_representation, uuid_rep_repr = UUID_REPRESENTATION_NAMES.get(self.uuid_representation,
self.uuid_representation) self.uuid_representation)
return ('document_class=%s, tz_aware=%r, uuid_representation=' return ('document_class=%s, tz_aware=%r, uuid_representation=%s, '
'%s, unicode_decode_error_handler=%r, tzinfo=%r' % 'unicode_decode_error_handler=%r, tzinfo=%r, '
'type_registry=%r' %
(document_class_repr, self.tz_aware, uuid_rep_repr, (document_class_repr, self.tz_aware, uuid_rep_repr,
self.unicode_decode_error_handler, self.tzinfo)) self.unicode_decode_error_handler, self.tzinfo,
self.type_registry))
def __repr__(self): def __repr__(self):
return '%s(%s)' % (self.__class__.__name__, self._arguments_repr()) return '%s(%s)' % (self.__class__.__name__, self._arguments_repr())
@ -165,7 +309,9 @@ class CodecOptions(_options_base):
kwargs.get('uuid_representation', self.uuid_representation), kwargs.get('uuid_representation', self.uuid_representation),
kwargs.get('unicode_decode_error_handler', kwargs.get('unicode_decode_error_handler',
self.unicode_decode_error_handler), self.unicode_decode_error_handler),
kwargs.get('tzinfo', self.tzinfo)) kwargs.get('tzinfo', self.tzinfo),
kwargs.get('type_registry', self.type_registry)
)
DEFAULT_CODEC_OPTIONS = CodecOptions() DEFAULT_CODEC_OPTIONS = CodecOptions()
@ -183,4 +329,6 @@ def _parse_codec_options(options):
unicode_decode_error_handler=options.get( unicode_decode_error_handler=options.get(
'unicode_decode_error_handler', 'unicode_decode_error_handler',
DEFAULT_CODEC_OPTIONS.unicode_decode_error_handler), DEFAULT_CODEC_OPTIONS.unicode_decode_error_handler),
tzinfo=options.get('tzinfo', DEFAULT_CODEC_OPTIONS.tzinfo)) tzinfo=options.get('tzinfo', DEFAULT_CODEC_OPTIONS.tzinfo),
type_registry=options.get(
'type_registry', DEFAULT_CODEC_OPTIONS.type_registry))

View File

@ -22,8 +22,12 @@ if PY3:
import codecs import codecs
import collections.abc as abc import collections.abc as abc
import _thread as thread import _thread as thread
from abc import ABC, abstractmethod
from io import BytesIO as StringIO from io import BytesIO as StringIO
def abstractproperty(func):
return property(abstractmethod(func))
MAXSIZE = sys.maxsize MAXSIZE = sys.maxsize
imap = map imap = map
@ -60,6 +64,7 @@ if PY3:
else: else:
import collections as abc import collections as abc
import thread import thread
from abc import ABCMeta, abstractproperty
from itertools import imap from itertools import imap
try: try:
@ -67,6 +72,8 @@ else:
except ImportError: except ImportError:
from StringIO import StringIO from StringIO import StringIO
ABC = ABCMeta('ABC', (object,), {})
MAXSIZE = sys.maxint MAXSIZE = sys.maxint
def b(s): def b(s):

View File

@ -15,11 +15,11 @@
"""Tools for representing raw BSON documents. """Tools for representing raw BSON documents.
""" """
from bson import _UNPACK_INT, _iterate_elements from bson import _elements_to_dict, _get_object_size
from bson.py3compat import abc, iteritems from bson.py3compat import abc, iteritems
from bson.codec_options import ( from bson.codec_options import (
DEFAULT_CODEC_OPTIONS as DEFAULT, _RAW_BSON_DOCUMENT_MARKER) DEFAULT_CODEC_OPTIONS as DEFAULT, _RAW_BSON_DOCUMENT_MARKER)
from bson.errors import InvalidBSON from bson.son import SON
class RawBSONDocument(abc.Mapping): class RawBSONDocument(abc.Mapping):
@ -34,12 +34,33 @@ class RawBSONDocument(abc.Mapping):
_type_marker = _RAW_BSON_DOCUMENT_MARKER _type_marker = _RAW_BSON_DOCUMENT_MARKER
def __init__(self, bson_bytes, codec_options=None): def __init__(self, bson_bytes, codec_options=None):
"""Create a new :class:`RawBSONDocument`. """Create a new :class:`RawBSONDocument`
:class:`RawBSONDocument` is a representation of a BSON document that
provides access to the underlying raw BSON bytes. Only when a field is
accessed or modified within the document does RawBSONDocument decode
its bytes.
:class:`RawBSONDocument` implements the ``Mapping`` abstract base
class from the standard library so it can be used like a read-only
``dict``::
>>> raw_doc = RawBSONDocument(BSON.encode({'_id': 'my_doc'}))
>>> raw_doc.raw
b'...'
>>> raw_doc['_id']
'my_doc'
:Parameters: :Parameters:
- `bson_bytes`: the BSON bytes that compose this document - `bson_bytes`: the BSON bytes that compose this document
- `codec_options` (optional): An instance of - `codec_options` (optional): An instance of
:class:`~bson.codec_options.CodecOptions`. :class:`~bson.codec_options.CodecOptions` whose ``document_class``
must be :class:`RawBSONDocument`. The default is
:attr:`DEFAULT_RAW_BSON_OPTIONS`.
.. versionchanged:: 3.8
:class:`RawBSONDocument` now validates that the ``bson_bytes``
passed in represent a single bson document.
.. versionchanged:: 3.5 .. versionchanged:: 3.5
If a :class:`~bson.codec_options.CodecOptions` is passed in, its If a :class:`~bson.codec_options.CodecOptions` is passed in, its
@ -56,6 +77,8 @@ class RawBSONDocument(abc.Mapping):
"RawBSONDocument cannot use CodecOptions with document " "RawBSONDocument cannot use CodecOptions with document "
"class %s" % (codec_options.document_class, )) "class %s" % (codec_options.document_class, ))
self.__codec_options = codec_options self.__codec_options = codec_options
# Validate the bson object size.
_get_object_size(bson_bytes, 0, len(bson_bytes))
@property @property
def raw(self): def raw(self):
@ -70,16 +93,10 @@ class RawBSONDocument(abc.Mapping):
def __inflated(self): def __inflated(self):
if self.__inflated_doc is None: if self.__inflated_doc is None:
# We already validated the object's size when this document was # We already validated the object's size when this document was
# created, so no need to do that again. We still need to check the # created, so no need to do that again.
# size of all the elements and compare to the document size. # Use SON to preserve ordering of elements.
object_size = _UNPACK_INT(self.__raw[:4])[0] - 1 self.__inflated_doc = _elements_to_dict(
position = 0 self.__raw, 4, len(self.__raw)-1, self.__codec_options, SON())
self.__inflated_doc = {}
for key, value, position in _iterate_elements(
self.__raw, 4, object_size, self.__codec_options):
self.__inflated_doc[key] = value
if position != object_size:
raise InvalidBSON('bad object or element length')
return self.__inflated_doc return self.__inflated_doc
def __getitem__(self, item): def __getitem__(self, item):
@ -102,3 +119,6 @@ class RawBSONDocument(abc.Mapping):
DEFAULT_RAW_BSON_OPTIONS = DEFAULT.with_options(document_class=RawBSONDocument) DEFAULT_RAW_BSON_OPTIONS = DEFAULT.with_options(document_class=RawBSONDocument)
"""The default :class:`~bson.codec_options.CodecOptions` for
:class:`RawBSONDocument`.
"""

View File

@ -1,5 +1,5 @@
:mod:`change_stream` -- Watch changes on a collection :mod:`change_stream` -- Watch changes on a collection, database, or cluster
===================================================== ===========================================================================
.. automodule:: pymongo.change_stream .. automodule:: pymongo.change_stream
:members: :members:

View File

@ -40,6 +40,7 @@
.. automethod:: list_database_names .. automethod:: list_database_names
.. automethod:: database_names .. automethod:: database_names
.. automethod:: drop_database .. automethod:: drop_database
.. automethod:: get_default_database
.. automethod:: get_database .. automethod:: get_database
.. automethod:: server_info .. automethod:: server_info
.. automethod:: close_cursor .. automethod:: close_cursor
@ -48,4 +49,3 @@
.. automethod:: watch .. automethod:: watch
.. automethod:: fsync .. automethod:: fsync
.. automethod:: unlock .. automethod:: unlock
.. automethod:: get_default_database

View File

@ -15,6 +15,55 @@ Changes in Version 3.8.0
- :class:`~bson.objectid.ObjectId` now implements the `ObjectID specification - :class:`~bson.objectid.ObjectId` now implements the `ObjectID specification
version 0.2 <https://github.com/mongodb/specifications/blob/master/source/objectid.rst>`_. version 0.2 <https://github.com/mongodb/specifications/blob/master/source/objectid.rst>`_.
- For better performance and to better follow the GridFS spec,
:class:`~gridfs.grid_file.GridOut` now uses a single cursor to read all the
chunks in the file. Previously, each chunk in the file was queried
individually using :meth:`~pymongo.collection.Collection.find_one`.
- :meth:`gridfs.grid_file.GridOut.read` now only checks for extra chunks after
reading the entire file. Previously, this method would check for extra
chunks on every call.
- :meth:`~pymongo.database.Database.current_op` now always uses the
``Database``'s :attr:`~pymongo.database.Database.codec_options`
when decoding the command response. Previously the codec_options
was only used when the MongoDB server version was <= 3.0.
- Undeprecated :meth:`~pymongo.mongo_client.MongoClient.get_default_database`
and added the ``default`` parameter.
- TLS Renegotiation is now disabled when possible.
- Custom types can now be directly encoded to, and decoded from MongoDB using
the :class:`~bson.codec_options.TypeCodec` and
:class:`~bson.codec_options.TypeRegistry` APIs. For more information, see
the :doc:`custom type example <examples/custom_type>`.
- Attempting a multi-document transaction on a sharded cluster now raises a
:exc:`~pymongo.errors.ConfigurationError`.
- :meth:`pymongo.cursor.Cursor.distinct` and
:meth:`pymongo.cursor.Cursor.count` now send the Cursor's
:meth:`~pymongo.cursor.Cursor.comment` as the "comment" top-level
command option instead of "$comment". Also, note that "comment" must be a
string.
- Add the ``filter`` parameter to
:meth:`~pymongo.database.Database.list_collection_names`.
- Changes can now be requested from a ``ChangeStream`` cursor without blocking
indefinitely using the new
:meth:`pymongo.change_stream.ChangeStream.try_next` method.
- Fixed a reference leak bug when splitting a batched write command based on
maxWriteBatchSize or the max message size.
- Deprecated running find queries that set :meth:`~pymongo.cursor.Cursor.min`
and/or :meth:`~pymongo.cursor.Cursor.max` but do not also set a
:meth:`~pymongo.cursor.Cursor.hint` of which index to use. The find command
is expected to require a :meth:`~pymongo.cursor.Cursor.hint` when using
min/max starting in MongoDB 4.2.
- Documented support for the uuidRepresentation URI option, which has been
supported since PyMongo 2.7. Valid values are `pythonLegacy` (the default),
`javaLegacy`, `csharpLegacy` and `standard`. New applications should consider
setting this to `standard` for cross language compatibility.
- :class:`~bson.raw_bson.RawBSONDocument` now validates that the ``bson_bytes``
passed in represent a single bson document. Earlier versions would mistakenly
accept multiple bson documents.
- Iterating over a :class:`~bson.raw_bson.RawBSONDocument` now maintains the
same field order of the underlying raw BSON document.
- Applications can now register a custom server selector. For more information
see :doc:`custom type example <examples/server_selection>`.
- The connection pool now implements a LIFO policy.
Issues Resolved Issues Resolved
............... ...............
@ -24,6 +73,30 @@ in this release.
.. _PyMongo 3.8 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=19904 .. _PyMongo 3.8 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=19904
Changes in Version 3.7.2
------------------------
Version 3.7.2 fixes a few issues discovered since the release of 3.7.1.
- Fixed a bug in retryable writes where a previous command's "txnNumber"
field could be sent leading to incorrect results.
- Fixed a memory leak of a few bytes on some insert, update, or delete
commands when running against MongoDB 3.6+.
- Fixed a bug that caused :meth:`pymongo.collection.Collection.ensure_index`
to only cache a single index per database.
- Updated the documentation examples to use
:meth:`pymongo.collection.Collection.count_documents` instead of
:meth:`pymongo.collection.Collection.count` and
:meth:`pymongo.cursor.Cursor.count`.
Issues Resolved
...............
See the `PyMongo 3.7.2 release notes in JIRA`_ for the list of resolved issues
in this release.
.. _PyMongo 3.7.2 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=21519
Changes in Version 3.7.1 Changes in Version 3.7.1
------------------------ ------------------------

View File

@ -166,6 +166,8 @@ latex_documents = [
# If false, no module index is generated. # If false, no module index is generated.
#latex_use_modindex = True #latex_use_modindex = True
intersphinx_mapping = { intersphinx_mapping = {
'gevent': ('http://www.gevent.org/', None), 'gevent': ('http://www.gevent.org/', None),
'py': ('https://docs.python.org/3/', None),
} }

View File

@ -85,3 +85,4 @@ The following is a list of people who have contributed to
- Jagrut Trivedi(Jagrut) - Jagrut Trivedi(Jagrut)
- Shrey Batra(shreybatra) - Shrey Batra(shreybatra)
- Felipe Rodrigues(fbidu) - Felipe Rodrigues(fbidu)
- Terence Honles (terencehonles)

View File

@ -109,7 +109,7 @@ the failure.
'writeConcernErrors': [], 'writeConcernErrors': [],
'writeErrors': [{u'code': 11000, 'writeErrors': [{u'code': 11000,
u'errmsg': u'...E11000...duplicate key error...', u'errmsg': u'...E11000...duplicate key error...',
u'index': 1, u'index': 1,...
u'op': {'_id': 4}}]} u'op': {'_id': 4}}]}
.. _unordered_bulk: .. _unordered_bulk:
@ -147,11 +147,11 @@ and fourth operations succeed.
'writeConcernErrors': [], 'writeConcernErrors': [],
'writeErrors': [{u'code': 11000, 'writeErrors': [{u'code': 11000,
u'errmsg': u'...E11000...duplicate key error...', u'errmsg': u'...E11000...duplicate key error...',
u'index': 0, u'index': 0,...
u'op': {'_id': 1}}, u'op': {'_id': 1}},
{u'code': 11000, {u'code': 11000,
u'errmsg': u'...E11000...duplicate key error...', u'errmsg': u'...E11000...duplicate key error...',
u'index': 2, u'index': 2,...
u'op': {'_id': 3}}]} u'op': {'_id': 3}}]}
Write Concern Write Concern

View File

@ -0,0 +1,421 @@
Custom Type Example
===================
This is an example of using a custom type with PyMongo. The example here shows
how to subclass :class:`~bson.codec_options.TypeCodec` to write a type
codec, which is used to populate a :class:`~bson.codec_options.TypeRegistry`.
The type registry can then be used to create a custom-type-aware
:class:`~pymongo.collection.Collection`. Read and write operations
issued against the resulting collection object transparently manipulate
documents as they are saved to or retrieved from MongoDB.
Setting Up
----------
We'll start by getting a clean database to use for the example:
.. doctest::
>>> from pymongo import MongoClient
>>> client = MongoClient()
>>> client.drop_database('custom_type_example')
>>> db = client.custom_type_example
Since the purpose of the example is to demonstrate working with custom types,
we'll need a custom data type to use. For this example, we will be working with
the :py:class:`~decimal.Decimal` type from Python's standard library. Since the
BSON library's :class:`~bson.decimal128.Decimal128` type (that implements
the IEEE 754 decimal128 decimal-based floating-point numbering format) is
distinct from Python's built-in :py:class:`~decimal.Decimal` type, attempting
to save an instance of ``Decimal`` with PyMongo, results in an
:exc:`~bson.errors.InvalidDocument` exception.
.. doctest::
>>> from decimal import Decimal
>>> num = Decimal("45.321")
>>> db.test.insert_one({'num': num})
Traceback (most recent call last):
...
bson.errors.InvalidDocument: cannot encode object: Decimal('45.321'), of type: <class 'decimal.Decimal'>
.. _custom-type-type-codec:
The :class:`~bson.codec_options.TypeCodec` Class
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. versionadded:: 3.8
In order to encode a custom type, we must first define a **type codec** for
that type. A type codec describes how an instance of a custom type can be
*transformed* to and/or from one of the types :mod:`~bson` already understands.
Depending on the desired functionality, users must choose from the following
base classes when defining type codecs:
* :class:`~bson.codec_options.TypeEncoder`: subclass this to define a codec that
encodes a custom Python type to a known BSON type. Users must implement the
``python_type`` property/attribute and the ``transform_python`` method.
* :class:`~bson.codec_options.TypeDecoder`: subclass this to define a codec that
decodes a specified BSON type into a custom Python type. Users must implement
the ``bson_type`` property/attribute and the ``transform_bson`` method.
* :class:`~bson.codec_options.TypeCodec`: subclass this to define a codec that
can both encode and decode a custom type. Users must implement the
``python_type`` and ``bson_type`` properties/attributes, as well as the
``transform_python`` and ``transform_bson`` methods.
The type codec for our custom type simply needs to define how a
:py:class:`~decimal.Decimal` instance can be converted into a
:class:`~bson.decimal128.Decimal128` instance and vice-versa. Since we are
interested in both encoding and decoding our custom type, we use the
``TypeCodec`` base class to define our codec:
.. doctest::
>>> from bson.decimal128 import Decimal128
>>> from bson.codec_options import TypeCodec
>>> class DecimalCodec(TypeCodec):
... python_type = Decimal # the Python type acted upon by this type codec
... bson_type = Decimal128 # the BSON type acted upon by this type codec
... def transform_python(self, value):
... """Function that transforms a custom type value into a type
... that BSON can encode."""
... return Decimal128(value)
... def transform_bson(self, value):
... """Function that transforms a vanilla BSON type value into our
... custom type."""
... return value.to_decimal()
>>> decimal_codec = DecimalCodec()
.. _custom-type-type-registry:
The :class:`~bson.codec_options.TypeRegistry` Class
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. versionadded:: 3.8
Before we can begin encoding and decoding our custom type objects, we must
first inform PyMongo about the corresponding codec. This is done by creating
a :class:`~bson.codec_options.TypeRegistry` instance:
.. doctest::
>>> from bson.codec_options import TypeRegistry
>>> type_registry = TypeRegistry([decimal_codec])
Note that type registries can be instantiated with any number of type codecs.
Once instantiated, registries are immutable and the only way to add codecs
to a registry is to create a new one.
Putting It Together
-------------------
Finally, we can define a :class:`~bson.codec_options.CodecOptions` instance
with our ``type_registry`` and use it to get a
:class:`~pymongo.collection.Collection` object that understands the
:py:class:`~decimal.Decimal` data type:
.. doctest::
>>> from bson.codec_options import CodecOptions
>>> codec_options = CodecOptions(type_registry=type_registry)
>>> collection = db.get_collection('test', codec_options=codec_options)
Now, we can seamlessly encode and decode instances of
:py:class:`~decimal.Decimal`:
.. doctest::
>>> collection.insert_one({'num': Decimal("45.321")})
<pymongo.results.InsertOneResult object at ...>
>>> mydoc = collection.find_one()
>>> import pprint
>>> pprint.pprint(mydoc)
{u'_id': ObjectId('...'), u'num': Decimal('45.321')}
We can see what's actually being saved to the database by creating a fresh
collection object without the customized codec options and using that to query
MongoDB:
.. doctest::
>>> vanilla_collection = db.get_collection('test')
>>> pprint.pprint(vanilla_collection.find_one())
{u'_id': ObjectId('...'), u'num': Decimal128('45.321')}
Encoding Subtypes
^^^^^^^^^^^^^^^^^
Consider the situation where, in addition to encoding
:py:class:`~decimal.Decimal`, we also need to encode a type that subclasses
``Decimal``. PyMongo does this automatically for types that inherit from
Python types that are BSON-encodable by default, but the type codec system
described above does not offer the same flexibility.
Consider this subtype of ``Decimal`` that has a method to return its value as
an integer:
.. doctest::
>>> class DecimalInt(Decimal):
... def my_method(self):
... """Method implementing some custom logic."""
... return int(self)
If we try to save an instance of this type without first registering a type
codec for it, we get an error:
.. doctest::
>>> collection.insert_one({'num': DecimalInt("45.321")})
Traceback (most recent call last):
...
bson.errors.InvalidDocument: cannot encode object: Decimal('45.321'), of type: <class 'decimal.Decimal'>
In order to proceed further, we must define a type codec for ``DecimalInt``.
This is trivial to do since the same transformation as the one used for
``Decimal`` is adequate for encoding ``DecimalInt`` as well:
.. doctest::
>>> class DecimalIntCodec(DecimalCodec):
... @property
... def python_type(self):
... """The Python type acted upon by this type codec."""
... return DecimalInt
>>> decimalint_codec = DecimalIntCodec()
.. note::
No attempt is made to modify decoding behavior because without additional
information, it is impossible to discern which incoming
:class:`~bson.decimal128.Decimal128` value needs to be decoded as ``Decimal``
and which needs to be decoded as ``DecimalInt``. This example only considers
the situation where a user wants to *encode* documents containing either
of these types.
After creating a new codec options object and using it to get a collection
object, we can seamlessly encode instances of ``DecimalInt``:
.. doctest::
>>> type_registry = TypeRegistry([decimal_codec, decimalint_codec])
>>> codec_options = CodecOptions(type_registry=type_registry)
>>> collection = db.get_collection('test', codec_options=codec_options)
>>> collection.drop()
>>> collection.insert_one({'num': DecimalInt("45.321")})
<pymongo.results.InsertOneResult object at ...>
>>> mydoc = collection.find_one()
>>> pprint.pprint(mydoc)
{u'_id': ObjectId('...'), u'num': Decimal('45.321')}
Note that the ``transform_bson`` method of the base codec class results in
these values being decoded as ``Decimal`` (and not ``DecimalInt``).
.. _decoding-binary-types:
Decoding :class:`~bson.binary.Binary` Types
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The decoding treatment of :class:`~bson.binary.Binary` types having
``subtype = 0`` by the :mod:`bson` module varies slightly depending on the
version of the Python runtime in use. This must be taken into account while
writing a ``TypeDecoder`` that modifies how this datatype is decoded.
On Python 3.x, :class:`~bson.binary.Binary` data (``subtype = 0``) is decoded
as a ``bytes`` instance:
.. code-block:: python
>>> # On Python 3.x.
>>> from bson.binary import Binary
>>> newcoll = db.get_collection('new')
>>> newcoll.insert_one({'_id': 1, 'data': Binary(b"123", subtype=0)})
>>> doc = newcoll.find_one()
>>> type(doc['data'])
bytes
On Python 2.7.x, the same data is decoded as a :class:`~bson.binary.Binary`
instance:
.. code-block:: python
>>> # On Python 2.7.x
>>> newcoll = db.get_collection('new')
>>> doc = newcoll.find_one()
>>> type(doc['data'])
bson.binary.Binary
As a consequence of this disparity, users must set the ``bson_type`` attribute
on their :class:`~bson.codec_options.TypeDecoder` classes differently,
depending on the python version in use.
.. note::
For codebases requiring compatibility with both Python 2 and 3, type
decoders will have to be registered for both possible ``bson_type`` values.
.. _fallback-encoder-callable:
The ``fallback_encoder`` Callable
---------------------------------
.. versionadded:: 3.8
In addition to type codecs, users can also register a callable to encode types
that BSON doesn't recognize and for which no type codec has been registered.
This callable is the **fallback encoder** and like the ``transform_python``
method, it accepts an unencodable value as a parameter and returns a
BSON-encodable value. The following fallback encoder encodes python's
:py:class:`~decimal.Decimal` type to a :class:`~bson.decimal128.Decimal128`:
.. doctest::
>>> def fallback_encoder(value):
... if isinstance(value, Decimal):
... return Decimal128(value)
... return value
After declaring the callback, we must create a type registry and codec options
with this fallback encoder before it can be used for initializing a collection:
.. doctest::
>>> type_registry = TypeRegistry(fallback_encoder=fallback_encoder)
>>> codec_options = CodecOptions(type_registry=type_registry)
>>> collection = db.get_collection('test', codec_options=codec_options)
>>> collection.drop()
We can now seamlessly encode instances of :py:class:`~decimal.Decimal`:
.. doctest::
>>> collection.insert_one({'num': Decimal("45.321")})
<pymongo.results.InsertOneResult object at ...>
>>> mydoc = collection.find_one()
>>> pprint.pprint(mydoc)
{u'_id': ObjectId('...'), u'num': Decimal128('45.321')}
.. note::
Fallback encoders are invoked *after* attempts to encode the given value
with standard BSON encoders and any configured type encoders have failed.
Therefore, in a type registry configured with a type encoder and fallback
encoder that both target the same custom type, the behavior specified in
the type encoder will prevail.
Because fallback encoders don't need to declare the types that they encode
beforehand, they can be used to support interesting use-cases that cannot be
serviced by ``TypeEncoder``. One such use-case is described in the next
section.
Encoding Unknown Types
^^^^^^^^^^^^^^^^^^^^^^
In this example, we demonstrate how a fallback encoder can be used to save
arbitrary objects to the database. We will use the the standard library's
:py:mod:`pickle` module to serialize the unknown types and so naturally, this
approach only works for types that are picklable.
We start by defining some arbitrary custom types:
.. code-block:: python
class MyStringType(object):
def __init__(self, value):
self.__value = value
def __repr__(self):
return "MyStringType('%s')" % (self.__value,)
class MyNumberType(object):
def __init__(self, value):
self.__value = value
def __repr__(self):
return "MyNumberType(%s)" % (self.__value,)
We also define a fallback encoder that pickles whatever objects it receives
and returns them as :class:`~bson.binary.Binary` instances with a custom
subtype. The custom subtype, in turn, allows us to write a TypeDecoder that
identifies pickled artifacts upon retrieval and transparently decodes them
back into Python objects:
.. code-block:: python
import pickle
from bson.binary import Binary, USER_DEFINED_SUBTYPE
def fallback_pickle_encoder(value):
return Binary(pickle.dumps(value), USER_DEFINED_SUBTYPE)
class PickledBinaryDecoder(TypeDecoder):
bson_type = Binary
def transform_bson(self, value):
if value.subtype == USER_DEFINED_SUBTYPE:
return pickle.loads(value)
return value
.. note::
The above example is written assuming the use of Python 3. If you are using
Python 2, ``bson_type`` must be set to ``Binary``. See the
:ref:`decoding-binary-types` section for a detailed explanation.
Finally, we create a ``CodecOptions`` instance:
.. code-block:: python
codec_options = CodecOptions(type_registry=TypeRegistry(
[PickledBinaryDecoder()], fallback_encoder=fallback_pickle_encoder))
We can now round trip our custom objects to MongoDB:
.. code-block:: python
collection = db.get_collection('test_fe', codec_options=codec_options)
collection.insert_one({'_id': 1, 'str': MyStringType("hello world"),
'num': MyNumberType(2)})
mydoc = collection.find_one()
assert isinstance(mydoc['str'], MyStringType)
assert isinstance(mydoc['num'], MyNumberType)
Limitations
-----------
PyMongo's type codec and fallback encoder features have the following
limitations:
#. Users cannot customize the encoding behavior of Python types that PyMongo
already understands like ``int`` and ``str`` (the 'built-in types').
Attempting to instantiate a type registry with one or more codecs that act
upon a built-in type results in a ``TypeError``. This limitation extends
to all subtypes of the standard types.
#. Chaining type encoders is not supported. A custom type value, once
transformed by a codec's ``transform_python`` method, *must* result in a
type that is either BSON-encodable by default, or can be
transformed by the fallback encoder into something BSON-encodable--it
*cannot* be transformed a second time by a different type codec.
#. The :meth:`~pymongo.database.Database.command` method does not apply the
user's TypeDecoders while decoding the command response document.
#. :mod:`gridfs` does not apply custom type encoding or decoding to any
documents received from or to returned to the user.

View File

@ -20,6 +20,7 @@ MongoDB, you can start it like so:
authentication authentication
collations collations
copydb copydb
custom_type
bulk bulk
datetimes datetimes
geo geo

View File

@ -248,7 +248,8 @@ collection, configured to use :class:`~bson.son.SON` instead of dict:
tz_aware=False, tz_aware=False,
uuid_representation=PYTHON_LEGACY, uuid_representation=PYTHON_LEGACY,
unicode_decode_error_handler='strict', unicode_decode_error_handler='strict',
tzinfo=None) tzinfo=None, type_registry=TypeRegistry(type_codecs=[],
fallback_encoder=None))
>>> collection_son = collection.with_options(codec_options=opts) >>> collection_son = collection.with_options(codec_options=opts)
Now, documents and subdocuments in query results are represented with Now, documents and subdocuments in query results are represented with

View File

@ -142,7 +142,7 @@ of the collections in our database:
.. doctest:: .. doctest::
>>> db.collection_names(include_system_collections=False) >>> db.list_collection_names()
[u'posts'] [u'posts']
Getting a Single Document With :meth:`~pymongo.collection.Collection.find_one` Getting a Single Document With :meth:`~pymongo.collection.Collection.find_one`

View File

@ -25,7 +25,8 @@ from gridfs.errors import NoFile
from gridfs.grid_file import (GridIn, from gridfs.grid_file import (GridIn,
GridOut, GridOut,
GridOutCursor, GridOutCursor,
DEFAULT_CHUNK_SIZE) DEFAULT_CHUNK_SIZE,
_clear_entity_type_registry)
from pymongo import (ASCENDING, from pymongo import (ASCENDING,
DESCENDING) DESCENDING)
from pymongo.common import UNAUTHORIZED_CODES, validate_string from pymongo.common import UNAUTHORIZED_CODES, validate_string
@ -61,6 +62,8 @@ class GridFS(object):
if not isinstance(database, Database): if not isinstance(database, Database):
raise TypeError("database must be an instance of Database") raise TypeError("database must be an instance of Database")
database = _clear_entity_type_registry(database)
if not database.write_concern.acknowledged: if not database.write_concern.acknowledged:
raise ConfigurationError('database must use ' raise ConfigurationError('database must use '
'acknowledged write_concern') 'acknowledged write_concern')
@ -443,6 +446,8 @@ class GridFSBucket(object):
if not isinstance(db, Database): if not isinstance(db, Database):
raise TypeError("database must be an instance of Database") raise TypeError("database must be an instance of Database")
db = _clear_entity_type_registry(db)
wtc = write_concern if write_concern is not None else db.write_concern wtc = write_concern if write_concern is not None else db.write_concern
if not wtc.acknowledged: if not wtc.acknowledged:
raise ConfigurationError('write concern must be acknowledged') raise ConfigurationError('write concern must be acknowledged')
@ -715,9 +720,9 @@ class GridFSBucket(object):
.. versionchanged:: 3.6 .. versionchanged:: 3.6
Added ``session`` parameter. Added ``session`` parameter.
""" """
gout = self.open_download_stream(file_id, session=session) with self.open_download_stream(file_id, session=session) as gout:
for chunk in gout: for chunk in gout:
destination.write(chunk) destination.write(chunk)
def delete(self, file_id, session=None): def delete(self, file_id, session=None):
"""Given an file_id, delete this stored file's files collection document """Given an file_id, delete this stored file's files collection document
@ -890,10 +895,10 @@ class GridFSBucket(object):
.. versionchanged:: 3.6 .. versionchanged:: 3.6
Added ``session`` parameter. Added ``session`` parameter.
""" """
gout = self.open_download_stream_by_name( with self.open_download_stream_by_name(
filename, revision, session=session) filename, revision, session=session) as gout:
for chunk in gout: for chunk in gout:
destination.write(chunk) destination.write(chunk)
def rename(self, file_id, new_filename, session=None): def rename(self, file_id, new_filename, session=None):
"""Renames the stored file with the specified file_id. """Renames the stored file with the specified file_id.

View File

@ -15,6 +15,7 @@
"""Tools for representing files stored in GridFS.""" """Tools for representing files stored in GridFS."""
import datetime import datetime
import hashlib import hashlib
import io
import math import math
import os import os
@ -27,6 +28,7 @@ from pymongo import ASCENDING
from pymongo.collection import Collection from pymongo.collection import Collection
from pymongo.cursor import Cursor from pymongo.cursor import Cursor
from pymongo.errors import (ConfigurationError, from pymongo.errors import (ConfigurationError,
CursorNotFound,
DuplicateKeyError, DuplicateKeyError,
OperationFailure) OperationFailure)
from pymongo.read_preferences import ReadPreference from pymongo.read_preferences import ReadPreference
@ -96,6 +98,12 @@ def _grid_out_property(field_name, docstring):
return property(getter, doc=docstring) return property(getter, doc=docstring)
def _clear_entity_type_registry(entity, **kwargs):
"""Clear the given database/collection object's type registry."""
codecopts = entity.codec_options.with_options(type_registry=None)
return entity.with_options(codec_options=codecopts, **kwargs)
class GridIn(object): class GridIn(object):
"""Class to write data to GridFS. """Class to write data to GridFS.
""" """
@ -166,8 +174,8 @@ class GridIn(object):
if "chunk_size" in kwargs: if "chunk_size" in kwargs:
kwargs["chunkSize"] = kwargs.pop("chunk_size") kwargs["chunkSize"] = kwargs.pop("chunk_size")
coll = root_collection.with_options( coll = _clear_entity_type_registry(
read_preference=ReadPreference.PRIMARY) root_collection, read_preference=ReadPreference.PRIMARY)
if not disable_md5: if not disable_md5:
kwargs["md5"] = hashlib.md5() kwargs["md5"] = hashlib.md5()
@ -311,6 +319,15 @@ class GridIn(object):
self.__flush() self.__flush()
object.__setattr__(self, "_closed", True) object.__setattr__(self, "_closed", True)
def read(self, size=-1):
raise io.UnsupportedOperation('read')
def readable(self):
return False
def seekable(self):
return False
def write(self, data): def write(self, data):
"""Write data to the file. There is no return value. """Write data to the file. There is no return value.
@ -378,6 +395,9 @@ class GridIn(object):
for line in sequence: for line in sequence:
self.write(line) self.write(line)
def writeable(self):
return True
def __enter__(self): def __enter__(self):
"""Support for the context manager protocol. """Support for the context manager protocol.
""" """
@ -419,6 +439,11 @@ class GridOut(object):
:class:`~pymongo.client_session.ClientSession` to use for all :class:`~pymongo.client_session.ClientSession` to use for all
commands commands
.. versionchanged:: 3.8
For better performance and to better follow the GridFS spec,
:class:`GridOut` now uses a single cursor to read all the chunks in
the file.
.. versionchanged:: 3.6 .. versionchanged:: 3.6
Added ``session`` parameter. Added ``session`` parameter.
@ -430,10 +455,13 @@ class GridOut(object):
raise TypeError("root_collection must be an " raise TypeError("root_collection must be an "
"instance of Collection") "instance of Collection")
root_collection = _clear_entity_type_registry(root_collection)
self.__chunks = root_collection.chunks self.__chunks = root_collection.chunks
self.__files = root_collection.files self.__files = root_collection.files
self.__file_id = file_id self.__file_id = file_id
self.__buffer = EMPTY self.__buffer = EMPTY
self.__chunk_iter = None
self.__position = 0 self.__position = 0
self._file = file_document self._file = file_document
self._session = session self._session = session
@ -465,6 +493,9 @@ class GridOut(object):
return self._file[name] return self._file[name]
raise AttributeError("GridOut object has no attribute '%s'" % name) raise AttributeError("GridOut object has no attribute '%s'" % name)
def readable(self):
return True
def readchunk(self): def readchunk(self):
"""Reads a chunk at a time. If the current position is within a """Reads a chunk at a time. If the current position is within a
chunk the remainder of the chunk is returned. chunk the remainder of the chunk is returned.
@ -477,12 +508,11 @@ class GridOut(object):
chunk_data = self.__buffer chunk_data = self.__buffer
elif self.__position < int(self.length): elif self.__position < int(self.length):
chunk_number = int((received + self.__position) / chunk_size) chunk_number = int((received + self.__position) / chunk_size)
chunk = self.__chunks.find_one({"files_id": self._id, if self.__chunk_iter is None:
"n": chunk_number}, self.__chunk_iter = _GridOutChunkIterator(
session=self._session) self, self.__chunks, self._session, chunk_number)
if not chunk:
raise CorruptGridFile("no chunk #%d" % chunk_number)
chunk = self.__chunk_iter.next()
chunk_data = chunk["data"][self.__position % chunk_size:] chunk_data = chunk["data"][self.__position % chunk_size:]
if not chunk_data: if not chunk_data:
@ -501,16 +531,21 @@ class GridOut(object):
:Parameters: :Parameters:
- `size` (optional): the number of bytes to read - `size` (optional): the number of bytes to read
.. versionchanged:: 3.8
This method now only checks for extra chunks after reading the
entire file. Previously, this method would check for extra chunks
on every call.
""" """
self._ensure_file() self._ensure_file()
if size == 0:
return EMPTY
remainder = int(self.length) - self.__position remainder = int(self.length) - self.__position
if size < 0 or size > remainder: if size < 0 or size > remainder:
size = remainder size = remainder
if size == 0:
return EMPTY
received = 0 received = 0
data = StringIO() data = StringIO()
while received < size: while received < size:
@ -518,16 +553,12 @@ class GridOut(object):
received += len(chunk_data) received += len(chunk_data)
data.write(chunk_data) data.write(chunk_data)
# Detect extra chunks. # Detect extra chunks after reading the entire file.
max_chunk_n = math.ceil(self.length / float(self.chunk_size)) if size == remainder and self.__chunk_iter:
chunk = self.__chunks.find_one({"files_id": self._id, try:
"n": {"$gte": max_chunk_n}}, self.__chunk_iter.next()
session=self._session) except StopIteration:
# According to spec, ignore extra chunks if they are empty. pass
if chunk is not None and len(chunk['data']):
raise CorruptGridFile(
"Extra chunk found: expected %i chunks but found "
"chunk with n=%i" % (max_chunk_n, chunk['n']))
self.__position -= received - size self.__position -= received - size
@ -543,13 +574,13 @@ class GridOut(object):
:Parameters: :Parameters:
- `size` (optional): the maximum number of bytes to read - `size` (optional): the maximum number of bytes to read
""" """
if size == 0:
return b''
remainder = int(self.length) - self.__position remainder = int(self.length) - self.__position
if size < 0 or size > remainder: if size < 0 or size > remainder:
size = remainder size = remainder
if size == 0:
return EMPTY
received = 0 received = 0
data = StringIO() data = StringIO()
while received < size: while received < size:
@ -600,8 +631,18 @@ class GridOut(object):
if new_pos < 0: if new_pos < 0:
raise IOError(22, "Invalid value for `pos` - must be positive") raise IOError(22, "Invalid value for `pos` - must be positive")
# Optimization, continue using the same buffer and chunk iterator.
if new_pos == self.__position:
return
self.__position = new_pos self.__position = new_pos
self.__buffer = EMPTY self.__buffer = EMPTY
if self.__chunk_iter:
self.__chunk_iter.close()
self.__chunk_iter = None
def seekable(self):
return True
def __iter__(self): def __iter__(self):
"""Return an iterator over all of this file's data. """Return an iterator over all of this file's data.
@ -610,12 +651,28 @@ class GridOut(object):
:class:`str` (:class:`bytes` in python 3). This can be :class:`str` (:class:`bytes` in python 3). This can be
useful when serving files using a webserver that handles useful when serving files using a webserver that handles
such an iterator efficiently. such an iterator efficiently.
.. note::
This is different from :py:class:`io.IOBase` which iterates over
*lines* in the file. Use :meth:`GridOut.readline` to read line by
line instead of chunk by chunk.
.. versionchanged:: 3.8
The iterator now raises :class:`CorruptGridFile` when encountering
any truncated, missing, or extra chunk in a file. The previous
behavior was to only raise :class:`CorruptGridFile` on a missing
chunk.
""" """
return GridOutIterator(self, self.__chunks, self._session) return GridOutIterator(self, self.__chunks, self._session)
def close(self): def close(self):
"""Make GridOut more generically file-like.""" """Make GridOut more generically file-like."""
pass if self.__chunk_iter:
self.__chunk_iter.close()
self.__chunk_iter = None
def write(self, value):
raise io.UnsupportedOperation('write')
def __enter__(self): def __enter__(self):
"""Makes it possible to use :class:`GridOut` files """Makes it possible to use :class:`GridOut` files
@ -627,30 +684,108 @@ class GridOut(object):
"""Makes it possible to use :class:`GridOut` files """Makes it possible to use :class:`GridOut` files
with the context manager protocol. with the context manager protocol.
""" """
self.close()
return False return False
class _GridOutChunkIterator(object):
"""Iterates over a file's chunks using a single cursor.
Raises CorruptGridFile when encountering any truncated, missing, or extra
chunk in a file.
"""
def __init__(self, grid_out, chunks, session, next_chunk):
self._id = grid_out._id
self._chunk_size = int(grid_out.chunk_size)
self._length = int(grid_out.length)
self._chunks = chunks
self._session = session
self._next_chunk = next_chunk
self._num_chunks = math.ceil(float(self._length) / self._chunk_size)
self._cursor = None
def expected_chunk_length(self, chunk_n):
if chunk_n < self._num_chunks - 1:
return self._chunk_size
return self._length - (self._chunk_size * (self._num_chunks - 1))
def __iter__(self):
return self
def _create_cursor(self):
filter = {"files_id": self._id}
if self._next_chunk > 0:
filter["n"] = {"$gte": self._next_chunk}
self._cursor = self._chunks.find(filter, sort=[("n", 1)],
session=self._session)
def _next_with_retry(self):
"""Return the next chunk and retry once on CursorNotFound.
We retry on CursorNotFound to maintain backwards compatibility in
cases where two calls to read occur more than 10 minutes apart (the
server's default cursor timeout).
"""
if self._cursor is None:
self._create_cursor()
try:
return self._cursor.next()
except CursorNotFound:
self._cursor.close()
self._create_cursor()
return self._cursor.next()
def next(self):
try:
chunk = self._next_with_retry()
except StopIteration:
if self._next_chunk >= self._num_chunks:
raise
raise CorruptGridFile("no chunk #%d" % self._next_chunk)
if chunk["n"] != self._next_chunk:
self.close()
raise CorruptGridFile(
"Missing chunk: expected chunk #%d but found "
"chunk with n=%d" % (self._next_chunk, chunk["n"]))
if chunk["n"] >= self._num_chunks:
# According to spec, ignore extra chunks if they are empty.
if len(chunk["data"]):
self.close()
raise CorruptGridFile(
"Extra chunk found: expected %d chunks but found "
"chunk with n=%d" % (self._num_chunks, chunk["n"]))
expected_length = self.expected_chunk_length(chunk["n"])
if len(chunk["data"]) != expected_length:
self.close()
raise CorruptGridFile(
"truncated chunk #%d: expected chunk length to be %d but "
"found chunk with length %d" % (
chunk["n"], expected_length, len(chunk["data"])))
self._next_chunk += 1
return chunk
__next__ = next
def close(self):
if self._cursor:
self._cursor.close()
self._cursor = None
class GridOutIterator(object): class GridOutIterator(object):
def __init__(self, grid_out, chunks, session): def __init__(self, grid_out, chunks, session):
self.__id = grid_out._id self.__chunk_iter = _GridOutChunkIterator(grid_out, chunks, session, 0)
self.__chunks = chunks
self.__session = session
self.__current_chunk = 0
self.__max_chunk = math.ceil(float(grid_out.length) /
grid_out.chunk_size)
def __iter__(self): def __iter__(self):
return self return self
def next(self): def next(self):
if self.__current_chunk >= self.__max_chunk: chunk = self.__chunk_iter.next()
raise StopIteration
chunk = self.__chunks.find_one({"files_id": self.__id,
"n": self.__current_chunk},
session=self.__session)
if not chunk:
raise CorruptGridFile("no chunk #%d" % self.__current_chunk)
self.__current_chunk += 1
return bytes(chunk["data"]) return bytes(chunk["data"])
__next__ = next __next__ = next
@ -673,6 +808,8 @@ class GridOutCursor(Cursor):
.. mongodoc:: cursors .. mongodoc:: cursors
""" """
collection = _clear_entity_type_registry(collection)
# Hold on to the base "fs" collection to create GridOut objects later. # Hold on to the base "fs" collection to create GridOut objects later.
self.__root_collection = collection self.__root_collection = collection

View File

@ -64,7 +64,7 @@ SLOW_ONLY = 1
ALL = 2 ALL = 2
"""Profile all operations.""" """Profile all operations."""
version_tuple = (3, 8, 0, '.dev0') version_tuple = (3, 8, 0)
def get_version_string(): def get_version_string():
if isinstance(version_tuple[-1], str): if isinstance(version_tuple[-1], str):

View File

@ -434,7 +434,7 @@ static PyObject* _cbson_query_message(PyObject* self, PyObject* args) {
buffer_t buffer; buffer_t buffer;
int length_location, message_length; int length_location, message_length;
unsigned char check_keys = 0; unsigned char check_keys = 0;
PyObject* result; PyObject* result = NULL;
if (!PyArg_ParseTuple(args, "Iet#iiOOO&|b", if (!PyArg_ParseTuple(args, "Iet#iiOOO&|b",
&flags, &flags,
@ -477,18 +477,14 @@ static PyObject* _cbson_query_message(PyObject* self, PyObject* args) {
/* PyDict_GetItemString returns a borrowed reference. */ /* PyDict_GetItemString returns a borrowed reference. */
Py_INCREF(cluster_time); Py_INCREF(cluster_time);
if (-1 == PyMapping_DelItemString(query, "$clusterTime")) { if (-1 == PyMapping_DelItemString(query, "$clusterTime")) {
destroy_codec_options(&options); goto fail;
PyMem_Free(collection_name);
return NULL;
} }
} }
} else if (PyMapping_HasKeyString(query, "$clusterTime")) { } else if (PyMapping_HasKeyString(query, "$clusterTime")) {
cluster_time = PyMapping_GetItemString(query, "$clusterTime"); cluster_time = PyMapping_GetItemString(query, "$clusterTime");
if (!cluster_time if (!cluster_time
|| -1 == PyMapping_DelItemString(query, "$clusterTime")) { || -1 == PyMapping_DelItemString(query, "$clusterTime")) {
destroy_codec_options(&options); goto fail;
PyMem_Free(collection_name);
return NULL;
} }
} }
if (!buffer_write_int32(buffer, (int32_t)request_id) || if (!buffer_write_int32(buffer, (int32_t)request_id) ||
@ -498,20 +494,12 @@ static PyObject* _cbson_query_message(PyObject* self, PyObject* args) {
collection_name_length + 1) || collection_name_length + 1) ||
!buffer_write_int32(buffer, (int32_t)num_to_skip) || !buffer_write_int32(buffer, (int32_t)num_to_skip) ||
!buffer_write_int32(buffer, (int32_t)num_to_return)) { !buffer_write_int32(buffer, (int32_t)num_to_return)) {
destroy_codec_options(&options); goto fail;
buffer_free(buffer);
PyMem_Free(collection_name);
Py_XDECREF(cluster_time);
return NULL;
} }
begin = buffer_get_position(buffer); begin = buffer_get_position(buffer);
if (!write_dict(state->_cbson, buffer, query, check_keys, &options, 1)) { if (!write_dict(state->_cbson, buffer, query, check_keys, &options, 1)) {
destroy_codec_options(&options); goto fail;
buffer_free(buffer);
PyMem_Free(collection_name);
Py_XDECREF(cluster_time);
return NULL;
} }
/* back up a byte and write $clusterTime */ /* back up a byte and write $clusterTime */
@ -522,19 +510,11 @@ static PyObject* _cbson_query_message(PyObject* self, PyObject* args) {
buffer_update_position(buffer, buffer_get_position(buffer) - 1); buffer_update_position(buffer, buffer_get_position(buffer) - 1);
if (!write_pair(state->_cbson, buffer, "$clusterTime", 12, cluster_time, if (!write_pair(state->_cbson, buffer, "$clusterTime", 12, cluster_time,
0, &options, 1)) { 0, &options, 1)) {
destroy_codec_options(&options); goto fail;
buffer_free(buffer);
PyMem_Free(collection_name);
Py_DECREF(cluster_time);
return NULL;
} }
if (!buffer_write_bytes(buffer, &zero, 1)) { if (!buffer_write_bytes(buffer, &zero, 1)) {
destroy_codec_options(&options); goto fail;
buffer_free(buffer);
PyMem_Free(collection_name);
Py_DECREF(cluster_time);
return NULL;
} }
length = buffer_get_position(buffer) - begin; length = buffer_get_position(buffer) - begin;
@ -543,14 +523,10 @@ static PyObject* _cbson_query_message(PyObject* self, PyObject* args) {
/* undo popping $clusterTime */ /* undo popping $clusterTime */
if (-1 == PyMapping_SetItemString( if (-1 == PyMapping_SetItemString(
query, "$clusterTime", cluster_time)) { query, "$clusterTime", cluster_time)) {
destroy_codec_options(&options); goto fail;
buffer_free(buffer);
PyMem_Free(collection_name);
Py_DECREF(cluster_time);
return NULL;
} }
Py_DECREF(cluster_time); Py_CLEAR(cluster_time);
} }
max_size = buffer_get_position(buffer) - begin; max_size = buffer_get_position(buffer) - begin;
@ -559,17 +535,12 @@ static PyObject* _cbson_query_message(PyObject* self, PyObject* args) {
begin = buffer_get_position(buffer); begin = buffer_get_position(buffer);
if (!write_dict(state->_cbson, buffer, field_selector, 0, if (!write_dict(state->_cbson, buffer, field_selector, 0,
&options, 1)) { &options, 1)) {
destroy_codec_options(&options); goto fail;
buffer_free(buffer);
PyMem_Free(collection_name);
return NULL;
} }
cur_size = buffer_get_position(buffer) - begin; cur_size = buffer_get_position(buffer) - begin;
max_size = (cur_size > max_size) ? cur_size : max_size; max_size = (cur_size > max_size) ? cur_size : max_size;
} }
PyMem_Free(collection_name);
message_length = buffer_get_position(buffer) - length_location; message_length = buffer_get_position(buffer) - length_location;
buffer_write_int32_at_position( buffer_write_int32_at_position(
buffer, length_location, (int32_t)message_length); buffer, length_location, (int32_t)message_length);
@ -579,8 +550,12 @@ static PyObject* _cbson_query_message(PyObject* self, PyObject* args) {
buffer_get_buffer(buffer), buffer_get_buffer(buffer),
buffer_get_position(buffer), buffer_get_position(buffer),
max_size); max_size);
fail:
PyMem_Free(collection_name);
destroy_codec_options(&options); destroy_codec_options(&options);
buffer_free(buffer); buffer_free(buffer);
Py_XDECREF(cluster_time);
return result; return result;
} }
@ -1142,11 +1117,11 @@ _batched_op_msg(
int size_location; int size_location;
int position; int position;
int length; int length;
PyObject* max_bson_size_obj; PyObject* max_bson_size_obj = NULL;
PyObject* max_write_batch_size_obj; PyObject* max_write_batch_size_obj = NULL;
PyObject* max_message_size_obj; PyObject* max_message_size_obj = NULL;
PyObject* doc; PyObject* doc = NULL;
PyObject* iterator; PyObject* iterator = NULL;
char* flags = ack ? "\x00\x00\x00\x00" : "\x02\x00\x00\x00"; char* flags = ack ? "\x00\x00\x00\x00" : "\x02\x00\x00\x00";
max_bson_size_obj = PyObject_GetAttrString(ctx, "max_bson_size"); max_bson_size_obj = PyObject_GetAttrString(ctx, "max_bson_size");
@ -1209,7 +1184,7 @@ _batched_op_msg(
case _INSERT: case _INSERT:
{ {
if (!buffer_write_bytes(buffer, "documents\x00", 10)) if (!buffer_write_bytes(buffer, "documents\x00", 10))
goto cmdfail; goto fail;
break; break;
} }
case _UPDATE: case _UPDATE:
@ -1217,7 +1192,7 @@ _batched_op_msg(
/* MongoDB does key validation for update. */ /* MongoDB does key validation for update. */
check_keys = 0; check_keys = 0;
if (!buffer_write_bytes(buffer, "updates\x00", 8)) if (!buffer_write_bytes(buffer, "updates\x00", 8))
goto cmdfail; goto fail;
break; break;
} }
case _DELETE: case _DELETE:
@ -1225,7 +1200,7 @@ _batched_op_msg(
/* Never check keys in a delete command. */ /* Never check keys in a delete command. */
check_keys = 0; check_keys = 0;
if (!buffer_write_bytes(buffer, "deletes\x00", 8)) if (!buffer_write_bytes(buffer, "deletes\x00", 8))
goto cmdfail; goto fail;
break; break;
} }
default: default:
@ -1255,7 +1230,7 @@ _batched_op_msg(
int unacked_doc_too_large = 0; int unacked_doc_too_large = 0;
if (!write_dict(state->_cbson, buffer, doc, check_keys, if (!write_dict(state->_cbson, buffer, doc, check_keys,
&options, 1)) { &options, 1)) {
goto cmditerfail; goto fail;
} }
cur_size = buffer_get_position(buffer) - cur_doc_begin; cur_size = buffer_get_position(buffer) - cur_doc_begin;
@ -1285,7 +1260,7 @@ _batched_op_msg(
Py_DECREF(DocumentTooLarge); Py_DECREF(DocumentTooLarge);
} }
} }
goto cmditerfail; goto fail;
} }
/* We have enough data, return this batch. */ /* We have enough data, return this batch. */
if (buffer_get_position(buffer) > max_message_size) { if (buffer_get_position(buffer) > max_message_size) {
@ -1294,10 +1269,11 @@ _batched_op_msg(
* of the last document encoded. * of the last document encoded.
*/ */
buffer_update_position(buffer, cur_doc_begin); buffer_update_position(buffer, cur_doc_begin);
Py_CLEAR(doc);
break; break;
} }
if (PyList_Append(to_publish, doc) < 0) { if (PyList_Append(to_publish, doc) < 0) {
goto cmditerfail; goto fail;
} }
Py_CLEAR(doc); Py_CLEAR(doc);
idx += 1; idx += 1;
@ -1306,10 +1282,10 @@ _batched_op_msg(
break; break;
} }
} }
Py_DECREF(iterator); Py_CLEAR(iterator);
if (PyErr_Occurred()) { if (PyErr_Occurred()) {
goto cmdfail; goto fail;
} }
position = buffer_get_position(buffer); position = buffer_get_position(buffer);
@ -1317,10 +1293,9 @@ _batched_op_msg(
buffer_write_int32_at_position(buffer, size_location, (int32_t)length); buffer_write_int32_at_position(buffer, size_location, (int32_t)length);
return 1; return 1;
cmditerfail: fail:
Py_XDECREF(doc); Py_XDECREF(doc);
Py_DECREF(iterator); Py_XDECREF(iterator);
cmdfail:
return 0; return 0;
} }
@ -1466,10 +1441,10 @@ _batched_write_command(
int lst_len_loc; int lst_len_loc;
int position; int position;
int length; int length;
PyObject* max_bson_size_obj; PyObject* max_bson_size_obj = NULL;
PyObject* max_write_batch_size_obj; PyObject* max_write_batch_size_obj = NULL;
PyObject* doc; PyObject* doc = NULL;
PyObject* iterator; PyObject* iterator = NULL;
max_bson_size_obj = PyObject_GetAttrString(ctx, "max_bson_size"); max_bson_size_obj = PyObject_GetAttrString(ctx, "max_bson_size");
#if PY_MAJOR_VERSION >= 3 #if PY_MAJOR_VERSION >= 3
@ -1524,7 +1499,7 @@ _batched_write_command(
case _INSERT: case _INSERT:
{ {
if (!buffer_write_bytes(buffer, "documents\x00", 10)) if (!buffer_write_bytes(buffer, "documents\x00", 10))
goto cmdfail; goto fail;
break; break;
} }
case _UPDATE: case _UPDATE:
@ -1532,7 +1507,7 @@ _batched_write_command(
/* MongoDB does key validation for update. */ /* MongoDB does key validation for update. */
check_keys = 0; check_keys = 0;
if (!buffer_write_bytes(buffer, "updates\x00", 8)) if (!buffer_write_bytes(buffer, "updates\x00", 8))
goto cmdfail; goto fail;
break; break;
} }
case _DELETE: case _DELETE:
@ -1540,7 +1515,7 @@ _batched_write_command(
/* Never check keys in a delete command. */ /* Never check keys in a delete command. */
check_keys = 0; check_keys = 0;
if (!buffer_write_bytes(buffer, "deletes\x00", 8)) if (!buffer_write_bytes(buffer, "deletes\x00", 8))
goto cmdfail; goto fail;
break; break;
} }
default: default:
@ -1575,25 +1550,23 @@ _batched_write_command(
int cur_doc_begin; int cur_doc_begin;
int cur_size; int cur_size;
int enough_data = 0; int enough_data = 0;
int enough_documents = 0;
char key[16]; char key[16];
INT2STRING(key, idx); INT2STRING(key, idx);
if (!buffer_write_bytes(buffer, "\x03", 1) || if (!buffer_write_bytes(buffer, "\x03", 1) ||
!buffer_write_bytes(buffer, key, (int)strlen(key) + 1)) { !buffer_write_bytes(buffer, key, (int)strlen(key) + 1)) {
goto cmditerfail; goto fail;
} }
cur_doc_begin = buffer_get_position(buffer); cur_doc_begin = buffer_get_position(buffer);
if (!write_dict(state->_cbson, buffer, doc, if (!write_dict(state->_cbson, buffer, doc,
check_keys, &options, 1)) { check_keys, &options, 1)) {
goto cmditerfail; goto fail;
} }
/* We have enough data, return this batch. /* We have enough data, return this batch.
* max_cmd_size accounts for the two trailing null bytes. * max_cmd_size accounts for the two trailing null bytes.
*/ */
enough_data = (buffer_get_position(buffer) > max_cmd_size); enough_data = (buffer_get_position(buffer) > max_cmd_size);
enough_documents = (idx >= max_write_batch_size); if (enough_data) {
if (enough_data || enough_documents) {
cur_size = buffer_get_position(buffer) - cur_doc_begin; cur_size = buffer_get_position(buffer) - cur_doc_begin;
/* This single document is too large for the command. */ /* This single document is too large for the command. */
@ -1614,30 +1587,35 @@ _batched_write_command(
Py_DECREF(DocumentTooLarge); Py_DECREF(DocumentTooLarge);
} }
} }
goto cmditerfail; goto fail;
} }
/* /*
* Roll the existing buffer back to the beginning * Roll the existing buffer back to the beginning
* of the last document encoded. * of the last document encoded.
*/ */
buffer_update_position(buffer, sub_doc_begin); buffer_update_position(buffer, sub_doc_begin);
Py_CLEAR(doc);
break; break;
} }
if (PyList_Append(to_publish, doc) < 0) { if (PyList_Append(to_publish, doc) < 0) {
goto cmditerfail; goto fail;
} }
Py_CLEAR(doc); Py_CLEAR(doc);
idx += 1; idx += 1;
/* We have enough documents, return this batch. */
if (idx == max_write_batch_size) {
break;
}
} }
Py_DECREF(iterator); Py_CLEAR(iterator);
if (PyErr_Occurred()) { if (PyErr_Occurred()) {
goto cmdfail; goto fail;
} }
if (!buffer_write_bytes(buffer, "\x00\x00", 2)) if (!buffer_write_bytes(buffer, "\x00\x00", 2)) {
goto cmdfail; goto fail;
}
position = buffer_get_position(buffer); position = buffer_get_position(buffer);
length = position - lst_len_loc - 1; length = position - lst_len_loc - 1;
@ -1646,10 +1624,9 @@ _batched_write_command(
buffer_write_int32_at_position(buffer, cmd_len_loc, (int32_t)length); buffer_write_int32_at_position(buffer, cmd_len_loc, (int32_t)length);
return 1; return 1;
cmditerfail: fail:
Py_XDECREF(doc); Py_XDECREF(doc);
Py_DECREF(iterator); Py_XDECREF(iterator);
cmdfail:
return 0; return 0;
} }

View File

@ -281,7 +281,8 @@ class _Bulk(object):
if retryable and not self.started_retryable_write: if retryable and not self.started_retryable_write:
session._start_retryable_write() session._start_retryable_write()
self.started_retryable_write = True self.started_retryable_write = True
session._apply_to(cmd, retryable, ReadPreference.PRIMARY) session._apply_to(cmd, retryable, ReadPreference.PRIMARY,
sock_info)
sock_info.send_cluster_time(cmd, session, client) sock_info.send_cluster_time(cmd, session, client)
check_keys = run.op_type == _INSERT check_keys = run.op_type == _INSERT
ops = islice(run.ops, run.idx_offset, None) ops = islice(run.ops, run.idx_offset, None)

View File

@ -12,10 +12,12 @@
# implied. See the License for the specific language governing # implied. See the License for the specific language governing
# permissions and limitations under the License. # permissions and limitations under the License.
"""ChangeStream cursor to iterate over changes on a collection.""" """Watch changes on a collection, a database, or the entire cluster."""
import copy import copy
from bson import _bson_to_dict
from bson.raw_bson import RawBSONDocument
from bson.son import SON from bson.son import SON
from pymongo import common from pymongo import common
@ -41,14 +43,12 @@ class ChangeStream(object):
"""The internal abstract base class for change stream cursors. """The internal abstract base class for change stream cursors.
Should not be called directly by application developers. Use Should not be called directly by application developers. Use
:meth:pymongo.collection.Collection.watch, :meth:`pymongo.collection.Collection.watch`,
:meth:pymongo.database.Database.watch, or :meth:`pymongo.database.Database.watch`, or
:meth:pymongo.mongo_client.MongoClient.watch instead. :meth:`pymongo.mongo_client.MongoClient.watch` instead.
Defines the interface for change streams. Should be subclassed to .. versionadded:: 3.6
implement the `ChangeStream._create_cursor` abstract method, and .. mongodoc:: changeStreams
the `ChangeStream._database`and ChangeStream._aggregation_target`
abstract properties.
""" """
def __init__(self, target, pipeline, full_document, resume_after, def __init__(self, target, pipeline, full_document, resume_after,
max_await_time_ms, batch_size, collation, max_await_time_ms, batch_size, collation,
@ -62,7 +62,18 @@ class ChangeStream(object):
validate_collation_or_none(collation) validate_collation_or_none(collation)
common.validate_non_negative_integer_or_none("batchSize", batch_size) common.validate_non_negative_integer_or_none("batchSize", batch_size)
self._target = target self._decode_custom = False
self._orig_codec_options = target.codec_options
if target.codec_options.type_registry._decoder_map:
self._decode_custom = True
# Keep the type registry so that we support encoding custom types
# in the pipeline.
self._target = target.with_options(
codec_options=target.codec_options.with_options(
document_class=RawBSONDocument))
else:
self._target = target
self._pipeline = copy.deepcopy(pipeline) self._pipeline = copy.deepcopy(pipeline)
self._full_document = full_document self._full_document = full_document
self._resume_token = copy.deepcopy(resume_after) self._resume_token = copy.deepcopy(resume_after)
@ -146,8 +157,7 @@ class ChangeStream(object):
aggregation_collection, cursor, sock_info.address, aggregation_collection, cursor, sock_info.address,
batch_size=self._batch_size or 0, batch_size=self._batch_size or 0,
max_await_time_ms=self._max_await_time_ms, max_await_time_ms=self._max_await_time_ms,
session=session, explicit_session=explicit_session session=session, explicit_session=explicit_session)
)
def _create_cursor(self): def _create_cursor(self):
with self._database.client._tmp_session(self._session, close=False) as s: with self._database.client._tmp_session(self._session, close=False) as s:
@ -175,34 +185,100 @@ class ChangeStream(object):
"""Advance the cursor. """Advance the cursor.
This method blocks until the next change document is returned or an This method blocks until the next change document is returned or an
unrecoverable error is raised. unrecoverable error is raised. This method is used when iterating over
all changes in the cursor. For example::
try:
with db.collection.watch(
[{'$match': {'operationType': 'insert'}}]) as stream:
for insert_change in stream:
print(insert_change)
except pymongo.errors.PyMongoError:
# The ChangeStream encountered an unrecoverable error or the
# resume attempt failed to recreate the cursor.
logging.error('...')
Raises :exc:`StopIteration` if this ChangeStream is closed. Raises :exc:`StopIteration` if this ChangeStream is closed.
""" """
while True: while self.alive:
try: doc = self.try_next()
change = self._cursor.next() if doc is not None:
except ConnectionFailure: return doc
self._resume()
continue raise StopIteration
except OperationFailure as exc:
if exc.code in _NON_RESUMABLE_GETMORE_ERRORS:
raise
self._resume()
continue
try:
resume_token = change['_id']
except KeyError:
self.close()
raise InvalidOperation(
"Cannot provide resume functionality when the resume "
"token is missing.")
self._resume_token = copy.copy(resume_token)
self._start_at_operation_time = None
return change
__next__ = next __next__ = next
@property
def alive(self):
"""Does this cursor have the potential to return more data?
.. note:: Even if :attr:`alive` is ``True``, :meth:`next` can raise
:exc:`StopIteration` and :meth:`try_next` can return ``None``.
.. versionadded:: 3.8
"""
return self._cursor.alive
def try_next(self):
"""Advance the cursor without blocking indefinitely.
This method returns the next change document without waiting
indefinitely for the next change. For example::
with db.collection.watch() as stream:
while stream.alive:
change = stream.try_next()
if change is not None:
print(change)
elif stream.alive:
# We end up here when there are no recent changes.
# Sleep for a while to avoid flooding the server with
# getMore requests when no changes are available.
time.sleep(10)
If no change document is cached locally then this method runs a single
getMore command. If the getMore yields any documents, the next
document is returned, otherwise, if the getMore returns no documents
(because there have been no changes) then ``None`` is returned.
:Returns:
The next change document or ``None`` when no document is available
after running a single getMore or when the cursor is closed.
.. versionadded:: 3.8
"""
# Attempt to get the next change with at most one getMore and at most
# one resume attempt.
try:
change = self._cursor._try_next(True)
except ConnectionFailure:
self._resume()
change = self._cursor._try_next(False)
except OperationFailure as exc:
if exc.code in _NON_RESUMABLE_GETMORE_ERRORS:
raise
self._resume()
change = self._cursor._try_next(False)
# No changes are available.
if change is None:
return None
try:
resume_token = change['_id']
except KeyError:
self.close()
raise InvalidOperation(
"Cannot provide resume functionality when the resume "
"token is missing.")
self._resume_token = copy.copy(resume_token)
self._start_at_operation_time = None
if self._decode_custom:
return _bson_to_dict(change.raw, self._orig_codec_options)
return change
def __enter__(self): def __enter__(self):
return self return self
@ -211,13 +287,12 @@ class ChangeStream(object):
class CollectionChangeStream(ChangeStream): class CollectionChangeStream(ChangeStream):
"""Class for creating a change stream on a collection. """A change stream that watches changes on a single collection.
Should not be called directly by application developers. Use Should not be called directly by application developers. Use
helper method :meth:`pymongo.collection.Collection.watch` instead. helper method :meth:`pymongo.collection.Collection.watch` instead.
.. versionadded: 3.6 .. versionadded:: 3.7
.. mongodoc:: changeStreams
""" """
@property @property
def _aggregation_target(self): def _aggregation_target(self):
@ -229,13 +304,12 @@ class CollectionChangeStream(ChangeStream):
class DatabaseChangeStream(ChangeStream): class DatabaseChangeStream(ChangeStream):
"""Class for creating a change stream on all collections in a database. """A change stream that watches changes on all collections in a database.
Should not be called directly by application developers. Use Should not be called directly by application developers. Use
helper method :meth:`pymongo.database.Database.watch` instead. helper method :meth:`pymongo.database.Database.watch` instead.
.. versionadded: 3.7 .. versionadded:: 3.7
.. mongodoc:: changeStreams
""" """
@property @property
def _aggregation_target(self): def _aggregation_target(self):
@ -247,13 +321,12 @@ class DatabaseChangeStream(ChangeStream):
class ClusterChangeStream(DatabaseChangeStream): class ClusterChangeStream(DatabaseChangeStream):
"""Class for creating a change stream on all collections on a cluster. """A change stream that watches changes on all collections in the cluster.
Should not be called directly by application developers. Use Should not be called directly by application developers. Use
helper method :meth:`pymongo.mongo_client.MongoClient.watch` instead. helper method :meth:`pymongo.mongo_client.MongoClient.watch` instead.
.. versionadded: 3.7 .. versionadded:: 3.7
.. mongodoc:: changeStreams
""" """
def _pipeline_options(self): def _pipeline_options(self):
options = super(ClusterChangeStream, self)._pipeline_options() options = super(ClusterChangeStream, self)._pipeline_options()

View File

@ -90,7 +90,7 @@ from bson.int64 import Int64
from bson.py3compat import abc, reraise_instance from bson.py3compat import abc, reraise_instance
from bson.timestamp import Timestamp from bson.timestamp import Timestamp
from pymongo import monotonic from pymongo import monotonic, __version__
from pymongo.errors import (ConfigurationError, from pymongo.errors import (ConfigurationError,
ConnectionFailure, ConnectionFailure,
InvalidOperation, InvalidOperation,
@ -263,6 +263,10 @@ _UNKNOWN_COMMIT_ERROR_CODES = _RETRYABLE_ERROR_CODES | frozenset([
64, # WriteConcernFailed 64, # WriteConcernFailed
]) ])
_MONGOS_NOT_SUPPORTED_MSG = (
'PyMongo %s does not support running multi-document transactions on '
'sharded clusters') % (__version__,)
class ClientSession(object): class ClientSession(object):
"""A session for ordering sequential operations.""" """A session for ordering sequential operations."""
@ -356,6 +360,9 @@ class ClientSession(object):
""" """
self._check_ended() self._check_ended()
if self._client._is_mongos_non_blocking():
raise ConfigurationError(_MONGOS_NOT_SUPPORTED_MSG)
if self._in_transaction: if self._in_transaction:
raise InvalidOperation("Transaction already in progress") raise InvalidOperation("Transaction already in progress")
@ -534,7 +541,7 @@ class ClientSession(object):
return self._transaction.opts.read_preference return self._transaction.opts.read_preference
return None return None
def _apply_to(self, command, is_retryable, read_preference): def _apply_to(self, command, is_retryable, read_preference, sock_info):
self._check_ended() self._check_ended()
self._server_session.last_use = monotonic.time() self._server_session.last_use = monotonic.time()
@ -548,6 +555,9 @@ class ClientSession(object):
return return
if self._in_transaction: if self._in_transaction:
if sock_info.is_mongos:
raise ConfigurationError(_MONGOS_NOT_SUPPORTED_MSG)
if read_preference != ReadPreference.PRIMARY: if read_preference != ReadPreference.PRIMARY:
raise InvalidOperation( raise InvalidOperation(
'read preference in a transaction must be primary, not: ' 'read preference in a transaction must be primary, not: '

View File

@ -53,6 +53,7 @@ from pymongo.write_concern import WriteConcern
_NO_OBJ_ERROR = "No matching object found" _NO_OBJ_ERROR = "No matching object found"
_UJOIN = u"%s.%s" _UJOIN = u"%s.%s"
_FIND_AND_MODIFY_DOC_FIELDS = {'value': 1}
class ReturnDocument(object): class ReturnDocument(object):
@ -202,7 +203,8 @@ class Collection(common.BaseObject):
write_concern=None, write_concern=None,
collation=None, collation=None,
session=None, session=None,
retryable_write=False): retryable_write=False,
user_fields=None):
"""Internal command helper. """Internal command helper.
:Parameters: :Parameters:
@ -222,6 +224,11 @@ class Collection(common.BaseObject):
:class:`~pymongo.collation.Collation`. :class:`~pymongo.collation.Collation`.
- `session` (optional): a - `session` (optional): a
:class:`~pymongo.client_session.ClientSession`. :class:`~pymongo.client_session.ClientSession`.
- `retryable_write` (optional): True if this command is a retryable
write.
- `user_fields` (optional): Response fields that should be decoded
using the TypeDecoders from codec_options, passed to
bson._decode_all_selective.
:Returns: :Returns:
The result document. The result document.
@ -241,7 +248,8 @@ class Collection(common.BaseObject):
collation=collation, collation=collation,
session=s, session=s,
client=self.__database.client, client=self.__database.client,
retryable_write=retryable_write) retryable_write=retryable_write,
user_fields=user_fields)
def __create(self, options, collation, session): def __create(self, options, collation, session):
"""Sends a create command with the given options. """Sends a create command with the given options.
@ -314,9 +322,8 @@ class Collection(common.BaseObject):
""" """
return self.__database return self.__database
def with_options( def with_options(self, codec_options=None, read_preference=None,
self, codec_options=None, read_preference=None, write_concern=None, read_concern=None):
write_concern=None, read_concern=None):
"""Get a clone of this collection changing the specified settings. """Get a clone of this collection changing the specified settings.
>>> coll1.read_preference >>> coll1.read_preference
@ -1299,7 +1306,8 @@ class Collection(common.BaseObject):
- `skip` (optional): the number of documents to omit (from - `skip` (optional): the number of documents to omit (from
the start of the result set) when returning the results the start of the result set) when returning the results
- `limit` (optional): the maximum number of results to - `limit` (optional): the maximum number of results to
return return. A limit of 0 (the default) is equivalent to setting no
limit.
- `no_cursor_timeout` (optional): if False (the default), any - `no_cursor_timeout` (optional): if False (the default), any
returned cursor is closed by the server after 10 minutes of returned cursor is closed by the server after 10 minutes of
inactivity. If set to True, the returned cursor will never inactivity. If set to True, the returned cursor will never
@ -1366,14 +1374,17 @@ class Collection(common.BaseObject):
- `min` (optional): A list of field, limit pairs specifying the - `min` (optional): A list of field, limit pairs specifying the
inclusive lower bound for all keys of a specific index in order. inclusive lower bound for all keys of a specific index in order.
Pass this as an alternative to calling Pass this as an alternative to calling
:meth:`~pymongo.cursor.Cursor.min` on the cursor. :meth:`~pymongo.cursor.Cursor.min` on the cursor. ``hint`` must
also be passed to ensure the query utilizes the correct index.
- `max` (optional): A list of field, limit pairs specifying the - `max` (optional): A list of field, limit pairs specifying the
exclusive upper bound for all keys of a specific index in order. exclusive upper bound for all keys of a specific index in order.
Pass this as an alternative to calling Pass this as an alternative to calling
:meth:`~pymongo.cursor.Cursor.max` on the cursor. :meth:`~pymongo.cursor.Cursor.max` on the cursor. ``hint`` must
- `comment` (optional): A string or document. Pass this as an also be passed to ensure the query utilizes the correct index.
alternative to calling :meth:`~pymongo.cursor.Cursor.comment` on the - `comment` (optional): A string to attach to the query to help
cursor. interpret and trace the operation in the server logs and in profile
data. Pass this as an alternative to calling
:meth:`~pymongo.cursor.Cursor.comment` on the cursor.
- `modifiers` (optional): **DEPRECATED** - A dict specifying - `modifiers` (optional): **DEPRECATED** - A dict specifying
additional MongoDB query modifiers. Use the keyword arguments listed additional MongoDB query modifiers. Use the keyword arguments listed
above instead. above instead.
@ -1623,7 +1634,8 @@ class Collection(common.BaseObject):
- `skip` (int): The number of matching documents to skip before - `skip` (int): The number of matching documents to skip before
returning results. returning results.
- `limit` (int): The maximum number of documents to count. - `limit` (int): The maximum number of documents to count. Must be
a positive integer. If not provided, no limit is imposed.
- `maxTimeMS` (int): The maximum amount of time to allow this - `maxTimeMS` (int): The maximum amount of time to allow this
operation to run, in milliseconds. operation to run, in milliseconds.
- `collation` (optional): An instance of - `collation` (optional): An instance of
@ -1699,7 +1711,8 @@ class Collection(common.BaseObject):
- `skip` (int): The number of matching documents to skip before - `skip` (int): The number of matching documents to skip before
returning results. returning results.
- `limit` (int): The maximum number of documents to count. - `limit` (int): The maximum number of documents to count. A limit
of 0 (the default) is equivalent to setting no limit.
- `maxTimeMS` (int): The maximum amount of time to allow the count - `maxTimeMS` (int): The maximum amount of time to allow the count
command to run, in milliseconds. command to run, in milliseconds.
- `collation` (optional): An instance of - `collation` (optional): An instance of
@ -2301,7 +2314,8 @@ class Collection(common.BaseObject):
write_concern=write_concern, write_concern=write_concern,
collation=collation, collation=collation,
session=session, session=session,
client=self.__database.client) client=self.__database.client,
user_fields={'cursor': {'firstBatch': 1}})
if "cursor" in result: if "cursor" in result:
cursor = result["cursor"] cursor = result["cursor"]
@ -2559,7 +2573,8 @@ class Collection(common.BaseObject):
with self._socket_for_reads(session=None) as (sock_info, slave_ok): with self._socket_for_reads(session=None) as (sock_info, slave_ok):
return self._command(sock_info, cmd, slave_ok, return self._command(sock_info, cmd, slave_ok,
collation=collation)["retval"] collation=collation,
user_fields={'retval': 1})["retval"]
def rename(self, new_name, session=None, **kwargs): def rename(self, new_name, session=None, **kwargs):
"""Rename this collection. """Rename this collection.
@ -2663,7 +2678,9 @@ class Collection(common.BaseObject):
with self._socket_for_reads(session) as (sock_info, slave_ok): with self._socket_for_reads(session) as (sock_info, slave_ok):
return self._command(sock_info, cmd, slave_ok, return self._command(sock_info, cmd, slave_ok,
read_concern=self.read_concern, read_concern=self.read_concern,
collation=collation, session=session)["values"] collation=collation,
session=session,
user_fields={"values": 1})["values"]
def map_reduce(self, map, reduce, out, full_response=False, session=None, def map_reduce(self, map, reduce, out, full_response=False, session=None,
**kwargs): **kwargs):
@ -2743,12 +2760,17 @@ class Collection(common.BaseObject):
write_concern = self._write_concern_for(session) write_concern = self._write_concern_for(session)
else: else:
write_concern = None write_concern = None
if inline:
user_fields = {'results': 1}
else:
user_fields = None
response = self._command( response = self._command(
sock_info, cmd, slave_ok, read_pref, sock_info, cmd, slave_ok, read_pref,
read_concern=read_concern, read_concern=read_concern,
write_concern=write_concern, write_concern=write_concern,
collation=collation, session=session) collation=collation, session=session,
user_fields=user_fields)
if full_response or not response.get('result'): if full_response or not response.get('result'):
return response return response
@ -2798,16 +2820,19 @@ class Collection(common.BaseObject):
("map", map), ("map", map),
("reduce", reduce), ("reduce", reduce),
("out", {"inline": 1})]) ("out", {"inline": 1})])
user_fields = {'results': 1}
collation = validate_collation_or_none(kwargs.pop('collation', None)) collation = validate_collation_or_none(kwargs.pop('collation', None))
cmd.update(kwargs) cmd.update(kwargs)
with self._socket_for_reads(session) as (sock_info, slave_ok): with self._socket_for_reads(session) as (sock_info, slave_ok):
if sock_info.max_wire_version >= 4 and 'readConcern' not in cmd: if sock_info.max_wire_version >= 4 and 'readConcern' not in cmd:
res = self._command(sock_info, cmd, slave_ok, res = self._command(sock_info, cmd, slave_ok,
read_concern=self.read_concern, read_concern=self.read_concern,
collation=collation, session=session) collation=collation, session=session,
user_fields=user_fields)
else: else:
res = self._command(sock_info, cmd, slave_ok, res = self._command(sock_info, cmd, slave_ok,
collation=collation, session=session) collation=collation, session=session,
user_fields=user_fields)
if full_response: if full_response:
return res return res
@ -2825,6 +2850,7 @@ class Collection(common.BaseObject):
return_document=ReturnDocument.BEFORE, return_document=ReturnDocument.BEFORE,
array_filters=None, session=None, **kwargs): array_filters=None, session=None, **kwargs):
"""Internal findAndModify helper.""" """Internal findAndModify helper."""
common.validate_is_mapping("filter", filter) common.validate_is_mapping("filter", filter)
if not isinstance(return_document, bool): if not isinstance(return_document, bool):
raise ValueError("return_document must be " raise ValueError("return_document must be "
@ -2864,8 +2890,10 @@ class Collection(common.BaseObject):
write_concern=write_concern, write_concern=write_concern,
allowable_errors=[_NO_OBJ_ERROR], allowable_errors=[_NO_OBJ_ERROR],
collation=collation, session=session, collation=collation, session=session,
retryable_write=retryable_write) retryable_write=retryable_write,
user_fields=_FIND_AND_MODIFY_DOC_FIELDS)
_check_write_command_response(out) _check_write_command_response(out)
return out.get("value") return out.get("value")
return self.__database.client._retryable_write( return self.__database.client._retryable_write(
@ -3281,7 +3309,8 @@ class Collection(common.BaseObject):
result = self._command( result = self._command(
sock_info, cmd, read_preference=ReadPreference.PRIMARY, sock_info, cmd, read_preference=ReadPreference.PRIMARY,
allowable_errors=[_NO_OBJ_ERROR], collation=collation, allowable_errors=[_NO_OBJ_ERROR], collation=collation,
session=session, retryable_write=retryable_write) session=session, retryable_write=retryable_write,
user_fields=_FIND_AND_MODIFY_DOC_FIELDS)
_check_write_command_response(result) _check_write_command_response(result)
return result return result

View File

@ -149,9 +149,14 @@ class CommandCursor(object):
reply = response.data reply = response.data
try: try:
docs = self._unpack_response(reply, user_fields = None
self.__id, legacy_response = True
self.__collection.codec_options) if from_command:
user_fields = {'cursor': {'nextBatch': 1}}
legacy_response = False
docs = self._unpack_response(
reply, self.__id, self.__collection.codec_options,
legacy_response=legacy_response, user_fields=user_fields)
if from_command: if from_command:
first = docs[0] first = docs[0]
client._receive_cluster_time(first, self.__session) client._receive_cluster_time(first, self.__session)
@ -174,7 +179,7 @@ class CommandCursor(object):
listeners.publish_command_failure( listeners.publish_command_failure(
duration(), exc.details, "getMore", rqst_id, self.__address) duration(), exc.details, "getMore", rqst_id, self.__address)
client._reset_server_and_request_check(self.address) client._reset_server_and_request_check(self.__address)
raise raise
except Exception as exc: except Exception as exc:
if publish: if publish:
@ -208,8 +213,10 @@ class CommandCursor(object):
kill() kill()
self.__data = deque(documents) self.__data = deque(documents)
def _unpack_response(self, response, cursor_id, codec_options): def _unpack_response(self, response, cursor_id, codec_options,
return response.unpack_response(cursor_id, codec_options) user_fields=None, legacy_response=False):
return response.unpack_response(cursor_id, codec_options, user_fields,
legacy_response)
def _refresh(self): def _refresh(self):
"""Refreshes the cursor with more data from the server. """Refreshes the cursor with more data from the server.
@ -285,15 +292,24 @@ class CommandCursor(object):
def next(self): def next(self):
"""Advance the cursor.""" """Advance the cursor."""
# Block until a document is returnable. # Block until a document is returnable.
while not len(self.__data) and not self.__killed: while self.alive:
doc = self._try_next(True)
if doc is not None:
return doc
raise StopIteration
__next__ = next
def _try_next(self, get_more_allowed):
"""Advance the cursor blocking for at most one getMore command."""
if not len(self.__data) and not self.__killed and get_more_allowed:
self._refresh() self._refresh()
if len(self.__data): if len(self.__data):
coll = self.__collection coll = self.__collection
return coll.database._fix_outgoing(self.__data.popleft(), coll) return coll.database._fix_outgoing(self.__data.popleft(), coll)
else: else:
raise StopIteration return None
__next__ = next
def __enter__(self): def __enter__(self):
return self return self
@ -321,7 +337,8 @@ class RawBatchCommandCursor(CommandCursor):
collection, cursor_info, address, retrieved, batch_size, collection, cursor_info, address, retrieved, batch_size,
max_await_time_ms, session, explicit_session) max_await_time_ms, session, explicit_session)
def _unpack_response(self, response, cursor_id, codec_options): def _unpack_response(self, response, cursor_id, codec_options,
user_fields=None, legacy_response=False):
return response.raw_response(cursor_id) return response.raw_response(cursor_id)
def __getitem__(self, index): def __getitem__(self, index):

View File

@ -21,7 +21,7 @@ import warnings
from bson import SON from bson import SON
from bson.binary import (STANDARD, PYTHON_LEGACY, from bson.binary import (STANDARD, PYTHON_LEGACY,
JAVA_LEGACY, CSHARP_LEGACY) JAVA_LEGACY, CSHARP_LEGACY)
from bson.codec_options import CodecOptions from bson.codec_options import CodecOptions, TypeRegistry
from bson.py3compat import abc, integer_types, iteritems, string_type from bson.py3compat import abc, integer_types, iteritems, string_type
from bson.raw_bson import RawBSONDocument from bson.raw_bson import RawBSONDocument
from pymongo.auth import MECHANISMS from pymongo.auth import MECHANISMS
@ -422,6 +422,14 @@ def validate_document_class(option, value):
return value return value
def validate_type_registry(option, value):
"""Validate the type_registry option."""
if value is not None and not isinstance(value, TypeRegistry):
raise TypeError("%s must be an instance of %s" % (
option, TypeRegistry))
return value
def validate_list(option, value): def validate_list(option, value):
"""Validates that 'value' is a list.""" """Validates that 'value' is a list."""
if not isinstance(value, list): if not isinstance(value, list):
@ -576,6 +584,7 @@ TIMEOUT_VALIDATORS = {
KW_VALIDATORS = { KW_VALIDATORS = {
'document_class': validate_document_class, 'document_class': validate_document_class,
'type_registry': validate_type_registry,
'read_preference': validate_read_preference, 'read_preference': validate_read_preference,
'event_listeners': _validate_event_listeners, 'event_listeners': _validate_event_listeners,
'tzinfo': validate_tzinfo, 'tzinfo': validate_tzinfo,

View File

@ -40,7 +40,7 @@ from pymongo.message import (_convert_exception,
_RawBatchGetMore, _RawBatchGetMore,
_Query, _Query,
_RawBatchQuery) _RawBatchQuery)
from pymongo.read_preferences import ReadPreference
_QUERY_OPTIONS = { _QUERY_OPTIONS = {
"tailable_cursor": 2, "tailable_cursor": 2,
@ -50,6 +50,7 @@ _QUERY_OPTIONS = {
"await_data": 32, "await_data": 32,
"exhaust": 64, "exhaust": 64,
"partial": 128} "partial": 128}
_CURSOR_DOC_FIELDS = {'cursor': {'firstBatch': 1, 'nextBatch': 1}}
class CursorType(object): class CursorType(object):
@ -634,12 +635,19 @@ class Cursor(object):
return self return self
def max(self, spec): def max(self, spec):
"""Adds `max` operator that specifies upper bound for specific index. """Adds ``max`` operator that specifies upper bound for specific index.
When using ``max``, :meth:`~hint` should also be configured to ensure
the query uses the expected index and starting in MongoDB 4.2
:meth:`~hint` will be required.
:Parameters: :Parameters:
- `spec`: a list of field, limit pairs specifying the exclusive - `spec`: a list of field, limit pairs specifying the exclusive
upper bound for all keys of a specific index in order. upper bound for all keys of a specific index in order.
.. versionchanged:: 3.8
Deprecated cursors that use ``max`` without a :meth:`~hint`.
.. versionadded:: 2.7 .. versionadded:: 2.7
""" """
if not isinstance(spec, (list, tuple)): if not isinstance(spec, (list, tuple)):
@ -650,12 +658,19 @@ class Cursor(object):
return self return self
def min(self, spec): def min(self, spec):
"""Adds `min` operator that specifies lower bound for specific index. """Adds ``min`` operator that specifies lower bound for specific index.
When using ``min``, :meth:`~hint` should also be configured to ensure
the query uses the expected index and starting in MongoDB 4.2
:meth:`~hint` will be required.
:Parameters: :Parameters:
- `spec`: a list of field, limit pairs specifying the inclusive - `spec`: a list of field, limit pairs specifying the inclusive
lower bound for all keys of a specific index in order. lower bound for all keys of a specific index in order.
.. versionchanged:: 3.8
Deprecated cursors that use ``min`` without a :meth:`~hint`.
.. versionadded:: 2.7 .. versionadded:: 2.7
""" """
if not isinstance(spec, (list, tuple)): if not isinstance(spec, (list, tuple)):
@ -754,7 +769,7 @@ class Cursor(object):
if self.__max_time_ms is not None: if self.__max_time_ms is not None:
cmd["maxTimeMS"] = self.__max_time_ms cmd["maxTimeMS"] = self.__max_time_ms
if self.__comment: if self.__comment:
cmd["$comment"] = self.__comment cmd["comment"] = self.__comment
if self.__hint is not None: if self.__hint is not None:
cmd["hint"] = self.__hint cmd["hint"] = self.__hint
@ -791,7 +806,7 @@ class Cursor(object):
if self.__max_time_ms is not None: if self.__max_time_ms is not None:
options['maxTimeMS'] = self.__max_time_ms options['maxTimeMS'] = self.__max_time_ms
if self.__comment: if self.__comment:
options['$comment'] = self.__comment options['comment'] = self.__comment
if self.__collation is not None: if self.__collation is not None:
options['collation'] = self.__collation options['collation'] = self.__collation
@ -801,6 +816,13 @@ class Cursor(object):
def explain(self): def explain(self):
"""Returns an explain plan record for this cursor. """Returns an explain plan record for this cursor.
.. note:: Starting with MongoDB 3.2 :meth:`explain` uses
the default verbosity mode of the `explain command
<https://docs.mongodb.com/manual/reference/command/explain/>`_,
``allPlansExecution``. To use a different verbosity use
:meth:`~pymongo.database.Database.command` to run the explain
command directly.
.. mongodoc:: explain .. mongodoc:: explain
""" """
c = self.clone() c = self.clone()
@ -855,7 +877,8 @@ class Cursor(object):
http://docs.mongodb.org/manual/reference/operator/comment/ http://docs.mongodb.org/manual/reference/operator/comment/
:Parameters: :Parameters:
- `comment`: A string or document - `comment`: A string to attach to the query to help interpret and
trace the operation in the server logs and in profile data.
.. versionadded:: 2.7 .. versionadded:: 2.7
""" """
@ -973,9 +996,14 @@ class Cursor(object):
raise raise
try: try:
docs = self._unpack_response(response=reply, user_fields = None
cursor_id=self.__id, legacy_response = True
codec_options=self.__codec_options) if from_command:
user_fields = _CURSOR_DOC_FIELDS
legacy_response = False
docs = self._unpack_response(
reply, self.__id, self.__collection.codec_options,
legacy_response=legacy_response, user_fields=user_fields)
if from_command: if from_command:
first = docs[0] first = docs[0]
client._receive_cluster_time(first, self.__session) client._receive_cluster_time(first, self.__session)
@ -1063,8 +1091,10 @@ class Cursor(object):
if self.__limit and self.__id and self.__limit <= self.__retrieved: if self.__limit and self.__id and self.__limit <= self.__retrieved:
self.__die() self.__die()
def _unpack_response(self, response, cursor_id, codec_options): def _unpack_response(self, response, cursor_id, codec_options,
return response.unpack_response(cursor_id, codec_options) user_fields=None, legacy_response=False):
return response.unpack_response(cursor_id, codec_options, user_fields,
legacy_response)
def _read_preference(self): def _read_preference(self):
if self.__read_preference is None: if self.__read_preference is None:
@ -1087,6 +1117,12 @@ class Cursor(object):
self.__session = self.__collection.database.client._ensure_session() self.__session = self.__collection.database.client._ensure_session()
if self.__id is None: # Query if self.__id is None: # Query
if (self.__min or self.__max) and not self.__hint:
warnings.warn("using a min/max query operator without "
"specifying a Cursor.hint is deprecated. A "
"hint will be required when using min/max in "
"PyMongo 4.0",
DeprecationWarning, stacklevel=3)
q = self._query_class(self.__query_flags, q = self._query_class(self.__query_flags,
self.__collection.database.name, self.__collection.database.name,
self.__collection.name, self.__collection.name,
@ -1275,7 +1311,8 @@ class RawBatchCursor(Cursor):
raise InvalidOperation( raise InvalidOperation(
"Cannot use RawBatchCursor with manipulate=True") "Cannot use RawBatchCursor with manipulate=True")
def _unpack_response(self, response, cursor_id, codec_options): def _unpack_response(self, response, cursor_id, codec_options,
user_fields=None, legacy_response=False):
return response.raw_response(cursor_id) return response.raw_response(cursor_id)
def explain(self): def explain(self):

View File

@ -222,6 +222,46 @@ class Database(common.BaseObject):
return [manipulator.__class__.__name__ return [manipulator.__class__.__name__
for manipulator in self.__outgoing_copying_manipulators] for manipulator in self.__outgoing_copying_manipulators]
def with_options(self, codec_options=None, read_preference=None,
write_concern=None, read_concern=None):
"""Get a clone of this database changing the specified settings.
>>> db1.read_preference
Primary()
>>> from pymongo import ReadPreference
>>> db2 = db1.with_options(read_preference=ReadPreference.SECONDARY)
>>> db1.read_preference
Primary()
>>> db2.read_preference
Secondary(tag_sets=None)
:Parameters:
- `codec_options` (optional): An instance of
:class:`~bson.codec_options.CodecOptions`. If ``None`` (the
default) the :attr:`codec_options` of this :class:`Collection`
is used.
- `read_preference` (optional): The read preference to use. If
``None`` (the default) the :attr:`read_preference` of this
:class:`Collection` is used. See :mod:`~pymongo.read_preferences`
for options.
- `write_concern` (optional): An instance of
:class:`~pymongo.write_concern.WriteConcern`. If ``None`` (the
default) the :attr:`write_concern` of this :class:`Collection`
is used.
- `read_concern` (optional): An instance of
:class:`~pymongo.read_concern.ReadConcern`. If ``None`` (the
default) the :attr:`read_concern` of this :class:`Collection`
is used.
.. versionadded:: 3.8
"""
return Database(self.client,
self.__name,
codec_options or self.codec_options,
read_preference or self.read_preference,
write_concern or self.write_concern,
read_concern or self.read_concern)
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, Database): if isinstance(other, Database):
return (self.__client == other.client and return (self.__client == other.client and
@ -361,7 +401,8 @@ class Database(common.BaseObject):
Removed deprecated argument: options Removed deprecated argument: options
""" """
with self.__client._tmp_session(session) as s: with self.__client._tmp_session(session) as s:
if name in self.list_collection_names(session=s): if name in self.list_collection_names(
filter={"name": name}, session=s):
raise CollectionInvalid("collection %s already exists" % name) raise CollectionInvalid("collection %s already exists" % name)
return Collection(self, name, True, codec_options, return Collection(self, name, True, codec_options,
@ -575,6 +616,9 @@ class Database(common.BaseObject):
:attr:`read_preference` or :attr:`codec_options`. You must use the :attr:`read_preference` or :attr:`codec_options`. You must use the
`read_preference` and `codec_options` parameters instead. `read_preference` and `codec_options` parameters instead.
.. note:: :meth:`command` does **not** apply any custom TypeDecoders
when decoding the command response.
.. versionchanged:: 3.6 .. versionchanged:: 3.6
Added ``session`` parameter. Added ``session`` parameter.
@ -651,12 +695,14 @@ class Database(common.BaseObject):
cursor = self._command(sock_info, cmd, slave_okay)["cursor"] cursor = self._command(sock_info, cmd, slave_okay)["cursor"]
return CommandCursor(coll, cursor, sock_info.address) return CommandCursor(coll, cursor, sock_info.address)
def list_collections(self, session=None, **kwargs): def list_collections(self, session=None, filter=None, **kwargs):
"""Get a cursor over the collectons of this database. """Get a cursor over the collectons of this database.
:Parameters: :Parameters:
- `session` (optional): a - `session` (optional): a
:class:`~pymongo.client_session.ClientSession`. :class:`~pymongo.client_session.ClientSession`.
- `filter` (optional): A query document to filter the list of
collections returned from the listCollections command.
- `**kwargs` (optional): Optional parameters of the - `**kwargs` (optional): Optional parameters of the
`listCollections command `listCollections command
<https://docs.mongodb.com/manual/reference/command/listCollections/>`_ <https://docs.mongodb.com/manual/reference/command/listCollections/>`_
@ -668,6 +714,8 @@ class Database(common.BaseObject):
.. versionadded:: 3.6 .. versionadded:: 3.6
""" """
if filter is not None:
kwargs['filter'] = filter
read_pref = ((session and session._txn_read_preference()) read_pref = ((session and session._txn_read_preference())
or ReadPreference.PRIMARY) or ReadPreference.PRIMARY)
with self.__client._socket_for_reads( with self.__client._socket_for_reads(
@ -676,18 +724,42 @@ class Database(common.BaseObject):
sock_info, slave_okay, session, read_preference=read_pref, sock_info, slave_okay, session, read_preference=read_pref,
**kwargs) **kwargs)
def list_collection_names(self, session=None): def list_collection_names(self, session=None, filter=None, **kwargs):
"""Get a list of all the collection names in this database. """Get a list of all the collection names in this database.
For example, to list all non-system collections::
filter = {"name": {"$regex": r"^(?!system\.)"}}
db.list_collection_names(filter=filter)
:Parameters: :Parameters:
- `session` (optional): a - `session` (optional): a
:class:`~pymongo.client_session.ClientSession`. :class:`~pymongo.client_session.ClientSession`.
- `filter` (optional): A query document to filter the list of
collections returned from the listCollections command.
- `**kwargs` (optional): Optional parameters of the
`listCollections command
<https://docs.mongodb.com/manual/reference/command/listCollections/>`_
can be passed as keyword arguments to this method. The supported
options differ by server version.
.. versionchanged:: 3.8
Added the ``filter`` and ``**kwargs`` parameters.
.. versionadded:: 3.6 .. versionadded:: 3.6
""" """
if filter is None:
kwargs["nameOnly"] = True
else:
# The enumerate collections spec states that "drivers MUST NOT set
# nameOnly if a filter specifies any keys other than name."
common.validate_is_mapping("filter", filter)
kwargs["filter"] = filter
if not filter or (len(filter) == 1 and "name" in filter):
kwargs["nameOnly"] = True
return [result["name"] return [result["name"]
for result in self.list_collections(session=session, for result in self.list_collections(session=session, **kwargs)]
nameOnly=True)]
def collection_names(self, include_system_collections=True, def collection_names(self, include_system_collections=True,
session=None): session=None):
@ -828,9 +900,9 @@ class Database(common.BaseObject):
cmd = SON([("currentOp", 1), ("$all", include_all)]) cmd = SON([("currentOp", 1), ("$all", include_all)])
with self.__client._socket_for_writes() as sock_info: with self.__client._socket_for_writes() as sock_info:
if sock_info.max_wire_version >= 4: if sock_info.max_wire_version >= 4:
with self.__client._tmp_session(session) as s: return self.__client.admin._command(
return sock_info.command("admin", cmd, session=s, sock_info, cmd, codec_options=self.codec_options,
client=self.__client) session=session)
else: else:
spec = {"$all": True} if include_all else {} spec = {"$all": True} if include_all else {}
return _first_batch(sock_info, "admin", "$cmd.sys.inprog", return _first_batch(sock_info, "admin", "$cmd.sys.inprog",

View File

@ -289,7 +289,7 @@ class _Query(object):
cmd = SON([('explain', cmd)]) cmd = SON([('explain', cmd)])
session = self.session session = self.session
if session: if session:
session._apply_to(cmd, False, self.read_preference) session._apply_to(cmd, False, self.read_preference, sock_info)
# Explain does not support readConcern. # Explain does not support readConcern.
if (not explain and session.options.causal_consistency if (not explain and session.options.causal_consistency
and session.operation_time is not None and session.operation_time is not None
@ -379,7 +379,7 @@ class _GetMore(object):
self.max_await_time_ms) self.max_await_time_ms)
if self.session: if self.session:
self.session._apply_to(cmd, False, self.read_preference) self.session._apply_to(cmd, False, self.read_preference, sock_info)
sock_info.send_cluster_time(cmd, self.session, self.client) sock_info.send_cluster_time(cmd, self.session, self.client)
self._as_command = cmd, self.db self._as_command = cmd, self.db
return self._as_command return self._as_command
@ -1398,7 +1398,8 @@ class _OpReply(object):
return [self.documents] return [self.documents]
def unpack_response(self, cursor_id=None, def unpack_response(self, cursor_id=None,
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS): codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
user_fields=None, legacy_response=False):
"""Unpack a response from the database and decode the BSON document(s). """Unpack a response from the database and decode the BSON document(s).
Check the response for errors and unpack, returning a dictionary Check the response for errors and unpack, returning a dictionary
@ -1415,7 +1416,10 @@ class _OpReply(object):
:class:`~bson.codec_options.CodecOptions` :class:`~bson.codec_options.CodecOptions`
""" """
self.raw_response(cursor_id) self.raw_response(cursor_id)
return bson.decode_all(self.documents, codec_options) if legacy_response:
return bson.decode_all(self.documents, codec_options)
return bson._decode_all_selective(
self.documents, codec_options, user_fields)
def command_response(self): def command_response(self):
"""Unpack a command response.""" """Unpack a command response."""
@ -1451,7 +1455,8 @@ class _OpMsg(object):
raise NotImplementedError raise NotImplementedError
def unpack_response(self, cursor_id=None, def unpack_response(self, cursor_id=None,
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS): codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
user_fields=None, legacy_response=False):
"""Unpack a OP_MSG command response. """Unpack a OP_MSG command response.
:Parameters: :Parameters:
@ -1459,7 +1464,10 @@ class _OpMsg(object):
- `codec_options` (optional): an instance of - `codec_options` (optional): an instance of
:class:`~bson.codec_options.CodecOptions` :class:`~bson.codec_options.CodecOptions`
""" """
return bson.decode_all(self.payload_document, codec_options) # If _OpMsg is in-use, this cannot be a legacy response.
assert not legacy_response
return bson._decode_all_selective(
self.payload_document, codec_options, user_fields)
def command_response(self): def command_response(self):
"""Unpack a command response.""" """Unpack a command response."""

View File

@ -39,7 +39,7 @@ import weakref
from collections import defaultdict from collections import defaultdict
from bson.codec_options import DEFAULT_CODEC_OPTIONS from bson.codec_options import DEFAULT_CODEC_OPTIONS, TypeRegistry
from bson.py3compat import (integer_types, from bson.py3compat import (integer_types,
string_type) string_type)
from bson.son import SON from bson.son import SON
@ -97,6 +97,7 @@ class MongoClient(common.BaseObject):
document_class=dict, document_class=dict,
tz_aware=None, tz_aware=None,
connect=None, connect=None,
type_registry=None,
**kwargs): **kwargs):
"""Client for a MongoDB instance, a replica set, or a set of mongoses. """Client for a MongoDB instance, a replica set, or a set of mongoses.
@ -187,6 +188,9 @@ class MongoClient(common.BaseObject):
- `port` (optional): port number on which to connect - `port` (optional): port number on which to connect
- `document_class` (optional): default class to use for - `document_class` (optional): default class to use for
documents returned from queries on this client documents returned from queries on this client
- `type_registry` (optional): instance of
:class:`~bson.codec_options.TypeRegistry` to enable encoding
and decoding of custom types.
- `tz_aware` (optional): if ``True``, - `tz_aware` (optional): if ``True``,
:class:`~datetime.datetime` instances returned as values :class:`~datetime.datetime` instances returned as values
in a document by this :class:`MongoClient` will be timezone in a document by this :class:`MongoClient` will be timezone
@ -287,6 +291,11 @@ class MongoClient(common.BaseObject):
are -1 through 9. -1 tells the zlib library to use its default are -1 through 9. -1 tells the zlib library to use its default
compression level (usually 6). 0 means no compression. 1 is best compression level (usually 6). 0 means no compression. 1 is best
speed. 9 is best compression. Defaults to -1. speed. 9 is best compression. Defaults to -1.
- `uuidRepresentation`: The BSON representation to use when encoding
from and decoding to instances of :class:`~uuid.UUID`. Valid
values are `pythonLegacy` (the default), `javaLegacy`,
`csharpLegacy` and `standard`. New applications should consider
setting this to `standard` for cross language compatibility.
| **Write Concern options:** | **Write Concern options:**
| (Only set if passed. No default values.) | (Only set if passed. No default values.)
@ -421,6 +430,7 @@ class MongoClient(common.BaseObject):
.. versionchanged:: 3.8 .. versionchanged:: 3.8
Added the ``server_selector`` keyword argument. Added the ``server_selector`` keyword argument.
Added the ``type_registry`` keyword argument.
.. versionchanged:: 3.7 .. versionchanged:: 3.7
Added the ``driver`` keyword argument. Added the ``driver`` keyword argument.
@ -530,6 +540,8 @@ class MongoClient(common.BaseObject):
keyword_opts = kwargs keyword_opts = kwargs
keyword_opts['document_class'] = document_class keyword_opts['document_class'] = document_class
if type_registry is not None:
keyword_opts['type_registry'] = type_registry
if tz_aware is None: if tz_aware is None:
tz_aware = opts.get('tz_aware', False) tz_aware = opts.get('tz_aware', False)
if connect is None: if connect is None:
@ -1523,6 +1535,9 @@ class MongoClient(common.BaseObject):
except Exception: except Exception:
helpers._handle_exception() helpers._handle_exception()
def _is_mongos_non_blocking(self):
return self._topology.is_mongos_non_blocking()
def __start_session(self, implicit, **kwargs): def __start_session(self, implicit, **kwargs):
# Driver Sessions Spec: "If startSession is called when multiple users # Driver Sessions Spec: "If startSession is called when multiple users
# are authenticated drivers MUST raise an error with the error message # are authenticated drivers MUST raise an error with the error message
@ -1745,8 +1760,9 @@ class MongoClient(common.BaseObject):
parse_write_concern_error=True, parse_write_concern_error=True,
session=session) session=session)
def get_default_database(self): def get_default_database(self, default=None, codec_options=None,
"""DEPRECATED - Get the database named in the MongoDB connection URI. read_preference=None, write_concern=None, read_concern=None):
"""Get the database named in the MongoDB connection URI.
>>> uri = 'mongodb://host/my_database' >>> uri = 'mongodb://host/my_database'
>>> client = MongoClient(uri) >>> client = MongoClient(uri)
@ -1758,15 +1774,41 @@ class MongoClient(common.BaseObject):
Useful in scripts where you want to choose which database to use Useful in scripts where you want to choose which database to use
based only on the URI in a configuration file. based only on the URI in a configuration file.
:Parameters:
- `default` (optional): the database name to use if no database name
was provided in the URI.
- `codec_options` (optional): An instance of
:class:`~bson.codec_options.CodecOptions`. If ``None`` (the
default) the :attr:`codec_options` of this :class:`MongoClient` is
used.
- `read_preference` (optional): The read preference to use. If
``None`` (the default) the :attr:`read_preference` of this
:class:`MongoClient` is used. See :mod:`~pymongo.read_preferences`
for options.
- `write_concern` (optional): An instance of
:class:`~pymongo.write_concern.WriteConcern`. If ``None`` (the
default) the :attr:`write_concern` of this :class:`MongoClient` is
used.
- `read_concern` (optional): An instance of
:class:`~pymongo.read_concern.ReadConcern`. If ``None`` (the
default) the :attr:`read_concern` of this :class:`MongoClient` is
used.
.. versionchanged:: 3.8
Undeprecated. Added the ``default``, ``codec_options``,
``read_preference``, ``write_concern`` and ``read_concern``
parameters.
.. versionchanged:: 3.5 .. versionchanged:: 3.5
Deprecated, use :meth:`get_database` instead. Deprecated, use :meth:`get_database` instead.
""" """
warnings.warn("get_default_database is deprecated. Use get_database " if self.__default_database_name is None and default is None:
"instead.", DeprecationWarning, stacklevel=2) raise ConfigurationError(
if self.__default_database_name is None: 'No default database name defined or provided.')
raise ConfigurationError('No default database defined')
return self[self.__default_database_name] return database.Database(
self, self.__default_database_name or default, codec_options,
read_preference, write_concern, read_concern)
def get_database(self, name=None, codec_options=None, read_preference=None, def get_database(self, name=None, codec_options=None, read_preference=None,
write_concern=None, read_concern=None): write_concern=None, read_concern=None):

View File

@ -58,7 +58,8 @@ def command(sock, dbname, spec, slave_ok, is_mongos,
collation=None, collation=None,
compression_ctx=None, compression_ctx=None,
use_op_msg=False, use_op_msg=False,
unacknowledged=False): unacknowledged=False,
user_fields=None):
"""Execute a command over the socket, or raise socket.error. """Execute a command over the socket, or raise socket.error.
:Parameters: :Parameters:
@ -81,6 +82,12 @@ def command(sock, dbname, spec, slave_ok, is_mongos,
- `parse_write_concern_error`: Whether to parse the ``writeConcernError`` - `parse_write_concern_error`: Whether to parse the ``writeConcernError``
field in the command response. field in the command response.
- `collation`: The collation for this command. - `collation`: The collation for this command.
- `compression_ctx`: optional compression Context.
- `use_op_msg`: True if we should use OP_MSG.
- `unacknowledged`: True if this is an unacknowledged command.
- `user_fields` (optional): Response fields that should be decoded
using the TypeDecoders from codec_options, passed to
bson._decode_all_selective.
""" """
name = next(iter(spec)) name = next(iter(spec))
ns = dbname + '.$cmd' ns = dbname + '.$cmd'
@ -139,7 +146,8 @@ def command(sock, dbname, spec, slave_ok, is_mongos,
response_doc = {"ok": 1} response_doc = {"ok": 1}
else: else:
reply = receive_message(sock, request_id) reply = receive_message(sock, request_id)
unpacked_docs = reply.unpack_response(codec_options=codec_options) unpacked_docs = reply.unpack_response(
codec_options=codec_options, user_fields=user_fields)
response_doc = unpacked_docs[0] response_doc = unpacked_docs[0]
if client: if client:

View File

@ -512,7 +512,8 @@ class SocketInfo(object):
session=None, session=None,
client=None, client=None,
retryable_write=False, retryable_write=False,
publish_events=True): publish_events=True,
user_fields=None):
"""Execute a command or raise an error. """Execute a command or raise an error.
:Parameters: :Parameters:
@ -533,6 +534,9 @@ class SocketInfo(object):
- `client`: optional MongoClient for gossipping $clusterTime. - `client`: optional MongoClient for gossipping $clusterTime.
- `retryable_write`: True if this command is a retryable write. - `retryable_write`: True if this command is a retryable write.
- `publish_events`: Should we publish events for this command? - `publish_events`: Should we publish events for this command?
- `user_fields` (optional): Response fields that should be decoded
using the TypeDecoders from codec_options, passed to
bson._decode_all_selective.
""" """
self.validate_session(client, session) self.validate_session(client, session)
session = _validate_session_write_concern(session, write_concern) session = _validate_session_write_concern(session, write_concern)
@ -560,7 +564,7 @@ class SocketInfo(object):
'Must be connected to MongoDB 3.4+ to use a collation.') 'Must be connected to MongoDB 3.4+ to use a collation.')
if session: if session:
session._apply_to(spec, retryable_write, read_preference) session._apply_to(spec, retryable_write, read_preference, self)
self.send_cluster_time(spec, session, client) self.send_cluster_time(spec, session, client)
listeners = self.listeners if publish_events else None listeners = self.listeners if publish_events else None
unacknowledged = write_concern and not write_concern.acknowledged unacknowledged = write_concern and not write_concern.acknowledged
@ -576,7 +580,8 @@ class SocketInfo(object):
collation=collation, collation=collation,
compression_ctx=self.compression_context, compression_ctx=self.compression_context,
use_op_msg=self.op_msg_enabled, use_op_msg=self.op_msg_enabled,
unacknowledged=unacknowledged) unacknowledged=unacknowledged,
user_fields=user_fields)
except OperationFailure: except OperationFailure:
raise raise
# Catch socket.error, KeyboardInterrupt, etc. and close ourselves. # Catch socket.error, KeyboardInterrupt, etc. and close ourselves.

View File

@ -128,6 +128,8 @@ if HAVE_SSL:
ctx.options |= getattr(ssl, "OP_NO_SSLv3", 0) ctx.options |= getattr(ssl, "OP_NO_SSLv3", 0)
# OpenSSL >= 1.0.0 # OpenSSL >= 1.0.0
ctx.options |= getattr(ssl, "OP_NO_COMPRESSION", 0) ctx.options |= getattr(ssl, "OP_NO_COMPRESSION", 0)
# Python 3.7+ with OpenSSL >= 1.1.0h
ctx.options |= getattr(ssl, "OP_NO_RENEGOTIATION", 0)
if certfile is not None: if certfile is not None:
try: try:
if passphrase is not None: if passphrase is not None:

View File

@ -30,6 +30,7 @@ from pymongo import common
from pymongo import periodic_executor from pymongo import periodic_executor
from pymongo.pool import PoolOptions from pymongo.pool import PoolOptions
from pymongo.topology_description import (updated_topology_description, from pymongo.topology_description import (updated_topology_description,
SERVER_TYPE,
TOPOLOGY_TYPE, TOPOLOGY_TYPE,
TopologyDescription) TopologyDescription)
from pymongo.errors import ServerSelectionTimeoutError, ConfigurationError from pymongo.errors import ServerSelectionTimeoutError, ConfigurationError
@ -451,6 +452,22 @@ class Topology(object):
# Called from a __del__ method, can't use a lock. # Called from a __del__ method, can't use a lock.
self._session_pool.return_server_session_no_lock(server_session) self._session_pool.return_server_session_no_lock(server_session)
def is_mongos_non_blocking(self):
"""Return if we are connected to a Mongos without blocking.
If the state is unknown, return False.
"""
with self._lock:
if not self._opened:
return False
if self._description.topology_type == TOPOLOGY_TYPE.Sharded:
return True
server_descriptions = self._description.apply_selector(
writable_server_selector, None)
if not server_descriptions:
return False
return server_descriptions[0].server_type == SERVER_TYPE.Mongos
def _new_selection(self): def _new_selection(self):
"""A Selection object, initially including all known servers. """A Selection object, initially including all known servers.

View File

@ -39,7 +39,7 @@ except ImportError:
except ImportError: except ImportError:
_HAVE_SPHINX = False _HAVE_SPHINX = False
version = "3.8.0.dev0" version = "3.8.0"
f = open("README.rst") f = open("README.rst")
try: try:

View File

@ -604,6 +604,12 @@ class ClientContext(object):
"""Does the connected server support getpreverror?""" """Does the connected server support getpreverror?"""
return not (self.version.at_least(4, 1, 0) or self.is_mongos) return not (self.version.at_least(4, 1, 0) or self.is_mongos)
@property
def requires_hint_with_min_max_queries(self):
"""Does the server require a hint with min/max queries."""
# Changed in SERVER-39567.
return self.version.at_least(4, 1, 10)
# Reusable client context # Reusable client context
client_context = ClientContext() client_context = ClientContext()
@ -659,7 +665,7 @@ class IntegrationTest(PyMongoTestCase):
# Use assertRaisesRegex if available, otherwise use Python 2.7's # Use assertRaisesRegex if available, otherwise use Python 2.7's
# deprecated assertRaisesRegexp, with a 'p'. # deprecated assertRaisesRegexp, with a 'p'.
if not hasattr(unittest.TestCase, 'assertRaisesRegex'): if not hasattr(unittest.TestCase, 'assertRaisesRegex'):
IntegrationTest.assertRaisesRegex = unittest.TestCase.assertRaisesRegexp unittest.TestCase.assertRaisesRegex = unittest.TestCase.assertRaisesRegexp
class MockClientTest(unittest.TestCase): class MockClientTest(unittest.TestCase):

View File

@ -41,7 +41,6 @@ from bson.dbref import DBRef
from bson.py3compat import abc, iteritems, PY3, StringIO, text_type from bson.py3compat import abc, iteritems, PY3, StringIO, text_type
from bson.son import SON from bson.son import SON
from bson.timestamp import Timestamp from bson.timestamp import Timestamp
from bson.tz_util import FixedOffset
from bson.errors import (InvalidBSON, from bson.errors import (InvalidBSON,
InvalidDocument, InvalidDocument,
InvalidStringData) InvalidStringData)
@ -51,6 +50,7 @@ from bson.tz_util import (FixedOffset,
utc) utc)
from test import qcheck, SkipTest, unittest from test import qcheck, SkipTest, unittest
from test.utils import ExceptionCatchingThread
if PY3: if PY3:
long = int long = int
@ -586,12 +586,12 @@ class TestBSON(unittest.TestCase):
def test_small_long_encode_decode(self): def test_small_long_encode_decode(self):
encoded1 = BSON.encode({'x': 256}) encoded1 = BSON.encode({'x': 256})
decoded1 = BSON.decode(encoded1)['x'] decoded1 = BSON(encoded1).decode()['x']
self.assertEqual(256, decoded1) self.assertEqual(256, decoded1)
self.assertEqual(type(256), type(decoded1)) self.assertEqual(type(256), type(decoded1))
encoded2 = BSON.encode({'x': Int64(256)}) encoded2 = BSON.encode({'x': Int64(256)})
decoded2 = BSON.decode(encoded2)['x'] decoded2 = BSON(encoded2).decode()['x']
expected = Int64(256) expected = Int64(256)
self.assertEqual(expected, decoded2) self.assertEqual(expected, decoded2)
self.assertEqual(type(expected), type(decoded2)) self.assertEqual(type(expected), type(decoded2))
@ -905,6 +905,38 @@ class TestBSON(unittest.TestCase):
{"_id": {'$oid': "52d0b971b3ba219fdeb4170e"}}, True) {"_id": {'$oid': "52d0b971b3ba219fdeb4170e"}}, True)
BSON.encode({"_id": {'$oid': "52d0b971b3ba219fdeb4170e"}}) BSON.encode({"_id": {'$oid': "52d0b971b3ba219fdeb4170e"}})
def test_bson_encode_thread_safe(self):
def target(i):
for j in range(1000):
my_int = type('MyInt_%s_%s' % (i, j), (int,), {})
bson.BSON.encode({'my_int': my_int()})
threads = [ExceptionCatchingThread(target=target, args=(i,))
for i in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()
for t in threads:
self.assertIsNone(t.exc)
def test_raise_invalid_document(self):
class Wrapper(object):
def __init__(self, val):
self.val = val
def __repr__(self):
return repr(self.val)
self.assertEqual('1', repr(Wrapper(1)))
with self.assertRaisesRegex(
InvalidDocument,
"cannot encode object: 1, of type: " + repr(Wrapper)):
BSON.encode({'t': Wrapper(1)})
class TestCodecOptions(unittest.TestCase): class TestCodecOptions(unittest.TestCase):
def test_document_class(self): def test_document_class(self):
@ -931,7 +963,8 @@ class TestCodecOptions(unittest.TestCase):
r = ("CodecOptions(document_class=dict, tz_aware=False, " r = ("CodecOptions(document_class=dict, tz_aware=False, "
"uuid_representation=PYTHON_LEGACY, " "uuid_representation=PYTHON_LEGACY, "
"unicode_decode_error_handler='strict', " "unicode_decode_error_handler='strict', "
"tzinfo=None)") "tzinfo=None, type_registry=TypeRegistry(type_codecs=[], "
"fallback_encoder=None))")
self.assertEqual(r, repr(CodecOptions())) self.assertEqual(r, repr(CodecOptions()))
def test_decode_all_defaults(self): def test_decode_all_defaults(self):
@ -956,63 +989,63 @@ class TestCodecOptions(unittest.TestCase):
replaced_key = b'ke\xe9str'.decode('utf-8', 'replace') replaced_key = b'ke\xe9str'.decode('utf-8', 'replace')
ignored_key = b'ke\xe9str'.decode('utf-8', 'ignore') ignored_key = b'ke\xe9str'.decode('utf-8', 'ignore')
dec = BSON.decode(invalid_key, CodecOptions( dec = BSON(invalid_key).decode(CodecOptions(
unicode_decode_error_handler="replace")) unicode_decode_error_handler="replace"))
self.assertEqual(dec, {replaced_key: u"foobar"}) self.assertEqual(dec, {replaced_key: u"foobar"})
dec = BSON.decode(invalid_key, CodecOptions( dec = BSON(invalid_key).decode(CodecOptions(
unicode_decode_error_handler="ignore")) unicode_decode_error_handler="ignore"))
self.assertEqual(dec, {ignored_key: u"foobar"}) self.assertEqual(dec, {ignored_key: u"foobar"})
self.assertRaises(InvalidBSON, BSON.decode, invalid_key, CodecOptions( self.assertRaises(InvalidBSON, BSON(invalid_key).decode, CodecOptions(
unicode_decode_error_handler="strict")) unicode_decode_error_handler="strict"))
self.assertRaises(InvalidBSON, BSON.decode, invalid_key, self.assertRaises(InvalidBSON, BSON(invalid_key).decode,
CodecOptions()) CodecOptions())
self.assertRaises(InvalidBSON, BSON.decode, invalid_key) self.assertRaises(InvalidBSON, BSON(invalid_key).decode)
# Test handing of bad string value. # Test handing of bad string value.
invalid_val = BSON(enc[:18] + b'\xe9' + enc[19:]) invalid_val = BSON(enc[:18] + b'\xe9' + enc[19:])
replaced_val = b'fo\xe9bar'.decode('utf-8', 'replace') replaced_val = b'fo\xe9bar'.decode('utf-8', 'replace')
ignored_val = b'fo\xe9bar'.decode('utf-8', 'ignore') ignored_val = b'fo\xe9bar'.decode('utf-8', 'ignore')
dec = BSON.decode(invalid_val, CodecOptions( dec = BSON(invalid_val).decode(CodecOptions(
unicode_decode_error_handler="replace")) unicode_decode_error_handler="replace"))
self.assertEqual(dec, {u"keystr": replaced_val}) self.assertEqual(dec, {u"keystr": replaced_val})
dec = BSON.decode(invalid_val, CodecOptions( dec = BSON(invalid_val).decode(CodecOptions(
unicode_decode_error_handler="ignore")) unicode_decode_error_handler="ignore"))
self.assertEqual(dec, {u"keystr": ignored_val}) self.assertEqual(dec, {u"keystr": ignored_val})
self.assertRaises(InvalidBSON, BSON.decode, invalid_val, CodecOptions( self.assertRaises(InvalidBSON, BSON(invalid_val).decode, CodecOptions(
unicode_decode_error_handler="strict")) unicode_decode_error_handler="strict"))
self.assertRaises(InvalidBSON, BSON.decode, invalid_val, self.assertRaises(InvalidBSON, BSON(invalid_val).decode,
CodecOptions()) CodecOptions())
self.assertRaises(InvalidBSON, BSON.decode, invalid_val) self.assertRaises(InvalidBSON, BSON(invalid_val).decode)
# Test handing bad key + bad value. # Test handing bad key + bad value.
invalid_both = BSON( invalid_both = BSON(
enc[:7] + b'\xe9' + enc[8:18] + b'\xe9' + enc[19:]) enc[:7] + b'\xe9' + enc[8:18] + b'\xe9' + enc[19:])
dec = BSON.decode(invalid_both, CodecOptions( dec = BSON(invalid_both).decode(CodecOptions(
unicode_decode_error_handler="replace")) unicode_decode_error_handler="replace"))
self.assertEqual(dec, {replaced_key: replaced_val}) self.assertEqual(dec, {replaced_key: replaced_val})
dec = BSON.decode(invalid_both, CodecOptions( dec = BSON(invalid_both).decode(CodecOptions(
unicode_decode_error_handler="ignore")) unicode_decode_error_handler="ignore"))
self.assertEqual(dec, {ignored_key: ignored_val}) self.assertEqual(dec, {ignored_key: ignored_val})
self.assertRaises(InvalidBSON, BSON.decode, invalid_both, CodecOptions( self.assertRaises(InvalidBSON, BSON(invalid_both).decode, CodecOptions(
unicode_decode_error_handler="strict")) unicode_decode_error_handler="strict"))
self.assertRaises(InvalidBSON, BSON.decode, invalid_both, self.assertRaises(InvalidBSON, BSON(invalid_both).decode,
CodecOptions()) CodecOptions())
self.assertRaises(InvalidBSON, BSON.decode, invalid_both) self.assertRaises(InvalidBSON, BSON(invalid_both).decode)
# Test handling bad error mode. # Test handling bad error mode.
dec = BSON.decode(enc, CodecOptions( dec = BSON(enc).decode(CodecOptions(
unicode_decode_error_handler="junk")) unicode_decode_error_handler="junk"))
self.assertEqual(dec, {"keystr": "foobar"}) self.assertEqual(dec, {"keystr": "foobar"})
self.assertRaises(InvalidBSON, BSON.decode, invalid_both, self.assertRaises(InvalidBSON, BSON(invalid_both).decode,
CodecOptions(unicode_decode_error_handler="junk")) CodecOptions(unicode_decode_error_handler="junk"))

View File

@ -36,13 +36,13 @@ from bson.binary import (ALL_UUID_REPRESENTATIONS,
from bson.py3compat import iteritems from bson.py3compat import iteritems
from bson.raw_bson import DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument from bson.raw_bson import DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument
from pymongo import monitoring
from pymongo.change_stream import _NON_RESUMABLE_GETMORE_ERRORS from pymongo.change_stream import _NON_RESUMABLE_GETMORE_ERRORS
from pymongo.command_cursor import CommandCursor from pymongo.command_cursor import CommandCursor
from pymongo.errors import (InvalidOperation, OperationFailure, from pymongo.errors import (InvalidOperation, OperationFailure,
ServerSelectionTimeoutError) ServerSelectionTimeoutError)
from pymongo.message import _CursorAddress from pymongo.message import _CursorAddress
from pymongo.read_concern import ReadConcern from pymongo.read_concern import ReadConcern
from pymongo.write_concern import WriteConcern
from test import client_context, unittest, IntegrationTest from test import client_context, unittest, IntegrationTest
from test.utils import ( from test.utils import (
@ -50,7 +50,97 @@ from test.utils import (
) )
class TestClusterChangeStream(IntegrationTest): class ChangeStreamTryNextMixin(object):
def change_stream_with_client(self, client, *args, **kwargs):
raise NotImplementedError
def change_stream(self, *args, **kwargs):
return self.change_stream_with_client(self.client, *args, **kwargs)
def watched_collection(self):
"""Return a collection that is watched by self.change_stream()."""
raise NotImplementedError
def kill_change_stream_cursor(self, change_stream):
# Cause a cursor not found error on the next getMore.
cursor = change_stream._cursor
address = _CursorAddress(cursor.address, cursor._CommandCursor__ns)
client = self.watched_collection().database.client
client._close_cursor_now(cursor.cursor_id, address)
def test_try_next(self):
# ChangeStreams only read majority committed data so use w:majority.
coll = self.watched_collection().with_options(
write_concern=WriteConcern("majority"))
coll.drop()
coll.insert_one({})
self.addCleanup(coll.drop)
with self.change_stream(max_await_time_ms=250) as stream:
self.assertIsNone(stream.try_next())
self.assertIsNone(stream._resume_token)
coll.insert_one({})
change = stream.try_next()
self.assertEqual(change['_id'], stream._resume_token)
self.assertIsNone(stream.try_next())
self.assertEqual(change['_id'], stream._resume_token)
def test_try_next_runs_one_getmore(self):
listener = EventListener()
client = rs_or_single_client(event_listeners=[listener])
# Connect to the cluster.
client.admin.command('ping')
listener.results.clear()
# ChangeStreams only read majority committed data so use w:majority.
coll = self.watched_collection().with_options(
write_concern=WriteConcern("majority"))
coll.drop()
# Create the watched collection before starting the change stream to
# skip any "create" events.
coll.insert_one({'_id': 1})
self.addCleanup(coll.drop)
with self.change_stream_with_client(
client, max_await_time_ms=250) as stream:
self.assertEqual(listener.started_command_names(), ["aggregate"])
listener.results.clear()
# Confirm that only a single getMore is run even when no documents
# are returned.
self.assertIsNone(stream.try_next())
self.assertEqual(listener.started_command_names(), ["getMore"])
listener.results.clear()
self.assertIsNone(stream.try_next())
self.assertEqual(listener.started_command_names(), ["getMore"])
listener.results.clear()
# Get at least one change before resuming.
coll.insert_one({'_id': 2})
change = stream.try_next()
self.assertEqual(change['_id'], stream._resume_token)
listener.results.clear()
# Cause the next request to initiate the resume process.
self.kill_change_stream_cursor(stream)
listener.results.clear()
# The sequence should be:
# - getMore, fail
# - resume with aggregate command
# - no results, return immediately without another getMore
self.assertIsNone(stream.try_next())
self.assertEqual(
listener.started_command_names(), ["getMore", "aggregate"])
listener.results.clear()
# Stream still works after a resume.
coll.insert_one({'_id': 3})
change = stream.try_next()
self.assertEqual(change['_id'], stream._resume_token)
self.assertEqual(listener.started_command_names(), ["getMore"])
self.assertIsNone(stream.try_next())
class TestClusterChangeStream(IntegrationTest, ChangeStreamTryNextMixin):
@classmethod @classmethod
@client_context.require_version_min(4, 0, 0, -1) @client_context.require_version_min(4, 0, 0, -1)
@ -66,8 +156,11 @@ class TestClusterChangeStream(IntegrationTest):
cls.client.drop_database(db) cls.client.drop_database(db)
super(TestClusterChangeStream, cls).tearDownClass() super(TestClusterChangeStream, cls).tearDownClass()
def change_stream(self, *args, **kwargs): def change_stream_with_client(self, client, *args, **kwargs):
return self.client.watch(*args, **kwargs) return client.watch(*args, **kwargs)
def watched_collection(self):
return self.db.test
def generate_unique_collnames(self, numcolls): def generate_unique_collnames(self, numcolls):
# Generate N collection names unique to a test. # Generate N collection names unique to a test.
@ -94,7 +187,7 @@ class TestClusterChangeStream(IntegrationTest):
) )
class TestDatabaseChangeStream(IntegrationTest): class TestDatabaseChangeStream(IntegrationTest, ChangeStreamTryNextMixin):
@classmethod @classmethod
@client_context.require_version_min(4, 0, 0, -1) @client_context.require_version_min(4, 0, 0, -1)
@ -103,8 +196,11 @@ class TestDatabaseChangeStream(IntegrationTest):
def setUpClass(cls): def setUpClass(cls):
super(TestDatabaseChangeStream, cls).setUpClass() super(TestDatabaseChangeStream, cls).setUpClass()
def change_stream(self, *args, **kwargs): def change_stream_with_client(self, client, *args, **kwargs):
return self.db.watch(*args, **kwargs) return client[self.db.name].watch(*args, **kwargs)
def watched_collection(self):
return self.db.test
def generate_unique_collnames(self, numcolls): def generate_unique_collnames(self, numcolls):
# Generate N collection names unique to a test. # Generate N collection names unique to a test.
@ -145,7 +241,7 @@ class TestDatabaseChangeStream(IntegrationTest):
self.client.drop_database(other_db) self.client.drop_database(other_db)
class TestCollectionChangeStream(IntegrationTest): class TestCollectionChangeStream(IntegrationTest, ChangeStreamTryNextMixin):
@classmethod @classmethod
@client_context.require_version_min(3, 5, 11) @client_context.require_version_min(3, 5, 11)
@ -171,6 +267,12 @@ class TestCollectionChangeStream(IntegrationTest):
def tearDown(self): def tearDown(self):
self.coll.drop() self.coll.drop()
def change_stream_with_client(self, client, *args, **kwargs):
return client[self.db.name].test.watch(*args, **kwargs)
def watched_collection(self):
return self.db.test
def insert_and_check(self, change_stream, doc): def insert_and_check(self, change_stream, doc):
self.coll.insert_one(doc) self.coll.insert_one(doc)
change = next(change_stream) change = next(change_stream)
@ -319,9 +421,7 @@ class TestCollectionChangeStream(IntegrationTest):
with self.coll.watch([]) as change_stream: with self.coll.watch([]) as change_stream:
self.insert_and_check(change_stream, {'_id': 1}) self.insert_and_check(change_stream, {'_id': 1})
# Cause a cursor not found error on the next getMore. # Cause a cursor not found error on the next getMore.
cursor = change_stream._cursor self.kill_change_stream_cursor(change_stream)
address = _CursorAddress(cursor.address, self.coll.full_name)
self.client._close_cursor_now(cursor.cursor_id, address)
self.insert_and_check(change_stream, {'_id': 2}) self.insert_and_check(change_stream, {'_id': 2})
def test_does_not_resume_fatal_errors(self): def test_does_not_resume_fatal_errors(self):
@ -330,16 +430,16 @@ class TestCollectionChangeStream(IntegrationTest):
with self.coll.watch() as change_stream: with self.coll.watch() as change_stream:
self.coll.insert_one({}) self.coll.insert_one({})
def mock_next(*args, **kwargs): def mock_try_next(*args, **kwargs):
change_stream._cursor.close() change_stream._cursor.close()
raise OperationFailure('Mock server error', code=code) raise OperationFailure('Mock server error', code=code)
original_next = change_stream._cursor.next original_try_next = change_stream._cursor._try_next
change_stream._cursor.next = mock_next change_stream._cursor._try_next = mock_try_next
with self.assertRaises(OperationFailure): with self.assertRaises(OperationFailure):
next(change_stream) next(change_stream)
change_stream._cursor.next = original_next change_stream._cursor._try_next = original_try_next
with self.assertRaises(StopIteration): with self.assertRaises(StopIteration):
next(change_stream) next(change_stream)
@ -368,8 +468,7 @@ class TestCollectionChangeStream(IntegrationTest):
self.insert_and_check(change_stream, {'_id': 1}) self.insert_and_check(change_stream, {'_id': 1})
# Cause a cursor not found error on the next getMore. # Cause a cursor not found error on the next getMore.
cursor = change_stream._cursor cursor = change_stream._cursor
address = _CursorAddress(cursor.address, self.coll.full_name) self.kill_change_stream_cursor(change_stream)
self.client._close_cursor_now(cursor.cursor_id, address)
cursor.close = raise_error cursor.close = raise_error
self.insert_and_check(change_stream, {'_id': 2}) self.insert_and_check(change_stream, {'_id': 2})

View File

@ -29,7 +29,7 @@ import warnings
sys.path[0:0] = [""] sys.path[0:0] = [""]
from bson import BSON from bson import BSON
from bson.codec_options import CodecOptions from bson.codec_options import CodecOptions, TypeEncoder, TypeRegistry
from bson.py3compat import thread from bson.py3compat import thread
from bson.son import SON from bson.son import SON
from bson.tz_util import utc from bson.tz_util import utc
@ -179,6 +179,42 @@ class ClientUnitTest(unittest.TestCase):
self.assertRaises(TypeError, iterate) self.assertRaises(TypeError, iterate)
def test_get_default_database(self):
c = rs_or_single_client("mongodb://%s:%d/foo" % (client_context.host,
client_context.port),
connect=False)
self.assertEqual(Database(c, 'foo'), c.get_default_database())
# Test that default doesn't override the URI value.
self.assertEqual(Database(c, 'foo'), c.get_default_database('bar'))
codec_options = CodecOptions(tz_aware=True)
write_concern = WriteConcern(w=2, j=True)
db = c.get_default_database(
None, codec_options, ReadPreference.SECONDARY, write_concern)
self.assertEqual('foo', db.name)
self.assertEqual(codec_options, db.codec_options)
self.assertEqual(ReadPreference.SECONDARY, db.read_preference)
self.assertEqual(write_concern, db.write_concern)
c = rs_or_single_client("mongodb://%s:%d/" % (client_context.host,
client_context.port),
connect=False)
self.assertEqual(Database(c, 'foo'), c.get_default_database('foo'))
def test_get_default_database_error(self):
# URI with no database.
c = rs_or_single_client("mongodb://%s:%d/" % (client_context.host,
client_context.port),
connect=False)
self.assertRaises(ConfigurationError, c.get_default_database)
def test_get_default_database_with_authsource(self):
# Ensure we distinguish database name from authSource.
uri = "mongodb://%s:%d/foo?authSource=src" % (
client_context.host, client_context.port)
c = rs_or_single_client(uri, connect=False)
self.assertEqual(Database(c, 'foo'), c.get_default_database())
def test_get_database_default(self): def test_get_database_default(self):
c = rs_or_single_client("mongodb://%s:%d/foo" % (client_context.host, c = rs_or_single_client("mongodb://%s:%d/foo" % (client_context.host,
client_context.port), client_context.port),
@ -248,14 +284,28 @@ class ClientUnitTest(unittest.TestCase):
self.assertEqual(options.pool_options.metadata, metadata) self.assertEqual(options.pool_options.metadata, metadata)
def test_kwargs_codec_options(self): def test_kwargs_codec_options(self):
class MyFloatType(object):
def __init__(self, x):
self.__x = x
@property
def x(self):
return self.__x
class MyFloatAsIntEncoder(TypeEncoder):
python_type = MyFloatType
def transform_python(self, value):
return int(value)
# Ensure codec options are passed in correctly # Ensure codec options are passed in correctly
document_class = SON document_class = SON
type_registry = TypeRegistry([MyFloatAsIntEncoder()])
tz_aware = True tz_aware = True
uuid_representation_label = 'javaLegacy' uuid_representation_label = 'javaLegacy'
unicode_decode_error_handler = 'ignore' unicode_decode_error_handler = 'ignore'
tzinfo = utc tzinfo = utc
c = MongoClient( c = MongoClient(
document_class=document_class, document_class=document_class,
type_registry=type_registry,
tz_aware=tz_aware, tz_aware=tz_aware,
uuidrepresentation=uuid_representation_label, uuidrepresentation=uuid_representation_label,
unicode_decode_error_handler=unicode_decode_error_handler, unicode_decode_error_handler=unicode_decode_error_handler,
@ -264,6 +314,7 @@ class ClientUnitTest(unittest.TestCase):
) )
self.assertEqual(c.codec_options.document_class, document_class) self.assertEqual(c.codec_options.document_class, document_class)
self.assertEqual(c.codec_options.type_registry, type_registry)
self.assertEqual(c.codec_options.tz_aware, tz_aware) self.assertEqual(c.codec_options.tz_aware, tz_aware)
self.assertEqual( self.assertEqual(
c.codec_options.uuid_representation, c.codec_options.uuid_representation,

View File

@ -30,6 +30,14 @@ class TestClientContext(unittest.TestCase):
'PYMONGO_MUST_CONNECT is set. Failed attempts:\n%s' % 'PYMONGO_MUST_CONNECT is set. Failed attempts:\n%s' %
(client_context.connection_attempt_info(),)) (client_context.connection_attempt_info(),))
def test_enableTestCommands_is_disabled(self):
if 'PYMONGO_DISABLE_TEST_COMMANDS' not in os.environ:
raise SkipTest('PYMONGO_DISABLE_TEST_COMMANDS is not set')
self.assertFalse(client_context.test_commands_enabled,
'enableTestCommands must be disabled when '
'PYMONGO_DISABLE_TEST_COMMANDS is set.')
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -132,6 +132,12 @@ class TestCollection(IntegrationTest):
def tearDownClass(cls): def tearDownClass(cls):
cls.db.drop_collection("test_large_limit") cls.db.drop_collection("test_large_limit")
def setUp(self):
self.db.test.drop()
def tearDown(self):
self.db.test.drop()
@contextlib.contextmanager @contextlib.contextmanager
def write_concern_collection(self): def write_concern_collection(self):
if client_context.version.at_least(3, 3, 9) and client_context.is_rs: if client_context.version.at_least(3, 3, 9) and client_context.is_rs:
@ -1765,9 +1771,8 @@ class TestCollection(IntegrationTest):
db.test_large_limit.create_index([('x', 1)]) db.test_large_limit.create_index([('x', 1)])
my_str = "mongomongo" * 1000 my_str = "mongomongo" * 1000
for i in range(2000): db.test_large_limit.insert_many(
doc = {"x": i, "y": my_str} {"x": i, "y": my_str} for i in range(2000))
db.test_large_limit.insert_one(doc)
i = 0 i = 0
y = 0 y = 0
@ -1781,9 +1786,7 @@ class TestCollection(IntegrationTest):
def test_find_kwargs(self): def test_find_kwargs(self):
db = self.db db = self.db
db.drop_collection("test") db.drop_collection("test")
db.test.insert_many({"x": i} for i in range(10))
for i in range(10):
db.test.insert_one({"x": i})
self.assertEqual(10, db.test.count_documents({})) self.assertEqual(10, db.test.count_documents({}))
@ -1808,8 +1811,7 @@ class TestCollection(IntegrationTest):
self.assertEqual(0, db.test.count_documents({})) self.assertEqual(0, db.test.count_documents({}))
self.assertEqual(0, db.foo.count_documents({})) self.assertEqual(0, db.foo.count_documents({}))
for i in range(10): db.test.insert_many({"x": i} for i in range(10))
db.test.insert_one({"x": i})
self.assertEqual(10, db.test.count_documents({})) self.assertEqual(10, db.test.count_documents({}))
@ -2004,10 +2006,13 @@ class TestCollection(IntegrationTest):
self.db.test.insert_many([{"x": 1}, {"x": 2}]) self.db.test.insert_many([{"x": 1}, {"x": 2}])
self.db.test.create_index("x") self.db.test.create_index("x")
self.assertEqual(1, len(list(self.db.test.find({"$min": {"x": 2}, cursor = self.db.test.find({"$min": {"x": 2}, "$query": {}})
"$query": {}})))) if client_context.requires_hint_with_min_max_queries:
self.assertEqual(2, self.db.test.find({"$min": {"x": 2}, cursor = cursor.hint("x_1")
"$query": {}})[0]["x"])
docs = list(cursor)
self.assertEqual(1, len(docs))
self.assertEqual(2, docs[0]["x"])
def test_numerous_inserts(self): def test_numerous_inserts(self):
# Ensure we don't exceed server's 1000-document batch size limit. # Ensure we don't exceed server's 1000-document batch size limit.

View File

@ -21,6 +21,7 @@ import re
import sys import sys
import time import time
import threading import threading
import warnings
sys.path[0:0] = [""] sys.path[0:0] = [""]
@ -449,66 +450,102 @@ class TestCursor(IntegrationTest):
break break
self.assertRaises(InvalidOperation, a.limit, 5) self.assertRaises(InvalidOperation, a.limit, 5)
@ignore_deprecations # Ignore max without hint.
def test_max(self): def test_max(self):
db = self.db db = self.db
db.test.drop() db.test.drop()
db.test.create_index([("j", ASCENDING)]) j_index = [("j", ASCENDING)]
db.test.create_index(j_index)
db.test.insert_many([{"j": j, "k": j} for j in range(10)]) db.test.insert_many([{"j": j, "k": j} for j in range(10)])
cursor = db.test.find().max([("j", 3)]) def find(max_spec, expected_index):
cursor = db.test.find().max(max_spec)
if client_context.requires_hint_with_min_max_queries:
cursor = cursor.hint(expected_index)
return cursor
cursor = find([("j", 3)], j_index)
self.assertEqual(len(list(cursor)), 3) self.assertEqual(len(list(cursor)), 3)
# Tuple. # Tuple.
cursor = db.test.find().max((("j", 3), )) cursor = find((("j", 3),), j_index)
self.assertEqual(len(list(cursor)), 3) self.assertEqual(len(list(cursor)), 3)
# Compound index. # Compound index.
db.test.create_index([("j", ASCENDING), ("k", ASCENDING)]) index_keys = [("j", ASCENDING), ("k", ASCENDING)]
cursor = db.test.find().max([("j", 3), ("k", 3)]) db.test.create_index(index_keys)
cursor = find([("j", 3), ("k", 3)], index_keys)
self.assertEqual(len(list(cursor)), 3) self.assertEqual(len(list(cursor)), 3)
# Wrong order. # Wrong order.
cursor = db.test.find().max([("k", 3), ("j", 3)]) cursor = find([("k", 3), ("j", 3)], index_keys)
self.assertRaises(OperationFailure, list, cursor) self.assertRaises(OperationFailure, list, cursor)
# No such index. # No such index.
cursor = db.test.find().max([("k", 3)]) cursor = find([("k", 3)], "k")
self.assertRaises(OperationFailure, list, cursor) self.assertRaises(OperationFailure, list, cursor)
self.assertRaises(TypeError, db.test.find().max, 10) self.assertRaises(TypeError, db.test.find().max, 10)
self.assertRaises(TypeError, db.test.find().max, {"j": 10}) self.assertRaises(TypeError, db.test.find().max, {"j": 10})
@ignore_deprecations # Ignore min without hint.
def test_min(self): def test_min(self):
db = self.db db = self.db
db.test.drop() db.test.drop()
db.test.create_index([("j", ASCENDING)]) j_index = [("j", ASCENDING)]
db.test.create_index(j_index)
db.test.insert_many([{"j": j, "k": j} for j in range(10)]) db.test.insert_many([{"j": j, "k": j} for j in range(10)])
cursor = db.test.find().min([("j", 3)]) def find(min_spec, expected_index):
cursor = db.test.find().min(min_spec)
if client_context.requires_hint_with_min_max_queries:
cursor = cursor.hint(expected_index)
return cursor
cursor = find([("j", 3)], j_index)
self.assertEqual(len(list(cursor)), 7) self.assertEqual(len(list(cursor)), 7)
# Tuple. # Tuple.
cursor = db.test.find().min((("j", 3), )) cursor = find((("j", 3),), j_index)
self.assertEqual(len(list(cursor)), 7) self.assertEqual(len(list(cursor)), 7)
# Compound index. # Compound index.
db.test.create_index([("j", ASCENDING), ("k", ASCENDING)]) index_keys = [("j", ASCENDING), ("k", ASCENDING)]
cursor = db.test.find().min([("j", 3), ("k", 3)]) db.test.create_index(index_keys)
cursor = find([("j", 3), ("k", 3)], index_keys)
self.assertEqual(len(list(cursor)), 7) self.assertEqual(len(list(cursor)), 7)
# Wrong order. # Wrong order.
cursor = db.test.find().min([("k", 3), ("j", 3)]) cursor = find([("k", 3), ("j", 3)], index_keys)
self.assertRaises(OperationFailure, list, cursor) self.assertRaises(OperationFailure, list, cursor)
# No such index. # No such index.
cursor = db.test.find().min([("k", 3)]) cursor = find([("k", 3)], "k")
self.assertRaises(OperationFailure, list, cursor) self.assertRaises(OperationFailure, list, cursor)
self.assertRaises(TypeError, db.test.find().min, 10) self.assertRaises(TypeError, db.test.find().min, 10)
self.assertRaises(TypeError, db.test.find().min, {"j": 10}) self.assertRaises(TypeError, db.test.find().min, {"j": 10})
@client_context.require_version_max(4, 1, -1)
def test_min_max_without_hint(self):
coll = self.db.test
j_index = [("j", ASCENDING)]
coll.create_index(j_index)
with warnings.catch_warnings(record=True) as warns:
warnings.simplefilter("default", DeprecationWarning)
list(coll.find().min([("j", 3)]))
self.assertIn('using a min/max query operator', str(warns[0]))
# Ensure the warning is raised with the proper stack level.
del warns[:]
list(coll.find().min([("j", 3)]))
self.assertIn('using a min/max query operator', str(warns[0]))
del warns[:]
list(coll.find().max([("j", 3)]))
self.assertIn('using a min/max query operator', str(warns[0]))
def test_batch_size(self): def test_batch_size(self):
db = self.db db = self.db
db.test.drop() db.test.drop()
@ -1220,9 +1257,6 @@ class TestCursor(IntegrationTest):
@client_context.require_no_mongos @client_context.require_no_mongos
@ignore_deprecations @ignore_deprecations
def test_comment(self): def test_comment(self):
if client_context.auth_enabled:
raise SkipTest("SERVER-4754 - This test uses profiling.")
# MongoDB 3.1.5 changed the ns for commands. # MongoDB 3.1.5 changed the ns for commands.
regex = {'$regex': r'pymongo_test.(\$cmd|test)'} regex = {'$regex': r'pymongo_test.(\$cmd|test)'}
@ -1246,14 +1280,14 @@ class TestCursor(IntegrationTest):
op = self.db.system.profile.find({'ns': regex, op = self.db.system.profile.find({'ns': regex,
'op': 'command', 'op': 'command',
'command.count': 'test', 'command.count': 'test',
'command.$comment': 'foo'}) 'command.comment': 'foo'})
self.assertEqual(op.count(), 1) self.assertEqual(op.count(), 1)
self.db.test.find().comment('foo').distinct('type') self.db.test.find().comment('foo').distinct('type')
op = self.db.system.profile.find({'ns': regex, op = self.db.system.profile.find({'ns': regex,
'op': 'command', 'op': 'command',
'command.distinct': 'test', 'command.distinct': 'test',
'command.$comment': 'foo'}) 'command.comment': 'foo'})
self.assertEqual(op.count(), 1) self.assertEqual(op.count(), 1)
finally: finally:
self.db.set_profiling_level(OFF) self.db.set_profiling_level(OFF)

915
test/test_custom_types.py Normal file
View File

@ -0,0 +1,915 @@
# Copyright 2019-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test support for callbacks to encode/decode custom types."""
import datetime
import sys
import tempfile
from collections import OrderedDict
from decimal import Decimal
from random import random
sys.path[0:0] = [""]
from bson import (BSON,
Decimal128,
decode_all,
decode_file_iter,
decode_iter,
RE_TYPE,
_BUILT_IN_TYPES,
_dict_to_bson,
_bson_to_dict)
from bson.code import Code
from bson.codec_options import (CodecOptions, TypeCodec, TypeDecoder,
TypeEncoder, TypeRegistry)
from bson.errors import InvalidDocument
from bson.int64 import Int64
from bson.raw_bson import RawBSONDocument
from bson.py3compat import text_type
from gridfs import GridIn, GridOut
from pymongo.collection import ReturnDocument
from pymongo.errors import DuplicateKeyError
from pymongo.message import _CursorAddress
from test import client_context, unittest
from test.test_client import IntegrationTest
from test.utils import ignore_deprecations, rs_client
class DecimalEncoder(TypeEncoder):
@property
def python_type(self):
return Decimal
def transform_python(self, value):
return Decimal128(value)
class DecimalDecoder(TypeDecoder):
@property
def bson_type(self):
return Decimal128
def transform_bson(self, value):
return value.to_decimal()
class DecimalCodec(DecimalDecoder, DecimalEncoder):
pass
DECIMAL_CODECOPTS = CodecOptions(
type_registry=TypeRegistry([DecimalCodec()]))
class UndecipherableInt64Type(object):
def __init__(self, value):
self.value = value
def __eq__(self, other):
if isinstance(other, type(self)):
return self.value == other.value
# Does not compare equal to integers.
return False
class UndecipherableIntDecoder(TypeDecoder):
bson_type = Int64
def transform_bson(self, value):
return UndecipherableInt64Type(value)
class UndecipherableIntEncoder(TypeEncoder):
python_type = UndecipherableInt64Type
def transform_python(self, value):
return Int64(value.value)
UNINT_DECODER_CODECOPTS = CodecOptions(
type_registry=TypeRegistry([UndecipherableIntDecoder(), ]))
UNINT_CODECOPTS = CodecOptions(type_registry=TypeRegistry(
[UndecipherableIntDecoder(), UndecipherableIntEncoder()]))
class UppercaseTextDecoder(TypeDecoder):
bson_type = text_type
def transform_bson(self, value):
return value.upper()
UPPERSTR_DECODER_CODECOPTS = CodecOptions(type_registry=TypeRegistry(
[UppercaseTextDecoder(),]))
def type_obfuscating_decoder_factory(rt_type):
class ResumeTokenToNanDecoder(TypeDecoder):
bson_type = rt_type
def transform_bson(self, value):
return "NaN"
return ResumeTokenToNanDecoder
class CustomBSONTypeTests(object):
def roundtrip(self, doc):
bsonbytes = BSON().encode(doc, codec_options=self.codecopts)
rt_document = BSON(bsonbytes).decode(codec_options=self.codecopts)
self.assertEqual(doc, rt_document)
def test_encode_decode_roundtrip(self):
self.roundtrip({'average': Decimal('56.47')})
self.roundtrip({'average': {'b': Decimal('56.47')}})
self.roundtrip({'average': [Decimal('56.47')]})
self.roundtrip({'average': [[Decimal('56.47')]]})
self.roundtrip({'average': [{'b': Decimal('56.47')}]})
def test_decode_all(self):
documents = []
for dec in range(3):
documents.append({'average': Decimal('56.4%s' % (dec,))})
bsonstream = bytes()
for doc in documents:
bsonstream += BSON.encode(doc, codec_options=self.codecopts)
self.assertEqual(
decode_all(bsonstream, self.codecopts), documents)
def test__bson_to_dict(self):
document = {'average': Decimal('56.47')}
rawbytes = BSON.encode(document, codec_options=self.codecopts)
decoded_document = _bson_to_dict(rawbytes, self.codecopts)
self.assertEqual(document, decoded_document)
def test__dict_to_bson(self):
document = {'average': Decimal('56.47')}
rawbytes = BSON.encode(document, codec_options=self.codecopts)
encoded_document = _dict_to_bson(document, False, self.codecopts)
self.assertEqual(encoded_document, rawbytes)
def _generate_multidocument_bson_stream(self):
inp_num = [str(random() * 100)[:4] for _ in range(10)]
docs = [{'n': Decimal128(dec)} for dec in inp_num]
edocs = [{'n': Decimal(dec)} for dec in inp_num]
bsonstream = b""
for doc in docs:
bsonstream += BSON.encode(doc)
return edocs, bsonstream
def test_decode_iter(self):
expected, bson_data = self._generate_multidocument_bson_stream()
for expected_doc, decoded_doc in zip(
expected, decode_iter(bson_data, self.codecopts)):
self.assertEqual(expected_doc, decoded_doc)
def test_decode_file_iter(self):
expected, bson_data = self._generate_multidocument_bson_stream()
fileobj = tempfile.TemporaryFile()
fileobj.write(bson_data)
fileobj.seek(0)
for expected_doc, decoded_doc in zip(
expected, decode_file_iter(fileobj, self.codecopts)):
self.assertEqual(expected_doc, decoded_doc)
fileobj.close()
class TestCustomPythonBSONTypeToBSONMonolithicCodec(CustomBSONTypeTests,
unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.codecopts = DECIMAL_CODECOPTS
class TestCustomPythonBSONTypeToBSONMultiplexedCodec(CustomBSONTypeTests,
unittest.TestCase):
@classmethod
def setUpClass(cls):
codec_options = CodecOptions(
type_registry=TypeRegistry((DecimalEncoder(), DecimalDecoder())))
cls.codecopts = codec_options
class TestBSONFallbackEncoder(unittest.TestCase):
def _get_codec_options(self, fallback_encoder):
type_registry = TypeRegistry(fallback_encoder=fallback_encoder)
return CodecOptions(type_registry=type_registry)
def test_simple(self):
codecopts = self._get_codec_options(lambda x: Decimal128(x))
document = {'average': Decimal('56.47')}
bsonbytes = BSON().encode(document, codec_options=codecopts)
exp_document = {'average': Decimal128('56.47')}
exp_bsonbytes = BSON().encode(exp_document)
self.assertEqual(bsonbytes, exp_bsonbytes)
def test_erroring_fallback_encoder(self):
codecopts = self._get_codec_options(lambda _: 1/0)
# fallback converter should not be invoked when encoding known types.
BSON().encode(
{'a': 1, 'b': Decimal128('1.01'), 'c': {'arr': ['abc', 3.678]}},
codec_options=codecopts)
# expect an error when encoding a custom type.
document = {'average': Decimal('56.47')}
with self.assertRaises(ZeroDivisionError):
BSON().encode(document, codec_options=codecopts)
def test_noop_fallback_encoder(self):
codecopts = self._get_codec_options(lambda x: x)
document = {'average': Decimal('56.47')}
with self.assertRaises(InvalidDocument):
BSON().encode(document, codec_options=codecopts)
def test_type_unencodable_by_fallback_encoder(self):
def fallback_encoder(value):
try:
return Decimal128(value)
except:
raise TypeError("cannot encode type %s" % (type(value)))
codecopts = self._get_codec_options(fallback_encoder)
document = {'average': Decimal}
with self.assertRaises(TypeError):
BSON().encode(document, codec_options=codecopts)
class TestBSONTypeEnDeCodecs(unittest.TestCase):
def test_instantiation(self):
msg = "Can't instantiate abstract class .* with abstract methods .*"
def run_test(base, attrs, fail):
codec = type('testcodec', (base,), attrs)
if fail:
with self.assertRaisesRegex(TypeError, msg):
codec()
else:
codec()
class MyType(object):
pass
run_test(TypeEncoder, {'python_type': MyType,}, fail=True)
run_test(TypeEncoder, {'transform_python': lambda s, x: x}, fail=True)
run_test(TypeEncoder, {'transform_python': lambda s, x: x,
'python_type': MyType}, fail=False)
run_test(TypeDecoder, {'bson_type': Decimal128, }, fail=True)
run_test(TypeDecoder, {'transform_bson': lambda s, x: x}, fail=True)
run_test(TypeDecoder, {'transform_bson': lambda s, x: x,
'bson_type': Decimal128}, fail=False)
run_test(TypeCodec, {'bson_type': Decimal128,
'python_type': MyType}, fail=True)
run_test(TypeCodec, {'transform_bson': lambda s, x: x,
'transform_python': lambda s, x: x}, fail=True)
run_test(TypeCodec, {'python_type': MyType,
'transform_python': lambda s, x: x,
'transform_bson': lambda s, x: x,
'bson_type': Decimal128}, fail=False)
def test_type_checks(self):
self.assertTrue(issubclass(TypeCodec, TypeEncoder))
self.assertTrue(issubclass(TypeCodec, TypeDecoder))
self.assertFalse(issubclass(TypeDecoder, TypeEncoder))
self.assertFalse(issubclass(TypeEncoder, TypeDecoder))
class TestBSONCustomTypeEncoderAndFallbackEncoderTandem(unittest.TestCase):
@classmethod
def setUpClass(cls):
class TypeA(object):
def __init__(self, x):
self.value = x
class TypeB(object):
def __init__(self, x):
self.value = x
# transforms A, and only A into B
def fallback_encoder_A2B(value):
assert isinstance(value, TypeA)
return TypeB(value.value)
# transforms A, and only A into something encodable
def fallback_encoder_A2BSON(value):
assert isinstance(value, TypeA)
return value.value
# transforms B into something encodable
class B2BSON(TypeEncoder):
python_type = TypeB
def transform_python(self, value):
return value.value
# transforms A into B
# technically, this isn't a proper type encoder as the output is not
# BSON-encodable.
class A2B(TypeEncoder):
python_type = TypeA
def transform_python(self, value):
return TypeB(value.value)
# transforms B into A
# technically, this isn't a proper type encoder as the output is not
# BSON-encodable.
class B2A(TypeEncoder):
python_type = TypeB
def transform_python(self, value):
return TypeA(value.value)
cls.TypeA = TypeA
cls.TypeB = TypeB
cls.fallback_encoder_A2B = staticmethod(fallback_encoder_A2B)
cls.fallback_encoder_A2BSON = staticmethod(fallback_encoder_A2BSON)
cls.B2BSON = B2BSON
cls.B2A = B2A
cls.A2B = A2B
def test_encode_fallback_then_custom(self):
codecopts = CodecOptions(type_registry=TypeRegistry(
[self.B2BSON()], fallback_encoder=self.fallback_encoder_A2B))
testdoc = {'x': self.TypeA(123)}
expected_bytes = BSON.encode({'x': 123})
self.assertEqual(BSON.encode(testdoc, codec_options=codecopts),
expected_bytes)
def test_encode_custom_then_fallback(self):
codecopts = CodecOptions(type_registry=TypeRegistry(
[self.B2A()], fallback_encoder=self.fallback_encoder_A2BSON))
testdoc = {'x': self.TypeB(123)}
expected_bytes = BSON.encode({'x': 123})
self.assertEqual(BSON.encode(testdoc, codec_options=codecopts),
expected_bytes)
def test_chaining_encoders_fails(self):
codecopts = CodecOptions(type_registry=TypeRegistry(
[self.A2B(), self.B2BSON()]))
with self.assertRaises(InvalidDocument):
BSON.encode({'x': self.TypeA(123)}, codec_options=codecopts)
def test_infinite_loop_exceeds_max_recursion_depth(self):
codecopts = CodecOptions(type_registry=TypeRegistry(
[self.B2A()], fallback_encoder=self.fallback_encoder_A2B))
# Raises max recursion depth exceeded error
with self.assertRaises(RuntimeError):
BSON.encode({'x': self.TypeA(100)}, codec_options=codecopts)
class TestTypeRegistry(unittest.TestCase):
@classmethod
def setUpClass(cls):
class MyIntType(object):
def __init__(self, x):
assert isinstance(x, int)
self.x = x
class MyStrType(object):
def __init__(self, x):
assert isinstance(x, str)
self.x = x
class MyIntCodec(TypeCodec):
@property
def python_type(self):
return MyIntType
@property
def bson_type(self):
return int
def transform_python(self, value):
return value.x
def transform_bson(self, value):
return MyIntType(value)
class MyStrCodec(TypeCodec):
@property
def python_type(self):
return MyStrType
@property
def bson_type(self):
return str
def transform_python(self, value):
return value.x
def transform_bson(self, value):
return MyStrType(value)
def fallback_encoder(value):
return value
cls.types = (MyIntType, MyStrType)
cls.codecs = (MyIntCodec, MyStrCodec)
cls.fallback_encoder = fallback_encoder
def test_simple(self):
codec_instances = [codec() for codec in self.codecs]
def assert_proper_initialization(type_registry, codec_instances):
self.assertEqual(type_registry._encoder_map, {
self.types[0]: codec_instances[0].transform_python,
self.types[1]: codec_instances[1].transform_python})
self.assertEqual(type_registry._decoder_map, {
int: codec_instances[0].transform_bson,
str: codec_instances[1].transform_bson})
self.assertEqual(
type_registry._fallback_encoder, self.fallback_encoder)
type_registry = TypeRegistry(codec_instances, self.fallback_encoder)
assert_proper_initialization(type_registry, codec_instances)
type_registry = TypeRegistry(
fallback_encoder=self.fallback_encoder, type_codecs=codec_instances)
assert_proper_initialization(type_registry, codec_instances)
# Ensure codec list held by the type registry doesn't change if we
# mutate the initial list.
codec_instances_copy = list(codec_instances)
codec_instances.pop(0)
self.assertListEqual(
type_registry._TypeRegistry__type_codecs, codec_instances_copy)
def test_simple_separate_codecs(self):
class MyIntEncoder(TypeEncoder):
python_type = self.types[0]
def transform_python(self, value):
return value.x
class MyIntDecoder(TypeDecoder):
bson_type = int
def transform_bson(self, value):
return self.types[0](value)
codec_instances = [MyIntDecoder(), MyIntEncoder()]
type_registry = TypeRegistry(codec_instances)
self.assertEqual(
type_registry._encoder_map,
{MyIntEncoder.python_type: codec_instances[1].transform_python})
self.assertEqual(
type_registry._decoder_map,
{MyIntDecoder.bson_type: codec_instances[0].transform_bson})
def test_initialize_fail(self):
err_msg = ("Expected an instance of TypeEncoder, TypeDecoder, "
"or TypeCodec, got .* instead")
with self.assertRaisesRegex(TypeError, err_msg):
TypeRegistry(self.codecs)
with self.assertRaisesRegex(TypeError, err_msg):
TypeRegistry([type('AnyType', (object,), {})()])
err_msg = "fallback_encoder %r is not a callable" % (True,)
with self.assertRaisesRegex(TypeError, err_msg):
TypeRegistry([], True)
err_msg = "fallback_encoder %r is not a callable" % ('hello',)
with self.assertRaisesRegex(TypeError, err_msg):
TypeRegistry(fallback_encoder='hello')
def test_type_registry_repr(self):
codec_instances = [codec() for codec in self.codecs]
type_registry = TypeRegistry(codec_instances)
r = ("TypeRegistry(type_codecs=%r, fallback_encoder=%r)" % (
codec_instances, None))
self.assertEqual(r, repr(type_registry))
def test_type_registry_eq(self):
codec_instances = [codec() for codec in self.codecs]
self.assertEqual(
TypeRegistry(codec_instances), TypeRegistry(codec_instances))
codec_instances_2 = [codec() for codec in self.codecs]
self.assertNotEqual(
TypeRegistry(codec_instances), TypeRegistry(codec_instances_2))
def test_builtin_types_override_fails(self):
def run_test(base, attrs):
msg = ("TypeEncoders cannot change how built-in types "
"are encoded \(encoder .* transforms type .*\)")
for pytype in _BUILT_IN_TYPES:
attrs.update({'python_type': pytype,
'transform_python': lambda x: x})
codec = type('testcodec', (base, ), attrs)
codec_instance = codec()
with self.assertRaisesRegex(TypeError, msg):
TypeRegistry([codec_instance,])
# Test only some subtypes as not all can be subclassed.
if pytype in [bool, type(None), RE_TYPE,]:
continue
class MyType(pytype):
pass
attrs.update({'python_type': MyType,
'transform_python': lambda x: x})
codec = type('testcodec', (base, ), attrs)
codec_instance = codec()
with self.assertRaisesRegex(TypeError, msg):
TypeRegistry([codec_instance,])
run_test(TypeEncoder, {})
run_test(TypeCodec, {'bson_type': Decimal128,
'transform_bson': lambda x: x})
class TestCollectionWCustomType(IntegrationTest):
def setUp(self):
self.db.test.drop()
def tearDown(self):
self.db.test.drop()
def test_command_errors_w_custom_type_decoder(self):
db = self.db
test_doc = {'_id': 1, 'data': 'a'}
test = db.get_collection('test',
codec_options=UNINT_DECODER_CODECOPTS)
result = test.insert_one(test_doc)
self.assertEqual(result.inserted_id, test_doc['_id'])
with self.assertRaises(DuplicateKeyError):
test.insert_one(test_doc)
def test_find_w_custom_type_decoder(self):
db = self.db
input_docs = [
{'x': Int64(k)} for k in [1, 2, 3]]
for doc in input_docs:
db.test.insert_one(doc)
test = db.get_collection(
'test', codec_options=UNINT_DECODER_CODECOPTS)
for doc in test.find({}, batch_size=1):
self.assertIsInstance(doc['x'], UndecipherableInt64Type)
def test_find_w_custom_type_decoder_and_document_class(self):
def run_test(doc_cls):
db = self.db
input_docs = [
{'x': Int64(k)} for k in [1, 2, 3]]
for doc in input_docs:
db.test.insert_one(doc)
test = db.get_collection('test', codec_options=CodecOptions(
type_registry=TypeRegistry([UndecipherableIntDecoder()]),
document_class=doc_cls))
for doc in test.find({}, batch_size=1):
self.assertIsInstance(doc, doc_cls)
self.assertIsInstance(doc['x'], UndecipherableInt64Type)
for doc_cls in [RawBSONDocument, OrderedDict]:
run_test(doc_cls)
@client_context.require_version_max(4, 1, 0, -1)
def test_group_w_custom_type(self):
db = self.db
test = db.get_collection('test', codec_options=UNINT_CODECOPTS)
test.insert_many([
{'sku': 'a', 'qty': UndecipherableInt64Type(2)},
{'sku': 'b', 'qty': UndecipherableInt64Type(5)},
{'sku': 'a', 'qty': UndecipherableInt64Type(1)}])
self.assertEqual([{'sku': 'b', 'qty': UndecipherableInt64Type(5)},],
test.group(["sku", "qty"], {"sku": "b"}, {},
"function (obj, prev) { }"))
def test_aggregate_w_custom_type_decoder(self):
db = self.db
db.test.insert_many([
{'status': 'in progress', 'qty': Int64(1)},
{'status': 'complete', 'qty': Int64(10)},
{'status': 'in progress', 'qty': Int64(1)},
{'status': 'complete', 'qty': Int64(10)},
{'status': 'in progress', 'qty': Int64(1)},])
test = db.get_collection(
'test', codec_options=UNINT_DECODER_CODECOPTS)
pipeline = [
{'$match': {'status': 'complete'}},
{'$group': {'_id': "$status", 'total_qty': {"$sum": "$qty"}}},]
result = test.aggregate(pipeline)
res = list(result)[0]
self.assertEqual(res['_id'], 'complete')
self.assertIsInstance(res['total_qty'], UndecipherableInt64Type)
self.assertEqual(res['total_qty'].value, 20)
def test_distinct_w_custom_type(self):
self.db.drop_collection("test")
test = self.db.get_collection('test', codec_options=UNINT_CODECOPTS)
values = [
UndecipherableInt64Type(1),
UndecipherableInt64Type(2),
UndecipherableInt64Type(3),
{"b": UndecipherableInt64Type(3)}]
test.insert_many({"a": val} for val in values)
self.assertEqual(values, test.distinct("a"))
def test_map_reduce_w_custom_type(self):
test = self.db.get_collection(
'test', codec_options=UPPERSTR_DECODER_CODECOPTS)
test.insert_many([
{'_id': 1, 'sku': 'abcd', 'qty': 1},
{'_id': 2, 'sku': 'abcd', 'qty': 2},
{'_id': 3, 'sku': 'abcd', 'qty': 3}])
map = Code("function () {"
" emit(this.sku, this.qty);"
"}")
reduce = Code("function (key, values) {"
" return Array.sum(values);"
"}")
result = test.map_reduce(map, reduce, out={'inline': 1})
self.assertTrue(isinstance(result, dict))
self.assertTrue('results' in result)
self.assertEqual(result['results'][0], {'_id': 'ABCD', 'value': 6})
result = test.inline_map_reduce(map, reduce)
self.assertTrue(isinstance(result, list))
self.assertEqual(1, len(result))
self.assertEqual(result[0]["_id"], 'ABCD')
full_result = test.inline_map_reduce(map, reduce,
full_response=True)
self.assertEqual(3, full_result["counts"]["emit"])
def test_find_one_and__w_custom_type_decoder(self):
db = self.db
c = db.get_collection('test', codec_options=UNINT_DECODER_CODECOPTS)
c.insert_one({'_id': 1, 'x': Int64(1)})
doc = c.find_one_and_update({'_id': 1}, {'$inc': {'x': 1}},
return_document=ReturnDocument.AFTER)
self.assertEqual(doc['_id'], 1)
self.assertIsInstance(doc['x'], UndecipherableInt64Type)
self.assertEqual(doc['x'].value, 2)
doc = c.find_one_and_replace({'_id': 1}, {'x': Int64(3), 'y': True},
return_document=ReturnDocument.AFTER)
self.assertEqual(doc['_id'], 1)
self.assertIsInstance(doc['x'], UndecipherableInt64Type)
self.assertEqual(doc['x'].value, 3)
self.assertEqual(doc['y'], True)
doc = c.find_one_and_delete({'y': True})
self.assertEqual(doc['_id'], 1)
self.assertIsInstance(doc['x'], UndecipherableInt64Type)
self.assertEqual(doc['x'].value, 3)
self.assertIsNone(c.find_one())
@ignore_deprecations
def test_find_and_modify_w_custom_type_decoder(self):
db = self.db
c = db.get_collection('test', codec_options=UNINT_DECODER_CODECOPTS)
c.insert_one({'_id': 1, 'x': Int64(1)})
doc = c.find_and_modify({'_id': 1}, {'$inc': {'x': Int64(10)}})
self.assertEqual(doc['_id'], 1)
self.assertIsInstance(doc['x'], UndecipherableInt64Type)
self.assertEqual(doc['x'].value, 1)
doc = c.find_one()
self.assertEqual(doc['_id'], 1)
self.assertIsInstance(doc['x'], UndecipherableInt64Type)
self.assertEqual(doc['x'].value, 11)
class TestGridFileCustomType(IntegrationTest):
def setUp(self):
self.db.drop_collection('fs.files')
self.db.drop_collection('fs.chunks')
def test_grid_out_custom_opts(self):
db = self.db.with_options(codec_options=UPPERSTR_DECODER_CODECOPTS)
one = GridIn(db.fs, _id=5, filename="my_file",
contentType="text/html", chunkSize=1000, aliases=["foo"],
metadata={"foo": 'red', "bar": 'blue'}, bar=3,
baz="hello")
one.write(b"hello world")
one.close()
two = GridOut(db.fs, 5)
self.assertEqual("my_file", two.name)
self.assertEqual("my_file", two.filename)
self.assertEqual(5, two._id)
self.assertEqual(11, two.length)
self.assertEqual("text/html", two.content_type)
self.assertEqual(1000, two.chunk_size)
self.assertTrue(isinstance(two.upload_date, datetime.datetime))
self.assertEqual(["foo"], two.aliases)
self.assertEqual({"foo": 'red', "bar": 'blue'}, two.metadata)
self.assertEqual(3, two.bar)
self.assertEqual("5eb63bbbe01eeed093cb22bb8f5acdc3", two.md5)
for attr in ["_id", "name", "content_type", "length", "chunk_size",
"upload_date", "aliases", "metadata", "md5"]:
self.assertRaises(AttributeError, setattr, two, attr, 5)
class ChangeStreamsWCustomTypesTestMixin(object):
def change_stream(self, *args, **kwargs):
return self.watched_target.watch(*args, **kwargs)
def insert_and_check(self, change_stream, insert_doc,
expected_doc):
self.input_target.insert_one(insert_doc)
change = next(change_stream)
self.assertEqual(change['fullDocument'], expected_doc)
def kill_change_stream_cursor(self, change_stream):
# Cause a cursor not found error on the next getMore.
cursor = change_stream._cursor
address = _CursorAddress(cursor.address, cursor._CommandCursor__ns)
client = self.input_target.database.client
client._close_cursor_now(cursor.cursor_id, address)
def test_simple(self):
codecopts = CodecOptions(type_registry=TypeRegistry([
UndecipherableIntEncoder(), UppercaseTextDecoder()]))
self.create_targets(codec_options=codecopts)
input_docs = [
{'_id': UndecipherableInt64Type(1), 'data': 'hello'},
{'_id': 2, 'data': 'world'},
{'_id': UndecipherableInt64Type(3), 'data': '!'},]
expected_docs = [
{'_id': 1, 'data': 'HELLO'},
{'_id': 2, 'data': 'WORLD'},
{'_id': 3, 'data': '!'},]
change_stream = self.change_stream()
self.insert_and_check(change_stream, input_docs[0], expected_docs[0])
self.kill_change_stream_cursor(change_stream)
self.insert_and_check(change_stream, input_docs[1], expected_docs[1])
self.kill_change_stream_cursor(change_stream)
self.insert_and_check(change_stream, input_docs[2], expected_docs[2])
def test_custom_type_in_pipeline(self):
codecopts = CodecOptions(type_registry=TypeRegistry([
UndecipherableIntEncoder(), UppercaseTextDecoder()]))
self.create_targets(codec_options=codecopts)
input_docs = [
{'_id': UndecipherableInt64Type(1), 'data': 'hello'},
{'_id': 2, 'data': 'world'},
{'_id': UndecipherableInt64Type(3), 'data': '!'}]
expected_docs = [
{'_id': 2, 'data': 'WORLD'},
{'_id': 3, 'data': '!'}]
# UndecipherableInt64Type should be encoded with the TypeRegistry.
change_stream = self.change_stream(
[{'$match': {'documentKey._id': {
'$gte': UndecipherableInt64Type(2)}}}])
self.input_target.insert_one(input_docs[0])
self.insert_and_check(change_stream, input_docs[1], expected_docs[0])
self.kill_change_stream_cursor(change_stream)
self.insert_and_check(change_stream, input_docs[2], expected_docs[1])
def test_break_resume_token(self):
# Get one document from a change stream to determine resumeToken type.
self.create_targets()
change_stream = self.change_stream()
self.input_target.insert_one({"data": "test"})
change = next(change_stream)
resume_token_decoder = type_obfuscating_decoder_factory(
type(change['_id']['_data']))
# Custom-decoding the resumeToken type breaks resume tokens.
codecopts = CodecOptions(type_registry=TypeRegistry([
resume_token_decoder(), UndecipherableIntEncoder()]))
# Re-create targets, change stream and proceed.
self.create_targets(codec_options=codecopts)
docs = [{'_id': 1}, {'_id': 2}, {'_id': 3}]
change_stream = self.change_stream()
self.insert_and_check(change_stream, docs[0], docs[0])
self.kill_change_stream_cursor(change_stream)
self.insert_and_check(change_stream, docs[1], docs[1])
self.kill_change_stream_cursor(change_stream)
self.insert_and_check(change_stream, docs[2], docs[2])
def test_document_class(self):
def run_test(doc_cls):
codecopts = CodecOptions(type_registry=TypeRegistry([
UppercaseTextDecoder(), UndecipherableIntEncoder()]),
document_class=doc_cls)
self.create_targets(codec_options=codecopts)
change_stream = self.change_stream()
doc = {'a': UndecipherableInt64Type(101), 'b': 'xyz'}
self.input_target.insert_one(doc)
change = next(change_stream)
self.assertIsInstance(change, doc_cls)
self.assertEqual(change['fullDocument']['a'], 101)
self.assertEqual(change['fullDocument']['b'], 'XYZ')
for doc_cls in [OrderedDict, RawBSONDocument]:
run_test(doc_cls)
class TestCollectionChangeStreamsWCustomTypes(
IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
@classmethod
@client_context.require_version_min(3, 6, 0)
@client_context.require_no_mmap
@client_context.require_no_standalone
def setUpClass(cls):
super(TestCollectionChangeStreamsWCustomTypes, cls).setUpClass()
def tearDown(self):
self.input_target.drop()
def create_targets(self, *args, **kwargs):
self.watched_target = self.db.get_collection(
'test', *args, **kwargs)
self.input_target = self.watched_target
# Insert a record to ensure db, coll are created.
self.input_target.insert_one({'data': 'dummy'})
class TestDatabaseChangeStreamsWCustomTypes(
IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
@classmethod
@client_context.require_version_min(4, 0, 0)
@client_context.require_no_mmap
@client_context.require_no_standalone
def setUpClass(cls):
super(TestDatabaseChangeStreamsWCustomTypes, cls).setUpClass()
def tearDown(self):
self.input_target.drop()
self.client.drop_database(self.watched_target)
def create_targets(self, *args, **kwargs):
self.watched_target = self.client.get_database(
self.db.name, *args, **kwargs)
self.input_target = self.watched_target.test
# Insert a record to ensure db, coll are created.
self.input_target.insert_one({'data': 'dummy'})
class TestClusterChangeStreamsWCustomTypes(
IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
@classmethod
@client_context.require_version_min(4, 0, 0)
@client_context.require_no_mmap
@client_context.require_no_standalone
def setUpClass(cls):
super(TestClusterChangeStreamsWCustomTypes, cls).setUpClass()
def tearDown(self):
self.input_target.drop()
self.client.drop_database(self.db)
def create_targets(self, *args, **kwargs):
codec_options = kwargs.pop('codec_options', None)
if codec_options:
kwargs['type_registry'] = codec_options.type_registry
kwargs['document_class'] = codec_options.document_class
self.watched_target = rs_client(*args, **kwargs)
self.input_target = self.watched_target[self.db.name].test
# Insert a record to ensure db, coll are created.
self.input_target.insert_one({'data': 'dummy'})
if __name__ == "__main__":
unittest.main()

View File

@ -56,7 +56,9 @@ from test.utils import (ignore_deprecations,
rs_or_single_client_noauth, rs_or_single_client_noauth,
rs_or_single_client, rs_or_single_client,
server_started_with_auth, server_started_with_auth,
IMPOSSIBLE_WRITE_CONCERN) IMPOSSIBLE_WRITE_CONCERN,
OvertCommandListener)
from test.test_custom_types import DECIMAL_CODECOPTS
if PY3: if PY3:
@ -156,7 +158,7 @@ class TestDatabase(IntegrationTest):
self.assertTrue(u"test.foo" in db.list_collection_names()) self.assertTrue(u"test.foo" in db.list_collection_names())
self.assertRaises(CollectionInvalid, db.create_collection, "test.foo") self.assertRaises(CollectionInvalid, db.create_collection, "test.foo")
def _test_collection_names(self, meth, test_no_system): def _test_collection_names(self, meth, **no_system_kwargs):
db = Database(self.client, "pymongo_test") db = Database(self.client, "pymongo_test")
db.test.insert_one({"dummy": u"object"}) db.test.insert_one({"dummy": u"object"})
db.test.mike.insert_one({"dummy": u"object"}) db.test.mike.insert_one({"dummy": u"object"})
@ -167,13 +169,11 @@ class TestDatabase(IntegrationTest):
for coll in colls: for coll in colls:
self.assertTrue("$" not in coll) self.assertTrue("$" not in coll)
if test_no_system: db.systemcoll.test.insert_one({})
db.systemcoll.test.insert_one({}) no_system_collections = getattr(db, meth)(**no_system_kwargs)
no_system_collections = getattr( for coll in no_system_collections:
db, meth)(include_system_collections=False) self.assertTrue(not coll.startswith("system."))
for coll in no_system_collections: self.assertIn("systemcoll.test", no_system_collections)
self.assertTrue(not coll.startswith("system."))
self.assertIn("systemcoll.test", no_system_collections)
# Force more than one batch. # Force more than one batch.
db = self.client.many_collections db = self.client.many_collections
@ -186,10 +186,45 @@ class TestDatabase(IntegrationTest):
self.client.drop_database("many_collections") self.client.drop_database("many_collections")
def test_collection_names(self): def test_collection_names(self):
self._test_collection_names('collection_names', True) self._test_collection_names(
'collection_names', include_system_collections=False)
def test_list_collection_names(self): def test_list_collection_names(self):
self._test_collection_names('list_collection_names', False) self._test_collection_names(
'list_collection_names', filter={
"name": {"$regex": r"^(?!system\.)"}})
def test_list_collection_names_filter(self):
listener = OvertCommandListener()
results = listener.results
client = rs_or_single_client(event_listeners=[listener])
db = client[self.db.name]
db.capped.drop()
db.create_collection("capped", capped=True, size=4096)
db.capped.insert_one({})
db.non_capped.insert_one({})
self.addCleanup(client.drop_database, db.name)
# Should not send nameOnly.
for filter in ({'options.capped': True},
{'options.capped': True, 'name': 'capped'}):
results.clear()
names = db.list_collection_names(filter=filter)
self.assertEqual(names, ["capped"])
self.assertNotIn("nameOnly", results["started"][0].command)
# Should send nameOnly (except on 2.6).
for filter in (None, {}, {'name': {'$in': ['capped', 'non_capped']}}):
results.clear()
names = db.list_collection_names(filter=filter)
self.assertIn("capped", names)
self.assertIn("non_capped", names)
command = results["started"][0].command
if client_context.version >= (3, 0):
self.assertIn("nameOnly", command)
self.assertTrue(command["nameOnly"])
else:
self.assertNotIn("nameOnly", command)
def test_list_collections(self): def test_list_collections(self):
self.client.drop_database("pymongo_test") self.client.drop_database("pymongo_test")
@ -942,6 +977,45 @@ class TestDatabase(IntegrationTest):
"maxTimeAlwaysTimeOut", "maxTimeAlwaysTimeOut",
mode="off") mode="off")
def test_with_options(self):
codec_options = DECIMAL_CODECOPTS
read_preference = ReadPreference.SECONDARY_PREFERRED
write_concern = WriteConcern(j=True)
read_concern = ReadConcern(level="majority")
# List of all options to compare.
allopts = ['name', 'client', 'codec_options',
'read_preference', 'write_concern', 'read_concern']
db1 = self.client.get_database(
'with_options_test', codec_options=codec_options,
read_preference=read_preference, write_concern=write_concern,
read_concern=read_concern)
# Case 1: swap no options
db2 = db1.with_options()
for opt in allopts:
self.assertEqual(getattr(db1, opt), getattr(db2, opt))
# Case 2: swap all options
newopts = {'codec_options': CodecOptions(),
'read_preference': ReadPreference.PRIMARY,
'write_concern': WriteConcern(w=1),
'read_concern': ReadConcern(level="local")}
db2 = db1.with_options(**newopts)
for opt in newopts:
self.assertEqual(
getattr(db2, opt), newopts.get(opt, getattr(db1, opt)))
def test_current_op_codec_options(self):
class MySON(SON):
pass
opts = CodecOptions(document_class=MySON)
db = self.client.get_database("pymongo_test", codec_options=opts)
current_op = db.current_op(True)
self.assertTrue(current_op['inprog'])
self.assertIsInstance(current_op, MySON)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -19,6 +19,7 @@
import datetime import datetime
import sys import sys
import zipfile
sys.path[0:0] = [""] sys.path[0:0] = [""]
from bson.objectid import ObjectId from bson.objectid import ObjectId
@ -33,10 +34,11 @@ from gridfs.grid_file import (DEFAULT_CHUNK_SIZE,
from gridfs.errors import NoFile from gridfs.errors import NoFile
from pymongo import MongoClient from pymongo import MongoClient
from pymongo.errors import ConfigurationError, ServerSelectionTimeoutError from pymongo.errors import ConfigurationError, ServerSelectionTimeoutError
from pymongo.message import _CursorAddress
from test import (IntegrationTest, from test import (IntegrationTest,
unittest, unittest,
qcheck) qcheck)
from test.utils import rs_or_single_client from test.utils import rs_or_single_client, EventListener
class TestGridFileNoConnect(unittest.TestCase): class TestGridFileNoConnect(unittest.TestCase):
@ -616,6 +618,51 @@ Bye"""))
with self.assertRaises(ConfigurationError): with self.assertRaises(ConfigurationError):
GridIn(rs_or_single_client(w=0).pymongo_test.fs) GridIn(rs_or_single_client(w=0).pymongo_test.fs)
def test_survive_cursor_not_found(self):
# By default the find command returns 101 documents in the first batch.
# Use 102 batches to cause a single getMore.
chunk_size = 1024
data = b'd' * (102 * chunk_size)
listener = EventListener()
client = rs_or_single_client(event_listeners=[listener])
db = client.pymongo_test
with GridIn(db.fs, chunk_size=chunk_size) as infile:
infile.write(data)
with GridOut(db.fs, infile._id) as outfile:
self.assertEqual(len(outfile.readchunk()), chunk_size)
# Kill the cursor to simulate the cursor timing out on the server
# when an application spends a long time between two calls to
# readchunk().
client._close_cursor_now(
outfile._GridOut__chunk_iter._cursor.cursor_id,
_CursorAddress(client.address, db.fs.chunks.full_name))
# Read the rest of the file without error.
self.assertEqual(len(outfile.read()), len(data) - chunk_size)
# Paranoid, ensure that a getMore was actually sent.
self.assertIn("getMore", listener.started_command_names())
def test_zip(self):
zf = StringIO()
z = zipfile.ZipFile(zf, "w")
z.writestr("test.txt", b"hello world")
z.close()
zf.seek(0)
f = GridIn(self.db.fs, filename="test.zip")
f.write(zf)
f.close()
self.assertEqual(1, self.db.fs.files.count_documents({}))
self.assertEqual(1, self.db.fs.chunks.count_documents({}))
g = GridOut(self.db.fs, f._id)
z = zipfile.ZipFile(g)
self.assertSequenceEqual(z.namelist(), ["test.txt"])
self.assertEqual(z.read("test.txt"), b"hello world")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -163,8 +163,8 @@ def create_test(scenario_def):
if test['assert'].get("error", False): if test['assert'].get("error", False):
self.assertIsNotNone(error) self.assertIsNotNone(error)
self.assertTrue(isinstance(error, self.assertIsInstance(error, errors[test['assert']['error']],
errors[test['assert']['error']])) test['description'])
else: else:
self.assertIsNone(error) self.assertIsNone(error)

View File

@ -1379,9 +1379,12 @@ class TestLegacy(IntegrationTest):
wait_until(raises_cursor_not_found, 'close cursor') wait_until(raises_cursor_not_found, 'close cursor')
def test_kill_cursors_with_tuple(self): def test_kill_cursors_with_tuple(self):
if (client_context.version[:2] == (3, 6) # Some evergreen distros (Debian 7.1) still test against 3.6.5 where
and client_context.auth_enabled): # OP_KILL_CURSORS does not work.
raise SkipTest("SERVER-33553") if (client_context.is_mongos and client_context.auth_enabled and
(3, 6, 0) <= client_context.version < (3, 6, 6)):
raise SkipTest("SERVER-33553 This server version does not support "
"OP_KILL_CURSORS")
coll = self.client.pymongo_test.test coll = self.client.pymongo_test.test
coll.drop() coll.drop()
@ -1407,26 +1410,6 @@ class TestLegacy(IntegrationTest):
wait_until(raises_cursor_not_found, 'close cursor') wait_until(raises_cursor_not_found, 'close cursor')
def test_get_default_database(self):
c = rs_or_single_client("mongodb://%s:%d/foo" % (client_context.host,
client_context.port),
connect=False)
self.assertEqual(Database(c, 'foo'), c.get_default_database())
def test_get_default_database_error(self):
# URI with no database.
c = rs_or_single_client("mongodb://%s:%d/" % (client_context.host,
client_context.port),
connect=False)
self.assertRaises(ConfigurationError, c.get_default_database)
def test_get_default_database_with_authsource(self):
# Ensure we distinguish database name from authSource.
uri = "mongodb://%s:%d/foo?authSource=src" % (
client_context.host, client_context.port)
c = rs_or_single_client(uri, connect=False)
self.assertEqual(Database(c, 'foo'), c.get_default_database())
class TestLegacyBulk(BulkTestBase): class TestLegacyBulk(BulkTestBase):

View File

@ -122,12 +122,12 @@ class TestMaxStaleness(unittest.TestCase):
# From max-staleness-tests.rst, "Parse lastWriteDate". # From max-staleness-tests.rst, "Parse lastWriteDate".
client = rs_or_single_client(heartbeatFrequencyMS=500) client = rs_or_single_client(heartbeatFrequencyMS=500)
client.pymongo_test.test.insert_one({}) client.pymongo_test.test.insert_one({})
time.sleep(1) time.sleep(2)
server = client._topology.select_server(writable_server_selector) server = client._topology.select_server(writable_server_selector)
last_write = server.description.last_write_date last_write = server.description.last_write_date
self.assertTrue(last_write) self.assertTrue(last_write)
client.pymongo_test.test.insert_one({}) client.pymongo_test.test.insert_one({})
time.sleep(1) time.sleep(2)
server = client._topology.select_server(writable_server_selector) server = client._topology.select_server(writable_server_selector)
self.assertGreater(server.description.last_write_date, last_write) self.assertGreater(server.description.last_write_date, last_write)
self.assertLess(server.description.last_write_date, last_write + 10) self.assertLess(server.description.last_write_date, last_write + 10)

View File

@ -18,7 +18,9 @@ import uuid
from bson import BSON from bson import BSON
from bson.binary import JAVA_LEGACY from bson.binary import JAVA_LEGACY
from bson.codec_options import CodecOptions from bson.codec_options import CodecOptions
from bson.errors import InvalidBSON
from bson.raw_bson import RawBSONDocument from bson.raw_bson import RawBSONDocument
from bson.son import SON
from test import client_context, unittest from test import client_context, unittest
@ -51,6 +53,21 @@ class TestRawBSONDocument(unittest.TestCase):
def test_raw(self): def test_raw(self):
self.assertEqual(self.bson_string, self.document.raw) self.assertEqual(self.bson_string, self.document.raw)
def test_empty_doc(self):
doc = RawBSONDocument(BSON.encode({}))
with self.assertRaises(KeyError):
doc['does-not-exist']
def test_invalid_bson_sequence(self):
bson_byte_sequence = BSON.encode({'a': 1})+BSON.encode({})
with self.assertRaisesRegex(InvalidBSON, 'invalid object length'):
RawBSONDocument(bson_byte_sequence)
def test_invalid_bson_eoo(self):
invalid_bson_eoo = BSON.encode({'a': 1})[:-1] + b'\x01'
with self.assertRaisesRegex(InvalidBSON, 'bad eoo'):
RawBSONDocument(invalid_bson_eoo)
@client_context.require_connection @client_context.require_connection
def test_round_trip(self): def test_round_trip(self):
db = self.client.get_database( db = self.client.get_database(
@ -141,3 +158,10 @@ class TestRawBSONDocument(unittest.TestCase):
coll.delete_many(self.document) coll.delete_many(self.document)
coll.update_one(self.document, {'$set': {'a': 'b'}}, upsert=True) coll.update_one(self.document, {'$set': {'a': 'b'}}, upsert=True)
coll.update_many(self.document, {'$set': {'b': 'c'}}) coll.update_many(self.document, {'$set': {'b': 'c'}})
def test_preserve_key_ordering(self):
keyvaluepairs = [('a', 1), ('b', 2), ('c', 3),]
rawdoc = RawBSONDocument(BSON.encode(SON(keyvaluepairs)))
for rkey, elt in zip(rawdoc, keyvaluepairs):
self.assertEqual(rkey, elt[0])

View File

@ -571,8 +571,11 @@ class TestSession(IntegrationTest):
for f in files: for f in files:
f.read() f.read()
with self.assertRaisesRegex(InvalidOperation, "ended session"): for f in files:
files[0].read() # Attempt to read the file again.
f.seek(0)
with self.assertRaisesRegex(InvalidOperation, "ended session"):
f.read()
def test_aggregate(self): def test_aggregate(self):
client = self.client client = self.client

View File

@ -210,6 +210,9 @@ class TestThreadsAuth(IntegrationTest):
super(TestThreadsAuth, cls).setUpClass() super(TestThreadsAuth, cls).setUpClass()
def test_auto_auth_login(self): def test_auto_auth_login(self):
# Create the database upfront to workaround SERVER-39167.
self.client.auth_test.test.insert_one({})
self.addCleanup(self.client.drop_database, "auth_test")
client = rs_or_single_client_noauth() client = rs_or_single_client_noauth()
self.assertRaises(OperationFailure, client.auth_test.test.find_one) self.assertRaises(OperationFailure, client.auth_test.test.find_one)

View File

@ -86,6 +86,23 @@ class TestTransactions(IntegrationTest):
cmd.update(command_args) cmd.update(command_args)
self.client.admin.command(cmd) self.client.admin.command(cmd)
@client_context.require_mongos
@client_context.require_version_min(4, 0)
def test_transactions_not_supported(self):
with self.client.start_session() as s:
with self.assertRaisesRegex(
ConfigurationError,
'does not support running multi-document transactions on '
'sharded clusters'):
s.start_transaction()
self.client.close()
with s.start_transaction():
with self.assertRaisesRegex(
ConfigurationError,
'does not support running multi-document transactions '
'on sharded clusters'):
self.client.test.test.insert_one({}, session=s)
@client_context.require_transactions @client_context.require_transactions
def test_transaction_options_validation(self): def test_transaction_options_validation(self):
default_options = TransactionOptions() default_options = TransactionOptions()
@ -390,6 +407,7 @@ def end_sessions(sessions):
def create_test(scenario_def, test): def create_test(scenario_def, test):
@client_context.require_test_commands
def run_scenario(self): def run_scenario(self):
if test.get('skipReason'): if test.get('skipReason'):
raise unittest.SkipTest(test.get('skipReason')) raise unittest.SkipTest(test.get('skipReason'))
@ -519,6 +537,12 @@ def create_test(scenario_def, test):
read_concern=ReadConcern('local')) read_concern=ReadConcern('local'))
self.assertEqual(list(primary_coll.find()), expected_c['data']) self.assertEqual(list(primary_coll.find()), expected_c['data'])
if 'minServerVersion' in scenario_def:
min_ver = tuple(
int(elt) for
elt in scenario_def['minServerVersion'].split('.'))
return client_context.require_version_min(*min_ver)(run_scenario)
return run_scenario return run_scenario

View File

@ -0,0 +1,107 @@
{
"minServerVersion": "4.0.2",
"database_name": "transaction-tests",
"collection_name": "test",
"data": [
{
"_id": 1
},
{
"_id": 2
},
{
"_id": 3
},
{
"_id": 4
}
],
"tests": [
{
"description": "count",
"operations": [
{
"name": "startTransaction",
"object": "session0"
},
{
"name": "count",
"object": "collection",
"arguments": {
"session": "session0",
"filter": {
"_id": 1
}
},
"result": {
"errorCodeName": "OperationNotSupportedInTransaction",
"errorLabelsOmit": [
"TransientTransactionError",
"UnknownTransactionCommitResult"
]
}
},
{
"name": "abortTransaction",
"object": "session0"
}
],
"expectations": [
{
"command_started_event": {
"command": {
"count": "test",
"query": {
"_id": 1
},
"readConcern": null,
"lsid": "session0",
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": true,
"autocommit": false,
"writeConcern": null
},
"command_name": "count",
"database_name": "transaction-tests"
}
},
{
"command_started_event": {
"command": {
"abortTransaction": 1,
"lsid": "session0",
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": null,
"autocommit": false,
"writeConcern": null
},
"command_name": "abortTransaction",
"database_name": "admin"
}
}
],
"outcome": {
"collection": {
"data": [
{
"_id": 1
},
{
"_id": 2
},
{
"_id": 3
},
{
"_id": 4
}
]
}
}
}
]
}

View File

@ -83,92 +83,6 @@
} }
} }
}, },
{
"description": "count",
"operations": [
{
"name": "startTransaction",
"object": "session0"
},
{
"name": "count",
"object": "collection",
"arguments": {
"session": "session0",
"filter": {
"_id": 1
}
},
"result": {
"errorContains": "Cannot run 'count' in a multi-document transaction",
"errorLabelsOmit": [
"TransientTransactionError",
"UnknownTransactionCommitResult"
]
}
},
{
"name": "abortTransaction",
"object": "session0"
}
],
"expectations": [
{
"command_started_event": {
"command": {
"count": "test",
"query": {
"_id": 1
},
"readConcern": null,
"lsid": "session0",
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": true,
"autocommit": false,
"writeConcern": null
},
"command_name": "count",
"database_name": "transaction-tests"
}
},
{
"command_started_event": {
"command": {
"abortTransaction": 1,
"lsid": "session0",
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": null,
"autocommit": false,
"writeConcern": null
},
"command_name": "abortTransaction",
"database_name": "admin"
}
}
],
"outcome": {
"collection": {
"data": [
{
"_id": 1
},
{
"_id": 2
},
{
"_id": 3
},
{
"_id": 4
}
]
}
}
},
{ {
"description": "find", "description": "find",
"operations": [ "operations": [

View File

@ -71,6 +71,10 @@ class EventListener(monitoring.CommandListener):
def failed(self, event): def failed(self, event):
self.results['failed'].append(event) self.results['failed'].append(event)
def started_command_names(self):
"""Return list of command names started."""
return [event.command_name for event in self.results['started']]
class OvertCommandListener(EventListener): class OvertCommandListener(EventListener):
"""A CommandListener that ignores sensitive commands.""" """A CommandListener that ignores sensitive commands."""
@ -267,9 +271,10 @@ def server_is_master_with_slave(client):
def drop_collections(db): def drop_collections(db):
for coll in db.list_collection_names(): # Drop all non-system collections in this database.
if not coll.startswith('system'): for coll in db.list_collection_names(
db.drop_collection(coll) filter={"name": {"$regex": r"^(?!system\.)"}}):
db.drop_collection(coll)
def remove_all_users(db): def remove_all_users(db):
@ -498,3 +503,17 @@ def enable_replication(client):
secondary = single_client(host, port) secondary = single_client(host, port)
secondary.admin.command('configureFailPoint', 'stopReplProducer', secondary.admin.command('configureFailPoint', 'stopReplProducer',
mode='off') mode='off')
class ExceptionCatchingThread(threading.Thread):
"""A thread that stores any exception encountered from run()."""
def __init__(self, *args, **kwargs):
self.exc = None
super(ExceptionCatchingThread, self).__init__(*args, **kwargs)
def run(self):
try:
super(ExceptionCatchingThread, self).run()
except BaseException as exc:
self.exc = exc
raise