Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow structs to be self-verifying. #293

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ var (
tags []string
structNameFromTitle bool
minSizedInts bool
structVerify bool

errFlagFormat = errors.New("flag must be in the format URI=PACKAGE")

Expand Down Expand Up @@ -77,6 +78,7 @@ var (
Tags: tags,
OnlyModels: onlyModels,
MinSizedInts: minSizedInts,
StructVerify: structVerify,
}
for _, id := range allKeys(schemaPackageMap, schemaOutputMap, schemaRootTypeMap) {
mapping := generator.SchemaMapping{SchemaID: id}
Expand Down Expand Up @@ -174,6 +176,12 @@ also look for foo.json if --resolve-extension json is provided.`)
false,
"Uses sized int and uint values based on the min and max values for the field")

rootCmd.PersistentFlags().BoolVar(
&structVerify,
"struct-verify",
false,
"Add a Verify method to the generated struct that validates the struct against the schema")

abortWithErr(rootCmd.Execute())
}

Expand Down
1 change: 1 addition & 0 deletions pkg/generator/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type Config struct {
OnlyModels bool
MinSizedInts bool
Loader schemas.Loader
StructVerify bool
}

type SchemaMapping struct {
Expand Down
4 changes: 4 additions & 0 deletions pkg/generator/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ func New(config Config) (*Generator, error) {
formatters = append(formatters, &yamlFormatter{})
}

if config.StructVerify {
formatters = append(formatters, &verifyFormatter{})
}

generator := &Generator{
caser: text.NewCaser(config.Capitalizations, config.ResolveExtensions),
config: config,
Expand Down
5 changes: 1 addition & 4 deletions pkg/generator/json_formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,7 @@ func (jf *jsonFormatter) enumUnmarshal(
enumType.Generate(out)
out.Newline()

varName := "v"
if wrapInStruct {
varName += ".Value"
}
varName := enumVarName(wrapInStruct)

out.Printlnf("if err := json.Unmarshal(b, &%s); err != nil { return err }", varName)
out.Printlnf("var ok bool")
Expand Down
10 changes: 10 additions & 0 deletions pkg/generator/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,13 @@ func isNamedType(t codegen.Type) bool {

return false
}

func enumVarName(wrapInStruct bool) string {
varName := "v"

if wrapInStruct {
varName += ".Value"
}

return varName
}
195 changes: 195 additions & 0 deletions pkg/generator/verify_formatter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
package generator

import (
"strings"

"github.com/atombender/go-jsonschema/pkg/codegen"
)

type verifyFormatter struct{}

func (v verifyFormatter) addImport(_ *codegen.File) {}

func (v verifyFormatter) generate(declType codegen.TypeDecl, validators []validator) func(*codegen.Emitter) {
return func(out *codegen.Emitter) {
var prefix string
switch declType.Type.(type) {
// No need to dereference the struct just to verify it.
case *codegen.StructType:
prefix = "*"

default:
prefix = ""
}

out.Comment("Verify checks all fields on the struct match the schema.")
out.Printlnf("func (%s %s%s) Verify() error {", varNamePlainStruct, prefix, declType.Name)
out.Indent(1)

for _, va := range validators {
desc := va.desc()
if desc.beforeJSONUnmarshal || desc.requiresRawAfter || !desc.hasError {
continue
}

va.generate(out)
}

if stct, ok := declType.Type.(*codegen.StructType); ok {
for _, field := range stct.Fields {
name := strings.ToLower(field.Name[0:1]) + field.Name[1:]
if verifyEmit := v.verifyType(field.Type, name); verifyEmit != nil {
out.Printlnf("%s := %s", name, getPlainName(field.Name))
verifyEmit(out)
}
}
}

out.Printlnf("return nil")
out.Indent(-1)
out.Printlnf("}")
}
}

func (v verifyFormatter) verifyType(tpe codegen.Type, access string) func(*codegen.Emitter) {
// For some types, pointers are sometimes used and sometime not.
switch utpe := tpe.(type) {
case *codegen.ArrayType:
return v.verifyArray(*utpe, access)

case codegen.ArrayType:
return v.verifyArray(utpe, access)

Check warning on line 61 in pkg/generator/verify_formatter.go

View check run for this annotation

Codecov / codecov/patch

pkg/generator/verify_formatter.go#L60-L61

Added lines #L60 - L61 were not covered by tests

case codegen.CustomNameType, *codegen.CustomNameType, codegen.NamedType, *codegen.NamedType, *codegen.StructType:
return func(out *codegen.Emitter) {
out.Printlnf("if err := %s.Verify(); err != nil {", access)
out.Indent(1)
out.Printlnf("return err")
out.Indent(-1)
out.Printlnf("}")
}

case *codegen.MapType:
return v.verifyMap(*utpe, access)

Check warning on line 73 in pkg/generator/verify_formatter.go

View check run for this annotation

Codecov / codecov/patch

pkg/generator/verify_formatter.go#L72-L73

Added lines #L72 - L73 were not covered by tests

case codegen.MapType:
return v.verifyMap(utpe, access)

Check warning on line 76 in pkg/generator/verify_formatter.go

View check run for this annotation

Codecov / codecov/patch

pkg/generator/verify_formatter.go#L75-L76

Added lines #L75 - L76 were not covered by tests

case *codegen.PointerType:
return v.verifyPointer(*utpe, access)

case codegen.PointerType:
return v.verifyPointer(utpe, access)

Check warning on line 82 in pkg/generator/verify_formatter.go

View check run for this annotation

Codecov / codecov/patch

pkg/generator/verify_formatter.go#L81-L82

Added lines #L81 - L82 were not covered by tests

default:
return nil
}
}

func (v verifyFormatter) enumMarshal(_ codegen.TypeDecl) func(*codegen.Emitter) {
return func(out *codegen.Emitter) {}
}

func (v verifyFormatter) enumUnmarshal(
declType codegen.TypeDecl,
_ codegen.Type,
valueConstant *codegen.Var,
wrapInStruct bool,
) func(*codegen.Emitter) {
return func(out *codegen.Emitter) {
varName := enumVarName(wrapInStruct)

out.Comment("Verify checks all fields on the struct match the schema.")
out.Printlnf("func (%s %s) Verify() error {", varNamePlainStruct, declType.Name)
out.Indent(1)
out.Printlnf("for _, expected := range %s {", valueConstant.Name)
out.Indent(1)
out.Printlnf("if reflect.DeepEqual(%s, expected) { return nil }", varName)
out.Indent(-1)
out.Printlnf("}")
out.Printlnf(`return fmt.Errorf("invalid value (expected one of %%#v): %%#v", %s, %s)`,
valueConstant.Name, varName)
out.Indent(-1)
out.Printlnf("}")
}
}

func (v verifyFormatter) verifyArray(tpe codegen.ArrayType, access string) func(*codegen.Emitter) {
aaccess := "a" + access

verifyFn := v.verifyType(tpe.Type, aaccess)
if verifyFn == nil {
return nil
}

return func(out *codegen.Emitter) {
out.Printlnf("for _, %s := range %s {", aaccess, access)
out.Indent(1)
verifyFn(out)
out.Indent(-1)
out.Printlnf("}")
}
}

func (v verifyFormatter) verifyMap(tpe codegen.MapType, access string) func(*codegen.Emitter) {
keyAccess := "k" + access
valueAccess := "v" + access
verifyKeyFn := v.verifyType(tpe.KeyType, keyAccess)
verifyValueFn := v.verifyType(tpe.ValueType, valueAccess)

if verifyKeyFn == nil && verifyValueFn == nil {
return nil
}

Check warning on line 142 in pkg/generator/verify_formatter.go

View check run for this annotation

Codecov / codecov/patch

pkg/generator/verify_formatter.go#L134-L142

Added lines #L134 - L142 were not covered by tests

if verifyKeyFn == nil {
keyAccess = "_"
}

Check warning on line 146 in pkg/generator/verify_formatter.go

View check run for this annotation

Codecov / codecov/patch

pkg/generator/verify_formatter.go#L144-L146

Added lines #L144 - L146 were not covered by tests

if verifyValueFn == nil {
valueAccess = "_"
}

Check warning on line 150 in pkg/generator/verify_formatter.go

View check run for this annotation

Codecov / codecov/patch

pkg/generator/verify_formatter.go#L148-L150

Added lines #L148 - L150 were not covered by tests

return func(out *codegen.Emitter) {
out.Printlnf("for %s, %s := range %s {", keyAccess, valueAccess, access)
out.Indent(1)

if verifyKeyFn != nil {
verifyKeyFn(out)
}

Check warning on line 158 in pkg/generator/verify_formatter.go

View check run for this annotation

Codecov / codecov/patch

pkg/generator/verify_formatter.go#L152-L158

Added lines #L152 - L158 were not covered by tests

if verifyValueFn != nil {
verifyValueFn(out)
}

Check warning on line 162 in pkg/generator/verify_formatter.go

View check run for this annotation

Codecov / codecov/patch

pkg/generator/verify_formatter.go#L160-L162

Added lines #L160 - L162 were not covered by tests

out.Indent(-1)
out.Printlnf("}")

Check warning on line 165 in pkg/generator/verify_formatter.go

View check run for this annotation

Codecov / codecov/patch

pkg/generator/verify_formatter.go#L164-L165

Added lines #L164 - L165 were not covered by tests
}
}

func (v verifyFormatter) verifyPointer(tpe codegen.PointerType, access string) func(*codegen.Emitter) {
var prefix string
switch tpe.Type.(type) {
// Access the verify and fields without copying it.
case codegen.CustomNameType, *codegen.CustomNameType, codegen.NamedType, *codegen.NamedType:
prefix = ""

default:
prefix = "*"
}

paccess := "p" + access

verifyFn := v.verifyType(tpe.Type, paccess)
if verifyFn == nil {
return nil
}

return func(out *codegen.Emitter) {
out.Printlnf("if %s != nil {", access)
out.Printlnf("%s := %s%s", paccess, prefix, access)
out.Indent(1)
verifyFn(out)
out.Indent(-1)
out.Printlnf("}")
}
}
5 changes: 1 addition & 4 deletions pkg/generator/yaml_formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,7 @@ func (yf *yamlFormatter) enumUnmarshal(
enumType.Generate(out)
out.Newline()

varName := "v"
if wrapInStruct {
varName += ".Value"
}
varName := enumVarName(wrapInStruct)

out.Printlnf("if err := value.Decode(&%s); err != nil { return err }", varName)
out.Printlnf("var ok bool")
Expand Down
Loading