From 2371f32894445e7ee4f76e211197b7fccf262f12 Mon Sep 17 00:00:00 2001 From: Bernie Hackett Date: Mon, 18 Nov 2013 14:54:10 -0800 Subject: [PATCH] Clean up str/bytes/unicode encoding in C. Minor perf improvements included. --- bson/_cbsonmodule.c | 140 ++++++++++++++++++++------------------------ bson/_cbsonmodule.h | 2 +- 2 files changed, 65 insertions(+), 77 deletions(-) diff --git a/bson/_cbsonmodule.c b/bson/_cbsonmodule.c index 811beeb1c..b22359ac3 100644 --- a/bson/_cbsonmodule.c +++ b/bson/_cbsonmodule.c @@ -212,7 +212,6 @@ int buffer_write_bytes(buffer_t buffer, const char* data, int size) { return 1; } -#if PY_MAJOR_VERSION >= 3 static int write_unicode(buffer_t buffer, PyObject* py_string) { int size; const char* data; @@ -220,27 +219,34 @@ static int write_unicode(buffer_t buffer, PyObject* py_string) { if (!encoded) { return 0; } - data = PyBytes_AsString(encoded); - if (!data) { - Py_DECREF(encoded); - return 0; - } - if ((size = _downcast_and_check(PyBytes_Size(encoded), 1)) == -1){ - Py_DECREF(encoded); - return 0; - } - if (!buffer_write_bytes(buffer, (const char*)&size, 4)) { - Py_DECREF(encoded); - return 0; - } - if (!buffer_write_bytes(buffer, data, size)) { - Py_DECREF(encoded); - return 0; - } +#if PY_MAJOR_VERSION >= 3 + data = PyBytes_AS_STRING(encoded); +#else + data = PyString_AS_STRING(encoded); +#endif + if (!data) + goto unicodefail; + +#if PY_MAJOR_VERSION >= 3 + if ((size = _downcast_and_check(PyBytes_GET_SIZE(encoded), 1)) == -1) +#else + if ((size = _downcast_and_check(PyString_GET_SIZE(encoded), 1)) == -1) +#endif + goto unicodefail; + + if (!buffer_write_bytes(buffer, (const char*)&size, 4)) + goto unicodefail; + + if (!buffer_write_bytes(buffer, data, size)) + goto unicodefail; + Py_DECREF(encoded); return 1; + +unicodefail: + Py_DECREF(encoded); + return 0; } -#endif /* returns 0 on failure */ static int write_string(buffer_t buffer, PyObject* py_string) { @@ -934,10 +940,10 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer, } else if (PyBytes_Check(value)) { int subtype = 0; int size; - const char* data = PyBytes_AsString(value); + const char* data = PyBytes_AS_STRING(value); if (!data) return 0; - if ((size = _downcast_and_check(PyBytes_Size(value), 0)) == -1) + if ((size = _downcast_and_check(PyBytes_GET_SIZE(value), 0)) == -1) return 0; *(buffer_get_buffer(buffer) + type_byte) = 0x05; if (!buffer_write_bytes(buffer, (const char*)&size, 4)) { @@ -956,12 +962,12 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer, result_t status; const char* data; int size; - if (!(data = PyString_AsString(value))) + if (!(data = PyString_AS_STRING(value))) return 0; - if ((size = _downcast_and_check(PyString_Size(value), 0)) == -1) + if ((size = _downcast_and_check(PyString_GET_SIZE(value), 1)) == -1) return 0; *(buffer_get_buffer(buffer) + type_byte) = 0x02; - status = check_string((const unsigned char*)data, size, 1, 0); + status = check_string((const unsigned char*)data, size - 1, 1, 0); if (status == NOT_UTF_8) { PyObject* InvalidStringData = _error("InvalidStringData"); @@ -988,20 +994,17 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer, } return 0; } - return write_string(buffer, value); -#endif - } else if (PyUnicode_Check(value)) { - PyObject* encoded; - int result; - - *(buffer_get_buffer(buffer) + type_byte) = 0x02; - encoded = PyUnicode_AsUTF8String(value); - if (!encoded) { + if (!buffer_write_bytes(buffer, (const char*)&size, 4)) { return 0; } - result = write_string(buffer, encoded); - Py_DECREF(encoded); - return result; + if (!buffer_write_bytes(buffer, data, size)) { + return 0; + } + return 1; +#endif + } else if (PyUnicode_Check(value)) { + *(buffer_get_buffer(buffer) + type_byte) = 0x02; + return write_unicode(buffer, value); } else if (PyDateTime_Check(value)) { long long millis; PyObject* utcoffset = PyObject_CallMethod(value, "utcoffset", NULL); @@ -1111,9 +1114,8 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer, return 0; } -static int check_key_name(const char* name, - const Py_ssize_t name_length) { - int i; +static int check_key_name(const char* name, int name_length) { + if (name_length > 0 && name[0] == '$') { PyObject* InvalidDocument = _error("InvalidDocument"); if (InvalidDocument) { @@ -1132,25 +1134,23 @@ static int check_key_name(const char* name, } return 0; } - for (i = 0; i < name_length; i++) { - if (name[i] == '.') { - PyObject* InvalidDocument = _error("InvalidDocument"); - if (InvalidDocument) { + if (strchr(name, '.')) { + PyObject* InvalidDocument = _error("InvalidDocument"); + if (InvalidDocument) { #if PY_MAJOR_VERSION >= 3 - PyObject* errmsg = PyUnicode_FromFormat( - "key '%s' must not contain '.'", name); + PyObject* errmsg = PyUnicode_FromFormat( + "key '%s' must not contain '.'", name); #else - PyObject* errmsg = PyString_FromFormat( - "key '%s' must not contain '.'", name); + PyObject* errmsg = PyString_FromFormat( + "key '%s' must not contain '.'", name); #endif - if (errmsg) { - PyErr_SetObject(InvalidDocument, errmsg); - Py_DECREF(errmsg); - } - Py_DECREF(InvalidDocument); + if (errmsg) { + PyErr_SetObject(InvalidDocument, errmsg); + Py_DECREF(errmsg); } - return 0; + Py_DECREF(InvalidDocument); } + return 0; } return 1; } @@ -1158,13 +1158,10 @@ static int check_key_name(const char* name, /* Write a (key, value) pair to the buffer. * * Returns 0 on failure */ -int write_pair(PyObject* self, buffer_t buffer, const char* name, Py_ssize_t name_length, +int write_pair(PyObject* self, buffer_t buffer, const char* name, int name_length, PyObject* value, unsigned char check_keys, unsigned char uuid_subtype, unsigned char allow_id) { int type_byte; - int length; - if ((length = _downcast_and_check(name_length, 1)) == -1) - return 0; /* Don't write any _id elements unless we're explicitly told to - * _id has to be written first so we do so, but don't bother @@ -1181,7 +1178,7 @@ int write_pair(PyObject* self, buffer_t buffer, const char* name, Py_ssize_t nam if (check_keys && !check_key_name(name, name_length)) { return 0; } - if (!buffer_write_bytes(buffer, name, length)) { + if (!buffer_write_bytes(buffer, name, name_length + 1)) { return 0; } if (!write_element_to_buffer(self, buffer, type_byte, @@ -1199,32 +1196,30 @@ int decode_and_write_pair(PyObject* self, buffer_t buffer, const char* data; int size; if (PyUnicode_Check(key)) { - result_t status; encoded = PyUnicode_AsUTF8String(key); if (!encoded) { return 0; } #if PY_MAJOR_VERSION >= 3 - if (!(data = PyBytes_AsString(encoded))) { + if (!(data = PyBytes_AS_STRING(encoded))) { Py_DECREF(encoded); return 0; } - if ((size = _downcast_and_check(PyBytes_Size(encoded), 0)) == -1) { + if ((size = _downcast_and_check(PyBytes_GET_SIZE(encoded), 1)) == -1) { Py_DECREF(encoded); return 0; } #else - if (!(data = PyString_AsString(encoded))) { + if (!(data = PyString_AS_STRING(encoded))) { Py_DECREF(encoded); return 0; } - if ((size = _downcast_and_check(PyString_Size(encoded), 0)) == -1) { + if ((size = _downcast_and_check(PyString_GET_SIZE(encoded), 1)) == -1) { Py_DECREF(encoded); return 0; } #endif - status = check_string((const unsigned char*)data, size, 0, 1); - if (status == HAS_NULL) { + if (strlen(data) != (size_t)(size - 1)) { PyObject* InvalidDocument = _error("InvalidDocument"); if (InvalidDocument) { PyErr_SetString(InvalidDocument, @@ -1240,15 +1235,15 @@ int decode_and_write_pair(PyObject* self, buffer_t buffer, encoded = key; Py_INCREF(encoded); - if (!(data = PyString_AsString(encoded))) { + if (!(data = PyString_AS_STRING(encoded))) { Py_DECREF(encoded); return 0; } - if ((size = _downcast_and_check(PyString_Size(encoded), 0)) == -1) { + if ((size = _downcast_and_check(PyString_GET_SIZE(encoded), 1)) == -1) { Py_DECREF(encoded); return 0; } - status = check_string((const unsigned char*)data, size, 1, 1); + status = check_string((const unsigned char*)data, size - 1, 1, 1); if (status == NOT_UTF_8) { PyObject* InvalidStringData = _error("InvalidStringData"); @@ -1308,15 +1303,8 @@ int decode_and_write_pair(PyObject* self, buffer_t buffer, } /* If top_level is True, don't allow writing _id here - it was already written. */ -#if PY_MAJOR_VERSION >= 3 if (!write_pair(self, buffer, data, - PyBytes_Size(encoded), value, - check_keys, uuid_subtype, !top_level)) { -#else - if (!write_pair(self, buffer, data, - PyString_Size(encoded), value, - check_keys, uuid_subtype, !top_level)) { -#endif + size - 1, value, check_keys, uuid_subtype, !top_level)) { Py_DECREF(encoded); return 0; } diff --git a/bson/_cbsonmodule.h b/bson/_cbsonmodule.h index be1d3a5ca..be69eafcb 100644 --- a/bson/_cbsonmodule.h +++ b/bson/_cbsonmodule.h @@ -37,7 +37,7 @@ typedef int Py_ssize_t; #define _cbson_write_pair_INDEX 2 #define _cbson_write_pair_RETURN int -#define _cbson_write_pair_PROTO (PyObject* self, buffer_t buffer, const char* name, Py_ssize_t name_length, PyObject* value, unsigned char check_keys, unsigned char uuid_subtype, unsigned char allow_id) +#define _cbson_write_pair_PROTO (PyObject* self, buffer_t buffer, const char* name, int name_length, PyObject* value, unsigned char check_keys, unsigned char uuid_subtype, unsigned char allow_id) #define _cbson_decode_and_write_pair_INDEX 3 #define _cbson_decode_and_write_pair_RETURN int