diff --git a/attachments.go b/attachments.go index 9a0ea88..d710849 100644 --- a/attachments.go +++ b/attachments.go @@ -323,19 +323,17 @@ func (br *DiscordBridge) copyAttachmentToMatrix(intent *appservice.IntentAPI, ur } func (portal *Portal) getEmojiMXCByDiscordID(emojiID, name string, animated bool) id.ContentURI { - var url, mimeType, ext string + mxc := portal.bridge.DMA.EmojiMXC(emojiID, name, animated) + if !mxc.IsEmpty() { + return mxc + } + var url, mimeType string if animated { url = discordgo.EndpointEmojiAnimated(emojiID) mimeType = "image/gif" - ext = "gif" } else { url = discordgo.EndpointEmoji(emojiID) mimeType = "image/png" - ext = "png" - } - mxc := portal.bridge.Config.Bridge.MediaPatterns.Emoji(emojiID, ext) - if !mxc.IsEmpty() { - return mxc } dbFile, err := portal.bridge.copyAttachmentToMatrix(portal.MainIntent(), url, false, AttachmentMeta{ AttachmentID: emojiID, diff --git a/config/bridge.go b/config/bridge.go index 046632f..47571e6 100644 --- a/config/bridge.go +++ b/config/bridge.go @@ -25,7 +25,6 @@ import ( "github.com/bwmarrin/discordgo" "maunium.net/go/mautrix/bridge/bridgeconfig" - "maunium.net/go/mautrix/id" ) type BridgeConfig struct { @@ -55,8 +54,8 @@ type BridgeConfig struct { EnableWebhookAvatars bool `yaml:"enable_webhook_avatars"` UseDiscordCDNUpload bool `yaml:"use_discord_cdn_upload"` - CacheMedia string `yaml:"cache_media"` - MediaPatterns MediaPatterns `yaml:"media_patterns"` + CacheMedia string `yaml:"cache_media"` + DirectMedia DirectMedia `yaml:"direct_media"` AnimatedSticker struct { Target string `yaml:"target"` @@ -96,111 +95,12 @@ type BridgeConfig struct { guildNameTemplate *template.Template `yaml:"-"` } -type MediaPatterns struct { - Enabled bool `yaml:"enabled"` - TplAttachments string `yaml:"attachments"` - TplEmojis string `yaml:"emojis"` - TplStickers string `yaml:"stickers"` - TplAvatars string `yaml:"avatars"` - - attachments *template.Template `yaml:"-"` - emojis *template.Template `yaml:"-"` - stickers *template.Template `yaml:"-"` - avatars *template.Template `yaml:"-"` -} - -type umMediaPatterns MediaPatterns - -func (mp *MediaPatterns) UnmarshalYAML(unmarshal func(interface{}) error) error { - err := unmarshal((*umMediaPatterns)(mp)) - if err != nil { - return err - } - tpl := template.New("media_patterns") - - pairs := []struct { - ptr **template.Template - name string - template string - }{ - {&mp.attachments, "attachments", mp.TplAttachments}, - {&mp.emojis, "emojis", mp.TplEmojis}, - {&mp.stickers, "stickers", mp.TplStickers}, - {&mp.avatars, "avatars", mp.TplAvatars}, - } - for _, pair := range pairs { - if pair.template == "" { - continue - } - *pair.ptr, err = tpl.New(pair.name).Parse(pair.template) - if err != nil { - return err - } - } - return nil -} - -type attachmentParams struct { - ChannelID string - AttachmentID string - FileName string -} - -type emojiStickerParams struct { - ID string - Ext string -} - -type avatarParams struct { - UserID string - AvatarID string - Ext string -} - -func (mp *MediaPatterns) execute(tpl *template.Template, params any) id.ContentURI { - if tpl == nil || !mp.Enabled { - return id.ContentURI{} - } - var out strings.Builder - err := tpl.Execute(&out, params) - if err != nil { - panic(err) - } - uri, err := id.ParseContentURI(out.String()) - if err != nil { - panic(err) - } - return uri -} - -func (mp *MediaPatterns) Attachment(channelID, attachmentID, filename string) id.ContentURI { - return mp.execute(mp.attachments, attachmentParams{ - ChannelID: channelID, - AttachmentID: attachmentID, - FileName: filename, - }) -} - -func (mp *MediaPatterns) Emoji(emojiID, ext string) id.ContentURI { - return mp.execute(mp.emojis, emojiStickerParams{ - ID: emojiID, - Ext: ext, - }) -} - -func (mp *MediaPatterns) Sticker(stickerID, ext string) id.ContentURI { - return mp.execute(mp.stickers, emojiStickerParams{ - ID: stickerID, - Ext: ext, - }) -} - -func (mp *MediaPatterns) Avatar(userID, avatarID, ext string) id.ContentURI { - return mp.execute(mp.avatars, avatarParams{ - UserID: userID, - AvatarID: avatarID, - Ext: ext, - }) +type DirectMedia struct { + Enabled bool `yaml:"enabled"` + ServerName string `yaml:"server_name"` + WellKnownResponse string `yaml:"well_known_response"` + AllowProxy bool `yaml:"allow_proxy"` + ServerKey string `yaml:"server_key"` } type BackfillLimitPart struct { diff --git a/config/upgrade.go b/config/upgrade.go index c3e9cff..9066af5 100644 --- a/config/upgrade.go +++ b/config/upgrade.go @@ -20,6 +20,7 @@ import ( up "go.mau.fi/util/configupgrade" "go.mau.fi/util/random" "maunium.net/go/mautrix/bridge/bridgeconfig" + "maunium.net/go/mautrix/federation" ) func DoUpgrade(helper *up.Helper) { @@ -58,12 +59,17 @@ func DoUpgrade(helper *up.Helper) { helper.Copy(up.Bool, "bridge", "prefix_webhook_messages") helper.Copy(up.Bool, "bridge", "enable_webhook_avatars") helper.Copy(up.Bool, "bridge", "use_discord_cdn_upload") - helper.Copy(up.Bool, "bridge", "media_patterns", "enabled") helper.Copy(up.Str, "bridge", "cache_media") - helper.Copy(up.Str|up.Null, "bridge", "media_patterns", "attachments") - helper.Copy(up.Str|up.Null, "bridge", "media_patterns", "emojis") - helper.Copy(up.Str|up.Null, "bridge", "media_patterns", "stickers") - helper.Copy(up.Str|up.Null, "bridge", "media_patterns", "avatars") + helper.Copy(up.Bool, "bridge", "direct_media", "enabled") + helper.Copy(up.Str, "bridge", "direct_media", "server_name") + helper.Copy(up.Str|up.Null, "bridge", "direct_media", "well_known_response") + helper.Copy(up.Bool, "bridge", "direct_media", "allow_proxy") + if serverKey, ok := helper.Get(up.Str, "bridge", "direct_media", "server_key"); !ok || serverKey == "generate" { + serverKey = federation.GenerateSigningKey().SynapseString() + helper.Set(up.Str, serverKey, "bridge", "direct_media", "server_key") + } else { + helper.Copy(up.Str, "bridge", "direct_media", "server_key") + } helper.Copy(up.Str, "bridge", "animated_sticker", "target") helper.Copy(up.Int, "bridge", "animated_sticker", "args", "width") helper.Copy(up.Int, "bridge", "animated_sticker", "args", "height") diff --git a/database/userportal.go b/database/userportal.go index 34d5660..783b83d 100644 --- a/database/userportal.go +++ b/database/userportal.go @@ -7,6 +7,7 @@ import ( "go.mau.fi/util/dbutil" log "maunium.net/go/maulogger/v2" + "maunium.net/go/mautrix/id" ) const ( @@ -44,6 +45,24 @@ func (u *User) scanUserPortals(rows dbutil.Rows) []UserPortal { return ups } +func (db *Database) GetUsersInPortal(channelID string) []id.UserID { + rows, err := db.Query("SELECT user_mxid FROM user_portal WHERE discord_id=$1", channelID) + if err != nil { + db.Portal.log.Errorln("Failed to get users in portal:", err) + } + var users []id.UserID + for rows.Next() { + var mxid id.UserID + err = rows.Scan(&mxid) + if err != nil { + db.Portal.log.Errorln("Failed to scan user in portal:", err) + } else { + users = append(users, mxid) + } + } + return users +} + func (u *User) GetPortals() []UserPortal { rows, err := u.db.Query("SELECT discord_id, type, timestamp, in_space FROM user_portal WHERE user_mxid=$1", u.MXID) if err != nil { diff --git a/directmedia.go b/directmedia.go new file mode 100644 index 0000000..a765f3f --- /dev/null +++ b/directmedia.go @@ -0,0 +1,570 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2024 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package main + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "mime" + "net" + "net/http" + "os" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/bwmarrin/discordgo" + "github.com/gorilla/mux" + "github.com/rs/zerolog" + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/federation" + "maunium.net/go/mautrix/id" + + "go.mau.fi/mautrix-discord/config" +) + +type DirectMediaAPI struct { + bridge *DiscordBridge + ks *federation.KeyServer + cfg config.DirectMedia + log zerolog.Logger + proxy http.Client + + signatureKey [32]byte + + attachmentCache map[AttachmentCacheKey]AttachmentCacheValue + attachmentCacheLock sync.Mutex +} + +type AttachmentCacheKey struct { + ChannelID uint64 + AttachmentID uint64 +} + +type AttachmentCacheValue struct { + URL string + Expiry time.Time +} + +func newDirectMediaAPI(br *DiscordBridge) *DirectMediaAPI { + if !br.Config.Bridge.DirectMedia.Enabled { + return nil + } + dma := &DirectMediaAPI{ + bridge: br, + cfg: br.Config.Bridge.DirectMedia, + log: br.ZLog.With().Str("component", "direct media").Logger(), + proxy: http.Client{ + Transport: &http.Transport{ + DialContext: (&net.Dialer{Timeout: 10 * time.Second}).DialContext, + TLSHandshakeTimeout: 10 * time.Second, + ForceAttemptHTTP2: false, + }, + Timeout: 60 * time.Second, + }, + attachmentCache: make(map[AttachmentCacheKey]AttachmentCacheValue), + } + r := br.AS.Router + + parsed, err := federation.ParseSynapseKey(dma.cfg.ServerKey) + if err != nil { + dma.log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to parse server key") + os.Exit(11) + return nil + } + dma.signatureKey = sha256.Sum256(parsed.Priv.Seed()) + dma.ks = &federation.KeyServer{ + KeyProvider: &federation.StaticServerKey{ + ServerName: dma.cfg.ServerName, + Key: parsed, + }, + WellKnownTarget: dma.cfg.WellKnownResponse, + Version: federation.ServerVersion{ + Name: br.Name, + Version: br.Version, + }, + } + if dma.ks.WellKnownTarget == "" { + dma.ks.WellKnownTarget = fmt.Sprintf("%s:443", dma.cfg.ServerName) + } + mediaRouter := r.PathPrefix("/_matrix/media").Subrouter() + var reqIDCounter atomic.Uint64 + mediaRouter.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization") + log := dma.log.With(). + Str("remote_addr", r.RemoteAddr). + Str("request_path", r.URL.Path). + Uint64("req_id", reqIDCounter.Add(1)). + Logger() + next.ServeHTTP(w, r.WithContext(log.WithContext(r.Context()))) + }) + }) + addRoutes := func(version string) { + mediaRouter.HandleFunc("/"+version+"/download/{serverName}/{mediaID}", dma.DownloadMedia).Methods(http.MethodGet) + mediaRouter.HandleFunc("/"+version+"/download/{serverName}/{mediaID}/{fileName}", dma.DownloadMedia).Methods(http.MethodGet) + mediaRouter.HandleFunc("/"+version+"/thumbnail/{serverName}/{mediaID}", dma.DownloadMedia).Methods(http.MethodGet) + mediaRouter.HandleFunc("/"+version+"/upload/{serverName}/{mediaID}", dma.UploadNotSupported).Methods(http.MethodPut) + mediaRouter.HandleFunc("/"+version+"/upload", dma.UploadNotSupported).Methods(http.MethodPost) + mediaRouter.HandleFunc("/"+version+"/config", dma.UploadNotSupported).Methods(http.MethodGet) + mediaRouter.HandleFunc("/"+version+"/preview_url", dma.PreviewURLNotSupported).Methods(http.MethodGet) + } + addRoutes("v3") + addRoutes("r0") + addRoutes("v1") + mediaRouter.NotFoundHandler = http.HandlerFunc(dma.UnknownEndpoint) + mediaRouter.MethodNotAllowedHandler = http.HandlerFunc(dma.UnsupportedMethod) + dma.ks.Register(r) + + return dma +} + +func (dma *DirectMediaAPI) makeMXC(data MediaIDData) id.ContentURI { + return id.ContentURI{ + Homeserver: dma.cfg.ServerName, + FileID: data.Wrap().SignedString(dma.signatureKey), + } +} + +func (dma *DirectMediaAPI) addAttachmentToCache(channelID uint64, att *discordgo.MessageAttachment) { + attachmentID, err := strconv.ParseUint(att.ID, 10, 64) + if err != nil { + return + } + dma.attachmentCache[AttachmentCacheKey{ + ChannelID: channelID, + AttachmentID: attachmentID, + }] = AttachmentCacheValue{ + URL: att.URL, + // TODO find expiry somehow properly? + Expiry: time.Now().Add(23 * time.Hour), + } +} + +func (dma *DirectMediaAPI) AttachmentMXC(channelID, messageID string, att *discordgo.MessageAttachment) (mxc id.ContentURI) { + if dma == nil { + return + } + channelIDInt, err := strconv.ParseUint(channelID, 10, 64) + if err != nil { + dma.log.Warn().Str("channel_id", channelID).Msg("Got non-integer channel ID") + return + } + messageIDInt, err := strconv.ParseUint(messageID, 10, 64) + if err != nil { + dma.log.Warn().Str("message_id", messageID).Msg("Got non-integer message ID") + return + } + attachmentIDInt, err := strconv.ParseUint(att.ID, 10, 64) + if err != nil { + dma.log.Warn().Str("attachment_id", att.ID).Msg("Got non-integer attachment ID") + return + } + dma.attachmentCacheLock.Lock() + dma.addAttachmentToCache(channelIDInt, att) + dma.attachmentCacheLock.Unlock() + return dma.makeMXC(&AttachmentMediaData{ + ChannelID: channelIDInt, + MessageID: messageIDInt, + AttachmentID: attachmentIDInt, + }) +} + +func (dma *DirectMediaAPI) EmojiMXC(emojiID, name string, animated bool) (mxc id.ContentURI) { + if dma == nil { + return + } + emojiIDInt, err := strconv.ParseUint(emojiID, 10, 64) + if err != nil { + dma.log.Warn().Str("emoji_id", emojiID).Msg("Got non-integer emoji ID") + return + } + return dma.makeMXC(&EmojiMediaData{ + EmojiMediaDataInner: EmojiMediaDataInner{ + EmojiID: emojiIDInt, + Animated: animated, + }, + Name: name, + }) +} + +func (dma *DirectMediaAPI) StickerMXC(stickerID string, format discordgo.StickerFormat) (mxc id.ContentURI) { + if dma == nil { + return + } + stickerIDInt, err := strconv.ParseUint(stickerID, 10, 64) + if err != nil { + dma.log.Warn().Str("sticker_id", stickerID).Msg("Got non-integer sticker ID") + return + } else if format > 255 || format < 0 { + dma.log.Warn().Int("format", int(format)).Msg("Got invalid sticker format") + return + } + return dma.makeMXC(&StickerMediaData{ + StickerID: stickerIDInt, + Format: byte(format), + }) +} + +func (dma *DirectMediaAPI) AvatarMXC(guildID, userID, avatarID string) (mxc id.ContentURI) { + if dma == nil { + return + } + animated := strings.HasPrefix(avatarID, "a_") + avatarIDBytes, err := hex.DecodeString(strings.TrimPrefix(avatarID, "a_")) + if err != nil { + dma.log.Warn().Str("avatar_id", avatarID).Msg("Got non-hex avatar ID") + return + } else if len(avatarIDBytes) != 16 { + dma.log.Warn().Str("avatar_id", avatarID).Msg("Got invalid avatar ID length") + return + } + avatarIDArray := [16]byte(avatarIDBytes) + userIDInt, err := strconv.ParseUint(userID, 10, 64) + if err != nil { + dma.log.Warn().Str("user_id", userID).Msg("Got non-integer user ID") + return + } + if guildID != "" { + guildIDInt, err := strconv.ParseUint(guildID, 10, 64) + if err != nil { + dma.log.Warn().Str("guild_id", guildID).Msg("Got non-integer guild ID") + return + } + return dma.makeMXC(&GuildMemberAvatarMediaData{ + GuildID: guildIDInt, + UserID: userIDInt, + AvatarID: avatarIDArray, + Animated: animated, + }) + } else { + return dma.makeMXC(&UserAvatarMediaData{ + UserID: userIDInt, + AvatarID: avatarIDArray, + Animated: animated, + }) + } +} + +type RespError struct { + Code string + Message string + Status int +} + +func (re *RespError) Error() string { + return re.Message +} + +var ErrNoUsersWithAccessFound = errors.New("no users found to fetch message") +var ErrAttachmentNotFound = errors.New("attachment not found") + +func (dma *DirectMediaAPI) fetchNewAttachmentURL(ctx context.Context, meta *AttachmentMediaData) (string, error) { + var client *discordgo.Session + channelIDStr := strconv.FormatUint(meta.ChannelID, 10) + users := dma.bridge.DB.GetUsersInPortal(channelIDStr) + for _, userID := range users { + user := dma.bridge.GetCachedUserByMXID(userID) + if user != nil && user.Session != nil { + client = user.Session + if !client.IsUser { + break + } + } + } + if client == nil { + return "", ErrNoUsersWithAccessFound + } + var url string + var msgs []*discordgo.Message + var err error + messageIDStr := strconv.FormatUint(meta.MessageID, 10) + if client.IsUser { + msgs, err = client.ChannelMessages(channelIDStr, 5, "", "", messageIDStr) + } else { + var msg *discordgo.Message + msg, err = client.ChannelMessage(channelIDStr, messageIDStr) + msgs = []*discordgo.Message{msg} + } + if err != nil { + return "", fmt.Errorf("failed to fetch message: %w", err) + } + attachmentIDStr := strconv.FormatUint(meta.AttachmentID, 10) + for _, item := range msgs { + for _, att := range item.Attachments { + dma.addAttachmentToCache(meta.ChannelID, att) + if att.ID == attachmentIDStr { + url = att.URL + } + } + } + if url == "" { + return "", ErrAttachmentNotFound + } + return url, nil +} + +func (dma *DirectMediaAPI) GetEmojiInfo(contentURI id.ContentURI) *EmojiMediaData { + if dma == nil || contentURI.IsEmpty() || contentURI.Homeserver != dma.cfg.ServerName { + return nil + } + mediaID, err := ParseMediaID(contentURI.FileID, dma.signatureKey) + if err != nil { + return nil + } + emojiData, ok := mediaID.Data.(*EmojiMediaData) + if !ok { + return nil + } + return emojiData + +} + +func (dma *DirectMediaAPI) getMediaURL(ctx context.Context, encodedMediaID string) (url string, expiry time.Time, err error) { + var mediaID *MediaID + mediaID, err = ParseMediaID(encodedMediaID, dma.signatureKey) + if err != nil { + err = &RespError{ + Code: mautrix.MNotFound.ErrCode, + Message: err.Error(), + Status: http.StatusNotFound, + } + return + } + switch mediaData := mediaID.Data.(type) { + case *AttachmentMediaData: + dma.attachmentCacheLock.Lock() + defer dma.attachmentCacheLock.Unlock() + cached, ok := dma.attachmentCache[mediaData.CacheKey()] + if ok && time.Until(cached.Expiry) > 5*time.Minute { + return cached.URL, cached.Expiry, nil + } + zerolog.Ctx(ctx).Debug(). + Uint64("channel_id", mediaData.ChannelID). + Uint64("message_id", mediaData.MessageID). + Uint64("attachment_id", mediaData.AttachmentID). + Msg("Refreshing attachment URL") + url, err = dma.fetchNewAttachmentURL(ctx, mediaData) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to refresh attachment URL") + msg := "Failed to refresh attachment URL" + if errors.Is(err, ErrNoUsersWithAccessFound) { + msg = "No users found with access to the channel" + } else if errors.Is(err, ErrAttachmentNotFound) { + msg = "Attachment not found in message. Perhaps it was deleted?" + } + err = &RespError{ + Code: mautrix.MNotFound.ErrCode, + Message: msg, + Status: http.StatusNotFound, + } + } else { + zerolog.Ctx(ctx).Debug().Msg("Successfully refreshed attachment URL") + // TODO find expiry somehow properly? + expiry = time.Now().Add(23 * time.Hour) + } + case *EmojiMediaData: + if mediaData.Animated { + url = discordgo.EndpointEmojiAnimated(strconv.FormatUint(mediaData.EmojiID, 10)) + } else { + url = discordgo.EndpointEmoji(strconv.FormatUint(mediaData.EmojiID, 10)) + } + case *StickerMediaData: + url = discordgo.EndpointStickerImage( + strconv.FormatUint(mediaData.StickerID, 10), + discordgo.StickerFormat(mediaData.Format), + ) + case *UserAvatarMediaData: + if mediaData.Animated { + url = discordgo.EndpointUserAvatarAnimated( + strconv.FormatUint(mediaData.UserID, 10), + fmt.Sprintf("a_%x", mediaData.AvatarID), + ) + } else { + url = discordgo.EndpointUserAvatar( + strconv.FormatUint(mediaData.UserID, 10), + fmt.Sprintf("%x", mediaData.AvatarID), + ) + } + case *GuildMemberAvatarMediaData: + if mediaData.Animated { + url = discordgo.EndpointGuildMemberAvatarAnimated( + strconv.FormatUint(mediaData.GuildID, 10), + strconv.FormatUint(mediaData.UserID, 10), + fmt.Sprintf("a_%x", mediaData.AvatarID), + ) + } else { + url = discordgo.EndpointGuildMemberAvatar( + strconv.FormatUint(mediaData.GuildID, 10), + strconv.FormatUint(mediaData.UserID, 10), + fmt.Sprintf("%x", mediaData.AvatarID), + ) + } + default: + zerolog.Ctx(ctx).Error().Type("media_data_type", mediaData).Msg("Unrecognized media data struct") + err = &RespError{ + Code: "M_UNKNOWN", + Message: "Unrecognized media data struct", + Status: http.StatusInternalServerError, + } + } + return +} + +func (dma *DirectMediaAPI) proxyDownload(ctx context.Context, w http.ResponseWriter, url, fileName string) { + log := zerolog.Ctx(ctx) + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + log.Err(err).Str("url", url).Msg("Failed to create proxy request") + jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ + ErrCode: "M_UNKNOWN", + Err: "Failed to create proxy request", + }) + return + } + for key, val := range discordgo.DroidDownloadHeaders { + req.Header.Set(key, val) + } + resp, err := dma.proxy.Do(req) + defer func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + }() + if err != nil { + log.Err(err).Str("url", url).Msg("Failed to proxy download") + jsonResponse(w, http.StatusServiceUnavailable, &mautrix.RespError{ + ErrCode: "M_UNKNOWN", + Err: "Failed to proxy download", + }) + return + } else if resp.StatusCode != http.StatusOK { + log.Warn().Str("url", url).Int("status", resp.StatusCode).Msg("Unexpected status code proxying download") + jsonResponse(w, resp.StatusCode, &mautrix.RespError{ + ErrCode: "M_UNKNOWN", + Err: "Unexpected status code proxying download", + }) + return + } + w.Header()["Content-Type"] = resp.Header["Content-Type"] + w.Header()["Content-Length"] = resp.Header["Content-Length"] + w.Header()["Last-Modified"] = resp.Header["Last-Modified"] + w.Header()["Cache-Control"] = resp.Header["Cache-Control"] + contentDisposition := "attachment" + switch resp.Header.Get("Content-Type") { + case "text/css", "text/plain", "text/csv", "application/json", "application/ld+json", "image/jpeg", "image/gif", + "image/png", "image/apng", "image/webp", "image/avif", "video/mp4", "video/webm", "video/ogg", "video/quicktime", + "audio/mp4", "audio/webm", "audio/aac", "audio/mpeg", "audio/ogg", "audio/wave", "audio/wav", "audio/x-wav", + "audio/x-pn-wav", "audio/flac", "audio/x-flac", "application/pdf": + contentDisposition = "inline" + } + if fileName != "" { + contentDisposition = mime.FormatMediaType(contentDisposition, map[string]string{ + "filename": fileName, + }) + } + w.Header().Set("Content-Disposition", contentDisposition) + w.WriteHeader(http.StatusOK) + _, err = io.Copy(w, resp.Body) + if err != nil { + log.Debug().Err(err).Msg("Failed to write proxy response") + } +} + +func (dma *DirectMediaAPI) DownloadMedia(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + log := zerolog.Ctx(ctx) + vars := mux.Vars(r) + if vars["serverName"] != dma.cfg.ServerName { + jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ + ErrCode: mautrix.MNotFound.ErrCode, + Err: fmt.Sprintf("This is a Discord media proxy for %q, other media downloads are not available here", dma.cfg.ServerName), + }) + return + } + url, expiresAt, err := dma.getMediaURL(ctx, vars["mediaID"]) + if err != nil { + var respError *RespError + if errors.As(err, &respError) { + jsonResponse(w, respError.Status, &mautrix.RespError{ + ErrCode: respError.Code, + Err: respError.Message, + }) + } else { + log.Err(err).Str("media_id", vars["mediaID"]).Msg("Failed to get media URL") + jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ + ErrCode: mautrix.MNotFound.ErrCode, + Err: "Media not found", + }) + } + return + + } + // Proxy if the config allows proxying and the request doesn't allow redirects. + // In any other case, redirect to the Discord CDN. + if dma.cfg.AllowProxy && r.URL.Query().Get("allow_redirect") != "true" { + dma.proxyDownload(ctx, w, url, vars["fileName"]) + return + } + w.Header().Set("Location", url) + expirySeconds := (time.Until(expiresAt) - 5*time.Minute).Seconds() + if expiresAt.IsZero() { + w.Header().Set("Cache-Control", "public, max-age=31536000, immutable") + } else if expirySeconds > 0 { + cacheControl := fmt.Sprintf("public, max-age=%d, immutable", int(expirySeconds)) + w.Header().Set("Cache-Control", cacheControl) + } else { + w.Header().Set("Cache-Control", "no-store") + } + w.WriteHeader(http.StatusTemporaryRedirect) +} + +func (dma *DirectMediaAPI) UploadNotSupported(w http.ResponseWriter, r *http.Request) { + jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ + ErrCode: mautrix.MUnrecognized.ErrCode, + Err: "This bridge only supports proxying Discord media downloads and does not support media uploads.", + }) +} + +func (dma *DirectMediaAPI) PreviewURLNotSupported(w http.ResponseWriter, r *http.Request) { + jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ + ErrCode: mautrix.MUnrecognized.ErrCode, + Err: "This bridge only supports proxying Discord media downloads and does not support URL previews.", + }) +} + +func (dma *DirectMediaAPI) UnknownEndpoint(w http.ResponseWriter, r *http.Request) { + jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ + ErrCode: mautrix.MUnrecognized.ErrCode, + Err: "Unrecognized endpoint", + }) +} + +func (dma *DirectMediaAPI) UnsupportedMethod(w http.ResponseWriter, r *http.Request) { + jsonResponse(w, http.StatusMethodNotAllowed, &mautrix.RespError{ + ErrCode: mautrix.MUnrecognized.ErrCode, + Err: "Invalid method for endpoint", + }) +} diff --git a/directmedia_id.go b/directmedia_id.go new file mode 100644 index 0000000..92b935a --- /dev/null +++ b/directmedia_id.go @@ -0,0 +1,287 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2024 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package main + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "errors" + "fmt" + "io" +) + +const MediaIDPrefix = "\U0001F408DISCORD" +const MediaIDVersion = 1 + +type MediaIDClass uint8 + +const ( + MediaIDClassAttachment MediaIDClass = 1 + MediaIDClassEmoji MediaIDClass = 2 + MediaIDClassSticker MediaIDClass = 3 + MediaIDClassUserAvatar MediaIDClass = 4 + MediaIDClassGuildMemberAvatar MediaIDClass = 5 +) + +type MediaIDData interface { + Write(to io.Writer) + Read(from io.Reader) error + Size() int + Wrap() *MediaID +} + +type MediaID struct { + Version uint8 + TypeClass MediaIDClass + Data MediaIDData +} + +func ParseMediaID(id string, key [32]byte) (*MediaID, error) { + data, err := base64.RawURLEncoding.DecodeString(id) + if err != nil { + return nil, fmt.Errorf("failed to decode base64: %w", err) + } + hasher := hmac.New(sha256.New, key[:]) + checksum := data[len(data)-TruncatedHashLength:] + data = data[:len(data)-TruncatedHashLength] + hasher.Write(data) + if !hmac.Equal(checksum, hasher.Sum(nil)[:TruncatedHashLength]) { + return nil, ErrMediaIDChecksumMismatch + } + mid := &MediaID{} + err = mid.Read(bytes.NewReader(data)) + if err != nil { + return nil, fmt.Errorf("failed to parse media ID: %w", err) + } + return mid, nil +} + +const TruncatedHashLength = 16 + +func (mid *MediaID) SignedString(key [32]byte) string { + buf := bytes.NewBuffer(make([]byte, 0, mid.Size())) + mid.Write(buf) + hasher := hmac.New(sha256.New, key[:]) + hasher.Write(buf.Bytes()) + buf.Write(hasher.Sum(nil)[:TruncatedHashLength]) + return base64.RawURLEncoding.EncodeToString(buf.Bytes()) +} + +func (mid *MediaID) Write(to io.Writer) { + _, _ = to.Write([]byte(MediaIDPrefix)) + _ = binary.Write(to, binary.BigEndian, mid.Version) + _ = binary.Write(to, binary.BigEndian, mid.TypeClass) + mid.Data.Write(to) +} + +func (mid *MediaID) Size() int { + return len(MediaIDPrefix) + 2 + mid.Data.Size() + TruncatedHashLength +} + +var ( + ErrInvalidMediaID = errors.New("invalid media ID") + ErrMediaIDChecksumMismatch = errors.New("invalid checksum in media ID") + ErrUnsupportedMediaID = errors.New("unsupported media ID") +) + +func (mid *MediaID) Read(from io.Reader) error { + prefix := make([]byte, len(MediaIDPrefix)) + _, err := io.ReadFull(from, prefix) + if err != nil || !bytes.Equal(prefix, []byte(MediaIDPrefix)) { + return fmt.Errorf("%w: prefix not found", ErrInvalidMediaID) + } + versionAndClass := make([]byte, 2) + _, err = io.ReadFull(from, versionAndClass) + if err != nil { + return fmt.Errorf("%w: version and class not found", ErrInvalidMediaID) + } else if versionAndClass[0] != MediaIDVersion { + return fmt.Errorf("%w: unknown version %d", ErrUnsupportedMediaID, versionAndClass[0]) + } + switch MediaIDClass(versionAndClass[1]) { + case MediaIDClassAttachment: + mid.Data = &AttachmentMediaData{} + case MediaIDClassEmoji: + mid.Data = &EmojiMediaData{} + case MediaIDClassSticker: + mid.Data = &StickerMediaData{} + case MediaIDClassUserAvatar: + mid.Data = &UserAvatarMediaData{} + case MediaIDClassGuildMemberAvatar: + mid.Data = &GuildMemberAvatarMediaData{} + default: + return fmt.Errorf("%w: unrecognized type class %d", ErrUnsupportedMediaID, versionAndClass[1]) + } + err = mid.Data.Read(from) + if err != nil { + return fmt.Errorf("failed to parse media ID data: %w", err) + } + return nil +} + +type AttachmentMediaData struct { + ChannelID uint64 + MessageID uint64 + AttachmentID uint64 +} + +func (amd *AttachmentMediaData) Write(to io.Writer) { + _ = binary.Write(to, binary.BigEndian, amd) +} + +func (amd *AttachmentMediaData) Read(from io.Reader) (err error) { + return binary.Read(from, binary.BigEndian, amd) +} + +func (amd *AttachmentMediaData) Size() int { + return binary.Size(amd) +} + +func (amd *AttachmentMediaData) Wrap() *MediaID { + return &MediaID{ + Version: MediaIDVersion, + TypeClass: MediaIDClassAttachment, + Data: amd, + } +} + +func (amd *AttachmentMediaData) CacheKey() AttachmentCacheKey { + return AttachmentCacheKey{ + ChannelID: amd.ChannelID, + AttachmentID: amd.AttachmentID, + } +} + +type StickerMediaData struct { + StickerID uint64 + Format uint8 +} + +func (smd *StickerMediaData) Write(to io.Writer) { + _ = binary.Write(to, binary.BigEndian, smd) +} + +func (smd *StickerMediaData) Read(from io.Reader) error { + return binary.Read(from, binary.BigEndian, smd) +} + +func (smd *StickerMediaData) Size() int { + return binary.Size(smd) +} + +func (smd *StickerMediaData) Wrap() *MediaID { + return &MediaID{ + Version: MediaIDVersion, + TypeClass: MediaIDClassSticker, + Data: smd, + } +} + +type EmojiMediaDataInner struct { + EmojiID uint64 + Animated bool +} + +type EmojiMediaData struct { + EmojiMediaDataInner + Name string +} + +func (emd *EmojiMediaData) Write(to io.Writer) { + _ = binary.Write(to, binary.BigEndian, &emd.EmojiMediaDataInner) + _, _ = to.Write([]byte(emd.Name)) +} + +func (emd *EmojiMediaData) Read(from io.Reader) (err error) { + err = binary.Read(from, binary.BigEndian, &emd.EmojiMediaDataInner) + if err != nil { + return + } + name, err := io.ReadAll(from) + if err != nil { + return + } + emd.Name = string(name) + return +} + +func (emd *EmojiMediaData) Size() int { + return binary.Size(&emd.EmojiMediaDataInner) + len(emd.Name) +} + +func (emd *EmojiMediaData) Wrap() *MediaID { + return &MediaID{ + Version: MediaIDVersion, + TypeClass: MediaIDClassEmoji, + Data: emd, + } +} + +type UserAvatarMediaData struct { + UserID uint64 + Animated bool + AvatarID [16]byte +} + +func (uamd *UserAvatarMediaData) Write(to io.Writer) { + _ = binary.Write(to, binary.BigEndian, uamd) +} + +func (uamd *UserAvatarMediaData) Read(from io.Reader) error { + return binary.Read(from, binary.BigEndian, uamd) +} + +func (uamd *UserAvatarMediaData) Size() int { + return binary.Size(uamd) +} + +func (uamd *UserAvatarMediaData) Wrap() *MediaID { + return &MediaID{ + Version: MediaIDVersion, + TypeClass: MediaIDClassUserAvatar, + Data: uamd, + } +} + +type GuildMemberAvatarMediaData struct { + GuildID uint64 + UserID uint64 + Animated bool + AvatarID [16]byte +} + +func (guamd *GuildMemberAvatarMediaData) Write(to io.Writer) { + _ = binary.Write(to, binary.BigEndian, guamd) +} + +func (guamd *GuildMemberAvatarMediaData) Read(from io.Reader) error { + return binary.Read(from, binary.BigEndian, guamd) +} + +func (guamd *GuildMemberAvatarMediaData) Size() int { + return binary.Size(guamd) +} + +func (guamd *GuildMemberAvatarMediaData) Wrap() *MediaID { + return &MediaID{ + Version: MediaIDVersion, + TypeClass: MediaIDClassGuildMemberAvatar, + Data: guamd, + } +} diff --git a/example-config.yaml b/example-config.yaml index 42b5c6c..9e2dc4c 100644 --- a/example-config.yaml +++ b/example-config.yaml @@ -168,20 +168,23 @@ bridge: # This can be `never` to never cache, `unencrypted` to only cache unencrypted mxc uris, or `always` to cache everything. # If you have a media repo that generates non-unique mxc uris, you should set this to never. cache_media: unencrypted - # Patterns for converting Discord media to custom mxc:// URIs instead of reuploading. - # Each of the patterns can be set to null to disable custom URIs for that type of media. + # Settings for converting Discord media to custom mxc:// URIs instead of reuploading. # More details can be found at https://docs.mau.fi/bridges/go/discord/direct-media.html - media_patterns: + direct_media: # Should custom mxc:// URIs be used instead of reuploading media? enabled: false - # Pattern for normal message attachments. - attachments: mxc://discord-media.mau.dev/attachments|{{.ChannelID}}|{{.AttachmentID}}|{{.FileName}} - # Pattern for custom emojis. - emojis: mxc://discord-media.mau.dev/emojis|{{.ID}}.{{.Ext}} - # Pattern for stickers. Note that animated lottie stickers will not be converted if this is enabled. - stickers: mxc://discord-media.mau.dev/stickers|{{.ID}}.{{.Ext}} - # Pattern for static user avatars. - avatars: mxc://discord-media.mau.dev/avatars|{{.UserID}}|{{.AvatarID}}.{{.Ext}} + # The server name to use for the custom mxc:// URIs. + # This server name will effectively be a real Matrix server, it just won't implement anything other than media. + # You must either set up .well-known delegation from this domain to the bridge, or proxy the domain directly to the bridge. + server_name: discord-media.example.com + # Optionally a custom .well-known response. This defaults to `server_name:443` + well_known_response: + # The bridge supports MSC3860 media download redirects and will use them if the requester supports it. + # Optionally, you can force redirects and not allow proxying at all by setting this to false. + allow_proxy: true + # Matrix server signing key to make the federation tester pass, same format as synapse's .signing.key file. + # This key is also used to sign the mxc:// URIs to ensure only the bridge can generate them. + server_key: generate # Settings for converting animated stickers. animated_sticker: # Format to which animated stickers should be converted. diff --git a/go.mod b/go.mod index 4e9f583..bf3770f 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( golang.org/x/exp v0.0.0-20231219180239-dc181d75b848 golang.org/x/sync v0.5.0 maunium.net/go/maulogger/v2 v2.4.1 - maunium.net/go/mautrix v0.16.2 + maunium.net/go/mautrix v0.16.3-0.20240218195727-4ceb1123b660 ) require ( diff --git a/go.sum b/go.sum index 464f848..6a9df03 100644 --- a/go.sum +++ b/go.sum @@ -71,5 +71,5 @@ maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8= maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho= -maunium.net/go/mautrix v0.16.2 h1:a6GUJXNWsTEOO8VE4dROBfCIfPp50mqaqzv7KPzChvg= -maunium.net/go/mautrix v0.16.2/go.mod h1:YL4l4rZB46/vj/ifRMEjcibbvHjgxHftOF1SgmruLu4= +maunium.net/go/mautrix v0.16.3-0.20240218195727-4ceb1123b660 h1:ZPg1i0wsyEs5ee7z6Gn4mSRsq9BtDV//rfYTGs82l8c= +maunium.net/go/mautrix v0.16.3-0.20240218195727-4ceb1123b660/go.mod h1:YL4l4rZB46/vj/ifRMEjcibbvHjgxHftOF1SgmruLu4= diff --git a/main.go b/main.go index 7fd2501..f3b51d3 100644 --- a/main.go +++ b/main.go @@ -48,6 +48,7 @@ type DiscordBridge struct { Config *config.Config DB *database.Database + DMA *DirectMediaAPI provisioning *ProvisioningAPI usersByMXID map[id.UserID]*User @@ -104,6 +105,7 @@ func (br *DiscordBridge) Start() { if br.Config.Bridge.Provisioning.SharedSecret != "disable" { br.provisioning = newProvisioningAPI(br) } + br.DMA = newDirectMediaAPI(br) br.WaitWebsocketConnected() go br.startUsers() } diff --git a/portal.go b/portal.go index f32ba28..e63c83f 100644 --- a/portal.go +++ b/portal.go @@ -1832,13 +1832,15 @@ func (portal *Portal) handleMatrixReaction(sender *User, evt *event.Event) { emojiID := reaction.RelatesTo.Key if strings.HasPrefix(emojiID, "mxc://") { uri, _ := id.ParseContentURI(emojiID) - emojiFile := portal.bridge.DB.File.GetEmojiByMXC(uri) - if emojiFile == nil || emojiFile.ID == "" || emojiFile.EmojiName == "" { + emojiInfo := portal.bridge.DMA.GetEmojiInfo(uri) + if emojiInfo != nil { + emojiID = fmt.Sprintf("%s:%d", emojiInfo.Name, emojiInfo.EmojiID) + } else if emojiFile := portal.bridge.DB.File.GetEmojiByMXC(uri); emojiFile != nil && emojiFile.ID != "" && emojiFile.EmojiName != "" { + emojiID = fmt.Sprintf("%s:%s", emojiFile.EmojiName, emojiFile.ID) + } else { go portal.sendMessageMetrics(evt, fmt.Errorf("%w %s", errUnknownEmoji, emojiID), "Ignoring") return } - - emojiID = fmt.Sprintf("%s:%s", emojiFile.EmojiName, emojiFile.ID) } else { emojiID = variationselector.FullyQualify(emojiID) } diff --git a/portal_convert.go b/portal_convert.go index 3d34005..90f0100 100644 --- a/portal_convert.go +++ b/portal_convert.go @@ -27,12 +27,11 @@ import ( "github.com/bwmarrin/discordgo" "github.com/rs/zerolog" "golang.org/x/exp/slices" - "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/format" + "maunium.net/go/mautrix/id" ) type ConvertedMessage struct { @@ -103,20 +102,16 @@ func (portal *Portal) cleanupConvertedStickerInfo(content *event.MessageEventCon } func (portal *Portal) convertDiscordSticker(ctx context.Context, intent *appservice.IntentAPI, sticker *discordgo.Sticker) *ConvertedMessage { - var mime, ext string + var mime string switch sticker.FormatType { case discordgo.StickerFormatTypePNG: mime = "image/png" - ext = "png" case discordgo.StickerFormatTypeAPNG: mime = "image/apng" - ext = "png" case discordgo.StickerFormatTypeLottie: mime = "application/json" - ext = "json" case discordgo.StickerFormatTypeGIF: mime = "image/gif" - ext = "gif" default: zerolog.Ctx(ctx).Warn(). Int("sticker_format", int(sticker.FormatType)). @@ -130,8 +125,9 @@ func (portal *Portal) convertDiscordSticker(ctx context.Context, intent *appserv }, } - mxc := portal.bridge.Config.Bridge.MediaPatterns.Sticker(sticker.ID, ext) - if mxc.IsEmpty() { + mxc := portal.bridge.DMA.StickerMXC(sticker.ID, sticker.FormatType) + // TODO add config option to use direct media even for lottie stickers + if mxc.IsEmpty() && mime != "application/json" { content = portal.convertDiscordFile(ctx, "sticker", intent, sticker.ID, sticker.URL(), content) } else { content.URL = mxc.CUString() @@ -144,7 +140,7 @@ func (portal *Portal) convertDiscordSticker(ctx context.Context, intent *appserv } } -func (portal *Portal) convertDiscordAttachment(ctx context.Context, intent *appservice.IntentAPI, att *discordgo.MessageAttachment) *ConvertedMessage { +func (portal *Portal) convertDiscordAttachment(ctx context.Context, intent *appservice.IntentAPI, messageID string, att *discordgo.MessageAttachment) *ConvertedMessage { content := &event.MessageEventContent{ Body: att.Filename, Info: &event.FileInfo{ @@ -182,7 +178,7 @@ func (portal *Portal) convertDiscordAttachment(ctx context.Context, intent *apps default: content.MsgType = event.MsgFile } - mxc := portal.bridge.Config.Bridge.MediaPatterns.Attachment(portal.Key.ChannelID, att.ID, att.Filename) + mxc := portal.bridge.DMA.AttachmentMXC(portal.Key.ChannelID, messageID, att) if mxc.IsEmpty() { content = portal.convertDiscordFile(ctx, "attachment", intent, att.ID, att.URL, content) } else { @@ -287,7 +283,7 @@ func (portal *Portal) convertDiscordMessage(ctx context.Context, puppet *Puppet, } handledIDs[att.ID] = struct{}{} log := log.With().Str("attachment_id", att.ID).Logger() - if part := portal.convertDiscordAttachment(log.WithContext(ctx), intent, att); part != nil { + if part := portal.convertDiscordAttachment(log.WithContext(ctx), intent, msg.ID, att); part != nil { parts = append(parts, part) } } diff --git a/puppet.go b/puppet.go index af47d8c..ca6489e 100644 --- a/puppet.go +++ b/puppet.go @@ -217,27 +217,23 @@ func (puppet *Puppet) UpdateName(info *discordgo.User) bool { } func (br *DiscordBridge) reuploadUserAvatar(intent *appservice.IntentAPI, guildID, userID, avatarID string) (id.ContentURI, string, error) { - var downloadURL, ext string + var downloadURL string if guildID == "" { - downloadURL = discordgo.EndpointUserAvatar(userID, avatarID) - ext = "png" if strings.HasPrefix(avatarID, "a_") { downloadURL = discordgo.EndpointUserAvatarAnimated(userID, avatarID) - ext = "gif" + } else { + downloadURL = discordgo.EndpointUserAvatar(userID, avatarID) } } else { - downloadURL = discordgo.EndpointGuildMemberAvatar(guildID, userID, avatarID) - ext = "png" if strings.HasPrefix(avatarID, "a_") { downloadURL = discordgo.EndpointGuildMemberAvatarAnimated(guildID, userID, avatarID) - ext = "gif" + } else { + downloadURL = discordgo.EndpointGuildMemberAvatar(guildID, userID, avatarID) } } - if guildID == "" { - url := br.Config.Bridge.MediaPatterns.Avatar(userID, avatarID, ext) - if !url.IsEmpty() { - return url, downloadURL, nil - } + url := br.DMA.AvatarMXC(guildID, userID, avatarID) + if !url.IsEmpty() { + return url, downloadURL, nil } copied, err := br.copyAttachmentToMatrix(intent, downloadURL, false, AttachmentMeta{ AttachmentID: fmt.Sprintf("avatar/%s/%s/%s", guildID, userID, avatarID), diff --git a/user.go b/user.go index 174ed0d..1f08fba 100644 --- a/user.go +++ b/user.go @@ -195,6 +195,12 @@ func (br *DiscordBridge) GetCachedUserByID(id string) *User { return br.usersByID[id] } +func (br *DiscordBridge) GetCachedUserByMXID(userID id.UserID) *User { + br.usersLock.Lock() + defer br.usersLock.Unlock() + return br.usersByMXID[userID] +} + func (br *DiscordBridge) NewUser(dbUser *database.User) *User { user := &User{ User: dbUser,