From ed15a4da4fae2e9b5ae2a20d5ff39c00f1ecf6f2 Mon Sep 17 00:00:00 2001 From: mejrs <59372212+mejrs@users.noreply.github.com> Date: Mon, 6 Jan 2025 18:28:33 +0100 Subject: [PATCH] Improve diagnostic for invalid function passed to from_py_with --- pyo3-macros-backend/src/frompyobject.rs | 91 ++++++++++++--------- pyo3-macros-backend/src/params.rs | 10 ++- pyo3-macros-backend/src/pymethod.rs | 14 +++- tests/ui/invalid_argument_attributes.rs | 7 ++ tests/ui/invalid_argument_attributes.stderr | 11 +++ 5 files changed, 88 insertions(+), 45 deletions(-) diff --git a/pyo3-macros-backend/src/frompyobject.rs b/pyo3-macros-backend/src/frompyobject.rs index 565c54da1f3..497ebfcd015 100644 --- a/pyo3-macros-backend/src/frompyobject.rs +++ b/pyo3-macros-backend/src/frompyobject.rs @@ -1,7 +1,7 @@ use crate::attributes::{self, get_pyo3_options, CrateAttribute, FromPyWithAttribute}; use crate::utils::Ctx; use proc_macro2::TokenStream; -use quote::{format_ident, quote}; +use quote::{format_ident, quote, quote_spanned}; use syn::{ ext::IdentExt, parenthesized, @@ -264,31 +264,40 @@ impl<'a> Container<'a> { let struct_name = self.name(); if let Some(ident) = field_ident { let field_name = ident.to_string(); - match from_py_with { - None => quote! { + if let Some(FromPyWithAttribute { + kw, + value: expr_path, + }) = from_py_with + { + let extractor = quote_spanned! { kw.span => + { let from_py_with: fn(_) -> _ = #expr_path; from_py_with } + }; + quote! { Ok(#self_ty { - #ident: #pyo3_path::impl_::frompyobject::extract_struct_field(obj, #struct_name, #field_name)? + #ident: #pyo3_path::impl_::frompyobject::extract_struct_field_with(#extractor, obj, #struct_name, #field_name)? }) - }, - Some(FromPyWithAttribute { - value: expr_path, .. - }) => quote! { + } + } else { + quote! { Ok(#self_ty { - #ident: #pyo3_path::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, obj, #struct_name, #field_name)? + #ident: #pyo3_path::impl_::frompyobject::extract_struct_field(obj, #struct_name, #field_name)? }) - }, + } + } + } else if let Some(FromPyWithAttribute { + kw, + value: expr_path, + }) = from_py_with + { + let extractor = quote_spanned! { kw.span => + { let from_py_with: fn(_) -> _ = #expr_path; from_py_with } + }; + quote! { + #pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#extractor, obj, #struct_name, 0).map(#self_ty) } } else { - match from_py_with { - None => quote! { - #pyo3_path::impl_::frompyobject::extract_tuple_struct_field(obj, #struct_name, 0).map(#self_ty) - }, - - Some(FromPyWithAttribute { - value: expr_path, .. - }) => quote! { - #pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path as fn(_) -> _, obj, #struct_name, 0).map(#self_ty) - }, + quote! { + #pyo3_path::impl_::frompyobject::extract_tuple_struct_field(obj, #struct_name, 0).map(#self_ty) } } } @@ -301,16 +310,20 @@ impl<'a> Container<'a> { .map(|i| format_ident!("arg{}", i)) .collect(); let fields = struct_fields.iter().zip(&field_idents).enumerate().map(|(index, (field, ident))| { - match &field.from_py_with { - None => quote!( + if let Some(FromPyWithAttribute { + kw, + value: expr_path, .. + }) = &field.from_py_with { + let extractor = quote_spanned! { kw.span => + { let from_py_with: fn(_) -> _ = #expr_path; from_py_with } + }; + quote! { + #pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#extractor, &#ident, #struct_name, #index)? + } + } else { + quote!{ #pyo3_path::impl_::frompyobject::extract_tuple_struct_field(&#ident, #struct_name, #index)? - ), - Some(FromPyWithAttribute { - value: expr_path, .. - }) => quote! ( - #pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path as fn(_) -> _, &#ident, #struct_name, #index)? - ), - } + }} }); quote!( @@ -346,15 +359,17 @@ impl<'a> Container<'a> { quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #field_name))) } }; - let extractor = match &field.from_py_with { - None => { - quote!(#pyo3_path::impl_::frompyobject::extract_struct_field(&#getter?, #struct_name, #field_name)?) - } - Some(FromPyWithAttribute { - value: expr_path, .. - }) => { - quote! (#pyo3_path::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, &#getter?, #struct_name, #field_name)?) - } + let extractor = if let Some(FromPyWithAttribute { + kw, + value: expr_path, + }) = &field.from_py_with + { + let extractor = quote_spanned! { kw.span => + { let from_py_with: fn(_) -> _ = #expr_path; from_py_with } + }; + quote! (#pyo3_path::impl_::frompyobject::extract_struct_field_with(#extractor, &#getter?, #struct_name, #field_name)?) + } else { + quote!(#pyo3_path::impl_::frompyobject::extract_struct_field(&#getter?, #struct_name, #field_name)?) }; fields.push(quote!(#ident: #extractor)); diff --git a/pyo3-macros-backend/src/params.rs b/pyo3-macros-backend/src/params.rs index f967149c725..9425b8d32b6 100644 --- a/pyo3-macros-backend/src/params.rs +++ b/pyo3-macros-backend/src/params.rs @@ -1,5 +1,6 @@ use crate::utils::Ctx; use crate::{ + attributes::FromPyWithAttribute, method::{FnArg, FnSpec, RegularArg}, pyfunction::FunctionSignature, quotes::some_wrap, @@ -248,13 +249,16 @@ pub(crate) fn impl_regular_arg_param( default = default.map(|tokens| some_wrap(tokens, ctx)); } - if arg.from_py_with.is_some() { + if let Some(FromPyWithAttribute { kw, .. }) = arg.from_py_with { + let extractor = quote_spanned! { kw.span => + { let from_py_with: fn(_) -> _ = #from_py_with; from_py_with } + }; if let Some(default) = default { quote_arg_span! { #pyo3_path::impl_::extract_argument::from_py_with_with_default( #arg_value, #name_str, - #from_py_with as fn(_) -> _, + #extractor, #[allow(clippy::redundant_closure)] { || #default @@ -267,7 +271,7 @@ pub(crate) fn impl_regular_arg_param( #pyo3_path::impl_::extract_argument::from_py_with( #unwrap, #name_str, - #from_py_with as fn(_) -> _, + #extractor, )? } } diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index c21f6d4556e..825a4addfd3 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -1,7 +1,7 @@ use std::borrow::Cow; use std::ffi::CString; -use crate::attributes::{NameAttribute, RenamingRule}; +use crate::attributes::{FromPyWithAttribute, NameAttribute, RenamingRule}; use crate::method::{CallingConvention, ExtractErrorMode, PyArg}; use crate::params::{impl_regular_arg_param, Holders}; use crate::utils::PythonDoc; @@ -1179,14 +1179,20 @@ fn extract_object( let Ctx { pyo3_path, .. } = ctx; let name = arg.name().unraw().to_string(); - let extract = if let Some(from_py_with) = - arg.from_py_with().map(|from_py_with| &from_py_with.value) + let extract = if let Some(FromPyWithAttribute { + kw, + value: extractor, + }) = arg.from_py_with() { + let extractor = quote_spanned! { kw.span => + { let from_py_with: fn(_) -> _ = #extractor; from_py_with } + }; + quote! { #pyo3_path::impl_::extract_argument::from_py_with( unsafe { #pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(py, &#source_ptr).0 }, #name, - #from_py_with as fn(_) -> _, + #extractor, ) } } else { diff --git a/tests/ui/invalid_argument_attributes.rs b/tests/ui/invalid_argument_attributes.rs index 6797642d77b..819d6709ef8 100644 --- a/tests/ui/invalid_argument_attributes.rs +++ b/tests/ui/invalid_argument_attributes.rs @@ -15,4 +15,11 @@ fn from_py_with_value_not_a_string(#[pyo3(from_py_with = func)] _param: String) #[pyfunction] fn from_py_with_repeated(#[pyo3(from_py_with = "func", from_py_with = "func")] _param: String) {} +fn bytes_from_py(bytes: &Bound<'_, pyo3::types::PyBytes>) -> Vec { + bytes.as_bytes().to_vec() +} + +#[pyfunction] +fn f(#[pyo3(from_py_with = "bytes_from_py")] _bytes: Vec) {} + fn main() {} diff --git a/tests/ui/invalid_argument_attributes.stderr b/tests/ui/invalid_argument_attributes.stderr index e6c42f82a87..6679dd635f1 100644 --- a/tests/ui/invalid_argument_attributes.stderr +++ b/tests/ui/invalid_argument_attributes.stderr @@ -27,3 +27,14 @@ error: `from_py_with` may only be specified once per argument | 16 | fn from_py_with_repeated(#[pyo3(from_py_with = "func", from_py_with = "func")] _param: String) {} | ^^^^^^^^^^^^ + +error[E0308]: mismatched types + --> tests/ui/invalid_argument_attributes.rs:23:13 + | +22 | #[pyfunction] + | ------------- here the type of `from_py_with` is inferred to be `fn(&pyo3::Bound<'_, PyBytes>) -> Vec` +23 | fn f(#[pyo3(from_py_with = "bytes_from_py")] _bytes: Vec) {} + | ^^^^^^^^^^^^ expected `PyAny`, found `PyBytes` + | + = note: expected fn pointer `fn(&pyo3::Bound<'_, PyAny>) -> Result<_, PyErr>` + found fn pointer `fn(&pyo3::Bound<'_, PyBytes>) -> Vec`