Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add command line flags and fix some bugs #1

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 86 additions & 19 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package main

import (
"database/sql"
"flag"
"fmt"
"log"
"os"
"path/filepath"
"strconv"
"strings"
"unicode"
Expand Down Expand Up @@ -48,6 +51,16 @@ var commonInitialisms = map[string]bool{
"DB": true,
}

var flagDbFile = flag.String("db-file", "./example.db", "path to the DB")
var flagOut = flag.String("out", "./gen", "output file for generated files")
var flagIgnoreColumns = flag.String("skip", "rowid,_rowid_,_rid,rid", "list of columns to be excluded from struct generation")
var flagGenJson = flag.Bool("json", true, "generate JSON annotation")
var flagGenDb = flag.Bool("db", true, "generate DB annotation")
var flagGenGorm = flag.Bool("gorm", true, "generate GORM annotation")
var flagPkgName = flag.String("pkg", "def", "specify package name")

var ignoreColumns []string

var intToWordMap = []string{
"zero",
"one",
Expand All @@ -62,25 +75,46 @@ var intToWordMap = []string{
}

func main() {
db, err := sql.Open("sqlite3", "./example.db")
flag.Parse()

db, err := sql.Open("sqlite3", *flagDbFile)

ignoreColumns = strings.Split(strings.ToLower(*flagIgnoreColumns), ",")

if err != nil {
log.Fatal(err)
}

defer db.Close()

tableNames := getTableNames(db)
outPath := filepath.Clean(*flagOut)
errOs := os.MkdirAll(outPath, 0770)

if errOs != nil {
log.Fatal(errOs)
}

os.Create(outPath)
c := 0
fmt.Printf("Generating code for the following tables (%d)\n", len(tableNames))
for _, tableName := range tableNames {
file := scanTableStructure(db, tableName)
c++
fmt.Printf("[%d] %s\n", c, tableName)

file := scanTableStructure(db, tableName, outPath, *flagPkgName)
structureName := formatFieldName(tableName)
err = file.Save("gen/" + fmt.Sprintf("%s.go", structureName))
fileName := filepath.Join(outPath, fmt.Sprintf("%s.go", structureName))
err = file.Save(fileName)
}

if err != nil {
panic(err)
}
}

func scanTableStructure(db *sql.DB, tableName string) *jen.File {
file := jen.NewFilePathName("gen", "def")
func scanTableStructure(db *sql.DB, tableName string, outPath string, packageName string) *jen.File {
file := jen.NewFilePathName(outPath, packageName)
structureName := formatFieldName(tableName)
file.Comment(fmt.Sprintf("// %s represent database table (%s)", structureName, tableName))
file.Type().Id(structureName).Struct(
Expand Down Expand Up @@ -120,49 +154,70 @@ func generateTableFields(db *sql.DB, tableName string) *[]jen.Code {
}
defer rows.Close()

fmt.Println(getTableNames(db))

//fmt.Println(getTableNames(db))
for rows.Next() {
var cid int
var name string
var ctype string
var notnull string
var dfltValue sql.NullString
var pk string
var ignore bool

err = rows.Scan(&cid, &name, &ctype, &notnull, &dfltValue, &pk)
if err != nil {
log.Fatal(err)
}
fmt.Println(cid, name, ctype, notnull, dfltValue, pk)
field := jen.Id(formatFieldName(name))

ignore = isIgnoreField(name)

//fmt.Println(cid, name, ctype, notnull, dfltValue, pk)
name2 := formatFieldName(name)
if ignore {
name2 = `// ` + name2
}
field := jen.Id(name2)
setFieldType(field, ctype)
setFieldTags(field, name)
fields = append(fields, field)
}

return &fields
}

func setFieldTags(field *jen.Statement, name string) {
field.Tag(
map[string]string{
"json": name,
"gorm": fmt.Sprintf("column:%s", name),
},
)
m := map[string]string{}

if *flagGenDb {
m["db"] = name
}

if *flagGenGorm {
m["gorm"] = fmt.Sprintf("column:%s", name)
}

if *flagGenJson {
m["json"] = name
}

field.Tag(m)
}

func setFieldType(field *jen.Statement, ctype string) {
dbType := strings.Split(ctype, "(")[0]
dbType := strings.ToUpper(strings.Split(ctype, "(")[0])
switch dbType {
case "VARCHAR", "TEXT":
field.String()
case "BOOL":
case "BOOL", "BOOLEAN":
field.Bool()
case "INTEGER":
case "TINYINT", "SMALLINT":
field.Int32()
case "FLOAT":
case "INTEGER", "INT", "INT2", "MEDIUMINT", "BIGINT", "UNSIGNED BIG INT", "INT8":
field.Int64()
case "REAL", "DOUBLE", "DOUBLE PRECISION", "FLOAT":
field.Float32()
case "NUMERIC", "DECIMAL", "DECIMAL(10,5)":
field.Float64()
default:
field.String()
}
Expand Down Expand Up @@ -210,6 +265,7 @@ func lintFieldName(name string) string {
break
}
}

if allLower {
runes := []rune(name)
if u := strings.ToUpper(name); commonInitialisms[u] {
Expand All @@ -227,6 +283,7 @@ func lintFieldName(name string) string {
break
}
}

if allUpperWithUnderscore {
name = strings.ToLower(name)
}
Expand Down Expand Up @@ -291,3 +348,13 @@ func stringifyFirstChar(str string) string {

return intToWordMap[i] + "_" + str[1:]
}

func isIgnoreField(fieldName string) bool {
lowerFieldName := strings.ToLower(fieldName)
for _, v := range ignoreColumns {
if v == lowerFieldName {
return true
}
}
return false
}