Skip to content

Commit

Permalink
Adds Contexts to PutExecution to attach context atomically.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 282819494
  • Loading branch information
hughmiao authored and tf-metadata-team committed Nov 27, 2019
1 parent 11f8863 commit c9747be
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 82 deletions.
2 changes: 1 addition & 1 deletion ml_metadata/metadata_store/metadata_access_object.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ class MetadataAccessObject {
// Creates an attribution, returns the assigned attribution id.
// Returns INVALID_ARGUMENT error, if no context matches the context_id.
// Returns INVALID_ARGUMENT error, if no artifact matches the artifact_id.
// Returns INTERNAL error, if the same attribution already exists.
// Returns AlreadyExists error, if the same attribution already exists.
tensorflow::Status CreateAttribution(const Attribution& attribution,
int64* attribution_id);

Expand Down
207 changes: 133 additions & 74 deletions ml_metadata/metadata_store/metadata_store.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ bool CheckFieldsConsistent(const T& stored_type, const T& other_type,
// Returns INVALID_ARGUMENT error, if any property type in `type` is unknown.
// Returns detailed INTERNAL error, if query execution fails.
template <typename T>
tensorflow::Status UpsertType(MetadataAccessObject* metadata_access_object,
const T& type, bool can_add_fields,
tensorflow::Status UpsertType(const T& type, bool can_add_fields,
MetadataAccessObject* metadata_access_object,
int64* type_id) {
T stored_type;
const tensorflow::Status status =
Expand All @@ -85,6 +85,86 @@ tensorflow::Status UpsertType(MetadataAccessObject* metadata_access_object,
return metadata_access_object->UpdateType(type);
}

// Updates or inserts an artifact. If the artifact.id is given, it updates the
// stored artifact, otherwise, it creates a new artifact.
tensorflow::Status UpsertArtifact(const Artifact& artifact,
MetadataAccessObject* metadata_access_object,
int64* artifact_id) {
CHECK(artifact_id) << "artifact_id should not be null";
if (artifact.has_id()) {
TF_RETURN_IF_ERROR(metadata_access_object->UpdateArtifact(artifact));
*artifact_id = artifact.id();
} else {
TF_RETURN_IF_ERROR(
metadata_access_object->CreateArtifact(artifact, artifact_id));
}
return tensorflow::Status::OK();
}

// Updates or inserts an execution. If the execution.id is given, it updates the
// stored execution, otherwise, it creates a new execution.
tensorflow::Status UpsertExecution(const Execution& execution,
MetadataAccessObject* metadata_access_object,
int64* execution_id) {
CHECK(execution_id) << "execution_id should not be null";
if (execution.has_id()) {
TF_RETURN_IF_ERROR(metadata_access_object->UpdateExecution(execution));
*execution_id = execution.id();
} else {
TF_RETURN_IF_ERROR(
metadata_access_object->CreateExecution(execution, execution_id));
}
return tensorflow::Status::OK();
}

// Updates or inserts a context. If the context.id is given, it updates the
// stored context, otherwise, it creates a new context.
tensorflow::Status UpsertContext(const Context& context,
MetadataAccessObject* metadata_access_object,
int64* context_id) {
CHECK(context_id) << "context_id should not be null";
if (context.has_id()) {
TF_RETURN_IF_ERROR(metadata_access_object->UpdateContext(context));
*context_id = context.id();
} else {
TF_RETURN_IF_ERROR(
metadata_access_object->CreateContext(context, context_id));
}
return tensorflow::Status::OK();
}

// Inserts an association. If the association already exists it returns OK.
tensorflow::Status InsertAssociationIfNotExist(
int64 context_id, int64 execution_id,
MetadataAccessObject* metadata_access_object) {
Association association;
association.set_execution_id(execution_id);
association.set_context_id(context_id);
int64 dummy_assocation_id;
tensorflow::Status status = metadata_access_object->CreateAssociation(
association, &dummy_assocation_id);
if (!status.ok() && !tensorflow::errors::IsAlreadyExists(status)) {
return status;
}
return tensorflow::Status::OK();
}

// Inserts an attribution. If the attribution already exists it returns OK.
tensorflow::Status InsertAttributionIfNotExist(
int64 context_id, int64 artifact_id,
MetadataAccessObject* metadata_access_object) {
Attribution attribution;
attribution.set_artifact_id(artifact_id);
attribution.set_context_id(context_id);
int64 dummy_attribution_id;
tensorflow::Status status = metadata_access_object->CreateAttribution(
attribution, &dummy_attribution_id);
if (!status.ok() && !tensorflow::errors::IsAlreadyExists(status)) {
return status;
}
return tensorflow::Status::OK();
}

} // namespace

tensorflow::Status MetadataStore::InitMetadataStore() {
Expand Down Expand Up @@ -117,22 +197,22 @@ tensorflow::Status MetadataStore::PutTypes(const PutTypesRequest& request,
[this, &request, &response]() -> tensorflow::Status {
for (const ArtifactType& artifact_type : request.artifact_types()) {
int64 artifact_type_id;
TF_RETURN_IF_ERROR(UpsertType(metadata_access_object_.get(),
artifact_type, request.can_add_fields(),
TF_RETURN_IF_ERROR(UpsertType(artifact_type, request.can_add_fields(),
metadata_access_object_.get(),
&artifact_type_id));
response->add_artifact_type_ids(artifact_type_id);
}
for (const ExecutionType& execution_type : request.execution_types()) {
int64 execution_type_id;
TF_RETURN_IF_ERROR(
UpsertType(metadata_access_object_.get(), execution_type,
request.can_add_fields(), &execution_type_id));
UpsertType(execution_type, request.can_add_fields(),
metadata_access_object_.get(), &execution_type_id));
response->add_execution_type_ids(execution_type_id);
}
for (const ContextType& context_type : request.context_types()) {
int64 context_type_id;
TF_RETURN_IF_ERROR(UpsertType(metadata_access_object_.get(),
context_type, request.can_add_fields(),
TF_RETURN_IF_ERROR(UpsertType(context_type, request.can_add_fields(),
metadata_access_object_.get(),
&context_type_id));
response->add_context_type_ids(context_type_id);
}
Expand All @@ -152,9 +232,9 @@ tensorflow::Status MetadataStore::PutArtifactType(
metadata_source_.get(),
[this, &request, &response]() -> tensorflow::Status {
int64 type_id;
TF_RETURN_IF_ERROR(UpsertType(metadata_access_object_.get(),
request.artifact_type(),
request.can_add_fields(), &type_id));
TF_RETURN_IF_ERROR(UpsertType(request.artifact_type(),
request.can_add_fields(),
metadata_access_object_.get(), &type_id));
response->set_type_id(type_id);
return tensorflow::Status::OK();
});
Expand All @@ -173,9 +253,9 @@ tensorflow::Status MetadataStore::PutExecutionType(
metadata_source_.get(),
[this, &request, &response]() -> tensorflow::Status {
int64 type_id;
TF_RETURN_IF_ERROR(UpsertType(metadata_access_object_.get(),
request.execution_type(),
request.can_add_fields(), &type_id));
TF_RETURN_IF_ERROR(UpsertType(request.execution_type(),
request.can_add_fields(),
metadata_access_object_.get(), &type_id));
response->set_type_id(type_id);
return tensorflow::Status::OK();
});
Expand All @@ -193,9 +273,9 @@ tensorflow::Status MetadataStore::PutContextType(
metadata_source_.get(),
[this, &request, &response]() -> tensorflow::Status {
int64 type_id;
TF_RETURN_IF_ERROR(UpsertType(metadata_access_object_.get(),
request.context_type(),
request.can_add_fields(), &type_id));
TF_RETURN_IF_ERROR(UpsertType(request.context_type(),
request.can_add_fields(),
metadata_access_object_.get(), &type_id));
response->set_type_id(type_id);
return tensorflow::Status::OK();
});
Expand Down Expand Up @@ -358,16 +438,10 @@ tensorflow::Status MetadataStore::PutArtifacts(
metadata_source_.get(),
[this, &request, &response]() -> tensorflow::Status {
for (const Artifact& artifact : request.artifacts()) {
if (artifact.has_id()) {
TF_RETURN_IF_ERROR(
metadata_access_object_->UpdateArtifact(artifact));
response->add_artifact_ids(artifact.id());
} else {
int64 artifact_id;
TF_RETURN_IF_ERROR(metadata_access_object_->CreateArtifact(
artifact, &artifact_id));
response->add_artifact_ids(artifact_id);
}
int64 artifact_id = -1;
TF_RETURN_IF_ERROR(UpsertArtifact(
artifact, metadata_access_object_.get(), &artifact_id));
response->add_artifact_ids(artifact_id);
}
return tensorflow::Status::OK();
});
Expand All @@ -379,16 +453,10 @@ tensorflow::Status MetadataStore::PutExecutions(
metadata_source_.get(),
[this, &request, &response]() -> tensorflow::Status {
for (const Execution& execution : request.executions()) {
if (execution.has_id()) {
TF_RETURN_IF_ERROR(
metadata_access_object_->UpdateExecution(execution));
response->add_execution_ids(execution.id());
} else {
int64 execution_id;
TF_RETURN_IF_ERROR(metadata_access_object_->CreateExecution(
execution, &execution_id));
response->add_execution_ids(execution_id);
}
int64 execution_id = -1;
TF_RETURN_IF_ERROR(UpsertExecution(
execution, metadata_access_object_.get(), &execution_id));
response->add_execution_ids(execution_id);
}
return tensorflow::Status::OK();
});
Expand All @@ -400,15 +468,10 @@ tensorflow::Status MetadataStore::PutContexts(const PutContextsRequest& request,
metadata_source_.get(),
[this, &request, &response]() -> tensorflow::Status {
for (const Context& context : request.contexts()) {
if (context.has_id()) {
TF_RETURN_IF_ERROR(metadata_access_object_->UpdateContext(context));
response->add_context_ids(context.id());
} else {
int64 context_id;
TF_RETURN_IF_ERROR(
metadata_access_object_->CreateContext(context, &context_id));
response->add_context_ids(context_id);
}
int64 context_id = -1;
TF_RETURN_IF_ERROR(UpsertContext(
context, metadata_access_object_.get(), &context_id));
response->add_context_ids(context_id);
}
return tensorflow::Status::OK();
});
Expand Down Expand Up @@ -467,14 +530,8 @@ tensorflow::Status MetadataStore::PutExecution(
// 1. Upsert Execution
const Execution& execution = request.execution();
int64 execution_id = -1;
if (execution.has_id()) {
TF_RETURN_IF_ERROR(
metadata_access_object_->UpdateExecution(execution));
execution_id = execution.id();
} else {
TF_RETURN_IF_ERROR(metadata_access_object_->CreateExecution(
execution, &execution_id));
}
TF_RETURN_IF_ERROR(UpsertExecution(
execution, metadata_access_object_.get(), &execution_id));
response->set_execution_id(execution_id);
// 2. Upsert Artifacts and insert events
for (const PutExecutionRequest::ArtifactAndEvent& artifact_and_event :
Expand All @@ -485,14 +542,8 @@ tensorflow::Status MetadataStore::PutExecution(
}
const Artifact& artifact = artifact_and_event.artifact();
int64 artifact_id = -1;
if (artifact.has_id()) {
TF_RETURN_IF_ERROR(
metadata_access_object_->UpdateArtifact(artifact));
artifact_id = artifact.id();
} else {
TF_RETURN_IF_ERROR(metadata_access_object_->CreateArtifact(
artifact, &artifact_id));
}
TF_RETURN_IF_ERROR(UpsertArtifact(
artifact, metadata_access_object_.get(), &artifact_id));
response->add_artifact_ids(artifact_id);
// insert event if any
if (!artifact_and_event.has_event()) {
Expand Down Expand Up @@ -520,6 +571,20 @@ tensorflow::Status MetadataStore::PutExecution(
TF_RETURN_IF_ERROR(
metadata_access_object_->CreateEvent(event, &dummy_event_id));
}
// 3. Upsert contexts and insert associations and attributions.
for (const Context& context : request.contexts()) {
int64 context_id = -1;
TF_RETURN_IF_ERROR(UpsertContext(
context, metadata_access_object_.get(), &context_id));
response->add_context_ids(context_id);
TF_RETURN_IF_ERROR(
InsertAssociationIfNotExist(context_id, response->execution_id(),
metadata_access_object_.get()));
for (const int64 artifact_id : response->artifact_ids()) {
TF_RETURN_IF_ERROR(InsertAttributionIfNotExist(
context_id, artifact_id, metadata_access_object_.get()));
}
}
return tensorflow::Status::OK();
});
}
Expand Down Expand Up @@ -784,20 +849,14 @@ tensorflow::Status MetadataStore::PutAttributionsAndAssociations(
return ExecuteTransaction(
metadata_source_.get(), [this, &request]() -> tensorflow::Status {
for (const Attribution& attribution : request.attributions()) {
int64 dummy_attribution_id;
tensorflow::Status status =
metadata_access_object_->CreateAttribution(attribution,
&dummy_attribution_id);
if (tensorflow::errors::IsAlreadyExists(status)) continue;
TF_RETURN_IF_ERROR(status);
TF_RETURN_IF_ERROR(InsertAttributionIfNotExist(
attribution.context_id(), attribution.artifact_id(),
metadata_access_object_.get()));
}
for (const Association& association : request.associations()) {
int64 dummy_assocation_id;
tensorflow::Status status =
metadata_access_object_->CreateAssociation(association,
&dummy_assocation_id);
if (tensorflow::errors::IsAlreadyExists(status)) continue;
TF_RETURN_IF_ERROR(status);
TF_RETURN_IF_ERROR(InsertAssociationIfNotExist(
association.context_id(), association.execution_id(),
metadata_access_object_.get()));
}
return tensorflow::Status::OK();
});
Expand Down
Loading

0 comments on commit c9747be

Please sign in to comment.