Skip to content

Commit

Permalink
msgconv/from-matrix: add support for polls
Browse files Browse the repository at this point in the history
  • Loading branch information
tulir committed Oct 7, 2024
1 parent 21f62e3 commit 7a6eabf
Show file tree
Hide file tree
Showing 9 changed files with 189 additions and 95 deletions.
1 change: 0 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
prevented the bridge from writing to the config, you should update it
manually.
* Group management features and commands are not yet available.
* Polls are not yet supported.

# v0.10.9 (2024-07-16)

Expand Down
4 changes: 2 additions & 2 deletions ROADMAP.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
* [x] Location messages
* [x] Media/files
* [x] Replies
* [ ] Polls
* [ ] Poll votes
* [x] Polls
* [x] Poll votes
* [x] Message redactions
* [x] Reactions
* [x] Presence
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ require (
golang.org/x/sync v0.8.0
google.golang.org/protobuf v1.34.2
gopkg.in/yaml.v3 v3.0.1
maunium.net/go/mautrix v0.21.1-0.20241007110956-c4d8189d4742
maunium.net/go/mautrix v0.21.1-0.20241007131243-7a9269e8ff9f
)

require (
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,5 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA=
maunium.net/go/mautrix v0.21.1-0.20241007110956-c4d8189d4742 h1:ewmYZN7GY4XDd8jrHg/qUQu2/4tsEnks7GtAQnpfkm0=
maunium.net/go/mautrix v0.21.1-0.20241007110956-c4d8189d4742/go.mod h1:+fF5qsmXRCEXQZgW5ececC0PI3c7gISHTLcyftP4Bh0=
maunium.net/go/mautrix v0.21.1-0.20241007131243-7a9269e8ff9f h1:EB+aYheAuukwFKCb/125Baz7oWFGM9U1JdVONx1jxKI=
maunium.net/go/mautrix v0.21.1-0.20241007131243-7a9269e8ff9f/go.mod h1:+fF5qsmXRCEXQZgW5ececC0PI3c7gISHTLcyftP4Bh0=
1 change: 1 addition & 0 deletions pkg/connector/capabilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ var whatsappCaps = &bridgev2.NetworkRoomCapabilities{
LocationMessages: true,
Captions: true,
Replies: true,
Polls: true,
Edits: true,
EditMaxCount: 10,
EditMaxAge: 15 * time.Minute,
Expand Down
32 changes: 32 additions & 0 deletions pkg/connector/handlematrix.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,45 @@ var (
_ bridgev2.ReactionHandlingNetworkAPI = (*WhatsAppClient)(nil)
_ bridgev2.RedactionHandlingNetworkAPI = (*WhatsAppClient)(nil)
_ bridgev2.ReadReceiptHandlingNetworkAPI = (*WhatsAppClient)(nil)
_ bridgev2.PollHandlingNetworkAPI = (*WhatsAppClient)(nil)
)

func (wa *WhatsAppClient) HandleMatrixPollStart(ctx context.Context, msg *bridgev2.MatrixPollStart) (*bridgev2.MatrixMessageResponse, error) {
waMsg, optionMap, err := wa.Main.MsgConv.PollStartToWhatsApp(ctx, msg.Content, msg.ReplyTo, msg.Portal)
if err != nil {
return nil, fmt.Errorf("failed to convert poll vote: %w", err)
}
resp, err := wa.handleConvertedMatrixMessage(ctx, &msg.MatrixMessage, waMsg)
if err != nil {
return nil, err
}
resp.DB.Metadata.(*waid.MessageMetadata).IsMatrixPoll = true
resp.PostSave = func(ctx context.Context, message *database.Message) {
err := wa.Main.DB.PollOption.Put(ctx, message.MXID, optionMap)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to save poll options")
}
}
return resp, nil
}

func (wa *WhatsAppClient) HandleMatrixPollVote(ctx context.Context, msg *bridgev2.MatrixPollVote) (*bridgev2.MatrixMessageResponse, error) {
waMsg, err := wa.Main.MsgConv.PollVoteToWhatsApp(ctx, wa.Client, msg.Content, msg.VoteTo)
if err != nil {
return nil, fmt.Errorf("failed to convert poll vote: %w", err)
}
return wa.handleConvertedMatrixMessage(ctx, &msg.MatrixMessage, waMsg)
}

func (wa *WhatsAppClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.MatrixMessage) (*bridgev2.MatrixMessageResponse, error) {
waMsg, err := wa.Main.MsgConv.ToWhatsApp(ctx, wa.Client, msg.Event, msg.Content, msg.ReplyTo, msg.Portal)
if err != nil {
return nil, fmt.Errorf("failed to convert message: %w", err)
}
return wa.handleConvertedMatrixMessage(ctx, msg, waMsg)
}

func (wa *WhatsAppClient) handleConvertedMatrixMessage(ctx context.Context, msg *bridgev2.MatrixMessage, waMsg *waE2E.Message) (*bridgev2.MatrixMessageResponse, error) {
messageID := wa.Client.GenerateMessageID()
chatJID, err := waid.ParsePortalID(msg.Portal.ID)
if err != nil {
Expand Down
48 changes: 30 additions & 18 deletions pkg/msgconv/from-matrix.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,8 @@ import (
"maunium.net/go/mautrix-whatsapp/pkg/waid"
)

func (mc *MessageConverter) ToWhatsApp(
ctx context.Context,
client *whatsmeow.Client,
evt *event.Event,
content *event.MessageEventContent,
replyTo *database.Message,
portal *bridgev2.Portal,
) (*waE2E.Message, error) {
ctx = context.WithValue(ctx, contextKeyClient, client)
ctx = context.WithValue(ctx, contextKeyPortal, portal)
if evt.Type == event.EventSticker {
content.MsgType = event.MsgImage
}

message := &waE2E.Message{}
func (mc *MessageConverter) generateContextInfo(replyTo *database.Message, portal *bridgev2.Portal) (*waE2E.ContextInfo, error) {
contextInfo := &waE2E.ContextInfo{}

if replyTo != nil {
msgID, err := waid.ParseMessageID(replyTo.ID)
if err == nil {
Expand All @@ -81,6 +66,28 @@ func (mc *MessageConverter) ToWhatsApp(
contextInfo.EphemeralSettingTimestamp = ptr.Ptr(setAt)
}
}
return contextInfo, nil
}

func (mc *MessageConverter) ToWhatsApp(
ctx context.Context,
client *whatsmeow.Client,
evt *event.Event,
content *event.MessageEventContent,
replyTo *database.Message,
portal *bridgev2.Portal,
) (*waE2E.Message, error) {
ctx = context.WithValue(ctx, contextKeyClient, client)
ctx = context.WithValue(ctx, contextKeyPortal, portal)
if evt.Type == event.EventSticker {
content.MsgType = event.MsgImage
}

message := &waE2E.Message{}
contextInfo, err := mc.generateContextInfo(replyTo, portal)
if err != nil {
return nil, err
}

switch content.MsgType {
case event.MsgText, event.MsgNotice, event.MsgEmote:
Expand Down Expand Up @@ -248,7 +255,8 @@ func (mc *MessageConverter) parseText(ctx context.Context, content *event.Messag
mentions = make([]string, 0)

parseCtx := format.NewContext(ctx)
parseCtx.ReturnData["mentions"] = &mentions
parseCtx.ReturnData["allowed_mentions"] = content.Mentions
parseCtx.ReturnData["output_mentions"] = &mentions
if content.Format == event.FormatHTML {
text = mc.HTMLParser.Parse(content.FormattedBody, parseCtx)
} else {
Expand All @@ -275,6 +283,10 @@ func (mc *MessageConverter) convertPill(displayname, mxid, eventID string, ctx f
if len(mxid) == 0 || mxid[0] != '@' {
return format.DefaultPillConverter(displayname, mxid, eventID, ctx)
}
allowedMentions, _ := ctx.ReturnData["allowed_mentions"].(*event.Mentions)
if allowedMentions != nil && !allowedMentions.Has(id.UserID(mxid)) {
return displayname
}
var jid types.JID
ghost, err := mc.Bridge.GetGhostByMXID(ctx.Ctx, id.UserID(mxid))
if err != nil {
Expand All @@ -295,7 +307,7 @@ func (mc *MessageConverter) convertPill(displayname, mxid, eventID string, ctx f
} else {
return displayname
}
mentions := ctx.ReturnData["mentions"].(*[]string)
mentions := ctx.ReturnData["output_mentions"].(*[]string)
*mentions = append(*mentions, jid.String())
return fmt.Sprintf("@%s", jid.User)
}
Expand Down
186 changes: 118 additions & 68 deletions pkg/msgconv/matrixpoll.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,92 +17,142 @@
package msgconv

import (
"reflect"
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"

"github.com/rs/zerolog"
"go.mau.fi/util/ptr"
"go.mau.fi/util/random"
"go.mau.fi/whatsmeow"
"go.mau.fi/whatsmeow/proto/waE2E"
"go.mau.fi/whatsmeow/types"

"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/bridgev2/database"
"maunium.net/go/mautrix/event"
)
"maunium.net/go/mautrix/format"

var (
TypeMSC3381PollStart = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.poll.start"}
TypeMSC3381PollResponse = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.poll.response"}
"maunium.net/go/mautrix-whatsapp/pkg/waid"
)

type PollResponseContent struct {
RelatesTo event.RelatesTo `json:"m.relates_to"`
V1Response struct {
Answers []string `json:"answers"`
} `json:"org.matrix.msc3381.poll.response"`
V2Selections []string `json:"org.matrix.msc3381.v2.selections"`
}

func (content *PollResponseContent) GetRelatesTo() *event.RelatesTo {
return &content.RelatesTo
}

func (content *PollResponseContent) OptionalGetRelatesTo() *event.RelatesTo {
if content.RelatesTo.Type == "" {
return nil
}
return &content.RelatesTo
}

func (content *PollResponseContent) SetRelatesTo(rel *event.RelatesTo) {
content.RelatesTo = *rel
}

type MSC1767Message struct {
Text string `json:"org.matrix.msc1767.text,omitempty"`
HTML string `json:"org.matrix.msc1767.html,omitempty"`
Message []struct {
MimeType string `json:"mimetype"`
Body string `json:"body"`
} `json:"org.matrix.msc1767.message,omitempty"`
}

//lint:ignore U1000 Unused function
func msc1767ToWhatsApp(msg MSC1767Message) string {
func (mc *MessageConverter) msc1767ToWhatsApp(ctx context.Context, msg event.MSC1767Message, allowedMentions *event.Mentions) (string, []string) {
for _, part := range msg.Message {
if part.MimeType == "text/html" && msg.HTML == "" {
msg.HTML = part.Body
} else if part.MimeType == "text/plain" && msg.Text == "" {
msg.Text = part.Body
}
}
mentions := make([]string, 0)
if msg.HTML != "" {
return parseWAFormattingToHTML(msg.HTML, false)
parseCtx := format.NewContext(ctx)
parseCtx.ReturnData["allowed_mentions"] = allowedMentions
parseCtx.ReturnData["output_mentions"] = &mentions
return mc.HTMLParser.Parse(msg.HTML, parseCtx), mentions
}
return msg.Text
return msg.Text, mentions
}

type PollStartContent struct {
RelatesTo *event.RelatesTo `json:"m.relates_to"`
PollStart struct {
Kind string `json:"kind"`
MaxSelections int `json:"max_selections"`
Question MSC1767Message `json:"question"`
Answers []struct {
ID string `json:"id"`
MSC1767Message
} `json:"answers"`
} `json:"org.matrix.msc3381.poll.start"`
}
var (
errPollMissingQuestion = bridgev2.WrapErrorInStatus(errors.New("poll message is missing question")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
errPollDuplicateOption = bridgev2.WrapErrorInStatus(errors.New("poll options must be unique")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
)

func (content *PollStartContent) GetRelatesTo() *event.RelatesTo {
if content.RelatesTo == nil {
content.RelatesTo = &event.RelatesTo{}
func (mc *MessageConverter) PollStartToWhatsApp(
ctx context.Context,
content *event.PollStartEventContent,
replyTo *database.Message,
portal *bridgev2.Portal,
) (*waE2E.Message, map[[32]byte]string, error) {
maxAnswers := content.PollStart.MaxSelections
if maxAnswers >= len(content.PollStart.Answers) || maxAnswers < 0 {
maxAnswers = 0
}
return content.RelatesTo
}

func (content *PollStartContent) OptionalGetRelatesTo() *event.RelatesTo {
return content.RelatesTo
}

func (content *PollStartContent) SetRelatesTo(rel *event.RelatesTo) {
content.RelatesTo = rel
contextInfo, err := mc.generateContextInfo(replyTo, portal)
if err != nil {
return nil, nil, err
}
var question string
question, contextInfo.MentionedJID = mc.msc1767ToWhatsApp(ctx, content.PollStart.Question, content.Mentions)
if len(question) == 0 {
return nil, nil, errPollMissingQuestion
}
options := make([]*waE2E.PollCreationMessage_Option, len(content.PollStart.Answers))
optionMap := make(map[[32]byte]string, len(options))
for i, opt := range content.PollStart.Answers {
body, _ := mc.msc1767ToWhatsApp(ctx, opt.MSC1767Message, &event.Mentions{})
hash := sha256.Sum256([]byte(body))
if _, alreadyExists := optionMap[hash]; alreadyExists {
zerolog.Ctx(ctx).Warn().Str("option", body).Msg("Poll has duplicate options, rejecting")
return nil, nil, errPollDuplicateOption
}
optionMap[hash] = opt.ID
options[i] = &waE2E.PollCreationMessage_Option{
OptionName: ptr.Ptr(body),
}
}
return &waE2E.Message{
PollCreationMessage: &waE2E.PollCreationMessage{
Name: ptr.Ptr(question),
Options: options,
SelectableOptionsCount: ptr.Ptr(uint32(maxAnswers)),
ContextInfo: contextInfo,
},
MessageContextInfo: &waE2E.MessageContextInfo{
MessageSecret: random.Bytes(32),
},
}, optionMap, nil
}

func init() {
event.TypeMap[TypeMSC3381PollResponse] = reflect.TypeOf(PollResponseContent{})
event.TypeMap[TypeMSC3381PollStart] = reflect.TypeOf(PollStartContent{})
func (mc *MessageConverter) PollVoteToWhatsApp(
ctx context.Context,
client *whatsmeow.Client,
content *event.PollResponseEventContent,
pollMsg *database.Message,
) (*waE2E.Message, error) {
parsedMsgID, err := waid.ParseMessageID(pollMsg.ID)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to parse message ID")
return nil, fmt.Errorf("failed to parse message ID")
}
pollMsgInfo := &types.MessageInfo{
MessageSource: types.MessageSource{
Chat: parsedMsgID.Chat,
Sender: parsedMsgID.Sender,
IsFromMe: parsedMsgID.Sender.User == client.Store.ID.User,
IsGroup: parsedMsgID.Chat.Server == types.GroupServer,
},
ID: parsedMsgID.ID,
Type: "poll",
}
optionHashes := make([][]byte, 0, len(content.Response.Answers))
if pollMsg.Metadata.(*waid.MessageMetadata).IsMatrixPoll {
mappedAnswers, err := mc.DB.PollOption.GetHashes(ctx, pollMsg.MXID, content.Response.Answers)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get poll option hashes from database")
return nil, fmt.Errorf("failed to get poll option hashes")
}
for _, selection := range content.Response.Answers {
hash, ok := mappedAnswers[selection]
if ok {
optionHashes = append(optionHashes, hash[:])
} else {
zerolog.Ctx(ctx).Warn().Str("option", selection).Msg("Didn't find hash for selected option")
}
}
} else {
for _, selection := range content.Response.Answers {
hash, _ := hex.DecodeString(selection)
if len(hash) == 32 {
optionHashes = append(optionHashes, hash)
}
}
}
pollUpdate, err := client.EncryptPollVote(pollMsgInfo, &waE2E.PollVoteMessage{
SelectedOptions: optionHashes,
})
return &waE2E.Message{PollUpdateMessage: pollUpdate}, err
}
Loading

0 comments on commit 7a6eabf

Please sign in to comment.