diff --git a/guide/src/conversions/traits.md b/guide/src/conversions/traits.md index c4e8f14866c..1aa445cce41 100644 --- a/guide/src/conversions/traits.md +++ b/guide/src/conversions/traits.md @@ -488,6 +488,48 @@ If the input is neither a string nor an integer, the error message will be: - apply a custom function to convert the field from Python the desired Rust type. - the argument must be the name of the function as a string. - the function signature must be `fn(&Bound) -> PyResult` where `T` is the Rust type of the argument. +- `pyo3(default)`, `pyo3(default = ...)` + - if the argument is set, uses the given default value. + - in this case, the argument must be a Rust expression returning a value of the desired Rust type. + - if the argument is not set, [`Default::default`](https://doc.rust-lang.org/std/default/trait.Default.html#tymethod.default) is used. + - note that the default value is only used if the field is not set. + If the field is set and the conversion function from Python to Rust fails, an exception is raised and the default value is not used. + - this attribute is only supported on named fields. + +For example, the code below applies the given conversion function on the `"value"` dict item to compute its length or fall back to the type default value (0): + +```rust +use pyo3::prelude::*; + +#[derive(FromPyObject)] +struct RustyStruct { + #[pyo3(item("value"), default, from_py_with = "Bound::<'_, PyAny>::len")] + len: usize, + #[pyo3(item)] + other: usize, +} +# +# use pyo3::types::PyDict; +# fn main() -> PyResult<()> { +# Python::with_gil(|py| -> PyResult<()> { +# // Filled case +# let dict = PyDict::new(py); +# dict.set_item("value", (1,)).unwrap(); +# dict.set_item("other", 1).unwrap(); +# let result = dict.extract::()?; +# assert_eq!(result.len, 1); +# assert_eq!(result.other, 1); +# +# // Empty case +# let dict = PyDict::new(py); +# dict.set_item("other", 1).unwrap(); +# let result = dict.extract::()?; +# assert_eq!(result.len, 0); +# assert_eq!(result.other, 1); +# Ok(()) +# }) +# } +``` ### `IntoPyObject` The ['IntoPyObject'] trait defines the to-python conversion for a Rust type. All types in PyO3 implement this trait, diff --git a/newsfragments/4829.added.md b/newsfragments/4829.added.md new file mode 100644 index 00000000000..9400501a799 --- /dev/null +++ b/newsfragments/4829.added.md @@ -0,0 +1 @@ +`derive(FromPyObject)` allow a `default` attribute to set a default value for extracted fields of named structs. The default value is either provided explicitly or fetched via `Default::default()`. \ No newline at end of file diff --git a/pyo3-macros-backend/src/attributes.rs b/pyo3-macros-backend/src/attributes.rs index 6fe75e44302..bd5da377121 100644 --- a/pyo3-macros-backend/src/attributes.rs +++ b/pyo3-macros-backend/src/attributes.rs @@ -351,6 +351,8 @@ impl ToTokens for OptionalKeywordAttribute { pub type FromPyWithAttribute = KeywordAttribute>; +pub type DefaultAttribute = OptionalKeywordAttribute; + /// For specifying the path to the pyo3 crate. pub type CrateAttribute = KeywordAttribute>; diff --git a/pyo3-macros-backend/src/frompyobject.rs b/pyo3-macros-backend/src/frompyobject.rs index 565c54da1f3..b353e2dc16d 100644 --- a/pyo3-macros-backend/src/frompyobject.rs +++ b/pyo3-macros-backend/src/frompyobject.rs @@ -1,7 +1,9 @@ -use crate::attributes::{self, get_pyo3_options, CrateAttribute, FromPyWithAttribute}; +use crate::attributes::{ + self, get_pyo3_options, CrateAttribute, DefaultAttribute, FromPyWithAttribute, +}; use crate::utils::Ctx; use proc_macro2::TokenStream; -use quote::{format_ident, quote}; +use quote::{format_ident, quote, ToTokens}; use syn::{ ext::IdentExt, parenthesized, @@ -90,6 +92,7 @@ struct NamedStructField<'a> { ident: &'a syn::Ident, getter: Option, from_py_with: Option, + default: Option, } struct TupleStructField { @@ -144,6 +147,10 @@ impl<'a> Container<'a> { attrs.getter.is_none(), field.span() => "`getter` is not permitted on tuple struct elements." ); + ensure_spanned!( + attrs.default.is_none(), + field.span() => "`default` is not permitted on tuple struct elements." + ); Ok(TupleStructField { from_py_with: attrs.from_py_with, }) @@ -193,10 +200,15 @@ impl<'a> Container<'a> { ident, getter: attrs.getter, from_py_with: attrs.from_py_with, + default: attrs.default, }) }) .collect::>>()?; - if options.transparent { + if struct_fields.iter().all(|field| field.default.is_some()) { + bail_spanned!( + fields.span() => "cannot derive FromPyObject for structs and variants with only default values" + ) + } else if options.transparent { ensure_spanned!( struct_fields.len() == 1, fields.span() => "transparent structs and variants can only have 1 field" @@ -346,18 +358,33 @@ impl<'a> Container<'a> { quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #field_name))) } }; - let extractor = match &field.from_py_with { - None => { - quote!(#pyo3_path::impl_::frompyobject::extract_struct_field(&#getter?, #struct_name, #field_name)?) - } - Some(FromPyWithAttribute { - value: expr_path, .. - }) => { - quote! (#pyo3_path::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, &#getter?, #struct_name, #field_name)?) - } + let extractor = if let Some(FromPyWithAttribute { + value: expr_path, .. + }) = &field.from_py_with + { + quote!(#pyo3_path::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, &value, #struct_name, #field_name)?) + } else { + quote!(#pyo3_path::impl_::frompyobject::extract_struct_field(&value, #struct_name, #field_name)?) + }; + let extracted = if let Some(default) = &field.default { + let default_expr = if let Some(default_expr) = &default.value { + default_expr.to_token_stream() + } else { + quote!(::std::default::Default::default()) + }; + quote!(if let ::std::result::Result::Ok(value) = #getter { + #extractor + } else { + #default_expr + }) + } else { + quote!({ + let value = #getter?; + #extractor + }) }; - fields.push(quote!(#ident: #extractor)); + fields.push(quote!(#ident: #extracted)); } quote!(::std::result::Result::Ok(#self_ty{#fields})) @@ -458,6 +485,7 @@ impl ContainerOptions { struct FieldPyO3Attributes { getter: Option, from_py_with: Option, + default: Option, } #[derive(Clone, Debug)] @@ -469,6 +497,7 @@ enum FieldGetter { enum FieldPyO3Attribute { Getter(FieldGetter), FromPyWith(FromPyWithAttribute), + Default(DefaultAttribute), } impl Parse for FieldPyO3Attribute { @@ -512,6 +541,8 @@ impl Parse for FieldPyO3Attribute { } } else if lookahead.peek(attributes::kw::from_py_with) { input.parse().map(FieldPyO3Attribute::FromPyWith) + } else if lookahead.peek(Token![default]) { + input.parse().map(FieldPyO3Attribute::Default) } else { Err(lookahead.error()) } @@ -523,6 +554,7 @@ impl FieldPyO3Attributes { fn from_attrs(attrs: &[Attribute]) -> Result { let mut getter = None; let mut from_py_with = None; + let mut default = None; for attr in attrs { if let Some(pyo3_attrs) = get_pyo3_options(attr)? { @@ -542,6 +574,13 @@ impl FieldPyO3Attributes { ); from_py_with = Some(from_py_with_attr); } + FieldPyO3Attribute::Default(default_attr) => { + ensure_spanned!( + default.is_none(), + attr.span() => "`default` may only be provided once" + ); + default = Some(default_attr); + } } } } @@ -550,6 +589,7 @@ impl FieldPyO3Attributes { Ok(FieldPyO3Attributes { getter, from_py_with, + default, }) } } diff --git a/src/tests/hygiene/misc.rs b/src/tests/hygiene/misc.rs index 6e00167ddb6..a953cea4a24 100644 --- a/src/tests/hygiene/misc.rs +++ b/src/tests/hygiene/misc.rs @@ -12,6 +12,8 @@ struct Derive3 { f: i32, #[pyo3(item(42))] g: i32, + #[pyo3(default)] + h: i32, } // struct case #[derive(crate::FromPyObject)] diff --git a/tests/test_frompyobject.rs b/tests/test_frompyobject.rs index 2192caf1f7c..d72a215814c 100644 --- a/tests/test_frompyobject.rs +++ b/tests/test_frompyobject.rs @@ -686,3 +686,117 @@ fn test_with_keyword_item() { assert_eq!(result, expected); }); } + +#[derive(Debug, FromPyObject, PartialEq, Eq)] +pub struct WithDefaultItem { + #[pyo3(item, default)] + opt: Option, + #[pyo3(item)] + value: usize, +} + +#[test] +fn test_with_default_item() { + Python::with_gil(|py| { + let dict = PyDict::new(py); + dict.set_item("value", 3).unwrap(); + let result = dict.extract::().unwrap(); + let expected = WithDefaultItem { + value: 3, + opt: None, + }; + assert_eq!(result, expected); + }); +} + +#[derive(Debug, FromPyObject, PartialEq, Eq)] +pub struct WithExplicitDefaultItem { + #[pyo3(item, default = 1)] + opt: usize, + #[pyo3(item)] + value: usize, +} + +#[test] +fn test_with_explicit_default_item() { + Python::with_gil(|py| { + let dict = PyDict::new(py); + dict.set_item("value", 3).unwrap(); + let result = dict.extract::().unwrap(); + let expected = WithExplicitDefaultItem { value: 3, opt: 1 }; + assert_eq!(result, expected); + }); +} + +#[derive(Debug, FromPyObject, PartialEq, Eq)] +pub struct WithDefaultItemAndConversionFunction { + #[pyo3(item, default, from_py_with = "Bound::<'_, PyAny>::len")] + opt: usize, + #[pyo3(item)] + value: usize, +} + +#[test] +fn test_with_default_item_and_conversion_function() { + Python::with_gil(|py| { + // Filled case + let dict = PyDict::new(py); + dict.set_item("opt", (1,)).unwrap(); + dict.set_item("value", 3).unwrap(); + let result = dict + .extract::() + .unwrap(); + let expected = WithDefaultItemAndConversionFunction { opt: 1, value: 3 }; + assert_eq!(result, expected); + + // Empty case + let dict = PyDict::new(py); + dict.set_item("value", 3).unwrap(); + let result = dict + .extract::() + .unwrap(); + let expected = WithDefaultItemAndConversionFunction { opt: 0, value: 3 }; + assert_eq!(result, expected); + + // Error case + let dict = PyDict::new(py); + dict.set_item("value", 3).unwrap(); + dict.set_item("opt", 1).unwrap(); + assert!(dict + .extract::() + .is_err()); + }); +} + +#[derive(Debug, FromPyObject, PartialEq, Eq)] +pub enum WithDefaultItemEnum { + #[pyo3(from_item_all)] + Foo { + a: usize, + #[pyo3(default)] + b: usize, + }, + NeverUsedA { + a: usize, + }, +} + +#[test] +fn test_with_default_item_enum() { + Python::with_gil(|py| { + // A and B filled + let dict = PyDict::new(py); + dict.set_item("a", 1).unwrap(); + dict.set_item("b", 2).unwrap(); + let result = dict.extract::().unwrap(); + let expected = WithDefaultItemEnum::Foo { a: 1, b: 2 }; + assert_eq!(result, expected); + + // A filled + let dict = PyDict::new(py); + dict.set_item("a", 1).unwrap(); + let result = dict.extract::().unwrap(); + let expected = WithDefaultItemEnum::Foo { a: 1, b: 0 }; + assert_eq!(result, expected); + }); +} diff --git a/tests/ui/invalid_frompy_derive.rs b/tests/ui/invalid_frompy_derive.rs index f123b149fb8..d3a778e686b 100644 --- a/tests/ui/invalid_frompy_derive.rs +++ b/tests/ui/invalid_frompy_derive.rs @@ -213,4 +213,21 @@ struct FromItemAllConflictAttrWithArgs { field: String, } +#[derive(FromPyObject)] +struct StructWithOnlyDefaultValues { + #[pyo3(default)] + field: String, +} + +#[derive(FromPyObject)] +enum EnumVariantWithOnlyDefaultValues { + Foo { + #[pyo3(default)] + field: String, + }, +} + +#[derive(FromPyObject)] +struct NamedTuplesWithDefaultValues(#[pyo3(default)] String); + fn main() {} diff --git a/tests/ui/invalid_frompy_derive.stderr b/tests/ui/invalid_frompy_derive.stderr index 8ed03caafb4..5b8c1fc718b 100644 --- a/tests/ui/invalid_frompy_derive.stderr +++ b/tests/ui/invalid_frompy_derive.stderr @@ -84,7 +84,7 @@ error: transparent structs and variants can only have 1 field 70 | | }, | |_____^ -error: expected one of: `attribute`, `item`, `from_py_with` +error: expected one of: `attribute`, `item`, `from_py_with`, `default` --> tests/ui/invalid_frompy_derive.rs:76:12 | 76 | #[pyo3(attr)] @@ -223,3 +223,29 @@ error: The struct is already annotated with `from_item_all`, `attribute` is not | 210 | #[pyo3(from_item_all)] | ^^^^^^^^^^^^^ + +error: cannot derive FromPyObject for structs and variants with only default values + --> tests/ui/invalid_frompy_derive.rs:217:36 + | +217 | struct StructWithOnlyDefaultValues { + | ____________________________________^ +218 | | #[pyo3(default)] +219 | | field: String, +220 | | } + | |_^ + +error: cannot derive FromPyObject for structs and variants with only default values + --> tests/ui/invalid_frompy_derive.rs:224:9 + | +224 | Foo { + | _________^ +225 | | #[pyo3(default)] +226 | | field: String, +227 | | }, + | |_____^ + +error: `default` is not permitted on tuple struct elements. + --> tests/ui/invalid_frompy_derive.rs:231:37 + | +231 | struct NamedTuplesWithDefaultValues(#[pyo3(default)] String); + | ^