From e4f1861509a996408baa540cda1f88c290214ce1 Mon Sep 17 00:00:00 2001 From: Preston Evans Date: Sat, 9 Sep 2023 14:28:05 -0500 Subject: [PATCH] Check that reader empty --- src/de.rs | 35 ++++++++++++++++++++++++----------- tests/serde.rs | 4 ++++ 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/src/de.rs b/src/de.rs index 9d4f028..2c7f738 100644 --- a/src/de.rs +++ b/src/de.rs @@ -59,7 +59,8 @@ where T: DeserializeOwned, { let mut deserializer = Deserializer::from_reader(reader, crate::MAX_CONTAINER_DEPTH); - T::deserialize(&mut deserializer) + let t = T::deserialize(&mut deserializer)?; + deserializer.end().map(move |_| t) } /// Deserialize a type from an implementation of [`Read`] using the provided seed @@ -71,7 +72,8 @@ where for<'a> T: DeserializeSeed<'a>, { let mut deserializer = Deserializer::from_reader(reader, crate::MAX_CONTAINER_DEPTH); - seed.deserialize(&mut deserializer) + let t = seed.deserialize(&mut deserializer)?; + deserializer.end().map(move |_| t) } /// Deserialization implementation for BCS @@ -146,6 +148,11 @@ trait BcsDeserializer<'de> { seed: K, ) -> Result<(K::Value, Self::MaybeBorrowedBytes), Error>; + /// The `Deserializer::end` method should be called after a type has been + /// fully deserialized. This allows the `Deserializer` to validate that + /// the there are no more bytes remaining in the input stream. + fn end(&mut self) -> Result<()>; + fn parse_bool(&mut self) -> Result { let byte = self.next()?; @@ -266,6 +273,15 @@ impl<'de, R: Read> BcsDeserializer<'de> for Deserializer> { let key_bytes = self.input.capture_buffer.take().unwrap(); Ok((key_value, key_bytes)) } + + fn end(&mut self) -> Result<()> { + let mut byte = [0u8; 1]; + match self.input.read_exact(&mut byte) { + Ok(_) => Err(Error::RemainingInput), + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => Ok(()), + Err(e) => Err(e.into()), + } + } } impl<'de> BcsDeserializer<'de> for Deserializer<&'de [u8]> { @@ -307,16 +323,7 @@ impl<'de> BcsDeserializer<'de> for Deserializer<&'de [u8]> { let key_bytes = &previous_input_slice[..key_len]; Ok((key_value, key_bytes)) } -} - -impl<'de> Deserializer<&'de [u8]> { - fn peek(&mut self) -> Result { - self.input.first().copied().ok_or(Error::Eof) - } - /// The `Deserializer::end` method should be called after a type has been - /// fully deserialized. This allows the `Deserializer` to validate that - /// the there are no more bytes remaining in the input stream. fn end(&mut self) -> Result<()> { if self.input.is_empty() { Ok(()) @@ -324,6 +331,12 @@ impl<'de> Deserializer<&'de [u8]> { Err(Error::RemainingInput) } } +} + +impl<'de> Deserializer<&'de [u8]> { + fn peek(&mut self) -> Result { + self.input.first().copied().ok_or(Error::Eof) + } fn parse_bytes(&mut self) -> Result<&'de [u8]> { let len = self.parse_length()?; diff --git a/tests/serde.rs b/tests/serde.rs index 0d49e0b..caa26e9 100644 --- a/tests/serde.rs +++ b/tests/serde.rs @@ -463,6 +463,10 @@ fn by_default_btreesets_are_serialized_as_sequences() { fn leftover_bytes() { let seq = vec![5, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; // 5 extra elements assert_eq!(from_bytes::>(&seq), Err(Error::RemainingInput)); + assert_eq!( + from_bytes_via_reader::>(&seq), + Err(Error::RemainingInput) + ); } #[test]