Skip to content

Commit

Permalink
Fix race condition when restarting continuous jobs (#1849)
Browse files Browse the repository at this point in the history
## Changes
We don't need to cancel existing runs when the job is continuous and
unpaused. The `/jobs/run-now` command will cancel the existing run and
trigger a new one automatically.

Cancelling the job manually can cause a race condition where both the
manual trigger from the CLI and the continuous trigger from the job
configuration happens at the same time. This PR prevents that from
happening.

## Tests
Unit tests and manually
  • Loading branch information
shreyas-goenka authored Oct 22, 2024
1 parent 68d69d6 commit 3bab21e
Show file tree
Hide file tree
Showing 6 changed files with 247 additions and 8 deletions.
23 changes: 23 additions & 0 deletions bundle/run/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,29 @@ func (r *jobRunner) Cancel(ctx context.Context) error {
return errGroup.Wait()
}

func (r *jobRunner) Restart(ctx context.Context, opts *Options) (output.RunOutput, error) {
// We don't need to cancel existing runs if the job is continuous and unpaused.
// the /jobs/run-now API will automatically cancel any existing runs before starting a new one.
//
// /jobs/run-now will not cancel existing runs if the job is continuous and paused.
// New job runs will be queued instead and will wait for existing runs to finish.
// In this case, we need to cancel the existing runs before starting a new one.
continuous := r.job.JobSettings.Continuous
if continuous != nil && continuous.PauseStatus == jobs.PauseStatusUnpaused {
return r.Run(ctx, opts)
}

s := cmdio.Spinner(ctx)
s <- "Cancelling all active job runs"
err := r.Cancel(ctx)
close(s)
if err != nil {
return nil, err
}

return r.Run(ctx, opts)
}

func (r *jobRunner) ParseArgs(args []string, opts *Options) error {
return r.posArgsHandler().ParseArgs(args, opts)
}
Expand Down
132 changes: 132 additions & 0 deletions bundle/run/job_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package run

import (
"bytes"
"context"
"testing"
"time"

"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/resources"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/flags"
"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/stretchr/testify/mock"
Expand Down Expand Up @@ -126,3 +129,132 @@ func TestJobRunnerCancelWithNoActiveRuns(t *testing.T) {
err := runner.Cancel(context.Background())
require.NoError(t, err)
}

func TestJobRunnerRestart(t *testing.T) {
for _, jobSettings := range []*jobs.JobSettings{
{},
{
Continuous: &jobs.Continuous{
PauseStatus: jobs.PauseStatusPaused,
},
},
} {
job := &resources.Job{
ID: "123",
JobSettings: jobSettings,
}
b := &bundle.Bundle{
Config: config.Root{
Resources: config.Resources{
Jobs: map[string]*resources.Job{
"test_job": job,
},
},
},
}

runner := jobRunner{key: "test", bundle: b, job: job}

m := mocks.NewMockWorkspaceClient(t)
b.SetWorkpaceClient(m.WorkspaceClient)
ctx := context.Background()
ctx = cmdio.InContext(ctx, cmdio.NewIO(flags.OutputText, &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{}, "", ""))
ctx = cmdio.NewContext(ctx, cmdio.NewLogger(flags.ModeAppend))

jobApi := m.GetMockJobsAPI()
jobApi.EXPECT().ListRunsAll(mock.Anything, jobs.ListRunsRequest{
ActiveOnly: true,
JobId: 123,
}).Return([]jobs.BaseRun{
{RunId: 1},
{RunId: 2},
}, nil)

// Mock the runner cancelling existing job runs.
mockWait := &jobs.WaitGetRunJobTerminatedOrSkipped[struct{}]{
Poll: func(time time.Duration, f func(j *jobs.Run)) (*jobs.Run, error) {
return nil, nil
},
}
jobApi.EXPECT().CancelRun(mock.Anything, jobs.CancelRun{
RunId: 1,
}).Return(mockWait, nil)
jobApi.EXPECT().CancelRun(mock.Anything, jobs.CancelRun{
RunId: 2,
}).Return(mockWait, nil)

// Mock the runner triggering a job run
mockWaitForRun := &jobs.WaitGetRunJobTerminatedOrSkipped[jobs.RunNowResponse]{
Poll: func(d time.Duration, f func(*jobs.Run)) (*jobs.Run, error) {
return &jobs.Run{
State: &jobs.RunState{
ResultState: jobs.RunResultStateSuccess,
},
}, nil
},
}
jobApi.EXPECT().RunNow(mock.Anything, jobs.RunNow{
JobId: 123,
}).Return(mockWaitForRun, nil)

// Mock the runner getting the job output
jobApi.EXPECT().GetRun(mock.Anything, jobs.GetRunRequest{}).Return(&jobs.Run{}, nil)

_, err := runner.Restart(ctx, &Options{})
require.NoError(t, err)
}
}

func TestJobRunnerRestartForContinuousUnpausedJobs(t *testing.T) {
job := &resources.Job{
ID: "123",
JobSettings: &jobs.JobSettings{
Continuous: &jobs.Continuous{
PauseStatus: jobs.PauseStatusUnpaused,
},
},
}
b := &bundle.Bundle{
Config: config.Root{
Resources: config.Resources{
Jobs: map[string]*resources.Job{
"test_job": job,
},
},
},
}

runner := jobRunner{key: "test", bundle: b, job: job}

m := mocks.NewMockWorkspaceClient(t)
b.SetWorkpaceClient(m.WorkspaceClient)
ctx := context.Background()
ctx = cmdio.InContext(ctx, cmdio.NewIO(flags.OutputText, &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{}, "", "..."))
ctx = cmdio.NewContext(ctx, cmdio.NewLogger(flags.ModeAppend))

jobApi := m.GetMockJobsAPI()

// The runner should not try and cancel existing job runs for unpaused continuous jobs.
jobApi.AssertNotCalled(t, "ListRunsAll")
jobApi.AssertNotCalled(t, "CancelRun")

// Mock the runner triggering a job run
mockWaitForRun := &jobs.WaitGetRunJobTerminatedOrSkipped[jobs.RunNowResponse]{
Poll: func(d time.Duration, f func(*jobs.Run)) (*jobs.Run, error) {
return &jobs.Run{
State: &jobs.RunState{
ResultState: jobs.RunResultStateSuccess,
},
}, nil
},
}
jobApi.EXPECT().RunNow(mock.Anything, jobs.RunNow{
JobId: 123,
}).Return(mockWaitForRun, nil)

// Mock the runner getting the job output
jobApi.EXPECT().GetRun(mock.Anything, jobs.GetRunRequest{}).Return(&jobs.Run{}, nil)

_, err := runner.Restart(ctx, &Options{})
require.NoError(t, err)
}
12 changes: 12 additions & 0 deletions bundle/run/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,18 @@ func (r *pipelineRunner) Cancel(ctx context.Context) error {
return err
}

func (r *pipelineRunner) Restart(ctx context.Context, opts *Options) (output.RunOutput, error) {
s := cmdio.Spinner(ctx)
s <- "Cancelling the active pipeline update"
err := r.Cancel(ctx)
close(s)
if err != nil {
return nil, err
}

return r.Run(ctx, opts)
}

func (r *pipelineRunner) ParseArgs(args []string, opts *Options) error {
if len(args) == 0 {
return nil
Expand Down
70 changes: 70 additions & 0 deletions bundle/run/pipeline_test.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
package run

import (
"bytes"
"context"
"testing"
"time"

"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/resources"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/flags"
sdk_config "github.com/databricks/databricks-sdk-go/config"
"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/service/pipelines"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -47,3 +52,68 @@ func TestPipelineRunnerCancel(t *testing.T) {
err := runner.Cancel(context.Background())
require.NoError(t, err)
}

func TestPipelineRunnerRestart(t *testing.T) {
pipeline := &resources.Pipeline{
ID: "123",
}

b := &bundle.Bundle{
Config: config.Root{
Resources: config.Resources{
Pipelines: map[string]*resources.Pipeline{
"test_pipeline": pipeline,
},
},
},
}

runner := pipelineRunner{key: "test", bundle: b, pipeline: pipeline}

m := mocks.NewMockWorkspaceClient(t)
m.WorkspaceClient.Config = &sdk_config.Config{
Host: "https://test.com",
}
b.SetWorkpaceClient(m.WorkspaceClient)
ctx := context.Background()
ctx = cmdio.InContext(ctx, cmdio.NewIO(flags.OutputText, &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{}, "", "..."))
ctx = cmdio.NewContext(ctx, cmdio.NewLogger(flags.ModeAppend))

mockWait := &pipelines.WaitGetPipelineIdle[struct{}]{
Poll: func(time.Duration, func(*pipelines.GetPipelineResponse)) (*pipelines.GetPipelineResponse, error) {
return nil, nil
},
}

pipelineApi := m.GetMockPipelinesAPI()
pipelineApi.EXPECT().Stop(mock.Anything, pipelines.StopRequest{
PipelineId: "123",
}).Return(mockWait, nil)

pipelineApi.EXPECT().GetByPipelineId(mock.Anything, "123").Return(&pipelines.GetPipelineResponse{}, nil)

// Mock runner starting a new update
pipelineApi.EXPECT().StartUpdate(mock.Anything, pipelines.StartUpdate{
PipelineId: "123",
}).Return(&pipelines.StartUpdateResponse{
UpdateId: "456",
}, nil)

// Mock runner polling for events
pipelineApi.EXPECT().ListPipelineEventsAll(mock.Anything, pipelines.ListPipelineEventsRequest{
Filter: `update_id = '456'`,
MaxResults: 100,
PipelineId: "123",
}).Return([]pipelines.PipelineEvent{}, nil)

// Mock runner polling for update status
pipelineApi.EXPECT().GetUpdateByPipelineIdAndUpdateId(mock.Anything, "123", "456").
Return(&pipelines.GetUpdateResponse{
Update: &pipelines.UpdateInfo{
State: pipelines.UpdateInfoStateCompleted,
},
}, nil)

_, err := runner.Restart(ctx, &Options{})
require.NoError(t, err)
}
4 changes: 4 additions & 0 deletions bundle/run/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ type Runner interface {
// Run the underlying worklow.
Run(ctx context.Context, opts *Options) (output.RunOutput, error)

// Restart the underlying workflow by cancelling any existing runs before
// starting a new one.
Restart(ctx context.Context, opts *Options) (output.RunOutput, error)

// Cancel the underlying workflow.
Cancel(ctx context.Context) error

Expand Down
14 changes: 6 additions & 8 deletions cmd/bundle/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/databricks/cli/bundle/deploy/terraform"
"github.com/databricks/cli/bundle/phases"
"github.com/databricks/cli/bundle/run"
"github.com/databricks/cli/bundle/run/output"
"github.com/databricks/cli/cmd/bundle/utils"
"github.com/databricks/cli/cmd/root"
"github.com/databricks/cli/libs/cmdio"
Expand Down Expand Up @@ -100,19 +101,16 @@ task or a Python wheel task, the second example applies.
}

runOptions.NoWait = noWait
var output output.RunOutput
if restart {
s := cmdio.Spinner(ctx)
s <- "Cancelling all runs"
err := runner.Cancel(ctx)
close(s)
if err != nil {
return err
}
output, err = runner.Restart(ctx, &runOptions)
} else {
output, err = runner.Run(ctx, &runOptions)
}
output, err := runner.Run(ctx, &runOptions)
if err != nil {
return err
}

if output != nil {
switch root.OutputType(cmd) {
case flags.OutputText:
Expand Down

0 comments on commit 3bab21e

Please sign in to comment.