diff --git a/collection.py b/collection.py new file mode 100644 index 000000000..def28c2b1 --- /dev/null +++ b/collection.py @@ -0,0 +1,259 @@ +"""Collection level utilities for Mongo.""" + +import types + +import bson +from objectid import ObjectId +from cursor import Cursor +from son import SON +from errors import InvalidName, OperationFailure + +_ZERO = "\x00\x00\x00\x00" +_ONE = "\x01\x00\x00\x00" +SYSTEM_INDEX_COLLECTION = "system.indexes" + +class Collection(object): + """A Mongo collection. + """ + def __init__(self, database, name): + """Get / create a Mongo collection. + + Raises TypeError if name is not an instance of (str, unicode). Raises + InvalidName if name is not a valid collection name. + + Arguments: + - `database`: the database to get a collection from + - `name`: the name of the collection to get + """ + if not isinstance(name, types.StringTypes): + raise TypeError("name must be an instance of (str, unicode)") + + if not name or ".." in name: + raise InvalidName("collection names cannot be empty") + if "$" in name and name not in ["$cmd"]: + raise InvalidName("collection names must not contain '$'") + if name[0] == "." or name[-1] == ".": + raise InvalidName("collecion names must not start or end with '.'") + + self.__database = database + self.__collection_name = unicode(name) + + def __getattr__(self, name): + """Get a sub-collection of this collection by name. + + Raises InvalidName if an invalid collection name is used. + + Arguments: + - `name`: the name of the collection to get + """ + return Collection(self.__database, u"%s.%s" % (self.__collection_name, name)) + + def __getitem__(self, name): + return self.__getattr__(name) + + def __repr__(self): + return "Collection(%r, %r)" % (self.__database, self.__collection_name) + + def __cmp__(self, other): + if isinstance(other, Collection): + return cmp((self.__database, self.__collection_name), + (other.__database, other.__collection_name)) + return NotImplemented + + def _full_name(self): + return u"%s.%s" % (self.__database.name(), self.__collection_name) + + def _name(self): + return self.__collection_name + + def _send_message(self, operation, data): + """Wrap up a message and send it. + """ + # reserved int, full collection name, message data + message = _ZERO + message += bson._make_c_string(self._full_name()) + message += data + return self.__database.connection().send_message(operation, message) + + def database(self): + return self.__database + + def save(self, to_save, add_meta=True): + """Save a SON object in this collection. + + Raises TypeError if to_save is not an instance of (dict, SON). + + Arguments: + - `to_save`: the SON object to be saved + - `add_meta` (optional): add meta information (like _id) to the object + if it's missing + """ + if not isinstance(to_save, (types.DictType, SON)): + raise TypeError("cannot save object of type %s" % type(to_save)) + + to_save = self.__database._fix_incoming(to_save, self, add_meta) + + if "_id" not in to_save: + self._send_message(2002, bson.BSON.from_dict(to_save)) + elif to_save["_id"].is_new(): + to_save["_id"]._use() + self._send_message(2002, bson.BSON.from_dict(to_save)) + else: + self._update({"_id": to_save["_id"]}, to_save, True) + + return to_save.get("_id", None) + + def _update(self, spec, document, upsert=False): + """Update an object(s) in this collection. + + Raises TypeError if either spec or document isn't an instance of + (dict, SON) or upsert isn't an instance of bool. + + - `spec`: a SON object specifying elements which must be present for a + document to be updated + - `document`: a SON object specifying the fields to be changed in the + selected document(s), or (in the case of an upsert) the document to + be inserted. + - `upsert` (optional): perform an upsert operation + """ + if not isinstance(spec, (types.DictType, SON)): + raise TypeError("spec must be an instance of (dict, SON)") + if not isinstance(document, (types.DictType, SON)): + raise TypeError("document must be an instance of (dict, SON)") + if not isinstance(upsert, types.BooleanType): + raise TypeError("upsert must be an instance of bool") + + message = upsert and _ONE or _ZERO + message += bson.BSON.from_dict(spec) + message += bson.BSON.from_dict(document) + + self._send_message(2001, message) + + def remove(self, spec_or_object_id): + """Remove an object(s) from this collection. + + Raises TypeEror if the argument is not an instance of + (dict, SON, ObjectId). + + Arguments: + - `spec_or_object_id` (optional): a SON object specifying elements + which must be present for a document to be removed OR an instance of + ObjectId to be used as the value for an _id element + """ + spec = spec_or_object_id + if isinstance(spec, ObjectId): + spec = SON({"_id": spec}) + + if not isinstance(spec, (types.DictType, SON)): + raise TypeError("spec must be an instance of (dict, SON)") + + self._send_message(2006, _ZERO + bson.BSON.from_dict(spec)) + + def find_one(self, spec_or_object_id=SON()): + """Get a single object from the database. + + Raises TypeError if the argument is of an improper type. Returns a + single SON object, or None if no result is found. + + Arguments: + - `spec_or_object_id` (optional): a SON object specifying elements + which must be present for a document to be included in the result + set OR an instance of ObjectId to be used as the value for an _id + query + """ + spec = spec_or_object_id + if isinstance(spec, ObjectId): + spec = SON({"_id": spec}) + + for result in self.find(spec, limit=1): + return result + return None + + def find(self, spec=SON(), fields=[], skip=0, limit=0): + """Query the database. + + Raises TypeError if any of the arguments are of improper type. Returns + an instance of Cursor corresponding to this query. + + Arguments: + - `spec` (optional): a SON object specifying elements which must be + present for a document to be included in the result set + - `fields` (optional): a list of field names that should be returned + in the result set + - `skip` (optional): the number of documents to omit (from the start of + the result set) when returning the results + - `limit` (optional): the maximum number of results to return in the + first reply message, or 0 for the default return size + """ + if not isinstance(spec, (types.DictType, SON)): + raise TypeError("spec must be an instance of (dict, SON)") + if not isinstance(fields, types.ListType): + raise TypeError("fields must be an instance of list") + if not isinstance(skip, types.IntType): + raise TypeError("skip must be an instance of int") + if not isinstance(limit, types.IntType): + raise TypeError("limit must be an instance of int") + + return_fields = len(fields) and SON() or None + for field in fields: + if not isinstance(field, types.StringTypes): + raise TypeError("fields must be a list of key names as (string, unicode)") + return_fields[field] = 1 + + return Cursor(self, spec, return_fields, skip, limit) + + def _gen_index_name(self, keys): + """Generate an index name from the set of fields it is over. + """ + return u"_".join([u"%s_%s" % item for item in keys]) + + def create_index(self, key_or_list, direction=None): + """Creates an index on this collection. + + Takes either a single key and a direction, or a list of (key, direction) + pairs. The key(s) must be an instance of (str, unicode), and the + direction(s) must be one of (Mongo.ASCENDING, Mongo.DESCENDING). + + Arguments: + - `key_or_list`: a single key or a list of (key, direction) pairs + specifying the index to ensure + - `direction` (optional): must be included if key_or_list is a single + key, otherwise must be None + """ + if direction: + keys = [(key_or_list, direction)] + else: + keys = key_or_list + + if not isinstance(keys, types.ListType): + raise TypeError("if no direction is specified, key_or_list must be an instance of list") + if not len(keys): + raise ValueError("key_or_list must not be the empty list") + + to_save = SON() + to_save["name"] = self._gen_index_name(keys) + to_save["ns"] = self._full_name() + + key_object = SON() + for (key, value) in keys: + if not isinstance(key, types.StringTypes): + raise TypeError("first item in each key pair must be a string") + if not isinstance(value, types.IntType): + raise TypeError("second item in each key pair must be Mongo.ASCENDING or Mongo.DESCENDING") + key_object[key] = value + to_save["key"] = key_object + + self.__database[SYSTEM_INDEX_COLLECTION].save(to_save, False) + + def drop_indexes(self): + """Drops all indexes on this collection. + + Can be used on non-existant collections or collections with no indexes. + Raises OperationFailure on an error. + """ + response = self.__database.command(SON([("deleteIndexes", self.__collection_name), + ("index", u"*")])) + if response["ok"] != 1: + if response["errmsg"] == "ns not found": + return + raise OperationFailure("error ping indexes: %s" % response["errmsg"]) diff --git a/cursor.py b/cursor.py new file mode 100644 index 000000000..d9348467d --- /dev/null +++ b/cursor.py @@ -0,0 +1,215 @@ +"""Cursor class to iterate over Mongo query results.""" + +import types +import struct + +import bson +from son import SON +from errors import InvalidOperation, OperationFailure + +class Cursor(object): + """A cursor / iterator over Mongo query results. + """ + def __init__(self, collection, spec, fields, skip, limit): + """Create a new cursor. + + Should not be called directly by application developers. + """ + self.__collection = collection + self.__spec = spec + self.__fields = fields + self.__skip = skip + self.__limit = limit + self.__ordering = None + + self.__data = [] + self.__id = None + self.__retrieved = 0 + self.__killed = False + + def __del__(self): + if self.__id and not self.__killed: + self.__die() + + def __die(self): + """Kills this cursor. + """ + self.__collection.database()._kill_cursor(self.__id) + self.__killed = True + + def __query_spec(self): + """Get the spec to use for a query. + + Just `self.__spec`, unless this cursor needs special query fields, like + orderby. + """ + if not self.__ordering: + return self.__spec + return SON([("query", self.__spec), + ("orderby", self.__ordering)]) + + def __check_okay_to_chain(self): + """Check if it is okay to chain more options onto this cursor. + """ + if self.__retrieved or self.__id is not None: + raise InvalidOperation("cannot set options after executing query") + + def limit(self, limit): + """Limits the number of results to be returned by this cursor. + + Raises TypeError if limit is not an instance of int. Raises + InvalidOperation if this cursor has already been used. + + Arguments: + - `limit`: the number of results to return + """ + if not isinstance(limit, types.IntType): + raise TypeError("limit must be an int") + self.__check_okay_to_chain() + + self.__limit = limit + return self + + def skip(self, skip): + """Skips the first `skip` results of this cursor. + + Raises TypeError if skip is not an instance of int. Raises + InvalidOperation if this cursor has already been used. + + Arguments: + - `skip`: the number of results to skip + """ + if not isinstance(skip, types.IntType): + raise TypeError("skip must be an int") + self.__check_okay_to_chain() + + self.__skip = skip + return self + + def sort(self, key_or_list, direction=None): + """Sorts this cursors results. + + Takes either a single key and a direction, or a list of (key, direction) + pairs. The key(s) must be an instance of (str, unicode), and the + direction(s) must be one of (Mongo.ASCENDING, Mongo.DESCENDING). Raises + InvalidOperation if this cursor has already been used. + + Arguments: + - `key_or_list`: a single key or a list of (key, direction) pairs + specifying the keys to sort on + - `direction` (optional): must be included if key_or_list is a single + key, otherwise must be None + """ + self.__check_okay_to_chain() + + # TODO a lot of this logic could be shared with create_index() + if direction: + keys = [(key_or_list, direction)] + else: + keys = key_or_list + + if not isinstance(keys, types.ListType): + raise TypeError("if no direction is specified, key_or_list must be an instance of list") + if not len(keys): + raise ValueError("key_or_list must not be the empty list") + + orderby = SON() + for (key, value) in keys: + if not isinstance(key, types.StringTypes): + raise TypeError("first item in each key pair must be a string") + if not isinstance(value, types.IntType): + raise TypeError("second item in each key pair must be Mongo.ASCENDING or Mongo.DESCENDING") + orderby[key] = value + + self.__ordering = orderby + return self + + def count(self): + """Get the size of the results set for this query. + + Returns the number of objects in the results set for this query. Does + not take limit and skip into account. Raises InvalidOperation if this + cursor has already been used. Raises OperationFailure on a database + error. + """ + self.__check_okay_to_chain() + + command = SON([("count", self.__collection._name()), + ("query", self.__spec)]) + response = self.__collection.database().command(command) + if response["ok"] != 1: + if response["errmsg"] == "ns does not exist": + return 0 + raise OperationFailure("error getting count: %s" % response["errmsg"]) + return int(response["n"]) + + def _refresh(self): + """Refreshes the cursor with more data from Mongo. + + Returns the length of self.__data after refresh. Will exit early if + self.__data is already non-empty. Raises OperationFailure when the + cursor cannot be refreshed due to an error on the query. + """ + if len(self.__data) or self.__killed: + return len(self.__data) + + def send_message(operation, message): + # TODO the send and receive here should be synchronized... + request_id = self.__collection._send_message(operation, message) + response = self.__collection.database().connection().receive_message(1, request_id) + + response_flag = struct.unpack(" self.__retrieved: + limit = self.__limit - self.__retrieved + else: + self.__die() + return 0 + + message = struct.pack("