From 695614ffbce4041e72bb7360a5d77d57283979b6 Mon Sep 17 00:00:00 2001 From: Krzysztof Tomecki <152964795+chris-4chain@users.noreply.github.com> Date: Wed, 27 Mar 2024 09:33:55 +0100 Subject: [PATCH 01/12] fix(BUX-686): where builder --- engine/datastore/interface.go | 1 - engine/datastore/models.go | 26 +- engine/datastore/where.go | 168 +++++---- engine/datastore/where_test.go | 646 +++++++++++++++++---------------- go.mod | 2 + go.sum | 1 + 6 files changed, 439 insertions(+), 405 deletions(-) diff --git a/engine/datastore/interface.go b/engine/datastore/interface.go index bc9a2641e..f474e984e 100644 --- a/engine/datastore/interface.go +++ b/engine/datastore/interface.go @@ -12,7 +12,6 @@ import ( type StorageService interface { AutoMigrateDatabase(ctx context.Context, models ...interface{}) error CreateInBatches(ctx context.Context, models interface{}, batchSize int) error - CustomWhere(tx CustomWhereInterface, conditions map[string]interface{}, engine Engine) interface{} Execute(query string) *gorm.DB GetModel(ctx context.Context, model interface{}, conditions map[string]interface{}, timeout time.Duration, forceWriteDB bool) error diff --git a/engine/datastore/models.go b/engine/datastore/models.go index f63dfb540..e317e2468 100644 --- a/engine/datastore/models.go +++ b/engine/datastore/models.go @@ -189,10 +189,10 @@ func (c *Client) GetModel( tx = ctxDB.Select("*") } - // Add conditions if len(conditions) > 0 { - gtx := gormWhere{tx: tx} - return checkResult(c.CustomWhere(>x, conditions, c.Engine()).(*gorm.DB).Find(model)) + if err := ApplyCustomWhere(c, tx, conditions); err != nil { + return err + } } return checkResult(tx.Find(model)) @@ -292,13 +292,10 @@ func (c *Client) find(ctx context.Context, result interface{}, conditions map[st }) } - // Check for errors or no records found if len(conditions) > 0 { - gtx := gormWhere{tx: tx} - if fieldResults != nil { - return checkResult(c.CustomWhere(>x, conditions, c.Engine()).(*gorm.DB).Find(fieldResults)) + if err := ApplyCustomWhere(c, tx, conditions); err != nil { + return err } - return checkResult(c.CustomWhere(>x, conditions, c.Engine()).(*gorm.DB).Find(result)) } // Skip the conditions @@ -320,10 +317,9 @@ func (c *Client) count(ctx context.Context, model interface{}, conditions map[st // Check for errors or no records found if len(conditions) > 0 { - gtx := gormWhere{tx: tx} - var count int64 - err := checkResult(c.CustomWhere(>x, conditions, c.Engine()).(*gorm.DB).Model(model).Count(&count)) - return count, err + if err := ApplyCustomWhere(c, tx, conditions); err != nil { + return 0, err + } } var count int64 err := checkResult(tx.Count(&count)) @@ -350,8 +346,10 @@ func (c *Client) aggregate(ctx context.Context, model interface{}, conditions ma // Check for errors or no records found var aggregate []map[string]interface{} if len(conditions) > 0 { - gtx := gormWhere{tx: tx} - err := checkResult(c.CustomWhere(>x, conditions, c.Engine()).(*gorm.DB).Model(model).Group(aggregateColumn).Scan(&aggregate)) + if err := ApplyCustomWhere(c, tx, conditions); err != nil { + return nil, err + } + err := checkResult(tx.Group(aggregateColumn).Scan(&aggregate)) if err != nil { return nil, err } diff --git a/engine/datastore/where.go b/engine/datastore/where.go index 336432024..645d87d53 100644 --- a/engine/datastore/where.go +++ b/engine/datastore/where.go @@ -2,33 +2,21 @@ package datastore import ( "encoding/json" + "fmt" "reflect" "strconv" "strings" + "sync" customtypes "github.com/bitcoin-sv/spv-wallet/engine/datastore/customtypes" "gorm.io/gorm" + "gorm.io/gorm/schema" ) -// CustomWhereInterface is an interface for the CustomWhere clauses type CustomWhereInterface interface { - Where(query interface{}, args ...interface{}) - getGormTx() *gorm.DB + Where(query interface{}, args ...interface{}) *gorm.DB } -// CustomWhere add conditions -func (c *Client) CustomWhere(tx CustomWhereInterface, conditions map[string]interface{}, engine Engine) interface{} { - // Empty accumulator - varNum := 0 - - // Process the conditions - processConditions(c, tx, conditions, engine, &varNum, nil) - - // Return the GORM tx - return tx.getGormTx() -} - -// txAccumulator is the accumulator struct type txAccumulator struct { CustomWhereInterface WhereClauses []string @@ -36,7 +24,7 @@ type txAccumulator struct { } // Where is our custom where method -func (tx *txAccumulator) Where(query interface{}, args ...interface{}) { +func (tx *txAccumulator) Where(query interface{}, args ...interface{}) *gorm.DB { tx.WhereClauses = append(tx.WhereClauses, query.(string)) if len(args) > 0 { @@ -46,49 +34,97 @@ func (tx *txAccumulator) Where(query interface{}, args ...interface{}) { } } } + + return nil +} + +// WhereBuilder holds a state during custom where preparation +type WhereBuilder struct { + client ClientInterface + gdb *gorm.DB + varNum int } -// getGormTx will get the GORM tx -func (tx *txAccumulator) getGormTx() *gorm.DB { +// ApplyCustomWhere adds conditions to the gorm db instance +func ApplyCustomWhere(client ClientInterface, gdb *gorm.DB, conditions map[string]interface{}) (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("error processing conditions: %v", r) + } + }() + + builder := &WhereBuilder{ + client: client, + gdb: gdb, + varNum: 0, + } + + builder.processConditions(gdb, conditions, nil) return nil } -// processConditions will process all conditions -func processConditions(client ClientInterface, tx CustomWhereInterface, conditions map[string]interface{}, - engine Engine, varNum *int, parentKey *string, -) map[string]interface{} { //nolint:nolintlint,unparam // ignore for now +func (builder *WhereBuilder) nextVarName() string { + varName := "var" + strconv.Itoa(builder.varNum) + builder.varNum++ + return varName +} +func getColumnName(columnName string, model interface{}) string { + sch, err := schema.Parse(model, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + panic(fmt.Errorf("cannot parse a model %v", model)) + } + if field, ok := sch.FieldsByDBName[columnName]; ok { + return field.DBName + } + + if field, ok := sch.FieldsByName[columnName]; ok { + return field.DBName + } + + panic(fmt.Errorf("column %s does not exist in the model", columnName)) +} + +func (builder *WhereBuilder) applyCondition(tx CustomWhereInterface, key string, operator string, condition interface{}) { + columnName := getColumnName(key, builder.gdb.Statement.Model) + + varName := builder.nextVarName() + query := fmt.Sprintf("%s %s @%s", columnName, operator, varName) + tx.Where(query, map[string]interface{}{varName: builder.formatCondition(condition)}) +} + +func (builder *WhereBuilder) applyExistsCondition(tx CustomWhereInterface, key string, condition bool) { + columnName := getColumnName(key, builder.gdb.Statement.Model) + + operator := "IS NULL" + if condition { + operator = "IS NOT NULL" + } + tx.Where(columnName + " " + operator) +} + +// processConditions will process all conditions +func (builder *WhereBuilder) processConditions(tx CustomWhereInterface, conditions map[string]interface{}, parentKey *string, +) { for key, condition := range conditions { if key == conditionAnd { - processWhereAnd(client, tx, condition, engine, varNum) + builder.processWhereAnd(tx, condition) } else if key == conditionOr { - processWhereOr(client, tx, conditions[conditionOr], engine, varNum) + builder.processWhereOr(tx, conditions[conditionOr]) } else if key == conditionGreaterThan { - varName := "var" + strconv.Itoa(*varNum) - tx.Where(*parentKey+" > @"+varName, map[string]interface{}{varName: formatCondition(condition, engine)}) - *varNum++ + builder.applyCondition(tx, *parentKey, ">", condition) } else if key == conditionLessThan { - varName := "var" + strconv.Itoa(*varNum) - tx.Where(*parentKey+" < @"+varName, map[string]interface{}{varName: formatCondition(condition, engine)}) - *varNum++ + builder.applyCondition(tx, *parentKey, "<", condition) } else if key == conditionGreaterThanOrEqual { - varName := "var" + strconv.Itoa(*varNum) - tx.Where(*parentKey+" >= @"+varName, map[string]interface{}{varName: formatCondition(condition, engine)}) - *varNum++ + builder.applyCondition(tx, *parentKey, ">=", condition) } else if key == conditionLessThanOrEqual { - varName := "var" + strconv.Itoa(*varNum) - tx.Where(*parentKey+" <= @"+varName, map[string]interface{}{varName: formatCondition(condition, engine)}) - *varNum++ + builder.applyCondition(tx, *parentKey, "<=", condition) } else if key == conditionExists { - if condition.(bool) { - tx.Where(*parentKey + " IS NOT NULL") - } else { - tx.Where(*parentKey + " IS NULL") - } - } else if StringInSlice(key, client.GetArrayFields()) { - tx.Where(whereSlice(engine, key, formatCondition(condition, engine))) - } else if StringInSlice(key, client.GetObjectFields()) { - tx.Where(whereObject(engine, key, formatCondition(condition, engine))) + builder.applyExistsCondition(tx, *parentKey, condition.(bool)) + } else if StringInSlice(key, builder.client.GetArrayFields()) { + tx.Where(builder.whereSlice(key, builder.formatCondition(condition))) + } else if StringInSlice(key, builder.client.GetObjectFields()) { + tx.Where(builder.whereObject(key, builder.formatCondition(condition))) } else { if condition == nil { tx.Where(key + " IS NULL") @@ -97,30 +133,27 @@ func processConditions(client ClientInterface, tx CustomWhereInterface, conditio switch v.Kind() { //nolint:exhaustive // not all cases are needed case reflect.Map: if _, ok := condition.(map[string]interface{}); ok { - processConditions(client, tx, condition.(map[string]interface{}), engine, varNum, &key) //nolint:scopelint // ignore for now + builder.processConditions(tx, condition.(map[string]interface{}), &key) //nolint:scopelint // ignore for now } else { c, _ := json.Marshal(condition) //nolint:errchkjson // this check might break the current code var cc map[string]interface{} _ = json.Unmarshal(c, &cc) - processConditions(client, tx, cc, engine, varNum, &key) //nolint:scopelint // ignore for now + builder.processConditions(tx, cc, &key) //nolint:scopelint // ignore for now } default: - varName := "var" + strconv.Itoa(*varNum) - tx.Where(key+" = @"+varName, map[string]interface{}{varName: formatCondition(condition, engine)}) - *varNum++ + builder.applyCondition(tx, key, "=", condition) } } } } - - return conditions } // formatCondition will format the conditions -func formatCondition(condition interface{}, engine Engine) interface{} { +func (builder *WhereBuilder) formatCondition(condition interface{}) interface{} { switch v := condition.(type) { case customtypes.NullTime: if v.Valid { + engine := builder.client.Engine() if engine == MySQL { return v.Time.Format("2006-01-02 15:04:05") } else if engine == PostgreSQL { @@ -136,24 +169,20 @@ func formatCondition(condition interface{}, engine Engine) interface{} { } // processWhereAnd will process the AND statements -func processWhereAnd(client ClientInterface, tx CustomWhereInterface, condition interface{}, engine Engine, varNum *int) { +func (builder *WhereBuilder) processWhereAnd(tx CustomWhereInterface, condition interface{}) { accumulator := &txAccumulator{ WhereClauses: make([]string, 0), Vars: make(map[string]interface{}), } for _, c := range condition.([]map[string]interface{}) { - processConditions(client, accumulator, c, engine, varNum, nil) + builder.processConditions(accumulator, c, nil) } - if len(accumulator.Vars) > 0 { - tx.Where(" ( "+strings.Join(accumulator.WhereClauses, " AND ")+" ) ", accumulator.Vars) - } else { - tx.Where(" ( " + strings.Join(accumulator.WhereClauses, " AND ") + " ) ") - } + tx.Where(" ( "+strings.Join(accumulator.WhereClauses, " AND ")+" ) ", accumulator.Vars) } // processWhereOr will process the OR statements -func processWhereOr(client ClientInterface, tx CustomWhereInterface, condition interface{}, engine Engine, varNum *int) { +func (builder *WhereBuilder) processWhereOr(tx CustomWhereInterface, condition interface{}) { or := make([]string, 0) orVars := make(map[string]interface{}) for _, cond := range condition.([]map[string]interface{}) { @@ -162,7 +191,7 @@ func processWhereOr(client ClientInterface, tx CustomWhereInterface, condition i WhereClauses: make([]string, 0), Vars: make(map[string]interface{}), } - processConditions(client, accumulator, cond, engine, varNum, nil) + builder.processConditions(accumulator, cond, nil) statement = append(statement, accumulator.WhereClauses...) for varName, varValue := range accumulator.Vars { orVars[varName] = varValue @@ -170,11 +199,7 @@ func processWhereOr(client ClientInterface, tx CustomWhereInterface, condition i or = append(or, strings.Join(statement[:], " AND ")) } - if len(orVars) > 0 { - tx.Where(" ( ("+strings.Join(or, ") OR (")+") ) ", orVars) - } else { - tx.Where(" ( (" + strings.Join(or, ") OR (") + ") ) ") - } + tx.Where(" ( ("+strings.Join(or, ") OR (")+") ) ", orVars) } // escapeDBString will escape the database string @@ -184,7 +209,7 @@ func escapeDBString(s string) string { } // whereObject generates the where object -func whereObject(engine Engine, k string, v interface{}) string { +func (builder *WhereBuilder) whereObject(k string, v interface{}) string { queryParts := make([]string, 0) // we don't know the type, we handle the rangeValue as a map[string]interface{} @@ -193,6 +218,8 @@ func whereObject(engine Engine, k string, v interface{}) string { var rangeV map[string]interface{} _ = json.Unmarshal(vJSON, &rangeV) + engine := builder.client.Engine() + for rangeKey, rangeValue := range rangeV { if engine == MySQL || engine == SQLite { switch vv := rangeValue.(type) { @@ -235,7 +262,8 @@ func whereObject(engine Engine, k string, v interface{}) string { } // whereSlice generates the where slice -func whereSlice(engine Engine, k string, v interface{}) string { +func (builder *WhereBuilder) whereSlice(k string, v interface{}) string { + engine := builder.client.Engine() if engine == MySQL { return "JSON_CONTAINS(" + k + ", CAST('[\"" + v.(string) + "\"]' AS JSON))" } else if engine == PostgreSQL { diff --git a/engine/datastore/where_test.go b/engine/datastore/where_test.go index aeab451ea..32ca3d62b 100644 --- a/engine/datastore/where_test.go +++ b/engine/datastore/where_test.go @@ -6,29 +6,88 @@ import ( "testing" "time" + "github.com/DATA-DOG/go-sqlmock" customtypes "github.com/bitcoin-sv/spv-wallet/engine/datastore/customtypes" "github.com/stretchr/testify/assert" + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" "gorm.io/gorm" ) +func mockDialector(engine Engine) gorm.Dialector { + mockDb, _, _ := sqlmock.New() + switch engine { + case MySQL: + return mysql.New(mysql.Config{ + Conn: mockDb, + SkipInitializeWithVersion: true, + DriverName: "mysql", + }) + case PostgreSQL: + return postgres.New(postgres.Config{ + Conn: mockDb, + DriverName: "postgres", + }) + case SQLite: + return sqlite.Open("file::memory:?cache=shared") + case MongoDB, Empty: + // the where builder is not applicable for MongoDB + return nil + default: + return nil + } +} + +func mockClient(engine Engine) (*Client, *gorm.DB) { + clientInterface, _ := NewClient(context.Background()) + client, _ := clientInterface.(*Client) + client.options.engine = engine + dialector := mockDialector(engine) + gdb, _ := gorm.Open(dialector, &gorm.Config{}) + return client, gdb +} + +func makeWhereBuilder(client *Client, gdb *gorm.DB, model interface{}) *WhereBuilder { + return &WhereBuilder{ + client: client, + gdb: gdb.Model(model), + varNum: 0, + } +} + +type mockObject struct { + ID string + CreatedAt time.Time + UniqueFieldName string + Number int + ReferenceID string +} + // Test_whereObject test the SQL where selector func Test_whereSlice(t *testing.T) { t.Parallel() t.Run("MySQL", func(t *testing.T) { - query := whereSlice(MySQL, fieldInIDs, "id_1") + client, gdb := mockClient(MySQL) + builder := makeWhereBuilder(client, gdb, mockObject{}) + query := builder.whereSlice(fieldInIDs, "id_1") expected := `JSON_CONTAINS(` + fieldInIDs + `, CAST('["id_1"]' AS JSON))` assert.Equal(t, expected, query) }) t.Run("Postgres", func(t *testing.T) { - query := whereSlice(PostgreSQL, fieldInIDs, "id_1") + client, gdb := mockClient(PostgreSQL) + builder := makeWhereBuilder(client, gdb, mockObject{}) + query := builder.whereSlice(fieldInIDs, "id_1") expected := fieldInIDs + `::jsonb @> '["id_1"]'` assert.Equal(t, expected, query) }) t.Run("SQLite", func(t *testing.T) { - query := whereSlice(SQLite, fieldInIDs, "id_1") + client, gdb := mockClient(SQLite) + builder := makeWhereBuilder(client, gdb, mockObject{}) + query := builder.whereSlice(fieldInIDs, "id_1") expected := `EXISTS (SELECT 1 FROM json_each(` + fieldInIDs + `) WHERE value = "id_1")` assert.Equal(t, expected, query) }) @@ -40,13 +99,15 @@ func Test_processConditions(t *testing.T) { dateField := dateCreatedAt uniqueField := "unique_field_name" + theTime := time.Date(2022, 4, 4, 15, 12, 37, 651387237, time.UTC) + nullTime := sql.NullTime{ + Valid: true, + Time: theTime, + } conditions := map[string]interface{}{ dateField: map[string]interface{}{ - conditionGreaterThan: customtypes.NullTime{NullTime: sql.NullTime{ - Valid: true, - Time: time.Date(2022, 4, 4, 15, 12, 37, 651387237, time.UTC), - }}, + conditionGreaterThan: customtypes.NullTime{NullTime: nullTime}, }, uniqueField: map[string]interface{}{ conditionExists: true, @@ -54,51 +115,48 @@ func Test_processConditions(t *testing.T) { } t.Run("MySQL", func(t *testing.T) { - client, deferFunc := testClient(context.Background(), t) - defer deferFunc() - tx := &mockSQLCtx{ - WhereClauses: make([]interface{}, 0), - Vars: make(map[string]interface{}), - } - varNum := 0 - _ = processConditions(client, tx, conditions, MySQL, &varNum, nil) - // assert.Equal(t, "created_at > @var0", tx.WhereClauses[0]) - assert.Contains(t, tx.WhereClauses, dateField+" > @var0") - // assert.Equal(t, "unique_field_name IS NOT NULL", tx.WhereClauses[1]) - assert.Contains(t, tx.WhereClauses, uniqueField+" IS NOT NULL") - assert.Equal(t, "2022-04-04 15:12:37", tx.Vars["var0"]) + client, gdb := mockClient(MySQL) + + raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx = tx.Model(mockObject{}) + err := ApplyCustomWhere(client, tx, conditions) + assert.NoError(t, err) + return tx.First(&mockObject{}) + }) + + assert.Contains(t, raw, "2022-04-04 15:12:37") + assert.Contains(t, raw, "AND") + assert.Regexp(t, "(.+)unique_field_name(.+)IS NOT NULL", raw) }) t.Run("Postgres", func(t *testing.T) { - client, deferFunc := testClient(context.Background(), t) - defer deferFunc() - tx := &mockSQLCtx{ - WhereClauses: make([]interface{}, 0), - Vars: make(map[string]interface{}), - } - varNum := 0 - _ = processConditions(client, tx, conditions, PostgreSQL, &varNum, nil) - // assert.Equal(t, "created_at > @var0", tx.WhereClauses[0]) - assert.Contains(t, tx.WhereClauses, dateField+" > @var0") - // assert.Equal(t, "unique_field_name IS NOT NULL", tx.WhereClauses[1]) - assert.Contains(t, tx.WhereClauses, uniqueField+" IS NOT NULL") - assert.Equal(t, "2022-04-04T15:12:37Z", tx.Vars["var0"]) + client, gdb := mockClient(PostgreSQL) + + raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx = tx.Model(mockObject{}) + err := ApplyCustomWhere(client, tx, conditions) + assert.NoError(t, err) + return tx.First(&mockObject{}) + }) + + assert.Contains(t, raw, "2022-04-04T15:12:37Z") + assert.Contains(t, raw, "AND") + assert.Regexp(t, "(.+)unique_field_name(.+)IS NOT NULL", raw) }) t.Run("SQLite", func(t *testing.T) { - client, deferFunc := testClient(context.Background(), t) - defer deferFunc() - tx := &mockSQLCtx{ - WhereClauses: make([]interface{}, 0), - Vars: make(map[string]interface{}), - } - varNum := 0 - _ = processConditions(client, tx, conditions, SQLite, &varNum, nil) - // assert.Equal(t, "created_at > @var0", tx.WhereClauses[0]) - assert.Contains(t, tx.WhereClauses, dateField+" > @var0") - // assert.Equal(t, "unique_field_name IS NOT NULL", tx.WhereClauses[1]) - assert.Contains(t, tx.WhereClauses, uniqueField+" IS NOT NULL") - assert.Equal(t, "2022-04-04T15:12:37.651Z", tx.Vars["var0"]) + client, gdb := mockClient(SQLite) + + raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx = tx.Model(mockObject{}) + err := ApplyCustomWhere(client, tx, conditions) + assert.NoError(t, err) + return tx.First(&mockObject{}) + }) + + assert.Contains(t, raw, "2022-04-04T15:12:37.651Z") + assert.Contains(t, raw, "AND") + assert.Regexp(t, "(.+)unique_field_name(.+)IS NOT NULL", raw) }) } @@ -107,17 +165,20 @@ func Test_whereObject(t *testing.T) { t.Parallel() t.Run("MySQL", func(t *testing.T) { + client, gdb := mockClient(MySQL) + builder := makeWhereBuilder(client, gdb, mockObject{}) + metadata := map[string]interface{}{ "test_key": "test-value", } - query := whereObject(MySQL, metadataField, metadata) + query := builder.whereObject(metadataField, metadata) expected := "JSON_EXTRACT(" + metadataField + ", '$.test_key') = \"test-value\"" assert.Equal(t, expected, query) metadata = map[string]interface{}{ "test_key": "test-'value'", } - query = whereObject(MySQL, metadataField, metadata) + query = builder.whereObject(metadataField, metadata) expected = "JSON_EXTRACT(" + metadataField + ", '$.test_key') = \"test-\\'value\\'\"" assert.Equal(t, expected, query) @@ -125,7 +186,7 @@ func Test_whereObject(t *testing.T) { "test_key1": "test-value", "test_key2": "test-value2", } - query = whereObject(MySQL, metadataField, metadata) + query = builder.whereObject(metadataField, metadata) assert.Contains(t, []string{ "(JSON_EXTRACT(" + metadataField + ", '$.test_key1') = \"test-value\" AND JSON_EXTRACT(" + metadataField + ", '$.test_key2') = \"test-value2\")", @@ -141,7 +202,7 @@ func Test_whereObject(t *testing.T) { "test_key2": "test-value2", }, } - query = whereObject(MySQL, "object_metadata", objectMetadata) + query = builder.whereObject("object_metadata", objectMetadata) assert.Contains(t, []string{ "(JSON_EXTRACT(object_metadata, '$.testId.test_key1') = \"test-value\" AND JSON_EXTRACT(object_metadata, '$.testId.test_key2') = \"test-value2\")", @@ -153,17 +214,20 @@ func Test_whereObject(t *testing.T) { }) t.Run("Postgres", func(t *testing.T) { + client, gdb := mockClient(PostgreSQL) + builder := makeWhereBuilder(client, gdb, mockObject{}) + metadata := map[string]interface{}{ "test_key": "test-value", } - query := whereObject(PostgreSQL, metadataField, metadata) + query := builder.whereObject(metadataField, metadata) expected := metadataField + "::jsonb @> '{\"test_key\":\"test-value\"}'::jsonb" assert.Equal(t, expected, query) metadata = map[string]interface{}{ "test_key": "test-'value'", } - query = whereObject(PostgreSQL, metadataField, metadata) + query = builder.whereObject(metadataField, metadata) expected = metadataField + "::jsonb @> '{\"test_key\":\"test-\\'value\\'\"}'::jsonb" assert.Equal(t, expected, query) @@ -171,7 +235,7 @@ func Test_whereObject(t *testing.T) { "test_key1": "test-value", "test_key2": "test-value2", } - query = whereObject(PostgreSQL, metadataField, metadata) + query = builder.whereObject(metadataField, metadata) assert.Contains(t, []string{ "(" + metadataField + "::jsonb @> '{\"test_key1\":\"test-value\"}'::jsonb AND " + metadataField + "::jsonb @> '{\"test_key2\":\"test-value2\"}'::jsonb)", @@ -187,7 +251,7 @@ func Test_whereObject(t *testing.T) { "test_key2": "test-value2", }, } - query = whereObject(PostgreSQL, "object_metadata", objectMetadata) + query = builder.whereObject("object_metadata", objectMetadata) assert.Contains(t, []string{ "object_metadata::jsonb @> '{\"testId\":{\"test_key1\":\"test-value\",\"test_key2\":\"test-value2\"}}'::jsonb", "object_metadata::jsonb @> '{\"testId\":{\"test_key2\":\"test-value2\",\"test_key1\":\"test-value\"}}'::jsonb", @@ -198,17 +262,20 @@ func Test_whereObject(t *testing.T) { }) t.Run("SQLite", func(t *testing.T) { + client, gdb := mockClient(SQLite) + builder := makeWhereBuilder(client, gdb, mockObject{}) + metadata := map[string]interface{}{ "test_key": "test-value", } - query := whereObject(SQLite, metadataField, metadata) + query := builder.whereObject(metadataField, metadata) expected := "JSON_EXTRACT(" + metadataField + ", '$.test_key') = \"test-value\"" assert.Equal(t, expected, query) metadata = map[string]interface{}{ "test_key": "test-'value'", } - query = whereObject(SQLite, metadataField, metadata) + query = builder.whereObject(metadataField, metadata) expected = "JSON_EXTRACT(" + metadataField + ", '$.test_key') = \"test-\\'value\\'\"" assert.Equal(t, expected, query) @@ -216,7 +283,7 @@ func Test_whereObject(t *testing.T) { "test_key1": "test-value", "test_key2": "test-value2", } - query = whereObject(SQLite, metadataField, metadata) + query = builder.whereObject(metadataField, metadata) assert.Contains(t, []string{ "(JSON_EXTRACT(" + metadataField + ", '$.test_key1') = \"test-value\" AND JSON_EXTRACT(" + metadataField + ", '$.test_key2') = \"test-value2\")", "(JSON_EXTRACT(" + metadataField + ", '$.test_key2') = \"test-value2\" AND JSON_EXTRACT(" + metadataField + ", '$.test_key1') = \"test-value\")", @@ -231,7 +298,7 @@ func Test_whereObject(t *testing.T) { "test_key2": "test-value2", }, } - query = whereObject(SQLite, "object_metadata", objectMetadata) + query = builder.whereObject("object_metadata", objectMetadata) assert.Contains(t, []string{ "(JSON_EXTRACT(object_metadata, '$.testId.test_key1') = \"test-value\" AND JSON_EXTRACT(object_metadata, '$.testId.test_key2') = \"test-value2\")", "(JSON_EXTRACT(object_metadata, '$.testId.test_key2') = \"test-value2\" AND JSON_EXTRACT(object_metadata, '$.testId.test_key1') = \"test-value\")", @@ -241,91 +308,50 @@ func Test_whereObject(t *testing.T) { }) } -// mockSQLCtx is used to mock the SQL -type mockSQLCtx struct { - WhereClauses []interface{} - Vars map[string]interface{} -} - -func (f *mockSQLCtx) Where(query interface{}, args ...interface{}) { - f.WhereClauses = append(f.WhereClauses, query) - if len(args) > 0 { - for _, variables := range args { - for key, value := range variables.(map[string]interface{}) { - f.Vars[key] = value - } - } - } -} - -func (f *mockSQLCtx) getGormTx() *gorm.DB { - return nil -} - // TestCustomWhere will test the method CustomWhere() func TestCustomWhere(t *testing.T) { t.Parallel() t.Run("SQLite empty select", func(t *testing.T) { - client, deferFunc := testClient(context.Background(), t) - defer deferFunc() - tx := mockSQLCtx{ - WhereClauses: make([]interface{}, 0), - Vars: make(map[string]interface{}), - } + client, gdb := mockClient(SQLite) + conditions := map[string]interface{}{} - _ = client.CustomWhere(&tx, conditions, SQLite) - assert.Equal(t, []interface{}{}, tx.WhereClauses) + + raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx = tx.Model(mockObject{}) + err := ApplyCustomWhere(client, tx, conditions) + assert.NoError(t, err) + return tx.First(&mockObject{}) + }) + + assert.Regexp(t, "SELECT(.+)FROM(.+)ORDER BY(.+)LIMIT 1", raw) + assert.NotContains(t, raw, "WHERE") }) t.Run("SQLite simple select", func(t *testing.T) { - client, deferFunc := testClient(context.Background(), t) - defer deferFunc() - tx := mockSQLCtx{ - WhereClauses: make([]interface{}, 0), - Vars: make(map[string]interface{}), - } + client, gdb := mockClient(SQLite) + conditions := map[string]interface{}{ sqlIDFieldProper: "testID", } - _ = client.CustomWhere(&tx, conditions, SQLite) - assert.Len(t, tx.WhereClauses, 1) - assert.Equal(t, sqlIDFieldProper+" = @var0", tx.WhereClauses[0]) - assert.Equal(t, "testID", tx.Vars["var0"]) - }) - t.Run("SQLite "+conditionOr, func(t *testing.T) { - arrayField1 := fieldInIDs - arrayField2 := fieldOutIDs + raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx = tx.Model(mockObject{}) + err := ApplyCustomWhere(client, tx, conditions) + assert.NoError(t, err) + return tx.First(&mockObject{}) + }) - client, deferFunc := testClient(context.Background(), t, WithCustomFields([]string{arrayField1, arrayField2}, nil)) - defer deferFunc() - tx := mockSQLCtx{ - WhereClauses: make([]interface{}, 0), - Vars: make(map[string]interface{}), - } - conditions := map[string]interface{}{ - conditionOr: []map[string]interface{}{{ - arrayField1: "value_id", - }, { - arrayField2: "value_id", - }}, - } - _ = client.CustomWhere(&tx, conditions, SQLite) - assert.Len(t, tx.WhereClauses, 1) - assert.Equal(t, " ( (EXISTS (SELECT 1 FROM json_each("+arrayField1+") WHERE value = \"value_id\")) OR (EXISTS (SELECT 1 FROM json_each("+arrayField2+") WHERE value = \"value_id\")) ) ", tx.WhereClauses[0]) + assert.Regexp(t, "SELECT(.+)FROM(.+)WHERE(.+)id(.*)\\=(.*)testID(.+)ORDER BY(.+)LIMIT 1", raw) }) - t.Run("MySQL "+conditionOr, func(t *testing.T) { + t.Run("SQLite $or in json", func(t *testing.T) { arrayField1 := fieldInIDs arrayField2 := fieldOutIDs - client, deferFunc := testClient(context.Background(), t, WithCustomFields([]string{arrayField1, arrayField2}, nil)) - defer deferFunc() - tx := mockSQLCtx{ - WhereClauses: make([]interface{}, 0), - Vars: make(map[string]interface{}), - } + client, gdb := mockClient(SQLite) + WithCustomFields([]string{arrayField1, arrayField2}, nil)(client.options) + conditions := map[string]interface{}{ conditionOr: []map[string]interface{}{{ arrayField1: "value_id", @@ -333,21 +359,26 @@ func TestCustomWhere(t *testing.T) { arrayField2: "value_id", }}, } - _ = client.CustomWhere(&tx, conditions, MySQL) - assert.Len(t, tx.WhereClauses, 1) - assert.Equal(t, " ( (JSON_CONTAINS("+arrayField1+", CAST('[\"value_id\"]' AS JSON))) OR (JSON_CONTAINS("+arrayField2+", CAST('[\"value_id\"]' AS JSON))) ) ", tx.WhereClauses[0]) + + raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx = tx.Model(mockObject{}) + err := ApplyCustomWhere(client, tx, conditions) + assert.NoError(t, err) + return tx.First(&mockObject{}) + }) + + assert.Contains(t, raw, "json_each(field_in_ids) WHERE value = \"value_id\"") + assert.Contains(t, raw, "json_each(field_out_ids) WHERE value = \"value_id\"") + assert.Regexp(t, "SELECT(.+)FROM(.+)WHERE(.+)EXISTS(.+)SELECT 1(.+)FROM(.+)json_each(.+)WHERE(.+)OR(.+)EXISTS(.+)SELECT 1(.+)FROM(.+)json_each(.+)WHERE(.+)ORDER BY(.+)LIMIT 1", raw) }) - t.Run("PostgreSQL "+conditionOr, func(t *testing.T) { + t.Run("PostgreSQL $or in json", func(t *testing.T) { arrayField1 := fieldInIDs arrayField2 := fieldOutIDs - client, deferFunc := testClient(context.Background(), t, WithCustomFields([]string{arrayField1, arrayField2}, nil)) - defer deferFunc() - tx := mockSQLCtx{ - WhereClauses: make([]interface{}, 0), - Vars: make(map[string]interface{}), - } + client, gdb := mockClient(PostgreSQL) + WithCustomFields([]string{arrayField1, arrayField2}, nil)(client.options) + conditions := map[string]interface{}{ conditionOr: []map[string]interface{}{{ arrayField1: "value_id", @@ -355,72 +386,80 @@ func TestCustomWhere(t *testing.T) { arrayField2: "value_id", }}, } - _ = client.CustomWhere(&tx, conditions, PostgreSQL) - assert.Len(t, tx.WhereClauses, 1) - assert.Equal(t, " ( ("+arrayField1+"::jsonb @> '[\"value_id\"]') OR ("+arrayField2+"::jsonb @> '[\"value_id\"]') ) ", tx.WhereClauses[0]) + + raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx = tx.Model(mockObject{}) + err := ApplyCustomWhere(client, tx, conditions) + assert.NoError(t, err) + return tx.First(&mockObject{}) + }) + + assert.Contains(t, raw, "field_in_ids::jsonb @> '[\"value_id\"]'") + assert.Contains(t, raw, "field_out_ids::jsonb @> '[\"value_id\"]") + assert.Regexp(t, "SELECT(.+)FROM(.+)WHERE(.+)field_(in|out)_ids(.+)OR(.+)field_(in|out)_ids(.+)ORDER BY(.+)LIMIT 1", raw) }) - t.Run("SQLite "+metadataField, func(t *testing.T) { - client, deferFunc := testClient(context.Background(), t) - defer deferFunc() - tx := mockSQLCtx{ - WhereClauses: make([]interface{}, 0), - Vars: make(map[string]interface{}), - } + t.Run("SQLite metadata", func(t *testing.T) { + client, gdb := mockClient(SQLite) conditions := map[string]interface{}{ metadataField: map[string]interface{}{ "field_name": "field_value", }, } - _ = client.CustomWhere(&tx, conditions, SQLite) - assert.Len(t, tx.WhereClauses, 1) - assert.Equal(t, "JSON_EXTRACT("+metadataField+", '$.field_name') = \"field_value\"", tx.WhereClauses[0]) + + raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx = tx.Model(mockObject{}) + err := ApplyCustomWhere(client, tx, conditions) + assert.NoError(t, err) + return tx.First(&mockObject{}) + }) + + assert.Contains(t, raw, "JSON_EXTRACT(metadata, '$.field_name') = \"field_value\"") }) - t.Run("MySQL "+metadataField, func(t *testing.T) { - client, deferFunc := testClient(context.Background(), t) - defer deferFunc() - tx := mockSQLCtx{ - WhereClauses: make([]interface{}, 0), - Vars: make(map[string]interface{}), - } + t.Run("MySQL metadata", func(t *testing.T) { + client, gdb := mockClient(MySQL) conditions := map[string]interface{}{ metadataField: map[string]interface{}{ "field_name": "field_value", }, } - _ = client.CustomWhere(&tx, conditions, MySQL) - assert.Len(t, tx.WhereClauses, 1) - assert.Equal(t, "JSON_EXTRACT("+metadataField+", '$.field_name') = \"field_value\"", tx.WhereClauses[0]) + + raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx = tx.Model(mockObject{}) + err := ApplyCustomWhere(client, tx, conditions) + assert.NoError(t, err) + return tx.First(&mockObject{}) + }) + + assert.Contains(t, raw, "JSON_EXTRACT(metadata, '$.field_name') = \"field_value\"") }) - t.Run("PostgreSQL "+metadataField, func(t *testing.T) { - client, deferFunc := testClient(context.Background(), t) - defer deferFunc() - tx := mockSQLCtx{ - WhereClauses: make([]interface{}, 0), - Vars: make(map[string]interface{}), - } + t.Run("PostgreSQL metadata", func(t *testing.T) { + client, gdb := mockClient(PostgreSQL) conditions := map[string]interface{}{ metadataField: map[string]interface{}{ "field_name": "field_value", }, } - _ = client.CustomWhere(&tx, conditions, PostgreSQL) - assert.Len(t, tx.WhereClauses, 1) - assert.Equal(t, metadataField+"::jsonb @> '{\"field_name\":\"field_value\"}'::jsonb", tx.WhereClauses[0]) + + raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx = tx.Model(mockObject{}) + err := ApplyCustomWhere(client, tx, conditions) + assert.NoError(t, err) + return tx.First(&mockObject{}) + }) + + assert.Contains(t, raw, "metadata::jsonb @> '{\"field_name\":\"field_value\"}'::jsonb") }) - t.Run("SQLite "+conditionAnd, func(t *testing.T) { + t.Run("SQLite $and", func(t *testing.T) { arrayField1 := fieldInIDs arrayField2 := fieldOutIDs - client, deferFunc := testClient(context.Background(), t, WithCustomFields([]string{arrayField1, arrayField2}, nil)) - defer deferFunc() - tx := mockSQLCtx{ - WhereClauses: make([]interface{}, 0), - Vars: make(map[string]interface{}), - } + client, gdb := mockClient(SQLite) + WithCustomFields([]string{arrayField1, arrayField2}, nil)(client.options) + conditions := map[string]interface{}{ conditionAnd: []map[string]interface{}{{ "reference_id": "reference", @@ -434,23 +473,26 @@ func TestCustomWhere(t *testing.T) { }}, }}, } - _ = client.CustomWhere(&tx, conditions, SQLite) - assert.Len(t, tx.WhereClauses, 1) - assert.Equal(t, " ( reference_id = @var0 AND number = @var1 AND ( (EXISTS (SELECT 1 FROM json_each("+arrayField1+") WHERE value = \"value_id\")) OR (EXISTS (SELECT 1 FROM json_each("+arrayField2+") WHERE value = \"value_id\")) ) ) ", tx.WhereClauses[0]) - assert.Equal(t, "reference", tx.Vars["var0"]) - assert.Equal(t, 12, tx.Vars["var1"]) + + raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx = tx.Model(mockObject{}) + err := ApplyCustomWhere(client, tx, conditions) + assert.NoError(t, err) + return tx.First(&mockObject{}) + }) + + assert.Regexp(t, "reference_id(.*)\\=(.*)reference", raw) + assert.Regexp(t, "number(.*)\\=(.*)12", raw) + assert.Regexp(t, "AND(.*)AND", raw) }) - t.Run("MySQL "+conditionAnd, func(t *testing.T) { + t.Run("MySQL $and", func(t *testing.T) { arrayField1 := fieldInIDs arrayField2 := fieldOutIDs - client, deferFunc := testClient(context.Background(), t, WithCustomFields([]string{arrayField1, arrayField2}, nil)) - defer deferFunc() - tx := mockSQLCtx{ - WhereClauses: make([]interface{}, 0), - Vars: make(map[string]interface{}), - } + client, gdb := mockClient(MySQL) + WithCustomFields([]string{arrayField1, arrayField2}, nil)(client.options) + conditions := map[string]interface{}{ conditionAnd: []map[string]interface{}{{ "reference_id": "reference", @@ -464,23 +506,26 @@ func TestCustomWhere(t *testing.T) { }}, }}, } - _ = client.CustomWhere(&tx, conditions, MySQL) - assert.Len(t, tx.WhereClauses, 1) - assert.Equal(t, " ( reference_id = @var0 AND number = @var1 AND ( (JSON_CONTAINS("+arrayField1+", CAST('[\"value_id\"]' AS JSON))) OR (JSON_CONTAINS("+arrayField2+", CAST('[\"value_id\"]' AS JSON))) ) ) ", tx.WhereClauses[0]) - assert.Equal(t, "reference", tx.Vars["var0"]) - assert.Equal(t, 12, tx.Vars["var1"]) + + raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx = tx.Model(mockObject{}) + err := ApplyCustomWhere(client, tx, conditions) + assert.NoError(t, err) + return tx.First(&mockObject{}) + }) + + assert.Regexp(t, "reference_id(.*)\\=(.*)reference", raw) + assert.Regexp(t, "number(.*)\\=(.*)12", raw) + assert.Regexp(t, "AND(.*)AND", raw) }) - t.Run("PostgreSQL "+conditionAnd, func(t *testing.T) { + t.Run("PostgreSQL $and", func(t *testing.T) { arrayField1 := fieldInIDs arrayField2 := fieldOutIDs - client, deferFunc := testClient(context.Background(), t, WithCustomFields([]string{arrayField1, arrayField2}, nil)) - defer deferFunc() - tx := mockSQLCtx{ - WhereClauses: make([]interface{}, 0), - Vars: make(map[string]interface{}), - } + client, gdb := mockClient(PostgreSQL) + WithCustomFields([]string{arrayField1, arrayField2}, nil)(client.options) + conditions := map[string]interface{}{ conditionAnd: []map[string]interface{}{{ "reference_id": "reference", @@ -494,181 +539,142 @@ func TestCustomWhere(t *testing.T) { }}, }}, } - _ = client.CustomWhere(&tx, conditions, PostgreSQL) - assert.Len(t, tx.WhereClauses, 1) - assert.Equal(t, " ( reference_id = @var0 AND number = @var1 AND ( ("+arrayField1+"::jsonb @> '[\"value_id\"]') OR ("+arrayField2+"::jsonb @> '[\"value_id\"]') ) ) ", tx.WhereClauses[0]) - assert.Equal(t, "reference", tx.Vars["var0"]) - assert.Equal(t, 12, tx.Vars["var1"]) + + raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx = tx.Model(mockObject{}) + err := ApplyCustomWhere(client, tx, conditions) + assert.NoError(t, err) + return tx.First(&mockObject{}) + }) + + assert.Regexp(t, "reference_id(.*)\\=(.*)reference", raw) + assert.Regexp(t, "number(.*)\\=(.*)12", raw) + assert.Regexp(t, "AND(.*)AND", raw) }) - t.Run("Where "+conditionGreaterThan, func(t *testing.T) { - client, deferFunc := testClient(context.Background(), t) - defer deferFunc() - tx := mockSQLCtx{ - WhereClauses: make([]interface{}, 0), - Vars: make(map[string]interface{}), - } + t.Run("Where $gt", func(t *testing.T) { + client, gdb := mockClient(PostgreSQL) + conditions := map[string]interface{}{ - "amount": map[string]interface{}{ + "number": map[string]interface{}{ conditionGreaterThan: 502, }, } - _ = client.CustomWhere(&tx, conditions, PostgreSQL) // all the same - assert.Len(t, tx.WhereClauses, 1) - assert.Equal(t, "amount > @var0", tx.WhereClauses[0]) - assert.Equal(t, 502, tx.Vars["var0"]) + + raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx = tx.Model(mockObject{}) + err := ApplyCustomWhere(client, tx, conditions) + assert.NoError(t, err) + return tx.First(&mockObject{}) + }) + + assert.Regexp(t, "number(.*)\\>(.*)502", raw) }) - t.Run("Where "+conditionGreaterThan+" "+conditionLessThan, func(t *testing.T) { - client, deferFunc := testClient(context.Background(), t) - defer deferFunc() - tx := mockSQLCtx{ - WhereClauses: make([]interface{}, 0), - Vars: make(map[string]interface{}), - } + t.Run("Where $and $gt $lt", func(t *testing.T) { + client, gdb := mockClient(PostgreSQL) + conditions := map[string]interface{}{ conditionAnd: []map[string]interface{}{{ - "amount": map[string]interface{}{ + "number": map[string]interface{}{ conditionLessThan: 503, }, }, { - "amount": map[string]interface{}{ + "number": map[string]interface{}{ conditionGreaterThan: 203, }, }}, } - _ = client.CustomWhere(&tx, conditions, PostgreSQL) // all the same - assert.Len(t, tx.WhereClauses, 1) - assert.Equal(t, " ( amount < @var0 AND amount > @var1 ) ", tx.WhereClauses[0]) - assert.Equal(t, 503, tx.Vars["var0"]) - assert.Equal(t, 203, tx.Vars["var1"]) + + raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx = tx.Model(mockObject{}) + err := ApplyCustomWhere(client, tx, conditions) + assert.NoError(t, err) + return tx.First(&mockObject{}) + }) + + // the order may vary + assert.Regexp(t, "number(.*)\\>(.*)203", raw) + assert.Regexp(t, "number(.*)\\<(.*)503", raw) + + assert.Regexp(t, "number(.*)[\\d](.*)AND(.*)number(.*)[\\d]", raw) }) - t.Run("Where "+conditionGreaterThanOrEqual+" "+conditionLessThanOrEqual, func(t *testing.T) { - client, deferFunc := testClient(context.Background(), t) - defer deferFunc() - tx := mockSQLCtx{ - WhereClauses: make([]interface{}, 0), - Vars: make(map[string]interface{}), - } + t.Run("Where $or $gte $lte", func(t *testing.T) { + client, gdb := mockClient(PostgreSQL) + conditions := map[string]interface{}{ conditionOr: []map[string]interface{}{{ - "amount": map[string]interface{}{ - conditionLessThanOrEqual: 203, + "number": map[string]interface{}{ + conditionLessThanOrEqual: 503, }, }, { - "amount": map[string]interface{}{ - conditionGreaterThanOrEqual: 1203, + "number": map[string]interface{}{ + conditionGreaterThanOrEqual: 203, }, }}, } - _ = client.CustomWhere(&tx, conditions, PostgreSQL) // all the same - assert.Len(t, tx.WhereClauses, 1) - assert.Equal(t, " ( (amount <= @var0) OR (amount >= @var1) ) ", tx.WhereClauses[0]) - assert.Equal(t, 203, tx.Vars["var0"]) - assert.Equal(t, 1203, tx.Vars["var1"]) + + raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx = tx.Model(mockObject{}) + err := ApplyCustomWhere(client, tx, conditions) + assert.NoError(t, err) + return tx.First(&mockObject{}) + }) + + // the order may vary + assert.Regexp(t, "number(.*)\\>\\=(.*)203", raw) + assert.Regexp(t, "number(.*)\\<\\=(.*)503", raw) + + assert.Regexp(t, "number(.*)[\\d](.*)OR(.*)number(.*)[\\d]", raw) }) - t.Run("Where "+conditionOr+" "+conditionAnd+" "+conditionOr+" "+conditionGreaterThanOrEqual+" "+conditionLessThanOrEqual, func(t *testing.T) { - client, deferFunc := testClient(context.Background(), t) - defer deferFunc() - tx := mockSQLCtx{ - WhereClauses: make([]interface{}, 0), - Vars: make(map[string]interface{}), - } + t.Run("Where $or $and $or $gte $lte", func(t *testing.T) { + client, gdb := mockClient(PostgreSQL) + conditions := map[string]interface{}{ conditionOr: []map[string]interface{}{{ conditionAnd: []map[string]interface{}{{ - "amount": map[string]interface{}{ + "number": map[string]interface{}{ conditionLessThanOrEqual: 203, }, }, { conditionOr: []map[string]interface{}{{ - "amount": map[string]interface{}{ + "number": map[string]interface{}{ conditionGreaterThanOrEqual: 1203, }, }, { - "value": map[string]interface{}{ + "number": map[string]interface{}{ conditionGreaterThanOrEqual: 2203, }, }}, }}, }, { conditionAnd: []map[string]interface{}{{ - "amount": map[string]interface{}{ + "number": map[string]interface{}{ conditionGreaterThanOrEqual: 3203, }, }, { - "value": map[string]interface{}{ + "number": map[string]interface{}{ conditionGreaterThanOrEqual: 4203, }, }}, }}, } - _ = client.CustomWhere(&tx, conditions, PostgreSQL) // all the same - assert.Len(t, tx.WhereClauses, 1) - assert.Equal(t, " ( ( ( amount <= @var0 AND ( (amount >= @var1) OR (value >= @var2) ) ) ) OR ( ( amount >= @var3 AND value >= @var4 ) ) ) ", tx.WhereClauses[0]) - assert.Equal(t, 203, tx.Vars["var0"]) - assert.Equal(t, 1203, tx.Vars["var1"]) - assert.Equal(t, 2203, tx.Vars["var2"]) - assert.Equal(t, 3203, tx.Vars["var3"]) - assert.Equal(t, 4203, tx.Vars["var4"]) - }) - t.Run("Where "+conditionAnd+" "+conditionOr+" "+conditionOr+" "+conditionGreaterThanOrEqual+" "+conditionLessThanOrEqual, func(t *testing.T) { - client, deferFunc := testClient(context.Background(), t) - defer deferFunc() - tx := mockSQLCtx{ - WhereClauses: make([]interface{}, 0), - Vars: make(map[string]interface{}), - } - conditions := map[string]interface{}{ - conditionAnd: []map[string]interface{}{{ - conditionAnd: []map[string]interface{}{{ - "amount": map[string]interface{}{ - conditionLessThanOrEqual: 203, - conditionGreaterThanOrEqual: 103, - }, - }, { - conditionOr: []map[string]interface{}{{ - "amount": map[string]interface{}{ - conditionGreaterThanOrEqual: 1203, - }, - }, { - "value": map[string]interface{}{ - conditionGreaterThanOrEqual: 2203, - }, - }}, - }}, - }, { - conditionOr: []map[string]interface{}{{ - "amount": map[string]interface{}{ - conditionGreaterThanOrEqual: 3203, - }, - }, { - "value": map[string]interface{}{ - conditionGreaterThanOrEqual: 4203, - }, - }}, - }}, - } - _ = client.CustomWhere(&tx, conditions, PostgreSQL) // all the same - assert.Len(t, tx.WhereClauses, 1) - assert.Contains(t, []string{ - " ( ( amount <= @var0 AND amount >= @var1 AND ( (amount >= @var2) OR (value >= @var3) ) ) AND ( (amount >= @var4) OR (value >= @var5) ) ) ", - " ( ( amount >= @var0 AND amount <= @var1 AND ( (amount >= @var2) OR (value >= @var3) ) ) AND ( (amount >= @var4) OR (value >= @var5) ) ) ", - }, tx.WhereClauses[0]) - - // assert.Equal(t, " ( ( amount <= @var0 AND amount >= @var1 AND ( (amount >= @var2) OR (value >= @var3) ) ) AND ( (amount >= @var4) OR (value >= @var5) ) ) ", tx.WhereClauses[0]) - - assert.Contains(t, []int{203, 103}, tx.Vars["var0"]) - assert.Contains(t, []int{203, 103}, tx.Vars["var1"]) - // assert.Equal(t, 203, tx.Vars["var0"]) - // assert.Equal(t, 103, tx.Vars["var1"]) - assert.Equal(t, 1203, tx.Vars["var2"]) - assert.Equal(t, 2203, tx.Vars["var3"]) - assert.Equal(t, 3203, tx.Vars["var4"]) - assert.Equal(t, 4203, tx.Vars["var5"]) + raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx = tx.Model(mockObject{}) + err := ApplyCustomWhere(client, tx, conditions) + assert.NoError(t, err) + return tx.First(&mockObject{}) + }) + + assert.Regexp(t, "number(.*)\\>\\=(.*)203", raw) + assert.Regexp(t, "number(.*)\\>\\=(.*)1203", raw) + assert.Regexp(t, "number(.*)\\<\\=(.*)2203", raw) + assert.Regexp(t, "number(.*)\\>\\=(.*)3203", raw) + assert.Regexp(t, "number(.*)\\<\\=(.*)4203", raw) + assert.Regexp(t, "AND(.+)OR(.+)AND", raw) }) } diff --git a/go.mod b/go.mod index 5b5a8f2b9..d26bdef58 100644 --- a/go.mod +++ b/go.mod @@ -30,6 +30,8 @@ require ( github.com/swaggo/swag v1.16.3 ) +require github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect + require ( github.com/99designs/gqlgen v0.17.43 // indirect github.com/KyleBanks/depth v1.2.1 // indirect diff --git a/go.sum b/go.sum index 68f2047ec..ae7621e62 100644 --- a/go.sum +++ b/go.sum @@ -204,6 +204,7 @@ github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFF github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/jtolds/gls v4.2.1+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/klauspost/compress v1.17.6 h1:60eq2E/jlfwQXtvZEeBUYADs+BwKBWURIY+Gj2eRGjI= github.com/klauspost/compress v1.17.6/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= From 988d07daf620f79b73ac283482412be7bf4fe752 Mon Sep 17 00:00:00 2001 From: Krzysztof Tomecki <152964795+chris-4chain@users.noreply.github.com> Date: Wed, 27 Mar 2024 09:33:55 +0100 Subject: [PATCH 02/12] fix(BUX-686): get gorm model field name --- engine/datastore/column_name.go | 38 ++++++++++++++++++++++ engine/datastore/where.go | 57 ++++++++++++++------------------- engine/datastore/where_test.go | 4 +-- 3 files changed, 64 insertions(+), 35 deletions(-) create mode 100644 engine/datastore/column_name.go diff --git a/engine/datastore/column_name.go b/engine/datastore/column_name.go new file mode 100644 index 000000000..1e934e889 --- /dev/null +++ b/engine/datastore/column_name.go @@ -0,0 +1,38 @@ +package datastore + +import ( + "sync" + + "gorm.io/gorm" + "gorm.io/gorm/schema" +) + +var modelsCache = sync.Map{} + +// GetColumnName checks if the the model has provided columnName as DBName or (struct field)Name +// Returns (DBName, true) if the column exists otherwise (_, false) +// Uses global cache store (thread safe) +// Checking is case-sensitive +// The gdb param is optional. When is provided, the actual naming strategy is used; otherwise default +func GetColumnName(columnName string, model interface{}, gdb *gorm.DB) (string, bool) { + var namer schema.Namer + if gdb != nil { + namer = gdb.NamingStrategy + } else { + namer = schema.NamingStrategy{} + } + + sch, err := schema.Parse(model, &modelsCache, namer) + if err != nil { + return "", false + } + if field, ok := sch.FieldsByDBName[columnName]; ok { + return field.DBName, true + } + + if field, ok := sch.FieldsByName[columnName]; ok { + return field.DBName, true + } + + return "", false +} diff --git a/engine/datastore/where.go b/engine/datastore/where.go index 645d87d53..91282ffa0 100644 --- a/engine/datastore/where.go +++ b/engine/datastore/where.go @@ -6,24 +6,23 @@ import ( "reflect" "strconv" "strings" - "sync" customtypes "github.com/bitcoin-sv/spv-wallet/engine/datastore/customtypes" "gorm.io/gorm" - "gorm.io/gorm/schema" ) -type CustomWhereInterface interface { +// customWhereInterface with single method Where which aligns with gorm.DB.Where +type customWhereInterface interface { Where(query interface{}, args ...interface{}) *gorm.DB } +// txAccumulator holds the state of the nested conditions for recursive processing type txAccumulator struct { - CustomWhereInterface WhereClauses []string Vars map[string]interface{} } -// Where is our custom where method +// Where makes txAccumulator implement customWhereInterface which will overload gorm.DB.Where behavior func (tx *txAccumulator) Where(query interface{}, args ...interface{}) *gorm.DB { tx.WhereClauses = append(tx.WhereClauses, query.(string)) @@ -38,14 +37,14 @@ func (tx *txAccumulator) Where(query interface{}, args ...interface{}) *gorm.DB return nil } -// WhereBuilder holds a state during custom where preparation -type WhereBuilder struct { +// whereBuilder holds a state during custom where preparation +type whereBuilder struct { client ClientInterface gdb *gorm.DB varNum int } -// ApplyCustomWhere adds conditions to the gorm db instance +// ApplyCustomWhere adds conditions (in-place) to the gorm db instance func ApplyCustomWhere(client ClientInterface, gdb *gorm.DB, conditions map[string]interface{}) (err error) { defer func() { if r := recover(); r != nil { @@ -53,7 +52,7 @@ func ApplyCustomWhere(client ClientInterface, gdb *gorm.DB, conditions map[strin } }() - builder := &WhereBuilder{ + builder := &whereBuilder{ client: client, gdb: gdb, varNum: 0, @@ -63,38 +62,31 @@ func ApplyCustomWhere(client ClientInterface, gdb *gorm.DB, conditions map[strin return nil } -func (builder *WhereBuilder) nextVarName() string { +func (builder *whereBuilder) nextVarName() string { varName := "var" + strconv.Itoa(builder.varNum) builder.varNum++ return varName } -func getColumnName(columnName string, model interface{}) string { - sch, err := schema.Parse(model, &sync.Map{}, schema.NamingStrategy{}) - if err != nil { - panic(fmt.Errorf("cannot parse a model %v", model)) - } - if field, ok := sch.FieldsByDBName[columnName]; ok { - return field.DBName - } - - if field, ok := sch.FieldsByName[columnName]; ok { - return field.DBName +func (builder *whereBuilder) getColumnNameOrPanic(columnName string, model interface{}) string { + columnName, ok := GetColumnName(columnName, model, builder.gdb) + if !ok { + panic(fmt.Errorf("column %s does not exist in the model", columnName)) } - panic(fmt.Errorf("column %s does not exist in the model", columnName)) + return columnName } -func (builder *WhereBuilder) applyCondition(tx CustomWhereInterface, key string, operator string, condition interface{}) { - columnName := getColumnName(key, builder.gdb.Statement.Model) +func (builder *whereBuilder) applyCondition(tx customWhereInterface, key string, operator string, condition interface{}) { + columnName := builder.getColumnNameOrPanic(key, builder.gdb.Statement.Model) varName := builder.nextVarName() query := fmt.Sprintf("%s %s @%s", columnName, operator, varName) tx.Where(query, map[string]interface{}{varName: builder.formatCondition(condition)}) } -func (builder *WhereBuilder) applyExistsCondition(tx CustomWhereInterface, key string, condition bool) { - columnName := getColumnName(key, builder.gdb.Statement.Model) +func (builder *whereBuilder) applyExistsCondition(tx customWhereInterface, key string, condition bool) { + columnName := builder.getColumnNameOrPanic(key, builder.gdb.Statement.Model) operator := "IS NULL" if condition { @@ -104,8 +96,7 @@ func (builder *WhereBuilder) applyExistsCondition(tx CustomWhereInterface, key s } // processConditions will process all conditions -func (builder *WhereBuilder) processConditions(tx CustomWhereInterface, conditions map[string]interface{}, parentKey *string, -) { +func (builder *whereBuilder) processConditions(tx customWhereInterface, conditions map[string]interface{}, parentKey *string) { for key, condition := range conditions { if key == conditionAnd { builder.processWhereAnd(tx, condition) @@ -149,7 +140,7 @@ func (builder *WhereBuilder) processConditions(tx CustomWhereInterface, conditio } // formatCondition will format the conditions -func (builder *WhereBuilder) formatCondition(condition interface{}) interface{} { +func (builder *whereBuilder) formatCondition(condition interface{}) interface{} { switch v := condition.(type) { case customtypes.NullTime: if v.Valid { @@ -169,7 +160,7 @@ func (builder *WhereBuilder) formatCondition(condition interface{}) interface{} } // processWhereAnd will process the AND statements -func (builder *WhereBuilder) processWhereAnd(tx CustomWhereInterface, condition interface{}) { +func (builder *whereBuilder) processWhereAnd(tx customWhereInterface, condition interface{}) { accumulator := &txAccumulator{ WhereClauses: make([]string, 0), Vars: make(map[string]interface{}), @@ -182,7 +173,7 @@ func (builder *WhereBuilder) processWhereAnd(tx CustomWhereInterface, condition } // processWhereOr will process the OR statements -func (builder *WhereBuilder) processWhereOr(tx CustomWhereInterface, condition interface{}) { +func (builder *whereBuilder) processWhereOr(tx customWhereInterface, condition interface{}) { or := make([]string, 0) orVars := make(map[string]interface{}) for _, cond := range condition.([]map[string]interface{}) { @@ -209,7 +200,7 @@ func escapeDBString(s string) string { } // whereObject generates the where object -func (builder *WhereBuilder) whereObject(k string, v interface{}) string { +func (builder *whereBuilder) whereObject(k string, v interface{}) string { queryParts := make([]string, 0) // we don't know the type, we handle the rangeValue as a map[string]interface{} @@ -262,7 +253,7 @@ func (builder *WhereBuilder) whereObject(k string, v interface{}) string { } // whereSlice generates the where slice -func (builder *WhereBuilder) whereSlice(k string, v interface{}) string { +func (builder *whereBuilder) whereSlice(k string, v interface{}) string { engine := builder.client.Engine() if engine == MySQL { return "JSON_CONTAINS(" + k + ", CAST('[\"" + v.(string) + "\"]' AS JSON))" diff --git a/engine/datastore/where_test.go b/engine/datastore/where_test.go index 32ca3d62b..58c9755ad 100644 --- a/engine/datastore/where_test.go +++ b/engine/datastore/where_test.go @@ -48,8 +48,8 @@ func mockClient(engine Engine) (*Client, *gorm.DB) { return client, gdb } -func makeWhereBuilder(client *Client, gdb *gorm.DB, model interface{}) *WhereBuilder { - return &WhereBuilder{ +func makeWhereBuilder(client *Client, gdb *gorm.DB, model interface{}) *whereBuilder { + return &whereBuilder{ client: client, gdb: gdb.Model(model), varNum: 0, From 28c4797b212927d89f5d1beab6cc28e432889c84 Mon Sep 17 00:00:00 2001 From: Krzysztof Tomecki <152964795+chris-4chain@users.noreply.github.com> Date: Wed, 27 Mar 2024 09:33:55 +0100 Subject: [PATCH 03/12] fix(BUX-686): further adjustments with the whole codebase unittests --- engine/datastore/models.go | 45 ++++++++++------------------- engine/datastore/where.go | 39 ++++++++++++++++--------- engine/datastore/where_test.go | 53 ++++++++++++---------------------- 3 files changed, 59 insertions(+), 78 deletions(-) diff --git a/engine/datastore/models.go b/engine/datastore/models.go index e317e2468..2f0c93cac 100644 --- a/engine/datastore/models.go +++ b/engine/datastore/models.go @@ -143,20 +143,6 @@ func convertToInt64(i interface{}) int64 { return i.(int64) } -type gormWhere struct { - tx *gorm.DB -} - -// Where will help fire the tx.Where method -func (g *gormWhere) Where(query interface{}, args ...interface{}) { - g.tx.Where(query, args...) -} - -// getGormTx returns the GORM db tx -func (g *gormWhere) getGormTx() *gorm.DB { - return g.tx -} - // GetModel will get a model from the datastore func (c *Client) GetModel( ctx context.Context, @@ -176,21 +162,17 @@ func (c *Client) GetModel( ctxDB, cancel := createCtx(ctx, c.options.db, timeout, c.IsDebug(), c.options.loggerDB) defer cancel() - // Get the model data using a select - // todo: optimize by specific fields - var tx *gorm.DB - if forceWriteDB { // Use the "write" database for this query (Only MySQL and Postgres) - if c.Engine() == MySQL || c.Engine() == PostgreSQL { - tx = ctxDB.Clauses(dbresolver.Write).Select("*") - } else { - tx = ctxDB.Select("*") - } - } else { // Use a replica if found - tx = ctxDB.Select("*") + tx := ctxDB.Model(model) + + if forceWriteDB && (c.Engine() == MySQL || c.Engine() == PostgreSQL) { + tx = ctxDB.Clauses(dbresolver.Write) } + tx = tx.Select("*") // todo: optimize by specific fields + if len(conditions) > 0 { - if err := ApplyCustomWhere(c, tx, conditions); err != nil { + var err error + if tx, err = ApplyCustomWhere(c, tx, conditions, model); err != nil { return err } } @@ -293,7 +275,8 @@ func (c *Client) find(ctx context.Context, result interface{}, conditions map[st } if len(conditions) > 0 { - if err := ApplyCustomWhere(c, tx, conditions); err != nil { + var err error + if tx, err = ApplyCustomWhere(c, tx, conditions, result); err != nil { return err } } @@ -317,7 +300,8 @@ func (c *Client) count(ctx context.Context, model interface{}, conditions map[st // Check for errors or no records found if len(conditions) > 0 { - if err := ApplyCustomWhere(c, tx, conditions); err != nil { + var err error + if tx, err = ApplyCustomWhere(c, tx, conditions, model); err != nil { return 0, err } } @@ -346,10 +330,11 @@ func (c *Client) aggregate(ctx context.Context, model interface{}, conditions ma // Check for errors or no records found var aggregate []map[string]interface{} if len(conditions) > 0 { - if err := ApplyCustomWhere(c, tx, conditions); err != nil { + var err error + if tx, err = ApplyCustomWhere(c, tx, conditions, model); err != nil { return nil, err } - err := checkResult(tx.Group(aggregateColumn).Scan(&aggregate)) + err = checkResult(tx.Group(aggregateColumn).Scan(&aggregate)) if err != nil { return nil, err } diff --git a/engine/datastore/where.go b/engine/datastore/where.go index 91282ffa0..8bbc9f384 100644 --- a/engine/datastore/where.go +++ b/engine/datastore/where.go @@ -40,26 +40,29 @@ func (tx *txAccumulator) Where(query interface{}, args ...interface{}) *gorm.DB // whereBuilder holds a state during custom where preparation type whereBuilder struct { client ClientInterface - gdb *gorm.DB + tx *gorm.DB varNum int } -// ApplyCustomWhere adds conditions (in-place) to the gorm db instance -func ApplyCustomWhere(client ClientInterface, gdb *gorm.DB, conditions map[string]interface{}) (err error) { +// ApplyCustomWhere adds conditions to the gorm db instance +// it returns a tx of type *gorm.DB with a model and conditions applied +func ApplyCustomWhere(client ClientInterface, gdb *gorm.DB, conditions map[string]interface{}, model interface{}) (tx *gorm.DB, err error) { defer func() { if r := recover(); r != nil { err = fmt.Errorf("error processing conditions: %v", r) } }() + tx = gdb.Model(model) + builder := &whereBuilder{ client: client, - gdb: gdb, + tx: tx, varNum: 0, } - builder.processConditions(gdb, conditions, nil) - return nil + builder.processConditions(tx, conditions, nil) + return } func (builder *whereBuilder) nextVarName() string { @@ -68,17 +71,17 @@ func (builder *whereBuilder) nextVarName() string { return varName } -func (builder *whereBuilder) getColumnNameOrPanic(columnName string, model interface{}) string { - columnName, ok := GetColumnName(columnName, model, builder.gdb) +func (builder *whereBuilder) getColumnNameOrPanic(key string) string { + columnName, ok := GetColumnName(key, builder.tx.Statement.Model, builder.tx) if !ok { - panic(fmt.Errorf("column %s does not exist in the model", columnName)) + panic(fmt.Errorf("column %s does not exist in the model", key)) } return columnName } func (builder *whereBuilder) applyCondition(tx customWhereInterface, key string, operator string, condition interface{}) { - columnName := builder.getColumnNameOrPanic(key, builder.gdb.Statement.Model) + columnName := builder.getColumnNameOrPanic(key) varName := builder.nextVarName() query := fmt.Sprintf("%s %s @%s", columnName, operator, varName) @@ -86,7 +89,7 @@ func (builder *whereBuilder) applyCondition(tx customWhereInterface, key string, } func (builder *whereBuilder) applyExistsCondition(tx customWhereInterface, key string, condition bool) { - columnName := builder.getColumnNameOrPanic(key, builder.gdb.Statement.Model) + columnName := builder.getColumnNameOrPanic(key) operator := "IS NULL" if condition { @@ -169,7 +172,12 @@ func (builder *whereBuilder) processWhereAnd(tx customWhereInterface, condition builder.processConditions(accumulator, c, nil) } - tx.Where(" ( "+strings.Join(accumulator.WhereClauses, " AND ")+" ) ", accumulator.Vars) + query := " ( " + strings.Join(accumulator.WhereClauses, " AND ") + " ) " + if len(accumulator.Vars) > 0 { + tx.Where(query, accumulator.Vars) + } else { + tx.Where(query) + } } // processWhereOr will process the OR statements @@ -190,7 +198,12 @@ func (builder *whereBuilder) processWhereOr(tx customWhereInterface, condition i or = append(or, strings.Join(statement[:], " AND ")) } - tx.Where(" ( ("+strings.Join(or, ") OR (")+") ) ", orVars) + query := " ( (" + strings.Join(or, ") OR (") + ") ) " + if len(orVars) > 0 { + tx.Where(query, orVars) + } else { + tx.Where(query) + } } // escapeDBString will escape the database string diff --git a/engine/datastore/where_test.go b/engine/datastore/where_test.go index 58c9755ad..29c4bee23 100644 --- a/engine/datastore/where_test.go +++ b/engine/datastore/where_test.go @@ -51,7 +51,7 @@ func mockClient(engine Engine) (*Client, *gorm.DB) { func makeWhereBuilder(client *Client, gdb *gorm.DB, model interface{}) *whereBuilder { return &whereBuilder{ client: client, - gdb: gdb.Model(model), + tx: gdb.Model(model), varNum: 0, } } @@ -118,8 +118,7 @@ func Test_processConditions(t *testing.T) { client, gdb := mockClient(MySQL) raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { - tx = tx.Model(mockObject{}) - err := ApplyCustomWhere(client, tx, conditions) + tx, err := ApplyCustomWhere(client, tx, conditions, mockObject{}) assert.NoError(t, err) return tx.First(&mockObject{}) }) @@ -133,8 +132,7 @@ func Test_processConditions(t *testing.T) { client, gdb := mockClient(PostgreSQL) raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { - tx = tx.Model(mockObject{}) - err := ApplyCustomWhere(client, tx, conditions) + tx, err := ApplyCustomWhere(client, tx, conditions, mockObject{}) assert.NoError(t, err) return tx.First(&mockObject{}) }) @@ -148,8 +146,7 @@ func Test_processConditions(t *testing.T) { client, gdb := mockClient(SQLite) raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { - tx = tx.Model(mockObject{}) - err := ApplyCustomWhere(client, tx, conditions) + tx, err := ApplyCustomWhere(client, tx, conditions, mockObject{}) assert.NoError(t, err) return tx.First(&mockObject{}) }) @@ -318,8 +315,7 @@ func TestCustomWhere(t *testing.T) { conditions := map[string]interface{}{} raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { - tx = tx.Model(mockObject{}) - err := ApplyCustomWhere(client, tx, conditions) + tx, err := ApplyCustomWhere(client, tx, conditions, mockObject{}) assert.NoError(t, err) return tx.First(&mockObject{}) }) @@ -336,8 +332,7 @@ func TestCustomWhere(t *testing.T) { } raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { - tx = tx.Model(mockObject{}) - err := ApplyCustomWhere(client, tx, conditions) + tx, err := ApplyCustomWhere(client, tx, conditions, mockObject{}) assert.NoError(t, err) return tx.First(&mockObject{}) }) @@ -361,8 +356,7 @@ func TestCustomWhere(t *testing.T) { } raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { - tx = tx.Model(mockObject{}) - err := ApplyCustomWhere(client, tx, conditions) + tx, err := ApplyCustomWhere(client, tx, conditions, mockObject{}) assert.NoError(t, err) return tx.First(&mockObject{}) }) @@ -388,8 +382,7 @@ func TestCustomWhere(t *testing.T) { } raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { - tx = tx.Model(mockObject{}) - err := ApplyCustomWhere(client, tx, conditions) + tx, err := ApplyCustomWhere(client, tx, conditions, mockObject{}) assert.NoError(t, err) return tx.First(&mockObject{}) }) @@ -408,8 +401,7 @@ func TestCustomWhere(t *testing.T) { } raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { - tx = tx.Model(mockObject{}) - err := ApplyCustomWhere(client, tx, conditions) + tx, err := ApplyCustomWhere(client, tx, conditions, mockObject{}) assert.NoError(t, err) return tx.First(&mockObject{}) }) @@ -426,8 +418,7 @@ func TestCustomWhere(t *testing.T) { } raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { - tx = tx.Model(mockObject{}) - err := ApplyCustomWhere(client, tx, conditions) + tx, err := ApplyCustomWhere(client, tx, conditions, mockObject{}) assert.NoError(t, err) return tx.First(&mockObject{}) }) @@ -444,8 +435,7 @@ func TestCustomWhere(t *testing.T) { } raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { - tx = tx.Model(mockObject{}) - err := ApplyCustomWhere(client, tx, conditions) + tx, err := ApplyCustomWhere(client, tx, conditions, mockObject{}) assert.NoError(t, err) return tx.First(&mockObject{}) }) @@ -475,8 +465,7 @@ func TestCustomWhere(t *testing.T) { } raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { - tx = tx.Model(mockObject{}) - err := ApplyCustomWhere(client, tx, conditions) + tx, err := ApplyCustomWhere(client, tx, conditions, mockObject{}) assert.NoError(t, err) return tx.First(&mockObject{}) }) @@ -508,8 +497,7 @@ func TestCustomWhere(t *testing.T) { } raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { - tx = tx.Model(mockObject{}) - err := ApplyCustomWhere(client, tx, conditions) + tx, err := ApplyCustomWhere(client, tx, conditions, mockObject{}) assert.NoError(t, err) return tx.First(&mockObject{}) }) @@ -541,8 +529,7 @@ func TestCustomWhere(t *testing.T) { } raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { - tx = tx.Model(mockObject{}) - err := ApplyCustomWhere(client, tx, conditions) + tx, err := ApplyCustomWhere(client, tx, conditions, mockObject{}) assert.NoError(t, err) return tx.First(&mockObject{}) }) @@ -562,8 +549,7 @@ func TestCustomWhere(t *testing.T) { } raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { - tx = tx.Model(mockObject{}) - err := ApplyCustomWhere(client, tx, conditions) + tx, err := ApplyCustomWhere(client, tx, conditions, mockObject{}) assert.NoError(t, err) return tx.First(&mockObject{}) }) @@ -587,8 +573,7 @@ func TestCustomWhere(t *testing.T) { } raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { - tx = tx.Model(mockObject{}) - err := ApplyCustomWhere(client, tx, conditions) + tx, err := ApplyCustomWhere(client, tx, conditions, mockObject{}) assert.NoError(t, err) return tx.First(&mockObject{}) }) @@ -616,8 +601,7 @@ func TestCustomWhere(t *testing.T) { } raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { - tx = tx.Model(mockObject{}) - err := ApplyCustomWhere(client, tx, conditions) + tx, err := ApplyCustomWhere(client, tx, conditions, mockObject{}) assert.NoError(t, err) return tx.First(&mockObject{}) }) @@ -663,8 +647,7 @@ func TestCustomWhere(t *testing.T) { } raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { - tx = tx.Model(mockObject{}) - err := ApplyCustomWhere(client, tx, conditions) + tx, err := ApplyCustomWhere(client, tx, conditions, mockObject{}) assert.NoError(t, err) return tx.First(&mockObject{}) }) From a3189e2a5adca0f05eddaae3d526c3136f6cf4b9 Mon Sep 17 00:00:00 2001 From: Krzysztof Tomecki <152964795+chris-4chain@users.noreply.github.com> Date: Wed, 27 Mar 2024 09:33:55 +0100 Subject: [PATCH 04/12] feat(BUX-686): go mod tidy --- go.mod | 2 -- go.sum | 1 - 2 files changed, 3 deletions(-) diff --git a/go.mod b/go.mod index d26bdef58..5b5a8f2b9 100644 --- a/go.mod +++ b/go.mod @@ -30,8 +30,6 @@ require ( github.com/swaggo/swag v1.16.3 ) -require github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect - require ( github.com/99designs/gqlgen v0.17.43 // indirect github.com/KyleBanks/depth v1.2.1 // indirect diff --git a/go.sum b/go.sum index ae7621e62..68f2047ec 100644 --- a/go.sum +++ b/go.sum @@ -204,7 +204,6 @@ github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFF github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/jtolds/gls v4.2.1+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= -github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/klauspost/compress v1.17.6 h1:60eq2E/jlfwQXtvZEeBUYADs+BwKBWURIY+Gj2eRGjI= github.com/klauspost/compress v1.17.6/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= From cbaed701639700f57bb4f42eba9d87f16c1fc5d5 Mon Sep 17 00:00:00 2001 From: Krzysztof Tomecki <152964795+chris-4chain@users.noreply.github.com> Date: Wed, 27 Mar 2024 13:01:22 +0100 Subject: [PATCH 05/12] feat(BUX-686): additional tests with sql-injection attempts --- engine/datastore/where_test.go | 57 +++++++++++++++++++++++++++++++--- 1 file changed, 53 insertions(+), 4 deletions(-) diff --git a/engine/datastore/where_test.go b/engine/datastore/where_test.go index 29c4bee23..37d1b83cd 100644 --- a/engine/datastore/where_test.go +++ b/engine/datastore/where_test.go @@ -97,8 +97,6 @@ func Test_whereSlice(t *testing.T) { func Test_processConditions(t *testing.T) { t.Parallel() - dateField := dateCreatedAt - uniqueField := "unique_field_name" theTime := time.Date(2022, 4, 4, 15, 12, 37, 651387237, time.UTC) nullTime := sql.NullTime{ Valid: true, @@ -106,10 +104,10 @@ func Test_processConditions(t *testing.T) { } conditions := map[string]interface{}{ - dateField: map[string]interface{}{ + "created_at": map[string]interface{}{ conditionGreaterThan: customtypes.NullTime{NullTime: nullTime}, }, - uniqueField: map[string]interface{}{ + "unique_field_name": map[string]interface{}{ conditionExists: true, }, } @@ -661,6 +659,57 @@ func TestCustomWhere(t *testing.T) { }) } +func Test_sqlInjectionSafety(t *testing.T) { + t.Parallel() + + t.Run("injection as simple key", func(t *testing.T) { + client, gdb := mockClient(PostgreSQL) + + conditions := map[string]interface{}{ + "1=1 --": 12, + } + + gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx, err := ApplyCustomWhere(client, tx, conditions, mockObject{}) + assert.Error(t, err) + return tx.First(&mockObject{}) + }) + }) + + t.Run("injection in key as conditionExists", func(t *testing.T) { + client, gdb := mockClient(PostgreSQL) + + conditions := map[string]interface{}{ + "1=1 OR unique_field_name": map[string]interface{}{ + conditionExists: true, + }, + } + + gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx, err := ApplyCustomWhere(client, tx, conditions, mockObject{}) + assert.Error(t, err) + return tx.First(&mockObject{}) + }) + }) + + t.Run("injection in metadata", func(t *testing.T) { + client, gdb := mockClient(PostgreSQL) + conditions := map[string]interface{}{ + metadataField: map[string]interface{}{ + "1=1; DELETE FROM users": "field_value", + }, + } + + raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx, err := ApplyCustomWhere(client, tx, conditions, mockObject{}) + assert.NoError(t, err) + return tx.First(&mockObject{}) + }) + + assert.Contains(t, raw, `'{"1=1; DELETE FROM users":"field_value"}'`) + }) +} + // Test_escapeDBString will test the method escapeDBString() func Test_escapeDBString(t *testing.T) { t.Parallel() From 62fc20c37869e8e835a55ee9cd68a5dc51ead4ce Mon Sep 17 00:00:00 2001 From: Krzysztof Tomecki <152964795+chris-4chain@users.noreply.github.com> Date: Wed, 27 Mar 2024 15:04:06 +0100 Subject: [PATCH 06/12] fix(BUX-686): double 'the' in the comment --- engine/datastore/column_name.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engine/datastore/column_name.go b/engine/datastore/column_name.go index 1e934e889..77c219c97 100644 --- a/engine/datastore/column_name.go +++ b/engine/datastore/column_name.go @@ -9,7 +9,7 @@ import ( var modelsCache = sync.Map{} -// GetColumnName checks if the the model has provided columnName as DBName or (struct field)Name +// GetColumnName checks if the model has provided columnName as DBName or (struct field)Name // Returns (DBName, true) if the column exists otherwise (_, false) // Uses global cache store (thread safe) // Checking is case-sensitive From 5e651e2d4cf80626b8fc860d6421d672b78a9c15 Mon Sep 17 00:00:00 2001 From: Damian Orzepowski Date: Thu, 28 Mar 2024 09:34:06 +0100 Subject: [PATCH 07/12] fix(BUX-686): fix injections for postgres when field is jsonb --- engine/datastore/where.go | 43 +++++++++++++++++++++++++++++++++++---- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/engine/datastore/where.go b/engine/datastore/where.go index 8bbc9f384..3e9dcdeca 100644 --- a/engine/datastore/where.go +++ b/engine/datastore/where.go @@ -7,7 +7,7 @@ import ( "strconv" "strings" - customtypes "github.com/bitcoin-sv/spv-wallet/engine/datastore/customtypes" + "github.com/bitcoin-sv/spv-wallet/engine/datastore/customtypes" "gorm.io/gorm" ) @@ -83,11 +83,30 @@ func (builder *whereBuilder) getColumnNameOrPanic(key string) string { func (builder *whereBuilder) applyCondition(tx customWhereInterface, key string, operator string, condition interface{}) { columnName := builder.getColumnNameOrPanic(key) + if condition == nil { + tx.Where(columnName + " " + operator) + return + } varName := builder.nextVarName() query := fmt.Sprintf("%s %s @%s", columnName, operator, varName) tx.Where(query, map[string]interface{}{varName: builder.formatCondition(condition)}) } +func (builder *whereBuilder) applyJson(tx customWhereInterface, key string, condition interface{}) { + columnName := builder.getColumnNameOrPanic(key) + + varName := builder.nextVarName() + engine := builder.client.Engine() + + if engine != PostgreSQL { + //todo handle other databases then postgres + panic("eoeoeoeoeoeoeoeoeoeoeo not implemented yet") + } + + query := fmt.Sprintf("%s::jsonb @> @%s", columnName, varName) + tx.Where(query, map[string]interface{}{varName: builder.formatCondition(condition)}) +} + func (builder *whereBuilder) applyExistsCondition(tx customWhereInterface, key string, condition bool) { columnName := builder.getColumnNameOrPanic(key) @@ -116,12 +135,12 @@ func (builder *whereBuilder) processConditions(tx customWhereInterface, conditio } else if key == conditionExists { builder.applyExistsCondition(tx, *parentKey, condition.(bool)) } else if StringInSlice(key, builder.client.GetArrayFields()) { - tx.Where(builder.whereSlice(key, builder.formatCondition(condition))) + builder.applyArray(tx, key, condition) } else if StringInSlice(key, builder.client.GetObjectFields()) { - tx.Where(builder.whereObject(key, builder.formatCondition(condition))) + builder.applyJson(tx, key, condition) } else { if condition == nil { - tx.Where(key + " IS NULL") + builder.applyCondition(tx, key, "IS NULL", nil) } else { v := reflect.ValueOf(condition) switch v.Kind() { //nolint:exhaustive // not all cases are needed @@ -275,3 +294,19 @@ func (builder *whereBuilder) whereSlice(k string, v interface{}) string { } return "EXISTS (SELECT 1 FROM json_each(" + k + ") WHERE value = \"" + v.(string) + "\")" } + +func (builder *whereBuilder) applyArray(tx customWhereInterface, key string, condition interface{}) { + columnName := builder.getColumnNameOrPanic(key) + + varName := builder.nextVarName() + engine := builder.client.Engine() + + if engine != PostgreSQL { + //todo handle other databases then postgres + panic("eoeoeoeoeoeoeoeoeoeoeo not implemented yet") + } + + query := fmt.Sprintf("%s::jsonb @> @%s", columnName, varName) + c := condition.(string) + tx.Where(query, map[string]interface{}{varName: builder.formatCondition("[\"" + c + "\"]")}) +} From db804eef1947c920abd88e35fcdbfc30da9bcfc0 Mon Sep 17 00:00:00 2001 From: Krzysztof Tomecki <152964795+chris-4chain@users.noreply.github.com> Date: Thu, 28 Mar 2024 12:43:20 +0100 Subject: [PATCH 08/12] fix(BUX-686): applyJSONCondition --- engine/datastore/where.go | 97 +++++++++++++++-------- engine/datastore/where_test.go | 141 +++++++++++++++++++++++++++++---- 2 files changed, 192 insertions(+), 46 deletions(-) diff --git a/engine/datastore/where.go b/engine/datastore/where.go index 3e9dcdeca..0385db2f6 100644 --- a/engine/datastore/where.go +++ b/engine/datastore/where.go @@ -92,21 +92,6 @@ func (builder *whereBuilder) applyCondition(tx customWhereInterface, key string, tx.Where(query, map[string]interface{}{varName: builder.formatCondition(condition)}) } -func (builder *whereBuilder) applyJson(tx customWhereInterface, key string, condition interface{}) { - columnName := builder.getColumnNameOrPanic(key) - - varName := builder.nextVarName() - engine := builder.client.Engine() - - if engine != PostgreSQL { - //todo handle other databases then postgres - panic("eoeoeoeoeoeoeoeoeoeoeo not implemented yet") - } - - query := fmt.Sprintf("%s::jsonb @> @%s", columnName, varName) - tx.Where(query, map[string]interface{}{varName: builder.formatCondition(condition)}) -} - func (builder *whereBuilder) applyExistsCondition(tx customWhereInterface, key string, condition bool) { columnName := builder.getColumnNameOrPanic(key) @@ -135,9 +120,9 @@ func (builder *whereBuilder) processConditions(tx customWhereInterface, conditio } else if key == conditionExists { builder.applyExistsCondition(tx, *parentKey, condition.(bool)) } else if StringInSlice(key, builder.client.GetArrayFields()) { - builder.applyArray(tx, key, condition) + builder.applyArray(tx, key, condition.(string)) } else if StringInSlice(key, builder.client.GetObjectFields()) { - builder.applyJson(tx, key, condition) + builder.applyJSONCondition(tx, key, condition) } else { if condition == nil { builder.applyCondition(tx, key, "IS NULL", nil) @@ -285,28 +270,78 @@ func (builder *whereBuilder) whereObject(k string, v interface{}) string { } // whereSlice generates the where slice -func (builder *whereBuilder) whereSlice(k string, v interface{}) string { +// func (builder *whereBuilder) whereSlice(k string, v interface{}) string { +// engine := builder.client.Engine() +// if engine == MySQL { +// return "JSON_CONTAINS(" + k + ", CAST('[\"" + v.(string) + "\"]' AS JSON))" +// } else if engine == PostgreSQL { +// return k + "::jsonb @> '[\"" + v.(string) + "\"]'" +// } +// return "EXISTS (SELECT 1 FROM json_each(" + k + ") WHERE value = \"" + v.(string) + "\")" +// } + +func (builder *whereBuilder) applyArray(tx customWhereInterface, key string, condition string) { + columnName := builder.getColumnNameOrPanic(key) + + varName := builder.nextVarName() engine := builder.client.Engine() - if engine == MySQL { - return "JSON_CONTAINS(" + k + ", CAST('[\"" + v.(string) + "\"]' AS JSON))" - } else if engine == PostgreSQL { - return k + "::jsonb @> '[\"" + v.(string) + "\"]'" + + query := "" + arg := "" + + switch engine { + case PostgreSQL: + query = fmt.Sprintf("%s::jsonb @> @%s", columnName, varName) + arg = fmt.Sprintf(`["%s"]`, condition) + case MySQL: + query = fmt.Sprintf("JSON_CONTAINS(%s, CAST(@%s AS JSON))", columnName, varName) + arg = fmt.Sprintf(`["%s"]`, condition) + case SQLite: + query = fmt.Sprintf("EXISTS (SELECT 1 FROM json_each(%s) WHERE value = @%s)", columnName, varName) + arg = condition + default: + panic("Database engine not supported") } - return "EXISTS (SELECT 1 FROM json_each(" + k + ") WHERE value = \"" + v.(string) + "\")" + + tx.Where(query, map[string]interface{}{varName: arg}) } -func (builder *whereBuilder) applyArray(tx customWhereInterface, key string, condition interface{}) { +func (builder *whereBuilder) applyJSONCondition(tx customWhereInterface, key string, condition interface{}) { columnName := builder.getColumnNameOrPanic(key) - - varName := builder.nextVarName() engine := builder.client.Engine() - if engine != PostgreSQL { - //todo handle other databases then postgres - panic("eoeoeoeoeoeoeoeoeoeoeo not implemented yet") + if engine == PostgreSQL { + builder.applyJSONBCondition(tx, columnName, condition) + } else if engine == MySQL || engine == SQLite { + builder.applyJSONExtractCondition(tx, columnName, condition) + } else { + panic("Database engine not supported") } +} +func (builder *whereBuilder) applyJSONBCondition(tx customWhereInterface, columnName string, condition interface{}) { + varName := builder.nextVarName() query := fmt.Sprintf("%s::jsonb @> @%s", columnName, varName) - c := condition.(string) - tx.Where(query, map[string]interface{}{varName: builder.formatCondition("[\"" + c + "\"]")}) + tx.Where(query, map[string]interface{}{varName: condition}) +} + +func (builder *whereBuilder) applyJSONExtractCondition(tx customWhereInterface, columnName string, condition interface{}) { + dict := convertTo[map[string]interface{}](condition) + for key, value := range dict { + keyVarName := builder.nextVarName() + valueVarName := builder.nextVarName() + query := fmt.Sprintf("JSON_EXTRACT(%s, @%s) = @%s", columnName, keyVarName, valueVarName) + tx.Where(query, map[string]interface{}{ + keyVarName: fmt.Sprintf("$.%s", key), + valueVarName: value, + }) + } +} + +func convertTo[T any](object interface{}) T { + vJSON, _ := json.Marshal(object) + + var converted T + _ = json.Unmarshal(vJSON, &converted) + return converted } diff --git a/engine/datastore/where_test.go b/engine/datastore/where_test.go index 37d1b83cd..0e437856e 100644 --- a/engine/datastore/where_test.go +++ b/engine/datastore/where_test.go @@ -1,18 +1,25 @@ package datastore import ( + "bytes" "context" "database/sql" + "database/sql/driver" + "encoding/json" "testing" "time" "github.com/DATA-DOG/go-sqlmock" customtypes "github.com/bitcoin-sv/spv-wallet/engine/datastore/customtypes" + "github.com/bitcoin-sv/spv-wallet/engine/utils" "github.com/stretchr/testify/assert" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/bsontype" "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" "gorm.io/gorm" + "gorm.io/gorm/schema" ) func mockDialector(engine Engine) gorm.Dialector { @@ -56,40 +63,144 @@ func makeWhereBuilder(client *Client, gdb *gorm.DB, model interface{}) *whereBui } } +const ( + // MetadataField is the field name used for metadata (params) + MetadataField = "metadata" +) + +type Metadata map[string]interface{} + +func (m Metadata) GormDataType() string { + return "text" +} + +func (m *Metadata) Scan(value interface{}) error { + if value == nil { + return nil + } + + byteValue, err := utils.ToByteArray(value) + if err != nil || bytes.Equal(byteValue, []byte("")) || bytes.Equal(byteValue, []byte("\"\"")) { + return nil + } + + return json.Unmarshal(byteValue, &m) +} + +func (m Metadata) Value() (driver.Value, error) { + if m == nil { + return nil, nil + } + marshal, err := json.Marshal(m) + if err != nil { + return nil, err + } + + return string(marshal), nil +} + +func (Metadata) GormDBDataType(db *gorm.DB, _ *schema.Field) string { + if db.Dialector.Name() == Postgres { + return JSONB + } + return JSON +} + +func (m *Metadata) MarshalBSONValue() (bsontype.Type, []byte, error) { + if m == nil || len(*m) == 0 { + return bson.TypeNull, nil, nil + } + + metadata := make([]map[string]interface{}, 0) + for key, value := range *m { + metadata = append(metadata, map[string]interface{}{ + "k": key, + "v": value, + }) + } + + return bson.MarshalValue(metadata) +} + +func (m *Metadata) UnmarshalBSONValue(t bsontype.Type, data []byte) error { + raw := bson.RawValue{Type: t, Value: data} + + if raw.Value == nil { + return nil + } + + var uMap []map[string]interface{} + if err := raw.Unmarshal(&uMap); err != nil { + return err + } + + *m = make(Metadata) + for _, meta := range uMap { + key := meta["k"].(string) + (*m)[key] = meta["v"] + } + + return nil +} + type mockObject struct { ID string CreatedAt time.Time UniqueFieldName string Number int ReferenceID string + Metadata Metadata } // Test_whereObject test the SQL where selector func Test_whereSlice(t *testing.T) { t.Parallel() - t.Run("MySQL", func(t *testing.T) { - client, gdb := mockClient(MySQL) - builder := makeWhereBuilder(client, gdb, mockObject{}) - query := builder.whereSlice(fieldInIDs, "id_1") - expected := `JSON_CONTAINS(` + fieldInIDs + `, CAST('["id_1"]' AS JSON))` - assert.Equal(t, expected, query) - }) + conditions := map[string]interface{}{ + "metadata": Metadata{ + "domain": "test-domain", + }, + } t.Run("Postgres", func(t *testing.T) { client, gdb := mockClient(PostgreSQL) - builder := makeWhereBuilder(client, gdb, mockObject{}) - query := builder.whereSlice(fieldInIDs, "id_1") - expected := fieldInIDs + `::jsonb @> '["id_1"]'` - assert.Equal(t, expected, query) + + raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx, err := ApplyCustomWhere(client, tx, conditions, mockObject{}) + assert.NoError(t, err) + return tx.First(&mockObject{}) + }) + + assert.Contains(t, raw, "metadata::jsonb @>") + assert.Contains(t, raw, `'{"domain":"test-domain"}'`) }) t.Run("SQLite", func(t *testing.T) { client, gdb := mockClient(SQLite) - builder := makeWhereBuilder(client, gdb, mockObject{}) - query := builder.whereSlice(fieldInIDs, "id_1") - expected := `EXISTS (SELECT 1 FROM json_each(` + fieldInIDs + `) WHERE value = "id_1")` - assert.Equal(t, expected, query) + + raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx, err := ApplyCustomWhere(client, tx, conditions, mockObject{}) + assert.NoError(t, err) + return tx.First(&mockObject{}) + }) + + assert.Contains(t, raw, "JSON_EXTRACT(metadata") + assert.Contains(t, raw, `"$.domain"`) + assert.Contains(t, raw, `"test-domain"`) + }) + + t.Run("MySQL", func(t *testing.T) { + client, gdb := mockClient(MySQL) + + raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx, err := ApplyCustomWhere(client, tx, conditions, mockObject{}) + assert.NoError(t, err) + return tx.First(&mockObject{}) + }) + + assert.Contains(t, raw, "JSON_EXTRACT(metadata") + assert.Contains(t, raw, "'$.domain'") + assert.Contains(t, raw, "'test-domain'") }) } From 220d5798bb35ee88bedfa8667df23f68bd0cec99 Mon Sep 17 00:00:00 2001 From: Krzysztof Tomecki <152964795+chris-4chain@users.noreply.github.com> Date: Thu, 28 Mar 2024 13:38:44 +0100 Subject: [PATCH 09/12] fix(BUX-686): applyArrayCondition & adjust tests --- engine/datastore/where.go | 26 +-- engine/datastore/where_test.go | 306 +++++++++++++-------------------- 2 files changed, 138 insertions(+), 194 deletions(-) diff --git a/engine/datastore/where.go b/engine/datastore/where.go index 0385db2f6..8a4e5e733 100644 --- a/engine/datastore/where.go +++ b/engine/datastore/where.go @@ -283,25 +283,31 @@ func (builder *whereBuilder) whereObject(k string, v interface{}) string { func (builder *whereBuilder) applyArray(tx customWhereInterface, key string, condition string) { columnName := builder.getColumnNameOrPanic(key) - varName := builder.nextVarName() engine := builder.client.Engine() - query := "" - arg := "" - switch engine { case PostgreSQL: - query = fmt.Sprintf("%s::jsonb @> @%s", columnName, varName) - arg = fmt.Sprintf(`["%s"]`, condition) + builder.applyJSONBCondition(tx, columnName, fmt.Sprintf(`["%s"]`, condition)) case MySQL: - query = fmt.Sprintf("JSON_CONTAINS(%s, CAST(@%s AS JSON))", columnName, varName) - arg = fmt.Sprintf(`["%s"]`, condition) + builder.applyJSONContainsCondition(tx, columnName, condition) case SQLite: - query = fmt.Sprintf("EXISTS (SELECT 1 FROM json_each(%s) WHERE value = @%s)", columnName, varName) - arg = condition + builder.applyJSONExistsCondition(tx, columnName, condition) default: panic("Database engine not supported") } +} + +func (builder *whereBuilder) applyJSONExistsCondition(tx customWhereInterface, columnName string, condition string) { + varName := builder.nextVarName() + query := fmt.Sprintf("EXISTS (SELECT 1 FROM json_each(%s) WHERE value = @%s)", columnName, varName) + + tx.Where(query, map[string]interface{}{varName: condition}) +} + +func (builder *whereBuilder) applyJSONContainsCondition(tx customWhereInterface, columnName string, condition string) { + varName := builder.nextVarName() + query := fmt.Sprintf("JSON_CONTAINS(%s, CAST(@%s AS JSON))", columnName, varName) + arg := fmt.Sprintf(`["%s"]`, condition) tx.Where(query, map[string]interface{}{varName: arg}) } diff --git a/engine/datastore/where_test.go b/engine/datastore/where_test.go index 0e437856e..1d461d93a 100644 --- a/engine/datastore/where_test.go +++ b/engine/datastore/where_test.go @@ -143,6 +143,48 @@ func (m *Metadata) UnmarshalBSONValue(t bsontype.Type, data []byte) error { return nil } +type IDs []string + +// GormDataType type in gorm +func (i IDs) GormDataType() string { + return "text" +} + +// Scan scan value into JSON, implements sql.Scanner interface +func (i *IDs) Scan(value interface{}) error { + if value == nil { + return nil + } + + byteValue, err := utils.ToByteArray(value) + if err != nil { + return nil + } + + return json.Unmarshal(byteValue, &i) +} + +// Value return json value, implement driver.Valuer interface +func (i IDs) Value() (driver.Value, error) { + if i == nil { + return nil, nil + } + marshal, err := json.Marshal(i) + if err != nil { + return nil, err + } + + return string(marshal), nil +} + +// GormDBDataType the gorm data type for metadata +func (IDs) GormDBDataType(db *gorm.DB, _ *schema.Field) string { + if db.Dialector.Name() == Postgres { + return JSONB + } + return JSON +} + type mockObject struct { ID string CreatedAt time.Time @@ -150,10 +192,12 @@ type mockObject struct { Number int ReferenceID string Metadata Metadata + FieldInIDs IDs + FieldOutIDs IDs } // Test_whereObject test the SQL where selector -func Test_whereSlice(t *testing.T) { +func Test_whereObject(t *testing.T) { t.Parallel() conditions := map[string]interface{}{ @@ -204,6 +248,63 @@ func Test_whereSlice(t *testing.T) { }) } +// Test_whereObject test the SQL where selector +func Test_whereSlice(t *testing.T) { + t.Parallel() + + conditions := map[string]interface{}{ + "field_in_ids": "test", + } + + t.Run("Postgres", func(t *testing.T) { + client, gdb := mockClient(PostgreSQL) + WithCustomFields([]string{"field_in_ids"}, nil)(client.options) + + raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx, err := ApplyCustomWhere(client, tx, conditions, mockObject{}) + assert.NoError(t, err) + return tx.First(&mockObject{}) + }) + // produced SQL: + // SELECT * FROM "mock_objects" WHERE field_in_ids::jsonb @> '["test"]' ORDER BY "mock_objects"."id" LIMIT 1 + + assert.Contains(t, raw, "field_in_ids::jsonb @>") + assert.Contains(t, raw, `'["test"]'`) + }) + + t.Run("MySQL", func(t *testing.T) { + client, gdb := mockClient(MySQL) + WithCustomFields([]string{"field_in_ids"}, nil)(client.options) + + raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx, err := ApplyCustomWhere(client, tx, conditions, mockObject{}) + assert.NoError(t, err) + return tx.First(&mockObject{}) + }) + // produced SQL: + // SELECT * FROM `mock_objects` WHERE JSON_CONTAINS(field_in_ids, CAST('["test"]' AS JSON)) ORDER BY `mock_objects`.`id` LIMIT 1 + + assert.Contains(t, raw, "JSON_CONTAINS(field_in_ids") + assert.Contains(t, raw, `'["test"]'`) + }) + + t.Run("SQLite", func(t *testing.T) { + client, gdb := mockClient(SQLite) + WithCustomFields([]string{"field_in_ids"}, nil)(client.options) + + raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx, err := ApplyCustomWhere(client, tx, conditions, mockObject{}) + assert.NoError(t, err) + return tx.First(&mockObject{}) + }) + // produced SQL: + // SELECT * FROM `mock_objects` WHERE EXISTS (SELECT 1 FROM json_each(field_in_ids) WHERE value = "test") ORDER BY `mock_objects`.`id` LIMIT 1 + + assert.Contains(t, raw, "json_each(field_in_ids)") + assert.Contains(t, raw, `"test"`) + }) +} + // Test_processConditions test the SQL where selectors func Test_processConditions(t *testing.T) { t.Parallel() @@ -266,154 +367,6 @@ func Test_processConditions(t *testing.T) { }) } -// Test_whereObject test the SQL where selector -func Test_whereObject(t *testing.T) { - t.Parallel() - - t.Run("MySQL", func(t *testing.T) { - client, gdb := mockClient(MySQL) - builder := makeWhereBuilder(client, gdb, mockObject{}) - - metadata := map[string]interface{}{ - "test_key": "test-value", - } - query := builder.whereObject(metadataField, metadata) - expected := "JSON_EXTRACT(" + metadataField + ", '$.test_key') = \"test-value\"" - assert.Equal(t, expected, query) - - metadata = map[string]interface{}{ - "test_key": "test-'value'", - } - query = builder.whereObject(metadataField, metadata) - expected = "JSON_EXTRACT(" + metadataField + ", '$.test_key') = \"test-\\'value\\'\"" - assert.Equal(t, expected, query) - - metadata = map[string]interface{}{ - "test_key1": "test-value", - "test_key2": "test-value2", - } - query = builder.whereObject(metadataField, metadata) - - assert.Contains(t, []string{ - "(JSON_EXTRACT(" + metadataField + ", '$.test_key1') = \"test-value\" AND JSON_EXTRACT(" + metadataField + ", '$.test_key2') = \"test-value2\")", - "(JSON_EXTRACT(" + metadataField + ", '$.test_key2') = \"test-value2\" AND JSON_EXTRACT(" + metadataField + ", '$.test_key1') = \"test-value\")", - }, query) - - // NOTE: the order of the items can change, hence the query order can change - // assert.Equal(t, expected, query) - - objectMetadata := map[string]interface{}{ - "testId": map[string]interface{}{ - "test_key1": "test-value", - "test_key2": "test-value2", - }, - } - query = builder.whereObject("object_metadata", objectMetadata) - - assert.Contains(t, []string{ - "(JSON_EXTRACT(object_metadata, '$.testId.test_key1') = \"test-value\" AND JSON_EXTRACT(object_metadata, '$.testId.test_key2') = \"test-value2\")", - "(JSON_EXTRACT(object_metadata, '$.testId.test_key2') = \"test-value2\" AND JSON_EXTRACT(object_metadata, '$.testId.test_key1') = \"test-value\")", - }, query) - - // NOTE: the order of the items can change, hence the query order can change - // assert.Equal(t, expected, query) - }) - - t.Run("Postgres", func(t *testing.T) { - client, gdb := mockClient(PostgreSQL) - builder := makeWhereBuilder(client, gdb, mockObject{}) - - metadata := map[string]interface{}{ - "test_key": "test-value", - } - query := builder.whereObject(metadataField, metadata) - expected := metadataField + "::jsonb @> '{\"test_key\":\"test-value\"}'::jsonb" - assert.Equal(t, expected, query) - - metadata = map[string]interface{}{ - "test_key": "test-'value'", - } - query = builder.whereObject(metadataField, metadata) - expected = metadataField + "::jsonb @> '{\"test_key\":\"test-\\'value\\'\"}'::jsonb" - assert.Equal(t, expected, query) - - metadata = map[string]interface{}{ - "test_key1": "test-value", - "test_key2": "test-value2", - } - query = builder.whereObject(metadataField, metadata) - - assert.Contains(t, []string{ - "(" + metadataField + "::jsonb @> '{\"test_key1\":\"test-value\"}'::jsonb AND " + metadataField + "::jsonb @> '{\"test_key2\":\"test-value2\"}'::jsonb)", - "(" + metadataField + "::jsonb @> '{\"test_key2\":\"test-value2\"}'::jsonb AND " + metadataField + "::jsonb @> '{\"test_key1\":\"test-value\"}'::jsonb)", - }, query) - - // NOTE: the order of the items can change, hence the query order can change - // assert.Equal(t, expected, query) - - objectMetadata := map[string]interface{}{ - "testId": map[string]interface{}{ - "test_key1": "test-value", - "test_key2": "test-value2", - }, - } - query = builder.whereObject("object_metadata", objectMetadata) - assert.Contains(t, []string{ - "object_metadata::jsonb @> '{\"testId\":{\"test_key1\":\"test-value\",\"test_key2\":\"test-value2\"}}'::jsonb", - "object_metadata::jsonb @> '{\"testId\":{\"test_key2\":\"test-value2\",\"test_key1\":\"test-value\"}}'::jsonb", - }, query) - - // NOTE: the order of the items can change, hence the query order can change - // assert.Equal(t, expected, query) - }) - - t.Run("SQLite", func(t *testing.T) { - client, gdb := mockClient(SQLite) - builder := makeWhereBuilder(client, gdb, mockObject{}) - - metadata := map[string]interface{}{ - "test_key": "test-value", - } - query := builder.whereObject(metadataField, metadata) - expected := "JSON_EXTRACT(" + metadataField + ", '$.test_key') = \"test-value\"" - assert.Equal(t, expected, query) - - metadata = map[string]interface{}{ - "test_key": "test-'value'", - } - query = builder.whereObject(metadataField, metadata) - expected = "JSON_EXTRACT(" + metadataField + ", '$.test_key') = \"test-\\'value\\'\"" - assert.Equal(t, expected, query) - - metadata = map[string]interface{}{ - "test_key1": "test-value", - "test_key2": "test-value2", - } - query = builder.whereObject(metadataField, metadata) - assert.Contains(t, []string{ - "(JSON_EXTRACT(" + metadataField + ", '$.test_key1') = \"test-value\" AND JSON_EXTRACT(" + metadataField + ", '$.test_key2') = \"test-value2\")", - "(JSON_EXTRACT(" + metadataField + ", '$.test_key2') = \"test-value2\" AND JSON_EXTRACT(" + metadataField + ", '$.test_key1') = \"test-value\")", - }, query) - - // NOTE: the order of the items can change, hence the query order can change - // assert.Equal(t, expected, query) - - objectMetadata := map[string]interface{}{ - "testId": map[string]interface{}{ - "test_key1": "test-value", - "test_key2": "test-value2", - }, - } - query = builder.whereObject("object_metadata", objectMetadata) - assert.Contains(t, []string{ - "(JSON_EXTRACT(object_metadata, '$.testId.test_key1') = \"test-value\" AND JSON_EXTRACT(object_metadata, '$.testId.test_key2') = \"test-value2\")", - "(JSON_EXTRACT(object_metadata, '$.testId.test_key2') = \"test-value2\" AND JSON_EXTRACT(object_metadata, '$.testId.test_key1') = \"test-value\")", - }, query) - // NOTE: the order of the items can change, hence the query order can change - // assert.Equal(t, expected, query) - }) -} - // TestCustomWhere will test the method CustomWhere() func TestCustomWhere(t *testing.T) { t.Parallel() @@ -450,17 +403,14 @@ func TestCustomWhere(t *testing.T) { }) t.Run("SQLite $or in json", func(t *testing.T) { - arrayField1 := fieldInIDs - arrayField2 := fieldOutIDs - client, gdb := mockClient(SQLite) - WithCustomFields([]string{arrayField1, arrayField2}, nil)(client.options) + WithCustomFields([]string{"field_in_ids", "field_out_ids"}, nil)(client.options) conditions := map[string]interface{}{ conditionOr: []map[string]interface{}{{ - arrayField1: "value_id", + "field_in_ids": "value_id", }, { - arrayField2: "value_id", + "field_out_ids": "value_id", }}, } @@ -476,17 +426,14 @@ func TestCustomWhere(t *testing.T) { }) t.Run("PostgreSQL $or in json", func(t *testing.T) { - arrayField1 := fieldInIDs - arrayField2 := fieldOutIDs - client, gdb := mockClient(PostgreSQL) - WithCustomFields([]string{arrayField1, arrayField2}, nil)(client.options) + WithCustomFields([]string{"field_in_ids", "field_out_ids"}, nil)(client.options) conditions := map[string]interface{}{ conditionOr: []map[string]interface{}{{ - arrayField1: "value_id", + "field_in_ids": "value_id", }, { - arrayField2: "value_id", + "field_out_ids": "value_id", }}, } @@ -515,7 +462,7 @@ func TestCustomWhere(t *testing.T) { return tx.First(&mockObject{}) }) - assert.Contains(t, raw, "JSON_EXTRACT(metadata, '$.field_name') = \"field_value\"") + assert.Contains(t, raw, `JSON_EXTRACT(metadata, "$.field_name") = "field_value"`) }) t.Run("MySQL metadata", func(t *testing.T) { @@ -532,13 +479,13 @@ func TestCustomWhere(t *testing.T) { return tx.First(&mockObject{}) }) - assert.Contains(t, raw, "JSON_EXTRACT(metadata, '$.field_name') = \"field_value\"") + assert.Contains(t, raw, "JSON_EXTRACT(metadata, '$.field_name') = 'field_value'") }) t.Run("PostgreSQL metadata", func(t *testing.T) { client, gdb := mockClient(PostgreSQL) conditions := map[string]interface{}{ - metadataField: map[string]interface{}{ + metadataField: Metadata{ "field_name": "field_value", }, } @@ -549,15 +496,12 @@ func TestCustomWhere(t *testing.T) { return tx.First(&mockObject{}) }) - assert.Contains(t, raw, "metadata::jsonb @> '{\"field_name\":\"field_value\"}'::jsonb") + assert.Contains(t, raw, `metadata::jsonb @> '{"field_name":"field_value"}'`) }) t.Run("SQLite $and", func(t *testing.T) { - arrayField1 := fieldInIDs - arrayField2 := fieldOutIDs - client, gdb := mockClient(SQLite) - WithCustomFields([]string{arrayField1, arrayField2}, nil)(client.options) + WithCustomFields([]string{"field_in_ids", "field_out_ids"}, nil)(client.options) conditions := map[string]interface{}{ conditionAnd: []map[string]interface{}{{ @@ -566,9 +510,9 @@ func TestCustomWhere(t *testing.T) { "number": 12, }, { conditionOr: []map[string]interface{}{{ - arrayField1: "value_id", + "field_in_ids": "value_id", }, { - arrayField2: "value_id", + "field_out_ids": "value_id", }}, }}, } @@ -585,11 +529,8 @@ func TestCustomWhere(t *testing.T) { }) t.Run("MySQL $and", func(t *testing.T) { - arrayField1 := fieldInIDs - arrayField2 := fieldOutIDs - client, gdb := mockClient(MySQL) - WithCustomFields([]string{arrayField1, arrayField2}, nil)(client.options) + WithCustomFields([]string{"field_in_ids", "field_out_ids"}, nil)(client.options) conditions := map[string]interface{}{ conditionAnd: []map[string]interface{}{{ @@ -598,9 +539,9 @@ func TestCustomWhere(t *testing.T) { "number": 12, }, { conditionOr: []map[string]interface{}{{ - arrayField1: "value_id", + "field_in_ids": "value_id", }, { - arrayField2: "value_id", + "field_out_ids": "value_id", }}, }}, } @@ -617,11 +558,8 @@ func TestCustomWhere(t *testing.T) { }) t.Run("PostgreSQL $and", func(t *testing.T) { - arrayField1 := fieldInIDs - arrayField2 := fieldOutIDs - client, gdb := mockClient(PostgreSQL) - WithCustomFields([]string{arrayField1, arrayField2}, nil)(client.options) + WithCustomFields([]string{"field_in_ids", "field_out_ids"}, nil)(client.options) conditions := map[string]interface{}{ conditionAnd: []map[string]interface{}{{ @@ -630,9 +568,9 @@ func TestCustomWhere(t *testing.T) { "number": 12, }, { conditionOr: []map[string]interface{}{{ - arrayField1: "value_id", + "field_in_ids": "value_id", }, { - arrayField2: "value_id", + "field_out_ids": "value_id", }}, }}, } @@ -806,7 +744,7 @@ func Test_sqlInjectionSafety(t *testing.T) { t.Run("injection in metadata", func(t *testing.T) { client, gdb := mockClient(PostgreSQL) conditions := map[string]interface{}{ - metadataField: map[string]interface{}{ + metadataField: Metadata{ "1=1; DELETE FROM users": "field_value", }, } From 71486519e1d7b55e67787c4f6d08421fe476c96d Mon Sep 17 00:00:00 2001 From: Krzysztof Tomecki <152964795+chris-4chain@users.noreply.github.com> Date: Thu, 28 Mar 2024 14:25:07 +0100 Subject: [PATCH 10/12] refactor(BUX-686): splitting to miltiple files --- engine/datastore/where.go | 325 -------------------------- engine/datastore/where_accumulator.go | 31 +++ engine/datastore/where_builder.go | 224 ++++++++++++++++++ engine/datastore/where_test.go | 8 - 4 files changed, 255 insertions(+), 333 deletions(-) create mode 100644 engine/datastore/where_accumulator.go create mode 100644 engine/datastore/where_builder.go diff --git a/engine/datastore/where.go b/engine/datastore/where.go index 8a4e5e733..8e3e41067 100644 --- a/engine/datastore/where.go +++ b/engine/datastore/where.go @@ -1,49 +1,11 @@ package datastore import ( - "encoding/json" "fmt" - "reflect" - "strconv" - "strings" - "github.com/bitcoin-sv/spv-wallet/engine/datastore/customtypes" "gorm.io/gorm" ) -// customWhereInterface with single method Where which aligns with gorm.DB.Where -type customWhereInterface interface { - Where(query interface{}, args ...interface{}) *gorm.DB -} - -// txAccumulator holds the state of the nested conditions for recursive processing -type txAccumulator struct { - WhereClauses []string - Vars map[string]interface{} -} - -// Where makes txAccumulator implement customWhereInterface which will overload gorm.DB.Where behavior -func (tx *txAccumulator) Where(query interface{}, args ...interface{}) *gorm.DB { - tx.WhereClauses = append(tx.WhereClauses, query.(string)) - - if len(args) > 0 { - for _, variables := range args { - for key, value := range variables.(map[string]interface{}) { - tx.Vars[key] = value - } - } - } - - return nil -} - -// whereBuilder holds a state during custom where preparation -type whereBuilder struct { - client ClientInterface - tx *gorm.DB - varNum int -} - // ApplyCustomWhere adds conditions to the gorm db instance // it returns a tx of type *gorm.DB with a model and conditions applied func ApplyCustomWhere(client ClientInterface, gdb *gorm.DB, conditions map[string]interface{}, model interface{}) (tx *gorm.DB, err error) { @@ -64,290 +26,3 @@ func ApplyCustomWhere(client ClientInterface, gdb *gorm.DB, conditions map[strin builder.processConditions(tx, conditions, nil) return } - -func (builder *whereBuilder) nextVarName() string { - varName := "var" + strconv.Itoa(builder.varNum) - builder.varNum++ - return varName -} - -func (builder *whereBuilder) getColumnNameOrPanic(key string) string { - columnName, ok := GetColumnName(key, builder.tx.Statement.Model, builder.tx) - if !ok { - panic(fmt.Errorf("column %s does not exist in the model", key)) - } - - return columnName -} - -func (builder *whereBuilder) applyCondition(tx customWhereInterface, key string, operator string, condition interface{}) { - columnName := builder.getColumnNameOrPanic(key) - - if condition == nil { - tx.Where(columnName + " " + operator) - return - } - varName := builder.nextVarName() - query := fmt.Sprintf("%s %s @%s", columnName, operator, varName) - tx.Where(query, map[string]interface{}{varName: builder.formatCondition(condition)}) -} - -func (builder *whereBuilder) applyExistsCondition(tx customWhereInterface, key string, condition bool) { - columnName := builder.getColumnNameOrPanic(key) - - operator := "IS NULL" - if condition { - operator = "IS NOT NULL" - } - tx.Where(columnName + " " + operator) -} - -// processConditions will process all conditions -func (builder *whereBuilder) processConditions(tx customWhereInterface, conditions map[string]interface{}, parentKey *string) { - for key, condition := range conditions { - if key == conditionAnd { - builder.processWhereAnd(tx, condition) - } else if key == conditionOr { - builder.processWhereOr(tx, conditions[conditionOr]) - } else if key == conditionGreaterThan { - builder.applyCondition(tx, *parentKey, ">", condition) - } else if key == conditionLessThan { - builder.applyCondition(tx, *parentKey, "<", condition) - } else if key == conditionGreaterThanOrEqual { - builder.applyCondition(tx, *parentKey, ">=", condition) - } else if key == conditionLessThanOrEqual { - builder.applyCondition(tx, *parentKey, "<=", condition) - } else if key == conditionExists { - builder.applyExistsCondition(tx, *parentKey, condition.(bool)) - } else if StringInSlice(key, builder.client.GetArrayFields()) { - builder.applyArray(tx, key, condition.(string)) - } else if StringInSlice(key, builder.client.GetObjectFields()) { - builder.applyJSONCondition(tx, key, condition) - } else { - if condition == nil { - builder.applyCondition(tx, key, "IS NULL", nil) - } else { - v := reflect.ValueOf(condition) - switch v.Kind() { //nolint:exhaustive // not all cases are needed - case reflect.Map: - if _, ok := condition.(map[string]interface{}); ok { - builder.processConditions(tx, condition.(map[string]interface{}), &key) //nolint:scopelint // ignore for now - } else { - c, _ := json.Marshal(condition) //nolint:errchkjson // this check might break the current code - var cc map[string]interface{} - _ = json.Unmarshal(c, &cc) - builder.processConditions(tx, cc, &key) //nolint:scopelint // ignore for now - } - default: - builder.applyCondition(tx, key, "=", condition) - } - } - } - } -} - -// formatCondition will format the conditions -func (builder *whereBuilder) formatCondition(condition interface{}) interface{} { - switch v := condition.(type) { - case customtypes.NullTime: - if v.Valid { - engine := builder.client.Engine() - if engine == MySQL { - return v.Time.Format("2006-01-02 15:04:05") - } else if engine == PostgreSQL { - return v.Time.Format("2006-01-02T15:04:05Z07:00") - } - // default & SQLite - return v.Time.Format("2006-01-02T15:04:05.000Z") - } - return nil - } - - return condition -} - -// processWhereAnd will process the AND statements -func (builder *whereBuilder) processWhereAnd(tx customWhereInterface, condition interface{}) { - accumulator := &txAccumulator{ - WhereClauses: make([]string, 0), - Vars: make(map[string]interface{}), - } - for _, c := range condition.([]map[string]interface{}) { - builder.processConditions(accumulator, c, nil) - } - - query := " ( " + strings.Join(accumulator.WhereClauses, " AND ") + " ) " - if len(accumulator.Vars) > 0 { - tx.Where(query, accumulator.Vars) - } else { - tx.Where(query) - } -} - -// processWhereOr will process the OR statements -func (builder *whereBuilder) processWhereOr(tx customWhereInterface, condition interface{}) { - or := make([]string, 0) - orVars := make(map[string]interface{}) - for _, cond := range condition.([]map[string]interface{}) { - statement := make([]string, 0) - accumulator := &txAccumulator{ - WhereClauses: make([]string, 0), - Vars: make(map[string]interface{}), - } - builder.processConditions(accumulator, cond, nil) - statement = append(statement, accumulator.WhereClauses...) - for varName, varValue := range accumulator.Vars { - orVars[varName] = varValue - } - or = append(or, strings.Join(statement[:], " AND ")) - } - - query := " ( (" + strings.Join(or, ") OR (") + ") ) " - if len(orVars) > 0 { - tx.Where(query, orVars) - } else { - tx.Where(query) - } -} - -// escapeDBString will escape the database string -func escapeDBString(s string) string { - rs := strings.Replace(s, "'", "\\'", -1) - return strings.Replace(rs, "\"", "\\\"", -1) -} - -// whereObject generates the where object -func (builder *whereBuilder) whereObject(k string, v interface{}) string { - queryParts := make([]string, 0) - - // we don't know the type, we handle the rangeValue as a map[string]interface{} - vJSON, _ := json.Marshal(v) //nolint:errchkjson // this check might break the current code - - var rangeV map[string]interface{} - _ = json.Unmarshal(vJSON, &rangeV) - - engine := builder.client.Engine() - - for rangeKey, rangeValue := range rangeV { - if engine == MySQL || engine == SQLite { - switch vv := rangeValue.(type) { - case string: - rangeValue = "\"" + escapeDBString(rangeValue.(string)) + "\"" - queryParts = append(queryParts, "JSON_EXTRACT("+k+", '$."+rangeKey+"') = "+rangeValue.(string)) - default: - metadataJSON, _ := json.Marshal(vv) //nolint:errchkjson // this check might break the current code - var metadata map[string]interface{} - _ = json.Unmarshal(metadataJSON, &metadata) - for kk, vvv := range metadata { - mJSON, _ := json.Marshal(vvv) //nolint:errchkjson // this check might break the current code - vvv = string(mJSON) - queryParts = append(queryParts, "JSON_EXTRACT("+k+", '$."+rangeKey+"."+kk+"') = "+vvv.(string)) - } - } - } else if engine == PostgreSQL { - switch vv := rangeValue.(type) { - case string: - rangeValue = "\"" + escapeDBString(rangeValue.(string)) + "\"" - default: - metadataJSON, _ := json.Marshal(vv) //nolint:errchkjson // this check might break the current code - rangeValue = string(metadataJSON) - } - queryParts = append(queryParts, k+"::jsonb @> '{\""+rangeKey+"\":"+rangeValue.(string)+"}'::jsonb") - } else { - queryParts = append(queryParts, "JSON_EXTRACT("+k+", '$."+rangeKey+"') = '"+escapeDBString(rangeValue.(string))+"'") - } - } - - if len(queryParts) == 0 { - return "" - } - query := queryParts[0] - if len(queryParts) > 1 { - query = "(" + strings.Join(queryParts, " AND ") + ")" - } - - return query -} - -// whereSlice generates the where slice -// func (builder *whereBuilder) whereSlice(k string, v interface{}) string { -// engine := builder.client.Engine() -// if engine == MySQL { -// return "JSON_CONTAINS(" + k + ", CAST('[\"" + v.(string) + "\"]' AS JSON))" -// } else if engine == PostgreSQL { -// return k + "::jsonb @> '[\"" + v.(string) + "\"]'" -// } -// return "EXISTS (SELECT 1 FROM json_each(" + k + ") WHERE value = \"" + v.(string) + "\")" -// } - -func (builder *whereBuilder) applyArray(tx customWhereInterface, key string, condition string) { - columnName := builder.getColumnNameOrPanic(key) - - engine := builder.client.Engine() - - switch engine { - case PostgreSQL: - builder.applyJSONBCondition(tx, columnName, fmt.Sprintf(`["%s"]`, condition)) - case MySQL: - builder.applyJSONContainsCondition(tx, columnName, condition) - case SQLite: - builder.applyJSONExistsCondition(tx, columnName, condition) - default: - panic("Database engine not supported") - } -} - -func (builder *whereBuilder) applyJSONExistsCondition(tx customWhereInterface, columnName string, condition string) { - varName := builder.nextVarName() - query := fmt.Sprintf("EXISTS (SELECT 1 FROM json_each(%s) WHERE value = @%s)", columnName, varName) - - tx.Where(query, map[string]interface{}{varName: condition}) -} - -func (builder *whereBuilder) applyJSONContainsCondition(tx customWhereInterface, columnName string, condition string) { - varName := builder.nextVarName() - query := fmt.Sprintf("JSON_CONTAINS(%s, CAST(@%s AS JSON))", columnName, varName) - arg := fmt.Sprintf(`["%s"]`, condition) - - tx.Where(query, map[string]interface{}{varName: arg}) -} - -func (builder *whereBuilder) applyJSONCondition(tx customWhereInterface, key string, condition interface{}) { - columnName := builder.getColumnNameOrPanic(key) - engine := builder.client.Engine() - - if engine == PostgreSQL { - builder.applyJSONBCondition(tx, columnName, condition) - } else if engine == MySQL || engine == SQLite { - builder.applyJSONExtractCondition(tx, columnName, condition) - } else { - panic("Database engine not supported") - } -} - -func (builder *whereBuilder) applyJSONBCondition(tx customWhereInterface, columnName string, condition interface{}) { - varName := builder.nextVarName() - query := fmt.Sprintf("%s::jsonb @> @%s", columnName, varName) - tx.Where(query, map[string]interface{}{varName: condition}) -} - -func (builder *whereBuilder) applyJSONExtractCondition(tx customWhereInterface, columnName string, condition interface{}) { - dict := convertTo[map[string]interface{}](condition) - for key, value := range dict { - keyVarName := builder.nextVarName() - valueVarName := builder.nextVarName() - query := fmt.Sprintf("JSON_EXTRACT(%s, @%s) = @%s", columnName, keyVarName, valueVarName) - tx.Where(query, map[string]interface{}{ - keyVarName: fmt.Sprintf("$.%s", key), - valueVarName: value, - }) - } -} - -func convertTo[T any](object interface{}) T { - vJSON, _ := json.Marshal(object) - - var converted T - _ = json.Unmarshal(vJSON, &converted) - return converted -} diff --git a/engine/datastore/where_accumulator.go b/engine/datastore/where_accumulator.go new file mode 100644 index 000000000..8bbf2bbaf --- /dev/null +++ b/engine/datastore/where_accumulator.go @@ -0,0 +1,31 @@ +package datastore + +import ( + "gorm.io/gorm" +) + +// customWhereInterface with single method Where which aligns with gorm.DB.Where +type customWhereInterface interface { + Where(query interface{}, args ...interface{}) *gorm.DB +} + +// txAccumulator holds the state of the nested conditions for recursive processing +type txAccumulator struct { + WhereClauses []string + Vars map[string]interface{} +} + +// Where makes txAccumulator implement customWhereInterface which will overload gorm.DB.Where behavior +func (tx *txAccumulator) Where(query interface{}, args ...interface{}) *gorm.DB { + tx.WhereClauses = append(tx.WhereClauses, query.(string)) + + if len(args) > 0 { + for _, variables := range args { + for key, value := range variables.(map[string]interface{}) { + tx.Vars[key] = value + } + } + } + + return nil +} diff --git a/engine/datastore/where_builder.go b/engine/datastore/where_builder.go new file mode 100644 index 000000000..0cf389df0 --- /dev/null +++ b/engine/datastore/where_builder.go @@ -0,0 +1,224 @@ +package datastore + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + + "github.com/bitcoin-sv/spv-wallet/engine/datastore/customtypes" + "gorm.io/gorm" +) + +// whereBuilder holds a state during custom where preparation +type whereBuilder struct { + client ClientInterface + tx *gorm.DB + varNum int +} + +// processConditions will process all conditions +func (builder *whereBuilder) processConditions(tx customWhereInterface, conditions map[string]interface{}, parentKey *string) { + for key, condition := range conditions { + switch { + case key == conditionAnd: + builder.processWhereAnd(tx, condition) + case key == conditionOr: + builder.processWhereOr(tx, condition) + case key == conditionGreaterThan: + builder.applyCondition(tx, *parentKey, ">", condition) + case key == conditionLessThan: + builder.applyCondition(tx, *parentKey, "<", condition) + case key == conditionGreaterThanOrEqual: + builder.applyCondition(tx, *parentKey, ">=", condition) + case key == conditionLessThanOrEqual: + builder.applyCondition(tx, *parentKey, "<=", condition) + case key == conditionExists: + builder.applyExistsCondition(tx, *parentKey, condition.(bool)) + case StringInSlice(key, builder.client.GetArrayFields()): + builder.applyJSONArrayContains(tx, key, condition.(string)) + case StringInSlice(key, builder.client.GetObjectFields()): + builder.applyJSONCondition(tx, key, condition) + case condition == nil: + builder.applyCondition(tx, key, "IS NULL", nil) + default: + v := reflect.ValueOf(condition) + if v.Kind() == reflect.Map { + dict := convertToDict(condition) + builder.processConditions(tx, dict, &key) + } else { + builder.applyCondition(tx, key, "=", condition) + } + } + } +} + +func (builder *whereBuilder) applyCondition(tx customWhereInterface, key string, operator string, condition interface{}) { + columnName := builder.getColumnNameOrPanic(key) + + if condition == nil { + tx.Where(columnName + " " + operator) + return + } + varName := builder.nextVarName() + query := fmt.Sprintf("%s %s @%s", columnName, operator, varName) + tx.Where(query, map[string]interface{}{varName: builder.formatCondition(condition)}) +} + +func (builder *whereBuilder) applyExistsCondition(tx customWhereInterface, key string, condition bool) { + operator := "IS NULL" + if condition { + operator = "IS NOT NULL" + } + builder.applyCondition(tx, key, operator, nil) +} + +// applyJSONArrayContains will apply array condition on JSON Array field - client.GetArrayFields() +func (builder *whereBuilder) applyJSONArrayContains(tx customWhereInterface, key string, condition string) { + columnName := builder.getColumnNameOrPanic(key) + + engine := builder.client.Engine() + + switch engine { + case PostgreSQL: + builder.applyPostgresJSONB(tx, columnName, fmt.Sprintf(`["%s"]`, condition)) + case MySQL: + varName := builder.nextVarName() + tx.Where( + fmt.Sprintf("JSON_CONTAINS(%s, CAST(@%s AS JSON))", columnName, varName), + map[string]interface{}{varName: fmt.Sprintf(`["%s"]`, condition)}, + ) + case SQLite: + varName := builder.nextVarName() + tx.Where( + fmt.Sprintf("EXISTS (SELECT 1 FROM json_each(%s) WHERE value = @%s)", columnName, varName), + map[string]interface{}{varName: condition}, + ) + case MongoDB, Empty: + panic("Database engine not supported") + default: + panic("Unknown database engine") + } +} + +func (builder *whereBuilder) applyJSONCondition(tx customWhereInterface, key string, condition interface{}) { + columnName := builder.getColumnNameOrPanic(key) + engine := builder.client.Engine() + + if engine == PostgreSQL { + builder.applyPostgresJSONB(tx, columnName, condition) + } else if engine == MySQL || engine == SQLite { + builder.applyJSONExtract(tx, columnName, condition) + } else { + panic("Database engine not supported") + } +} + +func (builder *whereBuilder) applyPostgresJSONB(tx customWhereInterface, columnName string, condition interface{}) { + varName := builder.nextVarName() + query := fmt.Sprintf("%s::jsonb @> @%s", columnName, varName) + tx.Where(query, map[string]interface{}{varName: condition}) +} + +func (builder *whereBuilder) applyJSONExtract(tx customWhereInterface, columnName string, condition interface{}) { + dict := convertToDict(condition) + for key, value := range dict { + keyVarName := builder.nextVarName() + valueVarName := builder.nextVarName() + query := fmt.Sprintf("JSON_EXTRACT(%s, @%s) = @%s", columnName, keyVarName, valueVarName) + tx.Where(query, map[string]interface{}{ + keyVarName: fmt.Sprintf("$.%s", key), + valueVarName: value, + }) + } +} + +func (builder *whereBuilder) processWhereAnd(tx customWhereInterface, condition interface{}) { + accumulator := &txAccumulator{ + WhereClauses: make([]string, 0), + Vars: make(map[string]interface{}), + } + for _, c := range condition.([]map[string]interface{}) { + builder.processConditions(accumulator, c, nil) + } + + query := " ( " + strings.Join(accumulator.WhereClauses, " AND ") + " ) " + if len(accumulator.Vars) > 0 { + tx.Where(query, accumulator.Vars) + } else { + tx.Where(query) + } +} + +func (builder *whereBuilder) processWhereOr(tx customWhereInterface, condition interface{}) { + or := make([]string, 0) + orVars := make(map[string]interface{}) + for _, cond := range condition.([]map[string]interface{}) { + statement := make([]string, 0) + accumulator := &txAccumulator{ + WhereClauses: make([]string, 0), + Vars: make(map[string]interface{}), + } + builder.processConditions(accumulator, cond, nil) + statement = append(statement, accumulator.WhereClauses...) + for varName, varValue := range accumulator.Vars { + orVars[varName] = varValue + } + or = append(or, strings.Join(statement[:], " AND ")) + } + + query := " ( (" + strings.Join(or, ") OR (") + ") ) " + if len(orVars) > 0 { + tx.Where(query, orVars) + } else { + tx.Where(query) + } +} + +func (builder *whereBuilder) nextVarName() string { + varName := fmt.Sprintf("var%d", builder.varNum) + builder.varNum++ + return varName +} + +func (builder *whereBuilder) getColumnNameOrPanic(key string) string { + columnName, ok := GetColumnName(key, builder.tx.Statement.Model, builder.tx) + if !ok { + panic(fmt.Errorf("column %s does not exist in the model", key)) + } + + return columnName +} + +func (builder *whereBuilder) formatCondition(condition interface{}) interface{} { + if nullType, ok := condition.(customtypes.NullTime); ok { + return builder.formatNullTime(nullType) + } + + return condition +} + +func (builder *whereBuilder) formatNullTime(condition customtypes.NullTime) interface{} { + if !condition.Valid { + return nil + } + engine := builder.client.Engine() + if engine == MySQL { + return condition.Time.Format("2006-01-02 15:04:05") + } + if engine == PostgreSQL { + return condition.Time.Format("2006-01-02T15:04:05Z07:00") + } + return condition.Time.Format("2006-01-02T15:04:05.000Z") +} + +func convertToDict(object interface{}) map[string]interface{} { + if converted, ok := object.(map[string]interface{}); ok { + return converted + } + vJSON, _ := json.Marshal(object) + + var converted map[string]interface{} + _ = json.Unmarshal(vJSON, &converted) + return converted +} diff --git a/engine/datastore/where_test.go b/engine/datastore/where_test.go index 1d461d93a..f06221a45 100644 --- a/engine/datastore/where_test.go +++ b/engine/datastore/where_test.go @@ -758,11 +758,3 @@ func Test_sqlInjectionSafety(t *testing.T) { assert.Contains(t, raw, `'{"1=1; DELETE FROM users":"field_value"}'`) }) } - -// Test_escapeDBString will test the method escapeDBString() -func Test_escapeDBString(t *testing.T) { - t.Parallel() - - str := escapeDBString(`SELECT * FROM 'table' WHERE 'field'=1;`) - assert.Equal(t, `SELECT * FROM \'table\' WHERE \'field\'=1;`, str) -} From 9c283ca8fa63a9880b3b3fe42f366e7501df25ec Mon Sep 17 00:00:00 2001 From: Krzysztof Tomecki <152964795+chris-4chain@users.noreply.github.com> Date: Thu, 28 Mar 2024 14:32:04 +0100 Subject: [PATCH 11/12] chore(BUX-686): adjust comments --- engine/datastore/where_builder.go | 1 + 1 file changed, 1 insertion(+) diff --git a/engine/datastore/where_builder.go b/engine/datastore/where_builder.go index 0cf389df0..f8cf515f8 100644 --- a/engine/datastore/where_builder.go +++ b/engine/datastore/where_builder.go @@ -101,6 +101,7 @@ func (builder *whereBuilder) applyJSONArrayContains(tx customWhereInterface, key } } +// applyJSONCondition will apply condition on JSON Object field - client.GetObjectFields() func (builder *whereBuilder) applyJSONCondition(tx customWhereInterface, key string, condition interface{}) { columnName := builder.getColumnNameOrPanic(key) engine := builder.client.Engine() From cbc77e3083e5de5355b82be3674a542c49e7ce29 Mon Sep 17 00:00:00 2001 From: Krzysztof Tomecki <152964795+chris-4chain@users.noreply.github.com> Date: Thu, 28 Mar 2024 14:41:32 +0100 Subject: [PATCH 12/12] chore(BUX-686): where_mocks transfered to another file --- engine/datastore/where_mocks_test.go | 180 ++++++++++++++++++++++++++ engine/datastore/where_test.go | 186 --------------------------- 2 files changed, 180 insertions(+), 186 deletions(-) create mode 100644 engine/datastore/where_mocks_test.go diff --git a/engine/datastore/where_mocks_test.go b/engine/datastore/where_mocks_test.go new file mode 100644 index 000000000..9be3117d9 --- /dev/null +++ b/engine/datastore/where_mocks_test.go @@ -0,0 +1,180 @@ +package datastore + +import ( + "bytes" + "context" + "database/sql/driver" + "encoding/json" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/bitcoin-sv/spv-wallet/engine/utils" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/bsontype" + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/schema" +) + +type mockObject struct { + ID string + CreatedAt time.Time + UniqueFieldName string + Number int + ReferenceID string + Metadata Metadata + FieldInIDs IDs + FieldOutIDs IDs +} + +func mockDialector(engine Engine) gorm.Dialector { + mockDb, _, _ := sqlmock.New() + switch engine { + case MySQL: + return mysql.New(mysql.Config{ + Conn: mockDb, + SkipInitializeWithVersion: true, + DriverName: "mysql", + }) + case PostgreSQL: + return postgres.New(postgres.Config{ + Conn: mockDb, + DriverName: "postgres", + }) + case SQLite: + return sqlite.Open("file::memory:?cache=shared") + case MongoDB, Empty: + // the where builder is not applicable for MongoDB + return nil + default: + return nil + } +} + +func mockClient(engine Engine) (*Client, *gorm.DB) { + clientInterface, _ := NewClient(context.Background()) + client, _ := clientInterface.(*Client) + client.options.engine = engine + dialector := mockDialector(engine) + gdb, _ := gorm.Open(dialector, &gorm.Config{}) + return client, gdb +} + +type Metadata map[string]interface{} + +func (m Metadata) GormDataType() string { + return "text" +} + +func (m *Metadata) Scan(value interface{}) error { + if value == nil { + return nil + } + + byteValue, err := utils.ToByteArray(value) + if err != nil || bytes.Equal(byteValue, []byte("")) || bytes.Equal(byteValue, []byte("\"\"")) { + return nil + } + + return json.Unmarshal(byteValue, &m) +} + +func (m Metadata) Value() (driver.Value, error) { + if m == nil { + return nil, nil + } + marshal, err := json.Marshal(m) + if err != nil { + return nil, err + } + + return string(marshal), nil +} + +func (Metadata) GormDBDataType(db *gorm.DB, _ *schema.Field) string { + if db.Dialector.Name() == Postgres { + return JSONB + } + return JSON +} + +func (m *Metadata) MarshalBSONValue() (bsontype.Type, []byte, error) { + if m == nil || len(*m) == 0 { + return bson.TypeNull, nil, nil + } + + metadata := make([]map[string]interface{}, 0) + for key, value := range *m { + metadata = append(metadata, map[string]interface{}{ + "k": key, + "v": value, + }) + } + + return bson.MarshalValue(metadata) +} + +func (m *Metadata) UnmarshalBSONValue(t bsontype.Type, data []byte) error { + raw := bson.RawValue{Type: t, Value: data} + + if raw.Value == nil { + return nil + } + + var uMap []map[string]interface{} + if err := raw.Unmarshal(&uMap); err != nil { + return err + } + + *m = make(Metadata) + for _, meta := range uMap { + key := meta["k"].(string) + (*m)[key] = meta["v"] + } + + return nil +} + +type IDs []string + +// GormDataType type in gorm +func (i IDs) GormDataType() string { + return "text" +} + +// Scan scan value into JSON, implements sql.Scanner interface +func (i *IDs) Scan(value interface{}) error { + if value == nil { + return nil + } + + byteValue, err := utils.ToByteArray(value) + if err != nil { + return nil + } + + return json.Unmarshal(byteValue, &i) +} + +// Value return json value, implement driver.Valuer interface +func (i IDs) Value() (driver.Value, error) { + if i == nil { + return nil, nil + } + marshal, err := json.Marshal(i) + if err != nil { + return nil, err + } + + return string(marshal), nil +} + +// GormDBDataType the gorm data type for metadata +func (IDs) GormDBDataType(db *gorm.DB, _ *schema.Field) string { + if db.Dialector.Name() == Postgres { + return JSONB + } + return JSON +} diff --git a/engine/datastore/where_test.go b/engine/datastore/where_test.go index f06221a45..25b8dab40 100644 --- a/engine/datastore/where_test.go +++ b/engine/datastore/where_test.go @@ -1,201 +1,15 @@ package datastore import ( - "bytes" - "context" "database/sql" - "database/sql/driver" - "encoding/json" "testing" "time" - "github.com/DATA-DOG/go-sqlmock" customtypes "github.com/bitcoin-sv/spv-wallet/engine/datastore/customtypes" - "github.com/bitcoin-sv/spv-wallet/engine/utils" "github.com/stretchr/testify/assert" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/bson/bsontype" - "gorm.io/driver/mysql" - "gorm.io/driver/postgres" - "gorm.io/driver/sqlite" "gorm.io/gorm" - "gorm.io/gorm/schema" ) -func mockDialector(engine Engine) gorm.Dialector { - mockDb, _, _ := sqlmock.New() - switch engine { - case MySQL: - return mysql.New(mysql.Config{ - Conn: mockDb, - SkipInitializeWithVersion: true, - DriverName: "mysql", - }) - case PostgreSQL: - return postgres.New(postgres.Config{ - Conn: mockDb, - DriverName: "postgres", - }) - case SQLite: - return sqlite.Open("file::memory:?cache=shared") - case MongoDB, Empty: - // the where builder is not applicable for MongoDB - return nil - default: - return nil - } -} - -func mockClient(engine Engine) (*Client, *gorm.DB) { - clientInterface, _ := NewClient(context.Background()) - client, _ := clientInterface.(*Client) - client.options.engine = engine - dialector := mockDialector(engine) - gdb, _ := gorm.Open(dialector, &gorm.Config{}) - return client, gdb -} - -func makeWhereBuilder(client *Client, gdb *gorm.DB, model interface{}) *whereBuilder { - return &whereBuilder{ - client: client, - tx: gdb.Model(model), - varNum: 0, - } -} - -const ( - // MetadataField is the field name used for metadata (params) - MetadataField = "metadata" -) - -type Metadata map[string]interface{} - -func (m Metadata) GormDataType() string { - return "text" -} - -func (m *Metadata) Scan(value interface{}) error { - if value == nil { - return nil - } - - byteValue, err := utils.ToByteArray(value) - if err != nil || bytes.Equal(byteValue, []byte("")) || bytes.Equal(byteValue, []byte("\"\"")) { - return nil - } - - return json.Unmarshal(byteValue, &m) -} - -func (m Metadata) Value() (driver.Value, error) { - if m == nil { - return nil, nil - } - marshal, err := json.Marshal(m) - if err != nil { - return nil, err - } - - return string(marshal), nil -} - -func (Metadata) GormDBDataType(db *gorm.DB, _ *schema.Field) string { - if db.Dialector.Name() == Postgres { - return JSONB - } - return JSON -} - -func (m *Metadata) MarshalBSONValue() (bsontype.Type, []byte, error) { - if m == nil || len(*m) == 0 { - return bson.TypeNull, nil, nil - } - - metadata := make([]map[string]interface{}, 0) - for key, value := range *m { - metadata = append(metadata, map[string]interface{}{ - "k": key, - "v": value, - }) - } - - return bson.MarshalValue(metadata) -} - -func (m *Metadata) UnmarshalBSONValue(t bsontype.Type, data []byte) error { - raw := bson.RawValue{Type: t, Value: data} - - if raw.Value == nil { - return nil - } - - var uMap []map[string]interface{} - if err := raw.Unmarshal(&uMap); err != nil { - return err - } - - *m = make(Metadata) - for _, meta := range uMap { - key := meta["k"].(string) - (*m)[key] = meta["v"] - } - - return nil -} - -type IDs []string - -// GormDataType type in gorm -func (i IDs) GormDataType() string { - return "text" -} - -// Scan scan value into JSON, implements sql.Scanner interface -func (i *IDs) Scan(value interface{}) error { - if value == nil { - return nil - } - - byteValue, err := utils.ToByteArray(value) - if err != nil { - return nil - } - - return json.Unmarshal(byteValue, &i) -} - -// Value return json value, implement driver.Valuer interface -func (i IDs) Value() (driver.Value, error) { - if i == nil { - return nil, nil - } - marshal, err := json.Marshal(i) - if err != nil { - return nil, err - } - - return string(marshal), nil -} - -// GormDBDataType the gorm data type for metadata -func (IDs) GormDBDataType(db *gorm.DB, _ *schema.Field) string { - if db.Dialector.Name() == Postgres { - return JSONB - } - return JSON -} - -type mockObject struct { - ID string - CreatedAt time.Time - UniqueFieldName string - Number int - ReferenceID string - Metadata Metadata - FieldInIDs IDs - FieldOutIDs IDs -} - // Test_whereObject test the SQL where selector func Test_whereObject(t *testing.T) { t.Parallel()