DEPRECATING method style access for several methods:

Connection.host() and port(), Database.connection() and name(),
and Collection.database(), name() and full_name(). Use attribute/property
style access for all of these instead now.
This commit is contained in:
Mike Dirolf 2009-12-04 15:07:42 -05:00
parent d7bfcff824
commit 99f08e541e
16 changed files with 237 additions and 112 deletions

View File

@ -14,8 +14,8 @@
Raises :class:`~pymongo.errors.InvalidName` if an invalid database name is used.
.. automethod:: host
.. automethod:: port
.. autoattribute:: host
.. autoattribute:: port
.. autoattribute:: slave_okay
.. automethod:: database_names
.. automethod:: drop_database

View File

@ -107,7 +107,8 @@ class GridFile(object):
raise ValueError("mode must be one of ('r', 'w')")
self.__collection = database[collection]
self.__collection.chunks.ensure_index([("files_id", ASCENDING), ("n", ASCENDING)], unique=True)
self.__collection.chunks.ensure_index([("files_id", ASCENDING),
("n", ASCENDING)], unique=True)
_files_lock.acquire()
@ -226,8 +227,8 @@ class GridFile(object):
self.__flush_write_buffer()
md5 = self.__collection.database()._command(SON([("filemd5", self.__id),
("root", self.__collection.name())]))["md5"]
md5 = self.__collection.database._command(SON([("filemd5", self.__id),
("root", self.__collection.name)]))["md5"]
grid_file = self.__collection.files.find_one({"_id": self.__id})
grid_file["md5"] = md5

View File

@ -65,7 +65,15 @@ class Collection(object):
"or end with '.': %r" % name)
self.__database = database
self.__collection_name = unicode(name)
self.__name = unicode(name)
self.__full_name = u"%s.%s" % (self.__database.name, self.__name)
# TODO remove the callable_value wrappers after deprecation is complete
self.__database_w = helpers.callable_value(self.__database,
"Collection.database")
self.__name_w = helpers.callable_value(self.__name,
"Collection.name")
self.__full_name_w = helpers.callable_value(self.__full_name,
"Collection.full_name")
if options is not None:
self.__create(options)
@ -78,7 +86,7 @@ class Collection(object):
if "size" in options:
options["size"] = float(options["size"])
command = SON({"create": self.__collection_name})
command = SON({"create": self.__name})
command.update(options)
self.__database._command(command)
@ -91,37 +99,52 @@ class Collection(object):
:Parameters:
- `name`: the name of the collection to get
"""
return Collection(self.__database, u"%s.%s" % (self.__collection_name,
name))
return Collection(self.__database, u"%s.%s" % (self.__name, name))
def __getitem__(self, name):
return self.__getattr__(name)
def __repr__(self):
return "Collection(%r, %r)" % (self.__database, self.__collection_name)
return "Collection(%r, %r)" % (self.__database, self.__name)
def __cmp__(self, other):
if isinstance(other, Collection):
return cmp((self.__database, self.__collection_name),
(other.__database, other.__collection_name))
return cmp((self.__database, self.__name),
(other.__database, other.__name))
return NotImplemented
def full_name(self):
"""Get the full name of this collection.
"""The full name of this :class:`Collection`.
The full name is of the form database_name.collection_name.
The full name is of the form `database_name.collection_name`.
.. versionchanged:: 1.1.2+
``full_name`` is now a property rather than a method. The
``full_name()`` method is deprecated.
"""
return u"%s.%s" % (self.__database.name(), self.__collection_name)
return self.__full_name_w
full_name = property(full_name)
def name(self):
"""Get the name of this collection.
"""The name of this :class:`Collection`.
.. versionchanged:: 1.1.2+
``name`` is now a property rather than a method. The
``name()`` method is deprecated.
"""
return self.__collection_name
return self.__name_w
name = property(name)
def database(self):
"""Get the database that this collection is a part of.
"""The :class:`~pymongo.database.Database` that this
:class:`Collection` is a part of.
.. versionchanged:: 1.1.2+
``database`` is now a property rather than a method. The
``database()`` method is deprecated.
"""
return self.__database
return self.__database_w
database = property(database)
def save(self, to_save, manipulate=True, safe=False):
"""Save a document in this collection.
@ -181,8 +204,8 @@ class Collection(object):
if manipulate:
docs = [self.__database._fix_incoming(doc, self) for doc in docs]
self.__database.connection()._send_message(
message.insert(self.full_name(), docs, check_keys, safe), safe)
self.__database.connection._send_message(
message.insert(self.__full_name, docs, check_keys, safe), safe)
ids = [doc.get("_id", None) for doc in docs]
return len(ids) == 1 and ids[0] or ids
@ -242,8 +265,8 @@ class Collection(object):
if upsert and manipulate:
document = self.__database._fix_incoming(document, self)
self.__database.connection()._send_message(
message.update(self.full_name(), upsert, multi,
self.__database.connection._send_message(
message.update(self.__full_name, upsert, multi,
spec, document, safe), safe)
def remove(self, spec_or_object_id=None, safe=False):
@ -291,8 +314,8 @@ class Collection(object):
raise TypeError("spec must be an instance of dict, not %s" %
type(spec))
self.__database.connection()._send_message(
message.delete(self.full_name(), spec, safe), safe)
self.__database.connection._send_message(
message.delete(self.__full_name, spec, safe), safe)
def find_one(self, spec_or_object_id=None, fields=None, slave_okay=None,
_sock=None, _must_use_master=False):
@ -394,7 +417,7 @@ class Collection(object):
if spec is None:
spec = SON()
if slave_okay is None:
slave_okay = self.__database.connection().slave_okay
slave_okay = self.__database.connection.slave_okay
else:
warnings.warn("The slave_okay option to find and find_one is "
"deprecated. Please set slave_okay on the Connection "
@ -471,15 +494,14 @@ class Collection(object):
keys = helpers._index_list(key_or_list)
name = self._gen_index_name(keys)
to_save["name"] = name
to_save["ns"] = self.full_name()
to_save["ns"] = self.__full_name
to_save["key"] = helpers._index_document(keys)
to_save["unique"] = unique
self.database().connection()._cache_index(self.__database.name(),
self.name(),
name, ttl)
self.__database.connection._cache_index(self.__database.name,
self.__name, name, ttl)
self.database().system.indexes.insert(to_save, manipulate=False,
self.__database.system.indexes.insert(to_save, manipulate=False,
check_keys=False)
return to_save["name"]
@ -528,9 +550,8 @@ class Collection(object):
keys = helpers._index_list(key_or_list)
name = self._gen_index_name(keys)
if self.database().connection()._cache_index(self.__database.name(),
self.name(),
name, ttl):
if self.__database.connection._cache_index(self.__database.name,
self.__name, name, ttl):
return self.create_index(key_or_list, unique=unique, ttl=ttl)
return None
@ -540,8 +561,8 @@ class Collection(object):
Can be used on non-existant collections or collections with no indexes.
Raises OperationFailure on an error.
"""
self.database().connection()._purge_index(self.database().name(),
self.name())
self.__database.connection._purge_index(self.__database.name,
self.__name)
self.drop_index(u"*")
def drop_index(self, index_or_name):
@ -564,10 +585,10 @@ class Collection(object):
if not isinstance(name, types.StringTypes):
raise TypeError("index_or_name must be an index name or list")
self.database().connection()._purge_index(self.database().name(),
self.name(), name)
self.__database.connection._purge_index(self.__database.name,
self.__name, name)
self.__database._command(SON([("deleteIndexes",
self.__collection_name),
self.__name),
("index", name)]),
["ns not found"])
@ -578,7 +599,7 @@ class Collection(object):
create_index()) and the values are lists of (key, direction) pairs
specifying the index (as passed to create_index()).
"""
raw = self.__database.system.indexes.find({"ns": self.full_name()})
raw = self.__database.system.indexes.find({"ns": self.__full_name})
info = {}
for index in raw:
info[index["name"]] = index["key"].items()
@ -593,7 +614,7 @@ class Collection(object):
has not been created yet.
"""
result = self.__database.system.namespaces.find_one(
{"name": self.full_name()})
{"name": self.__full_name})
if not result:
return {}
@ -639,7 +660,7 @@ class Collection(object):
if command:
if not isinstance(reduce, Code):
reduce = Code(reduce)
group = {"ns": self.__collection_name,
group = {"ns": self.__name,
"$reduce": reduce,
"key": self._fields_list_to_dict(keys),
"cond": condition,
@ -656,7 +677,7 @@ class Collection(object):
scope = {}
if isinstance(reduce, Code):
scope = reduce.scope
scope.update({"ns": self.__collection_name,
scope.update({"ns": self.__name,
"keys": keys,
"condition": condition,
"initial": initial})
@ -718,11 +739,11 @@ class Collection(object):
if new_name[0] == "." or new_name[-1] == ".":
raise InvalidName("collecion names must not start or end with '.'")
rename_command = SON([("renameCollection", self.full_name()),
("to", "%s.%s" % (self.__database.name(),
rename_command = SON([("renameCollection", self.__full_name),
("to", "%s.%s" % (self.__database.name,
new_name))])
self.__database.connection().admin._command(rename_command)
self.__database.connection.admin._command(rename_command)
def distinct(self, key):
"""Get a list of distinct values for `key` among all documents in this
@ -770,7 +791,7 @@ class Collection(object):
.. _map reduce command: http://www.mongodb.org/display/DOCS/MapReduce
"""
command = SON([("mapreduce", self.__collection_name),
command = SON([("mapreduce", self.__name),
("map", map), ("reduce", reduce)])
command.update(**kwargs)
@ -788,13 +809,13 @@ class Collection(object):
def __call__(self, *args, **kwargs):
"""This is only here so that some API misusages are easier to debug.
"""
if "." not in self.__collection_name:
if "." not in self.__name:
raise TypeError("'Collection' object is not callable. If you "
"meant to call the '%s' method on a 'Database' "
"object it is failing because no such method "
"exists." %
self.__collection_name)
self.__name)
raise TypeError("'Collection' object is not callable. If you meant to "
"call the '%s' method on a 'Collection' object it is "
"failing because no such method exists." %
self.__collection_name.split(".")[-1])
self.__name.split(".")[-1])

View File

@ -285,17 +285,25 @@ class Connection(object): # TODO support auth for pooling
if index_name in self.__index_cache[database_name][collection_name]:
del self.__index_cache[database_name][collection_name][index_name]
# TODO these really should be properties... Could be ugly to make that
# backwards compatible though...
def host(self):
"""Get the connection's current host.
"""Current connected host.
.. versionchanged:: 1.1.2+
``host`` is now a property rather than a method. The ``host()``
method is deprecated.
"""
return self.__host
return helpers.callable_value(self.__host, "Connection.host")
host = property(host)
def port(self):
"""Get the connection's current port.
"""Current connected port.
.. versionchanged:: 1.1.2+
``port`` is now a property rather than a method. The ``port()``
method is deprecated.
"""
return self.__port
return helpers.callable_value(self.__port, "Connection.port")
port = property(port)
def slave_okay(self):
"""Is it okay for this connection to connect directly to a slave?
@ -306,7 +314,7 @@ class Connection(object): # TODO support auth for pooling
def __find_master(self):
"""Create a new socket and use it to figure out who the master is.
Sets __host and __port so that `host()` and `port()` will return the
Sets __host and __port so that :attr:`host` and :attr:`port` will return the
address of the master.
"""
_logger.debug("finding master")
@ -365,7 +373,7 @@ class Connection(object): # TODO support auth for pooling
Connect to the master if this is a paired connection.
"""
if self.host() is None or self.port() is None:
if self.__host is None or self.__port is None:
self.__find_master()
_logger.debug("connecting socket %s..." % socket_number)
@ -377,7 +385,7 @@ class Connection(object): # TODO support auth for pooling
socket.TCP_NODELAY, 1)
sock = self.__sockets[socket_number]
sock.settimeout(_CONNECT_TIMEOUT)
sock.connect((self.host(), self.port()))
sock.connect((self.__host, self.__port))
sock.settimeout(self.__network_timeout)
_logger.debug("connected")
return
@ -751,7 +759,7 @@ class Connection(object): # TODO support auth for pooling
"""
name = name_or_database
if isinstance(name, Database):
name = name.name()
name = name.name
if not isinstance(name, types.StringTypes):
raise TypeError("name_or_database must be an instance of "

View File

@ -65,7 +65,8 @@ class Cursor(object):
self.__killed = False
def collection(self):
"""Get the collection for this cursor.
"""The :class:`~pymongo.collection.Collection` that this
:class:`Cursor` is iterating.
.. versionadded:: 1.1
"""
@ -114,7 +115,7 @@ class Cursor(object):
"""Closes this cursor.
"""
if self.__id and not self.__killed:
connection = self.__collection.database().connection()
connection = self.__collection.database.connection
if self.__connection_id is not None:
connection.close_cursor(self.__id, self.__connection_id)
else:
@ -293,7 +294,7 @@ class Cursor(object):
:meth:`~pymongo.cursor.Cursor.__len__` was deprecated in favor of
calling :meth:`count` with `with_limit_and_skip` set to ``True``.
"""
command = SON([("count", self.__collection.name()),
command = SON([("count", self.__collection.name),
("query", self.__spec),
("fields", self.__fields)])
@ -303,8 +304,8 @@ class Cursor(object):
if self.__skip:
command["skip"] = self.__skip
response = self.__collection.database()._command(command,
["ns missing"])
response = self.__collection.database._command(command,
["ns missing"])
if response.get("errmsg", "") == "ns missing":
return 0
return int(response["n"])
@ -328,12 +329,12 @@ class Cursor(object):
if not isinstance(key, types.StringTypes):
raise TypeError("key must be an instance of (str, unicode)")
command = SON([("distinct", self.__collection.name()), ("key", key)])
command = SON([("distinct", self.__collection.name), ("key", key)])
if self.__spec:
command["query"] = self.__spec
return self.__collection.database()._command(command)["values"]
return self.__collection.database._command(command)["values"]
def explain(self):
"""Returns an explain plan record for this cursor.
@ -403,14 +404,14 @@ class Cursor(object):
def __send_message(self, message):
"""Send a query or getmore message and handles the response.
"""
db = self.__collection.database()
db = self.__collection.database
kwargs = {"_sock": self.__socket,
"_must_use_master": self.__must_use_master}
if self.__connection_id is not None:
kwargs["_connection_to_use"] = self.__connection_id
response = db.connection()._send_message_with_response(message,
**kwargs)
response = db.connection._send_message_with_response(message,
**kwargs)
if isinstance(response, types.TupleType):
(connection_id, response) = response
@ -422,7 +423,7 @@ class Cursor(object):
try:
response = helpers._unpack_response(response, self.__id)
except AutoReconnect:
db.connection()._reset()
db.connection._reset()
raise
self.__id = response["cursor_id"]
@ -450,7 +451,7 @@ class Cursor(object):
# Query
self.__send_message(
message.query(self.__query_options(),
self.__collection.full_name(),
self.__collection.full_name,
self.__skip, self.__limit,
self.__query_spec(), self.__fields))
if not self.__id:
@ -466,7 +467,7 @@ class Cursor(object):
return 0
self.__send_message(
message.get_more(self.__collection.full_name(),
message.get_more(self.__collection.full_name,
limit, self.__id))
return len(self.__data)
@ -475,7 +476,7 @@ class Cursor(object):
return self
def next(self):
db = self.__collection.database()
db = self.__collection.database
if len(self.__data) or self._refresh():
next = db._fix_outgoing(self.__data.pop(0), self.__collection)
else:

View File

@ -52,6 +52,10 @@ class Database(object):
self.__name = unicode(name)
self.__connection = connection
# TODO remove the callable_value wrappers after deprecation is complete
self.__name_w = helpers.callable_value(self.__name, "Database.name")
self.__connection_w = helpers.callable_value(self.__connection, "Database.connection")
self.__incoming_manipulators = []
self.__incoming_copying_manipulators = []
self.__outgoing_manipulators = []
@ -90,14 +94,25 @@ class Database(object):
self.__outgoing_manipulators.insert(0, manipulator)
def connection(self):
"""Get the database connection.
"""The :class:`~pymongo.connection.Connection` instance for this
:class:`Database`.
.. versionchanged:: 1.1.2+
``connection`` is now a property rather than a method. The
``connection()`` method is deprecated.
"""
return self.__connection
return self.__connection_w
connection = property(connection)
def name(self):
"""Get the database name.
"""The name of this :class:`Database`.
.. versionchanged:: 1.1.2+
``name`` is now a property rather than a method. The
``name()`` method is deprecated.
"""
return self.__name
return self.__name_w
name = property(name)
def __cmp__(self, other):
if isinstance(other, Database):
@ -210,13 +225,13 @@ class Database(object):
"""
name = name_or_collection
if isinstance(name, Collection):
name = name.name()
name = name.name
if not isinstance(name, types.StringTypes):
raise TypeError("name_or_collection must be an instance of "
"(Collection, str, unicode)")
self.connection()._purge_index(self.name(), name)
self.__connection._purge_index(self.__name, name)
if name not in self.collection_names():
return
@ -231,7 +246,7 @@ class Database(object):
"""
name = name_or_collection
if isinstance(name, Collection):
name = name.name()
name = name.name
if not isinstance(name, types.StringTypes):
raise TypeError("name_or_collection must be an instance of "

View File

@ -16,6 +16,7 @@
import sys
import struct
import warnings
from son import SON
from errors import OperationFailure, AutoReconnect
@ -56,6 +57,7 @@ def _index_document(index_list):
index[key] = value
return index
def _reversed(l):
"""A version of the `reversed()` built-in for Python 2.3.
"""
@ -66,6 +68,7 @@ def _reversed(l):
if sys.version_info[:3] >= (2, 4, 0):
_reversed = reversed
def _unpack_response(response, cursor_id=None):
"""Unpack a response from the database.
@ -101,3 +104,38 @@ def _unpack_response(response, cursor_id=None):
result["data"] = bson._to_dicts(response[20:])
assert len(result["data"]) == result["number_returned"]
return result
# These two functions are some magic to get values we can use for deprecating
# method style access in favor of property style access while remaining
# backwards compatible.
def __prop_call(self, *args, **kwargs):
warnings.warn("'%s()' has been deprecated and will be removed. "
"Please use '%s' instead." %
(self.__prop_name, self.__prop_name),
DeprecationWarning)
return self
__class_cache = {}
def callable_value(value, prop_name):
t = type(value)
if "CallableVal" in str(t):
return value
if t in __class_cache:
cls = __class_cache[t]
else:
cls = type.__new__(type, "CallableVal", (t,),
{"__call__": __prop_call,
"__prop_name": prop_name})
__class_cache[t] = cls
try:
# This works for regular classes
value.__class__ = cls
return value
except:
# This works for builtins
return cls(value)

View File

@ -105,7 +105,7 @@ class NamespaceInjector(SONManipulator):
def transform_incoming(self, son, collection):
"""Add the _ns field to the incoming object
"""
son["_ns"] = collection.name()
son["_ns"] = collection.name
return son

View File

@ -134,8 +134,8 @@ class TestCollection(unittest.TestCase):
db.test.ensure_index("goodbye"))
self.assertEqual(None, db.test.ensure_index("goodbye"))
db_name = self.db.name()
self.connection.drop_database(self.db.name())
db_name = self.db.name
self.connection.drop_database(self.db.name)
self.assertEqual("goodbye_1",
db.test.ensure_index("goodbye"))
self.assertEqual(None, db.test.ensure_index("goodbye"))
@ -505,7 +505,7 @@ class TestCollection(unittest.TestCase):
def test_multi_update(self):
db = self.db
if not version.at_least(db.connection(), (1, 1, 3, -1)):
if not version.at_least(db.connection, (1, 1, 3, -1)):
raise SkipTest()
db.drop_collection("test")
@ -532,7 +532,7 @@ class TestCollection(unittest.TestCase):
def test_safe_update(self):
db = self.db
v113minus = version.at_least(db.connection(), (1, 1, 3, -1))
v113minus = version.at_least(db.connection, (1, 1, 3, -1))
db.drop_collection("test")
db.test.create_index("x", unique=True)
@ -571,7 +571,7 @@ class TestCollection(unittest.TestCase):
db.test.remove({"x": 1})
self.assertEqual(1, db.test.count())
if version.at_least(db.connection(), (1, 1, 3, -1)):
if version.at_least(db.connection, (1, 1, 3, -1)):
self.assertRaises(OperationFailure, db.test.remove, {"x": 1}, safe=True)
else: # Just test that it doesn't blow up
db.test.remove({"x": 1}, safe=True)
@ -676,7 +676,7 @@ class TestCollection(unittest.TestCase):
Code(reduce_function,
{"inc_value": 0.5}))[0]['count'])
if version.at_least(db.connection(), (1, 1)):
if version.at_least(db.connection, (1, 1)):
self.assertEqual(2, db.test.group([], {}, {"count": 0},
Code(reduce_function,
{"inc_value": 1}),
@ -806,7 +806,7 @@ class TestCollection(unittest.TestCase):
list(self.db.test.find(timeout=True))
def test_distinct(self):
if not version.at_least(self.db.connection(), (1, 1)):
if not version.at_least(self.db.connection, (1, 1)):
raise SkipTest()
self.db.drop_collection("test")
@ -851,7 +851,7 @@ class TestCollection(unittest.TestCase):
{"foo": "x" * 2 * 1024 * 1024}], safe=True)
def test_map_reduce(self):
if not version.at_least(self.db.connection(), (1, 1, 1)):
if not version.at_least(self.db.connection, (1, 1, 1)):
raise SkipTest()
db = self.db

View File

@ -16,6 +16,7 @@
import unittest
import os
import warnings
import sys
sys.path[0:0] = [""]
@ -72,8 +73,8 @@ class TestConnection(unittest.TestCase):
"Connection('%s', %s)" % (self.host, self.port))
def test_getters(self):
self.assertEqual(Connection(self.host, self.port).host(), self.host)
self.assertEqual(Connection(self.host, self.port).port(), self.port)
self.assertEqual(Connection(self.host, self.port).host, self.host)
self.assertEqual(Connection(self.host, self.port).port, self.port)
def test_get_db(self):
connection = Connection(self.host, self.port)
@ -154,6 +155,45 @@ class TestConnection(unittest.TestCase):
except AssertionError:
self.fail()
# NOTE this probably doesn't all belong in this file, but it's easier in
# one place
def test_deprecated_method_for_attr(self):
c = Connection(self.host, self.port)
db = c.foo
coll = db.bar
warnings.simplefilter("error")
self.assertRaises(DeprecationWarning, c.host)
self.assertRaises(DeprecationWarning, c.port)
self.assertRaises(DeprecationWarning, db.connection)
self.assertRaises(DeprecationWarning, db.name)
self.assertRaises(DeprecationWarning, coll.full_name)
self.assertRaises(DeprecationWarning, coll.name)
self.assertRaises(DeprecationWarning, coll.database)
warnings.resetwarnings()
warnings.simplefilter("ignore")
self.assertEqual(c.host, c.host())
self.assertEqual(c.port, c.port())
self.assertEqual(db.connection, db.connection())
self.assertEqual(db.name, db.name())
self.assertEqual(coll.full_name, coll.full_name())
self.assertEqual(coll.name, coll.name())
self.assertEqual(coll.database, coll.database())
warnings.resetwarnings()
warnings.simplefilter("default")
self.assertEqual(self.host, c.host)
self.assertEqual(self.port, c.port)
self.assertEqual(c, db.connection)
self.assertEqual("foo", db.name)
self.assertEqual("foo.bar", coll.full_name)
self.assertEqual("bar", coll.name)
self.assertEqual(db, coll.database)
# TODO come up with a different way to test `network_timeout`. This is just
# too sketchy.
#

View File

@ -321,8 +321,9 @@ class TestCursor(unittest.TestCase):
client_cursors = db._command({"cursorInfo": 1})["clientCursors_size"]
by_location = db._command({"cursorInfo": 1})["byLocation_size"]
test = db.test
for i in range(10000):
db.test.insert({"i": i})
test.insert({"i": i})
self.assertEqual(client_cursors,
db._command({"cursorInfo": 1})["clientCursors_size"])
@ -448,7 +449,7 @@ class TestCursor(unittest.TestCase):
self.db.test.remove({})
self.db.test.save({"x": 1})
if not version.at_least(self.db.connection(), (1, 1, 3, -1)):
if not version.at_least(self.db.connection, (1, 1, 3, -1)):
for _ in self.db.test.find({}, ["a"]):
self.fail()
@ -533,7 +534,7 @@ class TestCursor(unittest.TestCase):
self.assertRaises(IndexError, lambda x: self.db.test.find().skip(50)[x], 50)
def test_count_with_limit_and_skip(self):
if not version.at_least(self.db.connection(), (1, 1, 4, -1)):
if not version.at_least(self.db.connection, (1, 1, 4, -1)):
raise SkipTest()
def check_len(cursor, length):
@ -598,7 +599,7 @@ class TestCursor(unittest.TestCase):
self.assertEqual(3, db.test.count())
def test_distinct(self):
if not version.at_least(self.db.connection(), (1, 1, 3, 1)):
if not version.at_least(self.db.connection, (1, 1, 3, 1)):
raise SkipTest()
self.db.drop_collection("test")

View File

@ -43,7 +43,7 @@ class TestDatabase(unittest.TestCase):
def test_name(self):
self.assertRaises(TypeError, Database, self.connection, 4)
self.assertRaises(InvalidName, Database, self.connection, "my db")
self.assertEqual("name", Database(self.connection, "name").name())
self.assertEqual("name", Database(self.connection, "name").name)
def test_cmp(self):
self.assertNotEqual(Database(self.connection, "test"),

View File

@ -73,13 +73,13 @@ class TestPaired(unittest.TestCase):
connection = Connection.paired(self.left, self.right)
self.assert_(connection)
host = connection.host()
port = connection.port()
host = connection.host
port = connection.port
connection = Connection.paired(self.right, self.left)
self.assert_(connection)
self.assertEqual(host, connection.host())
self.assertEqual(port, connection.port())
self.assertEqual(host, connection.host)
self.assertEqual(port, connection.port)
slave = self.left == (host, port) and self.right or self.left
self.assertRaises(ConfigurationError, Connection.paired,

View File

@ -32,7 +32,7 @@ class TestPooling(unittest.TestCase):
default_connection = Connection(self.host, self.port)
no_auto_connection = Connection(self.host, self.port,
auto_start_request=False)
pooled_connection = Connection(self.host, self.port,
pooled_connection = Connection(self.host, self.port,
pool_size=10, timeout=-1)
no_auto_pooled_connection = Connection(self.host, self.port,
pool_size=10, timeout=-1,
@ -74,7 +74,7 @@ class TestPooling(unittest.TestCase):
for _ in range(100):
self.default_db.test.remove({})
self.default_db.test.insert({})
self.default_db.connection().end_request()
self.default_db.connection.end_request()
if not self.default_db.test.find_one():
count += 1
self.assertEqual(0, count)
@ -92,7 +92,7 @@ class TestPooling(unittest.TestCase):
# for _ in range(6000):
# self.pooled_db.test.remove({})
# self.pooled_db.test.insert({})
# self.pooled_db.connection().end_request()
# self.pooled_db.connection.end_request()
# if not self.pooled_db.test.find_one():
# count += 1
# self.assertNotEqual(0, count)
@ -118,22 +118,22 @@ class TestPooling(unittest.TestCase):
count = 0
for _ in range(100):
self.no_auto_db.connection().start_request()
self.no_auto_db.connection.start_request()
self.no_auto_db.test.remove({})
self.no_auto_db.test.insert({})
if not self.no_auto_db.test.find_one():
count += 1
self.no_auto_db.connection().end_request()
self.no_auto_db.connection.end_request()
self.assertEqual(0, count)
count = 0
for _ in range(100):
self.no_auto_pooled_db.connection().start_request()
self.no_auto_pooled_db.connection.start_request()
self.no_auto_pooled_db.test.remove({})
self.no_auto_pooled_db.test.insert({})
if not self.no_auto_pooled_db.test.find_one():
count += 1
self.no_auto_pooled_db.connection().end_request()
self.no_auto_pooled_db.connection.end_request()
self.assertEqual(0, count)
def test_multithread(self):
@ -154,7 +154,7 @@ class SaveAndFind(threading.Thread):
rand = random.randint(0, 100)
id = self.database.mt_test.save({"x": rand})
assert self.database.mt_test.find_one(id)["x"] == rand
self.database.connection().end_request()
self.database.connection.end_request()
if __name__ == "__main__":
unittest.main()

View File

@ -31,8 +31,8 @@ class TestPyMongo(unittest.TestCase):
def test_connection_alias(self):
c = pymongo.Connection(self.host, self.port)
self.assert_(c)
self.assertEqual(c.host(), self.host)
self.assertEqual(c.port(), self.port)
self.assertEqual(c.host, self.host)
self.assertEqual(c.port, self.port)
if __name__ == "__main__":
unittest.main()

View File

@ -94,7 +94,7 @@ class TestSONManipulator(unittest.TestCase):
def incoming_adds_ns(son):
son = manip.transform_incoming(son, collection)
assert "_ns" in son
return son["_ns"] == collection.name()
return son["_ns"] == collection.name
qcheck.check_unittest(self, incoming_adds_ns,
qcheck.gen_mongo_dict(3))