diff --git a/pkg/cmd/portforward/portforward.go b/pkg/cmd/portforward/portforward.go index 3aefbd5c..6083132d 100644 --- a/pkg/cmd/portforward/portforward.go +++ b/pkg/cmd/portforward/portforward.go @@ -45,7 +45,7 @@ func NewCmdPortForwardSSH(pfStore PortforwardStore, t *terminal.Terminal) *cobra if port == "" { port = startInput(t) } - err := RunPortforward(pfStore, args[0], port. useHost) + err := RunPortforward(pfStore, args[0], port, useHost) if err != nil { return breverrors.WrapAndTrace(err) } @@ -108,21 +108,37 @@ func ConvertNametoSSHName(store PortforwardStore, workspaceNameOrID string, useH } func RunSSHPortForward(forwardType string, localPort string, remotePort string, sshName string) (*os.Process, error) { - signals := make(chan os.Signal, 1) - signal.Notify(signals, os.Interrupt) - defer signal.Stop(signals) - - portMapping := fmt.Sprintf("%s:127.0.0.1:%s", localPort, remotePort) - cmdSHH := exec.Command("ssh", "-T", forwardType, portMapping, sshName, "-N") //nolint:gosec // variables are sanitzed or user specified - cmdSHH.Stdin = os.Stdin - fmt.Println("portforwarding...") - fmt.Printf("localhost:%s -> %s:%s\n", localPort, sshName, remotePort) - out, err := cmdSHH.CombinedOutput() - if err != nil { - return nil, breverrors.Wrap(err, string(out)) - } + signals := make(chan os.Signal, 1) + signal.Notify(signals, os.Interrupt) + defer signal.Stop(signals) + + portMapping := fmt.Sprintf("%s:127.0.0.1:%s", localPort, remotePort) + + homeDir, err := os.UserHomeDir() + if err != nil { + return nil, breverrors.Wrap(err, "failed to get user home directory") + } + + keyPath := filepath.Join(homeDir, ".brev", "brev.pem") + + if _, err := os.Stat(keyPath); os.IsNotExist(err) { + return nil, breverrors.Wrap(err, fmt.Sprintf("SSH key not found at %s. Please ensure your Brev SSH key is properly set up.", keyPath)) + } + + cmdSHH := exec.Command("ssh", "-i", keyPath, "-T", forwardType, portMapping, sshName, "-N") + cmdSHH.Stdin = os.Stdin + cmdSHH.Stdout = os.Stdout + cmdSHH.Stderr = os.Stderr + + fmt.Println("Port forwarding...") + fmt.Printf("localhost:%s -> %s:%s\n", localPort, sshName, remotePort) + + err = cmdSHH.Start() + if err != nil { + return nil, breverrors.Wrap(err, "Failed to start SSH command") + } - return cmdSHH.Process, nil + return cmdSHH.Process, nil } func startInput(t *terminal.Terminal) string {