diff --git a/.gitignore b/.gitignore index baa1df9..183ce61 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ target Cargo.lock tests/*.npy +benches/*.npy diff --git a/.travis.yml b/.travis.yml index 42c29fa..5ee3e3b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,3 +7,7 @@ matrix: allow_failures: - rust: nightly +script: + - cargo build --verbose + - cargo build --verbose --features derive --examples + - cargo test --verbose --features derive diff --git a/Cargo.toml b/Cargo.toml index d7f0d30..d5b1a1d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,51 @@ members = [ "npy-derive" ] byteorder = "1" nom = "3" +[dependencies.npy-derive] +path = "npy-derive" +version = "0.4" +optional = true +default-features = false + [dev-dependencies] memmap = "0.6" -npy-derive = { path = "npy-derive", version = "0.4" } + +[features] +default = [] + +# Reexports the derive macros so that you can use them qualified under `npy::`: +# +# #[derive(npy::Serialize, npy::Deserialize, npy::AutoSerialize)] +# struct Struct { ... } +# +# This is a nicer alternative to `#[macro_use] extern crate npy_derive`, which +# directly imports things like `#[derive(Serialize)]` that may conflict with +# other crates (e.g. `serde`). +derive = ["npy-derive"] + +[[example]] +name = "plain" + +[[example]] +name = "large" +required-features = ["derive"] + +[[example]] +name = "simple" +required-features = ["derive"] + +[[example]] +name = "roundtrip" +required-features = ["derive"] + +[[test]] +name = "derive_hygiene" +required-features = ["derive"] + +[[test]] +name = "roundtrip" +required-features = ["derive"] + +[[test]] +name = "serialize_array" +required-features = ["derive"] diff --git a/benches/bench.rs b/benches/bench.rs index 35c40e2..001641a 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -1,44 +1,107 @@ #![feature(test)] -#[macro_use] -extern crate npy_derive; extern crate npy; extern crate test; -use npy::Serializable; +use npy::{Serialize, Deserialize, AutoSerialize, TypeWrite, TypeRead}; use test::Bencher; use test::black_box as bb; -#[derive(Serializable, Debug, PartialEq)] -struct Array { - a: i32, - b: f32, +const NITER: usize = 100_000; + +macro_rules! gen_benches { + ($T:ty, $new:expr) => { + #[inline(never)] + fn test_data() -> Vec { + let mut raw = Vec::new(); + let writer = <$T>::writer(&<$T>::default_dtype()).unwrap(); + for i in 0usize..NITER { + writer.write_one(&mut raw, &$new(i)).unwrap(); + } + raw + } + + #[bench] + fn read(b: &mut Bencher) { + let raw = test_data(); + b.iter(|| { + let dtype = <$T>::default_dtype(); + let reader = <$T>::reader(&dtype).unwrap(); + + let mut remainder = &raw[..]; + for _ in 0usize..NITER { + let (value, new_remainder) = reader.read_one(remainder); + bb(value); + remainder = new_remainder; + } + assert_eq!(remainder.len(), 0); + }); + } + + #[bench] + fn read_to_vec(b: &mut Bencher) { + // FIXME: Write to a Cursor> once #16 is merged + let path = concat!("benches/bench_", stringify!($T), ".npy"); + + npy::to_file(path, (0usize..NITER).map($new)).unwrap(); + let bytes = std::fs::read(path).unwrap(); + + b.iter(|| { + bb(npy::NpyData::<$T>::from_bytes(&bytes).unwrap().to_vec()) + }); + } + + #[bench] + fn write(b: &mut Bencher) { + b.iter(|| { + bb(test_data()) + }); + } + }; } -const NITER: usize = 100_000; +#[cfg(feature = "derive")] +mod simple { + use super::*; -fn test_data() -> Vec { - let mut raw = Vec::new(); - for i in 0..NITER { - let arr = Array { a: i as i32, b: i as f32 }; - arr.write(&mut raw).unwrap(); + #[derive(npy::Serialize, npy::Deserialize, npy::AutoSerialize)] + #[derive(Debug, PartialEq)] + struct Simple { + a: i32, + b: f32, } - raw + + gen_benches!(Simple, |i| Simple { a: i as i32, b: i as f32 }); } -#[bench] -fn read(b: &mut Bencher) { - let raw = test_data(); - b.iter(|| { - for i in 0..NITER { - bb(Array::read(&raw[i*8..])); - } - }); +#[cfg(feature = "derive")] +mod one_field { + use super::*; + + #[derive(npy::Serialize, npy::Deserialize, npy::AutoSerialize)] + #[derive(Debug, PartialEq)] + struct OneField { + a: i32, + } + + gen_benches!(OneField, |i| OneField { a: i as i32 }); } -#[bench] -fn write(b: &mut Bencher) { - b.iter(|| { - bb(test_data()) - }); +#[cfg(feature = "derive")] +mod array { + use super::*; + + #[derive(npy::Serialize, npy::Deserialize, npy::AutoSerialize)] + #[derive(Debug, PartialEq)] + struct Array { + a: [f32; 8], + } + + gen_benches!(Array, |i| Array { a: [i as f32; 8] }); +} + +mod plain_f32 { + use super::*; + + gen_benches!(f32, |i| i as f32); } diff --git a/examples/large.rs b/examples/large.rs index c0767c3..3652e49 100644 --- a/examples/large.rs +++ b/examples/large.rs @@ -1,14 +1,11 @@ extern crate memmap; -#[macro_use] -extern crate npy_derive; extern crate npy; use std::fs::File; use memmap::MmapOptions; - -#[derive(Serializable, Debug, Default)] +#[derive(npy::Serialize, npy::Deserialize, Debug, Default)] struct Array { a: i32, b: f32, diff --git a/examples/roundtrip.rs b/examples/roundtrip.rs index ef75eaf..b188611 100644 --- a/examples/roundtrip.rs +++ b/examples/roundtrip.rs @@ -1,11 +1,10 @@ - #[macro_use] extern crate npy_derive; extern crate npy; use std::io::Read; -#[derive(Serializable, Debug, PartialEq, Clone)] +#[derive(Serialize, Deserialize, AutoSerialize, Debug, PartialEq, Clone)] struct Array { a: i32, b: f32, diff --git a/examples/simple.rs b/examples/simple.rs index 2a7eb43..c8b76a6 100644 --- a/examples/simple.rs +++ b/examples/simple.rs @@ -1,5 +1,4 @@ -#[macro_use] extern crate npy_derive; extern crate npy; @@ -12,7 +11,7 @@ use npy::NpyData; // a = np.array([(1,2.5,4), (2,3.1,5)], dtype=[('a', 'i4'),('b', 'f4'),('c', 'i8')]) // np.save('examples/simple.npy', a) -#[derive(Serializable, Debug)] +#[derive(npy::Deserialize, Debug)] struct Array { a: i32, b: f32, diff --git a/npy-derive/Cargo.toml b/npy-derive/Cargo.toml index eced55c..e355b55 100644 --- a/npy-derive/Cargo.toml +++ b/npy-derive/Cargo.toml @@ -10,5 +10,6 @@ repository = "https://github.com/potocpav/npy-rs" proc-macro = true [dependencies] +proc-macro2 = "0.2" quote = "0.4" syn = "0.12" diff --git a/npy-derive/src/lib.rs b/npy-derive/src/lib.rs index beb8999..4333c03 100644 --- a/npy-derive/src/lib.rs +++ b/npy-derive/src/lib.rs @@ -1,102 +1,279 @@ -#![recursion_limit = "128"] +#![recursion_limit = "256"] /*! Derive `trait Serializable` for a structure. -Using this crate, it is enough to `#[derive(Serializable)]` on a struct to be able to serialize and -deserialize it. All the fields must implement [`Serializable`](../npy/trait.Serializable.html). +Using this crate, it is enough to `#[derive(npy::Serialize, npy::Deserialize)]` on a struct to be able to +serialize and deserialize it. All of the fields must implement [`Serialize`](../npy/trait.Serialize.html) +and [`Deserialize`](../npy/trait.Deserialize.html) respectively. */ extern crate proc_macro; +extern crate proc_macro2; extern crate syn; #[macro_use] extern crate quote; use proc_macro::TokenStream; -use syn::Data; -use quote::{Tokens, ToTokens}; +use proc_macro2::Span; +use quote::Tokens; /// Macros 1.1-based custom derive function -#[proc_macro_derive(Serializable)] -pub fn npy_data(input: TokenStream) -> TokenStream { - // Construct a string representation of the type definition - // let s = input.to_string(); +#[proc_macro_derive(Serialize)] +pub fn npy_serialize(input: TokenStream) -> TokenStream { + // Parse the string representation + let ast = syn::parse(input).unwrap(); + + // Build the impl + let expanded = impl_npy_serialize(&ast); + // Return the generated impl + expanded.into() +} + +#[proc_macro_derive(Deserialize)] +pub fn npy_deserialize(input: TokenStream) -> TokenStream { // Parse the string representation let ast = syn::parse(input).unwrap(); // Build the impl - let expanded = impl_npy_data(&ast); + let expanded = impl_npy_deserialize(&ast); // Return the generated impl expanded.into() } -fn impl_npy_data(ast: &syn::DeriveInput) -> quote::Tokens { +#[proc_macro_derive(AutoSerialize)] +pub fn npy_auto_serialize(input: TokenStream) -> TokenStream { + // Parse the string representation + let ast = syn::parse(input).unwrap(); + + // Build the impl + let expanded = impl_npy_auto_serialize(&ast); + + // Return the generated impl + expanded.into() +} + +struct FieldData { + idents: Vec, + idents_str: Vec, + types: Vec, +} + +impl FieldData { + fn extract(ast: &syn::DeriveInput) -> Self { + let fields = match ast.data { + syn::Data::Struct(ref data) => &data.fields, + _ => panic!("npy derive macros can only be used with structs"), + }; + + let idents: Vec = fields.iter().map(|f| { + f.ident.clone().expect("Tuple structs not supported") + }).collect(); + let idents_str = idents.iter().map(|t| unraw(t)).collect::>(); + + let types: Vec = fields.iter().map(|f| { + let ty = &f.ty; + quote!( #ty ) + }).collect::>(); + + FieldData { idents, idents_str, types } + } +} + +fn impl_npy_serialize(ast: &syn::DeriveInput) -> Tokens { let name = &ast.ident; - let fields = match ast.data { - Data::Struct(ref data) => &data.fields, - _ => panic!("#[derive(Serializable)] can only be used with structs"), - }; - // Helper is provided for handling complex generic types correctly and effortlessly + let vis = &ast.vis; + let FieldData { ref idents, ref idents_str, ref types } = FieldData::extract(ast); + let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl(); + let field_dtypes_struct = gen_field_dtypes_struct(idents, idents_str); - let idents = fields.iter().map(|f| { - let mut t = Tokens::new(); - f.ident.clone().expect("Tuple structs not supported").to_tokens(&mut t); - t - }).collect::>(); - let types = fields.iter().map(|f| { - let mut t = Tokens::new(); - f.ty.to_tokens(&mut t); - t - }).collect::>(); - - let idents_c = idents.clone(); - let idents_str = idents.clone().into_iter().map(|t| t.to_string()).collect::>(); - let idents_str_c1 = idents_str.clone(); - let types_c1 = types.clone(); - let types_c2 = types.clone(); - let types_c3 = types.clone(); - - let nats_0 = 0usize..; - let nats_1 = 0usize..; - let n_fields = types.len(); + let idents_1 = idents; - quote! { - impl #impl_generics ::npy::Serializable for #name #ty_generics #where_clause { - fn dtype() -> ::npy::DType { - ::npy::DType::Record(vec![#( - ::npy::Field { - name: #idents_str_c1.to_string(), - dtype: <#types_c1 as ::npy::Serializable>::dtype() - } - ),*]) + wrap_in_const("Serialize", &name, quote! { + use ::std::io; + + #vis struct GeneratedWriter #ty_generics #where_clause { + writers: FieldWriters #ty_generics, + } + + struct FieldWriters #ty_generics #where_clause { + #( #idents: <#types as _npy::Serialize>::Writer ,)* + } + + #field_dtypes_struct + + impl #impl_generics _npy::TypeWrite for GeneratedWriter #ty_generics #where_clause { + type Value = #name #ty_generics; + + #[allow(unused_mut)] + fn write_one(&self, mut w: W, value: &Self::Value) -> io::Result<()> { + #({ // braces for pre-NLL + let method = <<#types as _npy::Serialize>::Writer as _npy::TypeWrite>::write_one; + method(&self.writers.#idents, &mut w, &value.#idents_1)?; + })* + p::Ok(()) } + } + + impl #impl_generics _npy::Serialize for #name #ty_generics #where_clause { + type Writer = GeneratedWriter #ty_generics; + + fn writer(dtype: &_npy::DType) -> p::Result { + let dtypes = FieldDTypes::extract(dtype)?; + let writers = FieldWriters { + #( #idents: <#types as _npy::Serialize>::writer(&dtypes.#idents_1)? ,)* + }; - fn n_bytes() -> usize { - #( <#types_c2 as ::npy::Serializable>::n_bytes() )+* + p::Ok(GeneratedWriter { writers }) } + } + }) +} - #[allow(unused_assignments)] - fn read(buf: &[u8]) -> Self { - let mut offset = 0; - let mut offsets = [0; #n_fields + 1]; +fn impl_npy_deserialize(ast: &syn::DeriveInput) -> Tokens { + let name = &ast.ident; + let vis = &ast.vis; + let FieldData { ref idents, ref idents_str, ref types } = FieldData::extract(ast); + + let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl(); + let field_dtypes_struct = gen_field_dtypes_struct(idents, idents_str); + + let idents_1 = idents; + + wrap_in_const("Deserialize", &name, quote! { + #vis struct GeneratedReader #ty_generics #where_clause { + readers: FieldReaders #ty_generics, + } + + struct FieldReaders #ty_generics #where_clause { + #( #idents: <#types as _npy::Deserialize>::Reader ,)* + } + + #field_dtypes_struct + + impl #impl_generics _npy::TypeRead for GeneratedReader #ty_generics #where_clause { + type Value = #name #ty_generics; + + #[allow(unused_mut)] + fn read_one<'a>(&self, mut remainder: &'a [u8]) -> (Self::Value, &'a [u8]) { #( - offset += <#types_c3 as ::npy::Serializable>::n_bytes(); - offsets[#nats_0 + 1] = offset; + let func = <<#types as _npy::Deserialize>::Reader as _npy::TypeRead>::read_one; + let (#idents, new_remainder) = func(&self.readers.#idents_1, remainder); + remainder = new_remainder; )* + (#name { #( #idents ),* }, remainder) + } + } + + impl #impl_generics _npy::Deserialize for #name #ty_generics #where_clause { + type Reader = GeneratedReader #ty_generics; - #name { #( - #idents: ::npy::Serializable::read(&buf[offsets[#nats_1]..]) - ),* } + fn reader(dtype: &_npy::DType) -> p::Result { + let dtypes = FieldDTypes::extract(dtype)?; + let readers = FieldReaders { + #( #idents: <#types as _npy::Deserialize>::reader(&dtypes.#idents_1)? ,)* + }; + + p::Ok(GeneratedReader { readers }) } + } + }) +} + +fn impl_npy_auto_serialize(ast: &syn::DeriveInput) -> Tokens { + let name = &ast.ident; + let FieldData { idents: _, ref idents_str, ref types } = FieldData::extract(ast); + + let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl(); - fn write(&self, writer: &mut W) -> ::std::io::Result<()> { - #( ::npy::Serializable::write(&self.#idents_c, writer)?; )* - Ok(()) + wrap_in_const("AutoSerialize", &name, quote! { + impl #impl_generics _npy::AutoSerialize for #name #ty_generics #where_clause { + fn default_dtype() -> _npy::DType { + _npy::DType::Record(vec![#( + _npy::Field { + name: #idents_str.to_string(), + dtype: <#types as _npy::AutoSerialize>::default_dtype() + } + ),*]) + } + } + }) +} + +fn gen_field_dtypes_struct( + idents: &[syn::Ident], + idents_str: &[String], +) -> Tokens { + assert_eq!(idents.len(), idents_str.len()); + quote!{ + struct FieldDTypes { + #( #idents : _npy::DType ,)* + } + + impl FieldDTypes { + fn extract(dtype: &_npy::DType) -> p::Result { + let fields = match dtype { + _npy::DType::Record(fields) => fields, + _npy::DType::Plain { ty, .. } => return p::Err(_npy::DTypeError::expected_record(ty)), + }; + + let correct_names: &[&str] = &[ #(#idents_str),* ]; + + if p::Iterator::ne( + p::Iterator::map(fields.iter(), |f| &f.name[..]), + p::Iterator::cloned(correct_names.iter()), + ) { + let actual_names = p::Iterator::map(fields.iter(), |f| &f.name[..]); + return p::Err(_npy::DTypeError::wrong_fields(actual_names, correct_names)); + } + + #[allow(unused_mut)] + let mut fields = p::IntoIterator::into_iter(fields); + p::Result::Ok(FieldDTypes { + #( #idents : { + let field = p::Iterator::next(&mut fields).unwrap(); + p::Clone::clone(&field.dtype) + },)* + }) } } } } + +// from the wonderful folks working on serde +fn wrap_in_const( + trait_: &str, + ty: &syn::Ident, + code: Tokens, +) -> Tokens { + let dummy_const = syn::Ident::new( + &format!("__IMPL_npy_{}_FOR_{}", trait_, unraw(ty)), + Span::call_site(), + ); + + quote! { + #[allow(non_upper_case_globals, unused_attributes, unused_qualifications)] + const #dummy_const: () = { + #[allow(unknown_lints)] + #[cfg_attr(feature = "cargo-clippy", allow(useless_attribute))] + #[allow(rust_2018_idioms)] + extern crate npy as _npy; + + // if our generated code directly imports any traits, then the #[no_implicit_prelude] + // test won't catch accidental use of method syntax on trait methods (which can fail + // due to ambiguity with similarly-named methods on other traits). So if we want to + // abbreviate paths, we need to do this instead: + use ::std::prelude::v1 as p; + + #code + }; + } +} + +fn unraw(ident: &syn::Ident) -> String { + ident.to_string().trim_start_matches("r#").to_owned() +} diff --git a/src/header.rs b/src/header.rs index 580221c..dc881c8 100644 --- a/src/header.rs +++ b/src/header.rs @@ -2,16 +2,19 @@ use nom::IResult; use std::collections::HashMap; use std::io::Result; +use type_str::TypeStr; /// Representation of a Numpy type -#[derive(PartialEq, Eq, Debug)] +#[derive(PartialEq, Eq, Debug, Clone)] pub enum DType { /// A simple array with only a single field Plain { - /// Numpy type string. First character is `'>'` for big endian, `'<'` for little endian. + /// Numpy type string. First character is `'>'` for big endian, `'<'` for little endian, + /// or can be `'|'` if it doesn't matter. /// - /// Examples: `>i4`, `f8`. The number corresponds to the number of bytes. - ty: String, + /// Examples: `>i4`, `f8`, `|S7`. The number usually corresponds to the number of + /// bytes (with the single exception of unicode strings `|U3`). + ty: TypeStr, /// Shape of a type. /// @@ -24,7 +27,7 @@ pub enum DType { Record(Vec) } -#[derive(PartialEq, Eq, Debug)] +#[derive(PartialEq, Eq, Debug, Clone)] /// A field of a record dtype pub struct Field { /// The name of the field @@ -64,11 +67,53 @@ impl DType { pub fn from_descr(descr: Value) -> Result { use DType::*; match descr { - Value::String(string) => Ok(Plain { ty: string, shape: vec![] }), + Value::String(ref string) => Ok(Self::new_scalar(convert_string_to_type_str(string)?)), Value::List(ref list) => Ok(Record(convert_list_to_record_fields(list)?)), _ => invalid_data("must be string or list") } } + + // not part of stable API, but needed by the serialize_array test + #[doc(hidden)] + pub fn parse(source: &str) -> Result { + let descr = match parser::item(source.as_bytes()) { + IResult::Done(_, header) => { + Ok(header) + }, + IResult::Incomplete(needed) => { + invalid_data(&format!("could not parse Python expression: {:?}", needed)) + }, + IResult::Error(err) => { + invalid_data(&format!("could not parse Python expression: {:?}", err)) + }, + }?; + Self::from_descr(descr) + } + + /// Construct a scalar `DType`. (one which is not a nested array or record type) + pub fn new_scalar(ty: TypeStr) -> Self { + DType::Plain { ty, shape: vec![] } + } + + /// Return a `TypeStr` only if the `DType` is a primitive scalar. (no arrays or record types) + pub(crate) fn as_scalar(&self) -> Option<&TypeStr> { + match self { + DType::Plain { ty, shape } if shape.is_empty() => Some(ty), + _ => None, + } + } + + /// Get the number of bytes that each item of this type occupies. + pub fn num_bytes(&self) -> usize { + match self { + DType::Plain { ty, shape } => { + ty.num_bytes() * shape.iter().product::() as usize + }, + DType::Record(fields) => { + fields.iter().map(|field| field.dtype.num_bytes()).sum() + }, + } + } } fn convert_list_to_record_fields(values: &[Value]) -> Result> { @@ -86,7 +131,7 @@ fn convert_tuple_to_record_field(tuple: &[Value]) -> Result { 2 | 3 => match (&tuple[0], &tuple[1], tuple.get(2)) { (&String(ref name), &String(ref dtype), ref shape) => Ok(Field { name: name.clone(), dtype: DType::Plain { - ty: dtype.clone(), + ty: convert_string_to_type_str(dtype)?, shape: if let &Some(ref s) = shape { convert_value_to_shape(s)? } else { @@ -127,6 +172,13 @@ fn convert_value_to_positive_integer(number: &Value) -> Result { } } +fn convert_string_to_type_str(string: &str) -> Result { + match string.parse() { + Ok(ty) => Ok(ty), + Err(e) => invalid_data(&format!("invalid type string: {}", e)), + } +} + fn first_error(results: I) -> Result> where I: IntoIterator> { @@ -244,84 +296,103 @@ mod parser { #[cfg(test)] mod tests { use super::*; + use std::error::Error; + + type TestResult = std::result::Result<(), Box>; #[test] - fn description_of_record_array_as_python_list_of_tuples() { + fn description_of_record_array_as_python_list_of_tuples() -> TestResult { let dtype = DType::Record(vec![ Field { name: "float".to_string(), - dtype: DType::Plain { ty: ">f4".to_string(), shape: vec![] } + dtype: DType::Plain { ty: ">f4".parse()?, shape: vec![] } }, Field { name: "byte".to_string(), - dtype: DType::Plain { ty: "f8".to_string(), shape: vec![] }; + fn description_of_unstructured_primitive_array() -> TestResult { + let dtype = DType::Plain { ty: ">f8".parse()?, shape: vec![] }; assert_eq!(dtype.descr(), "'>f8'"); + Ok(()) } #[test] - fn description_of_nested_record_dtype() { + fn description_of_nested_record_dtype() -> TestResult { let dtype = DType::Record(vec![ Field { name: "parent".to_string(), dtype: DType::Record(vec![ Field { name: "child".to_string(), - dtype: DType::Plain { ty: " TestResult { + let dtype = ">f8"; + assert_eq!( + DType::from_descr(Value::String(dtype.to_string())).unwrap(), + DType::Plain { ty: dtype.parse()?, shape: vec![] } + ); + Ok(()) } #[test] - fn converts_simple_description_to_record_dtype() { - let dtype = ">f8".to_string(); + fn converts_non_endian_description_to_record_dtype() -> TestResult { + let dtype = "|u1"; assert_eq!( - DType::from_descr(Value::String(dtype.clone())).unwrap(), - DType::Plain { ty: dtype, shape: vec![] } + DType::from_descr(Value::String(dtype.to_string())).unwrap(), + DType::Plain { ty: dtype.parse()?, shape: vec![] } ); + Ok(()) } #[test] - fn converts_record_description_to_record_dtype() { + fn converts_record_description_to_record_dtype() -> TestResult { let descr = parse("[('a', ' TestResult { let descr = parse("[('a', '>f8', (1,))]"); let expected_dtype = DType::Record(vec![ Field { name: "a".to_string(), - dtype: DType::Plain { ty: ">f8".to_string(), shape: vec![1] } + dtype: DType::Plain { ty: ">f8".parse()?, shape: vec![1] } } ]); assert_eq!(DType::from_descr(descr).unwrap(), expected_dtype); + Ok(()) } #[test] - fn record_description_with_nested_record_field() { + fn record_description_with_nested_record_field() -> TestResult { let descr = parse("[('parent', [('child', ' { +pub struct NpyData<'a, T: Deserialize> { data: &'a [u8], + dtype: DType, n_records: usize, - _t: PhantomData, + item_size: usize, + reader: ::Reader, } -impl<'a, T: Serializable> NpyData<'a, T> { +impl<'a, T: Deserialize> NpyData<'a, T> { /// Deserialize a NPY file represented as bytes pub fn from_bytes(bytes: &'a [u8]) -> ::std::io::Result> { - let (data_slice, ns) = Self::get_data_slice(bytes)?; - Ok(NpyData { data: data_slice, n_records: ns as usize, _t: PhantomData }) + let (dtype, data, ns) = Self::get_data_slice(bytes)?; + let reader = match T::reader(&dtype) { + Ok(reader) => reader, + Err(e) => return Err(Error::new(ErrorKind::InvalidData, e.to_string())), + }; + let item_size = dtype.num_bytes(); + Ok(NpyData { data, dtype, n_records: ns as usize, item_size, reader }) + } + + /// Get the dtype as written in the file. + pub fn dtype(&self) -> &DType { + &self.dtype } /// Gets a single data-record with the specified index. Returns None, if the index is @@ -46,9 +55,9 @@ impl<'a, T: Serializable> NpyData<'a, T> { self.n_records == 0 } - /// Gets a single data-record wit the specified index. Panics, if the index is out of bounds. + /// Gets a single data-record with the specified index. Panics if the index is out of bounds. pub fn get_unchecked(&self, i: usize) -> T { - T::read(&self.data[i * T::n_bytes()..]) + self.reader.read_one(&self.data[i * self.item_size..]).0 } /// Construct a vector with the deserialized contents of the whole file @@ -60,7 +69,7 @@ impl<'a, T: Serializable> NpyData<'a, T> { v } - fn get_data_slice(bytes: &[u8]) -> Result<(&[u8], i64)> { + fn get_data_slice(bytes: &[u8]) -> Result<(DType, &[u8], i64)> { let (data, header) = match parse_header(bytes) { IResult::Done(data, header) => { Ok((data, header)) @@ -95,35 +104,28 @@ impl<'a, T: Serializable> NpyData<'a, T> { "\'descr\' field is not present or doesn't contain a list."))?; if let Ok(dtype) = DType::from_descr(descr.clone()) { - let expected_dtype = T::dtype(); - if dtype != expected_dtype { - return Err(Error::new(ErrorKind::InvalidData, - format!("Types don't match! found: {:?}, expected: {:?}", dtype, expected_dtype) - )); - } + Ok((dtype, data, ns)) } else { - return Err(Error::new(ErrorKind::InvalidData, format!("fail?!?"))); + Err(Error::new(ErrorKind::InvalidData, format!("fail?!?"))) } - - Ok((data, ns)) } } /// A result of NPY file deserialization. /// /// It is an iterator to offer a lazy interface in case the data don't fit into memory. -pub struct IntoIter<'a, T: 'a> { +pub struct IntoIter<'a, T: 'a + Deserialize> { data: NpyData<'a, T>, i: usize, } -impl<'a, T> IntoIter<'a, T> { +impl<'a, T> IntoIter<'a, T> where T: Deserialize { fn new(data: NpyData<'a, T>) -> Self { IntoIter { data, i: 0 } } } -impl<'a, T: 'a + Serializable> IntoIterator for NpyData<'a, T> { +impl<'a, T: 'a> IntoIterator for NpyData<'a, T> where T: Deserialize { type Item = T; type IntoIter = IntoIter<'a, T>; @@ -132,7 +134,7 @@ impl<'a, T: 'a + Serializable> IntoIterator for NpyData<'a, T> { } } -impl<'a, T> Iterator for IntoIter<'a, T> where T: Serializable { +impl<'a, T> Iterator for IntoIter<'a, T> where T: Deserialize { type Item = T; fn next(&mut self) -> Option { @@ -145,4 +147,4 @@ impl<'a, T> Iterator for IntoIter<'a, T> where T: Serializable { } } -impl<'a, T> ExactSizeIterator for IntoIter<'a, T> where T: Serializable {} +impl<'a, T> ExactSizeIterator for IntoIter<'a, T> where T: Deserialize {} diff --git a/src/out_file.rs b/src/out_file.rs index f11fc21..6e69ec2 100644 --- a/src/out_file.rs +++ b/src/out_file.rs @@ -1,30 +1,34 @@ - use std::io::{self,Write,BufWriter,Seek,SeekFrom}; use std::fs::File; use std::path::Path; -use std::marker::PhantomData; use byteorder::{WriteBytesExt, LittleEndian}; -use serializable::Serializable; +use serialize::{AutoSerialize, Serialize, TypeWrite}; use header::DType; const FILLER: &'static [u8] = &[42; 19]; /// Serialize into a file one row at a time. To serialize an iterator, use the /// [`to_file`](fn.to_file.html) function. -pub struct OutFile { +pub struct OutFile { shape_pos: usize, len: usize, fw: BufWriter, - _t: PhantomData + writer: ::Writer, } -impl OutFile { - /// Open a file +impl OutFile { + /// Create a file, using the default format for the given type. pub fn open>(path: P) -> io::Result { - let dtype = Row::dtype(); - if let &DType::Plain { ref shape, .. } = &dtype { + Self::open_with_dtype(&Row::default_dtype(), path) + } +} + +impl OutFile { + /// Create a file, using the provided dtype. + pub fn open_with_dtype>(dtype: &DType, path: P) -> io::Result { + if let &DType::Plain { ref shape, .. } = dtype { assert!(shape.len() == 0, "plain non-scalar dtypes not supported"); } let mut fw = BufWriter::new(File::create(path)?); @@ -32,7 +36,12 @@ impl OutFile { fw.write_all(b"NUMPY")?; fw.write_all(&[0x01u8, 0x00])?; - let (header, shape_pos) = create_header(&dtype); + let (header, shape_pos) = create_header(dtype); + + let writer = match Row::writer(dtype) { + Ok(writer) => writer, + Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), + }; let mut padding: Vec = vec![]; padding.extend(&::std::iter::repeat(b' ').take(15 - ((header.len() + 10) % 16)).collect::>()); @@ -44,21 +53,21 @@ impl OutFile { fw.write_u16::(len as u16)?; fw.write_all(&header)?; - // Padding to 8 bytes + // Padding to 16 bytes fw.write_all(&padding)?; Ok(OutFile { shape_pos: shape_pos, len: 0, fw: fw, - _t: PhantomData, + writer: writer, }) } /// Append a single row to the file pub fn push(&mut self, row: &Row) -> io::Result<()> { self.len += 1; - row.write(&mut self.fw) + self.writer.write_one(&mut self.fw, row) } fn close_(&mut self) -> io::Result<()> { @@ -90,7 +99,7 @@ fn create_header(dtype: &DType) -> (Vec, usize) { (header, shape_pos) } -impl Drop for OutFile { +impl Drop for OutFile { fn drop(&mut self) { let _ = self.close_(); // Ignore the errors } @@ -101,9 +110,9 @@ impl Drop for OutFile { /// Serialize an iterator over a struct to a NPY file /// /// A single-statement alternative to saving row by row using the [`OutFile`](struct.OutFile.html). -pub fn to_file<'a, S, T, P>(filename: P, data: T) -> ::std::io::Result<()> where +pub fn to_file(filename: P, data: T) -> ::std::io::Result<()> where P: AsRef, - S: Serializable + 'a, + S: AutoSerialize, T: IntoIterator { let mut of = OutFile::open(filename)?; diff --git a/src/serializable.rs b/src/serializable.rs deleted file mode 100644 index 7b37221..0000000 --- a/src/serializable.rs +++ /dev/null @@ -1,234 +0,0 @@ - -use std::io::{Write,Result}; -use byteorder::{WriteBytesExt, LittleEndian}; -use header::DType; -use byteorder::ByteOrder; - -/// This trait contains information on how to serialize and deserialize a type. -/// -/// An example illustrating a `Serializable` implementation for a fixed-size vector is in -/// [the roundtrip test](https://github.com/potocpav/npy-rs/tree/master/tests/roundtrip.rs). -/// It is strongly advised to annotate the `Serializable` functions as `#[inline]` for good -/// performance. -pub trait Serializable : Sized { - /// Convert a type to a structure representing a Numpy type - fn dtype() -> DType; - - /// Get the number of bytes of the binary repr - fn n_bytes() -> usize; - - /// Deserialize a single data field, advancing the cursor in the process. - fn read(c: &[u8]) -> Self; - - /// Serialize a single data field into a writer. - fn write(&self, writer: &mut W) -> Result<()>; -} - -impl Serializable for i8 { - #[inline] - fn dtype() -> DType { - DType::Plain { ty: " usize { 1 } - #[inline] - fn read(buf: &[u8]) -> Self { - unsafe { ::std::mem::transmute(buf[0]) } // TODO: a better way - } - #[inline] - fn write(&self, writer: &mut W) -> Result<()> { - writer.write_i8(*self) - } -} - -impl Serializable for i16 { - #[inline] - fn dtype() -> DType { - DType::Plain { ty: " usize { 2 } - #[inline] - fn read(buf: &[u8]) -> Self { - LittleEndian::read_i16(buf) - } - #[inline] - fn write(&self, writer: &mut W) -> Result<()> { - writer.write_i16::(*self) - } -} - -impl Serializable for i32 { - #[inline] - fn dtype() -> DType { - DType::Plain { ty: " usize { 4 } - #[inline] - fn read(buf: &[u8]) -> Self { - LittleEndian::read_i32(buf) - } - #[inline] - fn write(&self, writer: &mut W) -> Result<()> { - writer.write_i32::(*self) - } -} - -impl Serializable for i64 { - #[inline] - fn dtype() -> DType { - DType::Plain { ty: " usize { 8 } - #[inline] - fn read(buf: &[u8]) -> Self { - LittleEndian::read_i64(buf) - } - #[inline] - fn write(&self, writer: &mut W) -> Result<()> { - writer.write_i64::(*self) - } -} - -impl Serializable for u8 { - #[inline] - fn dtype() -> DType { - DType::Plain { ty: " usize { 1 } - #[inline] - fn read(buf: &[u8]) -> Self { - buf[0] - } - #[inline] - fn write(&self, writer: &mut W) -> Result<()> { - writer.write_u8(*self) - } -} - -impl Serializable for u16 { - #[inline] - fn dtype() -> DType { - DType::Plain { ty: " usize { 2 } - #[inline] - fn read(buf: &[u8]) -> Self { - LittleEndian::read_u16(buf) - } - #[inline] - fn write(&self, writer: &mut W) -> Result<()> { - writer.write_u16::(*self) - } -} - -impl Serializable for u32 { - #[inline] - fn dtype() -> DType { - DType::Plain { ty: " usize { 4 } - #[inline] - fn read(buf: &[u8]) -> Self { - LittleEndian::read_u32(buf) - } - #[inline] - fn write(&self, writer: &mut W) -> Result<()> { - writer.write_u32::(*self) - } -} - -impl Serializable for u64 { - #[inline] - fn dtype() -> DType { - DType::Plain { ty: " usize { 8 } - #[inline] - fn read(buf: &[u8]) -> Self { - LittleEndian::read_u64(buf) - } - #[inline] - fn write(&self, writer: &mut W) -> Result<()> { - writer.write_u64::(*self) - } -} - -impl Serializable for f32 { - #[inline] - fn dtype() -> DType { - DType::Plain { ty: " usize { 4 } - #[inline] - fn read(buf: &[u8]) -> Self { - LittleEndian::read_f32(buf) - } - #[inline] - fn write(&self, writer: &mut W) -> Result<()> { - writer.write_f32::(*self) - } -} - -impl Serializable for f64 { - #[inline] - fn dtype() -> DType { - DType::Plain { ty: " usize { 8 } - #[inline] - fn read(buf: &[u8]) -> Self { - LittleEndian::read_f64(buf) - } - #[inline] - fn write(&self, writer: &mut W) -> Result<()> { - writer.write_f64::(*self) - } -} - -macro_rules! gen_array_serializable { - ($($n:tt),+) => { $( - impl Serializable for [T; $n] { - #[inline] - fn dtype() -> DType { - use DType::*; - match T::dtype() { - Plain { ref ty, ref shape } => DType::Plain { - ty: ty.clone(), - shape: shape.clone().into_iter().chain(Some($n)).collect() - }, - Record(_) => unimplemented!("arrays of nested records") - } - } - #[inline] - fn n_bytes() -> usize { T::n_bytes() * $n } - #[inline] - fn read(buf: &[u8]) -> Self { - let mut a = [T::default(); $n]; - let mut off = 0; - for x in &mut a { - *x = T::read(&buf[off..]); - off += T::n_bytes(); - } - a - } - #[inline] - fn write(&self, writer: &mut W) -> Result<()> { - for item in self { - item.write(writer)?; - } - Ok(()) - } - } - )+ } -} - -gen_array_serializable!(1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16); diff --git a/src/serialize.rs b/src/serialize.rs new file mode 100644 index 0000000..452b97c --- /dev/null +++ b/src/serialize.rs @@ -0,0 +1,954 @@ +use header::DType; +use type_str::{TypeStr, Endianness, TypeKind}; +use byteorder::{ByteOrder, NativeEndian, WriteBytesExt}; +use self::{TypeKind::*}; +use std::io; +use std::fmt; +use std::convert::TryFrom; + +/// Trait that permits reading a type from an `.npy` file. +/// +/// For an example of how to implement this, please see the +/// [roundtrip test](https://github.com/potocpav/npy-rs/tree/master/tests/roundtrip.rs). +pub trait Deserialize: Sized { + /// Think of this as like a `Fn(&[u8]) -> (Self, &[u8])`. + /// + /// There is no closure-like sugar for these; you must manually define a type that + /// implements [`TypeRead`]. + type Reader: TypeRead; + + /// Get a function that deserializes a single data field at a time + /// + /// The function receives a byte buffer containing at least + /// `dtype.num_bytes()` bytes. + /// + /// # Errors + /// + /// Returns `Err` if the `DType` is not compatible with `Self`. + fn reader(dtype: &DType) -> Result; +} + +/// Trait that permits writing a type to an `.npy` file. +/// +/// For an example of how to implement this, please see the +/// [roundtrip test](https://github.com/potocpav/npy-rs/tree/master/tests/roundtrip.rs). +pub trait Serialize { + /// Think of this as some sort of `for Fn(W, &Self) -> io::Result<()>`. + /// + /// There is no closure-like sugar for these; you must manually define a type that + /// implements [`TypeWrite`]. + type Writer: TypeWrite; + + /// Get a function that serializes a single data field at a time. + /// + /// # Errors + /// + /// Returns `Err` if the `DType` is not compatible with `Self`. + fn writer(dtype: &DType) -> Result; +} + +/// Subtrait of [`Serialize`] for types which have a reasonable default [`DType`]. +/// +/// This opens up some simpler APIs for serialization. (e.g. [`::to_file`]) +/// +/// For an example of how to implement this, please see the +/// [roundtrip test](https://github.com/potocpav/npy-rs/tree/master/tests/roundtrip.rs). +pub trait AutoSerialize: Serialize { + /// A suggested format for serialization. + /// + /// The builtin implementations for primitive types generally prefer `|` endianness if possible, + /// else the machine endian format. + fn default_dtype() -> DType; +} + +/// Like a `Fn(&[u8]) -> (T, &[u8])`. +/// +/// It is a separate trait from `Fn` for consistency with [`TypeWrite`], and so that +/// default methods can potentially be added in the future that may be overriden +/// for efficiency. +/// +/// For an example of how to implement this, please see the +/// [roundtrip test](https://github.com/potocpav/npy-rs/tree/master/tests/roundtrip.rs). +pub trait TypeRead { + /// Type returned by the function. + type Value; + + /// The function. + /// + /// Receives *at least* enough bytes to read `Self::Value`, and returns the remainder. + fn read_one<'a>(&self, bytes: &'a [u8]) -> (Self::Value, &'a [u8]); +} + +/// Like some sort of `for Fn(W, &T) -> io::Result<()>`. +/// +/// For an example of how to implement this, please see the +/// [roundtrip test](https://github.com/potocpav/npy-rs/tree/master/tests/roundtrip.rs). +pub trait TypeWrite { + /// Type accepted by the function. + type Value: ?Sized; + + /// The function. + fn write_one(&self, writer: W, value: &Self::Value) -> io::Result<()> + where Self: Sized; +} + +/// The proper trait to use for trait objects of [`TypeWrite`]. +/// +/// `Box` is useless because `dyn TypeWrite` has no object-safe methods. +/// The workaround is to use `Box` instead, which itself implements `TypeWrite`. +pub trait TypeWriteDyn: TypeWrite { + #[doc(hidden)] + fn write_one_dyn(&self, writer: &mut dyn io::Write, value: &Self::Value) -> io::Result<()>; +} + +impl TypeWriteDyn for T { + #[inline(always)] + fn write_one_dyn(&self, writer: &mut dyn io::Write, value: &Self::Value) -> io::Result<()> { + self.write_one(writer, value) + } +} + +/// Indicates that a particular rust type does not support serialization or deserialization +/// as a given [`DType`]. +#[derive(Debug, Clone)] +pub struct DTypeError(ErrorKind); + +#[derive(Debug, Clone)] +enum ErrorKind { + Custom(String), + ExpectedScalar { + dtype: String, + rust_type: &'static str, + }, + ExpectedArray { + got: &'static str, // "a scalar", "a record" + }, + WrongArrayLen { + expected: u64, + actual: u64, + }, + ExpectedRecord { + type_str: TypeStr, + }, + WrongFields { + expected: Vec, + actual: Vec, + }, + BadScalar { + type_str: TypeStr, + rust_type: &'static str, + verb: &'static str, + }, + UsizeOverflow(u64), +} + +impl std::error::Error for DTypeError {} + +impl DTypeError { + /// Construct with a custom error message. + pub fn custom>(msg: S) -> Self { + DTypeError(ErrorKind::Custom(msg.as_ref().to_string())) + } + + // verb should be "read" or "write" + fn bad_scalar(verb: &'static str, type_str: &TypeStr, rust_type: &'static str) -> Self { + let type_str = type_str.clone(); + DTypeError(ErrorKind::BadScalar { type_str, rust_type, verb }) + } + + fn expected_scalar(dtype: &DType, rust_type: &'static str) -> Self { + let dtype = dtype.descr(); + DTypeError(ErrorKind::ExpectedScalar { dtype, rust_type }) + } + + fn bad_usize(x: u64) -> Self { + DTypeError(ErrorKind::UsizeOverflow(x)) + } + + // used by derives + #[doc(hidden)] + pub fn expected_record(type_str: &TypeStr) -> Self { + let type_str = type_str.clone(); + DTypeError(ErrorKind::ExpectedRecord { type_str }) + } + + // used by derives + #[doc(hidden)] + pub fn wrong_fields, S2: AsRef>( + expected: impl IntoIterator, + actual: impl IntoIterator, + ) -> Self { + DTypeError(ErrorKind::WrongFields { + expected: expected.into_iter().map(|s| s.as_ref().to_string()).collect(), + actual: actual.into_iter().map(|s| s.as_ref().to_string()).collect(), + }) + } +} + +impl fmt::Display for DTypeError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match &self.0 { + ErrorKind::Custom(msg) => { + write!(f, "{}", msg) + }, + ErrorKind::ExpectedScalar { dtype, rust_type } => { + write!(f, "type {} requires a scalar (string) dtype, not {}", rust_type, dtype) + }, + ErrorKind::ExpectedRecord { type_str } => { + write!(f, "expected a record type; got a scalar type '{}'", type_str) + }, + ErrorKind::ExpectedArray { got } => { + write!(f, "rust array types require an array dtype (got {})", got) + }, + ErrorKind::WrongArrayLen { actual, expected } => { + write!(f, "wrong array size (expected {}, got {})", expected, actual) + }, + ErrorKind::WrongFields { actual, expected } => { + write!(f, "field names do not match (expected {:?}, got {:?})", expected, actual) + }, + ErrorKind::BadScalar { type_str, rust_type, verb } => { + write!(f, "cannot {} type {} with type-string '{}'", verb, rust_type, type_str) + }, + ErrorKind::UsizeOverflow(value) => { + write!(f, "cannot cast {} as usize", value) + }, + } + } +} + +impl TypeRead for Box> { + type Value = T; + + #[inline(always)] + fn read_one<'a>(&self, bytes: &'a [u8]) -> (T, &'a [u8]) { + (**self).read_one(bytes) + } +} + +impl TypeWrite for Box> { + type Value = T; + + #[inline(always)] + fn write_one(&self, mut writer: W, value: &T) -> io::Result<()> + where Self: Sized, + { + // Boxes must always go through two virtual dispatches. + // + // (one on the TypeWrite trait object, and one on the Writer which must be + // cast to the monomorphic type `&mut dyn io::write`) + (**self).write_one_dyn(&mut writer, value) + } +} + +fn invalid_data(message: &str) -> io::Result { + Err(io::Error::new(io::ErrorKind::InvalidData, message.to_string())) +} + +macro_rules! impl_integer_serializable { + ( + meta: [ (main_ty: $Int:ident) ] + types: [ $( + [ $size:literal $int:ident $read_int:ident $write_int:ident ] + )* ] + ) => { $( + mod $int { + use super::*; + + pub struct AnyEndianReader { pub(super) swap_byteorder: bool } + pub struct AnyEndianWriter { pub(super) swap_byteorder: bool } + + pub(super) fn expect_scalar_dtype(dtype: &DType) -> Result<&TypeStr, DTypeError> { + dtype.as_scalar().ok_or_else(|| { + DTypeError::expected_scalar(dtype, stringify!($int)) + }) + } + + #[inline] + fn maybe_swap(swap: bool, x: $int) -> $int { + match swap { + true => x.swap_bytes(), + false => x, + } + } + + impl TypeRead for AnyEndianReader { + type Value = $int; + + #[inline(always)] + fn read_one<'a>(&self, bytes: &'a [u8]) -> (Self::Value, &'a [u8]) { + let value = maybe_swap(self.swap_byteorder, NativeEndian::$read_int(bytes)); + (value, &bytes[$size..]) + } + } + + impl TypeWrite for AnyEndianWriter { + type Value = $int; + + #[inline(always)] + fn write_one(&self, mut writer: W, &value: &Self::Value) -> io::Result<()> { + writer.$write_int::(maybe_swap(self.swap_byteorder, value)) + } + } + } + + impl Deserialize for $int { + type Reader = $int::AnyEndianReader; + + fn reader(dtype: &DType) -> Result { + match $int::expect_scalar_dtype(dtype)? { + // Read an integer of the correct size and signedness. + TypeStr { size: $size, endianness, type_kind: $Int, .. } => { + assert!($size == 1 || endianness != &Endianness::Irrelevant, "(BUG) invalid dtype constructed?"); + + let swap_byteorder = endianness.requires_swap(Endianness::of_machine()); + Ok($int::AnyEndianReader { swap_byteorder }) + }, + type_str => Err(DTypeError::bad_scalar("read", type_str, stringify!($int))), + } + } + } + + impl Serialize for $int { + type Writer = $int::AnyEndianWriter; + + fn writer(dtype: &DType) -> Result { + match $int::expect_scalar_dtype(dtype)? { + // Write an integer of the correct size and signedness. + TypeStr { size: $size, endianness, type_kind: $Int, .. } => { + assert!($size == 1 || endianness != &Endianness::Irrelevant, "(BUG) invalid dtype constructed?"); + + let swap_byteorder = endianness.requires_swap(Endianness::of_machine()); + Ok($int::AnyEndianWriter { swap_byteorder }) + }, + type_str => Err(DTypeError::bad_scalar("write", type_str, stringify!($int))), + } + } + } + + impl AutoSerialize for $int { + fn default_dtype() -> DType { + DType::new_scalar(TypeStr::with_auto_endianness($Int, $size, None)) + } + } + )*}; +} + +// Needed by the macro: Methods missing from byteorder +trait ReadSingleByteExt { + #[inline(always)] fn read_u8_(bytes: &[u8]) -> u8 { bytes[0] } + #[inline(always)] fn read_i8_(bytes: &[u8]) -> i8 { i8::from_ne_bytes([bytes[0]]) } +} + +impl ReadSingleByteExt for E {} + +/// Needed by the macro: Methods modified to take a generic type param +trait WriteSingleByteExt: WriteBytesExt { + #[inline(always)] fn write_u8_(&mut self, value: u8) -> io::Result<()> { self.write_u8(value) } + #[inline(always)] fn write_i8_(&mut self, value: i8) -> io::Result<()> { self.write_i8(value) } +} + +impl WriteSingleByteExt for W {} + +impl_integer_serializable! { + meta: [ (main_ty: Int) ] + types: [ + // numpy doesn't support i128 + [ 8 i64 read_i64 write_i64 ] + [ 4 i32 read_i32 write_i32 ] + [ 2 i16 read_i16 write_i16 ] + [ 1 i8 read_i8_ write_i8_ ] + ] +} + +impl_integer_serializable! { + meta: [ (main_ty: Uint) ] + types: [ + // numpy doesn't support i128 + [ 8 u64 read_u64 write_u64 ] + [ 4 u32 read_u32 write_u32 ] + [ 2 u16 read_u16 write_u16 ] + [ 1 u8 read_u8_ write_u8_ ] + ] +} + +// Takes info about each data size, from largest to smallest. +macro_rules! impl_float_serializable { + ( $( [ $size:literal $float:ident $read_float:ident $write_float:ident ] )+ ) => { $( + mod $float { + use super::*; + + pub struct AnyEndianReader { pub(super) swap_byteorder: bool } + pub struct AnyEndianWriter { pub(super) swap_byteorder: bool } + + #[inline] + fn maybe_swap(swap: bool, x: $float) -> $float { + match swap { + true => $float::from_bits(x.to_bits().swap_bytes()), + false => x, + } + } + + pub(super) fn expect_scalar_dtype(dtype: &DType) -> Result<&TypeStr, DTypeError> { + dtype.as_scalar().ok_or_else(|| { + DTypeError::expected_scalar(dtype, stringify!($float)) + }) + } + + impl TypeRead for AnyEndianReader { + type Value = $float; + + #[inline(always)] + fn read_one<'a>(&self, bytes: &'a [u8]) -> ($float, &'a [u8]) { + let value = maybe_swap(self.swap_byteorder, NativeEndian::$read_float(bytes)); + (value, &bytes[$size..]) + } + } + + impl TypeWrite for AnyEndianWriter { + type Value = $float; + + #[inline(always)] + fn write_one(&self, mut writer: W, &value: &$float) -> io::Result<()> { + writer.$write_float::(maybe_swap(self.swap_byteorder, value)) + } + } + } + + impl Deserialize for $float { + type Reader = $float::AnyEndianReader; + + fn reader(dtype: &DType) -> Result { + match $float::expect_scalar_dtype(dtype)? { + // Read a float of the correct size + TypeStr { size: $size, endianness, type_kind: Float, .. } => { + assert_ne!(endianness, &Endianness::Irrelevant, "(BUG) invalid dtype constructed?"); + + let swap_byteorder = endianness.requires_swap(Endianness::of_machine()); + Ok($float::AnyEndianReader { swap_byteorder }) + }, + type_str => Err(DTypeError::bad_scalar("read", type_str, stringify!($float))), + } + } + } + + impl Serialize for $float { + type Writer = $float::AnyEndianWriter; + + fn writer(dtype: &DType) -> Result { + match $float::expect_scalar_dtype(dtype)? { + // Write a float of the correct size + TypeStr { size: $size, endianness, type_kind: Float, .. } => { + assert_ne!(endianness, &Endianness::Irrelevant, "(BUG) invalid dtype constructed?"); + + let swap_byteorder = endianness.requires_swap(Endianness::of_machine()); + Ok($float::AnyEndianWriter { swap_byteorder }) + }, + type_str => Err(DTypeError::bad_scalar("write", type_str, stringify!($float))), + } + } + } + + impl AutoSerialize for $float { + fn default_dtype() -> DType { + DType::new_scalar(TypeStr::with_auto_endianness(Float, $size, None)) + } + } + )+}; +} + +impl_float_serializable! { + // TODO: numpy supports f16, f128 + [ 8 f64 read_f64 write_f64 ] + [ 4 f32 read_f32 write_f32 ] +} + +pub struct BytesReader { + size: usize, + is_byte_str: bool, +} + +impl TypeRead for BytesReader { + type Value = Vec; + + fn read_one<'a>(&self, bytes: &'a [u8]) -> (Vec, &'a [u8]) { + let mut vec = vec![]; + + let (src, remainder) = bytes.split_at(self.size); + vec.resize(self.size, 0); + vec.copy_from_slice(src); + + // truncate trailing zeros for type 'S' + if self.is_byte_str { + let end = vec.iter().rposition(|x| x != &0).map_or(0, |ind| ind + 1); + vec.truncate(end); + } + + (vec, remainder) + } +} + +impl Deserialize for Vec { + type Reader = BytesReader; + + fn reader(type_str: &DType) -> Result { + let type_str = type_str.as_scalar().ok_or_else(|| DTypeError::expected_scalar(type_str, "Vec"))?; + let size = match usize::try_from(type_str.size) { + Ok(size) => size, + Err(_) => return Err(DTypeError::bad_usize(type_str.size)), + }; + + let is_byte_str = match *type_str { + TypeStr { type_kind: ByteStr, .. } => true, + TypeStr { type_kind: RawData, .. } => false, + _ => return Err(DTypeError::bad_scalar("read", type_str, "Vec")), + }; + Ok(BytesReader { size, is_byte_str }) + } +} + +pub struct BytesWriter { + type_str: TypeStr, + size: usize, + is_byte_str: bool, +} + +impl TypeWrite for BytesWriter { + type Value = [u8]; + + fn write_one(&self, mut w: W, bytes: &[u8]) -> io::Result<()> { + use std::cmp::Ordering; + + match (bytes.len().cmp(&self.size), self.is_byte_str) { + (Ordering::Greater, _) | + (Ordering::Less, false) => return invalid_data( + &format!("bad item length {} for type-string '{}'", bytes.len(), self.type_str), + ), + _ => {}, + } + + w.write_all(bytes)?; + if self.is_byte_str { + w.write_all(&vec![0; self.size - bytes.len()])?; + } + Ok(()) + } +} + +impl Serialize for [u8] { + type Writer = BytesWriter; + + fn writer(dtype: &DType) -> Result { + let type_str = dtype.as_scalar().ok_or_else(|| DTypeError::expected_scalar(dtype, "[u8]"))?; + + let size = match usize::try_from(type_str.size) { + Ok(size) => size, + Err(_) => return Err(DTypeError::bad_usize(type_str.size)), + }; + + let type_str = type_str.clone(); + let is_byte_str = match type_str { + TypeStr { type_kind: ByteStr, .. } => true, + TypeStr { type_kind: RawData, .. } => false, + _ => return Err(DTypeError::bad_scalar("read", &type_str, "[u8]")), + }; + Ok(BytesWriter { type_str, size, is_byte_str }) + } +} + +#[macro_use] +mod helper { + use super::*; + use std::ops::Deref; + + pub struct TypeWriteViaDeref + where + T: Deref, + ::Target: Serialize, + { + pub(crate) inner: <::Target as Serialize>::Writer, + } + + impl TypeWrite for TypeWriteViaDeref + where + T: Deref, + U: Serialize, + { + type Value = T; + + #[inline(always)] + fn write_one(&self, writer: W, value: &T) -> io::Result<()> { + self.inner.write_one(writer, value) + } + } + + macro_rules! impl_serialize_by_deref { + ([$($generics:tt)*] $T:ty => $Target:ty $(where $($bounds:tt)+)*) => { + impl<$($generics)*> Serialize for $T + $(where $($bounds)+)* + { + type Writer = helper::TypeWriteViaDeref<$T>; + + #[inline(always)] + fn writer(dtype: &DType) -> Result { + Ok(helper::TypeWriteViaDeref { inner: <$Target>::writer(dtype)? }) + } + } + }; + } + + macro_rules! impl_auto_serialize { + ([$($generics:tt)*] $T:ty as $Delegate:ty $(where $($bounds:tt)+)*) => { + impl<$($generics)*> AutoSerialize for $T + $(where $($bounds)+)* + { + #[inline(always)] + fn default_dtype() -> DType { + <$Delegate>::default_dtype() + } + } + }; + } +} + +impl_serialize_by_deref!{[] Vec => [u8]} + +impl_serialize_by_deref!{['a, T: ?Sized] &'a T => T where T: Serialize} +impl_serialize_by_deref!{['a, T: ?Sized] &'a mut T => T where T: Serialize} +impl_serialize_by_deref!{[T: ?Sized] Box => T where T: Serialize} +impl_serialize_by_deref!{[T: ?Sized] std::rc::Rc => T where T: Serialize} +impl_serialize_by_deref!{[T: ?Sized] std::sync::Arc => T where T: Serialize} +impl_serialize_by_deref!{['a, T: ?Sized] std::borrow::Cow<'a, T> => T where T: Serialize + std::borrow::ToOwned} +impl_auto_serialize!{[T: ?Sized] &T as T where T: AutoSerialize} +impl_auto_serialize!{[T: ?Sized] &mut T as T where T: AutoSerialize} +impl_auto_serialize!{[T: ?Sized] Box as T where T: AutoSerialize} +impl_auto_serialize!{[T: ?Sized] std::rc::Rc as T where T: AutoSerialize} +impl_auto_serialize!{[T: ?Sized] std::sync::Arc as T where T: AutoSerialize} +impl_auto_serialize!{[T: ?Sized] std::borrow::Cow<'_, T> as T where T: AutoSerialize + std::borrow::ToOwned} + +impl DType { + /// Expect an array dtype, get the length of the array and the inner dtype. + fn array_inner_dtype(&self, expected_len: u64) -> Result { + match *self { + DType::Record { .. } => Err(DTypeError(ErrorKind::ExpectedArray { got: "a record" })), + DType::Plain { ref ty, ref shape } => { + let ty = ty.clone(); + let mut shape = shape.to_vec(); + + let len = match shape.is_empty() { + true => return Err(DTypeError(ErrorKind::ExpectedArray { got: "a scalar" })), + false => shape.remove(0), + }; + + if len != expected_len { + return Err(DTypeError(ErrorKind::WrongArrayLen { + actual: len, + expected: expected_len, + })); + } + + Ok(DType::Plain { ty, shape }) + }, + } + } +} + +macro_rules! gen_array_serializable { + ($([$n:tt in mod $mod_name:ident])+) => { $( + mod $mod_name { + use super::*; + + pub struct ArrayReader{ inner: I } + pub struct ArrayWriter{ inner: I } + + impl TypeRead for ArrayReader + where I::Value: Copy + Default, + { + type Value = [I::Value; $n]; + + #[inline] + fn read_one<'a>(&self, bytes: &'a [u8]) -> (Self::Value, &'a [u8]) { + let mut value = [I::Value::default(); $n]; + + let mut remainder = bytes; + for place in &mut value { + let (item, new_remainder) = self.inner.read_one(remainder); + *place = item; + remainder = new_remainder; + } + + (value, remainder) + } + } + + impl TypeWrite for ArrayWriter + where I::Value: Sized, + { + type Value = [I::Value; $n]; + + #[inline] + fn write_one(&self, mut writer: W, value: &Self::Value) -> io::Result<()> + where Self: Sized, + { + for item in value { + self.inner.write_one(&mut writer, item)?; + } + Ok(()) + } + } + + impl AutoSerialize for [T; $n] { + #[inline] + fn default_dtype() -> DType { + use DType::*; + + match T::default_dtype() { + Plain { ty, mut shape } => DType::Plain { + ty, + shape: { shape.insert(0, $n); shape }, + }, + Record(_) => unimplemented!("arrays of nested records") + } + } + } + + impl Deserialize for [T; $n] { + type Reader = ArrayReader<::Reader>; + + #[inline] + fn reader(dtype: &DType) -> Result { + let inner_dtype = dtype.array_inner_dtype($n)?; + let inner = ::reader(&inner_dtype)?; + Ok(ArrayReader { inner }) + } + } + + impl Serialize for [T; $n] { + type Writer = ArrayWriter<::Writer>; + + #[inline] + fn writer(dtype: &DType) -> Result { + let inner = ::writer(&dtype.array_inner_dtype($n)?)?; + Ok(ArrayWriter { inner }) + } + } + } + )+ } +} + +gen_array_serializable!{ + /* no size 0 */ [ 1 in mod arr1] [ 2 in mod arr2] [ 3 in mod arr3] + [ 4 in mod arr4] [ 5 in mod arr5] [ 6 in mod arr6] [ 7 in mod arr7] + [ 8 in mod arr8] [ 9 in mod arr9] [10 in mod arr10] [11 in mod arr11] + [12 in mod arr12] [13 in mod arr13] [14 in mod arr14] [15 in mod arr15] + [16 in mod arr16] +} + +#[cfg(test)] +#[deny(unused)] +mod tests { + use super::*; + + // NOTE: Tests for arrays are in tests/serialize_array.rs because they require derives + + fn reader_output(dtype: &DType, bytes: &[u8]) -> T { + T::reader(dtype).unwrap_or_else(|e| panic!("{}", e)).read_one(bytes).0 + } + + fn reader_expect_err(dtype: &DType) { + T::reader(dtype).err().expect("reader_expect_err failed!"); + } + + fn writer_output(dtype: &DType, value: &T) -> Vec { + let mut vec = vec![]; + T::writer(dtype).unwrap_or_else(|e| panic!("{}", e)) + .write_one(&mut vec, value).unwrap(); + vec + } + + fn writer_expect_err(dtype: &DType) { + T::writer(dtype).err().expect("writer_expect_err failed!"); + } + + fn writer_expect_write_err(dtype: &DType, value: &T) { + let mut vec = vec![]; + T::writer(dtype).unwrap_or_else(|e| panic!("{}", e)) + .write_one(&mut vec, value) + .err().expect("writer_expect_write_err failed!"); + } + + const BE_ONE_32: &[u8] = &[0, 0, 0, 1]; + const LE_ONE_32: &[u8] = &[1, 0, 0, 0]; + + #[test] + fn native_int_types() { + let be = DType::parse("'>i4'").unwrap(); + let le = DType::parse("'(&be, BE_ONE_32), 1); + assert_eq!(reader_output::(&le, LE_ONE_32), 1); + assert_eq!(writer_output::(&be, &1), BE_ONE_32); + assert_eq!(writer_output::(&le, &1), LE_ONE_32); + + let be = DType::parse("'>u4'").unwrap(); + let le = DType::parse("'(&be, BE_ONE_32), 1); + assert_eq!(reader_output::(&le, LE_ONE_32), 1); + assert_eq!(writer_output::(&be, &1), BE_ONE_32); + assert_eq!(writer_output::(&le, &1), LE_ONE_32); + + for &dtype in &["'>i1'", "'(&dtype, &[1]), 1); + assert_eq!(writer_output::(&dtype, &1), &[1][..]); + } + + for &dtype in &["'>u1'", "'(&dtype, &[1]), 1); + assert_eq!(writer_output::(&dtype, &1), &[1][..]); + } + } + + #[test] + fn native_float_types() { + let be_bytes = 42.0_f64.to_bits().to_be_bytes(); + let le_bytes = 42.0_f64.to_bits().to_le_bytes(); + let be = DType::parse("'>f8'").unwrap(); + let le = DType::parse("'(&be, &be_bytes), 42.0); + assert_eq!(reader_output::(&le, &le_bytes), 42.0); + assert_eq!(writer_output::(&be, &42.0), &be_bytes); + assert_eq!(writer_output::(&le, &42.0), &le_bytes); + + let be_bytes = 42.0_f32.to_bits().to_be_bytes(); + let le_bytes = 42.0_f32.to_bits().to_le_bytes(); + let be = DType::parse("'>f4'").unwrap(); + let le = DType::parse("'(&be, &be_bytes), 42.0); + assert_eq!(reader_output::(&le, &le_bytes), 42.0); + assert_eq!(writer_output::(&be, &42.0), &be_bytes); + assert_eq!(writer_output::(&le, &42.0), &le_bytes); + } + + #[test] + fn illegal_endianness() { + // There is currently no need to test that each type rejects '|' endianness in their + // (De)Serialize impls, because this is checked up-front during DType construction. + assert!(DType::parse("'|i4'").is_err()); + } + + #[test] + fn wrong_size_int() { + let t_i32 = DType::parse("'(&t_i32); + reader_expect_err::(&t_i32); + reader_expect_err::(&t_u32); + reader_expect_err::(&t_u32); + writer_expect_err::(&t_i32); + writer_expect_err::(&t_i32); + writer_expect_err::(&t_u32); + writer_expect_err::(&t_u32); + } + + #[test] + fn bytes_any_endianness() { + for ty in vec!["'S3'", "'|S3'"] { + let ty = DType::parse(ty).unwrap(); + assert_eq!(writer_output(&ty, &[1, 3, 5][..]), vec![1, 3, 5]); + assert_eq!(reader_output::>(&ty, &[1, 3, 5][..]), vec![1, 3, 5]); + } + } + + #[test] + fn bytes_size_zero() { + let ts = DType::parse("'|S0'").unwrap(); + assert_eq!(reader_output::>(&ts, &[]), vec![]); + assert_eq!(writer_output(&ts, &[][..]), vec![]); + + let ts = DType::parse("'|V0'").unwrap(); + assert_eq!(reader_output::>(&ts, &[]), vec![]); + assert_eq!(writer_output::<[u8]>(&ts, &[]), vec![]); + } + + #[test] + fn wrong_size_bytes() { + let s_3 = DType::parse("'|S3'").unwrap(); + let v_3 = DType::parse("'|V3'").unwrap(); + + assert_eq!(writer_output(&s_3, &[1, 3, 5][..]), vec![1, 3, 5]); + assert_eq!(writer_output(&v_3, &[1, 3, 5][..]), vec![1, 3, 5]); + + assert_eq!(writer_output(&s_3, &[1][..]), vec![1, 0, 0]); + writer_expect_write_err(&v_3, &[1][..]); + + assert_eq!(writer_output(&s_3, &[][..]), vec![0, 0, 0]); + writer_expect_write_err(&v_3, &[][..]); + + writer_expect_write_err(&s_3, &[1, 3, 5, 7][..]); + writer_expect_write_err(&v_3, &[1, 3, 5, 7][..]); + } + + #[test] + fn read_bytes_with_trailing_zeros() { + let ts = DType::parse("'|S2'").unwrap(); + assert_eq!(reader_output::>(&ts, &[1, 3]), vec![1, 3]); + assert_eq!(reader_output::>(&ts, &[1, 0]), vec![1]); + assert_eq!(reader_output::>(&ts, &[0, 0]), vec![]); + + let ts = DType::parse("'|V2'").unwrap(); + assert_eq!(reader_output::>(&ts, &[1, 3]), vec![1, 3]); + assert_eq!(reader_output::>(&ts, &[1, 0]), vec![1, 0]); + assert_eq!(reader_output::>(&ts, &[0, 0]), vec![0, 0]); + } + + #[test] + fn bytestr_preserves_interior_zeros() { + const DATA: &[u8] = &[0, 1, 0, 0, 3, 5]; + + let ts = DType::parse("'|S6'").unwrap(); + + assert_eq!(reader_output::>(&ts, DATA), DATA.to_vec()); + assert_eq!(writer_output(&ts, DATA), DATA.to_vec()); + } + + #[test] + fn default_simple_type_strs() { + assert_eq!(i8::default_dtype().descr(), "'|i1'"); + assert_eq!(u8::default_dtype().descr(), "'|u1'"); + + if 1 == i32::from_be(1) { + assert_eq!(i16::default_dtype().descr(), "'>i2'"); + assert_eq!(i32::default_dtype().descr(), "'>i4'"); + assert_eq!(i64::default_dtype().descr(), "'>i8'"); + assert_eq!(u32::default_dtype().descr(), "'>u4'"); + } else { + assert_eq!(i16::default_dtype().descr(), "'>(&ts, &vec![1, 3, 5]), vec![1, 3, 5]); + assert_eq!(writer_output::<&[u8]>(&ts, &&[1, 3, 5][..]), vec![1, 3, 5]); + } + + #[test] + fn dynamic_readers_and_writers() { + let writer: Box> = Box::new(i32::writer(&i32::default_dtype()).unwrap()); + let reader: Box> = Box::new(i32::reader(&i32::default_dtype()).unwrap()); + + let mut buf = vec![]; + writer.write_one(&mut buf, &4000).unwrap(); + assert_eq!(reader.read_one(&buf).0, 4000); + } +} diff --git a/src/type_str.rs b/src/type_str.rs new file mode 100644 index 0000000..d71f316 --- /dev/null +++ b/src/type_str.rs @@ -0,0 +1,641 @@ +use std::fmt; + +/// Represents an Array Interface type-string. +/// +/// This is more or less the `DType` of a scalar type. +/// Exposes a `FromStr` impl for construction, and a `Display` impl for writing. +/// +/// ``` +/// # fn main() -> Result<(), Box> { +/// use npy::TypeStr; +/// +/// let ts = "|i1".parse::()?; +/// +/// assert_eq!(format!("{}", ts), "|i1"); +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct TypeStr { + pub(crate) endianness: Endianness, + pub(crate) type_kind: TypeKind, + pub(crate) size: u64, + pub(crate) time_units: Option, +} + +/// Represents the first character in a type-string. +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub(crate) enum Endianness { + /// Code `<`. + Little, + /// Code `>`. + Big, + /// Code `|`. Used when endianness is irrelevant. + /// + /// Only valid when the size is `1`, or when `kind` is `TypeKind::Other` + /// or `TypeKind::ByteStr`. + Irrelevant, +} + +impl Endianness { + fn from_char(s: char) -> Option { + match s { + '<' => Some(Endianness::Little), + '>' => Some(Endianness::Big), + '|' => Some(Endianness::Irrelevant), + _ => None, + } + } + + fn to_str(self) -> &'static str { + match self { + Endianness::Little => "<", + Endianness::Big => ">", + Endianness::Irrelevant => "|", + } + } +} + +impl Endianness { + pub(crate) fn of_machine() -> Self { + match i32::from_be(0x00_00_00_01) { + 0x00_00_00_01 => Endianness::Big, + 0x01_00_00_00 => Endianness::Little, + _ => unreachable!(), + } + } + + /// Returns `true` if byteorder swapping is necessary between two types. + pub(crate) fn requires_swap(self, other: Endianness) -> bool { + match (self, other) { + (Endianness::Little, Endianness::Big) | + (Endianness::Big, Endianness::Little) => true, + + _ => false, + } + } +} + +/// Represents the second character in a type-string. +/// +/// Indicates the type of data stored. Affects the interpretation of `size` and `endianness`. +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub(crate) enum TypeKind { + /// Code `b`. + /// + /// `size` must be 1, and legal values are `0x00` (`false`) or `0x01` (`true`). + Bool, + /// Code `i`. + /// + /// Notice that numpy does not support 128-bit integers. + Int, + /// Code `u`. + /// + /// Notice that numpy does not support 128-bit integers. + Uint, + /// Code `f`. + /// + /// Notice that numpy supports half-precision floats (`np.float16`), as well as possibly + /// `` for serialization. + /// + /// A `bytes` of length `size`. Strings shorter than this length are zero-padded on the right. + /// This implies that they cannot contain trailing `NUL`s. (They can, however, contain interior + /// `NUL`s). To preserve trailing `NUL`s, use `RawData` (`V`) instead. + ByteStr, + /// Code `U`. Represents a Python 3 `str` (`unicode` in Python 2). + /// + /// A `str` that contains `size` code points (**not bytes!**). Each code unit is encoded as a + /// 32-bit integer of the given endianness. Strings with fewer than `size` code units are + /// zero-padded on the right. (thus they cannot contain trailing copies of U+0000 'NULL'; + /// they can, however, contain interior copies) + /// + /// Like Rust's `char`, the code points must have a value in `[0, 0x110000)`. However, unlike + /// `char`, surrogate code points are allowed. + UnicodeStr, + /// Code `V`. Represents a binary blob of `size` bytes. + /// + /// Can use `Vec` for serialization. + RawData, +} + +impl TypeKind { + fn from_char(s: char) -> Option { + match s { + 'b' => Some(TypeKind::Bool), + 'i' => Some(TypeKind::Int), + 'u' => Some(TypeKind::Uint), + 'f' => Some(TypeKind::Float), + 'c' => Some(TypeKind::Complex), + 'm' => Some(TypeKind::TimeDelta), + 'M' => Some(TypeKind::DateTime), + 'S' => Some(TypeKind::ByteStr), + 'U' => Some(TypeKind::UnicodeStr), + 'V' => Some(TypeKind::RawData), + _ => None, + } + } + + fn to_str(self) -> &'static str { + match self { + TypeKind::Bool => "b", + TypeKind::Int => "i", + TypeKind::Uint => "u", + TypeKind::Float => "f", + TypeKind::Complex => "c", + TypeKind::TimeDelta => "m", + TypeKind::DateTime => "M", + TypeKind::ByteStr => "S", + TypeKind::UnicodeStr => "U", + TypeKind::RawData => "V", + } + } +} + +impl TypeKind { + // `None` means all sizes are valid. + fn valid_sizes(self) -> Option<&'static [u64]> { + match self { + TypeKind::Bool => Some(&[1]), + + // numpy doesn't actually support 128-bit ints + TypeKind::Int | + TypeKind::Uint => Some(&[1, 2, 4, 8]), + + // yes, 128-bit floats are supported by numpy + TypeKind::Float => Some(&[2, 4, 8, 16]), + + // 4-byte complex numbers are mysteriously missing from numpy + TypeKind::Complex => Some(&[8, 16, 32]), + + TypeKind::TimeDelta | + TypeKind::DateTime => Some(&[8]), + + // (Note: numpy does support types `|S0` and `|U0`, though for some reason `numpy.save` + // changes them to `|S1` and `|U1`.) + TypeKind::ByteStr | + TypeKind::UnicodeStr | + TypeKind::RawData => None, + } + } + + /// Returns `true` if `|` endianness is illegal. + fn requires_endianness(self, size: u64) -> bool { + match self { + TypeKind::Bool | + TypeKind::Int | + TypeKind::Uint | + TypeKind::Float | + TypeKind::TimeDelta | + TypeKind::DateTime | + TypeKind::Complex => size != 1, + + TypeKind::UnicodeStr => true, + + TypeKind::ByteStr | + TypeKind::RawData => false, + } + } + + /// Returns `true` if unit specification is required. + fn has_units(self) -> bool { + match self { + TypeKind::TimeDelta | + TypeKind::DateTime => true, + + _ => false, + } + } +} + +impl TypeStr { + pub(crate) fn with_auto_endianness(type_kind: TypeKind, size: u64, time_units: Option) -> Self { + let endianness = match type_kind.requires_endianness(size) { + true => Endianness::of_machine(), + false => Endianness::Irrelevant, + }; + TypeStr { endianness, type_kind, size, time_units }.validate().unwrap() + } + + /// The number of bytes for a single scalar value. + pub(crate) fn num_bytes(&self) -> usize { + match self.type_kind { + TypeKind::Bool | + TypeKind::Int | + TypeKind::Uint | + TypeKind::Float | + TypeKind::Complex | + TypeKind::TimeDelta | + TypeKind::DateTime | + TypeKind::ByteStr | + TypeKind::RawData => self.size as usize, + + TypeKind::UnicodeStr => self.size as usize * 4, + } + } +} + +/// Represents the units of the `m` and `M` datatypes. +/// +/// These appear inside square brackets at the end of the `descr` string for these datatypes. +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub(crate) enum TimeUnits { + /// Code `Y`. + Year, + /// Code `M`. + Month, + /// Code `W`. + Week, + /// Code `D`. + Day, + /// Code `h`. + Hour, + /// Code `m`. + Minute, + /// Code `s`. + Second, + /// Code `ms`. + Millisecond, + /// Code `us`. + Microsecond, + /// Code `ns`. + Nanosecond, + /// Code `ps`. + Picosecond, + /// Code `fs`. + Femtosecond, + /// Code `as`. + Attosecond, +} + +impl TimeUnits { + fn from_str(s: &str) -> Option { + match s { + "Y" => Some(TimeUnits::Year), + "M" => Some(TimeUnits::Month), + "W" => Some(TimeUnits::Week), + "D" => Some(TimeUnits::Day), + "h" => Some(TimeUnits::Hour), + "m" => Some(TimeUnits::Minute), + "s" => Some(TimeUnits::Second), + "ms" => Some(TimeUnits::Millisecond), + "us" => Some(TimeUnits::Microsecond), + "ns" => Some(TimeUnits::Nanosecond), + "ps" => Some(TimeUnits::Picosecond), + "fs" => Some(TimeUnits::Femtosecond), + "as" => Some(TimeUnits::Attosecond), + _ => None, + } + } + + fn to_str(self) -> &'static str { + match self { + TimeUnits::Year => "Y", + TimeUnits::Month => "M", + TimeUnits::Week => "W", + TimeUnits::Day => "D", + TimeUnits::Hour => "h", + TimeUnits::Minute => "m", + TimeUnits::Second => "s", + TimeUnits::Millisecond => "ms", + TimeUnits::Microsecond => "us", + TimeUnits::Nanosecond => "ns", + TimeUnits::Picosecond => "ps", + TimeUnits::Femtosecond => "fs", + TimeUnits::Attosecond => "as", + } + } +} + +impl fmt::Display for Endianness { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(self.to_str(), f) + } +} + +impl fmt::Display for TypeKind { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(self.to_str(), f) + } +} + +impl fmt::Display for TimeUnits { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(self.to_str(), f) + } +} + +impl fmt::Display for TypeStr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}{}{}", self.endianness, self.type_kind, self.size)?; + if let Some(time_units) = self.time_units { + write!(f, "[{}]", time_units)?; + } + Ok(()) + } +} + +pub use self::parse::ParseTypeStrError; +mod parse { + use super::*; + + /// Error type returned by `::parse`. + #[derive(Debug, Clone)] + pub struct ParseTypeStrError(ErrorKind); + + #[derive(Debug, Clone)] + enum ErrorKind { + SyntaxError, + ParseIntError(std::num::ParseIntError), + InvalidEndianness(TypeStr), + InvalidSize(TypeStr), + MissingOrUnexpectedUnits(TypeStr), + } + + impl fmt::Display for ParseTypeStrError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use self::ErrorKind::*; + + match &self.0 { + SyntaxError => write!(f, "Invalid type-string"), + InvalidEndianness(ty) => write!(f, "Type string '{}' has invalid endianness", ty), + InvalidSize(ty) => { + write!(f, "Type string '{}' has invalid size.", ty)?; + write!(f, " Valid sizes are: {:?}", ty.type_kind.valid_sizes().unwrap())?; + Ok(()) + }, + MissingOrUnexpectedUnits(ty) => { + if ty.type_kind.has_units() { + write!(f, "Type string '{}' is missing time units.", ty) + } else { + write!(f, "Unexpected time units in type string '{}'.", ty) + } + }, + ParseIntError(e) => write!(f, "{}", e), + } + } + } + + macro_rules! bail { + ($variant:expr) => { + return Err(ParseTypeStrError($variant)) + }; + } + + impl std::error::Error for ParseTypeStrError {} + + impl std::str::FromStr for TypeStr { + type Err = ParseTypeStrError; + + fn from_str(input: &str) -> Result { + use self::ErrorKind::*; + + if input.len() < 3 { + bail!(SyntaxError); + } + + let mut chars = input.chars(); + + let c = chars.next().unwrap(); + let endianness = match Endianness::from_char(c) { + None => bail!(SyntaxError), + Some(v) => v, + }; + + let c = chars.next().unwrap(); + let type_kind = match TypeKind::from_char(c) { + None => bail!(SyntaxError), + Some(v) => v, + }; + + let remainder = chars.as_str(); + let size_end = { + remainder.bytes().position(|b| !b.is_ascii_digit()) + .unwrap_or(remainder.len()) + }; + if size_end == 0 { + bail!(SyntaxError); + } + let (size, remainder) = remainder.split_at(size_end); + let size = match size.parse() { + Err(e) => bail!(ParseIntError(e)), // probably overflow + Ok(v) => v, + }; + + let time_units = if remainder.is_empty() { + None + } else { + let mut chars = remainder.chars(); + match (chars.next(), chars.next_back()) { + (Some('['), Some(']')) => {}, + _ => bail!(SyntaxError), + } + + match TimeUnits::from_str(chars.as_str()) { + None => bail!(SyntaxError), + Some(v) => Some(v), + } + }; + + TypeStr { endianness, type_kind, size, time_units } + .validate() + } + } + + impl TypeStr { + pub(crate) fn validate(self) -> Result { + use self::ErrorKind::*; + + let TypeStr { endianness, type_kind, size, time_units } = self; + + if type_kind.requires_endianness(size) && endianness == Endianness::Irrelevant { + bail!(InvalidEndianness(self)); + } + + if let Some(valid_sizes) = type_kind.valid_sizes() { + if !valid_sizes.contains(&size) { + bail!(InvalidSize(self)); + } + } + + if type_kind.has_units() != time_units.is_some() { + bail!(MissingOrUnexpectedUnits(self)); + } + + Ok(self) + } + } + + #[cfg(test)] + #[deny(unused)] + mod tests { + use super::*; + + macro_rules! assert_matches { + ($expr:expr, $pat:pat) => { + match $expr { + $pat => {}, + actual => panic!("Expected: {}\nGot: {:?}", stringify!($pat), actual), + } + }; + } + + macro_rules! check_ok { + ($s:expr) => { + assert_matches!($s.parse::(), Ok(_)); + }; + } + macro_rules! check_err { + ($s:expr, $p:pat) => { + assert_matches!($s.parse::(), Err(ParseTypeStrError($p))); + }; + } + + #[test] + fn errors() { + use self::ErrorKind::*; + + check_err!("", SyntaxError); + check_err!(">", SyntaxError); + check_err!(">m", SyntaxError); + check_err!(">m8[", SyntaxError); + check_err!(">m8[us", SyntaxError); + check_ok!(">m8[us]"); + check_ok!(">m8[D]"); + check_err!(">m8[us]garbage", SyntaxError); + check_err!(">m8[us]]", SyntaxError); + + + check_err!("", SyntaxError); + check_err!(">", SyntaxError); + check_err!(">i", SyntaxError); + check_ok!(">i8"); + check_ok!(">c16"); + check_err!(">i8garbage", SyntaxError); + + // length-zero integer + check_err!(">m[us]", SyntaxError); + check_err!(">i", SyntaxError); + + // make sure integer overflow doesn't panic + check_err!(">m999999999999999999999999999999[us]", _); + check_err!(">i999999999999999999999999999999", _); + + // Unrecognized specifiers + check_ok!("m8[us]"); + check_err!(">m8[bus]", _); + check_err!(">m8[usb]", _); + check_err!(">m8[xq]", _); + + // Required endianness + check_ok!("|i1"); + check_ok!("|S7"); + check_ok!("|V7"); + check_err!("|i8", InvalidEndianness { .. }); + check_err!("|U1", InvalidEndianness { .. }); + + // Size + check_ok!(">i8"); + check_err!(">i9", InvalidSize { .. }); + check_err!(">m4[us]", InvalidSize { .. }); + check_err!(">b4", InvalidSize { .. }); + check_ok!("|S0"); + check_ok!(">U0"); + check_ok!("|V0"); + check_ok!("|V7"); + + // Presence or absence of units + check_ok!(">i8"); + check_ok!(">m8[us]"); + check_err!(">i8[us]", MissingOrUnexpectedUnits { .. }); + check_err!(">m8", MissingOrUnexpectedUnits { .. }); + } + } +} + +#[cfg(test)] +#[deny(unused)] +mod tests { + use super::*; + + #[test] + fn display_simple() { + assert_eq!( + TypeStr { + endianness: Endianness::Little, + type_kind: TypeKind::Int, + size: 8, + time_units: None, + }.to_string(), + "m8[ns]", + ); + } + + #[test] + fn roundtrip() { + macro_rules! check_roundtrip { + ($text:expr) => { + let text = $text.to_string(); + match text.parse::() { + Err(e) => panic!("Failed to parse {:?}: {}", text, e), + Ok(v) => assert_eq!(text, v.to_string()), + } + }; + } + + check_roundtrip!(">i8"); + check_roundtrip!(">f16"); + check_roundtrip!("i1"); + check_roundtrip!("|i1"); + check_roundtrip!("|S7"); + check_roundtrip!("|S0"); + check_roundtrip!("U3"); + check_roundtrip!("m8[ms]"); + } +} diff --git a/tests/derive_hygiene.rs b/tests/derive_hygiene.rs new file mode 100644 index 0000000..b8eed08 --- /dev/null +++ b/tests/derive_hygiene.rs @@ -0,0 +1,18 @@ +extern crate npy_derive; +extern crate npy as lol; + +#[no_implicit_prelude] +mod not_root { + use ::npy_derive; + + #[derive(npy_derive::Serialize, npy_derive::Deserialize)] + struct Struct { + foo: i32, + bar: LocalType, + } + + #[derive(npy_derive::Serialize, npy_derive::Deserialize)] + struct LocalType; +} + +fn main() {} diff --git a/tests/roundtrip.rs b/tests/roundtrip.rs index 334a616..f4f3935 100644 --- a/tests/roundtrip.rs +++ b/tests/roundtrip.rs @@ -1,20 +1,20 @@ -#[macro_use] -extern crate npy_derive; extern crate npy; extern crate byteorder; use byteorder::ByteOrder; use std::io::{Read, Write}; use byteorder::{WriteBytesExt, LittleEndian}; -use npy::{DType, Serializable}; +use npy::{DType, Field, OutFile, Serialize, Deserialize, AutoSerialize}; -#[derive(Serializable, Debug, PartialEq, Clone)] +#[derive(Serialize, Deserialize, AutoSerialize)] +#[derive(Debug, PartialEq, Clone)] struct Nested { v1: f32, v2: f32, } -#[derive(Serializable, Debug, PartialEq, Clone)] +#[derive(Serialize, Deserialize, AutoSerialize)] +#[derive(Debug, PartialEq, Clone)] struct Array { v_i8: i8, v_i16: i16, @@ -35,35 +35,66 @@ struct Array { #[derive(Debug, PartialEq, Clone)] struct Vector5(Vec); -impl Serializable for Vector5 { +impl AutoSerialize for Vector5 { #[inline] - fn dtype() -> DType { - DType::Plain { ty: " DType { + DType::Plain { ty: " usize { 5 * 4 } +impl Serialize for Vector5 { + type Writer = Vector5Writer; - #[inline] - fn read(buf: &[u8]) -> Self { - let mut ret = Vector5(vec![]); - let mut off = 0; - for _ in 0..5 { - ret.0.push(LittleEndian::read_i32(&buf[off..])); - off += i32::n_bytes(); + fn writer(dtype: &DType) -> Result { + if dtype == &Self::default_dtype() { + Ok(Vector5Writer) + } else { + Err(npy::DTypeError::custom("Vector5 only supports ' Result { + if dtype == &Self::default_dtype() { + Ok(Vector5Reader) + } else { + Err(npy::DTypeError::custom("Vector5 only supports '(&self, writer: &mut W) -> std::io::Result<()> { + fn write_one(&self, mut writer: W, value: &Self::Value) -> std::io::Result<()> { for i in 0..5 { - writer.write_i32::(self.0[i])? + writer.write_i32::(value.0[i])? } Ok(()) } } +impl npy::TypeRead for Vector5Reader { + type Value = Vector5; + + #[inline] + fn read_one<'a>(&self, mut remainder: &'a [u8]) -> (Self::Value, &'a [u8]) { + let mut ret = Vector5(vec![]); + for _ in 0..5 { + ret.0.push(LittleEndian::read_i32(remainder)); + remainder = &remainder[4..]; + } + (ret, remainder) + } +} + #[test] fn roundtrip() { let n = 100i64; @@ -101,16 +132,254 @@ fn roundtrip() { assert_eq!(arrays, arrays2); } +fn plain_field(name: &str, dtype: &str) -> Field { + Field { + name: name.to_string(), + dtype: DType::new_scalar(dtype.parse().unwrap()), + } +} + #[test] -fn roundtrip_with_simple_dtype() { +fn roundtrip_with_plain_dtype() { let array_written = vec![2., 3., 4., 5.]; - npy::to_file("tests/roundtrip_simple.npy", array_written.clone()).unwrap(); + npy::to_file("tests/roundtrip_plain.npy", array_written.clone()).unwrap(); let mut buffer = vec![]; - std::fs::File::open("tests/roundtrip_simple.npy").unwrap() + std::fs::File::open("tests/roundtrip_plain.npy").unwrap() .read_to_end(&mut buffer).unwrap(); let array_read = npy::NpyData::from_bytes(&buffer).unwrap().to_vec(); assert_eq!(array_written, array_read); } + +#[test] +fn roundtrip_byteorder() { + let path = "tests/roundtrip_byteorder.npy"; + + #[derive(npy::Serialize, npy::Deserialize)] + #[derive(Debug, PartialEq, Clone)] + struct Row { + be_u32: u32, + le_u32: u32, + be_f32: f32, + le_f32: f32, + be_i8: i8, + le_i8: i8, + na_i8: i8, + } + + let dtype = DType::Record(vec![ + plain_field("be_u32", ">u4"), + plain_field("le_u32", "f4"), + plain_field("le_f32", "i1"), + plain_field("le_i8", "::from_bytes(&buffer).unwrap(); + assert_eq!(data.to_vec(), vec![row]); + assert_eq!(data.dtype(), &dtype); +} + +#[test] +fn roundtrip_datetime() { + let path = "tests/roundtrip_datetime.npy"; + + // Similar to: + // + // ``` + // import numpy.datetime64 as dt + // import numpy as np + // + // arr = np.array([( + // dt('2011-01-01', 'ns'), + // dt('2011-01-02') - dt('2011-01-01'), + // dt('2011-01-02') - dt('2011-01-01'), + // )], dtype=[ + // ('datetime', 'm8[D]'), + // ]) + // ``` + #[derive(npy::Serialize, npy::Deserialize)] + #[derive(Debug, PartialEq, Clone)] + struct Row { + datetime: u64, + timedelta_le: i64, + timedelta_be: i64, + } + + let dtype = DType::Record(vec![ + plain_field("datetime", "m8[D]"), + ]); + + let row = Row { + datetime: 1_293_840_000_000_000_000, + timedelta_le: 1, + timedelta_be: 1, + }; + + let expected_data_bytes = { + let mut buf = vec![]; + buf.extend_from_slice(&i64::to_le_bytes(1_293_840_000_000_000_000)); + buf.extend_from_slice(&i64::to_le_bytes(1)); + buf.extend_from_slice(&i64::to_be_bytes(1)); + buf + }; + + let mut out_file = OutFile::open_with_dtype(&dtype, path).unwrap(); + out_file.push(&row).unwrap(); + out_file.close().unwrap(); + + let buffer = std::fs::read(path).unwrap(); + assert!(buffer.ends_with(&expected_data_bytes)); + + let data = npy::NpyData::::from_bytes(&buffer).unwrap(); + assert_eq!(data.to_vec(), vec![row]); + assert_eq!(data.dtype(), &dtype); +} + +#[test] +fn roundtrip_bytes() { + let path = "tests/roundtrip_bytes.npy"; + + // Similar to: + // + // ``` + // import numpy as np + // + // arr = np.array([( + // b"\x00such\x00wow", + // b"\x00such\x00wow\x00\x00\x00", + // )], dtype=[ + // ('bytestr', '|S12'), + // ('raw', '|V12'), + // ]) + // ``` + #[derive(npy::Serialize, npy::Deserialize)] + #[derive(Debug, PartialEq, Clone)] + struct Row { + bytestr: Vec, + raw: Vec, + } + + let dtype = DType::Record(vec![ + plain_field("bytestr", "|S12"), + plain_field("raw", "|V12"), + ]); + + let row = Row { + // checks that: + // * bytestr can be shorter than the len + // * bytestr can contain non-trailing NULs + bytestr: b"\x00lol\x00lol".to_vec(), + // * raw can contain trailing NULs + raw: b"\x00lol\x00lol\x00\x00\x00\x00".to_vec(), + }; + + let expected_data_bytes = { + let mut buf = vec![]; + // check that bytestr is nul-padded + buf.extend_from_slice(b"\x00lol\x00lol\x00\x00\x00\x00"); + buf.extend_from_slice(b"\x00lol\x00lol\x00\x00\x00\x00"); + buf + }; + + let mut out_file = OutFile::open_with_dtype(&dtype, path).unwrap(); + out_file.push(&row).unwrap(); + out_file.close().unwrap(); + + let buffer = std::fs::read(path).unwrap(); + assert!(buffer.ends_with(&expected_data_bytes)); + + let data = npy::NpyData::::from_bytes(&buffer).unwrap(); + assert_eq!(data.to_vec(), vec![row]); + assert_eq!(data.dtype(), &dtype); +} + +// check that all byte orders are identical for bytestrings +// (i.e. don't accidentally reverse the bytestrings) +#[test] +fn roundtrip_bytes_byteorder() { + let path = "tests/roundtrip_bytes_byteorder.npy"; + + #[derive(npy::Serialize, npy::Deserialize)] + #[derive(Debug, PartialEq, Clone)] + struct Row { + s_le: Vec, + s_be: Vec, + s_na: Vec, + v_le: Vec, + v_be: Vec, + v_na: Vec, + }; + + let dtype = DType::Record(vec![ + plain_field("s_le", "S4"), + plain_field("s_na", "|S4"), + plain_field("v_le", "V4"), + plain_field("v_na", "|V4"), + ]); + + let row = Row { + s_le: b"abcd".to_vec(), + s_be: b"abcd".to_vec(), + s_na: b"abcd".to_vec(), + v_le: b"abcd".to_vec(), + v_be: b"abcd".to_vec(), + v_na: b"abcd".to_vec(), + }; + + let expected_data_bytes = { + let mut buf = vec![]; + for _ in 0..6 { + buf.extend_from_slice(b"abcd"); + } + buf + }; + + let mut out_file = OutFile::open_with_dtype(&dtype, path).unwrap(); + out_file.push(&row).unwrap(); + out_file.close().unwrap(); + + let buffer = std::fs::read(path).unwrap(); + assert!(buffer.ends_with(&expected_data_bytes)); + + let data = npy::NpyData::::from_bytes(&buffer).unwrap(); + assert_eq!(data.to_vec(), vec![row]); + assert_eq!(data.dtype(), &dtype); +} diff --git a/tests/serialize_array.rs b/tests/serialize_array.rs new file mode 100644 index 0000000..4f746b1 --- /dev/null +++ b/tests/serialize_array.rs @@ -0,0 +1,125 @@ +extern crate npy; + +use npy::{Deserialize, Serialize, AutoSerialize, DType, TypeStr, Field}; +use npy::{TypeRead, TypeWrite}; + +// These tests ideally would be in npy::serialize::tests, but they require "derive" +// because arrays can only exist as record fields. + +fn reader_output(dtype: &DType, bytes: &[u8]) -> T { + T::reader(dtype).unwrap_or_else(|e| panic!("{}", e)).read_one(bytes).0 +} + +fn reader_expect_err(dtype: &DType) { + T::reader(dtype).err().expect("reader_expect_err failed!"); +} + +fn writer_output(dtype: &DType, value: &T) -> Vec { + let mut vec = vec![]; + T::writer(dtype).unwrap_or_else(|e| panic!("{}", e)) + .write_one(&mut vec, value).unwrap(); + vec +} + +fn writer_expect_err(dtype: &DType) { + T::writer(dtype).err().expect("writer_expect_err failed!"); +} + +#[derive(npy::Serialize, npy::Deserialize, npy::AutoSerialize)] +#[derive(Debug, PartialEq)] +struct Array3 { + field: [i32; 3], +} + +#[derive(npy::Serialize, npy::Deserialize, npy::AutoSerialize)] +#[derive(Debug, PartialEq)] +struct Array23 { + field: [[i32; 3]; 2], +} + +const ARRAY3_DESCR_LE: &str = "[('field', '(&dtype, &bytes), value); + assert_eq!(writer_output::(&dtype, &value), bytes); + reader_expect_err::(&dtype); + writer_expect_err::(&dtype); +} + +#[test] +fn read_write_nested() { + let dtype = DType::parse(ARRAY23_DESCR_LE).unwrap(); + let value = Array23 { field: [[1, 3, 5], [7, 9, 11]] }; + let mut bytes = vec![]; + for n in vec![1, 3, 5, 7, 9, 11] { + bytes.extend_from_slice(&i32::to_le_bytes(n)); + } + + assert_eq!(reader_output::(&dtype, &bytes), value); + assert_eq!(writer_output::(&dtype, &value), bytes); + reader_expect_err::(&dtype); + writer_expect_err::(&dtype); +} + +#[test] +fn incompatible() { + // wrong size + let dtype = DType::parse(ARRAY2_DESCR_LE).unwrap(); + writer_expect_err::(&dtype); + reader_expect_err::(&dtype); + + // scalar instead of array + let dtype = DType::parse(ARRAY_SCALAR_DESCR_LE).unwrap(); + writer_expect_err::(&dtype); + reader_expect_err::(&dtype); + + // record instead of array + let dtype = DType::parse(ARRAY_RECORD_DESCR_LE).unwrap(); + writer_expect_err::(&dtype); + reader_expect_err::(&dtype); +} + +#[test] +fn default_dtype() { + let int_ty: TypeStr = { + if 1 == i32::from_be(1) { + ">i4".parse().unwrap() + } else { + "