From a57f8196d2326925f769d1dff0b2aa7be789e08f Mon Sep 17 00:00:00 2001 From: ytimocin Date: Sat, 18 Jan 2025 14:45:24 -0800 Subject: [PATCH] Show Context feature in Shell Signed-off-by: ytimocin --- pkg/config/config.go | 6 +++++ pkg/conn/cloud.go | 22 +++++++++++++++ pkg/conn/db.go | 56 ++++++++++++++++++++++++++++++++++----- pkg/conn/kubernetes.go | 40 ++++++++++++++++++++++++++++ pkg/conn/types.go | 3 +++ pkg/ui/conn/db/create.go | 2 +- pkg/ui/conn/k8s/create.go | 2 +- pkg/ui/conn/open.go | 1 - pkg/ui/shell.go | 39 ++++++++++++++++++++------- 9 files changed, 153 insertions(+), 18 deletions(-) diff --git a/pkg/config/config.go b/pkg/config/config.go index 3036d03..8d1570c 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -41,6 +41,12 @@ func init() { // loadConnections reads the connections from the JSON configuration file. func loadConnections() error { + if _, statErr := os.Stat(connectionsConfigFilePath); os.IsNotExist(statErr) { + if createErr := writeConnections(); createErr != nil { + return createErr + } + } + file, err := os.Open(connectionsConfigFilePath) if err != nil { return err diff --git a/pkg/conn/cloud.go b/pkg/conn/cloud.go index c44107c..5ae8507 100644 --- a/pkg/conn/cloud.go +++ b/pkg/conn/cloud.go @@ -169,6 +169,8 @@ func (c *BaseCloudConnection) FormatResultAsTable(result []byte) (string, error) return buffer.String(), nil } +var _ ConnectionInterface = &AzureConnection{} + type AzureConnection struct { BaseCloudConnection @@ -225,6 +227,26 @@ func (a *AzureConnection) GetContext() string { return context } +func (a *AzureConnection) GetFormattedContext() (string, error) { + if a.ResourceGroups == nil { + // Call SetContext to populate the resource groups. + // This is a fallback in case SetContext is not called. + if err := a.SetContext(); err != nil { + return "", fmt.Errorf("error getting context: %v", err) + } + } + + var buffer bytes.Buffer + table := tablewriter.NewWriter(&buffer) + table.SetHeader([]string{"Resource Group"}) + for _, rg := range a.ResourceGroups { + table.Append([]string{rg.Name}) + } + table.Render() + + return buffer.String(), nil +} + func NewAzureConnection(connnection *Connection) *AzureConnection { return &AzureConnection{ BaseCloudConnection: BaseCloudConnection{ diff --git a/pkg/conn/db.go b/pkg/conn/db.go index 4e6c0ae..d1bdd55 100644 --- a/pkg/conn/db.go +++ b/pkg/conn/db.go @@ -5,8 +5,8 @@ import ( "database/sql" "encoding/json" "fmt" - "regexp" + "github.com/charmbracelet/lipgloss" _ "github.com/lib/pq" "github.com/olekukonko/tablewriter" "github.com/prompt-ops/pops/pkg/ai" @@ -186,7 +186,7 @@ func (b *BaseRDBMSConnection) SetContext() error { column = AddQuotesIfNeeded(column) dataType = AddQuotesIfNeeded(dataType) - fullTableName := fmt.Sprintf(`%s."%s"`, schema, table) + fullTableName := fmt.Sprintf(`%s.%s`, schema, table) b.TablesAndColumns[fullTableName] = append(b.TablesAndColumns[fullTableName], ColumnDetail{ Name: column, DataType: dataType, @@ -201,10 +201,11 @@ func (b *BaseRDBMSConnection) SetContext() error { // AddQuotesIfNeeded adds quotes around the name if it contains capital letters. func AddQuotesIfNeeded(name string) string { - if regexp.MustCompile(`[A-Z]`).MatchString(name) { - return fmt.Sprintf(`"%s"`, name) - } - return name + // if regexp.MustCompile(`[A-Z]`).MatchString(name) { + // return fmt.Sprintf(`"%s"`, name) + // } + // return name + return fmt.Sprintf(`"%s"`, name) } // GetContext returns the tables and columns set by SetContext. @@ -218,6 +219,8 @@ func (b *BaseRDBMSConnection) GetContext() string { } context := fmt.Sprintf("%s Connection Details:\n", b.Connection.Type.GetSubtype()) + context += "Note to the AI: Please use all columns and table with double quotes as defined below.\n" + context += "Note to the AI: And please always use tables with aliases where possible.\n" context += "Database Schema:\n" // If still no tables found, return an error message. @@ -237,6 +240,45 @@ func (b *BaseRDBMSConnection) GetContext() string { return context } +// GetFormattedContext generates a pretty-printed string of the tables and columns. +func (b *BaseRDBMSConnection) GetFormattedContext() (string, error) { + if b.TablesAndColumns == nil { + // Call SetContext to populate the tables and columns. + if err := b.SetContext(); err != nil { + return "", fmt.Errorf("error getting context: %v", err) + } + } + + if len(b.TablesAndColumns) == 0 { + return "No tables found or SetContext() not called.", nil + } + + var buffer bytes.Buffer + tableStyle := lipgloss.NewStyle(). + Border(lipgloss.RoundedBorder()). + BorderForeground(lipgloss.Color("240")). + Padding(1, 2). + Margin(1, 0) + + columnStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color("212")) + + for tableName, columns := range b.TablesAndColumns { + var tableBuffer bytes.Buffer + tableBuffer.WriteString(fmt.Sprintf("Table: %s\n", tableName)) + tableBuffer.WriteString("Columns:\n") + for _, column := range columns { + columnContent := fmt.Sprintf("%s (%s)\n", column.Name, column.DataType) + tableBuffer.WriteString(columnStyle.Render(columnContent)) + } + tableContent := tableBuffer.String() + buffer.WriteString(tableStyle.Render(tableContent)) + buffer.WriteString("\n") + } + + return buffer.String(), nil +} + func (b *BaseRDBMSConnection) ExecuteCommand(command string) ([]byte, error) { connectionDetails, err := GetDatabaseConnectionDetails(b.Connection) if err != nil { @@ -344,6 +386,8 @@ type PostgreSQLConnection struct { BaseRDBMSConnection } +var _ ConnectionInterface = &PostgreSQLConnection{} + func NewPostgreSQLConnection(connnection *Connection) *PostgreSQLConnection { if connnection.Type.GetSubtype() != "PostgreSQL" { panic("Connection type is not PostgreSQL") diff --git a/pkg/conn/kubernetes.go b/pkg/conn/kubernetes.go index e988b4a..18193a8 100644 --- a/pkg/conn/kubernetes.go +++ b/pkg/conn/kubernetes.go @@ -90,6 +90,8 @@ func NewKubernetesConnectionImpl(connection *Connection) *KubernetesConnectionIm } } +var _ ConnectionInterface = &KubernetesConnectionImpl{} + func (k *KubernetesConnectionImpl) GetConnection() Connection { return k.Connection } @@ -168,6 +170,44 @@ func (k *KubernetesConnectionImpl) GetContext() string { return sb.String() } +func (k *KubernetesConnectionImpl) GetFormattedContext() (string, error) { + var buffer bytes.Buffer + table := tablewriter.NewWriter(&buffer) + + // Namespaces + table.SetHeader([]string{"Namespaces"}) + for _, ns := range k.Namespaces { + table.Append([]string{ns.Name}) + } + table.Render() + + // Pods + table = tablewriter.NewWriter(&buffer) + table.SetHeader([]string{"Pods", "Namespace"}) + for _, pod := range k.Pods { + table.Append([]string{pod.Name, pod.Namespace}) + } + table.Render() + + // Deployments + table = tablewriter.NewWriter(&buffer) + table.SetHeader([]string{"Deployments", "Namespace"}) + for _, dep := range k.Deployments { + table.Append([]string{dep.Name, dep.Namespace}) + } + table.Render() + + // Services + table = tablewriter.NewWriter(&buffer) + table.SetHeader([]string{"Services", "Namespace"}) + for _, svc := range k.Services { + table.Append([]string{svc.Name, svc.Namespace}) + } + table.Render() + + return buffer.String(), nil +} + func (k *KubernetesConnectionImpl) GetCommand(prompt string) (string, error) { aiModel, err := ai.NewOpenAIModel(k.CommandType(), k.GetContext()) if err != nil { diff --git a/pkg/conn/types.go b/pkg/conn/types.go index c3aafec..3e40a43 100644 --- a/pkg/conn/types.go +++ b/pkg/conn/types.go @@ -138,6 +138,9 @@ type ConnectionInterface interface { // This information will be sent to the AI model which will use it to generate the queries/commands. GetContext() string + // GetFormattedContext returns the formatted context for the AI model. + GetFormattedContext() (string, error) + // ExecuteCommand executes the given command and returns the output as byte array. ExecuteCommand(command string) ([]byte, error) diff --git a/pkg/ui/conn/db/create.go b/pkg/ui/conn/db/create.go index 8cfa64f..f311ecb 100644 --- a/pkg/ui/conn/db/create.go +++ b/pkg/ui/conn/db/create.go @@ -287,7 +287,7 @@ func (m *createModel) View() string { if m.err != nil { return clearScreen + fmt.Sprintf("❌ Error: %v\n\nPress 'q', 'esc', or Ctrl+C to quit.", m.err) } - return clearScreen + fmt.Sprintf("Saving conn... %s", m.spinner.View()) + return clearScreen + fmt.Sprintf("Saving connection... %s", m.spinner.View()) case stepCreateDone: if m.err != nil { diff --git a/pkg/ui/conn/k8s/create.go b/pkg/ui/conn/k8s/create.go index dab0ec4..7aff91f 100644 --- a/pkg/ui/conn/k8s/create.go +++ b/pkg/ui/conn/k8s/create.go @@ -269,7 +269,7 @@ func (m *createModel) View() string { return clearScreen + s case stepCreateSpinner: - return clearScreen + outputStyle.Render("Saving conn... ") + m.spinner.View() + return clearScreen + outputStyle.Render("Saving connection... ") + m.spinner.View() case stepCreateDone: if m.err != nil { diff --git a/pkg/ui/conn/open.go b/pkg/ui/conn/open.go index 19f399b..8431e65 100644 --- a/pkg/ui/conn/open.go +++ b/pkg/ui/conn/open.go @@ -97,7 +97,6 @@ func (m *openRootModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { case ui.TransitionToShellMsg: - fmt.Println("Selected connection:", msg.Connection.Name) m.shellModel = ui.NewShellModel(msg.Connection) m.step = stepShell return m.shellModel, m.shellModel.Init() diff --git a/pkg/ui/shell.go b/pkg/ui/shell.go index 758e111..4483e8f 100644 --- a/pkg/ui/shell.go +++ b/pkg/ui/shell.go @@ -147,8 +147,6 @@ func (m shellModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case checkPassedMsg: m.checkPassed = true - m.step = stepShowContext - m.output = "Will be added here" m.step = stepEnterPrompt return m, textinput.Blink @@ -165,6 +163,12 @@ func (m shellModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, cmd case stepShowContext: + switch msg := msg.(type) { + case tea.KeyMsg: + if msg.Type == tea.KeyF1 { + m.step = stepEnterPrompt + } + } return m, nil case stepEnterPrompt: @@ -215,6 +219,16 @@ func (m shellModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case tea.KeyCtrlC, tea.KeyEsc: return m, tea.Quit + + case tea.KeyF1: + m.step = stepShowContext + output, err := m.popsConnection.GetFormattedContext() + if err != nil { + m.err = err + return m, nil + } + m.output = output + return m, nil } } return m, cmd @@ -322,6 +336,9 @@ func (m shellModel) View() string { case stepInitialChecks: content = m.viewInitialChecks() + case stepShowContext: + content = m.viewShowContext() + case stepEnterPrompt: content = m.viewEnterPrompt() @@ -370,7 +387,7 @@ func (m shellModel) viewEnterPrompt() string { modeStr = "answer" } - footer := "Use ←/→ to switch between modes (currently " + modeStr + "). Press Enter when ready." + footer := "Use ←/→ to switch between modes (currently " + modeStr + "). Press Enter when ready.\n\nPress F1 to show context." return fmt.Sprintf( "%s\n\n%s\n\n%s", @@ -380,6 +397,16 @@ func (m shellModel) viewEnterPrompt() string { ) } +func (m shellModel) viewShowContext() string { + footer := "Press F1 to return to prompt." + + return fmt.Sprintf( + "%s\n\n%s", + titleStyle.Render("ℹ️ Current Context"), + outputStyle.Render(m.output), + ) + "\n\n" + lipgloss.NewStyle().Foreground(lipgloss.Color("8")).Render(footer) +} + func (m shellModel) viewGenerateCommand() string { return titleStyle.Render("🤖 Generating command...") } @@ -498,17 +525,11 @@ func (m shellModel) runCommand(command string) tea.Cmd { return errMsg{err} } - fmt.Println("Output:") - fmt.Println(string(out)) - outStr, err := m.popsConnection.FormatResultAsTable(out) if err != nil { return errMsg{err} } - fmt.Println("Formatted Output:") - fmt.Println(outStr) - return outputMsg{ output: outStr, }