Also PYTHON-162 and PYTHON-189. Credit goes to James Murty for most of the patch. With this change we cache auth credentials (user, password) in the driver so that each new socket can be automatically authenticated. This solves the problem of each new spawned thread having to re-authenticate.
457 lines
16 KiB
Python
457 lines
16 KiB
Python
# Copyright 2009-2010 10gen, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Test the connection module."""
|
|
|
|
import datetime
|
|
import os
|
|
import sys
|
|
import time
|
|
import unittest
|
|
import warnings
|
|
sys.path[0:0] = [""]
|
|
|
|
from nose.plugins.skip import SkipTest
|
|
|
|
from bson.son import SON
|
|
from bson.tz_util import utc
|
|
from pymongo.connection import Connection
|
|
from pymongo.database import Database
|
|
from pymongo.errors import (AutoReconnect,
|
|
ConfigurationError,
|
|
ConnectionFailure,
|
|
InvalidName,
|
|
InvalidURI,
|
|
OperationFailure)
|
|
from test import version
|
|
|
|
|
|
def get_connection(*args, **kwargs):
|
|
host = os.environ.get("DB_IP", "localhost")
|
|
port = int(os.environ.get("DB_PORT", 27017))
|
|
return Connection(host, port, *args, **kwargs)
|
|
|
|
|
|
class TestConnection(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
self.host = os.environ.get("DB_IP", "localhost")
|
|
self.port = int(os.environ.get("DB_PORT", 27017))
|
|
|
|
def test_types(self):
|
|
self.assertRaises(TypeError, Connection, 1)
|
|
self.assertRaises(TypeError, Connection, 1.14)
|
|
self.assertRaises(TypeError, Connection, "localhost", "27017")
|
|
self.assertRaises(TypeError, Connection, "localhost", 1.14)
|
|
self.assertRaises(TypeError, Connection, "localhost", [])
|
|
|
|
self.assertRaises(ConfigurationError, Connection, [])
|
|
|
|
def test_constants(self):
|
|
Connection.HOST = self.host
|
|
Connection.PORT = self.port
|
|
self.assert_(Connection())
|
|
|
|
Connection.HOST = "somedomainthatdoesntexist.org"
|
|
Connection.PORT = 123456789
|
|
self.assertRaises(ConnectionFailure, Connection)
|
|
self.assert_(Connection(self.host, self.port))
|
|
|
|
Connection.HOST = self.host
|
|
Connection.PORT = self.port
|
|
self.assert_(Connection())
|
|
|
|
def test_connect(self):
|
|
self.assertRaises(ConnectionFailure, Connection,
|
|
"somedomainthatdoesntexist.org")
|
|
self.assertRaises(ConnectionFailure, Connection, self.host, 123456789)
|
|
|
|
self.assert_(Connection(self.host, self.port))
|
|
|
|
def test_host_w_port(self):
|
|
self.assert_(Connection("%s:%d" % (self.host, self.port)))
|
|
self.assertRaises(ConnectionFailure, Connection,
|
|
"%s:1234567" % self.host, self.port)
|
|
|
|
def test_repr(self):
|
|
self.assertEqual(repr(Connection(self.host, self.port)),
|
|
"Connection('%s', %s)" % (self.host, self.port))
|
|
|
|
def test_getters(self):
|
|
self.assertEqual(Connection(self.host, self.port).host, self.host)
|
|
self.assertEqual(Connection(self.host, self.port).port, self.port)
|
|
self.assertEqual(set([(self.host, self.port)]),
|
|
Connection(self.host, self.port).nodes)
|
|
|
|
def test_get_db(self):
|
|
connection = Connection(self.host, self.port)
|
|
|
|
def make_db(base, name):
|
|
return base[name]
|
|
|
|
self.assertRaises(InvalidName, make_db, connection, "")
|
|
self.assertRaises(InvalidName, make_db, connection, "te$t")
|
|
self.assertRaises(InvalidName, make_db, connection, "te.t")
|
|
self.assertRaises(InvalidName, make_db, connection, "te\\t")
|
|
self.assertRaises(InvalidName, make_db, connection, "te/t")
|
|
self.assertRaises(InvalidName, make_db, connection, "te st")
|
|
|
|
self.assert_(isinstance(connection.test, Database))
|
|
self.assertEqual(connection.test, connection["test"])
|
|
self.assertEqual(connection.test, Database(connection, "test"))
|
|
|
|
def test_database_names(self):
|
|
connection = Connection(self.host, self.port)
|
|
|
|
connection.pymongo_test.test.save({"dummy": u"object"})
|
|
connection.pymongo_test_mike.test.save({"dummy": u"object"})
|
|
|
|
dbs = connection.database_names()
|
|
self.assert_("pymongo_test" in dbs)
|
|
self.assert_("pymongo_test_mike" in dbs)
|
|
|
|
def test_drop_database(self):
|
|
connection = Connection(self.host, self.port)
|
|
|
|
self.assertRaises(TypeError, connection.drop_database, 5)
|
|
self.assertRaises(TypeError, connection.drop_database, None)
|
|
|
|
connection.pymongo_test.test.save({"dummy": u"object"})
|
|
dbs = connection.database_names()
|
|
self.assert_("pymongo_test" in dbs)
|
|
connection.drop_database("pymongo_test")
|
|
dbs = connection.database_names()
|
|
self.assert_("pymongo_test" not in dbs)
|
|
|
|
connection.pymongo_test.test.save({"dummy": u"object"})
|
|
dbs = connection.database_names()
|
|
self.assert_("pymongo_test" in dbs)
|
|
connection.drop_database(connection.pymongo_test)
|
|
dbs = connection.database_names()
|
|
self.assert_("pymongo_test" not in dbs)
|
|
|
|
def test_copy_db(self):
|
|
c = Connection(self.host, self.port)
|
|
|
|
self.assertRaises(TypeError, c.copy_database, 4, "foo")
|
|
self.assertRaises(TypeError, c.copy_database, "foo", 4)
|
|
|
|
self.assertRaises(InvalidName, c.copy_database, "foo", "$foo")
|
|
|
|
c.pymongo_test.test.drop()
|
|
c.drop_database("pymongo_test1")
|
|
c.drop_database("pymongo_test2")
|
|
|
|
c.pymongo_test.test.insert({"foo": "bar"})
|
|
|
|
self.assertFalse("pymongo_test1" in c.database_names())
|
|
self.assertFalse("pymongo_test2" in c.database_names())
|
|
|
|
c.copy_database("pymongo_test", "pymongo_test1")
|
|
|
|
self.assert_("pymongo_test1" in c.database_names())
|
|
self.assertEqual("bar", c.pymongo_test1.test.find_one()["foo"])
|
|
|
|
c.copy_database("pymongo_test", "pymongo_test2",
|
|
"%s:%s" % (self.host, self.port))
|
|
|
|
self.assert_("pymongo_test2" in c.database_names())
|
|
self.assertEqual("bar", c.pymongo_test2.test.find_one()["foo"])
|
|
|
|
if version.at_least(c, (1, 3, 3, 1)):
|
|
c.drop_database("pymongo_test1")
|
|
|
|
c.pymongo_test.add_user("mike", "password")
|
|
|
|
self.assertRaises(OperationFailure, c.copy_database,
|
|
"pymongo_test", "pymongo_test1",
|
|
username="foo", password="bar")
|
|
self.assertFalse("pymongo_test1" in c.database_names())
|
|
|
|
self.assertRaises(OperationFailure, c.copy_database,
|
|
"pymongo_test", "pymongo_test1",
|
|
username="mike", password="bar")
|
|
self.assertFalse("pymongo_test1" in c.database_names())
|
|
|
|
c.copy_database("pymongo_test", "pymongo_test1",
|
|
username="mike", password="password")
|
|
self.assert_("pymongo_test1" in c.database_names())
|
|
self.assertEqual("bar", c.pymongo_test1.test.find_one()["foo"])
|
|
|
|
def test_iteration(self):
|
|
connection = Connection(self.host, self.port)
|
|
|
|
def iterate():
|
|
[a for a in connection]
|
|
|
|
self.assertRaises(TypeError, iterate)
|
|
|
|
# TODO this test is probably very dependent on the machine its running on
|
|
# due to timing issues, but I want to get something in here.
|
|
def test_low_network_timeout(self):
|
|
c = None
|
|
i = 0
|
|
n = 10
|
|
while c is None and i < n:
|
|
try:
|
|
c = Connection(self.host, self.port, network_timeout=0.0001)
|
|
except AutoReconnect:
|
|
i += 1
|
|
if i == n:
|
|
raise SkipTest()
|
|
|
|
coll = c.pymongo_test.test
|
|
|
|
for _ in range(1000):
|
|
try:
|
|
coll.find_one()
|
|
except AutoReconnect:
|
|
pass
|
|
except AssertionError:
|
|
self.fail()
|
|
|
|
def test_disconnect(self):
|
|
c = Connection(self.host, self.port)
|
|
coll = c.foo.bar
|
|
|
|
c.disconnect()
|
|
c.disconnect()
|
|
|
|
coll.count()
|
|
|
|
c.disconnect()
|
|
c.disconnect()
|
|
|
|
coll.count()
|
|
|
|
def test_from_uri(self):
|
|
c = Connection(self.host, self.port)
|
|
|
|
self.assertEqual(c, Connection("mongodb://%s:%s" %
|
|
(self.host, self.port)))
|
|
|
|
c.admin.system.users.remove({})
|
|
c.pymongo_test.system.users.remove({})
|
|
|
|
c.admin.add_user("admin", "pass")
|
|
c.admin.authenticate("admin", "pass")
|
|
c.pymongo_test.add_user("user", "pass")
|
|
|
|
self.assertRaises(ConfigurationError, Connection,
|
|
"mongodb://foo:bar@%s:%s" % (self.host, self.port))
|
|
self.assertRaises(ConfigurationError, Connection,
|
|
"mongodb://admin:bar@%s:%s" % (self.host, self.port))
|
|
self.assertRaises(ConfigurationError, Connection,
|
|
"mongodb://user:pass@%s:%s" % (self.host, self.port))
|
|
Connection("mongodb://admin:pass@%s:%s" % (self.host, self.port))
|
|
|
|
self.assertRaises(ConfigurationError, Connection,
|
|
"mongodb://admin:pass@%s:%s/pymongo_test" %
|
|
(self.host, self.port))
|
|
self.assertRaises(ConfigurationError, Connection,
|
|
"mongodb://user:foo@%s:%s/pymongo_test" %
|
|
(self.host, self.port))
|
|
Connection("mongodb://user:pass@%s:%s/pymongo_test" %
|
|
(self.host, self.port))
|
|
|
|
self.assert_(Connection("mongodb://%s:%s" %
|
|
(self.host, self.port),
|
|
slave_okay=True).slave_okay)
|
|
self.assert_(Connection("mongodb://%s:%s/?slaveok=true;w=2" %
|
|
(self.host, self.port)).slave_okay)
|
|
c.admin.system.users.remove({})
|
|
c.pymongo_test.system.users.remove({})
|
|
|
|
def test_fork(self):
|
|
"""Test using a connection before and after a fork.
|
|
"""
|
|
if sys.platform == "win32":
|
|
raise SkipTest()
|
|
|
|
try:
|
|
from multiprocessing import Process, Pipe
|
|
except ImportError:
|
|
raise SkipTest()
|
|
|
|
db = Connection(self.host, self.port).pymongo_test
|
|
|
|
# Failure occurs if the connection is used before the fork
|
|
db.test.find_one()
|
|
db.connection.end_request()
|
|
|
|
def loop(pipe):
|
|
while True:
|
|
try:
|
|
db.test.insert({"a": "b"}, safe=True)
|
|
for _ in db.test.find():
|
|
pass
|
|
except:
|
|
pipe.send(True)
|
|
os._exit(1)
|
|
|
|
cp1, cc1 = Pipe()
|
|
cp2, cc2 = Pipe()
|
|
|
|
p1 = Process(target=loop, args=(cc1,))
|
|
p2 = Process(target=loop, args=(cc2,))
|
|
|
|
p1.start()
|
|
p2.start()
|
|
|
|
p1.join(1)
|
|
p2.join(1)
|
|
|
|
p1.terminate()
|
|
p2.terminate()
|
|
|
|
p1.join()
|
|
p2.join()
|
|
|
|
cc1.close()
|
|
cc2.close()
|
|
|
|
# recv will only have data if the subprocess failed
|
|
try:
|
|
cp1.recv()
|
|
self.fail()
|
|
except EOFError:
|
|
pass
|
|
try:
|
|
cp2.recv()
|
|
self.fail()
|
|
except EOFError:
|
|
pass
|
|
|
|
def test_document_class(self):
|
|
c = Connection(self.host, self.port)
|
|
db = c.pymongo_test
|
|
db.test.insert({"x": 1})
|
|
|
|
self.assertEqual(dict, c.document_class)
|
|
self.assert_(isinstance(db.test.find_one(), dict))
|
|
self.assertFalse(isinstance(db.test.find_one(), SON))
|
|
|
|
c.document_class = SON
|
|
|
|
self.assertEqual(SON, c.document_class)
|
|
self.assert_(isinstance(db.test.find_one(), SON))
|
|
self.assertFalse(isinstance(db.test.find_one(as_class=dict), SON))
|
|
|
|
c = Connection(self.host, self.port, document_class=SON)
|
|
db = c.pymongo_test
|
|
|
|
self.assertEqual(SON, c.document_class)
|
|
self.assert_(isinstance(db.test.find_one(), SON))
|
|
self.assertFalse(isinstance(db.test.find_one(as_class=dict), SON))
|
|
|
|
c.document_class = dict
|
|
|
|
self.assertEqual(dict, c.document_class)
|
|
self.assert_(isinstance(db.test.find_one(), dict))
|
|
self.assertFalse(isinstance(db.test.find_one(), SON))
|
|
|
|
def test_network_timeout(self):
|
|
no_timeout = Connection(self.host, self.port)
|
|
timeout = Connection(self.host, self.port, network_timeout=0.1)
|
|
|
|
no_timeout.pymongo_test.drop_collection("test")
|
|
no_timeout.pymongo_test.test.insert({"x": 1}, safe=True)
|
|
|
|
where_func = """function (doc) {
|
|
var d = new Date().getTime() + 200;
|
|
var x = new Date().getTime();
|
|
while (x < d) {
|
|
x = new Date().getTime();
|
|
}
|
|
return true;
|
|
}"""
|
|
|
|
def get_x(db):
|
|
return db.test.find().where(where_func).next()["x"]
|
|
self.assertEqual(1, get_x(no_timeout.pymongo_test))
|
|
self.assertRaises(ConnectionFailure, get_x, timeout.pymongo_test)
|
|
|
|
def get_x_timeout(db, t):
|
|
return db.test.find(
|
|
network_timeout=t).where(where_func).next()["x"]
|
|
self.assertEqual(1, get_x_timeout(timeout.pymongo_test, None))
|
|
self.assertRaises(ConnectionFailure, get_x_timeout,
|
|
no_timeout.pymongo_test, 0.1)
|
|
|
|
def test_tz_aware(self):
|
|
aware = Connection(self.host, self.port, tz_aware=True)
|
|
naive = Connection(self.host, self.port)
|
|
aware.pymongo_test.drop_collection("test")
|
|
|
|
now = datetime.datetime.utcnow()
|
|
aware.pymongo_test.test.insert({"x": now}, safe=True)
|
|
|
|
self.assertEqual(None, naive.pymongo_test.test.find_one()["x"].tzinfo)
|
|
self.assertEqual(utc, aware.pymongo_test.test.find_one()["x"].tzinfo)
|
|
self.assertEqual(
|
|
aware.pymongo_test.test.find_one()["x"].replace(tzinfo=None),
|
|
naive.pymongo_test.test.find_one()["x"])
|
|
|
|
def test_ipv6(self):
|
|
self.assertRaises(AutoReconnect, Connection, 'foo')
|
|
try:
|
|
connection = Connection("[::1]")
|
|
except:
|
|
# Either mongod was started without --ipv6
|
|
# or the OS doesn't support it (or both).
|
|
raise SkipTest()
|
|
|
|
# Try a few simple things
|
|
connection = Connection("mongodb://[::1]:27017")
|
|
connection = Connection("mongodb://[::1]:27017/?slaveOk=true")
|
|
connection = Connection("[::1]:27017,localhost:27017")
|
|
connection = Connection("localhost:27017,[::1]:27017")
|
|
connection.pymongo_test.test.save({"dummy": u"object"})
|
|
connection.pymongo_test_bernie.test.save({"dummy": u"object"})
|
|
|
|
dbs = connection.database_names()
|
|
self.assert_("pymongo_test" in dbs)
|
|
self.assert_("pymongo_test_bernie" in dbs)
|
|
|
|
def test_autoreconnect(self):
|
|
def find_one(conn):
|
|
return conn.test.stuff.find_one()
|
|
# Simulate a temporary connection failure
|
|
c = Connection('foo', _connect=False)
|
|
self.assertRaises(AutoReconnect, find_one, c)
|
|
c._Connection__nodes = set([('localhost', 27017)])
|
|
self.assert_(find_one, c)
|
|
|
|
def test_fsync_lock_unlock(self):
|
|
c = get_connection()
|
|
self.assertFalse(c.is_locked)
|
|
# async flushing not supported on windows...
|
|
if sys.platform not in ('cygwin', 'win32'):
|
|
c.fsync(async=True)
|
|
self.assertFalse(c.is_locked)
|
|
c.fsync(lock=True)
|
|
self.assertTrue(c.is_locked)
|
|
locked = True
|
|
c.unlock()
|
|
for _ in xrange(5):
|
|
locked = c.is_locked
|
|
if not locked:
|
|
break
|
|
time.sleep(1)
|
|
self.assertFalse(locked)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|