use code_w_scope instead of code for eval. change Code class to support scopes

This commit is contained in:
Mike Dirolf 2009-03-12 11:16:03 -04:00
parent 238dc58b85
commit 5eda427cbc
6 changed files with 75 additions and 14 deletions

View File

@ -108,7 +108,6 @@ def _validate_ref(data):
_validate_code = _validate_string
# still not sure what is actually stored here, but i know how big it is...
def _validate_code_w_scope(data):
(length, data) = _get_int(data)
assert len(data) >= length + 1
@ -258,7 +257,7 @@ _element_getter = {
"\x0C": _get_ref,
"\x0D": _get_string, # code
"\x0E": _get_string, # symbol
# "\x0F": _get_code_w_scope
# "\x0F": _get_code_w_scope
"\x10": _get_int, # number_int
}
@ -297,8 +296,10 @@ def _element_to_bson(key, value):
return "\x05" + name + struct.pack("<i", len(value)) + chr(subtype) + value
if isinstance(value, Code):
cstring = _make_c_string(value)
scope = _dict_to_bson(value.scope)
full_length = struct.pack("<i", 8 + len(cstring) + len(scope))
length = struct.pack("<i", len(cstring))
return "\x0D" + name + length + cstring
return "\x0F" + name + full_length + length + cstring + scope
if isinstance(value, str):
cstring = _make_c_string(value)
length = struct.pack("<i", len(cstring))

View File

@ -12,10 +12,47 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Representation of code to be stored in or retrieved from Mongo.
"""Representation of JavaScript code to be evaluated by MongoDB.
"""
import types
class Code(str):
"""Code to be stored in or retrieved from Mongo.
"""JavaScript code to be evaluated by MongoDB.
"""
def __new__(cls, code, scope=None):
"""Initialize a new code object.
`code` is a string containing JavaScript code.
`scope` is a dictionary representing the scope in which `code` should
be evaluated. It should be a mapping from identifiers (as strings) to
values.
Raises TypeError if `code` is not an instance of (str, unicode) or
`scope` is not an instance of dict.
:Parameters:
- `code`: JavaScript code to be evaluated
- `scope` (optional): dictionary representing the scope for evaluation
"""
if not isinstance(code, types.StringTypes):
raise TypeError("code must be an instance of (str, unicode)")
if scope is None:
scope = {}
if not isinstance(scope, types.DictType):
raise TypeError("scope must be an instance of dict")
self = str.__new__(cls, code)
self.__scope = scope
return self
@property
def scope(self):
"""Get the scope dictionary.
"""
return self.__scope
def __repr__(self):
return "Code(%s)" % str.__repr__(self)
return "Code(%s, %r)" % (str.__repr__(self), self.__scope)

View File

@ -361,8 +361,8 @@ class Database(object):
arguments will be passed to that function when it is run on the
server.
Raises TypeError if `code` is not an instance of (str, unicode). Raises
OperationFailure if the eval fails. Returns the result of the
Raises TypeError if `code` is not an instance of (str, unicode, `Code`).
Raises OperationFailure if the eval fails. Returns the result of the
evaluation.
:Parameters:
@ -370,9 +370,9 @@ class Database(object):
- `args` (optional): additional positional arguments are passed to
the `code` being evaluated
"""
if not isinstance(code, types.StringTypes):
raise TypeError("code must be an instance of (str, unicode)")
if not isinstance(code, Code):
code = Code(code)
command = SON([("$eval", Code(code)), ("args", list(args))])
command = SON([("$eval", code), ("args", list(args))])
result = self._command(command)
return result.get("retval", None)

View File

@ -92,7 +92,7 @@ class TestBSON(unittest.TestCase):
self.assertEqual(BSON.from_dict({"regex": re.compile("a*b", re.IGNORECASE)}),
"\x12\x00\x00\x00\x0B\x72\x65\x67\x65\x78\x00\x61\x2A\x62\x00\x69\x00\x00")
self.assertEqual(BSON.from_dict({"$where": Code("test")}),
"\x16\x00\x00\x00\x0D\x24\x77\x68\x65\x72\x65\x00\x05\x00\x00\x00\x74\x65\x73\x74\x00\x00")
"\x1F\x00\x00\x00\x0F\x24\x77\x68\x65\x72\x65\x00\x12\x00\x00\x00\x05\x00\x00\x00\x74\x65\x73\x74\x00\x05\x00\x00\x00\x00\x00")
a = ObjectId("\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B")
self.assertEqual(BSON.from_dict({"oid": a}),
"\x16\x00\x00\x00\x07\x6F\x69\x64\x00\x07\x06\x05\x04\x03\x02\x01\x00\x0B\x0A\x09\x08\x00")

View File

@ -24,6 +24,16 @@ class TestCode(unittest.TestCase):
def setUp(self):
pass
def test_types(self):
self.assertRaises(TypeError, Code, 5)
self.assertRaises(TypeError, Code, None)
self.assertRaises(TypeError, Code, "aoeu", 5)
self.assertRaises(TypeError, Code, u"aoeu", 5)
self.assert_(Code("aoeu"))
self.assert_(Code(u"aoeu"))
self.assert_(Code("aoeu", {}))
self.assert_(Code(u"aoeu", {}))
def test_code(self):
a_string = "hello world"
a_code = Code("hello world")
@ -31,12 +41,19 @@ class TestCode(unittest.TestCase):
self.assert_(a_code.endswith("world"))
self.assert_(isinstance(a_code, Code))
self.failIf(isinstance(a_string, Code))
self.assertEqual(a_code.scope, {})
a_code.scope["my_var"] = 5
self.assertEqual(a_code.scope, {"my_var": 5})
def test_repr(self):
c = Code("hello world")
self.assertEqual(repr(c), "Code('hello world')")
self.assertEqual(repr(c), "Code('hello world', {})")
c.scope["foo"] = "bar"
self.assertEqual(repr(c), "Code('hello world', {'foo': 'bar'})")
c = Code("hello world", {"blah": 3})
self.assertEqual(repr(c), "Code('hello world', {'blah': 3})")
c = Code("\x08\xFF")
self.assertEqual(repr(c), "Code('\\x08\\xff')")
self.assertEqual(repr(c), "Code('\\x08\\xff', {})")
if __name__ == "__main__":
unittest.main()

View File

@ -29,6 +29,7 @@ from pymongo import ASCENDING, DESCENDING, OFF, SLOW_ONLY, ALL
from pymongo.connection import Connection
from pymongo.collection import Collection
from pymongo.dbref import DBRef
from pymongo.code import Code
from test_connection import get_connection
class TestDatabase(unittest.TestCase):
@ -265,6 +266,11 @@ class TestDatabase(unittest.TestCase):
self.assertEqual(5, db.eval("function () {return 5;}"))
self.assertEqual(5, db.eval("2 + 3;"))
self.assertEqual(5, db.eval(Code("2 + 3;")))
self.assertEqual(None, db.eval(Code("return i;")))
self.assertEqual(2, db.eval(Code("return i;", {"i": 2})))
self.assertEqual(5, db.eval(Code("i + 3;", {"i": 2})))
self.assertRaises(OperationFailure, db.eval, "5 ++ 5;")
# TODO some of these tests belong in the collection level testing.