diff --git a/ml_metadata/metadata_store/metadata_access_object.cc b/ml_metadata/metadata_store/metadata_access_object.cc index 83566da30..2c74bd757 100644 --- a/ml_metadata/metadata_store/metadata_access_object.cc +++ b/ml_metadata/metadata_store/metadata_access_object.cc @@ -1929,6 +1929,29 @@ tensorflow::Status MetadataAccessObject::FindContextsByTypeId( contexts); } +tensorflow::Status MetadataAccessObject::FindContextByTypeIdAndName( + int64 type_id, absl::string_view name, Context* context) { + std::vector contexts; + Query find_node_ids_query; + TF_RETURN_IF_ERROR( + ComposeParameterizedQuery( + query_config_.select_context_by_type_id_and_name(), + {Bind(type_id), Bind(metadata_source_, name)}, &find_node_ids_query)); + TF_RETURN_IF_ERROR( + FindNodeByIdsQueryImpl(find_node_ids_query, query_config_, + metadata_source_, &contexts)); + + // By design, a pair uniquely identifies a context. + // Returns ok status and updates the input context if one context is found. + // Returns NotFound error and does nothing if no context is found. + // Fails if multiple contexts are found. + CHECK(contexts.size() == 1) << absl::StrCat( + "Found more than one contexts with type_id: ", std::to_string(type_id), + " and context name: ", name); + *context = contexts[0]; + return tensorflow::Status::OK(); +} + tensorflow::Status MetadataAccessObject::FindArtifactsByURI( const absl::string_view uri, std::vector* artifacts) { Query find_node_ids_query; diff --git a/ml_metadata/metadata_store/metadata_access_object.h b/ml_metadata/metadata_store/metadata_access_object.h index 3d3602468..71d8a824e 100644 --- a/ml_metadata/metadata_store/metadata_access_object.h +++ b/ml_metadata/metadata_store/metadata_access_object.h @@ -252,11 +252,17 @@ class MetadataAccessObject { tensorflow::Status FindContexts(std::vector* contexts); // Queries contexts by a given type_id. - // Returns NOT_FOUND error, if the given context_type_id cannot be found. + // Returns NOT_FOUND error, if no context can be found. // Returns detailed INTERNAL error, if query execution fails. tensorflow::Status FindContextsByTypeId(int64 context_type_id, std::vector* contexts); + // Queries a context by a type_id and a context name. + // Returns NOT_FOUND error, if no context can be found. + // Returns detailed INTERNAL error, if query execution fails. + tensorflow::Status FindContextByTypeIdAndName( + int64 type_id, absl::string_view name, Context* context); + // Updates a context. // Returns INVALID_ARGUMENT error, if the id field is not given. // Returns INVALID_ARGUMENT error, if no context is found with the given id. diff --git a/ml_metadata/metadata_store/metadata_access_object_test.cc b/ml_metadata/metadata_store/metadata_access_object_test.cc index 48f121f79..131eb3d4d 100644 --- a/ml_metadata/metadata_store/metadata_access_object_test.cc +++ b/ml_metadata/metadata_store/metadata_access_object_test.cc @@ -1113,6 +1113,21 @@ TEST_P(MetadataAccessObjectTest, CreateAndFindContext) { type2_id, &got_type2_contexts)); EXPECT_EQ(got_type2_contexts.size(), 1); EXPECT_THAT(got_type2_contexts[0], EqualsProto(context2)); + + Context got_context_from_type_and_name1; + TF_EXPECT_OK(metadata_access_object_->FindContextByTypeIdAndName( + type1_id, "my_context1", &got_context_from_type_and_name1)); + EXPECT_THAT(got_context_from_type_and_name1, EqualsProto(context1)); + Context got_context_from_type_and_name2; + TF_EXPECT_OK(metadata_access_object_->FindContextByTypeIdAndName( + type2_id, "my_context2", &got_context_from_type_and_name2)); + EXPECT_THAT(got_context_from_type_and_name2, EqualsProto(context2)); + Context got_empty_context; + EXPECT_EQ(metadata_access_object_->FindContextByTypeIdAndName( + type1_id, "my_context2", &got_empty_context).code(), + tensorflow::error::NOT_FOUND); + EXPECT_THAT(got_empty_context, + EqualsProto(ParseTextProtoOrDie(R"()"))); } TEST_P(MetadataAccessObjectTest, CreateContextError) { diff --git a/ml_metadata/metadata_store/metadata_store.cc b/ml_metadata/metadata_store/metadata_store.cc index 04195586b..8fea620cf 100644 --- a/ml_metadata/metadata_store/metadata_store.cc +++ b/ml_metadata/metadata_store/metadata_store.cc @@ -843,6 +843,33 @@ tensorflow::Status MetadataStore::GetContextsByType( }); } +tensorflow::Status MetadataStore::GetContextByTypeAndName( + const GetContextByTypeAndNameRequest& request, + GetContextByTypeAndNameResponse* response) { + return ExecuteTransaction( + metadata_source_.get(), + [this, &request, &response]() -> tensorflow::Status { + ContextType context_type; + tensorflow::Status status = metadata_access_object_->FindTypeByName( + request.type_name(), &context_type); + if (tensorflow::errors::IsNotFound(status)) { + return tensorflow::Status::OK(); + } else if (!status.ok()) { + return status; + } + Context* context = new Context; + status = metadata_access_object_->FindContextByTypeIdAndName( + context_type.id(), request.context_name(), context); + if (tensorflow::errors::IsNotFound(status)) { + return tensorflow::Status::OK(); + } else if (!status.ok()) { + return status; + } + response->set_allocated_context(context); + return tensorflow::Status::OK(); + }); +} + tensorflow::Status MetadataStore::PutAttributionsAndAssociations( const PutAttributionsAndAssociationsRequest& request, PutAttributionsAndAssociationsResponse* response) { diff --git a/ml_metadata/metadata_store/metadata_store.h b/ml_metadata/metadata_store/metadata_store.h index faac50e55..6b122bc5a 100644 --- a/ml_metadata/metadata_store/metadata_store.h +++ b/ml_metadata/metadata_store/metadata_store.h @@ -383,6 +383,13 @@ class MetadataStore { tensorflow::Status GetContextsByType(const GetContextsByTypeRequest& request, GetContextsByTypeResponse* response); + // Gets the context of a given type and name. If no context found, it returns + // OK and empty response. If more than one contexts matchs the type and name, + // the query execution fails. + tensorflow::Status GetContextByTypeAndName( + const GetContextByTypeAndNameRequest& request, + GetContextByTypeAndNameResponse* response); + // Inserts attribution and association relationships in the database. // The context_id, artifact_id, and execution_id must already exist. // If the relationship exists, this call does nothing. Once added, the diff --git a/ml_metadata/metadata_store/metadata_store_test.cc b/ml_metadata/metadata_store/metadata_store_test.cc index 682fb7312..ac00c5b48 100644 --- a/ml_metadata/metadata_store/metadata_store_test.cc +++ b/ml_metadata/metadata_store/metadata_store_test.cc @@ -1956,6 +1956,62 @@ TEST_F(MetadataStoreTest, PutContextsUpdateGetContexts) { testing::EqualsProto(want_context3)); } +// Test creating a context and then getting it by its type and context name. +TEST_F(MetadataStoreTest, PutContextGetContextsByTypeAndName) { + // Create a context type + const PutContextTypeRequest put_context_type_request = + ParseTextProtoOrDie(R"( + all_fields_match: true + context_type: { + name: 'test_type' + } + )"); + PutContextTypeResponse put_context_type_response; + TF_ASSERT_OK(metadata_store_->PutContextType(put_context_type_request, + &put_context_type_response)); + ASSERT_TRUE(put_context_type_response.has_type_id()); + const int64 type_id = put_context_type_response.type_id(); + + // Create a context + PutContextsRequest put_contexts_request = + ParseTextProtoOrDie(R"( + contexts: { + name: 'context_name' + } + )"); + put_contexts_request.mutable_contexts(0)->set_type_id(type_id); + PutContextsResponse put_contexts_response; + TF_ASSERT_OK(metadata_store_->PutContexts(put_contexts_request, + &put_contexts_response)); + ASSERT_THAT(put_contexts_response.context_ids(), SizeIs(1)); + const int64 id1 = put_contexts_response.context_ids(0); + + // Test the returned context is the same as the created one. + Context want_context1 = *put_contexts_request.mutable_contexts(0); + want_context1.set_id(id1); + + GetContextByTypeAndNameRequest get_context_by_type_and_name_request; + get_context_by_type_and_name_request.set_type_name("test_type"); + get_context_by_type_and_name_request.set_context_name("context_name"); + GetContextByTypeAndNameResponse get_context_by_type_and_name_response; + TF_ASSERT_OK(metadata_store_->GetContextByTypeAndName( + get_context_by_type_and_name_request, + &get_context_by_type_and_name_response)); + ASSERT_TRUE(get_context_by_type_and_name_response.has_context()); + EXPECT_THAT(get_context_by_type_and_name_response.context(), + testing::EqualsProto(want_context1)); + + // Test that no context is found given the input type and name. + GetContextByTypeAndNameRequest get_no_context_by_type_and_name_request; + get_context_by_type_and_name_request.set_type_name("test_type1"); + get_context_by_type_and_name_request.set_context_name("context3"); + GetContextByTypeAndNameResponse get_no_context_by_type_and_name_response; + TF_ASSERT_OK(metadata_store_->GetContextByTypeAndName( + get_no_context_by_type_and_name_request, + &get_no_context_by_type_and_name_response)); + EXPECT_FALSE(get_no_context_by_type_and_name_response.has_context()); +} + TEST_F(MetadataStoreTest, PutAndUseAttributionsAndAssociations) { const PutTypesRequest put_types_request = ParseTextProtoOrDie(R"( diff --git a/ml_metadata/proto/metadata_source.proto b/ml_metadata/proto/metadata_source.proto index 6409404da..4dd3889fa 100644 --- a/ml_metadata/proto/metadata_source.proto +++ b/ml_metadata/proto/metadata_source.proto @@ -47,7 +47,7 @@ enum MetadataSourceType { // A config includes a set of SQL queries and the type of metadata source. // It is used by MetadataAccessObject to init backend and issue queries. -// Next ID: 93 +// Next ID: 94 message MetadataSourceQueryConfig { // the type of the metadata source MetadataSourceType metadata_source_type = 1; @@ -283,6 +283,12 @@ message MetadataSourceQueryConfig { // $0 is the context_type_id TemplateQuery select_contexts_by_type_id = 72; + // Queries a context from the Context table by its type_id and name. It has 2 + // parameters. + // $0 is the context_type_id + // $1 is the context_name + TemplateQuery select_context_by_type_id_and_name = 93; + // Updates a context in the Context table. It has 3 parameters. // $0 is the existing context id // $1 is the type_id diff --git a/ml_metadata/proto/metadata_store_service.proto b/ml_metadata/proto/metadata_store_service.proto index 707b89f3b..93d2d639b 100644 --- a/ml_metadata/proto/metadata_store_service.proto +++ b/ml_metadata/proto/metadata_store_service.proto @@ -397,6 +397,15 @@ message GetContextsByTypeResponse { repeated Context contexts = 1; } +message GetContextByTypeAndNameRequest { + optional string type_name = 1; + optional string context_name = 2; +} + +message GetContextByTypeAndNameResponse { + optional Context context = 1; +} + message GetContextsByIDRequest { // A list of context ids to retrieve. repeated int64 context_ids = 1; diff --git a/ml_metadata/util/metadata_source_query_config.cc b/ml_metadata/util/metadata_source_query_config.cc index ae9fe0a12..eb0b5ae57 100644 --- a/ml_metadata/util/metadata_source_query_config.cc +++ b/ml_metadata/util/metadata_source_query_config.cc @@ -291,6 +291,10 @@ constexpr char kBaseQueryConfig[] = R"pb( query: " SELECT `id` from `Context` WHERE `type_id` = $0; " parameter_num: 1 } + select_context_by_type_id_and_name { + query: " SELECT `id` from `Context` WHERE `type_id` = $0 and `name` = $1; " + parameter_num: 2 + } update_context { query: " UPDATE `Context` " " SET `type_id` = $1, `name` = $2"