Skip to content

Commit

Permalink
derive(FromPyObject): adds default option (#4829)
Browse files Browse the repository at this point in the history
* derive(FromPyObject): adds default option

Takes an optional expression to set a custom value that is not the one from the Default trait

* Documentation, testing and hygiene

* Support enum variant named fields and cover failures
  • Loading branch information
Tpt authored Jan 10, 2025
1 parent 1840bc5 commit 21132a8
Show file tree
Hide file tree
Showing 8 changed files with 258 additions and 14 deletions.
42 changes: 42 additions & 0 deletions guide/src/conversions/traits.md
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,48 @@ If the input is neither a string nor an integer, the error message will be:
- apply a custom function to convert the field from Python the desired Rust type.
- the argument must be the name of the function as a string.
- the function signature must be `fn(&Bound<PyAny>) -> PyResult<T>` where `T` is the Rust type of the argument.
- `pyo3(default)`, `pyo3(default = ...)`
- if the argument is set, uses the given default value.
- in this case, the argument must be a Rust expression returning a value of the desired Rust type.
- if the argument is not set, [`Default::default`](https://doc.rust-lang.org/std/default/trait.Default.html#tymethod.default) is used.
- note that the default value is only used if the field is not set.
If the field is set and the conversion function from Python to Rust fails, an exception is raised and the default value is not used.
- this attribute is only supported on named fields.

For example, the code below applies the given conversion function on the `"value"` dict item to compute its length or fall back to the type default value (0):

```rust
use pyo3::prelude::*;

#[derive(FromPyObject)]
struct RustyStruct {
#[pyo3(item("value"), default, from_py_with = "Bound::<'_, PyAny>::len")]
len: usize,
#[pyo3(item)]
other: usize,
}
#
# use pyo3::types::PyDict;
# fn main() -> PyResult<()> {
# Python::with_gil(|py| -> PyResult<()> {
# // Filled case
# let dict = PyDict::new(py);
# dict.set_item("value", (1,)).unwrap();
# dict.set_item("other", 1).unwrap();
# let result = dict.extract::<RustyStruct>()?;
# assert_eq!(result.len, 1);
# assert_eq!(result.other, 1);
#
# // Empty case
# let dict = PyDict::new(py);
# dict.set_item("other", 1).unwrap();
# let result = dict.extract::<RustyStruct>()?;
# assert_eq!(result.len, 0);
# assert_eq!(result.other, 1);
# Ok(())
# })
# }
```

### `IntoPyObject`
The ['IntoPyObject'] trait defines the to-python conversion for a Rust type. All types in PyO3 implement this trait,
Expand Down
1 change: 1 addition & 0 deletions newsfragments/4829.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`derive(FromPyObject)` allow a `default` attribute to set a default value for extracted fields of named structs. The default value is either provided explicitly or fetched via `Default::default()`.
2 changes: 2 additions & 0 deletions pyo3-macros-backend/src/attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ impl<K: ToTokens, V: ToTokens> ToTokens for OptionalKeywordAttribute<K, V> {

pub type FromPyWithAttribute = KeywordAttribute<kw::from_py_with, LitStrValue<ExprPath>>;

pub type DefaultAttribute = OptionalKeywordAttribute<Token![default], Expr>;

/// For specifying the path to the pyo3 crate.
pub type CrateAttribute = KeywordAttribute<Token![crate], LitStrValue<Path>>;

Expand Down
66 changes: 53 additions & 13 deletions pyo3-macros-backend/src/frompyobject.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use crate::attributes::{self, get_pyo3_options, CrateAttribute, FromPyWithAttribute};
use crate::attributes::{
self, get_pyo3_options, CrateAttribute, DefaultAttribute, FromPyWithAttribute,
};
use crate::utils::Ctx;
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use quote::{format_ident, quote, ToTokens};
use syn::{
ext::IdentExt,
parenthesized,
Expand Down Expand Up @@ -90,6 +92,7 @@ struct NamedStructField<'a> {
ident: &'a syn::Ident,
getter: Option<FieldGetter>,
from_py_with: Option<FromPyWithAttribute>,
default: Option<DefaultAttribute>,
}

struct TupleStructField {
Expand Down Expand Up @@ -144,6 +147,10 @@ impl<'a> Container<'a> {
attrs.getter.is_none(),
field.span() => "`getter` is not permitted on tuple struct elements."
);
ensure_spanned!(
attrs.default.is_none(),
field.span() => "`default` is not permitted on tuple struct elements."
);
Ok(TupleStructField {
from_py_with: attrs.from_py_with,
})
Expand Down Expand Up @@ -193,10 +200,15 @@ impl<'a> Container<'a> {
ident,
getter: attrs.getter,
from_py_with: attrs.from_py_with,
default: attrs.default,
})
})
.collect::<Result<Vec<_>>>()?;
if options.transparent {
if struct_fields.iter().all(|field| field.default.is_some()) {
bail_spanned!(
fields.span() => "cannot derive FromPyObject for structs and variants with only default values"
)
} else if options.transparent {
ensure_spanned!(
struct_fields.len() == 1,
fields.span() => "transparent structs and variants can only have 1 field"
Expand Down Expand Up @@ -346,18 +358,33 @@ impl<'a> Container<'a> {
quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #field_name)))
}
};
let extractor = match &field.from_py_with {
None => {
quote!(#pyo3_path::impl_::frompyobject::extract_struct_field(&#getter?, #struct_name, #field_name)?)
}
Some(FromPyWithAttribute {
value: expr_path, ..
}) => {
quote! (#pyo3_path::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, &#getter?, #struct_name, #field_name)?)
}
let extractor = if let Some(FromPyWithAttribute {
value: expr_path, ..
}) = &field.from_py_with
{
quote!(#pyo3_path::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, &value, #struct_name, #field_name)?)
} else {
quote!(#pyo3_path::impl_::frompyobject::extract_struct_field(&value, #struct_name, #field_name)?)
};
let extracted = if let Some(default) = &field.default {
let default_expr = if let Some(default_expr) = &default.value {
default_expr.to_token_stream()
} else {
quote!(::std::default::Default::default())
};
quote!(if let ::std::result::Result::Ok(value) = #getter {
#extractor
} else {
#default_expr
})
} else {
quote!({
let value = #getter?;
#extractor
})
};

fields.push(quote!(#ident: #extractor));
fields.push(quote!(#ident: #extracted));
}

quote!(::std::result::Result::Ok(#self_ty{#fields}))
Expand Down Expand Up @@ -458,6 +485,7 @@ impl ContainerOptions {
struct FieldPyO3Attributes {
getter: Option<FieldGetter>,
from_py_with: Option<FromPyWithAttribute>,
default: Option<DefaultAttribute>,
}

#[derive(Clone, Debug)]
Expand All @@ -469,6 +497,7 @@ enum FieldGetter {
enum FieldPyO3Attribute {
Getter(FieldGetter),
FromPyWith(FromPyWithAttribute),
Default(DefaultAttribute),
}

impl Parse for FieldPyO3Attribute {
Expand Down Expand Up @@ -512,6 +541,8 @@ impl Parse for FieldPyO3Attribute {
}
} else if lookahead.peek(attributes::kw::from_py_with) {
input.parse().map(FieldPyO3Attribute::FromPyWith)
} else if lookahead.peek(Token![default]) {
input.parse().map(FieldPyO3Attribute::Default)
} else {
Err(lookahead.error())
}
Expand All @@ -523,6 +554,7 @@ impl FieldPyO3Attributes {
fn from_attrs(attrs: &[Attribute]) -> Result<Self> {
let mut getter = None;
let mut from_py_with = None;
let mut default = None;

for attr in attrs {
if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
Expand All @@ -542,6 +574,13 @@ impl FieldPyO3Attributes {
);
from_py_with = Some(from_py_with_attr);
}
FieldPyO3Attribute::Default(default_attr) => {
ensure_spanned!(
default.is_none(),
attr.span() => "`default` may only be provided once"
);
default = Some(default_attr);
}
}
}
}
Expand All @@ -550,6 +589,7 @@ impl FieldPyO3Attributes {
Ok(FieldPyO3Attributes {
getter,
from_py_with,
default,
})
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/tests/hygiene/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ struct Derive3 {
f: i32,
#[pyo3(item(42))]
g: i32,
#[pyo3(default)]
h: i32,
} // struct case

#[derive(crate::FromPyObject)]
Expand Down
114 changes: 114 additions & 0 deletions tests/test_frompyobject.rs
Original file line number Diff line number Diff line change
Expand Up @@ -686,3 +686,117 @@ fn test_with_keyword_item() {
assert_eq!(result, expected);
});
}

#[derive(Debug, FromPyObject, PartialEq, Eq)]
pub struct WithDefaultItem {
#[pyo3(item, default)]
opt: Option<usize>,
#[pyo3(item)]
value: usize,
}

#[test]
fn test_with_default_item() {
Python::with_gil(|py| {
let dict = PyDict::new(py);
dict.set_item("value", 3).unwrap();
let result = dict.extract::<WithDefaultItem>().unwrap();
let expected = WithDefaultItem {
value: 3,
opt: None,
};
assert_eq!(result, expected);
});
}

#[derive(Debug, FromPyObject, PartialEq, Eq)]
pub struct WithExplicitDefaultItem {
#[pyo3(item, default = 1)]
opt: usize,
#[pyo3(item)]
value: usize,
}

#[test]
fn test_with_explicit_default_item() {
Python::with_gil(|py| {
let dict = PyDict::new(py);
dict.set_item("value", 3).unwrap();
let result = dict.extract::<WithExplicitDefaultItem>().unwrap();
let expected = WithExplicitDefaultItem { value: 3, opt: 1 };
assert_eq!(result, expected);
});
}

#[derive(Debug, FromPyObject, PartialEq, Eq)]
pub struct WithDefaultItemAndConversionFunction {
#[pyo3(item, default, from_py_with = "Bound::<'_, PyAny>::len")]
opt: usize,
#[pyo3(item)]
value: usize,
}

#[test]
fn test_with_default_item_and_conversion_function() {
Python::with_gil(|py| {
// Filled case
let dict = PyDict::new(py);
dict.set_item("opt", (1,)).unwrap();
dict.set_item("value", 3).unwrap();
let result = dict
.extract::<WithDefaultItemAndConversionFunction>()
.unwrap();
let expected = WithDefaultItemAndConversionFunction { opt: 1, value: 3 };
assert_eq!(result, expected);

// Empty case
let dict = PyDict::new(py);
dict.set_item("value", 3).unwrap();
let result = dict
.extract::<WithDefaultItemAndConversionFunction>()
.unwrap();
let expected = WithDefaultItemAndConversionFunction { opt: 0, value: 3 };
assert_eq!(result, expected);

// Error case
let dict = PyDict::new(py);
dict.set_item("value", 3).unwrap();
dict.set_item("opt", 1).unwrap();
assert!(dict
.extract::<WithDefaultItemAndConversionFunction>()
.is_err());
});
}

#[derive(Debug, FromPyObject, PartialEq, Eq)]
pub enum WithDefaultItemEnum {
#[pyo3(from_item_all)]
Foo {
a: usize,
#[pyo3(default)]
b: usize,
},
NeverUsedA {
a: usize,
},
}

#[test]
fn test_with_default_item_enum() {
Python::with_gil(|py| {
// A and B filled
let dict = PyDict::new(py);
dict.set_item("a", 1).unwrap();
dict.set_item("b", 2).unwrap();
let result = dict.extract::<WithDefaultItemEnum>().unwrap();
let expected = WithDefaultItemEnum::Foo { a: 1, b: 2 };
assert_eq!(result, expected);

// A filled
let dict = PyDict::new(py);
dict.set_item("a", 1).unwrap();
let result = dict.extract::<WithDefaultItemEnum>().unwrap();
let expected = WithDefaultItemEnum::Foo { a: 1, b: 0 };
assert_eq!(result, expected);
});
}
17 changes: 17 additions & 0 deletions tests/ui/invalid_frompy_derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,4 +213,21 @@ struct FromItemAllConflictAttrWithArgs {
field: String,
}

#[derive(FromPyObject)]
struct StructWithOnlyDefaultValues {
#[pyo3(default)]
field: String,
}

#[derive(FromPyObject)]
enum EnumVariantWithOnlyDefaultValues {
Foo {
#[pyo3(default)]
field: String,
},
}

#[derive(FromPyObject)]
struct NamedTuplesWithDefaultValues(#[pyo3(default)] String);

fn main() {}
28 changes: 27 additions & 1 deletion tests/ui/invalid_frompy_derive.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ error: transparent structs and variants can only have 1 field
70 | | },
| |_____^

error: expected one of: `attribute`, `item`, `from_py_with`
error: expected one of: `attribute`, `item`, `from_py_with`, `default`
--> tests/ui/invalid_frompy_derive.rs:76:12
|
76 | #[pyo3(attr)]
Expand Down Expand Up @@ -223,3 +223,29 @@ error: The struct is already annotated with `from_item_all`, `attribute` is not
|
210 | #[pyo3(from_item_all)]
| ^^^^^^^^^^^^^

error: cannot derive FromPyObject for structs and variants with only default values
--> tests/ui/invalid_frompy_derive.rs:217:36
|
217 | struct StructWithOnlyDefaultValues {
| ____________________________________^
218 | | #[pyo3(default)]
219 | | field: String,
220 | | }
| |_^

error: cannot derive FromPyObject for structs and variants with only default values
--> tests/ui/invalid_frompy_derive.rs:224:9
|
224 | Foo {
| _________^
225 | | #[pyo3(default)]
226 | | field: String,
227 | | },
| |_____^

error: `default` is not permitted on tuple struct elements.
--> tests/ui/invalid_frompy_derive.rs:231:37
|
231 | struct NamedTuplesWithDefaultValues(#[pyo3(default)] String);
| ^

0 comments on commit 21132a8

Please sign in to comment.