PYTHON-1783 Decode user-facing documents but not internal driver-server

communications.
This commit is contained in:
Prashant Mital 2019-04-02 16:47:51 -07:00
parent 007aa6ba50
commit 749116287a
No known key found for this signature in database
GPG Key ID: 3D2DAA9E483ABE51
13 changed files with 488 additions and 57 deletions

View File

@ -215,7 +215,7 @@ def _get_array(data, position, obj_end, opts, element_name):
end -= 1
result = []
# Avoid doing global and attibute lookups in the loop.
# Avoid doing global and attribute lookups in the loop.
append = result.append
index = data.index
getter = _ELEMENT_GETTER
@ -940,6 +940,59 @@ if _USE_C:
decode_all = _cbson.decode_all
def _decode_selective(rawdoc, fields, codec_options):
doc = codec_options.document_class()
for key, value in iteritems(rawdoc):
if key in fields:
if fields[key] == list:
doc[key] = [_bson_to_dict(r.raw, codec_options) for r in value]
elif fields[key] == dict:
doc[key] = _bson_to_dict(value.raw, codec_options)
else:
doc[key] = _decode_selective(value, fields[key], codec_options)
continue
doc[key] = value
return doc
def _decode_all_selective(data, codec_options, fields):
"""Decode BSON data to a single document while using user-provided
custom decoding logic.
`data` must be a string representing a valid, BSON-encoded document.
:Parameters:
- `data`: BSON data
- `codec_options`: An instance of
:class:`~bson.codec_options.CodecOptions` with user-specified type
decoders. If no decoders are found, this method is the same as
``decode_all``.
- `fields`: Map of document namespaces where data that needs
to be custom decoded lives or None. For example, to custom decode a
list of objects in 'field1.subfield1', the specified value should be
``{'field1': {'subfield1': list}}``. Use ``dict`` instead of ``list``
if the field contains a single object to custom decode. If ``fields``
is an empty map or None, this method is the same as ``decode_all``.
:Returns:
- `document_list`: Single-member list containing the decoded document.
.. versionadded:: 3.8
"""
if not codec_options.type_registry._decoder_map:
return decode_all(data, codec_options)
if not fields:
return decode_all(data, codec_options.with_options(type_registry=None))
# Decode documents for internal use.
from bson.raw_bson import RawBSONDocument
internal_codec_options = codec_options.with_options(
document_class=RawBSONDocument, type_registry=None)
_doc = _bson_to_dict(data, internal_codec_options)
return [_decode_selective(_doc, fields, codec_options,)]
def decode_iter(data, codec_options=DEFAULT_CODEC_OPTIONS):
"""Decode BSON data to multiple documents as a generator.

View File

@ -25,7 +25,8 @@ from gridfs.errors import NoFile
from gridfs.grid_file import (GridIn,
GridOut,
GridOutCursor,
DEFAULT_CHUNK_SIZE)
DEFAULT_CHUNK_SIZE,
_clear_entity_type_registry)
from pymongo import (ASCENDING,
DESCENDING)
from pymongo.common import UNAUTHORIZED_CODES, validate_string
@ -61,6 +62,8 @@ class GridFS(object):
if not isinstance(database, Database):
raise TypeError("database must be an instance of Database")
database = _clear_entity_type_registry(database)
if not database.write_concern.acknowledged:
raise ConfigurationError('database must use '
'acknowledged write_concern')
@ -443,6 +446,8 @@ class GridFSBucket(object):
if not isinstance(db, Database):
raise TypeError("database must be an instance of Database")
db = _clear_entity_type_registry(db)
wtc = write_concern if write_concern is not None else db.write_concern
if not wtc.acknowledged:
raise ConfigurationError('write concern must be acknowledged')

View File

@ -98,6 +98,12 @@ def _grid_out_property(field_name, docstring):
return property(getter, doc=docstring)
def _clear_entity_type_registry(entity, **kwargs):
"""Clear the given database/collection object's type registry."""
codecopts = entity.codec_options.with_options(type_registry=None)
return entity.with_options(codec_options=codecopts, **kwargs)
class GridIn(object):
"""Class to write data to GridFS.
"""
@ -168,8 +174,8 @@ class GridIn(object):
if "chunk_size" in kwargs:
kwargs["chunkSize"] = kwargs.pop("chunk_size")
coll = root_collection.with_options(
read_preference=ReadPreference.PRIMARY)
coll = _clear_entity_type_registry(
root_collection, read_preference=ReadPreference.PRIMARY)
if not disable_md5:
kwargs["md5"] = hashlib.md5()
@ -449,6 +455,8 @@ class GridOut(object):
raise TypeError("root_collection must be an "
"instance of Collection")
root_collection = _clear_entity_type_registry(root_collection)
self.__chunks = root_collection.chunks
self.__files = root_collection.files
self.__file_id = file_id
@ -800,6 +808,8 @@ class GridOutCursor(Cursor):
.. mongodoc:: cursors
"""
collection = _clear_entity_type_registry(collection)
# Hold on to the base "fs" collection to create GridOut objects later.
self.__root_collection = collection

View File

@ -53,6 +53,7 @@ from pymongo.write_concern import WriteConcern
_NO_OBJ_ERROR = "No matching object found"
_UJOIN = u"%s.%s"
_FIND_AND_MODIFY_DOC_FIELDS = {'value': dict}
class ReturnDocument(object):
@ -203,7 +204,8 @@ class Collection(common.BaseObject):
write_concern=None,
collation=None,
session=None,
retryable_write=False):
retryable_write=False,
user_fields=None):
"""Internal command helper.
:Parameters:
@ -242,7 +244,8 @@ class Collection(common.BaseObject):
collation=collation,
session=s,
client=self.__database.client,
retryable_write=retryable_write)
retryable_write=retryable_write,
user_fields=user_fields)
def __create(self, options, collation, session):
"""Sends a create command with the given options.
@ -315,9 +318,8 @@ class Collection(common.BaseObject):
"""
return self.__database
def with_options(
self, codec_options=None, read_preference=None,
write_concern=None, read_concern=None):
def with_options(self, codec_options=None, read_preference=None,
write_concern=None, read_concern=None):
"""Get a clone of this collection changing the specified settings.
>>> coll1.read_preference
@ -2310,7 +2312,8 @@ class Collection(common.BaseObject):
write_concern=write_concern,
collation=collation,
session=session,
client=self.__database.client)
client=self.__database.client,
user_fields={'cursor': {'firstBatch': list}})
if "cursor" in result:
cursor = result["cursor"]
@ -2571,7 +2574,8 @@ class Collection(common.BaseObject):
with self._socket_for_reads(session=None) as (sock_info, slave_ok):
return self._command(sock_info, cmd, slave_ok,
collation=collation)["retval"]
collation=collation,
user_fields={'retval': list})["retval"]
def rename(self, new_name, session=None, **kwargs):
"""Rename this collection.
@ -2675,7 +2679,8 @@ class Collection(common.BaseObject):
with self._socket_for_reads(session) as (sock_info, slave_ok):
return self._command(sock_info, cmd, slave_ok,
read_concern=self.read_concern,
collation=collation, session=session)["values"]
collation=collation,
session=session)["values"]
def map_reduce(self, map, reduce, out, full_response=False, session=None,
**kwargs):
@ -2755,12 +2760,17 @@ class Collection(common.BaseObject):
write_concern = self._write_concern_for(session)
else:
write_concern = None
if inline:
user_fields = {'results': list}
else:
user_fields = None
response = self._command(
sock_info, cmd, slave_ok, read_pref,
read_concern=read_concern,
write_concern=write_concern,
collation=collation, session=session)
collation=collation, session=session,
user_fields=user_fields)
if full_response or not response.get('result'):
return response
@ -2810,16 +2820,19 @@ class Collection(common.BaseObject):
("map", map),
("reduce", reduce),
("out", {"inline": 1})])
user_fields = {'results': list}
collation = validate_collation_or_none(kwargs.pop('collation', None))
cmd.update(kwargs)
with self._socket_for_reads(session) as (sock_info, slave_ok):
if sock_info.max_wire_version >= 4 and 'readConcern' not in cmd:
res = self._command(sock_info, cmd, slave_ok,
read_concern=self.read_concern,
collation=collation, session=session)
collation=collation, session=session,
user_fields=user_fields)
else:
res = self._command(sock_info, cmd, slave_ok,
collation=collation, session=session)
collation=collation, session=session,
user_fields=user_fields)
if full_response:
return res
@ -2837,6 +2850,7 @@ class Collection(common.BaseObject):
return_document=ReturnDocument.BEFORE,
array_filters=None, session=None, **kwargs):
"""Internal findAndModify helper."""
common.validate_is_mapping("filter", filter)
if not isinstance(return_document, bool):
raise ValueError("return_document must be "
@ -2876,8 +2890,10 @@ class Collection(common.BaseObject):
write_concern=write_concern,
allowable_errors=[_NO_OBJ_ERROR],
collation=collation, session=session,
retryable_write=retryable_write)
retryable_write=retryable_write,
user_fields=_FIND_AND_MODIFY_DOC_FIELDS)
_check_write_command_response(out)
return out.get("value")
return self.__database.client._retryable_write(
@ -3293,7 +3309,8 @@ class Collection(common.BaseObject):
result = self._command(
sock_info, cmd, read_preference=ReadPreference.PRIMARY,
allowable_errors=[_NO_OBJ_ERROR], collation=collation,
session=session, retryable_write=retryable_write)
session=session, retryable_write=retryable_write,
user_fields=_FIND_AND_MODIFY_DOC_FIELDS)
_check_write_command_response(result)
return result

View File

@ -150,14 +150,18 @@ class CommandCursor(object):
try:
with client._reset_on_error(self.__address, self.__session):
docs = self._unpack_response(reply,
self.__id,
self.__collection.codec_options)
user_fields = None
legacy_response = True
if from_command:
user_fields = {'cursor': {'nextBatch': list}}
legacy_response = False
docs = self._unpack_response(
reply, self.__id, self.__collection.codec_options,
legacy_response=legacy_response, user_fields=user_fields)
if from_command:
first = docs[0]
client._process_response(first, self.__session)
helpers._check_command_response(first)
except OperationFailure as exc:
kill()
@ -208,8 +212,10 @@ class CommandCursor(object):
kill()
self.__data = deque(documents)
def _unpack_response(self, response, cursor_id, codec_options):
return response.unpack_response(cursor_id, codec_options)
def _unpack_response(self, response, cursor_id, codec_options,
user_fields=None, legacy_response=False):
return response.unpack_response(cursor_id, codec_options, user_fields,
legacy_response)
def _refresh(self):
"""Refreshes the cursor with more data from the server.
@ -330,7 +336,8 @@ class RawBatchCommandCursor(CommandCursor):
collection, cursor_info, address, retrieved, batch_size,
max_await_time_ms, session, explicit_session)
def _unpack_response(self, response, cursor_id, codec_options):
def _unpack_response(self, response, cursor_id, codec_options,
user_fields=None, legacy_response=False):
return response.raw_response(cursor_id)
def __getitem__(self, index):

View File

@ -40,7 +40,7 @@ from pymongo.message import (_convert_exception,
_RawBatchGetMore,
_Query,
_RawBatchQuery)
from pymongo.read_preferences import ReadPreference
_QUERY_OPTIONS = {
"tailable_cursor": 2,
@ -50,6 +50,7 @@ _QUERY_OPTIONS = {
"await_data": 32,
"exhaust": 64,
"partial": 128}
_CURSOR_DOC_FIELDS = {'cursor': {'firstBatch': list, 'nextBatch': list}}
class CursorType(object):
@ -996,9 +997,14 @@ class Cursor(object):
try:
with client._reset_on_error(self.__address, self.__session):
docs = self._unpack_response(reply,
self.__id,
self.__collection.codec_options)
user_fields = None
legacy_response = True
if from_command:
user_fields = _CURSOR_DOC_FIELDS
legacy_response = False
docs = self._unpack_response(
reply, self.__id, self.__collection.codec_options,
legacy_response=legacy_response, user_fields=user_fields)
if from_command:
first = docs[0]
client._process_response(first, self.__session)
@ -1085,8 +1091,10 @@ class Cursor(object):
if self.__limit and self.__id and self.__limit <= self.__retrieved:
self.__die()
def _unpack_response(self, response, cursor_id, codec_options):
return response.unpack_response(cursor_id, codec_options)
def _unpack_response(self, response, cursor_id, codec_options,
user_fields=None, legacy_response=False):
return response.unpack_response(cursor_id, codec_options, user_fields,
legacy_response)
def _read_preference(self):
if self.__read_preference is None:
@ -1303,7 +1311,8 @@ class RawBatchCursor(Cursor):
raise InvalidOperation(
"Cannot use RawBatchCursor with manipulate=True")
def _unpack_response(self, response, cursor_id, codec_options):
def _unpack_response(self, response, cursor_id, codec_options,
user_fields=None, legacy_response=False):
return response.raw_response(cursor_id)
def explain(self):

View File

@ -222,6 +222,46 @@ class Database(common.BaseObject):
return [manipulator.__class__.__name__
for manipulator in self.__outgoing_copying_manipulators]
def with_options(self, codec_options=None, read_preference=None,
write_concern=None, read_concern=None):
"""Get a clone of this database changing the specified settings.
>>> db1.read_preference
Primary()
>>> from pymongo import ReadPreference
>>> db2 = db1.with_options(read_preference=ReadPreference.SECONDARY)
>>> db1.read_preference
Primary()
>>> db2.read_preference
Secondary(tag_sets=None)
:Parameters:
- `codec_options` (optional): An instance of
:class:`~bson.codec_options.CodecOptions`. If ``None`` (the
default) the :attr:`codec_options` of this :class:`Collection`
is used.
- `read_preference` (optional): The read preference to use. If
``None`` (the default) the :attr:`read_preference` of this
:class:`Collection` is used. See :mod:`~pymongo.read_preferences`
for options.
- `write_concern` (optional): An instance of
:class:`~pymongo.write_concern.WriteConcern`. If ``None`` (the
default) the :attr:`write_concern` of this :class:`Collection`
is used.
- `read_concern` (optional): An instance of
:class:`~pymongo.read_concern.ReadConcern`. If ``None`` (the
default) the :attr:`read_concern` of this :class:`Collection`
is used.
.. versionadded:: 3.8
"""
return Database(self.client,
self.__name,
codec_options or self.codec_options,
read_preference or self.read_preference,
write_concern or self.write_concern,
read_concern or self.read_concern)
def __eq__(self, other):
if isinstance(other, Database):
return (self.__client == other.client and

View File

@ -1398,7 +1398,8 @@ class _OpReply(object):
return [self.documents]
def unpack_response(self, cursor_id=None,
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS):
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
user_fields=None, legacy_response=False):
"""Unpack a response from the database and decode the BSON document(s).
Check the response for errors and unpack, returning a dictionary
@ -1415,7 +1416,10 @@ class _OpReply(object):
:class:`~bson.codec_options.CodecOptions`
"""
self.raw_response(cursor_id)
return bson.decode_all(self.documents, codec_options)
if legacy_response:
return bson.decode_all(self.documents, codec_options)
return bson._decode_all_selective(
self.documents, codec_options, user_fields)
def command_response(self):
"""Unpack a command response."""
@ -1451,7 +1455,8 @@ class _OpMsg(object):
raise NotImplementedError
def unpack_response(self, cursor_id=None,
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS):
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
user_fields=None, legacy_response=False):
"""Unpack a OP_MSG command response.
:Parameters:
@ -1459,7 +1464,10 @@ class _OpMsg(object):
- `codec_options` (optional): an instance of
:class:`~bson.codec_options.CodecOptions`
"""
return bson.decode_all(self.payload_document, codec_options)
# If _OpMsg is in-use, this cannot be a legacy response.
assert not legacy_response
return bson._decode_all_selective(
self.payload_document, codec_options, user_fields)
def command_response(self):
"""Unpack a command response."""

View File

@ -58,7 +58,8 @@ def command(sock, dbname, spec, slave_ok, is_mongos,
collation=None,
compression_ctx=None,
use_op_msg=False,
unacknowledged=False):
unacknowledged=False,
user_fields=None):
"""Execute a command over the socket, or raise socket.error.
:Parameters:
@ -139,7 +140,8 @@ def command(sock, dbname, spec, slave_ok, is_mongos,
response_doc = {"ok": 1}
else:
reply = receive_message(sock, request_id)
unpacked_docs = reply.unpack_response(codec_options=codec_options)
unpacked_docs = reply.unpack_response(
codec_options=codec_options, user_fields=user_fields)
response_doc = unpacked_docs[0]
if client:

View File

@ -512,7 +512,8 @@ class SocketInfo(object):
session=None,
client=None,
retryable_write=False,
publish_events=True):
publish_events=True,
user_fields=None):
"""Execute a command or raise an error.
:Parameters:
@ -576,7 +577,8 @@ class SocketInfo(object):
collation=collation,
compression_ctx=self.compression_context,
use_op_msg=self.op_msg_enabled,
unacknowledged=unacknowledged)
unacknowledged=unacknowledged,
user_fields=user_fields)
except OperationFailure:
raise
# Catch socket.error, KeyboardInterrupt, etc. and close ourselves.

View File

@ -132,6 +132,12 @@ class TestCollection(IntegrationTest):
def tearDownClass(cls):
cls.db.drop_collection("test_large_limit")
def setUp(self):
self.db.test.drop()
def tearDown(self):
self.db.test.drop()
@contextlib.contextmanager
def write_concern_collection(self):
if client_context.version.at_least(3, 3, 9) and client_context.is_rs:

View File

@ -14,6 +14,7 @@
"""Test support for callbacks to encode/decode custom types."""
import datetime
import sys
import tempfile
from decimal import Decimal
@ -30,11 +31,21 @@ from bson import (BSON,
_BUILT_IN_TYPES,
_dict_to_bson,
_bson_to_dict)
from bson.code import Code
from bson.codec_options import (CodecOptions, TypeCodec, TypeDecoder,
TypeEncoder, TypeRegistry)
from bson.errors import InvalidDocument
from bson.int64 import Int64
from bson.py3compat import text_type
from test import unittest
from gridfs import GridIn, GridOut
from pymongo.collection import ReturnDocument
from pymongo.errors import DuplicateKeyError
from test import client_context, unittest
from test.test_client import IntegrationTest
from test.utils import ignore_deprecations
class DecimalEncoder(TypeEncoder):
@ -59,7 +70,51 @@ class DecimalCodec(DecimalDecoder, DecimalEncoder):
pass
class CustomTypeTests(object):
DECIMAL_CODECOPTS = CodecOptions(
type_registry=TypeRegistry([DecimalCodec()]))
class UndecipherableInt64Type(object):
def __init__(self, value):
self.value = value
def __eq__(self, other):
if isinstance(other, type(self)):
return self.value == other.value
return self.value == other
class UndecipherableIntDecoder(TypeDecoder):
bson_type = Int64
def transform_bson(self, value):
return UndecipherableInt64Type(value)
class UndecipherableIntEncoder(TypeEncoder):
python_type = UndecipherableInt64Type
def transform_python(self, value):
return Int64(value.value)
UNINT_DECODER_CODECOPTS = CodecOptions(
type_registry=TypeRegistry([UndecipherableIntDecoder(), ]))
UNINT_CODECOPTS = CodecOptions(type_registry=TypeRegistry(
[UndecipherableIntDecoder(), UndecipherableIntEncoder()]))
class UppercaseTextDecoder(TypeDecoder):
bson_type = text_type
def transform_bson(self, value):
return value.upper()
UPPERSTR_DECODER_CODECOPTS = CodecOptions(type_registry=TypeRegistry(
[UppercaseTextDecoder(),]))
class CustomBSONTypeTests(object):
def test_encode_decode_roundtrip(self):
document = {'average': Decimal('56.47')}
bsonbytes = BSON().encode(document, codec_options=self.codecopts)
@ -117,25 +172,24 @@ class CustomTypeTests(object):
fileobj.close()
class TestCustomPythonTypeToBSONMonolithicCodec(CustomTypeTests,
unittest.TestCase):
class TestCustomPythonBSONTypeToBSONMonolithicCodec(CustomBSONTypeTests,
unittest.TestCase):
@classmethod
def setUpClass(cls):
type_registry = TypeRegistry((DecimalCodec(),))
codec_options = CodecOptions(type_registry=type_registry)
cls.codecopts = DECIMAL_CODECOPTS
class TestCustomPythonBSONTypeToBSONMultiplexedCodec(CustomBSONTypeTests,
unittest.TestCase):
@classmethod
def setUpClass(cls):
codec_options = CodecOptions(
type_registry=TypeRegistry((DecimalEncoder(), DecimalDecoder())))
cls.codecopts = codec_options
class TestCustomPythonTypeToBSONMultiplexedCodec(CustomTypeTests,
unittest.TestCase):
@classmethod
def setUpClass(cls):
type_registry = TypeRegistry((DecimalEncoder(), DecimalDecoder()))
codec_options = CodecOptions(type_registry=type_registry)
cls.codecopts = codec_options
class TestFallbackEncoder(unittest.TestCase):
class TestBSONFallbackEncoder(unittest.TestCase):
def _get_codec_options(self, fallback_encoder):
type_registry = TypeRegistry(fallback_encoder=fallback_encoder)
return CodecOptions(type_registry=type_registry)
@ -180,7 +234,7 @@ class TestFallbackEncoder(unittest.TestCase):
BSON().encode(document, codec_options=codecopts)
class TestTypeEnDeCodecs(unittest.TestCase):
class TestBSONTypeEnDeCodecs(unittest.TestCase):
def test_instantiation(self):
msg = "Can't instantiate abstract class .* with abstract methods .*"
def run_test(base, attrs, fail):
@ -220,7 +274,7 @@ class TestTypeEnDeCodecs(unittest.TestCase):
self.assertFalse(issubclass(TypeEncoder, TypeDecoder))
class TestCustomTypeEncoderAndFallbackEncoderTandem(unittest.TestCase):
class TestBSONCustomTypeEncoderAndFallbackEncoderTandem(unittest.TestCase):
@classmethod
def setUpClass(cls):
class TypeA(object):
@ -467,5 +521,192 @@ class TestTypeRegistry(unittest.TestCase):
'transform_bson': lambda x: x})
class TestCollectionWCustomType(IntegrationTest):
def setUp(self):
self.db.test.drop()
def tearDown(self):
self.db.test.drop()
def test_command_errors_w_custom_type_decoder(self):
db = self.db
test_doc = {'_id': 1, 'data': 'a'}
test = db.get_collection('test',
codec_options=UNINT_DECODER_CODECOPTS)
result = test.insert_one(test_doc)
self.assertEqual(result.inserted_id, test_doc['_id'])
with self.assertRaises(DuplicateKeyError):
test.insert_one(test_doc)
def test_find_w_custom_type_decoder(self):
db = self.db
input_docs = [
{'x': Int64(k)} for k in [1.0, 2.0, 3.0]]
for doc in input_docs:
db.test.insert_one(doc)
test = db.get_collection(
'test', codec_options=UNINT_DECODER_CODECOPTS)
for doc in test.find({}, batch_size=1):
self.assertIsInstance(doc['x'], UndecipherableInt64Type)
@client_context.require_version_max(4, 1, 0, -1)
def test_group_w_custom_type(self):
db = self.db
test = db.get_collection('test', codec_options=UNINT_CODECOPTS)
test.insert_many([
{'sku': 'a', 'qty': UndecipherableInt64Type(2)},
{'sku': 'b', 'qty': UndecipherableInt64Type(5)},
{'sku': 'a', 'qty': UndecipherableInt64Type(1)}])
self.assertEqual([{'sku': 'b', 'qty': UndecipherableInt64Type(5)},],
test.group(["sku", "qty"], {"sku": "b"}, {},
"function (obj, prev) { }"))
def test_aggregate_w_custom_type_decoder(self):
db = self.db
db.test.insert_many([
{'status': 'in progress', 'qty': Int64(1)},
{'status': 'complete', 'qty': Int64(10)},
{'status': 'in progress', 'qty': Int64(1)},
{'status': 'complete', 'qty': Int64(10)},
{'status': 'in progress', 'qty': Int64(1)},])
test = db.get_collection(
'test', codec_options=UNINT_DECODER_CODECOPTS)
pipeline = [
{'$match': {'status': 'complete'}},
{'$group': {'_id': "$status", 'total_qty': {"$sum": "$qty"}}},]
result = test.aggregate(pipeline)
res = list(result)[0]
self.assertEqual(res['_id'], 'complete')
self.assertIsInstance(res['total_qty'], UndecipherableInt64Type)
self.assertEqual(res['total_qty'].value, 20)
# collection.distinct does not support custom type decoding
def test_distinct_w_custom_type(self):
self.db.drop_collection("test")
test = self.db.get_collection('test', codec_options=UNINT_CODECOPTS)
test.insert_many([
{"a": UndecipherableInt64Type(1)},
{"a": UndecipherableInt64Type(2)},
{"a": UndecipherableInt64Type(2)},
{"a": UndecipherableInt64Type(2)},
{"a": UndecipherableInt64Type(3)}])
distinct = test.distinct("a")
distinct.sort()
self.assertEqual([
UndecipherableInt64Type(1), UndecipherableInt64Type(2),
UndecipherableInt64Type(3)], distinct)
def test_map_reduce_w_custom_type(self):
test = self.db.get_collection(
'test', codec_options=UPPERSTR_DECODER_CODECOPTS)
test.insert_many([
{'_id': 1, 'sku': 'abcd', 'qty': 1},
{'_id': 2, 'sku': 'abcd', 'qty': 2},
{'_id': 3, 'sku': 'abcd', 'qty': 3}])
map = Code("function () {"
" emit(this.sku, this.qty);"
"}")
reduce = Code("function (key, values) {"
" return Array.sum(values);"
"}")
result = test.map_reduce(map, reduce, out={'inline': 1})
self.assertTrue(isinstance(result, dict))
self.assertTrue('results' in result)
self.assertEqual(result['results'][0], {'_id': 'ABCD', 'value': 6})
result = test.inline_map_reduce(map, reduce)
self.assertTrue(isinstance(result, list))
self.assertEqual(1, len(result))
self.assertEqual(result[0]["_id"], 'ABCD')
full_result = test.inline_map_reduce(map, reduce,
full_response=True)
self.assertEqual(3, full_result["counts"]["emit"])
def test_find_one_and__w_custom_type_decoder(self):
db = self.db
c = db.get_collection('test', codec_options=UNINT_DECODER_CODECOPTS)
c.insert_one({'_id': 1, 'x': Int64(1)})
doc = c.find_one_and_update({'_id': 1}, {'$inc': {'x': 1}},
return_document=ReturnDocument.AFTER)
self.assertEqual(doc['_id'], 1)
self.assertIsInstance(doc['x'], UndecipherableInt64Type)
self.assertEqual(doc['x'].value, 2)
doc = c.find_one_and_replace({'_id': 1}, {'x': Int64(3), 'y': True},
return_document=ReturnDocument.AFTER)
self.assertEqual(doc['_id'], 1)
self.assertIsInstance(doc['x'], UndecipherableInt64Type)
self.assertEqual(doc['x'].value, 3)
self.assertEqual(doc['y'], True)
doc = c.find_one_and_delete({'y': True})
self.assertEqual(doc['_id'], 1)
self.assertIsInstance(doc['x'], UndecipherableInt64Type)
self.assertEqual(doc['x'].value, 3)
self.assertIsNone(c.find_one())
@ignore_deprecations
def test_find_and_modify_w_custom_type_decoder(self):
db = self.db
c = db.get_collection('test', codec_options=UNINT_DECODER_CODECOPTS)
c.insert_one({'_id': 1, 'x': Int64(1)})
doc = c.find_and_modify({'_id': 1}, {'$inc': {'x': Int64(10)}})
self.assertEqual(doc['_id'], 1)
self.assertIsInstance(doc['x'], UndecipherableInt64Type)
self.assertEqual(doc['x'].value, 1)
doc = c.find_one()
self.assertEqual(doc['_id'], 1)
self.assertIsInstance(doc['x'], UndecipherableInt64Type)
self.assertEqual(doc['x'].value, 11)
class TestGridFileCustomType(IntegrationTest):
def setUp(self):
self.db.drop_collection('fs.files')
self.db.drop_collection('fs.chunks')
def test_grid_out_custom_opts(self):
db = self.db.with_options(codec_options=UPPERSTR_DECODER_CODECOPTS)
one = GridIn(db.fs, _id=5, filename="my_file",
contentType="text/html", chunkSize=1000, aliases=["foo"],
metadata={"foo": 'red', "bar": 'blue'}, bar=3,
baz="hello")
one.write(b"hello world")
one.close()
two = GridOut(db.fs, 5)
self.assertEqual("my_file", two.name)
self.assertEqual("my_file", two.filename)
self.assertEqual(5, two._id)
self.assertEqual(11, two.length)
self.assertEqual("text/html", two.content_type)
self.assertEqual(1000, two.chunk_size)
self.assertTrue(isinstance(two.upload_date, datetime.datetime))
self.assertEqual(["foo"], two.aliases)
self.assertEqual({"foo": 'red', "bar": 'blue'}, two.metadata)
self.assertEqual(3, two.bar)
self.assertEqual("5eb63bbbe01eeed093cb22bb8f5acdc3", two.md5)
for attr in ["_id", "name", "content_type", "length", "chunk_size",
"upload_date", "aliases", "metadata", "md5"]:
self.assertRaises(AttributeError, setattr, two, attr, 5)
if __name__ == "__main__":
unittest.main()

View File

@ -58,6 +58,7 @@ from test.utils import (ignore_deprecations,
server_started_with_auth,
IMPOSSIBLE_WRITE_CONCERN,
OvertCommandListener)
from test.test_custom_types import DECIMAL_CODECOPTS
if PY3:
@ -976,6 +977,36 @@ class TestDatabase(IntegrationTest):
"maxTimeAlwaysTimeOut",
mode="off")
def test_with_options(self):
codec_options = DECIMAL_CODECOPTS
read_preference = ReadPreference.SECONDARY_PREFERRED
write_concern = WriteConcern(j=True)
read_concern = ReadConcern(level="majority")
# List of all options to compare.
allopts = ['name', 'client', 'codec_options',
'read_preference', 'write_concern', 'read_concern']
db1 = self.client.get_database(
'with_options_test', codec_options=codec_options,
read_preference=read_preference, write_concern=write_concern,
read_concern=read_concern)
# Case 1: swap no options
db2 = db1.with_options()
for opt in allopts:
self.assertEqual(getattr(db1, opt), getattr(db2, opt))
# Case 2: swap all options
newopts = {'codec_options': CodecOptions(),
'read_preference': ReadPreference.PRIMARY,
'write_concern': WriteConcern(w=1),
'read_concern': ReadConcern(level="local")}
db2 = db1.with_options(**newopts)
for opt in newopts:
self.assertEqual(
getattr(db2, opt), newopts.get(opt, getattr(db1, opt)))
def test_current_op_codec_options(self):
class MySON(SON):
pass