Skip to content

Commit

Permalink
fix: Take into account package names when looking for message types i…
Browse files Browse the repository at this point in the history
…n the descriptors
  • Loading branch information
rholshausen committed Apr 11, 2024
1 parent 4587ac7 commit d9ce8fb
Show file tree
Hide file tree
Showing 8 changed files with 3,728 additions and 21 deletions.
3,333 changes: 3,333 additions & 0 deletions integrated_tests/imported_message/Cargo.lock

Large diffs are not rendered by default.

26 changes: 26 additions & 0 deletions integrated_tests/imported_message/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
[workspace]

[package]
name = "imported_message"
version = "0.1.0"
edition = "2021"

[dependencies]
tokio = { version = "1", features = ["full"] }
anyhow = "1.0.43"
tonic = "0.8.3"
prost = "0.11.9"
prost-types = "0.11.9"
tracing = { version = "0.1", features = [ "log-always" ] }
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

[dev-dependencies]
expectest = "0.12.0"
env_logger = "0.10.1"
pact-plugin-driver = "0.4.6"
pact_consumer = "1.0.5"
serde_json = "1.0.66"
maplit = "1.0.2"

[build-dependencies]
tonic-build = "0.8.4"
7 changes: 7 additions & 0 deletions integrated_tests/imported_message/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
fn main() -> Result<(), Box<dyn std::error::Error>> {
tonic_build::configure().include_file("mod.rs").compile(
&["primary/primary.proto", "imported/imported.proto"],
&["."],
)?;
Ok(())
}
39 changes: 39 additions & 0 deletions integrated_tests/imported_message/imported/imported.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

syntax = "proto3";

option go_package = "github.com/pact-foundation/pact-go/v2/examples/grpc/imported";
option java_multiple_files = true;
option java_package = "io.grpc.examples.imported";
option java_outer_classname = "ImportedProto";

package imported;

service Imported {
rpc GetRectangle(RectangleLocationRequest) returns (RectangleLocationResponse) {}
}

message Rectangle {
// The width of the rectangle.
int32 width = 1;

// The length of the rectangle.
int32 length = 2;
}

// Request message for GetRectangle method. This message has different fields,
// but the same name as a message defined in primary.proto
message RectangleLocationRequest {
int32 width = 1;
int32 length = 2;
}

// Response message for GetRectangle method. This message has different fields,
// but the same name as a message defined in primary.proto
message RectangleLocationResponse {
Point location = 1;
}

message Point {
int32 latitude = 1;
int32 longitude = 2;
}
42 changes: 42 additions & 0 deletions integrated_tests/imported_message/primary/primary.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@

// Copyright 2015 gRPC authors.
//
syntax = "proto3";

option go_package = "github.com/pact-foundation/pact-go/v2/examples/grpc/primary";
option java_multiple_files = true;
option java_package = "io.grpc.examples.primary";
option java_outer_classname = "PrimaryProto";

import "imported/imported.proto";

package primary;

service Primary {
rpc GetRectangle(RectangleLocationRequest) returns (RectangleLocationResponse) {}
}

// A latitude-longitude rectangle, represented as two diagonally opposite
// points "lo" and "hi".
message Rectangle {
// One corner of the rectangle.
imported.Point lo = 1;

// The other corner of the rectangle.
imported.Point hi = 2;
}

// A request payload to get a Rectangle.
message RectangleLocationRequest {
// The width of the rectangle.
int32 x = 1;
int32 y = 2;
int32 width = 3;
int32 length = 4;
}

// A response payload containing a Rectangle.
message RectangleLocationResponse {
// The location of the rectangle.
Rectangle rectangle = 1;
}
86 changes: 86 additions & 0 deletions integrated_tests/imported_message/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
tonic::include_proto!("mod");

#[cfg(test)]
mod tests {
use std::path::Path;

use crate::primary::primary_client::PrimaryClient;
use pact_consumer::mock_server::StartMockServerAsync;
use pact_consumer::prelude::*;
use serde_json::json;
use tonic::IntoRequest;
use tracing::info;

use super::*;

#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn test_proto_client() {
let _ = env_logger::builder().is_test(true).try_init();

let mut pact_builder = PactBuilderAsync::new_v4("grpc-consumer-rust", "imported_message");
let mock_server = pact_builder
.using_plugin("protobuf", None)
.await
.synchronous_message_interaction(
"package namespace not respected",
|mut i| async move {
let proto_file = Path::new("primary/primary.proto")
.canonicalize()
.unwrap()
.to_string_lossy()
.to_string();
let proto_include = Path::new(".")
.canonicalize()
.unwrap()
.to_string_lossy()
.to_string();
info!("proto_file: {}", proto_file);
info!("proto_include: {}", proto_include);
i.contents_from(json!({
"pact:proto": proto_file,
"pact:proto-service": "Primary/GetRectangle",
"pact:content-type": "application/protobuf",
"pact:protobuf-config": {
"additionalIncludes": [ proto_include ]
},
"request": {
"x": "matching(number, 180)",
"y": "matching(number, 200)",
"width": "matching(number, 10)",
"length": "matching(number, 20)"
},
"response": {
"rectangle": {
"lo": {
"latitude": "matching(number, 180)",
"longitude": "matching(number, 99)"
},
"hi": {
"latitude": "matching(number, 200)",
"longitude": "matching(number, 99)"
}
}
}
}))
.await;
i
},
)
.await
.start_mock_server_async(Some("protobuf/transport/grpc"))
.await;

let url = mock_server.url();

let mut client = PrimaryClient::connect(url.to_string()).await.unwrap();
let request_message = primary::RectangleLocationRequest {
x: 180,
y: 200,
width: 10,
length: 20,
};

let response = client.get_rectangle(request_message.into_request()).await;
let _response_message = response.unwrap();
}
}
83 changes: 66 additions & 17 deletions src/protobuf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ use crate::utils::{
find_enum_value_by_name,
find_enum_value_by_name_in_message,
find_message_type_in_file_descriptors,
find_message_with_package_in_file_descriptors,
find_nested_type,
is_map_field,
is_repeated_field,
last_name,
prost_string
prost_string,
split_name
};

/// Process the provided protobuf file and configure the interaction
Expand Down Expand Up @@ -136,6 +138,7 @@ fn configure_protobuf_service(
let service_descriptor = descriptor.service
.iter().find(|p| p.name.clone().unwrap_or_default() == service)
.ok_or_else(|| anyhow!("Did not find a descriptor for service '{}'", service_name))?;
trace!("service_descriptor = {:?}", service_descriptor);
construct_protobuf_interaction_for_service(service_descriptor, config, service,
proc_name, all_descriptors, descriptor)
.map(|(request, response)| {
Expand Down Expand Up @@ -178,16 +181,21 @@ fn construct_protobuf_interaction_for_service(
.find(|m| m.name.clone().unwrap_or_default() == method_name)
.ok_or_else(|| anyhow!("Did not find a method descriptor for method '{}' in service '{}'", method_name, service_name))?;

let input_name = method_descriptor.input_type.as_ref().ok_or_else(|| anyhow!("Input message name is empty for service {}/{}", service_name, method_name))?;
let output_name = method_descriptor.output_type.as_ref().ok_or_else(|| anyhow!("Input message name is empty for service {}/{}", service_name, method_name))?;
let input_message_name = last_name(input_name.as_str());
let output_message_name = last_name(output_name.as_str());
let input_name = method_descriptor.input_type.as_ref()
.ok_or_else(|| anyhow!("Input message name is empty for service {}/{}", service_name, method_name))?;
let output_name = method_descriptor.output_type.as_ref()
.ok_or_else(|| anyhow!("Input message name is empty for service {}/{}", service_name, method_name))?;
let (input_message_name, input_package) = split_name(input_name.as_str());
let (output_message_name, output_package) = split_name(output_name.as_str());

trace!(input_name = input_name.as_str(), input_message_name, "Input message");
trace!(output_name = output_name.as_str(), output_message_name, "Output message");
trace!(%input_name, ?input_package, input_message_name, "Input message");
trace!(%output_name, ?output_package, output_message_name, "Output message");

let request_descriptor = find_message_descriptor(input_message_name, all_descriptors)?;
let response_descriptor = find_message_descriptor(output_message_name, all_descriptors)?;
let request_descriptor = find_message_descriptor(input_message_name, input_package, file_descriptor, all_descriptors)?;
let response_descriptor = find_message_descriptor(output_message_name, output_package, file_descriptor, all_descriptors)?;

trace!("request_descriptor = {:?}", request_descriptor);
trace!("response_descriptor = {:?}", response_descriptor);

let request_part_config = request_part(config, service_part)?;
trace!(config = ?request_part_config, service_part, "Processing request part config");
Expand Down Expand Up @@ -276,14 +284,18 @@ fn request_part(
}
}

fn find_message_descriptor(message_name: &str, all_descriptors: &HashMap<String, &FileDescriptorProto>) -> anyhow::Result<DescriptorProto> {
all_descriptors.values().map(|descriptor| {
descriptor.message_type.iter()
.find(|p| p.name.clone().unwrap_or_default() == message_name)
}).find(|d| d.is_some())
.flatten()
.cloned()
.ok_or_else(|| anyhow!("Did not find the descriptor for message {}", message_name))
// Search for a message by name, first in the current file descriptor, then in all descriptors.
fn find_message_descriptor(
message_name: &str,
package: Option<&str>,
file_descriptor: &FileDescriptorProto,
all_descriptors: &HashMap<String, &FileDescriptorProto>
) -> anyhow::Result<DescriptorProto> {
if let Some(package) = package {
find_message_with_package_in_file_descriptors(message_name, package, file_descriptor, all_descriptors)
} else {
find_message_type_in_file_descriptors(message_name, file_descriptor, all_descriptors)
}
}

/// Configure the interaction for a single Protobuf message
Expand Down Expand Up @@ -327,6 +339,7 @@ fn construct_protobuf_interaction_for_message(
) -> anyhow::Result<InteractionResponse> {
trace!(">> construct_protobuf_interaction_for_message({}, {}, {:?}, {:?}, {:?})", message_name,
message_part, file_descriptor.name, config.keys(), metadata);
trace!("message_descriptor = {:?}", message_descriptor);

let mut message_builder = MessageBuilder::new(message_descriptor, message_name, file_descriptor);
let mut matching_rules = MatchingRuleCategory::empty("body");
Expand Down Expand Up @@ -2706,4 +2719,40 @@ pub(crate) mod tests {
]
));
}

#[test]
fn find_message_descriptor_test() {
let descriptors = "CpAEChdpbXBvcnRlZC9pbXBvcnRlZC5wcm90bxIIaW1wb3J0ZWQiOQoJUmVjdGFuZ2x\
lEhQKBXdpZHRoGAEgASgFUgV3aWR0aBIWCgZsZW5ndGgYAiABKAVSBmxlbmd0aCJIChhSZWN0YW5nbGVMb2NhdGlvblJ\
lcXVlc3QSFAoFd2lkdGgYASABKAVSBXdpZHRoEhYKBmxlbmd0aBgCIAEoBVIGbGVuZ3RoIkgKGVJlY3RhbmdsZUxvY2F0\
aW9uUmVzcG9uc2USKwoIbG9jYXRpb24YASABKAsyDy5pbXBvcnRlZC5Qb2ludFIIbG9jYXRpb24iQQoFUG9pbnQSGgoIb\
GF0aXR1ZGUYASABKAVSCGxhdGl0dWRlEhwKCWxvbmdpdHVkZRgCIAEoBVIJbG9uZ2l0dWRlMmUKCEltcG9ydGVkElkKDE\
dldFJlY3RhbmdsZRIiLmltcG9ydGVkLlJlY3RhbmdsZUxvY2F0aW9uUmVxdWVzdBojLmltcG9ydGVkLlJlY3RhbmdsZUxv\
Y2F0aW9uUmVzcG9uc2UiAEJqChlpby5ncnBjLmV4YW1wbGVzLmltcG9ydGVkQg1JbXBvcnRlZFByb3RvUAFaPGdpdGh1Y\
i5jb20vcGFjdC1mb3VuZGF0aW9uL3BhY3QtZ28vdjIvZXhhbXBsZXMvZ3JwYy9pbXBvcnRlZGIGcHJvdG8zCooECg1wcm\
ltYXJ5LnByb3RvEgdwcmltYXJ5GhdpbXBvcnRlZC9pbXBvcnRlZC5wcm90byJNCglSZWN0YW5nbGUSHwoCbG8YASABKAs\
yDy5pbXBvcnRlZC5Qb2ludFICbG8SHwoCaGkYAiABKAsyDy5pbXBvcnRlZC5Qb2ludFICaGkiZAoYUmVjdGFuZ2xlTG9j\
YXRpb25SZXF1ZXN0EgwKAXgYASABKAVSAXgSDAoBeRgCIAEoBVIBeRIUCgV3aWR0aBgDIAEoBVIFd2lkdGgSFgoGbGVuZ\
3RoGAQgASgFUgZsZW5ndGgiTQoZUmVjdGFuZ2xlTG9jYXRpb25SZXNwb25zZRIwCglyZWN0YW5nbGUYASABKAsyEi5wcml\
tYXJ5LlJlY3RhbmdsZVIJcmVjdGFuZ2xlMmIKB1ByaW1hcnkSVwoMR2V0UmVjdGFuZ2xlEiEucHJpbWFyeS5SZWN0YW5nb\
GVMb2NhdGlvblJlcXVlc3QaIi5wcmltYXJ5LlJlY3RhbmdsZUxvY2F0aW9uUmVzcG9uc2UiAEJnChhpby5ncnBjLmV4YW1\
wbGVzLnByaW1hcnlCDFByaW1hcnlQcm90b1ABWjtnaXRodWIuY29tL3BhY3QtZm91bmRhdGlvbi9wYWN0LWdvL3YyL2V4Y\
W1wbGVzL2dycGMvcHJpbWFyeWIGcHJvdG8z";
let decoded = BASE64.decode(descriptors).unwrap();
let bytes = Bytes::copy_from_slice(decoded.as_slice());
let fds = FileDescriptorSet::decode(bytes).unwrap();
let all: HashMap<String, &FileDescriptorProto> = fds.file
.iter().map(|des| (des.name.clone().unwrap_or_default(), des))
.collect();
let file_descriptor = &fds.file[0];

let result = super::find_message_descriptor("RectangleLocationRequest", None, file_descriptor, &all).unwrap();
expect!(result.field.len()).to(be_equal_to(2));
let result = super::find_message_descriptor("RectangleLocationRequest", Some("primary"), file_descriptor, &all).unwrap();
expect!(result.field.len()).to(be_equal_to(4));
let result = super::find_message_descriptor("RectangleLocationRequest", Some(".primary"), file_descriptor, &all).unwrap();
expect!(result.field.len()).to(be_equal_to(4));
let result = super::find_message_descriptor("RectangleLocationRequest", Some("imported"), file_descriptor, &all).unwrap();
expect!(result.field.len()).to(be_equal_to(2));
}
}
Loading

0 comments on commit d9ce8fb

Please sign in to comment.