Skip to content

Commit

Permalink
Fix Serde implementation (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
james7132 authored Feb 26, 2024
1 parent 84c1324 commit fab28de
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
rust: [1.39.0, stable, nightly]
rust: [1.56.0, stable, nightly]

steps:
- uses: actions/checkout@v2
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
target/
.idea/
Cargo.lock
.vscode
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ version = "0.4.2"
authors = ["bluss"]
license = "MIT OR Apache-2.0"
readme = "README.md"
rust-version = "1.56"
edition = "2021"

description = "FixedBitSet is a simple bitset collection"
documentation = "https://docs.rs/fixedbitset/"
Expand All @@ -21,7 +23,7 @@ no-dev-version = true
tag-name = "{{version}}"

[dependencies]
serde = { version = "1.0", features = ["derive"], optional = true }
serde = { version = "1.0", optional = true }

[dev-dependencies]
serde_json = "1.0"
2 changes: 1 addition & 1 deletion benches/benches/benches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,12 @@ fn count_ones(c: &mut Criterion) {

criterion_group!(
benches,
bitchange,
iter_ones_using_contains_all_zeros,
iter_ones_using_contains_all_ones,
iter_ones_all_zeros,
iter_ones_sparse,
iter_ones_all_ones,
iter_ones_all_ones_rev,
insert_range,
insert,
intersect_with,
Expand Down
17 changes: 10 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ mod range;
#[cfg(feature = "serde")]
extern crate serde;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
mod serde_impl;

use std::fmt::Write;
use std::fmt::{Binary, Display, Error, Formatter};
Expand All @@ -38,7 +38,10 @@ use std::cmp::{Ord, Ordering};
use std::iter::{Chain, ExactSizeIterator, FromIterator, FusedIterator};
use std::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Index};

const BITS: usize = std::mem::size_of::<Block>() * 8;
pub(crate) const BITS: usize = std::mem::size_of::<Block>() * 8;
#[cfg(feature = "serde")]
pub(crate) const BYTES: usize = std::mem::size_of::<Block>();

pub type Block = usize;

#[inline]
Expand All @@ -55,11 +58,10 @@ fn div_rem(x: usize) -> (usize, usize) {
/// Derived traits depend on both the zeros and ones, so [0,1] is not equal to
/// [0,1,0].
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct FixedBitSet {
data: Vec<Block>,
pub(crate) data: Vec<Block>,
/// length in bits
length: usize,
pub(crate) length: usize,
}

impl FixedBitSet {
Expand Down Expand Up @@ -2022,16 +2024,17 @@ mod tests {
assert_eq!(format!("{:#}", fb), "0b00101000");
}

// TODO: Rewite this test to be platform agnostic.
#[test]
#[cfg(feature = "serde")]
#[cfg(all(feature = "serde", target_pointer_width = "64"))]
fn test_serialize() {
let mut fb = FixedBitSet::with_capacity(10);
fb.put(2);
fb.put(3);
fb.put(6);
fb.put(8);
let serialized = serde_json::to_string(&fb).unwrap();
assert_eq!(r#"{"data":[332],"length":10}"#, serialized);
assert_eq!(r#"{"length":10,"data":[76,1,0,0,0,0,0,0]}"#, serialized);
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/range.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#[cfg(not(feature = "std"))]
use core as std;
use std::ops::{Range, RangeFrom, RangeFull, RangeTo};

// Taken from https://github.com/bluss/odds/blob/master/src/range.rs.
Expand Down
143 changes: 143 additions & 0 deletions src/serde_impl.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
#[cfg(not(feature = "std"))]
use core as std;

use crate::{FixedBitSet, BYTES};
use serde::de::{self, Deserialize, Deserializer, MapAccess, SeqAccess, Visitor};
use serde::ser::{Serialize, SerializeStruct, Serializer};
use std::{convert::TryFrom, fmt};

struct BitSetByteSerializer<'a>(&'a FixedBitSet);

impl Serialize for FixedBitSet {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut struct_serializer = serializer.serialize_struct("FixedBitset", 2)?;
struct_serializer.serialize_field("length", &(self.length as u64))?;
struct_serializer.serialize_field("data", &BitSetByteSerializer(self))?;
struct_serializer.end()
}
}

impl<'a> Serialize for BitSetByteSerializer<'a> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let len = self.0.data.len() * BYTES;
// PERF: Figure out a way to do this without allocating.
let mut temp = Vec::with_capacity(len);
for block in &self.0.data {
temp.extend(&block.to_le_bytes());
}
serializer.serialize_bytes(&temp)
}
}

impl<'de> Deserialize<'de> for FixedBitSet {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
enum Field {
Length,
Data,
}

impl<'de> Deserialize<'de> for Field {
fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
where
D: Deserializer<'de>,
{
struct FieldVisitor;

impl<'de> Visitor<'de> for FieldVisitor {
type Value = Field;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("`length` or `data`")
}

fn visit_str<E>(self, value: &str) -> Result<Field, E>
where
E: de::Error,
{
match value {
"length" => Ok(Field::Length),
"data" => Ok(Field::Data),
_ => Err(de::Error::unknown_field(value, FIELDS)),
}
}
}

deserializer.deserialize_identifier(FieldVisitor)
}
}

struct FixedBitSetVisitor;

impl<'de> Visitor<'de> for FixedBitSetVisitor {
type Value = FixedBitSet;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("struct Duration")
}

fn visit_seq<V>(self, mut seq: V) -> Result<FixedBitSet, V::Error>
where
V: SeqAccess<'de>,
{
let length = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(0, &self))?;
let data = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(1, &self))?;
Ok(FixedBitSet { length, data })
}

fn visit_map<V>(self, mut map: V) -> Result<FixedBitSet, V::Error>
where
V: MapAccess<'de>,
{
let mut length = None;
let mut temp: Option<&[u8]> = None;
while let Some(key) = map.next_key()? {
match key {
Field::Length => {
if length.is_some() {
return Err(de::Error::duplicate_field("length"));
}
length = Some(map.next_value()?);
}
Field::Data => {
if temp.is_some() {
return Err(de::Error::duplicate_field("data"));
}
temp = Some(map.next_value()?);
}
}
}
let length = length.ok_or_else(|| de::Error::missing_field("length"))?;
let temp = temp.ok_or_else(|| de::Error::missing_field("data"))?;
let block_len = length / BYTES + 1;
let mut data = Vec::with_capacity(block_len);
for chunk in temp.chunks(BYTES) {
match <&[u8; BYTES]>::try_from(chunk) {
Ok(bytes) => data.push(usize::from_le_bytes(*bytes)),
Err(_) => {
let mut bytes = [0u8; BYTES];
bytes[0..BYTES].copy_from_slice(chunk);
data.push(usize::from_le_bytes(bytes));
}
}
}
Ok(FixedBitSet { length, data })
}
}

const FIELDS: &'static [&'static str] = &["length", "data"];
deserializer.deserialize_struct("Duration", FIELDS, FixedBitSetVisitor)
}
}

0 comments on commit fab28de

Please sign in to comment.