diff --git a/src/types/dict.rs b/src/types/dict.rs index 129f32dc9e1..905d31b39e6 100644 --- a/src/types/dict.rs +++ b/src/types/dict.rs @@ -4,7 +4,7 @@ use crate::ffi_ptr_ext::FfiPtrExt; use crate::instance::{Borrowed, Bound}; use crate::py_result_ext::PyResultExt; use crate::types::{PyAny, PyAnyMethods, PyList, PyMapping}; -use crate::{ffi, BoundObject, IntoPyObject, Python}; +use crate::{ffi, BoundObject, IntoPyObject, Python, TryExtend}; /// Represents a Python `dict`. /// @@ -408,6 +408,98 @@ impl<'py> PyDictMethods<'py> for Bound<'py, PyDict> { } } +impl<'py, I> TryExtend, Bound<'py, PyAny>)> for Bound<'_, PyDict> +where + I: IntoIterator, Bound<'py, PyAny>)>, +{ + #[cfg(not(feature = "nightly"))] + fn try_extend(&mut self, iter: I) -> PyResult<()> { + iter.into_iter() + .try_for_each(|(key, value)| self.set_item(key, value)) + } + + #[cfg(feature = "nightly")] + default fn try_extend(&mut self, iter: I) -> PyResult<()> { + iter.into_iter() + .try_for_each(|(key, value)| self.set_item(key, value)) + } +} + +impl<'py, I> TryExtend, Bound<'py, PyAny>)>> for Bound<'_, PyDict> +where + I: IntoIterator, Bound<'py, PyAny>)>>, +{ + #[cfg(not(feature = "nightly"))] + fn try_extend(&mut self, iter: I) -> PyResult<()> { + iter.into_iter().try_for_each(|item| { + let (key, value) = item?; + self.set_item(key, value) + }) + } + + #[cfg(feature = "nightly")] + default fn try_extend(&mut self, iter: I) -> PyResult<()> { + iter.into_iter().try_for_each(|item| { + let (key, value) = item?; + self.set_item(key, value) + }) + } +} + +impl<'py, I> TryExtend> for Bound<'_, PyDict> +where + I: IntoIterator>, +{ + #[cfg(not(feature = "nightly"))] + fn try_extend(&mut self, iter: I) -> PyResult<()> { + iter.into_iter().try_for_each(|item| { + let (key, value): (Bound<'py, PyAny>, Bound<'py, PyAny>) = item.extract()?; + self.set_item(key, value) + }) + } + + #[cfg(feature = "nightly")] + default fn try_extend(&mut self, iter: I) -> PyResult<()> { + iter.into_iter().try_for_each(|item| { + let (key, value): (Bound<'py, PyAny>, Bound<'py, PyAny>) = item.extract()?; + self.set_item(key, value) + }) + } +} + +#[cfg(feature = "nightly")] +impl<'py> TryExtend, (Bound<'py, PyAny>, Bound<'py, PyAny>)> + for Bound<'_, PyDict> +{ + #[cfg(feature = "nightly")] + fn try_extend(&mut self, iter: Bound<'py, PyDict>) -> PyResult<()> { + err::error_on_minusone(iter.py(), unsafe { + ffi::PyDict_Merge(self.as_ptr(), iter.as_ptr(), 1) + }) + } +} + +macro_rules! impl_try_extend_specialization( + ($i:ty, $t:ty) => { + #[cfg(feature = "nightly")] + impl<'py> TryExtend<$i, $t> for Bound<'_, PyDict> { + fn try_extend(&mut self, iter: $i) -> PyResult<()> { + err::error_on_minusone(iter.py(), unsafe { + ffi::PyDict_MergeFromSeq2(self.as_ptr(), iter.as_ptr(), 1) + }) + } + } + } +); + +impl_try_extend_specialization!( + Bound<'py, crate::types::PyIterator>, + PyResult> +); +impl_try_extend_specialization!(Bound<'py, crate::types::PyList>, Bound<'py, PyAny>); +impl_try_extend_specialization!(Bound<'py, crate::types::PySet>, Bound<'py, PyAny>); +impl_try_extend_specialization!(Bound<'py, crate::types::PyTuple>, Bound<'py, PyAny>); + impl<'a, 'py> Borrowed<'a, 'py, PyDict> { /// Iterates over the contents of this dictionary without incrementing reference counts. /// @@ -1652,4 +1744,42 @@ mod tests { .is_err()); }); } + + #[test] + fn test_dict_extend() { + Python::with_gil::<_, PyResult<()>>(|py| { + let mut dict = PyDict::new(py); + + let vec = vec![( + Bound::into_any(1.into_pyobject(py)?), + Bound::into_any(1.into_pyobject(py)?), + )]; + dict.try_extend(vec)?; + + let slice = [( + Bound::into_any(2.into_pyobject(py)?), + Bound::into_any(2.into_pyobject(py)?), + )]; + dict.try_extend(slice)?; + + let other_dict = [(3, 3)].into_py_dict(py)?; + dict.try_extend(other_dict)?; + + let list = PyList::new(py, [(4, 4)])?; + dict.try_extend(list)?; + + let tuple = PyTuple::new(py, [(5, 5)])?; + dict.try_extend(tuple)?; + + assert_eq!(dict.len(), 5); + assert!(dict.iter().all(|(k, v)| { + let k = k.extract::().unwrap(); + let v = v.extract::().unwrap(); + k == v + })); + + Ok(()) + }) + .unwrap(); + } }