use code_w_scope instead of code for eval. change Code class to support scopes
This commit is contained in:
parent
238dc58b85
commit
5eda427cbc
@ -108,7 +108,6 @@ def _validate_ref(data):
|
||||
|
||||
_validate_code = _validate_string
|
||||
|
||||
# still not sure what is actually stored here, but i know how big it is...
|
||||
def _validate_code_w_scope(data):
|
||||
(length, data) = _get_int(data)
|
||||
assert len(data) >= length + 1
|
||||
@ -258,7 +257,7 @@ _element_getter = {
|
||||
"\x0C": _get_ref,
|
||||
"\x0D": _get_string, # code
|
||||
"\x0E": _get_string, # symbol
|
||||
# "\x0F": _get_code_w_scope
|
||||
# "\x0F": _get_code_w_scope
|
||||
"\x10": _get_int, # number_int
|
||||
}
|
||||
|
||||
@ -297,8 +296,10 @@ def _element_to_bson(key, value):
|
||||
return "\x05" + name + struct.pack("<i", len(value)) + chr(subtype) + value
|
||||
if isinstance(value, Code):
|
||||
cstring = _make_c_string(value)
|
||||
scope = _dict_to_bson(value.scope)
|
||||
full_length = struct.pack("<i", 8 + len(cstring) + len(scope))
|
||||
length = struct.pack("<i", len(cstring))
|
||||
return "\x0D" + name + length + cstring
|
||||
return "\x0F" + name + full_length + length + cstring + scope
|
||||
if isinstance(value, str):
|
||||
cstring = _make_c_string(value)
|
||||
length = struct.pack("<i", len(cstring))
|
||||
|
||||
@ -12,10 +12,47 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Representation of code to be stored in or retrieved from Mongo.
|
||||
"""Representation of JavaScript code to be evaluated by MongoDB.
|
||||
"""
|
||||
|
||||
import types
|
||||
|
||||
class Code(str):
|
||||
"""Code to be stored in or retrieved from Mongo.
|
||||
"""JavaScript code to be evaluated by MongoDB.
|
||||
"""
|
||||
def __new__(cls, code, scope=None):
|
||||
"""Initialize a new code object.
|
||||
|
||||
`code` is a string containing JavaScript code.
|
||||
|
||||
`scope` is a dictionary representing the scope in which `code` should
|
||||
be evaluated. It should be a mapping from identifiers (as strings) to
|
||||
values.
|
||||
|
||||
Raises TypeError if `code` is not an instance of (str, unicode) or
|
||||
`scope` is not an instance of dict.
|
||||
|
||||
:Parameters:
|
||||
- `code`: JavaScript code to be evaluated
|
||||
- `scope` (optional): dictionary representing the scope for evaluation
|
||||
"""
|
||||
if not isinstance(code, types.StringTypes):
|
||||
raise TypeError("code must be an instance of (str, unicode)")
|
||||
|
||||
if scope is None:
|
||||
scope = {}
|
||||
if not isinstance(scope, types.DictType):
|
||||
raise TypeError("scope must be an instance of dict")
|
||||
|
||||
self = str.__new__(cls, code)
|
||||
self.__scope = scope
|
||||
return self
|
||||
|
||||
@property
|
||||
def scope(self):
|
||||
"""Get the scope dictionary.
|
||||
"""
|
||||
return self.__scope
|
||||
|
||||
def __repr__(self):
|
||||
return "Code(%s)" % str.__repr__(self)
|
||||
return "Code(%s, %r)" % (str.__repr__(self), self.__scope)
|
||||
|
||||
@ -361,8 +361,8 @@ class Database(object):
|
||||
arguments will be passed to that function when it is run on the
|
||||
server.
|
||||
|
||||
Raises TypeError if `code` is not an instance of (str, unicode). Raises
|
||||
OperationFailure if the eval fails. Returns the result of the
|
||||
Raises TypeError if `code` is not an instance of (str, unicode, `Code`).
|
||||
Raises OperationFailure if the eval fails. Returns the result of the
|
||||
evaluation.
|
||||
|
||||
:Parameters:
|
||||
@ -370,9 +370,9 @@ class Database(object):
|
||||
- `args` (optional): additional positional arguments are passed to
|
||||
the `code` being evaluated
|
||||
"""
|
||||
if not isinstance(code, types.StringTypes):
|
||||
raise TypeError("code must be an instance of (str, unicode)")
|
||||
if not isinstance(code, Code):
|
||||
code = Code(code)
|
||||
|
||||
command = SON([("$eval", Code(code)), ("args", list(args))])
|
||||
command = SON([("$eval", code), ("args", list(args))])
|
||||
result = self._command(command)
|
||||
return result.get("retval", None)
|
||||
|
||||
@ -92,7 +92,7 @@ class TestBSON(unittest.TestCase):
|
||||
self.assertEqual(BSON.from_dict({"regex": re.compile("a*b", re.IGNORECASE)}),
|
||||
"\x12\x00\x00\x00\x0B\x72\x65\x67\x65\x78\x00\x61\x2A\x62\x00\x69\x00\x00")
|
||||
self.assertEqual(BSON.from_dict({"$where": Code("test")}),
|
||||
"\x16\x00\x00\x00\x0D\x24\x77\x68\x65\x72\x65\x00\x05\x00\x00\x00\x74\x65\x73\x74\x00\x00")
|
||||
"\x1F\x00\x00\x00\x0F\x24\x77\x68\x65\x72\x65\x00\x12\x00\x00\x00\x05\x00\x00\x00\x74\x65\x73\x74\x00\x05\x00\x00\x00\x00\x00")
|
||||
a = ObjectId("\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B")
|
||||
self.assertEqual(BSON.from_dict({"oid": a}),
|
||||
"\x16\x00\x00\x00\x07\x6F\x69\x64\x00\x07\x06\x05\x04\x03\x02\x01\x00\x0B\x0A\x09\x08\x00")
|
||||
|
||||
@ -24,6 +24,16 @@ class TestCode(unittest.TestCase):
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def test_types(self):
|
||||
self.assertRaises(TypeError, Code, 5)
|
||||
self.assertRaises(TypeError, Code, None)
|
||||
self.assertRaises(TypeError, Code, "aoeu", 5)
|
||||
self.assertRaises(TypeError, Code, u"aoeu", 5)
|
||||
self.assert_(Code("aoeu"))
|
||||
self.assert_(Code(u"aoeu"))
|
||||
self.assert_(Code("aoeu", {}))
|
||||
self.assert_(Code(u"aoeu", {}))
|
||||
|
||||
def test_code(self):
|
||||
a_string = "hello world"
|
||||
a_code = Code("hello world")
|
||||
@ -31,12 +41,19 @@ class TestCode(unittest.TestCase):
|
||||
self.assert_(a_code.endswith("world"))
|
||||
self.assert_(isinstance(a_code, Code))
|
||||
self.failIf(isinstance(a_string, Code))
|
||||
self.assertEqual(a_code.scope, {})
|
||||
a_code.scope["my_var"] = 5
|
||||
self.assertEqual(a_code.scope, {"my_var": 5})
|
||||
|
||||
def test_repr(self):
|
||||
c = Code("hello world")
|
||||
self.assertEqual(repr(c), "Code('hello world')")
|
||||
self.assertEqual(repr(c), "Code('hello world', {})")
|
||||
c.scope["foo"] = "bar"
|
||||
self.assertEqual(repr(c), "Code('hello world', {'foo': 'bar'})")
|
||||
c = Code("hello world", {"blah": 3})
|
||||
self.assertEqual(repr(c), "Code('hello world', {'blah': 3})")
|
||||
c = Code("\x08\xFF")
|
||||
self.assertEqual(repr(c), "Code('\\x08\\xff')")
|
||||
self.assertEqual(repr(c), "Code('\\x08\\xff', {})")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -29,6 +29,7 @@ from pymongo import ASCENDING, DESCENDING, OFF, SLOW_ONLY, ALL
|
||||
from pymongo.connection import Connection
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.dbref import DBRef
|
||||
from pymongo.code import Code
|
||||
from test_connection import get_connection
|
||||
|
||||
class TestDatabase(unittest.TestCase):
|
||||
@ -265,6 +266,11 @@ class TestDatabase(unittest.TestCase):
|
||||
self.assertEqual(5, db.eval("function () {return 5;}"))
|
||||
self.assertEqual(5, db.eval("2 + 3;"))
|
||||
|
||||
self.assertEqual(5, db.eval(Code("2 + 3;")))
|
||||
self.assertEqual(None, db.eval(Code("return i;")))
|
||||
self.assertEqual(2, db.eval(Code("return i;", {"i": 2})))
|
||||
self.assertEqual(5, db.eval(Code("i + 3;", {"i": 2})))
|
||||
|
||||
self.assertRaises(OperationFailure, db.eval, "5 ++ 5;")
|
||||
|
||||
# TODO some of these tests belong in the collection level testing.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user