Skip to content

Commit

Permalink
Refactor parsing of packb and unpackb arguments
Browse files Browse the repository at this point in the history
Signed-off-by: Emanuele Giaquinta <emanuele.giaquinta@gmail.com>
  • Loading branch information
exg committed Dec 15, 2024
1 parent f2659fe commit 7933610
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 105 deletions.
123 changes: 56 additions & 67 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ mod typeref;
mod unicode;

use pyo3::ffi::*;
use std::borrow::Cow;
use std::os::raw::c_char;
use std::os::raw::c_int;
use std::os::raw::c_long;
Expand Down Expand Up @@ -143,8 +142,7 @@ pub unsafe extern "C" fn ormsgpack_exec(mptr: *mut PyObject) -> c_int {

#[cold]
#[inline(never)]
fn raise_unpackb_exception(err: deserialize::DeserializeError) -> *mut PyObject {
let msg = err.message;
fn raise_unpackb_exception(msg: &str) -> *mut PyObject {
unsafe {
let err_msg =
PyUnicode_FromStringAndSize(msg.as_ptr() as *const c_char, msg.len() as isize);
Expand All @@ -158,7 +156,7 @@ fn raise_unpackb_exception(err: deserialize::DeserializeError) -> *mut PyObject

#[cold]
#[inline(never)]
fn raise_packb_exception(msg: Cow<str>) -> *mut PyObject {
fn raise_packb_exception(msg: &str) -> *mut PyObject {
unsafe {
let err_msg =
PyUnicode_FromStringAndSize(msg.as_ptr() as *const c_char, msg.len() as isize);
Expand All @@ -168,6 +166,21 @@ fn raise_packb_exception(msg: Cow<str>) -> *mut PyObject {
std::ptr::null_mut()
}

unsafe fn parse_option_arg(opts: *mut PyObject, mask: i32) -> Result<i32, ()> {
if Py_TYPE(opts) == typeref::INT_TYPE {
let val = PyLong_AsLong(opts) as i32;
if val & !mask == 0 {
Ok(val)
} else {
Err(())
}
} else if opts == typeref::NONE {
Ok(0)
} else {
Err(())
}
}

#[no_mangle]
pub unsafe extern "C" fn unpackb(
_self: *mut PyObject,
Expand All @@ -181,50 +194,37 @@ pub unsafe extern "C" fn unpackb(
let num_args = PyVectorcall_NARGS(nargs as usize);
if unlikely!(num_args != 1) {
let msg = if num_args > 1 {
Cow::Borrowed("unpackb() accepts only 1 positional argument")
"unpackb() accepts only 1 positional argument"
} else {
Cow::Borrowed("unpackb() missing 1 required positional argument: 'obj'")
"unpackb() missing 1 required positional argument: 'obj'"
};
return raise_unpackb_exception(deserialize::DeserializeError::new(msg));
return raise_unpackb_exception(msg);
}
if !kwnames.is_null() {
let tuple_size = PyTuple_GET_SIZE(kwnames);
if tuple_size > 0 {
for i in 0..=tuple_size - 1 {
let arg = PyTuple_GET_ITEM(kwnames, i as Py_ssize_t);
if arg == typeref::EXT_HOOK {
ext_hook = Some(NonNull::new_unchecked(*args.offset(num_args + i)));
} else if arg == typeref::OPTION {
optsptr = Some(NonNull::new_unchecked(*args.offset(num_args + i)));
} else {
return raise_unpackb_exception(deserialize::DeserializeError::new(
Cow::Borrowed("unpackb() got an unexpected keyword argument"),
));
}
for i in 0..tuple_size {
let arg = PyTuple_GET_ITEM(kwnames, i as Py_ssize_t);
if arg == typeref::EXT_HOOK {
ext_hook = Some(NonNull::new_unchecked(*args.offset(num_args + i)));
} else if arg == typeref::OPTION {
optsptr = Some(NonNull::new_unchecked(*args.offset(num_args + i)));
} else {
return raise_unpackb_exception("unpackb() got an unexpected keyword argument");
}
}
}

let mut optsbits: i32 = 0;
if let Some(opts) = optsptr {
let ob_type = (*opts.as_ptr()).ob_type;
if ob_type == typeref::INT_TYPE {
optsbits = PyLong_AsLong(optsptr.unwrap().as_ptr()) as i32;
if !(0..=opt::MAX_UNPACKB_OPT).contains(&optsbits) {
return raise_unpackb_exception(deserialize::DeserializeError::new(Cow::Borrowed(
"Invalid opts",
)));
}
} else if ob_type != typeref::NONE_TYPE {
return raise_unpackb_exception(deserialize::DeserializeError::new(Cow::Borrowed(
"Invalid opts",
)));
match parse_option_arg(opts.as_ptr(), opt::UNPACKB_OPT_MASK) {
Ok(val) => optsbits = val,
Err(()) => return raise_unpackb_exception("Invalid opts"),
}
}

match crate::deserialize::deserialize(*args, ext_hook, optsbits as opt::Opt) {
Ok(val) => val.as_ptr(),
Err(err) => raise_unpackb_exception(err),
Err(err) => raise_unpackb_exception(&err.message),
}
}

Expand All @@ -240,59 +240,48 @@ pub unsafe extern "C" fn packb(

let num_args = PyVectorcall_NARGS(nargs as usize);
if unlikely!(num_args == 0) {
return raise_packb_exception(Cow::Borrowed(
"packb() missing 1 required positional argument: 'obj'",
));
return raise_packb_exception("packb() missing 1 required positional argument: 'obj'");
}
if num_args & 2 == 2 {
if num_args >= 2 {
default = Some(NonNull::new_unchecked(*args.offset(1)));
}
if num_args & 3 == 3 {
if num_args >= 3 {
optsptr = Some(NonNull::new_unchecked(*args.offset(2)));
}
if !kwnames.is_null() {
let tuple_size = PyTuple_GET_SIZE(kwnames);
if tuple_size > 0 {
for i in 0..=tuple_size - 1 {
let arg = PyTuple_GET_ITEM(kwnames, i as Py_ssize_t);
if arg == typeref::DEFAULT {
if unlikely!(num_args & 2 == 2) {
return raise_packb_exception(Cow::Borrowed(
"packb() got multiple values for argument: 'default'",
));
}
default = Some(NonNull::new_unchecked(*args.offset(num_args + i)));
} else if arg == typeref::OPTION {
if unlikely!(num_args & 3 == 3) {
return raise_packb_exception(Cow::Borrowed(
"packb() got multiple values for argument: 'option'",
));
}
optsptr = Some(NonNull::new_unchecked(*args.offset(num_args + i)));
} else {
return raise_packb_exception(Cow::Borrowed(
"packb() got an unexpected keyword argument",
));
for i in 0..tuple_size {
let arg = PyTuple_GET_ITEM(kwnames, i as Py_ssize_t);
if arg == typeref::DEFAULT {
if unlikely!(default.is_some()) {
return raise_packb_exception(
"packb() got multiple values for argument: 'default'",
);
}
default = Some(NonNull::new_unchecked(*args.offset(num_args + i)));
} else if arg == typeref::OPTION {
if unlikely!(optsptr.is_some()) {
return raise_packb_exception(
"packb() got multiple values for argument: 'option'",
);
}
optsptr = Some(NonNull::new_unchecked(*args.offset(num_args + i)));
} else {
return raise_packb_exception("packb() got an unexpected keyword argument");
}
}
}

let mut optsbits: i32 = 0;
if let Some(opts) = optsptr {
let ob_type = (*opts.as_ptr()).ob_type;
if ob_type == typeref::INT_TYPE {
optsbits = PyLong_AsLong(optsptr.unwrap().as_ptr()) as i32;
if !(0..=opt::MAX_PACKB_OPT).contains(&optsbits) {
return raise_packb_exception(Cow::Borrowed("Invalid opts"));
}
} else if ob_type != typeref::NONE_TYPE {
return raise_packb_exception(Cow::Borrowed("Invalid opts"));
match parse_option_arg(opts.as_ptr(), opt::PACKB_OPT_MASK) {
Ok(val) => optsbits = val,
Err(()) => return raise_packb_exception("Invalid opts"),
}
}

match crate::serialize::serialize(*args, default, optsbits as opt::Opt) {
Ok(val) => val.as_ptr(),
Err(err) => raise_packb_exception(Cow::Borrowed(&err)),
Err(err) => raise_packb_exception(&err),
}
}
4 changes: 2 additions & 2 deletions src/opt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub const NOT_PASSTHROUGH: Opt = !(PASSTHROUGH_BIG_INT
| PASSTHROUGH_SUBCLASS
| PASSTHROUGH_TUPLE);

pub const MAX_PACKB_OPT: i32 = (NAIVE_UTC
pub const PACKB_OPT_MASK: i32 = (NAIVE_UTC
| NON_STR_KEYS
| OMIT_MICROSECONDS
| PASSTHROUGH_BIG_INT
Expand All @@ -34,4 +34,4 @@ pub const MAX_PACKB_OPT: i32 = (NAIVE_UTC
| SORT_KEYS
| UTC_Z) as i32;

pub const MAX_UNPACKB_OPT: i32 = NON_STR_KEYS as i32;
pub const UNPACKB_OPT_MASK: i32 = NON_STR_KEYS as i32;
70 changes: 34 additions & 36 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,44 +82,42 @@ def test_valueerror() -> None:
ormsgpack.unpackb(b"\x91")


def test_option_not_int() -> None:
"""
packb/unpackb() option not int or None
"""
with pytest.raises(ormsgpack.MsgpackEncodeError):
ormsgpack.packb(True, option=True)
with pytest.raises(ormsgpack.MsgpackDecodeError):
ormsgpack.unpackb(b"\x00", option=True)


def test_option_invalid_int() -> None:
"""
packb/unpackb() option invalid 64-bit number
"""
with pytest.raises(ormsgpack.MsgpackEncodeError):
ormsgpack.packb(True, option=9223372036854775809)
with pytest.raises(ormsgpack.MsgpackDecodeError):
ormsgpack.unpackb(b"\x00", option=9223372036854775809)


def test_option_range_low() -> None:
"""
packb/unpackb() option out of range low
"""
with pytest.raises(ormsgpack.MsgpackEncodeError):
ormsgpack.packb(True, option=-1)
with pytest.raises(ormsgpack.MsgpackDecodeError):
ormsgpack.unpackb(b"\x00", option=-1)


def test_option_range_high() -> None:
"""
packb/unpackb() option out of range high
"""
@pytest.mark.parametrize(
"option",
(
1 << 12,
True,
-1,
9223372036854775809,
),
)
def test_packb_invalid_option(option: int) -> None:
with pytest.raises(ormsgpack.MsgpackEncodeError):
ormsgpack.packb(True, option=1 << 14)
ormsgpack.packb(True, option=option)


@pytest.mark.parametrize(
"option",
(
ormsgpack.OPT_NAIVE_UTC,
ormsgpack.OPT_OMIT_MICROSECONDS,
ormsgpack.OPT_PASSTHROUGH_BIG_INT,
ormsgpack.OPT_PASSTHROUGH_DATACLASS,
ormsgpack.OPT_PASSTHROUGH_DATETIME,
ormsgpack.OPT_PASSTHROUGH_SUBCLASS,
ormsgpack.OPT_PASSTHROUGH_TUPLE,
ormsgpack.OPT_SERIALIZE_NUMPY,
ormsgpack.OPT_SERIALIZE_PYDANTIC,
ormsgpack.OPT_SORT_KEYS,
ormsgpack.OPT_UTC_Z,
True,
-1,
9223372036854775809,
),
)
def test_unpackb_invalid_option(option: int) -> None:
with pytest.raises(ormsgpack.MsgpackDecodeError):
ormsgpack.unpackb(b"\x00", option=1 << 14)
ormsgpack.unpackb(b"\x00", option=option)


def test_opts_multiple() -> None:
Expand Down

0 comments on commit 7933610

Please sign in to comment.