From 2f239be8786b93c573cae1ed15026dda7360c4c0 Mon Sep 17 00:00:00 2001 From: Sushain Cherivirala Date: Wed, 9 Oct 2024 12:56:58 -0500 Subject: [PATCH] Fix races on temporary file copies (#616) * Fix races on temporary file copies * Appease Windows * Fix ordering --- core/core.go | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/core/core.go b/core/core.go index 4aff4249..79534f43 100644 --- a/core/core.go +++ b/core/core.go @@ -404,10 +404,17 @@ func downloadBazelIfNecessary(version string, bazeliskHome string, bazelForkOrUR } func atomicWriteFile(path string, contents []byte, perm os.FileMode) error { - if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + parent := filepath.Dir(path) + if err := os.MkdirAll(parent, 0755); err != nil { return fmt.Errorf("failed to MkdirAll parent of %s: %w", path, err) } - tmpPath := path + ".tmp" + tmpFile, err := os.CreateTemp(parent, filepath.Base(path)+".tmp") + if err != nil { + return fmt.Errorf("failed to create temporary file in %s: %w", parent, err) + } + tmpFile.Close() + defer os.Remove(tmpFile.Name()) + tmpPath := tmpFile.Name() if err := os.WriteFile(tmpPath, contents, perm); err != nil { return fmt.Errorf("failed to write file %s: %w", tmpPath, err) } @@ -458,12 +465,20 @@ func downloadBazelToCAS(version string, bazeliskHome string, repos *Repositories f.Close() actualSha256 := strings.ToLower(fmt.Sprintf("%x", h.Sum(nil))) - pathToBazelInCAS := filepath.Join(casDir, actualSha256, "bin", "bazel"+platforms.DetermineExecutableFilenameSuffix()) - if err := os.MkdirAll(filepath.Dir(pathToBazelInCAS), 0755); err != nil { + bazelInCASBasename := "bazel" + platforms.DetermineExecutableFilenameSuffix() + pathToBazelInCAS := filepath.Join(casDir, actualSha256, "bin", bazelInCASBasename) + dirForBazelInCAS := filepath.Dir(pathToBazelInCAS) + if err := os.MkdirAll(dirForBazelInCAS, 0755); err != nil { return "", "", fmt.Errorf("failed to MkdirAll parent of %s: %w", pathToBazelInCAS, err) } - tmpPathInCorrectDirectory := pathToBazelInCAS + ".tmp" + tmpPathFile, err := os.CreateTemp(dirForBazelInCAS, bazelInCASBasename+".tmp") + if err != nil { + return "", "", fmt.Errorf("failed to create temporary file in %s: %w", dirForBazelInCAS, err) + } + tmpPathFile.Close() + defer os.Remove(tmpPathFile.Name()) + tmpPathInCorrectDirectory := tmpPathFile.Name() if err := os.Rename(tmpDestPath, tmpPathInCorrectDirectory); err != nil { return "", "", fmt.Errorf("failed to move %s to %s: %w", tmpDestPath, tmpPathInCorrectDirectory, err) }