From ea4feeaa408526be96e07cc2bdf9969a8fdc538b Mon Sep 17 00:00:00 2001 From: zenkovev <99416694+zenkovev@users.noreply.github.com> Date: Thu, 19 Dec 2024 15:36:39 +0300 Subject: [PATCH 01/11] feat!: build log writer for container request (#2925) * feat: build log writer for container request * fix: single BuildLogWriter method for ImageBuildInfo interface * fix: change BuildLogWriter default behavior * fix: require in Test_BuildContainerFromDockerfileWithBuildLogWriter --- container.go | 21 +++++++++++++++++---- docker.go | 5 +---- docker_test.go | 31 +++++++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 8 deletions(-) diff --git a/container.go b/container.go index 5ee0aac881..35be60fb81 100644 --- a/container.go +++ b/container.go @@ -77,7 +77,7 @@ type ImageBuildInfo interface { GetDockerfile() string // the relative path to the Dockerfile, including the file itself GetRepo() string // get repo label for image GetTag() string // get tag label for image - ShouldPrintBuildLog() bool // allow build log to be printed to stdout + BuildLogWriter() io.Writer // for output of build log, use io.Discard to disable the output ShouldBuildImage() bool // return true if the image needs to be built GetBuildArgs() map[string]*string // return the environment args used to build the from Dockerfile GetAuthConfigs() map[string]registry.AuthConfig // Deprecated. Testcontainers will detect registry credentials automatically. Return the auth configs to be able to pull from an authenticated docker registry @@ -92,7 +92,8 @@ type FromDockerfile struct { Repo string // the repo label for image, defaults to UUID Tag string // the tag label for image, defaults to UUID BuildArgs map[string]*string // enable user to pass build args to docker daemon - PrintBuildLog bool // enable user to print build log + PrintBuildLog bool // Deprecated: Use BuildLogWriter instead + BuildLogWriter io.Writer // for output of build log, defaults to io.Discard AuthConfigs map[string]registry.AuthConfig // Deprecated. Testcontainers will detect registry credentials automatically. Enable auth configs to be able to pull from an authenticated docker registry // KeepImage describes whether DockerContainer.Terminate should not delete the // container image. Useful for images that are built from a Dockerfile and take a @@ -410,8 +411,20 @@ func (c *ContainerRequest) ShouldKeepBuiltImage() bool { return c.FromDockerfile.KeepImage } -func (c *ContainerRequest) ShouldPrintBuildLog() bool { - return c.FromDockerfile.PrintBuildLog +// BuildLogWriter returns the io.Writer for output of log when building a Docker image from +// a Dockerfile. It returns the BuildLogWriter from the ContainerRequest, defaults to io.Discard. +// For backward compatibility, if BuildLogWriter is default and PrintBuildLog is true, +// the function returns os.Stderr. +func (c *ContainerRequest) BuildLogWriter() io.Writer { + if c.FromDockerfile.BuildLogWriter != nil { + return c.FromDockerfile.BuildLogWriter + } + if c.FromDockerfile.PrintBuildLog { + c.FromDockerfile.BuildLogWriter = os.Stderr + } else { + c.FromDockerfile.BuildLogWriter = io.Discard + } + return c.FromDockerfile.BuildLogWriter } // BuildOptions returns the image build options when building a Docker image from a Dockerfile. diff --git a/docker.go b/docker.go index 296fe6743c..b10b14b7ff 100644 --- a/docker.go +++ b/docker.go @@ -1004,10 +1004,7 @@ func (p *DockerProvider) BuildImage(ctx context.Context, img ImageBuildInfo) (st } defer resp.Body.Close() - output := io.Discard - if img.ShouldPrintBuildLog() { - output = os.Stderr - } + output := img.BuildLogWriter() // Always process the output, even if it is not printed // to ensure that errors during the build process are diff --git a/docker_test.go b/docker_test.go index 3fa686632f..eb92e15060 100644 --- a/docker_test.go +++ b/docker_test.go @@ -705,6 +705,37 @@ func Test_BuildContainerFromDockerfileWithBuildLog(t *testing.T) { assert.Regexpf(t, `^Step\s*1/\d+\s*:\s*FROM alpine$`, temp[0], "Expected stdout first line to be %s. Got '%s'.", "Step 1/* : FROM alpine", temp[0]) } +func Test_BuildContainerFromDockerfileWithBuildLogWriter(t *testing.T) { + var buffer bytes.Buffer + + ctx := context.Background() + + // fromDockerfile { + req := ContainerRequest{ + FromDockerfile: FromDockerfile{ + Context: filepath.Join(".", "testdata"), + Dockerfile: "buildlog.Dockerfile", + BuildLogWriter: &buffer, + }, + } + // } + + genContainerReq := GenericContainerRequest{ + ProviderType: providerType, + ContainerRequest: req, + Started: true, + } + + c, err := GenericContainer(ctx, genContainerReq) + CleanupContainer(t, c) + require.NoError(t, err) + + out := buffer.String() + temp := strings.Split(out, "\n") + require.NotEmpty(t, temp) + require.Regexpf(t, `^Step\s*1/\d+\s*:\s*FROM alpine$`, temp[0], "Expected stdout first line to be %s. Got '%s'.", "Step 1/* : FROM alpine", temp[0]) +} + func TestContainerCreationWaitsForLogAndPortContextTimeout(t *testing.T) { ctx := context.Background() req := ContainerRequest{ From abe0f8244bf210e4dcd15ab553f2cdcb034345be Mon Sep 17 00:00:00 2001 From: Viktor Stanchev Date: Fri, 20 Dec 2024 01:18:49 -0500 Subject: [PATCH 02/11] fix: avoid double lock in DockerProvider.DaemonHost() (#2900) * avoid double lock in DockerProvider.DaemonHost() * cleaner structure * put comment back * add regression test * use require * test improvements * better error output * try to fix rootless mode * pass on XDG_RUNTIME_DIR * fix: DaemonHost locking test Fix the DaemonHost locking test by implementing a way to change the location of the file the core library tests for. --------- Co-authored-by: Steven Hartland --- docker.go | 12 +++++++++++- docker_test.go | 38 ++++++++++++++++++++++++++++++++++++ internal/core/docker_host.go | 7 ++++++- 3 files changed, 55 insertions(+), 2 deletions(-) diff --git a/docker.go b/docker.go index b10b14b7ff..01b3d3d4d2 100644 --- a/docker.go +++ b/docker.go @@ -1495,7 +1495,11 @@ func (p *DockerProvider) daemonHostLocked(ctx context.Context) (string, error) { p.hostCache = daemonURL.Hostname() case "unix", "npipe": if core.InAContainer() { - ip, err := p.GetGatewayIP(ctx) + defaultNetwork, err := p.ensureDefaultNetworkLocked(ctx) + if err != nil { + return "", fmt.Errorf("ensure default network: %w", err) + } + ip, err := p.getGatewayIP(ctx, defaultNetwork) if err != nil { ip, err = core.DefaultGatewayIP() if err != nil { @@ -1595,7 +1599,10 @@ func (p *DockerProvider) GetGatewayIP(ctx context.Context) (string, error) { if err != nil { return "", fmt.Errorf("ensure default network: %w", err) } + return p.getGatewayIP(ctx, defaultNetwork) +} +func (p *DockerProvider) getGatewayIP(ctx context.Context, defaultNetwork string) (string, error) { nw, err := p.GetNetwork(ctx, NetworkRequest{Name: defaultNetwork}) if err != nil { return "", err @@ -1621,7 +1628,10 @@ func (p *DockerProvider) GetGatewayIP(ctx context.Context) (string, error) { func (p *DockerProvider) ensureDefaultNetwork(ctx context.Context) (string, error) { p.mtx.Lock() defer p.mtx.Unlock() + return p.ensureDefaultNetworkLocked(ctx) +} +func (p *DockerProvider) ensureDefaultNetworkLocked(ctx context.Context) (string, error) { if p.defaultNetwork != "" { // Already set. return p.defaultNetwork, nil diff --git a/docker_test.go b/docker_test.go index eb92e15060..8fcd60c558 100644 --- a/docker_test.go +++ b/docker_test.go @@ -25,6 +25,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go/internal/core" "github.com/testcontainers/testcontainers-go/wait" ) @@ -35,6 +36,7 @@ const ( nginxAlpineImage = "nginx:alpine" nginxDefaultPort = "80/tcp" nginxHighPort = "8080/tcp" + golangImage = "golang" daemonMaxVersion = "1.41" ) @@ -2156,3 +2158,39 @@ func TestCustomPrefixTrailingSlashIsProperlyRemovedIfPresent(t *testing.T) { dockerContainer := c.(*DockerContainer) require.Equal(t, fmt.Sprintf("%s%s", hubPrefixWithTrailingSlash, dockerImage), dockerContainer.Image) } + +// TODO: remove this skip check when context rework is merged alongside [core.DockerEnvFile] removal. +func Test_Provider_DaemonHost_Issue2897(t *testing.T) { + ctx := context.Background() + provider, err := NewDockerProvider() + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, provider.Close()) + }) + + orig := core.DockerEnvFile + core.DockerEnvFile = filepath.Join(t.TempDir(), ".dockerenv") + t.Cleanup(func() { + core.DockerEnvFile = orig + }) + + f, err := os.Create(core.DockerEnvFile) + require.NoError(t, err) + require.NoError(t, f.Close()) + t.Cleanup(func() { + require.NoError(t, os.Remove(f.Name())) + }) + + errCh := make(chan error, 1) + go func() { + _, err := provider.DaemonHost(ctx) + errCh <- err + }() + + select { + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for DaemonHost") + case err := <-errCh: + require.NoError(t, err) + } +} diff --git a/internal/core/docker_host.go b/internal/core/docker_host.go index 3088a3742b..765626da57 100644 --- a/internal/core/docker_host.go +++ b/internal/core/docker_host.go @@ -309,10 +309,15 @@ func testcontainersHostFromProperties(ctx context.Context) (string, error) { return "", ErrTestcontainersHostNotSetInProperties } +// DockerEnvFile is the file that is created when running inside a container. +// It's a variable to allow testing. +// TODO: Remove this once context rework is done, which eliminates need for the default network creation. +var DockerEnvFile = "/.dockerenv" + // InAContainer returns true if the code is running inside a container // See https://github.com/docker/docker/blob/a9fa38b1edf30b23cae3eade0be48b3d4b1de14b/daemon/initlayer/setup_unix.go#L25 func InAContainer() bool { - return inAContainer("/.dockerenv") + return inAContainer(DockerEnvFile) } func inAContainer(path string) bool { From 4f67ae08757f3b880691d2496e62cb9c696523af Mon Sep 17 00:00:00 2001 From: Emanuel Bennici Date: Fri, 20 Dec 2024 14:53:36 +0100 Subject: [PATCH 03/11] fix: Handle nil value in CleanupNetwork (#2928) The godoc of `CleanupNetwork` states that a `nil` network will result in a no-op. --- network/network_test.go | 5 +++++ testing.go | 4 +++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/network/network_test.go b/network/network_test.go index bbe5d45c7c..8b83056f43 100644 --- a/network/network_test.go +++ b/network/network_test.go @@ -440,3 +440,8 @@ func TestWithNewNetworkContextTimeout(t *testing.T) { require.Empty(t, req.Networks) require.Empty(t, req.NetworkAliases) } + +func TestCleanupWithNil(t *testing.T) { + var network *testcontainers.DockerNetwork + testcontainers.CleanupNetwork(t, network) +} diff --git a/testing.go b/testing.go index 35ce4f0a39..8502f018d9 100644 --- a/testing.go +++ b/testing.go @@ -83,7 +83,9 @@ func CleanupNetwork(tb testing.TB, network Network) { tb.Helper() tb.Cleanup(func() { - noErrorOrIgnored(tb, network.Remove(context.Background())) + if !isNil(network) { + noErrorOrIgnored(tb, network.Remove(context.Background())) + } }) } From 63fad4d8bda2c92beccec9929b976e93401ab679 Mon Sep 17 00:00:00 2001 From: Steven Hartland Date: Fri, 20 Dec 2024 15:14:37 +0000 Subject: [PATCH 04/11] feat(wait): log sub match callback (#2929) Add support for a sub match callback to wait.LogStrategy which allows containers to process the matched pattern storing details or otherwise validating them. The callback can return a PermanentError if no more retries should be attempted. --- docs/features/wait/log.md | 40 ++++++++- wait/log.go | 111 ++++++++++++++++++++----- wait/log_test.go | 169 ++++++++++++++++++++++++++------------ 3 files changed, 246 insertions(+), 74 deletions(-) diff --git a/docs/features/wait/log.md b/docs/features/wait/log.md index f1d40ff360..8466d68511 100644 --- a/docs/features/wait/log.md +++ b/docs/features/wait/log.md @@ -3,10 +3,11 @@ The Log wait strategy will check if a string occurs in the container logs for a desired number of times, and allows to set the following conditions: - the string to be waited for in the container log. -- the number of occurrences of the string to wait for, default is `1`. +- the number of occurrences of the string to wait for, default is `1` (ignored for Submatch). - look for the string using a regular expression, default is `false`. - the startup timeout to be used in seconds, default is 60 seconds. - the poll interval to be used in milliseconds, default is 100 milliseconds. +- the regular expression submatch callback, default nil (occurrences is ignored). ```golang req := ContainerRequest{ @@ -33,3 +34,40 @@ req := ContainerRequest{ WaitingFor: wait.ForLog(`.*MySQL Community Server`).AsRegexp(), } ``` + +Using regular expression with submatch: + +```golang +var host, port string +req := ContainerRequest{ + Image: "ollama/ollama:0.1.25", + ExposedPorts: []string{"11434/tcp"}, + WaitingFor: wait.ForLog(`Listening on (.*:\d+) \(version\s(.*)\)`).Submatch(func(pattern string, submatches [][][]byte) error { + var err error + for _, matches := range submatches { + if len(matches) != 3 { + err = fmt.Errorf("`%s` matched %d times, expected %d", pattern, len(matches), 3) + continue + } + host, port, err = net.SplitHostPort(string(matches[1])) + if err != nil { + return wait.NewPermanentError(fmt.Errorf("split host port: %w", err)) + } + + // Host and port successfully extracted from log. + return nil + } + + if err != nil { + // Return the last error encountered. + return err + } + + return fmt.Errorf("address and version not found: `%s` no matches", pattern) + }), +} +``` + +If the return from a Submatch callback function is a `wait.PermanentError` the +wait will stop and the error will be returned. Use `wait.NewPermanentError(err error)` +to achieve this. diff --git a/wait/log.go b/wait/log.go index 530077f909..41c96e3eb9 100644 --- a/wait/log.go +++ b/wait/log.go @@ -1,10 +1,12 @@ package wait import ( + "bytes" "context" + "errors" + "fmt" "io" "regexp" - "strings" "time" ) @@ -14,6 +16,21 @@ var ( _ StrategyTimeout = (*LogStrategy)(nil) ) +// PermanentError is a special error that will stop the wait and return an error. +type PermanentError struct { + err error +} + +// Error implements the error interface. +func (e *PermanentError) Error() string { + return e.err.Error() +} + +// NewPermanentError creates a new PermanentError. +func NewPermanentError(err error) *PermanentError { + return &PermanentError{err: err} +} + // LogStrategy will wait until a given log entry shows up in the docker logs type LogStrategy struct { // all Strategies should have a startupTimeout to avoid waiting infinitely @@ -24,6 +41,18 @@ type LogStrategy struct { IsRegexp bool Occurrence int PollInterval time.Duration + + // check is the function that will be called to check if the log entry is present. + check func([]byte) error + + // submatchCallback is a callback that will be called with the sub matches of the regexp. + submatchCallback func(pattern string, matches [][][]byte) error + + // re is the optional compiled regexp. + re *regexp.Regexp + + // log byte slice version of [LogStrategy.Log] used for count checks. + log []byte } // NewLogStrategy constructs with polling interval of 100 milliseconds and startup timeout of 60 seconds by default @@ -46,6 +75,18 @@ func (ws *LogStrategy) AsRegexp() *LogStrategy { return ws } +// Submatch configures a function that will be called with the result of +// [regexp.Regexp.FindAllSubmatch], allowing the caller to process the results. +// If the callback returns nil, the strategy will be considered successful. +// Returning a [PermanentError] will stop the wait and return an error, otherwise +// it will retry until the timeout is reached. +// [LogStrategy.Occurrence] is ignored if this option is set. +func (ws *LogStrategy) Submatch(callback func(pattern string, matches [][][]byte) error) *LogStrategy { + ws.submatchCallback = callback + + return ws +} + // WithStartupTimeout can be used to change the default startup timeout func (ws *LogStrategy) WithStartupTimeout(timeout time.Duration) *LogStrategy { ws.timeout = &timeout @@ -89,57 +130,85 @@ func (ws *LogStrategy) WaitUntilReady(ctx context.Context, target StrategyTarget timeout = *ws.timeout } + switch { + case ws.submatchCallback != nil: + ws.re = regexp.MustCompile(ws.Log) + ws.check = ws.checkSubmatch + case ws.IsRegexp: + ws.re = regexp.MustCompile(ws.Log) + ws.check = ws.checkRegexp + default: + ws.log = []byte(ws.Log) + ws.check = ws.checkCount + } + ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - length := 0 - -LOOP: + var lastLen int + var lastError error for { select { case <-ctx.Done(): - return ctx.Err() + return errors.Join(lastError, ctx.Err()) default: checkErr := checkTarget(ctx, target) reader, err := target.Logs(ctx) if err != nil { + // TODO: fix as this will wait for timeout if the logs are not available. time.Sleep(ws.PollInterval) continue } b, err := io.ReadAll(reader) if err != nil { + // TODO: fix as this will wait for timeout if the logs are not readable. time.Sleep(ws.PollInterval) continue } - logs := string(b) - - switch { - case length == len(logs) && checkErr != nil: + if lastLen == len(b) && checkErr != nil { + // Log length hasn't changed so we're not making progress. return checkErr - case checkLogsFn(ws, b): - break LOOP - default: - length = len(logs) + } + + if err := ws.check(b); err != nil { + var errPermanent *PermanentError + if errors.As(err, &errPermanent) { + return err + } + + lastError = err + lastLen = len(b) time.Sleep(ws.PollInterval) continue } + + return nil } } +} + +// checkCount checks if the log entry is present in the logs using a string count. +func (ws *LogStrategy) checkCount(b []byte) error { + if count := bytes.Count(b, ws.log); count < ws.Occurrence { + return fmt.Errorf("%q matched %d times, expected %d", ws.Log, count, ws.Occurrence) + } return nil } -func checkLogsFn(ws *LogStrategy, b []byte) bool { - if ws.IsRegexp { - re := regexp.MustCompile(ws.Log) - occurrences := re.FindAll(b, -1) - - return len(occurrences) >= ws.Occurrence +// checkRegexp checks if the log entry is present in the logs using a regexp count. +func (ws *LogStrategy) checkRegexp(b []byte) error { + if matches := ws.re.FindAll(b, -1); len(matches) < ws.Occurrence { + return fmt.Errorf("`%s` matched %d times, expected %d", ws.Log, len(matches), ws.Occurrence) } - logs := string(b) - return strings.Count(logs, ws.Log) >= ws.Occurrence + return nil +} + +// checkSubmatch checks if the log entry is present in the logs using a regexp sub match callback. +func (ws *LogStrategy) checkSubmatch(b []byte) error { + return ws.submatchCallback(ws.Log, ws.re.FindAllSubmatch(b, -1)) } diff --git a/wait/log_test.go b/wait/log_test.go index 7c767c0e25..4bfbc26438 100644 --- a/wait/log_test.go +++ b/wait/log_test.go @@ -1,14 +1,17 @@ -package wait +package wait_test import ( - "bytes" "context" + "fmt" "io" + "strings" "testing" "time" "github.com/docker/docker/api/types" "github.com/stretchr/testify/require" + + "github.com/testcontainers/testcontainers-go/wait" ) const logTimeout = time.Second @@ -25,107 +28,164 @@ Donec ut libero sed arcu vehicula ultricies a non tortor. Lorem ipsum dolor sit amet, consectetur adipiscing elit.` func TestWaitForLog(t *testing.T) { - t.Run("no regexp", func(t *testing.T) { - target := NopStrategyTarget{ - ReaderCloser: io.NopCloser(bytes.NewReader([]byte("docker"))), + t.Run("string", func(t *testing.T) { + target := wait.NopStrategyTarget{ + ReaderCloser: readCloser("docker"), } - wg := NewLogStrategy("docker").WithStartupTimeout(100 * time.Millisecond) + wg := wait.NewLogStrategy("docker").WithStartupTimeout(100 * time.Millisecond) err := wg.WaitUntilReady(context.Background(), target) require.NoError(t, err) }) - t.Run("no regexp", func(t *testing.T) { - target := NopStrategyTarget{ - ReaderCloser: io.NopCloser(bytes.NewReader([]byte(loremIpsum))), + t.Run("regexp", func(t *testing.T) { + target := wait.NopStrategyTarget{ + ReaderCloser: readCloser(loremIpsum), } // get all words that start with "ip", end with "m" and has a whitespace before the "ip" - wg := NewLogStrategy(`\sip[\w]+m`).WithStartupTimeout(100 * time.Millisecond).AsRegexp() + wg := wait.NewLogStrategy(`\sip[\w]+m`).WithStartupTimeout(100 * time.Millisecond).AsRegexp() + err := wg.WaitUntilReady(context.Background(), target) + require.NoError(t, err) + }) + + t.Run("submatch/valid", func(t *testing.T) { + target := wait.NopStrategyTarget{ + ReaderCloser: readCloser("three matches: ip1m, ip2m, ip3m"), + } + + wg := wait.NewLogStrategy(`ip(\d)m`).WithStartupTimeout(100 * time.Millisecond).Submatch(func(pattern string, submatches [][][]byte) error { + if len(submatches) != 3 { + return wait.NewPermanentError(fmt.Errorf("%q matched %d times, expected %d", pattern, len(submatches), 3)) + } + return nil + }) + err := wg.WaitUntilReady(context.Background(), target) + require.NoError(t, err) + }) + + t.Run("submatch/permanent-error", func(t *testing.T) { + target := wait.NopStrategyTarget{ + ReaderCloser: readCloser("single matches: ip1m"), + } + + wg := wait.NewLogStrategy(`ip(\d)m`).WithStartupTimeout(100 * time.Millisecond).Submatch(func(pattern string, submatches [][][]byte) error { + if len(submatches) != 3 { + return wait.NewPermanentError(fmt.Errorf("%q matched %d times, expected %d", pattern, len(submatches), 3)) + } + return nil + }) + err := wg.WaitUntilReady(context.Background(), target) + require.Error(t, err) + var permanentError *wait.PermanentError + require.ErrorAs(t, err, &permanentError) + }) + + t.Run("submatch/temporary-error", func(t *testing.T) { + target := newRunningTarget() + expect := target.EXPECT() + expect.Logs(anyContext).Return(readCloser(""), nil).Once() // No matches. + expect.Logs(anyContext).Return(readCloser("ip1m, ip2m"), nil).Once() // Two matches. + expect.Logs(anyContext).Return(readCloser("ip1m, ip2m, ip3m"), nil).Once() // Three matches. + expect.Logs(anyContext).Return(readCloser("ip1m, ip2m, ip3m, ip4m"), nil) // Four matches. + + wg := wait.NewLogStrategy(`ip(\d)m`).WithStartupTimeout(400 * time.Second).Submatch(func(pattern string, submatches [][][]byte) error { + switch len(submatches) { + case 0, 2: + // Too few matches. + return fmt.Errorf("`%s` matched %d times, expected %d (temporary)", pattern, len(submatches), 3) + case 3: + // Expected number of matches should stop the wait. + return nil + default: + // Should not be triggered. + return wait.NewPermanentError(fmt.Errorf("`%s` matched %d times, expected %d (permanent)", pattern, len(submatches), 3)) + } + }) err := wg.WaitUntilReady(context.Background(), target) require.NoError(t, err) }) } func TestWaitWithExactNumberOfOccurrences(t *testing.T) { - t.Run("no regexp", func(t *testing.T) { - target := NopStrategyTarget{ - ReaderCloser: io.NopCloser(bytes.NewReader([]byte("kubernetes\r\ndocker\n\rdocker"))), + t.Run("string", func(t *testing.T) { + target := wait.NopStrategyTarget{ + ReaderCloser: readCloser("kubernetes\r\ndocker\n\rdocker"), } - wg := NewLogStrategy("docker"). + wg := wait.NewLogStrategy("docker"). WithStartupTimeout(100 * time.Millisecond). WithOccurrence(2) err := wg.WaitUntilReady(context.Background(), target) require.NoError(t, err) }) - t.Run("as regexp", func(t *testing.T) { - target := NopStrategyTarget{ - ReaderCloser: io.NopCloser(bytes.NewReader([]byte(loremIpsum))), + t.Run("regexp", func(t *testing.T) { + target := wait.NopStrategyTarget{ + ReaderCloser: readCloser(loremIpsum), } // get texts from "ip" to the next "m". // there are three occurrences of this pattern in the string: // one "ipsum mauris" and two "ipsum dolor sit am" - wg := NewLogStrategy(`ip(.*)m`).WithStartupTimeout(100 * time.Millisecond).AsRegexp().WithOccurrence(3) + wg := wait.NewLogStrategy(`ip(.*)m`).WithStartupTimeout(100 * time.Millisecond).AsRegexp().WithOccurrence(3) err := wg.WaitUntilReady(context.Background(), target) require.NoError(t, err) }) } func TestWaitWithExactNumberOfOccurrencesButItWillNeverHappen(t *testing.T) { - t.Run("no regexp", func(t *testing.T) { - target := NopStrategyTarget{ - ReaderCloser: io.NopCloser(bytes.NewReader([]byte("kubernetes\r\ndocker"))), + t.Run("string", func(t *testing.T) { + target := wait.NopStrategyTarget{ + ReaderCloser: readCloser("kubernetes\r\ndocker"), } - wg := NewLogStrategy("containerd"). + wg := wait.NewLogStrategy("containerd"). WithStartupTimeout(logTimeout). WithOccurrence(2) err := wg.WaitUntilReady(context.Background(), target) require.Error(t, err) }) - t.Run("as regexp", func(t *testing.T) { - target := NopStrategyTarget{ - ReaderCloser: io.NopCloser(bytes.NewReader([]byte(loremIpsum))), + t.Run("regexp", func(t *testing.T) { + target := wait.NopStrategyTarget{ + ReaderCloser: readCloser(loremIpsum), } // get texts from "ip" to the next "m". // there are only three occurrences matching - wg := NewLogStrategy(`do(.*)ck.+`).WithStartupTimeout(100 * time.Millisecond).AsRegexp().WithOccurrence(4) + wg := wait.NewLogStrategy(`do(.*)ck.+`).WithStartupTimeout(100 * time.Millisecond).AsRegexp().WithOccurrence(4) err := wg.WaitUntilReady(context.Background(), target) require.Error(t, err) }) } func TestWaitShouldFailWithExactNumberOfOccurrences(t *testing.T) { - t.Run("no regexp", func(t *testing.T) { - target := NopStrategyTarget{ - ReaderCloser: io.NopCloser(bytes.NewReader([]byte("kubernetes\r\ndocker"))), + t.Run("string", func(t *testing.T) { + target := wait.NopStrategyTarget{ + ReaderCloser: readCloser("kubernetes\r\ndocker"), } - wg := NewLogStrategy("docker"). + wg := wait.NewLogStrategy("docker"). WithStartupTimeout(logTimeout). WithOccurrence(2) err := wg.WaitUntilReady(context.Background(), target) require.Error(t, err) }) - t.Run("as regexp", func(t *testing.T) { - target := NopStrategyTarget{ - ReaderCloser: io.NopCloser(bytes.NewReader([]byte(loremIpsum))), + t.Run("regexp", func(t *testing.T) { + target := wait.NopStrategyTarget{ + ReaderCloser: readCloser(loremIpsum), } // get "Maecenas". // there are only one occurrence matching - wg := NewLogStrategy(`^Mae[\w]?enas\s`).WithStartupTimeout(100 * time.Millisecond).AsRegexp().WithOccurrence(2) + wg := wait.NewLogStrategy(`^Mae[\w]?enas\s`).WithStartupTimeout(100 * time.Millisecond).AsRegexp().WithOccurrence(2) err := wg.WaitUntilReady(context.Background(), target) require.Error(t, err) }) } func TestWaitForLogFailsDueToOOMKilledContainer(t *testing.T) { - target := &MockStrategyTarget{ + target := &wait.MockStrategyTarget{ LogsImpl: func(_ context.Context) (io.ReadCloser, error) { - return io.NopCloser(bytes.NewReader([]byte(""))), nil + return readCloser(""), nil }, StateImpl: func(_ context.Context) (*types.ContainerState, error) { return &types.ContainerState{ @@ -134,16 +194,16 @@ func TestWaitForLogFailsDueToOOMKilledContainer(t *testing.T) { }, } - t.Run("no regexp", func(t *testing.T) { - wg := ForLog("docker").WithStartupTimeout(logTimeout) + t.Run("string", func(t *testing.T) { + wg := wait.ForLog("docker").WithStartupTimeout(logTimeout) err := wg.WaitUntilReady(context.Background(), target) expected := "container crashed with out-of-memory (OOMKilled)" require.EqualError(t, err, expected) }) - t.Run("as regexp", func(t *testing.T) { - wg := ForLog("docker").WithStartupTimeout(logTimeout).AsRegexp() + t.Run("regexp", func(t *testing.T) { + wg := wait.ForLog("docker").WithStartupTimeout(logTimeout).AsRegexp() err := wg.WaitUntilReady(context.Background(), target) expected := "container crashed with out-of-memory (OOMKilled)" @@ -152,9 +212,9 @@ func TestWaitForLogFailsDueToOOMKilledContainer(t *testing.T) { } func TestWaitForLogFailsDueToExitedContainer(t *testing.T) { - target := &MockStrategyTarget{ + target := &wait.MockStrategyTarget{ LogsImpl: func(_ context.Context) (io.ReadCloser, error) { - return io.NopCloser(bytes.NewReader([]byte(""))), nil + return readCloser(""), nil }, StateImpl: func(_ context.Context) (*types.ContainerState, error) { return &types.ContainerState{ @@ -164,16 +224,16 @@ func TestWaitForLogFailsDueToExitedContainer(t *testing.T) { }, } - t.Run("no regexp", func(t *testing.T) { - wg := ForLog("docker").WithStartupTimeout(logTimeout) + t.Run("string", func(t *testing.T) { + wg := wait.ForLog("docker").WithStartupTimeout(logTimeout) err := wg.WaitUntilReady(context.Background(), target) expected := "container exited with code 1" require.EqualError(t, err, expected) }) - t.Run("as regexp", func(t *testing.T) { - wg := ForLog("docker").WithStartupTimeout(logTimeout).AsRegexp() + t.Run("regexp", func(t *testing.T) { + wg := wait.ForLog("docker").WithStartupTimeout(logTimeout).AsRegexp() err := wg.WaitUntilReady(context.Background(), target) expected := "container exited with code 1" @@ -182,9 +242,9 @@ func TestWaitForLogFailsDueToExitedContainer(t *testing.T) { } func TestWaitForLogFailsDueToUnexpectedContainerStatus(t *testing.T) { - target := &MockStrategyTarget{ + target := &wait.MockStrategyTarget{ LogsImpl: func(_ context.Context) (io.ReadCloser, error) { - return io.NopCloser(bytes.NewReader([]byte(""))), nil + return readCloser(""), nil }, StateImpl: func(_ context.Context) (*types.ContainerState, error) { return &types.ContainerState{ @@ -193,19 +253,24 @@ func TestWaitForLogFailsDueToUnexpectedContainerStatus(t *testing.T) { }, } - t.Run("no regexp", func(t *testing.T) { - wg := ForLog("docker").WithStartupTimeout(logTimeout) + t.Run("string", func(t *testing.T) { + wg := wait.ForLog("docker").WithStartupTimeout(logTimeout) err := wg.WaitUntilReady(context.Background(), target) expected := "unexpected container status \"dead\"" require.EqualError(t, err, expected) }) - t.Run("as regexp", func(t *testing.T) { - wg := ForLog("docker").WithStartupTimeout(logTimeout).AsRegexp() + t.Run("regexp", func(t *testing.T) { + wg := wait.ForLog("docker").WithStartupTimeout(logTimeout).AsRegexp() err := wg.WaitUntilReady(context.Background(), target) expected := "unexpected container status \"dead\"" require.EqualError(t, err, expected) }) } + +// readCloser returns an io.ReadCloser that reads from s. +func readCloser(s string) io.ReadCloser { + return io.NopCloser(strings.NewReader((s))) +} From cc55f13bf9480645bb4b1d6a8e89800c057e7f6b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 1 Jan 2025 14:52:23 +0000 Subject: [PATCH 05/11] chore(deps): bump github/codeql-action from 3.25.15 to 3.28.0 (#2932) Bumps [github/codeql-action](https://github.com/github/codeql-action) from 3.25.15 to 3.28.0. - [Release notes](https://github.com/github/codeql-action/releases) - [Changelog](https://github.com/github/codeql-action/blob/main/CHANGELOG.md) - [Commits](https://github.com/github/codeql-action/compare/afb54ba388a7dca6ecae48f608c4ff05ff4cc77a...48ab28a6f5dbc2a99bf1e0131198dd8f1df78169) --- updated-dependencies: - dependency-name: github/codeql-action dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/codeql.yml | 6 +++--- .github/workflows/scorecards.yml | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 15f56f2eb5..2c899be9d4 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -53,7 +53,7 @@ jobs: # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@afb54ba388a7dca6ecae48f608c4ff05ff4cc77a # v3.25.15 + uses: github/codeql-action/init@48ab28a6f5dbc2a99bf1e0131198dd8f1df78169 # v3.28.0 with: languages: ${{ matrix.language }} # If you wish to specify custom queries, you can do so here or in a config file. @@ -67,7 +67,7 @@ jobs: # Autobuild attempts to build any compiled languages (C/C++, C#, Go, Java, or Swift). # If this step fails, then you should remove it and run the build manually (see below) - name: Autobuild - uses: github/codeql-action/autobuild@afb54ba388a7dca6ecae48f608c4ff05ff4cc77a # v3.25.15 + uses: github/codeql-action/autobuild@48ab28a6f5dbc2a99bf1e0131198dd8f1df78169 # v3.28.0 # ℹ️ Command-line programs to run using the OS shell. # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun @@ -80,6 +80,6 @@ jobs: # ./location_of_script_within_repo/buildscript.sh - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@afb54ba388a7dca6ecae48f608c4ff05ff4cc77a # v3.25.15 + uses: github/codeql-action/analyze@48ab28a6f5dbc2a99bf1e0131198dd8f1df78169 # v3.28.0 with: category: "/language:${{matrix.language}}" diff --git a/.github/workflows/scorecards.yml b/.github/workflows/scorecards.yml index 51b8a535c5..1d141de781 100644 --- a/.github/workflows/scorecards.yml +++ b/.github/workflows/scorecards.yml @@ -51,6 +51,6 @@ jobs: # required for Code scanning alerts - name: "Upload SARIF results to code scanning" - uses: github/codeql-action/upload-sarif@afb54ba388a7dca6ecae48f608c4ff05ff4cc77a # v3.25.15 + uses: github/codeql-action/upload-sarif@48ab28a6f5dbc2a99bf1e0131198dd8f1df78169 # v3.28.0 with: sarif_file: results.sarif From eb5b8ed3597dfe7fe38c7910b446397ee4921d2a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 2 Jan 2025 09:00:10 +0100 Subject: [PATCH 06/11] chore(deps): bump slackapi/slack-github-action from 1.26.0 to 2.0.0 (#2934) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * chore(deps): bump slackapi/slack-github-action from 1.26.0 to 2.0.0 Bumps [slackapi/slack-github-action](https://github.com/slackapi/slack-github-action) from 1.26.0 to 2.0.0. - [Release notes](https://github.com/slackapi/slack-github-action/releases) - [Commits](https://github.com/slackapi/slack-github-action/compare/v1.26.0...v2.0.0) --- updated-dependencies: - dependency-name: slackapi/slack-github-action dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] * chore: fix breaking change See https://github.com/slackapi/slack-github-action/releases/tag/v2.0.0 --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Manuel de la Peña --- .github/workflows/docker-moby-latest.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/docker-moby-latest.yml b/.github/workflows/docker-moby-latest.yml index dc06fb49e7..86ad12c6e0 100644 --- a/.github/workflows/docker-moby-latest.yml +++ b/.github/workflows/docker-moby-latest.yml @@ -70,8 +70,9 @@ jobs: - name: Notify to Slack on failures if: failure() id: slack - uses: slackapi/slack-github-action@v1.26.0 + uses: slackapi/slack-github-action@v2.0.0 with: + payload-templated: true payload-file-path: "./payload-slack-content.json" env: SLACK_WEBHOOK_URL: ${{ secrets.SLACK_DOCKER_LATEST_WEBHOOK }} From 6f718ee2f04205e5534af556384963ef871a508d Mon Sep 17 00:00:00 2001 From: Mohamed Badawi Date: Thu, 2 Jan 2025 09:59:14 +0100 Subject: [PATCH 07/11] feat(termination)!: make container termination timeout configurable (#2926) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(termination): make container termination timeout configurable * revert TerminateOption to terminateOptions, seperate from DockerContainer and introduce configuration test * rename to defaultOptions and pass ctx * address comments * resolve more suggestions * don't export terminateOptions * address more comments * avoid global default * address more comments * add docs for new termination options * move docs to garbage_collector/#terminate-function * resolve doc comments * docs: typo --------- Co-authored-by: Manuel de la Peña --- cleanup.go | 98 +++++++++++++++++------------- container.go | 2 +- docker.go | 15 +++-- docker_test.go | 75 +++++++++++++++++++++++ docs/features/garbage_collector.md | 41 +++++++++++++ modules/etcd/etcd.go | 6 +- port_forwarding.go | 4 +- 7 files changed, 188 insertions(+), 53 deletions(-) diff --git a/cleanup.go b/cleanup.go index e2d52440b9..d676b42bdb 100644 --- a/cleanup.go +++ b/cleanup.go @@ -8,20 +8,65 @@ import ( "time" ) -// terminateOptions is a type that holds the options for terminating a container. -type terminateOptions struct { - ctx context.Context - timeout *time.Duration - volumes []string +// TerminateOptions is a type that holds the options for terminating a container. +type TerminateOptions struct { + ctx context.Context + stopTimeout *time.Duration + volumes []string } // TerminateOption is a type that represents an option for terminating a container. -type TerminateOption func(*terminateOptions) +type TerminateOption func(*TerminateOptions) + +// NewTerminateOptions returns a fully initialised TerminateOptions. +// Defaults: StopTimeout: 10 seconds. +func NewTerminateOptions(ctx context.Context, opts ...TerminateOption) *TerminateOptions { + timeout := time.Second * 10 + options := &TerminateOptions{ + stopTimeout: &timeout, + ctx: ctx, + } + for _, opt := range opts { + opt(options) + } + return options +} + +// Context returns the context to use during a Terminate. +func (o *TerminateOptions) Context() context.Context { + return o.ctx +} + +// StopTimeout returns the stop timeout to use during a Terminate. +func (o *TerminateOptions) StopTimeout() *time.Duration { + return o.stopTimeout +} + +// Cleanup performs any clean up needed +func (o *TerminateOptions) Cleanup() error { + // TODO: simplify this when when perform the client refactor. + if len(o.volumes) == 0 { + return nil + } + client, err := NewDockerClientWithOpts(o.ctx) + if err != nil { + return fmt.Errorf("docker client: %w", err) + } + defer client.Close() + // Best effort to remove all volumes. + var errs []error + for _, volume := range o.volumes { + if errRemove := client.VolumeRemove(o.ctx, volume, true); errRemove != nil { + errs = append(errs, fmt.Errorf("volume remove %q: %w", volume, errRemove)) + } + } + return errors.Join(errs...) +} // StopContext returns a TerminateOption that sets the context. // Default: context.Background(). func StopContext(ctx context.Context) TerminateOption { - return func(c *terminateOptions) { + return func(c *TerminateOptions) { c.ctx = ctx } } @@ -29,8 +74,8 @@ func StopContext(ctx context.Context) TerminateOption { // StopTimeout returns a TerminateOption that sets the timeout. // Default: See [Container.Stop]. func StopTimeout(timeout time.Duration) TerminateOption { - return func(c *terminateOptions) { - c.timeout = &timeout + return func(c *TerminateOptions) { + c.stopTimeout = &timeout } } @@ -39,7 +84,7 @@ func StopTimeout(timeout time.Duration) TerminateOption { // which are not removed by default. // Default: nil. func RemoveVolumes(volumes ...string) TerminateOption { - return func(c *terminateOptions) { + return func(c *TerminateOptions) { c.volumes = volumes } } @@ -54,41 +99,12 @@ func TerminateContainer(container Container, options ...TerminateOption) error { return nil } - c := &terminateOptions{ - ctx: context.Background(), - } - - for _, opt := range options { - opt(c) - } - - // TODO: Add a timeout when terminate supports it. - err := container.Terminate(c.ctx) + err := container.Terminate(context.Background(), options...) if !isCleanupSafe(err) { return fmt.Errorf("terminate: %w", err) } - // Remove additional volumes if any. - if len(c.volumes) == 0 { - return nil - } - - client, err := NewDockerClientWithOpts(c.ctx) - if err != nil { - return fmt.Errorf("docker client: %w", err) - } - - defer client.Close() - - // Best effort to remove all volumes. - var errs []error - for _, volume := range c.volumes { - if errRemove := client.VolumeRemove(c.ctx, volume, true); errRemove != nil { - errs = append(errs, fmt.Errorf("volume remove %q: %w", volume, errRemove)) - } - } - - return errors.Join(errs...) + return nil } // isNil returns true if val is nil or an nil instance false otherwise. diff --git a/container.go b/container.go index 35be60fb81..50fc656e7e 100644 --- a/container.go +++ b/container.go @@ -50,7 +50,7 @@ type Container interface { Stop(context.Context, *time.Duration) error // stop the container // Terminate stops and removes the container and its image if it was built and not flagged as kept. - Terminate(ctx context.Context) error + Terminate(ctx context.Context, opts ...TerminateOption) error Logs(context.Context) (io.ReadCloser, error) // Get logs of the container FollowOutput(LogConsumer) // Deprecated: it will be removed in the next major release diff --git a/docker.go b/docker.go index 01b3d3d4d2..2ce849be50 100644 --- a/docker.go +++ b/docker.go @@ -303,12 +303,11 @@ func (c *DockerContainer) Stop(ctx context.Context, timeout *time.Duration) erro // The following hooks are called in order: // - [ContainerLifecycleHooks.PreTerminates] // - [ContainerLifecycleHooks.PostTerminates] -func (c *DockerContainer) Terminate(ctx context.Context) error { - // ContainerRemove hardcodes stop timeout to 3 seconds which is too short - // to ensure that child containers are stopped so we manually call stop. - // TODO: make this configurable via a functional option. - timeout := 10 * time.Second - err := c.Stop(ctx, &timeout) +// +// Default: timeout is 10 seconds. +func (c *DockerContainer) Terminate(ctx context.Context, opts ...TerminateOption) error { + options := NewTerminateOptions(ctx, opts...) + err := c.Stop(options.Context(), options.StopTimeout()) if err != nil && !isCleanupSafe(err) { return fmt.Errorf("stop: %w", err) } @@ -343,6 +342,10 @@ func (c *DockerContainer) Terminate(ctx context.Context) error { c.sessionID = "" c.isRunning = false + if err = options.Cleanup(); err != nil { + errs = append(errs, err) + } + return errors.Join(errs...) } diff --git a/docker_test.go b/docker_test.go index 8fcd60c558..0dd60f6db9 100644 --- a/docker_test.go +++ b/docker_test.go @@ -281,6 +281,18 @@ func TestContainerStateAfterTermination(t *testing.T) { require.Nil(t, state, "expected nil container inspect.") }) + t.Run("termination-timeout", func(t *testing.T) { + ctx := context.Background() + nginx, err := createContainerFn(ctx) + require.NoError(t, err) + + err = nginx.Start(ctx) + require.NoError(t, err, "expected no error from container start.") + + err = nginx.Terminate(ctx, StopTimeout(5*time.Microsecond)) + require.NoError(t, err) + }) + t.Run("Nil State after termination if raw as already set", func(t *testing.T) { ctx := context.Background() nginx, err := createContainerFn(ctx) @@ -1077,6 +1089,38 @@ func TestContainerCreationWithVolumeAndFileWritingToIt(t *testing.T) { { HostFilePath: absPath, ContainerFilePath: "/hello.sh", + FileMode: 700, + }, + }, + Mounts: Mounts(VolumeMount(volumeName, "/data")), + Cmd: []string{"bash", "/hello.sh"}, + WaitingFor: wait.ForLog("done"), + }, + Started: true, + }) + CleanupContainer(t, bashC, RemoveVolumes(volumeName)) + require.NoError(t, err) +} + +func TestContainerCreationWithVolumeCleaning(t *testing.T) { + absPath, err := filepath.Abs(filepath.Join(".", "testdata", "hello.sh")) + require.NoError(t, err) + ctx, cnl := context.WithTimeout(context.Background(), 30*time.Second) + defer cnl() + + // Create the volume. + volumeName := "volumeName" + + // Create the container that writes into the mounted volume. + bashC, err := GenericContainer(ctx, GenericContainerRequest{ + ProviderType: providerType, + ContainerRequest: ContainerRequest{ + Image: "bash:5.2.26", + Files: []ContainerFile{ + { + HostFilePath: absPath, + ContainerFilePath: "/hello.sh", + FileMode: 700, }, }, Mounts: Mounts(VolumeMount(volumeName, "/data")), @@ -1085,10 +1129,41 @@ func TestContainerCreationWithVolumeAndFileWritingToIt(t *testing.T) { }, Started: true, }) + require.NoError(t, err) + err = bashC.Terminate(ctx, RemoveVolumes(volumeName)) CleanupContainer(t, bashC, RemoveVolumes(volumeName)) require.NoError(t, err) } +func TestContainerTerminationOptions(t *testing.T) { + t.Run("volumes", func(t *testing.T) { + var options TerminateOptions + RemoveVolumes("vol1", "vol2")(&options) + require.Equal(t, TerminateOptions{ + volumes: []string{"vol1", "vol2"}, + }, options) + }) + t.Run("stop-timeout", func(t *testing.T) { + var options TerminateOptions + timeout := 11 * time.Second + StopTimeout(timeout)(&options) + require.Equal(t, TerminateOptions{ + stopTimeout: &timeout, + }, options) + }) + + t.Run("all", func(t *testing.T) { + var options TerminateOptions + timeout := 9 * time.Second + StopTimeout(timeout)(&options) + RemoveVolumes("vol1", "vol2")(&options) + require.Equal(t, TerminateOptions{ + stopTimeout: &timeout, + volumes: []string{"vol1", "vol2"}, + }, options) + }) +} + func TestContainerWithTmpFs(t *testing.T) { ctx := context.Background() req := ContainerRequest{ diff --git a/docs/features/garbage_collector.md b/docs/features/garbage_collector.md index e725f5a9bd..4712c59748 100644 --- a/docs/features/garbage_collector.md +++ b/docs/features/garbage_collector.md @@ -17,6 +17,47 @@ The primary method is to use the `Terminate(context.Context)` function that is available when a container is created. Use `defer` to ensure that it is called on test completion. +The `Terminate` function can be customised with termination options to determine how a container is removed: termination timeout, and the ability to remove container volumes are supported at the moment. You can build the default options using the `testcontainers.NewTerminationOptions` function. + +#### NewTerminateOptions + +- Not available until the next release of testcontainers-go :material-tag: main + +If you want to attach option to container termination, you can use the `testcontainers.NewTerminateOptions(ctx context.Context, opts ...TerminateOption) *TerminateOptions` option, which receives a TerminateOption as parameter, creating custom termination options to be passed on the container termination. + +##### Terminate Options + +###### [StopContext](../../cleanup.go) +Sets the context for the Container termination. + +- **Function**: `StopContext(ctx context.Context) TerminateOption` +- **Default**: The context passed in `Terminate()` +- **Usage**: +```go +err := container.Terminate(ctx,StopContext(context.Background())) +``` + +###### [StopTimeout](../../cleanup.go) +Sets the timeout for stopping the Container. + +- **Function**: ` StopTimeout(timeout time.Duration) TerminateOption` +- **Default**: 10 seconds +- **Usage**: +```go +err := container.Terminate(ctx, StopTimeout(20 * time.Second)) +``` + +###### [RemoveVolumes](../../cleanup.go) +Sets the volumes to be removed during Container termination. + +- **Function**: ` RemoveVolumes(volumes ...string) TerminateOption` +- **Default**: Empty (no volumes removed) +- **Usage**: +```go +err := container.Terminate(ctx, RemoveVolumes("vol1", "vol2")) +``` + + !!!tip Remember to `defer` as soon as possible so you won't forget. The best time diff --git a/modules/etcd/etcd.go b/modules/etcd/etcd.go index 7ea78b4385..a715150bf1 100644 --- a/modules/etcd/etcd.go +++ b/modules/etcd/etcd.go @@ -29,18 +29,18 @@ type EtcdContainer struct { // Terminate terminates the etcd container, its child nodes, and the network in which the cluster is running // to communicate between the nodes. -func (c *EtcdContainer) Terminate(ctx context.Context) error { +func (c *EtcdContainer) Terminate(ctx context.Context, opts ...testcontainers.TerminateOption) error { var errs []error // child nodes has no other children for i, child := range c.childNodes { - if err := child.Terminate(ctx); err != nil { + if err := child.Terminate(ctx, opts...); err != nil { errs = append(errs, fmt.Errorf("terminate child node(%d): %w", i, err)) } } if c.Container != nil { - if err := c.Container.Terminate(ctx); err != nil { + if err := c.Container.Terminate(ctx, opts...); err != nil { errs = append(errs, fmt.Errorf("terminate cluster node: %w", err)) } } diff --git a/port_forwarding.go b/port_forwarding.go index bb6bae2393..3411ff0c1f 100644 --- a/port_forwarding.go +++ b/port_forwarding.go @@ -225,10 +225,10 @@ type sshdContainer struct { } // Terminate stops the container and closes the SSH session -func (sshdC *sshdContainer) Terminate(ctx context.Context) error { +func (sshdC *sshdContainer) Terminate(ctx context.Context, opts ...TerminateOption) error { return errors.Join( sshdC.closePorts(), - sshdC.Container.Terminate(ctx), + sshdC.Container.Terminate(ctx, opts...), ) } From 7ca837dc011ed788a1ae159fb924c19df9a5a1d5 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 2 Jan 2025 10:01:16 +0100 Subject: [PATCH 08/11] chore(deps): bump sonarsource/sonarcloud-github-action (#2933) Bumps [sonarsource/sonarcloud-github-action](https://github.com/sonarsource/sonarcloud-github-action) from 2.1.1 to 4.0.0. - [Release notes](https://github.com/sonarsource/sonarcloud-github-action/releases) - [Commits](https://github.com/sonarsource/sonarcloud-github-action/compare/49e6cd3b187936a73b8280d59ffd9da69df63ec9...02ef91109b2d589e757aefcfb2854c2783fd7b19) --- updated-dependencies: - dependency-name: sonarsource/sonarcloud-github-action dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Steven Hartland --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 10fb7f4949..1538b1e645 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -140,7 +140,7 @@ jobs: merge-multiple: true - name: Analyze with SonarCloud - uses: sonarsource/sonarcloud-github-action@49e6cd3b187936a73b8280d59ffd9da69df63ec9 # v2.1.1 + uses: sonarsource/sonarcloud-github-action@02ef91109b2d589e757aefcfb2854c2783fd7b19 # v4.0.0 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }} From 632249a3d322b046d2bd999ef0c2874a12c5324f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 2 Jan 2025 15:05:53 +0100 Subject: [PATCH 09/11] chore(deps): bump jinja2 from 3.1.4 to 3.1.5 (#2935) Bumps [jinja2](https://github.com/pallets/jinja) from 3.1.4 to 3.1.5. - [Release notes](https://github.com/pallets/jinja/releases) - [Changelog](https://github.com/pallets/jinja/blob/main/CHANGES.rst) - [Commits](https://github.com/pallets/jinja/compare/3.1.4...3.1.5) --- updated-dependencies: - dependency-name: jinja2 dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Pipfile.lock | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/Pipfile.lock b/Pipfile.lock index 9a2f6d24c8..d08964ab4c 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -178,11 +178,12 @@ }, "jinja2": { "hashes": [ - "sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369", - "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d" + "sha256:8fefff8dc3034e27bb80d67c671eb8a9bc424c0ef4c0826edbff304cceff43bb", + "sha256:aba0f4dc9ed8013c424088f68a5c226f7d6097ed89b246d7749c2ec4175c6adb" ], + "index": "pypi", "markers": "python_version >= '3.7'", - "version": "==3.1.4" + "version": "==3.1.5" }, "markdown": { "hashes": [ From 6ec91f1ea81779b2859aaae53a7f553be282efcf Mon Sep 17 00:00:00 2001 From: Steven Hartland Date: Thu, 2 Jan 2025 15:05:10 +0000 Subject: [PATCH 10/11] feat(ollama): support calling the Ollama local process (#2923) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: support running ollama from the local binary * fix: wrong working dir at CI * chore: extract wait to a function * chore: print local binary logs on error * chore: remove debug logs * fix(ci): kill ollama before the tests * chore: stop ollama using systemctl * chore: support setting log file from the env * chore: support running ollama commands, only * fix: release lock on error * chore: add more test coverage for the option * chore: simplify useLocal checks * chore: simpolify * chore: pass context to runLocal * chore: move ctx to the right scope * chore: remove not needed * chore: use a container function * chore: support reading OLLAMA_HOST * chore: return error with copy APIs * chore: simply execute the script * chore: simplify var initialisation * chore: return nil * fix: return errors on terminate * chore: remove options type * chore: use a map * chor: simplify error on wait * chore: wrap start logic around the localContext * chor: fold * chore: merge wait into start * fix: use proper ContainersState * fix: remove extra conversion * chore: handle remove log file errors properly * chore: go back to string in env vars * refactor(ollama): local process Refactor local process handling for Ollama using a container implementation avoiding the wrapping methods. This defaults to running the binary with an ephemeral port to avoid port conflicts. This behaviour can be overridden my setting OLLAMA_HOST either in the parent environment or in the values passed via WithUseLocal. Improve API compatibility with: - Multiplexed output streams - State reporting - Exec option processing - WaitingFor customisation Fix Container implementation: - Port management - Running checks - Terminate processing - Endpoint argument definition - Add missing methods - Consistent environment handling * chore(ollama): refactor local to use log sub match. Refactor local processing to use the new log sub match functionality. * feat(ollama): validate container request Validate the container request to ensure the user configuration can be processed and no fields that would be ignored are present. * chore(ollama): remove temporary test Remove temporary simple test. * feat(ollama): configurable local process binary Allow the local ollama binary name to be configured using the image name. * docs(ollama): detail local process supported fields Detail the container request supported fields. * docs(ollama): update local process site docs Update local process site docs to match recent changes. * chore: refactor to support TerminateOption Refactor Terminate to support testcontainers.TerminateOption. * fix: remove unused var --------- Co-authored-by: Manuel de la Peña --- .../modules/ollama/install-dependencies.sh | 6 + .github/workflows/ci-test-go.yml | 10 + docs/modules/ollama.md | 50 ++ modules/ollama/examples_test.go | 70 ++ modules/ollama/go.mod | 2 +- modules/ollama/local.go | 755 ++++++++++++++++++ modules/ollama/local_test.go | 636 +++++++++++++++ modules/ollama/ollama.go | 39 +- modules/ollama/options.go | 29 +- modules/ollama/options_test.go | 49 ++ 10 files changed, 1630 insertions(+), 16 deletions(-) create mode 100755 .github/scripts/modules/ollama/install-dependencies.sh create mode 100644 modules/ollama/local.go create mode 100644 modules/ollama/local_test.go create mode 100644 modules/ollama/options_test.go diff --git a/.github/scripts/modules/ollama/install-dependencies.sh b/.github/scripts/modules/ollama/install-dependencies.sh new file mode 100755 index 0000000000..d699158806 --- /dev/null +++ b/.github/scripts/modules/ollama/install-dependencies.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +curl -fsSL https://ollama.com/install.sh | sh + +# kill any running ollama process so that the tests can start from a clean state +sudo systemctl stop ollama.service diff --git a/.github/workflows/ci-test-go.yml b/.github/workflows/ci-test-go.yml index 82be78435f..0d6af15880 100644 --- a/.github/workflows/ci-test-go.yml +++ b/.github/workflows/ci-test-go.yml @@ -107,6 +107,16 @@ jobs: working-directory: ./${{ inputs.project-directory }} run: go build + - name: Install dependencies + shell: bash + run: | + SCRIPT_PATH="./.github/scripts/${{ inputs.project-directory }}/install-dependencies.sh" + if [ -f "$SCRIPT_PATH" ]; then + $SCRIPT_PATH + else + echo "No dependencies script found at $SCRIPT_PATH - skipping installation" + fi + - name: go test # only run tests on linux, there are a number of things that won't allow the tests to run on anything else # many (maybe, all?) images used can only be build on Linux, they don't have Windows in their manifest, and diff --git a/docs/modules/ollama.md b/docs/modules/ollama.md index c16e612142..18cb08b47a 100644 --- a/docs/modules/ollama.md +++ b/docs/modules/ollama.md @@ -16,10 +16,15 @@ go get github.com/testcontainers/testcontainers-go/modules/ollama ## Usage example +The module allows you to run the Ollama container or the local Ollama binary. + [Creating a Ollama container](../../modules/ollama/examples_test.go) inside_block:runOllamaContainer +[Running the local Ollama binary](../../modules/ollama/examples_test.go) inside_block:localOllama +If the local Ollama binary fails to execute, the module will fallback to the container version of Ollama. + ## Module Reference ### Run function @@ -48,6 +53,51 @@ When starting the Ollama container, you can pass options in a variadic way to co If you need to set a different Ollama Docker image, you can set a valid Docker image as the second argument in the `Run` function. E.g. `Run(context.Background(), "ollama/ollama:0.1.25")`. +#### Use Local + +- Not available until the next release of testcontainers-go :material-tag: main + +!!!warning + Please make sure the local Ollama binary is not running when using the local version of the module: + Ollama can be started as a system service, or as part of the Ollama application, + and interacting with the logs of a running Ollama process not managed by the module is not supported. + +If you need to run the local Ollama binary, you can set the `UseLocal` option in the `Run` function. +This option accepts a list of environment variables as a string, that will be applied to the Ollama binary when executing commands. + +E.g. `Run(context.Background(), "ollama/ollama:0.1.25", WithUseLocal("OLLAMA_DEBUG=true"))`. + +All the container methods are available when using the local Ollama binary, but will be executed locally instead of inside the container. +Please consider the following differences when using the local Ollama binary: + +- The local Ollama binary will create a log file in the current working directory, identified by the session ID. E.g. `local-ollama-.log`. It's possible to set the log file name using the `OLLAMA_LOGFILE` environment variable. So if you're running Ollama yourself, from the Ollama app, or the standalone binary, you could use this environment variable to set the same log file name. + - For the Ollama app, the default log file resides in the `$HOME/.ollama/logs/server.log`. + - For the standalone binary, you should start it redirecting the logs to a file. E.g. `ollama serve > /tmp/ollama.log 2>&1`. +- `ConnectionString` returns the connection string to connect to the local Ollama binary started by the module instead of the container. +- `ContainerIP` returns the bound host IP `127.0.0.1` by default. +- `ContainerIPs` returns the bound host IP `["127.0.0.1"]` by default. +- `CopyToContainer`, `CopyDirToContainer`, `CopyFileToContainer` and `CopyFileFromContainer` return an error if called. +- `GetLogProductionErrorChannel` returns a nil channel. +- `Endpoint` returns the endpoint to connect to the local Ollama binary started by the module instead of the container. +- `Exec` passes the command to the local Ollama binary started by the module instead of inside the container. First argument is the command to execute, and the second argument is the list of arguments, else, an error is returned. +- `GetContainerID` returns the container ID of the local Ollama binary started by the module instead of the container, which maps to `local-ollama-`. +- `Host` returns the bound host IP `127.0.0.1` by default. +- `Inspect` returns a ContainerJSON with the state of the local Ollama binary started by the module. +- `IsRunning` returns true if the local Ollama binary process started by the module is running. +- `Logs` returns the logs from the local Ollama binary started by the module instead of the container. +- `MappedPort` returns the port mapping for the local Ollama binary started by the module instead of the container. +- `Start` starts the local Ollama binary process. +- `State` returns the current state of the local Ollama binary process, `stopped` or `running`. +- `Stop` stops the local Ollama binary process. +- `Terminate` calls the `Stop` method and then removes the log file. + +The local Ollama binary will create a log file in the current working directory, and it will be available in the container's `Logs` method. + +!!!info + The local Ollama binary will use the `OLLAMA_HOST` environment variable to set the host and port to listen on. + If the environment variable is not set, it will default to `localhost:0` + which bind to a loopback address on an ephemeral port to avoid port conflicts. + {% include "../features/common_functional_options.md" %} ### Container Methods diff --git a/modules/ollama/examples_test.go b/modules/ollama/examples_test.go index 741db846be..188be45bbb 100644 --- a/modules/ollama/examples_test.go +++ b/modules/ollama/examples_test.go @@ -173,3 +173,73 @@ func ExampleRun_withModel_llama2_langchain() { // Intentionally not asserting the output, as we don't want to run this example in the tests. } + +func ExampleRun_withLocal() { + ctx := context.Background() + + // localOllama { + ollamaContainer, err := tcollama.Run(ctx, "ollama/ollama:0.3.13", tcollama.WithUseLocal("OLLAMA_DEBUG=true")) + defer func() { + if err := testcontainers.TerminateContainer(ollamaContainer); err != nil { + log.Printf("failed to terminate container: %s", err) + } + }() + if err != nil { + log.Printf("failed to start container: %s", err) + return + } + // } + + model := "llama3.2:1b" + + _, _, err = ollamaContainer.Exec(ctx, []string{"ollama", "pull", model}) + if err != nil { + log.Printf("failed to pull model %s: %s", model, err) + return + } + + _, _, err = ollamaContainer.Exec(ctx, []string{"ollama", "run", model}) + if err != nil { + log.Printf("failed to run model %s: %s", model, err) + return + } + + connectionStr, err := ollamaContainer.ConnectionString(ctx) + if err != nil { + log.Printf("failed to get connection string: %s", err) + return + } + + var llm *langchainollama.LLM + if llm, err = langchainollama.New( + langchainollama.WithModel(model), + langchainollama.WithServerURL(connectionStr), + ); err != nil { + log.Printf("failed to create langchain ollama: %s", err) + return + } + + completion, err := llm.Call( + context.Background(), + "how can Testcontainers help with testing?", + llms.WithSeed(42), // the lower the seed, the more deterministic the completion + llms.WithTemperature(0.0), // the lower the temperature, the more creative the completion + ) + if err != nil { + log.Printf("failed to create langchain ollama: %s", err) + return + } + + words := []string{ + "easy", "isolation", "consistency", + } + lwCompletion := strings.ToLower(completion) + + for _, word := range words { + if strings.Contains(lwCompletion, word) { + fmt.Println(true) + } + } + + // Intentionally not asserting the output, as we don't want to run this example in the tests. +} diff --git a/modules/ollama/go.mod b/modules/ollama/go.mod index e22b801031..2aab83b978 100644 --- a/modules/ollama/go.mod +++ b/modules/ollama/go.mod @@ -4,6 +4,7 @@ go 1.22 require ( github.com/docker/docker v27.1.1+incompatible + github.com/docker/go-connections v0.5.0 github.com/google/uuid v1.6.0 github.com/stretchr/testify v1.9.0 github.com/testcontainers/testcontainers-go v0.34.0 @@ -22,7 +23,6 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/distribution/reference v0.6.0 // indirect github.com/dlclark/regexp2 v1.8.1 // indirect - github.com/docker/go-connections v0.5.0 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-logr/logr v1.4.1 // indirect diff --git a/modules/ollama/local.go b/modules/ollama/local.go new file mode 100644 index 0000000000..5751ceee07 --- /dev/null +++ b/modules/ollama/local.go @@ -0,0 +1,755 @@ +package ollama + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "io/fs" + "net" + "os" + "os/exec" + "reflect" + "strings" + "sync" + "syscall" + "time" + + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/network" + "github.com/docker/docker/errdefs" + "github.com/docker/docker/pkg/stdcopy" + "github.com/docker/go-connections/nat" + + "github.com/testcontainers/testcontainers-go" + tcexec "github.com/testcontainers/testcontainers-go/exec" + "github.com/testcontainers/testcontainers-go/wait" +) + +const ( + localPort = "11434" + localBinary = "ollama" + localServeArg = "serve" + localLogRegex = `Listening on (.*:\d+) \(version\s(.*)\)` + localNamePrefix = "local-ollama" + localHostVar = "OLLAMA_HOST" + localLogVar = "OLLAMA_LOGFILE" +) + +var ( + // Ensure localProcess implements the required interfaces. + _ testcontainers.Container = (*localProcess)(nil) + _ testcontainers.ContainerCustomizer = (*localProcess)(nil) + + // zeroTime is the zero time value. + zeroTime time.Time +) + +// localProcess emulates the Ollama container using a local process to improve performance. +type localProcess struct { + sessionID string + + // env is the combined environment variables passed to the Ollama binary. + env []string + + // cmd is the command that runs the Ollama binary, not valid externally if nil. + cmd *exec.Cmd + + // logName and logFile are the file where the Ollama logs are written. + logName string + logFile *os.File + + // host, port and version are extracted from log on startup. + host string + port string + version string + + // waitFor is the strategy to wait for the process to be ready. + waitFor wait.Strategy + + // done is closed when the process is finished. + done chan struct{} + + // wg is used to wait for the process to finish. + wg sync.WaitGroup + + // startedAt is the time when the process started. + startedAt time.Time + + // mtx is used to synchronize access to the process state fields below. + mtx sync.Mutex + + // finishedAt is the time when the process finished. + finishedAt time.Time + + // exitErr is the error returned by the process. + exitErr error + + // binary is the name of the Ollama binary. + binary string +} + +// runLocal returns an OllamaContainer that uses the local Ollama binary instead of using a Docker container. +func (c *localProcess) run(ctx context.Context, req testcontainers.GenericContainerRequest) (*OllamaContainer, error) { + if err := c.validateRequest(req); err != nil { + return nil, fmt.Errorf("validate request: %w", err) + } + + // Apply the updated details from the request. + c.waitFor = req.WaitingFor + c.env = c.env[:0] + for k, v := range req.Env { + c.env = append(c.env, k+"="+v) + if k == localLogVar { + c.logName = v + } + } + + err := c.Start(ctx) + var container *OllamaContainer + if c.cmd != nil { + container = &OllamaContainer{Container: c} + } + + if err != nil { + return container, fmt.Errorf("start ollama: %w", err) + } + + return container, nil +} + +// validateRequest checks that req is valid for the local Ollama binary. +func (c *localProcess) validateRequest(req testcontainers.GenericContainerRequest) error { + var errs []error + if req.WaitingFor == nil { + errs = append(errs, errors.New("ContainerRequest.WaitingFor must be set")) + } + + if !req.Started { + errs = append(errs, errors.New("Started must be true")) + } + + if !reflect.DeepEqual(req.ExposedPorts, []string{localPort + "/tcp"}) { + errs = append(errs, fmt.Errorf("ContainerRequest.ExposedPorts must be %s/tcp got: %s", localPort, req.ExposedPorts)) + } + + // Validate the image and extract the binary name. + // The image must be in the format "[/][:latest]". + if binary := req.Image; binary != "" { + // Check if the version is "latest" or not specified. + if idx := strings.IndexByte(binary, ':'); idx != -1 { + if binary[idx+1:] != "latest" { + errs = append(errs, fmt.Errorf(`ContainerRequest.Image version must be blank or "latest", got: %q`, binary[idx+1:])) + } + binary = binary[:idx] + } + + // Trim the path if present. + if idx := strings.LastIndexByte(binary, '/'); idx != -1 { + binary = binary[idx+1:] + } + + if _, err := exec.LookPath(binary); err != nil { + errs = append(errs, fmt.Errorf("invalid image %q: %w", req.Image, err)) + } else { + c.binary = binary + } + } + + // Reset fields we support to their zero values. + req.Env = nil + req.ExposedPorts = nil + req.WaitingFor = nil + req.Image = "" + req.Started = false + req.Logger = nil // We don't need the logger. + + parts := make([]string, 0, 3) + value := reflect.ValueOf(req) + typ := value.Type() + fields := reflect.VisibleFields(typ) + for _, f := range fields { + field := value.FieldByIndex(f.Index) + if field.Kind() == reflect.Struct { + // Only check the leaf fields. + continue + } + + if !field.IsZero() { + parts = parts[:0] + for i := range f.Index { + parts = append(parts, typ.FieldByIndex(f.Index[:i+1]).Name) + } + errs = append(errs, fmt.Errorf("unsupported field: %s = %q", strings.Join(parts, "."), field)) + } + } + + return errors.Join(errs...) +} + +// Start implements testcontainers.Container interface for the local Ollama binary. +func (c *localProcess) Start(ctx context.Context) error { + if c.IsRunning() { + return errors.New("already running") + } + + cmd := exec.CommandContext(ctx, c.binary, localServeArg) + cmd.Env = c.env + + var err error + c.logFile, err = os.Create(c.logName) + if err != nil { + return fmt.Errorf("create ollama log file: %w", err) + } + + // Multiplex stdout and stderr to the log file matching the Docker API. + cmd.Stdout = stdcopy.NewStdWriter(c.logFile, stdcopy.Stdout) + cmd.Stderr = stdcopy.NewStdWriter(c.logFile, stdcopy.Stderr) + + // Run the ollama serve command in background. + if err = cmd.Start(); err != nil { + return fmt.Errorf("start ollama serve: %w", errors.Join(err, c.cleanup())) + } + + // Past this point, the process was started successfully. + c.cmd = cmd + c.startedAt = time.Now() + + // Reset the details to allow multiple start / stop cycles. + c.done = make(chan struct{}) + c.mtx.Lock() + c.finishedAt = zeroTime + c.exitErr = nil + c.mtx.Unlock() + + // Wait for the process to finish in a goroutine. + c.wg.Add(1) + go func() { + defer func() { + c.wg.Done() + close(c.done) + }() + + err := c.cmd.Wait() + c.mtx.Lock() + defer c.mtx.Unlock() + if err != nil { + c.exitErr = fmt.Errorf("process wait: %w", err) + } + c.finishedAt = time.Now() + }() + + if err = c.waitStrategy(ctx); err != nil { + return fmt.Errorf("wait strategy: %w", err) + } + + return nil +} + +// waitStrategy waits until the Ollama process is ready. +func (c *localProcess) waitStrategy(ctx context.Context) error { + if err := c.waitFor.WaitUntilReady(ctx, c); err != nil { + logs, lerr := c.Logs(ctx) + if lerr != nil { + return errors.Join(err, lerr) + } + defer logs.Close() + + var stderr, stdout bytes.Buffer + _, cerr := stdcopy.StdCopy(&stdout, &stderr, logs) + + return fmt.Errorf( + "%w (stdout: %s, stderr: %s)", + errors.Join(err, cerr), + strings.TrimSpace(stdout.String()), + strings.TrimSpace(stderr.String()), + ) + } + + return nil +} + +// extractLogDetails extracts the listening address and version from the log. +func (c *localProcess) extractLogDetails(pattern string, submatches [][][]byte) error { + var err error + for _, matches := range submatches { + if len(matches) != 3 { + err = fmt.Errorf("`%s` matched %d times, expected %d", pattern, len(matches), 3) + continue + } + + c.host, c.port, err = net.SplitHostPort(string(matches[1])) + if err != nil { + return wait.NewPermanentError(fmt.Errorf("split host port: %w", err)) + } + + // Set OLLAMA_HOST variable to the extracted host so Exec can use it. + c.env = append(c.env, localHostVar+"="+string(matches[1])) + c.version = string(matches[2]) + + return nil + } + + if err != nil { + // Return the last error encountered. + return err + } + + return fmt.Errorf("address and version not found: `%s` no matches", pattern) +} + +// ContainerIP implements testcontainers.Container interface for the local Ollama binary. +func (c *localProcess) ContainerIP(ctx context.Context) (string, error) { + return c.host, nil +} + +// ContainerIPs returns a slice with the IP address of the local Ollama binary. +func (c *localProcess) ContainerIPs(ctx context.Context) ([]string, error) { + return []string{c.host}, nil +} + +// CopyToContainer implements testcontainers.Container interface for the local Ollama binary. +// Returns [errors.ErrUnsupported]. +func (c *localProcess) CopyToContainer(ctx context.Context, fileContent []byte, containerFilePath string, fileMode int64) error { + return errors.ErrUnsupported +} + +// CopyDirToContainer implements testcontainers.Container interface for the local Ollama binary. +// Returns [errors.ErrUnsupported]. +func (c *localProcess) CopyDirToContainer(ctx context.Context, hostDirPath string, containerParentPath string, fileMode int64) error { + return errors.ErrUnsupported +} + +// CopyFileToContainer implements testcontainers.Container interface for the local Ollama binary. +// Returns [errors.ErrUnsupported]. +func (c *localProcess) CopyFileToContainer(ctx context.Context, hostFilePath string, containerFilePath string, fileMode int64) error { + return errors.ErrUnsupported +} + +// CopyFileFromContainer implements testcontainers.Container interface for the local Ollama binary. +// Returns [errors.ErrUnsupported]. +func (c *localProcess) CopyFileFromContainer(ctx context.Context, filePath string) (io.ReadCloser, error) { + return nil, errors.ErrUnsupported +} + +// GetLogProductionErrorChannel implements testcontainers.Container interface for the local Ollama binary. +// It returns a nil channel because the local Ollama binary doesn't have a production error channel. +func (c *localProcess) GetLogProductionErrorChannel() <-chan error { + return nil +} + +// Exec implements testcontainers.Container interface for the local Ollama binary. +// It executes a command using the local Ollama binary and returns the exit status +// of the executed command, an [io.Reader] containing the combined stdout and stderr, +// and any encountered error. +// +// Reading directly from the [io.Reader] may result in unexpected bytes due to custom +// stream multiplexing headers. Use [tcexec.Multiplexed] option to read the combined output +// without the multiplexing headers. +// Alternatively, to separate the stdout and stderr from [io.Reader] and interpret these +// headers properly, [stdcopy.StdCopy] from the Docker API should be used. +func (c *localProcess) Exec(ctx context.Context, cmd []string, options ...tcexec.ProcessOption) (int, io.Reader, error) { + if len(cmd) == 0 { + return 1, nil, errors.New("no command provided") + } else if cmd[0] != c.binary { + return 1, nil, fmt.Errorf("command %q: %w", cmd[0], errors.ErrUnsupported) + } + + command := exec.CommandContext(ctx, cmd[0], cmd[1:]...) + command.Env = c.env + + // Multiplex stdout and stderr to the buffer so they can be read separately later. + var buf bytes.Buffer + command.Stdout = stdcopy.NewStdWriter(&buf, stdcopy.Stdout) + command.Stderr = stdcopy.NewStdWriter(&buf, stdcopy.Stderr) + + // Use process options to customize the command execution + // emulating the Docker API behaviour. + processOptions := tcexec.NewProcessOptions(cmd) + processOptions.Reader = &buf + for _, o := range options { + o.Apply(processOptions) + } + + if err := c.validateExecOptions(processOptions.ExecConfig); err != nil { + return 1, nil, fmt.Errorf("validate exec option: %w", err) + } + + if !processOptions.ExecConfig.AttachStderr { + command.Stderr = io.Discard + } + if !processOptions.ExecConfig.AttachStdout { + command.Stdout = io.Discard + } + if processOptions.ExecConfig.AttachStdin { + command.Stdin = os.Stdin + } + + command.Dir = processOptions.ExecConfig.WorkingDir + command.Env = append(command.Env, processOptions.ExecConfig.Env...) + + if err := command.Run(); err != nil { + return command.ProcessState.ExitCode(), processOptions.Reader, fmt.Errorf("exec %v: %w", cmd, err) + } + + return command.ProcessState.ExitCode(), processOptions.Reader, nil +} + +// validateExecOptions checks if the given exec options are supported by the local Ollama binary. +func (c *localProcess) validateExecOptions(options container.ExecOptions) error { + var errs []error + if options.User != "" { + errs = append(errs, fmt.Errorf("user: %w", errors.ErrUnsupported)) + } + if options.Privileged { + errs = append(errs, fmt.Errorf("privileged: %w", errors.ErrUnsupported)) + } + if options.Tty { + errs = append(errs, fmt.Errorf("tty: %w", errors.ErrUnsupported)) + } + if options.Detach { + errs = append(errs, fmt.Errorf("detach: %w", errors.ErrUnsupported)) + } + if options.DetachKeys != "" { + errs = append(errs, fmt.Errorf("detach keys: %w", errors.ErrUnsupported)) + } + + return errors.Join(errs...) +} + +// Inspect implements testcontainers.Container interface for the local Ollama binary. +// It returns a ContainerJSON with the state of the local Ollama binary. +func (c *localProcess) Inspect(ctx context.Context) (*types.ContainerJSON, error) { + state, err := c.State(ctx) + if err != nil { + return nil, fmt.Errorf("state: %w", err) + } + + return &types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + ID: c.GetContainerID(), + Name: localNamePrefix + "-" + c.sessionID, + State: state, + }, + Config: &container.Config{ + Image: localNamePrefix + ":" + c.version, + ExposedPorts: nat.PortSet{ + nat.Port(localPort + "/tcp"): struct{}{}, + }, + Hostname: c.host, + Entrypoint: []string{c.binary, localServeArg}, + }, + NetworkSettings: &types.NetworkSettings{ + Networks: map[string]*network.EndpointSettings{}, + NetworkSettingsBase: types.NetworkSettingsBase{ + Bridge: "bridge", + Ports: nat.PortMap{ + nat.Port(localPort + "/tcp"): { + {HostIP: c.host, HostPort: c.port}, + }, + }, + }, + DefaultNetworkSettings: types.DefaultNetworkSettings{ + IPAddress: c.host, + }, + }, + }, nil +} + +// IsRunning implements testcontainers.Container interface for the local Ollama binary. +// It returns true if the local Ollama process is running, false otherwise. +func (c *localProcess) IsRunning() bool { + if c.startedAt.IsZero() { + // The process hasn't started yet. + return false + } + + select { + case <-c.done: + // The process exited. + return false + default: + // The process is still running. + return true + } +} + +// Logs implements testcontainers.Container interface for the local Ollama binary. +// It returns the logs from the local Ollama binary. +func (c *localProcess) Logs(ctx context.Context) (io.ReadCloser, error) { + file, err := os.Open(c.logFile.Name()) + if err != nil { + return nil, fmt.Errorf("open log file: %w", err) + } + + return file, nil +} + +// State implements testcontainers.Container interface for the local Ollama binary. +// It returns the current state of the Ollama process, simulating a container state. +func (c *localProcess) State(ctx context.Context) (*types.ContainerState, error) { + c.mtx.Lock() + defer c.mtx.Unlock() + + if !c.IsRunning() { + state := &types.ContainerState{ + Status: "exited", + ExitCode: c.cmd.ProcessState.ExitCode(), + StartedAt: c.startedAt.Format(time.RFC3339Nano), + FinishedAt: c.finishedAt.Format(time.RFC3339Nano), + } + if c.exitErr != nil { + state.Error = c.exitErr.Error() + } + + return state, nil + } + + // Setting the Running field because it's required by the wait strategy + // to check if the given log message is present. + return &types.ContainerState{ + Status: "running", + Running: true, + Pid: c.cmd.Process.Pid, + StartedAt: c.startedAt.Format(time.RFC3339Nano), + FinishedAt: c.finishedAt.Format(time.RFC3339Nano), + }, nil +} + +// Stop implements testcontainers.Container interface for the local Ollama binary. +// It gracefully stops the local Ollama process. +func (c *localProcess) Stop(ctx context.Context, d *time.Duration) error { + if err := c.cmd.Process.Signal(syscall.SIGTERM); err != nil { + return fmt.Errorf("signal ollama: %w", err) + } + + if d != nil { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, *d) + defer cancel() + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-c.done: + // The process exited. + c.mtx.Lock() + defer c.mtx.Unlock() + + return c.exitErr + } +} + +// Terminate implements testcontainers.Container interface for the local Ollama binary. +// It stops the local Ollama process, removing the log file. +func (c *localProcess) Terminate(ctx context.Context, opts ...testcontainers.TerminateOption) error { + options := testcontainers.NewTerminateOptions(ctx, opts...) + // First try to stop gracefully. + if err := c.Stop(options.Context(), options.StopTimeout()); !c.isCleanupSafe(err) { + return fmt.Errorf("stop: %w", err) + } + + var errs []error + if c.IsRunning() { + // Still running, force kill. + if err := c.cmd.Process.Kill(); !c.isCleanupSafe(err) { + // Best effort so we can continue with the cleanup. + errs = append(errs, fmt.Errorf("kill: %w", err)) + } + + // Wait for the process to exit so we can capture any error. + c.wg.Wait() + } + + errs = append(errs, c.cleanup(), options.Cleanup()) + + return errors.Join(errs...) +} + +// cleanup performs all clean up, closing and removing the log file if set. +func (c *localProcess) cleanup() error { + c.mtx.Lock() + defer c.mtx.Unlock() + + if c.logFile == nil { + return c.exitErr + } + + var errs []error + if c.exitErr != nil { + errs = append(errs, fmt.Errorf("exit: %w", c.exitErr)) + } + + if err := c.logFile.Close(); err != nil { + errs = append(errs, fmt.Errorf("close log: %w", err)) + } + + if err := os.Remove(c.logFile.Name()); err != nil && !errors.Is(err, fs.ErrNotExist) { + errs = append(errs, fmt.Errorf("remove log: %w", err)) + } + + c.logFile = nil // Prevent double cleanup. + + return errors.Join(errs...) +} + +// Endpoint implements testcontainers.Container interface for the local Ollama binary. +// It returns proto://host:port string for the Ollama port. +// It returns just host:port if proto is blank. +func (c *localProcess) Endpoint(ctx context.Context, proto string) (string, error) { + return c.PortEndpoint(ctx, localPort, proto) +} + +// GetContainerID implements testcontainers.Container interface for the local Ollama binary. +func (c *localProcess) GetContainerID() string { + return localNamePrefix + "-" + c.sessionID +} + +// Host implements testcontainers.Container interface for the local Ollama binary. +func (c *localProcess) Host(ctx context.Context) (string, error) { + return c.host, nil +} + +// MappedPort implements testcontainers.Container interface for the local Ollama binary. +func (c *localProcess) MappedPort(ctx context.Context, port nat.Port) (nat.Port, error) { + if port.Port() != localPort || port.Proto() != "tcp" { + return "", errdefs.NotFound(fmt.Errorf("port %q not found", port)) + } + + return nat.Port(c.port + "/tcp"), nil +} + +// Networks implements testcontainers.Container interface for the local Ollama binary. +// It returns a nil slice. +func (c *localProcess) Networks(ctx context.Context) ([]string, error) { + return nil, nil +} + +// NetworkAliases implements testcontainers.Container interface for the local Ollama binary. +// It returns a nil map. +func (c *localProcess) NetworkAliases(ctx context.Context) (map[string][]string, error) { + return nil, nil +} + +// PortEndpoint implements testcontainers.Container interface for the local Ollama binary. +// It returns proto://host:port string for the given exposed port. +// It returns just host:port if proto is blank. +func (c *localProcess) PortEndpoint(ctx context.Context, port nat.Port, proto string) (string, error) { + host, err := c.Host(ctx) + if err != nil { + return "", fmt.Errorf("host: %w", err) + } + + outerPort, err := c.MappedPort(ctx, port) + if err != nil { + return "", fmt.Errorf("mapped port: %w", err) + } + + if proto != "" { + proto += "://" + } + + return fmt.Sprintf("%s%s:%s", proto, host, outerPort.Port()), nil +} + +// SessionID implements testcontainers.Container interface for the local Ollama binary. +func (c *localProcess) SessionID() string { + return c.sessionID +} + +// Deprecated: it will be removed in the next major release. +// FollowOutput is not implemented for the local Ollama binary. +// It panics if called. +func (c *localProcess) FollowOutput(consumer testcontainers.LogConsumer) { + panic("not implemented") +} + +// Deprecated: use c.Inspect(ctx).NetworkSettings.Ports instead. +// Ports gets the exposed ports for the container. +func (c *localProcess) Ports(ctx context.Context) (nat.PortMap, error) { + inspect, err := c.Inspect(ctx) + if err != nil { + return nil, err + } + + return inspect.NetworkSettings.Ports, nil +} + +// Deprecated: it will be removed in the next major release. +// StartLogProducer implements testcontainers.Container interface for the local Ollama binary. +// It returns an error because the local Ollama binary doesn't have a log producer. +func (c *localProcess) StartLogProducer(context.Context, ...testcontainers.LogProductionOption) error { + return errors.ErrUnsupported +} + +// Deprecated: it will be removed in the next major release. +// StopLogProducer implements testcontainers.Container interface for the local Ollama binary. +// It returns an error because the local Ollama binary doesn't have a log producer. +func (c *localProcess) StopLogProducer() error { + return errors.ErrUnsupported +} + +// Deprecated: Use c.Inspect(ctx).Name instead. +// Name returns the name for the local Ollama binary. +func (c *localProcess) Name(context.Context) (string, error) { + return localNamePrefix + "-" + c.sessionID, nil +} + +// Customize implements the [testcontainers.ContainerCustomizer] interface. +// It configures the environment variables set by [WithUseLocal] and sets up +// the wait strategy to extract the host, port and version from the log. +func (c *localProcess) Customize(req *testcontainers.GenericContainerRequest) error { + // Replace the default host port strategy with one that waits for a log entry + // and extracts the host, port and version from it. + if err := wait.Walk(&req.WaitingFor, func(w wait.Strategy) error { + if _, ok := w.(*wait.HostPortStrategy); ok { + return wait.VisitRemove + } + + return nil + }); err != nil { + return fmt.Errorf("walk strategies: %w", err) + } + + logStrategy := wait.ForLog(localLogRegex).Submatch(c.extractLogDetails) + if req.WaitingFor == nil { + req.WaitingFor = logStrategy + } else { + req.WaitingFor = wait.ForAll(req.WaitingFor, logStrategy) + } + + // Setup the environment variables using a random port by default + // to avoid conflicts. + osEnv := os.Environ() + env := make(map[string]string, len(osEnv)+len(c.env)+1) + env[localHostVar] = "localhost:0" + for _, kv := range append(osEnv, c.env...) { + parts := strings.SplitN(kv, "=", 2) + if len(parts) != 2 { + return fmt.Errorf("invalid environment variable: %q", kv) + } + + env[parts[0]] = parts[1] + } + + return testcontainers.WithEnv(env)(req) +} + +// isCleanupSafe reports whether all errors in err's tree are one of the +// following, so can safely be ignored: +// - nil +// - os: process already finished +// - context deadline exceeded +func (c *localProcess) isCleanupSafe(err error) bool { + switch { + case err == nil, + errors.Is(err, os.ErrProcessDone), + errors.Is(err, context.DeadlineExceeded): + return true + default: + return false + } +} diff --git a/modules/ollama/local_test.go b/modules/ollama/local_test.go new file mode 100644 index 0000000000..3e0376d4de --- /dev/null +++ b/modules/ollama/local_test.go @@ -0,0 +1,636 @@ +package ollama_test + +import ( + "context" + "errors" + "io" + "io/fs" + "os" + "os/exec" + "path/filepath" + "regexp" + "testing" + "time" + + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/strslice" + "github.com/stretchr/testify/require" + + "github.com/testcontainers/testcontainers-go" + tcexec "github.com/testcontainers/testcontainers-go/exec" + "github.com/testcontainers/testcontainers-go/modules/ollama" +) + +const ( + testImage = "ollama/ollama:latest" + testNatPort = "11434/tcp" + testHost = "127.0.0.1" + testBinary = "ollama" +) + +var ( + // reLogDetails matches the log details of the local ollama binary and should match localLogRegex. + reLogDetails = regexp.MustCompile(`Listening on (.*:\d+) \(version\s(.*)\)`) + zeroTime = time.Time{}.Format(time.RFC3339Nano) +) + +func TestRun_local(t *testing.T) { + // check if the local ollama binary is available + if _, err := exec.LookPath(testBinary); err != nil { + t.Skip("local ollama binary not found, skipping") + } + + ctx := context.Background() + ollamaContainer, err := ollama.Run( + ctx, + testImage, + ollama.WithUseLocal("FOO=BAR"), + ) + testcontainers.CleanupContainer(t, ollamaContainer) + require.NoError(t, err) + + t.Run("state", func(t *testing.T) { + state, err := ollamaContainer.State(ctx) + require.NoError(t, err) + require.NotEmpty(t, state.StartedAt) + require.NotEqual(t, zeroTime, state.StartedAt) + require.NotZero(t, state.Pid) + require.Equal(t, &types.ContainerState{ + Status: "running", + Running: true, + Pid: state.Pid, + StartedAt: state.StartedAt, + FinishedAt: time.Time{}.Format(time.RFC3339Nano), + }, state) + }) + + t.Run("connection-string", func(t *testing.T) { + connectionStr, err := ollamaContainer.ConnectionString(ctx) + require.NoError(t, err) + require.NotEmpty(t, connectionStr) + }) + + t.Run("container-id", func(t *testing.T) { + id := ollamaContainer.GetContainerID() + require.Equal(t, "local-ollama-"+testcontainers.SessionID(), id) + }) + + t.Run("container-ips", func(t *testing.T) { + ip, err := ollamaContainer.ContainerIP(ctx) + require.NoError(t, err) + require.Equal(t, testHost, ip) + + ips, err := ollamaContainer.ContainerIPs(ctx) + require.NoError(t, err) + require.Equal(t, []string{testHost}, ips) + }) + + t.Run("copy", func(t *testing.T) { + err := ollamaContainer.CopyToContainer(ctx, []byte("test"), "/tmp", 0o755) + require.Error(t, err) + + err = ollamaContainer.CopyDirToContainer(ctx, ".", "/tmp", 0o755) + require.Error(t, err) + + err = ollamaContainer.CopyFileToContainer(ctx, ".", "/tmp", 0o755) + require.Error(t, err) + + reader, err := ollamaContainer.CopyFileFromContainer(ctx, "/tmp") + require.Error(t, err) + require.Nil(t, reader) + }) + + t.Run("log-production-error-channel", func(t *testing.T) { + ch := ollamaContainer.GetLogProductionErrorChannel() + require.Nil(t, ch) + }) + + t.Run("endpoint", func(t *testing.T) { + endpoint, err := ollamaContainer.Endpoint(ctx, "") + require.NoError(t, err) + require.Contains(t, endpoint, testHost+":") + + endpoint, err = ollamaContainer.Endpoint(ctx, "http") + require.NoError(t, err) + require.Contains(t, endpoint, "http://"+testHost+":") + }) + + t.Run("is-running", func(t *testing.T) { + require.True(t, ollamaContainer.IsRunning()) + + err = ollamaContainer.Stop(ctx, nil) + require.NoError(t, err) + require.False(t, ollamaContainer.IsRunning()) + + // return it to the running state + err = ollamaContainer.Start(ctx) + require.NoError(t, err) + require.True(t, ollamaContainer.IsRunning()) + }) + + t.Run("host", func(t *testing.T) { + host, err := ollamaContainer.Host(ctx) + require.NoError(t, err) + require.Equal(t, testHost, host) + }) + + t.Run("inspect", func(t *testing.T) { + inspect, err := ollamaContainer.Inspect(ctx) + require.NoError(t, err) + + require.Equal(t, "local-ollama-"+testcontainers.SessionID(), inspect.ContainerJSONBase.ID) + require.Equal(t, "local-ollama-"+testcontainers.SessionID(), inspect.ContainerJSONBase.Name) + require.True(t, inspect.ContainerJSONBase.State.Running) + + require.NotEmpty(t, inspect.Config.Image) + _, exists := inspect.Config.ExposedPorts[testNatPort] + require.True(t, exists) + require.Equal(t, testHost, inspect.Config.Hostname) + require.Equal(t, strslice.StrSlice(strslice.StrSlice{testBinary, "serve"}), inspect.Config.Entrypoint) + + require.Empty(t, inspect.NetworkSettings.Networks) + require.Equal(t, "bridge", inspect.NetworkSettings.NetworkSettingsBase.Bridge) + + ports := inspect.NetworkSettings.NetworkSettingsBase.Ports + port, exists := ports[testNatPort] + require.True(t, exists) + require.Len(t, port, 1) + require.Equal(t, testHost, port[0].HostIP) + require.NotEmpty(t, port[0].HostPort) + }) + + t.Run("logfile", func(t *testing.T) { + file, err := os.Open("local-ollama-" + testcontainers.SessionID() + ".log") + require.NoError(t, err) + require.NoError(t, file.Close()) + }) + + t.Run("logs", func(t *testing.T) { + logs, err := ollamaContainer.Logs(ctx) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, logs.Close()) + }) + + bs, err := io.ReadAll(logs) + require.NoError(t, err) + require.Regexp(t, reLogDetails, string(bs)) + }) + + t.Run("mapped-port", func(t *testing.T) { + port, err := ollamaContainer.MappedPort(ctx, testNatPort) + require.NoError(t, err) + require.NotEmpty(t, port.Port()) + require.Equal(t, "tcp", port.Proto()) + }) + + t.Run("networks", func(t *testing.T) { + networks, err := ollamaContainer.Networks(ctx) + require.NoError(t, err) + require.Nil(t, networks) + }) + + t.Run("network-aliases", func(t *testing.T) { + aliases, err := ollamaContainer.NetworkAliases(ctx) + require.NoError(t, err) + require.Nil(t, aliases) + }) + + t.Run("port-endpoint", func(t *testing.T) { + endpoint, err := ollamaContainer.PortEndpoint(ctx, testNatPort, "") + require.NoError(t, err) + require.Regexp(t, regexp.MustCompile(`^127.0.0.1:\d+$`), endpoint) + + endpoint, err = ollamaContainer.PortEndpoint(ctx, testNatPort, "http") + require.NoError(t, err) + require.Regexp(t, regexp.MustCompile(`^http://127.0.0.1:\d+$`), endpoint) + }) + + t.Run("session-id", func(t *testing.T) { + require.Equal(t, testcontainers.SessionID(), ollamaContainer.SessionID()) + }) + + t.Run("stop-start", func(t *testing.T) { + d := time.Second * 5 + err := ollamaContainer.Stop(ctx, &d) + require.NoError(t, err) + + state, err := ollamaContainer.State(ctx) + require.NoError(t, err) + require.Equal(t, "exited", state.Status) + require.NotEmpty(t, state.StartedAt) + require.NotEqual(t, zeroTime, state.StartedAt) + require.NotEmpty(t, state.FinishedAt) + require.NotEqual(t, zeroTime, state.FinishedAt) + require.Zero(t, state.ExitCode) + + err = ollamaContainer.Start(ctx) + require.NoError(t, err) + + state, err = ollamaContainer.State(ctx) + require.NoError(t, err) + require.Equal(t, "running", state.Status) + + logs, err := ollamaContainer.Logs(ctx) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, logs.Close()) + }) + + bs, err := io.ReadAll(logs) + require.NoError(t, err) + require.Regexp(t, reLogDetails, string(bs)) + }) + + t.Run("start-start", func(t *testing.T) { + state, err := ollamaContainer.State(ctx) + require.NoError(t, err) + require.Equal(t, "running", state.Status) + + err = ollamaContainer.Start(ctx) + require.Error(t, err) + }) + + t.Run("terminate", func(t *testing.T) { + err := ollamaContainer.Terminate(ctx) + require.NoError(t, err) + + _, err = os.Stat("ollama-" + testcontainers.SessionID() + ".log") + require.ErrorIs(t, err, fs.ErrNotExist) + + state, err := ollamaContainer.State(ctx) + require.NoError(t, err) + require.NotEmpty(t, state.StartedAt) + require.NotEqual(t, zeroTime, state.StartedAt) + require.NotEmpty(t, state.FinishedAt) + require.NotEqual(t, zeroTime, state.FinishedAt) + require.Equal(t, &types.ContainerState{ + Status: "exited", + StartedAt: state.StartedAt, + FinishedAt: state.FinishedAt, + }, state) + }) + + t.Run("deprecated", func(t *testing.T) { + t.Run("ports", func(t *testing.T) { + inspect, err := ollamaContainer.Inspect(ctx) + require.NoError(t, err) + + ports, err := ollamaContainer.Ports(ctx) + require.NoError(t, err) + require.Equal(t, inspect.NetworkSettings.Ports, ports) + }) + + t.Run("follow-output", func(t *testing.T) { + require.Panics(t, func() { + ollamaContainer.FollowOutput(&testcontainers.StdoutLogConsumer{}) + }) + }) + + t.Run("start-log-producer", func(t *testing.T) { + err := ollamaContainer.StartLogProducer(ctx) + require.ErrorIs(t, err, errors.ErrUnsupported) + }) + + t.Run("stop-log-producer", func(t *testing.T) { + err := ollamaContainer.StopLogProducer() + require.ErrorIs(t, err, errors.ErrUnsupported) + }) + + t.Run("name", func(t *testing.T) { + name, err := ollamaContainer.Name(ctx) + require.NoError(t, err) + require.Equal(t, "local-ollama-"+testcontainers.SessionID(), name) + }) + }) +} + +func TestRun_localWithCustomLogFile(t *testing.T) { + ctx := context.Background() + logFile := filepath.Join(t.TempDir(), "server.log") + + t.Run("parent-env", func(t *testing.T) { + t.Setenv("OLLAMA_LOGFILE", logFile) + + ollamaContainer, err := ollama.Run(ctx, testImage, ollama.WithUseLocal()) + testcontainers.CleanupContainer(t, ollamaContainer) + require.NoError(t, err) + + logs, err := ollamaContainer.Logs(ctx) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, logs.Close()) + }) + + bs, err := io.ReadAll(logs) + require.NoError(t, err) + require.Regexp(t, reLogDetails, string(bs)) + + file, ok := logs.(*os.File) + require.True(t, ok) + require.Equal(t, logFile, file.Name()) + }) + + t.Run("local-env", func(t *testing.T) { + ollamaContainer, err := ollama.Run(ctx, testImage, ollama.WithUseLocal("OLLAMA_LOGFILE="+logFile)) + testcontainers.CleanupContainer(t, ollamaContainer) + require.NoError(t, err) + + logs, err := ollamaContainer.Logs(ctx) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, logs.Close()) + }) + + bs, err := io.ReadAll(logs) + require.NoError(t, err) + require.Regexp(t, reLogDetails, string(bs)) + + file, ok := logs.(*os.File) + require.True(t, ok) + require.Equal(t, logFile, file.Name()) + }) +} + +func TestRun_localWithCustomHost(t *testing.T) { + ctx := context.Background() + + t.Run("parent-env", func(t *testing.T) { + t.Setenv("OLLAMA_HOST", "127.0.0.1:1234") + + ollamaContainer, err := ollama.Run(ctx, testImage, ollama.WithUseLocal()) + testcontainers.CleanupContainer(t, ollamaContainer) + require.NoError(t, err) + + testRun_localWithCustomHost(ctx, t, ollamaContainer) + }) + + t.Run("local-env", func(t *testing.T) { + ollamaContainer, err := ollama.Run(ctx, testImage, ollama.WithUseLocal("OLLAMA_HOST=127.0.0.1:1234")) + testcontainers.CleanupContainer(t, ollamaContainer) + require.NoError(t, err) + + testRun_localWithCustomHost(ctx, t, ollamaContainer) + }) +} + +func testRun_localWithCustomHost(ctx context.Context, t *testing.T, ollamaContainer *ollama.OllamaContainer) { + t.Helper() + + t.Run("connection-string", func(t *testing.T) { + connectionStr, err := ollamaContainer.ConnectionString(ctx) + require.NoError(t, err) + require.Equal(t, "http://127.0.0.1:1234", connectionStr) + }) + + t.Run("endpoint", func(t *testing.T) { + endpoint, err := ollamaContainer.Endpoint(ctx, "http") + require.NoError(t, err) + require.Equal(t, "http://127.0.0.1:1234", endpoint) + }) + + t.Run("inspect", func(t *testing.T) { + inspect, err := ollamaContainer.Inspect(ctx) + require.NoError(t, err) + require.Regexp(t, regexp.MustCompile(`^local-ollama:\d+\.\d+\.\d+$`), inspect.Config.Image) + + _, exists := inspect.Config.ExposedPorts[testNatPort] + require.True(t, exists) + require.Equal(t, testHost, inspect.Config.Hostname) + require.Equal(t, strslice.StrSlice(strslice.StrSlice{testBinary, "serve"}), inspect.Config.Entrypoint) + + require.Empty(t, inspect.NetworkSettings.Networks) + require.Equal(t, "bridge", inspect.NetworkSettings.NetworkSettingsBase.Bridge) + + ports := inspect.NetworkSettings.NetworkSettingsBase.Ports + port, exists := ports[testNatPort] + require.True(t, exists) + require.Len(t, port, 1) + require.Equal(t, testHost, port[0].HostIP) + require.Equal(t, "1234", port[0].HostPort) + }) + + t.Run("logs", func(t *testing.T) { + logs, err := ollamaContainer.Logs(ctx) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, logs.Close()) + }) + + bs, err := io.ReadAll(logs) + require.NoError(t, err) + + require.Contains(t, string(bs), "Listening on 127.0.0.1:1234") + }) + + t.Run("mapped-port", func(t *testing.T) { + port, err := ollamaContainer.MappedPort(ctx, testNatPort) + require.NoError(t, err) + require.Equal(t, "1234", port.Port()) + require.Equal(t, "tcp", port.Proto()) + }) +} + +func TestRun_localExec(t *testing.T) { + // check if the local ollama binary is available + if _, err := exec.LookPath(testBinary); err != nil { + t.Skip("local ollama binary not found, skipping") + } + + ctx := context.Background() + + ollamaContainer, err := ollama.Run(ctx, testImage, ollama.WithUseLocal()) + testcontainers.CleanupContainer(t, ollamaContainer) + require.NoError(t, err) + + t.Run("no-command", func(t *testing.T) { + code, r, err := ollamaContainer.Exec(ctx, nil) + require.Error(t, err) + require.Equal(t, 1, code) + require.Nil(t, r) + }) + + t.Run("unsupported-command", func(t *testing.T) { + code, r, err := ollamaContainer.Exec(ctx, []string{"cat", "/etc/hosts"}) + require.ErrorIs(t, err, errors.ErrUnsupported) + require.Equal(t, 1, code) + require.Nil(t, r) + }) + + t.Run("unsupported-option-user", func(t *testing.T) { + code, r, err := ollamaContainer.Exec(ctx, []string{testBinary, "-v"}, tcexec.WithUser("root")) + require.ErrorIs(t, err, errors.ErrUnsupported) + require.Equal(t, 1, code) + require.Nil(t, r) + }) + + t.Run("unsupported-option-privileged", func(t *testing.T) { + code, r, err := ollamaContainer.Exec(ctx, []string{testBinary, "-v"}, tcexec.ProcessOptionFunc(func(opts *tcexec.ProcessOptions) { + opts.ExecConfig.Privileged = true + })) + require.ErrorIs(t, err, errors.ErrUnsupported) + require.Equal(t, 1, code) + require.Nil(t, r) + }) + + t.Run("unsupported-option-tty", func(t *testing.T) { + code, r, err := ollamaContainer.Exec(ctx, []string{testBinary, "-v"}, tcexec.ProcessOptionFunc(func(opts *tcexec.ProcessOptions) { + opts.ExecConfig.Tty = true + })) + require.ErrorIs(t, err, errors.ErrUnsupported) + require.Equal(t, 1, code) + require.Nil(t, r) + }) + + t.Run("unsupported-option-detach", func(t *testing.T) { + code, r, err := ollamaContainer.Exec(ctx, []string{testBinary, "-v"}, tcexec.ProcessOptionFunc(func(opts *tcexec.ProcessOptions) { + opts.ExecConfig.Detach = true + })) + require.ErrorIs(t, err, errors.ErrUnsupported) + require.Equal(t, 1, code) + require.Nil(t, r) + }) + + t.Run("unsupported-option-detach-keys", func(t *testing.T) { + code, r, err := ollamaContainer.Exec(ctx, []string{testBinary, "-v"}, tcexec.ProcessOptionFunc(func(opts *tcexec.ProcessOptions) { + opts.ExecConfig.DetachKeys = "ctrl-p,ctrl-q" + })) + require.ErrorIs(t, err, errors.ErrUnsupported) + require.Equal(t, 1, code) + require.Nil(t, r) + }) + + t.Run("pull-and-run-model", func(t *testing.T) { + const model = "llama3.2:1b" + + code, r, err := ollamaContainer.Exec(ctx, []string{testBinary, "pull", model}) + require.NoError(t, err) + require.Zero(t, code) + + bs, err := io.ReadAll(r) + require.NoError(t, err) + require.Contains(t, string(bs), "success") + + code, r, err = ollamaContainer.Exec(ctx, []string{testBinary, "run", model}, tcexec.Multiplexed()) + require.NoError(t, err) + require.Zero(t, code) + + bs, err = io.ReadAll(r) + require.NoError(t, err) + require.Empty(t, bs) + + logs, err := ollamaContainer.Logs(ctx) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, logs.Close()) + }) + + bs, err = io.ReadAll(logs) + require.NoError(t, err) + require.Contains(t, string(bs), "llama runner started") + }) +} + +func TestRun_localValidateRequest(t *testing.T) { + // check if the local ollama binary is available + if _, err := exec.LookPath(testBinary); err != nil { + t.Skip("local ollama binary not found, skipping") + } + + ctx := context.Background() + t.Run("waiting-for-nil", func(t *testing.T) { + ollamaContainer, err := ollama.Run( + ctx, + testImage, + ollama.WithUseLocal("FOO=BAR"), + testcontainers.CustomizeRequestOption(func(req *testcontainers.GenericContainerRequest) error { + req.WaitingFor = nil + return nil + }), + ) + testcontainers.CleanupContainer(t, ollamaContainer) + require.EqualError(t, err, "validate request: ContainerRequest.WaitingFor must be set") + }) + + t.Run("started-false", func(t *testing.T) { + ollamaContainer, err := ollama.Run( + ctx, + testImage, + ollama.WithUseLocal("FOO=BAR"), + testcontainers.CustomizeRequestOption(func(req *testcontainers.GenericContainerRequest) error { + req.Started = false + return nil + }), + ) + testcontainers.CleanupContainer(t, ollamaContainer) + require.EqualError(t, err, "validate request: Started must be true") + }) + + t.Run("exposed-ports-empty", func(t *testing.T) { + ollamaContainer, err := ollama.Run( + ctx, + testImage, + ollama.WithUseLocal("FOO=BAR"), + testcontainers.CustomizeRequestOption(func(req *testcontainers.GenericContainerRequest) error { + req.ExposedPorts = req.ExposedPorts[:0] + return nil + }), + ) + testcontainers.CleanupContainer(t, ollamaContainer) + require.EqualError(t, err, "validate request: ContainerRequest.ExposedPorts must be 11434/tcp got: []") + }) + + t.Run("dockerfile-set", func(t *testing.T) { + ollamaContainer, err := ollama.Run( + ctx, + testImage, + ollama.WithUseLocal("FOO=BAR"), + testcontainers.CustomizeRequestOption(func(req *testcontainers.GenericContainerRequest) error { + req.Dockerfile = "FROM scratch" + return nil + }), + ) + testcontainers.CleanupContainer(t, ollamaContainer) + require.EqualError(t, err, "validate request: unsupported field: ContainerRequest.FromDockerfile.Dockerfile = \"FROM scratch\"") + }) + + t.Run("image-only", func(t *testing.T) { + ollamaContainer, err := ollama.Run( + ctx, + testBinary, + ollama.WithUseLocal(), + ) + testcontainers.CleanupContainer(t, ollamaContainer) + require.NoError(t, err) + }) + + t.Run("image-path", func(t *testing.T) { + ollamaContainer, err := ollama.Run( + ctx, + "prefix-path/"+testBinary, + ollama.WithUseLocal(), + ) + testcontainers.CleanupContainer(t, ollamaContainer) + require.NoError(t, err) + }) + + t.Run("image-bad-version", func(t *testing.T) { + ollamaContainer, err := ollama.Run( + ctx, + testBinary+":bad-version", + ollama.WithUseLocal(), + ) + testcontainers.CleanupContainer(t, ollamaContainer) + require.EqualError(t, err, `validate request: ContainerRequest.Image version must be blank or "latest", got: "bad-version"`) + }) + + t.Run("image-not-found", func(t *testing.T) { + ollamaContainer, err := ollama.Run( + ctx, + "ollama/ollama-not-found", + ollama.WithUseLocal(), + ) + testcontainers.CleanupContainer(t, ollamaContainer) + require.EqualError(t, err, `validate request: invalid image "ollama/ollama-not-found": exec: "ollama-not-found": executable file not found in $PATH`) + }) +} diff --git a/modules/ollama/ollama.go b/modules/ollama/ollama.go index 203d80103f..4d78fa171e 100644 --- a/modules/ollama/ollama.go +++ b/modules/ollama/ollama.go @@ -27,12 +27,12 @@ type OllamaContainer struct { func (c *OllamaContainer) ConnectionString(ctx context.Context) (string, error) { host, err := c.Host(ctx) if err != nil { - return "", err + return "", fmt.Errorf("host: %w", err) } port, err := c.MappedPort(ctx, "11434/tcp") if err != nil { - return "", err + return "", fmt.Errorf("mapped port: %w", err) } return fmt.Sprintf("http://%s:%d", host, port.Int()), nil @@ -43,6 +43,10 @@ func (c *OllamaContainer) ConnectionString(ctx context.Context) (string, error) // of the container into a new image with the given name, so it doesn't override existing images. // It should be used for creating an image that contains a loaded model. func (c *OllamaContainer) Commit(ctx context.Context, targetImage string) error { + if _, ok := c.Container.(*localProcess); ok { + return nil + } + cli, err := testcontainers.NewDockerClientWithOpts(context.Background()) if err != nil { return err @@ -80,27 +84,34 @@ func RunContainer(ctx context.Context, opts ...testcontainers.ContainerCustomize // Run creates an instance of the Ollama container type func Run(ctx context.Context, img string, opts ...testcontainers.ContainerCustomizer) (*OllamaContainer, error) { - req := testcontainers.ContainerRequest{ - Image: img, - ExposedPorts: []string{"11434/tcp"}, - WaitingFor: wait.ForListeningPort("11434/tcp").WithStartupTimeout(60 * time.Second), - } - - genericContainerReq := testcontainers.GenericContainerRequest{ - ContainerRequest: req, - Started: true, + req := testcontainers.GenericContainerRequest{ + ContainerRequest: testcontainers.ContainerRequest{ + Image: img, + ExposedPorts: []string{"11434/tcp"}, + WaitingFor: wait.ForListeningPort("11434/tcp").WithStartupTimeout(60 * time.Second), + }, + Started: true, } - // always request a GPU if the host supports it + // Always request a GPU if the host supports it. opts = append(opts, withGpu()) + var local *localProcess for _, opt := range opts { - if err := opt.Customize(&genericContainerReq); err != nil { + if err := opt.Customize(&req); err != nil { return nil, fmt.Errorf("customize: %w", err) } + if l, ok := opt.(*localProcess); ok { + local = l + } + } + + // Now we have processed all the options, we can check if we need to use the local process. + if local != nil { + return local.run(ctx, req) } - container, err := testcontainers.GenericContainer(ctx, genericContainerReq) + container, err := testcontainers.GenericContainer(ctx, req) var c *OllamaContainer if container != nil { c = &OllamaContainer{Container: container} diff --git a/modules/ollama/options.go b/modules/ollama/options.go index 605768a379..1cf29453fe 100644 --- a/modules/ollama/options.go +++ b/modules/ollama/options.go @@ -11,7 +11,7 @@ import ( var noopCustomizeRequestOption = func(req *testcontainers.GenericContainerRequest) error { return nil } // withGpu requests a GPU for the container, which could improve performance for some models. -// This option will be automaticall added to the Ollama container to check if the host supports nvidia. +// This option will be automatically added to the Ollama container to check if the host supports nvidia. func withGpu() testcontainers.CustomizeRequestOption { cli, err := testcontainers.NewDockerClientWithOpts(context.Background()) if err != nil { @@ -37,3 +37,30 @@ func withGpu() testcontainers.CustomizeRequestOption { } }) } + +// WithUseLocal starts a local Ollama process with the given environment in +// format KEY=VALUE instead of a Docker container, which can be more performant +// as it has direct access to the GPU. +// By default `OLLAMA_HOST=localhost:0` is set to avoid port conflicts. +// +// When using this option, the container request will be validated to ensure +// that only the options that are compatible with the local process are used. +// +// Supported fields are: +// - [testcontainers.GenericContainerRequest.Started] must be set to true +// - [testcontainers.GenericContainerRequest.ExposedPorts] must be set to ["11434/tcp"] +// - [testcontainers.ContainerRequest.WaitingFor] should not be changed from the default +// - [testcontainers.ContainerRequest.Image] used to determine the local process binary [/][:latest] if not blank. +// - [testcontainers.ContainerRequest.Env] applied to all local process executions +// - [testcontainers.GenericContainerRequest.Logger] is unused +// +// Any other leaf field not set to the type's zero value will result in an error. +func WithUseLocal(envKeyValues ...string) *localProcess { + sessionID := testcontainers.SessionID() + return &localProcess{ + sessionID: sessionID, + logName: localNamePrefix + "-" + sessionID + ".log", + env: envKeyValues, + binary: localBinary, + } +} diff --git a/modules/ollama/options_test.go b/modules/ollama/options_test.go new file mode 100644 index 0000000000..f842d15a17 --- /dev/null +++ b/modules/ollama/options_test.go @@ -0,0 +1,49 @@ +package ollama_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/modules/ollama" +) + +func TestWithUseLocal(t *testing.T) { + req := testcontainers.GenericContainerRequest{} + + t.Run("keyVal/valid", func(t *testing.T) { + opt := ollama.WithUseLocal("OLLAMA_MODELS=/path/to/models") + err := opt.Customize(&req) + require.NoError(t, err) + require.Equal(t, "/path/to/models", req.Env["OLLAMA_MODELS"]) + }) + + t.Run("keyVal/invalid", func(t *testing.T) { + opt := ollama.WithUseLocal("OLLAMA_MODELS") + err := opt.Customize(&req) + require.Error(t, err) + }) + + t.Run("keyVal/valid/multiple", func(t *testing.T) { + opt := ollama.WithUseLocal("OLLAMA_MODELS=/path/to/models", "OLLAMA_HOST=localhost") + err := opt.Customize(&req) + require.NoError(t, err) + require.Equal(t, "/path/to/models", req.Env["OLLAMA_MODELS"]) + require.Equal(t, "localhost", req.Env["OLLAMA_HOST"]) + }) + + t.Run("keyVal/valid/multiple-equals", func(t *testing.T) { + opt := ollama.WithUseLocal("OLLAMA_MODELS=/path/to/models", "OLLAMA_HOST=localhost=127.0.0.1") + err := opt.Customize(&req) + require.NoError(t, err) + require.Equal(t, "/path/to/models", req.Env["OLLAMA_MODELS"]) + require.Equal(t, "localhost=127.0.0.1", req.Env["OLLAMA_HOST"]) + }) + + t.Run("keyVal/invalid/multiple", func(t *testing.T) { + opt := ollama.WithUseLocal("OLLAMA_MODELS=/path/to/models", "OLLAMA_HOST") + err := opt.Customize(&req) + require.Error(t, err) + }) +} From 3330dc1d098e155a1c14f061b8bb5875d8cc35e6 Mon Sep 17 00:00:00 2001 From: Barrett Strausser Date: Tue, 7 Jan 2025 07:16:03 -0500 Subject: [PATCH 11/11] feat(postgres): ssl for postgres (#2473) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * SSL for postgres * Add entrypoint wrapper * Add in init so we can test ssl+init path * Remove unused fields from options * Remove unused consts * Separate entrypoint from ssl * Use external cert generation * Make entrypoint not-optional * Add docstring * Spaces to tab in entrypoint * Add postgres ssl docs * Remove WithEntrypoint * Update docs/modules/postgres.md Co-authored-by: Manuel de la Peña * Update docs/modules/postgres.md Co-authored-by: Manuel de la Peña * Update docs/modules/postgres.md Co-authored-by: Manuel de la Peña * Update modules/postgres/postgres_test.go Co-authored-by: Manuel de la Peña * Update modules/postgres/postgres_test.go Co-authored-by: Manuel de la Peña * Embed resources + Use custom conf automatically * Update docs/modules/postgres.md Co-authored-by: Manuel de la Peña * Update docs/modules/postgres.md Co-authored-by: Manuel de la Peña * Update docs/modules/postgres.md Co-authored-by: Manuel de la Peña * Update modules/postgres/postgres_test.go Co-authored-by: Manuel de la Peña * Update modules/postgres/postgres_test.go Co-authored-by: Manuel de la Peña * Update modules/postgres/postgres_test.go Co-authored-by: Manuel de la Peña * Update modules/postgres/postgres_test.go Co-authored-by: Manuel de la Peña * Revert to use passed in conf * Update doc for required conf * Error checking in the customizer * Few formatting fix * Use non-nil error when err is nil * Update modules/postgres/postgres_test.go Co-authored-by: Steven Hartland * Update modules/postgres/postgres_test.go Co-authored-by: Steven Hartland * Update modules/postgres/postgres.go Co-authored-by: Steven Hartland * Update modules/postgres/postgres.go Co-authored-by: Steven Hartland * Update modules/postgres/postgres_test.go Co-authored-by: Steven Hartland * Addresses review modulo cleanup * Remove unused type * Use ContainerCleanup * Lint pass * Add t.Helper and Linting * Remove SSLSetting struct, use raw paths * Use single command for chown key material * docs: remove spaces * fix: use non-deprecated APIs * chore: rename variable --------- Co-authored-by: bstrausser Co-authored-by: Manuel de la Peña Co-authored-by: Steven Hartland Co-authored-by: Manuel de la Peña --- docs/modules/postgres.md | 26 ++++++ modules/postgres/go.mod | 1 + modules/postgres/go.sum | 2 + modules/postgres/postgres.go | 41 +++++++++ modules/postgres/postgres_test.go | 88 +++++++++++++++++++ .../postgres/resources/customEntrypoint.sh | 25 ++++++ modules/postgres/testdata/postgres-ssl.conf | 80 +++++++++++++++++ 7 files changed, 263 insertions(+) create mode 100644 modules/postgres/resources/customEntrypoint.sh create mode 100644 modules/postgres/testdata/postgres-ssl.conf diff --git a/docs/modules/postgres.md b/docs/modules/postgres.md index 930de50c15..4192cf7eca 100644 --- a/docs/modules/postgres.md +++ b/docs/modules/postgres.md @@ -74,9 +74,35 @@ An example of a `*.sh` script that creates a user and database is shown below: In the case you have a custom config file for Postgres, it's possible to copy that file into the container before it's started, using the `WithConfigFile(cfgPath string)` function. +This function can be used `WithSSLSettings` but requires your configuration correctly sets the SSL properties. See the below section for more information. + !!!tip For information on what is available to configure, see the [PostgreSQL docs](https://www.postgresql.org/docs/14/runtime-config.html) for the specific version of PostgreSQL that you are running. +#### SSL Configuration + +- Not available until the next release of testcontainers-go :material-tag: main + +If you would like to use SSL with the container you can use the `WithSSLSettings`. This function accepts a `SSLSettings` which has the required secret material, namely the ca-certificate, server certificate and key. The container will copy this material to `/tmp/testcontainers-go/postgres/ca_cert.pem`, `/tmp/testcontainers-go/postgres/server.cert` and `/tmp/testcontainers-go/postgres/server.key` + +This function requires a custom postgres configuration file that enables SSL and correctly sets the paths on the key material. + +If you use this function by itself or in conjuction with `WithConfigFile` your custom conf must set the require ssl fields. The configuration must correctly align the key material provided via `SSLSettings` with the server configuration, namely the paths. Your configuration will need to contain the following: + +``` +ssl = on +ssl_ca_file = '/tmp/testcontainers-go/postgres/ca_cert.pem' +ssl_cert_file = '/tmp/testcontainers-go/postgres/server.cert' +ssl_key_file = '/tmp/testcontainers-go/postgres/server.key' +``` + +!!!warning + This function assumes the postgres user in the container is `postgres` + + There is no current support for mutual authentication. + + The `SSLSettings` function will modify the container `entrypoint`. This is done so that key material copied over to the container is chowned by `postgres`. All other container arguments will be passed through to the original container entrypoint. + ### Container Methods #### ConnectionString diff --git a/modules/postgres/go.mod b/modules/postgres/go.mod index 01ee22b8bd..955edef1de 100644 --- a/modules/postgres/go.mod +++ b/modules/postgres/go.mod @@ -6,6 +6,7 @@ require ( github.com/docker/go-connections v0.5.0 github.com/jackc/pgx/v5 v5.5.4 github.com/lib/pq v1.10.9 + github.com/mdelapenya/tlscert v0.1.0 github.com/stretchr/testify v1.9.0 github.com/testcontainers/testcontainers-go v0.34.0 diff --git a/modules/postgres/go.sum b/modules/postgres/go.sum index ba337a188c..9d1c34d3df 100644 --- a/modules/postgres/go.sum +++ b/modules/postgres/go.sum @@ -71,6 +71,8 @@ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/mdelapenya/tlscert v0.1.0 h1:YTpF579PYUX475eOL+6zyEO3ngLTOUWck78NBuJVXaM= +github.com/mdelapenya/tlscert v0.1.0/go.mod h1:wrbyM/DwbFCeCeqdPX/8c6hNOqQgbf0rUDErE1uD+64= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk= diff --git a/modules/postgres/postgres.go b/modules/postgres/postgres.go index ce3719575d..e25bc667a6 100644 --- a/modules/postgres/postgres.go +++ b/modules/postgres/postgres.go @@ -3,6 +3,7 @@ package postgres import ( "context" "database/sql" + _ "embed" "errors" "fmt" "io" @@ -19,6 +20,9 @@ const ( defaultSnapshotName = "migrated_template" ) +//go:embed resources/customEntrypoint.sh +var embeddedCustomEntrypoint string + // PostgresContainer represents the postgres container type used in the module type PostgresContainer struct { testcontainers.Container @@ -205,6 +209,43 @@ func WithSnapshotName(name string) SnapshotOption { } } +// WithSSLSettings configures the Postgres server to run with the provided CA Chain +// This will not function if the corresponding postgres conf is not correctly configured. +// Namely the paths below must match what is set in the conf file +func WithSSLCert(caCertFile string, certFile string, keyFile string) testcontainers.CustomizeRequestOption { + const defaultPermission = 0o600 + + return func(req *testcontainers.GenericContainerRequest) error { + const entrypointPath = "/usr/local/bin/docker-entrypoint-ssl.bash" + + req.Files = append(req.Files, + testcontainers.ContainerFile{ + HostFilePath: caCertFile, + ContainerFilePath: "/tmp/testcontainers-go/postgres/ca_cert.pem", + FileMode: defaultPermission, + }, + testcontainers.ContainerFile{ + HostFilePath: certFile, + ContainerFilePath: "/tmp/testcontainers-go/postgres/server.cert", + FileMode: defaultPermission, + }, + testcontainers.ContainerFile{ + HostFilePath: keyFile, + ContainerFilePath: "/tmp/testcontainers-go/postgres/server.key", + FileMode: defaultPermission, + }, + testcontainers.ContainerFile{ + Reader: strings.NewReader(embeddedCustomEntrypoint), + ContainerFilePath: entrypointPath, + FileMode: defaultPermission, + }, + ) + req.Entrypoint = []string{"sh", entrypointPath} + + return nil + } +} + // Snapshot takes a snapshot of the current state of the database as a template, which can then be restored using // the Restore method. By default, the snapshot will be created under a database called migrated_template, you can // customize the snapshot name with the options. diff --git a/modules/postgres/postgres_test.go b/modules/postgres/postgres_test.go index 8c02e68476..e83b8e1454 100644 --- a/modules/postgres/postgres_test.go +++ b/modules/postgres/postgres_test.go @@ -3,7 +3,9 @@ package postgres_test import ( "context" "database/sql" + "errors" "fmt" + "os" "path/filepath" "testing" "time" @@ -12,6 +14,8 @@ import ( "github.com/jackc/pgx/v5" _ "github.com/jackc/pgx/v5/stdlib" _ "github.com/lib/pq" + "github.com/mdelapenya/tlscert" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/testcontainers/testcontainers-go" @@ -25,6 +29,40 @@ const ( password = "password" ) +func createSSLCerts(t *testing.T) (*tlscert.Certificate, *tlscert.Certificate, error) { + t.Helper() + tmpDir := t.TempDir() + certsDir := tmpDir + "/certs" + + require.NoError(t, os.MkdirAll(certsDir, 0o755)) + + t.Cleanup(func() { + require.NoError(t, os.RemoveAll(tmpDir)) + }) + + caCert := tlscert.SelfSignedFromRequest(tlscert.Request{ + Host: "localhost", + Name: "ca-cert", + ParentDir: certsDir, + }) + + if caCert == nil { + return caCert, nil, errors.New("unable to create CA Authority") + } + + cert := tlscert.SelfSignedFromRequest(tlscert.Request{ + Host: "localhost", + Name: "client-cert", + Parent: caCert, + ParentDir: certsDir, + }) + if cert == nil { + return caCert, cert, errors.New("unable to create Server Certificates") + } + + return caCert, cert, nil +} + func TestPostgres(t *testing.T) { ctx := context.Background() @@ -171,6 +209,56 @@ func TestWithConfigFile(t *testing.T) { defer db.Close() } +func TestWithSSL(t *testing.T) { + ctx := context.Background() + + caCert, serverCerts, err := createSSLCerts(t) + require.NoError(t, err) + + ctr, err := postgres.Run(ctx, + "postgres:16-alpine", + postgres.WithConfigFile(filepath.Join("testdata", "postgres-ssl.conf")), + postgres.WithInitScripts(filepath.Join("testdata", "init-user-db.sh")), + postgres.WithDatabase(dbname), + postgres.WithUsername(user), + postgres.WithPassword(password), + testcontainers.WithWaitStrategy(wait.ForLog("database system is ready to accept connections").WithOccurrence(2).WithStartupTimeout(5*time.Second)), + postgres.WithSSLCert(caCert.CertPath, serverCerts.CertPath, serverCerts.KeyPath), + ) + + testcontainers.CleanupContainer(t, ctr) + require.NoError(t, err) + + connStr, err := ctr.ConnectionString(ctx, "sslmode=require") + require.NoError(t, err) + + db, err := sql.Open("postgres", connStr) + require.NoError(t, err) + assert.NotNil(t, db) + defer db.Close() + + result, err := db.Exec("SELECT * FROM testdb;") + require.NoError(t, err) + assert.NotNil(t, result) +} + +func TestSSLValidatesKeyMaterialPath(t *testing.T) { + ctx := context.Background() + + _, err := postgres.Run(ctx, + "postgres:16-alpine", + postgres.WithConfigFile(filepath.Join("testdata", "postgres-ssl.conf")), + postgres.WithInitScripts(filepath.Join("testdata", "init-user-db.sh")), + postgres.WithDatabase(dbname), + postgres.WithUsername(user), + postgres.WithPassword(password), + testcontainers.WithWaitStrategy(wait.ForLog("database system is ready to accept connections").WithOccurrence(2).WithStartupTimeout(5*time.Second)), + postgres.WithSSLCert("", "", ""), + ) + + require.Error(t, err, "Error should not have been nil. Container creation should have failed due to empty key material") +} + func TestWithInitScript(t *testing.T) { ctx := context.Background() diff --git a/modules/postgres/resources/customEntrypoint.sh b/modules/postgres/resources/customEntrypoint.sh new file mode 100644 index 0000000000..ff4ffa4291 --- /dev/null +++ b/modules/postgres/resources/customEntrypoint.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +set -Eeo pipefail + + +pUID=$(id -u postgres) +pGID=$(id -g postgres) + +if [ -z "$pUID" ] +then + echo "Unable to find postgres user id, required in order to chown key material" + exit 1 +fi + +if [ -z "$pGID" ] +then + echo "Unable to find postgres group id, required in order to chown key material" + exit 1 +fi + +chown "$pUID":"$pGID" \ + /tmp/testcontainers-go/postgres/ca_cert.pem \ + /tmp/testcontainers-go/postgres/server.cert \ + /tmp/testcontainers-go/postgres/server.key + +/usr/local/bin/docker-entrypoint.sh "$@" diff --git a/modules/postgres/testdata/postgres-ssl.conf b/modules/postgres/testdata/postgres-ssl.conf new file mode 100644 index 0000000000..5e49f16a4f --- /dev/null +++ b/modules/postgres/testdata/postgres-ssl.conf @@ -0,0 +1,80 @@ +# ----------------------------- +# PostgreSQL configuration file +# ----------------------------- +# +# This file consists of lines of the form: +# +# name = value +# +# (The "=" is optional.) Whitespace may be used. Comments are introduced with +# "#" anywhere on a line. The complete list of parameter names and allowed +# values can be found in the PostgreSQL documentation. +# +# The commented-out settings shown in this file represent the default values. +# Re-commenting a setting is NOT sufficient to revert it to the default value; +# you need to reload the server. +# +# This file is read on server startup and when the server receives a SIGHUP +# signal. If you edit the file on a running system, you have to SIGHUP the +# server for the changes to take effect, run "pg_ctl reload", or execute +# "SELECT pg_reload_conf()". Some parameters, which are marked below, +# require a server shutdown and restart to take effect. +# +# Any parameter can also be given as a command-line option to the server, e.g., +# "postgres -c log_connections=on". Some parameters can be changed at run time +# with the "SET" SQL command. +# +# Memory units: B = bytes Time units: ms = milliseconds +# kB = kilobytes s = seconds +# MB = megabytes min = minutes +# GB = gigabytes h = hours +# TB = terabytes d = days + + +#------------------------------------------------------------------------------ +# FILE LOCATIONS +#------------------------------------------------------------------------------ + +# The default values of these variables are driven from the -D command-line +# option or PGDATA environment variable, represented here as ConfigDir. + +#data_directory = 'ConfigDir' # use data in another directory + # (change requires restart) +#hba_file = 'ConfigDir/pg_hba.conf' # host-based authentication file + # (change requires restart) +#ident_file = 'ConfigDir/pg_ident.conf' # ident configuration file + # (change requires restart) + +# If external_pid_file is not explicitly set, no extra PID file is written. +#external_pid_file = '' # write an extra PID file + # (change requires restart) + + +#------------------------------------------------------------------------------ +# CONNECTIONS AND AUTHENTICATION +#------------------------------------------------------------------------------ + +# - Connection Settings - + +listen_addresses = '*' + # comma-separated list of addresses; + # defaults to 'localhost'; use '*' for all + # (change requires restart) +#port = 5432 # (change requires restart) +#max_connections = 100 # (change requires restart) + +# - SSL - + +ssl = on +ssl_ca_file = '/tmp/testcontainers-go/postgres/ca_cert.pem' +ssl_cert_file = '/tmp/testcontainers-go/postgres/server.cert' +#ssl_crl_file = '' +ssl_key_file = '/tmp/testcontainers-go/postgres/server.key' +#ssl_ciphers = 'HIGH:MEDIUM:+3DES:!aNULL' # allowed SSL ciphers +#ssl_prefer_server_ciphers = on +#ssl_ecdh_curve = 'prime256v1' +#ssl_dh_params_file = '' +#ssl_passphrase_command = '' +#ssl_passphrase_command_supports_reload = off + +