mongo-python-driver/pymongo/helpers.py

383 lines
15 KiB
Python

# Copyright 2009-2015 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Bits and pieces used by the driver that don't really fit elsewhere."""
import random
import struct
import warnings
import bson
import pymongo
from bson.binary import OLD_UUID_SUBTYPE
from bson.son import SON
from pymongo import auth
from pymongo.errors import (AutoReconnect,
CursorNotFound,
DuplicateKeyError,
InvalidName,
InvalidOperation,
OperationFailure,
ExecutionTimeout,
WTimeoutError)
from pymongo.read_preferences import _ServerMode
from pymongo.write_concern import WriteConcern as _WriteConcern
def _get_common_options(obj, codec_options, read_preference, write_concern):
"""Get the codec options, read preference mode and tags, and write concern
necessary to create a new Database of Collection instance.
"""
if codec_options is None:
codec_options = obj.codec_options
if read_preference is None:
rp_mode = obj.read_preference
rp_tags = obj.tag_sets
else:
if isinstance(read_preference, _ServerMode):
rp_mode = read_preference.mode
rp_tags = read_preference.tag_sets
else:
rp_mode = read_preference
rp_tags = [{}]
if write_concern is None:
wc_document = obj.write_concern
else:
if not isinstance(write_concern, _WriteConcern):
raise TypeError("write_concern must be an instance of "
"pymongo.write_concern.WriteConcern")
wc_document = write_concern.document
return codec_options, rp_mode, rp_tags, wc_document
def _index_list(key_or_list, direction=None):
"""Helper to generate a list of (key, direction) pairs.
Takes such a list, or a single key, or a single key and direction.
"""
if direction is not None:
return [(key_or_list, direction)]
else:
if isinstance(key_or_list, basestring):
return [(key_or_list, pymongo.ASCENDING)]
elif not isinstance(key_or_list, (list, tuple)):
raise TypeError("if no direction is specified, "
"key_or_list must be an instance of list")
return key_or_list
def _index_document(index_list):
"""Helper to generate an index specifying document.
Takes a list of (key, direction) pairs.
"""
if isinstance(index_list, dict):
raise TypeError("passing a dict to sort/create_index/hint is not "
"allowed - use a list of tuples instead. did you "
"mean %r?" % list(index_list.iteritems()))
elif not isinstance(index_list, (list, tuple)):
raise TypeError("must use a list of (key, direction) pairs, "
"not: " + repr(index_list))
if not len(index_list):
raise ValueError("key_or_list must not be the empty list")
index = SON()
for (key, value) in index_list:
if not isinstance(key, basestring):
raise TypeError("first item in each key pair must be a string")
if not isinstance(value, (basestring, int, dict)):
raise TypeError("second item in each key pair must be 1, -1, "
"'2d', 'geoHaystack', or another valid MongoDB "
"index specifier.")
index[key] = value
return index
def _unpack_response(response, cursor_id=None, as_class=dict,
tz_aware=False, uuid_subtype=OLD_UUID_SUBTYPE,
compile_re=True):
"""Unpack a response from the database.
Check the response for errors and unpack, returning a dictionary
containing the response data.
:Parameters:
- `response`: byte string as returned from the database
- `cursor_id` (optional): cursor_id we sent to get this response -
used for raising an informative exception when we get cursor id not
valid at server response
- `as_class` (optional): class to use for resulting documents
"""
response_flag = struct.unpack("<i", response[:4])[0]
if response_flag & 1:
# Shouldn't get this response if we aren't doing a getMore
assert cursor_id is not None
raise CursorNotFound("cursor id '%s' not valid at server" %
cursor_id)
elif response_flag & 2:
error_object = bson.BSON(response[20:]).decode()
if error_object["$err"].startswith("not master"):
raise AutoReconnect(error_object["$err"])
elif error_object.get("code") == 50:
raise ExecutionTimeout(error_object.get("$err"),
error_object.get("code"),
error_object)
raise OperationFailure("database error: %s" %
error_object.get("$err"),
error_object.get("code"),
error_object)
result = {}
result["cursor_id"] = struct.unpack("<q", response[4:12])[0]
result["starting_from"] = struct.unpack("<i", response[12:16])[0]
result["number_returned"] = struct.unpack("<i", response[16:20])[0]
result["data"] = bson.decode_all(response[20:],
as_class, tz_aware, uuid_subtype,
compile_re)
assert len(result["data"]) == result["number_returned"]
return result
def _check_command_response(response, reset, msg=None, allowable_errors=None):
"""Check the response to a command for errors.
"""
if "ok" not in response:
# Server didn't recognize our message as a command.
raise OperationFailure(response.get("$err"),
response.get("code"),
response)
if response.get("wtimeout", False):
# MongoDB versions before 1.8.0 return the error message in an "errmsg"
# field. If "errmsg" exists "err" will also exist set to None, so we
# have to check for "errmsg" first.
raise WTimeoutError(response.get("errmsg", response.get("err")),
response.get("code"),
response)
if not response["ok"]:
details = response
# Mongos returns the error details in a 'raw' object
# for some errors.
if "raw" in response:
for shard in response["raw"].itervalues():
# Grab the first non-empty raw error from a shard.
if shard.get("errmsg") and not shard.get("ok"):
details = shard
break
errmsg = details["errmsg"]
if allowable_errors is None or errmsg not in allowable_errors:
# Server is "not master" or "recovering"
if (errmsg.startswith("not master")
or errmsg.startswith("node is recovering")):
if reset is not None:
reset()
raise AutoReconnect(errmsg)
# Server assertion failures
if errmsg == "db assertion failure":
errmsg = ("db assertion failure, assertion: '%s'" %
details.get("assertion", ""))
raise OperationFailure(errmsg,
details.get("assertionCode"),
response)
# Other errors
code = details.get("code")
# findAndModify with upsert can raise duplicate key error
if code in (11000, 11001, 12582):
raise DuplicateKeyError(errmsg, code, response)
elif code == 50:
raise ExecutionTimeout(errmsg, code, response)
msg = msg or "%s"
raise OperationFailure(msg % errmsg, code, response)
def _check_write_command_response(results):
"""Backward compatibility helper for write command error handling.
"""
errors = [res for res in results
if "writeErrors" in res[1] or "writeConcernError" in res[1]]
if errors:
# If multiple batches had errors
# raise from the last batch.
offset, result = errors[-1]
# Prefer write errors over write concern errors
write_errors = result.get("writeErrors")
if write_errors:
# If the last batch had multiple errors only report
# the last error to emulate continue_on_error.
error = write_errors[-1]
error["index"] += offset
if error.get("code") == 11000:
raise DuplicateKeyError(error.get("errmsg"), 11000, error)
else:
error = result["writeConcernError"]
if "errInfo" in error and error["errInfo"].get('wtimeout'):
# Make sure we raise WTimeoutError
raise WTimeoutError(error.get("errmsg"),
error.get("code"), error)
raise OperationFailure(error.get("errmsg"), error.get("code"), error)
def _fields_list_to_dict(fields):
"""Takes a list of field names and returns a matching dictionary.
["a", "b"] becomes {"a": 1, "b": 1}
and
["a.b.c", "d", "a.c"] becomes {"a.b.c": 1, "d": 1, "a.c": 1}
"""
as_dict = {}
for field in fields:
if not isinstance(field, basestring):
raise TypeError("fields must be a list of key names, "
"each an instance of %s" % (basestring.__name__,))
as_dict[field] = 1
return as_dict
def _check_database_name(name):
"""Check if a database name is valid."""
if not name:
raise InvalidName("database name cannot be the empty string")
for invalid_char in [' ', '.', '$', '/', '\\', '\x00', '"']:
if invalid_char in name:
raise InvalidName("database names cannot contain the "
"character %r" % invalid_char)
def _copy_database(
fromdb,
todb,
fromhost,
mechanism,
username,
password,
sock_info,
cmd_func):
"""Copy a database, perhaps from a remote host.
:Parameters:
- `fromdb`: Source database.
- `todb`: Target database.
- `fromhost`: Source host like 'foo.com', 'foo.com:27017', or None.
- `mechanism`: An authentication mechanism.
- `username`: A str or unicode, or None.
- `password`: A str or unicode, or None.
- `sock_info`: A SocketInfo instance.
- `cmd_func`: A callback taking args sock_info, database, command doc.
"""
if not isinstance(fromdb, basestring):
raise TypeError('from_name must be an instance '
'of %s' % (basestring.__name__,))
if not isinstance(todb, basestring):
raise TypeError('to_name must be an instance '
'of %s' % (basestring.__name__,))
_check_database_name(todb)
warnings.warn("copy_database is deprecated. Use the raw 'copydb' command"
" or db.copyDatabase() in the mongo shell. See"
" doc/examples/copydb.",
DeprecationWarning, stacklevel=2)
# It would be better if the user told us what mechanism to use, but for
# backwards compatibility with earlier PyMongos we don't require the
# mechanism. Hope 'fromhost' runs the same version as the target.
if mechanism == 'DEFAULT':
if sock_info.max_wire_version >= 3:
mechanism = 'SCRAM-SHA-1'
else:
mechanism = 'MONGODB-CR'
if username is not None:
if mechanism == 'SCRAM-SHA-1':
credentials = auth._build_credentials_tuple(mech=mechanism,
source='admin',
user=username,
passwd=password,
extra=None)
try:
auth._copydb_scram_sha1(credentials=credentials,
sock_info=sock_info,
cmd_func=cmd_func,
fromdb=fromdb,
todb=todb,
fromhost=fromhost)
except OperationFailure, exc:
errmsg = exc.details and exc.details.get('errmsg') or ''
if 'no such cmd: saslStart' in errmsg:
explanation = (
"%s doesn't support SCRAM-SHA-1, pass"
" mechanism='MONGODB-CR' to copy_database" % fromhost)
raise OperationFailure(explanation,
exc.code,
exc.details)
else:
raise
elif mechanism == 'MONGODB-CR':
get_nonce_cmd = SON([('copydbgetnonce', 1),
('fromhost', fromhost)])
get_nonce_response, _ = cmd_func(sock_info, 'admin', get_nonce_cmd)
nonce = get_nonce_response['nonce']
copydb_cmd = SON([('copydb', 1),
('fromdb', fromdb),
('todb', todb)])
copydb_cmd['username'] = username
copydb_cmd['nonce'] = nonce
copydb_cmd['key'] = auth._auth_key(nonce, username, password)
if fromhost is not None:
copydb_cmd['fromhost'] = fromhost
cmd_func(sock_info, 'admin', copydb_cmd)
else:
raise InvalidOperation('Authentication mechanism %r not supported'
' for copy_database' % mechanism)
else:
# No username.
copydb_cmd = SON([('copydb', 1),
('fromdb', fromdb),
('todb', todb)])
if fromhost:
copydb_cmd['fromhost'] = fromhost
cmd_func(sock_info, 'admin', copydb_cmd)
def shuffled(sequence):
"""Returns a copy of the sequence (as a :class:`list`) which has been
shuffled by :func:`random.shuffle`.
"""
out = list(sequence)
random.shuffle(out)
return out