Skip to content

Commit ecf7ac7

Browse files
authored
PYTHON-5013 Add NULL checks in InvalidDocument bson handling (#2049)
1 parent b9f4f79 commit ecf7ac7

File tree

2 files changed

+83
-34
lines changed

2 files changed

+83
-34
lines changed

bson/_cbsonmodule.c

+55-34
Original file line numberDiff line numberDiff line change
@@ -1644,6 +1644,56 @@ static int write_raw_doc(buffer_t buffer, PyObject* raw, PyObject* _raw_str) {
16441644
return bytes_written;
16451645
}
16461646

1647+
1648+
/* Update Invalid Document error message to include doc.
1649+
*/
1650+
void handle_invalid_doc_error(PyObject* dict) {
1651+
PyObject *etype = NULL, *evalue = NULL, *etrace = NULL;
1652+
PyObject *msg = NULL, *dict_str = NULL, *new_msg = NULL;
1653+
PyErr_Fetch(&etype, &evalue, &etrace);
1654+
PyObject *InvalidDocument = _error("InvalidDocument");
1655+
if (InvalidDocument == NULL) {
1656+
goto cleanup;
1657+
}
1658+
1659+
if (evalue && PyErr_GivenExceptionMatches(etype, InvalidDocument)) {
1660+
PyObject *msg = PyObject_Str(evalue);
1661+
if (msg) {
1662+
// Prepend doc to the existing message
1663+
PyObject *dict_str = PyObject_Str(dict);
1664+
if (dict_str == NULL) {
1665+
goto cleanup;
1666+
}
1667+
const char * dict_str_utf8 = PyUnicode_AsUTF8(dict_str);
1668+
if (dict_str_utf8 == NULL) {
1669+
goto cleanup;
1670+
}
1671+
const char * msg_utf8 = PyUnicode_AsUTF8(msg);
1672+
if (msg_utf8 == NULL) {
1673+
goto cleanup;
1674+
}
1675+
PyObject *new_msg = PyUnicode_FromFormat("Invalid document %s | %s", dict_str_utf8, msg_utf8);
1676+
Py_DECREF(evalue);
1677+
Py_DECREF(etype);
1678+
etype = InvalidDocument;
1679+
InvalidDocument = NULL;
1680+
if (new_msg) {
1681+
evalue = new_msg;
1682+
} else {
1683+
evalue = msg;
1684+
}
1685+
}
1686+
PyErr_NormalizeException(&etype, &evalue, &etrace);
1687+
}
1688+
cleanup:
1689+
PyErr_Restore(etype, evalue, etrace);
1690+
Py_XDECREF(msg);
1691+
Py_XDECREF(InvalidDocument);
1692+
Py_XDECREF(dict_str);
1693+
Py_XDECREF(new_msg);
1694+
}
1695+
1696+
16471697
/* returns the number of bytes written or 0 on failure */
16481698
int write_dict(PyObject* self, buffer_t buffer,
16491699
PyObject* dict, unsigned char check_keys,
@@ -1743,40 +1793,8 @@ int write_dict(PyObject* self, buffer_t buffer,
17431793
while (PyDict_Next(dict, &pos, &key, &value)) {
17441794
if (!decode_and_write_pair(self, buffer, key, value,
17451795
check_keys, options, top_level)) {
1746-
if (PyErr_Occurred()) {
1747-
PyObject *etype = NULL, *evalue = NULL, *etrace = NULL;
1748-
PyErr_Fetch(&etype, &evalue, &etrace);
1749-
PyObject *InvalidDocument = _error("InvalidDocument");
1750-
1751-
if (top_level && InvalidDocument && PyErr_GivenExceptionMatches(etype, InvalidDocument)) {
1752-
1753-
Py_DECREF(etype);
1754-
etype = InvalidDocument;
1755-
1756-
if (evalue) {
1757-
PyObject *msg = PyObject_Str(evalue);
1758-
Py_DECREF(evalue);
1759-
1760-
if (msg) {
1761-
// Prepend doc to the existing message
1762-
PyObject *dict_str = PyObject_Str(dict);
1763-
PyObject *new_msg = PyUnicode_FromFormat("Invalid document %s | %s", PyUnicode_AsUTF8(dict_str), PyUnicode_AsUTF8(msg));
1764-
Py_DECREF(dict_str);
1765-
1766-
if (new_msg) {
1767-
evalue = new_msg;
1768-
}
1769-
else {
1770-
evalue = msg;
1771-
}
1772-
}
1773-
}
1774-
PyErr_NormalizeException(&etype, &evalue, &etrace);
1775-
}
1776-
else {
1777-
Py_DECREF(InvalidDocument);
1778-
}
1779-
PyErr_Restore(etype, evalue, etrace);
1796+
if (PyErr_Occurred() && top_level) {
1797+
handle_invalid_doc_error(dict);
17801798
}
17811799
return 0;
17821800
}
@@ -1796,6 +1814,9 @@ int write_dict(PyObject* self, buffer_t buffer,
17961814
}
17971815
if (!decode_and_write_pair(self, buffer, key, value,
17981816
check_keys, options, top_level)) {
1817+
if (PyErr_Occurred() && top_level) {
1818+
handle_invalid_doc_error(dict);
1819+
}
17991820
Py_DECREF(key);
18001821
Py_DECREF(value);
18011822
Py_DECREF(iter);

test/test_bson.py

+28
Original file line numberDiff line numberDiff line change
@@ -1112,6 +1112,34 @@ def __repr__(self):
11121112
with self.assertRaisesRegex(InvalidDocument, f"Invalid document {doc}"):
11131113
encode(doc)
11141114

1115+
def test_doc_in_invalid_document_error_message_mapping(self):
1116+
class MyMapping(abc.Mapping):
1117+
def keys():
1118+
return ["t"]
1119+
1120+
def __getitem__(self, name):
1121+
if name == "_id":
1122+
return None
1123+
return Wrapper(name)
1124+
1125+
def __len__(self):
1126+
return 1
1127+
1128+
def __iter__(self):
1129+
return iter(["t"])
1130+
1131+
class Wrapper:
1132+
def __init__(self, val):
1133+
self.val = val
1134+
1135+
def __repr__(self):
1136+
return repr(self.val)
1137+
1138+
self.assertEqual("1", repr(Wrapper(1)))
1139+
doc = MyMapping()
1140+
with self.assertRaisesRegex(InvalidDocument, f"Invalid document {doc}"):
1141+
encode(doc)
1142+
11151143

11161144
class TestCodecOptions(unittest.TestCase):
11171145
def test_document_class(self):

0 commit comments

Comments
 (0)