diff --git a/RELEASE.md b/RELEASE.md index 0f25f59a1..97c68fd52 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -11,6 +11,8 @@ used by mlmd powered systems (e.g., orchestrator). * Add support to pass migration options as command line parameters to the MLMD gRPC server. +* Adding a new Python API get_context_by_type_and_name to allow querying a + context by its type and context name at the same time. ## Bug Fixes and Other Changes diff --git a/ml_metadata/metadata_store/metadata_store.py b/ml_metadata/metadata_store/metadata_store.py index 11d6501ae..50971e9c9 100644 --- a/ml_metadata/metadata_store/metadata_store.py +++ b/ml_metadata/metadata_store/metadata_store.py @@ -811,19 +811,15 @@ def get_context_by_type_and_name( The Context matching the type and context name. None if no matched Context found. """ - # TODO(b/139092990): Change get logic to the new C++ API once implemented. - request = metadata_store_service_pb2.GetContextsByTypeRequest() + request = metadata_store_service_pb2.GetContextByTypeAndNameRequest() request.type_name = type_name - response = metadata_store_service_pb2.GetContextsByTypeResponse() - - self._call('GetContextsByType', request, response) - result = [c for c in response.contexts - if c.HasField('name') and c.name == context_name] + request.context_name = context_name + response = metadata_store_service_pb2.GetContextByTypeAndNameResponse() - assert len(result) <= 1, 'Found more than one contexts with input.' - if not result: + self._call('GetContextByTypeAndName', request, response) + if not response.HasField('context'): return None - return result[0] + return response.context def get_artifact_types_by_id( self, type_ids: Sequence[int]) -> List[metadata_store_pb2.ArtifactType]: diff --git a/ml_metadata/metadata_store/metadata_store_service_impl.cc b/ml_metadata/metadata_store/metadata_store_service_impl.cc index d8ac4cb57..e437dda7f 100644 --- a/ml_metadata/metadata_store/metadata_store_service_impl.cc +++ b/ml_metadata/metadata_store/metadata_store_service_impl.cc @@ -416,6 +416,20 @@ ::grpc::Status MetadataStoreServiceImpl::GetContextsByType( return status; } +::grpc::Status MetadataStoreServiceImpl::GetContextByTypeAndName( + ::grpc::ServerContext* context, + const ::ml_metadata::GetContextByTypeAndNameRequest* request, + ::ml_metadata::GetContextByTypeAndNameResponse* response) { + absl::WriterMutexLock l(&lock_); + const ::grpc::Status status = ToGRPCStatus( + metadata_store_->GetContextByTypeAndName(*request, response)); + if (!status.ok()) { + LOG(WARNING) << "GetContextByTypeAndName failed: " + << status.error_message(); + } + return status; +} + ::grpc::Status MetadataStoreServiceImpl::PutAttributionsAndAssociations( ::grpc::ServerContext* context, const ::ml_metadata::PutAttributionsAndAssociationsRequest* request, diff --git a/ml_metadata/metadata_store/metadata_store_service_impl.h b/ml_metadata/metadata_store/metadata_store_service_impl.h index b9f7d1fe0..8c2d4a6df 100644 --- a/ml_metadata/metadata_store/metadata_store_service_impl.h +++ b/ml_metadata/metadata_store/metadata_store_service_impl.h @@ -204,6 +204,12 @@ class MetadataStoreServiceImpl final ::ml_metadata::GetContextsByTypeResponse* response) override ABSL_LOCKS_EXCLUDED(lock_); + ::grpc::Status GetContextByTypeAndName( + ::grpc::ServerContext* context, + const ::ml_metadata::GetContextByTypeAndNameRequest* request, + ::ml_metadata::GetContextByTypeAndNameResponse* response) override + ABSL_LOCKS_EXCLUDED(lock_); + ::grpc::Status PutAttributionsAndAssociations( ::grpc::ServerContext* context, const ::ml_metadata::PutAttributionsAndAssociationsRequest* request, diff --git a/ml_metadata/metadata_store/tf_metadata_store_serialized.i b/ml_metadata/metadata_store/tf_metadata_store_serialized.i index 3490c5af1..60c0920c3 100644 --- a/ml_metadata/metadata_store/tf_metadata_store_serialized.i +++ b/ml_metadata/metadata_store/tf_metadata_store_serialized.i @@ -294,6 +294,12 @@ PyObject* GetContextsByType(ml_metadata::MetadataStore* metadata_store, &ml_metadata::MetadataStore::GetContextsByType); } +PyObject* GetContextByTypeAndName(ml_metadata::MetadataStore* metadata_store, + const string& request) { + return AccessMetadataStore(metadata_store, request, + &ml_metadata::MetadataStore::GetContextByTypeAndName); +} + PyObject* PutAttributionsAndAssociations( ml_metadata::MetadataStore* metadata_store, const string& request) { return AccessMetadataStore(metadata_store, request, @@ -440,6 +446,9 @@ PyObject* GetContexts(ml_metadata::MetadataStore* metadata_store, PyObject* GetContextsByType(ml_metadata::MetadataStore* metadata_store, const string& request); +PyObject* GetContextByTypeAndName(ml_metadata::MetadataStore* metadata_store, + const string& request); + PyObject* PutAttributionsAndAssociations( ml_metadata::MetadataStore* metadata_store, const string& request); diff --git a/ml_metadata/proto/metadata_store_service.proto b/ml_metadata/proto/metadata_store_service.proto index 93d2d639b..10e76c75d 100644 --- a/ml_metadata/proto/metadata_store_service.proto +++ b/ml_metadata/proto/metadata_store_service.proto @@ -752,6 +752,10 @@ service MetadataStoreService { rpc GetContextsByType(GetContextsByTypeRequest) returns (GetContextsByTypeResponse) {} + // Gets the context of the given type and context name. + rpc GetContextByTypeAndName(GetContextByTypeAndNameRequest) + returns (GetContextByTypeAndNameResponse) {} + // Gets all the artifacts of a given uri. rpc GetArtifactsByURI(GetArtifactsByURIRequest) returns (GetArtifactsByURIResponse) {}