diff --git a/Cargo.lock b/Cargo.lock index 552ae19..de63307 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2255,6 +2255,7 @@ dependencies = [ "serde_json", "tempfile", "test-log", + "thiserror", "tokio", "tonic", "tower 0.5.1", @@ -3573,9 +3574,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.85" +version = "2.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5023162dfcd14ef8f32034d8bcd4cc5ddc61ef7a247c024a33e24e1f24d21b56" +checksum = "e89275301d38033efb81a6e60e3497e734dfcc62571f2854bf4b16690398824c" dependencies = [ "proc-macro2", "quote", @@ -3670,18 +3671,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.65" +version = "1.0.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d11abd9594d9b38965ef50805c5e469ca9cc6f197f883f717e0269a3057b3d5" +checksum = "5d171f59dbaa811dbbb1aee1e73db92ec2b122911a48e1390dfe327a821ddede" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.65" +version = "1.0.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae71770322cbd277e69d762a16c444af02aa0575ac0d174f0b9562d3b37f8602" +checksum = "b08be0f17bd307950653ce45db00cd31200d82b624b36e181337d9c7d92765b5" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 1fea656..795181c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,6 +57,7 @@ regex-syntax = "0.8.4" reqwest = { version = "0.12.5", default-features = false, features = ["rustls-tls", "rustls-tls-native-roots", "json", "gzip", "deflate"] } serde_json = "1.0.120" tempfile = "3.10.1" +thiserror = "1.0.66" tonic = "0.12.1" tokio = { version = "1.38.1", features = ["full"] } tower = { version = "0.5.1", features = [ "full" ] } diff --git a/src/dynamic_message.rs b/src/dynamic_message.rs index 5a8707c..ccbb89e 100644 --- a/src/dynamic_message.rs +++ b/src/dynamic_message.rs @@ -4,10 +4,11 @@ use std::collections::HashMap; use std::iter::Peekable; use std::slice::Iter; -use anyhow::anyhow; +use anyhow::{anyhow, bail}; use bytes::{BufMut, Bytes}; use itertools::Itertools; use pact_matching::generators::DefaultVariantMatcher; +use pact_models::expression_parser::DataValue; use pact_models::generators::{ GenerateValue, Generator, @@ -24,6 +25,7 @@ use tonic::Status; use tracing::{debug, error, instrument, trace, warn}; use crate::message_decoder::{decode_message, ProtobufField, ProtobufFieldData}; +use crate::message_decoder::generators::{data_value_to_proto_value, GeneratorError}; #[derive(Debug, Clone)] pub struct PactCodec { @@ -161,14 +163,14 @@ impl DynamicMessage { } /// Retrieve the value for a message field using the given path - #[instrument(ret, skip(self))] - pub fn fetch_field_value(&mut self, path: &DocPath) -> Option { + #[instrument(ret, skip(self), fields(path = %path))] + pub fn fetch_field_value(&mut self, path: &DocPath) -> Option> { let path_tokens = path.tokens().clone(); let mut iter = path_tokens.iter().peekable(); if let Some(PathToken::Root) = iter.peek() { iter.next(); let mut found = None; - let result = self.match_path(&mut iter, |v| { + let result = self.match_path(&mut iter, |v, _| { found.replace(v.clone()); }); if let Err(err) = result { @@ -181,37 +183,45 @@ impl DynamicMessage { } /// Update the value using the given path - #[instrument(ret, skip(self))] + #[instrument(ret, skip(self), fields(path = %path))] pub fn set_field_value(&mut self, path: &DocPath, value: ProtobufFieldData) -> anyhow::Result<()> { let path_tokens = path.tokens().clone(); let mut iter = path_tokens.iter().peekable(); if let Some(PathToken::Root) = iter.peek() { iter.next(); - self.match_path(&mut iter, |v| { - v.data = value.clone(); + self.match_path(&mut iter, |v, segment| { + if let Some(PathToken::Index(index)) = segment { + if index >= v.len() { + v.resize(index + 1, v[0].clone()); + } + v[index].data = value.clone(); + } else { + v[0].data = value.clone(); + } }) } else { Err(anyhow!("Path '{}' does not start with a root path marker ('$')", path)) } } - #[instrument(skip(self, callback))] fn match_path( &mut self, path_tokens: &mut Peekable>, callback: F - ) -> anyhow::Result<()> where F: FnOnce(&mut ProtobufField) { + ) -> anyhow::Result<()> where F: FnOnce(&mut Vec, Option) { let descriptors = self.descriptors.clone(); let fields = &mut self.fields; if let Some(next) = path_tokens.next() { match next { - PathToken::Root => Ok(()), - PathToken::Field(name) => if let Some(field) = find_field_value(fields, name.as_str()) { + PathToken::Root => {}, + PathToken::Field(name) => return if let Some(field) = find_field_values(fields, name.as_str()) { if path_tokens.peek().is_none() { - callback(field); + callback(field, None); Ok(()) } else { - match &mut field.data { + // OK to unwrap here, as if the vec was empty, find_field_values would have skipped it. + let first_entry = field.first_mut().unwrap(); + match &mut first_entry.data { ProtobufFieldData::Enum(_, _) => Err(anyhow!("Support for dynamically fetching enum values is not supported yet")), ProtobufFieldData::Message(data, descriptor) => { let mut buffer = Bytes::copy_from_slice(data); @@ -230,22 +240,39 @@ impl DynamicMessage { } } }, - _ => { - warn!("Ignoring field of type '{}'", field.data.type_name()); - Ok(()) + _ => match path_tokens.next() { + Some(PathToken::Star) | Some(PathToken::StarIndex) => { + if path_tokens.peek().is_none() { + callback(field, None); + Ok(()) + } else { + Err(anyhow!("Path does not match any field in the message (additional path \ + segments can only be applied to a child message, but field type is '{}')", first_entry.data.type_name())) + } + } + Some(PathToken::Index(index)) => if first_entry.repeated_field() && path_tokens.peek().is_none() { + callback(field, Some(PathToken::Index(*index))); + Ok(()) + } else { + Err(anyhow!("Path segment '{}' can only be applied to repeated fields", index)) + } + Some(segment) => Err(anyhow!("Path segment '{}' can not be applied any field in the message", segment)), + None => Err(anyhow!("Path name '{}' does not match any field in the message", name)) } } } } else { - Err(anyhow!("Path '{}' does not match any field int the message", name)) - } - PathToken::Index(_) => Err(anyhow!("Support for index paths is not supported yet")), - PathToken::Star => Err(anyhow!("Support for '*' in paths is not supported yet")), - PathToken::StarIndex => Err(anyhow!("Support for '[*]' in paths is not supported yet")), + Err(anyhow!("Path name '{}' does not match any field in the message", name)) + }, + PathToken::Index(_) => return Err(anyhow!("Support for index paths is not supported yet")), + PathToken::Star => return Err(anyhow!("Support for '*' in paths is not supported yet")), + PathToken::StarIndex => return Err(anyhow!("Support for '[*]' in paths is not supported yet")), } } else { - Err(anyhow!("Path does not match any field int the message")) + return Err(anyhow!("Path does not match any field in the message (end of path tokens reached)")) } + + Ok(()) } /// Mutates the message by applying the generators to any matching message fields @@ -263,8 +290,31 @@ impl DynamicMessage { let value = self.fetch_field_value(&path); if let Some(value) = value { if generator.corresponds_to_mode(mode) { - let generated_value = generator.generate_value(&value.data, &context, &vm_boxed)?; - self.set_field_value(&path, generated_value)?; + // OK to unwrap here, for if the vec was empty, fetch_field_value would have returned None. + let first_entry = value.first().unwrap(); + match generator.generate_value(&first_entry.data, &context, &vm_boxed) { + Ok(generated_value) => { + self.set_field_value(&path, generated_value)?; + } + Err(err) => { + warn!("Failed to apply generator '{}' for field {}: {}", path, first_entry, err); + if let Some(GeneratorError::ProviderStateValueIsCollection(val)) = err.downcast_ref::() { + if first_entry.repeated_field() && val.wrapped.is_array() { + let array = as_array(val)?; + trace!("Applying a array value ({} items) to repeated field '{}'", array.len(), first_entry.field_name); + for (index, dv) in array.iter().enumerate() { + let index_path = path_join_index(path, index); + let pv = data_value_to_proto_value(&first_entry.data, dv)?; + self.set_field_value(&index_path, pv)?; + } + } else { + bail!(err); + } + } else { + bail!(err); + } + } + } } } else { warn!("No matching field found for generator '{}'", path); @@ -276,20 +326,46 @@ impl DynamicMessage { } } -// TODO: This only supports the first value, needs to deal with repeated fields -fn find_field_value<'a>( +// TODO: Replace this with DocPath.join_index when pact_models 1.2.5 is released +fn path_join_index(path: &DocPath, index: usize) -> DocPath { + let mut new_path = path.clone(); + match path.tokens().last() { + Some(PathToken::Root) => { new_path.push_index(index); } + Some(PathToken::Field(_)) => { new_path.push_index(index); } + Some(PathToken::Index(_)) => { new_path.push_index(index); } + Some(PathToken::Star) | Some(PathToken::StarIndex) => { + let tokens = new_path.tokens().clone(); + new_path = DocPath::empty(); + for token in tokens.iter().dropping_back(1) { + new_path.push(token.clone()); + } + new_path.push_index(index); + } + None => { new_path.push_index(index); } + } + new_path +} + +fn as_array(data: &DataValue) -> anyhow::Result> { + if let Value::Array(values) = &data.wrapped { + Ok(values.iter() + .map(|v| DataValue { + wrapped: v.clone(), + data_type: data.data_type + }) + .collect()) + } else { + Err(anyhow!("Value {} is not an array", data.wrapped)) + } +} + +fn find_field_values<'a>( fields: &'a mut HashMap>, field_name: &str -) -> Option<&'a mut ProtobufField> { +) -> Option<&'a mut Vec> { fields.iter_mut() .find(|(_, fields)| fields.iter().any(|field| field.field_name == field_name)) - .map(|(_, fields)| { - if fields.len() > 1 { - warn!("There is more than one field value"); - } - fields.get_mut(0) - }) - .flatten() + .map(|(_, fields)| fields) } #[derive(Debug, Clone)] @@ -387,7 +463,7 @@ mod tests { let descriptor = DescriptorProto::default(); let mut message = DynamicMessage::new(&descriptor, fields.as_slice(), &descriptors); let path = DocPath::new("one").unwrap(); - expect!(message.fetch_field_value(&path)).to(be_some().value(field)); + expect!(message.fetch_field_value(&path)).to(be_some().value(fields)); } #[test] @@ -406,7 +482,7 @@ mod tests { let fields = vec![ field.clone() ]; let mut message = DynamicMessage::new(&descriptor, fields.as_slice(), &descriptors); let path = DocPath::new("$.one").unwrap(); - expect!(message.fetch_field_value(&path)).to(be_some().value(field)); + expect!(message.fetch_field_value(&path)).to(be_some().value(fields)); } #[test] @@ -483,7 +559,7 @@ mod tests { let fields = vec![ field.clone() ]; let mut message = DynamicMessage::new(&descriptor, fields.as_slice(), &descriptors); let path = DocPath::new("$.one.two").unwrap(); - expect!(message.fetch_field_value(&path)).to(be_some().value(child_field)); + expect!(message.fetch_field_value(&path)).to(be_some().value(vec![child_field])); } #[test] @@ -605,6 +681,6 @@ mod tests { }; expect!(message.apply_generators(Some(&generators), &GeneratorTestMode::Provider, &hashmap!{})).to(be_ok()); - expect!(message.fetch_field_value(&path).unwrap().data.as_i64().unwrap()).to_not(be_equal_to(100)); + expect!(message.fetch_field_value(&path).unwrap().first().unwrap().data.as_i64().unwrap()).to_not(be_equal_to(100)); } } diff --git a/src/message_decoder/generators.rs b/src/message_decoder/generators.rs index 24cb97a..d8f0b61 100644 --- a/src/message_decoder/generators.rs +++ b/src/message_decoder/generators.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use anyhow::anyhow; use chrono::{DateTime, Local}; +use pact_models::expression_parser::DataValue; use pact_models::generators::{ generate_ascii_string, generate_decimal, @@ -23,12 +24,19 @@ use rand::prelude::*; use regex::{Captures, Regex}; use serde_json::Value; use serde_json::Value::Object; +use thiserror::Error; use tracing::{debug, instrument, trace, warn}; use uuid::Uuid; use crate::message_decoder::ProtobufFieldData; use crate::metadata::MessageMetadataValue; +#[derive(Error, Debug)] +pub enum GeneratorError { + #[error("Provider state value is a collection (Array or Object), and can not be injected into a single field")] + ProviderStateValueIsCollection(DataValue) +} + impl GenerateValue for Generator { #[instrument(ret)] fn generate_value(&self, @@ -197,19 +205,11 @@ impl GenerateValue for Generator { } else { context.clone() }; - match generate_value_from_context(exp, &provider_state_config, dt) { - Ok(val) => match value { - ProtobufFieldData::String(_) => Ok(ProtobufFieldData::String(val.to_string())), - ProtobufFieldData::Boolean(_) => Ok(ProtobufFieldData::Boolean(bool::try_from(val)?)), - ProtobufFieldData::UInteger32(_) => Ok(ProtobufFieldData::UInteger32(u64::try_from(val)? as u32)), - ProtobufFieldData::Integer32(_) => Ok(ProtobufFieldData::Integer32(i64::try_from(val)? as i32)), - ProtobufFieldData::UInteger64(_) => Ok(ProtobufFieldData::UInteger64(u64::try_from(val)?)), - ProtobufFieldData::Integer64(_) => Ok(ProtobufFieldData::Integer64(i64::try_from(val)?)), - ProtobufFieldData::Float(_) => Ok(ProtobufFieldData::Float(f64::try_from(val)? as f32)), - ProtobufFieldData::Double(_) => Ok(ProtobufFieldData::Double(f64::try_from(val)?)), - _ => Err(anyhow!("Can not generate a value from the provider state for a field type {:?}", value)) - }, - Err(err) => Err(err) + let val = generate_value_from_context(exp, &provider_state_config, dt)?; + if val.wrapped.is_array() || val.wrapped.is_object() { + Err(anyhow!(GeneratorError::ProviderStateValueIsCollection(val.clone()))) + } else { + data_value_to_proto_value(value, &val) } } Generator::MockServerURL(example, regex) => { @@ -264,6 +264,21 @@ fn replace_with_regex(example: &String, url: String, re: Regex) -> String { }).to_string() } +pub fn data_value_to_proto_value(value: &ProtobufFieldData, val: &DataValue) -> anyhow::Result { + let val = val.clone(); + match value { + ProtobufFieldData::String(_) => Ok(ProtobufFieldData::String(val.to_string())), + ProtobufFieldData::Boolean(_) => Ok(ProtobufFieldData::Boolean(bool::try_from(val)?)), + ProtobufFieldData::UInteger32(_) => Ok(ProtobufFieldData::UInteger32(u64::try_from(val)? as u32)), + ProtobufFieldData::Integer32(_) => Ok(ProtobufFieldData::Integer32(i64::try_from(val)? as i32)), + ProtobufFieldData::UInteger64(_) => Ok(ProtobufFieldData::UInteger64(u64::try_from(val)?)), + ProtobufFieldData::Integer64(_) => Ok(ProtobufFieldData::Integer64(i64::try_from(val)?)), + ProtobufFieldData::Float(_) => Ok(ProtobufFieldData::Float(f64::try_from(val)? as f32)), + ProtobufFieldData::Double(_) => Ok(ProtobufFieldData::Double(f64::try_from(val)?)), + _ => Err(anyhow!("Can not generate a value from the provider state for a field type {:?}", value)) + } +} + impl GenerateValue for Generator { #[instrument] fn generate_value( diff --git a/src/message_decoder/mod.rs b/src/message_decoder/mod.rs index 6cf8245..ddae879 100644 --- a/src/message_decoder/mod.rs +++ b/src/message_decoder/mod.rs @@ -16,7 +16,7 @@ use crate::utils::{ as_hex, find_enum_by_name, find_enum_by_name_in_message, find_message_descriptor_for_type, is_repeated_field, last_name, should_be_packed_type }; -mod generators; +pub mod generators; /// Decoded Protobuf field #[derive(Clone, Debug, PartialEq)] @@ -66,6 +66,11 @@ impl ProtobufField { pub fn is_default_value(&self) -> bool { self.data.is_default_field_value() } + + /// If the field is a Protobuf repeated field + pub fn repeated_field(&self) -> bool { + is_repeated_field(&self.descriptor) + } } fn default_field_data( diff --git a/src/mock_service.rs b/src/mock_service.rs index 6c7bc3b..0d64708 100644 --- a/src/mock_service.rs +++ b/src/mock_service.rs @@ -6,10 +6,9 @@ use std::task::{Context, Poll}; use maplit::hashmap; use pact_matching::{CoreMatchingContext, DiffConfig}; -use pact_models::generators::{GenerateValue, GeneratorCategory, NoopVariantMatcher, VariantMatcher}; +use pact_models::generators::{GeneratorCategory, GeneratorTestMode}; use pact_models::json_utils::json_to_string; use pact_models::pact::Pact; -use pact_models::path_exp::DocPath; use pact_models::prelude::v4::V4Pact; use pact_models::v4::message_parts::MessageContents; use pact_models::v4::sync_message::SynchronousMessage; @@ -211,19 +210,10 @@ impl MockService { } fn apply_generators(&self, message: &mut DynamicMessage, contents: &MessageContents) -> anyhow::Result<()> { - let variant_matcher = NoopVariantMatcher {}; - let vm_boxed = variant_matcher.boxed(); let context = hashmap!{}; // TODO: This needs to be passed in via the start mock server call if let Some(generators) = contents.generators.categories.get(&GeneratorCategory::BODY) { - for (key, generator) in generators.iter() { - let path = DocPath::new(key)?; - let value = message.fetch_field_value(&path); - if let Some(value) = value { - let generated_value = generator.generate_value(&value.data, &context, &vm_boxed)?; - message.set_field_value(&path, generated_value)?; - } - } + message.apply_generators(Some(&generators), &GeneratorTestMode::Consumer, &context)?; } Ok(()) diff --git a/src/server.rs b/src/server.rs index 97a2f6b..75bf186 100644 --- a/src/server.rs +++ b/src/server.rs @@ -15,7 +15,6 @@ use pact_models::generators::{ Generator, GeneratorCategory, GeneratorTestMode, - NoopVariantMatcher, VariantMatcher }; use pact_models::json_utils::json_to_string; @@ -615,24 +614,18 @@ fn generate_protobuf_contents( mode: TestMode ) -> anyhow::Result { let mut message: DynamicMessage = DynamicMessage::new(message_descriptor, fields, all_descriptors); - let variant_matcher = NoopVariantMatcher {}; - let vm_boxed = variant_matcher.boxed(); let context = hashmap!{}; + let mut generator_map = hashmap!{}; for (key, generator) in generators { let path = DocPath::new(key)?; - let value = message.fetch_field_value(&path); - if let Some(value) = value { - let generator_values = generator.values.as_ref() - .map(proto_struct_to_json) - .unwrap_or_default(); - let generator = Generator::create(generator.r#type.as_str(), &generator_values)?; - if generator.corresponds_to_mode(&to_generator_mode(mode)) { - let generated_value = generator.generate_value(&value.data, &context, &vm_boxed)?; - message.set_field_value(&path, generated_value)?; - } - } + let generator_values = generator.values.as_ref() + .map(proto_struct_to_json) + .unwrap_or_default(); + let generator = Generator::create(generator.r#type.as_str(), &generator_values)?; + generator_map.insert(path, generator); } + message.apply_generators(Some(&generator_map), &to_generator_mode(mode), &context)?; trace!(?message, "Writing generated message"); let mut buffer = BytesMut::new();