From 29b358a7608e83e92fc613dbf55527c484f3a982 Mon Sep 17 00:00:00 2001 From: Thomas Montfort <61255722+tmonty12@users.noreply.github.com> Date: Wed, 13 Nov 2024 11:56:13 -0800 Subject: [PATCH] autostop removal (#208) Co-authored-by: Alec Fong --- bin/install-latest-linux.sh | 30 - bin/install-latest.sh | 24 - pkg/autostartconf/autostartconf.go | 52 - pkg/cmd/autostop/autostop.go | 86 - pkg/cmd/background/background.go | 34 - pkg/cmd/bmon/bmon.go | 51 - pkg/cmd/cmd.go | 11 - pkg/cmd/envsetup/envsetup.go | 743 ------ pkg/cmd/envsetup/envsetup_test.go | 87 - pkg/cmd/envsetup/motd.sh | 60 - pkg/cmd/envsetup/speedtest.py | 2184 ----------------- .../optimizeinstances/optimizeinstances.go | 566 ----- pkg/cmd/postinstall/postinstall.go | 113 - pkg/cmd/upgrade/upgrade.go | 137 -- pkg/store/autostop.go | 65 - pkg/store/workspace.go | 17 - 16 files changed, 4260 deletions(-) delete mode 100644 bin/install-latest-linux.sh delete mode 100755 bin/install-latest.sh delete mode 100644 pkg/cmd/autostop/autostop.go delete mode 100644 pkg/cmd/bmon/bmon.go delete mode 100644 pkg/cmd/envsetup/envsetup.go delete mode 100644 pkg/cmd/envsetup/envsetup_test.go delete mode 100644 pkg/cmd/envsetup/motd.sh delete mode 100644 pkg/cmd/envsetup/speedtest.py delete mode 100644 pkg/cmd/optimizeinstances/optimizeinstances.go delete mode 100644 pkg/cmd/postinstall/postinstall.go delete mode 100644 pkg/cmd/upgrade/upgrade.go delete mode 100644 pkg/store/autostop.go diff --git a/bin/install-latest-linux.sh b/bin/install-latest-linux.sh deleted file mode 100644 index 7f832110..00000000 --- a/bin/install-latest-linux.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/usr/bin/env bash -# Install the latest version of the Linux binary - -# Install the latest version of the Linux binary - -# Get THE DOWNLOAD URL -DOWNLOAD_URL=$(curl -s https://brevapi.us-west-2-prod.control-plane.brev.dev/api/autostop/cli-download-url) - -# download the tar to a tmp directory - -TMP_DIR=$(mktemp -d) - -wget --directory-prefix=$TMP_DIR $DOWNLOAD_URL - -# extract the tar to the bin directory -tar -xzf $TMP_DIR/brev* -C $TMP_DIR # glob is a hack to get the filename - -# move the binary to the bin directory -sudo mv $TMP_DIR/brev /usr/local/bin/brev - -# remove the tmp directory -rm -rf $TMP_DIR - -# make the binary executable -chmod +x /usr/local/bin/brev - -# run post install commands, write now creates a file in etc -# to store email so needs root - -sudo brev postinstall diff --git a/bin/install-latest.sh b/bin/install-latest.sh deleted file mode 100755 index 0afdad8f..00000000 --- a/bin/install-latest.sh +++ /dev/null @@ -1,24 +0,0 @@ -#!/usr/bin/env bash -# Install the latest version of the Linux binary - -set -eo pipefail - -# Get THE DOWNLOAD URL -DOWNLOAD_URL=$(curl -s https://brevapi.us-west-2-prod.control-plane.brev.dev/api/autostop/cli-download-url) - -# download the tar to a tmp directory - -TMP_DIR=$(mktemp -d) -curl -sL "$DOWNLOAD_URL" -o "$TMP_DIR/brev.tar.gz" - -# extract the tar to the bin directory -tar -xzf "$TMP_DIR/brev.tar.gz" -C "$TMP_DIR" - -# move the binary to the bin directory -mv "$TMP_DIR/brev" /usr/local/bin/brev - -# remove the tmp directory -rm -rf "$TMP_DIR" - -# make the binary executable -chmod +x /usr/local/bin/brev diff --git a/pkg/autostartconf/autostartconf.go b/pkg/autostartconf/autostartconf.go index bcf3c346..db1f1a8f 100644 --- a/pkg/autostartconf/autostartconf.go +++ b/pkg/autostartconf/autostartconf.go @@ -236,55 +236,3 @@ User=` + store.GetOSUser() + ` } return nil } - -func NewBrevMonConfigure( - store AutoStartStore, - disableAutostop bool, - reportInterval string, - portToCheckTrafficOn string, -) DaemonConfigurer { - configFile := fmt.Sprintf(`[Unit] - Description=brevmon - After=network.target - - [Service] - User=root - Type=exec - ExecStart=/usr/local/bin/brevmon %s - ExecReload=/usr/local/bin/brevmon %s - Restart=always - - [Install] - WantedBy=default.target - `, portToCheckTrafficOn, portToCheckTrafficOn) - if disableAutostop { - configFile = fmt.Sprintf(`[Unit] -Description=brevmon -After=network.target - -[Service] -User=root -Type=exec -ExecStart=/usr/local/bin/brevmon %s --disable-autostop --report-interval `+reportInterval+` -ExecReload=/usr/local/bin/brevmon %s --disable-autostop --report-interval `+reportInterval+` -Restart=always - -[Install] -WantedBy=default.target -`, portToCheckTrafficOn, portToCheckTrafficOn) - } - return AptBinaryConfigurer{ - LinuxSystemdConfigurer: LinuxSystemdConfigurer{ - Store: store, - ValueConfigFile: configFile, - ServiceName: "brevmon.service", - ServiceType: "system", - }, - - URL: "https://s3.amazonaws.com/brevmon.brev.dev/brevmon.tar.gz", - Name: "brevmon", - aptDependencies: []string{ - "libpcap-dev", - }, - } -} diff --git a/pkg/cmd/autostop/autostop.go b/pkg/cmd/autostop/autostop.go deleted file mode 100644 index 5911fb33..00000000 --- a/pkg/cmd/autostop/autostop.go +++ /dev/null @@ -1,86 +0,0 @@ -package autostop - -import ( - "github.com/hashicorp/go-multierror" - "github.com/samber/lo" - "github.com/samber/mo" - "github.com/spf13/cobra" - - "github.com/brevdev/brev-cli/pkg/entity" - breverrors "github.com/brevdev/brev-cli/pkg/errors" - "github.com/brevdev/brev-cli/pkg/terminal" -) - -var ( - short = "TODO" - long = "TODO" - example = "TODO" -) - -func NewCmdautostop(t *terminal.Terminal, store autostopStore) *cobra.Command { - cmd := &cobra.Command{ - Use: "autostop", - DisableFlagsInUseLine: true, - Short: short, - Long: long, - Example: example, - RunE: func(cmd *cobra.Command, args []string) error { - err := Runautostop( - runAutostopArgs{ - t: t, - args: args, - store: store, - }, - ) - if err != nil { - return breverrors.WrapAndTrace(err) - } - return nil - }, - } - return cmd -} - -type autostopStore interface { - AutoStopWorkspace(workspaceID string) (*entity.Workspace, error) - GetActiveOrganizationOrDefault() (*entity.Organization, error) - GetWorkspaceByNameOrID(orgID string, nameOrID string) ([]entity.Workspace, error) -} - -type runAutostopArgs struct { - t *terminal.Terminal - args []string - store autostopStore -} - -func Runautostop(args runAutostopArgs) error { - org, err := args.store.GetActiveOrganizationOrDefault() - if err != nil { - return breverrors.WrapAndTrace(err) - } - envs, err := args.store.GetWorkspaceByNameOrID(org.ID, args.args[0]) - if err != nil { - return breverrors.WrapAndTrace(err) - } - - asResults := lo.Map( - envs, - func(env entity.Workspace, _ int) mo.Result[*entity.Workspace] { - return mo.TupleToResult(args.store.AutoStopWorkspace(env.ID)) - }, - ) - err = lo.Reduce( - asResults, - func(acc error, res mo.Result[*entity.Workspace], _ int) error { - if res.IsError() { - return multierror.Append(acc, res.Error()) - } - return acc - }, - nil, - ) - if err != nil { - return breverrors.WrapAndTrace(err) - } - return nil -} diff --git a/pkg/cmd/background/background.go b/pkg/cmd/background/background.go index 5dd38501..273f1f44 100644 --- a/pkg/cmd/background/background.go +++ b/pkg/cmd/background/background.go @@ -28,28 +28,6 @@ type BackgroundStore interface { CreateWorkspace(organizationID string, options *store.CreateWorkspacesOptions) (*entity.Workspace, error) } -func DisableAutoStop(s BackgroundStore, workspaceID string) error { - isStoppable := false - _, err := s.ModifyWorkspace(workspaceID, &store.ModifyWorkspaceRequest{ - IsStoppable: &isStoppable, - }) - if err != nil { - return breverrors.WrapAndTrace(err) - } - return nil -} - -func EnableAutoStop(s BackgroundStore, workspaceID string) error { - isStoppable := true - _, err := s.ModifyWorkspace(workspaceID, &store.ModifyWorkspaceRequest{ - IsStoppable: &isStoppable, - }) - if err != nil { - return breverrors.WrapAndTrace(err) - } - return nil -} - func NewCmdBackground(t *terminal.Terminal, s BackgroundStore) *cobra.Command { cmd := &cobra.Command{ Annotations: map[string]string{"workspace": ""}, @@ -88,7 +66,6 @@ func NewCmdBackground(t *terminal.Terminal, s BackgroundStore) *cobra.Command { log.Fatal(err) } - checkAutoStop(s) // Run the command in the background using nohup c := exec.Command("nohup", "bash", "-c", command+">"+logsDir+"/log.txt 2>&1 &") // #nosec G204 err = c.Start() @@ -133,17 +110,6 @@ func NewCmdBackground(t *terminal.Terminal, s BackgroundStore) *cobra.Command { return cmd } -func checkAutoStop(s BackgroundStore) { - wsID, err := s.GetCurrentWorkspaceID() - if err == nil && wsID != "" { - // Disable auto stop - err = DisableAutoStop(s, wsID) - if err != nil { - log.Fatal(err) - } - } -} - func pushBackgroundAnalytics(s BackgroundStore) error { // Call analytics for open userID := "" diff --git a/pkg/cmd/bmon/bmon.go b/pkg/cmd/bmon/bmon.go deleted file mode 100644 index de1eafd2..00000000 --- a/pkg/cmd/bmon/bmon.go +++ /dev/null @@ -1,51 +0,0 @@ -package bmon - -import ( - "github.com/spf13/cobra" - - "github.com/brevdev/brev-cli/pkg/autostartconf" - breverrors "github.com/brevdev/brev-cli/pkg/errors" - "github.com/brevdev/brev-cli/pkg/terminal" -) - -var ( - short = "TODO" - long = "TODO" - example = "TODO" -) - -type bmonStore interface { - autostartconf.AutoStartStore -} - -func NewCmdbmon(t *terminal.Terminal, store bmonStore) *cobra.Command { - cmd := &cobra.Command{ - Use: "bmon", - DisableFlagsInUseLine: true, - Short: short, - Long: long, - Example: example, - RunE: func(cmd *cobra.Command, args []string) error { - err := Runbmon(t, args, store) - if err != nil { - return breverrors.WrapAndTrace(err) - } - return nil - }, - } - return cmd -} - -func Runbmon(_ *terminal.Terminal, _ []string, store bmonStore) error { - bmonConfig := autostartconf.NewBrevMonConfigure( - store, - false, - "10m", - "22", - ) - err := bmonConfig.Install() - if err != nil { - return breverrors.WrapAndTrace(err) - } - return nil -} diff --git a/pkg/cmd/cmd.go b/pkg/cmd/cmd.go index 07ad9b06..ddecd49d 100644 --- a/pkg/cmd/cmd.go +++ b/pkg/cmd/cmd.go @@ -5,15 +5,12 @@ import ( "fmt" "github.com/brevdev/brev-cli/pkg/auth" - "github.com/brevdev/brev-cli/pkg/cmd/autostop" "github.com/brevdev/brev-cli/pkg/cmd/background" - "github.com/brevdev/brev-cli/pkg/cmd/bmon" "github.com/brevdev/brev-cli/pkg/cmd/clipboard" "github.com/brevdev/brev-cli/pkg/cmd/configureenvvars" "github.com/brevdev/brev-cli/pkg/cmd/connect" "github.com/brevdev/brev-cli/pkg/cmd/create" "github.com/brevdev/brev-cli/pkg/cmd/delete" - "github.com/brevdev/brev-cli/pkg/cmd/envsetup" "github.com/brevdev/brev-cli/pkg/cmd/envvars" "github.com/brevdev/brev-cli/pkg/cmd/fu" "github.com/brevdev/brev-cli/pkg/cmd/healthcheck" @@ -29,7 +26,6 @@ import ( "github.com/brevdev/brev-cli/pkg/cmd/open" "github.com/brevdev/brev-cli/pkg/cmd/org" "github.com/brevdev/brev-cli/pkg/cmd/portforward" - "github.com/brevdev/brev-cli/pkg/cmd/postinstall" "github.com/brevdev/brev-cli/pkg/cmd/profile" "github.com/brevdev/brev-cli/pkg/cmd/proxy" "github.com/brevdev/brev-cli/pkg/cmd/recreate" @@ -48,7 +44,6 @@ import ( "github.com/brevdev/brev-cli/pkg/cmd/tasks" "github.com/brevdev/brev-cli/pkg/cmd/test" "github.com/brevdev/brev-cli/pkg/cmd/updatemodel" - "github.com/brevdev/brev-cli/pkg/cmd/upgrade" "github.com/brevdev/brev-cli/pkg/cmd/workspacegroups" "github.com/brevdev/brev-cli/pkg/cmd/writeconnectionevent" "github.com/brevdev/brev-cli/pkg/config" @@ -279,13 +274,7 @@ func createCmdTree(cmd *cobra.Command, t *terminal.Terminal, loginCmdStore *stor cmd.AddCommand(setupworkspace.NewCmdSetupWorkspace(noLoginCmdStore)) cmd.AddCommand(recreate.NewCmdRecreate(t, loginCmdStore)) - cmd.AddCommand(envsetup.NewCmdEnvSetup(loginCmdStore, loginAuth)) - cmd.AddCommand(postinstall.NewCmdpostinstall(t, loginCmdStore)) - cmd.AddCommand(postinstall.NewCMDOptimizeThis(t, loginCmdStore)) - cmd.AddCommand(bmon.NewCmdbmon(t, loginCmdStore)) - cmd.AddCommand(upgrade.NewCmdUpgrade(t, loginCmdStore)) cmd.AddCommand(writeconnectionevent.NewCmdwriteConnectionEvent(t, loginCmdStore)) - cmd.AddCommand(autostop.NewCmdautostop(t, loginCmdStore)) cmd.AddCommand(updatemodel.NewCmdupdatemodel(t, loginCmdStore)) } diff --git a/pkg/cmd/envsetup/envsetup.go b/pkg/cmd/envsetup/envsetup.go deleted file mode 100644 index 329a1be3..00000000 --- a/pkg/cmd/envsetup/envsetup.go +++ /dev/null @@ -1,743 +0,0 @@ -package envsetup - -import ( - "context" - "encoding/json" - "fmt" - "io" - "io/ioutil" - "log" - "os" - "os/user" - "path/filepath" - "strings" - "time" - - "github.com/hashicorp/go-multierror" - "github.com/spf13/afero" - "github.com/spf13/cobra" - - _ "embed" - - "github.com/brevdev/brev-cli/pkg/autostartconf" - "github.com/brevdev/brev-cli/pkg/cmd/updatemodel" - "github.com/brevdev/brev-cli/pkg/cmd/version" - "github.com/brevdev/brev-cli/pkg/collections" - "github.com/brevdev/brev-cli/pkg/entity" - breverrors "github.com/brevdev/brev-cli/pkg/errors" - "github.com/brevdev/brev-cli/pkg/featureflag" - "github.com/brevdev/brev-cli/pkg/setupworkspace" - "github.com/brevdev/brev-cli/pkg/store" - "github.com/brevdev/brev-cli/pkg/util" -) - -type envsetupStore interface { - GetEnvSetupParams(wsid string) (*store.SetupParamsV0, error) - WriteSetupScript(script string) error - GetSetupScriptPath() string - GetCurrentUser() (*entity.User, error) - GetCurrentWorkspaceID() (string, error) - GetOSUser() string - GetOrCreateSetupLogFile(path string) (afero.File, error) - GetBrevHomePath() (string, error) - BuildBrevHome() error - CopyBin(targetBin string) error - WriteString(path, data string) error - UserHomeDir() (string, error) - Remove(target string) error - FileExists(target string) (bool, error) - DownloadBinary(url string, target string) error - AppendString(path string, content string) error - Chmod(path string, mode os.FileMode) error - ChownFilePathToUser(path string) error - OverWriteString(path string, content string) error - GetFileAsString(path string) (string, error) -} - -type nologinEnvStore interface { - LoginWithToken(token string) error -} - -const name = "envsetup" - -func NewCmdEnvSetup(store envsetupStore, noLoginStore nologinEnvStore) *cobra.Command { - var forceEnableSetup bool - // add debugger flag to toggle features when running command through a debugger - // this is useful for debugging setup scripts - debugger := false - configureSystemSSHConfig := true - - // if a token flag is supplied, log in with it - var token string - - var datadogAPIKey string - - var disableAutostop bool - - var reportInterval string - - var autostopPort string - cmd := &cobra.Command{ - Use: name, - DisableFlagsInUseLine: true, - Short: "TODO", - Long: "TODO", - Example: "TODO", - RunE: func(cmd *cobra.Command, args []string) error { - var errors error - for _, arg := range args { - err := RunEnvSetup( - store, - name, - forceEnableSetup, - debugger, - configureSystemSSHConfig, - arg, - token, - noLoginStore, - datadogAPIKey, - disableAutostop, - reportInterval, - autostopPort, - ) - if err != nil { - errors = multierror.Append(err) - } - } - if errors != nil { - return breverrors.WrapAndTrace(errors) - } - return nil - }, - } - cmd.PersistentFlags().BoolVar(&forceEnableSetup, "force-enable", false, "force the setup script to run despite params") - cmd.PersistentFlags().BoolVar(&debugger, "debugger", debugger, "toggle features that don't play well with debuggers") - cmd.PersistentFlags().BoolVar(&configureSystemSSHConfig, "configure-system-ssh-config", configureSystemSSHConfig, "configure system ssh config") - cmd.PersistentFlags().StringVar(&token, "token", "", "token to use for login") - cmd.PersistentFlags().StringVar(&datadogAPIKey, "datadog-api-key", "", "datadog API key to use for logging") - cmd.PersistentFlags().BoolVar(&disableAutostop, "disable-autostop", false, "disable autostop") - cmd.PersistentFlags().StringVar(&reportInterval, "report-interval", "10m", "report interval") - cmd.PersistentFlags().StringVar(&autostopPort, "autostop-port", "22", "autostop port") - - return cmd -} - -func RunEnvSetup( - store envsetupStore, - name string, - forceEnableSetup, debugger, configureSystemSSHConfig bool, - workspaceid, token string, - noLoginStore nologinEnvStore, - datadogAPIKey string, - disableAutostop bool, - reportInterval string, - portToCheckAutostopTrafficOn string, -) error { - if token != "" { - err := noLoginStore.LoginWithToken(token) - if err != nil { - return breverrors.WrapAndTrace(err) - } - } - - breverrors.GetDefaultErrorReporter().AddTag("command", name) - _, err := store.GetCurrentWorkspaceID() // do this to error reporting - if err != nil { - return breverrors.WrapAndTrace(err) - } - fmt.Println("setting up instance") - - params, err := store.GetEnvSetupParams(workspaceid) - if err != nil { - return breverrors.WrapAndTrace(err) - } - res, err := json.MarshalIndent(params, "", "") - if err != nil { - return breverrors.WrapAndTrace(err) - } - fmt.Println(string(res)) - - if !featureflag.IsDev() && !debugger { - _, err = store.GetCurrentUser() // do this to set error user reporting - if err != nil { - fmt.Println(err) - if !params.DisableSetup { - breverrors.GetDefaultErrorReporter().ReportError(breverrors.Wrap(err, "setup continued")) - } - } - } - - if !forceEnableSetup && params.DisableSetup { - fmt.Printf("WARNING: setup script not running [params.DisableSetup=%v, forceEnableSetup=%v]", params.DisableSetup, forceEnableSetup) - return nil - } - - err = setupEnv( - store, - params, - configureSystemSSHConfig, - datadogAPIKey, - disableAutostop, - reportInterval, - portToCheckAutostopTrafficOn, - ) - if err != nil { - return breverrors.WrapAndTrace(err) - } - - fmt.Println("done setting up instance") - return nil -} - -type envInitier struct { - setupworkspace.WorkspaceIniter - ConfigureSystemSSHConfig bool - brevMonConfigurer autostartconf.DaemonConfigurer - datadogAPIKey string - store envsetupStore -} - -func appendLogToFile(content string, file string) error { - cmd := setupworkspace.CmdStringBuilder(fmt.Sprintf(`echo "%s" >> %s`, content, file)) - err := cmd.Run() - if err != nil { - return breverrors.WrapAndTrace(err) - } - return nil -} - -//go:embed motd.sh -var motd string - -func (e *envInitier) SetupMOTD() error { - err := e.store.OverWriteString("/etc/ssh/my_banner", motd) - if err != nil { - return breverrors.WrapAndTrace(err) - } - - fstring, err := e.store.GetFileAsString("/etc/ssh/sshd_config") - if err != nil { - return breverrors.WrapAndTrace(err) - } - - if !strings.Contains(fstring, "Banner /etc/ssh/my_banner") { - err = e.store.AppendString("/etc/ssh/sshd_config", "Banner /etc/ssh/my_banner") - if err != nil { - return breverrors.WrapAndTrace(err) - } - } - - err = setupworkspace.BuildAndRunCmd("systemctl", "reload", "ssh.service") - if err != nil { - return breverrors.WrapAndTrace(err) - } - - return nil -} - -//go:embed speedtest.py -var speedtest string - -func (e *envInitier) SetupSpeedTest() error { - err := e.store.WriteString("/usr/local/bin/speedtest", speedtest) - if err != nil { - return breverrors.WrapAndTrace(err) - } - err = e.store.Chmod("/usr/local/bin/speedtest", 0o755) - if err != nil { - return breverrors.WrapAndTrace(err) - } - return nil -} - -func (e *envInitier) SetupUpdateModel() error { - dc := updatemodel.DaemonConfigurer{ - Store: e.store, - } - err := dc.Configure() - if err != nil { - return breverrors.WrapAndTrace(err) - } - return nil -} - -func (e envInitier) Setup() error { //nolint:funlen,gocyclo // TODO - var setupErr error - - err := appendLogToFile("setup started", "/var/log/brev-setup-steps.log") - if err != nil { - setupErr = multierror.Append(setupErr, breverrors.WrapAndTrace(err)) - } - - err = setupworkspace.BuildAndRunCmd("systemctl", "stop", "unattended-upgrades") - if err != nil { - setupErr = multierror.Append(setupErr, breverrors.WrapAndTrace(err)) - } - out, err := setupworkspace.RunCMDWithOutput("apt-get", "-y", "remove", "unattended-upgrades") - if err != nil { - setupErr = multierror.Append(setupErr, - breverrors.Wrap(err, "apt-get -y remove unattended-upgrades"+out)) - } - - cmd := setupworkspace.CmdStringBuilder("echo user: $(whoami) && echo pwd: $(pwd)") - err = cmd.Run() - if err != nil { - setupErr = multierror.Append(setupErr, breverrors.WrapAndTrace(err)) - } - - postPrepare := util.RunEAsync( - e.SetupVsCodeExtensions, - e.SetupSpeedTest, - e.SetupMOTD, - e.SetupUpdateModel, - ) - - err = util.RunEAsync( - e.SetupSSH, - e.SetupGit, - ).Await() - if err != nil { - setupErr = multierror.Append(setupErr, breverrors.WrapAndTrace(err)) - } - - err = appendLogToFile("starting repo setup", "/var/log/brev-steps.log") - if err != nil { - setupErr = multierror.Append(setupErr, breverrors.WrapAndTrace(err)) - } - err = e.SetupRepos() - if err != nil { - setupErr = multierror.Append(breverrors.WrapAndTrace(err)) - } - fmt.Println("------ Git repo cloned ------") - err = appendLogToFile("repo setup done", "/var/log/brev-steps.log") - if err != nil { - setupErr = multierror.Append(setupErr, breverrors.WrapAndTrace(err)) - } - - err = e.SetupEnvVars() // here - if err != nil { - setupErr = multierror.Append(breverrors.WrapAndTrace(err)) - } - err = appendLogToFile("starting to run execs", "/var/log/brev-steps.log") - if err != nil { - setupErr = multierror.Append(setupErr, breverrors.WrapAndTrace(err)) - } - err = e.RunExecs() // here - if err != nil { - setupErr = multierror.Append(breverrors.WrapAndTrace(err)) - } - - err = e.CreateVerbYamlFile() // create - if err != nil { - setupErr = multierror.Append(breverrors.WrapAndTrace(err)) - } - - fmt.Println("------ Done running execs ------") - err = appendLogToFile("done running execs", "/var/log/brev-steps.log") - if err != nil { - setupErr = multierror.Append(breverrors.WrapAndTrace(err)) - } - - err = e.brevMonConfigurer.Install() - if err != nil { - setupErr = multierror.Append(breverrors.WrapAndTrace(err)) - } - - if e.datadogAPIKey != "" { - err = e.SetupDatadog() - if err != nil { - setupErr = multierror.Append(breverrors.WrapAndTrace(err)) - } - } - - err = postPrepare.Await() - if err != nil { - setupErr = multierror.Append(breverrors.WrapAndTrace(err)) - } - - err = appendLogToFile("setup done", "/var/log/brev-steps.log") - if err != nil { - setupErr = multierror.Append(breverrors.WrapAndTrace(err)) - } - - if setupErr != nil { - return breverrors.WrapAndTrace(setupErr) - } - - return nil -} - -func (e envInitier) SetupDatadog() error { - installScriptURL := "https://s3.amazonaws.com/dd-agent/scripts/install_script.sh" - var installScript string - - resp, err := collections.GetRequestWithContext(context.TODO(), installScriptURL) - if err != nil { - return breverrors.WrapAndTrace(err) - } - - defer resp.Body.Close() //nolint: errcheck // we don't care about the error here b/c defer - - if resp.StatusCode != 200 { - return breverrors.WrapAndTrace(fmt.Errorf("failed to download datadog install script")) - } - - bodyBytes, err := ioutil.ReadAll(resp.Body) - if err != nil { - return breverrors.WrapAndTrace(err) - } - - installScript = string(bodyBytes) - - cmd := setupworkspace.CmdStringBuilder(installScript) - - cmd.Env = append(cmd.Env, - append( - os.Environ(), - []string{ - "DD_API_KEY=" + e.datadogAPIKey, - "DD_AGENT_MAJOR_VERSION=7", - "DD_SITE=\"datadoghq.com\"", - }..., - )...) - - err = cmd.Run() - if err != nil { - out, err0 := cmd.CombinedOutput() - if err0 != nil { - return breverrors.WrapAndTrace(err0) - } - return breverrors.WrapAndTrace(fmt.Errorf("failed to install datadog agent: %s", string(out))) - } - - err = e.store.WriteString("/etc/datadog-agent/conf.d/systemd.d/conf.yaml", ` -init_config: -instances: - ## @param unit_names - list of strings - required - ## List of systemd units to monitor. - ## Full names must be used. Examples: ssh.service, docker.socket - # - - unit_names: - - ssh.service - - brevmon.service -`) - if err != nil { - return breverrors.WrapAndTrace(err) - } - - err = e.store.WriteString("/etc/datadog-agent/conf.d/journald.d/conf.yaml", ` -logs: - - type: journald - path: /var/log/journal/ - include_units: - - brevmon.service - - sshd.service -`) - if err != nil { - return breverrors.WrapAndTrace(err) - } - - err = setupworkspace.BuildAndRunCmd( - "/usr/sbin/usermod", - "-a", - "-G", - "systemd-journal", - "dd-agent", - ) - if err != nil { - return breverrors.WrapAndTrace(err) - } - - hostname, _ := os.Hostname() - stringToAppend := fmt.Sprintf("\nlogs_enabled: true\nhostname: %s\n", hostname) - // add logs_enabled: true to /etc/datadog-agent/datadog.yaml - err = e.store.AppendString("/etc/datadog-agent/datadog.yaml", stringToAppend) - if err != nil { - return breverrors.WrapAndTrace(err) - } - - err = setupworkspace.BuildAndRunCmd( - "systemctl", - "restart", - "datadog-agent", - ) - if err != nil { - return breverrors.WrapAndTrace(err) - } - - return nil -} - -type setupKeyI interface { - WriteString(path, content string) error - Chmod(path string, mode os.FileMode) error - ChownFilePathToUser(path string) error -} - -func setupKey(path, content string, perm os.FileMode, store setupKeyI) error { - err := store.WriteString(path, content) - if err != nil { - return breverrors.WrapAndTrace(err) - } - err = store.Chmod(path, perm) - if err != nil { - return breverrors.WrapAndTrace(err) - } - err = store.ChownFilePathToUser(path) - if err != nil { - return breverrors.WrapAndTrace(err) - } - return nil -} - -func (e envInitier) setupPrivateKey(content string) error { - pkpath := e.BuildHomePath(".ssh", "id_rsa") - err := setupKey(pkpath, content, 0o600, e.store) - if err != nil { - return breverrors.WrapAndTrace(err) - } - return nil -} - -func (e envInitier) setupPublicKey(content string) error { - pubkeypath := e.BuildHomePath(".ssh", "id_rsa.pub") - err := setupKey(pubkeypath, content, 0o644, e.store) - if err != nil { - return breverrors.WrapAndTrace(err) - } - return nil -} - -func (e envInitier) SetupSSHKeys(keys *store.KeyPair) error { - err := e.setupPrivateKey(keys.PrivateKeyData) - if err != nil { - return breverrors.WrapAndTrace(err) - } - err = e.setupPublicKey(keys.PublicKeyData) - if err != nil { - return breverrors.WrapAndTrace(err) - } - return nil -} - -func (e envInitier) SetupSSH() error { - keys := e.Params.WorkspaceKeyPair - err := e.SetupSSHKeys(keys) - if err != nil { - return breverrors.WrapAndTrace(err) - } - c := fmt.Sprintf(`eval "$(ssh-agent -s)" && ssh-add %s`, e.BuildHomePath(".ssh", "id_rsa")) - cmd := setupworkspace.CmdStringBuilder(c) - err = cmd.Run() - if err != nil { - return breverrors.WrapAndTrace(err) - } - - authorizedKeyPath := e.BuildHomePath(".ssh", "authorized_keys") - - err = e.store.AppendString(authorizedKeyPath, "\n"+keys.PublicKeyData) - if err != nil { - return breverrors.WrapAndTrace(err) - } - err = e.store.ChownFilePathToUser(authorizedKeyPath) - if err != nil { - return breverrors.WrapAndTrace(err) - } - - if e.ConfigureSystemSSHConfig { - err = e.store.WriteString( - filepath.Join( - "/etc", - "ssh", - "sshd_config.d", - fmt.Sprintf("%s.conf", e.User.Username), - ), - fmt.Sprintf( - `PubkeyAuthentication yes -AuthorizedKeysFile %s -PasswordAuthentication no`, authorizedKeyPath), - ) - if err != nil { - return breverrors.WrapAndTrace(err) - } - } - - return nil -} - -func (e envInitier) SetupEnvVars() error { - // set env vars - err := e.store.AppendString("/etc/bash.bashrc", ` -_brev_hook() { - local previous_exit_status=$?; - trap -- '' SIGINT; - eval "$(/usr/local/bin/brev configure-env-vars bash)"; - trap - SIGINT; - return $previous_exit_status; -}; -if ! [[ "${PROMPT_COMMAND:-}" =~ _brev_hook ]]; then - PROMPT_COMMAND="_brev_hook${PROMPT_COMMAND:+;$PROMPT_COMMAND}" -fi -`) - if err != nil { - return breverrors.WrapAndTrace(err) - } - - fileExists, err := e.store.FileExists("/etc/zsh/zshrc") - if err != nil { - return breverrors.WrapAndTrace(err) - } - if !fileExists { - err = e.store.WriteString("/etc/zsh/zshrc", "") - if err != nil { - return breverrors.WrapAndTrace(err) - } - } - - err = e.store.AppendString("/etc/zsh/zshrc", ` -_brev_hook() { - trap -- '' SIGINT; - eval "$(/usr/local/bin/brev configure-env-vars zsh)"; - trap - SIGINT; -} -typeset -ag precmd_functions; -if [[ -z "${precmd_functions[(r)_brev_hook]+1}" ]]; then - precmd_functions=( _brev_hook ${precmd_functions[@]} ) -fi -typeset -ag chpwd_functions; -if [[ -z "${chpwd_functions[(r)_brev_hook]+1}" ]]; then - chpwd_functions=( _brev_hook ${chpwd_functions[@]} ) -fi -export PATH="/opt/conda/bin:$PATH" -`) - if err != nil { - return breverrors.WrapAndTrace(err) - } - return nil -} - -func newEnvIniter( - user *user.User, - params *store.SetupParamsV0, - configureSystemSSHConfig bool, - store envsetupStore, - datadogAPIKey string, - disableAbleAUtosop bool, - reportInterval string, - portToCheckTrafficOn string, -) *envInitier { - workspaceIniter := setupworkspace.NewWorkspaceIniter(user.HomeDir, user, params) - return &envInitier{ - *workspaceIniter, - configureSystemSSHConfig, - autostartconf.NewBrevMonConfigure( - store, - disableAbleAUtosop, - reportInterval, - portToCheckTrafficOn, - ), - datadogAPIKey, - store, - } -} - -func setupEnv( - store envsetupStore, - params *store.SetupParamsV0, - configureSystemSSHConfig bool, - datadogAPIKey string, - disableAutoStop bool, - reportInterval string, - portToCheckAutostopTrafficOn string, -) error { - err := store.BuildBrevHome() - if err != nil { - return breverrors.WrapAndTrace(err) - } - user, err := setupworkspace.GetUserFromUserStr(store.GetOSUser()) - if err != nil { - return breverrors.WrapAndTrace(err) - } - wi := newEnvIniter( - user, - params, - configureSystemSSHConfig, - store, - datadogAPIKey, - disableAutoStop, - reportInterval, - portToCheckAutostopTrafficOn, - ) - // set logfile path to ~/.brev/envsetup.log - logFilePath := filepath.Join(user.HomeDir, ".brev", "envsetup.log") - done, err := mirrorPipesToFile(store, logFilePath) - if err != nil { - return breverrors.WrapAndTrace(err) - } - defer done() - fmt.Printf("brev %s\n", version.Version) - - fmt.Println("------ Setup Begin ------") - err = wi.Setup() - fmt.Println("------ Setup End ------") - if err != nil { - fmt.Println("------ Failure ------") - time.Sleep(time.Millisecond * 100) // wait for buffer to be written - //nolint:gosec // constant - logFile, errF := ioutil.ReadFile(logFilePath) - if errF != nil { - return multierror.Append(err, errF) - } - breverrors.GetDefaultErrorReporter().AddBreadCrumb(breverrors.ErrReportBreadCrumb{Type: "log-file", Message: string(logFile)}) - return breverrors.WrapAndTrace(err) - } else { - fmt.Println("------ Success ------") - } - return nil -} - -func mirrorPipesToFile(store envsetupStore, logFile string) (func(), error) { - f, err := store.GetOrCreateSetupLogFile(logFile) - if err != nil { - return nil, breverrors.WrapAndTrace(err) - } - - // save existing stdout | MultiWriter writes to saved stdout and file - stdOut := os.Stdout - stdErr := os.Stderr - mw := io.MultiWriter(stdOut, f) - - // get pipe reader and writer | writes to pipe writer come out pipe reader - r, w, err := os.Pipe() - if err != nil { - return nil, breverrors.WrapAndTrace(err) - } - - // replace stdout,stderr with pipe writer | all writes to stdout, stderr will go through pipe instead (fmt.print, log) - os.Stdout = w - os.Stderr = w - - // writes with log.Print should also write to mw - log.SetOutput(mw) - - // create channel to control exit | will block until all copies are finished - exit := make(chan bool) - - go func() { - // copy all reads from pipe to multiwriter, which writes to stdout and file - _, _ = io.Copy(mw, r) - // when r or w is closed copy will finish and true will be sent to channel - exit <- true - }() - - // function to be deferred in main until program exits - return func() { - // close writer then block on exit channel | this will let mw finish writing before the program exits - _ = w.Close() - <-exit - // close file after all writes have finished - _ = f.Close() - os.Stdout = stdOut - os.Stderr = stdErr - }, nil -} diff --git a/pkg/cmd/envsetup/envsetup_test.go b/pkg/cmd/envsetup/envsetup_test.go deleted file mode 100644 index 2819e074..00000000 --- a/pkg/cmd/envsetup/envsetup_test.go +++ /dev/null @@ -1,87 +0,0 @@ -package envsetup - -import ( - _ "embed" - "os" - "os/user" - "testing" - - "github.com/brevdev/brev-cli/pkg/store" - "github.com/spf13/afero" - "github.com/tweekmonster/luser" -) - -func Test_appendLogToFile(t *testing.T) { - t.Skip() - err := appendLogToFile("test", "test") - if err != nil { - t.Errorf("error appending to file %s", err) - } -} - -func Test_MOTDExists(t *testing.T) { - if motd == "" { - t.Errorf("motd is empty") - } -} - -func Test_SpeedtestExists(t *testing.T) { - if speedtest == "" { - t.Errorf("speedtest is empty") - } -} - -func makeMockFS() setupKeyI { - bs := store.NewBasicStore().WithEnvGetter( - func(s string) string { - return "test" - }, - ) - fs := bs.WithFileSystem(afero.NewMemMapFs()) - - fs = fs.WithUserHomeDirGetter( - func() (string, error) { - return "/home/test", nil - }, - ) - fs.User = &luser.User{ - User: &user.User{ - Uid: "1000", - Gid: "1000", - }, - } - return fs -} - -func Test_setupKey(t *testing.T) { - type args struct { - path string - content string - perm os.FileMode - store setupKeyI - } - tests := []struct { - name string - args args - wantErr bool - }{ - // TODO: Add test cases. - { - name: "test", - args: args{ - path: "test", - content: "test", - perm: 0o644, - store: makeMockFS(), - }, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := setupKey(tt.args.path, tt.args.content, tt.args.perm, tt.args.store); (err != nil) != tt.wantErr { - t.Errorf("setupKey() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} diff --git a/pkg/cmd/envsetup/motd.sh b/pkg/cmd/envsetup/motd.sh deleted file mode 100644 index c64db4c6..00000000 --- a/pkg/cmd/envsetup/motd.sh +++ /dev/null @@ -1,60 +0,0 @@ -Welcome to brev.dev - - ##@@@#. - #@@@@@@@@@@. - #@@@. .@@@@% - #@% :@@@@- - #@ =@@@= -#@* +@@@= -=%@%= -#@: =@@@@ *@@@@@@@@ -#@: =@@@@ -@@@= +@+@@@@@@+ -#@: @@@@% @@@- @@: @@@@# @@@@@ -#@@= @@@@+ %@@- +@@* *@@@@@@@@. - #@@- @@@@.%@@- #@@ .@@@@.. @@@@@= @@@. - #@@@= @@@@@@@@ @@@@@@@@@@% *@@@@- @@# - #@@@@ -@@@@@@ -@# #@@+ @@@@# @@@ - #@@@% @@@@@@* -@ .@@ =@@@@: @@@ - #@@@- @@@@@@ @@@: -# @ @@@@* @@@ - #@@@- @@@@@@ :@@@@% -. @ @@@@ @@# - #@@@ +@@@@@= :@@@@@ - @@@@@ %@@ - #@@@ ..@@@@ =@@@@@ - @@@@+ .@@- - #@@: :@@ @@@@@@ @@@@: -@@@ - #@@ @ @@@@@@ @@@@@ .@@@ - #@@. *% @@@@@: #+ @@ @@@. - #@@ @ @@@@# @ @@ @@@ - #@@- = @@@@ @ +@# @ @@@. - #@@ . . @@@ . @@@ @@ *@@# - #@@ @@* - @+ #. -@@@ @@ @@+ - #@@@ @@@: - .@ =@@@ @+ @@@ - #@@@ *@@@ - @@ @@@@ @ %@@ - #@@. +@@@ - @@* @@@@ @ -@@: - #@@+ .@@@ - @@@. :%@ @@% - #@@@ @@@: - .@@@ =. #@@ - #@@@ @@@ - #@@ @@ - #@@@ @@@ - @# .@@+ - #@@@ @@@ . @: @@@ - #@@+ @ + @: +@@: - #@@@ @ @ @@@ - #@@@ % @ % = %@@ - #@@@ *% @ # @ .@@@ - #@@@. @ .@ .: +@ @@@+ - #@@+ :@ .@ # @..=@@ +@@@- - #@@@ @= .@ # @@@@@@ @@@@ - #@@ *@ @@ @ #@@@* %@@@@ - #@@# @@ @- : +@* @@@@ - #@@@. -@@ @- :: %* *@@@@ - #@@@%%%@@@ @- : * @@@@= - #@@@@@@@@ *@- @ * @@@@@ - #@@@@@@ @@ # @ =@@@@ - #@@@@@= .@@ -# @ =@@@@@ - #@@@@= @@@ # @@@@@@@@ - #@@@@@@@@@@: .@ @@@@@@@. - #@ @@ @@@@@@@= - #@@@@# #@@@@@@= - #@@@@@. .@@@@@@@ - #%@@@@@@@@@@@@@@- - ###@@@@@## - -Internet Speed: -Avg Upload: 1574.31 Mbit/s -Avg Download: 1103.68 Mbit/s \ No newline at end of file diff --git a/pkg/cmd/envsetup/speedtest.py b/pkg/cmd/envsetup/speedtest.py deleted file mode 100644 index b7e73eb1..00000000 --- a/pkg/cmd/envsetup/speedtest.py +++ /dev/null @@ -1,2184 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright 2012 Matt Martz -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -import csv -import datetime -import errno -import math -import os -import platform -import re -import signal -import socket -import sys -import threading -import subprocess -import timeit -import xml.parsers.expat - -try: - import gzip - GZIP_BASE = gzip.GzipFile -except ImportError: - gzip = None - GZIP_BASE = object - -__version__ = '2.1.4b1' - -downloadspeed = "" -uploadspeed = "" - -class FakeShutdownEvent(object): - """Class to fake a threading.Event.isSet so that users of this module - are not required to register their own threading.Event() - """ - - @staticmethod - def isSet(): - "Dummy method to always return false""" - return False - - is_set = isSet - - -# Some global variables we use -DEBUG = False -_GLOBAL_DEFAULT_TIMEOUT = object() -PY25PLUS = sys.version_info[:2] >= (2, 5) -PY26PLUS = sys.version_info[:2] >= (2, 6) -PY32PLUS = sys.version_info[:2] >= (3, 2) -PY310PLUS = sys.version_info[:2] >= (3, 10) - -# Begin import game to handle Python 2 and Python 3 -try: - import json -except ImportError: - try: - import simplejson as json - except ImportError: - json = None - -try: - import xml.etree.ElementTree as ET - try: - from xml.etree.ElementTree import _Element as ET_Element - except ImportError: - pass -except ImportError: - from xml.dom import minidom as DOM - from xml.parsers.expat import ExpatError - ET = None - -try: - from urllib2 import (urlopen, Request, HTTPError, URLError, - AbstractHTTPHandler, ProxyHandler, - HTTPDefaultErrorHandler, HTTPRedirectHandler, - HTTPErrorProcessor, OpenerDirector) -except ImportError: - from urllib.request import (urlopen, Request, HTTPError, URLError, - AbstractHTTPHandler, ProxyHandler, - HTTPDefaultErrorHandler, HTTPRedirectHandler, - HTTPErrorProcessor, OpenerDirector) - -try: - from httplib import HTTPConnection, BadStatusLine -except ImportError: - from http.client import HTTPConnection, BadStatusLine - -try: - from httplib import HTTPSConnection -except ImportError: - try: - from http.client import HTTPSConnection - except ImportError: - HTTPSConnection = None - -try: - from httplib import FakeSocket -except ImportError: - FakeSocket = None - -try: - from Queue import Queue -except ImportError: - from queue import Queue - -try: - from urlparse import urlparse -except ImportError: - from urllib.parse import urlparse - -try: - from urlparse import parse_qs -except ImportError: - try: - from urllib.parse import parse_qs - except ImportError: - from cgi import parse_qs - -try: - from hashlib import md5 -except ImportError: - from md5 import md5 - -try: - from argparse import ArgumentParser as ArgParser - from argparse import SUPPRESS as ARG_SUPPRESS - PARSER_TYPE_INT = int - PARSER_TYPE_STR = str - PARSER_TYPE_FLOAT = float -except ImportError: - from optparse import OptionParser as ArgParser - from optparse import SUPPRESS_HELP as ARG_SUPPRESS - PARSER_TYPE_INT = 'int' - PARSER_TYPE_STR = 'string' - PARSER_TYPE_FLOAT = 'float' - -try: - from cStringIO import StringIO - BytesIO = None -except ImportError: - try: - from StringIO import StringIO - BytesIO = None - except ImportError: - from io import StringIO, BytesIO - -try: - import __builtin__ -except ImportError: - import builtins - from io import TextIOWrapper, FileIO - - class _Py3Utf8Output(TextIOWrapper): - """UTF-8 encoded wrapper around stdout for py3, to override - ASCII stdout - """ - def __init__(self, f, **kwargs): - buf = FileIO(f.fileno(), 'w') - super(_Py3Utf8Output, self).__init__( - buf, - encoding='utf8', - errors='strict' - ) - - def write(self, s): - super(_Py3Utf8Output, self).write(s) - self.flush() - - _py3_print = getattr(builtins, 'print') - try: - _py3_utf8_stdout = _Py3Utf8Output(sys.stdout) - _py3_utf8_stderr = _Py3Utf8Output(sys.stderr) - except OSError: - # sys.stdout/sys.stderr is not a compatible stdout/stderr object - # just use it and hope things go ok - _py3_utf8_stdout = sys.stdout - _py3_utf8_stderr = sys.stderr - - def to_utf8(v): - """No-op encode to utf-8 for py3""" - return v - - def print_(*args, **kwargs): - """Wrapper function for py3 to print, with a utf-8 encoded stdout""" - if kwargs.get('file') == sys.stderr: - kwargs['file'] = _py3_utf8_stderr - else: - kwargs['file'] = kwargs.get('file', _py3_utf8_stdout) - _py3_print(*args, **kwargs) -else: - del __builtin__ - - def to_utf8(v): - """Encode value to utf-8 if possible for py2""" - try: - return v.encode('utf8', 'strict') - except AttributeError: - return v - - def print_(*args, **kwargs): - """The new-style print function for Python 2.4 and 2.5. - - Taken from https://pypi.python.org/pypi/six/ - - Modified to set encoding to UTF-8 always, and to flush after write - """ - fp = kwargs.pop("file", sys.stdout) - if fp is None: - return - - def write(data): - if not isinstance(data, basestring): - data = str(data) - # If the file has an encoding, encode unicode with it. - encoding = 'utf8' # Always trust UTF-8 for output - if (isinstance(fp, file) and - isinstance(data, unicode) and - encoding is not None): - errors = getattr(fp, "errors", None) - if errors is None: - errors = "strict" - data = data.encode(encoding, errors) - fp.write(data) - fp.flush() - want_unicode = False - sep = kwargs.pop("sep", None) - if sep is not None: - if isinstance(sep, unicode): - want_unicode = True - elif not isinstance(sep, str): - raise TypeError("sep must be None or a string") - end = kwargs.pop("end", None) - if end is not None: - if isinstance(end, unicode): - want_unicode = True - elif not isinstance(end, str): - raise TypeError("end must be None or a string") - if kwargs: - raise TypeError("invalid keyword arguments to print()") - if not want_unicode: - for arg in args: - if isinstance(arg, unicode): - want_unicode = True - break - if want_unicode: - newline = unicode("\n") - space = unicode(" ") - else: - newline = "\n" - space = " " - if sep is None: - sep = space - if end is None: - end = newline - for i, arg in enumerate(args): - if i: - write(sep) - write(arg) - write(end) - -# Exception "constants" to support Python 2 through Python 3 -try: - import ssl - try: - CERT_ERROR = (ssl.CertificateError,) - except AttributeError: - CERT_ERROR = tuple() - - HTTP_ERRORS = ( - (HTTPError, URLError, socket.error, ssl.SSLError, BadStatusLine) + - CERT_ERROR - ) -except ImportError: - ssl = None - HTTP_ERRORS = (HTTPError, URLError, socket.error, BadStatusLine) - -if PY32PLUS: - etree_iter = ET.Element.iter -elif PY25PLUS: - etree_iter = ET_Element.getiterator - -if PY26PLUS: - thread_is_alive = threading.Thread.is_alive -else: - thread_is_alive = threading.Thread.isAlive - - -def event_is_set(event): - try: - return event.is_set() - except AttributeError: - return event.isSet() - - -class SpeedtestException(Exception): - """Base exception for this module""" - - -class SpeedtestCLIError(SpeedtestException): - """Generic exception for raising errors during CLI operation""" - - -class SpeedtestHTTPError(SpeedtestException): - """Base HTTP exception for this module""" - - -class SpeedtestConfigError(SpeedtestException): - """Configuration XML is invalid""" - - -class SpeedtestServersError(SpeedtestException): - """Servers XML is invalid""" - - -class ConfigRetrievalError(SpeedtestHTTPError): - """Could not retrieve config.php""" - - -class ServersRetrievalError(SpeedtestHTTPError): - """Could not retrieve speedtest-servers.php""" - - -class InvalidServerIDType(SpeedtestException): - """Server ID used for filtering was not an integer""" - - -class NoMatchedServers(SpeedtestException): - """No servers matched when filtering""" - - -class SpeedtestMiniConnectFailure(SpeedtestException): - """Could not connect to the provided speedtest mini server""" - - -class InvalidSpeedtestMiniServer(SpeedtestException): - """Server provided as a speedtest mini server does not actually appear - to be a speedtest mini server - """ - - -class ShareResultsConnectFailure(SpeedtestException): - """Could not connect to speedtest.net API to POST results""" - - -class ShareResultsSubmitFailure(SpeedtestException): - """Unable to successfully POST results to speedtest.net API after - connection - """ - - -class SpeedtestUploadTimeout(SpeedtestException): - """testlength configuration reached during upload - Used to ensure the upload halts when no additional data should be sent - """ - - -class SpeedtestBestServerFailure(SpeedtestException): - """Unable to determine best server""" - - -class SpeedtestMissingBestServer(SpeedtestException): - """get_best_server not called or not able to determine best server""" - - -def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT, - source_address=None): - """Connect to *address* and return the socket object. - - Convenience function. Connect to *address* (a 2-tuple ``(host, - port)``) and return the socket object. Passing the optional - *timeout* parameter will set the timeout on the socket instance - before attempting to connect. If no *timeout* is supplied, the - global default timeout setting returned by :func:`getdefaulttimeout` - is used. If *source_address* is set it must be a tuple of (host, port) - for the socket to bind as a source address before making the connection. - An host of '' or port 0 tells the OS to use the default. - - Largely vendored from Python 2.7, modified to work with Python 2.4 - """ - - host, port = address - err = None - for res in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM): - af, socktype, proto, canonname, sa = res - sock = None - try: - sock = socket.socket(af, socktype, proto) - if timeout is not _GLOBAL_DEFAULT_TIMEOUT: - sock.settimeout(float(timeout)) - if source_address: - sock.bind(source_address) - sock.connect(sa) - return sock - - except socket.error: - err = get_exception() - if sock is not None: - sock.close() - - if err is not None: - raise err - else: - raise socket.error("getaddrinfo returns an empty list") - - -class SpeedtestHTTPConnection(HTTPConnection): - """Custom HTTPConnection to support source_address across - Python 2.4 - Python 3 - """ - def __init__(self, *args, **kwargs): - source_address = kwargs.pop('source_address', None) - timeout = kwargs.pop('timeout', 10) - - self._tunnel_host = None - - HTTPConnection.__init__(self, *args, **kwargs) - - self.source_address = source_address - self.timeout = timeout - - def connect(self): - """Connect to the host and port specified in __init__.""" - try: - self.sock = socket.create_connection( - (self.host, self.port), - self.timeout, - self.source_address - ) - except (AttributeError, TypeError): - self.sock = create_connection( - (self.host, self.port), - self.timeout, - self.source_address - ) - - if self._tunnel_host: - self._tunnel() - - -if HTTPSConnection: - class SpeedtestHTTPSConnection(HTTPSConnection): - """Custom HTTPSConnection to support source_address across - Python 2.4 - Python 3 - """ - default_port = 443 - - def __init__(self, *args, **kwargs): - source_address = kwargs.pop('source_address', None) - timeout = kwargs.pop('timeout', 10) - - self._tunnel_host = None - - HTTPSConnection.__init__(self, *args, **kwargs) - - self.timeout = timeout - self.source_address = source_address - - def connect(self): - "Connect to a host on a given (SSL) port." - try: - self.sock = socket.create_connection( - (self.host, self.port), - self.timeout, - self.source_address - ) - except (AttributeError, TypeError): - self.sock = create_connection( - (self.host, self.port), - self.timeout, - self.source_address - ) - - if self._tunnel_host: - self._tunnel() - - if ssl: - try: - kwargs = {} - if hasattr(ssl, 'SSLContext'): - if self._tunnel_host: - kwargs['server_hostname'] = self._tunnel_host - else: - kwargs['server_hostname'] = self.host - self.sock = self._context.wrap_socket(self.sock, **kwargs) - except AttributeError: - self.sock = ssl.wrap_socket(self.sock) - try: - self.sock.server_hostname = self.host - except AttributeError: - pass - elif FakeSocket: - # Python 2.4/2.5 support - try: - self.sock = FakeSocket(self.sock, socket.ssl(self.sock)) - except AttributeError: - raise SpeedtestException( - 'This version of Python does not support HTTPS/SSL ' - 'functionality' - ) - else: - raise SpeedtestException( - 'This version of Python does not support HTTPS/SSL ' - 'functionality' - ) - - -def _build_connection(connection, source_address, timeout, context=None): - """Cross Python 2.4 - Python 3 callable to build an ``HTTPConnection`` or - ``HTTPSConnection`` with the args we need - - Called from ``http(s)_open`` methods of ``SpeedtestHTTPHandler`` or - ``SpeedtestHTTPSHandler`` - """ - def inner(host, **kwargs): - kwargs.update({ - 'source_address': source_address, - 'timeout': timeout - }) - if context: - kwargs['context'] = context - return connection(host, **kwargs) - return inner - - -class SpeedtestHTTPHandler(AbstractHTTPHandler): - """Custom ``HTTPHandler`` that can build a ``HTTPConnection`` with the - args we need for ``source_address`` and ``timeout`` - """ - def __init__(self, debuglevel=0, source_address=None, timeout=10): - AbstractHTTPHandler.__init__(self, debuglevel) - self.source_address = source_address - self.timeout = timeout - - def http_open(self, req): - return self.do_open( - _build_connection( - SpeedtestHTTPConnection, - self.source_address, - self.timeout - ), - req - ) - - http_request = AbstractHTTPHandler.do_request_ - - -class SpeedtestHTTPSHandler(AbstractHTTPHandler): - """Custom ``HTTPSHandler`` that can build a ``HTTPSConnection`` with the - args we need for ``source_address`` and ``timeout`` - """ - def __init__(self, debuglevel=0, context=None, source_address=None, - timeout=10): - AbstractHTTPHandler.__init__(self, debuglevel) - self._context = context - self.source_address = source_address - self.timeout = timeout - - def https_open(self, req): - return self.do_open( - _build_connection( - SpeedtestHTTPSConnection, - self.source_address, - self.timeout, - context=self._context, - ), - req - ) - - https_request = AbstractHTTPHandler.do_request_ - - -def build_opener(source_address=None, timeout=10): - """Function similar to ``urllib2.build_opener`` that will build - an ``OpenerDirector`` with the explicit handlers we want, - ``source_address`` for binding, ``timeout`` and our custom - `User-Agent` - """ - - printer('Timeout set to %d' % timeout, debug=True) - - if source_address: - source_address_tuple = (source_address, 0) - printer('Binding to source address: %r' % (source_address_tuple,), - debug=True) - else: - source_address_tuple = None - - handlers = [ - ProxyHandler(), - SpeedtestHTTPHandler(source_address=source_address_tuple, - timeout=timeout), - SpeedtestHTTPSHandler(source_address=source_address_tuple, - timeout=timeout), - HTTPDefaultErrorHandler(), - HTTPRedirectHandler(), - HTTPErrorProcessor() - ] - - opener = OpenerDirector() - opener.addheaders = [('User-agent', build_user_agent())] - - for handler in handlers: - opener.add_handler(handler) - - return opener - - -class GzipDecodedResponse(GZIP_BASE): - """A file-like object to decode a response encoded with the gzip - method, as described in RFC 1952. - - Largely copied from ``xmlrpclib``/``xmlrpc.client`` and modified - to work for py2.4-py3 - """ - def __init__(self, response): - # response doesn't support tell() and read(), required by - # GzipFile - if not gzip: - raise SpeedtestHTTPError('HTTP response body is gzip encoded, ' - 'but gzip support is not available') - IO = BytesIO or StringIO - self.io = IO() - while 1: - chunk = response.read(1024) - if len(chunk) == 0: - break - self.io.write(chunk) - self.io.seek(0) - gzip.GzipFile.__init__(self, mode='rb', fileobj=self.io) - - def close(self): - try: - gzip.GzipFile.close(self) - finally: - self.io.close() - - -def get_exception(): - """Helper function to work with py2.4-py3 for getting the current - exception in a try/except block - """ - return sys.exc_info()[1] - - -def distance(origin, destination): - """Determine distance between 2 sets of [lat,lon] in km""" - - lat1, lon1 = origin - lat2, lon2 = destination - radius = 6371 # km - - dlat = math.radians(lat2 - lat1) - dlon = math.radians(lon2 - lon1) - a = (math.sin(dlat / 2) * math.sin(dlat / 2) + - math.cos(math.radians(lat1)) * - math.cos(math.radians(lat2)) * math.sin(dlon / 2) * - math.sin(dlon / 2)) - c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a)) - d = radius * c - - return d - - -def build_user_agent(): - """Build a Mozilla/5.0 compatible User-Agent string""" - - ua_tuple = ( - 'Mozilla/5.0', - '(%s; U; %s; en-us)' % (platform.platform(), - platform.architecture()[0]), - 'Python/%s' % platform.python_version(), - '(KHTML, like Gecko)', - 'speedtest-cli/%s' % __version__ - ) - user_agent = ' '.join(ua_tuple) - printer('User-Agent: %s' % user_agent, debug=True) - return user_agent - - -def build_request(url, data=None, headers=None, bump='0', secure=False): - """Build a urllib2 request object - - This function automatically adds a User-Agent header to all requests - - """ - - if not headers: - headers = {} - - if url[0] == ':': - scheme = ('http', 'https')[bool(secure)] - schemed_url = '%s%s' % (scheme, url) - else: - schemed_url = url - - if '?' in url: - delim = '&' - else: - delim = '?' - - # WHO YOU GONNA CALL? CACHE BUSTERS! - final_url = '%s%sx=%s.%s' % (schemed_url, delim, - int(timeit.time.time() * 1000), - bump) - - headers.update({ - 'Cache-Control': 'no-cache', - }) - - printer('%s %s' % (('GET', 'POST')[bool(data)], final_url), - debug=True) - - return Request(final_url, data=data, headers=headers) - - -def catch_request(request, opener=None): - """Helper function to catch common exceptions encountered when - establishing a connection with a HTTP/HTTPS request - - """ - - if opener: - _open = opener.open - else: - _open = urlopen - - try: - uh = _open(request) - if request.get_full_url() != uh.geturl(): - printer('Redirected to %s' % uh.geturl(), debug=True) - return uh, False - except HTTP_ERRORS: - e = get_exception() - return None, e - - -def get_response_stream(response): - """Helper function to return either a Gzip reader if - ``Content-Encoding`` is ``gzip`` otherwise the response itself - - """ - - try: - getheader = response.headers.getheader - except AttributeError: - getheader = response.getheader - - if getheader('content-encoding') == 'gzip': - return GzipDecodedResponse(response) - - return response - - -def get_attributes_by_tag_name(dom, tag_name): - """Retrieve an attribute from an XML document and return it in a - consistent format - - Only used with xml.dom.minidom, which is likely only to be used - with python versions older than 2.5 - """ - elem = dom.getElementsByTagName(tag_name)[0] - return dict(list(elem.attributes.items())) - - -def print_dots(shutdown_event): - """Built in callback function used by Thread classes for printing - status - """ - def inner(current, total, start=False, end=False): - if event_is_set(shutdown_event): - return - - sys.stdout.write('.') - if current + 1 == total and end is True: - sys.stdout.write('\n') - sys.stdout.flush() - return inner - - -def do_nothing(*args, **kwargs): - pass - - -class HTTPDownloader(threading.Thread): - """Thread class for retrieving a URL""" - - def __init__(self, i, request, start, timeout, opener=None, - shutdown_event=None): - threading.Thread.__init__(self) - self.request = request - self.result = [0] - self.starttime = start - self.timeout = timeout - self.i = i - if opener: - self._opener = opener.open - else: - self._opener = urlopen - - if shutdown_event: - self._shutdown_event = shutdown_event - else: - self._shutdown_event = FakeShutdownEvent() - - def run(self): - try: - if (timeit.default_timer() - self.starttime) <= self.timeout: - f = self._opener(self.request) - while (not event_is_set(self._shutdown_event) and - (timeit.default_timer() - self.starttime) <= - self.timeout): - self.result.append(len(f.read(10240))) - if self.result[-1] == 0: - break - f.close() - except IOError: - pass - except HTTP_ERRORS: - pass - - -class HTTPUploaderData(object): - """File like object to improve cutting off the upload once the timeout - has been reached - """ - - def __init__(self, length, start, timeout, shutdown_event=None): - self.length = length - self.start = start - self.timeout = timeout - - if shutdown_event: - self._shutdown_event = shutdown_event - else: - self._shutdown_event = FakeShutdownEvent() - - self._data = None - - self.total = [0] - - def pre_allocate(self): - chars = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ' - multiplier = int(round(int(self.length) / 36.0)) - IO = BytesIO or StringIO - try: - self._data = IO( - ('content1=%s' % - (chars * multiplier)[0:int(self.length) - 9] - ).encode() - ) - except MemoryError: - raise SpeedtestCLIError( - 'Insufficient memory to pre-allocate upload data. Please ' - 'use --no-pre-allocate' - ) - - @property - def data(self): - if not self._data: - self.pre_allocate() - return self._data - - def read(self, n=10240): - if ((timeit.default_timer() - self.start) <= self.timeout and - not event_is_set(self._shutdown_event)): - chunk = self.data.read(n) - self.total.append(len(chunk)) - return chunk - else: - raise SpeedtestUploadTimeout() - - def __len__(self): - return self.length - - -class HTTPUploader(threading.Thread): - """Thread class for putting a URL""" - - def __init__(self, i, request, start, size, timeout, opener=None, - shutdown_event=None): - threading.Thread.__init__(self) - self.request = request - self.request.data.start = self.starttime = start - self.size = size - self.result = 0 - self.timeout = timeout - self.i = i - - if opener: - self._opener = opener.open - else: - self._opener = urlopen - - if shutdown_event: - self._shutdown_event = shutdown_event - else: - self._shutdown_event = FakeShutdownEvent() - - def run(self): - request = self.request - try: - if ((timeit.default_timer() - self.starttime) <= self.timeout and - not event_is_set(self._shutdown_event)): - try: - f = self._opener(request) - except TypeError: - # PY24 expects a string or buffer - # This also causes issues with Ctrl-C, but we will concede - # for the moment that Ctrl-C on PY24 isn't immediate - request = build_request(self.request.get_full_url(), - data=request.data.read(self.size)) - f = self._opener(request) - f.read(11) - f.close() - self.result = sum(self.request.data.total) - else: - self.result = 0 - except (IOError, SpeedtestUploadTimeout): - self.result = sum(self.request.data.total) - except HTTP_ERRORS: - self.result = 0 - - -class SpeedtestResults(object): - """Class for holding the results of a speedtest, including: - - Download speed - Upload speed - Ping/Latency to test server - Data about server that the test was run against - - Additionally this class can return a result data as a dictionary or CSV, - as well as submit a POST of the result data to the speedtest.net API - to get a share results image link. - """ - - def __init__(self, download=0, upload=0, ping=0, server=None, client=None, - opener=None, secure=False): - self.download = download - self.upload = upload - self.ping = ping - if server is None: - self.server = {} - else: - self.server = server - self.client = client or {} - - self._share = None - self.timestamp = '%sZ' % datetime.datetime.utcnow().isoformat() - self.bytes_received = 0 - self.bytes_sent = 0 - - if opener: - self._opener = opener - else: - self._opener = build_opener() - - self._secure = secure - - def __repr__(self): - return repr(self.dict()) - - def share(self): - """POST data to the speedtest.net API to obtain a share results - link - """ - - if self._share: - return self._share - - download = int(round(self.download / 1000.0, 0)) - ping = int(round(self.ping, 0)) - upload = int(round(self.upload / 1000.0, 0)) - - # Build the request to send results back to speedtest.net - # We use a list instead of a dict because the API expects parameters - # in a certain order - api_data = [ - 'recommendedserverid=%s' % self.server['id'], - 'ping=%s' % ping, - 'screenresolution=', - 'promo=', - 'download=%s' % download, - 'screendpi=', - 'upload=%s' % upload, - 'testmethod=http', - 'hash=%s' % md5(('%s-%s-%s-%s' % - (ping, upload, download, '297aae72')) - .encode()).hexdigest(), - 'touchscreen=none', - 'startmode=pingselect', - 'accuracy=1', - 'bytesreceived=%s' % self.bytes_received, - 'bytessent=%s' % self.bytes_sent, - 'serverid=%s' % self.server['id'], - ] - - headers = {'Referer': 'http://c.speedtest.net/flash/speedtest.swf'} - request = build_request('://www.speedtest.net/api/api.php', - data='&'.join(api_data).encode(), - headers=headers, secure=self._secure) - f, e = catch_request(request, opener=self._opener) - if e: - raise ShareResultsConnectFailure(e) - - response = f.read() - code = f.code - f.close() - - if int(code) != 200: - raise ShareResultsSubmitFailure('Could not submit results to ' - 'speedtest.net') - - qsargs = parse_qs(response.decode()) - resultid = qsargs.get('resultid') - if not resultid or len(resultid) != 1: - raise ShareResultsSubmitFailure('Could not submit results to ' - 'speedtest.net') - - self._share = 'http://www.speedtest.net/result/%s.png' % resultid[0] - - return self._share - - def dict(self): - """Return dictionary of result data""" - - return { - 'download': self.download, - 'upload': self.upload, - 'ping': self.ping, - 'server': self.server, - 'timestamp': self.timestamp, - 'bytes_sent': self.bytes_sent, - 'bytes_received': self.bytes_received, - 'share': self._share, - 'client': self.client, - } - - @staticmethod - def csv_header(delimiter=','): - """Return CSV Headers""" - - row = ['Server ID', 'Sponsor', 'Server Name', 'Timestamp', 'Distance', - 'Ping', 'Download', 'Upload', 'Share', 'IP Address'] - out = StringIO() - writer = csv.writer(out, delimiter=delimiter, lineterminator='') - writer.writerow([to_utf8(v) for v in row]) - return out.getvalue() - - def csv(self, delimiter=','): - """Return data in CSV format""" - - data = self.dict() - out = StringIO() - writer = csv.writer(out, delimiter=delimiter, lineterminator='') - row = [data['server']['id'], data['server']['sponsor'], - data['server']['name'], data['timestamp'], - data['server']['d'], data['ping'], data['download'], - data['upload'], self._share or '', self.client['ip']] - writer.writerow([to_utf8(v) for v in row]) - return out.getvalue() - - def json(self, pretty=False): - """Return data in JSON format""" - - kwargs = {} - if pretty: - kwargs.update({ - 'indent': 4, - 'sort_keys': True - }) - return json.dumps(self.dict(), **kwargs) - - -class Speedtest(object): - """Class for performing standard speedtest.net testing operations""" - - def __init__(self, config=None, source_address=None, timeout=10, - secure=False, shutdown_event=None): - self.config = {} - - self._source_address = source_address - self._timeout = timeout - self._opener = build_opener(source_address, timeout) - - self._secure = secure - - if shutdown_event: - self._shutdown_event = shutdown_event - else: - self._shutdown_event = FakeShutdownEvent() - - self.get_config() - if config is not None: - self.config.update(config) - - self.servers = {} - self.closest = [] - self._best = {} - - self.results = SpeedtestResults( - client=self.config['client'], - opener=self._opener, - secure=secure, - ) - - @property - def best(self): - if not self._best: - self.get_best_server() - return self._best - - def get_config(self): - """Download the speedtest.net configuration and return only the data - we are interested in - """ - - headers = {} - if gzip: - headers['Accept-Encoding'] = 'gzip' - request = build_request('://www.speedtest.net/speedtest-config.php', - headers=headers, secure=self._secure) - uh, e = catch_request(request, opener=self._opener) - if e: - raise ConfigRetrievalError(e) - configxml_list = [] - - stream = get_response_stream(uh) - - while 1: - try: - configxml_list.append(stream.read(1024)) - except (OSError, EOFError): - raise ConfigRetrievalError(get_exception()) - if len(configxml_list[-1]) == 0: - break - stream.close() - uh.close() - - if int(uh.code) != 200: - return None - - configxml = ''.encode().join(configxml_list) - - printer('Config XML:\n%s' % configxml, debug=True) - - try: - try: - root = ET.fromstring(configxml) - except ET.ParseError: - e = get_exception() - raise SpeedtestConfigError( - 'Malformed speedtest.net configuration: %s' % e - ) - server_config = root.find('server-config').attrib - download = root.find('download').attrib - upload = root.find('upload').attrib - # times = root.find('times').attrib - client = root.find('client').attrib - - except AttributeError: - try: - root = DOM.parseString(configxml) - except ExpatError: - e = get_exception() - raise SpeedtestConfigError( - 'Malformed speedtest.net configuration: %s' % e - ) - server_config = get_attributes_by_tag_name(root, 'server-config') - download = get_attributes_by_tag_name(root, 'download') - upload = get_attributes_by_tag_name(root, 'upload') - # times = get_attributes_by_tag_name(root, 'times') - client = get_attributes_by_tag_name(root, 'client') - - ignore_servers = [ - int(i) for i in server_config['ignoreids'].split(',') if i - ] - - ratio = int(upload['ratio']) - upload_max = int(upload['maxchunkcount']) - up_sizes = [32768, 65536, 131072, 262144, 524288, 1048576, 7340032] - sizes = { - 'upload': up_sizes[ratio - 1:], - 'download': [350, 500, 750, 1000, 1500, 2000, 2500, - 3000, 3500, 4000] - } - - size_count = len(sizes['upload']) - - upload_count = int(math.ceil(upload_max / size_count)) - - counts = { - 'upload': upload_count, - 'download': int(download['threadsperurl']) - } - - threads = { - 'upload': int(upload['threads']), - 'download': int(server_config['threadcount']) * 2 - } - - length = { - 'upload': int(upload['testlength']), - 'download': int(download['testlength']) - } - - self.config.update({ - 'client': client, - 'ignore_servers': ignore_servers, - 'sizes': sizes, - 'counts': counts, - 'threads': threads, - 'length': length, - 'upload_max': upload_count * size_count - }) - - try: - self.lat_lon = (float(client['lat']), float(client['lon'])) - except ValueError: - raise SpeedtestConfigError( - 'Unknown location: lat=%r lon=%r' % - (client.get('lat'), client.get('lon')) - ) - - printer('Config:\n%r' % self.config, debug=True) - - return self.config - - def get_servers(self, servers=None, exclude=None): - """Retrieve a the list of speedtest.net servers, optionally filtered - to servers matching those specified in the ``servers`` argument - """ - if servers is None: - servers = [] - - if exclude is None: - exclude = [] - - self.servers.clear() - - for server_list in (servers, exclude): - for i, s in enumerate(server_list): - try: - server_list[i] = int(s) - except ValueError: - raise InvalidServerIDType( - '%s is an invalid server type, must be int' % s - ) - - urls = [ - '://www.speedtest.net/speedtest-servers-static.php', - 'http://c.speedtest.net/speedtest-servers-static.php', - '://www.speedtest.net/speedtest-servers.php', - 'http://c.speedtest.net/speedtest-servers.php', - ] - - headers = {} - if gzip: - headers['Accept-Encoding'] = 'gzip' - - errors = [] - for url in urls: - try: - request = build_request( - '%s?threads=%s' % (url, - self.config['threads']['download']), - headers=headers, - secure=self._secure - ) - uh, e = catch_request(request, opener=self._opener) - if e: - errors.append('%s' % e) - raise ServersRetrievalError() - - stream = get_response_stream(uh) - - serversxml_list = [] - while 1: - try: - serversxml_list.append(stream.read(1024)) - except (OSError, EOFError): - raise ServersRetrievalError(get_exception()) - if len(serversxml_list[-1]) == 0: - break - - stream.close() - uh.close() - - if int(uh.code) != 200: - raise ServersRetrievalError() - - serversxml = ''.encode().join(serversxml_list) - - printer('Servers XML:\n%s' % serversxml, debug=True) - - try: - try: - try: - root = ET.fromstring(serversxml) - except ET.ParseError: - e = get_exception() - raise SpeedtestServersError( - 'Malformed speedtest.net server list: %s' % e - ) - elements = etree_iter(root, 'server') - except AttributeError: - try: - root = DOM.parseString(serversxml) - except ExpatError: - e = get_exception() - raise SpeedtestServersError( - 'Malformed speedtest.net server list: %s' % e - ) - elements = root.getElementsByTagName('server') - except (SyntaxError, xml.parsers.expat.ExpatError): - raise ServersRetrievalError() - - for server in elements: - try: - attrib = server.attrib - except AttributeError: - attrib = dict(list(server.attributes.items())) - - if servers and int(attrib.get('id')) not in servers: - continue - - if (int(attrib.get('id')) in self.config['ignore_servers'] - or int(attrib.get('id')) in exclude): - continue - - try: - d = distance(self.lat_lon, - (float(attrib.get('lat')), - float(attrib.get('lon')))) - except Exception: - continue - - attrib['d'] = d - - try: - self.servers[d].append(attrib) - except KeyError: - self.servers[d] = [attrib] - - break - - except ServersRetrievalError: - continue - - if (servers or exclude) and not self.servers: - raise NoMatchedServers() - - return self.servers - - def set_mini_server(self, server): - """Instead of querying for a list of servers, set a link to a - speedtest mini server - """ - - urlparts = urlparse(server) - - name, ext = os.path.splitext(urlparts[2]) - if ext: - url = os.path.dirname(server) - else: - url = server - - request = build_request(url) - uh, e = catch_request(request, opener=self._opener) - if e: - raise SpeedtestMiniConnectFailure('Failed to connect to %s' % - server) - else: - text = uh.read() - uh.close() - - extension = re.findall('upload_?[Ee]xtension: "([^"]+)"', - text.decode()) - if not extension: - for ext in ['php', 'asp', 'aspx', 'jsp']: - try: - f = self._opener.open( - '%s/speedtest/upload.%s' % (url, ext) - ) - except Exception: - pass - else: - data = f.read().strip().decode() - if (f.code == 200 and - len(data.splitlines()) == 1 and - re.match('size=[0-9]', data)): - extension = [ext] - break - if not urlparts or not extension: - raise InvalidSpeedtestMiniServer('Invalid Speedtest Mini Server: ' - '%s' % server) - - self.servers = [{ - 'sponsor': 'Speedtest Mini', - 'name': urlparts[1], - 'd': 0, - 'url': '%s/speedtest/upload.%s' % (url.rstrip('/'), extension[0]), - 'latency': 0, - 'id': 0 - }] - - return self.servers - - def get_closest_servers(self, limit=5): - """Limit servers to the closest speedtest.net servers based on - geographic distance - """ - - if not self.servers: - self.get_servers() - - for d in sorted(self.servers.keys()): - for s in self.servers[d]: - self.closest.append(s) - if len(self.closest) == limit: - break - else: - continue - break - - printer('Closest Servers:\n%r' % self.closest, debug=True) - return self.closest - - def get_best_server(self, servers=None): - """Perform a speedtest.net "ping" to determine which speedtest.net - server has the lowest latency - """ - - if not servers: - if not self.closest: - servers = self.get_closest_servers() - servers = self.closest - - if self._source_address: - source_address_tuple = (self._source_address, 0) - else: - source_address_tuple = None - - user_agent = build_user_agent() - - results = {} - for server in servers: - cum = [] - url = os.path.dirname(server['url']) - stamp = int(timeit.time.time() * 1000) - latency_url = '%s/latency.txt?x=%s' % (url, stamp) - for i in range(0, 3): - this_latency_url = '%s.%s' % (latency_url, i) - printer('%s %s' % ('GET', this_latency_url), - debug=True) - urlparts = urlparse(latency_url) - try: - if urlparts[0] == 'https': - h = SpeedtestHTTPSConnection( - urlparts[1], - source_address=source_address_tuple - ) - else: - h = SpeedtestHTTPConnection( - urlparts[1], - source_address=source_address_tuple - ) - headers = {'User-Agent': user_agent} - path = '%s?%s' % (urlparts[2], urlparts[4]) - start = timeit.default_timer() - h.request("GET", path, headers=headers) - r = h.getresponse() - total = (timeit.default_timer() - start) - except HTTP_ERRORS: - e = get_exception() - printer('ERROR: %r' % e, debug=True) - cum.append(3600) - continue - - text = r.read(9) - if int(r.status) == 200 and text == 'test=test'.encode(): - cum.append(total) - else: - cum.append(3600) - h.close() - - avg = round((sum(cum) / 6) * 1000.0, 3) - results[avg] = server - - try: - fastest = sorted(results.keys())[0] - except IndexError: - raise SpeedtestBestServerFailure('Unable to connect to servers to ' - 'test latency.') - best = results[fastest] - best['latency'] = fastest - - self.results.ping = fastest - self.results.server = best - - self._best.update(best) - printer('Best Server:\n%r' % best, debug=True) - return best - - def download(self, callback=do_nothing, threads=None): - """Test download speed against speedtest.net - - A ``threads`` value of ``None`` will fall back to those dictated - by the speedtest.net configuration - """ - - urls = [] - for size in self.config['sizes']['download']: - for _ in range(0, self.config['counts']['download']): - urls.append('%s/random%sx%s.jpg' % - (os.path.dirname(self.best['url']), size, size)) - - request_count = len(urls) - requests = [] - for i, url in enumerate(urls): - requests.append( - build_request(url, bump=i, secure=self._secure) - ) - - max_threads = threads or self.config['threads']['download'] - in_flight = {'threads': 0} - - def producer(q, requests, request_count): - for i, request in enumerate(requests): - thread = HTTPDownloader( - i, - request, - start, - self.config['length']['download'], - opener=self._opener, - shutdown_event=self._shutdown_event - ) - while in_flight['threads'] >= max_threads: - timeit.time.sleep(0.001) - thread.start() - q.put(thread, True) - in_flight['threads'] += 1 - callback(i, request_count, start=True) - - finished = [] - - def consumer(q, request_count): - _is_alive = thread_is_alive - while len(finished) < request_count: - thread = q.get(True) - while _is_alive(thread): - thread.join(timeout=0.001) - in_flight['threads'] -= 1 - finished.append(sum(thread.result)) - callback(thread.i, request_count, end=True) - - q = Queue(max_threads) - prod_thread = threading.Thread(target=producer, - args=(q, requests, request_count)) - cons_thread = threading.Thread(target=consumer, - args=(q, request_count)) - start = timeit.default_timer() - prod_thread.start() - cons_thread.start() - _is_alive = thread_is_alive - while _is_alive(prod_thread): - prod_thread.join(timeout=0.001) - while _is_alive(cons_thread): - cons_thread.join(timeout=0.001) - - stop = timeit.default_timer() - self.results.bytes_received = sum(finished) - self.results.download = ( - (self.results.bytes_received / (stop - start)) * 8.0 - ) - if self.results.download > 100000: - self.config['threads']['upload'] = 8 - return self.results.download - - def upload(self, callback=do_nothing, pre_allocate=True, threads=None): - """Test upload speed against speedtest.net - - A ``threads`` value of ``None`` will fall back to those dictated - by the speedtest.net configuration - """ - - sizes = [] - - for size in self.config['sizes']['upload']: - for _ in range(0, self.config['counts']['upload']): - sizes.append(size) - - # request_count = len(sizes) - request_count = self.config['upload_max'] - - requests = [] - for i, size in enumerate(sizes): - # We set ``0`` for ``start`` and handle setting the actual - # ``start`` in ``HTTPUploader`` to get better measurements - data = HTTPUploaderData( - size, - 0, - self.config['length']['upload'], - shutdown_event=self._shutdown_event - ) - if pre_allocate: - data.pre_allocate() - - headers = {'Content-length': size} - requests.append( - ( - build_request(self.best['url'], data, secure=self._secure, - headers=headers), - size - ) - ) - - max_threads = threads or self.config['threads']['upload'] - in_flight = {'threads': 0} - - def producer(q, requests, request_count): - for i, request in enumerate(requests[:request_count]): - thread = HTTPUploader( - i, - request[0], - start, - request[1], - self.config['length']['upload'], - opener=self._opener, - shutdown_event=self._shutdown_event - ) - while in_flight['threads'] >= max_threads: - timeit.time.sleep(0.001) - thread.start() - q.put(thread, True) - in_flight['threads'] += 1 - callback(i, request_count, start=True) - - finished = [] - - def consumer(q, request_count): - _is_alive = thread_is_alive - while len(finished) < request_count: - thread = q.get(True) - while _is_alive(thread): - thread.join(timeout=0.001) - in_flight['threads'] -= 1 - finished.append(thread.result) - callback(thread.i, request_count, end=True) - - q = Queue(threads or self.config['threads']['upload']) - prod_thread = threading.Thread(target=producer, - args=(q, requests, request_count)) - cons_thread = threading.Thread(target=consumer, - args=(q, request_count)) - start = timeit.default_timer() - prod_thread.start() - cons_thread.start() - _is_alive = thread_is_alive - while _is_alive(prod_thread): - prod_thread.join(timeout=0.1) - while _is_alive(cons_thread): - cons_thread.join(timeout=0.1) - - stop = timeit.default_timer() - self.results.bytes_sent = sum(finished) - self.results.upload = ( - (self.results.bytes_sent / (stop - start)) * 8.0 - ) - return self.results.upload - - -def ctrl_c(shutdown_event): - """Catch Ctrl-C key sequence and set a SHUTDOWN_EVENT for our threaded - operations - """ - def inner(signum, frame): - shutdown_event.set() - printer('\nCancelling...', error=True) - sys.exit(0) - return inner - - -def version(): - """Print the version""" - - printer('speedtest-cli %s' % __version__) - printer('Python %s' % sys.version.replace('\n', '')) - sys.exit(0) - - -def csv_header(delimiter=','): - """Print the CSV Headers""" - - printer(SpeedtestResults.csv_header(delimiter=delimiter)) - sys.exit(0) - - -def parse_args(): - """Function to handle building and parsing of command line arguments""" - description = ( - 'Command line interface for testing internet bandwidth using ' - 'speedtest.net.\n' - '------------------------------------------------------------' - '--------------\n' - 'https://github.com/sivel/speedtest-cli') - - parser = ArgParser(description=description) - # Give optparse.OptionParser an `add_argument` method for - # compatibility with argparse.ArgumentParser - try: - parser.add_argument = parser.add_option - except AttributeError: - pass - parser.add_argument('--no-download', dest='download', default=True, - action='store_const', const=False, - help='Do not perform download test') - parser.add_argument('--no-upload', dest='upload', default=True, - action='store_const', const=False, - help='Do not perform upload test') - parser.add_argument('--single', default=False, action='store_true', - help='Only use a single connection instead of ' - 'multiple. This simulates a typical file ' - 'transfer.') - parser.add_argument('--bytes', dest='units', action='store_const', - const=('byte', 8), default=('bit', 1), - help='Display values in bytes instead of bits. Does ' - 'not affect the image generated by --share, nor ' - 'output from --json or --csv') - parser.add_argument('--share', action='store_true', - help='Generate and provide a URL to the speedtest.net ' - 'share results image, not displayed with --csv') - parser.add_argument('--simple', action='store_true', default=False, - help='Suppress verbose output, only show basic ' - 'information') - parser.add_argument('--csv', action='store_true', default=False, - help='Suppress verbose output, only show basic ' - 'information in CSV format. Speeds listed in ' - 'bit/s and not affected by --bytes') - parser.add_argument('--csv-delimiter', default=',', type=PARSER_TYPE_STR, - help='Single character delimiter to use in CSV ' - 'output. Default ","') - parser.add_argument('--csv-header', action='store_true', default=False, - help='Print CSV headers') - parser.add_argument('--json', action='store_true', default=False, - help='Suppress verbose output, only show basic ' - 'information in JSON format. Speeds listed in ' - 'bit/s and not affected by --bytes') - parser.add_argument('--list', action='store_true', - help='Display a list of speedtest.net servers ' - 'sorted by distance') - parser.add_argument('--server', type=PARSER_TYPE_INT, action='append', - help='Specify a server ID to test against. Can be ' - 'supplied multiple times') - parser.add_argument('--exclude', type=PARSER_TYPE_INT, action='append', - help='Exclude a server from selection. Can be ' - 'supplied multiple times') - parser.add_argument('--mini', help='URL of the Speedtest Mini server') - parser.add_argument('--source', help='Source IP address to bind to') - parser.add_argument('--timeout', default=10, type=PARSER_TYPE_FLOAT, - help='HTTP timeout in seconds. Default 10') - parser.add_argument('--secure', action='store_true', - help='Use HTTPS instead of HTTP when communicating ' - 'with speedtest.net operated servers') - parser.add_argument('--no-pre-allocate', dest='pre_allocate', - action='store_const', default=True, const=False, - help='Do not pre allocate upload data. Pre allocation ' - 'is enabled by default to improve upload ' - 'performance. To support systems with ' - 'insufficient memory, use this option to avoid a ' - 'MemoryError') - parser.add_argument('--version', action='store_true', - help='Show the version number and exit') - parser.add_argument('--debug', action='store_true', - help=ARG_SUPPRESS, default=ARG_SUPPRESS) - - options = parser.parse_args() - if isinstance(options, tuple): - args = options[0] - else: - args = options - return args - - -def validate_optional_args(args): - """Check if an argument was provided that depends on a module that may - not be part of the Python standard library. - - If such an argument is supplied, and the module does not exist, exit - with an error stating which module is missing. - """ - optional_args = { - 'json': ('json/simplejson python module', json), - 'secure': ('SSL support', HTTPSConnection), - } - - for arg, info in optional_args.items(): - if getattr(args, arg, False) and info[1] is None: - raise SystemExit('%s is not installed. --%s is ' - 'unavailable' % (info[0], arg)) - - -def printer(string, quiet=False, debug=False, error=False, **kwargs): - """Helper function print a string with various features""" - - if debug and not DEBUG: - return - - if debug: - if sys.stdout.isatty(): - out = '\033[1;30mDEBUG: %s\033[0m' % string - else: - out = 'DEBUG: %s' % string - else: - out = string - - if error: - kwargs['file'] = sys.stderr - - if not quiet: - print_(out, **kwargs) - - -def shell(): - """Run the full speedtest.net test""" - - global DEBUG - shutdown_event = threading.Event() - - signal.signal(signal.SIGINT, ctrl_c(shutdown_event)) - - args = parse_args() - - # Print the version and exit - if args.version: - version() - - if not args.download and not args.upload: - raise SpeedtestCLIError('Cannot supply both --no-download and ' - '--no-upload') - - if len(args.csv_delimiter) != 1: - raise SpeedtestCLIError('--csv-delimiter must be a single character') - - if args.csv_header: - csv_header(args.csv_delimiter) - - validate_optional_args(args) - - debug = getattr(args, 'debug', False) - if debug == 'SUPPRESSHELP': - debug = False - if debug: - DEBUG = True - - if args.simple or args.csv or args.json: - quiet = True - else: - quiet = False - - if args.csv or args.json: - machine_format = True - else: - machine_format = False - - # Don't set a callback if we are running quietly - if quiet or debug: - callback = do_nothing - else: - callback = print_dots(shutdown_event) - - printer('Retrieving speedtest.net configuration...', quiet) - try: - speedtest = Speedtest( - source_address=args.source, - timeout=args.timeout, - secure=args.secure - ) - except (ConfigRetrievalError,) + HTTP_ERRORS: - printer('Cannot retrieve speedtest configuration', error=True) - raise SpeedtestCLIError(get_exception()) - - if args.list: - try: - speedtest.get_servers() - except (ServersRetrievalError,) + HTTP_ERRORS: - printer('Cannot retrieve speedtest server list', error=True) - raise SpeedtestCLIError(get_exception()) - - for _, servers in sorted(speedtest.servers.items()): - for server in servers: - line = ('%(id)5s) %(sponsor)s (%(name)s, %(country)s) ' - '[%(d)0.2f km]' % server) - try: - printer(line) - except IOError: - e = get_exception() - if e.errno != errno.EPIPE: - raise - sys.exit(0) - - printer('Testing from %(isp)s (%(ip)s)...' % speedtest.config['client'], - quiet) - - if not args.mini: - printer('Retrieving speedtest.net server list...', quiet) - try: - speedtest.get_servers(servers=args.server, exclude=args.exclude) - except NoMatchedServers: - raise SpeedtestCLIError( - 'No matched servers: %s' % - ', '.join('%s' % s for s in args.server) - ) - except (ServersRetrievalError,) + HTTP_ERRORS: - printer('Cannot retrieve speedtest server list', error=True) - raise SpeedtestCLIError(get_exception()) - except InvalidServerIDType: - raise SpeedtestCLIError( - '%s is an invalid server type, must ' - 'be an int' % ', '.join('%s' % s for s in args.server) - ) - - if args.server and len(args.server) == 1: - printer('Retrieving information for the selected server...', quiet) - else: - printer('Selecting best server based on ping...', quiet) - speedtest.get_best_server() - elif args.mini: - speedtest.get_best_server(speedtest.set_mini_server(args.mini)) - - results = speedtest.results - - printer('Hosted by %(sponsor)s (%(name)s) [%(d)0.2f km]: ' - '%(latency)s ms' % results.server, quiet) - - if args.download: - printer('Testing download speed', quiet, - end=('', '\n')[bool(debug)]) - speedtest.download( - callback=callback, - threads=(None, 1)[args.single] - ) - printer('Download: %0.2f M%s/s' % - ((results.download / 1000.0 / 1000.0) / args.units[1], - args.units[0]), - quiet) - - - - - - else: - printer('Skipping download test', quiet) - - if args.upload: - printer('Testing upload speed', quiet, - end=('', '\n')[bool(debug)]) - speedtest.upload( - callback=callback, - pre_allocate=args.pre_allocate, - threads=(None, 1)[args.single] - ) - - printer('Upload: %0.2f M%s/s' % - ((results.upload / 1000.0 / 1000.0) / args.units[1], - args.units[0]), - quiet) - - - else: - printer('Skipping upload test', quiet) - - if args.upload and args.download: - uploadspeed = 'Upload: %0.2f M%s/s' %\ - ((results.upload / 1000.0 / 1000.0) / args.units[1], - args.units[0]) - downloadspeed = 'Download: %0.2f M%s/s' % \ - ((results.download / 1000.0 / 1000.0) / args.units[1], - args.units[0]) - - f = open("speedtest-output.sh", "w") - f.write("""#!/bin/sh -[ -r /etc/lsb-release ] && . /etc/lsb-release - -if [ -z "$DISTRIB_DESCRIPTION" ] && [ -x /usr/bin/lsb_release ]; then - # Fall back to using the very slow lsb_release utility - DISTRIB_DESCRIPTION=$(lsb_release -s -d) -fi - -printf "Welcome to brev.dev\n" -printf "\n" -printf " ##@@@#.\n" -printf " #@@@@@@@@@@.\n" -printf " #@@@. .@@@@%\n" -printf " #@% :@@@@-\n" -printf " #@ =@@@=\n" -printf "#@* +@@@= -=%@%=\n" -printf "#@: =@@@@ *@@@@@@@@\n" -printf "#@: =@@@@ -@@@= +@+@@@@@@+\n" -printf "#@: @@@@% @@@- @@: @@@@# @@@@@\n" -printf "#@@= @@@@+ %@@- +@@* *@@@@@@@@.\n" -printf " #@@- @@@@.%@@- #@@ .@@@@.. @@@@@= @@@.\n" -printf " #@@@= @@@@@@@@ @@@@@@@@@@% *@@@@- @@#\n" -printf " #@@@@ -@@@@@@ -@# #@@+ @@@@# @@@\n" -printf " #@@@% @@@@@@* -@ .@@ =@@@@: @@@\n" -printf " #@@@- @@@@@@ @@@: -# @ @@@@* @@@\n" -printf " #@@@- @@@@@@ :@@@@% -. @ @@@@ @@#\n" -printf " #@@@ +@@@@@= :@@@@@ - @@@@@ %@@\n" -printf " #@@@ ..@@@@ =@@@@@ - @@@@+ .@@-\n" -printf " #@@: :@@ @@@@@@ @@@@: -@@@\n" -printf " #@@ @ @@@@@@ @@@@@ .@@@\n" -printf " #@@. *% @@@@@: #+ @@ @@@.\n" -printf " #@@ @ @@@@# @ @@ @@@\n" -printf " #@@- = @@@@ @ +@# @ @@@.\n" -printf " #@@ . . @@@ . @@@ @@ *@@#\n" -printf " #@@ @@* - @+ #. -@@@ @@ @@+\n" -printf " #@@@ @@@: - .@ =@@@ @+ @@@\n" -printf " #@@@ *@@@ - @@ @@@@ @ %@@\n" -printf " #@@. +@@@ - @@* @@@@ @ -@@:\n" -printf " #@@+ .@@@ - @@@. :%@ @@%\n" -printf " #@@@ @@@: - .@@@ =. #@@\n" -printf " #@@@ @@@ - #@@ @@\n" -printf " #@@@ @@@ - @# .@@+\n" -printf " #@@@ @@@ . @: @@@\n" -printf " #@@+ @ + @: +@@:\n" -printf " #@@@ @ @ @@@\n" -printf " #@@@ % @ % = %@@\n" -printf " #@@@ *% @ # @ .@@@\n" -printf " #@@@. @ .@ .: +@ @@@+\n" -printf " #@@+ :@ .@ # @..=@@ +@@@-\n" -printf " #@@@ @= .@ # @@@@@@ @@@@\n" -printf " #@@ *@ @@ @ #@@@* %@@@@\n" -printf " #@@# @@ @- : +@* @@@@\n" -printf " #@@@. -@@ @- :: %* *@@@@\n" -printf " #@@@%%%@@@ @- : * @@@@=\n" -printf " #@@@@@@@@ *@- @ * @@@@@\n" -printf " #@@@@@@ @@ # @ =@@@@\n" -printf " #@@@@@= .@@ -# @ =@@@@@\n" -printf " #@@@@= @@@ # @@@@@@@@\n" -printf " #@@@@@@@@@@: .@ @@@@@@@.\n" -printf " #@ @@ @@@@@@@=\n" -printf " #@@@@# #@@@@@@=\n" -printf " #@@@@@. .@@@@@@@\n" -printf " #%@@@@@@@@@@@@@@-\n" -printf " ###@@@@@##\n" -printf "\n\nInternet Speed:\n" - - """) - f.write('printf "' + uploadspeed + '\n"\nprintf "' + downloadspeed + '\n"') - f.write(""" -printf "\nRunning %s (%s %s %s)\n" "$DISTRIB_DESCRIPTION" "$(uname -o)" "$(uname -r)" "$(uname -m)" - """) - - f.close() - - superstring = f""" -[ -r /etc/lsb-release ] && . /etc/lsb-release - -if [ -z "$DISTRIB_DESCRIPTION" ] && [ -x /usr/bin/lsb_release ]; then - # Fall back to using the very slow lsb_release utility - DISTRIB_DESCRIPTION=$(lsb_release -s -d) -fi - -printf "Welcome to brev.dev\n" -printf "\n" -printf " ##@@@#.\n" -printf " #@@@@@@@@@@.\n" -printf " #@@@. .@@@@%\n" -printf " #@% :@@@@-\n" -printf " #@ =@@@=\n" -printf "#@* +@@@= -=%@%=\n" -printf "#@: =@@@@ *@@@@@@@@\n" -printf "#@: =@@@@ -@@@= +@+@@@@@@+\n" -printf "#@: @@@@% @@@- @@: @@@@# @@@@@\n" -printf "#@@= @@@@+ %@@- +@@* *@@@@@@@@.\n" -printf " #@@- @@@@.%@@- #@@ .@@@@.. @@@@@= @@@.\n" -printf " #@@@= @@@@@@@@ @@@@@@@@@@% *@@@@- @@#\n" -printf " #@@@@ -@@@@@@ -@# #@@+ @@@@# @@@\n" -printf " #@@@% @@@@@@* -@ .@@ =@@@@: @@@\n" -printf " #@@@- @@@@@@ @@@: -# @ @@@@* @@@\n" -printf " #@@@- @@@@@@ :@@@@% -. @ @@@@ @@#\n" -printf " #@@@ +@@@@@= :@@@@@ - @@@@@ %@@\n" -printf " #@@@ ..@@@@ =@@@@@ - @@@@+ .@@-\n" -printf " #@@: :@@ @@@@@@ @@@@: -@@@\n" -printf " #@@ @ @@@@@@ @@@@@ .@@@\n" -printf " #@@. *% @@@@@: #+ @@ @@@.\n" -printf " #@@ @ @@@@# @ @@ @@@\n" -printf " #@@- = @@@@ @ +@# @ @@@.\n" -printf " #@@ . . @@@ . @@@ @@ *@@#\n" -printf " #@@ @@* - @+ #. -@@@ @@ @@+\n" -printf " #@@@ @@@: - .@ =@@@ @+ @@@\n" -printf " #@@@ *@@@ - @@ @@@@ @ %@@\n" -printf " #@@. +@@@ - @@* @@@@ @ -@@:\n" -printf " #@@+ .@@@ - @@@. :%@ @@%\n" -printf " #@@@ @@@: - .@@@ =. #@@\n" -printf " #@@@ @@@ - #@@ @@\n" -printf " #@@@ @@@ - @# .@@+\n" -printf " #@@@ @@@ . @: @@@\n" -printf " #@@+ @ + @: +@@:\n" -printf " #@@@ @ @ @@@\n" -printf " #@@@ % @ % = %@@\n" -printf " #@@@ *% @ # @ .@@@\n" -printf " #@@@. @ .@ .: +@ @@@+\n" -printf " #@@+ :@ .@ # @..=@@ +@@@-\n" -printf " #@@@ @= .@ # @@@@@@ @@@@\n" -printf " #@@ *@ @@ @ #@@@* %@@@@\n" -printf " #@@# @@ @- : +@* @@@@\n" -printf " #@@@. -@@ @- :: %* *@@@@\n" -printf " #@@@%%%@@@ @- : * @@@@=\n" -printf " #@@@@@@@@ *@- @ * @@@@@\n" -printf " #@@@@@@ @@ # @ =@@@@\n" -printf " #@@@@@= .@@ -# @ =@@@@@\n" -printf " #@@@@= @@@ # @@@@@@@@\n" -printf " #@@@@@@@@@@: .@ @@@@@@@.\n" -printf " #@ @@ @@@@@@@=\n" -printf " #@@@@# #@@@@@@=\n" -printf " #@@@@@. .@@@@@@@\n" -printf " #%@@@@@@@@@@@@@@-\n" -printf " ###@@@@@##\n" -printf "\n\nInternet Speed:\n" -printf "{uploadspeed}\n{downloadspeed}\n" -printf "\nRunning %s (%s %s %s)\n" "$DISTRIB_DESCRIPTION" "$(uname -o)" "$(uname -r)" "$(uname -m)" -""" - - printer('Results:\n%r' % results.dict(), debug=True) - - if not args.simple and args.share: - results.share() - - if args.simple: - printer('Ping: %s ms\nDownload: %0.2f M%s/s\nUpload: %0.2f M%s/s' % - (results.ping, - (results.download / 1000.0 / 1000.0) / args.units[1], - args.units[0], - (results.upload / 1000.0 / 1000.0) / args.units[1], - args.units[0])) - elif args.csv: - printer(results.csv(delimiter=args.csv_delimiter)) - elif args.json: - printer(results.json()) - - if args.share and not machine_format: - printer('Share results: %s' % results.share()) - - -def main(): - try: - shell() - except KeyboardInterrupt: - printer('\nCancelling...', error=True) - except (SpeedtestException, SystemExit): - e = get_exception() - # Ignore a successful exit, or argparse exit - if getattr(e, 'code', 1) not in (0, 2): - msg = '%s' % e - if not msg: - msg = '%r' % e - raise SystemExit('ERROR: %s' % msg) - - -if __name__ == '__main__': - main() - -def getRam(): - output=subprocess.run( - "free", - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - - return float(str(output.stdout).split()[7])/1000000 \ No newline at end of file diff --git a/pkg/cmd/optimizeinstances/optimizeinstances.go b/pkg/cmd/optimizeinstances/optimizeinstances.go deleted file mode 100644 index 24ef50b5..00000000 --- a/pkg/cmd/optimizeinstances/optimizeinstances.go +++ /dev/null @@ -1,566 +0,0 @@ -package optimizeinstances - -import ( - "context" - "encoding/base64" - "fmt" - "net/mail" - "strings" - "sync" - "time" - - "github.com/alecthomas/units" - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/config" - "github.com/aws/aws-sdk-go-v2/service/costexplorer" - "github.com/aws/aws-sdk-go-v2/service/dlm" - "github.com/aws/aws-sdk-go-v2/service/ec2" - ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" - "github.com/aws/aws-sdk-go-v2/service/iam" - "github.com/aws/aws-sdk-go-v2/service/pricing" - "github.com/aws/aws-sdk-go-v2/service/servicequotas" - - "github.com/brevdev/brev-cli/pkg/collections" - "github.com/brevdev/brev-cli/pkg/errors" - "github.com/brevdev/brev-cli/pkg/ids" - "github.com/brevdev/brev-cli/pkg/terminal" - "github.com/hashicorp/go-multierror" - "github.com/spf13/cobra" - "go.opentelemetry.io/otel" -) - -var ( - short = "Apply Brev cost optimizations to your instances. Enter the IDs of the instances you want to optimized." - long = "Apply Brev cost optimizations to your instances. Enter the IDs of the instances you want to optimized. We will apply everything from autostop to letting you scale between hardware configs" - example = "brev optimize-instances i-1234567890abcdef0 i-1234567890abcdef1" -) - -type optimizeInstancesStore interface{} - -type AWSClient struct { - location string - ec2Client *ec2.Client - quotaClient *servicequotas.Client - pricingClient *pricing.Client - costExplorerClient *costexplorer.Client - dlmClient *dlm.Client - iamClient *iam.Client -} -type LifecycleStatus string - -type Status struct { - LifecycleStatus LifecycleStatus -} - -type Instance struct { - ID ids.CloudProviderInstanceID - Hostname string - ImageID string - InstanceType string - DiskSize units.Base2Bytes - PubKeyFingerprint string - Status Status - MetaEndpointEnabled bool - MetaTagsEnabled bool - VPCID string - SubnetID string - Spot bool - Name string -} - -type KeyPair struct { - KeyFingerprint string - - // The name of the key pair. - KeyName string - - // The ID of the key pair. - KeyPairID string -} - -func NewCmdOptimizeInstances(t *terminal.Terminal, store optimizeInstancesStore) *cobra.Command { - cmd := &cobra.Command{ - Use: "optimize-instances", - Aliases: []string{"oi"}, - DisableFlagsInUseLine: true, - Short: short, - Long: long, - Example: example, - RunE: func(cmd *cobra.Command, args []string) error { - err := OptimizeInstances(t, args, store) - if err != nil { - return errors.WrapAndTrace(err) - } - return nil - }, - } - return cmd -} - -func NewCmdOptimize(t *terminal.Terminal, store optimizeInstancesStore) *cobra.Command { - cmd := &cobra.Command{ - Use: "optimize", - DisableFlagsInUseLine: true, - Short: short, - Long: long, - Example: example, - RunE: func(cmd *cobra.Command, args []string) error { - err := OptimizeInstances(t, args, store) - if err != nil { - return errors.WrapAndTrace(err) - } - return nil - }, - } - return cmd -} - -func OptimizeInstances(t *terminal.Terminal, args []string, _ optimizeInstancesStore) error { - userEmail, err := getUserEmail(t, "Please enter your email address: ") - if err != nil { - return errors.WrapAndTrace(err) - } - fmt.Print("This operation will modify your instances' user data. Do you want to continue? If you don't know what that is don't worry (y/n): ") - confirmed, err := askForConfirmation() - if err != nil { - return errors.WrapAndTrace(err) - } - if !confirmed { - fmt.Println("Aborting") - return nil - } - config, err := GetLiveFileConfig(context.TODO()) - if err != nil { - return errors.WrapAndTrace(err) - } - awsClient := GetAWSClient(config) - - var result error - var wg sync.WaitGroup - - for _, arg := range args { - wg.Add(1) - - arg := arg - go func() { - defer wg.Done() - err := updateInstance(t, awsClient, arg, userEmail) - if err != nil { - fmt.Println(t.Red("Failed to update instance: "), arg, t.Red(err.Error())) - } else { - fmt.Println(t.Green("Successfully updated: "), arg) - } - result = multierror.Append(result, err) - }() - } - wg.Wait() - - if merr, ok := result.(*multierror.Error); ok { - if len(merr.Errors) < len(args) { - return nil - } - } - return errors.WrapAndTrace(result) -} - -func getUserEmail(t *terminal.Terminal, promptText string) (string, error) { - fmt.Print(t.Yellow(promptText)) - var userEmail string - _, err := fmt.Scanln(&userEmail) - if err != nil { - return "", errors.WrapAndTrace(err) - } - if validEmail(userEmail) { - return userEmail, nil - } - return getUserEmail(t, "Please enter a valid email address: ") -} - -func validEmail(email string) bool { - _, err := mail.ParseAddress(email) - return err == nil -} - -func findAWSClientWithCorrectRegion(ctx context.Context, instanceID string) (AWSClient, error) { - regions := []string{"us-east-1", "us-east-2", "us-west-1", "us-west-2", "ap-east-1", "ap-south-1", "ap-northeast-3", "ap-northeast-2", "ap-southeast-1", "ap-southeast-2", "ap-northeast-1", "ca-central-1", "eu-central-1", "eu-west-1", "eu-west-2", "eu-west-3", "eu-north-1", "me-south-1", "sa-east-1", "af-south-1", "ap-southeast-3", "eu-south-1", "me-central-1", "us-gov-east-1", "us-gov-west-1"} - - config, err := GetLiveFileConfig(context.TODO()) - if err != nil { - return AWSClient{}, errors.WrapAndTrace(err) - } - for _, region := range regions { - config.Region = region - awsClient := GetAWSClient(config) - - _, err := awsClient.GetInstance(ctx, ids.CloudProviderInstanceID(instanceID)) - if err == nil { - return awsClient, nil - } - } - return AWSClient{}, errors.WrapAndTrace(errors.New("Could not find instance in any region")) -} - -func updateInstance(t *terminal.Terminal, awsClient AWSClient, instanceID string, userEmail string) error { - awsClient, err := findAWSClientWithCorrectRegion(context.TODO(), instanceID) - if err != nil { - return errors.WrapAndTrace(err) - } - fmt.Println(t.Red("Stopping instance: "), instanceID) - err = awsClient.StopInstance(context.Background(), ids.CloudProviderInstanceID(instanceID)) - if err != nil { - return errors.WrapAndTrace(err) - } - // TODO: append the actual current data - err = waitForEc2State(context.Background(), awsClient.ec2Client, ids.CloudProviderInstanceID(instanceID), ec2types.InstanceStateNameStopped, 300) - if err != nil { - return errors.WrapAndTrace(err) - } - userDataArgs := UserDataArgs{ - OnBootScript: `#!/bin/bash - bash -c "$(curl -fsSL https://raw.githubusercontent.com/brevdev/brev-cli/main/bin/install-latest.sh)" - brev postinstall ` + userEmail, - OnInstanceScript: ``, - OnOnceScript: ``, - } - newUserData := makeUserData(userDataArgs) - fmt.Println(t.Yellow("Updating instance: "), instanceID) - err = awsClient.UpdateInstanceUserData(context.TODO(), ids.CloudProviderInstanceID(instanceID), newUserData) - if err != nil { - return errors.WrapAndTrace(err) - } - time.Sleep(5 * time.Second) - err = awsClient.StartInstance(context.Background(), ids.CloudProviderInstanceID(instanceID)) - if err != nil { - return errors.WrapAndTrace(err) - } - return nil -} - -func GetLiveFileConfig(ctx context.Context) (aws.Config, error) { - cfg, err := config.LoadDefaultConfig(ctx) - if err != nil { - return aws.Config{}, errors.WrapAndTrace(err) - } - - return cfg, nil -} - -func GetAWSClient(cfg aws.Config) AWSClient { - ec2Client := ec2.NewFromConfig(cfg) - pricingConfig := cfg.Copy() - pricingClient := pricing.NewFromConfig(pricingConfig) - quotaClient := servicequotas.NewFromConfig(cfg) - costExplorerCLient := costexplorer.NewFromConfig(cfg) - dlmClient := dlm.NewFromConfig(cfg) - iamClient := iam.NewFromConfig(cfg) - return NewAWSClient(ec2Client, pricingClient, quotaClient, costExplorerCLient, dlmClient, iamClient, pricingConfig.Region) -} - -func NewAWSClient(ec2Client *ec2.Client, pricingClient *pricing.Client, quotaClient *servicequotas.Client, costExplorerClient *costexplorer.Client, dlmClient *dlm.Client, iamClient *iam.Client, location string) AWSClient { - return AWSClient{ - ec2Client: ec2Client, - quotaClient: quotaClient, - pricingClient: pricingClient, - costExplorerClient: costExplorerClient, - location: location, - dlmClient: dlmClient, - iamClient: iamClient, - } -} - -func (a AWSClient) GetInstanceUserData(ctx context.Context, instanceID ids.CloudProviderInstanceID) (string, error) { - result, err := a.ec2Client.DescribeInstanceAttribute(ctx, &ec2.DescribeInstanceAttributeInput{ - InstanceId: aws.String(string(instanceID)), - Attribute: ec2types.InstanceAttributeNameUserData, - }) - if err != nil { - return "", errors.WrapAndTrace(err) - } - if result.UserData == nil || result.UserData.Value == nil { - return "", nil - } - return *result.UserData.Value, nil -} - -func (a AWSClient) UpdateInstanceUserData(ctx context.Context, instanceID ids.CloudProviderInstanceID, userData string) error { - _, err := a.ec2Client.ModifyInstanceAttribute(ctx, &ec2.ModifyInstanceAttributeInput{ - InstanceId: aws.String(string(instanceID)), - UserData: &ec2types.BlobAttributeValue{ - Value: []byte(userData), - }, - }) - if err != nil { - return errors.WrapAndTrace(err) - } - return nil -} - -func waitForEc2State(ctx context.Context, client *ec2.Client, instanceID ids.CloudProviderInstanceID, state ec2types.InstanceStateName, maxWait int) error { - tracer := otel.GetTracerProvider().Tracer("") - ctx, span := tracer.Start(ctx, "ec2.waitForState") - defer span.End() - // create vars here b/c they are used outside the loop and go is lexically scoped - var res *ec2.DescribeInstancesOutput - var err error - for maxWait > 0 { - res, err = client.DescribeInstances(ctx, &ec2.DescribeInstancesInput{ - InstanceIds: []string{string(instanceID)}, - }) - if err != nil { - return errors.WrapAndTrace(err) - } - if len(res.Reservations) == 0 { - return errors.New("no reservations found") - } - if len(res.Reservations[0].Instances) == 0 { - return errors.New("no instances found") - } - if res.Reservations[0].Instances[0].State.Name == state { - return nil - } - time.Sleep(time.Second) - maxWait-- - } - return fmt.Errorf("timeout waiting for state %s, current state %s", state, res.Reservations[0].Instances[0].State.Name) -} - -func (a AWSClient) StopInstance(ctx context.Context, instanceID ids.CloudProviderInstanceID) error { - _, err := a.ec2Client.StopInstances(ctx, &ec2.StopInstancesInput{ - InstanceIds: []string{string(instanceID)}, - Force: aws.Bool(true), - }) - return errors.WrapAndTrace(handleAWSError(err)) -} - -func (a AWSClient) StartInstance(ctx context.Context, instanceID ids.CloudProviderInstanceID) error { - _, err := a.ec2Client.StartInstances(ctx, &ec2.StartInstancesInput{ - InstanceIds: []string{string(instanceID)}, - }) - return errors.WrapAndTrace(handleAWSError(err)) -} - -var ErrInstanceNotFound = fmt.Errorf("instance not found") - -func handleAWSError(e error) error { - if e == nil { - return nil - } - if strings.Contains(e.Error(), "InvalidInstanceID.NotFound") { - return ErrInstanceNotFound - } else { - return e - } -} - -func (a AWSClient) GetInstance(ctx context.Context, instanceID ids.CloudProviderInstanceID) (*Instance, error) { - res, err := a.ec2Client.DescribeInstances(ctx, &ec2.DescribeInstancesInput{ - InstanceIds: []string{string(instanceID)}, - }) - if err != nil { - return nil, errors.WrapAndTrace(handleAWSError(err)) - } - if len(res.Reservations) == 0 { - return nil, errors.WrapAndTrace(ErrInstanceNotFound) - } - if len(res.Reservations[0].Instances) == 0 { - return nil, errors.WrapAndTrace(ErrInstanceNotFound) - } - instance := res.Reservations[0].Instances[0] - - vols, err := a.GetInstanceVolumes(ctx, instanceID) - if err != nil { - return nil, errors.WrapAndTrace(err) - } - if len(vols) == 0 { - return nil, errors.New("no volumes found") - } - diskSize := Int32GiBToUnit(*vols[0].Size) - - keyFingerprint := "" - if instance.KeyName != nil { - var keyPair *KeyPair - keyPair, err = a.GetKeyPairByName(ctx, *instance.KeyName) - if err != nil { - return nil, errors.WrapAndTrace(err) - } - keyFingerprint = keyPair.KeyFingerprint - } - - status, err := AWSInstanceStateToLifecyclState(*instance.State) - if err != nil { - return nil, errors.WrapAndTrace(err) - } - - i := AWSInstanceToInstanceAttrs(instance, status, diskSize, keyFingerprint) - if err != nil { - return nil, errors.WrapAndTrace(err) - } - return &i, nil -} - -func AWSInstanceToInstanceAttrs(instance ec2types.Instance, status LifecycleStatus, diskSize units.Base2Bytes, publicKeyFingerprint string) Instance { - return Instance{ - ID: ids.CloudProviderInstanceID(*instance.InstanceId), - Hostname: collections.ValueOrZero(instance.PublicDnsName), - ImageID: collections.ValueOrZero(instance.ImageId), - InstanceType: string(instance.InstanceType), - DiskSize: diskSize, - PubKeyFingerprint: publicKeyFingerprint, - Status: Status{LifecycleStatus: status}, - MetaEndpointEnabled: collections.ValueOrZero(instance.MetadataOptions).HttpEndpoint == ec2types.InstanceMetadataEndpointStateEnabled, - MetaTagsEnabled: collections.ValueOrZero(instance.MetadataOptions).InstanceMetadataTags == ec2types.InstanceMetadataTagsStateEnabled, - VPCID: collections.ValueOrZero(instance.VpcId), - SubnetID: collections.ValueOrZero(instance.SubnetId), - Spot: collections.ValueOrZero(instance.SpotInstanceRequestId) != "", - Name: getNameTag(instance.Tags), - } -} - -func getNameTag(tags []ec2types.Tag) string { - for _, tag := range tags { - if *tag.Key == "Name" { - return *tag.Value - } - } - return "" -} - -func Int32GiBToUnit(i int32) units.Base2Bytes { - return units.Base2Bytes(i) * units.GiB -} - -func (a AWSClient) GetInstanceVolumes(ctx context.Context, instanceID ids.CloudProviderInstanceID) ([]ec2types.Volume, error) { - result, err := a.ec2Client.DescribeVolumes(ctx, &ec2.DescribeVolumesInput{ - Filters: []ec2types.Filter{ - { - Name: aws.String("attachment.instance-id"), - Values: []string{string(instanceID)}, - }, - }, - }) - if err != nil { - return nil, errors.WrapAndTrace(err) - } - return result.Volumes, nil -} - -func AWSInstanceStateToLifecyclState(status ec2types.InstanceState) (LifecycleStatus, error) { - switch status.Name { - case ec2types.InstanceStateNamePending: - return LifecycleStatePending, nil - case ec2types.InstanceStateNameRunning: - return LifecycleStateRunning, nil - case ec2types.InstanceStateNameStopping: - return LifecycleStateStopping, nil - case ec2types.InstanceStateNameStopped: - return LifecycleStateStopped, nil - case ec2types.InstanceStateNameShuttingDown: - return LifecycleStateTerminating, nil - case ec2types.InstanceStateNameTerminated: - return LifecycleStateTerminated, nil - } - return "", fmt.Errorf("unknown instance state: %s", status.Name) -} - -const ( - LifecycleStatePending LifecycleStatus = "pending" - LifecycleStateRunning LifecycleStatus = "running" - LifecycleStateStopping LifecycleStatus = "stopping" - LifecycleStateStopped LifecycleStatus = "stopped" - LifecycleStateSuspending LifecycleStatus = "suspending" - LifecycleStateSuspended LifecycleStatus = "suspended" - LifecycleStateTerminating LifecycleStatus = "terminating" - LifecycleStateTerminated LifecycleStatus = "terminated" -) - -type UserDataArgs struct { - OnBootScript string - OnInstanceScript string - OnOnceScript string -} - -func (a AWSClient) GetKeyPairByName(ctx context.Context, name string) (*KeyPair, error) { - result, err := a.ec2Client.DescribeKeyPairs(ctx, &ec2.DescribeKeyPairsInput{ - KeyNames: []string{name}, - }) - if err != nil { - return nil, errors.WrapAndTrace(err) - } - if len(result.KeyPairs) == 0 { - return nil, errors.New("no key pairs found") - } - keyPair := result.KeyPairs[0] - return &KeyPair{ - KeyFingerprint: *keyPair.KeyFingerprint, - KeyName: *keyPair.KeyName, - KeyPairID: collections.ValueOrZero(keyPair.KeyPairId), // undo when local stack fixes - }, nil -} - -const devPlaneUserDataTemplate = `Content-Type: multipart/mixed; boundary="===============7279599212584821875==" -MIME-Version: 1.0 - ---===============7279599212584821875== -Content-Type: text/x-shellscript-per-boot; charset="utf-8" -MIME-Version: 1.0 -Content-Transfer-Encoding: base64 -Content-Disposition: attachment; filename="always.sh" - - -%s ---===============7279599212584821875== -Content-Type: text/x-shellscript-per-instance; charset="utf-8" -MIME-Version: 1.0 -Content-Transfer-Encoding: base64 -Content-Disposition: attachment; filename="instance.sh" - -%s - ---===============7279599212584821875== -Content-Type: text/x-shellscript-per-once; charset="utf-8" -MIME-Version: 1.0 -Content-Transfer-Encoding: base64 -Content-Disposition: attachment; filename="once.sh" - -%s - ---===============7279599212584821875==--` - -func makeUserData(args UserDataArgs) string { - userdata := fmt.Sprintf(devPlaneUserDataTemplate, - base64.StdEncoding.EncodeToString([]byte(args.OnBootScript)), - base64.StdEncoding.EncodeToString([]byte(args.OnInstanceScript)), - base64.StdEncoding.EncodeToString([]byte(args.OnOnceScript))) - - return userdata // base64.StdEncoding.EncodeToString( -} - -func askForConfirmation() (bool, error) { - var response string - _, err := fmt.Scanln(&response) - if err != nil { - return false, errors.WrapAndTrace(err) - } - okayResponses := []string{"y", "Y", "yes", "Yes", "YES"} - nokayResponses := []string{"n", "N", "no", "No", "NO"} - if containsString(okayResponses, response) { - return true, nil - } else if containsString(nokayResponses, response) { - return false, nil - } - fmt.Println("Please type yes or no and then press enter:") - return askForConfirmation() -} - -func containsString(slice []string, element string) bool { - return !(posString(slice, element) == -1) -} - -func posString(slice []string, element string) int { - for index, elem := range slice { - if elem == element { - return index - } - } - return -1 -} diff --git a/pkg/cmd/postinstall/postinstall.go b/pkg/cmd/postinstall/postinstall.go deleted file mode 100644 index 50ecbeb1..00000000 --- a/pkg/cmd/postinstall/postinstall.go +++ /dev/null @@ -1,113 +0,0 @@ -package postinstall - -import ( - "github.com/spf13/cobra" - - "github.com/brevdev/brev-cli/pkg/autostartconf" - breverrors "github.com/brevdev/brev-cli/pkg/errors" - "github.com/brevdev/brev-cli/pkg/terminal" -) - -var ( - short = "TODO" - long = "TODO" - example = "TODO" -) - -type postinstallStore interface { - autostartconf.AutoStartStore - RegisterNotificationEmail(string) error - WriteEmail(email string) error -} - -func NewCmdpostinstall(_ *terminal.Terminal, store postinstallStore) *cobra.Command { - // var email string - - cmd := &cobra.Command{ - Use: "postinstall", - // DisableFlagsInUseLine: true, - Short: short, - Long: long, - Example: example, - RunE: func(cmd *cobra.Command, args []string) error { - email := "" - if len(args) > 0 { - email = args[0] - } - err := Runpostinstall( - store, - email, - ) - if err != nil { - return breverrors.WrapAndTrace(err) - } - return nil - }, - } - - return cmd -} - -func NewCMDOptimizeThis(_ *terminal.Terminal, store postinstallStore) *cobra.Command { - // var email string - - cmd := &cobra.Command{ - Use: "optimize-this", - // DisableFlagsInUseLine: true, - Short: short, - Long: long, - Example: example, - RunE: func(cmd *cobra.Command, args []string) error { - email := "" - if len(args) > 0 { - email = args[0] - } - err := Runpostinstall( - store, - email, - ) - if err != nil { - return breverrors.WrapAndTrace(err) - } - return nil - }, - } - - return cmd -} - -func Runpostinstall( - store postinstallStore, - email string, -) error { - if email == "" { - email = terminal.PromptGetInput(terminal.PromptContent{ - Label: "Email: ", - ErrorMsg: "error", - }) - } - - err := store.WriteEmail(email) - if err != nil { - return breverrors.WrapAndTrace(err) - } - - err = store.RegisterNotificationEmail(email) - if err != nil { - return breverrors.WrapAndTrace(err) - } - - brevmonConfigurer := autostartconf.NewBrevMonConfigure( - store, - false, - "10m", // todo pass brevmon args instead of individual args - "22", - ) - - err = brevmonConfigurer.Install() - if err != nil { - return breverrors.WrapAndTrace(err) - } - - return nil -} diff --git a/pkg/cmd/upgrade/upgrade.go b/pkg/cmd/upgrade/upgrade.go deleted file mode 100644 index 76dc4dee..00000000 --- a/pkg/cmd/upgrade/upgrade.go +++ /dev/null @@ -1,137 +0,0 @@ -package upgrade - -import ( - "os" - "os/exec" - "runtime" - - "github.com/samber/mo" - "github.com/spf13/cobra" - - breverrors "github.com/brevdev/brev-cli/pkg/errors" - "github.com/brevdev/brev-cli/pkg/terminal" -) - -var ( - short = "TODO" - long = "TODO" - example = "TODO" -) - -type upgradeStore interface { - GetOSUser() string - DownloadURL() (string, error) - DownloadBrevBinary(url, path string) error -} - -type uFunc func(ucmd upgradeCMD) error - -func NewCmdUpgrade(t *terminal.Terminal, store upgradeStore) *cobra.Command { - var debugger bool - cmd := &cobra.Command{ - Use: "upgrade", - DisableFlagsInUseLine: true, - Short: short, - Long: long, - Example: example, - RunE: func(cmd *cobra.Command, args []string) error { - err := RunUpgrade(upgradeCMD{ - t: t, - args: args, - store: store, - debugger: debugger, - }) - if err != nil { - return breverrors.WrapAndTrace(err) - } - t.Green("Upgrade complete!") - return nil - }, - } - cmd.Flags().BoolVarP(&debugger, "debugger", "d", false, "indicates command is being run in debugger") // todo remove -d - return cmd -} - -type saveOutput struct { - savedOutput []byte -} - -func (so *saveOutput) Write(p []byte) (n int, err error) { - so.savedOutput = append(so.savedOutput, p...) - n, err = os.Stdout.Write(p) - if err != nil { - return n, breverrors.WrapAndTrace(err) - } - return n, nil -} - -func runcmd(c string, args ...string) error { - var so saveOutput - cmd := exec.Command(c, args...) - cmd.Stdin = os.Stdin - cmd.Stdout = &so - cmd.Stderr = os.Stderr - err := cmd.Run() - if err != nil { - return breverrors.Wrap(err, string(so.savedOutput)) - } - - return nil -} - -var upgradeFuncs = map[string]uFunc{ - "darwin": func(ucmd upgradeCMD) error { - err := runcmd("brew", "upgrade", "brevdev/homebrew-brev/brev") - if err != nil { - return breverrors.WrapAndTrace(err) - } - - return nil - }, - "linux": func(ucmd upgradeCMD) error { - uid := ucmd.store.GetOSUser() - if uid != "0" && !ucmd.debugger { // root is uid 0 almost always - return breverrors.New("You must be root to upgrade, re run with: sudo brev upgrade") - } - // get cli download url - url, err := ucmd.store.DownloadURL() - if err != nil { - return breverrors.WrapAndTrace(err) - } - // download binary - err = ucmd.store.DownloadBrevBinary( - url, - "/usr/local/bin/brev", - ) - if err != nil { - return breverrors.WrapAndTrace(err) - } - return nil - }, -} - -func getUpgradeFunc() uFunc { - upgradeFunc, ok := upgradeFuncs[runtime.GOOS] - return mo.TupleToOption(upgradeFunc, ok).OrEmpty() -} - -type upgradeCMD struct { - t *terminal.Terminal - args []string - store upgradeStore - debugger bool -} - -func RunUpgrade(ucmd upgradeCMD) error { - upgradeFunc := getUpgradeFunc() - if upgradeFunc == nil { - return breverrors.New("unsupported OS") - } - - err := upgradeFunc(ucmd) - if err != nil { - return breverrors.WrapAndTrace(err) - } - - return nil -} diff --git a/pkg/store/autostop.go b/pkg/store/autostop.go deleted file mode 100644 index 8bd232eb..00000000 --- a/pkg/store/autostop.go +++ /dev/null @@ -1,65 +0,0 @@ -package store - -import ( - breverrors "github.com/brevdev/brev-cli/pkg/errors" -) - -const pathDownloadURL = "api/autostop/cli-download-url" - -func (n NoAuthHTTPStore) DownloadURL() (string, error) { - res, err := n.noAuthHTTPClient.restyClient.R(). - Get(pathDownloadURL) - if err != nil { - return "", breverrors.WrapAndTrace(err) - } - if res.IsError() { - return "", NewHTTPResponseError(res) - } - return res.String(), nil -} - -const pathRegisterNotificationEmail = "api/autostop/register" - -func (n NoAuthHTTPStore) RegisterNotificationEmail(email string) error { - res, err := n.noAuthHTTPClient.restyClient.R(). - SetHeader("Content-Type", "application/json"). - SetBody(map[string]any{ - "email": email, - }). - Post(pathRegisterNotificationEmail) - if err != nil { - return breverrors.WrapAndTrace(err) - } - if res.IsError() { - return NewHTTPResponseError(res) - } - - return nil -} - -const pathRecordAutoStop = "api/autostop/record" - -type RecordAutopstopBody struct { - Email string - InstanceType string - Region string - Name string - EnvID string -} - -func (n NoAuthHTTPStore) RecordAutoStop( - recordAutopstopBody RecordAutopstopBody, -) error { - res, err := n.noAuthHTTPClient.restyClient.R(). - SetHeader("Content-Type", "application/json"). - SetBody(recordAutopstopBody). - Post(pathRecordAutoStop) - if err != nil { - return breverrors.WrapAndTrace(err) - } - if res.IsError() { - return NewHTTPResponseError(res) - } - - return nil -} diff --git a/pkg/store/workspace.go b/pkg/store/workspace.go index be558d3e..8461431a 100644 --- a/pkg/store/workspace.go +++ b/pkg/store/workspace.go @@ -427,23 +427,6 @@ func (s AuthHTTPStore) StopWorkspace(workspaceID string) (*entity.Workspace, err return &result, nil } -func (s AuthHTTPStore) AutoStopWorkspace(workspaceID string) (*entity.Workspace, error) { - var result entity.Workspace - res, err := s.authHTTPClient.restyClient.R(). - SetHeader("Content-Type", "application/json"). - SetQueryParam("autoStop", "true"). - SetPathParam(workspaceIDParamName, workspaceID). - SetResult(&result). - Put(workspaceStopPath) - if err != nil { - return nil, breverrors.WrapAndTrace(err) - } - if res.IsError() { - return nil, NewHTTPResponseError(res) - } - return &result, nil -} - var ( workspaceStartPathPattern = fmt.Sprintf("%s/start", workspacePathPattern) workspaceStartPath = fmt.Sprintf(workspaceStartPathPattern, fmt.Sprintf("{%s}", workspaceIDParamName))