diff --git a/.github/workflows/release.yml b/.github/workflows/release.yaml similarity index 92% rename from .github/workflows/release.yml rename to .github/workflows/release.yaml index 0b446ba..a86a9c6 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yaml @@ -19,6 +19,8 @@ jobs: go-version-file: go.mod - name: Test application run: go test ./... + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 - name: Login to Docker Hub uses: docker/login-action@v3 with: diff --git a/.github/workflows/testing-pr.yaml b/.github/workflows/testing-pr.yaml new file mode 100644 index 0000000..1279c0b --- /dev/null +++ b/.github/workflows/testing-pr.yaml @@ -0,0 +1,31 @@ +name: Testing Pull Request +on: + pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + test-app: + name: Test Goreleaser Application + runs-on: ubuntu-latest + steps: + - name: Clone repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + - name: Test application + run: go test ./... + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + - name: Dry-run goreleaser application + uses: goreleaser/goreleaser-action@v6 + with: + distribution: goreleaser + version: ~> v2 + args: release --snapshot --skip=publish --clean diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yaml similarity index 79% rename from .github/workflows/testing.yml rename to .github/workflows/testing.yaml index 0f89b50..e5787c1 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yaml @@ -1,7 +1,6 @@ -name: Testing +name: Testing Commit on: push: - pull_request: jobs: test-app: @@ -10,11 +9,11 @@ jobs: steps: - name: Clone repository uses: actions/checkout@v4 + with: + fetch-depth: 0 - name: Set up Go uses: actions/setup-go@v5 with: go-version-file: go.mod - name: Test application run: go test ./... - - name: Compile application - run: go build diff --git a/.gitignore b/.gitignore index 849ddff..f75cb67 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ dist/ +wait-for diff --git a/.goreleaser.yml b/.goreleaser.yml index 76c64fc..73cf929 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -31,7 +31,7 @@ archives: checksum: name_template: "checksums.txt" snapshot: - name_template: "{{ incpatch .Version }}-next" + version_template: "{{ incpatch .Version }}-next" dockers: - image_templates: - "ghcr.io/patrickdappollonio/wait-for:{{ .Tag }}-amd64" diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..e962648 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Patrick D'appollonio + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 75ab526..56fa87f 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # `wait-for` -A tiny Go application with zero dependencies. Given a number of TCP or UDP `host:port` pairs, the app will wait until either all are available or a timeout is reached. `wait-for` supports pinging TCP or UDP hosts, by prefixing the host with `tcp://` or `udp://`, respectively. If no prefix is provided, the app will default to TCP. +A Go application with zero dependencies. Given a number of hosts, the app will wait until either all are available or a timeout is reached. `wait-for` supports pinging several host types (see [supported probes](#supported-probes)), by prefixing the host with a specific protocol. If no prefix is provided, the app will default to TCP. Kudos to @vishnubob for the [original implementation in Bash](https://github.com/vishnubob/wait-for-it). @@ -19,20 +19,37 @@ This will ping both `google.com` on port `443` and `mysql.example.com` on port ` All the parameters accepted by the application are shown in the help section, as shown below. +### Command-line help + ```text -wait-for allows you to wait for a TCP resource to respond to requests. +wait-for allows you to wait for a resource to respond to requests. + +It does this by performing a connection to the specified host and port. If +there's no resource behind it and the connection cannot be established, the +request is retried until either the timeout is reached or the resource becomes +available. -It does this by performing a TCP connection to the specified host and port. If there's -no resource behind it and the connection cannot be established, the request is retried -until either the timeout is reached or the resource becomes available. +Each protocol defines its own way of checking for the resource. For example, a +TCP connection will attempt to connect to the host and port specified, while a +MySQL connection will attempt to connect to the host and port, and then ping the +database. -By default, the standard timeout is 10 seconds. +By default, the standard timeout is 10 seconds but it can be customized for all +requests. The time between each request is 1 second, but this can also be +customized. For documentation, visit: https://github.com/patrickdappollonio/wait-for. Usage: wait-for [flags] +Examples: + wait-for -s localhost:80 wait for a web server to accept connections + wait-for -s mysql.example.local:3306 wait for a MySQL database to accept connections + wait-for -s udp://localhost:53 wait for a DNS server to accept connections + wait-for --host localhost:80 --host localhost:81 wait for multiple resources to accept connections + wait-for --host mysql://localhost:3306 wait until a MySQL database is ready to accept connections and responds to pings + Flags: -e, --every duration time to wait between each request attempt against the host (default 1s) -h, --help help for wait-for @@ -42,6 +59,18 @@ Flags: --version version for wait-for ``` +### Supported probes + +The following probes are supported: + +* [TCP probe](docs/tcp-probe.md) +* [UDP probe](docs/udp-probe.md) +* [HTTP & HTTPS probe](docs/http-https-probe.md) +* [MySQL probe](docs/mysql-probe.md) *(experimental)* +* [PostgreSQL probe](docs/postgres-probe.md) *(experimental)* + +If you're interested in adding a new probe, please refer to the [Adding new probes documentation](docs/readme.md#adding-new-probes). + ### Usage with Kubernetes Simply use this tool as an `initContainer` before your application runs, and validate whether your databases or any TCP-accessible resource (such as websites, too) are up and running, or fail early with proper knowledge of the situation. @@ -71,3 +100,15 @@ spec: - name: nginx-container image: nginx ``` + +### Validating connectivity to a MySQL or Postgres database + +If you want to validate that a MySQL database is up and running, you can use the `mysql://` or `postgres://` prefix. This will attempt to connect to the host and port specified, and then ping the database as well. This is different than the default TCP probe, which only checks if the server is accepting connections on the specified port. + +For more details, check the [MySQL probe documentation](docs/mysql-probe.md) and the [PostgreSQL probe documentation](docs/postgres-probe.md). + +### Validating connectivity to an HTTP or HTTPS endpoint + +If you want to validate that an HTTP or HTTPS endpoint is up and running, you can use the `http://` or `https://` prefix. This will attempt to connect to the host and port specified, and then perform an HTTP GET request to the root path (`/`) of the server where the server must respond within 1 second. This is different than the default TCP probe, which only checks if the server is accepting connections on the specified port. + +For HTTPS requests, the certificate is also validated. For more details, check the [HTTP & HTTPS probe documentation](docs/http-https-probe.md). diff --git a/app.go b/app.go index 8764b70..b3bd0ff 100644 --- a/app.go +++ b/app.go @@ -1,52 +1,77 @@ package main import ( + "errors" "fmt" + "io/fs" "time" "github.com/patrickdappollonio/wait-for/wait" "github.com/spf13/cobra" + "github.com/spf13/viper" ) var version = "development" const ( - helpShort = "wait-for allows you to wait for a TCP resource to respond to requests." + helpShort = "wait-for allows you to wait for a resource to respond to requests." - helpLong = `wait-for allows you to wait for a TCP resource to respond to requests. + helpLong = `wait-for allows you to wait for a resource to respond to requests. -It does this by performing a TCP connection to the specified host and port. If there's -no resource behind it and the connection cannot be established, the request is retried -until either the timeout is reached or the resource becomes available. +It does this by performing a connection to the specified host and port. If there's no resource behind it and the connection cannot be established, the request is retried until either the timeout is reached or the resource becomes available. -By default, the standard timeout is 10 seconds. +Each protocol defines its own way of checking for the resource. For example, a TCP connection will attempt to connect to the host and port specified, while a MySQL connection will attempt to connect to the host and port, and then ping the database. + +By default, the standard timeout is 10 seconds but it can be customized for all requests. The time between each request is 1 second, but this can also be customized. For documentation, visit: https://github.com/patrickdappollonio/wait-for.` ) func root() *cobra.Command { - var ( - hosts []string - timeout time.Duration - step time.Duration - verbose bool - ) + var cfgFile string + var hosts []string rootCommand := &cobra.Command{ - Use: "wait-for", - Short: helpShort, - Long: helpLong, - Version: version, + Use: "wait-for", + Short: helpShort, + Long: wrap(helpLong, 80), + Version: version, + Example: exampleCommands("wait-for", []example{ + {command: "-s localhost:80", helper: "wait for a web server to accept connections"}, + {command: "-s mysql.example.local:3306", helper: "wait for a MySQL database to accept connections"}, + {command: "-s udp://localhost:53", helper: "wait for a DNS server to accept connections"}, + {command: "--host localhost:80 --host localhost:81", helper: "wait for multiple resources to accept connections"}, + {command: "--host mysql://localhost:3306", helper: "wait until a MySQL database is ready to accept connections and responds to pings"}, + {command: "--host postgres://localhost:5432", helper: "wait until a PostgreSQL database is ready to accept connections and responds to pings"}, + {command: "--host http://localhost:8080", helper: "wait until an HTTP server is ready to accept connections and responds to requests with a 200-299 status code"}, + {command: "--host https://localhost:443", helper: "wait until an HTTPS server is ready to accept connections and responds to requests with a 200-299 status code and a valid certificate"}, + {command: "--config targets.yaml", helper: "load hosts and settings from a YAML file"}, + }), SilenceUsage: true, SilenceErrors: true, RunE: func(_ *cobra.Command, args []string) error { - w, err := wait.New(hosts, step, timeout, verbose) - if err != nil { - return err + // Read config file if available + viper.SetConfigFile(cfgFile) + if err := viper.ReadInConfig(); err != nil { + // If config not found, it's not fatal unless we rely on it + if errors.Is(err, &fs.PathError{}) { + return fmt.Errorf("error reading config file: %w", err) + } + } + + // Merge hosts flag with viper flags + hosts = append(hosts, viper.GetStringSlice("host")...) + + // Retrieve final values from viper after merging CLI flags and config + app := &wait.App{ + Hosts: hosts, + Timeout: viper.GetDuration("timeout"), + Every: viper.GetDuration("every"), + Verbose: viper.GetBool("verbose"), } - fmt.Println(w.String()) - if err := w.PingAll(); err != nil { + // Run the application + if err := app.Run(); err != nil { return err } @@ -55,10 +80,17 @@ func root() *cobra.Command { }, } - rootCommand.Flags().StringSliceVarP(&hosts, "host", "s", []string{}, "hosts to connect to in the format \"host:port\" with optional protocol prefix (tcp:// or udp://)") - rootCommand.Flags().DurationVarP(&timeout, "timeout", "t", time.Second*10, "maximum time to wait for the endpoints to respond before giving up") - rootCommand.Flags().DurationVarP(&step, "every", "e", time.Second*1, "time to wait between each request attempt against the host") - rootCommand.Flags().BoolVarP(&verbose, "verbose", "v", false, "enable verbose output -- will print every time a request is made") + // Flags for the program + rootCommand.Flags().StringSliceVarP(&hosts, "host", "s", []string{}, `hosts to connect to in the format "host:port" or with protocol prefix for one of the supported protocols (e.g. "udp://host:port")`) + rootCommand.Flags().DurationP("timeout", "t", 10*time.Second, "maximum time to wait for the endpoints to respond before giving up") + rootCommand.Flags().DurationP("every", "e", 1*time.Second, "time to wait between each request attempt against the host") + rootCommand.Flags().BoolP("verbose", "v", false, "enable verbose output -- will print every time a request is made") + rootCommand.Flags().StringVar(&cfgFile, "config", "targets.yaml", "config file to load hosts and settings from") + + // Bind flags to viper except hosts and config file + viper.BindPFlag("timeout", rootCommand.Flags().Lookup("timeout")) + viper.BindPFlag("every", rootCommand.Flags().Lookup("every")) + viper.BindPFlag("verbose", rootCommand.Flags().Lookup("verbose")) return rootCommand } diff --git a/docs/configuration-file.md b/docs/configuration-file.md new file mode 100644 index 0000000..0baa85b --- /dev/null +++ b/docs/configuration-file.md @@ -0,0 +1,55 @@ +# Configuring `wait-for` with a configuration file + +`wait-for` supports reading configuration from a file from a file called `targets.yaml` (configurable with `--config`). This is useful when you have a lot of hosts to check and you don't want to pass them all as command-line arguments. + +The following is an example YAML configuration file: + +```yaml +# file: targets.yaml +hosts: + - "tcp://localhost:8080" + - "udp://localhost:53" +timeout: 30s +every: 2s +verbose: true +``` + +This is equal to calling the CLI with the following arguments: + +```bash +wait-for \ + --host "tcp://localhost:8080" \ + --host "udp://localhost:53" \ + --timeout 30s \ + --every 2s \ + --verbose +``` + +You can mix-and-match hosts: any hosts provided via the configuration file will be merged with the hosts provided via the command line argument `--host` or `-s`, for example, the following config file and the following command will ping all endpoints (both from the config file and the command line): + +```bash +$ cat targets.yaml +# file: targets.yaml +hosts: + - "tcp://localhost:8080" + - "udp://localhost:53" +timeout: 30s +every: 2s +verbose: true + +$ wait-for \ + --host "localhost:80" \ + --host "localhost:81" \ + --timeout 10s +``` + +The above command will ping the following hosts by merging the two sources (configuration file and command-line flags): + +```text +tcp://localhost:8080 +udp://localhost:53 +tcp://localhost:80 +tcp://localhost:81 +``` + +A host present both in the command-line arguments and in the configuration file will be pinged twice. diff --git a/docs/http-https-probe.md b/docs/http-https-probe.md new file mode 100644 index 0000000..aa409c6 --- /dev/null +++ b/docs/http-https-probe.md @@ -0,0 +1,23 @@ +# HTTP & HTTPS + +The HTTP and HTTPS probes are used to send an HTTP or HTTPS `GET` request to a server and check the response. A request is successful not only if the HTTP server was able to provide a connection but also if the response status code is within the range of 200 to 299. If the request responds within this range, the probe will exit successfully. + +If the connection cannot be established or the response status code is outside the range of 200 to 299, the probe will retry until either the timeout is reached or the resource becomes available. + +An example request to `http://localhost:80` would look like this: + +```bash +wait-for --host "http://localhost:80" +``` + +An example request to `https://localhost:443` would look like this: + +```bash +wait-for --host "https://localhost:443" +``` + +## Certificate Validation + +The HTTPS probe (that is, where a target host is configured to use `https://` protocol) will attempt to validate the certificate chain and the hostname. If the certificate chain is invalid or the hostname doesn't match, the probe will exit with an error and the resource will be considered unavailable. + +A valid HTTPS request on resources with custom certificates will require you to provide the CA certificate to the probe. By default, any certs stored in `/etc/ssl/certs/ca-certificates.crt` will be used to validate the connection. diff --git a/docs/mysql-probe.md b/docs/mysql-probe.md new file mode 100644 index 0000000..6d79b23 --- /dev/null +++ b/docs/mysql-probe.md @@ -0,0 +1,25 @@ +# MySQL + +The MySQL probe will attempt to connect to the host and port specified. Once connected, it will attempt to perform a "ping" with a 1 second timeout from establishing the connection. If the connection can be established successfully and the database responds to the ping, the probe will exit successfully. + +If the connection cannot be established or the ping fails, the probe will retry until either the timeout is reached or the resource becomes available. + +The probe makes no guarantees about the existence of a table or the validity of the data in the database. It merely checks if the server is accepting connections on the specified port and if the database responds to the ping. + +Internally, [the probe uses the `github.com/go-sql-driver/mysql` package](https://github.com/go-sql-driver/mysql) to connect to the database and perform the ping. This means that the connection string must be in the format `mysql://user:password@host:port/dbname`. + +## Security + +Since the credentials have to be provided plain text in the command line, it's recommended to use this probe in a secure environment or dynamically create the configuration file for such purpose (we recommend [using something like `tgen` to generate the configuration file](https://github.com/patrickdappollonio/tgen) on-the-fly). + +It is not possible to provide *just the password* via environment variables, since the application supports specifying multiple hosts and ports, and each one can have a different password. However, nothing prevents you from using environment variables as part of the connection string when providing command-line flags, like: + +```bash +wait-for --host "mysql://$MYSQL_USER:$MYSQL_PASSWORD@localhost:3306" +``` + +## TLS Support + +To perform TLS connections, the container or host running the probe must have the necessary certificates to establish the connection. The probe will not attempt to validate the certificate chain or the hostname, so it's recommended to use this probe in a secure environment or to validate the certificates in another way. + +By default, any certs stored in `/etc/ssl/certs/ca-certificates.crt` will be used to validate the connection. diff --git a/docs/postgres-probe.md b/docs/postgres-probe.md new file mode 100644 index 0000000..2d5335e --- /dev/null +++ b/docs/postgres-probe.md @@ -0,0 +1,25 @@ +# PostgreSQL + +The PostgreSQL probe will attempt to connect to the host and port specified. Once connected, it will attempt to perform a "ping" with a 1 second timeout from establishing the connection. If the connection can be established successfully and the database responds to the ping, the probe will exit successfully. + +If the connection cannot be established or the ping fails, the probe will retry until either the timeout is reached or the resource becomes available. + +The probe makes no guarantees about the existence of a table or the validity of the data in the database. It merely checks if the server is accepting connections on the specified port and if the database responds to the ping. + +Internally, [the probe uses the `github.com/jackc/pgx/v5` package](https://github.com/jackc/pgx/v5) to connect to the database and perform the ping. This means that the connection string must be in the format `postgres://username:password@localhost:5432/database_name`. + +## Security + +Since the credentials have to be provided plain text in the command line, it's recommended to use this probe in a secure environment or dynamically create the configuration file for such purpose (we recommend [using something like `tgen` to generate the configuration file](https://github.com/patrickdappollonio/tgen) on-the-fly). + +It is not possible to provide *just the password* via environment variables, since the application supports specifying multiple hosts and ports, and each one can have a different password. However, nothing prevents you from using environment variables as part of the connection string when providing command-line flags, like: + +```bash +wait-for --host "postgres://$PG_USER:$PG_PASSWORD@localhost:5432" +``` + +## TLS Support + +To perform TLS connections, the container or host running the probe must have the necessary certificates to establish the connection. The probe will not attempt to validate the certificate chain or the hostname, so it's recommended to use this probe in a secure environment or to validate the certificates in another way. + +By default, any certs stored in `/etc/ssl/certs/ca-certificates.crt` will be used to validate the connection. diff --git a/docs/readme.md b/docs/readme.md new file mode 100644 index 0000000..8025b2a --- /dev/null +++ b/docs/readme.md @@ -0,0 +1,59 @@ +# `wait-for` documentation + +`wait-for` allows you to wait for a resource to respond to requests. It does this by performing a connection to the specified host provided either by a configuration file or by command line arguments. If there's no resource behind it and the connection cannot be established, the request is retried until either the timeout is reached or the resource becomes available. + +## Configuration + +The application can be configured either by command line arguments or by a configuration file. The configuration file is a YAML file that can be passed to the application using the `--config` flag. + +Host flags (those with `--host` or `-s`) can be specified both via the command line and the configuration file. If a host is specified in both, it will be pinged twice. + +For more information on how to use the configuration file, please refer to the [configuration file documentation](configuration-file.md). + +## Supported probes + +"Probes" are the way `wait-for` checks for the availability of a resource. Each probe maps to a specific protocol and checks for the availability of a resource in a specific way. + +The following probes are supported: + +* [TCP probe](tcp-probe.md) +* [UDP probe](udp-probe.md) +* [HTTP & HTTPS probe](http-https-probe.md) +* [MySQL probe](mysql-probe.md) *(experimental)* +* [PostgreSQL probe](postgres-probe.md) *(experimental)* + +## Adding new probes + +If you want to add a new probe, you can do so by implementing the `Probe` interface. The interface is defined as follows: + +```go +// Pinger defines the interface for a pinger. +type Pinger interface { + Bootstrap(host string) error + Ping(ctx context.Context) error +} +``` + +Then, the probe has to be matched to a protocol stored in the `pingerRegistry` variable in the `wait` package. This is done by adding a new entry to the map, where the key is the protocol and the value is the probe implementation. Currently, the following protocols are supported: + +```go +// pingerRegistry holds the mapping from protocol to pinger handler. +// Add your own pinger here. +var pingerRegistry = map[string]func() Pinger{ + "tcp": func() Pinger { return &probes.TCPPinger{} }, + "udp": func() Pinger { return &probes.UDPPinger{} }, + "mysql": func() Pinger { return &probes.MySQLPinger{} }, + "postgres": func() Pinger { return &probes.PostgresPinger{} }, + "http": func() Pinger { return &probes.HTTPPinger{} }, + "https": func() Pinger { return &probes.HTTPSPinger{} }, +} +``` + +When creating your own probes, the following rules apply: + +* A protocol must be unique. +* The protocol must be lowercase. +* A probe `struct` should accept no parameters. +* A probe `Bootstrap` method should accept a `host` parameter which should validate the host and set up the probe, or return an error if the host is invalid. +* A probe `Ping` method should accept a `context.Context` parameter and return an error if the ping fails or the context is canceled. +* I reserve the discretion to accept or reject any pull request that adds a new probe. diff --git a/docs/tcp-probe.md b/docs/tcp-probe.md new file mode 100644 index 0000000..15fcd1c --- /dev/null +++ b/docs/tcp-probe.md @@ -0,0 +1,14 @@ +# TCP + +The TCP probe will attempt to connect to the host and port specified. If the connection can be established successfully, the probe will exit successfully. + +If the connection cannot be established, the probe will retry until either the timeout is reached or the resource becomes available. + +By default, if the protocol isn't specified when providing a `--host` flag, TCP is assumed. This means that the following two commands are equivalent: + +```bash +wait-for --host "localhost:80" +wait-for --host "tcp://localhost:80" +``` + +TCP probes make no guarantees the response received from the server has any sort of validity. They merely check if the server is accepting connections on the specified port. diff --git a/docs/udp-probe.md b/docs/udp-probe.md new file mode 100644 index 0000000..0305f3e --- /dev/null +++ b/docs/udp-probe.md @@ -0,0 +1,13 @@ +# UDP + +The UDP probe will attempt to connect to the host and port specified. If the connection can be established successfully and at least one zero-length packet can be sent, the probe will exit successfully. + +If the connection cannot be established, the probe will retry until either the timeout is reached or the resource becomes available. + +By default, if the protocol isn't specified when providing a `--host` flag, TCP is assumed. To use UDP, you must prefix the host with `udp://`: + +```bash +wait-for --host "udp://localhost:53" +``` + +UDP probes make no guarantees the response received from the server has any sort of validity. They merely check if the server is accepting connections on the specified port and if a zero-length packet can be sent. diff --git a/go.mod b/go.mod index 8079017..3a46872 100644 --- a/go.mod +++ b/go.mod @@ -2,9 +2,37 @@ module github.com/patrickdappollonio/wait-for go 1.22 -require github.com/spf13/cobra v1.8.1 +require ( + github.com/go-sql-driver/mysql v1.8.1 + github.com/jackc/pgx/v5 v5.7.1 + github.com/spf13/cobra v1.8.1 + github.com/spf13/viper v1.19.0 + golang.org/x/sync v0.10.0 +) require ( + filippo.io/edwards25519 v1.1.0 // indirect + github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/magiconair/properties v1.8.7 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/sagikazarmark/locafero v0.4.0 // indirect + github.com/sagikazarmark/slog-shim v0.1.0 // indirect + github.com/sourcegraph/conc v0.3.0 // indirect + github.com/spf13/afero v1.11.0 // indirect + github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + go.uber.org/atomic v1.9.0 // indirect + go.uber.org/multierr v1.9.0 // indirect + golang.org/x/crypto v0.27.0 // indirect + golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect + golang.org/x/sys v0.25.0 // indirect + golang.org/x/text v0.18.0 // indirect + gopkg.in/ini.v1 v1.67.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 912390a..72828b7 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,94 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= +github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs= +github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= +github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= +github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= +github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= +github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= +github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= +github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= +github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= +github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= +github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= +github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= +github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.19.0 h1:RWq5SEjt8o25SROyN3z2OrDB9l7RPd3lwTWU8EcEdcI= +github.com/spf13/viper v1.19.0/go.mod h1:GQUN9bilAbhU/jgc1bKs99f/suXKeUMct8Adx5+Ntkg= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= +go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= +golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= +golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= +golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= +golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= +gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..f457753 --- /dev/null +++ b/utils.go @@ -0,0 +1,106 @@ +package main + +import ( + "fmt" + "strings" +) + +// wrap wraps the input text to the specified width without splitting words. +// It preserves existing line breaks. +// If width is less than or equal to 0, it defaults to 80. +func wrap(text string, width int) string { + if width <= 0 { + width = 80 + } + + var builder strings.Builder + + // Split the text into lines based on newline characters. + // Handles both Unix (\n) and Windows (\r\n) line endings. + lines := splitIntoLines(text) + + for i, line := range lines { + // For lines after the first, prepend a newline to preserve line breaks. + if i > 0 { + builder.WriteByte('\n') + } + + // If the line is empty, preserve the empty line. + if strings.TrimSpace(line) == "" { + continue + } + + // Wrap the individual line and write to the builder. + wrapped := wrapSingleLine(line, width) + builder.WriteString(wrapped) + } + + return builder.String() +} + +// splitIntoLines splits the input text into lines, handling both \n and \r\n. +func splitIntoLines(text string) []string { + // Normalize line endings to \n + text = strings.ReplaceAll(text, "\r\n", "\n") + // Split by \n + return strings.Split(text, "\n") +} + +// wrapSingleLine wraps a single line of text to the specified width without splitting words. +func wrapSingleLine(line string, width int) string { + words := strings.Fields(line) + if len(words) == 0 { + return "" + } + + var wrapped strings.Builder + currentLine := words[0] + + for _, word := range words[1:] { + // Check if adding the next word exceeds the width + if len(currentLine)+1+len(word) > width { + // Write the current line to the builder + wrapped.WriteString(currentLine) + wrapped.WriteByte('\n') + // Start a new line with the current word + currentLine = word + } else { + // Add the word to the current line + currentLine += " " + word + } + } + + // Append the last line + wrapped.WriteString(currentLine) + + return wrapped.String() +} + +type example struct { + command string + helper string +} + +func exampleCommands(cmdname string, examples []example) string { + padding := 0 + for _, v := range examples { + if len(v.command) > padding { + padding = len(v.command) + } + } + + var sb strings.Builder + + for i, v := range examples { + if i > 0 { + sb.WriteString("\n") + } + fmt.Fprintf( + &sb, + " %s %s %s %s", + cmdname, v.command, strings.Repeat(" ", padding-len(v.command)+3), v.helper, + ) + } + + return sb.String() +} diff --git a/wait/ping.go b/wait/ping.go deleted file mode 100644 index c481ded..0000000 --- a/wait/ping.go +++ /dev/null @@ -1,53 +0,0 @@ -package wait - -import ( - "fmt" - "net" - "sync" - "time" -) - -func (w *Wait) PingAll() error { - var wg sync.WaitGroup - - startTime := time.Now() - finished := make(chan struct{}, 1) - - go func() { - for _, host := range w.hosts { - wg.Add(1) - go w.ping(startTime, host, &wg) - } - - wg.Wait() - finished <- struct{}{} - }() - - select { - case <-finished: - return nil - case <-time.After(w.timeout): - return fmt.Errorf("%s timeout reached before all hosts were up", w.timeout) - } -} - -func (w *Wait) ping(startTime time.Time, host host, wg *sync.WaitGroup) { - defer wg.Done() - - for { - conn, err := net.Dial(host.GetProtocol(), host.String()) - if err == nil { - conn.Close() - w.log.Printf("> up: %s (after %s)", w.pad(host.String()), time.Since(startTime)) - return - } - - w.log.Printf("> down: %s", w.pad(host.String())) - time.Sleep(w.step) - } -} - -func (w *Wait) pad(str string) string { - format := fmt.Sprintf("%%-%ds", w.padding) - return fmt.Sprintf(format, str) -} diff --git a/wait/pinger.go b/wait/pinger.go new file mode 100644 index 0000000..bff73dd --- /dev/null +++ b/wait/pinger.go @@ -0,0 +1,229 @@ +package wait + +import ( + "context" + "fmt" + "os" + "os/signal" + "strings" + "syscall" + "time" + + "github.com/patrickdappollonio/wait-for/wait/probes" + "golang.org/x/sync/errgroup" +) + +// App represents the application configuration. +type App struct { + Hosts []string + Timeout time.Duration + Every time.Duration + Verbose bool + + padding int +} + +// Pinger defines the interface for a pinger. +type Pinger interface { + Bootstrap(host string) error + Ping(ctx context.Context) error +} + +// pingerRegistry holds the mapping from protocol to pinger handler. +// Add your own pinger here. +var pingerRegistry = map[string]func() Pinger{ + "tcp": func() Pinger { return &probes.TCPPinger{} }, + "udp": func() Pinger { return &probes.UDPPinger{} }, + "mysql": func() Pinger { return &probes.MySQLPinger{} }, + "postgres": func() Pinger { return &probes.PostgresPinger{} }, + "http": func() Pinger { return &probes.HTTPPinger{} }, + "https": func() Pinger { return &probes.HTTPSPinger{} }, +} + +// matchedURLItem is a helper struct to hold the URL and the raw string. +type matchedURLItem struct { + Raw string + Pinger Pinger +} + +// String returns the string representation of the URL. +func (u *matchedURLItem) String() string { + return u.Raw +} + +// stringifyHosts returns a string representation of the hosts, with all +// the URLs quoted and separated by commas. +func stringifyHosts(urls []matchedURLItem) string { + var sb strings.Builder + + for i, v := range urls { + if i > 0 { + sb.WriteString(", ") + } + + sb.WriteString(`"` + v.String() + `"`) + } + + return sb.String() +} + +// Run executes the application. +func (app *App) Run() error { + if len(app.Hosts) == 0 { + return fmt.Errorf("no hosts specified") + } + + hostItems := make([]matchedURLItem, 0, len(app.Hosts)) + for _, rawURL := range app.Hosts { + // Parse the host URL + matched, err := parseHost(rawURL) + if err != nil { + return fmt.Errorf("failed to parse host %q: %v", rawURL, err) + } + + // Calculate the padding for the output + if len(rawURL) > app.padding { + app.padding = len(matched.Raw) + } + + // Bootstrap the pinger and validate its URL + if err := matched.Pinger.Bootstrap(rawURL); err != nil { + return fmt.Errorf("failed to bootstrap host %q: %v", rawURL, err) + } + + // Append the host to the list + hostItems = append(hostItems, *matched) + } + + // Register signal handlers for early termination. + sigterm, done := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer done() + + // Create a context with a timeout for the maximum time allowed. + ctx, cancel := context.WithTimeout(context.Background(), app.Timeout) + defer cancel() + + // Document what are we doing to the user. + fmt.Fprintln( + os.Stdout, + "Waiting for hosts:", stringifyHosts(hostItems), + fmt.Sprintf("(timeout: %s, attempting every %s)", app.Timeout, app.Every), + ) + + // Create an error group. + var eg errgroup.Group + + // Iterate over all hosts and ping them. + for _, host := range hostItems { + eg.Go(app.handlePing(ctx, sigterm, host)) + } + + // Create a channel to signal when all goroutines are done. + doneChan := make(chan error) + + go func() { + // Wait for all goroutines or for the first error. + if err := eg.Wait(); err != nil { + doneChan <- err + } + + // All goroutines finished successfully. + close(doneChan) + }() + + select { + case err := <-doneChan: + // Immediately return the first error encountered. + return err + case <-ctx.Done(): + // Global timeout triggered. + return fmt.Errorf("%s timeout reached before all hosts were up", app.Timeout) + } +} + +// handlePing pings the host asynchronously and returns an error if the host +// is not reachable. +func (app *App) handlePing(ctx, sigterm context.Context, h matchedURLItem) func() error { + return func() error { + startTime := time.Now() + + // Ping right away the first time + if err := h.Pinger.Ping(ctx); err == nil { + app.printOnVerbose("> up: %s (after %s)", app.pad(h.Raw), time.Since(startTime)) + return nil // Host is reachable, break the loop. + } else { + app.printOnVerbose("> down: %s -- %s", app.pad(h.Raw), err.Error()) + } + + // Create a ticker to ping the host every `app.Every` duration. + ticker := time.NewTicker(app.Every) + defer ticker.Stop() + + for { + select { + case <-sigterm.Done(): + // User requested early termination. + return fmt.Errorf("user requested early termination") + case <-ctx.Done(): + // Timeout reached. + return fmt.Errorf("timeout reached while waiting for %q", h) + case <-ticker.C: + // Ping the host and check if it's reachable. + if err := h.Pinger.Ping(ctx); err == nil { + app.printOnVerbose("> up: %s (after %s)", app.pad(h.Raw), time.Since(startTime)) + return nil // Host is reachable, break the loop. + } else { + app.printOnVerbose("> down: %s -- %s", app.pad(h.Raw), err.Error()) + } + } + } + } +} + +// printOnVerbose prints the message to the standard output if the verbose +// flag is enabled. +func (app *App) printOnVerbose(format string, args ...interface{}) { + if app.Verbose { + fmt.Fprintf(os.Stdout, format+"\n", args...) + } +} + +// parseHost parses the host string and returns a URL. +func parseHost(hostStr string) (*matchedURLItem, error) { + // If no scheme, assume tcp + if !strings.Contains(hostStr, "://") { + hostStr = "tcp://" + hostStr + } + + // Parse the URL without url.Parse + if !strings.Contains(hostStr, "://") { + return nil, fmt.Errorf("invalid URL: %s", hostStr) + } + + // Find the scheme + schemeEnd := strings.Index(hostStr, "://") + if schemeEnd == -1 { + return nil, fmt.Errorf("invalid URL: %s", hostStr) + } + + scheme := hostStr[:schemeEnd] + + // Find the Pinger + pingerCtor, ok := pingerRegistry[scheme] + if !ok { + return nil, fmt.Errorf("no handler registered for scheme %q (host: %q)", scheme, hostStr) + } + + // Return the URL and the Pinger + return &matchedURLItem{ + Raw: hostStr, + Pinger: pingerCtor(), + }, nil +} + +// pad pads the string to the configured padding based on the longest host +// full string URL representation (including protocol). +func (app *App) pad(str string) string { + format := fmt.Sprintf("%%-%ds", app.padding) + return fmt.Sprintf(format, str) +} diff --git a/wait/pinger_test.go b/wait/pinger_test.go new file mode 100644 index 0000000..8478842 --- /dev/null +++ b/wait/pinger_test.go @@ -0,0 +1,109 @@ +package wait + +import ( + "testing" +) + +func TestParseHost(t *testing.T) { + tests := []struct { + name string + hostStr string + want *matchedURLItem + wantErr bool + }{ + { + name: "Valid TCP URL", + hostStr: "example.com", + want: &matchedURLItem{ + Raw: "tcp://example.com", + Pinger: pingerRegistry["tcp"](), + }, + wantErr: false, + }, + { + name: "Valid HTTP URL", + hostStr: "http://example.com", + want: &matchedURLItem{ + Raw: "http://example.com", + Pinger: pingerRegistry["http"](), + }, + wantErr: false, + }, + { + name: "Invalid URL without scheme", + hostStr: "://example.com", + wantErr: true, + }, + { + name: "Invalid URL with unknown scheme", + hostStr: "unknown://example.com", + wantErr: true, + }, + { + name: "Invalid URL with no scheme and no host", + hostStr: "://", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseHost(tt.hostStr) + if (err != nil) != tt.wantErr { + t.Errorf("parseHost() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr && got.Raw != tt.want.Raw { + t.Errorf("parseHost() got = %v, want %v", got.Raw, tt.want.Raw) + } + + if !tt.wantErr && got.Pinger == nil { + t.Errorf("parseHost() got Pinger = nil, want non-nil") + } + }) + } +} + +func TestStringifyHosts(t *testing.T) { + tests := []struct { + name string + urls []matchedURLItem + want string + }{ + { + name: "Single URL", + urls: []matchedURLItem{ + { + Raw: "http://example.com", + }, + }, + want: `"http://example.com"`, + }, + { + name: "Multiple URLs", + urls: []matchedURLItem{ + { + Raw: "http://example.com", + }, + { + Raw: "https://example.org", + }, + }, + want: `"http://example.com", "https://example.org"`, + }, + { + name: "No URLs", + urls: []matchedURLItem{}, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := stringifyHosts(tt.urls); got != tt.want { + t.Errorf("stringifyHosts() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/wait/probes/helper.go b/wait/probes/helper.go new file mode 100644 index 0000000..017163d --- /dev/null +++ b/wait/probes/helper.go @@ -0,0 +1,64 @@ +package probes + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" +) + +// oneOf returns true if the first argument is equal to any of the +// following arguments. +func oneOf[T comparable](s T, values ...T) bool { + for _, v := range values { + if s == v { + return true + } + } + + return false +} + +// unwrapError recursively unwraps the error to get the root cause. +func unwrapError(err error) error { + if unwrapped := errors.Unwrap(err); unwrapped != nil { + return unwrapError(unwrapped) + } + + return err +} + +// doGet performs a GET request to the given URL with the provided client +// and context, then checks the status code to ensure it is in the 2xx range. +func doGet(ctx context.Context, client *http.Client, url string) error { + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return fmt.Errorf("error creating request: %w", err) + } + + resp, err := client.Do(req) + if err != nil { + return unwrapError(err) + } + resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode > 299 { + return fmt.Errorf("received non-2xx status code: %d %s", resp.StatusCode, http.StatusText(resp.StatusCode)) + } + + return nil +} + +// extractProtocol extracts the protocol from the host string. +// If no protocol is found, an empty string is returned. +func extractProtocol(host string) string { + // Find if there's a "://" in the host string. + // If there is, extract the protocol. + // If there isn't, assume it's a hostname and return an empty string. + if i := strings.Index(host, "://"); i >= 0 { + return host[:i] + } + + return "" +} diff --git a/wait/probes/helper_test.go b/wait/probes/helper_test.go new file mode 100644 index 0000000..7198e57 --- /dev/null +++ b/wait/probes/helper_test.go @@ -0,0 +1,215 @@ +package probes + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "testing" +) + +func TestOneOf(t *testing.T) { + tests := []struct { + name string + s interface{} + values []interface{} + want bool + }{ + { + name: "String present in list", + s: "apple", + values: []interface{}{"banana", "apple", "cherry"}, + want: true, + }, + { + name: "String not present in list", + s: "grape", + values: []interface{}{"banana", "apple", "cherry"}, + want: false, + }, + { + name: "Integer present in list", + s: 42, + values: []interface{}{1, 2, 42, 100}, + want: true, + }, + { + name: "Integer not present in list", + s: 99, + values: []interface{}{1, 2, 42, 100}, + want: false, + }, + { + name: "Empty list", + s: "test", + values: []interface{}{}, + want: false, + }, + { + name: "Single element list, present", + s: "single", + values: []interface{}{"single"}, + want: true, + }, + { + name: "Single element list, not present", + s: "single", + values: []interface{}{"not_single"}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := oneOf(tt.s, tt.values...) + if got != tt.want { + t.Errorf("oneOf() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUnwrapError(t *testing.T) { + tests := []struct { + name string + err error + want error + }{ + { + name: "No wrapping", + err: errors.New("root error"), + want: errors.New("root error"), + }, + { + name: "Single wrapping", + err: fmt.Errorf("wrapped: %w", errors.New("root error")), + want: errors.New("root error"), + }, + { + name: "Double wrapping", + err: fmt.Errorf("wrapped again: %w", fmt.Errorf("wrapped: %w", errors.New("root error"))), + want: errors.New("root error"), + }, + { + name: "Nil error", + err: nil, + want: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := unwrapError(tt.err) + if got == nil && tt.want != nil { + t.Errorf("unwrapError() = nil, want %v", tt.want) + } else if got != nil && tt.want == nil { + t.Errorf("unwrapError() = %v, want nil", got) + } else if got != nil && got.Error() != tt.want.Error() { + t.Errorf("unwrapError() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDoGet(t *testing.T) { + tests := []struct { + name string + url string + statusCode int + wantErr bool + errMsg string + }{ + { + name: "Status OK", + url: "/", + statusCode: http.StatusOK, + wantErr: false, + }, + { + name: "Status Created", + url: "/", + statusCode: http.StatusCreated, + wantErr: false, + }, + { + name: "Status Bad Request", + url: "/", + statusCode: http.StatusBadRequest, + wantErr: true, + errMsg: "received non-2xx status code: 400 Bad Request", + }, + { + name: "Status Internal Server Error", + url: "/", + statusCode: http.StatusInternalServerError, + wantErr: true, + errMsg: "received non-2xx status code: 500 Internal Server Error", + }, + { + name: "Invalid URL", + url: "http://[::1]:namedport", + wantErr: true, + errMsg: "error creating request: parse \"http://[::1]:namedport\": invalid port \":namedport\" after host", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a test server that responds with the specified status code + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tt.statusCode) + })) + defer server.Close() + + client := server.Client() + ctx := context.Background() + + url := server.URL + tt.url + if tt.name == "Invalid URL" { + url = tt.url // Use the invalid URL directly + } + + err := doGet(ctx, client, url) + if (err != nil) != tt.wantErr { + t.Errorf("doGet() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err != nil && err.Error() != tt.errMsg { + t.Errorf("doGet() error = %v, errMsg %v", err, tt.errMsg) + } + }) + } +} + +func Test_extractProtocol(t *testing.T) { + tests := []struct { + name string + host string + def string + want string + }{ + { + name: "No protocol", + host: "example.com", + want: "", + }, + { + name: "TCP protocol", + host: "tcp://example.com", + want: "tcp", + }, + { + name: "HTTP protocol", + host: "http://example.com", + want: "http", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := extractProtocol(tt.host); got != tt.want { + t.Errorf("extractProtocol() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/wait/probes/https.go b/wait/probes/https.go new file mode 100644 index 0000000..eebd0ee --- /dev/null +++ b/wait/probes/https.go @@ -0,0 +1,101 @@ +// File: probes/https_pinger.go +package probes + +import ( + "context" + "crypto/tls" + "fmt" + "net/http" + "net/url" + "time" +) + +// validateURL checks if the URL is valid. There are no checks +// to validate if the URL is nil. +func validateURL(scheme string, u *url.URL) error { + // Validate the URL scheme + if u.Scheme != scheme { + return fmt.Errorf("invalid scheme: %s", u.Scheme) + } + + if !u.IsAbs() { + return fmt.Errorf("invalid URL: %s", u.String()) + } + + if u.Hostname() == "" { + return fmt.Errorf("no host specified: %s", u.String()) + } + + return nil +} + +// HTTPSPinger is a pinger for HTTPS connections. +type HTTPSPinger struct { + url *url.URL + httpClient *http.Client +} + +// Bootstrap sets up the pinger with the HTTPS URL. +func (h *HTTPSPinger) Bootstrap(host string) error { + u, err := url.Parse(host) + if err != nil { + return fmt.Errorf("failed to parse host %q: %v", host, err) + } + + if err := validateURL("https", u); err != nil { + return err + } + + h.url = u + + // Initialize HTTPS client with timeout and TLS configuration + h.httpClient = &http.Client{ + Timeout: 1 * time.Second, // 1 second timeout per request + + // Ensure TLS certificate verification is enabled + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: false, // Do not skip TLS verification + }, + }, + } + + return nil +} + +// Ping performs an HTTPS GET request and checks the status code. +func (h *HTTPSPinger) Ping(ctx context.Context) error { + return doGet(ctx, h.httpClient, h.url.String()) +} + +// HTTPPinger is a pinger for HTTP connections. +type HTTPPinger struct { + url *url.URL + HTTPClient *http.Client +} + +// Bootstrap sets up the pinger with the HTTP URL. +func (h *HTTPPinger) Bootstrap(host string) error { + u, err := url.Parse(host) + if err != nil { + return fmt.Errorf("failed to parse host %q: %v", host, err) + } + + if err := validateURL("http", u); err != nil { + return err + } + + h.url = u + + // Initialize HTTP client with timeout for each request + h.HTTPClient = &http.Client{ + Timeout: 1 * time.Second, // 1 second timeout per request + } + + return nil +} + +// Ping performs an HTTP GET request and checks the status code. +func (h *HTTPPinger) Ping(ctx context.Context) error { + return doGet(ctx, h.HTTPClient, h.url.String()) +} diff --git a/wait/probes/https_test.go b/wait/probes/https_test.go new file mode 100644 index 0000000..bc7cf58 --- /dev/null +++ b/wait/probes/https_test.go @@ -0,0 +1,227 @@ +package probes + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestHTTPPinger_Bootstrap(t *testing.T) { + tests := []struct { + name string + urlStr string + wantErr bool + }{ + { + name: "Valid URL", + urlStr: "http://example.com", + wantErr: false, + }, + { + name: "No host specified", + urlStr: "http://", + wantErr: true, + }, + { + name: "Invalid URL", + urlStr: "://example.com", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pinger := &HTTPPinger{} + if err := pinger.Bootstrap(tt.urlStr); (err != nil) != tt.wantErr { + t.Errorf("HTTPPinger.Bootstrap() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestHTTPPinger_Ping(t *testing.T) { + tests := []struct { + name string + statusCode int + wantErr bool + }{ + { + name: "Status OK", + statusCode: http.StatusOK, + wantErr: false, + }, + { + name: "Status Not Found", + statusCode: http.StatusNotFound, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a test server that returns the specified status code + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tt.statusCode) + })) + defer ts.Close() + + pinger := &HTTPPinger{} + if err := pinger.Bootstrap(ts.URL); err != nil { + t.Fatalf("HTTPPinger.Bootstrap() error = %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + if err := pinger.Ping(ctx); (err != nil) != tt.wantErr { + t.Errorf("HTTPPinger.Ping() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func generateSelfSignedCert() (tls.Certificate, error) { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return tls.Certificate{}, err + } + + notBefore := time.Now() + notAfter := notBefore.Add(365 * 24 * time.Hour) + + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return tls.Certificate{}, err + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Example Co"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return tls.Certificate{}, err + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + keyDER, err := x509.MarshalECPrivateKey(priv) + if err != nil { + return tls.Certificate{}, err + } + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) + + return tls.X509KeyPair(certPEM, keyPEM) +} + +func TestHTTPSPinger_Bootstrap(t *testing.T) { + tests := []struct { + name string + urlStr string + wantErr bool + }{ + { + name: "Valid URL", + urlStr: "https://example.com", + wantErr: false, + }, + { + name: "No host specified", + urlStr: "https://", + wantErr: true, + }, + { + name: "Invalid URL", + urlStr: "://example.com", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pinger := &HTTPSPinger{} + if err := pinger.Bootstrap(tt.urlStr); (err != nil) != tt.wantErr { + t.Errorf("HTTPSPinger.Bootstrap() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestHTTPSPinger_Ping(t *testing.T) { + cert, err := generateSelfSignedCert() + if err != nil { + t.Fatalf("Failed to generate self-signed certificate: %v", err) + } + + tests := []struct { + name string + statusCode int + wantErr bool + }{ + { + name: "Status OK", + statusCode: http.StatusOK, + wantErr: false, + }, + { + name: "Status Not Found", + statusCode: http.StatusNotFound, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tt.statusCode) + })) + ts.TLS = &tls.Config{Certificates: []tls.Certificate{cert}} + ts.StartTLS() + defer ts.Close() + + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Certificate[0]})) + + pinger := &HTTPSPinger{} + if err := pinger.Bootstrap(ts.URL); err != nil { + t.Fatalf("HTTPSPinger.Bootstrap() error = %v", err) + } + + // Override the bootstrapped client + pinger.httpClient = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: caCertPool, + }, + }, + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + if err := pinger.Ping(ctx); (err != nil) != tt.wantErr { + t.Errorf("HTTPSPinger.Ping() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/wait/probes/mysql.go b/wait/probes/mysql.go new file mode 100644 index 0000000..529e383 --- /dev/null +++ b/wait/probes/mysql.go @@ -0,0 +1,51 @@ +package probes + +import ( + "context" + "database/sql" + "fmt" + "net/url" + "time" + + _ "github.com/go-sql-driver/mysql" +) + +// MySQLPinger is a pinger for MySQL connections. +type MySQLPinger struct { + DSN string +} + +// Bootstrap sets up the pinger with the URL. +// Expected URL format: mysql://user:password@host:port/dbname +func (m *MySQLPinger) Bootstrap(host string) error { + u, err := url.Parse(host) + if err != nil { + return fmt.Errorf("failed to parse host %q: %v", host, err) + } + + hostname := u.Host + user := u.User.Username() + pass, _ := u.User.Password() + + if user == "" { + user = "root" + } + + // We use the "tcp(host:port)" format for MySQL driver. + m.DSN = fmt.Sprintf("%s:%s@tcp(%s)/", user, pass, hostname) + return nil +} + +// Ping attempts to connect to the host and ping the database. +func (m *MySQLPinger) Ping(ctx context.Context) error { + db, err := sql.Open("mysql", m.DSN) + if err != nil { + return err + } + defer db.Close() + + ctx, cancel := context.WithTimeout(ctx, 1*time.Second) + defer cancel() + + return db.PingContext(ctx) +} diff --git a/wait/probes/postgres.go b/wait/probes/postgres.go new file mode 100644 index 0000000..0998082 --- /dev/null +++ b/wait/probes/postgres.go @@ -0,0 +1,69 @@ +package probes + +import ( + "context" + "fmt" + "net/url" + "strings" + "time" + + "github.com/jackc/pgx/v5" +) + +// PostgresPinger is a pinger for PostgreSQL connections. +type PostgresPinger struct { + DSN string +} + +// Bootstrap sets up the pinger with the PostgreSQL URL. +// Expected URL format: postgres://user:password@host:port/dbname +func (p *PostgresPinger) Bootstrap(host string) error { + u, err := url.Parse(host) + if err != nil { + return fmt.Errorf("failed to parse host %q: %v", host, err) + } + + // Extract user credentials + user := u.User.Username() + pass, _ := u.User.Password() + + // Extract host (includes port if specified) + hostname := u.Host + + // Extract database name, trimming the leading '/' + dbname := strings.TrimPrefix(u.Path, "/") + if dbname == "" { + return fmt.Errorf("no database name specified in the URL") + } + + // Handle query parameters + queryParams := u.Query() + + // Construct the DSN (Data Source Name) + // Example: postgres://user:password@host:port/dbname + p.DSN = fmt.Sprintf("postgres://%s:%s@%s/%s?%s", + user, pass, hostname, dbname, queryParams.Encode()) + + return nil +} + +// Ping attempts to connect to the PostgreSQL database and ping it. +func (p *PostgresPinger) Ping(ctx context.Context) error { + // Open a connection to the database + db, err := pgx.Connect(ctx, p.DSN) + if err != nil { + return fmt.Errorf("error opening PostgreSQL connection: %w", err) + } + defer db.Close(ctx) + + // Set a short timeout for the ping + ctx, cancel := context.WithTimeout(ctx, 1*time.Second) + defer cancel() + + // Attempt to ping the database + if err := db.Ping(ctx); err != nil { + return fmt.Errorf("error pinging PostgreSQL database: %w", err) + } + + return nil +} diff --git a/wait/probes/tcp.go b/wait/probes/tcp.go new file mode 100644 index 0000000..fef5246 --- /dev/null +++ b/wait/probes/tcp.go @@ -0,0 +1,48 @@ +package probes + +import ( + "context" + "fmt" + "net" + "net/url" + "time" +) + +// TCPPinger is a pinger for TCP connections. +type TCPPinger struct { + Host string +} + +// Bootstrap sets up the pinger with the URL. +func (t *TCPPinger) Bootstrap(host string) error { + if proto := extractProtocol(host); proto == "" { + host = "tcp://" + host + } + + u, err := url.Parse(host) + if err != nil { + return fmt.Errorf("failed to parse host %q: %v", host, err) + } + + if u.Host == "" { + return fmt.Errorf("no host specified for tcp scheme") + } + + if !oneOf(u.Scheme, "tcp", "tcp4", "tcp6") { + return fmt.Errorf("invalid scheme for tcp probe: %s", u.Scheme) + } + + t.Host = u.Host + return nil +} + +// Ping attempts to connect to the host. +func (t *TCPPinger) Ping(ctx context.Context) error { + d := net.Dialer{Timeout: 1 * time.Second} + conn, err := d.DialContext(ctx, "tcp", t.Host) + if err != nil { + return err + } + conn.Close() + return nil +} diff --git a/wait/probes/tcp_test.go b/wait/probes/tcp_test.go new file mode 100644 index 0000000..b9f6792 --- /dev/null +++ b/wait/probes/tcp_test.go @@ -0,0 +1,97 @@ +package probes + +import ( + "context" + "net" + "testing" + "time" +) + +func TestTCPPinger_Bootstrap(t *testing.T) { + tests := []struct { + name string + urlStr string + wantErr bool + }{ + { + name: "Valid URL", + urlStr: "tcp://example.com:80", + wantErr: false, + }, + { + name: "No host specified", + urlStr: "tcp://", + wantErr: true, + }, + { + name: "Invalid scheme", + urlStr: "http://example.com:80", + wantErr: true, + }, + { + name: "No scheme", + urlStr: "example.com:80", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pinger := &TCPPinger{} + if err := pinger.Bootstrap(tt.urlStr); (err != nil) != tt.wantErr { + t.Errorf("TCPPinger.Bootstrap() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestTCPPinger_Ping(t *testing.T) { + // Launch a local server to test the pinger + chPort := make(chan string, 1) + + go func() { + srv, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Errorf("net.Listen() error = %v", err) + } + defer srv.Close() + + chPort <- srv.Addr().String() + + for { + conn, err := srv.Accept() + if err != nil { + return + } + conn.Close() + } + }() + + tests := []struct { + name string + host string + wantErr bool + }{ + { + name: "Valid host", + host: <-chPort, + wantErr: false, + }, + { + name: "Invalid host", + host: "invalidhost:80", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pinger := &TCPPinger{Host: tt.host} + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if err := pinger.Ping(ctx); (err != nil) != tt.wantErr { + t.Errorf("TCPPinger.Ping() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/wait/probes/udp.go b/wait/probes/udp.go new file mode 100644 index 0000000..7d40195 --- /dev/null +++ b/wait/probes/udp.go @@ -0,0 +1,48 @@ +package probes + +import ( + "context" + "fmt" + "net" + "net/url" + "time" +) + +// UDPPinger is a pinger for UDP connections. +type UDPPinger struct { + Host string +} + +// Bootstrap sets up the pinger with the URL. +func (u *UDPPinger) Bootstrap(host string) error { + url, err := url.Parse(host) + if err != nil { + return fmt.Errorf("failed to parse host %q: %v", host, err) + } + + if url.Host == "" { + return fmt.Errorf("no host specified for udp scheme") + } + + if !oneOf(url.Scheme, "udp", "udp4", "udp6") { + return fmt.Errorf("invalid scheme for udp probe: %s", url.Scheme) + } + + u.Host = url.Host + return nil +} + +// Ping attempts to send a datagram to the host. +func (u *UDPPinger) Ping(ctx context.Context) error { + // For UDP "ping", we can attempt to send a datagram and check for error. + // Unlike TCP, we don't get a "connected" state just by dialing. + conn, err := net.DialTimeout("udp", u.Host, 1*time.Second) + if err != nil { + return err + } + defer conn.Close() + + // Send a zero-length packet just to see if it errors out. + _, err = conn.Write([]byte{}) + return err +} diff --git a/wait/probes/udp_test.go b/wait/probes/udp_test.go new file mode 100644 index 0000000..4153754 --- /dev/null +++ b/wait/probes/udp_test.go @@ -0,0 +1,97 @@ +package probes + +import ( + "context" + "net" + "testing" + "time" +) + +func TestUDPPinger_Bootstrap(t *testing.T) { + tests := []struct { + name string + urlStr string + wantErr bool + }{ + { + name: "Valid URL", + urlStr: "udp://example.com:80", + wantErr: false, + }, + { + name: "No host specified", + urlStr: "udp://", + wantErr: true, + }, + { + name: "Invalid scheme", + urlStr: "http://example.com:80", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pinger := &UDPPinger{} + if err := pinger.Bootstrap(tt.urlStr); (err != nil) != tt.wantErr { + t.Errorf("UDPPinger.Bootstrap() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestUDPPinger_Ping(t *testing.T) { + // Launch a local server to test the pinger + chPort := make(chan string, 1) + + go func() { + srv, err := net.ListenUDP("udp", &net.UDPAddr{Port: 0}) + if err != nil { + t.Errorf("net.Listen() error = %v", err) + } + defer srv.Close() + + chPort <- srv.LocalAddr().String() + + for { + buf := make([]byte, 1024) + n, _, err := srv.ReadFromUDP(buf) + if err != nil { + t.Errorf("srv.ReadFrom() error = %v", err) + } + + if n >= 0 { + srv.Close() + break + } + } + }() + + tests := []struct { + name string + host string + wantErr bool + }{ + { + name: "Valid host", + host: <-chPort, + wantErr: false, + }, + { + name: "Invalid host", + host: "invalidhost:80", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pinger := &UDPPinger{Host: tt.host} + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if err := pinger.Ping(ctx); (err != nil) != tt.wantErr { + t.Errorf("UDPPinger.Ping() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/wait/wait.go b/wait/wait.go deleted file mode 100644 index cd05d1e..0000000 --- a/wait/wait.go +++ /dev/null @@ -1,124 +0,0 @@ -package wait - -import ( - "fmt" - "io" - "log" - "net" - "os" - "regexp" - "strings" - "time" -) - -type proto string - -const ( - tcp proto = "tcp" - udp proto = "udp" -) - -type host struct { - host string - port string - protocol proto -} - -func (h host) String() string { - return fmt.Sprintf("%s:%s", h.host, h.port) -} - -func (h host) GetProtocol() string { - if h.protocol == "" { - return string(tcp) - } - - return string(h.protocol) -} - -func stringifyHosts(hosts []host) string { - var sb strings.Builder - - for i, v := range hosts { - if i > 0 { - sb.WriteString(", ") - } - - sb.WriteString(`"` + fmt.Sprintf("%s://%s:%s", v.GetProtocol(), v.host, v.port) + `"`) - } - - return sb.String() -} - -type Wait struct { - hosts []host - timeout time.Duration - step time.Duration - log *log.Logger - padding int -} - -var reLooksLikeProtocol = regexp.MustCompile(`^(\w+)://`) - -func New(hosts []string, step, timeout time.Duration, verbose bool) (*Wait, error) { - w := &Wait{ - timeout: timeout, - step: step, - } - - if len(hosts) == 0 { - return nil, fmt.Errorf("no hosts specified") - } - - full := make([]host, 0, len(hosts)) - for _, v := range hosts { - if len(v) > w.padding { - w.padding = len(v) - } - - var proto proto - - if strings.HasPrefix(v, "tcp://") { - proto = tcp - v = strings.TrimPrefix(v, "tcp://") - } - - if strings.HasPrefix(v, "udp://") { - proto = udp - v = strings.TrimPrefix(v, "udp://") - } - - if proto == "" && reLooksLikeProtocol.MatchString(v) { - return nil, fmt.Errorf("invalid protocol specified: %q -- only \"tcp\" and \"udp\" are supported", v) - } - - parsedHost, parsedPort, err := net.SplitHostPort(v) - if err != nil { - return nil, fmt.Errorf("invalid host format: %q -- must be in the format \"host:port\" or \"(tcp|udp)://host:port\"", v) - } - - full = append(full, host{ - host: parsedHost, - port: parsedPort, - protocol: proto, - }) - } - - w.hosts = full - w.log = log.New(io.Discard, "", 0) - - if verbose { - w.log.SetOutput(os.Stdout) - } - - return w, nil -} - -func (w *Wait) String() string { - return fmt.Sprintf( - "Waiting for hosts: %s (timeout: %s, attempting every %s)", - stringifyHosts(w.hosts), - w.timeout, - w.step, - ) -}