diff --git a/pattern.go b/pattern.go index dcd688f..ceacdd4 100644 --- a/pattern.go +++ b/pattern.go @@ -45,7 +45,7 @@ type Pattern struct { } func NewPattern(cleanRedundantAt ...bool) *Pattern { - clean := false + clean := true if len(cleanRedundantAt) > 0 { clean = cleanRedundantAt[0] } @@ -57,10 +57,9 @@ func NewPattern(cleanRedundantAt ...bool) *Pattern { } type PatternSegment struct { - typ string - optional bool - parse Parser - cleanRedundantAt bool // only for Reply + typ string + optional bool + parse Parser } type Parser func(msg *message.Segment) PatternParsed @@ -68,13 +67,13 @@ type Parser func(msg *message.Segment) PatternParsed // SetOptional set previous segment is optional, is v is empty, optional will be true // if Pattern is empty, panic func (p *Pattern) SetOptional(v ...bool) *Pattern { - if len((*p).segments) == 0 { + if len(p.segments) == 0 { panic("pattern is empty") } if len(v) == 1 { - (*p).segments[len((*p).segments)-1].optional = v[0] + p.segments[len(p.segments)-1].optional = v[0] } else { - (*p).segments[len((*p).segments)-1].optional = true + p.segments[len(p.segments)-1].optional = true } return p } @@ -122,16 +121,11 @@ func (p PatternParsed) Raw() *message.Segment { return p.msg } -func (p *Pattern) Add(typ string, optional bool, parse Parser, cleanRedundantAt ...bool) *Pattern { - clean := false - if len(cleanRedundantAt) > 0 { - clean = cleanRedundantAt[0] - } +func (p *Pattern) Add(typ string, optional bool, parse Parser) *Pattern { pattern := &PatternSegment{ - typ: typ, - optional: optional, - parse: parse, - cleanRedundantAt: clean, + typ: typ, + optional: optional, + parse: parse, } p.segments = append(p.segments, *pattern) return p @@ -197,12 +191,8 @@ func NewImageParser() Parser { } // Reply type zero.PatternReplyMatched -func (p *Pattern) Reply(noCleanRedundantAt ...bool) *Pattern { - noClean := false - if len(noCleanRedundantAt) > 0 { - noClean = noCleanRedundantAt[0] - } - p.Add("reply", false, NewReplyParser(), !noClean) +func (p *Pattern) Reply() *Pattern { + p.Add("reply", false, NewReplyParser()) return p } @@ -264,12 +254,14 @@ func patternMatch(ctx *Ctx, pattern Pattern, msgs []message.Segment) bool { for i := 0; i < len(pattern.segments); i++ { if j < len(msgs) && pattern.segments[i].matchType(msgs[j]) { patternState[i] = pattern.segments[i].parse(&msgs[j]) - } else { + } + if patternState[i].value == nil { if pattern.segments[i].optional { continue } return false } + j++ } ctx.State[KeyPattern] = patternState return true diff --git a/pattern_test.go b/pattern_test.go index 668c189..3ceb4eb 100644 --- a/pattern_test.go +++ b/pattern_test.go @@ -1,6 +1,7 @@ package zero import ( + "fmt" "github.com/stretchr/testify/assert" "github.com/tidwall/gjson" "github.com/wdvxdr1123/ZeroBot/message" @@ -27,7 +28,7 @@ func fakeCtx(msg message.Message) *Ctx { // copy from extension.PatternModel type PatternModel struct { - Matched []*PatternParsed `zero:"pattern_matched"` + Matched []PatternParsed `zero:"pattern_matched"` } // Test Match @@ -128,7 +129,7 @@ func TestPattern_ReplyFilter(t *testing.T) { expected bool }{ {[]message.Segment{message.Reply(12345), message.At(12345), message.Text("1234")}, NewPattern().Reply().Text("1234"), true}, - {[]message.Segment{message.Reply(12345), message.At(12345), message.Text("1234")}, NewPattern().Reply(true).Text("1234"), false}, + {[]message.Segment{message.Reply(12345), message.At(12345), message.Text("1234")}, NewPattern(false).Reply().Text("1234"), false}, } for i, v := range textTests { t.Run(strconv.Itoa(i), func(t *testing.T) { @@ -234,7 +235,10 @@ func TestPattern_SetOptional(t *testing.T) { } assert.Equal(t, len(v.expected), len(parsed.Matched)) for i := range parsed.Matched { - assert.Equal(t, v.expected[i].value != nil, parsed.Matched[i].value != nil) + t.Run(strconv.Itoa(i), func(t *testing.T) { + fmt.Println((parsed.Matched[i].value)) + assert.Equal(t, v.expected[i].value != nil, parsed.Matched[i].value != nil) + }) } }) }