Skip to content

Commit

Permalink
New provider: rds (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
fubarhouse authored Aug 30, 2024
1 parent 2cbe98e commit 416a3dc
Show file tree
Hide file tree
Showing 18 changed files with 383 additions and 144 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/packages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ on:
push:
branches:
- 'main'
- 's3-export-import'
tags:
- 'v[0-9]+.[0-9]+.[0-9]+[0-9A-Za-z]?'
- 'v[0-9]+.[0-9]+.[0-9]+*'

env:
REGISTRY: ghcr.io
Expand Down
18 changes: 11 additions & 7 deletions cmd/mtk/dump/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/spf13/cobra"

"github.com/skpr/mtk/internal/mysql"
"github.com/skpr/mtk/internal/mysql/provider"
"github.com/skpr/mtk/pkg/config"
"github.com/skpr/mtk/pkg/envar"
)
Expand All @@ -31,17 +32,19 @@ const cmdExample = `
# List all database tables and dump each table to a file.
mtk table list <database> | xargs -I {} sh -c "mtk dump <database> '{}' > '{}.sql'"`

// Options is the commandline options for 'config' sub command
// Options is the commandline options for 'dump' sub command
type Options struct {
ConfigFile string
ExtendedInsertRows int
}

// NewOptions will return a new Options.
func NewOptions() Options {
return Options{}
}

func NewCommand(conn *mysql.Connection) *cobra.Command {
// NewCommand will return a new Cobra command.
func NewCommand(conn *mysql.Connection, provider, rdsRegion, rdsS3uri string) *cobra.Command {
o := NewOptions()

cmd := &cobra.Command{
Expand All @@ -68,7 +71,7 @@ func NewCommand(conn *mysql.Connection) *cobra.Command {
panic(err)
}

if err := o.Run(os.Stdout, logger, conn, database, table, cfg); err != nil {
if err := o.Run(os.Stdout, logger, conn, database, table, provider, rdsRegion, rdsS3uri, cfg); err != nil {
panic(err)
}
},
Expand All @@ -80,15 +83,16 @@ func NewCommand(conn *mysql.Connection) *cobra.Command {
return cmd
}

func (o *Options) Run(w io.Writer, logger *log.Logger, conn *mysql.Connection, database, table string, cfg config.Rules) error {
// Run will execute the dump command.
func (o *Options) Run(w io.Writer, logger *log.Logger, conn *mysql.Connection, database, table, provider, region, uri string, cfg config.Rules) error {
db, err := conn.Open(database)
if err != nil {
return fmt.Errorf("failed to open database connection: %w", err)
}

defer db.Close()

client := mysql.NewClient(db, logger)
client := mysql.NewClient(db, logger, provider, region, uri)

if table != "" {
return o.runDumpTable(w, client, table, cfg)
Expand All @@ -110,7 +114,7 @@ func (o *Options) runDumpTables(w io.Writer, client *mysql.Client, cfg config.Ru
return err
}

params := mysql.DumpParams{
params := provider.DumpParams{
ExtendedInsertRows: o.ExtendedInsertRows,
}

Expand Down Expand Up @@ -154,7 +158,7 @@ func (o *Options) runDumpTables(w io.Writer, client *mysql.Client, cfg config.Ru
//
// eg. runDumpTables has to perform ListTablesByGlobal for each table, which is slow.
func (o *Options) runDumpTable(w io.Writer, client *mysql.Client, table string, cfg config.Rules) error {
params := mysql.DumpParams{
params := provider.DumpParams{
ExtendedInsertRows: o.ExtendedInsertRows,
}

Expand Down
15 changes: 13 additions & 2 deletions cmd/mtk/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@ import (
"github.com/skpr/mtk/pkg/envar"
)

var conn = new(mysql.Connection)
var (
conn = new(mysql.Connection)
providerName string
rdsRegion string
rdsS3Uri string
)

const cmdExample = `
export MTK_HOSTNAME=localhost
Expand Down Expand Up @@ -51,6 +56,12 @@ func init() {
cmd.PersistentFlags().StringVar(&conn.Protocol, "protocol", envar.GetStringWithFallback("tcp", envar.Protocol, envar.MySQLProtocol), "Connection protocol to use when connecting to MySQL instance")
cmd.PersistentFlags().Int32Var(&conn.Port, "port", int32(envar.GetIntWithFallback(3306, envar.Port, envar.MySQLPort)), "Port to connect to the MySQL instance on")
cmd.PersistentFlags().IntVar(&conn.MaxConn, "max-conn", envar.GetIntWithFallback(50, envar.MaxConn), "Sets the maximum number of open connections to the database")

cmd.PersistentFlags().StringVar(&providerName, "provider", envar.GetStringWithFallback("stdout", envar.Provider), "The provider to use (either 'stdout' or 'rds')")

// RDS Provider Flags.
cmd.PersistentFlags().StringVar(&rdsRegion, "rds-region", envar.GetStringWithFallback("", envar.RDSRegion), "The AWS region to use for S3 when connecting to the MySQL RDS instance")
cmd.PersistentFlags().StringVar(&rdsS3Uri, "rds-s3-uri", envar.GetStringWithFallback("", envar.RDSS3Uri), "The S3 URI to use for exporting to S3 when exporting data from the MySQL RDS instance")
}

func main() {
Expand All @@ -68,7 +79,7 @@ func main() {
usageTemplate = re.ReplaceAllLiteralString(usageTemplate, `{{StyleHeading "Flags:"}}`)
cmd.SetUsageTemplate(usageTemplate)

cmd.AddCommand(dump.NewCommand(conn))
cmd.AddCommand(dump.NewCommand(conn, providerName, rdsRegion, rdsS3Uri))
cmd.AddCommand(table.NewCommand(conn))

if err := cmd.Execute(); err != nil {
Expand Down
1 change: 1 addition & 0 deletions cmd/mtk/table/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ const cmdExample = `
# List all database tables.
mtk table list <database>`

// NewCommand will execute the table command.
func NewCommand(conn *mysql.Connection) *cobra.Command {
cmd := &cobra.Command{
Use: "table",
Expand Down
4 changes: 2 additions & 2 deletions cmd/mtk/table/list/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ const cmdExample = `
# List all database tables and dump each table to a file.
mtk table list <database> | xargs -I {} sh -c "mtk dump <database> '{}' > '{}.sql'"`

// Options is the commandline options for 'config' sub command
// Options is the commandline options for 'list' sub command
type Options struct {
ConfigFile string
}
Expand Down Expand Up @@ -84,7 +84,7 @@ func (o *Options) Run(logger *log.Logger, conn *mysql.Connection, database strin

defer db.Close()

client := mysql.NewClient(db, logger)
client := mysql.NewClient(db, logger, "", "", "")

tables, err := client.QueryTables()
if err != nil {
Expand Down
34 changes: 20 additions & 14 deletions internal/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"log"

"github.com/go-sql-driver/mysql"

"github.com/skpr/mtk/internal/mysql/provider"
)

const (
Expand All @@ -16,6 +18,7 @@ const (
OperationNoData = "nodata"
)

// Connection is a struct containing metadata for the database connection.
type Connection struct {
Hostname string
Username string
Expand All @@ -25,6 +28,7 @@ type Connection struct {
MaxConn int
}

// Open will Open a new database connection.
func (o Connection) Open(database string) (*sql.DB, error) {
cfg := mysql.Config{
User: o.Username,
Expand All @@ -49,29 +53,31 @@ func (o Connection) Open(database string) (*sql.DB, error) {
type Client struct {
DB *sql.DB
Logger *log.Logger

// A field for caching a list of tables for this database.
cachedTables []string

// Provider configuration.
Provider string
// For the AWS RDS Provider, specify the AWS Region.
Region string
// For the AWS RDS Provider, specify the S3 URI.
URI string
}

// NewClient for dumping a full or single table from a database.
func NewClient(db *sql.DB, logger *log.Logger) *Client {
func NewClient(db *sql.DB, logger *log.Logger, provider, region, uri string) *Client {
return &Client{
DB: db,
Logger: logger,
DB: db,
Logger: logger,
Provider: provider,
Region: region,
URI: uri,
}
}

// DumpParams is used to pass parameters to the Dump function.
type DumpParams struct {
SelectMap map[string]map[string]string
WhereMap map[string]string
FilterMap map[string]string
UseTableLock bool
ExtendedInsertRows int
}

// DumpTables will write all table data to a single writer.
func (d *Client) DumpTables(w io.Writer, params DumpParams) error {
func (d *Client) DumpTables(w io.Writer, params provider.DumpParams) error {
if err := d.WriteHeader(w); err != nil {
return fmt.Errorf("failed to write header: %w", err)
}
Expand All @@ -92,7 +98,7 @@ func (d *Client) DumpTables(w io.Writer, params DumpParams) error {
}

// DumpTable is convenient if you wish to coordinate a dump eg. Single file per table.
func (d *Client) DumpTable(w io.Writer, table string, params DumpParams) error {
func (d *Client) DumpTable(w io.Writer, table string, params provider.DumpParams) error {
if err := d.WriteHeader(w); err != nil {
return fmt.Errorf("failed to write header: %w", err)
}
Expand Down
15 changes: 15 additions & 0 deletions internal/mysql/provider/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package provider

// Interface implements the required functionality for a Provider.
type Interface interface {
GetSelectQueryForTable(table string, params DumpParams) (string, error)
}

// DumpParams is used to pass parameters to the Dump function.
type DumpParams struct {
SelectMap map[string]map[string]string
WhereMap map[string]string
FilterMap map[string]string
UseTableLock bool
ExtendedInsertRows int
}
74 changes: 74 additions & 0 deletions internal/mysql/provider/rds/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package rds

import (
"database/sql"
"fmt"
"log"
"strings"

"github.com/skpr/mtk/internal/mysql/provider"
providerutils "github.com/skpr/mtk/internal/mysql/provider/utils"
)

// Client used for dumping a database and/or table.
type Client struct {
provider.Interface
DB *sql.DB
Logger *log.Logger

Region string // Region configuration
URI string // S3 URI configuration
}

// NewClient for dumping a full or single table from a database.
func NewClient(db *sql.DB, logger *log.Logger, region, uri string) *Client {
return &Client{
DB: db,
Logger: logger,
Region: region,
URI: uri,
}
}

// GetSelectQueryForTable will return a complete SELECT query to export data from a table.
func (d *Client) GetSelectQueryForTable(table string, params provider.DumpParams) (string, error) {
cols, err := providerutils.QueryColumnsForTable(d.DB, table, params)
if err != nil {
return "", err
}

query := fmt.Sprintf("SELECT %s", strings.Join(cols, ", "))
query = fmt.Sprintf("%s FROM `%s`", query, table)

if where, ok := params.WhereMap[strings.ToLower(table)]; ok {
query = fmt.Sprintf("%s WHERE %s", query, where)
}

query = fmt.Sprintf("%s INTO OUTFILE S3 '%s/%s.csv'", query, d.URI, table)
query = fmt.Sprintf("%s FIELDS TERMINATED BY ',' ENCLOSED BY '\"' LINES TERMINATED BY '\\n'", query)
query = fmt.Sprintf("%s MANIFEST ON", query)
query = fmt.Sprintf("%s OVERWRITE ON", query)

importQuery, err := d.GetLoadQueryForTable(table)
if err != nil {
return "", err
}

fmt.Println(importQuery)
return query, nil
}

// GetLoadQueryForTable will return a complete SELECT query to fetch data from a table.
func (d *Client) GetLoadQueryForTable(table string) (string, error) {
if table == "" {
return "", fmt.Errorf("error: no table specified")
}
if d.Region == "" || len(strings.Split(d.Region, "-")) != 3 {
return "", fmt.Errorf("error: region is not configured correctly")
}
path := strings.TrimPrefix(d.URI, "s3://")
query := fmt.Sprintf("LOAD DATA FROM S3 MANIFEST 'S3-%s://%s/%s.csv.manifest' INTO TABLE `%s`", d.Region, path, table, table)
query = fmt.Sprintf("%s FIELDS TERMINATED BY ',' ENCLOSED BY '\"' LINES TERMINATED BY '\\n'", query)

return query, nil
}
35 changes: 35 additions & 0 deletions internal/mysql/provider/rds/provider_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package rds

import (
"log"
"os"
"testing"

"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"

"github.com/skpr/mtk/internal/mysql/mock"
"github.com/skpr/mtk/internal/mysql/provider"
)

func TestMySQLGetExportSelectQueryFor(t *testing.T) {
db, mock := mock.GetDB(t)
dumper := NewClient(db, log.New(os.Stdout, "", 0), "ap-southheast-2", "s3://path/to/bucket")
mock.ExpectQuery("SELECT \\* FROM `table` LIMIT 1").WillReturnRows(
sqlmock.NewRows([]string{"c1", "c2"}).AddRow("a", "b"))
query, err := dumper.GetSelectQueryForTable("table", provider.DumpParams{
SelectMap: map[string]map[string]string{"table": {"c2": "NOW()"}},
WhereMap: map[string]string{"table": "c1 > 0"},
})
assert.Nil(t, err)
assert.Equal(t, "SELECT `c1`, NOW() AS `c2` FROM `table` WHERE c1 > 0 INTO OUTFILE S3 's3://path/to/bucket/table.csv' FIELDS TERMINATED BY ',' ENCLOSED BY '\"' LINES TERMINATED BY '\\n' MANIFEST ON OVERWRITE ON", query)
}

func TestMySQLGetLoadQueryFor(t *testing.T) {
db, _ := mock.GetDB(t)
dumper := NewClient(db, log.New(os.Stdout, "", 0), "ap-southeast-4", "s3://path/to/bucket")
query, err := dumper.GetLoadQueryForTable("table_name")
assert.Nil(t, err)
assert.Equal(t, "LOAD DATA FROM S3 MANIFEST 'S3-ap-southeast-4://path/to/bucket/table_name.csv.manifest' INTO TABLE `table_name` FIELDS TERMINATED BY ',' ENCLOSED BY '\"' LINES TERMINATED BY '\\n'", query)

}
42 changes: 42 additions & 0 deletions internal/mysql/provider/stdout/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package stdout

import (
"database/sql"
"fmt"
"log"
"strings"

"github.com/skpr/mtk/internal/mysql/provider"
providerutils "github.com/skpr/mtk/internal/mysql/provider/utils"
)

// Client used for dumping a database and/or table.
type Client struct {
provider.Interface
DB *sql.DB
Logger *log.Logger
}

// NewClient for dumping a full or single table from a database.
func NewClient(db *sql.DB, logger *log.Logger) *Client {
return &Client{
DB: db,
Logger: logger,
}
}

// GetSelectQueryForTable will return a complete SELECT query to fetch data from a table.
func (d *Client) GetSelectQueryForTable(table string, params provider.DumpParams) (string, error) {
cols, err := providerutils.QueryColumnsForTable(d.DB, table, params)
if err != nil {
return "", err
}

query := fmt.Sprintf("SELECT %s FROM `%s`", strings.Join(cols, ", "), table)

if where, ok := params.WhereMap[strings.ToLower(table)]; ok {
query = fmt.Sprintf("%s WHERE %s", query, where)
}

return query, nil
}
Loading

0 comments on commit 416a3dc

Please sign in to comment.