From 7a6eabf25f5bcef803d8d368fb9eb0c20b9fd5f7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 7 Oct 2024 16:14:58 +0300 Subject: [PATCH] msgconv/from-matrix: add support for polls --- CHANGELOG.md | 1 - ROADMAP.md | 4 +- go.mod | 2 +- go.sum | 4 +- pkg/connector/capabilities.go | 1 + pkg/connector/handlematrix.go | 32 ++++++ pkg/msgconv/from-matrix.go | 48 +++++---- pkg/msgconv/matrixpoll.go | 186 +++++++++++++++++++++------------- pkg/msgconv/wa-poll.go | 6 +- 9 files changed, 189 insertions(+), 95 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fdf10190..4ada1271 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,6 @@ prevented the bridge from writing to the config, you should update it manually. * Group management features and commands are not yet available. - * Polls are not yet supported. # v0.10.9 (2024-07-16) diff --git a/ROADMAP.md b/ROADMAP.md index ad1b7cca..91f43710 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -6,8 +6,8 @@ * [x] Location messages * [x] Media/files * [x] Replies - * [ ] Polls - * [ ] Poll votes + * [x] Polls + * [x] Poll votes * [x] Message redactions * [x] Reactions * [x] Presence diff --git a/go.mod b/go.mod index 373aa910..9b4bb0bd 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( golang.org/x/sync v0.8.0 google.golang.org/protobuf v1.34.2 gopkg.in/yaml.v3 v3.0.1 - maunium.net/go/mautrix v0.21.1-0.20241007110956-c4d8189d4742 + maunium.net/go/mautrix v0.21.1-0.20241007131243-7a9269e8ff9f ) require ( diff --git a/go.sum b/go.sum index e859272a..54264d63 100644 --- a/go.sum +++ b/go.sum @@ -101,5 +101,5 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= -maunium.net/go/mautrix v0.21.1-0.20241007110956-c4d8189d4742 h1:ewmYZN7GY4XDd8jrHg/qUQu2/4tsEnks7GtAQnpfkm0= -maunium.net/go/mautrix v0.21.1-0.20241007110956-c4d8189d4742/go.mod h1:+fF5qsmXRCEXQZgW5ececC0PI3c7gISHTLcyftP4Bh0= +maunium.net/go/mautrix v0.21.1-0.20241007131243-7a9269e8ff9f h1:EB+aYheAuukwFKCb/125Baz7oWFGM9U1JdVONx1jxKI= +maunium.net/go/mautrix v0.21.1-0.20241007131243-7a9269e8ff9f/go.mod h1:+fF5qsmXRCEXQZgW5ececC0PI3c7gISHTLcyftP4Bh0= diff --git a/pkg/connector/capabilities.go b/pkg/connector/capabilities.go index 344d136e..dc3e18f5 100644 --- a/pkg/connector/capabilities.go +++ b/pkg/connector/capabilities.go @@ -25,6 +25,7 @@ var whatsappCaps = &bridgev2.NetworkRoomCapabilities{ LocationMessages: true, Captions: true, Replies: true, + Polls: true, Edits: true, EditMaxCount: 10, EditMaxAge: 15 * time.Minute, diff --git a/pkg/connector/handlematrix.go b/pkg/connector/handlematrix.go index 006ed36e..6731229b 100644 --- a/pkg/connector/handlematrix.go +++ b/pkg/connector/handlematrix.go @@ -24,13 +24,45 @@ var ( _ bridgev2.ReactionHandlingNetworkAPI = (*WhatsAppClient)(nil) _ bridgev2.RedactionHandlingNetworkAPI = (*WhatsAppClient)(nil) _ bridgev2.ReadReceiptHandlingNetworkAPI = (*WhatsAppClient)(nil) + _ bridgev2.PollHandlingNetworkAPI = (*WhatsAppClient)(nil) ) +func (wa *WhatsAppClient) HandleMatrixPollStart(ctx context.Context, msg *bridgev2.MatrixPollStart) (*bridgev2.MatrixMessageResponse, error) { + waMsg, optionMap, err := wa.Main.MsgConv.PollStartToWhatsApp(ctx, msg.Content, msg.ReplyTo, msg.Portal) + if err != nil { + return nil, fmt.Errorf("failed to convert poll vote: %w", err) + } + resp, err := wa.handleConvertedMatrixMessage(ctx, &msg.MatrixMessage, waMsg) + if err != nil { + return nil, err + } + resp.DB.Metadata.(*waid.MessageMetadata).IsMatrixPoll = true + resp.PostSave = func(ctx context.Context, message *database.Message) { + err := wa.Main.DB.PollOption.Put(ctx, message.MXID, optionMap) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save poll options") + } + } + return resp, nil +} + +func (wa *WhatsAppClient) HandleMatrixPollVote(ctx context.Context, msg *bridgev2.MatrixPollVote) (*bridgev2.MatrixMessageResponse, error) { + waMsg, err := wa.Main.MsgConv.PollVoteToWhatsApp(ctx, wa.Client, msg.Content, msg.VoteTo) + if err != nil { + return nil, fmt.Errorf("failed to convert poll vote: %w", err) + } + return wa.handleConvertedMatrixMessage(ctx, &msg.MatrixMessage, waMsg) +} + func (wa *WhatsAppClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.MatrixMessage) (*bridgev2.MatrixMessageResponse, error) { waMsg, err := wa.Main.MsgConv.ToWhatsApp(ctx, wa.Client, msg.Event, msg.Content, msg.ReplyTo, msg.Portal) if err != nil { return nil, fmt.Errorf("failed to convert message: %w", err) } + return wa.handleConvertedMatrixMessage(ctx, msg, waMsg) +} + +func (wa *WhatsAppClient) handleConvertedMatrixMessage(ctx context.Context, msg *bridgev2.MatrixMessage, waMsg *waE2E.Message) (*bridgev2.MatrixMessageResponse, error) { messageID := wa.Client.GenerateMessageID() chatJID, err := waid.ParsePortalID(msg.Portal.ID) if err != nil { diff --git a/pkg/msgconv/from-matrix.go b/pkg/msgconv/from-matrix.go index c4f3acba..92247b87 100644 --- a/pkg/msgconv/from-matrix.go +++ b/pkg/msgconv/from-matrix.go @@ -47,23 +47,8 @@ import ( "maunium.net/go/mautrix-whatsapp/pkg/waid" ) -func (mc *MessageConverter) ToWhatsApp( - ctx context.Context, - client *whatsmeow.Client, - evt *event.Event, - content *event.MessageEventContent, - replyTo *database.Message, - portal *bridgev2.Portal, -) (*waE2E.Message, error) { - ctx = context.WithValue(ctx, contextKeyClient, client) - ctx = context.WithValue(ctx, contextKeyPortal, portal) - if evt.Type == event.EventSticker { - content.MsgType = event.MsgImage - } - - message := &waE2E.Message{} +func (mc *MessageConverter) generateContextInfo(replyTo *database.Message, portal *bridgev2.Portal) (*waE2E.ContextInfo, error) { contextInfo := &waE2E.ContextInfo{} - if replyTo != nil { msgID, err := waid.ParseMessageID(replyTo.ID) if err == nil { @@ -81,6 +66,28 @@ func (mc *MessageConverter) ToWhatsApp( contextInfo.EphemeralSettingTimestamp = ptr.Ptr(setAt) } } + return contextInfo, nil +} + +func (mc *MessageConverter) ToWhatsApp( + ctx context.Context, + client *whatsmeow.Client, + evt *event.Event, + content *event.MessageEventContent, + replyTo *database.Message, + portal *bridgev2.Portal, +) (*waE2E.Message, error) { + ctx = context.WithValue(ctx, contextKeyClient, client) + ctx = context.WithValue(ctx, contextKeyPortal, portal) + if evt.Type == event.EventSticker { + content.MsgType = event.MsgImage + } + + message := &waE2E.Message{} + contextInfo, err := mc.generateContextInfo(replyTo, portal) + if err != nil { + return nil, err + } switch content.MsgType { case event.MsgText, event.MsgNotice, event.MsgEmote: @@ -248,7 +255,8 @@ func (mc *MessageConverter) parseText(ctx context.Context, content *event.Messag mentions = make([]string, 0) parseCtx := format.NewContext(ctx) - parseCtx.ReturnData["mentions"] = &mentions + parseCtx.ReturnData["allowed_mentions"] = content.Mentions + parseCtx.ReturnData["output_mentions"] = &mentions if content.Format == event.FormatHTML { text = mc.HTMLParser.Parse(content.FormattedBody, parseCtx) } else { @@ -275,6 +283,10 @@ func (mc *MessageConverter) convertPill(displayname, mxid, eventID string, ctx f if len(mxid) == 0 || mxid[0] != '@' { return format.DefaultPillConverter(displayname, mxid, eventID, ctx) } + allowedMentions, _ := ctx.ReturnData["allowed_mentions"].(*event.Mentions) + if allowedMentions != nil && !allowedMentions.Has(id.UserID(mxid)) { + return displayname + } var jid types.JID ghost, err := mc.Bridge.GetGhostByMXID(ctx.Ctx, id.UserID(mxid)) if err != nil { @@ -295,7 +307,7 @@ func (mc *MessageConverter) convertPill(displayname, mxid, eventID string, ctx f } else { return displayname } - mentions := ctx.ReturnData["mentions"].(*[]string) + mentions := ctx.ReturnData["output_mentions"].(*[]string) *mentions = append(*mentions, jid.String()) return fmt.Sprintf("@%s", jid.User) } diff --git a/pkg/msgconv/matrixpoll.go b/pkg/msgconv/matrixpoll.go index f2298e04..03ab714d 100644 --- a/pkg/msgconv/matrixpoll.go +++ b/pkg/msgconv/matrixpoll.go @@ -17,50 +17,28 @@ package msgconv import ( - "reflect" + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "github.com/rs/zerolog" + "go.mau.fi/util/ptr" + "go.mau.fi/util/random" + "go.mau.fi/whatsmeow" + "go.mau.fi/whatsmeow/proto/waE2E" + "go.mau.fi/whatsmeow/types" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/event" -) + "maunium.net/go/mautrix/format" -var ( - TypeMSC3381PollStart = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.poll.start"} - TypeMSC3381PollResponse = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.poll.response"} + "maunium.net/go/mautrix-whatsapp/pkg/waid" ) -type PollResponseContent struct { - RelatesTo event.RelatesTo `json:"m.relates_to"` - V1Response struct { - Answers []string `json:"answers"` - } `json:"org.matrix.msc3381.poll.response"` - V2Selections []string `json:"org.matrix.msc3381.v2.selections"` -} - -func (content *PollResponseContent) GetRelatesTo() *event.RelatesTo { - return &content.RelatesTo -} - -func (content *PollResponseContent) OptionalGetRelatesTo() *event.RelatesTo { - if content.RelatesTo.Type == "" { - return nil - } - return &content.RelatesTo -} - -func (content *PollResponseContent) SetRelatesTo(rel *event.RelatesTo) { - content.RelatesTo = *rel -} - -type MSC1767Message struct { - Text string `json:"org.matrix.msc1767.text,omitempty"` - HTML string `json:"org.matrix.msc1767.html,omitempty"` - Message []struct { - MimeType string `json:"mimetype"` - Body string `json:"body"` - } `json:"org.matrix.msc1767.message,omitempty"` -} - -//lint:ignore U1000 Unused function -func msc1767ToWhatsApp(msg MSC1767Message) string { +func (mc *MessageConverter) msc1767ToWhatsApp(ctx context.Context, msg event.MSC1767Message, allowedMentions *event.Mentions) (string, []string) { for _, part := range msg.Message { if part.MimeType == "text/html" && msg.HTML == "" { msg.HTML = part.Body @@ -68,41 +46,113 @@ func msc1767ToWhatsApp(msg MSC1767Message) string { msg.Text = part.Body } } + mentions := make([]string, 0) if msg.HTML != "" { - return parseWAFormattingToHTML(msg.HTML, false) + parseCtx := format.NewContext(ctx) + parseCtx.ReturnData["allowed_mentions"] = allowedMentions + parseCtx.ReturnData["output_mentions"] = &mentions + return mc.HTMLParser.Parse(msg.HTML, parseCtx), mentions } - return msg.Text + return msg.Text, mentions } -type PollStartContent struct { - RelatesTo *event.RelatesTo `json:"m.relates_to"` - PollStart struct { - Kind string `json:"kind"` - MaxSelections int `json:"max_selections"` - Question MSC1767Message `json:"question"` - Answers []struct { - ID string `json:"id"` - MSC1767Message - } `json:"answers"` - } `json:"org.matrix.msc3381.poll.start"` -} +var ( + errPollMissingQuestion = bridgev2.WrapErrorInStatus(errors.New("poll message is missing question")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) + errPollDuplicateOption = bridgev2.WrapErrorInStatus(errors.New("poll options must be unique")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) +) -func (content *PollStartContent) GetRelatesTo() *event.RelatesTo { - if content.RelatesTo == nil { - content.RelatesTo = &event.RelatesTo{} +func (mc *MessageConverter) PollStartToWhatsApp( + ctx context.Context, + content *event.PollStartEventContent, + replyTo *database.Message, + portal *bridgev2.Portal, +) (*waE2E.Message, map[[32]byte]string, error) { + maxAnswers := content.PollStart.MaxSelections + if maxAnswers >= len(content.PollStart.Answers) || maxAnswers < 0 { + maxAnswers = 0 } - return content.RelatesTo -} - -func (content *PollStartContent) OptionalGetRelatesTo() *event.RelatesTo { - return content.RelatesTo -} - -func (content *PollStartContent) SetRelatesTo(rel *event.RelatesTo) { - content.RelatesTo = rel + contextInfo, err := mc.generateContextInfo(replyTo, portal) + if err != nil { + return nil, nil, err + } + var question string + question, contextInfo.MentionedJID = mc.msc1767ToWhatsApp(ctx, content.PollStart.Question, content.Mentions) + if len(question) == 0 { + return nil, nil, errPollMissingQuestion + } + options := make([]*waE2E.PollCreationMessage_Option, len(content.PollStart.Answers)) + optionMap := make(map[[32]byte]string, len(options)) + for i, opt := range content.PollStart.Answers { + body, _ := mc.msc1767ToWhatsApp(ctx, opt.MSC1767Message, &event.Mentions{}) + hash := sha256.Sum256([]byte(body)) + if _, alreadyExists := optionMap[hash]; alreadyExists { + zerolog.Ctx(ctx).Warn().Str("option", body).Msg("Poll has duplicate options, rejecting") + return nil, nil, errPollDuplicateOption + } + optionMap[hash] = opt.ID + options[i] = &waE2E.PollCreationMessage_Option{ + OptionName: ptr.Ptr(body), + } + } + return &waE2E.Message{ + PollCreationMessage: &waE2E.PollCreationMessage{ + Name: ptr.Ptr(question), + Options: options, + SelectableOptionsCount: ptr.Ptr(uint32(maxAnswers)), + ContextInfo: contextInfo, + }, + MessageContextInfo: &waE2E.MessageContextInfo{ + MessageSecret: random.Bytes(32), + }, + }, optionMap, nil } -func init() { - event.TypeMap[TypeMSC3381PollResponse] = reflect.TypeOf(PollResponseContent{}) - event.TypeMap[TypeMSC3381PollStart] = reflect.TypeOf(PollStartContent{}) +func (mc *MessageConverter) PollVoteToWhatsApp( + ctx context.Context, + client *whatsmeow.Client, + content *event.PollResponseEventContent, + pollMsg *database.Message, +) (*waE2E.Message, error) { + parsedMsgID, err := waid.ParseMessageID(pollMsg.ID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to parse message ID") + return nil, fmt.Errorf("failed to parse message ID") + } + pollMsgInfo := &types.MessageInfo{ + MessageSource: types.MessageSource{ + Chat: parsedMsgID.Chat, + Sender: parsedMsgID.Sender, + IsFromMe: parsedMsgID.Sender.User == client.Store.ID.User, + IsGroup: parsedMsgID.Chat.Server == types.GroupServer, + }, + ID: parsedMsgID.ID, + Type: "poll", + } + optionHashes := make([][]byte, 0, len(content.Response.Answers)) + if pollMsg.Metadata.(*waid.MessageMetadata).IsMatrixPoll { + mappedAnswers, err := mc.DB.PollOption.GetHashes(ctx, pollMsg.MXID, content.Response.Answers) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get poll option hashes from database") + return nil, fmt.Errorf("failed to get poll option hashes") + } + for _, selection := range content.Response.Answers { + hash, ok := mappedAnswers[selection] + if ok { + optionHashes = append(optionHashes, hash[:]) + } else { + zerolog.Ctx(ctx).Warn().Str("option", selection).Msg("Didn't find hash for selected option") + } + } + } else { + for _, selection := range content.Response.Answers { + hash, _ := hex.DecodeString(selection) + if len(hash) == 32 { + optionHashes = append(optionHashes, hash) + } + } + } + pollUpdate, err := client.EncryptPollVote(pollMsgInfo, &waE2E.PollVoteMessage{ + SelectedOptions: optionHashes, + }) + return &waE2E.Message{PollUpdateMessage: pollUpdate}, err } diff --git a/pkg/msgconv/wa-poll.go b/pkg/msgconv/wa-poll.go index 8b3e235f..e11fc443 100644 --- a/pkg/msgconv/wa-poll.go +++ b/pkg/msgconv/wa-poll.go @@ -60,7 +60,7 @@ func (mc *MessageConverter) convertPollCreationMessage(ctx context.Context, msg } evtType := event.EventMessage if mc.ExtEvPolls { - evtType = TypeMSC3381PollStart + evtType = event.EventUnstablePollStart } return &bridgev2.ConvertedMessagePart{ @@ -130,7 +130,7 @@ func (mc *MessageConverter) keyToMessageID(ctx context.Context, chat, sender typ } var failedPollUpdatePart = &bridgev2.ConvertedMessagePart{ - Type: TypeMSC3381PollResponse, + Type: event.EventUnstablePollResponse, Content: &event.MessageEventContent{}, DontBridge: true, } @@ -176,7 +176,7 @@ func (mc *MessageConverter) convertPollUpdateMessage(ctx context.Context, info * } } return &bridgev2.ConvertedMessagePart{ - Type: TypeMSC3381PollResponse, + Type: event.EventUnstablePollResponse, Content: &event.MessageEventContent{ RelatesTo: &event.RelatesTo{ Type: event.RelReference,