Skip to content

Commit

Permalink
chore(coderd): extract fileszip to package archive for reuse (coder#1…
Browse files Browse the repository at this point in the history
…5229)

Related to coder#15087
As part of sniffing the workspace tags from an uploaded file, we need to
be able to handle both zip and tar files. Extracting the functions to
a separate `archive` package will be helpful here.
  • Loading branch information
johnstcn authored Oct 25, 2024
1 parent 5ad4747 commit df34858
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 126 deletions.
25 changes: 14 additions & 11 deletions coderd/fileszip.go → archive/archive.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package coderd
package archive

import (
"archive/tar"
Expand All @@ -10,29 +10,30 @@ import (
"strings"
)

func CreateTarFromZip(zipReader *zip.Reader) ([]byte, error) {
// CreateTarFromZip converts the given zipReader to a tar archive.
func CreateTarFromZip(zipReader *zip.Reader, maxSize int64) ([]byte, error) {
var tarBuffer bytes.Buffer
err := writeTarArchive(&tarBuffer, zipReader)
err := writeTarArchive(&tarBuffer, zipReader, maxSize)
if err != nil {
return nil, err
}
return tarBuffer.Bytes(), nil
}

func writeTarArchive(w io.Writer, zipReader *zip.Reader) error {
func writeTarArchive(w io.Writer, zipReader *zip.Reader, maxSize int64) error {
tarWriter := tar.NewWriter(w)
defer tarWriter.Close()

for _, file := range zipReader.File {
err := processFileInZipArchive(file, tarWriter)
err := processFileInZipArchive(file, tarWriter, maxSize)
if err != nil {
return err
}
}
return nil
}

func processFileInZipArchive(file *zip.File, tarWriter *tar.Writer) error {
func processFileInZipArchive(file *zip.File, tarWriter *tar.Writer, maxSize int64) error {
fileReader, err := file.Open()
if err != nil {
return err
Expand All @@ -52,24 +53,26 @@ func processFileInZipArchive(file *zip.File, tarWriter *tar.Writer) error {
return err
}

n, err := io.CopyN(tarWriter, fileReader, httpFileMaxBytes)
n, err := io.CopyN(tarWriter, fileReader, maxSize)
log.Println(file.Name, n, err)
if errors.Is(err, io.EOF) {
err = nil
}
return err
}

func CreateZipFromTar(tarReader *tar.Reader) ([]byte, error) {
// CreateZipFromTar converts the given tarReader to a zip archive.
func CreateZipFromTar(tarReader *tar.Reader, maxSize int64) ([]byte, error) {
var zipBuffer bytes.Buffer
err := WriteZipArchive(&zipBuffer, tarReader)
err := WriteZip(&zipBuffer, tarReader, maxSize)
if err != nil {
return nil, err
}
return zipBuffer.Bytes(), nil
}

func WriteZipArchive(w io.Writer, tarReader *tar.Reader) error {
// WriteZip writes the given tarReader to w.
func WriteZip(w io.Writer, tarReader *tar.Reader, maxSize int64) error {
zipWriter := zip.NewWriter(w)
defer zipWriter.Close()

Expand Down Expand Up @@ -100,7 +103,7 @@ func WriteZipArchive(w io.Writer, tarReader *tar.Reader) error {
return err
}

_, err = io.CopyN(zipEntry, tarReader, httpFileMaxBytes)
_, err = io.CopyN(zipEntry, tarReader, maxSize)
if errors.Is(err, io.EOF) {
err = nil
}
Expand Down
109 changes: 12 additions & 97 deletions coderd/fileszip_test.go → archive/archive_test.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
package coderd_test
package archive_test

import (
"archive/tar"
"archive/zip"
"bytes"
"io"
"io/fs"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"

"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/archive"
"github.com/coder/coder/v2/archive/archivetest"
"github.com/coder/coder/v2/testutil"
)

Expand All @@ -30,18 +28,17 @@ func TestCreateTarFromZip(t *testing.T) {

// Read a zip file we prepared earlier
ctx := testutil.Context(t, testutil.WaitShort)
zipBytes, err := os.ReadFile(filepath.Join("testdata", "test.zip"))
require.NoError(t, err, "failed to read sample zip file")
zipBytes := archivetest.TestZipFileBytes()
// Assert invariant
assertSampleZipFile(t, zipBytes)
archivetest.AssertSampleZipFile(t, zipBytes)

zr, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
require.NoError(t, err, "failed to parse sample zip file")

tarBytes, err := coderd.CreateTarFromZip(zr)
tarBytes, err := archive.CreateTarFromZip(zr, int64(len(zipBytes)))
require.NoError(t, err, "failed to convert zip to tar")

assertSampleTarFile(t, tarBytes)
archivetest.AssertSampleTarFile(t, tarBytes)

tempDir := t.TempDir()
tempFilePath := filepath.Join(tempDir, "test.tar")
Expand All @@ -60,14 +57,13 @@ func TestCreateZipFromTar(t *testing.T) {
}
t.Run("OK", func(t *testing.T) {
t.Parallel()
tarBytes, err := os.ReadFile(filepath.Join(".", "testdata", "test.tar"))
require.NoError(t, err, "failed to read sample tar file")
tarBytes := archivetest.TestTarFileBytes()

tr := tar.NewReader(bytes.NewReader(tarBytes))
zipBytes, err := coderd.CreateZipFromTar(tr)
zipBytes, err := archive.CreateZipFromTar(tr, int64(len(tarBytes)))
require.NoError(t, err)

assertSampleZipFile(t, zipBytes)
archivetest.AssertSampleZipFile(t, zipBytes)

tempDir := t.TempDir()
tempFilePath := filepath.Join(tempDir, "test.zip")
Expand Down Expand Up @@ -99,7 +95,7 @@ func TestCreateZipFromTar(t *testing.T) {

// When: we convert this to a zip
tr := tar.NewReader(&tarBytes)
zipBytes, err := coderd.CreateZipFromTar(tr)
zipBytes, err := archive.CreateZipFromTar(tr, int64(tarBytes.Len()))
require.NoError(t, err)

// Then: the resulting zip should contain a corresponding directory
Expand Down Expand Up @@ -133,7 +129,7 @@ func assertExtractedFiles(t *testing.T, dir string, checkModePerm bool) {
if checkModePerm {
assert.Equal(t, fs.ModePerm&0o755, stat.Mode().Perm(), "expected mode 0755 on directory")
}
assert.Equal(t, archiveRefTime(t).UTC(), stat.ModTime().UTC(), "unexpected modtime of %q", path)
assert.Equal(t, archivetest.ArchiveRefTime(t).UTC(), stat.ModTime().UTC(), "unexpected modtime of %q", path)
case "/test/hello.txt":
stat, err := os.Stat(path)
assert.NoError(t, err, "failed to stat path %q", path)
Expand Down Expand Up @@ -168,84 +164,3 @@ func assertExtractedFiles(t *testing.T, dir string, checkModePerm bool) {
return nil
})
}

func assertSampleTarFile(t *testing.T, tarBytes []byte) {
t.Helper()

tr := tar.NewReader(bytes.NewReader(tarBytes))
for {
hdr, err := tr.Next()
if err != nil {
if err == io.EOF {
return
}
require.NoError(t, err)
}

// Note: ignoring timezones here.
require.Equal(t, archiveRefTime(t).UTC(), hdr.ModTime.UTC())

switch hdr.Name {
case "test/":
require.Equal(t, hdr.Typeflag, byte(tar.TypeDir))
case "test/hello.txt":
require.Equal(t, hdr.Typeflag, byte(tar.TypeReg))
bs, err := io.ReadAll(tr)
if err != nil && !xerrors.Is(err, io.EOF) {
require.NoError(t, err)
}
require.Equal(t, "hello", string(bs))
case "test/dir/":
require.Equal(t, hdr.Typeflag, byte(tar.TypeDir))
case "test/dir/world.txt":
require.Equal(t, hdr.Typeflag, byte(tar.TypeReg))
bs, err := io.ReadAll(tr)
if err != nil && !xerrors.Is(err, io.EOF) {
require.NoError(t, err)
}
require.Equal(t, "world", string(bs))
default:
require.Failf(t, "unexpected file in tar", hdr.Name)
}
}
}

func assertSampleZipFile(t *testing.T, zipBytes []byte) {
t.Helper()

zr, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
require.NoError(t, err)

for _, f := range zr.File {
// Note: ignoring timezones here.
require.Equal(t, archiveRefTime(t).UTC(), f.Modified.UTC())
switch f.Name {
case "test/", "test/dir/":
// directory
case "test/hello.txt":
rc, err := f.Open()
require.NoError(t, err)
bs, err := io.ReadAll(rc)
_ = rc.Close()
require.NoError(t, err)
require.Equal(t, "hello", string(bs))
case "test/dir/world.txt":
rc, err := f.Open()
require.NoError(t, err)
bs, err := io.ReadAll(rc)
_ = rc.Close()
require.NoError(t, err)
require.Equal(t, "world", string(bs))
default:
require.Failf(t, "unexpected file in zip", f.Name)
}
}
}

// archiveRefTime is the Go reference time. The contents of the sample tar and zip files
// in testdata/ all have their modtimes set to the below in some timezone.
func archiveRefTime(t *testing.T) time.Time {
locMST, err := time.LoadLocation("MST")
require.NoError(t, err, "failed to load MST timezone")
return time.Date(2006, 1, 2, 3, 4, 5, 0, locMST)
}
113 changes: 113 additions & 0 deletions archive/archivetest/archivetest.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package archivetest

import (
"archive/tar"
"archive/zip"
"bytes"
_ "embed"
"io"
"testing"
"time"

"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
)

//go:embed testdata/test.tar
var testTarFileBytes []byte

//go:embed testdata/test.zip
var testZipFileBytes []byte

// TestTarFileBytes returns the content of testdata/test.tar
func TestTarFileBytes() []byte {
return append([]byte{}, testTarFileBytes...)
}

// TestZipFileBytes returns the content of testdata/test.zip
func TestZipFileBytes() []byte {
return append([]byte{}, testZipFileBytes...)
}

// AssertSampleTarfile compares the content of tarBytes against testdata/test.tar.
func AssertSampleTarFile(t *testing.T, tarBytes []byte) {
t.Helper()

tr := tar.NewReader(bytes.NewReader(tarBytes))
for {
hdr, err := tr.Next()
if err != nil {
if err == io.EOF {
return
}
require.NoError(t, err)
}

// Note: ignoring timezones here.
require.Equal(t, ArchiveRefTime(t).UTC(), hdr.ModTime.UTC())

switch hdr.Name {
case "test/":
require.Equal(t, hdr.Typeflag, byte(tar.TypeDir))
case "test/hello.txt":
require.Equal(t, hdr.Typeflag, byte(tar.TypeReg))
bs, err := io.ReadAll(tr)
if err != nil && !xerrors.Is(err, io.EOF) {
require.NoError(t, err)
}
require.Equal(t, "hello", string(bs))
case "test/dir/":
require.Equal(t, hdr.Typeflag, byte(tar.TypeDir))
case "test/dir/world.txt":
require.Equal(t, hdr.Typeflag, byte(tar.TypeReg))
bs, err := io.ReadAll(tr)
if err != nil && !xerrors.Is(err, io.EOF) {
require.NoError(t, err)
}
require.Equal(t, "world", string(bs))
default:
require.Failf(t, "unexpected file in tar", hdr.Name)
}
}
}

// AssertSampleZipFile compares the content of zipBytes against testdata/test.zip.
func AssertSampleZipFile(t *testing.T, zipBytes []byte) {
t.Helper()

zr, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
require.NoError(t, err)

for _, f := range zr.File {
// Note: ignoring timezones here.
require.Equal(t, ArchiveRefTime(t).UTC(), f.Modified.UTC())
switch f.Name {
case "test/", "test/dir/":
// directory
case "test/hello.txt":
rc, err := f.Open()
require.NoError(t, err)
bs, err := io.ReadAll(rc)
_ = rc.Close()
require.NoError(t, err)
require.Equal(t, "hello", string(bs))
case "test/dir/world.txt":
rc, err := f.Open()
require.NoError(t, err)
bs, err := io.ReadAll(rc)
_ = rc.Close()
require.NoError(t, err)
require.Equal(t, "world", string(bs))
default:
require.Failf(t, "unexpected file in zip", f.Name)
}
}
}

// archiveRefTime is the Go reference time. The contents of the sample tar and zip files
// in testdata/ all have their modtimes set to the below in some timezone.
func ArchiveRefTime(t *testing.T) time.Time {
locMST, err := time.LoadLocation("MST")
require.NoError(t, err, "failed to load MST timezone")
return time.Date(2006, 1, 2, 3, 4, 5, 0, locMST)
}
File renamed without changes.
File renamed without changes.
3 changes: 2 additions & 1 deletion cli/templatepull_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/require"

"github.com/coder/coder/v2/archive"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/coderdtest"
Expand Down Expand Up @@ -95,7 +96,7 @@ func TestTemplatePull_Stdout(t *testing.T) {

// Verify .zip format
tarReader := tar.NewReader(bytes.NewReader(expected))
expectedZip, err := coderd.CreateZipFromTar(tarReader)
expectedZip, err := archive.CreateZipFromTar(tarReader, coderd.HTTPFileMaxBytes)
require.NoError(t, err)

inv, root = clitest.New(t, "templates", "pull", "--zip", template.Name)
Expand Down
Loading

0 comments on commit df34858

Please sign in to comment.