Skip to content

Commit

Permalink
Merge pull request #22 from asynchronics/feature/custom-deserialize
Browse files Browse the repository at this point in the history
Validate nanoseconds field on deserialize
  • Loading branch information
sbarral authored Jun 13, 2024
2 parents bf4c504 + 0a16f81 commit 3fbc2e9
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 1 deletion.
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ serde = { version = "1", default-features = false, features = ["derive"], option
nix = { version = "0.26", default-features = false, features = ["time"], optional = true }
defmt = { version = "0.3", optional = true }

[dev-dependencies]
serde_json = "1"

[features]
default = ["std"]
std = []
Expand Down
160 changes: 159 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ pub type Tai1972Time = TaiTime<63_072_000>;
/// assert_eq!(timestamp.subsec_nanos(), 789_333_333);
/// ```
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct TaiTime<const EPOCH_REF: i64> {
/// The number of whole seconds in the future (if positive) or in the past
Expand Down Expand Up @@ -1416,6 +1416,116 @@ impl<const EPOCH_REF: i64> fmt::Display for TaiTime<EPOCH_REF> {
}
}

#[cfg(feature = "serde")]
impl<'de, const EPOCH_REF: i64> serde::Deserialize<'de> for TaiTime<EPOCH_REF> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::de::Deserializer<'de>,
{
use serde::de::{self, Deserialize, Deserializer, MapAccess, SeqAccess, Visitor};

enum Field {
Secs,
Nanos,
}

impl<'de> Deserialize<'de> for Field {
fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
where
D: Deserializer<'de>,
{
struct FieldVisitor;

impl<'de> Visitor<'de> for FieldVisitor {
type Value = Field;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("`secs` or `nanos`")
}

fn visit_str<E>(self, value: &str) -> Result<Field, E>
where
E: de::Error,
{
match value {
"secs" => Ok(Field::Secs),
"nanos" => Ok(Field::Nanos),
_ => Err(de::Error::unknown_field(value, FIELDS)),
}
}
}

deserializer.deserialize_identifier(FieldVisitor)
}
}

struct DurationVisitor<const EPOCH_REF: i64>;

impl<'de, const EPOCH_REF: i64> Visitor<'de> for DurationVisitor<EPOCH_REF> {
type Value = TaiTime<EPOCH_REF>;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("struct TaiTime")
}

fn visit_seq<V>(self, mut seq: V) -> Result<TaiTime<EPOCH_REF>, V::Error>
where
V: SeqAccess<'de>,
{
let secs = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(0, &self))?;
let nanos = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(1, &self))?;

TaiTime::new(secs, nanos).ok_or_else(|| {
de::Error::invalid_value(
de::Unexpected::Unsigned(nanos as u64),
&"a number of nanoseconds between 0 and 999999999",
)
})
}

fn visit_map<V>(self, mut map: V) -> Result<TaiTime<EPOCH_REF>, V::Error>
where
V: MapAccess<'de>,
{
let mut secs = None;
let mut nanos = None;
while let Some(key) = map.next_key()? {
match key {
Field::Secs => {
if secs.is_some() {
return Err(de::Error::duplicate_field("secs"));
}
secs = Some(map.next_value()?);
}
Field::Nanos => {
if nanos.is_some() {
return Err(de::Error::duplicate_field("nanos"));
}
nanos = Some(map.next_value()?);
}
}
}
let secs = secs.ok_or_else(|| de::Error::missing_field("secs"))?;
let nanos = nanos.ok_or_else(|| de::Error::missing_field("nanos"))?;

TaiTime::new(secs, nanos).ok_or_else(|| {
de::Error::invalid_value(
de::Unexpected::Unsigned(nanos as u64),
&"a number of nanoseconds between 0 and 999999999",
)
})
}
}

const FIELDS: &[&str] = &["secs", "nanos"];
deserializer.deserialize_struct("TaiTime", FIELDS, DurationVisitor::<EPOCH_REF>)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -1927,4 +2037,52 @@ mod tests {
assert!(date_time_str.parse::<MonotonicTime>().is_err());
}
}

#[cfg(feature = "serde")]
#[test]
fn deserialize_from_seq() {
use serde_json;

let data = r#"[987654321, 123456789]"#;

let t: GpsTime = serde_json::from_str(data).unwrap();
assert_eq!(t, GpsTime::new(987654321, 123456789).unwrap());
}

#[cfg(feature = "serde")]
#[test]
fn deserialize_from_map() {
use serde_json;

let data = r#"{"secs": 987654321, "nanos": 123456789}"#;

let t: GpsTime = serde_json::from_str(data).unwrap();
assert_eq!(t, GpsTime::new(987654321, 123456789).unwrap());
}

#[cfg(feature = "serde")]
#[test]
fn deserialize_invalid_nanos() {
use serde_json;

let data = r#"{"secs": 987654321, "nanos": 1000000000}"#;

let t: Result<GpsTime, serde_json::Error> = serde_json::from_str(data);

assert!(t.is_err())
}

#[cfg(feature = "serde")]
#[test]
fn serialize_roundtrip() {
use serde_json;

let t0 = GpsTime::new(987654321, 123456789).unwrap();

let data = serde_json::to_string(&t0).unwrap();

let t1: GpsTime = serde_json::from_str(&data).unwrap();

assert_eq!(t0, t1);
}
}

0 comments on commit 3fbc2e9

Please sign in to comment.