diff --git a/.github/scripts/modules/ollama/install-dependencies.sh b/.github/scripts/modules/ollama/install-dependencies.sh
new file mode 100755
index 0000000000..d699158806
--- /dev/null
+++ b/.github/scripts/modules/ollama/install-dependencies.sh
@@ -0,0 +1,6 @@
+#!/usr/bin/env bash
+
+curl -fsSL https://ollama.com/install.sh | sh
+
+# kill any running ollama process so that the tests can start from a clean state
+sudo systemctl stop ollama.service
diff --git a/.github/workflows/ci-test-go.yml b/.github/workflows/ci-test-go.yml
index 82be78435f..0d6af15880 100644
--- a/.github/workflows/ci-test-go.yml
+++ b/.github/workflows/ci-test-go.yml
@@ -107,6 +107,16 @@ jobs:
working-directory: ./${{ inputs.project-directory }}
run: go build
+ - name: Install dependencies
+ shell: bash
+ run: |
+ SCRIPT_PATH="./.github/scripts/${{ inputs.project-directory }}/install-dependencies.sh"
+ if [ -f "$SCRIPT_PATH" ]; then
+ $SCRIPT_PATH
+ else
+ echo "No dependencies script found at $SCRIPT_PATH - skipping installation"
+ fi
+
- name: go test
# only run tests on linux, there are a number of things that won't allow the tests to run on anything else
# many (maybe, all?) images used can only be build on Linux, they don't have Windows in their manifest, and
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index b2972da7d5..fe1ae1da37 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -140,7 +140,7 @@ jobs:
merge-multiple: true
- name: Analyze with SonarCloud
- uses: sonarsource/sonarcloud-github-action@49e6cd3b187936a73b8280d59ffd9da69df63ec9 # v2.1.1
+ uses: sonarsource/sonarcloud-github-action@02ef91109b2d589e757aefcfb2854c2783fd7b19 # v4.0.0
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }}
diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml
index 15f56f2eb5..2c899be9d4 100644
--- a/.github/workflows/codeql.yml
+++ b/.github/workflows/codeql.yml
@@ -53,7 +53,7 @@ jobs:
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
- uses: github/codeql-action/init@afb54ba388a7dca6ecae48f608c4ff05ff4cc77a # v3.25.15
+ uses: github/codeql-action/init@48ab28a6f5dbc2a99bf1e0131198dd8f1df78169 # v3.28.0
with:
languages: ${{ matrix.language }}
# If you wish to specify custom queries, you can do so here or in a config file.
@@ -67,7 +67,7 @@ jobs:
# Autobuild attempts to build any compiled languages (C/C++, C#, Go, Java, or Swift).
# If this step fails, then you should remove it and run the build manually (see below)
- name: Autobuild
- uses: github/codeql-action/autobuild@afb54ba388a7dca6ecae48f608c4ff05ff4cc77a # v3.25.15
+ uses: github/codeql-action/autobuild@48ab28a6f5dbc2a99bf1e0131198dd8f1df78169 # v3.28.0
# âšī¸ Command-line programs to run using the OS shell.
# đ See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
@@ -80,6 +80,6 @@ jobs:
# ./location_of_script_within_repo/buildscript.sh
- name: Perform CodeQL Analysis
- uses: github/codeql-action/analyze@afb54ba388a7dca6ecae48f608c4ff05ff4cc77a # v3.25.15
+ uses: github/codeql-action/analyze@48ab28a6f5dbc2a99bf1e0131198dd8f1df78169 # v3.28.0
with:
category: "/language:${{matrix.language}}"
diff --git a/.github/workflows/docker-moby-latest.yml b/.github/workflows/docker-moby-latest.yml
index dc06fb49e7..86ad12c6e0 100644
--- a/.github/workflows/docker-moby-latest.yml
+++ b/.github/workflows/docker-moby-latest.yml
@@ -70,8 +70,9 @@ jobs:
- name: Notify to Slack on failures
if: failure()
id: slack
- uses: slackapi/slack-github-action@v1.26.0
+ uses: slackapi/slack-github-action@v2.0.0
with:
+ payload-templated: true
payload-file-path: "./payload-slack-content.json"
env:
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_DOCKER_LATEST_WEBHOOK }}
diff --git a/.github/workflows/scorecards.yml b/.github/workflows/scorecards.yml
index 51b8a535c5..1d141de781 100644
--- a/.github/workflows/scorecards.yml
+++ b/.github/workflows/scorecards.yml
@@ -51,6 +51,6 @@ jobs:
# required for Code scanning alerts
- name: "Upload SARIF results to code scanning"
- uses: github/codeql-action/upload-sarif@afb54ba388a7dca6ecae48f608c4ff05ff4cc77a # v3.25.15
+ uses: github/codeql-action/upload-sarif@48ab28a6f5dbc2a99bf1e0131198dd8f1df78169 # v3.28.0
with:
sarif_file: results.sarif
diff --git a/Pipfile.lock b/Pipfile.lock
index 9a2f6d24c8..d08964ab4c 100644
--- a/Pipfile.lock
+++ b/Pipfile.lock
@@ -178,11 +178,12 @@
},
"jinja2": {
"hashes": [
- "sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369",
- "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d"
+ "sha256:8fefff8dc3034e27bb80d67c671eb8a9bc424c0ef4c0826edbff304cceff43bb",
+ "sha256:aba0f4dc9ed8013c424088f68a5c226f7d6097ed89b246d7749c2ec4175c6adb"
],
+ "index": "pypi",
"markers": "python_version >= '3.7'",
- "version": "==3.1.4"
+ "version": "==3.1.5"
},
"markdown": {
"hashes": [
diff --git a/cleanup.go b/cleanup.go
index e2d52440b9..d676b42bdb 100644
--- a/cleanup.go
+++ b/cleanup.go
@@ -8,20 +8,65 @@ import (
"time"
)
-// terminateOptions is a type that holds the options for terminating a container.
-type terminateOptions struct {
- ctx context.Context
- timeout *time.Duration
- volumes []string
+// TerminateOptions is a type that holds the options for terminating a container.
+type TerminateOptions struct {
+ ctx context.Context
+ stopTimeout *time.Duration
+ volumes []string
}
// TerminateOption is a type that represents an option for terminating a container.
-type TerminateOption func(*terminateOptions)
+type TerminateOption func(*TerminateOptions)
+
+// NewTerminateOptions returns a fully initialised TerminateOptions.
+// Defaults: StopTimeout: 10 seconds.
+func NewTerminateOptions(ctx context.Context, opts ...TerminateOption) *TerminateOptions {
+ timeout := time.Second * 10
+ options := &TerminateOptions{
+ stopTimeout: &timeout,
+ ctx: ctx,
+ }
+ for _, opt := range opts {
+ opt(options)
+ }
+ return options
+}
+
+// Context returns the context to use during a Terminate.
+func (o *TerminateOptions) Context() context.Context {
+ return o.ctx
+}
+
+// StopTimeout returns the stop timeout to use during a Terminate.
+func (o *TerminateOptions) StopTimeout() *time.Duration {
+ return o.stopTimeout
+}
+
+// Cleanup performs any clean up needed
+func (o *TerminateOptions) Cleanup() error {
+ // TODO: simplify this when when perform the client refactor.
+ if len(o.volumes) == 0 {
+ return nil
+ }
+ client, err := NewDockerClientWithOpts(o.ctx)
+ if err != nil {
+ return fmt.Errorf("docker client: %w", err)
+ }
+ defer client.Close()
+ // Best effort to remove all volumes.
+ var errs []error
+ for _, volume := range o.volumes {
+ if errRemove := client.VolumeRemove(o.ctx, volume, true); errRemove != nil {
+ errs = append(errs, fmt.Errorf("volume remove %q: %w", volume, errRemove))
+ }
+ }
+ return errors.Join(errs...)
+}
// StopContext returns a TerminateOption that sets the context.
// Default: context.Background().
func StopContext(ctx context.Context) TerminateOption {
- return func(c *terminateOptions) {
+ return func(c *TerminateOptions) {
c.ctx = ctx
}
}
@@ -29,8 +74,8 @@ func StopContext(ctx context.Context) TerminateOption {
// StopTimeout returns a TerminateOption that sets the timeout.
// Default: See [Container.Stop].
func StopTimeout(timeout time.Duration) TerminateOption {
- return func(c *terminateOptions) {
- c.timeout = &timeout
+ return func(c *TerminateOptions) {
+ c.stopTimeout = &timeout
}
}
@@ -39,7 +84,7 @@ func StopTimeout(timeout time.Duration) TerminateOption {
// which are not removed by default.
// Default: nil.
func RemoveVolumes(volumes ...string) TerminateOption {
- return func(c *terminateOptions) {
+ return func(c *TerminateOptions) {
c.volumes = volumes
}
}
@@ -54,41 +99,12 @@ func TerminateContainer(container Container, options ...TerminateOption) error {
return nil
}
- c := &terminateOptions{
- ctx: context.Background(),
- }
-
- for _, opt := range options {
- opt(c)
- }
-
- // TODO: Add a timeout when terminate supports it.
- err := container.Terminate(c.ctx)
+ err := container.Terminate(context.Background(), options...)
if !isCleanupSafe(err) {
return fmt.Errorf("terminate: %w", err)
}
- // Remove additional volumes if any.
- if len(c.volumes) == 0 {
- return nil
- }
-
- client, err := NewDockerClientWithOpts(c.ctx)
- if err != nil {
- return fmt.Errorf("docker client: %w", err)
- }
-
- defer client.Close()
-
- // Best effort to remove all volumes.
- var errs []error
- for _, volume := range c.volumes {
- if errRemove := client.VolumeRemove(c.ctx, volume, true); errRemove != nil {
- errs = append(errs, fmt.Errorf("volume remove %q: %w", volume, errRemove))
- }
- }
-
- return errors.Join(errs...)
+ return nil
}
// isNil returns true if val is nil or an nil instance false otherwise.
diff --git a/container.go b/container.go
index 5ee0aac881..50fc656e7e 100644
--- a/container.go
+++ b/container.go
@@ -50,7 +50,7 @@ type Container interface {
Stop(context.Context, *time.Duration) error // stop the container
// Terminate stops and removes the container and its image if it was built and not flagged as kept.
- Terminate(ctx context.Context) error
+ Terminate(ctx context.Context, opts ...TerminateOption) error
Logs(context.Context) (io.ReadCloser, error) // Get logs of the container
FollowOutput(LogConsumer) // Deprecated: it will be removed in the next major release
@@ -77,7 +77,7 @@ type ImageBuildInfo interface {
GetDockerfile() string // the relative path to the Dockerfile, including the file itself
GetRepo() string // get repo label for image
GetTag() string // get tag label for image
- ShouldPrintBuildLog() bool // allow build log to be printed to stdout
+ BuildLogWriter() io.Writer // for output of build log, use io.Discard to disable the output
ShouldBuildImage() bool // return true if the image needs to be built
GetBuildArgs() map[string]*string // return the environment args used to build the from Dockerfile
GetAuthConfigs() map[string]registry.AuthConfig // Deprecated. Testcontainers will detect registry credentials automatically. Return the auth configs to be able to pull from an authenticated docker registry
@@ -92,7 +92,8 @@ type FromDockerfile struct {
Repo string // the repo label for image, defaults to UUID
Tag string // the tag label for image, defaults to UUID
BuildArgs map[string]*string // enable user to pass build args to docker daemon
- PrintBuildLog bool // enable user to print build log
+ PrintBuildLog bool // Deprecated: Use BuildLogWriter instead
+ BuildLogWriter io.Writer // for output of build log, defaults to io.Discard
AuthConfigs map[string]registry.AuthConfig // Deprecated. Testcontainers will detect registry credentials automatically. Enable auth configs to be able to pull from an authenticated docker registry
// KeepImage describes whether DockerContainer.Terminate should not delete the
// container image. Useful for images that are built from a Dockerfile and take a
@@ -410,8 +411,20 @@ func (c *ContainerRequest) ShouldKeepBuiltImage() bool {
return c.FromDockerfile.KeepImage
}
-func (c *ContainerRequest) ShouldPrintBuildLog() bool {
- return c.FromDockerfile.PrintBuildLog
+// BuildLogWriter returns the io.Writer for output of log when building a Docker image from
+// a Dockerfile. It returns the BuildLogWriter from the ContainerRequest, defaults to io.Discard.
+// For backward compatibility, if BuildLogWriter is default and PrintBuildLog is true,
+// the function returns os.Stderr.
+func (c *ContainerRequest) BuildLogWriter() io.Writer {
+ if c.FromDockerfile.BuildLogWriter != nil {
+ return c.FromDockerfile.BuildLogWriter
+ }
+ if c.FromDockerfile.PrintBuildLog {
+ c.FromDockerfile.BuildLogWriter = os.Stderr
+ } else {
+ c.FromDockerfile.BuildLogWriter = io.Discard
+ }
+ return c.FromDockerfile.BuildLogWriter
}
// BuildOptions returns the image build options when building a Docker image from a Dockerfile.
diff --git a/docker.go b/docker.go
index 296fe6743c..2ce849be50 100644
--- a/docker.go
+++ b/docker.go
@@ -303,12 +303,11 @@ func (c *DockerContainer) Stop(ctx context.Context, timeout *time.Duration) erro
// The following hooks are called in order:
// - [ContainerLifecycleHooks.PreTerminates]
// - [ContainerLifecycleHooks.PostTerminates]
-func (c *DockerContainer) Terminate(ctx context.Context) error {
- // ContainerRemove hardcodes stop timeout to 3 seconds which is too short
- // to ensure that child containers are stopped so we manually call stop.
- // TODO: make this configurable via a functional option.
- timeout := 10 * time.Second
- err := c.Stop(ctx, &timeout)
+//
+// Default: timeout is 10 seconds.
+func (c *DockerContainer) Terminate(ctx context.Context, opts ...TerminateOption) error {
+ options := NewTerminateOptions(ctx, opts...)
+ err := c.Stop(options.Context(), options.StopTimeout())
if err != nil && !isCleanupSafe(err) {
return fmt.Errorf("stop: %w", err)
}
@@ -343,6 +342,10 @@ func (c *DockerContainer) Terminate(ctx context.Context) error {
c.sessionID = ""
c.isRunning = false
+ if err = options.Cleanup(); err != nil {
+ errs = append(errs, err)
+ }
+
return errors.Join(errs...)
}
@@ -1004,10 +1007,7 @@ func (p *DockerProvider) BuildImage(ctx context.Context, img ImageBuildInfo) (st
}
defer resp.Body.Close()
- output := io.Discard
- if img.ShouldPrintBuildLog() {
- output = os.Stderr
- }
+ output := img.BuildLogWriter()
// Always process the output, even if it is not printed
// to ensure that errors during the build process are
@@ -1498,7 +1498,11 @@ func (p *DockerProvider) daemonHostLocked(ctx context.Context) (string, error) {
p.hostCache = daemonURL.Hostname()
case "unix", "npipe":
if core.InAContainer() {
- ip, err := p.GetGatewayIP(ctx)
+ defaultNetwork, err := p.ensureDefaultNetworkLocked(ctx)
+ if err != nil {
+ return "", fmt.Errorf("ensure default network: %w", err)
+ }
+ ip, err := p.getGatewayIP(ctx, defaultNetwork)
if err != nil {
ip, err = core.DefaultGatewayIP()
if err != nil {
@@ -1598,7 +1602,10 @@ func (p *DockerProvider) GetGatewayIP(ctx context.Context) (string, error) {
if err != nil {
return "", fmt.Errorf("ensure default network: %w", err)
}
+ return p.getGatewayIP(ctx, defaultNetwork)
+}
+func (p *DockerProvider) getGatewayIP(ctx context.Context, defaultNetwork string) (string, error) {
nw, err := p.GetNetwork(ctx, NetworkRequest{Name: defaultNetwork})
if err != nil {
return "", err
@@ -1624,7 +1631,10 @@ func (p *DockerProvider) GetGatewayIP(ctx context.Context) (string, error) {
func (p *DockerProvider) ensureDefaultNetwork(ctx context.Context) (string, error) {
p.mtx.Lock()
defer p.mtx.Unlock()
+ return p.ensureDefaultNetworkLocked(ctx)
+}
+func (p *DockerProvider) ensureDefaultNetworkLocked(ctx context.Context) (string, error) {
if p.defaultNetwork != "" {
// Already set.
return p.defaultNetwork, nil
diff --git a/docker_test.go b/docker_test.go
index 3fa686632f..0dd60f6db9 100644
--- a/docker_test.go
+++ b/docker_test.go
@@ -25,6 +25,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+ "github.com/testcontainers/testcontainers-go/internal/core"
"github.com/testcontainers/testcontainers-go/wait"
)
@@ -35,6 +36,7 @@ const (
nginxAlpineImage = "nginx:alpine"
nginxDefaultPort = "80/tcp"
nginxHighPort = "8080/tcp"
+ golangImage = "golang"
daemonMaxVersion = "1.41"
)
@@ -279,6 +281,18 @@ func TestContainerStateAfterTermination(t *testing.T) {
require.Nil(t, state, "expected nil container inspect.")
})
+ t.Run("termination-timeout", func(t *testing.T) {
+ ctx := context.Background()
+ nginx, err := createContainerFn(ctx)
+ require.NoError(t, err)
+
+ err = nginx.Start(ctx)
+ require.NoError(t, err, "expected no error from container start.")
+
+ err = nginx.Terminate(ctx, StopTimeout(5*time.Microsecond))
+ require.NoError(t, err)
+ })
+
t.Run("Nil State after termination if raw as already set", func(t *testing.T) {
ctx := context.Background()
nginx, err := createContainerFn(ctx)
@@ -705,6 +719,37 @@ func Test_BuildContainerFromDockerfileWithBuildLog(t *testing.T) {
assert.Regexpf(t, `^Step\s*1/\d+\s*:\s*FROM alpine$`, temp[0], "Expected stdout first line to be %s. Got '%s'.", "Step 1/* : FROM alpine", temp[0])
}
+func Test_BuildContainerFromDockerfileWithBuildLogWriter(t *testing.T) {
+ var buffer bytes.Buffer
+
+ ctx := context.Background()
+
+ // fromDockerfile {
+ req := ContainerRequest{
+ FromDockerfile: FromDockerfile{
+ Context: filepath.Join(".", "testdata"),
+ Dockerfile: "buildlog.Dockerfile",
+ BuildLogWriter: &buffer,
+ },
+ }
+ // }
+
+ genContainerReq := GenericContainerRequest{
+ ProviderType: providerType,
+ ContainerRequest: req,
+ Started: true,
+ }
+
+ c, err := GenericContainer(ctx, genContainerReq)
+ CleanupContainer(t, c)
+ require.NoError(t, err)
+
+ out := buffer.String()
+ temp := strings.Split(out, "\n")
+ require.NotEmpty(t, temp)
+ require.Regexpf(t, `^Step\s*1/\d+\s*:\s*FROM alpine$`, temp[0], "Expected stdout first line to be %s. Got '%s'.", "Step 1/* : FROM alpine", temp[0])
+}
+
func TestContainerCreationWaitsForLogAndPortContextTimeout(t *testing.T) {
ctx := context.Background()
req := ContainerRequest{
@@ -1044,6 +1089,7 @@ func TestContainerCreationWithVolumeAndFileWritingToIt(t *testing.T) {
{
HostFilePath: absPath,
ContainerFilePath: "/hello.sh",
+ FileMode: 700,
},
},
Mounts: Mounts(VolumeMount(volumeName, "/data")),
@@ -1056,6 +1102,68 @@ func TestContainerCreationWithVolumeAndFileWritingToIt(t *testing.T) {
require.NoError(t, err)
}
+func TestContainerCreationWithVolumeCleaning(t *testing.T) {
+ absPath, err := filepath.Abs(filepath.Join(".", "testdata", "hello.sh"))
+ require.NoError(t, err)
+ ctx, cnl := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cnl()
+
+ // Create the volume.
+ volumeName := "volumeName"
+
+ // Create the container that writes into the mounted volume.
+ bashC, err := GenericContainer(ctx, GenericContainerRequest{
+ ProviderType: providerType,
+ ContainerRequest: ContainerRequest{
+ Image: "bash:5.2.26",
+ Files: []ContainerFile{
+ {
+ HostFilePath: absPath,
+ ContainerFilePath: "/hello.sh",
+ FileMode: 700,
+ },
+ },
+ Mounts: Mounts(VolumeMount(volumeName, "/data")),
+ Cmd: []string{"bash", "/hello.sh"},
+ WaitingFor: wait.ForLog("done"),
+ },
+ Started: true,
+ })
+ require.NoError(t, err)
+ err = bashC.Terminate(ctx, RemoveVolumes(volumeName))
+ CleanupContainer(t, bashC, RemoveVolumes(volumeName))
+ require.NoError(t, err)
+}
+
+func TestContainerTerminationOptions(t *testing.T) {
+ t.Run("volumes", func(t *testing.T) {
+ var options TerminateOptions
+ RemoveVolumes("vol1", "vol2")(&options)
+ require.Equal(t, TerminateOptions{
+ volumes: []string{"vol1", "vol2"},
+ }, options)
+ })
+ t.Run("stop-timeout", func(t *testing.T) {
+ var options TerminateOptions
+ timeout := 11 * time.Second
+ StopTimeout(timeout)(&options)
+ require.Equal(t, TerminateOptions{
+ stopTimeout: &timeout,
+ }, options)
+ })
+
+ t.Run("all", func(t *testing.T) {
+ var options TerminateOptions
+ timeout := 9 * time.Second
+ StopTimeout(timeout)(&options)
+ RemoveVolumes("vol1", "vol2")(&options)
+ require.Equal(t, TerminateOptions{
+ stopTimeout: &timeout,
+ volumes: []string{"vol1", "vol2"},
+ }, options)
+ })
+}
+
func TestContainerWithTmpFs(t *testing.T) {
ctx := context.Background()
req := ContainerRequest{
@@ -2125,3 +2233,39 @@ func TestCustomPrefixTrailingSlashIsProperlyRemovedIfPresent(t *testing.T) {
dockerContainer := c.(*DockerContainer)
require.Equal(t, fmt.Sprintf("%s%s", hubPrefixWithTrailingSlash, dockerImage), dockerContainer.Image)
}
+
+// TODO: remove this skip check when context rework is merged alongside [core.DockerEnvFile] removal.
+func Test_Provider_DaemonHost_Issue2897(t *testing.T) {
+ ctx := context.Background()
+ provider, err := NewDockerProvider()
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ require.NoError(t, provider.Close())
+ })
+
+ orig := core.DockerEnvFile
+ core.DockerEnvFile = filepath.Join(t.TempDir(), ".dockerenv")
+ t.Cleanup(func() {
+ core.DockerEnvFile = orig
+ })
+
+ f, err := os.Create(core.DockerEnvFile)
+ require.NoError(t, err)
+ require.NoError(t, f.Close())
+ t.Cleanup(func() {
+ require.NoError(t, os.Remove(f.Name()))
+ })
+
+ errCh := make(chan error, 1)
+ go func() {
+ _, err := provider.DaemonHost(ctx)
+ errCh <- err
+ }()
+
+ select {
+ case <-time.After(1 * time.Second):
+ t.Fatal("timeout waiting for DaemonHost")
+ case err := <-errCh:
+ require.NoError(t, err)
+ }
+}
diff --git a/docs/features/garbage_collector.md b/docs/features/garbage_collector.md
index e725f5a9bd..4712c59748 100644
--- a/docs/features/garbage_collector.md
+++ b/docs/features/garbage_collector.md
@@ -17,6 +17,47 @@ The primary method is to use the `Terminate(context.Context)` function that is
available when a container is created. Use `defer` to ensure that it is called
on test completion.
+The `Terminate` function can be customised with termination options to determine how a container is removed: termination timeout, and the ability to remove container volumes are supported at the moment. You can build the default options using the `testcontainers.NewTerminationOptions` function.
+
+#### NewTerminateOptions
+
+- Not available until the next release of testcontainers-go :material-tag: main
+
+If you want to attach option to container termination, you can use the `testcontainers.NewTerminateOptions(ctx context.Context, opts ...TerminateOption) *TerminateOptions` option, which receives a TerminateOption as parameter, creating custom termination options to be passed on the container termination.
+
+##### Terminate Options
+
+###### [StopContext](../../cleanup.go)
+Sets the context for the Container termination.
+
+- **Function**: `StopContext(ctx context.Context) TerminateOption`
+- **Default**: The context passed in `Terminate()`
+- **Usage**:
+```go
+err := container.Terminate(ctx,StopContext(context.Background()))
+```
+
+###### [StopTimeout](../../cleanup.go)
+Sets the timeout for stopping the Container.
+
+- **Function**: ` StopTimeout(timeout time.Duration) TerminateOption`
+- **Default**: 10 seconds
+- **Usage**:
+```go
+err := container.Terminate(ctx, StopTimeout(20 * time.Second))
+```
+
+###### [RemoveVolumes](../../cleanup.go)
+Sets the volumes to be removed during Container termination.
+
+- **Function**: ` RemoveVolumes(volumes ...string) TerminateOption`
+- **Default**: Empty (no volumes removed)
+- **Usage**:
+```go
+err := container.Terminate(ctx, RemoveVolumes("vol1", "vol2"))
+```
+
+
!!!tip
Remember to `defer` as soon as possible so you won't forget. The best time
diff --git a/docs/features/wait/log.md b/docs/features/wait/log.md
index f1d40ff360..8466d68511 100644
--- a/docs/features/wait/log.md
+++ b/docs/features/wait/log.md
@@ -3,10 +3,11 @@
The Log wait strategy will check if a string occurs in the container logs for a desired number of times, and allows to set the following conditions:
- the string to be waited for in the container log.
-- the number of occurrences of the string to wait for, default is `1`.
+- the number of occurrences of the string to wait for, default is `1` (ignored for Submatch).
- look for the string using a regular expression, default is `false`.
- the startup timeout to be used in seconds, default is 60 seconds.
- the poll interval to be used in milliseconds, default is 100 milliseconds.
+- the regular expression submatch callback, default nil (occurrences is ignored).
```golang
req := ContainerRequest{
@@ -33,3 +34,40 @@ req := ContainerRequest{
WaitingFor: wait.ForLog(`.*MySQL Community Server`).AsRegexp(),
}
```
+
+Using regular expression with submatch:
+
+```golang
+var host, port string
+req := ContainerRequest{
+ Image: "ollama/ollama:0.1.25",
+ ExposedPorts: []string{"11434/tcp"},
+ WaitingFor: wait.ForLog(`Listening on (.*:\d+) \(version\s(.*)\)`).Submatch(func(pattern string, submatches [][][]byte) error {
+ var err error
+ for _, matches := range submatches {
+ if len(matches) != 3 {
+ err = fmt.Errorf("`%s` matched %d times, expected %d", pattern, len(matches), 3)
+ continue
+ }
+ host, port, err = net.SplitHostPort(string(matches[1]))
+ if err != nil {
+ return wait.NewPermanentError(fmt.Errorf("split host port: %w", err))
+ }
+
+ // Host and port successfully extracted from log.
+ return nil
+ }
+
+ if err != nil {
+ // Return the last error encountered.
+ return err
+ }
+
+ return fmt.Errorf("address and version not found: `%s` no matches", pattern)
+ }),
+}
+```
+
+If the return from a Submatch callback function is a `wait.PermanentError` the
+wait will stop and the error will be returned. Use `wait.NewPermanentError(err error)`
+to achieve this.
diff --git a/docs/modules/ollama.md b/docs/modules/ollama.md
index c16e612142..18cb08b47a 100644
--- a/docs/modules/ollama.md
+++ b/docs/modules/ollama.md
@@ -16,10 +16,15 @@ go get github.com/testcontainers/testcontainers-go/modules/ollama
## Usage example
+The module allows you to run the Ollama container or the local Ollama binary.
+
[Creating a Ollama container](../../modules/ollama/examples_test.go) inside_block:runOllamaContainer
+[Running the local Ollama binary](../../modules/ollama/examples_test.go) inside_block:localOllama
+If the local Ollama binary fails to execute, the module will fallback to the container version of Ollama.
+
## Module Reference
### Run function
@@ -48,6 +53,51 @@ When starting the Ollama container, you can pass options in a variadic way to co
If you need to set a different Ollama Docker image, you can set a valid Docker image as the second argument in the `Run` function.
E.g. `Run(context.Background(), "ollama/ollama:0.1.25")`.
+#### Use Local
+
+- Not available until the next release of testcontainers-go :material-tag: main
+
+!!!warning
+ Please make sure the local Ollama binary is not running when using the local version of the module:
+ Ollama can be started as a system service, or as part of the Ollama application,
+ and interacting with the logs of a running Ollama process not managed by the module is not supported.
+
+If you need to run the local Ollama binary, you can set the `UseLocal` option in the `Run` function.
+This option accepts a list of environment variables as a string, that will be applied to the Ollama binary when executing commands.
+
+E.g. `Run(context.Background(), "ollama/ollama:0.1.25", WithUseLocal("OLLAMA_DEBUG=true"))`.
+
+All the container methods are available when using the local Ollama binary, but will be executed locally instead of inside the container.
+Please consider the following differences when using the local Ollama binary:
+
+- The local Ollama binary will create a log file in the current working directory, identified by the session ID. E.g. `local-ollama-.log`. It's possible to set the log file name using the `OLLAMA_LOGFILE` environment variable. So if you're running Ollama yourself, from the Ollama app, or the standalone binary, you could use this environment variable to set the same log file name.
+ - For the Ollama app, the default log file resides in the `$HOME/.ollama/logs/server.log`.
+ - For the standalone binary, you should start it redirecting the logs to a file. E.g. `ollama serve > /tmp/ollama.log 2>&1`.
+- `ConnectionString` returns the connection string to connect to the local Ollama binary started by the module instead of the container.
+- `ContainerIP` returns the bound host IP `127.0.0.1` by default.
+- `ContainerIPs` returns the bound host IP `["127.0.0.1"]` by default.
+- `CopyToContainer`, `CopyDirToContainer`, `CopyFileToContainer` and `CopyFileFromContainer` return an error if called.
+- `GetLogProductionErrorChannel` returns a nil channel.
+- `Endpoint` returns the endpoint to connect to the local Ollama binary started by the module instead of the container.
+- `Exec` passes the command to the local Ollama binary started by the module instead of inside the container. First argument is the command to execute, and the second argument is the list of arguments, else, an error is returned.
+- `GetContainerID` returns the container ID of the local Ollama binary started by the module instead of the container, which maps to `local-ollama-`.
+- `Host` returns the bound host IP `127.0.0.1` by default.
+- `Inspect` returns a ContainerJSON with the state of the local Ollama binary started by the module.
+- `IsRunning` returns true if the local Ollama binary process started by the module is running.
+- `Logs` returns the logs from the local Ollama binary started by the module instead of the container.
+- `MappedPort` returns the port mapping for the local Ollama binary started by the module instead of the container.
+- `Start` starts the local Ollama binary process.
+- `State` returns the current state of the local Ollama binary process, `stopped` or `running`.
+- `Stop` stops the local Ollama binary process.
+- `Terminate` calls the `Stop` method and then removes the log file.
+
+The local Ollama binary will create a log file in the current working directory, and it will be available in the container's `Logs` method.
+
+!!!info
+ The local Ollama binary will use the `OLLAMA_HOST` environment variable to set the host and port to listen on.
+ If the environment variable is not set, it will default to `localhost:0`
+ which bind to a loopback address on an ephemeral port to avoid port conflicts.
+
{% include "../features/common_functional_options.md" %}
### Container Methods
diff --git a/docs/modules/postgres.md b/docs/modules/postgres.md
index 930de50c15..4192cf7eca 100644
--- a/docs/modules/postgres.md
+++ b/docs/modules/postgres.md
@@ -74,9 +74,35 @@ An example of a `*.sh` script that creates a user and database is shown below:
In the case you have a custom config file for Postgres, it's possible to copy that file into the container before it's started, using the `WithConfigFile(cfgPath string)` function.
+This function can be used `WithSSLSettings` but requires your configuration correctly sets the SSL properties. See the below section for more information.
+
!!!tip
For information on what is available to configure, see the [PostgreSQL docs](https://www.postgresql.org/docs/14/runtime-config.html) for the specific version of PostgreSQL that you are running.
+#### SSL Configuration
+
+- Not available until the next release of testcontainers-go :material-tag: main
+
+If you would like to use SSL with the container you can use the `WithSSLSettings`. This function accepts a `SSLSettings` which has the required secret material, namely the ca-certificate, server certificate and key. The container will copy this material to `/tmp/testcontainers-go/postgres/ca_cert.pem`, `/tmp/testcontainers-go/postgres/server.cert` and `/tmp/testcontainers-go/postgres/server.key`
+
+This function requires a custom postgres configuration file that enables SSL and correctly sets the paths on the key material.
+
+If you use this function by itself or in conjuction with `WithConfigFile` your custom conf must set the require ssl fields. The configuration must correctly align the key material provided via `SSLSettings` with the server configuration, namely the paths. Your configuration will need to contain the following:
+
+```
+ssl = on
+ssl_ca_file = '/tmp/testcontainers-go/postgres/ca_cert.pem'
+ssl_cert_file = '/tmp/testcontainers-go/postgres/server.cert'
+ssl_key_file = '/tmp/testcontainers-go/postgres/server.key'
+```
+
+!!!warning
+ This function assumes the postgres user in the container is `postgres`
+
+ There is no current support for mutual authentication.
+
+ The `SSLSettings` function will modify the container `entrypoint`. This is done so that key material copied over to the container is chowned by `postgres`. All other container arguments will be passed through to the original container entrypoint.
+
### Container Methods
#### ConnectionString
diff --git a/internal/core/docker_host.go b/internal/core/docker_host.go
index 3088a3742b..765626da57 100644
--- a/internal/core/docker_host.go
+++ b/internal/core/docker_host.go
@@ -309,10 +309,15 @@ func testcontainersHostFromProperties(ctx context.Context) (string, error) {
return "", ErrTestcontainersHostNotSetInProperties
}
+// DockerEnvFile is the file that is created when running inside a container.
+// It's a variable to allow testing.
+// TODO: Remove this once context rework is done, which eliminates need for the default network creation.
+var DockerEnvFile = "/.dockerenv"
+
// InAContainer returns true if the code is running inside a container
// See https://github.com/docker/docker/blob/a9fa38b1edf30b23cae3eade0be48b3d4b1de14b/daemon/initlayer/setup_unix.go#L25
func InAContainer() bool {
- return inAContainer("/.dockerenv")
+ return inAContainer(DockerEnvFile)
}
func inAContainer(path string) bool {
diff --git a/modules/etcd/etcd.go b/modules/etcd/etcd.go
index 7ea78b4385..a715150bf1 100644
--- a/modules/etcd/etcd.go
+++ b/modules/etcd/etcd.go
@@ -29,18 +29,18 @@ type EtcdContainer struct {
// Terminate terminates the etcd container, its child nodes, and the network in which the cluster is running
// to communicate between the nodes.
-func (c *EtcdContainer) Terminate(ctx context.Context) error {
+func (c *EtcdContainer) Terminate(ctx context.Context, opts ...testcontainers.TerminateOption) error {
var errs []error
// child nodes has no other children
for i, child := range c.childNodes {
- if err := child.Terminate(ctx); err != nil {
+ if err := child.Terminate(ctx, opts...); err != nil {
errs = append(errs, fmt.Errorf("terminate child node(%d): %w", i, err))
}
}
if c.Container != nil {
- if err := c.Container.Terminate(ctx); err != nil {
+ if err := c.Container.Terminate(ctx, opts...); err != nil {
errs = append(errs, fmt.Errorf("terminate cluster node: %w", err))
}
}
diff --git a/modules/ollama/examples_test.go b/modules/ollama/examples_test.go
index 741db846be..188be45bbb 100644
--- a/modules/ollama/examples_test.go
+++ b/modules/ollama/examples_test.go
@@ -173,3 +173,73 @@ func ExampleRun_withModel_llama2_langchain() {
// Intentionally not asserting the output, as we don't want to run this example in the tests.
}
+
+func ExampleRun_withLocal() {
+ ctx := context.Background()
+
+ // localOllama {
+ ollamaContainer, err := tcollama.Run(ctx, "ollama/ollama:0.3.13", tcollama.WithUseLocal("OLLAMA_DEBUG=true"))
+ defer func() {
+ if err := testcontainers.TerminateContainer(ollamaContainer); err != nil {
+ log.Printf("failed to terminate container: %s", err)
+ }
+ }()
+ if err != nil {
+ log.Printf("failed to start container: %s", err)
+ return
+ }
+ // }
+
+ model := "llama3.2:1b"
+
+ _, _, err = ollamaContainer.Exec(ctx, []string{"ollama", "pull", model})
+ if err != nil {
+ log.Printf("failed to pull model %s: %s", model, err)
+ return
+ }
+
+ _, _, err = ollamaContainer.Exec(ctx, []string{"ollama", "run", model})
+ if err != nil {
+ log.Printf("failed to run model %s: %s", model, err)
+ return
+ }
+
+ connectionStr, err := ollamaContainer.ConnectionString(ctx)
+ if err != nil {
+ log.Printf("failed to get connection string: %s", err)
+ return
+ }
+
+ var llm *langchainollama.LLM
+ if llm, err = langchainollama.New(
+ langchainollama.WithModel(model),
+ langchainollama.WithServerURL(connectionStr),
+ ); err != nil {
+ log.Printf("failed to create langchain ollama: %s", err)
+ return
+ }
+
+ completion, err := llm.Call(
+ context.Background(),
+ "how can Testcontainers help with testing?",
+ llms.WithSeed(42), // the lower the seed, the more deterministic the completion
+ llms.WithTemperature(0.0), // the lower the temperature, the more creative the completion
+ )
+ if err != nil {
+ log.Printf("failed to create langchain ollama: %s", err)
+ return
+ }
+
+ words := []string{
+ "easy", "isolation", "consistency",
+ }
+ lwCompletion := strings.ToLower(completion)
+
+ for _, word := range words {
+ if strings.Contains(lwCompletion, word) {
+ fmt.Println(true)
+ }
+ }
+
+ // Intentionally not asserting the output, as we don't want to run this example in the tests.
+}
diff --git a/modules/ollama/go.mod b/modules/ollama/go.mod
index e22b801031..2aab83b978 100644
--- a/modules/ollama/go.mod
+++ b/modules/ollama/go.mod
@@ -4,6 +4,7 @@ go 1.22
require (
github.com/docker/docker v27.1.1+incompatible
+ github.com/docker/go-connections v0.5.0
github.com/google/uuid v1.6.0
github.com/stretchr/testify v1.9.0
github.com/testcontainers/testcontainers-go v0.34.0
@@ -22,7 +23,6 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/dlclark/regexp2 v1.8.1 // indirect
- github.com/docker/go-connections v0.5.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-logr/logr v1.4.1 // indirect
diff --git a/modules/ollama/local.go b/modules/ollama/local.go
new file mode 100644
index 0000000000..5751ceee07
--- /dev/null
+++ b/modules/ollama/local.go
@@ -0,0 +1,755 @@
+package ollama
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "io/fs"
+ "net"
+ "os"
+ "os/exec"
+ "reflect"
+ "strings"
+ "sync"
+ "syscall"
+ "time"
+
+ "github.com/docker/docker/api/types"
+ "github.com/docker/docker/api/types/container"
+ "github.com/docker/docker/api/types/network"
+ "github.com/docker/docker/errdefs"
+ "github.com/docker/docker/pkg/stdcopy"
+ "github.com/docker/go-connections/nat"
+
+ "github.com/testcontainers/testcontainers-go"
+ tcexec "github.com/testcontainers/testcontainers-go/exec"
+ "github.com/testcontainers/testcontainers-go/wait"
+)
+
+const (
+ localPort = "11434"
+ localBinary = "ollama"
+ localServeArg = "serve"
+ localLogRegex = `Listening on (.*:\d+) \(version\s(.*)\)`
+ localNamePrefix = "local-ollama"
+ localHostVar = "OLLAMA_HOST"
+ localLogVar = "OLLAMA_LOGFILE"
+)
+
+var (
+ // Ensure localProcess implements the required interfaces.
+ _ testcontainers.Container = (*localProcess)(nil)
+ _ testcontainers.ContainerCustomizer = (*localProcess)(nil)
+
+ // zeroTime is the zero time value.
+ zeroTime time.Time
+)
+
+// localProcess emulates the Ollama container using a local process to improve performance.
+type localProcess struct {
+ sessionID string
+
+ // env is the combined environment variables passed to the Ollama binary.
+ env []string
+
+ // cmd is the command that runs the Ollama binary, not valid externally if nil.
+ cmd *exec.Cmd
+
+ // logName and logFile are the file where the Ollama logs are written.
+ logName string
+ logFile *os.File
+
+ // host, port and version are extracted from log on startup.
+ host string
+ port string
+ version string
+
+ // waitFor is the strategy to wait for the process to be ready.
+ waitFor wait.Strategy
+
+ // done is closed when the process is finished.
+ done chan struct{}
+
+ // wg is used to wait for the process to finish.
+ wg sync.WaitGroup
+
+ // startedAt is the time when the process started.
+ startedAt time.Time
+
+ // mtx is used to synchronize access to the process state fields below.
+ mtx sync.Mutex
+
+ // finishedAt is the time when the process finished.
+ finishedAt time.Time
+
+ // exitErr is the error returned by the process.
+ exitErr error
+
+ // binary is the name of the Ollama binary.
+ binary string
+}
+
+// runLocal returns an OllamaContainer that uses the local Ollama binary instead of using a Docker container.
+func (c *localProcess) run(ctx context.Context, req testcontainers.GenericContainerRequest) (*OllamaContainer, error) {
+ if err := c.validateRequest(req); err != nil {
+ return nil, fmt.Errorf("validate request: %w", err)
+ }
+
+ // Apply the updated details from the request.
+ c.waitFor = req.WaitingFor
+ c.env = c.env[:0]
+ for k, v := range req.Env {
+ c.env = append(c.env, k+"="+v)
+ if k == localLogVar {
+ c.logName = v
+ }
+ }
+
+ err := c.Start(ctx)
+ var container *OllamaContainer
+ if c.cmd != nil {
+ container = &OllamaContainer{Container: c}
+ }
+
+ if err != nil {
+ return container, fmt.Errorf("start ollama: %w", err)
+ }
+
+ return container, nil
+}
+
+// validateRequest checks that req is valid for the local Ollama binary.
+func (c *localProcess) validateRequest(req testcontainers.GenericContainerRequest) error {
+ var errs []error
+ if req.WaitingFor == nil {
+ errs = append(errs, errors.New("ContainerRequest.WaitingFor must be set"))
+ }
+
+ if !req.Started {
+ errs = append(errs, errors.New("Started must be true"))
+ }
+
+ if !reflect.DeepEqual(req.ExposedPorts, []string{localPort + "/tcp"}) {
+ errs = append(errs, fmt.Errorf("ContainerRequest.ExposedPorts must be %s/tcp got: %s", localPort, req.ExposedPorts))
+ }
+
+ // Validate the image and extract the binary name.
+ // The image must be in the format "[/][:latest]".
+ if binary := req.Image; binary != "" {
+ // Check if the version is "latest" or not specified.
+ if idx := strings.IndexByte(binary, ':'); idx != -1 {
+ if binary[idx+1:] != "latest" {
+ errs = append(errs, fmt.Errorf(`ContainerRequest.Image version must be blank or "latest", got: %q`, binary[idx+1:]))
+ }
+ binary = binary[:idx]
+ }
+
+ // Trim the path if present.
+ if idx := strings.LastIndexByte(binary, '/'); idx != -1 {
+ binary = binary[idx+1:]
+ }
+
+ if _, err := exec.LookPath(binary); err != nil {
+ errs = append(errs, fmt.Errorf("invalid image %q: %w", req.Image, err))
+ } else {
+ c.binary = binary
+ }
+ }
+
+ // Reset fields we support to their zero values.
+ req.Env = nil
+ req.ExposedPorts = nil
+ req.WaitingFor = nil
+ req.Image = ""
+ req.Started = false
+ req.Logger = nil // We don't need the logger.
+
+ parts := make([]string, 0, 3)
+ value := reflect.ValueOf(req)
+ typ := value.Type()
+ fields := reflect.VisibleFields(typ)
+ for _, f := range fields {
+ field := value.FieldByIndex(f.Index)
+ if field.Kind() == reflect.Struct {
+ // Only check the leaf fields.
+ continue
+ }
+
+ if !field.IsZero() {
+ parts = parts[:0]
+ for i := range f.Index {
+ parts = append(parts, typ.FieldByIndex(f.Index[:i+1]).Name)
+ }
+ errs = append(errs, fmt.Errorf("unsupported field: %s = %q", strings.Join(parts, "."), field))
+ }
+ }
+
+ return errors.Join(errs...)
+}
+
+// Start implements testcontainers.Container interface for the local Ollama binary.
+func (c *localProcess) Start(ctx context.Context) error {
+ if c.IsRunning() {
+ return errors.New("already running")
+ }
+
+ cmd := exec.CommandContext(ctx, c.binary, localServeArg)
+ cmd.Env = c.env
+
+ var err error
+ c.logFile, err = os.Create(c.logName)
+ if err != nil {
+ return fmt.Errorf("create ollama log file: %w", err)
+ }
+
+ // Multiplex stdout and stderr to the log file matching the Docker API.
+ cmd.Stdout = stdcopy.NewStdWriter(c.logFile, stdcopy.Stdout)
+ cmd.Stderr = stdcopy.NewStdWriter(c.logFile, stdcopy.Stderr)
+
+ // Run the ollama serve command in background.
+ if err = cmd.Start(); err != nil {
+ return fmt.Errorf("start ollama serve: %w", errors.Join(err, c.cleanup()))
+ }
+
+ // Past this point, the process was started successfully.
+ c.cmd = cmd
+ c.startedAt = time.Now()
+
+ // Reset the details to allow multiple start / stop cycles.
+ c.done = make(chan struct{})
+ c.mtx.Lock()
+ c.finishedAt = zeroTime
+ c.exitErr = nil
+ c.mtx.Unlock()
+
+ // Wait for the process to finish in a goroutine.
+ c.wg.Add(1)
+ go func() {
+ defer func() {
+ c.wg.Done()
+ close(c.done)
+ }()
+
+ err := c.cmd.Wait()
+ c.mtx.Lock()
+ defer c.mtx.Unlock()
+ if err != nil {
+ c.exitErr = fmt.Errorf("process wait: %w", err)
+ }
+ c.finishedAt = time.Now()
+ }()
+
+ if err = c.waitStrategy(ctx); err != nil {
+ return fmt.Errorf("wait strategy: %w", err)
+ }
+
+ return nil
+}
+
+// waitStrategy waits until the Ollama process is ready.
+func (c *localProcess) waitStrategy(ctx context.Context) error {
+ if err := c.waitFor.WaitUntilReady(ctx, c); err != nil {
+ logs, lerr := c.Logs(ctx)
+ if lerr != nil {
+ return errors.Join(err, lerr)
+ }
+ defer logs.Close()
+
+ var stderr, stdout bytes.Buffer
+ _, cerr := stdcopy.StdCopy(&stdout, &stderr, logs)
+
+ return fmt.Errorf(
+ "%w (stdout: %s, stderr: %s)",
+ errors.Join(err, cerr),
+ strings.TrimSpace(stdout.String()),
+ strings.TrimSpace(stderr.String()),
+ )
+ }
+
+ return nil
+}
+
+// extractLogDetails extracts the listening address and version from the log.
+func (c *localProcess) extractLogDetails(pattern string, submatches [][][]byte) error {
+ var err error
+ for _, matches := range submatches {
+ if len(matches) != 3 {
+ err = fmt.Errorf("`%s` matched %d times, expected %d", pattern, len(matches), 3)
+ continue
+ }
+
+ c.host, c.port, err = net.SplitHostPort(string(matches[1]))
+ if err != nil {
+ return wait.NewPermanentError(fmt.Errorf("split host port: %w", err))
+ }
+
+ // Set OLLAMA_HOST variable to the extracted host so Exec can use it.
+ c.env = append(c.env, localHostVar+"="+string(matches[1]))
+ c.version = string(matches[2])
+
+ return nil
+ }
+
+ if err != nil {
+ // Return the last error encountered.
+ return err
+ }
+
+ return fmt.Errorf("address and version not found: `%s` no matches", pattern)
+}
+
+// ContainerIP implements testcontainers.Container interface for the local Ollama binary.
+func (c *localProcess) ContainerIP(ctx context.Context) (string, error) {
+ return c.host, nil
+}
+
+// ContainerIPs returns a slice with the IP address of the local Ollama binary.
+func (c *localProcess) ContainerIPs(ctx context.Context) ([]string, error) {
+ return []string{c.host}, nil
+}
+
+// CopyToContainer implements testcontainers.Container interface for the local Ollama binary.
+// Returns [errors.ErrUnsupported].
+func (c *localProcess) CopyToContainer(ctx context.Context, fileContent []byte, containerFilePath string, fileMode int64) error {
+ return errors.ErrUnsupported
+}
+
+// CopyDirToContainer implements testcontainers.Container interface for the local Ollama binary.
+// Returns [errors.ErrUnsupported].
+func (c *localProcess) CopyDirToContainer(ctx context.Context, hostDirPath string, containerParentPath string, fileMode int64) error {
+ return errors.ErrUnsupported
+}
+
+// CopyFileToContainer implements testcontainers.Container interface for the local Ollama binary.
+// Returns [errors.ErrUnsupported].
+func (c *localProcess) CopyFileToContainer(ctx context.Context, hostFilePath string, containerFilePath string, fileMode int64) error {
+ return errors.ErrUnsupported
+}
+
+// CopyFileFromContainer implements testcontainers.Container interface for the local Ollama binary.
+// Returns [errors.ErrUnsupported].
+func (c *localProcess) CopyFileFromContainer(ctx context.Context, filePath string) (io.ReadCloser, error) {
+ return nil, errors.ErrUnsupported
+}
+
+// GetLogProductionErrorChannel implements testcontainers.Container interface for the local Ollama binary.
+// It returns a nil channel because the local Ollama binary doesn't have a production error channel.
+func (c *localProcess) GetLogProductionErrorChannel() <-chan error {
+ return nil
+}
+
+// Exec implements testcontainers.Container interface for the local Ollama binary.
+// It executes a command using the local Ollama binary and returns the exit status
+// of the executed command, an [io.Reader] containing the combined stdout and stderr,
+// and any encountered error.
+//
+// Reading directly from the [io.Reader] may result in unexpected bytes due to custom
+// stream multiplexing headers. Use [tcexec.Multiplexed] option to read the combined output
+// without the multiplexing headers.
+// Alternatively, to separate the stdout and stderr from [io.Reader] and interpret these
+// headers properly, [stdcopy.StdCopy] from the Docker API should be used.
+func (c *localProcess) Exec(ctx context.Context, cmd []string, options ...tcexec.ProcessOption) (int, io.Reader, error) {
+ if len(cmd) == 0 {
+ return 1, nil, errors.New("no command provided")
+ } else if cmd[0] != c.binary {
+ return 1, nil, fmt.Errorf("command %q: %w", cmd[0], errors.ErrUnsupported)
+ }
+
+ command := exec.CommandContext(ctx, cmd[0], cmd[1:]...)
+ command.Env = c.env
+
+ // Multiplex stdout and stderr to the buffer so they can be read separately later.
+ var buf bytes.Buffer
+ command.Stdout = stdcopy.NewStdWriter(&buf, stdcopy.Stdout)
+ command.Stderr = stdcopy.NewStdWriter(&buf, stdcopy.Stderr)
+
+ // Use process options to customize the command execution
+ // emulating the Docker API behaviour.
+ processOptions := tcexec.NewProcessOptions(cmd)
+ processOptions.Reader = &buf
+ for _, o := range options {
+ o.Apply(processOptions)
+ }
+
+ if err := c.validateExecOptions(processOptions.ExecConfig); err != nil {
+ return 1, nil, fmt.Errorf("validate exec option: %w", err)
+ }
+
+ if !processOptions.ExecConfig.AttachStderr {
+ command.Stderr = io.Discard
+ }
+ if !processOptions.ExecConfig.AttachStdout {
+ command.Stdout = io.Discard
+ }
+ if processOptions.ExecConfig.AttachStdin {
+ command.Stdin = os.Stdin
+ }
+
+ command.Dir = processOptions.ExecConfig.WorkingDir
+ command.Env = append(command.Env, processOptions.ExecConfig.Env...)
+
+ if err := command.Run(); err != nil {
+ return command.ProcessState.ExitCode(), processOptions.Reader, fmt.Errorf("exec %v: %w", cmd, err)
+ }
+
+ return command.ProcessState.ExitCode(), processOptions.Reader, nil
+}
+
+// validateExecOptions checks if the given exec options are supported by the local Ollama binary.
+func (c *localProcess) validateExecOptions(options container.ExecOptions) error {
+ var errs []error
+ if options.User != "" {
+ errs = append(errs, fmt.Errorf("user: %w", errors.ErrUnsupported))
+ }
+ if options.Privileged {
+ errs = append(errs, fmt.Errorf("privileged: %w", errors.ErrUnsupported))
+ }
+ if options.Tty {
+ errs = append(errs, fmt.Errorf("tty: %w", errors.ErrUnsupported))
+ }
+ if options.Detach {
+ errs = append(errs, fmt.Errorf("detach: %w", errors.ErrUnsupported))
+ }
+ if options.DetachKeys != "" {
+ errs = append(errs, fmt.Errorf("detach keys: %w", errors.ErrUnsupported))
+ }
+
+ return errors.Join(errs...)
+}
+
+// Inspect implements testcontainers.Container interface for the local Ollama binary.
+// It returns a ContainerJSON with the state of the local Ollama binary.
+func (c *localProcess) Inspect(ctx context.Context) (*types.ContainerJSON, error) {
+ state, err := c.State(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("state: %w", err)
+ }
+
+ return &types.ContainerJSON{
+ ContainerJSONBase: &types.ContainerJSONBase{
+ ID: c.GetContainerID(),
+ Name: localNamePrefix + "-" + c.sessionID,
+ State: state,
+ },
+ Config: &container.Config{
+ Image: localNamePrefix + ":" + c.version,
+ ExposedPorts: nat.PortSet{
+ nat.Port(localPort + "/tcp"): struct{}{},
+ },
+ Hostname: c.host,
+ Entrypoint: []string{c.binary, localServeArg},
+ },
+ NetworkSettings: &types.NetworkSettings{
+ Networks: map[string]*network.EndpointSettings{},
+ NetworkSettingsBase: types.NetworkSettingsBase{
+ Bridge: "bridge",
+ Ports: nat.PortMap{
+ nat.Port(localPort + "/tcp"): {
+ {HostIP: c.host, HostPort: c.port},
+ },
+ },
+ },
+ DefaultNetworkSettings: types.DefaultNetworkSettings{
+ IPAddress: c.host,
+ },
+ },
+ }, nil
+}
+
+// IsRunning implements testcontainers.Container interface for the local Ollama binary.
+// It returns true if the local Ollama process is running, false otherwise.
+func (c *localProcess) IsRunning() bool {
+ if c.startedAt.IsZero() {
+ // The process hasn't started yet.
+ return false
+ }
+
+ select {
+ case <-c.done:
+ // The process exited.
+ return false
+ default:
+ // The process is still running.
+ return true
+ }
+}
+
+// Logs implements testcontainers.Container interface for the local Ollama binary.
+// It returns the logs from the local Ollama binary.
+func (c *localProcess) Logs(ctx context.Context) (io.ReadCloser, error) {
+ file, err := os.Open(c.logFile.Name())
+ if err != nil {
+ return nil, fmt.Errorf("open log file: %w", err)
+ }
+
+ return file, nil
+}
+
+// State implements testcontainers.Container interface for the local Ollama binary.
+// It returns the current state of the Ollama process, simulating a container state.
+func (c *localProcess) State(ctx context.Context) (*types.ContainerState, error) {
+ c.mtx.Lock()
+ defer c.mtx.Unlock()
+
+ if !c.IsRunning() {
+ state := &types.ContainerState{
+ Status: "exited",
+ ExitCode: c.cmd.ProcessState.ExitCode(),
+ StartedAt: c.startedAt.Format(time.RFC3339Nano),
+ FinishedAt: c.finishedAt.Format(time.RFC3339Nano),
+ }
+ if c.exitErr != nil {
+ state.Error = c.exitErr.Error()
+ }
+
+ return state, nil
+ }
+
+ // Setting the Running field because it's required by the wait strategy
+ // to check if the given log message is present.
+ return &types.ContainerState{
+ Status: "running",
+ Running: true,
+ Pid: c.cmd.Process.Pid,
+ StartedAt: c.startedAt.Format(time.RFC3339Nano),
+ FinishedAt: c.finishedAt.Format(time.RFC3339Nano),
+ }, nil
+}
+
+// Stop implements testcontainers.Container interface for the local Ollama binary.
+// It gracefully stops the local Ollama process.
+func (c *localProcess) Stop(ctx context.Context, d *time.Duration) error {
+ if err := c.cmd.Process.Signal(syscall.SIGTERM); err != nil {
+ return fmt.Errorf("signal ollama: %w", err)
+ }
+
+ if d != nil {
+ var cancel context.CancelFunc
+ ctx, cancel = context.WithTimeout(ctx, *d)
+ defer cancel()
+ }
+
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-c.done:
+ // The process exited.
+ c.mtx.Lock()
+ defer c.mtx.Unlock()
+
+ return c.exitErr
+ }
+}
+
+// Terminate implements testcontainers.Container interface for the local Ollama binary.
+// It stops the local Ollama process, removing the log file.
+func (c *localProcess) Terminate(ctx context.Context, opts ...testcontainers.TerminateOption) error {
+ options := testcontainers.NewTerminateOptions(ctx, opts...)
+ // First try to stop gracefully.
+ if err := c.Stop(options.Context(), options.StopTimeout()); !c.isCleanupSafe(err) {
+ return fmt.Errorf("stop: %w", err)
+ }
+
+ var errs []error
+ if c.IsRunning() {
+ // Still running, force kill.
+ if err := c.cmd.Process.Kill(); !c.isCleanupSafe(err) {
+ // Best effort so we can continue with the cleanup.
+ errs = append(errs, fmt.Errorf("kill: %w", err))
+ }
+
+ // Wait for the process to exit so we can capture any error.
+ c.wg.Wait()
+ }
+
+ errs = append(errs, c.cleanup(), options.Cleanup())
+
+ return errors.Join(errs...)
+}
+
+// cleanup performs all clean up, closing and removing the log file if set.
+func (c *localProcess) cleanup() error {
+ c.mtx.Lock()
+ defer c.mtx.Unlock()
+
+ if c.logFile == nil {
+ return c.exitErr
+ }
+
+ var errs []error
+ if c.exitErr != nil {
+ errs = append(errs, fmt.Errorf("exit: %w", c.exitErr))
+ }
+
+ if err := c.logFile.Close(); err != nil {
+ errs = append(errs, fmt.Errorf("close log: %w", err))
+ }
+
+ if err := os.Remove(c.logFile.Name()); err != nil && !errors.Is(err, fs.ErrNotExist) {
+ errs = append(errs, fmt.Errorf("remove log: %w", err))
+ }
+
+ c.logFile = nil // Prevent double cleanup.
+
+ return errors.Join(errs...)
+}
+
+// Endpoint implements testcontainers.Container interface for the local Ollama binary.
+// It returns proto://host:port string for the Ollama port.
+// It returns just host:port if proto is blank.
+func (c *localProcess) Endpoint(ctx context.Context, proto string) (string, error) {
+ return c.PortEndpoint(ctx, localPort, proto)
+}
+
+// GetContainerID implements testcontainers.Container interface for the local Ollama binary.
+func (c *localProcess) GetContainerID() string {
+ return localNamePrefix + "-" + c.sessionID
+}
+
+// Host implements testcontainers.Container interface for the local Ollama binary.
+func (c *localProcess) Host(ctx context.Context) (string, error) {
+ return c.host, nil
+}
+
+// MappedPort implements testcontainers.Container interface for the local Ollama binary.
+func (c *localProcess) MappedPort(ctx context.Context, port nat.Port) (nat.Port, error) {
+ if port.Port() != localPort || port.Proto() != "tcp" {
+ return "", errdefs.NotFound(fmt.Errorf("port %q not found", port))
+ }
+
+ return nat.Port(c.port + "/tcp"), nil
+}
+
+// Networks implements testcontainers.Container interface for the local Ollama binary.
+// It returns a nil slice.
+func (c *localProcess) Networks(ctx context.Context) ([]string, error) {
+ return nil, nil
+}
+
+// NetworkAliases implements testcontainers.Container interface for the local Ollama binary.
+// It returns a nil map.
+func (c *localProcess) NetworkAliases(ctx context.Context) (map[string][]string, error) {
+ return nil, nil
+}
+
+// PortEndpoint implements testcontainers.Container interface for the local Ollama binary.
+// It returns proto://host:port string for the given exposed port.
+// It returns just host:port if proto is blank.
+func (c *localProcess) PortEndpoint(ctx context.Context, port nat.Port, proto string) (string, error) {
+ host, err := c.Host(ctx)
+ if err != nil {
+ return "", fmt.Errorf("host: %w", err)
+ }
+
+ outerPort, err := c.MappedPort(ctx, port)
+ if err != nil {
+ return "", fmt.Errorf("mapped port: %w", err)
+ }
+
+ if proto != "" {
+ proto += "://"
+ }
+
+ return fmt.Sprintf("%s%s:%s", proto, host, outerPort.Port()), nil
+}
+
+// SessionID implements testcontainers.Container interface for the local Ollama binary.
+func (c *localProcess) SessionID() string {
+ return c.sessionID
+}
+
+// Deprecated: it will be removed in the next major release.
+// FollowOutput is not implemented for the local Ollama binary.
+// It panics if called.
+func (c *localProcess) FollowOutput(consumer testcontainers.LogConsumer) {
+ panic("not implemented")
+}
+
+// Deprecated: use c.Inspect(ctx).NetworkSettings.Ports instead.
+// Ports gets the exposed ports for the container.
+func (c *localProcess) Ports(ctx context.Context) (nat.PortMap, error) {
+ inspect, err := c.Inspect(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ return inspect.NetworkSettings.Ports, nil
+}
+
+// Deprecated: it will be removed in the next major release.
+// StartLogProducer implements testcontainers.Container interface for the local Ollama binary.
+// It returns an error because the local Ollama binary doesn't have a log producer.
+func (c *localProcess) StartLogProducer(context.Context, ...testcontainers.LogProductionOption) error {
+ return errors.ErrUnsupported
+}
+
+// Deprecated: it will be removed in the next major release.
+// StopLogProducer implements testcontainers.Container interface for the local Ollama binary.
+// It returns an error because the local Ollama binary doesn't have a log producer.
+func (c *localProcess) StopLogProducer() error {
+ return errors.ErrUnsupported
+}
+
+// Deprecated: Use c.Inspect(ctx).Name instead.
+// Name returns the name for the local Ollama binary.
+func (c *localProcess) Name(context.Context) (string, error) {
+ return localNamePrefix + "-" + c.sessionID, nil
+}
+
+// Customize implements the [testcontainers.ContainerCustomizer] interface.
+// It configures the environment variables set by [WithUseLocal] and sets up
+// the wait strategy to extract the host, port and version from the log.
+func (c *localProcess) Customize(req *testcontainers.GenericContainerRequest) error {
+ // Replace the default host port strategy with one that waits for a log entry
+ // and extracts the host, port and version from it.
+ if err := wait.Walk(&req.WaitingFor, func(w wait.Strategy) error {
+ if _, ok := w.(*wait.HostPortStrategy); ok {
+ return wait.VisitRemove
+ }
+
+ return nil
+ }); err != nil {
+ return fmt.Errorf("walk strategies: %w", err)
+ }
+
+ logStrategy := wait.ForLog(localLogRegex).Submatch(c.extractLogDetails)
+ if req.WaitingFor == nil {
+ req.WaitingFor = logStrategy
+ } else {
+ req.WaitingFor = wait.ForAll(req.WaitingFor, logStrategy)
+ }
+
+ // Setup the environment variables using a random port by default
+ // to avoid conflicts.
+ osEnv := os.Environ()
+ env := make(map[string]string, len(osEnv)+len(c.env)+1)
+ env[localHostVar] = "localhost:0"
+ for _, kv := range append(osEnv, c.env...) {
+ parts := strings.SplitN(kv, "=", 2)
+ if len(parts) != 2 {
+ return fmt.Errorf("invalid environment variable: %q", kv)
+ }
+
+ env[parts[0]] = parts[1]
+ }
+
+ return testcontainers.WithEnv(env)(req)
+}
+
+// isCleanupSafe reports whether all errors in err's tree are one of the
+// following, so can safely be ignored:
+// - nil
+// - os: process already finished
+// - context deadline exceeded
+func (c *localProcess) isCleanupSafe(err error) bool {
+ switch {
+ case err == nil,
+ errors.Is(err, os.ErrProcessDone),
+ errors.Is(err, context.DeadlineExceeded):
+ return true
+ default:
+ return false
+ }
+}
diff --git a/modules/ollama/local_test.go b/modules/ollama/local_test.go
new file mode 100644
index 0000000000..3e0376d4de
--- /dev/null
+++ b/modules/ollama/local_test.go
@@ -0,0 +1,636 @@
+package ollama_test
+
+import (
+ "context"
+ "errors"
+ "io"
+ "io/fs"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "regexp"
+ "testing"
+ "time"
+
+ "github.com/docker/docker/api/types"
+ "github.com/docker/docker/api/types/strslice"
+ "github.com/stretchr/testify/require"
+
+ "github.com/testcontainers/testcontainers-go"
+ tcexec "github.com/testcontainers/testcontainers-go/exec"
+ "github.com/testcontainers/testcontainers-go/modules/ollama"
+)
+
+const (
+ testImage = "ollama/ollama:latest"
+ testNatPort = "11434/tcp"
+ testHost = "127.0.0.1"
+ testBinary = "ollama"
+)
+
+var (
+ // reLogDetails matches the log details of the local ollama binary and should match localLogRegex.
+ reLogDetails = regexp.MustCompile(`Listening on (.*:\d+) \(version\s(.*)\)`)
+ zeroTime = time.Time{}.Format(time.RFC3339Nano)
+)
+
+func TestRun_local(t *testing.T) {
+ // check if the local ollama binary is available
+ if _, err := exec.LookPath(testBinary); err != nil {
+ t.Skip("local ollama binary not found, skipping")
+ }
+
+ ctx := context.Background()
+ ollamaContainer, err := ollama.Run(
+ ctx,
+ testImage,
+ ollama.WithUseLocal("FOO=BAR"),
+ )
+ testcontainers.CleanupContainer(t, ollamaContainer)
+ require.NoError(t, err)
+
+ t.Run("state", func(t *testing.T) {
+ state, err := ollamaContainer.State(ctx)
+ require.NoError(t, err)
+ require.NotEmpty(t, state.StartedAt)
+ require.NotEqual(t, zeroTime, state.StartedAt)
+ require.NotZero(t, state.Pid)
+ require.Equal(t, &types.ContainerState{
+ Status: "running",
+ Running: true,
+ Pid: state.Pid,
+ StartedAt: state.StartedAt,
+ FinishedAt: time.Time{}.Format(time.RFC3339Nano),
+ }, state)
+ })
+
+ t.Run("connection-string", func(t *testing.T) {
+ connectionStr, err := ollamaContainer.ConnectionString(ctx)
+ require.NoError(t, err)
+ require.NotEmpty(t, connectionStr)
+ })
+
+ t.Run("container-id", func(t *testing.T) {
+ id := ollamaContainer.GetContainerID()
+ require.Equal(t, "local-ollama-"+testcontainers.SessionID(), id)
+ })
+
+ t.Run("container-ips", func(t *testing.T) {
+ ip, err := ollamaContainer.ContainerIP(ctx)
+ require.NoError(t, err)
+ require.Equal(t, testHost, ip)
+
+ ips, err := ollamaContainer.ContainerIPs(ctx)
+ require.NoError(t, err)
+ require.Equal(t, []string{testHost}, ips)
+ })
+
+ t.Run("copy", func(t *testing.T) {
+ err := ollamaContainer.CopyToContainer(ctx, []byte("test"), "/tmp", 0o755)
+ require.Error(t, err)
+
+ err = ollamaContainer.CopyDirToContainer(ctx, ".", "/tmp", 0o755)
+ require.Error(t, err)
+
+ err = ollamaContainer.CopyFileToContainer(ctx, ".", "/tmp", 0o755)
+ require.Error(t, err)
+
+ reader, err := ollamaContainer.CopyFileFromContainer(ctx, "/tmp")
+ require.Error(t, err)
+ require.Nil(t, reader)
+ })
+
+ t.Run("log-production-error-channel", func(t *testing.T) {
+ ch := ollamaContainer.GetLogProductionErrorChannel()
+ require.Nil(t, ch)
+ })
+
+ t.Run("endpoint", func(t *testing.T) {
+ endpoint, err := ollamaContainer.Endpoint(ctx, "")
+ require.NoError(t, err)
+ require.Contains(t, endpoint, testHost+":")
+
+ endpoint, err = ollamaContainer.Endpoint(ctx, "http")
+ require.NoError(t, err)
+ require.Contains(t, endpoint, "http://"+testHost+":")
+ })
+
+ t.Run("is-running", func(t *testing.T) {
+ require.True(t, ollamaContainer.IsRunning())
+
+ err = ollamaContainer.Stop(ctx, nil)
+ require.NoError(t, err)
+ require.False(t, ollamaContainer.IsRunning())
+
+ // return it to the running state
+ err = ollamaContainer.Start(ctx)
+ require.NoError(t, err)
+ require.True(t, ollamaContainer.IsRunning())
+ })
+
+ t.Run("host", func(t *testing.T) {
+ host, err := ollamaContainer.Host(ctx)
+ require.NoError(t, err)
+ require.Equal(t, testHost, host)
+ })
+
+ t.Run("inspect", func(t *testing.T) {
+ inspect, err := ollamaContainer.Inspect(ctx)
+ require.NoError(t, err)
+
+ require.Equal(t, "local-ollama-"+testcontainers.SessionID(), inspect.ContainerJSONBase.ID)
+ require.Equal(t, "local-ollama-"+testcontainers.SessionID(), inspect.ContainerJSONBase.Name)
+ require.True(t, inspect.ContainerJSONBase.State.Running)
+
+ require.NotEmpty(t, inspect.Config.Image)
+ _, exists := inspect.Config.ExposedPorts[testNatPort]
+ require.True(t, exists)
+ require.Equal(t, testHost, inspect.Config.Hostname)
+ require.Equal(t, strslice.StrSlice(strslice.StrSlice{testBinary, "serve"}), inspect.Config.Entrypoint)
+
+ require.Empty(t, inspect.NetworkSettings.Networks)
+ require.Equal(t, "bridge", inspect.NetworkSettings.NetworkSettingsBase.Bridge)
+
+ ports := inspect.NetworkSettings.NetworkSettingsBase.Ports
+ port, exists := ports[testNatPort]
+ require.True(t, exists)
+ require.Len(t, port, 1)
+ require.Equal(t, testHost, port[0].HostIP)
+ require.NotEmpty(t, port[0].HostPort)
+ })
+
+ t.Run("logfile", func(t *testing.T) {
+ file, err := os.Open("local-ollama-" + testcontainers.SessionID() + ".log")
+ require.NoError(t, err)
+ require.NoError(t, file.Close())
+ })
+
+ t.Run("logs", func(t *testing.T) {
+ logs, err := ollamaContainer.Logs(ctx)
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ require.NoError(t, logs.Close())
+ })
+
+ bs, err := io.ReadAll(logs)
+ require.NoError(t, err)
+ require.Regexp(t, reLogDetails, string(bs))
+ })
+
+ t.Run("mapped-port", func(t *testing.T) {
+ port, err := ollamaContainer.MappedPort(ctx, testNatPort)
+ require.NoError(t, err)
+ require.NotEmpty(t, port.Port())
+ require.Equal(t, "tcp", port.Proto())
+ })
+
+ t.Run("networks", func(t *testing.T) {
+ networks, err := ollamaContainer.Networks(ctx)
+ require.NoError(t, err)
+ require.Nil(t, networks)
+ })
+
+ t.Run("network-aliases", func(t *testing.T) {
+ aliases, err := ollamaContainer.NetworkAliases(ctx)
+ require.NoError(t, err)
+ require.Nil(t, aliases)
+ })
+
+ t.Run("port-endpoint", func(t *testing.T) {
+ endpoint, err := ollamaContainer.PortEndpoint(ctx, testNatPort, "")
+ require.NoError(t, err)
+ require.Regexp(t, regexp.MustCompile(`^127.0.0.1:\d+$`), endpoint)
+
+ endpoint, err = ollamaContainer.PortEndpoint(ctx, testNatPort, "http")
+ require.NoError(t, err)
+ require.Regexp(t, regexp.MustCompile(`^http://127.0.0.1:\d+$`), endpoint)
+ })
+
+ t.Run("session-id", func(t *testing.T) {
+ require.Equal(t, testcontainers.SessionID(), ollamaContainer.SessionID())
+ })
+
+ t.Run("stop-start", func(t *testing.T) {
+ d := time.Second * 5
+ err := ollamaContainer.Stop(ctx, &d)
+ require.NoError(t, err)
+
+ state, err := ollamaContainer.State(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "exited", state.Status)
+ require.NotEmpty(t, state.StartedAt)
+ require.NotEqual(t, zeroTime, state.StartedAt)
+ require.NotEmpty(t, state.FinishedAt)
+ require.NotEqual(t, zeroTime, state.FinishedAt)
+ require.Zero(t, state.ExitCode)
+
+ err = ollamaContainer.Start(ctx)
+ require.NoError(t, err)
+
+ state, err = ollamaContainer.State(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "running", state.Status)
+
+ logs, err := ollamaContainer.Logs(ctx)
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ require.NoError(t, logs.Close())
+ })
+
+ bs, err := io.ReadAll(logs)
+ require.NoError(t, err)
+ require.Regexp(t, reLogDetails, string(bs))
+ })
+
+ t.Run("start-start", func(t *testing.T) {
+ state, err := ollamaContainer.State(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "running", state.Status)
+
+ err = ollamaContainer.Start(ctx)
+ require.Error(t, err)
+ })
+
+ t.Run("terminate", func(t *testing.T) {
+ err := ollamaContainer.Terminate(ctx)
+ require.NoError(t, err)
+
+ _, err = os.Stat("ollama-" + testcontainers.SessionID() + ".log")
+ require.ErrorIs(t, err, fs.ErrNotExist)
+
+ state, err := ollamaContainer.State(ctx)
+ require.NoError(t, err)
+ require.NotEmpty(t, state.StartedAt)
+ require.NotEqual(t, zeroTime, state.StartedAt)
+ require.NotEmpty(t, state.FinishedAt)
+ require.NotEqual(t, zeroTime, state.FinishedAt)
+ require.Equal(t, &types.ContainerState{
+ Status: "exited",
+ StartedAt: state.StartedAt,
+ FinishedAt: state.FinishedAt,
+ }, state)
+ })
+
+ t.Run("deprecated", func(t *testing.T) {
+ t.Run("ports", func(t *testing.T) {
+ inspect, err := ollamaContainer.Inspect(ctx)
+ require.NoError(t, err)
+
+ ports, err := ollamaContainer.Ports(ctx)
+ require.NoError(t, err)
+ require.Equal(t, inspect.NetworkSettings.Ports, ports)
+ })
+
+ t.Run("follow-output", func(t *testing.T) {
+ require.Panics(t, func() {
+ ollamaContainer.FollowOutput(&testcontainers.StdoutLogConsumer{})
+ })
+ })
+
+ t.Run("start-log-producer", func(t *testing.T) {
+ err := ollamaContainer.StartLogProducer(ctx)
+ require.ErrorIs(t, err, errors.ErrUnsupported)
+ })
+
+ t.Run("stop-log-producer", func(t *testing.T) {
+ err := ollamaContainer.StopLogProducer()
+ require.ErrorIs(t, err, errors.ErrUnsupported)
+ })
+
+ t.Run("name", func(t *testing.T) {
+ name, err := ollamaContainer.Name(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "local-ollama-"+testcontainers.SessionID(), name)
+ })
+ })
+}
+
+func TestRun_localWithCustomLogFile(t *testing.T) {
+ ctx := context.Background()
+ logFile := filepath.Join(t.TempDir(), "server.log")
+
+ t.Run("parent-env", func(t *testing.T) {
+ t.Setenv("OLLAMA_LOGFILE", logFile)
+
+ ollamaContainer, err := ollama.Run(ctx, testImage, ollama.WithUseLocal())
+ testcontainers.CleanupContainer(t, ollamaContainer)
+ require.NoError(t, err)
+
+ logs, err := ollamaContainer.Logs(ctx)
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ require.NoError(t, logs.Close())
+ })
+
+ bs, err := io.ReadAll(logs)
+ require.NoError(t, err)
+ require.Regexp(t, reLogDetails, string(bs))
+
+ file, ok := logs.(*os.File)
+ require.True(t, ok)
+ require.Equal(t, logFile, file.Name())
+ })
+
+ t.Run("local-env", func(t *testing.T) {
+ ollamaContainer, err := ollama.Run(ctx, testImage, ollama.WithUseLocal("OLLAMA_LOGFILE="+logFile))
+ testcontainers.CleanupContainer(t, ollamaContainer)
+ require.NoError(t, err)
+
+ logs, err := ollamaContainer.Logs(ctx)
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ require.NoError(t, logs.Close())
+ })
+
+ bs, err := io.ReadAll(logs)
+ require.NoError(t, err)
+ require.Regexp(t, reLogDetails, string(bs))
+
+ file, ok := logs.(*os.File)
+ require.True(t, ok)
+ require.Equal(t, logFile, file.Name())
+ })
+}
+
+func TestRun_localWithCustomHost(t *testing.T) {
+ ctx := context.Background()
+
+ t.Run("parent-env", func(t *testing.T) {
+ t.Setenv("OLLAMA_HOST", "127.0.0.1:1234")
+
+ ollamaContainer, err := ollama.Run(ctx, testImage, ollama.WithUseLocal())
+ testcontainers.CleanupContainer(t, ollamaContainer)
+ require.NoError(t, err)
+
+ testRun_localWithCustomHost(ctx, t, ollamaContainer)
+ })
+
+ t.Run("local-env", func(t *testing.T) {
+ ollamaContainer, err := ollama.Run(ctx, testImage, ollama.WithUseLocal("OLLAMA_HOST=127.0.0.1:1234"))
+ testcontainers.CleanupContainer(t, ollamaContainer)
+ require.NoError(t, err)
+
+ testRun_localWithCustomHost(ctx, t, ollamaContainer)
+ })
+}
+
+func testRun_localWithCustomHost(ctx context.Context, t *testing.T, ollamaContainer *ollama.OllamaContainer) {
+ t.Helper()
+
+ t.Run("connection-string", func(t *testing.T) {
+ connectionStr, err := ollamaContainer.ConnectionString(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "http://127.0.0.1:1234", connectionStr)
+ })
+
+ t.Run("endpoint", func(t *testing.T) {
+ endpoint, err := ollamaContainer.Endpoint(ctx, "http")
+ require.NoError(t, err)
+ require.Equal(t, "http://127.0.0.1:1234", endpoint)
+ })
+
+ t.Run("inspect", func(t *testing.T) {
+ inspect, err := ollamaContainer.Inspect(ctx)
+ require.NoError(t, err)
+ require.Regexp(t, regexp.MustCompile(`^local-ollama:\d+\.\d+\.\d+$`), inspect.Config.Image)
+
+ _, exists := inspect.Config.ExposedPorts[testNatPort]
+ require.True(t, exists)
+ require.Equal(t, testHost, inspect.Config.Hostname)
+ require.Equal(t, strslice.StrSlice(strslice.StrSlice{testBinary, "serve"}), inspect.Config.Entrypoint)
+
+ require.Empty(t, inspect.NetworkSettings.Networks)
+ require.Equal(t, "bridge", inspect.NetworkSettings.NetworkSettingsBase.Bridge)
+
+ ports := inspect.NetworkSettings.NetworkSettingsBase.Ports
+ port, exists := ports[testNatPort]
+ require.True(t, exists)
+ require.Len(t, port, 1)
+ require.Equal(t, testHost, port[0].HostIP)
+ require.Equal(t, "1234", port[0].HostPort)
+ })
+
+ t.Run("logs", func(t *testing.T) {
+ logs, err := ollamaContainer.Logs(ctx)
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ require.NoError(t, logs.Close())
+ })
+
+ bs, err := io.ReadAll(logs)
+ require.NoError(t, err)
+
+ require.Contains(t, string(bs), "Listening on 127.0.0.1:1234")
+ })
+
+ t.Run("mapped-port", func(t *testing.T) {
+ port, err := ollamaContainer.MappedPort(ctx, testNatPort)
+ require.NoError(t, err)
+ require.Equal(t, "1234", port.Port())
+ require.Equal(t, "tcp", port.Proto())
+ })
+}
+
+func TestRun_localExec(t *testing.T) {
+ // check if the local ollama binary is available
+ if _, err := exec.LookPath(testBinary); err != nil {
+ t.Skip("local ollama binary not found, skipping")
+ }
+
+ ctx := context.Background()
+
+ ollamaContainer, err := ollama.Run(ctx, testImage, ollama.WithUseLocal())
+ testcontainers.CleanupContainer(t, ollamaContainer)
+ require.NoError(t, err)
+
+ t.Run("no-command", func(t *testing.T) {
+ code, r, err := ollamaContainer.Exec(ctx, nil)
+ require.Error(t, err)
+ require.Equal(t, 1, code)
+ require.Nil(t, r)
+ })
+
+ t.Run("unsupported-command", func(t *testing.T) {
+ code, r, err := ollamaContainer.Exec(ctx, []string{"cat", "/etc/hosts"})
+ require.ErrorIs(t, err, errors.ErrUnsupported)
+ require.Equal(t, 1, code)
+ require.Nil(t, r)
+ })
+
+ t.Run("unsupported-option-user", func(t *testing.T) {
+ code, r, err := ollamaContainer.Exec(ctx, []string{testBinary, "-v"}, tcexec.WithUser("root"))
+ require.ErrorIs(t, err, errors.ErrUnsupported)
+ require.Equal(t, 1, code)
+ require.Nil(t, r)
+ })
+
+ t.Run("unsupported-option-privileged", func(t *testing.T) {
+ code, r, err := ollamaContainer.Exec(ctx, []string{testBinary, "-v"}, tcexec.ProcessOptionFunc(func(opts *tcexec.ProcessOptions) {
+ opts.ExecConfig.Privileged = true
+ }))
+ require.ErrorIs(t, err, errors.ErrUnsupported)
+ require.Equal(t, 1, code)
+ require.Nil(t, r)
+ })
+
+ t.Run("unsupported-option-tty", func(t *testing.T) {
+ code, r, err := ollamaContainer.Exec(ctx, []string{testBinary, "-v"}, tcexec.ProcessOptionFunc(func(opts *tcexec.ProcessOptions) {
+ opts.ExecConfig.Tty = true
+ }))
+ require.ErrorIs(t, err, errors.ErrUnsupported)
+ require.Equal(t, 1, code)
+ require.Nil(t, r)
+ })
+
+ t.Run("unsupported-option-detach", func(t *testing.T) {
+ code, r, err := ollamaContainer.Exec(ctx, []string{testBinary, "-v"}, tcexec.ProcessOptionFunc(func(opts *tcexec.ProcessOptions) {
+ opts.ExecConfig.Detach = true
+ }))
+ require.ErrorIs(t, err, errors.ErrUnsupported)
+ require.Equal(t, 1, code)
+ require.Nil(t, r)
+ })
+
+ t.Run("unsupported-option-detach-keys", func(t *testing.T) {
+ code, r, err := ollamaContainer.Exec(ctx, []string{testBinary, "-v"}, tcexec.ProcessOptionFunc(func(opts *tcexec.ProcessOptions) {
+ opts.ExecConfig.DetachKeys = "ctrl-p,ctrl-q"
+ }))
+ require.ErrorIs(t, err, errors.ErrUnsupported)
+ require.Equal(t, 1, code)
+ require.Nil(t, r)
+ })
+
+ t.Run("pull-and-run-model", func(t *testing.T) {
+ const model = "llama3.2:1b"
+
+ code, r, err := ollamaContainer.Exec(ctx, []string{testBinary, "pull", model})
+ require.NoError(t, err)
+ require.Zero(t, code)
+
+ bs, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Contains(t, string(bs), "success")
+
+ code, r, err = ollamaContainer.Exec(ctx, []string{testBinary, "run", model}, tcexec.Multiplexed())
+ require.NoError(t, err)
+ require.Zero(t, code)
+
+ bs, err = io.ReadAll(r)
+ require.NoError(t, err)
+ require.Empty(t, bs)
+
+ logs, err := ollamaContainer.Logs(ctx)
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ require.NoError(t, logs.Close())
+ })
+
+ bs, err = io.ReadAll(logs)
+ require.NoError(t, err)
+ require.Contains(t, string(bs), "llama runner started")
+ })
+}
+
+func TestRun_localValidateRequest(t *testing.T) {
+ // check if the local ollama binary is available
+ if _, err := exec.LookPath(testBinary); err != nil {
+ t.Skip("local ollama binary not found, skipping")
+ }
+
+ ctx := context.Background()
+ t.Run("waiting-for-nil", func(t *testing.T) {
+ ollamaContainer, err := ollama.Run(
+ ctx,
+ testImage,
+ ollama.WithUseLocal("FOO=BAR"),
+ testcontainers.CustomizeRequestOption(func(req *testcontainers.GenericContainerRequest) error {
+ req.WaitingFor = nil
+ return nil
+ }),
+ )
+ testcontainers.CleanupContainer(t, ollamaContainer)
+ require.EqualError(t, err, "validate request: ContainerRequest.WaitingFor must be set")
+ })
+
+ t.Run("started-false", func(t *testing.T) {
+ ollamaContainer, err := ollama.Run(
+ ctx,
+ testImage,
+ ollama.WithUseLocal("FOO=BAR"),
+ testcontainers.CustomizeRequestOption(func(req *testcontainers.GenericContainerRequest) error {
+ req.Started = false
+ return nil
+ }),
+ )
+ testcontainers.CleanupContainer(t, ollamaContainer)
+ require.EqualError(t, err, "validate request: Started must be true")
+ })
+
+ t.Run("exposed-ports-empty", func(t *testing.T) {
+ ollamaContainer, err := ollama.Run(
+ ctx,
+ testImage,
+ ollama.WithUseLocal("FOO=BAR"),
+ testcontainers.CustomizeRequestOption(func(req *testcontainers.GenericContainerRequest) error {
+ req.ExposedPorts = req.ExposedPorts[:0]
+ return nil
+ }),
+ )
+ testcontainers.CleanupContainer(t, ollamaContainer)
+ require.EqualError(t, err, "validate request: ContainerRequest.ExposedPorts must be 11434/tcp got: []")
+ })
+
+ t.Run("dockerfile-set", func(t *testing.T) {
+ ollamaContainer, err := ollama.Run(
+ ctx,
+ testImage,
+ ollama.WithUseLocal("FOO=BAR"),
+ testcontainers.CustomizeRequestOption(func(req *testcontainers.GenericContainerRequest) error {
+ req.Dockerfile = "FROM scratch"
+ return nil
+ }),
+ )
+ testcontainers.CleanupContainer(t, ollamaContainer)
+ require.EqualError(t, err, "validate request: unsupported field: ContainerRequest.FromDockerfile.Dockerfile = \"FROM scratch\"")
+ })
+
+ t.Run("image-only", func(t *testing.T) {
+ ollamaContainer, err := ollama.Run(
+ ctx,
+ testBinary,
+ ollama.WithUseLocal(),
+ )
+ testcontainers.CleanupContainer(t, ollamaContainer)
+ require.NoError(t, err)
+ })
+
+ t.Run("image-path", func(t *testing.T) {
+ ollamaContainer, err := ollama.Run(
+ ctx,
+ "prefix-path/"+testBinary,
+ ollama.WithUseLocal(),
+ )
+ testcontainers.CleanupContainer(t, ollamaContainer)
+ require.NoError(t, err)
+ })
+
+ t.Run("image-bad-version", func(t *testing.T) {
+ ollamaContainer, err := ollama.Run(
+ ctx,
+ testBinary+":bad-version",
+ ollama.WithUseLocal(),
+ )
+ testcontainers.CleanupContainer(t, ollamaContainer)
+ require.EqualError(t, err, `validate request: ContainerRequest.Image version must be blank or "latest", got: "bad-version"`)
+ })
+
+ t.Run("image-not-found", func(t *testing.T) {
+ ollamaContainer, err := ollama.Run(
+ ctx,
+ "ollama/ollama-not-found",
+ ollama.WithUseLocal(),
+ )
+ testcontainers.CleanupContainer(t, ollamaContainer)
+ require.EqualError(t, err, `validate request: invalid image "ollama/ollama-not-found": exec: "ollama-not-found": executable file not found in $PATH`)
+ })
+}
diff --git a/modules/ollama/ollama.go b/modules/ollama/ollama.go
index 203d80103f..4d78fa171e 100644
--- a/modules/ollama/ollama.go
+++ b/modules/ollama/ollama.go
@@ -27,12 +27,12 @@ type OllamaContainer struct {
func (c *OllamaContainer) ConnectionString(ctx context.Context) (string, error) {
host, err := c.Host(ctx)
if err != nil {
- return "", err
+ return "", fmt.Errorf("host: %w", err)
}
port, err := c.MappedPort(ctx, "11434/tcp")
if err != nil {
- return "", err
+ return "", fmt.Errorf("mapped port: %w", err)
}
return fmt.Sprintf("http://%s:%d", host, port.Int()), nil
@@ -43,6 +43,10 @@ func (c *OllamaContainer) ConnectionString(ctx context.Context) (string, error)
// of the container into a new image with the given name, so it doesn't override existing images.
// It should be used for creating an image that contains a loaded model.
func (c *OllamaContainer) Commit(ctx context.Context, targetImage string) error {
+ if _, ok := c.Container.(*localProcess); ok {
+ return nil
+ }
+
cli, err := testcontainers.NewDockerClientWithOpts(context.Background())
if err != nil {
return err
@@ -80,27 +84,34 @@ func RunContainer(ctx context.Context, opts ...testcontainers.ContainerCustomize
// Run creates an instance of the Ollama container type
func Run(ctx context.Context, img string, opts ...testcontainers.ContainerCustomizer) (*OllamaContainer, error) {
- req := testcontainers.ContainerRequest{
- Image: img,
- ExposedPorts: []string{"11434/tcp"},
- WaitingFor: wait.ForListeningPort("11434/tcp").WithStartupTimeout(60 * time.Second),
- }
-
- genericContainerReq := testcontainers.GenericContainerRequest{
- ContainerRequest: req,
- Started: true,
+ req := testcontainers.GenericContainerRequest{
+ ContainerRequest: testcontainers.ContainerRequest{
+ Image: img,
+ ExposedPorts: []string{"11434/tcp"},
+ WaitingFor: wait.ForListeningPort("11434/tcp").WithStartupTimeout(60 * time.Second),
+ },
+ Started: true,
}
- // always request a GPU if the host supports it
+ // Always request a GPU if the host supports it.
opts = append(opts, withGpu())
+ var local *localProcess
for _, opt := range opts {
- if err := opt.Customize(&genericContainerReq); err != nil {
+ if err := opt.Customize(&req); err != nil {
return nil, fmt.Errorf("customize: %w", err)
}
+ if l, ok := opt.(*localProcess); ok {
+ local = l
+ }
+ }
+
+ // Now we have processed all the options, we can check if we need to use the local process.
+ if local != nil {
+ return local.run(ctx, req)
}
- container, err := testcontainers.GenericContainer(ctx, genericContainerReq)
+ container, err := testcontainers.GenericContainer(ctx, req)
var c *OllamaContainer
if container != nil {
c = &OllamaContainer{Container: container}
diff --git a/modules/ollama/options.go b/modules/ollama/options.go
index 605768a379..1cf29453fe 100644
--- a/modules/ollama/options.go
+++ b/modules/ollama/options.go
@@ -11,7 +11,7 @@ import (
var noopCustomizeRequestOption = func(req *testcontainers.GenericContainerRequest) error { return nil }
// withGpu requests a GPU for the container, which could improve performance for some models.
-// This option will be automaticall added to the Ollama container to check if the host supports nvidia.
+// This option will be automatically added to the Ollama container to check if the host supports nvidia.
func withGpu() testcontainers.CustomizeRequestOption {
cli, err := testcontainers.NewDockerClientWithOpts(context.Background())
if err != nil {
@@ -37,3 +37,30 @@ func withGpu() testcontainers.CustomizeRequestOption {
}
})
}
+
+// WithUseLocal starts a local Ollama process with the given environment in
+// format KEY=VALUE instead of a Docker container, which can be more performant
+// as it has direct access to the GPU.
+// By default `OLLAMA_HOST=localhost:0` is set to avoid port conflicts.
+//
+// When using this option, the container request will be validated to ensure
+// that only the options that are compatible with the local process are used.
+//
+// Supported fields are:
+// - [testcontainers.GenericContainerRequest.Started] must be set to true
+// - [testcontainers.GenericContainerRequest.ExposedPorts] must be set to ["11434/tcp"]
+// - [testcontainers.ContainerRequest.WaitingFor] should not be changed from the default
+// - [testcontainers.ContainerRequest.Image] used to determine the local process binary [/][:latest] if not blank.
+// - [testcontainers.ContainerRequest.Env] applied to all local process executions
+// - [testcontainers.GenericContainerRequest.Logger] is unused
+//
+// Any other leaf field not set to the type's zero value will result in an error.
+func WithUseLocal(envKeyValues ...string) *localProcess {
+ sessionID := testcontainers.SessionID()
+ return &localProcess{
+ sessionID: sessionID,
+ logName: localNamePrefix + "-" + sessionID + ".log",
+ env: envKeyValues,
+ binary: localBinary,
+ }
+}
diff --git a/modules/ollama/options_test.go b/modules/ollama/options_test.go
new file mode 100644
index 0000000000..f842d15a17
--- /dev/null
+++ b/modules/ollama/options_test.go
@@ -0,0 +1,49 @@
+package ollama_test
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "github.com/testcontainers/testcontainers-go"
+ "github.com/testcontainers/testcontainers-go/modules/ollama"
+)
+
+func TestWithUseLocal(t *testing.T) {
+ req := testcontainers.GenericContainerRequest{}
+
+ t.Run("keyVal/valid", func(t *testing.T) {
+ opt := ollama.WithUseLocal("OLLAMA_MODELS=/path/to/models")
+ err := opt.Customize(&req)
+ require.NoError(t, err)
+ require.Equal(t, "/path/to/models", req.Env["OLLAMA_MODELS"])
+ })
+
+ t.Run("keyVal/invalid", func(t *testing.T) {
+ opt := ollama.WithUseLocal("OLLAMA_MODELS")
+ err := opt.Customize(&req)
+ require.Error(t, err)
+ })
+
+ t.Run("keyVal/valid/multiple", func(t *testing.T) {
+ opt := ollama.WithUseLocal("OLLAMA_MODELS=/path/to/models", "OLLAMA_HOST=localhost")
+ err := opt.Customize(&req)
+ require.NoError(t, err)
+ require.Equal(t, "/path/to/models", req.Env["OLLAMA_MODELS"])
+ require.Equal(t, "localhost", req.Env["OLLAMA_HOST"])
+ })
+
+ t.Run("keyVal/valid/multiple-equals", func(t *testing.T) {
+ opt := ollama.WithUseLocal("OLLAMA_MODELS=/path/to/models", "OLLAMA_HOST=localhost=127.0.0.1")
+ err := opt.Customize(&req)
+ require.NoError(t, err)
+ require.Equal(t, "/path/to/models", req.Env["OLLAMA_MODELS"])
+ require.Equal(t, "localhost=127.0.0.1", req.Env["OLLAMA_HOST"])
+ })
+
+ t.Run("keyVal/invalid/multiple", func(t *testing.T) {
+ opt := ollama.WithUseLocal("OLLAMA_MODELS=/path/to/models", "OLLAMA_HOST")
+ err := opt.Customize(&req)
+ require.Error(t, err)
+ })
+}
diff --git a/modules/postgres/go.mod b/modules/postgres/go.mod
index 01ee22b8bd..955edef1de 100644
--- a/modules/postgres/go.mod
+++ b/modules/postgres/go.mod
@@ -6,6 +6,7 @@ require (
github.com/docker/go-connections v0.5.0
github.com/jackc/pgx/v5 v5.5.4
github.com/lib/pq v1.10.9
+ github.com/mdelapenya/tlscert v0.1.0
github.com/stretchr/testify v1.9.0
github.com/testcontainers/testcontainers-go v0.34.0
diff --git a/modules/postgres/go.sum b/modules/postgres/go.sum
index ba337a188c..9d1c34d3df 100644
--- a/modules/postgres/go.sum
+++ b/modules/postgres/go.sum
@@ -71,6 +71,8 @@ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
+github.com/mdelapenya/tlscert v0.1.0 h1:YTpF579PYUX475eOL+6zyEO3ngLTOUWck78NBuJVXaM=
+github.com/mdelapenya/tlscert v0.1.0/go.mod h1:wrbyM/DwbFCeCeqdPX/8c6hNOqQgbf0rUDErE1uD+64=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk=
diff --git a/modules/postgres/postgres.go b/modules/postgres/postgres.go
index ce3719575d..e25bc667a6 100644
--- a/modules/postgres/postgres.go
+++ b/modules/postgres/postgres.go
@@ -3,6 +3,7 @@ package postgres
import (
"context"
"database/sql"
+ _ "embed"
"errors"
"fmt"
"io"
@@ -19,6 +20,9 @@ const (
defaultSnapshotName = "migrated_template"
)
+//go:embed resources/customEntrypoint.sh
+var embeddedCustomEntrypoint string
+
// PostgresContainer represents the postgres container type used in the module
type PostgresContainer struct {
testcontainers.Container
@@ -205,6 +209,43 @@ func WithSnapshotName(name string) SnapshotOption {
}
}
+// WithSSLSettings configures the Postgres server to run with the provided CA Chain
+// This will not function if the corresponding postgres conf is not correctly configured.
+// Namely the paths below must match what is set in the conf file
+func WithSSLCert(caCertFile string, certFile string, keyFile string) testcontainers.CustomizeRequestOption {
+ const defaultPermission = 0o600
+
+ return func(req *testcontainers.GenericContainerRequest) error {
+ const entrypointPath = "/usr/local/bin/docker-entrypoint-ssl.bash"
+
+ req.Files = append(req.Files,
+ testcontainers.ContainerFile{
+ HostFilePath: caCertFile,
+ ContainerFilePath: "/tmp/testcontainers-go/postgres/ca_cert.pem",
+ FileMode: defaultPermission,
+ },
+ testcontainers.ContainerFile{
+ HostFilePath: certFile,
+ ContainerFilePath: "/tmp/testcontainers-go/postgres/server.cert",
+ FileMode: defaultPermission,
+ },
+ testcontainers.ContainerFile{
+ HostFilePath: keyFile,
+ ContainerFilePath: "/tmp/testcontainers-go/postgres/server.key",
+ FileMode: defaultPermission,
+ },
+ testcontainers.ContainerFile{
+ Reader: strings.NewReader(embeddedCustomEntrypoint),
+ ContainerFilePath: entrypointPath,
+ FileMode: defaultPermission,
+ },
+ )
+ req.Entrypoint = []string{"sh", entrypointPath}
+
+ return nil
+ }
+}
+
// Snapshot takes a snapshot of the current state of the database as a template, which can then be restored using
// the Restore method. By default, the snapshot will be created under a database called migrated_template, you can
// customize the snapshot name with the options.
diff --git a/modules/postgres/postgres_test.go b/modules/postgres/postgres_test.go
index 8c02e68476..e83b8e1454 100644
--- a/modules/postgres/postgres_test.go
+++ b/modules/postgres/postgres_test.go
@@ -3,7 +3,9 @@ package postgres_test
import (
"context"
"database/sql"
+ "errors"
"fmt"
+ "os"
"path/filepath"
"testing"
"time"
@@ -12,6 +14,8 @@ import (
"github.com/jackc/pgx/v5"
_ "github.com/jackc/pgx/v5/stdlib"
_ "github.com/lib/pq"
+ "github.com/mdelapenya/tlscert"
+ "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/testcontainers/testcontainers-go"
@@ -25,6 +29,40 @@ const (
password = "password"
)
+func createSSLCerts(t *testing.T) (*tlscert.Certificate, *tlscert.Certificate, error) {
+ t.Helper()
+ tmpDir := t.TempDir()
+ certsDir := tmpDir + "/certs"
+
+ require.NoError(t, os.MkdirAll(certsDir, 0o755))
+
+ t.Cleanup(func() {
+ require.NoError(t, os.RemoveAll(tmpDir))
+ })
+
+ caCert := tlscert.SelfSignedFromRequest(tlscert.Request{
+ Host: "localhost",
+ Name: "ca-cert",
+ ParentDir: certsDir,
+ })
+
+ if caCert == nil {
+ return caCert, nil, errors.New("unable to create CA Authority")
+ }
+
+ cert := tlscert.SelfSignedFromRequest(tlscert.Request{
+ Host: "localhost",
+ Name: "client-cert",
+ Parent: caCert,
+ ParentDir: certsDir,
+ })
+ if cert == nil {
+ return caCert, cert, errors.New("unable to create Server Certificates")
+ }
+
+ return caCert, cert, nil
+}
+
func TestPostgres(t *testing.T) {
ctx := context.Background()
@@ -171,6 +209,56 @@ func TestWithConfigFile(t *testing.T) {
defer db.Close()
}
+func TestWithSSL(t *testing.T) {
+ ctx := context.Background()
+
+ caCert, serverCerts, err := createSSLCerts(t)
+ require.NoError(t, err)
+
+ ctr, err := postgres.Run(ctx,
+ "postgres:16-alpine",
+ postgres.WithConfigFile(filepath.Join("testdata", "postgres-ssl.conf")),
+ postgres.WithInitScripts(filepath.Join("testdata", "init-user-db.sh")),
+ postgres.WithDatabase(dbname),
+ postgres.WithUsername(user),
+ postgres.WithPassword(password),
+ testcontainers.WithWaitStrategy(wait.ForLog("database system is ready to accept connections").WithOccurrence(2).WithStartupTimeout(5*time.Second)),
+ postgres.WithSSLCert(caCert.CertPath, serverCerts.CertPath, serverCerts.KeyPath),
+ )
+
+ testcontainers.CleanupContainer(t, ctr)
+ require.NoError(t, err)
+
+ connStr, err := ctr.ConnectionString(ctx, "sslmode=require")
+ require.NoError(t, err)
+
+ db, err := sql.Open("postgres", connStr)
+ require.NoError(t, err)
+ assert.NotNil(t, db)
+ defer db.Close()
+
+ result, err := db.Exec("SELECT * FROM testdb;")
+ require.NoError(t, err)
+ assert.NotNil(t, result)
+}
+
+func TestSSLValidatesKeyMaterialPath(t *testing.T) {
+ ctx := context.Background()
+
+ _, err := postgres.Run(ctx,
+ "postgres:16-alpine",
+ postgres.WithConfigFile(filepath.Join("testdata", "postgres-ssl.conf")),
+ postgres.WithInitScripts(filepath.Join("testdata", "init-user-db.sh")),
+ postgres.WithDatabase(dbname),
+ postgres.WithUsername(user),
+ postgres.WithPassword(password),
+ testcontainers.WithWaitStrategy(wait.ForLog("database system is ready to accept connections").WithOccurrence(2).WithStartupTimeout(5*time.Second)),
+ postgres.WithSSLCert("", "", ""),
+ )
+
+ require.Error(t, err, "Error should not have been nil. Container creation should have failed due to empty key material")
+}
+
func TestWithInitScript(t *testing.T) {
ctx := context.Background()
diff --git a/modules/postgres/resources/customEntrypoint.sh b/modules/postgres/resources/customEntrypoint.sh
new file mode 100644
index 0000000000..ff4ffa4291
--- /dev/null
+++ b/modules/postgres/resources/customEntrypoint.sh
@@ -0,0 +1,25 @@
+#!/usr/bin/env bash
+set -Eeo pipefail
+
+
+pUID=$(id -u postgres)
+pGID=$(id -g postgres)
+
+if [ -z "$pUID" ]
+then
+ echo "Unable to find postgres user id, required in order to chown key material"
+ exit 1
+fi
+
+if [ -z "$pGID" ]
+then
+ echo "Unable to find postgres group id, required in order to chown key material"
+ exit 1
+fi
+
+chown "$pUID":"$pGID" \
+ /tmp/testcontainers-go/postgres/ca_cert.pem \
+ /tmp/testcontainers-go/postgres/server.cert \
+ /tmp/testcontainers-go/postgres/server.key
+
+/usr/local/bin/docker-entrypoint.sh "$@"
diff --git a/modules/postgres/testdata/postgres-ssl.conf b/modules/postgres/testdata/postgres-ssl.conf
new file mode 100644
index 0000000000..5e49f16a4f
--- /dev/null
+++ b/modules/postgres/testdata/postgres-ssl.conf
@@ -0,0 +1,80 @@
+# -----------------------------
+# PostgreSQL configuration file
+# -----------------------------
+#
+# This file consists of lines of the form:
+#
+# name = value
+#
+# (The "=" is optional.) Whitespace may be used. Comments are introduced with
+# "#" anywhere on a line. The complete list of parameter names and allowed
+# values can be found in the PostgreSQL documentation.
+#
+# The commented-out settings shown in this file represent the default values.
+# Re-commenting a setting is NOT sufficient to revert it to the default value;
+# you need to reload the server.
+#
+# This file is read on server startup and when the server receives a SIGHUP
+# signal. If you edit the file on a running system, you have to SIGHUP the
+# server for the changes to take effect, run "pg_ctl reload", or execute
+# "SELECT pg_reload_conf()". Some parameters, which are marked below,
+# require a server shutdown and restart to take effect.
+#
+# Any parameter can also be given as a command-line option to the server, e.g.,
+# "postgres -c log_connections=on". Some parameters can be changed at run time
+# with the "SET" SQL command.
+#
+# Memory units: B = bytes Time units: ms = milliseconds
+# kB = kilobytes s = seconds
+# MB = megabytes min = minutes
+# GB = gigabytes h = hours
+# TB = terabytes d = days
+
+
+#------------------------------------------------------------------------------
+# FILE LOCATIONS
+#------------------------------------------------------------------------------
+
+# The default values of these variables are driven from the -D command-line
+# option or PGDATA environment variable, represented here as ConfigDir.
+
+#data_directory = 'ConfigDir' # use data in another directory
+ # (change requires restart)
+#hba_file = 'ConfigDir/pg_hba.conf' # host-based authentication file
+ # (change requires restart)
+#ident_file = 'ConfigDir/pg_ident.conf' # ident configuration file
+ # (change requires restart)
+
+# If external_pid_file is not explicitly set, no extra PID file is written.
+#external_pid_file = '' # write an extra PID file
+ # (change requires restart)
+
+
+#------------------------------------------------------------------------------
+# CONNECTIONS AND AUTHENTICATION
+#------------------------------------------------------------------------------
+
+# - Connection Settings -
+
+listen_addresses = '*'
+ # comma-separated list of addresses;
+ # defaults to 'localhost'; use '*' for all
+ # (change requires restart)
+#port = 5432 # (change requires restart)
+#max_connections = 100 # (change requires restart)
+
+# - SSL -
+
+ssl = on
+ssl_ca_file = '/tmp/testcontainers-go/postgres/ca_cert.pem'
+ssl_cert_file = '/tmp/testcontainers-go/postgres/server.cert'
+#ssl_crl_file = ''
+ssl_key_file = '/tmp/testcontainers-go/postgres/server.key'
+#ssl_ciphers = 'HIGH:MEDIUM:+3DES:!aNULL' # allowed SSL ciphers
+#ssl_prefer_server_ciphers = on
+#ssl_ecdh_curve = 'prime256v1'
+#ssl_dh_params_file = ''
+#ssl_passphrase_command = ''
+#ssl_passphrase_command_supports_reload = off
+
+
diff --git a/network/network_test.go b/network/network_test.go
index bbe5d45c7c..8b83056f43 100644
--- a/network/network_test.go
+++ b/network/network_test.go
@@ -440,3 +440,8 @@ func TestWithNewNetworkContextTimeout(t *testing.T) {
require.Empty(t, req.Networks)
require.Empty(t, req.NetworkAliases)
}
+
+func TestCleanupWithNil(t *testing.T) {
+ var network *testcontainers.DockerNetwork
+ testcontainers.CleanupNetwork(t, network)
+}
diff --git a/port_forwarding.go b/port_forwarding.go
index bb6bae2393..3411ff0c1f 100644
--- a/port_forwarding.go
+++ b/port_forwarding.go
@@ -225,10 +225,10 @@ type sshdContainer struct {
}
// Terminate stops the container and closes the SSH session
-func (sshdC *sshdContainer) Terminate(ctx context.Context) error {
+func (sshdC *sshdContainer) Terminate(ctx context.Context, opts ...TerminateOption) error {
return errors.Join(
sshdC.closePorts(),
- sshdC.Container.Terminate(ctx),
+ sshdC.Container.Terminate(ctx, opts...),
)
}
diff --git a/testing.go b/testing.go
index 35ce4f0a39..8502f018d9 100644
--- a/testing.go
+++ b/testing.go
@@ -83,7 +83,9 @@ func CleanupNetwork(tb testing.TB, network Network) {
tb.Helper()
tb.Cleanup(func() {
- noErrorOrIgnored(tb, network.Remove(context.Background()))
+ if !isNil(network) {
+ noErrorOrIgnored(tb, network.Remove(context.Background()))
+ }
})
}
diff --git a/wait/log.go b/wait/log.go
index 530077f909..41c96e3eb9 100644
--- a/wait/log.go
+++ b/wait/log.go
@@ -1,10 +1,12 @@
package wait
import (
+ "bytes"
"context"
+ "errors"
+ "fmt"
"io"
"regexp"
- "strings"
"time"
)
@@ -14,6 +16,21 @@ var (
_ StrategyTimeout = (*LogStrategy)(nil)
)
+// PermanentError is a special error that will stop the wait and return an error.
+type PermanentError struct {
+ err error
+}
+
+// Error implements the error interface.
+func (e *PermanentError) Error() string {
+ return e.err.Error()
+}
+
+// NewPermanentError creates a new PermanentError.
+func NewPermanentError(err error) *PermanentError {
+ return &PermanentError{err: err}
+}
+
// LogStrategy will wait until a given log entry shows up in the docker logs
type LogStrategy struct {
// all Strategies should have a startupTimeout to avoid waiting infinitely
@@ -24,6 +41,18 @@ type LogStrategy struct {
IsRegexp bool
Occurrence int
PollInterval time.Duration
+
+ // check is the function that will be called to check if the log entry is present.
+ check func([]byte) error
+
+ // submatchCallback is a callback that will be called with the sub matches of the regexp.
+ submatchCallback func(pattern string, matches [][][]byte) error
+
+ // re is the optional compiled regexp.
+ re *regexp.Regexp
+
+ // log byte slice version of [LogStrategy.Log] used for count checks.
+ log []byte
}
// NewLogStrategy constructs with polling interval of 100 milliseconds and startup timeout of 60 seconds by default
@@ -46,6 +75,18 @@ func (ws *LogStrategy) AsRegexp() *LogStrategy {
return ws
}
+// Submatch configures a function that will be called with the result of
+// [regexp.Regexp.FindAllSubmatch], allowing the caller to process the results.
+// If the callback returns nil, the strategy will be considered successful.
+// Returning a [PermanentError] will stop the wait and return an error, otherwise
+// it will retry until the timeout is reached.
+// [LogStrategy.Occurrence] is ignored if this option is set.
+func (ws *LogStrategy) Submatch(callback func(pattern string, matches [][][]byte) error) *LogStrategy {
+ ws.submatchCallback = callback
+
+ return ws
+}
+
// WithStartupTimeout can be used to change the default startup timeout
func (ws *LogStrategy) WithStartupTimeout(timeout time.Duration) *LogStrategy {
ws.timeout = &timeout
@@ -89,57 +130,85 @@ func (ws *LogStrategy) WaitUntilReady(ctx context.Context, target StrategyTarget
timeout = *ws.timeout
}
+ switch {
+ case ws.submatchCallback != nil:
+ ws.re = regexp.MustCompile(ws.Log)
+ ws.check = ws.checkSubmatch
+ case ws.IsRegexp:
+ ws.re = regexp.MustCompile(ws.Log)
+ ws.check = ws.checkRegexp
+ default:
+ ws.log = []byte(ws.Log)
+ ws.check = ws.checkCount
+ }
+
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
- length := 0
-
-LOOP:
+ var lastLen int
+ var lastError error
for {
select {
case <-ctx.Done():
- return ctx.Err()
+ return errors.Join(lastError, ctx.Err())
default:
checkErr := checkTarget(ctx, target)
reader, err := target.Logs(ctx)
if err != nil {
+ // TODO: fix as this will wait for timeout if the logs are not available.
time.Sleep(ws.PollInterval)
continue
}
b, err := io.ReadAll(reader)
if err != nil {
+ // TODO: fix as this will wait for timeout if the logs are not readable.
time.Sleep(ws.PollInterval)
continue
}
- logs := string(b)
-
- switch {
- case length == len(logs) && checkErr != nil:
+ if lastLen == len(b) && checkErr != nil {
+ // Log length hasn't changed so we're not making progress.
return checkErr
- case checkLogsFn(ws, b):
- break LOOP
- default:
- length = len(logs)
+ }
+
+ if err := ws.check(b); err != nil {
+ var errPermanent *PermanentError
+ if errors.As(err, &errPermanent) {
+ return err
+ }
+
+ lastError = err
+ lastLen = len(b)
time.Sleep(ws.PollInterval)
continue
}
+
+ return nil
}
}
+}
+
+// checkCount checks if the log entry is present in the logs using a string count.
+func (ws *LogStrategy) checkCount(b []byte) error {
+ if count := bytes.Count(b, ws.log); count < ws.Occurrence {
+ return fmt.Errorf("%q matched %d times, expected %d", ws.Log, count, ws.Occurrence)
+ }
return nil
}
-func checkLogsFn(ws *LogStrategy, b []byte) bool {
- if ws.IsRegexp {
- re := regexp.MustCompile(ws.Log)
- occurrences := re.FindAll(b, -1)
-
- return len(occurrences) >= ws.Occurrence
+// checkRegexp checks if the log entry is present in the logs using a regexp count.
+func (ws *LogStrategy) checkRegexp(b []byte) error {
+ if matches := ws.re.FindAll(b, -1); len(matches) < ws.Occurrence {
+ return fmt.Errorf("`%s` matched %d times, expected %d", ws.Log, len(matches), ws.Occurrence)
}
- logs := string(b)
- return strings.Count(logs, ws.Log) >= ws.Occurrence
+ return nil
+}
+
+// checkSubmatch checks if the log entry is present in the logs using a regexp sub match callback.
+func (ws *LogStrategy) checkSubmatch(b []byte) error {
+ return ws.submatchCallback(ws.Log, ws.re.FindAllSubmatch(b, -1))
}
diff --git a/wait/log_test.go b/wait/log_test.go
index 7c767c0e25..4bfbc26438 100644
--- a/wait/log_test.go
+++ b/wait/log_test.go
@@ -1,14 +1,17 @@
-package wait
+package wait_test
import (
- "bytes"
"context"
+ "fmt"
"io"
+ "strings"
"testing"
"time"
"github.com/docker/docker/api/types"
"github.com/stretchr/testify/require"
+
+ "github.com/testcontainers/testcontainers-go/wait"
)
const logTimeout = time.Second
@@ -25,107 +28,164 @@ Donec ut libero sed arcu vehicula ultricies a non tortor.
Lorem ipsum dolor sit amet, consectetur adipiscing elit.`
func TestWaitForLog(t *testing.T) {
- t.Run("no regexp", func(t *testing.T) {
- target := NopStrategyTarget{
- ReaderCloser: io.NopCloser(bytes.NewReader([]byte("docker"))),
+ t.Run("string", func(t *testing.T) {
+ target := wait.NopStrategyTarget{
+ ReaderCloser: readCloser("docker"),
}
- wg := NewLogStrategy("docker").WithStartupTimeout(100 * time.Millisecond)
+ wg := wait.NewLogStrategy("docker").WithStartupTimeout(100 * time.Millisecond)
err := wg.WaitUntilReady(context.Background(), target)
require.NoError(t, err)
})
- t.Run("no regexp", func(t *testing.T) {
- target := NopStrategyTarget{
- ReaderCloser: io.NopCloser(bytes.NewReader([]byte(loremIpsum))),
+ t.Run("regexp", func(t *testing.T) {
+ target := wait.NopStrategyTarget{
+ ReaderCloser: readCloser(loremIpsum),
}
// get all words that start with "ip", end with "m" and has a whitespace before the "ip"
- wg := NewLogStrategy(`\sip[\w]+m`).WithStartupTimeout(100 * time.Millisecond).AsRegexp()
+ wg := wait.NewLogStrategy(`\sip[\w]+m`).WithStartupTimeout(100 * time.Millisecond).AsRegexp()
+ err := wg.WaitUntilReady(context.Background(), target)
+ require.NoError(t, err)
+ })
+
+ t.Run("submatch/valid", func(t *testing.T) {
+ target := wait.NopStrategyTarget{
+ ReaderCloser: readCloser("three matches: ip1m, ip2m, ip3m"),
+ }
+
+ wg := wait.NewLogStrategy(`ip(\d)m`).WithStartupTimeout(100 * time.Millisecond).Submatch(func(pattern string, submatches [][][]byte) error {
+ if len(submatches) != 3 {
+ return wait.NewPermanentError(fmt.Errorf("%q matched %d times, expected %d", pattern, len(submatches), 3))
+ }
+ return nil
+ })
+ err := wg.WaitUntilReady(context.Background(), target)
+ require.NoError(t, err)
+ })
+
+ t.Run("submatch/permanent-error", func(t *testing.T) {
+ target := wait.NopStrategyTarget{
+ ReaderCloser: readCloser("single matches: ip1m"),
+ }
+
+ wg := wait.NewLogStrategy(`ip(\d)m`).WithStartupTimeout(100 * time.Millisecond).Submatch(func(pattern string, submatches [][][]byte) error {
+ if len(submatches) != 3 {
+ return wait.NewPermanentError(fmt.Errorf("%q matched %d times, expected %d", pattern, len(submatches), 3))
+ }
+ return nil
+ })
+ err := wg.WaitUntilReady(context.Background(), target)
+ require.Error(t, err)
+ var permanentError *wait.PermanentError
+ require.ErrorAs(t, err, &permanentError)
+ })
+
+ t.Run("submatch/temporary-error", func(t *testing.T) {
+ target := newRunningTarget()
+ expect := target.EXPECT()
+ expect.Logs(anyContext).Return(readCloser(""), nil).Once() // No matches.
+ expect.Logs(anyContext).Return(readCloser("ip1m, ip2m"), nil).Once() // Two matches.
+ expect.Logs(anyContext).Return(readCloser("ip1m, ip2m, ip3m"), nil).Once() // Three matches.
+ expect.Logs(anyContext).Return(readCloser("ip1m, ip2m, ip3m, ip4m"), nil) // Four matches.
+
+ wg := wait.NewLogStrategy(`ip(\d)m`).WithStartupTimeout(400 * time.Second).Submatch(func(pattern string, submatches [][][]byte) error {
+ switch len(submatches) {
+ case 0, 2:
+ // Too few matches.
+ return fmt.Errorf("`%s` matched %d times, expected %d (temporary)", pattern, len(submatches), 3)
+ case 3:
+ // Expected number of matches should stop the wait.
+ return nil
+ default:
+ // Should not be triggered.
+ return wait.NewPermanentError(fmt.Errorf("`%s` matched %d times, expected %d (permanent)", pattern, len(submatches), 3))
+ }
+ })
err := wg.WaitUntilReady(context.Background(), target)
require.NoError(t, err)
})
}
func TestWaitWithExactNumberOfOccurrences(t *testing.T) {
- t.Run("no regexp", func(t *testing.T) {
- target := NopStrategyTarget{
- ReaderCloser: io.NopCloser(bytes.NewReader([]byte("kubernetes\r\ndocker\n\rdocker"))),
+ t.Run("string", func(t *testing.T) {
+ target := wait.NopStrategyTarget{
+ ReaderCloser: readCloser("kubernetes\r\ndocker\n\rdocker"),
}
- wg := NewLogStrategy("docker").
+ wg := wait.NewLogStrategy("docker").
WithStartupTimeout(100 * time.Millisecond).
WithOccurrence(2)
err := wg.WaitUntilReady(context.Background(), target)
require.NoError(t, err)
})
- t.Run("as regexp", func(t *testing.T) {
- target := NopStrategyTarget{
- ReaderCloser: io.NopCloser(bytes.NewReader([]byte(loremIpsum))),
+ t.Run("regexp", func(t *testing.T) {
+ target := wait.NopStrategyTarget{
+ ReaderCloser: readCloser(loremIpsum),
}
// get texts from "ip" to the next "m".
// there are three occurrences of this pattern in the string:
// one "ipsum mauris" and two "ipsum dolor sit am"
- wg := NewLogStrategy(`ip(.*)m`).WithStartupTimeout(100 * time.Millisecond).AsRegexp().WithOccurrence(3)
+ wg := wait.NewLogStrategy(`ip(.*)m`).WithStartupTimeout(100 * time.Millisecond).AsRegexp().WithOccurrence(3)
err := wg.WaitUntilReady(context.Background(), target)
require.NoError(t, err)
})
}
func TestWaitWithExactNumberOfOccurrencesButItWillNeverHappen(t *testing.T) {
- t.Run("no regexp", func(t *testing.T) {
- target := NopStrategyTarget{
- ReaderCloser: io.NopCloser(bytes.NewReader([]byte("kubernetes\r\ndocker"))),
+ t.Run("string", func(t *testing.T) {
+ target := wait.NopStrategyTarget{
+ ReaderCloser: readCloser("kubernetes\r\ndocker"),
}
- wg := NewLogStrategy("containerd").
+ wg := wait.NewLogStrategy("containerd").
WithStartupTimeout(logTimeout).
WithOccurrence(2)
err := wg.WaitUntilReady(context.Background(), target)
require.Error(t, err)
})
- t.Run("as regexp", func(t *testing.T) {
- target := NopStrategyTarget{
- ReaderCloser: io.NopCloser(bytes.NewReader([]byte(loremIpsum))),
+ t.Run("regexp", func(t *testing.T) {
+ target := wait.NopStrategyTarget{
+ ReaderCloser: readCloser(loremIpsum),
}
// get texts from "ip" to the next "m".
// there are only three occurrences matching
- wg := NewLogStrategy(`do(.*)ck.+`).WithStartupTimeout(100 * time.Millisecond).AsRegexp().WithOccurrence(4)
+ wg := wait.NewLogStrategy(`do(.*)ck.+`).WithStartupTimeout(100 * time.Millisecond).AsRegexp().WithOccurrence(4)
err := wg.WaitUntilReady(context.Background(), target)
require.Error(t, err)
})
}
func TestWaitShouldFailWithExactNumberOfOccurrences(t *testing.T) {
- t.Run("no regexp", func(t *testing.T) {
- target := NopStrategyTarget{
- ReaderCloser: io.NopCloser(bytes.NewReader([]byte("kubernetes\r\ndocker"))),
+ t.Run("string", func(t *testing.T) {
+ target := wait.NopStrategyTarget{
+ ReaderCloser: readCloser("kubernetes\r\ndocker"),
}
- wg := NewLogStrategy("docker").
+ wg := wait.NewLogStrategy("docker").
WithStartupTimeout(logTimeout).
WithOccurrence(2)
err := wg.WaitUntilReady(context.Background(), target)
require.Error(t, err)
})
- t.Run("as regexp", func(t *testing.T) {
- target := NopStrategyTarget{
- ReaderCloser: io.NopCloser(bytes.NewReader([]byte(loremIpsum))),
+ t.Run("regexp", func(t *testing.T) {
+ target := wait.NopStrategyTarget{
+ ReaderCloser: readCloser(loremIpsum),
}
// get "Maecenas".
// there are only one occurrence matching
- wg := NewLogStrategy(`^Mae[\w]?enas\s`).WithStartupTimeout(100 * time.Millisecond).AsRegexp().WithOccurrence(2)
+ wg := wait.NewLogStrategy(`^Mae[\w]?enas\s`).WithStartupTimeout(100 * time.Millisecond).AsRegexp().WithOccurrence(2)
err := wg.WaitUntilReady(context.Background(), target)
require.Error(t, err)
})
}
func TestWaitForLogFailsDueToOOMKilledContainer(t *testing.T) {
- target := &MockStrategyTarget{
+ target := &wait.MockStrategyTarget{
LogsImpl: func(_ context.Context) (io.ReadCloser, error) {
- return io.NopCloser(bytes.NewReader([]byte(""))), nil
+ return readCloser(""), nil
},
StateImpl: func(_ context.Context) (*types.ContainerState, error) {
return &types.ContainerState{
@@ -134,16 +194,16 @@ func TestWaitForLogFailsDueToOOMKilledContainer(t *testing.T) {
},
}
- t.Run("no regexp", func(t *testing.T) {
- wg := ForLog("docker").WithStartupTimeout(logTimeout)
+ t.Run("string", func(t *testing.T) {
+ wg := wait.ForLog("docker").WithStartupTimeout(logTimeout)
err := wg.WaitUntilReady(context.Background(), target)
expected := "container crashed with out-of-memory (OOMKilled)"
require.EqualError(t, err, expected)
})
- t.Run("as regexp", func(t *testing.T) {
- wg := ForLog("docker").WithStartupTimeout(logTimeout).AsRegexp()
+ t.Run("regexp", func(t *testing.T) {
+ wg := wait.ForLog("docker").WithStartupTimeout(logTimeout).AsRegexp()
err := wg.WaitUntilReady(context.Background(), target)
expected := "container crashed with out-of-memory (OOMKilled)"
@@ -152,9 +212,9 @@ func TestWaitForLogFailsDueToOOMKilledContainer(t *testing.T) {
}
func TestWaitForLogFailsDueToExitedContainer(t *testing.T) {
- target := &MockStrategyTarget{
+ target := &wait.MockStrategyTarget{
LogsImpl: func(_ context.Context) (io.ReadCloser, error) {
- return io.NopCloser(bytes.NewReader([]byte(""))), nil
+ return readCloser(""), nil
},
StateImpl: func(_ context.Context) (*types.ContainerState, error) {
return &types.ContainerState{
@@ -164,16 +224,16 @@ func TestWaitForLogFailsDueToExitedContainer(t *testing.T) {
},
}
- t.Run("no regexp", func(t *testing.T) {
- wg := ForLog("docker").WithStartupTimeout(logTimeout)
+ t.Run("string", func(t *testing.T) {
+ wg := wait.ForLog("docker").WithStartupTimeout(logTimeout)
err := wg.WaitUntilReady(context.Background(), target)
expected := "container exited with code 1"
require.EqualError(t, err, expected)
})
- t.Run("as regexp", func(t *testing.T) {
- wg := ForLog("docker").WithStartupTimeout(logTimeout).AsRegexp()
+ t.Run("regexp", func(t *testing.T) {
+ wg := wait.ForLog("docker").WithStartupTimeout(logTimeout).AsRegexp()
err := wg.WaitUntilReady(context.Background(), target)
expected := "container exited with code 1"
@@ -182,9 +242,9 @@ func TestWaitForLogFailsDueToExitedContainer(t *testing.T) {
}
func TestWaitForLogFailsDueToUnexpectedContainerStatus(t *testing.T) {
- target := &MockStrategyTarget{
+ target := &wait.MockStrategyTarget{
LogsImpl: func(_ context.Context) (io.ReadCloser, error) {
- return io.NopCloser(bytes.NewReader([]byte(""))), nil
+ return readCloser(""), nil
},
StateImpl: func(_ context.Context) (*types.ContainerState, error) {
return &types.ContainerState{
@@ -193,19 +253,24 @@ func TestWaitForLogFailsDueToUnexpectedContainerStatus(t *testing.T) {
},
}
- t.Run("no regexp", func(t *testing.T) {
- wg := ForLog("docker").WithStartupTimeout(logTimeout)
+ t.Run("string", func(t *testing.T) {
+ wg := wait.ForLog("docker").WithStartupTimeout(logTimeout)
err := wg.WaitUntilReady(context.Background(), target)
expected := "unexpected container status \"dead\""
require.EqualError(t, err, expected)
})
- t.Run("as regexp", func(t *testing.T) {
- wg := ForLog("docker").WithStartupTimeout(logTimeout).AsRegexp()
+ t.Run("regexp", func(t *testing.T) {
+ wg := wait.ForLog("docker").WithStartupTimeout(logTimeout).AsRegexp()
err := wg.WaitUntilReady(context.Background(), target)
expected := "unexpected container status \"dead\""
require.EqualError(t, err, expected)
})
}
+
+// readCloser returns an io.ReadCloser that reads from s.
+func readCloser(s string) io.ReadCloser {
+ return io.NopCloser(strings.NewReader((s)))
+}