Skip to content

Commit

Permalink
Show Context feature in Shell
Browse files Browse the repository at this point in the history
Signed-off-by: ytimocin <ytimocin@microsoft.com>
  • Loading branch information
ytimocin committed Jan 20, 2025
1 parent 196517a commit a57f819
Show file tree
Hide file tree
Showing 9 changed files with 153 additions and 18 deletions.
6 changes: 6 additions & 0 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ func init() {

// loadConnections reads the connections from the JSON configuration file.
func loadConnections() error {
if _, statErr := os.Stat(connectionsConfigFilePath); os.IsNotExist(statErr) {
if createErr := writeConnections(); createErr != nil {
return createErr
}
}

file, err := os.Open(connectionsConfigFilePath)
if err != nil {
return err
Expand Down
22 changes: 22 additions & 0 deletions pkg/conn/cloud.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ func (c *BaseCloudConnection) FormatResultAsTable(result []byte) (string, error)
return buffer.String(), nil
}

var _ ConnectionInterface = &AzureConnection{}

type AzureConnection struct {
BaseCloudConnection

Expand Down Expand Up @@ -225,6 +227,26 @@ func (a *AzureConnection) GetContext() string {
return context
}

func (a *AzureConnection) GetFormattedContext() (string, error) {
if a.ResourceGroups == nil {
// Call SetContext to populate the resource groups.
// This is a fallback in case SetContext is not called.
if err := a.SetContext(); err != nil {
return "", fmt.Errorf("error getting context: %v", err)
}
}

var buffer bytes.Buffer
table := tablewriter.NewWriter(&buffer)
table.SetHeader([]string{"Resource Group"})
for _, rg := range a.ResourceGroups {
table.Append([]string{rg.Name})
}
table.Render()

return buffer.String(), nil
}

func NewAzureConnection(connnection *Connection) *AzureConnection {
return &AzureConnection{
BaseCloudConnection: BaseCloudConnection{
Expand Down
56 changes: 50 additions & 6 deletions pkg/conn/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import (
"database/sql"
"encoding/json"
"fmt"
"regexp"

"github.com/charmbracelet/lipgloss"
_ "github.com/lib/pq"
"github.com/olekukonko/tablewriter"
"github.com/prompt-ops/pops/pkg/ai"
Expand Down Expand Up @@ -186,7 +186,7 @@ func (b *BaseRDBMSConnection) SetContext() error {
column = AddQuotesIfNeeded(column)
dataType = AddQuotesIfNeeded(dataType)

fullTableName := fmt.Sprintf(`%s."%s"`, schema, table)
fullTableName := fmt.Sprintf(`%s.%s`, schema, table)
b.TablesAndColumns[fullTableName] = append(b.TablesAndColumns[fullTableName], ColumnDetail{
Name: column,
DataType: dataType,
Expand All @@ -201,10 +201,11 @@ func (b *BaseRDBMSConnection) SetContext() error {

// AddQuotesIfNeeded adds quotes around the name if it contains capital letters.
func AddQuotesIfNeeded(name string) string {
if regexp.MustCompile(`[A-Z]`).MatchString(name) {
return fmt.Sprintf(`"%s"`, name)
}
return name
// if regexp.MustCompile(`[A-Z]`).MatchString(name) {
// return fmt.Sprintf(`"%s"`, name)
// }
// return name
return fmt.Sprintf(`"%s"`, name)
}

// GetContext returns the tables and columns set by SetContext.
Expand All @@ -218,6 +219,8 @@ func (b *BaseRDBMSConnection) GetContext() string {
}

context := fmt.Sprintf("%s Connection Details:\n", b.Connection.Type.GetSubtype())
context += "Note to the AI: Please use all columns and table with double quotes as defined below.\n"
context += "Note to the AI: And please always use tables with aliases where possible.\n"
context += "Database Schema:\n"

// If still no tables found, return an error message.
Expand All @@ -237,6 +240,45 @@ func (b *BaseRDBMSConnection) GetContext() string {
return context
}

// GetFormattedContext generates a pretty-printed string of the tables and columns.
func (b *BaseRDBMSConnection) GetFormattedContext() (string, error) {
if b.TablesAndColumns == nil {
// Call SetContext to populate the tables and columns.
if err := b.SetContext(); err != nil {
return "", fmt.Errorf("error getting context: %v", err)
}
}

if len(b.TablesAndColumns) == 0 {
return "No tables found or SetContext() not called.", nil
}

var buffer bytes.Buffer
tableStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("240")).
Padding(1, 2).
Margin(1, 0)

columnStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("212"))

for tableName, columns := range b.TablesAndColumns {
var tableBuffer bytes.Buffer
tableBuffer.WriteString(fmt.Sprintf("Table: %s\n", tableName))
tableBuffer.WriteString("Columns:\n")
for _, column := range columns {
columnContent := fmt.Sprintf("%s (%s)\n", column.Name, column.DataType)
tableBuffer.WriteString(columnStyle.Render(columnContent))
}
tableContent := tableBuffer.String()
buffer.WriteString(tableStyle.Render(tableContent))
buffer.WriteString("\n")
}

return buffer.String(), nil
}

func (b *BaseRDBMSConnection) ExecuteCommand(command string) ([]byte, error) {
connectionDetails, err := GetDatabaseConnectionDetails(b.Connection)
if err != nil {
Expand Down Expand Up @@ -344,6 +386,8 @@ type PostgreSQLConnection struct {
BaseRDBMSConnection
}

var _ ConnectionInterface = &PostgreSQLConnection{}

func NewPostgreSQLConnection(connnection *Connection) *PostgreSQLConnection {
if connnection.Type.GetSubtype() != "PostgreSQL" {
panic("Connection type is not PostgreSQL")
Expand Down
40 changes: 40 additions & 0 deletions pkg/conn/kubernetes.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ func NewKubernetesConnectionImpl(connection *Connection) *KubernetesConnectionIm
}
}

var _ ConnectionInterface = &KubernetesConnectionImpl{}

func (k *KubernetesConnectionImpl) GetConnection() Connection {
return k.Connection
}
Expand Down Expand Up @@ -168,6 +170,44 @@ func (k *KubernetesConnectionImpl) GetContext() string {
return sb.String()
}

func (k *KubernetesConnectionImpl) GetFormattedContext() (string, error) {
var buffer bytes.Buffer
table := tablewriter.NewWriter(&buffer)

// Namespaces
table.SetHeader([]string{"Namespaces"})
for _, ns := range k.Namespaces {
table.Append([]string{ns.Name})
}
table.Render()

// Pods
table = tablewriter.NewWriter(&buffer)
table.SetHeader([]string{"Pods", "Namespace"})
for _, pod := range k.Pods {
table.Append([]string{pod.Name, pod.Namespace})
}
table.Render()

// Deployments
table = tablewriter.NewWriter(&buffer)
table.SetHeader([]string{"Deployments", "Namespace"})
for _, dep := range k.Deployments {
table.Append([]string{dep.Name, dep.Namespace})
}
table.Render()

// Services
table = tablewriter.NewWriter(&buffer)
table.SetHeader([]string{"Services", "Namespace"})
for _, svc := range k.Services {
table.Append([]string{svc.Name, svc.Namespace})
}
table.Render()

return buffer.String(), nil
}

func (k *KubernetesConnectionImpl) GetCommand(prompt string) (string, error) {
aiModel, err := ai.NewOpenAIModel(k.CommandType(), k.GetContext())
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions pkg/conn/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ type ConnectionInterface interface {
// This information will be sent to the AI model which will use it to generate the queries/commands.
GetContext() string

// GetFormattedContext returns the formatted context for the AI model.
GetFormattedContext() (string, error)

// ExecuteCommand executes the given command and returns the output as byte array.
ExecuteCommand(command string) ([]byte, error)

Expand Down
2 changes: 1 addition & 1 deletion pkg/ui/conn/db/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ func (m *createModel) View() string {
if m.err != nil {
return clearScreen + fmt.Sprintf("❌ Error: %v\n\nPress 'q', 'esc', or Ctrl+C to quit.", m.err)
}
return clearScreen + fmt.Sprintf("Saving conn... %s", m.spinner.View())
return clearScreen + fmt.Sprintf("Saving connection... %s", m.spinner.View())

case stepCreateDone:
if m.err != nil {
Expand Down
2 changes: 1 addition & 1 deletion pkg/ui/conn/k8s/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ func (m *createModel) View() string {
return clearScreen + s

case stepCreateSpinner:
return clearScreen + outputStyle.Render("Saving conn... ") + m.spinner.View()
return clearScreen + outputStyle.Render("Saving connection... ") + m.spinner.View()

case stepCreateDone:
if m.err != nil {
Expand Down
1 change: 0 additions & 1 deletion pkg/ui/conn/open.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ func (m *openRootModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {

case ui.TransitionToShellMsg:
fmt.Println("Selected connection:", msg.Connection.Name)
m.shellModel = ui.NewShellModel(msg.Connection)
m.step = stepShell
return m.shellModel, m.shellModel.Init()
Expand Down
39 changes: 30 additions & 9 deletions pkg/ui/shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,6 @@ func (m shellModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {

case checkPassedMsg:
m.checkPassed = true
m.step = stepShowContext
m.output = "Will be added here"
m.step = stepEnterPrompt
return m, textinput.Blink

Expand All @@ -165,6 +163,12 @@ func (m shellModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, cmd

case stepShowContext:
switch msg := msg.(type) {
case tea.KeyMsg:
if msg.Type == tea.KeyF1 {
m.step = stepEnterPrompt
}
}
return m, nil

case stepEnterPrompt:
Expand Down Expand Up @@ -215,6 +219,16 @@ func (m shellModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {

case tea.KeyCtrlC, tea.KeyEsc:
return m, tea.Quit

case tea.KeyF1:
m.step = stepShowContext
output, err := m.popsConnection.GetFormattedContext()
if err != nil {
m.err = err
return m, nil
}
m.output = output
return m, nil
}
}
return m, cmd
Expand Down Expand Up @@ -322,6 +336,9 @@ func (m shellModel) View() string {
case stepInitialChecks:
content = m.viewInitialChecks()

case stepShowContext:
content = m.viewShowContext()

case stepEnterPrompt:
content = m.viewEnterPrompt()

Expand Down Expand Up @@ -370,7 +387,7 @@ func (m shellModel) viewEnterPrompt() string {
modeStr = "answer"
}

footer := "Use ←/→ to switch between modes (currently " + modeStr + "). Press Enter when ready."
footer := "Use ←/→ to switch between modes (currently " + modeStr + "). Press Enter when ready.\n\nPress F1 to show context."

return fmt.Sprintf(
"%s\n\n%s\n\n%s",
Expand All @@ -380,6 +397,16 @@ func (m shellModel) viewEnterPrompt() string {
)
}

func (m shellModel) viewShowContext() string {
footer := "Press F1 to return to prompt."

return fmt.Sprintf(
"%s\n\n%s",
titleStyle.Render("ℹ️ Current Context"),
outputStyle.Render(m.output),
) + "\n\n" + lipgloss.NewStyle().Foreground(lipgloss.Color("8")).Render(footer)
}

func (m shellModel) viewGenerateCommand() string {
return titleStyle.Render("🤖 Generating command...")
}
Expand Down Expand Up @@ -498,17 +525,11 @@ func (m shellModel) runCommand(command string) tea.Cmd {
return errMsg{err}
}

fmt.Println("Output:")
fmt.Println(string(out))

outStr, err := m.popsConnection.FormatResultAsTable(out)
if err != nil {
return errMsg{err}
}

fmt.Println("Formatted Output:")
fmt.Println(outStr)

return outputMsg{
output: outStr,
}
Expand Down

0 comments on commit a57f819

Please sign in to comment.