diff --git a/.github/workflows/packages.yml b/.github/workflows/packages.yml index 5702192..6af6a0e 100644 --- a/.github/workflows/packages.yml +++ b/.github/workflows/packages.yml @@ -4,8 +4,9 @@ on: push: branches: - 'main' + - 's3-export-import' tags: - - 'v[0-9]+.[0-9]+.[0-9]+[0-9A-Za-z]?' + - 'v[0-9]+.[0-9]+.[0-9]+*' env: REGISTRY: ghcr.io diff --git a/cmd/mtk/dump/command.go b/cmd/mtk/dump/command.go index 96a9ed1..1414ac5 100644 --- a/cmd/mtk/dump/command.go +++ b/cmd/mtk/dump/command.go @@ -10,6 +10,7 @@ import ( "github.com/spf13/cobra" "github.com/skpr/mtk/internal/mysql" + "github.com/skpr/mtk/internal/mysql/provider" "github.com/skpr/mtk/pkg/config" "github.com/skpr/mtk/pkg/envar" ) @@ -31,17 +32,19 @@ const cmdExample = ` # List all database tables and dump each table to a file. mtk table list | xargs -I {} sh -c "mtk dump '{}' > '{}.sql'"` -// Options is the commandline options for 'config' sub command +// Options is the commandline options for 'dump' sub command type Options struct { ConfigFile string ExtendedInsertRows int } +// NewOptions will return a new Options. func NewOptions() Options { return Options{} } -func NewCommand(conn *mysql.Connection) *cobra.Command { +// NewCommand will return a new Cobra command. +func NewCommand(conn *mysql.Connection, provider, rdsRegion, rdsS3uri string) *cobra.Command { o := NewOptions() cmd := &cobra.Command{ @@ -68,7 +71,7 @@ func NewCommand(conn *mysql.Connection) *cobra.Command { panic(err) } - if err := o.Run(os.Stdout, logger, conn, database, table, cfg); err != nil { + if err := o.Run(os.Stdout, logger, conn, database, table, provider, rdsRegion, rdsS3uri, cfg); err != nil { panic(err) } }, @@ -80,7 +83,8 @@ func NewCommand(conn *mysql.Connection) *cobra.Command { return cmd } -func (o *Options) Run(w io.Writer, logger *log.Logger, conn *mysql.Connection, database, table string, cfg config.Rules) error { +// Run will execute the dump command. +func (o *Options) Run(w io.Writer, logger *log.Logger, conn *mysql.Connection, database, table, provider, region, uri string, cfg config.Rules) error { db, err := conn.Open(database) if err != nil { return fmt.Errorf("failed to open database connection: %w", err) @@ -88,7 +92,7 @@ func (o *Options) Run(w io.Writer, logger *log.Logger, conn *mysql.Connection, d defer db.Close() - client := mysql.NewClient(db, logger) + client := mysql.NewClient(db, logger, provider, region, uri) if table != "" { return o.runDumpTable(w, client, table, cfg) @@ -110,7 +114,7 @@ func (o *Options) runDumpTables(w io.Writer, client *mysql.Client, cfg config.Ru return err } - params := mysql.DumpParams{ + params := provider.DumpParams{ ExtendedInsertRows: o.ExtendedInsertRows, } @@ -154,7 +158,7 @@ func (o *Options) runDumpTables(w io.Writer, client *mysql.Client, cfg config.Ru // // eg. runDumpTables has to perform ListTablesByGlobal for each table, which is slow. func (o *Options) runDumpTable(w io.Writer, client *mysql.Client, table string, cfg config.Rules) error { - params := mysql.DumpParams{ + params := provider.DumpParams{ ExtendedInsertRows: o.ExtendedInsertRows, } diff --git a/cmd/mtk/main.go b/cmd/mtk/main.go index 42f7b3c..bb7091f 100644 --- a/cmd/mtk/main.go +++ b/cmd/mtk/main.go @@ -15,7 +15,12 @@ import ( "github.com/skpr/mtk/pkg/envar" ) -var conn = new(mysql.Connection) +var ( + conn = new(mysql.Connection) + providerName string + rdsRegion string + rdsS3Uri string +) const cmdExample = ` export MTK_HOSTNAME=localhost @@ -51,6 +56,12 @@ func init() { cmd.PersistentFlags().StringVar(&conn.Protocol, "protocol", envar.GetStringWithFallback("tcp", envar.Protocol, envar.MySQLProtocol), "Connection protocol to use when connecting to MySQL instance") cmd.PersistentFlags().Int32Var(&conn.Port, "port", int32(envar.GetIntWithFallback(3306, envar.Port, envar.MySQLPort)), "Port to connect to the MySQL instance on") cmd.PersistentFlags().IntVar(&conn.MaxConn, "max-conn", envar.GetIntWithFallback(50, envar.MaxConn), "Sets the maximum number of open connections to the database") + + cmd.PersistentFlags().StringVar(&providerName, "provider", envar.GetStringWithFallback("stdout", envar.Provider), "The provider to use (either 'stdout' or 'rds')") + + // RDS Provider Flags. + cmd.PersistentFlags().StringVar(&rdsRegion, "rds-region", envar.GetStringWithFallback("", envar.RDSRegion), "The AWS region to use for S3 when connecting to the MySQL RDS instance") + cmd.PersistentFlags().StringVar(&rdsS3Uri, "rds-s3-uri", envar.GetStringWithFallback("", envar.RDSS3Uri), "The S3 URI to use for exporting to S3 when exporting data from the MySQL RDS instance") } func main() { @@ -68,7 +79,7 @@ func main() { usageTemplate = re.ReplaceAllLiteralString(usageTemplate, `{{StyleHeading "Flags:"}}`) cmd.SetUsageTemplate(usageTemplate) - cmd.AddCommand(dump.NewCommand(conn)) + cmd.AddCommand(dump.NewCommand(conn, providerName, rdsRegion, rdsS3Uri)) cmd.AddCommand(table.NewCommand(conn)) if err := cmd.Execute(); err != nil { diff --git a/cmd/mtk/table/command.go b/cmd/mtk/table/command.go index bc23e0b..17a1290 100644 --- a/cmd/mtk/table/command.go +++ b/cmd/mtk/table/command.go @@ -14,6 +14,7 @@ const cmdExample = ` # List all database tables. mtk table list ` +// NewCommand will execute the table command. func NewCommand(conn *mysql.Connection) *cobra.Command { cmd := &cobra.Command{ Use: "table", diff --git a/cmd/mtk/table/list/command.go b/cmd/mtk/table/list/command.go index 340b114..401db62 100644 --- a/cmd/mtk/table/list/command.go +++ b/cmd/mtk/table/list/command.go @@ -32,7 +32,7 @@ const cmdExample = ` # List all database tables and dump each table to a file. mtk table list | xargs -I {} sh -c "mtk dump '{}' > '{}.sql'"` -// Options is the commandline options for 'config' sub command +// Options is the commandline options for 'list' sub command type Options struct { ConfigFile string } @@ -84,7 +84,7 @@ func (o *Options) Run(logger *log.Logger, conn *mysql.Connection, database strin defer db.Close() - client := mysql.NewClient(db, logger) + client := mysql.NewClient(db, logger, "", "", "") tables, err := client.QueryTables() if err != nil { diff --git a/internal/mysql/mysql.go b/internal/mysql/mysql.go index 475df62..7a433ba 100644 --- a/internal/mysql/mysql.go +++ b/internal/mysql/mysql.go @@ -7,6 +7,8 @@ import ( "log" "github.com/go-sql-driver/mysql" + + "github.com/skpr/mtk/internal/mysql/provider" ) const ( @@ -16,6 +18,7 @@ const ( OperationNoData = "nodata" ) +// Connection is a struct containing metadata for the database connection. type Connection struct { Hostname string Username string @@ -25,6 +28,7 @@ type Connection struct { MaxConn int } +// Open will Open a new database connection. func (o Connection) Open(database string) (*sql.DB, error) { cfg := mysql.Config{ User: o.Username, @@ -49,29 +53,31 @@ func (o Connection) Open(database string) (*sql.DB, error) { type Client struct { DB *sql.DB Logger *log.Logger + // A field for caching a list of tables for this database. cachedTables []string + + // Provider configuration. + Provider string + // For the AWS RDS Provider, specify the AWS Region. + Region string + // For the AWS RDS Provider, specify the S3 URI. + URI string } // NewClient for dumping a full or single table from a database. -func NewClient(db *sql.DB, logger *log.Logger) *Client { +func NewClient(db *sql.DB, logger *log.Logger, provider, region, uri string) *Client { return &Client{ - DB: db, - Logger: logger, + DB: db, + Logger: logger, + Provider: provider, + Region: region, + URI: uri, } } -// DumpParams is used to pass parameters to the Dump function. -type DumpParams struct { - SelectMap map[string]map[string]string - WhereMap map[string]string - FilterMap map[string]string - UseTableLock bool - ExtendedInsertRows int -} - // DumpTables will write all table data to a single writer. -func (d *Client) DumpTables(w io.Writer, params DumpParams) error { +func (d *Client) DumpTables(w io.Writer, params provider.DumpParams) error { if err := d.WriteHeader(w); err != nil { return fmt.Errorf("failed to write header: %w", err) } @@ -92,7 +98,7 @@ func (d *Client) DumpTables(w io.Writer, params DumpParams) error { } // DumpTable is convenient if you wish to coordinate a dump eg. Single file per table. -func (d *Client) DumpTable(w io.Writer, table string, params DumpParams) error { +func (d *Client) DumpTable(w io.Writer, table string, params provider.DumpParams) error { if err := d.WriteHeader(w); err != nil { return fmt.Errorf("failed to write header: %w", err) } diff --git a/internal/mysql/provider/provider.go b/internal/mysql/provider/provider.go new file mode 100644 index 0000000..7e66fbd --- /dev/null +++ b/internal/mysql/provider/provider.go @@ -0,0 +1,15 @@ +package provider + +// Interface implements the required functionality for a Provider. +type Interface interface { + GetSelectQueryForTable(table string, params DumpParams) (string, error) +} + +// DumpParams is used to pass parameters to the Dump function. +type DumpParams struct { + SelectMap map[string]map[string]string + WhereMap map[string]string + FilterMap map[string]string + UseTableLock bool + ExtendedInsertRows int +} diff --git a/internal/mysql/provider/rds/provider.go b/internal/mysql/provider/rds/provider.go new file mode 100644 index 0000000..96efd66 --- /dev/null +++ b/internal/mysql/provider/rds/provider.go @@ -0,0 +1,74 @@ +package rds + +import ( + "database/sql" + "fmt" + "log" + "strings" + + "github.com/skpr/mtk/internal/mysql/provider" + providerutils "github.com/skpr/mtk/internal/mysql/provider/utils" +) + +// Client used for dumping a database and/or table. +type Client struct { + provider.Interface + DB *sql.DB + Logger *log.Logger + + Region string // Region configuration + URI string // S3 URI configuration +} + +// NewClient for dumping a full or single table from a database. +func NewClient(db *sql.DB, logger *log.Logger, region, uri string) *Client { + return &Client{ + DB: db, + Logger: logger, + Region: region, + URI: uri, + } +} + +// GetSelectQueryForTable will return a complete SELECT query to export data from a table. +func (d *Client) GetSelectQueryForTable(table string, params provider.DumpParams) (string, error) { + cols, err := providerutils.QueryColumnsForTable(d.DB, table, params) + if err != nil { + return "", err + } + + query := fmt.Sprintf("SELECT %s", strings.Join(cols, ", ")) + query = fmt.Sprintf("%s FROM `%s`", query, table) + + if where, ok := params.WhereMap[strings.ToLower(table)]; ok { + query = fmt.Sprintf("%s WHERE %s", query, where) + } + + query = fmt.Sprintf("%s INTO OUTFILE S3 '%s/%s.csv'", query, d.URI, table) + query = fmt.Sprintf("%s FIELDS TERMINATED BY ',' ENCLOSED BY '\"' LINES TERMINATED BY '\\n'", query) + query = fmt.Sprintf("%s MANIFEST ON", query) + query = fmt.Sprintf("%s OVERWRITE ON", query) + + importQuery, err := d.GetLoadQueryForTable(table) + if err != nil { + return "", err + } + + fmt.Println(importQuery) + return query, nil +} + +// GetLoadQueryForTable will return a complete SELECT query to fetch data from a table. +func (d *Client) GetLoadQueryForTable(table string) (string, error) { + if table == "" { + return "", fmt.Errorf("error: no table specified") + } + if d.Region == "" || len(strings.Split(d.Region, "-")) != 3 { + return "", fmt.Errorf("error: region is not configured correctly") + } + path := strings.TrimPrefix(d.URI, "s3://") + query := fmt.Sprintf("LOAD DATA FROM S3 MANIFEST 'S3-%s://%s/%s.csv.manifest' INTO TABLE `%s`", d.Region, path, table, table) + query = fmt.Sprintf("%s FIELDS TERMINATED BY ',' ENCLOSED BY '\"' LINES TERMINATED BY '\\n'", query) + + return query, nil +} diff --git a/internal/mysql/provider/rds/provider_test.go b/internal/mysql/provider/rds/provider_test.go new file mode 100644 index 0000000..d132cb7 --- /dev/null +++ b/internal/mysql/provider/rds/provider_test.go @@ -0,0 +1,35 @@ +package rds + +import ( + "log" + "os" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" + + "github.com/skpr/mtk/internal/mysql/mock" + "github.com/skpr/mtk/internal/mysql/provider" +) + +func TestMySQLGetExportSelectQueryFor(t *testing.T) { + db, mock := mock.GetDB(t) + dumper := NewClient(db, log.New(os.Stdout, "", 0), "ap-southheast-2", "s3://path/to/bucket") + mock.ExpectQuery("SELECT \\* FROM `table` LIMIT 1").WillReturnRows( + sqlmock.NewRows([]string{"c1", "c2"}).AddRow("a", "b")) + query, err := dumper.GetSelectQueryForTable("table", provider.DumpParams{ + SelectMap: map[string]map[string]string{"table": {"c2": "NOW()"}}, + WhereMap: map[string]string{"table": "c1 > 0"}, + }) + assert.Nil(t, err) + assert.Equal(t, "SELECT `c1`, NOW() AS `c2` FROM `table` WHERE c1 > 0 INTO OUTFILE S3 's3://path/to/bucket/table.csv' FIELDS TERMINATED BY ',' ENCLOSED BY '\"' LINES TERMINATED BY '\\n' MANIFEST ON OVERWRITE ON", query) +} + +func TestMySQLGetLoadQueryFor(t *testing.T) { + db, _ := mock.GetDB(t) + dumper := NewClient(db, log.New(os.Stdout, "", 0), "ap-southeast-4", "s3://path/to/bucket") + query, err := dumper.GetLoadQueryForTable("table_name") + assert.Nil(t, err) + assert.Equal(t, "LOAD DATA FROM S3 MANIFEST 'S3-ap-southeast-4://path/to/bucket/table_name.csv.manifest' INTO TABLE `table_name` FIELDS TERMINATED BY ',' ENCLOSED BY '\"' LINES TERMINATED BY '\\n'", query) + +} diff --git a/internal/mysql/provider/stdout/provider.go b/internal/mysql/provider/stdout/provider.go new file mode 100644 index 0000000..4d6a325 --- /dev/null +++ b/internal/mysql/provider/stdout/provider.go @@ -0,0 +1,42 @@ +package stdout + +import ( + "database/sql" + "fmt" + "log" + "strings" + + "github.com/skpr/mtk/internal/mysql/provider" + providerutils "github.com/skpr/mtk/internal/mysql/provider/utils" +) + +// Client used for dumping a database and/or table. +type Client struct { + provider.Interface + DB *sql.DB + Logger *log.Logger +} + +// NewClient for dumping a full or single table from a database. +func NewClient(db *sql.DB, logger *log.Logger) *Client { + return &Client{ + DB: db, + Logger: logger, + } +} + +// GetSelectQueryForTable will return a complete SELECT query to fetch data from a table. +func (d *Client) GetSelectQueryForTable(table string, params provider.DumpParams) (string, error) { + cols, err := providerutils.QueryColumnsForTable(d.DB, table, params) + if err != nil { + return "", err + } + + query := fmt.Sprintf("SELECT %s FROM `%s`", strings.Join(cols, ", "), table) + + if where, ok := params.WhereMap[strings.ToLower(table)]; ok { + query = fmt.Sprintf("%s WHERE %s", query, where) + } + + return query, nil +} diff --git a/internal/mysql/provider/stdout/provider_test.go b/internal/mysql/provider/stdout/provider_test.go new file mode 100644 index 0000000..d8b6e50 --- /dev/null +++ b/internal/mysql/provider/stdout/provider_test.go @@ -0,0 +1,40 @@ +package stdout + +import ( + "errors" + "log" + "os" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" + + "github.com/skpr/mtk/internal/mysql/mock" + "github.com/skpr/mtk/internal/mysql/provider" +) + +func TestMySQLGetSelectQueryFor(t *testing.T) { + db, mock := mock.GetDB(t) + dumper := NewClient(db, log.New(os.Stdout, "", 0)) + mock.ExpectQuery("SELECT \\* FROM `table` LIMIT 1").WillReturnRows( + sqlmock.NewRows([]string{"c1", "c2"}).AddRow("a", "b")) + query, err := dumper.GetSelectQueryForTable("table", provider.DumpParams{ + SelectMap: map[string]map[string]string{"table": {"c2": "NOW()"}}, + WhereMap: map[string]string{"table": "c1 > 0"}, + }) + assert.Nil(t, err) + assert.Equal(t, "SELECT `c1`, NOW() AS `c2` FROM `table` WHERE c1 > 0", query) +} + +func TestMySQLGetSelectQueryForHandlingError(t *testing.T) { + db, mock := mock.GetDB(t) + dumper := NewClient(db, log.New(os.Stdout, "", 0)) + error := errors.New("broken") + mock.ExpectQuery("SELECT \\* FROM `table` LIMIT 1").WillReturnError(error) + query, err := dumper.GetSelectQueryForTable("table", provider.DumpParams{ + SelectMap: map[string]map[string]string{"table": {"c2": "NOW()"}}, + WhereMap: map[string]string{"table": "c1 > 0"}, + }) + assert.Equal(t, error, err) + assert.Equal(t, "", query) +} diff --git a/internal/mysql/provider/utils/utils.go b/internal/mysql/provider/utils/utils.go new file mode 100644 index 0000000..15ebe42 --- /dev/null +++ b/internal/mysql/provider/utils/utils.go @@ -0,0 +1,37 @@ +package utils + +import ( + "database/sql" + "fmt" + "strings" + + "github.com/skpr/mtk/internal/mysql/provider" +) + +// QueryColumnsForTable for a given table. +func QueryColumnsForTable(database *sql.DB, table string, params provider.DumpParams) ([]string, error) { + var rows *sql.Rows + + rows, err := database.Query(fmt.Sprintf("SELECT * FROM `%s` LIMIT 1", table)) + if err != nil { + return nil, err + } + + defer rows.Close() + + columns, err := rows.Columns() + if err != nil { + return nil, err + } + + for k, column := range columns { + replacement, ok := params.SelectMap[strings.ToLower(table)][strings.ToLower(column)] + if ok { + columns[k] = fmt.Sprintf("%s AS `%s`", replacement, column) + } else { + columns[k] = fmt.Sprintf("`%s`", column) + } + } + + return columns, nil +} diff --git a/internal/mysql/provider/utils/utils_test.go b/internal/mysql/provider/utils/utils_test.go new file mode 100644 index 0000000..7a0bc8d --- /dev/null +++ b/internal/mysql/provider/utils/utils_test.go @@ -0,0 +1,34 @@ +package utils + +import ( + "errors" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" + + "github.com/skpr/mtk/internal/mysql/mock" + "github.com/skpr/mtk/internal/mysql/provider" +) + +func TestMySQLGetColumnsForSelect(t *testing.T) { + db, mock := mock.GetDB(t) + mock.ExpectQuery("SELECT \\* FROM `table` LIMIT 1").WillReturnRows( + sqlmock.NewRows([]string{"col1", "col2", "col3"}).AddRow("a", "b", "c")) + columns, err := QueryColumnsForTable(db, "table", provider.DumpParams{ + SelectMap: map[string]map[string]string{"table": {"col2": "NOW()"}}, + }) + assert.Nil(t, err) + assert.Equal(t, []string{"`col1`", "NOW() AS `col2`", "`col3`"}, columns) +} + +func TestMySQLGetColumnsForSelectHandlingErrorWhenQuerying(t *testing.T) { + db, mock := mock.GetDB(t) + error := errors.New("broken") + mock.ExpectQuery("SELECT \\* FROM `table` LIMIT 1").WillReturnError(error) + columns, err := QueryColumnsForTable(db, "table", provider.DumpParams{ + SelectMap: map[string]map[string]string{"table": {"col2": "NOW()"}}, + }) + assert.Equal(t, err, error) + assert.Empty(t, columns) +} diff --git a/internal/mysql/tables.go b/internal/mysql/tables.go index 847af6d..443fc44 100644 --- a/internal/mysql/tables.go +++ b/internal/mysql/tables.go @@ -2,11 +2,15 @@ package mysql import ( "database/sql" + "errors" "fmt" "strings" "github.com/gobwas/glob" + "github.com/skpr/mtk/internal/mysql/provider" + "github.com/skpr/mtk/internal/mysql/provider/rds" + "github.com/skpr/mtk/internal/mysql/provider/stdout" "github.com/skpr/mtk/internal/sliceutils" ) @@ -67,53 +71,27 @@ func (d *Client) QueryTables() ([]string, error) { return tables, nil } -// QueryColumnsForTable for a given table. -func (d *Client) QueryColumnsForTable(table string, params DumpParams) ([]string, error) { - var rows *sql.Rows - - rows, err := d.DB.Query(fmt.Sprintf("SELECT * FROM `%s` LIMIT 1", table)) - if err != nil { - return nil, err - } - - defer rows.Close() - - columns, err := rows.Columns() - if err != nil { - return nil, err - } - - for k, column := range columns { - replacement, ok := params.SelectMap[strings.ToLower(table)][strings.ToLower(column)] - if ok { - columns[k] = fmt.Sprintf("%s AS `%s`", replacement, column) - } else { - columns[k] = fmt.Sprintf("`%s`", column) - } +func (d *Client) getProviderClient() (provider.Interface, error) { + switch d.Provider { + case "rds": + client := rds.NewClient(d.DB, d.Logger, d.Region, d.URI) + return client, nil + case "stdout": + return stdout.NewClient(d.DB, d.Logger), nil + default: + return nil, errors.New("invalid provider") } - - return columns, nil } -// GetSelectQueryForTable will return a complete SELECT query to fetch data from a table. -func (d *Client) GetSelectQueryForTable(table string, params DumpParams) (string, error) { - cols, err := d.QueryColumnsForTable(table, params) - if err != nil { - return "", err - } - - query := fmt.Sprintf("SELECT %s FROM `%s`", strings.Join(cols, ", "), table) +// Helper function to get all data for a table. +func (d *Client) selectAllDataForTable(table string, params provider.DumpParams) (*sql.Rows, []string, error) { - if where, ok := params.WhereMap[strings.ToLower(table)]; ok { - query = fmt.Sprintf("%s WHERE %s", query, where) + client, err := d.getProviderClient() + if err != nil { + return nil, nil, err } - return query, nil -} - -// Helper function to get all data for a table. -func (d *Client) selectAllDataForTable(table string, params DumpParams) (*sql.Rows, []string, error) { - query, err := d.GetSelectQueryForTable(table, params) + query, err := client.GetSelectQueryForTable(table, params) if err != nil { return nil, nil, err } @@ -132,7 +110,7 @@ func (d *Client) selectAllDataForTable(table string, params DumpParams) (*sql.Ro } // GetRowCountForTable will return the number of rows using a SELECT statement. -func (d *Client) GetRowCountForTable(table string, params DumpParams) (uint64, error) { +func (d *Client) GetRowCountForTable(table string, params provider.DumpParams) (uint64, error) { query := fmt.Sprintf("SELECT COUNT(*) FROM `%s`", table) if where, ok := params.WhereMap[strings.ToLower(table)]; ok { diff --git a/internal/mysql/tables_test.go b/internal/mysql/tables_test.go index 1ce93f0..d5a615e 100644 --- a/internal/mysql/tables_test.go +++ b/internal/mysql/tables_test.go @@ -11,11 +11,12 @@ import ( "github.com/stretchr/testify/assert" "github.com/skpr/mtk/internal/mysql/mock" + "github.com/skpr/mtk/internal/mysql/provider" ) func TestMySQLFlushTable(t *testing.T) { db, mock := mock.GetDB(t) - dumper := NewClient(db, log.New(os.Stdout, "", 0)) + dumper := NewClient(db, log.New(os.Stdout, "", 0), "", "", "") mock.ExpectExec("FLUSH TABLES `table`").WillReturnResult(sqlmock.NewResult(0, 1)) _, err := dumper.FlushTable("table") assert.Nil(t, err) @@ -23,7 +24,7 @@ func TestMySQLFlushTable(t *testing.T) { func TestMySQLUnlockTables(t *testing.T) { db, mock := mock.GetDB(t) - dumper := NewClient(db, log.New(os.Stdout, "", 0)) + dumper := NewClient(db, log.New(os.Stdout, "", 0), "", "", "") mock.ExpectExec("UNLOCK TABLES").WillReturnResult(sqlmock.NewResult(0, 1)) _, err := dumper.UnlockTables() assert.Nil(t, err) @@ -31,7 +32,7 @@ func TestMySQLUnlockTables(t *testing.T) { func TestMySQLQueryTables(t *testing.T) { db, mock := mock.GetDB(t) - dumper := NewClient(db, log.New(os.Stdout, "", 0)) + dumper := NewClient(db, log.New(os.Stdout, "", 0), "", "", "") mock.ExpectQuery("SHOW FULL TABLES").WillReturnRows( sqlmock.NewRows([]string{"Tables_in_database", "Table_type"}). AddRow("table1", "BASE TABLE"). @@ -44,7 +45,7 @@ func TestMySQLQueryTables(t *testing.T) { func TestMySQLLockTableRead(t *testing.T) { db, mock := mock.GetDB(t) - dumper := NewClient(db, log.New(os.Stdout, "", 0)) + dumper := NewClient(db, log.New(os.Stdout, "", 0), "", "", "") mock.ExpectExec("LOCK TABLES `table` READ").WillReturnResult(sqlmock.NewResult(0, 1)) _, err := dumper.LockTableReading("table") assert.Nil(t, err) @@ -52,7 +53,7 @@ func TestMySQLLockTableRead(t *testing.T) { func TestMySQLGetTablesHandlingErrorWhenListingTables(t *testing.T) { db, mock := mock.GetDB(t) - dumper := NewClient(db, log.New(os.Stdout, "", 0)) + dumper := NewClient(db, log.New(os.Stdout, "", 0), "", "", "") expectedErr := errors.New("broken") mock.ExpectQuery("SHOW FULL TABLES").WillReturnError(expectedErr) tables, err := dumper.QueryTables() @@ -62,7 +63,7 @@ func TestMySQLGetTablesHandlingErrorWhenListingTables(t *testing.T) { func TestMySQLGetTablesHandlingErrorWhenScanningRow(t *testing.T) { db, mock := mock.GetDB(t) - dumper := NewClient(db, log.New(os.Stdout, "", 0)) + dumper := NewClient(db, log.New(os.Stdout, "", 0), "", "", "") mock.ExpectQuery("SHOW FULL TABLES").WillReturnRows( sqlmock.NewRows([]string{"Tables_in_database", "Table_type"}).AddRow(1, nil)) tables, err := dumper.QueryTables() @@ -77,7 +78,7 @@ func TestMySQLDumpCreateTable(t *testing.T) { "PRIMARY KEY (`id`), KEY `idx_name` (`name`) " + ") ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8" db, mock := mock.GetDB(t) - dumper := NewClient(db, log.New(os.Stdout, "", 0)) + dumper := NewClient(db, log.New(os.Stdout, "", 0), "", "", "") mock.ExpectQuery("SHOW CREATE TABLE `table`").WillReturnRows( sqlmock.NewRows([]string{"Table", "Create Table"}). AddRow("table", ddl), @@ -90,69 +91,19 @@ func TestMySQLDumpCreateTable(t *testing.T) { func TestMySQLDumpCreateTableHandlingErrorWhenScanningRows(t *testing.T) { db, mock := mock.GetDB(t) - dumper := NewClient(db, log.New(os.Stdout, "", 0)) + dumper := NewClient(db, log.New(os.Stdout, "", 0), "", "", "") mock.ExpectQuery("SHOW CREATE TABLE `table`").WillReturnRows( sqlmock.NewRows([]string{"Table", "Create Table"}).AddRow("table", nil)) buffer := bytes.NewBuffer(make([]byte, 0)) assert.NotNil(t, dumper.WriteCreateTable(buffer, "table")) } -func TestMySQLGetColumnsForSelect(t *testing.T) { - db, mock := mock.GetDB(t) - dumper := NewClient(db, log.New(os.Stdout, "", 0)) - mock.ExpectQuery("SELECT \\* FROM `table` LIMIT 1").WillReturnRows( - sqlmock.NewRows([]string{"col1", "col2", "col3"}).AddRow("a", "b", "c")) - columns, err := dumper.QueryColumnsForTable("table", DumpParams{ - SelectMap: map[string]map[string]string{"table": {"col2": "NOW()"}}, - }) - assert.Nil(t, err) - assert.Equal(t, []string{"`col1`", "NOW() AS `col2`", "`col3`"}, columns) -} - -func TestMySQLGetColumnsForSelectHandlingErrorWhenQuerying(t *testing.T) { - db, mock := mock.GetDB(t) - dumper := NewClient(db, log.New(os.Stdout, "", 0)) - error := errors.New("broken") - mock.ExpectQuery("SELECT \\* FROM `table` LIMIT 1").WillReturnError(error) - columns, err := dumper.QueryColumnsForTable("table", DumpParams{ - SelectMap: map[string]map[string]string{"table": {"col2": "NOW()"}}, - }) - assert.Equal(t, err, error) - assert.Empty(t, columns) -} - -func TestMySQLGetSelectQueryFor(t *testing.T) { - db, mock := mock.GetDB(t) - dumper := NewClient(db, log.New(os.Stdout, "", 0)) - mock.ExpectQuery("SELECT \\* FROM `table` LIMIT 1").WillReturnRows( - sqlmock.NewRows([]string{"c1", "c2"}).AddRow("a", "b")) - query, err := dumper.GetSelectQueryForTable("table", DumpParams{ - SelectMap: map[string]map[string]string{"table": {"c2": "NOW()"}}, - WhereMap: map[string]string{"table": "c1 > 0"}, - }) - assert.Nil(t, err) - assert.Equal(t, "SELECT `c1`, NOW() AS `c2` FROM `table` WHERE c1 > 0", query) -} - -func TestMySQLGetSelectQueryForHandlingError(t *testing.T) { - db, mock := mock.GetDB(t) - dumper := NewClient(db, log.New(os.Stdout, "", 0)) - error := errors.New("broken") - mock.ExpectQuery("SELECT \\* FROM `table` LIMIT 1").WillReturnError(error) - query, err := dumper.GetSelectQueryForTable("table", DumpParams{ - SelectMap: map[string]map[string]string{"table": {"c2": "NOW()"}}, - WhereMap: map[string]string{"table": "c1 > 0"}, - }) - assert.Equal(t, error, err) - assert.Equal(t, "", query) -} - func TestMySQLGetRowCount(t *testing.T) { db, mock := mock.GetDB(t) - dumper := NewClient(db, log.New(os.Stdout, "", 0)) + dumper := NewClient(db, log.New(os.Stdout, "", 0), "", "", "") mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM `table` WHERE c1 > 0").WillReturnRows( sqlmock.NewRows([]string{"COUNT(*)"}).AddRow(1234)) - count, err := dumper.GetRowCountForTable("table", DumpParams{ + count, err := dumper.GetRowCountForTable("table", provider.DumpParams{ WhereMap: map[string]string{"table": "c1 > 0"}, }) assert.Nil(t, err) @@ -161,10 +112,10 @@ func TestMySQLGetRowCount(t *testing.T) { func TestMySQLGetRowCountHandlingError(t *testing.T) { db, mock := mock.GetDB(t) - dumper := NewClient(db, log.New(os.Stdout, "", 0)) + dumper := NewClient(db, log.New(os.Stdout, "", 0), "", "", "") mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM `table` WHERE c1 > 0").WillReturnRows( sqlmock.NewRows([]string{"COUNT(*)"}).AddRow(nil)) - count, err := dumper.GetRowCountForTable("table", DumpParams{ + count, err := dumper.GetRowCountForTable("table", provider.DumpParams{ WhereMap: map[string]string{"table": "c1 > 0"}, }) assert.NotNil(t, err) diff --git a/internal/mysql/write.go b/internal/mysql/write.go index 8a0c9da..fdd9279 100644 --- a/internal/mysql/write.go +++ b/internal/mysql/write.go @@ -6,6 +6,8 @@ import ( "io" "strings" "time" + + "github.com/skpr/mtk/internal/mysql/provider" ) // WriteHeader is intended to be added at the beginning of a dump to manage database configuration. @@ -101,7 +103,7 @@ func (d *Client) WriteCreateTable(w io.Writer, table string) error { } // WriteTableHeader which contains debug information. -func (d *Client) WriteTableHeader(w io.Writer, table string, params DumpParams) (uint64, error) { +func (d *Client) WriteTableHeader(w io.Writer, table string, params provider.DumpParams) (uint64, error) { fmt.Fprintf(w, "\n--\n-- Data for table `%s`", table) count, err := d.GetRowCountForTable(table, params) @@ -115,7 +117,7 @@ func (d *Client) WriteTableHeader(w io.Writer, table string, params DumpParams) } // WriteTableData for a specific table. -func (d *Client) WriteTableData(w io.Writer, table string, params DumpParams) error { +func (d *Client) WriteTableData(w io.Writer, table string, params provider.DumpParams) error { d.Logger.Println("Dumping data for table:", table) rows, columns, err := d.selectAllDataForTable(table, params) @@ -184,7 +186,7 @@ func (d *Client) WriteTableData(w io.Writer, table string, params DumpParams) er } // WriteTables will create a script for all tables. -func (d *Client) writeTables(w io.Writer, params DumpParams) error { +func (d *Client) writeTables(w io.Writer, params provider.DumpParams) error { tables, err := d.QueryTables() if err != nil { return err @@ -200,7 +202,7 @@ func (d *Client) writeTables(w io.Writer, params DumpParams) error { } // WriteTable allows for a single table dump script. -func (d *Client) writeTable(w io.Writer, table string, params DumpParams) error { +func (d *Client) writeTable(w io.Writer, table string, params provider.DumpParams) error { if params.FilterMap[strings.ToLower(table)] == OperationIgnore { return nil } diff --git a/internal/mysql/write_test.go b/internal/mysql/write_test.go index 853af2e..c5f048b 100644 --- a/internal/mysql/write_test.go +++ b/internal/mysql/write_test.go @@ -12,15 +12,16 @@ import ( "github.com/stretchr/testify/assert" "github.com/skpr/mtk/internal/mysql/mock" + "github.com/skpr/mtk/internal/mysql/provider" ) func TestMySQLDumpTableHeader(t *testing.T) { db, mock := mock.GetDB(t) - dumper := NewClient(db, log.New(os.Stdout, "", 0)) + dumper := NewClient(db, log.New(os.Stdout, "", 0), "stdout", "", "") mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM `table`").WillReturnRows( sqlmock.NewRows([]string{"COUNT(*)"}).AddRow(1234)) buffer := bytes.NewBuffer(make([]byte, 0)) - count, err := dumper.WriteTableHeader(buffer, "table", DumpParams{}) + count, err := dumper.WriteTableHeader(buffer, "table", provider.DumpParams{}) assert.Equal(t, uint64(1234), count) assert.Nil(t, err) assert.Contains(t, buffer.String(), "Data for table `table`") @@ -29,25 +30,25 @@ func TestMySQLDumpTableHeader(t *testing.T) { func TestMySQLDumpTableHeaderHandlingError(t *testing.T) { db, mock := mock.GetDB(t) - dumper := NewClient(db, log.New(os.Stdout, "", 0)) + dumper := NewClient(db, log.New(os.Stdout, "", 0), "stdout", "", "") mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM `table`").WillReturnRows( sqlmock.NewRows([]string{"COUNT(*)"}).AddRow(nil)) buffer := bytes.NewBuffer(make([]byte, 0)) - count, err := dumper.WriteTableHeader(buffer, "table", DumpParams{}) + count, err := dumper.WriteTableHeader(buffer, "table", provider.DumpParams{}) assert.Equal(t, uint64(0), count) assert.NotNil(t, err) } func TestMySQLDumpTableLockWrite(t *testing.T) { buffer := bytes.NewBuffer(make([]byte, 0)) - dumper := NewClient(nil, log.New(os.Stdout, "", 0)) + dumper := NewClient(nil, log.New(os.Stdout, "", 0), "stdout", "", "") dumper.WriteTableLockWrite(buffer, "table") assert.Contains(t, buffer.String(), "LOCK TABLES `table` WRITE;") } func TestMySQLDumpUnlockTables(t *testing.T) { buffer := bytes.NewBuffer(make([]byte, 0)) - dumper := NewClient(nil, log.New(os.Stdout, "", 0)) + dumper := NewClient(nil, log.New(os.Stdout, "", 0), "stdout", "", "") dumper.WriteUnlockTables(buffer) assert.Contains(t, buffer.String(), "UNLOCK TABLES;") } @@ -55,7 +56,7 @@ func TestMySQLDumpUnlockTables(t *testing.T) { func TestMySQLDumpTableData(t *testing.T) { db, mock := mock.GetDB(t) buffer := bytes.NewBuffer(make([]byte, 0)) - dumper := NewClient(db, log.New(os.Stdout, "", 0)) + dumper := NewClient(db, log.New(os.Stdout, "", 0), "stdout", "", "") mock.ExpectQuery("SELECT \\* FROM `table` LIMIT 1").WillReturnRows( sqlmock.NewRows([]string{"id", "language"}). @@ -70,7 +71,8 @@ func TestMySQLDumpTableData(t *testing.T) { AddRow(5, "Rust"). AddRow(6, "Closure")) - assert.Nil(t, dumper.WriteTableData(buffer, "table", DumpParams{ExtendedInsertRows: 2})) + assert.Nil(t, dumper.WriteTableData(buffer, "table", provider.DumpParams{ + ExtendedInsertRows: 2})) assert.Equal(t, strings.Count(buffer.String(), "INSERT INTO `table` VALUES"), 3) assert.Equal(t, buffer.String(), "INSERT INTO `table` VALUES (1,'Go'),(2,'Java');\nINSERT INTO `table` VALUES (3,'C'),(4,'C++');\nINSERT INTO `table` VALUES (5,'Rust'),(6,'Closure');\n") @@ -79,8 +81,8 @@ func TestMySQLDumpTableData(t *testing.T) { func TestMySQLDumpTableDataHandlingErrorFromSelectAllDataFor(t *testing.T) { db, mock := mock.GetDB(t) buffer := bytes.NewBuffer(make([]byte, 0)) - dumper := NewClient(db, log.New(os.Stdout, "", 0)) + dumper := NewClient(db, log.New(os.Stdout, "", 0), "stdout", "", "") error := errors.New("fail") mock.ExpectQuery("SELECT \\* FROM `table` LIMIT 1").WillReturnError(error) - assert.Equal(t, error, dumper.WriteTableData(buffer, "table", DumpParams{})) + assert.Equal(t, error, dumper.WriteTableData(buffer, "table", provider.DumpParams{})) } diff --git a/pkg/envar/const.go b/pkg/envar/const.go index fb23e45..b294c3c 100644 --- a/pkg/envar/const.go +++ b/pkg/envar/const.go @@ -15,6 +15,8 @@ const ( Protocol = "MTK_PROTOCOL" // Port defines the environment variable when using the command line. Port = "MTK_PORT" + // Provider defines the provider to use when using the command line. + Provider = "MTK_PROVIDER" // ExtendedInsertRows defines the environment variable when using the command line. ExtendedInsertRows = "MTK_EXTENDED_INSERT_ROWS" // MySQLHostname defines the environment variable when using the command line. @@ -27,4 +29,8 @@ const ( MySQLProtocol = "MYSQL_PROTOCOL" // MySQLPort defines the environment variable when using the command line. MySQLPort = "MYSQL_PORT" + // RDSRegion defines the default AWS Region configuration for use with RDS from the command line. + RDSRegion = "MTK_RDS_REGION" + // RDSS3Uri defines the URI of the bucket path for use with the RDS provider from the command line. + RDSS3Uri = "MTK_RDS_S3_URI" )