1169 lines
41 KiB
Python
1169 lines
41 KiB
Python
"""Driver for Mongo.
|
|
|
|
The database is accessed through an instance of the Mongo class."""
|
|
|
|
import unittest
|
|
import socket
|
|
import types
|
|
import traceback
|
|
import os
|
|
import struct
|
|
import random
|
|
|
|
from son import SON
|
|
import bson
|
|
from objectid import ObjectId
|
|
from dbref import DBRef
|
|
|
|
class DatabaseException(Exception):
|
|
"""Raised when a database operation fails.
|
|
"""
|
|
|
|
class InvalidOperation(Exception):
|
|
"""Raised when a client attempts to perform an invalid operation.
|
|
"""
|
|
|
|
class ConnectionException(IOError):
|
|
"""Raised when a connection to the database cannot be made or is lost.
|
|
"""
|
|
|
|
class InvalidCollection(ValueError):
|
|
"""Raised when an invalid collection name is used.
|
|
"""
|
|
|
|
_ZERO = "\x00\x00\x00\x00"
|
|
_ONE = "\x01\x00\x00\x00"
|
|
_MAX_DYING_CURSORS = 20
|
|
_SYSTEM_INDEX_COLLECTION = "system.indexes"
|
|
|
|
ASCENDING = 1
|
|
DESCENDING = -1
|
|
|
|
class Mongo(object):
|
|
"""A connection to a Mongo database.
|
|
"""
|
|
def __init__(self, name, host="localhost", port=27017, settings={}):
|
|
"""Open a new connection to the database at host:port.
|
|
|
|
Raises TypeError if name or host is not an instance of string or port is
|
|
not an instance of int. Raises ConnectionException if the connection
|
|
cannot be made.
|
|
|
|
Settings are passed in as a dictionary. Possible settings, along with
|
|
their default values (in parens), are listed below:
|
|
- "auto_dereference" (False): automatically dereference any `DBRef`s
|
|
contained within SON objects being returned from queries
|
|
- "auto_reference" (False): automatically create `DBRef`s out of any
|
|
sub-objects that have already been saved in the database
|
|
|
|
Arguments:
|
|
- `name`: the name of the database to connect to
|
|
- `host` (optional): the hostname or IPv4 address of the database to
|
|
connect to
|
|
- `port` (optional): the port number on which to connect
|
|
- `settings` (optional): a dictionary of settings
|
|
"""
|
|
if not isinstance(name, types.StringTypes):
|
|
raise TypeError("name must be an instance of (str, unicode)")
|
|
if not isinstance(host, types.StringType):
|
|
raise TypeError("host must be an instance of str")
|
|
if not isinstance(port, types.IntType):
|
|
raise TypeError("port must be an instance of int")
|
|
if not isinstance(settings, types.DictType):
|
|
raise TypeError("settings must be an instance of dict")
|
|
|
|
self.__name = name
|
|
self.__host = host
|
|
self.__port = port
|
|
self.__id = 1
|
|
self.__dying_cursors = []
|
|
self.__auto_dereference = settings.get("auto_dereference", False)
|
|
self.__auto_reference = settings.get("auto_reference", False)
|
|
|
|
self.__connect()
|
|
|
|
def name(self):
|
|
return self.__name
|
|
|
|
def __connect(self):
|
|
"""(Re-)connect to the database."""
|
|
try:
|
|
self.__connection = socket.socket()
|
|
self.__connection.connect((self.__host, self.__port))
|
|
except socket.error:
|
|
raise ConnectionException("could not connect to %s:%s, got: %s" %
|
|
(self.__host, self.__port,
|
|
traceback.format_exc()))
|
|
|
|
def _send_message(self, operation, data):
|
|
"""Say something to the database.
|
|
|
|
Arguments:
|
|
- `operation`: the opcode of the message
|
|
- `data`: the data to send
|
|
"""
|
|
# header
|
|
to_send = struct.pack("<i", 16 + len(data))
|
|
to_send += struct.pack("<i", self.__id)
|
|
self.__id += 1
|
|
to_send += struct.pack("<i", 0) # responseTo
|
|
to_send += struct.pack("<i", operation)
|
|
|
|
to_send += data
|
|
|
|
total_sent = 0
|
|
while total_sent < len(to_send):
|
|
sent = self.__connection.send(to_send[total_sent:])
|
|
if sent == 0:
|
|
raise ConnectionException("connection closed")
|
|
total_sent += sent
|
|
|
|
return self.__id - 1
|
|
|
|
def _receive_message(self, operation, request_id):
|
|
"""Receive a message from the database.
|
|
|
|
Returns the message body. Asserts that the message uses the given opcode
|
|
and request id.
|
|
|
|
Arguments:
|
|
- `operation`: the opcode of the message
|
|
- `request_id`: the request id that the message should be in response to
|
|
"""
|
|
def receive(length):
|
|
message = ""
|
|
while len(message) < length:
|
|
chunk = self.__connection.recv(length - len(message))
|
|
if chunk == "":
|
|
raise ConnectionException("connection closed")
|
|
message += chunk
|
|
return message
|
|
|
|
header = receive(16)
|
|
length = struct.unpack("<i", header[:4])[0]
|
|
assert request_id == struct.unpack("<i", header[8:12])[0]
|
|
assert operation == struct.unpack("<i", header[12:])[0]
|
|
|
|
return receive(length - 16)
|
|
|
|
def _command(self, command):
|
|
"""Issue a DB command.
|
|
"""
|
|
return self["$cmd"].find_one(command)
|
|
|
|
def _kill_cursors(self):
|
|
message = _ZERO
|
|
message += struct.pack("<i", len(self.__dying_cursors))
|
|
for cursor_id in self.__dying_cursors:
|
|
message += struct.pack("<q", cursor_id)
|
|
self._send_message(2007, message)
|
|
self.__dying_cursors = []
|
|
|
|
def _kill_cursor(self, cursor_id):
|
|
self.__dying_cursors.append(cursor_id)
|
|
|
|
if len(self.__dying_cursors) > _MAX_DYING_CURSORS:
|
|
self._kill_cursors()
|
|
|
|
def __cmp__(self, other):
|
|
if isinstance(other, Mongo):
|
|
return cmp((self.__host, self.__port), (other.__host, other.__port))
|
|
return NotImplemented
|
|
|
|
def __repr__(self):
|
|
return "Mongo(%r, %r, %r)" % (self.__name, self.__host, self.__port)
|
|
|
|
def __getattr__(self, name):
|
|
"""Get a collection of this database by name.
|
|
|
|
Raises InvalidCollection if an invalid collection name is used.
|
|
|
|
Arguments:
|
|
- `name`: the name of the collection to get
|
|
"""
|
|
return Collection(self, name)
|
|
|
|
def __getitem__(self, name):
|
|
return self.__getattr__(name)
|
|
|
|
def dereference(self, dbref):
|
|
"""Dereference a DBRef, getting the SON object it points to.
|
|
|
|
Raises TypeError if dbref is not an instance of DBRef. Returns a SON
|
|
object or None if the reference does not point to a valid object.
|
|
|
|
Arguments:
|
|
- `dbref`: the reference
|
|
"""
|
|
if not isinstance(dbref, DBRef):
|
|
raise TypeError("cannot dereference a %s" % type(dbref))
|
|
return self[dbref.collection()].find_one(dbref.id())
|
|
|
|
def _fix_outgoing(self, son):
|
|
"""Fixes an object coming out of the database.
|
|
|
|
Used to do things like auto dereferencing, if the option is enabled.
|
|
|
|
Arguments:
|
|
- `son`: a SON object coming out of the database
|
|
"""
|
|
if not self.__auto_dereference:
|
|
return son
|
|
|
|
def fix_value(value):
|
|
if isinstance(value, DBRef):
|
|
deref = self.dereference(value)
|
|
if deref is None:
|
|
return value
|
|
return self._fix_outgoing(deref)
|
|
elif isinstance(value, (SON, types.DictType)):
|
|
return self._fix_outgoing(value)
|
|
elif isinstance(value, types.ListType):
|
|
return [fix_value(v) for v in value]
|
|
return value
|
|
|
|
for (key, value) in son.items():
|
|
son[key] = fix_value(value)
|
|
|
|
return son
|
|
|
|
def _fix_incoming(self, to_save, collection, add_meta):
|
|
"""Fixes an object going in to the database.
|
|
|
|
Used to do things like auto referencing, if the option is enabled.
|
|
Will also add _id and _ns if they are missing and desired (as specified
|
|
by add_meta).
|
|
|
|
Arguments:
|
|
- `to_save`: a SON object going into the database
|
|
- `collection`: collection into which this object is being saved
|
|
- `add_meta`: should _id and other meta-fields be added to the object
|
|
"""
|
|
if "_id" in to_save:
|
|
assert isinstance(to_save["_id"], ObjectId), "'_id' must be an ObjectId"
|
|
elif add_meta:
|
|
to_save["_id"] = ObjectId()
|
|
|
|
if add_meta:
|
|
to_save["_ns"] = collection._name()
|
|
|
|
if not self.__auto_reference:
|
|
return to_save
|
|
|
|
# make a copy, so only what is being saved gets auto-ref'ed
|
|
to_save = SON(to_save)
|
|
|
|
def fix_value(value):
|
|
if isinstance(value, (SON, types.DictType)):
|
|
if "_id" in value and not value["_id"].is_new() and "_ns" in value:
|
|
return DBRef(value["_ns"], value["_id"])
|
|
return value
|
|
|
|
for (key, value) in to_save.items():
|
|
to_save[key] = fix_value(value)
|
|
|
|
return to_save
|
|
|
|
class Collection(object):
|
|
"""A Mongo collection.
|
|
"""
|
|
def __init__(self, database, name):
|
|
"""Get / create a Mongo collection.
|
|
|
|
Raises TypeError if database is not an instance of Mongo or name is not
|
|
an instance of (str, unicode). Raises InvalidCollection if name is not a
|
|
valid collection name.
|
|
|
|
Arguments:
|
|
- `database`: the database to get a collection from
|
|
- `name`: the name of the collection to get
|
|
"""
|
|
if not isinstance(database, Mongo):
|
|
raise TypeError("database must be an instance of Mongo")
|
|
if not isinstance(name, types.StringTypes):
|
|
raise TypeError("name must be an instance of (str, unicode)")
|
|
|
|
if not name or ".." in name:
|
|
raise InvalidCollection("collection names cannot be empty")
|
|
if "$" in name and name not in ["$cmd"]:
|
|
raise InvalidCollection("collection names must not contain '$'")
|
|
if name[0] == "." or name[-1] == ".":
|
|
raise InvalidCollection("collecion names must not start or end with '.'")
|
|
|
|
self.__database = database
|
|
self.__collection_name = unicode(name)
|
|
|
|
def __getattr__(self, name):
|
|
"""Get a sub-collection of this collection by name.
|
|
|
|
Raises InvalidCollection if an invalid collection name is used.
|
|
|
|
Arguments:
|
|
- `name`: the name of the collection to get
|
|
"""
|
|
return Collection(self.__database, u"%s.%s" % (self.__collection_name, name))
|
|
|
|
def __getitem__(self, name):
|
|
return self.__getattr__(name)
|
|
|
|
def __repr__(self):
|
|
return "Collection(%r, %r)" % (self.__database, self.__collection_name)
|
|
|
|
def __cmp__(self, other):
|
|
if isinstance(other, Collection):
|
|
return cmp((self.__database, self.__collection_name),
|
|
(other.__database, other.__collection_name))
|
|
return NotImplemented
|
|
|
|
def _full_name(self):
|
|
return u"%s.%s" % (self.__database.name(), self.__collection_name)
|
|
|
|
def _name(self):
|
|
return self.__collection_name
|
|
|
|
def _send_message(self, operation, data):
|
|
"""Wrap up a message and send it.
|
|
"""
|
|
# reserved int, full collection name, message data
|
|
message = _ZERO
|
|
message += bson._make_c_string(self._full_name())
|
|
message += data
|
|
return self.__database._send_message(operation, message)
|
|
|
|
def database(self):
|
|
return self.__database
|
|
|
|
def save(self, to_save, add_meta=True):
|
|
"""Save a SON object in this collection.
|
|
|
|
Raises TypeError if to_save is not an instance of (dict, SON).
|
|
|
|
Arguments:
|
|
- `to_save`: the SON object to be saved
|
|
- `add_meta` (optional): add meta information (like _id) to the object
|
|
if it's missing
|
|
"""
|
|
if not isinstance(to_save, (types.DictType, SON)):
|
|
raise TypeError("cannot save object of type %s" % type(to_save))
|
|
|
|
to_save = self.__database._fix_incoming(to_save, self, add_meta)
|
|
|
|
if "_id" not in to_save:
|
|
self._send_message(2002, bson.BSON.from_dict(to_save))
|
|
elif to_save["_id"].is_new():
|
|
to_save["_id"]._use()
|
|
self._send_message(2002, bson.BSON.from_dict(to_save))
|
|
else:
|
|
self._update({"_id": to_save["_id"]}, to_save, True)
|
|
|
|
return to_save.get("_id", None)
|
|
|
|
def _update(self, spec, document, upsert=False):
|
|
"""Update an object(s) in this collection.
|
|
|
|
Raises TypeError if either spec or document isn't an instance of
|
|
(dict, SON) or upsert isn't an instance of bool.
|
|
|
|
- `spec`: a SON object specifying elements which must be present for a
|
|
document to be updated
|
|
- `document`: a SON object specifying the fields to be changed in the
|
|
selected document(s), or (in the case of an upsert) the document to
|
|
be inserted.
|
|
- `upsert` (optional): perform an upsert operation
|
|
"""
|
|
if not isinstance(spec, (types.DictType, SON)):
|
|
raise TypeError("spec must be an instance of (dict, SON)")
|
|
if not isinstance(document, (types.DictType, SON)):
|
|
raise TypeError("document must be an instance of (dict, SON)")
|
|
if not isinstance(upsert, types.BooleanType):
|
|
raise TypeError("upsert must be an instance of bool")
|
|
|
|
message = upsert and _ONE or _ZERO
|
|
message += bson.BSON.from_dict(spec)
|
|
message += bson.BSON.from_dict(document)
|
|
|
|
self._send_message(2001, message)
|
|
|
|
def remove(self, spec_or_object_id):
|
|
"""Remove an object(s) from this collection.
|
|
|
|
Raises TypeEror if the argument is not an instance of
|
|
(dict, SON, ObjectId).
|
|
|
|
Arguments:
|
|
- `spec_or_object_id` (optional): a SON object specifying elements
|
|
which must be present for a document to be removed OR an instance of
|
|
ObjectId to be used as the value for an _id element
|
|
"""
|
|
spec = spec_or_object_id
|
|
if isinstance(spec, ObjectId):
|
|
spec = SON({"_id": spec})
|
|
|
|
if not isinstance(spec, (types.DictType, SON)):
|
|
raise TypeError("spec must be an instance of (dict, SON)")
|
|
|
|
self._send_message(2006, _ZERO + bson.BSON.from_dict(spec))
|
|
|
|
def find_one(self, spec_or_object_id=SON()):
|
|
"""Get a single object from the database.
|
|
|
|
Raises TypeError if the argument is of an improper type. Returns a
|
|
single SON object, or None if no result is found.
|
|
|
|
Arguments:
|
|
- `spec_or_object_id` (optional): a SON object specifying elements
|
|
which must be present for a document to be included in the result
|
|
set OR an instance of ObjectId to be used as the value for an _id
|
|
query
|
|
"""
|
|
spec = spec_or_object_id
|
|
if isinstance(spec, ObjectId):
|
|
spec = SON({"_id": spec})
|
|
|
|
for result in self.find(spec, limit=1):
|
|
return result
|
|
return None
|
|
|
|
def find(self, spec=SON(), fields=[], skip=0, limit=0):
|
|
"""Query the database.
|
|
|
|
Raises TypeError if any of the arguments are of improper type. Returns
|
|
an instance of Cursor corresponding to this query.
|
|
|
|
Arguments:
|
|
- `spec` (optional): a SON object specifying elements which must be
|
|
present for a document to be included in the result set
|
|
- `fields` (optional): a list of field names that should be returned
|
|
in the result set
|
|
- `skip` (optional): the number of documents to omit (from the start of
|
|
the result set) when returning the results
|
|
- `limit` (optional): the maximum number of results to return in the
|
|
first reply message, or 0 for the default return size
|
|
"""
|
|
if not isinstance(spec, (types.DictType, SON)):
|
|
raise TypeError("spec must be an instance of (dict, SON)")
|
|
if not isinstance(fields, types.ListType):
|
|
raise TypeError("fields must be an instance of list")
|
|
if not isinstance(skip, types.IntType):
|
|
raise TypeError("skip must be an instance of int")
|
|
if not isinstance(limit, types.IntType):
|
|
raise TypeError("limit must be an instance of int")
|
|
|
|
return_fields = len(fields) and SON() or None
|
|
for field in fields:
|
|
if not isinstance(field, types.StringTypes):
|
|
raise TypeError("fields must be a list of key names as (string, unicode)")
|
|
return_fields[field] = 1
|
|
|
|
return Cursor(self, spec, return_fields, skip, limit)
|
|
|
|
def _gen_index_name(self, keys):
|
|
"""Generate an index name from the set of fields it is over.
|
|
"""
|
|
return u"_".join([u"%s_%s" % item for item in keys])
|
|
|
|
def create_index(self, key_or_list, direction=None):
|
|
"""Creates an index on this collection.
|
|
|
|
Takes either a single key and a direction, or a list of (key, direction)
|
|
pairs. The key(s) must be an instance of (str, unicode), and the
|
|
direction(s) must be one of (Mongo.ASCENDING, Mongo.DESCENDING).
|
|
|
|
Arguments:
|
|
- `key_or_list`: a single key or a list of (key, direction) pairs
|
|
specifying the index to ensure
|
|
- `direction` (optional): must be included if key_or_list is a single
|
|
key, otherwise must be None
|
|
"""
|
|
if direction:
|
|
keys = [(key_or_list, direction)]
|
|
else:
|
|
keys = key_or_list
|
|
|
|
if not isinstance(keys, types.ListType):
|
|
raise TypeError("if no direction is specified, key_or_list must be an instance of list")
|
|
if not len(keys):
|
|
raise ValueError("key_or_list must not be the empty list")
|
|
|
|
to_save = SON()
|
|
to_save["name"] = self._gen_index_name(keys)
|
|
to_save["ns"] = self._full_name()
|
|
|
|
key_object = SON()
|
|
for (key, value) in keys:
|
|
if not isinstance(key, types.StringTypes):
|
|
raise TypeError("first item in each key pair must be a string")
|
|
if not isinstance(value, types.IntType):
|
|
raise TypeError("second item in each key pair must be Mongo.ASCENDING or Mongo.DESCENDING")
|
|
key_object[key] = value
|
|
to_save["key"] = key_object
|
|
|
|
self.__database[_SYSTEM_INDEX_COLLECTION].save(to_save, False)
|
|
|
|
def drop_indexes(self):
|
|
"""Drops all indexes on this collection.
|
|
|
|
Can be used on non-existant collections or collections with no indexes.
|
|
Raises DatabaseException on an error.
|
|
"""
|
|
response = self.__database._command(SON([("deleteIndexes", self.__collection_name),
|
|
("index", u"*")]))
|
|
if response["ok"] != 1:
|
|
if response["errmsg"] == "ns not found":
|
|
return
|
|
raise DatabaseException("error dropping indexes: %s" % response["errmsg"])
|
|
|
|
class Cursor(object):
|
|
"""A cursor / iterator over Mongo query results.
|
|
"""
|
|
def __init__(self, collection, spec, fields, skip, limit):
|
|
"""Create a new cursor.
|
|
|
|
Generally not needed to be used by application developers.
|
|
"""
|
|
self.__collection = collection
|
|
self.__spec = spec
|
|
self.__fields = fields
|
|
self.__skip = skip
|
|
self.__limit = limit
|
|
self.__ordering = None
|
|
|
|
self.__data = []
|
|
self.__id = None
|
|
self.__retrieved = 0
|
|
self.__killed = False
|
|
|
|
def __del__(self):
|
|
if self.__id and not self.__killed:
|
|
self._die()
|
|
|
|
def _die(self):
|
|
"""Kills this cursor.
|
|
"""
|
|
self.__collection.database()._kill_cursor(self.__id)
|
|
self.__killed = True
|
|
|
|
def __query_spec(self):
|
|
"""Get the spec to use for a query.
|
|
|
|
Just `self.__spec`, unless this cursor needs special query fields, like
|
|
orderby.
|
|
"""
|
|
if not self.__ordering:
|
|
return self.__spec
|
|
return SON([("query", self.__spec),
|
|
("orderby", self.__ordering)])
|
|
|
|
def __check_okay_to_chain(self):
|
|
"""Check if it is okay to chain more options onto this cursor.
|
|
"""
|
|
if self.__retrieved or self.__id is not None:
|
|
raise InvalidOperation("cannot set options after executing query")
|
|
|
|
def limit(self, limit):
|
|
"""Limits the number of results to be returned by this cursor.
|
|
|
|
Raises TypeError if limit is not an instance of int. Raises
|
|
InvalidOperation if this cursor has already been used.
|
|
|
|
Arguments:
|
|
- `limit`: the number of results to return
|
|
"""
|
|
if not isinstance(limit, types.IntType):
|
|
raise TypeError("limit must be an int")
|
|
self.__check_okay_to_chain()
|
|
|
|
self.__limit = limit
|
|
return self
|
|
|
|
def skip(self, skip):
|
|
"""Skips the first `skip` results of this cursor.
|
|
|
|
Raises TypeError if skip is not an instance of int. Raises
|
|
InvalidOperation if this cursor has already been used.
|
|
|
|
Arguments:
|
|
- `skip`: the number of results to skip
|
|
"""
|
|
if not isinstance(skip, types.IntType):
|
|
raise TypeError("skip must be an int")
|
|
self.__check_okay_to_chain()
|
|
|
|
self.__skip = skip
|
|
return self
|
|
|
|
def sort(self, key_or_list, direction=None):
|
|
"""Sorts this cursors results.
|
|
|
|
Takes either a single key and a direction, or a list of (key, direction)
|
|
pairs. The key(s) must be an instance of (str, unicode), and the
|
|
direction(s) must be one of (Mongo.ASCENDING, Mongo.DESCENDING). Raises
|
|
InvalidOperation if this cursor has already been used.
|
|
|
|
Arguments:
|
|
- `key_or_list`: a single key or a list of (key, direction) pairs
|
|
specifying the keys to sort on
|
|
- `direction` (optional): must be included if key_or_list is a single
|
|
key, otherwise must be None
|
|
"""
|
|
self.__check_okay_to_chain()
|
|
|
|
# TODO a lot of this logic could be shared with create_index()
|
|
if direction:
|
|
keys = [(key_or_list, direction)]
|
|
else:
|
|
keys = key_or_list
|
|
|
|
if not isinstance(keys, types.ListType):
|
|
raise TypeError("if no direction is specified, key_or_list must be an instance of list")
|
|
if not len(keys):
|
|
raise ValueError("key_or_list must not be the empty list")
|
|
|
|
orderby = SON()
|
|
for (key, value) in keys:
|
|
if not isinstance(key, types.StringTypes):
|
|
raise TypeError("first item in each key pair must be a string")
|
|
if not isinstance(value, types.IntType):
|
|
raise TypeError("second item in each key pair must be Mongo.ASCENDING or Mongo.DESCENDING")
|
|
orderby[key] = value
|
|
|
|
self.__ordering = orderby
|
|
return self
|
|
|
|
def count(self):
|
|
"""Get the size of the results set for this query.
|
|
|
|
Returns the number of objects in the results set for this query. Does
|
|
not take limit and skip into account. Raises Invalid Operation if this
|
|
cursor has already been used. Raises DatabaseException on a database
|
|
error.
|
|
"""
|
|
self.__check_okay_to_chain()
|
|
|
|
command = SON([("count", self.__collection._name()),
|
|
("query", self.__spec)])
|
|
response = self.__collection.database()._command(command)
|
|
if response["ok"] != 1:
|
|
if response["errmsg"] == "ns does not exist":
|
|
return 0
|
|
raise DatabaseException("error getting count: %s" % response["errmsg"])
|
|
return int(response["n"])
|
|
|
|
def _refresh(self):
|
|
"""Refreshes the cursor with more data from Mongo.
|
|
|
|
Returns the length of self.__data after refresh. Will exit early if
|
|
self.__data is already non-empty. Raises DatabaseException when the
|
|
cursor cannot be refreshed due to an error on the query.
|
|
"""
|
|
if len(self.__data) or self.__killed:
|
|
return len(self.__data)
|
|
|
|
def send_message(operation, message):
|
|
# TODO the send and receive here should be synchronized...
|
|
request_id = self.__collection._send_message(operation, message)
|
|
response = self.__collection.database()._receive_message(1, request_id)
|
|
|
|
response_flag = struct.unpack("<i", response[:4])[0]
|
|
if response_flag == 1:
|
|
raise DatabaseException("cursor id '%s' not valid at server" % self.__id)
|
|
elif response_flag == 2:
|
|
error_object = bson.to_dict(response[20:])
|
|
raise DatabaseException("database error: %s" % error_object["$err"])
|
|
else:
|
|
assert response_flag == 0
|
|
|
|
self.__id = struct.unpack("<q", response[4:12])[0]
|
|
assert struct.unpack("<i", response[12:16])[0] == self.__retrieved
|
|
|
|
number_returned = struct.unpack("<i", response[16:20])[0]
|
|
self.__retrieved += number_returned
|
|
self.__data = bson.to_dicts(response[20:])
|
|
assert len(self.__data) == number_returned
|
|
|
|
if self.__id is None:
|
|
# Query
|
|
message = struct.pack("<i", self.__skip)
|
|
message += struct.pack("<i", self.__limit)
|
|
message += bson.BSON.from_dict(self.__query_spec())
|
|
if self.__fields:
|
|
message += bson.BSON.from_dict(self.__fields)
|
|
|
|
send_message(2004, message)
|
|
elif self.__id != 0:
|
|
# Get More
|
|
limit = 0
|
|
if self.__limit:
|
|
if self.__limit > self.__retrieved:
|
|
limit = self.__limit - self.__retrieved
|
|
else:
|
|
self._die()
|
|
return 0
|
|
|
|
message = struct.pack("<i", limit)
|
|
message += struct.pack("<q", self.__id)
|
|
|
|
send_message(2005, message)
|
|
|
|
length = len(self.__data)
|
|
if not length:
|
|
self._die()
|
|
return length
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def next(self):
|
|
if len(self.__data):
|
|
return self.__collection.database()._fix_outgoing(self.__data.pop(0))
|
|
if self._refresh():
|
|
return self.__collection.database()._fix_outgoing(self.__data.pop(0))
|
|
raise StopIteration
|
|
|
|
class TestMongo(unittest.TestCase):
|
|
def setUp(self):
|
|
self.host = os.environ.get("db_ip", "localhost")
|
|
self.port = int(os.environ.get("db_port", 27017))
|
|
|
|
def test_connection(self):
|
|
self.assertRaises(TypeError, Mongo, 1)
|
|
self.assertRaises(TypeError, Mongo, 1.14)
|
|
self.assertRaises(TypeError, Mongo, None)
|
|
self.assertRaises(TypeError, Mongo, [])
|
|
self.assertRaises(TypeError, Mongo, "test", 1)
|
|
self.assertRaises(TypeError, Mongo, "test", 1.14)
|
|
self.assertRaises(TypeError, Mongo, "test", None)
|
|
self.assertRaises(TypeError, Mongo, "test", [])
|
|
self.assertRaises(TypeError, Mongo, "test", "localhost", "27017")
|
|
self.assertRaises(TypeError, Mongo, "test", "localhost", 1.14)
|
|
self.assertRaises(TypeError, Mongo, "test", "localhost", None)
|
|
self.assertRaises(TypeError, Mongo, "test", "localhost", [])
|
|
self.assertRaises(TypeError, Mongo, "test", "localhost", 27017, "settings")
|
|
self.assertRaises(TypeError, Mongo, "test", "localhost", 27017, None)
|
|
|
|
self.assertRaises(ConnectionException, Mongo, "test", "somedomainthatdoesntexist.org")
|
|
self.assertRaises(ConnectionException, Mongo, "test", self.host, 123456789)
|
|
|
|
self.assertTrue(Mongo("test", self.host, self.port))
|
|
|
|
def test_repr(self):
|
|
self.assertEqual(repr(Mongo("test", self.host, self.port)),
|
|
"Mongo('test', '%s', %s)" % (self.host, self.port))
|
|
self.assertEqual(repr(Mongo("test", self.host, self.port).test),
|
|
"Collection(Mongo('test', '%s', %s), u'test')" % (self.host, self.port))
|
|
|
|
def test_collection(self):
|
|
db = Mongo(u"test", self.host, self.port)
|
|
|
|
self.assertRaises(TypeError, Collection, db, 5)
|
|
self.assertRaises(TypeError, Collection, 5, "test")
|
|
|
|
def make_col(base, name):
|
|
base[name]
|
|
|
|
self.assertRaises(InvalidCollection, make_col, db, "")
|
|
self.assertRaises(InvalidCollection, make_col, db, "te$t")
|
|
self.assertRaises(InvalidCollection, make_col, db, ".test")
|
|
self.assertRaises(InvalidCollection, make_col, db, "test.")
|
|
self.assertRaises(InvalidCollection, make_col, db, "tes..t")
|
|
self.assertRaises(InvalidCollection, make_col, db.test, "")
|
|
self.assertRaises(InvalidCollection, make_col, db.test, "te$t")
|
|
self.assertRaises(InvalidCollection, make_col, db.test, ".test")
|
|
self.assertRaises(InvalidCollection, make_col, db.test, "test.")
|
|
self.assertRaises(InvalidCollection, make_col, db.test, "tes..t")
|
|
|
|
self.assertTrue(isinstance(db.test, Collection))
|
|
self.assertEqual(db.test, db["test"])
|
|
self.assertEqual(db.test, Collection(db, "test"))
|
|
self.assertEqual(db.test.mike, db["test.mike"])
|
|
self.assertEqual(db.test["mike"], db["test.mike"])
|
|
|
|
def test_save_find_one(self):
|
|
db = Mongo("test", self.host, self.port)
|
|
db.test.remove({})
|
|
|
|
a_doc = SON({"hello": u"world"})
|
|
a_key = db.test.save(a_doc)
|
|
self.assertTrue(isinstance(a_doc["_id"], ObjectId))
|
|
self.assertEqual(a_doc["_id"], a_key)
|
|
self.assertEqual(a_doc, db.test.find_one({"_id": a_doc["_id"]}))
|
|
self.assertEqual(a_doc, db.test.find_one(a_key))
|
|
self.assertEqual(None, db.test.find_one(ObjectId()))
|
|
self.assertEqual(a_doc, db.test.find_one({"hello": u"world"}))
|
|
self.assertEqual(None, db.test.find_one({"hello": u"test"}))
|
|
|
|
b = db.test.find_one()
|
|
self.assertFalse(b["_id"].is_new())
|
|
b["hello"] = u"mike"
|
|
db.test.save(b)
|
|
|
|
self.assertNotEqual(a_doc, db.test.find_one(a_key))
|
|
self.assertEqual(b, db.test.find_one(a_key))
|
|
self.assertEqual(b, db.test.find_one())
|
|
|
|
count = 0
|
|
for _ in db.test.find():
|
|
count += 1
|
|
self.assertEqual(count, 1)
|
|
|
|
def test_remove(self):
|
|
db = Mongo("test", self.host, self.port)
|
|
db.test.remove({})
|
|
|
|
self.assertRaises(TypeError, db.test.remove, 5)
|
|
self.assertRaises(TypeError, db.test.remove, "test")
|
|
self.assertRaises(TypeError, db.test.remove, [])
|
|
|
|
one = db.test.save({"x": 1})
|
|
two = db.test.save({"x": 2})
|
|
three = db.test.save({"x": 3})
|
|
length = 0
|
|
for _ in db.test.find():
|
|
length += 1
|
|
self.assertEqual(length, 3)
|
|
|
|
db.test.remove(one)
|
|
length = 0
|
|
for _ in db.test.find():
|
|
length += 1
|
|
self.assertEqual(length, 2)
|
|
|
|
db.test.remove(db.test.find_one())
|
|
db.test.remove(db.test.find_one())
|
|
self.assertEqual(db.test.find_one(), None)
|
|
|
|
one = db.test.save({"x": 1})
|
|
two = db.test.save({"x": 2})
|
|
three = db.test.save({"x": 3})
|
|
|
|
self.assertTrue(db.test.find_one({"x": 2}))
|
|
db.test.remove({"x": 2})
|
|
self.assertFalse(db.test.find_one({"x": 2}))
|
|
|
|
self.assertTrue(db.test.find_one())
|
|
db.test.remove({})
|
|
self.assertFalse(db.test.find_one())
|
|
|
|
def test_save_a_bunch(self):
|
|
db = Mongo("test", self.host, self.port)
|
|
db.test.remove({})
|
|
|
|
for i in xrange(1000):
|
|
db.test.save({"x": i})
|
|
|
|
count = 0
|
|
for _ in db.test.find():
|
|
count += 1
|
|
|
|
self.assertEqual(1000, count)
|
|
|
|
# test that kill cursors doesn't assert or anything
|
|
for _ in xrange(3 * _MAX_DYING_CURSORS + 2):
|
|
for _ in db.test.find():
|
|
break
|
|
|
|
def test_create_index(self):
|
|
db = Mongo("test", self.host, self.port)
|
|
|
|
self.assertRaises(TypeError, db.test.create_index, 5)
|
|
self.assertRaises(TypeError, db.test.create_index, "hello")
|
|
self.assertRaises(ValueError, db.test.create_index, [])
|
|
self.assertRaises(TypeError, db.test.create_index, [], ASCENDING)
|
|
self.assertRaises(TypeError, db.test.create_index, [("hello", DESCENDING)], DESCENDING)
|
|
self.assertRaises(TypeError, db.test.create_index, "hello", "world")
|
|
|
|
db.test.drop_indexes()
|
|
self.assertFalse(db[_SYSTEM_INDEX_COLLECTION].find_one({"ns": u"test.test"}))
|
|
|
|
db.test.create_index("hello", ASCENDING)
|
|
db.test.create_index([("hello", DESCENDING), ("world", ASCENDING)])
|
|
|
|
count = 0
|
|
for _ in db[_SYSTEM_INDEX_COLLECTION].find({"ns": u"test.test"}):
|
|
count += 1
|
|
self.assertEqual(count, 2)
|
|
|
|
db.test.drop_indexes()
|
|
self.assertFalse(db[_SYSTEM_INDEX_COLLECTION].find_one({"ns": u"test.test"}))
|
|
db.test.create_index("hello", ASCENDING)
|
|
self.assertEqual(db[_SYSTEM_INDEX_COLLECTION].find_one({"ns": u"test.test"}),
|
|
SON([(u"name", u"hello_1"),
|
|
(u"ns", u"test.test"),
|
|
(u"key", SON([(u"hello", 1)]))]))
|
|
|
|
db.test.drop_indexes()
|
|
self.assertFalse(db[_SYSTEM_INDEX_COLLECTION].find_one({"ns": u"test.test"}))
|
|
db.test.create_index([("hello", DESCENDING), ("world", ASCENDING)])
|
|
self.assertEqual(db[_SYSTEM_INDEX_COLLECTION].find_one({"ns": u"test.test"}),
|
|
SON([(u"name", u"hello_-1_world_1"),
|
|
(u"ns", u"test.test"),
|
|
(u"key", SON([(u"hello", -1),
|
|
(u"world", 1)]))]))
|
|
|
|
def test_limit(self):
|
|
db = Mongo("test", self.host, self.port)
|
|
|
|
self.assertRaises(TypeError, db.test.find().limit, None)
|
|
self.assertRaises(TypeError, db.test.find().limit, "hello")
|
|
self.assertRaises(TypeError, db.test.find().limit, 5.5)
|
|
|
|
db.test.remove({})
|
|
for i in range(100):
|
|
db.test.save({"x": i})
|
|
|
|
count = 0
|
|
for _ in db.test.find():
|
|
count += 1
|
|
self.assertEqual(count, 100)
|
|
|
|
count = 0
|
|
for _ in db.test.find().limit(20):
|
|
count += 1
|
|
self.assertEqual(count, 20)
|
|
|
|
count = 0
|
|
for _ in db.test.find().limit(99):
|
|
count += 1
|
|
self.assertEqual(count, 99)
|
|
|
|
count = 0
|
|
for _ in db.test.find().limit(1):
|
|
count += 1
|
|
self.assertEqual(count, 1)
|
|
|
|
count = 0
|
|
for _ in db.test.find().limit(0):
|
|
count += 1
|
|
self.assertEqual(count, 100)
|
|
|
|
count = 0
|
|
for _ in db.test.find().limit(0).limit(50).limit(10):
|
|
count += 1
|
|
self.assertEqual(count, 10)
|
|
|
|
a = db.test.find()
|
|
a.limit(10)
|
|
for _ in a:
|
|
break
|
|
self.assertRaises(InvalidOperation, a.limit, 5)
|
|
|
|
def test_limit(self):
|
|
db = Mongo("test", self.host, self.port)
|
|
|
|
self.assertRaises(TypeError, db.test.find().skip, None)
|
|
self.assertRaises(TypeError, db.test.find().skip, "hello")
|
|
self.assertRaises(TypeError, db.test.find().skip, 5.5)
|
|
|
|
db.test.remove({})
|
|
for i in range(100):
|
|
db.test.save({"x": i})
|
|
|
|
for i in db.test.find():
|
|
self.assertEqual(i["x"], 0)
|
|
break
|
|
|
|
for i in db.test.find().skip(20):
|
|
self.assertEqual(i["x"], 20)
|
|
break
|
|
|
|
for i in db.test.find().skip(99):
|
|
self.assertEqual(i["x"], 99)
|
|
break
|
|
|
|
for i in db.test.find().skip(1):
|
|
self.assertEqual(i["x"], 1)
|
|
break
|
|
|
|
for i in db.test.find().skip(0):
|
|
self.assertEqual(i["x"], 0)
|
|
break
|
|
|
|
for i in db.test.find().skip(0).skip(50).skip(10):
|
|
self.assertEqual(i["x"], 10)
|
|
break
|
|
|
|
for i in db.test.find().skip(1000):
|
|
self.fail()
|
|
|
|
a = db.test.find()
|
|
a.skip(10)
|
|
for _ in a:
|
|
break
|
|
self.assertRaises(InvalidOperation, a.skip, 5)
|
|
|
|
def test_sort(self):
|
|
db = Mongo("test", self.host, self.port)
|
|
|
|
self.assertRaises(TypeError, db.test.find().sort, 5)
|
|
self.assertRaises(TypeError, db.test.find().sort, "hello")
|
|
self.assertRaises(ValueError, db.test.find().sort, [])
|
|
self.assertRaises(TypeError, db.test.find().sort, [], ASCENDING)
|
|
self.assertRaises(TypeError, db.test.find().sort, [("hello", DESCENDING)], DESCENDING)
|
|
self.assertRaises(TypeError, db.test.find().sort, "hello", "world")
|
|
|
|
db.test.remove({})
|
|
|
|
unsort = range(10)
|
|
random.shuffle(unsort)
|
|
|
|
for i in unsort:
|
|
db.test.save({"x": i})
|
|
|
|
asc = [i["x"] for i in db.test.find().sort("x", ASCENDING)]
|
|
self.assertEqual(asc, range(10))
|
|
asc = [i["x"] for i in db.test.find().sort([("x", ASCENDING)])]
|
|
self.assertEqual(asc, range(10))
|
|
|
|
expect = range(10)
|
|
expect.reverse()
|
|
desc = [i["x"] for i in db.test.find().sort("x", DESCENDING)]
|
|
self.assertEqual(desc, expect)
|
|
desc = [i["x"] for i in db.test.find().sort([("x", DESCENDING)])]
|
|
self.assertEqual(desc, expect)
|
|
desc = [i["x"] for i in db.test.find().sort("x", ASCENDING).sort("x", DESCENDING)]
|
|
self.assertEqual(desc, expect)
|
|
|
|
expected = [(1, 5), (2, 5), (0, 3), (7, 3), (9, 2), (2, 1), (3, 1)]
|
|
shuffled = list(expected)
|
|
random.shuffle(shuffled)
|
|
|
|
db.test.remove({})
|
|
for (a, b) in shuffled:
|
|
db.test.save({"a": a, "b": b})
|
|
|
|
result = [(i["a"], i["b"]) for i in db.test.find().sort([("b", DESCENDING),
|
|
("a", ASCENDING)])]
|
|
self.assertEqual(result, expected)
|
|
|
|
a = db.test.find()
|
|
a.sort("x", ASCENDING)
|
|
for _ in a:
|
|
break
|
|
self.assertRaises(InvalidOperation, a.sort, "x", ASCENDING)
|
|
|
|
def test_count(self):
|
|
db = Mongo("test", self.host, self.port)
|
|
db.test.remove({})
|
|
|
|
self.assertEqual(0, db.test.find().count())
|
|
|
|
for i in range(10):
|
|
db.test.save({"x": i})
|
|
|
|
self.assertEqual(10, db.test.find().count())
|
|
self.assertTrue(isinstance(db.test.find().count(), types.IntType))
|
|
self.assertEqual(10, db.test.find().limit(5).count())
|
|
self.assertEqual(10, db.test.find().skip(5).count())
|
|
|
|
self.assertEqual(1, db.test.find({"x": 1}).count())
|
|
self.assertEqual(5, db.test.find({"x": {"$lt": 5}}).count())
|
|
|
|
a = db.test.find()
|
|
b = a.count()
|
|
for _ in a:
|
|
break
|
|
self.assertRaises(InvalidOperation, a.count)
|
|
|
|
self.assertEqual(0, db.test.acollectionthatdoesntexist.find().count())
|
|
|
|
def test_deref(self):
|
|
db = Mongo("test", self.host, self.port)
|
|
db.test.remove({})
|
|
|
|
self.assertRaises(TypeError, db.dereference, 5)
|
|
self.assertRaises(TypeError, db.dereference, "hello")
|
|
self.assertRaises(TypeError, db.dereference, None)
|
|
|
|
self.assertEqual(None, db.dereference(DBRef("test", ObjectId())))
|
|
|
|
obj = {"x": True}
|
|
key = db.test.save(obj)
|
|
self.assertEqual(obj, db.dereference(DBRef("test", key)))
|
|
|
|
def test_auto_deref(self):
|
|
db = Mongo("test", self.host, self.port)
|
|
db.test.a.remove({})
|
|
db.test.b.remove({})
|
|
db.test.remove({})
|
|
|
|
a = {"hello": u"world"}
|
|
key = db.test.b.save(a)
|
|
dbref = DBRef("test.b", key)
|
|
|
|
self.assertEqual(db.dereference(dbref), a)
|
|
|
|
b = {"b_obj": dbref}
|
|
db.test.a.save(b)
|
|
self.assertEqual(dbref, db.test.a.find_one()["b_obj"])
|
|
self.assertEqual(a, db.dereference(db.test.a.find_one()["b_obj"]))
|
|
|
|
db = Mongo("test", self.host, self.port, {"auto_dereference": False})
|
|
self.assertEqual(dbref, db.test.a.find_one()["b_obj"])
|
|
|
|
db = Mongo("test", self.host, self.port, {"auto_dereference": True})
|
|
self.assertNotEqual(dbref, db.test.a.find_one()["b_obj"])
|
|
self.assertEqual(a, db.test.a.find_one()["b_obj"])
|
|
|
|
key2 = db.test.a.save({"x": [dbref]})
|
|
self.assertEqual(a, db.test.a.find_one(key2)["x"][0])
|
|
|
|
dbref2 = DBRef("test.a", key2)
|
|
key3 = db.test.b.save({"x": dbref2})
|
|
self.assertEqual(a, db.test.b.find_one(key3)["x"]["x"][0])
|
|
|
|
dbref = DBRef("test.c", ObjectId())
|
|
key = db.test.save({"x": dbref})
|
|
self.assertEqual(dbref, db.test.find_one(key)["x"])
|
|
|
|
def test_auto_ref(self):
|
|
db = Mongo("test", self.host, self.port)
|
|
db.test.a.remove({})
|
|
db.test.b.remove({})
|
|
|
|
a = SON({u"hello": u"world"})
|
|
db.test.a.save(a)
|
|
self.assertEqual(a["_ns"], "test.a")
|
|
|
|
b = SON({"ref?": a})
|
|
key = db.test.b.save(b)
|
|
self.assertEqual(b["_ns"], "test.b")
|
|
self.assertEqual(b["ref?"], a)
|
|
self.assertEqual(db.test.b.find_one(key)["ref?"], a)
|
|
|
|
db = Mongo("test", self.host, self.port, {"auto_reference": False})
|
|
key = db.test.b.save(b)
|
|
self.assertEqual(b["_ns"], "test.b")
|
|
self.assertEqual(b["ref?"], a)
|
|
self.assertEqual(db.test.b.find_one(key)["ref?"], a)
|
|
|
|
db = Mongo("test", self.host, self.port, {"auto_reference": True})
|
|
key = db.test.b.save(b)
|
|
self.assertEqual(b["_ns"], "test.b")
|
|
self.assertEqual(b["ref?"], a)
|
|
self.assertNotEqual(db.test.b.find_one(key)["ref?"], a)
|
|
self.assertEqual(db.dereference(db.test.b.find_one(key)["ref?"]), a)
|
|
|
|
def test_auto_ref_and_deref(self):
|
|
db = Mongo("test", self.host, self.port, {"auto_reference": True, "auto_dereference": True})
|
|
db.test.a.remove({})
|
|
db.test.b.remove({})
|
|
db.test.c.remove({})
|
|
|
|
a = SON({"hello": u"world"})
|
|
b = SON({"test": a})
|
|
c = SON({"another test": b})
|
|
|
|
db.test.a.save(a)
|
|
db.test.b.save(b)
|
|
db.test.c.save(c)
|
|
|
|
self.assertEqual(db.test.a.find_one(), a)
|
|
self.assertEqual(db.test.b.find_one()["test"], a)
|
|
self.assertEqual(db.test.c.find_one()["another test"]["test"], a)
|
|
self.assertEqual(db.test.b.find_one(), b)
|
|
self.assertEqual(db.test.c.find_one()["another test"], b)
|
|
self.assertEqual(db.test.c.find_one(), c)
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|