Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Move caching to node executor for fast cache hits (#485)
Browse files Browse the repository at this point in the history
* fast cache working-ish

Signed-off-by: Daniel Rammer <daniel@union.ai>

* processing downstream immediately on cache hit

Signed-off-by: Daniel Rammer <daniel@union.ai>

* moved cache write to node executor

Signed-off-by: Daniel Rammer <daniel@union.ai>

* working cache and cache serialize

Signed-off-by: Daniel Rammer <daniel@union.ai>

* starting to clean up

Signed-off-by: Daniel Rammer <daniel@union.ai>

* removed commented out code

Signed-off-by: Daniel Rammer <daniel@union.ai>

* removed separate IsCacheable and IsCacheSerializable functions from CacheableNode interface

Signed-off-by: Daniel Rammer <daniel@union.ai>

* refactored reservation owner id to new function to remove duplication

Signed-off-by: Daniel Rammer <daniel@union.ai>

* added cache metrics to the node executor

Signed-off-by: Daniel Rammer <daniel@union.ai>

* cleaned up node cache.go

Signed-off-by: Daniel Rammer <daniel@union.ai>

* more cleanup

Signed-off-by: Daniel Rammer <daniel@union.ai>

* setting cache information in phase info so that it is available in events

Signed-off-by: Daniel Rammer <daniel@union.ai>

* minor refactoring and bug fixes

Signed-off-by: Daniel Rammer <daniel@union.ai>

* doing an outputs lookup on cache to ensure correctness during failures

Signed-off-by: Daniel Rammer <daniel@union.ai>

* fix unit tests

Signed-off-by: Daniel Rammer <daniel@union.ai>

* fixed lint issues

Signed-off-by: Daniel Rammer <daniel@union.ai>

* moved catalog package to the node level

Signed-off-by: Daniel Rammer <daniel@union.ai>

* refactored task handler

Signed-off-by: Daniel Rammer <daniel@union.ai>

* fixed catalog imports on unit testes

Signed-off-by: Daniel Rammer <daniel@union.ai>

* started cache unit tests

Signed-off-by: Daniel Rammer <daniel@union.ai>

* added CheckCatalogCache unit tests

Signed-off-by: Daniel Rammer <daniel@union.ai>

* unit tests for node cache file

Signed-off-by: Daniel Rammer <daniel@union.ai>

* added node executor cache unit tests

Signed-off-by: Daniel Rammer <daniel@union.ai>

* fixed cache unit tets

Signed-off-by: Daniel Rammer <daniel@union.ai>

* fixed lint issues

Signed-off-by: Daniel Rammer <daniel@union.ai>

* transitioning to 'Succeeded' immediately on cache hit

Signed-off-by: Daniel Rammer <daniel@union.ai>

* supporting cache overwrite

Signed-off-by: Daniel Rammer <daniel@union.ai>

* fixed lint issues

Signed-off-by: Daniel Rammer <daniel@union.ai>

* removed automatic downstream on cache hit

Signed-off-by: Daniel Rammer <daniel@union.ai>

* bumping boilerplate support tools to go 1.19 to fix generate

Signed-off-by: Daniel Rammer <daniel@union.ai>

* fixed unit tests and linter

Signed-off-by: Daniel Rammer <daniel@union.ai>

* removed unnecessary async catalog client from nodeExecutor

Signed-off-by: Daniel Rammer <daniel@union.ai>

* general refactoring

Signed-off-by: Daniel Rammer <daniel@union.ai>

* fastcache working with arraynode

Signed-off-by: Daniel Rammer <daniel@union.ai>

* fixed unit tests - no longer checking for output existance on first execution of cached

Signed-off-by: Daniel Rammer <daniel@union.ai>

* updating documentation TODOs

Signed-off-by: Daniel Rammer <daniel@union.ai>

* updated arraynode fastcache to correctly report cache hits

Signed-off-by: Daniel Rammer <daniel@union.ai>

* remove print statement

Signed-off-by: Daniel Rammer <daniel@union.ai>

* fixed cache serialize

Signed-off-by: Daniel Rammer <daniel@union.ai>

* fixed unit tests

Signed-off-by: Daniel Rammer <daniel@union.ai>

---------

Signed-off-by: Daniel Rammer <daniel@union.ai>
  • Loading branch information
hamersaw authored Aug 14, 2023
1 parent cbfcdf3 commit 6e06386
Show file tree
Hide file tree
Showing 30 changed files with 1,711 additions and 1,116 deletions.
2 changes: 1 addition & 1 deletion pkg/apis/flyteworkflow/v1alpha1/node_status.go
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ func (in *NodeStatus) GetOrCreateArrayNodeStatus() MutableArrayNodeStatus {
}

func (in *NodeStatus) UpdatePhase(p NodePhase, occurredAt metav1.Time, reason string, err *core.ExecutionError) {
if in.Phase == p {
if in.Phase == p && in.Message == reason {
// We will not update the phase multiple times. This prevents the comparison from returning false positive
return
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/apis/flyteworkflow/v1alpha1/node_status_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ func TestNodeStatus_UpdatePhase(t *testing.T) {
t.Run("identical-phase", func(t *testing.T) {
p := NodePhaseQueued
ns := NodeStatus{
Phase: p,
Phase: p,
Message: queued,
}
msg := queued
ns.UpdatePhase(p, n, msg, nil)
Expand Down
2 changes: 1 addition & 1 deletion pkg/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ import (
"github.com/flyteorg/flytepropeller/pkg/controller/config"
"github.com/flyteorg/flytepropeller/pkg/controller/executors"
"github.com/flyteorg/flytepropeller/pkg/controller/nodes"
"github.com/flyteorg/flytepropeller/pkg/controller/nodes/catalog"
errors3 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors"
"github.com/flyteorg/flytepropeller/pkg/controller/nodes/factory"
"github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery"
"github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan"
"github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/catalog"
"github.com/flyteorg/flytepropeller/pkg/controller/workflow"
"github.com/flyteorg/flytepropeller/pkg/controller/workflowstore"
leader "github.com/flyteorg/flytepropeller/pkg/leaderelection"
Expand Down
24 changes: 17 additions & 7 deletions pkg/controller/nodes/array/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,18 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu

retryAttempt := subNodeStatus.GetAttempts()

// fastcache will not emit task events for cache hits. we need to manually detect a
// transition to `SUCCEEDED` and add an `ExternalResourceInfo` for it.
if cacheStatus == idlcore.CatalogCacheStatus_CACHE_HIT && len(arrayEventRecorder.TaskEvents()) == 0 {
externalResources = append(externalResources, &event.ExternalResourceInfo{
ExternalId: buildSubNodeID(nCtx, i, retryAttempt),
Index: uint32(i),
RetryAttempt: retryAttempt,
Phase: idlcore.TaskExecution_SUCCEEDED,
CacheStatus: cacheStatus,
})
}

for _, taskExecutionEvent := range arrayEventRecorder.TaskEvents() {
for _, log := range taskExecutionEvent.Logs {
log.Name = fmt.Sprintf("%s-%d", log.Name, i)
Expand Down Expand Up @@ -543,19 +555,17 @@ func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx inter

inputReader := newStaticInputReader(nCtx.InputReader(), inputLiteralMap)

// if node has not yet started we automatically set to NodePhaseQueued to skip input resolution
if nodePhase == v1alpha1.NodePhaseNotYetStarted {
// TODO - to supprt fastcache we'll need to override the bindings to BindingScalars for the input resolution on the nCtx
// that way resolution is just reading a literal ... but does this still write a file then?!?
nodePhase = v1alpha1.NodePhaseQueued
}

// wrap node lookup
subNodeSpec := *arrayNode.GetSubNodeSpec()

subNodeID := fmt.Sprintf("%s-n%d", nCtx.NodeID(), subNodeIndex)
subNodeSpec.ID = subNodeID
subNodeSpec.Name = subNodeID
// mock the input bindings for the subNode to nil to bypass input resolution in the
// `nodeExecutor.preExecute` function. this is required because this function is the entrypoint
// for initial cache lookups. an alternative solution would be to mock the datastore to bypass
// writing the inputFile.
subNodeSpec.InputBindings = nil

// TODO - if we want to support more plugin types we need to figure out the best way to store plugin state
// currently just mocking based on node phase -> which works for all k8s plugins
Expand Down
2 changes: 1 addition & 1 deletion pkg/controller/nodes/array/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ import (
"github.com/flyteorg/flytepropeller/pkg/controller/config"
execmocks "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks"
"github.com/flyteorg/flytepropeller/pkg/controller/nodes"
"github.com/flyteorg/flytepropeller/pkg/controller/nodes/catalog"
gatemocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/gate/mocks"
"github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler"
"github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces"
"github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks"
recoverymocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery/mocks"
"github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan"
"github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/catalog"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
pluginmocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks"
Expand Down
8 changes: 8 additions & 0 deletions pkg/controller/nodes/array/node_lookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ type arrayNodeLookup struct {
subNodeStatus *v1alpha1.NodeStatus
}

func (a *arrayNodeLookup) ToNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, error) {
if id == a.subNodeID {
return nil, nil
}

return a.NodeLookup.ToNode(id)
}

func (a *arrayNodeLookup) GetNode(nodeID v1alpha1.NodeID) (v1alpha1.ExecutableNode, bool) {
if nodeID == a.subNodeID {
return a.subNodeSpec, true
Expand Down
234 changes: 234 additions & 0 deletions pkg/controller/nodes/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
package nodes

import (
"context"
"strconv"
"time"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/encoding"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils"

"github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1"
"github.com/flyteorg/flytepropeller/pkg/controller/nodes/common"
nodeserrors "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors"
"github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler"
"github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces"
"github.com/flyteorg/flytepropeller/pkg/controller/nodes/task"

"github.com/flyteorg/flytestdlib/logger"
"github.com/flyteorg/flytestdlib/storage"

"github.com/pkg/errors"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

// computeCatalogReservationOwnerID constructs a unique identifier which includes the nodes
// parent information, node ID, and retry attempt number. This is used to uniquely identify a task
// when the cache reservation API to serialize cached executions.
func computeCatalogReservationOwnerID(nCtx interfaces.NodeExecutionContext) (string, error) {
currentNodeUniqueID, err := common.GenerateUniqueID(nCtx.ExecutionContext().GetParentInfo(), nCtx.NodeID())
if err != nil {
return "", err
}

ownerID, err := encoding.FixedLengthUniqueIDForParts(task.IDMaxLength,
[]string{nCtx.NodeExecutionMetadata().GetOwnerID().Name, currentNodeUniqueID, strconv.Itoa(int(nCtx.CurrentAttempt()))})
if err != nil {
return "", err
}

return ownerID, nil
}

// updatePhaseCacheInfo adds the cache and catalog reservation metadata to the PhaseInfo. This
// ensures this information is reported in events and available within FlyteAdmin.
func updatePhaseCacheInfo(phaseInfo handler.PhaseInfo, cacheStatus *catalog.Status, reservationStatus *core.CatalogReservation_Status) handler.PhaseInfo {
if cacheStatus == nil && reservationStatus == nil {
return phaseInfo
}

info := phaseInfo.GetInfo()
if info == nil {
info = &handler.ExecutionInfo{}
}

if info.TaskNodeInfo == nil {
info.TaskNodeInfo = &handler.TaskNodeInfo{}
}

if info.TaskNodeInfo.TaskNodeMetadata == nil {
info.TaskNodeInfo.TaskNodeMetadata = &event.TaskNodeMetadata{}
}

if cacheStatus != nil {
info.TaskNodeInfo.TaskNodeMetadata.CacheStatus = cacheStatus.GetCacheStatus()
info.TaskNodeInfo.TaskNodeMetadata.CatalogKey = cacheStatus.GetMetadata()
}

if reservationStatus != nil {
info.TaskNodeInfo.TaskNodeMetadata.ReservationStatus = *reservationStatus
}

return phaseInfo.WithInfo(info)
}

// CheckCatalogCache uses the handler and contexts to check if cached outputs for the current node
// exist. If the exist, this function also copies the outputs to this node.
func (n *nodeExecutor) CheckCatalogCache(ctx context.Context, nCtx interfaces.NodeExecutionContext, cacheHandler interfaces.CacheableNodeHandler) (catalog.Entry, error) {
catalogKey, err := cacheHandler.GetCatalogKey(ctx, nCtx)
if err != nil {
return catalog.Entry{}, errors.Wrapf(err, "failed to initialize the catalogKey")
}

entry, err := n.catalog.Get(ctx, catalogKey)
if err != nil {
causeErr := errors.Cause(err)
if taskStatus, ok := status.FromError(causeErr); ok && taskStatus.Code() == codes.NotFound {
n.metrics.catalogMissCount.Inc(ctx)
logger.Infof(ctx, "Catalog CacheMiss: Artifact not found in Catalog. Executing Task.")
return catalog.NewCatalogEntry(nil, catalog.NewStatus(core.CatalogCacheStatus_CACHE_MISS, nil)), nil
}

n.metrics.catalogGetFailureCount.Inc(ctx)
logger.Errorf(ctx, "Catalog Failure: memoization check failed. err: %v", err.Error())
return catalog.Entry{}, errors.Wrapf(err, "Failed to check Catalog for previous results")
}

if entry.GetStatus().GetCacheStatus() != core.CatalogCacheStatus_CACHE_HIT {
logger.Errorf(ctx, "No CacheHIT and no Error received. Illegal state, Cache State: %s", entry.GetStatus().GetCacheStatus().String())
// TODO should this be an error?
return entry, nil
}

logger.Infof(ctx, "Catalog CacheHit: for task [%s/%s/%s/%s]", catalogKey.Identifier.Project,
catalogKey.Identifier.Domain, catalogKey.Identifier.Name, catalogKey.Identifier.Version)
n.metrics.catalogHitCount.Inc(ctx)

iface := catalogKey.TypedInterface
if iface.Outputs != nil && len(iface.Outputs.Variables) > 0 {
// copy cached outputs to node outputs
o, ee, err := entry.GetOutputs().Read(ctx)
if err != nil {
logger.Errorf(ctx, "failed to read from catalog, err: %s", err.Error())
return catalog.Entry{}, err
} else if ee != nil {
logger.Errorf(ctx, "got execution error from catalog output reader? This should not happen, err: %s", ee.String())
return catalog.Entry{}, nodeserrors.Errorf(nodeserrors.IllegalStateError, nCtx.NodeID(), "execution error from a cache output, bad state: %s", ee.String())
}

outputFile := v1alpha1.GetOutputsFile(nCtx.NodeStatus().GetOutputDir())
if err := nCtx.DataStore().WriteProtobuf(ctx, outputFile, storage.Options{}, o); err != nil {
logger.Errorf(ctx, "failed to write cached value to datastore, err: %s", err.Error())
return catalog.Entry{}, err
}
}

return entry, nil
}

// GetOrExtendCatalogReservation attempts to acquire an artifact reservation if the task is
// cachable and cache serializable. If the reservation already exists for this owner, the
// reservation is extended.
func (n *nodeExecutor) GetOrExtendCatalogReservation(ctx context.Context, nCtx interfaces.NodeExecutionContext,
cacheHandler interfaces.CacheableNodeHandler, heartbeatInterval time.Duration) (catalog.ReservationEntry, error) {

catalogKey, err := cacheHandler.GetCatalogKey(ctx, nCtx)
if err != nil {
return catalog.NewReservationEntryStatus(core.CatalogReservation_RESERVATION_DISABLED),
errors.Wrapf(err, "failed to initialize the catalogKey")
}

ownerID, err := computeCatalogReservationOwnerID(nCtx)
if err != nil {
return catalog.NewReservationEntryStatus(core.CatalogReservation_RESERVATION_DISABLED),
errors.Wrapf(err, "failed to initialize the cache reservation ownerID")
}

reservation, err := n.catalog.GetOrExtendReservation(ctx, catalogKey, ownerID, heartbeatInterval)
if err != nil {
n.metrics.reservationGetFailureCount.Inc(ctx)
logger.Errorf(ctx, "Catalog Failure: reservation get or extend failed. err: %v", err.Error())
return catalog.NewReservationEntryStatus(core.CatalogReservation_RESERVATION_FAILURE), err
}

var status core.CatalogReservation_Status
if reservation.OwnerId == ownerID {
status = core.CatalogReservation_RESERVATION_ACQUIRED
} else {
status = core.CatalogReservation_RESERVATION_EXISTS
}

n.metrics.reservationGetSuccessCount.Inc(ctx)
return catalog.NewReservationEntry(reservation.ExpiresAt.AsTime(),
reservation.HeartbeatInterval.AsDuration(), reservation.OwnerId, status), nil
}

// ReleaseCatalogReservation attempts to release an artifact reservation if the task is cachable
// and cache serializable. If the reservation does not exist for this owner (e.x. it never existed
// or has been acquired by another owner) this call is still successful.
func (n *nodeExecutor) ReleaseCatalogReservation(ctx context.Context, nCtx interfaces.NodeExecutionContext,
cacheHandler interfaces.CacheableNodeHandler) (catalog.ReservationEntry, error) {

catalogKey, err := cacheHandler.GetCatalogKey(ctx, nCtx)
if err != nil {
return catalog.NewReservationEntryStatus(core.CatalogReservation_RESERVATION_DISABLED),
errors.Wrapf(err, "failed to initialize the catalogKey")
}

ownerID, err := computeCatalogReservationOwnerID(nCtx)
if err != nil {
return catalog.NewReservationEntryStatus(core.CatalogReservation_RESERVATION_DISABLED),
errors.Wrapf(err, "failed to initialize the cache reservation ownerID")
}

err = n.catalog.ReleaseReservation(ctx, catalogKey, ownerID)
if err != nil {
n.metrics.reservationReleaseFailureCount.Inc(ctx)
logger.Errorf(ctx, "Catalog Failure: release reservation failed. err: %v", err.Error())
return catalog.NewReservationEntryStatus(core.CatalogReservation_RESERVATION_FAILURE), err
}

n.metrics.reservationReleaseSuccessCount.Inc(ctx)
return catalog.NewReservationEntryStatus(core.CatalogReservation_RESERVATION_RELEASED), nil
}

// WriteCatalogCache relays the outputs of this node to the cache. This allows future executions
// to reuse these data to avoid recomputation.
func (n *nodeExecutor) WriteCatalogCache(ctx context.Context, nCtx interfaces.NodeExecutionContext, cacheHandler interfaces.CacheableNodeHandler) (catalog.Status, error) {
catalogKey, err := cacheHandler.GetCatalogKey(ctx, nCtx)
if err != nil {
return catalog.NewStatus(core.CatalogCacheStatus_CACHE_DISABLED, nil), errors.Wrapf(err, "failed to initialize the catalogKey")
}

iface := catalogKey.TypedInterface
if iface.Outputs != nil && len(iface.Outputs.Variables) == 0 {
return catalog.NewStatus(core.CatalogCacheStatus_CACHE_DISABLED, nil), nil
}

logger.Infof(ctx, "Catalog CacheEnabled. recording execution [%s/%s/%s/%s]", catalogKey.Identifier.Project,
catalogKey.Identifier.Domain, catalogKey.Identifier.Name, catalogKey.Identifier.Version)

outputPaths := ioutils.NewReadOnlyOutputFilePaths(ctx, nCtx.DataStore(), nCtx.NodeStatus().GetOutputDir())
outputReader := ioutils.NewRemoteFileOutputReader(ctx, nCtx.DataStore(), outputPaths, nCtx.MaxDatasetSizeBytes())
metadata := catalog.Metadata{
TaskExecutionIdentifier: task.GetTaskExecutionIdentifier(nCtx),
}

// ignores discovery write failures
status, err := n.catalog.Put(ctx, catalogKey, outputReader, metadata)
if err != nil {
n.metrics.catalogPutFailureCount.Inc(ctx)
logger.Errorf(ctx, "Failed to write results to catalog for Task [%v]. Error: %v", catalogKey.Identifier, err)
return catalog.NewStatus(core.CatalogCacheStatus_CACHE_PUT_FAILURE, status.GetMetadata()), nil
}

n.metrics.catalogPutSuccessCount.Inc(ctx)
logger.Infof(ctx, "Successfully cached results to catalog - Task [%v]", catalogKey.Identifier)
return status, nil
}
Loading

0 comments on commit 6e06386

Please sign in to comment.