From b330c5836e4b9de52e627637b7f711ed1373b00f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 29 Nov 2024 19:25:43 +0200 Subject: [PATCH] client: set referers properly --- attachments.go | 6 ++++++ backfill.go | 4 ++-- commands.go | 4 ++-- commands_botinteraction.go | 6 +++--- directmedia.go | 6 +++++- go.mod | 2 +- go.sum | 4 ++-- portal.go | 30 ++++++++++++++++++++++-------- thread.go | 6 +++++- 9 files changed, 48 insertions(+), 20 deletions(-) diff --git a/attachments.go b/attachments.go index d710849..50d3c0b 100644 --- a/attachments.go +++ b/attachments.go @@ -73,6 +73,12 @@ func uploadDiscordAttachment(url string, data []byte) error { for key, value := range discordgo.DroidFetchHeaders { req.Header.Set(key, value) } + req.Header.Set("Content-Type", "application/octet-stream") + req.Header.Set("Referer", "https://discord.com/") + req.Header.Del("X-Debug-Options") + req.Header.Del("X-Discord-Locale") + req.Header.Del("X-Discord-Timezone") + req.Header.Del("X-Super-Properties") resp, err := http.DefaultClient.Do(req) if err != nil { diff --git a/backfill.go b/backfill.go index 122181b..c4966bd 100644 --- a/backfill.go +++ b/backfill.go @@ -117,7 +117,7 @@ func (portal *Portal) collectBackfillMessages(log zerolog.Logger, source *User, } for { log.Debug().Str("before_id", before).Msg("Fetching messages for backfill") - newMessages, err := source.Session.ChannelMessages(protoChannelID, messageFetchChunkSize, before, "", "") + newMessages, err := source.Session.ChannelMessages(protoChannelID, messageFetchChunkSize, before, "", "", portal.RefererOptIfUser(source.Session, protoChannelID)...) if err != nil { return nil, false, err } @@ -183,7 +183,7 @@ func (portal *Portal) backfillUnlimitedMissed(log zerolog.Logger, source *User, } for { log.Debug().Str("after_id", after).Msg("Fetching chunk of messages to backfill") - messages, err := source.Session.ChannelMessages(protoChannelID, messageFetchChunkSize, "", after, "") + messages, err := source.Session.ChannelMessages(protoChannelID, messageFetchChunkSize, "", after, "", portal.RefererOptIfUser(source.Session, protoChannelID)...) if err != nil { log.Err(err).Msg("Error fetching chunk of messages to forward backfill") return diff --git a/commands.go b/commands.go index ab407ed..060a4ab 100644 --- a/commands.go +++ b/commands.go @@ -468,7 +468,7 @@ func fnSetRelay(ce *WrappedCommandEvent) { return } case "create": - perms, err := ce.User.Session.UserChannelPermissions(ce.User.DiscordID, portal.Key.ChannelID) + perms, err := ce.User.Session.UserChannelPermissions(ce.User.DiscordID, portal.Key.ChannelID, portal.RefererOptIfUser(ce.User.Session, "")...) if err != nil { log.Warn().Err(err).Msg("Failed to check user permissions") ce.Reply("Failed to check if you have permission to create webhooks") @@ -483,7 +483,7 @@ func fnSetRelay(ce *WrappedCommandEvent) { name = strings.Join(ce.Args[1:], " ") } log.Debug().Str("webhook_name", name).Msg("Creating webhook") - webhookMeta, err = ce.User.Session.WebhookCreate(portal.Key.ChannelID, name, "") + webhookMeta, err = ce.User.Session.WebhookCreate(portal.Key.ChannelID, name, "", portal.RefererOptIfUser(ce.User.Session, "")...) if err != nil { log.Warn().Err(err).Msg("Failed to create webhook") ce.Reply("Failed to create webhook: %v", err) diff --git a/commands_botinteraction.go b/commands_botinteraction.go index 28a1340..8dd585a 100644 --- a/commands_botinteraction.go +++ b/commands_botinteraction.go @@ -61,7 +61,7 @@ func (portal *Portal) getCommand(user *User, command string) (*discordgo.Applica defer portal.commandsLock.Unlock() cmd, ok := portal.commands[command] if !ok { - results, err := user.Session.ApplicationCommandsSearch(portal.Key.ChannelID, command) + results, err := user.Session.ApplicationCommandsSearch(portal.Key.ChannelID, command, portal.RefererOpt("")) if err != nil { return nil, err } @@ -247,7 +247,7 @@ func fnCommands(ce *WrappedCommandEvent) { } subcmd := strings.ToLower(ce.Args[0]) if subcmd == "search" { - results, err := ce.User.Session.ApplicationCommandsSearch(ce.Portal.Key.ChannelID, ce.Args[1]) + results, err := ce.User.Session.ApplicationCommandsSearch(ce.Portal.Key.ChannelID, ce.Args[1], ce.Portal.RefererOpt("")) if err != nil { ce.Reply("Error searching for commands: %v", err) return @@ -297,7 +297,7 @@ func fnExec(ce *WrappedCommandEvent) { ce.User.pendingInteractionsLock.Lock() ce.User.pendingInteractions[nonce] = ce ce.User.pendingInteractionsLock.Unlock() - err = ce.User.Session.SendInteractions(ce.Portal.GuildID, ce.Portal.Key.ChannelID, cmd, options, nonce) + err = ce.User.Session.SendInteractions(ce.Portal.GuildID, ce.Portal.Key.ChannelID, cmd, options, nonce, ce.Portal.RefererOpt("")) if err != nil { ce.Reply("Error sending interaction: %v", err) ce.User.pendingInteractionsLock.Lock() diff --git a/directmedia.go b/directmedia.go index c6f8a2b..4499c1a 100644 --- a/directmedia.go +++ b/directmedia.go @@ -357,7 +357,11 @@ func (dma *DirectMediaAPI) fetchNewAttachmentURL(ctx context.Context, meta *Atta var err error messageIDStr := strconv.FormatUint(meta.MessageID, 10) if client.IsUser { - msgs, err = client.ChannelMessages(channelIDStr, 5, "", "", messageIDStr) + var refs []discordgo.RequestOption + if portal != nil { + refs = append(refs, discordgo.WithChannelReferer(portal.GuildID, channelIDStr)) + } + msgs, err = client.ChannelMessages(channelIDStr, 5, "", "", messageIDStr, refs...) } else { var msg *discordgo.Message msg, err = client.ChannelMessage(channelIDStr, messageIDStr) diff --git a/go.mod b/go.mod index 272fe12..c6e4cee 100644 --- a/go.mod +++ b/go.mod @@ -42,4 +42,4 @@ require ( maunium.net/go/mauflag v1.0.0 // indirect ) -replace github.com/bwmarrin/discordgo => github.com/beeper/discordgo v0.0.0-20241121222213-424cfdb527f5 +replace github.com/bwmarrin/discordgo => github.com/beeper/discordgo v0.0.0-20241129150404-0ddeff8635e8 diff --git a/go.sum b/go.sum index 8ad3197..978f572 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= -github.com/beeper/discordgo v0.0.0-20241121222213-424cfdb527f5 h1:3w2UYUgbHJ8HJhmOi4fNSYRPV25rmoCcOoV1T74bzRI= -github.com/beeper/discordgo v0.0.0-20241121222213-424cfdb527f5/go.mod h1:59+AOzzjmL6onAh62nuLXmn7dJCaC/owDLWbGtjTcFA= +github.com/beeper/discordgo v0.0.0-20241129150404-0ddeff8635e8 h1:pJeDjlzwk6z6XTC1N54QNdAplXlnQ+Er+tO6ogquj0Q= +github.com/beeper/discordgo v0.0.0-20241129150404-0ddeff8635e8/go.mod h1:59+AOzzjmL6onAh62nuLXmn7dJCaC/owDLWbGtjTcFA= github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= diff --git a/portal.go b/portal.go index 09c1853..2c91156 100644 --- a/portal.go +++ b/portal.go @@ -1167,7 +1167,7 @@ func (portal *Portal) startThreadFromMatrix(sender *User, threadRoot id.EventID) AutoArchiveDuration: 24 * 60, Type: discordgo.ChannelTypeGuildPublicThread, Location: "Message", - }) + }, portal.RefererOptIfUser(sender.Session, "")...) if err != nil { return "", fmt.Errorf("error starting thread: %v", err) } @@ -1497,6 +1497,20 @@ func (portal *Portal) convertReplyMessageToEmbed(eventID id.EventID, url string) return embed, nil } +func (portal *Portal) RefererOpt(threadID string) discordgo.RequestOption { + if threadID != "" && threadID != portal.Key.ChannelID { + return discordgo.WithThreadReferer(portal.GuildID, portal.Key.ChannelID, threadID) + } + return discordgo.WithChannelReferer(portal.GuildID, portal.Key.ChannelID) +} + +func (portal *Portal) RefererOptIfUser(sess *discordgo.Session, threadID string) []discordgo.RequestOption { + if sess == nil || !sess.IsUser { + return nil + } + return []discordgo.RequestOption{portal.RefererOpt(threadID)} +} + func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) { if portal.IsPrivateChat() && sender.DiscordID != portal.Key.Receiver { go portal.sendMessageMetrics(evt, errUserNotReceiver, "Ignoring") @@ -1626,7 +1640,7 @@ func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) { Name: att.Filename, ID: sender.NextDiscordUploadID(), }}, - }) + }, portal.RefererOpt(threadID)) if err != nil { go portal.sendMessageMetrics(evt, err, "Error preparing to reupload media in") return @@ -1667,7 +1681,7 @@ func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) { var msg *discordgo.Message var err error if !isWebhookSend { - msg, err = sess.ChannelMessageSendComplex(channelID, &sendReq) + msg, err = sess.ChannelMessageSendComplex(channelID, &sendReq, portal.RefererOptIfUser(sess, threadID)...) } else { username, avatarURL := portal.getRelayUserMeta(sender) msg, err = relayClient.WebhookThreadExecute(portal.RelayWebhookID, portal.RelayWebhookSecret, true, threadID, &discordgo.WebhookParams{ @@ -1915,7 +1929,7 @@ func (portal *Portal) handleMatrixReaction(sender *User, evt *event.Event) { return } - err := sender.Session.MessageReactionAdd(msg.DiscordProtoChannelID(), msg.DiscordID, emojiID) + err := sender.Session.MessageReactionAddUser(portal.GuildID, msg.DiscordProtoChannelID(), msg.DiscordID, emojiID) go portal.sendMessageMetrics(evt, err, "Error sending") if err == nil { dbReaction := portal.bridge.DB.Reaction.New() @@ -2051,7 +2065,7 @@ func (portal *Portal) handleMatrixRedaction(sender *User, evt *event.Event) { var err error // TODO add support for deleting individual attachments from messages if sess != nil { - err = sess.ChannelMessageDelete(message.DiscordProtoChannelID(), message.DiscordID) + err = sess.ChannelMessageDelete(message.DiscordProtoChannelID(), message.DiscordID, portal.RefererOptIfUser(sess, message.ThreadID)...) } else { // TODO pre-validate that the message was sent by the webhook? err = relayClient.WebhookMessageDelete(portal.RelayWebhookID, portal.RelayWebhookSecret, message.DiscordID) @@ -2066,7 +2080,7 @@ func (portal *Portal) handleMatrixRedaction(sender *User, evt *event.Event) { if sess != nil { reaction := portal.bridge.DB.Reaction.GetByMXID(evt.Redacts) if reaction != nil && reaction.Channel == portal.Key { - err := sess.MessageReactionRemove(reaction.DiscordProtoChannelID(), reaction.MessageID, reaction.EmojiName, reaction.Sender) + err := sess.MessageReactionRemoveUser(portal.GuildID, reaction.DiscordProtoChannelID(), reaction.MessageID, reaction.EmojiName, reaction.Sender) go portal.sendMessageMetrics(evt, err, "Error sending") if err == nil { reaction.Delete() @@ -2135,7 +2149,7 @@ func (portal *Portal) HandleMatrixReadReceipt(brUser bridge.User, eventID id.Eve Msg("Dropping read receipt: thread ID mismatch") return } - resp, err := sender.Session.ChannelMessageAckNoToken(msg.DiscordProtoChannelID(), msg.DiscordID) + resp, err := sender.Session.ChannelMessageAckNoToken(msg.DiscordProtoChannelID(), msg.DiscordID, portal.RefererOpt(msg.DiscordProtoChannelID())) if err != nil { log.Err(err).Msg("Failed to send read receipt to Discord") } else if resp.Token != nil { @@ -2169,7 +2183,7 @@ func (portal *Portal) HandleMatrixTyping(newTyping []id.UserID) { user := portal.bridge.GetUserByMXID(userID) if user != nil && user.Session != nil { user.ViewingChannel(portal) - err := user.Session.ChannelTyping(portal.Key.ChannelID) + err := user.Session.ChannelTyping(portal.Key.ChannelID, portal.RefererOptIfUser(user.Session, "")...) if err != nil { portal.log.Warn().Err(err). Str("user_id", user.MXID.String()). diff --git a/thread.go b/thread.go index 5de2410..6e6aa7b 100644 --- a/thread.go +++ b/thread.go @@ -112,6 +112,10 @@ func (thread *Thread) maybeInitialBackfill(source *User) { thread.Parent.forwardBackfillInitial(source, thread) } +func (thread *Thread) RefererOpt() discordgo.RequestOption { + return discordgo.WithThreadReferer(thread.Parent.GuildID, thread.ParentID, thread.ID) +} + func (thread *Thread) Join(user *User) { if user.IsInPortal(thread.ID) { return @@ -137,7 +141,7 @@ func (thread *Thread) Join(user *User) { var err error if user.Session.IsUser { - err = user.Session.ThreadJoinWithLocation(thread.ID, discordgo.ThreadJoinLocationContextMenu) + err = user.Session.ThreadJoin(thread.ID, discordgo.WithLocationParam(discordgo.ThreadJoinLocationContextMenu), thread.RefererOpt()) } else { err = user.Session.ThreadJoin(thread.ID) }