diff --git a/extension/model.go b/extension/model.go index 5690a9a..5e4ee1b 100644 --- a/extension/model.go +++ b/extension/model.go @@ -1,5 +1,7 @@ package extension +import zero "github.com/wdvxdr1123/ZeroBot" + // PrefixModel is model of zero.PrefixRule type PrefixModel struct { Prefix string `zero:"prefix"` @@ -32,3 +34,8 @@ type FullMatchModel struct { type RegexModel struct { Matched []string `zero:"regex_matched"` } + +// PatternModel is model of zero.PatternRule +type PatternModel struct { + Matched []*zero.PatternParsed `zero:"pattern_matched"` +} diff --git a/extension/rate/rate.go b/extension/rate/rate.go index f176e54..24720f3 100644 --- a/extension/rate/rate.go +++ b/extension/rate/rate.go @@ -26,7 +26,7 @@ func NewManager[K comparable](interval time.Duration, burst int) *LimiterManager } // Delete 删除对应限速器 -func (l *LimiterManager[K]) Delete(key K) { +func (l *LimiterManager[K]) Delete(key K) { l.limiters.Delete(key) } diff --git a/extension/single/single.go b/extension/single/single.go index af52909..e4304fb 100644 --- a/extension/single/single.go +++ b/extension/single/single.go @@ -4,6 +4,7 @@ import ( "runtime" "github.com/RomiChan/syncx" + zero "github.com/wdvxdr1123/ZeroBot" ) diff --git a/pattern.go b/pattern.go new file mode 100644 index 0000000..d25d2c5 --- /dev/null +++ b/pattern.go @@ -0,0 +1,268 @@ +package zero + +import ( + "regexp" + "strconv" + "strings" + + "github.com/wdvxdr1123/ZeroBot/message" +) + +const ( + KeyPattern = "pattern_matched" +) + +// AsRule build PatternRule +func (p *Pattern) AsRule() Rule { + return func(ctx *Ctx) bool { + if len(ctx.Event.Message) == 0 { + return false + } + if !p.cleanRedundantAt { + return patternMatch(ctx, *p, ctx.Event.Message) + } + + // copy messages + msgs := make([]message.Segment, 0, len(ctx.Event.Message)) + msgs = append(msgs, ctx.Event.Message[0]) + for i := 1; i < len(ctx.Event.Message); i++ { + if ctx.Event.Message[i-1].Type == "reply" && ctx.Event.Message[i].Type == "at" { + // [reply][at] + reply := ctx.GetMessage(ctx.Event.Message[i-1].Data["id"]) + if reply.MessageID.ID() != 0 && reply.Sender != nil && reply.Sender.ID != 0 && strconv.FormatInt(reply.Sender.ID, 10) == ctx.Event.Message[i].Data["qq"] { + continue + } + } + msgs = append(msgs, ctx.Event.Message[i]) + } + return patternMatch(ctx, *p, msgs) + } +} + +type Pattern struct { + cleanRedundantAt bool + segments []PatternSegment +} + +func NewPattern(cleanRedundantAt ...bool) *Pattern { + clean := true + if len(cleanRedundantAt) > 0 { + clean = cleanRedundantAt[0] + } + pattern := Pattern{ + cleanRedundantAt: clean, + segments: make([]PatternSegment, 0, 4), + } + return &pattern +} + +type PatternSegment struct { + typ string + optional bool + parse Parser +} + +type Parser func(msg *message.Segment) PatternParsed + +// SetOptional set previous segment is optional, is v is empty, optional will be true +// if Pattern is empty, panic +func (p *Pattern) SetOptional(v ...bool) *Pattern { + if len(p.segments) == 0 { + panic("pattern is empty") + } + if len(v) == 1 { + p.segments[len(p.segments)-1].optional = v[0] + } else { + p.segments[len(p.segments)-1].optional = true + } + return p +} + +// PatternParsed PatternRule parse result +type PatternParsed struct { + value any + msg *message.Segment +} + +// Text 获取正则表达式匹配到的文本数组 +func (p PatternParsed) Text() []string { + if p.value == nil { + return nil + } + return p.value.([]string) +} + +// At 获取被@者ID +func (p PatternParsed) At() string { + if p.value == nil { + return "" + } + return p.value.(string) +} + +// Image 获取图片URL +func (p PatternParsed) Image() string { + if p.value == nil { + return "" + } + return p.value.(string) +} + +// Reply 获取被回复的消息ID +func (p PatternParsed) Reply() string { + if p.value == nil { + return "" + } + return p.value.(string) +} + +// Raw 获取原始消息 +func (p PatternParsed) Raw() *message.Segment { + return p.msg +} + +func (p *Pattern) Add(typ string, optional bool, parse Parser) *Pattern { + pattern := &PatternSegment{ + typ: typ, + optional: optional, + parse: parse, + } + p.segments = append(p.segments, *pattern) + return p +} + +// Text use regex to search a 'text' segment +func (p *Pattern) Text(regex string) *Pattern { + p.Add("text", false, NewTextParser(regex)) + return p +} + +func NewTextParser(regex string) Parser { + re := regexp.MustCompile(regex) + return func(msg *message.Segment) PatternParsed { + s := msg.Data["text"] + s = strings.Trim(s, " \n\r\t") + matchString := re.MatchString(s) + if matchString { + return PatternParsed{ + value: re.FindStringSubmatch(s), + msg: msg, + } + } + + return PatternParsed{} + } +} + +// At use regex to match an 'at' segment, if id is not empty, only match specific target +func (p *Pattern) At(id ...message.ID) *Pattern { + if len(id) > 1 { + panic("at pattern only support one id") + } + p.Add("at", false, NewAtParser(id...)) + return p +} + +func NewAtParser(id ...message.ID) Parser { + return func(msg *message.Segment) PatternParsed { + if len(id) == 0 || len(id) == 1 && id[0].String() == msg.Data["qq"] { + return PatternParsed{ + value: msg.Data["qq"], + msg: msg, + } + } + return PatternParsed{} + } +} + +// Image use regex to match an 'at' segment, if id is not empty, only match specific target +func (p *Pattern) Image() *Pattern { + p.Add("image", false, NewImageParser()) + return p +} + +func NewImageParser() Parser { + return func(msg *message.Segment) PatternParsed { + return PatternParsed{ + value: msg.Data["file"], + msg: msg, + } + } +} + +// Reply type zero.PatternReplyMatched +func (p *Pattern) Reply() *Pattern { + p.Add("reply", false, NewReplyParser()) + return p +} + +func NewReplyParser() Parser { + return func(msg *message.Segment) PatternParsed { + return PatternParsed{ + value: msg.Data["id"], + msg: msg, + } + } +} + +// Any match any segment +func (p *Pattern) Any() *Pattern { + p.Add("any", false, NewAnyParser()) + return p +} + +func NewAnyParser() Parser { + return func(msg *message.Segment) PatternParsed { + parsed := PatternParsed{ + value: nil, + msg: msg, + } + switch { + case msg.Data["text"] != "": + parsed.value = msg.Data["text"] + case msg.Data["qq"] != "": + parsed.value = msg.Data["qq"] + case msg.Data["file"] != "": + parsed.value = msg.Data["file"] + case msg.Data["id"] != "": + parsed.value = msg.Data["id"] + default: + parsed.value = msg.Data + } + return parsed + } +} + +func (s *PatternSegment) matchType(msg message.Segment) bool { + return s.typ == msg.Type || s.typ == "any" +} +func mustMatchAllPatterns(pattern Pattern) bool { + for _, p := range pattern.segments { + if p.optional { + return false + } + } + return true +} +func patternMatch(ctx *Ctx, pattern Pattern, msgs []message.Segment) bool { + if mustMatchAllPatterns(pattern) && len(pattern.segments) != len(msgs) { + return false + } + patternState := make([]PatternParsed, len(pattern.segments)) + + j := 0 + for i := range pattern.segments { + if j < len(msgs) && pattern.segments[i].matchType(msgs[j]) { + patternState[i] = pattern.segments[i].parse(&msgs[j]) + } + if patternState[i].value == nil { + if pattern.segments[i].optional { + continue + } + return false + } + j++ + } + ctx.State[KeyPattern] = patternState + return true +} diff --git a/pattern_test.go b/pattern_test.go new file mode 100644 index 0000000..3ceb4eb --- /dev/null +++ b/pattern_test.go @@ -0,0 +1,302 @@ +package zero + +import ( + "fmt" + "github.com/stretchr/testify/assert" + "github.com/tidwall/gjson" + "github.com/wdvxdr1123/ZeroBot/message" + "strconv" + "testing" +) + +type mockAPICaller struct{} + +func (m mockAPICaller) CallAPI(_ APIRequest) (APIResponse, error) { + return APIResponse{ + Status: "", + Data: gjson.Parse(`{"message_id":"12345","sender":{"user_id":12345}}`), // just for reply cleaner + Msg: "", + Wording: "", + RetCode: 0, + Echo: 0, + }, nil +} +func fakeCtx(msg message.Message) *Ctx { + ctx := &Ctx{Event: &Event{Message: msg}, State: map[string]interface{}{}, caller: mockAPICaller{}} + return ctx +} + +// copy from extension.PatternModel +type PatternModel struct { + Matched []PatternParsed `zero:"pattern_matched"` +} + +// Test Match +func TestPattern_Text(t *testing.T) { + textTests := [...]struct { + msg message.Message + pattern *Pattern + expected bool + }{ + {[]message.Segment{message.Text("haha")}, NewPattern().Text("haha"), true}, + {[]message.Segment{message.Text("aaa")}, NewPattern().Text("not match"), false}, + {[]message.Segment{message.Image("not a image")}, NewPattern().Text("not match"), false}, + {[]message.Segment{message.At(114514)}, NewPattern().Text("not match"), false}, + {[]message.Segment{message.Text("你说的对但是ZeroBot-Plugin 是 ZeroBot 的 实用插件合集")}, NewPattern().Text("实用插件合集"), true}, + {[]message.Segment{message.Text("你说的对但是ZeroBot-Plugin 是 ZeroBot 的 实用插件合集")}, NewPattern().Text("nonono"), false}, + } + for i, v := range textTests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + ctx := fakeCtx(v.msg) + rule := v.pattern.AsRule() + out := rule(ctx) + assert.Equal(t, out, v.expected) + }) + } +} + +func TestPattern_Image(t *testing.T) { + textTests := [...]struct { + msg message.Message + pattern *Pattern + expected bool + }{ + {[]message.Segment{message.Text("haha")}, NewPattern().Image(), false}, + {[]message.Segment{message.Text("haha"), message.Image("not a image")}, NewPattern().Image().Image(), false}, + {[]message.Segment{message.Text("haha"), message.Image("not a image")}, NewPattern().Text("haha").Image(), true}, + {[]message.Segment{message.Image("not a image")}, NewPattern().Image(), true}, + {[]message.Segment{message.Image("not a image"), message.Image("not a image")}, NewPattern().Image(), false}, + {[]message.Segment{message.Image("not a image"), message.Image("not a image")}, NewPattern().Image().Image(), true}, + {[]message.Segment{message.Image("not a image"), message.Image("not a image")}, NewPattern().Image().Image().Image(), false}, + } + for i, v := range textTests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + ctx := fakeCtx(v.msg) + rule := v.pattern.AsRule() + out := rule(ctx) + assert.Equal(t, v.expected, out) + }) + } +} + +func TestPattern_At(t *testing.T) { + textTests := [...]struct { + msg message.Message + pattern *Pattern + expected bool + }{ + {[]message.Segment{message.Text("haha")}, NewPattern().At(), false}, + {[]message.Segment{message.Image("not a image")}, NewPattern().At(), false}, + {[]message.Segment{message.At(114514)}, NewPattern().At(), true}, + {[]message.Segment{message.At(114514)}, NewPattern().At(message.NewMessageIDFromString("1919810")), false}, + } + for i, v := range textTests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + ctx := fakeCtx(v.msg) + rule := v.pattern.AsRule() + out := rule(ctx) + assert.Equal(t, out, v.expected) + }) + } +} + +func TestPattern_Reply(t *testing.T) { + textTests := [...]struct { + msg message.Message + pattern *Pattern + expected bool + }{ + {[]message.Segment{message.Text("haha")}, NewPattern().Reply(), false}, + {[]message.Segment{message.Image("not a image")}, NewPattern().Reply(), false}, + {[]message.Segment{message.At(1919810), message.Reply(12345)}, NewPattern().Reply().At(), false}, + {[]message.Segment{message.Reply(12345), message.At(1919810)}, NewPattern().Reply().At(), true}, + {[]message.Segment{message.Reply(12345)}, NewPattern().Reply(), true}, + {[]message.Segment{message.Reply(12345), message.At(1919810)}, NewPattern().Reply(), false}, + } + for i, v := range textTests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + ctx := fakeCtx(v.msg) + rule := v.pattern.AsRule() + out := rule(ctx) + assert.Equal(t, out, v.expected) + }) + } +} +func TestPattern_ReplyFilter(t *testing.T) { + textTests := [...]struct { + msg message.Message + pattern *Pattern + expected bool + }{ + {[]message.Segment{message.Reply(12345), message.At(12345), message.Text("1234")}, NewPattern().Reply().Text("1234"), true}, + {[]message.Segment{message.Reply(12345), message.At(12345), message.Text("1234")}, NewPattern(false).Reply().Text("1234"), false}, + } + for i, v := range textTests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + ctx := fakeCtx(v.msg) + rule := v.pattern.AsRule() + out := rule(ctx) + assert.Equal(t, v.expected, out) + }) + } +} +func TestPattern_Any(t *testing.T) { + textTests := [...]struct { + msg message.Message + pattern *Pattern + expected bool + }{ + {[]message.Segment{message.Text("haha")}, NewPattern().Any(), true}, + {[]message.Segment{message.Image("not a image")}, NewPattern().Any(), true}, + {[]message.Segment{message.At(1919810), message.Reply(12345)}, NewPattern().Any().Reply(), true}, + {[]message.Segment{message.Reply(12345), message.At(1919810)}, NewPattern().Any().At(), true}, + } + for i, v := range textTests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + ctx := fakeCtx(v.msg) + rule := v.pattern.AsRule() + out := rule(ctx) + assert.Equal(t, out, v.expected) + }) + } + t.Run("get", func(t *testing.T) { + ctx := fakeCtx([]message.Segment{message.Reply("just for test")}) + rule := NewPattern().Any().AsRule() + _ = rule(ctx) + model := PatternModel{} + err := ctx.Parse(&model) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, "just for test", model.Matched[0].Reply()) + }) +} +func TestPatternParsed_Gets(t *testing.T) { + assert.Equal(t, []string{"gaga"}, PatternParsed{value: []string{"gaga"}}.Text()) + assert.Equal(t, "image", PatternParsed{value: "image"}.Image()) + assert.Equal(t, "reply", PatternParsed{value: "reply"}.Reply()) + assert.Equal(t, "114514", PatternParsed{value: "114514"}.At()) + text := message.Text("1234") + assert.Equal(t, &text, PatternParsed{msg: &text}.Raw()) +} +func TestPattern_SetOptional(t *testing.T) { + assert.Panics(t, func() { + NewPattern().SetOptional() + }) + tests := [...]struct { + msg message.Message + pattern *Pattern + expected []PatternParsed + }{ + {[]message.Segment{message.Text("/do it")}, NewPattern().Text("/(do) (.*)").At().SetOptional(true), []PatternParsed{ + { + value: []string{"/do it", "do", "it"}, + }, { + value: nil, + }, + }}, + {[]message.Segment{message.Text("/do it")}, NewPattern().Text("/(do) (.*)").At().SetOptional(false), []PatternParsed{}}, + {[]message.Segment{message.Text("happy bear"), message.At(114514)}, NewPattern().Reply().SetOptional().Text(".+").SetOptional().At().SetOptional(false), []PatternParsed{ + { + value: nil, + }, + { + value: "happy bear", + }, + { + value: "114514", + }, + }}, + {[]message.Segment{message.Text("happy bear"), message.At(114514)}, NewPattern().Image().SetOptional().Image().SetOptional().Image().SetOptional(), []PatternParsed{ // why you do this + { + value: nil, + }, + { + value: nil, + }, + { + value: nil, + }, + }}, + } + for i, v := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + ctx := fakeCtx(v.msg) + rule := v.pattern.AsRule() + matched := rule(ctx) + if !matched { + assert.Equal(t, 0, len(v.expected)) + return + } + parsed := &PatternModel{} + err := ctx.Parse(parsed) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, len(v.expected), len(parsed.Matched)) + for i := range parsed.Matched { + t.Run(strconv.Itoa(i), func(t *testing.T) { + fmt.Println((parsed.Matched[i].value)) + assert.Equal(t, v.expected[i].value != nil, parsed.Matched[i].value != nil) + }) + } + }) + } +} + +// Test Parse +func TestAllParse(t *testing.T) { + textTests := [...]struct { + msg message.Message + pattern *Pattern + expected []PatternParsed + }{ + {[]message.Segment{message.Text("test haha test"), message.At(123)}, NewPattern().Text("((ha)+)").At(), []PatternParsed{ + { + value: []string{"haha", "haha", "ha"}, + }, { + value: "123", + }, + }}, + {[]message.Segment{message.Text("haha")}, NewPattern().Text("(h)(a)(h)(a)"), []PatternParsed{ + { + value: []string{"haha", "h", "a", "h", "a"}, + }, + }}, + {[]message.Segment{message.Reply("fake reply"), message.Image("fake image"), message.At(999), message.At(124), message.Text("haha")}, NewPattern().Reply().Image().At().At(message.NewMessageIDFromInteger(124)).Text("(h)(a)(h)(a)"), []PatternParsed{ + + { + value: "fake reply", + }, + { + value: "fake image", + }, + { + value: "999", + }, + { + value: "124", + }, + { + value: []string{"haha", "h", "a", "h", "a"}, + }, + }}, + } + for i, v := range textTests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + ctx := fakeCtx(v.msg) + rule := v.pattern.AsRule() + matched := rule(ctx) + parsed := &PatternModel{} + err := ctx.Parse(parsed) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, true, matched) + for i := range parsed.Matched { + assert.Equal(t, v.expected[i].value, parsed.Matched[i].value) + assert.Equal(t, &(v.msg[i]), parsed.Matched[i].msg) + } + }) + } +}