diff --git a/bundle/run/app_test.go b/bundle/run/app_test.go index 8e82f45ae8..77f197e8d4 100644 --- a/bundle/run/app_test.go +++ b/bundle/run/app_test.go @@ -1,7 +1,6 @@ package run import ( - "bytes" "context" "errors" "os" @@ -16,7 +15,6 @@ import ( "github.com/databricks/cli/bundle/internal/bundletest" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/dyn" - "github.com/databricks/cli/libs/flags" "github.com/databricks/cli/libs/vfs" "github.com/databricks/databricks-sdk-go/experimental/mocks" "github.com/databricks/databricks-sdk-go/service/apps" @@ -76,9 +74,7 @@ func setupBundle(t *testing.T) (context.Context, *bundle.Bundle, *mocks.MockWork b.SetWorkpaceClient(mwc.WorkspaceClient) bundletest.SetLocation(b, "resources.apps.my_app", []dyn.Location{{File: "./databricks.yml"}}) - ctx := context.Background() - ctx = cmdio.InContext(ctx, cmdio.NewIO(ctx, flags.OutputText, &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{}, "", "...")) - ctx = cmdio.NewContext(ctx, cmdio.NewLogger(flags.ModeAppend)) + ctx := cmdio.MockDiscard(context.Background()) diags := bundle.Apply(ctx, b, bundle.Seq( mutator.DefineDefaultWorkspacePaths(), diff --git a/bundle/run/job_test.go b/bundle/run/job_test.go index 72aecc8872..daf6cf063e 100644 --- a/bundle/run/job_test.go +++ b/bundle/run/job_test.go @@ -1,7 +1,6 @@ package run import ( - "bytes" "context" "testing" "time" @@ -159,8 +158,8 @@ func TestJobRunnerRestart(t *testing.T) { m := mocks.NewMockWorkspaceClient(t) b.SetWorkpaceClient(m.WorkspaceClient) - ctx := context.Background() - ctx = cmdio.InContext(ctx, cmdio.NewIO(ctx, flags.OutputText, &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{}, "", "")) + + ctx := cmdio.MockDiscard(context.Background()) ctx = cmdio.NewContext(ctx, cmdio.NewLogger(flags.ModeAppend)) jobApi := m.GetMockJobsAPI() @@ -230,8 +229,8 @@ func TestJobRunnerRestartForContinuousUnpausedJobs(t *testing.T) { m := mocks.NewMockWorkspaceClient(t) b.SetWorkpaceClient(m.WorkspaceClient) - ctx := context.Background() - ctx = cmdio.InContext(ctx, cmdio.NewIO(ctx, flags.OutputText, &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{}, "", "...")) + + ctx := cmdio.MockDiscard(context.Background()) ctx = cmdio.NewContext(ctx, cmdio.NewLogger(flags.ModeAppend)) jobApi := m.GetMockJobsAPI() diff --git a/bundle/run/pipeline_test.go b/bundle/run/pipeline_test.go index bfa0c5846c..56d800d353 100644 --- a/bundle/run/pipeline_test.go +++ b/bundle/run/pipeline_test.go @@ -1,7 +1,6 @@ package run import ( - "bytes" "context" "testing" "time" @@ -75,8 +74,8 @@ func TestPipelineRunnerRestart(t *testing.T) { Host: "https://test.com", } b.SetWorkpaceClient(m.WorkspaceClient) - ctx := context.Background() - ctx = cmdio.InContext(ctx, cmdio.NewIO(ctx, flags.OutputText, &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{}, "", "...")) + + ctx := cmdio.MockDiscard(context.Background()) ctx = cmdio.NewContext(ctx, cmdio.NewLogger(flags.ModeAppend)) mockWait := &pipelines.WaitGetPipelineIdle[struct{}]{ diff --git a/cmd/bundle/init.go b/cmd/bundle/init.go index 687c141eca..fafdf45076 100644 --- a/cmd/bundle/init.go +++ b/cmd/bundle/init.go @@ -1,175 +1,15 @@ package bundle import ( - "context" "errors" "fmt" - "io/fs" - "os" - "path/filepath" - "slices" - "strings" "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/libs/cmdio" - "github.com/databricks/cli/libs/dbr" - "github.com/databricks/cli/libs/filer" - "github.com/databricks/cli/libs/git" "github.com/databricks/cli/libs/template" "github.com/spf13/cobra" ) -var gitUrlPrefixes = []string{ - "https://", - "git@", -} - -type nativeTemplate struct { - name string - gitUrl string - description string - aliases []string - hidden bool -} - -const customTemplate = "custom..." - -var nativeTemplates = []nativeTemplate{ - { - name: "default-python", - description: "The default Python template for Notebooks / Delta Live Tables / Workflows", - }, - { - name: "default-sql", - description: "The default SQL template for .sql files that run with Databricks SQL", - }, - { - name: "dbt-sql", - description: "The dbt SQL template (databricks.com/blog/delivering-cost-effective-data-real-time-dbt-and-databricks)", - }, - { - name: "mlops-stacks", - gitUrl: "https://github.com/databricks/mlops-stacks", - description: "The Databricks MLOps Stacks template (github.com/databricks/mlops-stacks)", - aliases: []string{"mlops-stack"}, - }, - { - name: "default-pydabs", - gitUrl: "https://databricks.github.io/workflows-authoring-toolkit/pydabs-template.git", - hidden: true, - description: "The default PyDABs template", - }, - { - name: customTemplate, - description: "Bring your own template", - }, -} - -// Return template descriptions for command-line help -func nativeTemplateHelpDescriptions() string { - var lines []string - for _, template := range nativeTemplates { - if template.name != customTemplate && !template.hidden { - lines = append(lines, fmt.Sprintf("- %s: %s", template.name, template.description)) - } - } - return strings.Join(lines, "\n") -} - -// Return template options for an interactive prompt -func nativeTemplateOptions() []cmdio.Tuple { - names := make([]cmdio.Tuple, 0, len(nativeTemplates)) - for _, template := range nativeTemplates { - if template.hidden { - continue - } - tuple := cmdio.Tuple{ - Name: template.name, - Id: template.description, - } - names = append(names, tuple) - } - return names -} - -func getNativeTemplateByDescription(description string) string { - for _, template := range nativeTemplates { - if template.description == description { - return template.name - } - } - return "" -} - -func getUrlForNativeTemplate(name string) string { - for _, template := range nativeTemplates { - if template.name == name { - return template.gitUrl - } - if slices.Contains(template.aliases, name) { - return template.gitUrl - } - } - return "" -} - -func getFsForNativeTemplate(name string) (fs.FS, error) { - builtin, err := template.Builtin() - if err != nil { - return nil, err - } - - // If this is a built-in template, the return value will be non-nil. - var templateFS fs.FS - for _, entry := range builtin { - if entry.Name == name { - templateFS = entry.FS - break - } - } - - return templateFS, nil -} - -func isRepoUrl(url string) bool { - result := false - for _, prefix := range gitUrlPrefixes { - if strings.HasPrefix(url, prefix) { - result = true - break - } - } - return result -} - -// Computes the repo name from the repo URL. Treats the last non empty word -// when splitting at '/' as the repo name. For example: for url git@github.com:databricks/cli.git -// the name would be "cli.git" -func repoName(url string) string { - parts := strings.Split(strings.TrimRight(url, "/"), "/") - return parts[len(parts)-1] -} - -func constructOutputFiler(ctx context.Context, outputDir string) (filer.Filer, error) { - outputDir, err := filepath.Abs(outputDir) - if err != nil { - return nil, err - } - - // If the CLI is running on DBR and we're writing to the workspace file system, - // use the extension-aware workspace filesystem filer to instantiate the template. - // - // It is not possible to write notebooks through the workspace filesystem's FUSE mount. - // Therefore this is the only way we can initialize templates that contain notebooks - // when running the CLI on DBR and initializing a template to the workspace. - // - if strings.HasPrefix(outputDir, "/Workspace/") && dbr.RunsOnRuntime(ctx) { - return filer.NewWorkspaceFilesExtensionsClient(root.WorkspaceClient(ctx), outputDir) - } - - return filer.NewLocalClient(outputDir) -} - func newInitCommand() *cobra.Command { cmd := &cobra.Command{ Use: "init [TEMPLATE_PATH]", @@ -182,7 +22,7 @@ TEMPLATE_PATH optionally specifies which template to use. It can be one of the f - a local file system path with a template directory - a Git repository URL, e.g. https://github.com/my/repository -See https://docs.databricks.com/en/dev-tools/bundles/templates.html for more information on templates.`, nativeTemplateHelpDescriptions()), +See https://docs.databricks.com/en/dev-tools/bundles/templates.html for more information on templates.`, template.HelpDescriptions()), } var configFile string @@ -202,88 +42,32 @@ See https://docs.databricks.com/en/dev-tools/bundles/templates.html for more inf return errors.New("only one of --tag or --branch can be specified") } - // Git ref to use for template initialization - ref := branch - if tag != "" { - ref = tag - } - - ctx := cmd.Context() - var templatePath string + var templatePathOrUrl string if len(args) > 0 { - templatePath = args[0] - } else { - var err error - if !cmdio.IsPromptSupported(ctx) { - return errors.New("please specify a template") - } - description, err := cmdio.SelectOrdered(ctx, nativeTemplateOptions(), "Template to use") - if err != nil { - return err - } - templatePath = getNativeTemplateByDescription(description) + templatePathOrUrl = args[0] } - - outputFiler, err := constructOutputFiler(ctx, outputDir) - if err != nil { - return err + r := template.Resolver{ + TemplatePathOrUrl: templatePathOrUrl, + ConfigFile: configFile, + OutputDir: outputDir, + TemplateDir: templateDir, + Tag: tag, + Branch: branch, } - if templatePath == customTemplate { + ctx := cmd.Context() + tmpl, err := r.Resolve(ctx) + if errors.Is(err, template.ErrCustomSelected) { cmdio.LogString(ctx, "Please specify a path or Git repository to use a custom template.") cmdio.LogString(ctx, "See https://docs.databricks.com/en/dev-tools/bundles/templates.html to learn more about custom templates.") return nil } - - // Expand templatePath to a git URL if it's an alias for a known native template - // and we know it's git URL. - if gitUrl := getUrlForNativeTemplate(templatePath); gitUrl != "" { - templatePath = gitUrl - } - - if !isRepoUrl(templatePath) { - if templateDir != "" { - return errors.New("--template-dir can only be used with a Git repository URL") - } - - templateFS, err := getFsForNativeTemplate(templatePath) - if err != nil { - return err - } - - // If this is not a built-in template, then it must be a local file system path. - if templateFS == nil { - templateFS = os.DirFS(templatePath) - } - - // skip downloading the repo because input arg is not a URL. We assume - // it's a path on the local file system in that case - return template.Materialize(ctx, configFile, templateFS, outputFiler) - } - - // Create a temporary directory with the name of the repository. The '*' - // character is replaced by a random string in the generated temporary directory. - repoDir, err := os.MkdirTemp("", repoName(templatePath)+"-*") - if err != nil { - return err - } - - // start the spinner - promptSpinner := cmdio.Spinner(ctx) - promptSpinner <- "Downloading the template\n" - - // TODO: Add automated test that the downloaded git repo is cleaned up. - // Clone the repository in the temporary directory - err = git.Clone(ctx, templatePath, ref, repoDir) - close(promptSpinner) if err != nil { return err } + defer tmpl.Reader.Cleanup() - // Clean up downloaded repository once the template is materialized. - defer os.RemoveAll(repoDir) - templateFS := os.DirFS(filepath.Join(repoDir, templateDir)) - return template.Materialize(ctx, configFile, templateFS, outputFiler) + return tmpl.Writer.Materialize(ctx, tmpl.Reader) } return cmd } diff --git a/integration/bundle/helpers_test.go b/integration/bundle/helpers_test.go index a537ca3517..e89f5e5f71 100644 --- a/integration/bundle/helpers_test.go +++ b/integration/bundle/helpers_test.go @@ -16,7 +16,6 @@ import ( "github.com/databricks/cli/internal/testutil" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/env" - "github.com/databricks/cli/libs/filer" "github.com/databricks/cli/libs/flags" "github.com/databricks/cli/libs/folders" "github.com/databricks/cli/libs/template" @@ -40,10 +39,19 @@ func initTestTemplateWithBundleRoot(t testutil.TestingT, ctx context.Context, te cmd := cmdio.NewIO(ctx, flags.OutputJSON, strings.NewReader(""), os.Stdout, os.Stderr, "", "bundles") ctx = cmdio.InContext(ctx, cmd) - out, err := filer.NewLocalClient(bundleRoot) + r := template.Resolver{ + TemplatePathOrUrl: templateRoot, + ConfigFile: configFilePath, + OutputDir: bundleRoot, + } + + tmpl, err := r.Resolve(ctx) require.NoError(t, err) - err = template.Materialize(ctx, configFilePath, os.DirFS(templateRoot), out) + defer tmpl.Reader.Cleanup() + + err = tmpl.Writer.Materialize(ctx, tmpl.Reader) require.NoError(t, err) + return bundleRoot } diff --git a/libs/cmdio/io.go b/libs/cmdio/io.go index c0e9e868a8..11b75157d7 100644 --- a/libs/cmdio/io.go +++ b/libs/cmdio/io.go @@ -285,3 +285,14 @@ func fromContext(ctx context.Context) *cmdIO { } return io } + +// Mocks the context with a cmdio object that discards all output. +func MockDiscard(ctx context.Context) context.Context { + return InContext(ctx, &cmdIO{ + interactive: false, + outputFormat: flags.OutputText, + in: io.NopCloser(strings.NewReader("")), + out: io.Discard, + err: io.Discard, + }) +} diff --git a/libs/databrickscfg/cfgpickers/clusters_test.go b/libs/databrickscfg/cfgpickers/clusters_test.go index 29e190a935..840916e915 100644 --- a/libs/databrickscfg/cfgpickers/clusters_test.go +++ b/libs/databrickscfg/cfgpickers/clusters_test.go @@ -1,12 +1,10 @@ package cfgpickers import ( - "bytes" "context" "testing" "github.com/databricks/cli/libs/cmdio" - "github.com/databricks/cli/libs/flags" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/qa" "github.com/databricks/databricks-sdk-go/service/compute" @@ -114,8 +112,8 @@ func TestFirstCompatibleCluster(t *testing.T) { defer server.Close() w := databricks.Must(databricks.NewWorkspaceClient((*databricks.Config)(cfg))) - ctx := context.Background() - ctx = cmdio.InContext(ctx, cmdio.NewIO(ctx, flags.OutputText, &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{}, "", "...")) + ctx := cmdio.MockDiscard(context.Background()) + clusterID, err := AskForCluster(ctx, w, WithDatabricksConnect("13.1")) require.NoError(t, err) require.Equal(t, "bcd-id", clusterID) @@ -161,8 +159,7 @@ func TestNoCompatibleClusters(t *testing.T) { defer server.Close() w := databricks.Must(databricks.NewWorkspaceClient((*databricks.Config)(cfg))) - ctx := context.Background() - ctx = cmdio.InContext(ctx, cmdio.NewIO(ctx, flags.OutputText, &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{}, "", "...")) + ctx := cmdio.MockDiscard(context.Background()) _, err := AskForCluster(ctx, w, WithDatabricksConnect("13.1")) require.Equal(t, ErrNoCompatibleClusters, err) } diff --git a/libs/filer/workspace_files_extensions_client.go b/libs/filer/workspace_files_extensions_client.go index 9ee2722e17..0127d180c6 100644 --- a/libs/filer/workspace_files_extensions_client.go +++ b/libs/filer/workspace_files_extensions_client.go @@ -16,7 +16,7 @@ import ( "github.com/databricks/databricks-sdk-go/service/workspace" ) -type workspaceFilesExtensionsClient struct { +type WorkspaceFilesExtensionsClient struct { workspaceClient *databricks.WorkspaceClient wsfs Filer @@ -32,7 +32,7 @@ type workspaceFileStatus struct { nameForWorkspaceAPI string } -func (w *workspaceFilesExtensionsClient) stat(ctx context.Context, name string) (wsfsFileInfo, error) { +func (w *WorkspaceFilesExtensionsClient) stat(ctx context.Context, name string) (wsfsFileInfo, error) { info, err := w.wsfs.Stat(ctx, name) if err != nil { return wsfsFileInfo{}, err @@ -42,7 +42,7 @@ func (w *workspaceFilesExtensionsClient) stat(ctx context.Context, name string) // This function returns the stat for the provided notebook. The stat object itself contains the path // with the extension since it is meant to be used in the context of a fs.FileInfo. -func (w *workspaceFilesExtensionsClient) getNotebookStatByNameWithExt(ctx context.Context, name string) (*workspaceFileStatus, error) { +func (w *WorkspaceFilesExtensionsClient) getNotebookStatByNameWithExt(ctx context.Context, name string) (*workspaceFileStatus, error) { ext := path.Ext(name) nameWithoutExt := strings.TrimSuffix(name, ext) @@ -104,7 +104,7 @@ func (w *workspaceFilesExtensionsClient) getNotebookStatByNameWithExt(ctx contex }, nil } -func (w *workspaceFilesExtensionsClient) getNotebookStatByNameWithoutExt(ctx context.Context, name string) (*workspaceFileStatus, error) { +func (w *WorkspaceFilesExtensionsClient) getNotebookStatByNameWithoutExt(ctx context.Context, name string) (*workspaceFileStatus, error) { stat, err := w.stat(ctx, name) if err != nil { return nil, err @@ -184,7 +184,7 @@ func newWorkspaceFilesExtensionsClient(w *databricks.WorkspaceClient, root strin filer = newWorkspaceFilesReadaheadCache(filer) } - return &workspaceFilesExtensionsClient{ + return &WorkspaceFilesExtensionsClient{ workspaceClient: w, wsfs: filer, @@ -193,7 +193,7 @@ func newWorkspaceFilesExtensionsClient(w *databricks.WorkspaceClient, root strin }, nil } -func (w *workspaceFilesExtensionsClient) ReadDir(ctx context.Context, name string) ([]fs.DirEntry, error) { +func (w *WorkspaceFilesExtensionsClient) ReadDir(ctx context.Context, name string) ([]fs.DirEntry, error) { entries, err := w.wsfs.ReadDir(ctx, name) if err != nil { return nil, err @@ -235,7 +235,7 @@ func (w *workspaceFilesExtensionsClient) ReadDir(ctx context.Context, name strin // Note: The import API returns opaque internal errors for namespace clashes // (e.g. a file and a notebook or a directory and a notebook). Thus users of this // method should be careful to avoid such clashes. -func (w *workspaceFilesExtensionsClient) Write(ctx context.Context, name string, reader io.Reader, mode ...WriteMode) error { +func (w *WorkspaceFilesExtensionsClient) Write(ctx context.Context, name string, reader io.Reader, mode ...WriteMode) error { if w.readonly { return ReadOnlyError{"write"} } @@ -244,7 +244,7 @@ func (w *workspaceFilesExtensionsClient) Write(ctx context.Context, name string, } // Try to read the file as a regular file. If the file is not found, try to read it as a notebook. -func (w *workspaceFilesExtensionsClient) Read(ctx context.Context, name string) (io.ReadCloser, error) { +func (w *WorkspaceFilesExtensionsClient) Read(ctx context.Context, name string) (io.ReadCloser, error) { // Ensure that the file / notebook exists. We do this check here to avoid reading // the content of a notebook called `foo` when the user actually wanted // to read the content of a file called `foo`. @@ -283,7 +283,7 @@ func (w *workspaceFilesExtensionsClient) Read(ctx context.Context, name string) } // Try to delete the file as a regular file. If the file is not found, try to delete it as a notebook. -func (w *workspaceFilesExtensionsClient) Delete(ctx context.Context, name string, mode ...DeleteMode) error { +func (w *WorkspaceFilesExtensionsClient) Delete(ctx context.Context, name string, mode ...DeleteMode) error { if w.readonly { return ReadOnlyError{"delete"} } @@ -320,7 +320,7 @@ func (w *workspaceFilesExtensionsClient) Delete(ctx context.Context, name string } // Try to stat the file as a regular file. If the file is not found, try to stat it as a notebook. -func (w *workspaceFilesExtensionsClient) Stat(ctx context.Context, name string) (fs.FileInfo, error) { +func (w *WorkspaceFilesExtensionsClient) Stat(ctx context.Context, name string) (fs.FileInfo, error) { info, err := w.wsfs.Stat(ctx, name) // If the file is not found, it might be a notebook. @@ -361,7 +361,7 @@ func (w *workspaceFilesExtensionsClient) Stat(ctx context.Context, name string) // Note: The import API returns opaque internal errors for namespace clashes // (e.g. a file and a notebook or a directory and a notebook). Thus users of this // method should be careful to avoid such clashes. -func (w *workspaceFilesExtensionsClient) Mkdir(ctx context.Context, name string) error { +func (w *WorkspaceFilesExtensionsClient) Mkdir(ctx context.Context, name string) error { if w.readonly { return ReadOnlyError{"mkdir"} } diff --git a/libs/filer/workspace_files_extensions_client_test.go b/libs/filer/workspace_files_extensions_client_test.go index 10a2bebf0a..9ea837fa99 100644 --- a/libs/filer/workspace_files_extensions_client_test.go +++ b/libs/filer/workspace_files_extensions_client_test.go @@ -181,7 +181,7 @@ func TestFilerWorkspaceFilesExtensionsErrorsOnDupName(t *testing.T) { root: NewWorkspaceRootPath("/dir"), } - workspaceFilesExtensionsClient := workspaceFilesExtensionsClient{ + workspaceFilesExtensionsClient := WorkspaceFilesExtensionsClient{ workspaceClient: mockedWorkspaceClient.WorkspaceClient, wsfs: &workspaceFilesClient, } diff --git a/libs/template/builtin.go b/libs/template/builtin.go index dcb3a88582..5b10534ef5 100644 --- a/libs/template/builtin.go +++ b/libs/template/builtin.go @@ -8,14 +8,14 @@ import ( //go:embed all:templates var builtinTemplates embed.FS -// BuiltinTemplate represents a template that is built into the CLI. -type BuiltinTemplate struct { +// builtinTemplate represents a template that is built into the CLI. +type builtinTemplate struct { Name string FS fs.FS } -// Builtin returns the list of all built-in templates. -func Builtin() ([]BuiltinTemplate, error) { +// builtin returns the list of all built-in templates. +func builtin() ([]builtinTemplate, error) { templates, err := fs.Sub(builtinTemplates, "templates") if err != nil { return nil, err @@ -26,7 +26,7 @@ func Builtin() ([]BuiltinTemplate, error) { return nil, err } - var out []BuiltinTemplate + var out []builtinTemplate for _, entry := range entries { if !entry.IsDir() { continue @@ -37,7 +37,7 @@ func Builtin() ([]BuiltinTemplate, error) { return nil, err } - out = append(out, BuiltinTemplate{ + out = append(out, builtinTemplate{ Name: entry.Name(), FS: templateFS, }) diff --git a/libs/template/builtin_test.go b/libs/template/builtin_test.go index 79e04cb841..162a227ea9 100644 --- a/libs/template/builtin_test.go +++ b/libs/template/builtin_test.go @@ -9,12 +9,12 @@ import ( ) func TestBuiltin(t *testing.T) { - out, err := Builtin() + out, err := builtin() require.NoError(t, err) assert.GreaterOrEqual(t, len(out), 3) // Create a map of templates by name for easier lookup - templates := make(map[string]*BuiltinTemplate) + templates := make(map[string]*builtinTemplate) for _, tmpl := range out { templates[tmpl.Name] = &tmpl } diff --git a/libs/template/materialize.go b/libs/template/materialize.go deleted file mode 100644 index 86a6a8c37a..0000000000 --- a/libs/template/materialize.go +++ /dev/null @@ -1,94 +0,0 @@ -package template - -import ( - "context" - "errors" - "fmt" - "io/fs" - - "github.com/databricks/cli/libs/cmdio" - "github.com/databricks/cli/libs/filer" -) - -const ( - libraryDirName = "library" - templateDirName = "template" - schemaFileName = "databricks_template_schema.json" -) - -// This function materializes the input templates as a project, using user defined -// configurations. -// Parameters: -// -// ctx: context containing a cmdio object. This is used to prompt the user -// configFilePath: file path containing user defined config values -// templateFS: root of the template definition -// outputFiler: filer to use for writing the initialized template -func Materialize(ctx context.Context, configFilePath string, templateFS fs.FS, outputFiler filer.Filer) error { - if _, err := fs.Stat(templateFS, schemaFileName); errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("not a bundle template: expected to find a template schema file at %s", schemaFileName) - } - - config, err := newConfig(ctx, templateFS, schemaFileName) - if err != nil { - return err - } - - // Read and assign config values from file - if configFilePath != "" { - err = config.assignValuesFromFile(configFilePath) - if err != nil { - return err - } - } - - helpers := loadHelpers(ctx) - r, err := newRenderer(ctx, config.values, helpers, templateFS, templateDirName, libraryDirName) - if err != nil { - return err - } - - // Print welcome message - welcome := config.schema.WelcomeMessage - if welcome != "" { - welcome, err = r.executeTemplate(welcome) - if err != nil { - return err - } - cmdio.LogString(ctx, welcome) - } - - // Prompt user for any missing config values. Assign default values if - // terminal is not TTY - err = config.promptOrAssignDefaultValues(r) - if err != nil { - return err - } - err = config.validate() - if err != nil { - return err - } - - // Walk and render the template, since input configuration is complete - err = r.walk() - if err != nil { - return err - } - - err = r.persistToDisk(ctx, outputFiler) - if err != nil { - return err - } - - success := config.schema.SuccessMessage - if success == "" { - cmdio.LogString(ctx, "✨ Successfully initialized template") - } else { - success, err = r.executeTemplate(success) - if err != nil { - return err - } - cmdio.LogString(ctx, success) - } - return nil -} diff --git a/libs/template/materialize_test.go b/libs/template/materialize_test.go deleted file mode 100644 index c9331b43fe..0000000000 --- a/libs/template/materialize_test.go +++ /dev/null @@ -1,23 +0,0 @@ -package template - -import ( - "context" - "os" - "testing" - - "github.com/databricks/cli/cmd/root" - "github.com/databricks/databricks-sdk-go" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestMaterializeForNonTemplateDirectory(t *testing.T) { - tmpDir := t.TempDir() - w, err := databricks.NewWorkspaceClient(&databricks.Config{}) - require.NoError(t, err) - ctx := root.SetWorkspaceClient(context.Background(), w) - - // Try to materialize a non-template directory. - err = Materialize(ctx, "", os.DirFS(tmpDir), nil) - assert.EqualError(t, err, "not a bundle template: expected to find a template schema file at "+schemaFileName) -} diff --git a/libs/template/reader.go b/libs/template/reader.go new file mode 100644 index 0000000000..7dc135919e --- /dev/null +++ b/libs/template/reader.go @@ -0,0 +1,113 @@ +package template + +import ( + "context" + "fmt" + "io/fs" + "os" + "path/filepath" + "strings" + + "github.com/databricks/cli/libs/cmdio" +) + +type Reader interface { + // FS returns a file system that contains the template + // definition files. + FS(ctx context.Context) (fs.FS, error) + + // Cleanup releases any resources associated with the reader + // like cleaning up temporary directories. + Cleanup() +} + +type builtinReader struct { + name string +} + +func (r *builtinReader) FS(ctx context.Context) (fs.FS, error) { + builtin, err := builtin() + if err != nil { + return nil, err + } + + var templateFS fs.FS + for _, entry := range builtin { + if entry.Name == r.name { + templateFS = entry.FS + break + } + } + + if templateFS == nil { + return nil, fmt.Errorf("builtin template %s not found", r.name) + } + + return templateFS, nil +} + +func (r *builtinReader) Cleanup() {} + +type gitReader struct { + gitUrl string + // tag or branch to checkout + ref string + // subdirectory within the repository that contains the template + templateDir string + // temporary directory where the repository is cloned + tmpRepoDir string + + // Function to clone the repository. This is a function pointer to allow + // mocking in tests. + cloneFunc func(ctx context.Context, url, reference, targetPath string) error +} + +// Computes the repo name from the repo URL. Treats the last non empty word +// when splitting at '/' as the repo name. For example: for url git@github.com:databricks/cli.git +// the name would be "cli.git" +func repoName(url string) string { + parts := strings.Split(strings.TrimRight(url, "/"), "/") + return parts[len(parts)-1] +} + +func (r *gitReader) FS(ctx context.Context) (fs.FS, error) { + // Create a temporary directory with the name of the repository. The '*' + // character is replaced by a random string in the generated temporary directory. + repoDir, err := os.MkdirTemp("", repoName(r.gitUrl)+"-*") + if err != nil { + return nil, err + } + r.tmpRepoDir = repoDir + + // start the spinner + promptSpinner := cmdio.Spinner(ctx) + promptSpinner <- "Downloading the template\n" + + err = r.cloneFunc(ctx, r.gitUrl, r.ref, repoDir) + close(promptSpinner) + if err != nil { + return nil, err + } + + return os.DirFS(filepath.Join(repoDir, r.templateDir)), nil +} + +func (r *gitReader) Cleanup() { + if r.tmpRepoDir == "" { + return + } + + // Cleanup is best effort. Ignore errors. + os.RemoveAll(r.tmpRepoDir) +} + +type localReader struct { + // Path on the local filesystem that contains the template + path string +} + +func (r *localReader) FS(ctx context.Context) (fs.FS, error) { + return os.DirFS(r.path), nil +} + +func (r *localReader) Cleanup() {} diff --git a/libs/template/reader_test.go b/libs/template/reader_test.go new file mode 100644 index 0000000000..e304727717 --- /dev/null +++ b/libs/template/reader_test.go @@ -0,0 +1,107 @@ +package template + +import ( + "context" + "io" + "io/fs" + "path/filepath" + "testing" + + "github.com/databricks/cli/internal/testutil" + "github.com/databricks/cli/libs/cmdio" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBuiltInReader(t *testing.T) { + exists := []string{ + "default-python", + "default-sql", + "dbt-sql", + } + + for _, name := range exists { + t.Run(name, func(t *testing.T) { + r := &builtinReader{name: name} + fs, err := r.FS(context.Background()) + assert.NoError(t, err) + assert.NotNil(t, fs) + + // Assert file content returned is accurate and every template has a welcome + // message defined. + fd, err := fs.Open("databricks_template_schema.json") + require.NoError(t, err) + b, err := io.ReadAll(fd) + require.NoError(t, err) + assert.Contains(t, string(b), "welcome_message") + assert.NoError(t, fd.Close()) + }) + } + + t.Run("doesnotexist", func(t *testing.T) { + r := &builtinReader{name: "doesnotexist"} + _, err := r.FS(context.Background()) + assert.EqualError(t, err, "builtin template doesnotexist not found") + }) +} + +func TestGitUrlReader(t *testing.T) { + ctx := cmdio.MockDiscard(context.Background()) + + var args []string + numCalls := 0 + cloneFunc := func(ctx context.Context, url, reference, targetPath string) error { + numCalls++ + args = []string{url, reference, targetPath} + testutil.WriteFile(t, filepath.Join(targetPath, "a", "b", "c", "somefile"), "somecontent") + return nil + } + r := &gitReader{ + gitUrl: "someurl", + cloneFunc: cloneFunc, + ref: "sometag", + templateDir: "a/b/c", + } + + // Assert cloneFunc is called with the correct args. + fsys, err := r.FS(ctx) + require.NoError(t, err) + require.NotEmpty(t, r.tmpRepoDir) + assert.Equal(t, 1, numCalls) + assert.DirExists(t, r.tmpRepoDir) + assert.Equal(t, []string{"someurl", "sometag", r.tmpRepoDir}, args) + + // Assert the fs returned is rooted at the templateDir. + fd, err := fsys.Open("somefile") + require.NoError(t, err) + b, err := io.ReadAll(fd) + require.NoError(t, err) + assert.Equal(t, "somecontent", string(b)) + assert.NoError(t, fd.Close()) + + // Assert the downloaded repository is cleaned up. + fd, err = fsys.Open(".") + require.NoError(t, err) + require.NoError(t, fd.Close()) + r.Cleanup() + _, err = fsys.Open(".") + assert.ErrorIs(t, err, fs.ErrNotExist) +} + +func TestLocalReader(t *testing.T) { + tmpDir := t.TempDir() + testutil.WriteFile(t, filepath.Join(tmpDir, "somefile"), "somecontent") + ctx := context.Background() + + r := &localReader{path: tmpDir} + fs, err := r.FS(ctx) + require.NoError(t, err) + + // Assert the fs returned is rooted at correct location. + fd, err := fs.Open("somefile") + require.NoError(t, err) + b, err := io.ReadAll(fd) + require.NoError(t, err) + assert.Equal(t, "somecontent", string(b)) + assert.NoError(t, fd.Close()) +} diff --git a/libs/template/resolver.go b/libs/template/resolver.go new file mode 100644 index 0000000000..2cc8bf1c75 --- /dev/null +++ b/libs/template/resolver.go @@ -0,0 +1,122 @@ +package template + +import ( + "context" + "errors" + "strings" + + "github.com/databricks/cli/libs/git" +) + +var gitUrlPrefixes = []string{ + "https://", + "git@", +} + +func isRepoUrl(url string) bool { + result := false + for _, prefix := range gitUrlPrefixes { + if strings.HasPrefix(url, prefix) { + result = true + break + } + } + return result +} + +type Resolver struct { + // One of the following three: + // 1. Path to a local template directory. + // 2. URL to a Git repository containing a template. + // 3. Name of a built-in template. + TemplatePathOrUrl string + + // Path to a JSON file containing the configuration values to be used for + // template initialization. + ConfigFile string + + // Directory to write the initialized template to. + OutputDir string + + // Directory path within a Git repository containing the template. + TemplateDir string + + // Git tag or branch to download the template from. Only one of these can be + // specified. + Tag string + Branch string +} + +// ErrCustomSelected is returned when the user selects the "custom..." option +// in the prompt UI when they run `databricks bundle init`. This error signals +// the upstream callsite to show documentation to the user on how to use a custom +// template. +var ErrCustomSelected = errors.New("custom template selected") + +// Configures the reader and the writer for template and returns +// a handle to the template. +// Prompts the user if needed. +func (r Resolver) Resolve(ctx context.Context) (*Template, error) { + if r.Tag != "" && r.Branch != "" { + return nil, errors.New("only one of tag or branch can be specified") + } + + // Git ref to use for template initialization + ref := r.Branch + if r.Tag != "" { + ref = r.Tag + } + + var err error + var templateName TemplateName + + if r.TemplatePathOrUrl == "" { + // Prompt the user to select a template + // if a template path or URL is not provided. + templateName, err = SelectTemplate(ctx) + if err != nil { + return nil, err + } + } else { + templateName = TemplateName(r.TemplatePathOrUrl) + } + + tmpl := GetDatabricksTemplate(templateName) + + // If we could not find a databricks template with the name provided by the user, + // then we assume that the user provided us with a reference to a custom template. + // + // This reference could be one of: + // 1. Path to a local template directory. + // 2. URL to a Git repository containing a template. + // + // We resolve the appropriate reader according to the reference provided by the user. + if tmpl == nil { + tmpl = &Template{ + name: Custom, + // We use a writer that does not log verbose telemetry for custom templates. + // This is important because template definitions can contain PII that we + // do not want to centralize. + Writer: &defaultWriter{}, + } + + if isRepoUrl(r.TemplatePathOrUrl) { + tmpl.Reader = &gitReader{ + gitUrl: r.TemplatePathOrUrl, + ref: ref, + templateDir: r.TemplateDir, + cloneFunc: git.Clone, + } + } else { + tmpl.Reader = &localReader{ + path: r.TemplatePathOrUrl, + } + } + } + err = tmpl.Writer.Configure(ctx, r.ConfigFile, r.OutputDir) + if err != nil { + return nil, err + } + + return tmpl, nil +} diff --git a/libs/template/resolver_test.go b/libs/template/resolver_test.go new file mode 100644 index 0000000000..1dee1c45fe --- /dev/null +++ b/libs/template/resolver_test.go @@ -0,0 +1,110 @@ +package template + +import ( + "context" + "testing" + + "github.com/databricks/cli/libs/cmdio" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTemplateResolverBothTagAndBranch(t *testing.T) { + r := Resolver{ + Tag: "tag", + Branch: "branch", + } + + _, err := r.Resolve(context.Background()) + assert.EqualError(t, err, "only one of tag or branch can be specified") +} + +func TestTemplateResolverErrorsWhenPromptingIsNotSupported(t *testing.T) { + r := Resolver{} + ctx := cmdio.MockDiscard(context.Background()) + + _, err := r.Resolve(ctx) + assert.EqualError(t, err, "prompting is not supported. Please specify the path, name or URL of the template to use") +} + +func TestTemplateResolverForDefaultTemplates(t *testing.T) { + for _, name := range []string{ + "default-python", + "default-sql", + "dbt-sql", + } { + t.Run(name, func(t *testing.T) { + r := Resolver{ + TemplatePathOrUrl: name, + } + + tmpl, err := r.Resolve(context.Background()) + require.NoError(t, err) + + assert.Equal(t, &builtinReader{name: name}, tmpl.Reader) + assert.IsType(t, &writerWithFullTelemetry{}, tmpl.Writer) + }) + } + + t.Run("mlops-stacks", func(t *testing.T) { + r := Resolver{ + TemplatePathOrUrl: "mlops-stacks", + ConfigFile: "/config/file", + } + + tmpl, err := r.Resolve(context.Background()) + require.NoError(t, err) + + // Assert reader and writer configuration + assert.Equal(t, "https://github.com/databricks/mlops-stacks", tmpl.Reader.(*gitReader).gitUrl) + assert.Equal(t, "/config/file", tmpl.Writer.(*writerWithFullTelemetry).configPath) + }) +} + +func TestTemplateResolverForCustomUrl(t *testing.T) { + r := Resolver{ + TemplatePathOrUrl: "https://www.example.com/abc", + Tag: "tag", + TemplateDir: "/template/dir", + ConfigFile: "/config/file", + } + + tmpl, err := r.Resolve(context.Background()) + require.NoError(t, err) + + assert.Equal(t, Custom, tmpl.name) + + // Assert reader configuration + assert.Equal(t, "https://www.example.com/abc", tmpl.Reader.(*gitReader).gitUrl) + assert.Equal(t, "tag", tmpl.Reader.(*gitReader).ref) + assert.Equal(t, "/template/dir", tmpl.Reader.(*gitReader).templateDir) + + // Assert writer configuration + assert.Equal(t, "/config/file", tmpl.Writer.(*defaultWriter).configPath) +} + +func TestTemplateResolverForCustomPath(t *testing.T) { + r := Resolver{ + TemplatePathOrUrl: "/custom/path", + ConfigFile: "/config/file", + } + + tmpl, err := r.Resolve(context.Background()) + require.NoError(t, err) + + assert.Equal(t, Custom, tmpl.name) + + // Assert reader configuration + assert.Equal(t, "/custom/path", tmpl.Reader.(*localReader).path) + + // Assert writer configuration + assert.Equal(t, "/config/file", tmpl.Writer.(*defaultWriter).configPath) +} + +func TestBundleInitIsRepoUrl(t *testing.T) { + assert.True(t, isRepoUrl("git@github.com:databricks/cli.git")) + assert.True(t, isRepoUrl("https://github.com/databricks/cli.git")) + + assert.False(t, isRepoUrl("./local")) + assert.False(t, isRepoUrl("foo")) +} diff --git a/libs/template/template.go b/libs/template/template.go new file mode 100644 index 0000000000..e82e52840b --- /dev/null +++ b/libs/template/template.go @@ -0,0 +1,132 @@ +package template + +import ( + "context" + "errors" + "fmt" + "slices" + "strings" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/git" +) + +type Template struct { + Reader Reader + Writer Writer + + name TemplateName + description string + aliases []string + hidden bool +} + +type TemplateName string + +const ( + DefaultPython TemplateName = "default-python" + DefaultSql TemplateName = "default-sql" + DbtSql TemplateName = "dbt-sql" + MlopsStacks TemplateName = "mlops-stacks" + DefaultPydabs TemplateName = "default-pydabs" + Custom TemplateName = "custom" +) + +var databricksTemplates = []Template{ + { + name: DefaultPython, + description: "The default Python template for Notebooks / Delta Live Tables / Workflows", + Reader: &builtinReader{name: string(DefaultPython)}, + Writer: &writerWithFullTelemetry{}, + }, + { + name: DefaultSql, + description: "The default SQL template for .sql files that run with Databricks SQL", + Reader: &builtinReader{name: string(DefaultSql)}, + Writer: &writerWithFullTelemetry{}, + }, + { + name: DbtSql, + description: "The dbt SQL template (databricks.com/blog/delivering-cost-effective-data-real-time-dbt-and-databricks)", + Reader: &builtinReader{name: string(DbtSql)}, + Writer: &writerWithFullTelemetry{}, + }, + { + name: MlopsStacks, + description: "The Databricks MLOps Stacks template (github.com/databricks/mlops-stacks)", + aliases: []string{"mlops-stack"}, + Reader: &gitReader{gitUrl: "https://github.com/databricks/mlops-stacks", cloneFunc: git.Clone}, + Writer: &writerWithFullTelemetry{}, + }, + { + name: DefaultPydabs, + hidden: true, + description: "The default PyDABs template", + Reader: &gitReader{gitUrl: "https://databricks.github.io/workflows-authoring-toolkit/pydabs-template.git", cloneFunc: git.Clone}, + Writer: &writerWithFullTelemetry{}, + }, +} + +func HelpDescriptions() string { + var lines []string + for _, template := range databricksTemplates { + if template.name != Custom && !template.hidden { + lines = append(lines, fmt.Sprintf("- %s: %s", template.name, template.description)) + } + } + return strings.Join(lines, "\n") +} + +var customTemplateDescription = "Bring your own template" + +func options() []cmdio.Tuple { + names := make([]cmdio.Tuple, 0, len(databricksTemplates)) + for _, template := range databricksTemplates { + if template.hidden { + continue + } + tuple := cmdio.Tuple{ + Name: string(template.name), + Id: template.description, + } + names = append(names, tuple) + } + + names = append(names, cmdio.Tuple{ + Name: "custom...", + Id: customTemplateDescription, + }) + return names +} + +func SelectTemplate(ctx context.Context) (TemplateName, error) { + if !cmdio.IsPromptSupported(ctx) { + return "", errors.New("prompting is not supported. Please specify the path, name or URL of the template to use") + } + description, err := cmdio.SelectOrdered(ctx, options(), "Template to use") + if err != nil { + return "", err + } + + if description == customTemplateDescription { + return TemplateName(""), ErrCustomSelected + } + + for _, template := range databricksTemplates { + if template.description == description { + return template.name, nil + } + } + + panic("this should never happen - template not found") +} + +func GetDatabricksTemplate(name TemplateName) *Template { + for _, template := range databricksTemplates { + if template.name == name || slices.Contains(template.aliases, string(name)) { + return &template + } + } + + return nil +} diff --git a/cmd/bundle/init_test.go b/libs/template/template_test.go similarity index 59% rename from cmd/bundle/init_test.go rename to libs/template/template_test.go index 475b2e1499..80391e58b7 100644 --- a/cmd/bundle/init_test.go +++ b/libs/template/template_test.go @@ -1,4 +1,4 @@ -package bundle +package template import ( "testing" @@ -7,12 +7,23 @@ import ( "github.com/stretchr/testify/assert" ) -func TestBundleInitIsRepoUrl(t *testing.T) { - assert.True(t, isRepoUrl("git@github.com:databricks/cli.git")) - assert.True(t, isRepoUrl("https://github.com/databricks/cli.git")) +func TestTemplateHelpDescriptions(t *testing.T) { + expected := `- default-python: The default Python template for Notebooks / Delta Live Tables / Workflows +- default-sql: The default SQL template for .sql files that run with Databricks SQL +- dbt-sql: The dbt SQL template (databricks.com/blog/delivering-cost-effective-data-real-time-dbt-and-databricks) +- mlops-stacks: The Databricks MLOps Stacks template (github.com/databricks/mlops-stacks)` + assert.Equal(t, expected, HelpDescriptions()) +} - assert.False(t, isRepoUrl("./local")) - assert.False(t, isRepoUrl("foo")) +func TestTemplateOptions(t *testing.T) { + expected := []cmdio.Tuple{ + {Name: "default-python", Id: "The default Python template for Notebooks / Delta Live Tables / Workflows"}, + {Name: "default-sql", Id: "The default SQL template for .sql files that run with Databricks SQL"}, + {Name: "dbt-sql", Id: "The dbt SQL template (databricks.com/blog/delivering-cost-effective-data-real-time-dbt-and-databricks)"}, + {Name: "mlops-stacks", Id: "The Databricks MLOps Stacks template (github.com/databricks/mlops-stacks)"}, + {Name: "custom...", Id: "Bring your own template"}, + } + assert.Equal(t, expected, options()) } func TestBundleInitRepoName(t *testing.T) { @@ -27,28 +38,41 @@ func TestBundleInitRepoName(t *testing.T) { assert.Equal(t, "www.github.com", repoName("https://www.github.com")) } -func TestNativeTemplateOptions(t *testing.T) { - expected := []cmdio.Tuple{ - {Name: "default-python", Id: "The default Python template for Notebooks / Delta Live Tables / Workflows"}, - {Name: "default-sql", Id: "The default SQL template for .sql files that run with Databricks SQL"}, - {Name: "dbt-sql", Id: "The dbt SQL template (databricks.com/blog/delivering-cost-effective-data-real-time-dbt-and-databricks)"}, - {Name: "mlops-stacks", Id: "The Databricks MLOps Stacks template (github.com/databricks/mlops-stacks)"}, - {Name: "custom...", Id: "Bring your own template"}, +func TestTemplateTelemetryIsCapturedForAllDefaultTemplates(t *testing.T) { + for _, tmpl := range databricksTemplates { + w := tmpl.Writer + + // Assert telemetry is captured for all databricks templates, i.e. templates + // owned by databricks. + assert.IsType(t, &writerWithFullTelemetry{}, w) } - assert.Equal(t, expected, nativeTemplateOptions()) } -func TestNativeTemplateHelpDescriptions(t *testing.T) { - expected := `- default-python: The default Python template for Notebooks / Delta Live Tables / Workflows -- default-sql: The default SQL template for .sql files that run with Databricks SQL -- dbt-sql: The dbt SQL template (databricks.com/blog/delivering-cost-effective-data-real-time-dbt-and-databricks) -- mlops-stacks: The Databricks MLOps Stacks template (github.com/databricks/mlops-stacks)` - assert.Equal(t, expected, nativeTemplateHelpDescriptions()) -} +func TestTemplateGetDatabricksTemplate(t *testing.T) { + names := []TemplateName{ + DefaultPython, + DefaultSql, + DbtSql, + MlopsStacks, + DefaultPydabs, + } + + for _, name := range names { + tmpl := GetDatabricksTemplate(name) + assert.Equal(t, tmpl.name, name) + } + + notExist := []string{ + "/some/path", + "doesnotexist", + "https://www.someurl.com", + } + + for _, name := range notExist { + tmpl := GetDatabricksTemplate(TemplateName(name)) + assert.Nil(t, tmpl) + } -func TestGetUrlForNativeTemplate(t *testing.T) { - assert.Equal(t, "https://github.com/databricks/mlops-stacks", getUrlForNativeTemplate("mlops-stacks")) - assert.Equal(t, "https://github.com/databricks/mlops-stacks", getUrlForNativeTemplate("mlops-stack")) - assert.Equal(t, "", getUrlForNativeTemplate("default-python")) - assert.Equal(t, "", getUrlForNativeTemplate("invalid")) + // Assert the alias works. + assert.Equal(t, MlopsStacks, GetDatabricksTemplate(TemplateName("mlops-stack")).name) } diff --git a/libs/template/writer.go b/libs/template/writer.go new file mode 100644 index 0000000000..e3d5af5835 --- /dev/null +++ b/libs/template/writer.go @@ -0,0 +1,171 @@ +package template + +import ( + "context" + "errors" + "fmt" + "io/fs" + "path/filepath" + "strings" + + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/dbr" + "github.com/databricks/cli/libs/filer" +) + +const ( + libraryDirName = "library" + templateDirName = "template" + schemaFileName = "databricks_template_schema.json" +) + +type Writer interface { + // Configure the writer with: + // 1. The path to the config file (if any) that contains input values for the + // template. + // 2. The output directory where the template will be materialized. + Configure(ctx context.Context, configPath, outputDir string) error + + // Materialize the template to the local file system. + Materialize(ctx context.Context, r Reader) error +} + +type defaultWriter struct { + configPath string + outputFiler filer.Filer + + // Internal state + config *config + renderer *renderer +} + +func constructOutputFiler(ctx context.Context, outputDir string) (filer.Filer, error) { + outputDir, err := filepath.Abs(outputDir) + if err != nil { + return nil, err + } + + // If the CLI is running on DBR and we're writing to the workspace file system, + // use the extension-aware workspace filesystem filer to instantiate the template. + // + // It is not possible to write notebooks through the workspace filesystem's FUSE mount. + // Therefore this is the only way we can initialize templates that contain notebooks + // when running the CLI on DBR and initializing a template to the workspace. + // + if strings.HasPrefix(outputDir, "/Workspace/") && dbr.RunsOnRuntime(ctx) { + return filer.NewWorkspaceFilesExtensionsClient(root.WorkspaceClient(ctx), outputDir) + } + + return filer.NewLocalClient(outputDir) +} + +func (tmpl *defaultWriter) Configure(ctx context.Context, configPath, outputDir string) error { + tmpl.configPath = configPath + + outputFiler, err := constructOutputFiler(ctx, outputDir) + if err != nil { + return err + } + + tmpl.outputFiler = outputFiler + return nil +} + +func (tmpl *defaultWriter) promptForInput(ctx context.Context, reader Reader) error { + readerFs, err := reader.FS(ctx) + if err != nil { + return err + } + if _, err := fs.Stat(readerFs, schemaFileName); errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("not a bundle template: expected to find a template schema file at %s", schemaFileName) + } + + tmpl.config, err = newConfig(ctx, readerFs, schemaFileName) + if err != nil { + return err + } + + // Read and assign config values from file + if tmpl.configPath != "" { + err = tmpl.config.assignValuesFromFile(tmpl.configPath) + if err != nil { + return err + } + } + + helpers := loadHelpers(ctx) + tmpl.renderer, err = newRenderer(ctx, tmpl.config.values, helpers, readerFs, templateDirName, libraryDirName) + if err != nil { + return err + } + + // Print welcome message + welcome := tmpl.config.schema.WelcomeMessage + if welcome != "" { + welcome, err = tmpl.renderer.executeTemplate(welcome) + if err != nil { + return err + } + cmdio.LogString(ctx, welcome) + } + + // Prompt user for any missing config values. Assign default values if + // terminal is not TTY + err = tmpl.config.promptOrAssignDefaultValues(tmpl.renderer) + if err != nil { + return err + } + return tmpl.config.validate() +} + +func (tmpl *defaultWriter) printSuccessMessage(ctx context.Context) error { + success := tmpl.config.schema.SuccessMessage + if success == "" { + cmdio.LogString(ctx, "✨ Successfully initialized template") + return nil + } + + success, err := tmpl.renderer.executeTemplate(success) + if err != nil { + return err + } + cmdio.LogString(ctx, success) + return nil +} + +func (tmpl *defaultWriter) Materialize(ctx context.Context, reader Reader) error { + err := tmpl.promptForInput(ctx, reader) + if err != nil { + return err + } + + // Walk the template file tree and compute in-memory representations of the + // output files. + err = tmpl.renderer.walk() + if err != nil { + return err + } + + // Flush the output files to disk. + err = tmpl.renderer.persistToDisk(ctx, tmpl.outputFiler) + if err != nil { + return err + } + + return tmpl.printSuccessMessage(ctx) +} + +func (tmpl *defaultWriter) LogTelemetry(ctx context.Context) error { + // TODO, only log the template name and uuid. + return nil +} + +type writerWithFullTelemetry struct { + defaultWriter +} + +func (tmpl *writerWithFullTelemetry) LogTelemetry(ctx context.Context) error { + // TODO, log template name, uuid and enum args as well. + return nil +} diff --git a/libs/template/writer_test.go b/libs/template/writer_test.go new file mode 100644 index 0000000000..9d57966ee2 --- /dev/null +++ b/libs/template/writer_test.go @@ -0,0 +1,58 @@ +package template + +import ( + "context" + "runtime" + "testing" + + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/dbr" + "github.com/databricks/cli/libs/filer" + "github.com/databricks/databricks-sdk-go" + workspaceConfig "github.com/databricks/databricks-sdk-go/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDefaultWriterConfigure(t *testing.T) { + // Test on local file system. + w := &defaultWriter{} + err := w.Configure(context.Background(), "/foo/bar", "/out/abc") + assert.NoError(t, err) + + assert.Equal(t, "/foo/bar", w.configPath) + assert.IsType(t, &filer.LocalClient{}, w.outputFiler) +} + +func TestDefaultWriterConfigureOnDBR(t *testing.T) { + // This test is not valid on windows because a DBR image is always based on + // Linux. + if runtime.GOOS == "windows" { + t.Skip("Skipping test on Windows") + } + + ctx := dbr.MockRuntime(context.Background(), true) + ctx = root.SetWorkspaceClient(ctx, &databricks.WorkspaceClient{ + Config: &workspaceConfig.Config{Host: "https://myhost.com"}, + }) + w := &defaultWriter{} + err := w.Configure(ctx, "/foo/bar", "/Workspace/out/abc") + assert.NoError(t, err) + + assert.Equal(t, "/foo/bar", w.configPath) + assert.IsType(t, &filer.WorkspaceFilesExtensionsClient{}, w.outputFiler) +} + +func TestMaterializeForNonTemplateDirectory(t *testing.T) { + tmpDir1 := t.TempDir() + tmpDir2 := t.TempDir() + ctx := context.Background() + + w := &defaultWriter{} + err := w.Configure(ctx, "/foo/bar", tmpDir1) + require.NoError(t, err) + + // Try to materialize a non-template directory. + err = w.Materialize(ctx, &localReader{path: tmpDir2}) + assert.EqualError(t, err, "not a bundle template: expected to find a template schema file at databricks_template_schema.json") +}