diff --git a/interfaces.go b/interfaces.go index 3bcc3d570..d64e428c4 100644 --- a/interfaces.go +++ b/interfaces.go @@ -20,6 +20,10 @@ type Dialector interface { Explain(sql string, vars ...interface{}) string } +type ArrayValueHandler interface { + HandleArray(field *schema.Field) error +} + // Plugin GORM plugin interface type Plugin interface { Name() string diff --git a/schema/field.go b/schema/field.go index a16c98ab0..bfe989622 100644 --- a/schema/field.go +++ b/schema/field.go @@ -47,6 +47,7 @@ const ( String DataType = "string" Time DataType = "time" Bytes DataType = "bytes" + Array DataType = "array" ) const DefaultAutoIncrementIncrement int64 = 1 @@ -282,6 +283,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { case reflect.Array, reflect.Slice: if reflect.Indirect(fieldValue).Type().Elem() == ByteReflectType && field.DataType == "" { field.DataType = Bytes + } else { + elemType := reflect.Indirect(fieldValue).Type().Elem() + field.DataType = Array + field.TagSettings["ELEM_TYPE"] = elemType.Kind().String() } } @@ -977,6 +982,10 @@ func (field *Field) setupValuerAndSetter() { return } } + + if field.DataType != "" && field.FieldType.Kind() == reflect.Slice && field.FieldType.Elem().Kind() != reflect.Uint8 { + field.TagSettings["ARRAY_FIELD"] = "true" + } } func (field *Field) setupNewValuePool() { diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index 472434b48..32ad1ae5e 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -7,6 +7,8 @@ import ( "encoding/json" "errors" "fmt" + "log" + "os" "reflect" "regexp" "strconv" @@ -65,6 +67,54 @@ func TestScannerValuer(t *testing.T) { AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Bytes", "Num", "Strings", "Structs") } +func TestScannerValuerArray(t *testing.T) { + // Use custom dialector to enable array handler + os.Setenv("GORM_DIALECT", "postgres") + os.Setenv("GORM_ENABLE_ARRAY_HANDLER", "true") + var err error + if DB, err = OpenTestConnection(&gorm.Config{}); err != nil { + log.Printf("failed to connect database, got error %v", err) + os.Exit(1) + } + + DB.Migrator().DropTable(&ScannerValuerStructOfArrays{}) + if err := DB.Migrator().AutoMigrate(&ScannerValuerStructOfArrays{}); err != nil { + t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) + } + + data := ScannerValuerStructOfArrays{ + StringArray: []string{"a", "b", "c"}, + IntArray: []int{1, 2, 3}, + Int8Array: []int8{1, 2, 3}, + Int16Array: []int16{1, 2, 3}, + Int32Array: []int32{1, 2, 3}, + Int64Array: []int64{1, 2, 3}, + UintArray: []uint{1, 2, 3}, + Uint16Array: []uint16{1, 2, 3}, + Uint32Array: []uint32{1, 2, 3}, + Uint64Array: []uint64{1, 2, 3}, + Float32Array: []float32{ + 1.1, 2.2, 3.3, + }, + Float64Array: []float64{ + 1.1, 2.2, 3.3, + }, + BoolArray: []bool{true, false, true}, + } + + if err := DB.Create(&data).Error; err != nil { + t.Fatalf("No error should happened when create scanner valuer struct, but got %v", err) + } + + var result ScannerValuerStructOfArrays + + if err := DB.Find(&result, "id = ?", data.ID).Error; err != nil { + t.Fatalf("no error should happen when query scanner, valuer struct, but got %v", err) + } + + AssertObjEqual(t, data, result, "StringArray", "IntArray", "Int8Array", "Int16Array", "Int32Array", "Int64Array", "UintArray", "Uint16Array", "Uint32Array", "Uint64Array", "Float32Array", "Float64Array", "BoolArray") +} + func TestScannerValuerWithFirstOrCreate(t *testing.T) { DB.Migrator().DropTable(&ScannerValuerStruct{}) if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil { @@ -162,6 +212,23 @@ type ScannerValuerStruct struct { ExampleStructPtr *ExampleStruct } +type ScannerValuerStructOfArrays struct { + gorm.Model + StringArray []string + IntArray []int + Int8Array []int8 + Int16Array []int16 + Int32Array []int32 + Int64Array []int64 + UintArray []uint + Uint16Array []uint16 + Uint32Array []uint32 + Uint64Array []uint64 + Float32Array []float32 + Float64Array []float64 + BoolArray []bool +} + type EncryptedData []byte func (data *EncryptedData) Scan(value interface{}) error { diff --git a/tests/tests_test.go b/tests/tests_test.go index e84162cd3..a59130928 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -48,6 +48,7 @@ func init() { func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) { dbDSN := os.Getenv("GORM_DSN") + enableArrayHandler := os.Getenv("GORM_ENABLE_ARRAY_HANDLER") switch os.Getenv("GORM_DIALECT") { case "mysql": log.Println("testing mysql...") @@ -63,6 +64,7 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) { db, err = gorm.Open(postgres.New(postgres.Config{ DSN: dbDSN, PreferSimpleProtocol: true, + EnableArrayHandler: enableArrayHandler == "true", }), cfg) case "sqlserver": // go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest