diff --git a/attachments.go b/attachments.go
index 9a0ea88..d710849 100644
--- a/attachments.go
+++ b/attachments.go
@@ -323,19 +323,17 @@ func (br *DiscordBridge) copyAttachmentToMatrix(intent *appservice.IntentAPI, ur
}
func (portal *Portal) getEmojiMXCByDiscordID(emojiID, name string, animated bool) id.ContentURI {
- var url, mimeType, ext string
+ mxc := portal.bridge.DMA.EmojiMXC(emojiID, name, animated)
+ if !mxc.IsEmpty() {
+ return mxc
+ }
+ var url, mimeType string
if animated {
url = discordgo.EndpointEmojiAnimated(emojiID)
mimeType = "image/gif"
- ext = "gif"
} else {
url = discordgo.EndpointEmoji(emojiID)
mimeType = "image/png"
- ext = "png"
- }
- mxc := portal.bridge.Config.Bridge.MediaPatterns.Emoji(emojiID, ext)
- if !mxc.IsEmpty() {
- return mxc
}
dbFile, err := portal.bridge.copyAttachmentToMatrix(portal.MainIntent(), url, false, AttachmentMeta{
AttachmentID: emojiID,
diff --git a/config/bridge.go b/config/bridge.go
index 046632f..47571e6 100644
--- a/config/bridge.go
+++ b/config/bridge.go
@@ -25,7 +25,6 @@ import (
"github.com/bwmarrin/discordgo"
"maunium.net/go/mautrix/bridge/bridgeconfig"
- "maunium.net/go/mautrix/id"
)
type BridgeConfig struct {
@@ -55,8 +54,8 @@ type BridgeConfig struct {
EnableWebhookAvatars bool `yaml:"enable_webhook_avatars"`
UseDiscordCDNUpload bool `yaml:"use_discord_cdn_upload"`
- CacheMedia string `yaml:"cache_media"`
- MediaPatterns MediaPatterns `yaml:"media_patterns"`
+ CacheMedia string `yaml:"cache_media"`
+ DirectMedia DirectMedia `yaml:"direct_media"`
AnimatedSticker struct {
Target string `yaml:"target"`
@@ -96,111 +95,12 @@ type BridgeConfig struct {
guildNameTemplate *template.Template `yaml:"-"`
}
-type MediaPatterns struct {
- Enabled bool `yaml:"enabled"`
- TplAttachments string `yaml:"attachments"`
- TplEmojis string `yaml:"emojis"`
- TplStickers string `yaml:"stickers"`
- TplAvatars string `yaml:"avatars"`
-
- attachments *template.Template `yaml:"-"`
- emojis *template.Template `yaml:"-"`
- stickers *template.Template `yaml:"-"`
- avatars *template.Template `yaml:"-"`
-}
-
-type umMediaPatterns MediaPatterns
-
-func (mp *MediaPatterns) UnmarshalYAML(unmarshal func(interface{}) error) error {
- err := unmarshal((*umMediaPatterns)(mp))
- if err != nil {
- return err
- }
- tpl := template.New("media_patterns")
-
- pairs := []struct {
- ptr **template.Template
- name string
- template string
- }{
- {&mp.attachments, "attachments", mp.TplAttachments},
- {&mp.emojis, "emojis", mp.TplEmojis},
- {&mp.stickers, "stickers", mp.TplStickers},
- {&mp.avatars, "avatars", mp.TplAvatars},
- }
- for _, pair := range pairs {
- if pair.template == "" {
- continue
- }
- *pair.ptr, err = tpl.New(pair.name).Parse(pair.template)
- if err != nil {
- return err
- }
- }
- return nil
-}
-
-type attachmentParams struct {
- ChannelID string
- AttachmentID string
- FileName string
-}
-
-type emojiStickerParams struct {
- ID string
- Ext string
-}
-
-type avatarParams struct {
- UserID string
- AvatarID string
- Ext string
-}
-
-func (mp *MediaPatterns) execute(tpl *template.Template, params any) id.ContentURI {
- if tpl == nil || !mp.Enabled {
- return id.ContentURI{}
- }
- var out strings.Builder
- err := tpl.Execute(&out, params)
- if err != nil {
- panic(err)
- }
- uri, err := id.ParseContentURI(out.String())
- if err != nil {
- panic(err)
- }
- return uri
-}
-
-func (mp *MediaPatterns) Attachment(channelID, attachmentID, filename string) id.ContentURI {
- return mp.execute(mp.attachments, attachmentParams{
- ChannelID: channelID,
- AttachmentID: attachmentID,
- FileName: filename,
- })
-}
-
-func (mp *MediaPatterns) Emoji(emojiID, ext string) id.ContentURI {
- return mp.execute(mp.emojis, emojiStickerParams{
- ID: emojiID,
- Ext: ext,
- })
-}
-
-func (mp *MediaPatterns) Sticker(stickerID, ext string) id.ContentURI {
- return mp.execute(mp.stickers, emojiStickerParams{
- ID: stickerID,
- Ext: ext,
- })
-}
-
-func (mp *MediaPatterns) Avatar(userID, avatarID, ext string) id.ContentURI {
- return mp.execute(mp.avatars, avatarParams{
- UserID: userID,
- AvatarID: avatarID,
- Ext: ext,
- })
+type DirectMedia struct {
+ Enabled bool `yaml:"enabled"`
+ ServerName string `yaml:"server_name"`
+ WellKnownResponse string `yaml:"well_known_response"`
+ AllowProxy bool `yaml:"allow_proxy"`
+ ServerKey string `yaml:"server_key"`
}
type BackfillLimitPart struct {
diff --git a/config/upgrade.go b/config/upgrade.go
index c3e9cff..9066af5 100644
--- a/config/upgrade.go
+++ b/config/upgrade.go
@@ -20,6 +20,7 @@ import (
up "go.mau.fi/util/configupgrade"
"go.mau.fi/util/random"
"maunium.net/go/mautrix/bridge/bridgeconfig"
+ "maunium.net/go/mautrix/federation"
)
func DoUpgrade(helper *up.Helper) {
@@ -58,12 +59,17 @@ func DoUpgrade(helper *up.Helper) {
helper.Copy(up.Bool, "bridge", "prefix_webhook_messages")
helper.Copy(up.Bool, "bridge", "enable_webhook_avatars")
helper.Copy(up.Bool, "bridge", "use_discord_cdn_upload")
- helper.Copy(up.Bool, "bridge", "media_patterns", "enabled")
helper.Copy(up.Str, "bridge", "cache_media")
- helper.Copy(up.Str|up.Null, "bridge", "media_patterns", "attachments")
- helper.Copy(up.Str|up.Null, "bridge", "media_patterns", "emojis")
- helper.Copy(up.Str|up.Null, "bridge", "media_patterns", "stickers")
- helper.Copy(up.Str|up.Null, "bridge", "media_patterns", "avatars")
+ helper.Copy(up.Bool, "bridge", "direct_media", "enabled")
+ helper.Copy(up.Str, "bridge", "direct_media", "server_name")
+ helper.Copy(up.Str|up.Null, "bridge", "direct_media", "well_known_response")
+ helper.Copy(up.Bool, "bridge", "direct_media", "allow_proxy")
+ if serverKey, ok := helper.Get(up.Str, "bridge", "direct_media", "server_key"); !ok || serverKey == "generate" {
+ serverKey = federation.GenerateSigningKey().SynapseString()
+ helper.Set(up.Str, serverKey, "bridge", "direct_media", "server_key")
+ } else {
+ helper.Copy(up.Str, "bridge", "direct_media", "server_key")
+ }
helper.Copy(up.Str, "bridge", "animated_sticker", "target")
helper.Copy(up.Int, "bridge", "animated_sticker", "args", "width")
helper.Copy(up.Int, "bridge", "animated_sticker", "args", "height")
diff --git a/database/userportal.go b/database/userportal.go
index 34d5660..783b83d 100644
--- a/database/userportal.go
+++ b/database/userportal.go
@@ -7,6 +7,7 @@ import (
"go.mau.fi/util/dbutil"
log "maunium.net/go/maulogger/v2"
+ "maunium.net/go/mautrix/id"
)
const (
@@ -44,6 +45,24 @@ func (u *User) scanUserPortals(rows dbutil.Rows) []UserPortal {
return ups
}
+func (db *Database) GetUsersInPortal(channelID string) []id.UserID {
+ rows, err := db.Query("SELECT user_mxid FROM user_portal WHERE discord_id=$1", channelID)
+ if err != nil {
+ db.Portal.log.Errorln("Failed to get users in portal:", err)
+ }
+ var users []id.UserID
+ for rows.Next() {
+ var mxid id.UserID
+ err = rows.Scan(&mxid)
+ if err != nil {
+ db.Portal.log.Errorln("Failed to scan user in portal:", err)
+ } else {
+ users = append(users, mxid)
+ }
+ }
+ return users
+}
+
func (u *User) GetPortals() []UserPortal {
rows, err := u.db.Query("SELECT discord_id, type, timestamp, in_space FROM user_portal WHERE user_mxid=$1", u.MXID)
if err != nil {
diff --git a/directmedia.go b/directmedia.go
new file mode 100644
index 0000000..a765f3f
--- /dev/null
+++ b/directmedia.go
@@ -0,0 +1,570 @@
+// mautrix-discord - A Matrix-Discord puppeting bridge.
+// Copyright (C) 2024 Tulir Asokan
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package main
+
+import (
+ "context"
+ "crypto/sha256"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "io"
+ "mime"
+ "net"
+ "net/http"
+ "os"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/bwmarrin/discordgo"
+ "github.com/gorilla/mux"
+ "github.com/rs/zerolog"
+ "maunium.net/go/mautrix"
+ "maunium.net/go/mautrix/federation"
+ "maunium.net/go/mautrix/id"
+
+ "go.mau.fi/mautrix-discord/config"
+)
+
+type DirectMediaAPI struct {
+ bridge *DiscordBridge
+ ks *federation.KeyServer
+ cfg config.DirectMedia
+ log zerolog.Logger
+ proxy http.Client
+
+ signatureKey [32]byte
+
+ attachmentCache map[AttachmentCacheKey]AttachmentCacheValue
+ attachmentCacheLock sync.Mutex
+}
+
+type AttachmentCacheKey struct {
+ ChannelID uint64
+ AttachmentID uint64
+}
+
+type AttachmentCacheValue struct {
+ URL string
+ Expiry time.Time
+}
+
+func newDirectMediaAPI(br *DiscordBridge) *DirectMediaAPI {
+ if !br.Config.Bridge.DirectMedia.Enabled {
+ return nil
+ }
+ dma := &DirectMediaAPI{
+ bridge: br,
+ cfg: br.Config.Bridge.DirectMedia,
+ log: br.ZLog.With().Str("component", "direct media").Logger(),
+ proxy: http.Client{
+ Transport: &http.Transport{
+ DialContext: (&net.Dialer{Timeout: 10 * time.Second}).DialContext,
+ TLSHandshakeTimeout: 10 * time.Second,
+ ForceAttemptHTTP2: false,
+ },
+ Timeout: 60 * time.Second,
+ },
+ attachmentCache: make(map[AttachmentCacheKey]AttachmentCacheValue),
+ }
+ r := br.AS.Router
+
+ parsed, err := federation.ParseSynapseKey(dma.cfg.ServerKey)
+ if err != nil {
+ dma.log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to parse server key")
+ os.Exit(11)
+ return nil
+ }
+ dma.signatureKey = sha256.Sum256(parsed.Priv.Seed())
+ dma.ks = &federation.KeyServer{
+ KeyProvider: &federation.StaticServerKey{
+ ServerName: dma.cfg.ServerName,
+ Key: parsed,
+ },
+ WellKnownTarget: dma.cfg.WellKnownResponse,
+ Version: federation.ServerVersion{
+ Name: br.Name,
+ Version: br.Version,
+ },
+ }
+ if dma.ks.WellKnownTarget == "" {
+ dma.ks.WellKnownTarget = fmt.Sprintf("%s:443", dma.cfg.ServerName)
+ }
+ mediaRouter := r.PathPrefix("/_matrix/media").Subrouter()
+ var reqIDCounter atomic.Uint64
+ mediaRouter.Use(func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Access-Control-Allow-Origin", "*")
+ w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
+ w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization")
+ log := dma.log.With().
+ Str("remote_addr", r.RemoteAddr).
+ Str("request_path", r.URL.Path).
+ Uint64("req_id", reqIDCounter.Add(1)).
+ Logger()
+ next.ServeHTTP(w, r.WithContext(log.WithContext(r.Context())))
+ })
+ })
+ addRoutes := func(version string) {
+ mediaRouter.HandleFunc("/"+version+"/download/{serverName}/{mediaID}", dma.DownloadMedia).Methods(http.MethodGet)
+ mediaRouter.HandleFunc("/"+version+"/download/{serverName}/{mediaID}/{fileName}", dma.DownloadMedia).Methods(http.MethodGet)
+ mediaRouter.HandleFunc("/"+version+"/thumbnail/{serverName}/{mediaID}", dma.DownloadMedia).Methods(http.MethodGet)
+ mediaRouter.HandleFunc("/"+version+"/upload/{serverName}/{mediaID}", dma.UploadNotSupported).Methods(http.MethodPut)
+ mediaRouter.HandleFunc("/"+version+"/upload", dma.UploadNotSupported).Methods(http.MethodPost)
+ mediaRouter.HandleFunc("/"+version+"/config", dma.UploadNotSupported).Methods(http.MethodGet)
+ mediaRouter.HandleFunc("/"+version+"/preview_url", dma.PreviewURLNotSupported).Methods(http.MethodGet)
+ }
+ addRoutes("v3")
+ addRoutes("r0")
+ addRoutes("v1")
+ mediaRouter.NotFoundHandler = http.HandlerFunc(dma.UnknownEndpoint)
+ mediaRouter.MethodNotAllowedHandler = http.HandlerFunc(dma.UnsupportedMethod)
+ dma.ks.Register(r)
+
+ return dma
+}
+
+func (dma *DirectMediaAPI) makeMXC(data MediaIDData) id.ContentURI {
+ return id.ContentURI{
+ Homeserver: dma.cfg.ServerName,
+ FileID: data.Wrap().SignedString(dma.signatureKey),
+ }
+}
+
+func (dma *DirectMediaAPI) addAttachmentToCache(channelID uint64, att *discordgo.MessageAttachment) {
+ attachmentID, err := strconv.ParseUint(att.ID, 10, 64)
+ if err != nil {
+ return
+ }
+ dma.attachmentCache[AttachmentCacheKey{
+ ChannelID: channelID,
+ AttachmentID: attachmentID,
+ }] = AttachmentCacheValue{
+ URL: att.URL,
+ // TODO find expiry somehow properly?
+ Expiry: time.Now().Add(23 * time.Hour),
+ }
+}
+
+func (dma *DirectMediaAPI) AttachmentMXC(channelID, messageID string, att *discordgo.MessageAttachment) (mxc id.ContentURI) {
+ if dma == nil {
+ return
+ }
+ channelIDInt, err := strconv.ParseUint(channelID, 10, 64)
+ if err != nil {
+ dma.log.Warn().Str("channel_id", channelID).Msg("Got non-integer channel ID")
+ return
+ }
+ messageIDInt, err := strconv.ParseUint(messageID, 10, 64)
+ if err != nil {
+ dma.log.Warn().Str("message_id", messageID).Msg("Got non-integer message ID")
+ return
+ }
+ attachmentIDInt, err := strconv.ParseUint(att.ID, 10, 64)
+ if err != nil {
+ dma.log.Warn().Str("attachment_id", att.ID).Msg("Got non-integer attachment ID")
+ return
+ }
+ dma.attachmentCacheLock.Lock()
+ dma.addAttachmentToCache(channelIDInt, att)
+ dma.attachmentCacheLock.Unlock()
+ return dma.makeMXC(&AttachmentMediaData{
+ ChannelID: channelIDInt,
+ MessageID: messageIDInt,
+ AttachmentID: attachmentIDInt,
+ })
+}
+
+func (dma *DirectMediaAPI) EmojiMXC(emojiID, name string, animated bool) (mxc id.ContentURI) {
+ if dma == nil {
+ return
+ }
+ emojiIDInt, err := strconv.ParseUint(emojiID, 10, 64)
+ if err != nil {
+ dma.log.Warn().Str("emoji_id", emojiID).Msg("Got non-integer emoji ID")
+ return
+ }
+ return dma.makeMXC(&EmojiMediaData{
+ EmojiMediaDataInner: EmojiMediaDataInner{
+ EmojiID: emojiIDInt,
+ Animated: animated,
+ },
+ Name: name,
+ })
+}
+
+func (dma *DirectMediaAPI) StickerMXC(stickerID string, format discordgo.StickerFormat) (mxc id.ContentURI) {
+ if dma == nil {
+ return
+ }
+ stickerIDInt, err := strconv.ParseUint(stickerID, 10, 64)
+ if err != nil {
+ dma.log.Warn().Str("sticker_id", stickerID).Msg("Got non-integer sticker ID")
+ return
+ } else if format > 255 || format < 0 {
+ dma.log.Warn().Int("format", int(format)).Msg("Got invalid sticker format")
+ return
+ }
+ return dma.makeMXC(&StickerMediaData{
+ StickerID: stickerIDInt,
+ Format: byte(format),
+ })
+}
+
+func (dma *DirectMediaAPI) AvatarMXC(guildID, userID, avatarID string) (mxc id.ContentURI) {
+ if dma == nil {
+ return
+ }
+ animated := strings.HasPrefix(avatarID, "a_")
+ avatarIDBytes, err := hex.DecodeString(strings.TrimPrefix(avatarID, "a_"))
+ if err != nil {
+ dma.log.Warn().Str("avatar_id", avatarID).Msg("Got non-hex avatar ID")
+ return
+ } else if len(avatarIDBytes) != 16 {
+ dma.log.Warn().Str("avatar_id", avatarID).Msg("Got invalid avatar ID length")
+ return
+ }
+ avatarIDArray := [16]byte(avatarIDBytes)
+ userIDInt, err := strconv.ParseUint(userID, 10, 64)
+ if err != nil {
+ dma.log.Warn().Str("user_id", userID).Msg("Got non-integer user ID")
+ return
+ }
+ if guildID != "" {
+ guildIDInt, err := strconv.ParseUint(guildID, 10, 64)
+ if err != nil {
+ dma.log.Warn().Str("guild_id", guildID).Msg("Got non-integer guild ID")
+ return
+ }
+ return dma.makeMXC(&GuildMemberAvatarMediaData{
+ GuildID: guildIDInt,
+ UserID: userIDInt,
+ AvatarID: avatarIDArray,
+ Animated: animated,
+ })
+ } else {
+ return dma.makeMXC(&UserAvatarMediaData{
+ UserID: userIDInt,
+ AvatarID: avatarIDArray,
+ Animated: animated,
+ })
+ }
+}
+
+type RespError struct {
+ Code string
+ Message string
+ Status int
+}
+
+func (re *RespError) Error() string {
+ return re.Message
+}
+
+var ErrNoUsersWithAccessFound = errors.New("no users found to fetch message")
+var ErrAttachmentNotFound = errors.New("attachment not found")
+
+func (dma *DirectMediaAPI) fetchNewAttachmentURL(ctx context.Context, meta *AttachmentMediaData) (string, error) {
+ var client *discordgo.Session
+ channelIDStr := strconv.FormatUint(meta.ChannelID, 10)
+ users := dma.bridge.DB.GetUsersInPortal(channelIDStr)
+ for _, userID := range users {
+ user := dma.bridge.GetCachedUserByMXID(userID)
+ if user != nil && user.Session != nil {
+ client = user.Session
+ if !client.IsUser {
+ break
+ }
+ }
+ }
+ if client == nil {
+ return "", ErrNoUsersWithAccessFound
+ }
+ var url string
+ var msgs []*discordgo.Message
+ var err error
+ messageIDStr := strconv.FormatUint(meta.MessageID, 10)
+ if client.IsUser {
+ msgs, err = client.ChannelMessages(channelIDStr, 5, "", "", messageIDStr)
+ } else {
+ var msg *discordgo.Message
+ msg, err = client.ChannelMessage(channelIDStr, messageIDStr)
+ msgs = []*discordgo.Message{msg}
+ }
+ if err != nil {
+ return "", fmt.Errorf("failed to fetch message: %w", err)
+ }
+ attachmentIDStr := strconv.FormatUint(meta.AttachmentID, 10)
+ for _, item := range msgs {
+ for _, att := range item.Attachments {
+ dma.addAttachmentToCache(meta.ChannelID, att)
+ if att.ID == attachmentIDStr {
+ url = att.URL
+ }
+ }
+ }
+ if url == "" {
+ return "", ErrAttachmentNotFound
+ }
+ return url, nil
+}
+
+func (dma *DirectMediaAPI) GetEmojiInfo(contentURI id.ContentURI) *EmojiMediaData {
+ if dma == nil || contentURI.IsEmpty() || contentURI.Homeserver != dma.cfg.ServerName {
+ return nil
+ }
+ mediaID, err := ParseMediaID(contentURI.FileID, dma.signatureKey)
+ if err != nil {
+ return nil
+ }
+ emojiData, ok := mediaID.Data.(*EmojiMediaData)
+ if !ok {
+ return nil
+ }
+ return emojiData
+
+}
+
+func (dma *DirectMediaAPI) getMediaURL(ctx context.Context, encodedMediaID string) (url string, expiry time.Time, err error) {
+ var mediaID *MediaID
+ mediaID, err = ParseMediaID(encodedMediaID, dma.signatureKey)
+ if err != nil {
+ err = &RespError{
+ Code: mautrix.MNotFound.ErrCode,
+ Message: err.Error(),
+ Status: http.StatusNotFound,
+ }
+ return
+ }
+ switch mediaData := mediaID.Data.(type) {
+ case *AttachmentMediaData:
+ dma.attachmentCacheLock.Lock()
+ defer dma.attachmentCacheLock.Unlock()
+ cached, ok := dma.attachmentCache[mediaData.CacheKey()]
+ if ok && time.Until(cached.Expiry) > 5*time.Minute {
+ return cached.URL, cached.Expiry, nil
+ }
+ zerolog.Ctx(ctx).Debug().
+ Uint64("channel_id", mediaData.ChannelID).
+ Uint64("message_id", mediaData.MessageID).
+ Uint64("attachment_id", mediaData.AttachmentID).
+ Msg("Refreshing attachment URL")
+ url, err = dma.fetchNewAttachmentURL(ctx, mediaData)
+ if err != nil {
+ zerolog.Ctx(ctx).Err(err).Msg("Failed to refresh attachment URL")
+ msg := "Failed to refresh attachment URL"
+ if errors.Is(err, ErrNoUsersWithAccessFound) {
+ msg = "No users found with access to the channel"
+ } else if errors.Is(err, ErrAttachmentNotFound) {
+ msg = "Attachment not found in message. Perhaps it was deleted?"
+ }
+ err = &RespError{
+ Code: mautrix.MNotFound.ErrCode,
+ Message: msg,
+ Status: http.StatusNotFound,
+ }
+ } else {
+ zerolog.Ctx(ctx).Debug().Msg("Successfully refreshed attachment URL")
+ // TODO find expiry somehow properly?
+ expiry = time.Now().Add(23 * time.Hour)
+ }
+ case *EmojiMediaData:
+ if mediaData.Animated {
+ url = discordgo.EndpointEmojiAnimated(strconv.FormatUint(mediaData.EmojiID, 10))
+ } else {
+ url = discordgo.EndpointEmoji(strconv.FormatUint(mediaData.EmojiID, 10))
+ }
+ case *StickerMediaData:
+ url = discordgo.EndpointStickerImage(
+ strconv.FormatUint(mediaData.StickerID, 10),
+ discordgo.StickerFormat(mediaData.Format),
+ )
+ case *UserAvatarMediaData:
+ if mediaData.Animated {
+ url = discordgo.EndpointUserAvatarAnimated(
+ strconv.FormatUint(mediaData.UserID, 10),
+ fmt.Sprintf("a_%x", mediaData.AvatarID),
+ )
+ } else {
+ url = discordgo.EndpointUserAvatar(
+ strconv.FormatUint(mediaData.UserID, 10),
+ fmt.Sprintf("%x", mediaData.AvatarID),
+ )
+ }
+ case *GuildMemberAvatarMediaData:
+ if mediaData.Animated {
+ url = discordgo.EndpointGuildMemberAvatarAnimated(
+ strconv.FormatUint(mediaData.GuildID, 10),
+ strconv.FormatUint(mediaData.UserID, 10),
+ fmt.Sprintf("a_%x", mediaData.AvatarID),
+ )
+ } else {
+ url = discordgo.EndpointGuildMemberAvatar(
+ strconv.FormatUint(mediaData.GuildID, 10),
+ strconv.FormatUint(mediaData.UserID, 10),
+ fmt.Sprintf("%x", mediaData.AvatarID),
+ )
+ }
+ default:
+ zerolog.Ctx(ctx).Error().Type("media_data_type", mediaData).Msg("Unrecognized media data struct")
+ err = &RespError{
+ Code: "M_UNKNOWN",
+ Message: "Unrecognized media data struct",
+ Status: http.StatusInternalServerError,
+ }
+ }
+ return
+}
+
+func (dma *DirectMediaAPI) proxyDownload(ctx context.Context, w http.ResponseWriter, url, fileName string) {
+ log := zerolog.Ctx(ctx)
+ req, err := http.NewRequest(http.MethodGet, url, nil)
+ if err != nil {
+ log.Err(err).Str("url", url).Msg("Failed to create proxy request")
+ jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{
+ ErrCode: "M_UNKNOWN",
+ Err: "Failed to create proxy request",
+ })
+ return
+ }
+ for key, val := range discordgo.DroidDownloadHeaders {
+ req.Header.Set(key, val)
+ }
+ resp, err := dma.proxy.Do(req)
+ defer func() {
+ if resp != nil && resp.Body != nil {
+ _ = resp.Body.Close()
+ }
+ }()
+ if err != nil {
+ log.Err(err).Str("url", url).Msg("Failed to proxy download")
+ jsonResponse(w, http.StatusServiceUnavailable, &mautrix.RespError{
+ ErrCode: "M_UNKNOWN",
+ Err: "Failed to proxy download",
+ })
+ return
+ } else if resp.StatusCode != http.StatusOK {
+ log.Warn().Str("url", url).Int("status", resp.StatusCode).Msg("Unexpected status code proxying download")
+ jsonResponse(w, resp.StatusCode, &mautrix.RespError{
+ ErrCode: "M_UNKNOWN",
+ Err: "Unexpected status code proxying download",
+ })
+ return
+ }
+ w.Header()["Content-Type"] = resp.Header["Content-Type"]
+ w.Header()["Content-Length"] = resp.Header["Content-Length"]
+ w.Header()["Last-Modified"] = resp.Header["Last-Modified"]
+ w.Header()["Cache-Control"] = resp.Header["Cache-Control"]
+ contentDisposition := "attachment"
+ switch resp.Header.Get("Content-Type") {
+ case "text/css", "text/plain", "text/csv", "application/json", "application/ld+json", "image/jpeg", "image/gif",
+ "image/png", "image/apng", "image/webp", "image/avif", "video/mp4", "video/webm", "video/ogg", "video/quicktime",
+ "audio/mp4", "audio/webm", "audio/aac", "audio/mpeg", "audio/ogg", "audio/wave", "audio/wav", "audio/x-wav",
+ "audio/x-pn-wav", "audio/flac", "audio/x-flac", "application/pdf":
+ contentDisposition = "inline"
+ }
+ if fileName != "" {
+ contentDisposition = mime.FormatMediaType(contentDisposition, map[string]string{
+ "filename": fileName,
+ })
+ }
+ w.Header().Set("Content-Disposition", contentDisposition)
+ w.WriteHeader(http.StatusOK)
+ _, err = io.Copy(w, resp.Body)
+ if err != nil {
+ log.Debug().Err(err).Msg("Failed to write proxy response")
+ }
+}
+
+func (dma *DirectMediaAPI) DownloadMedia(w http.ResponseWriter, r *http.Request) {
+ ctx := r.Context()
+ log := zerolog.Ctx(ctx)
+ vars := mux.Vars(r)
+ if vars["serverName"] != dma.cfg.ServerName {
+ jsonResponse(w, http.StatusNotFound, &mautrix.RespError{
+ ErrCode: mautrix.MNotFound.ErrCode,
+ Err: fmt.Sprintf("This is a Discord media proxy for %q, other media downloads are not available here", dma.cfg.ServerName),
+ })
+ return
+ }
+ url, expiresAt, err := dma.getMediaURL(ctx, vars["mediaID"])
+ if err != nil {
+ var respError *RespError
+ if errors.As(err, &respError) {
+ jsonResponse(w, respError.Status, &mautrix.RespError{
+ ErrCode: respError.Code,
+ Err: respError.Message,
+ })
+ } else {
+ log.Err(err).Str("media_id", vars["mediaID"]).Msg("Failed to get media URL")
+ jsonResponse(w, http.StatusNotFound, &mautrix.RespError{
+ ErrCode: mautrix.MNotFound.ErrCode,
+ Err: "Media not found",
+ })
+ }
+ return
+
+ }
+ // Proxy if the config allows proxying and the request doesn't allow redirects.
+ // In any other case, redirect to the Discord CDN.
+ if dma.cfg.AllowProxy && r.URL.Query().Get("allow_redirect") != "true" {
+ dma.proxyDownload(ctx, w, url, vars["fileName"])
+ return
+ }
+ w.Header().Set("Location", url)
+ expirySeconds := (time.Until(expiresAt) - 5*time.Minute).Seconds()
+ if expiresAt.IsZero() {
+ w.Header().Set("Cache-Control", "public, max-age=31536000, immutable")
+ } else if expirySeconds > 0 {
+ cacheControl := fmt.Sprintf("public, max-age=%d, immutable", int(expirySeconds))
+ w.Header().Set("Cache-Control", cacheControl)
+ } else {
+ w.Header().Set("Cache-Control", "no-store")
+ }
+ w.WriteHeader(http.StatusTemporaryRedirect)
+}
+
+func (dma *DirectMediaAPI) UploadNotSupported(w http.ResponseWriter, r *http.Request) {
+ jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{
+ ErrCode: mautrix.MUnrecognized.ErrCode,
+ Err: "This bridge only supports proxying Discord media downloads and does not support media uploads.",
+ })
+}
+
+func (dma *DirectMediaAPI) PreviewURLNotSupported(w http.ResponseWriter, r *http.Request) {
+ jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{
+ ErrCode: mautrix.MUnrecognized.ErrCode,
+ Err: "This bridge only supports proxying Discord media downloads and does not support URL previews.",
+ })
+}
+
+func (dma *DirectMediaAPI) UnknownEndpoint(w http.ResponseWriter, r *http.Request) {
+ jsonResponse(w, http.StatusNotFound, &mautrix.RespError{
+ ErrCode: mautrix.MUnrecognized.ErrCode,
+ Err: "Unrecognized endpoint",
+ })
+}
+
+func (dma *DirectMediaAPI) UnsupportedMethod(w http.ResponseWriter, r *http.Request) {
+ jsonResponse(w, http.StatusMethodNotAllowed, &mautrix.RespError{
+ ErrCode: mautrix.MUnrecognized.ErrCode,
+ Err: "Invalid method for endpoint",
+ })
+}
diff --git a/directmedia_id.go b/directmedia_id.go
new file mode 100644
index 0000000..92b935a
--- /dev/null
+++ b/directmedia_id.go
@@ -0,0 +1,287 @@
+// mautrix-discord - A Matrix-Discord puppeting bridge.
+// Copyright (C) 2024 Tulir Asokan
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package main
+
+import (
+ "bytes"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+)
+
+const MediaIDPrefix = "\U0001F408DISCORD"
+const MediaIDVersion = 1
+
+type MediaIDClass uint8
+
+const (
+ MediaIDClassAttachment MediaIDClass = 1
+ MediaIDClassEmoji MediaIDClass = 2
+ MediaIDClassSticker MediaIDClass = 3
+ MediaIDClassUserAvatar MediaIDClass = 4
+ MediaIDClassGuildMemberAvatar MediaIDClass = 5
+)
+
+type MediaIDData interface {
+ Write(to io.Writer)
+ Read(from io.Reader) error
+ Size() int
+ Wrap() *MediaID
+}
+
+type MediaID struct {
+ Version uint8
+ TypeClass MediaIDClass
+ Data MediaIDData
+}
+
+func ParseMediaID(id string, key [32]byte) (*MediaID, error) {
+ data, err := base64.RawURLEncoding.DecodeString(id)
+ if err != nil {
+ return nil, fmt.Errorf("failed to decode base64: %w", err)
+ }
+ hasher := hmac.New(sha256.New, key[:])
+ checksum := data[len(data)-TruncatedHashLength:]
+ data = data[:len(data)-TruncatedHashLength]
+ hasher.Write(data)
+ if !hmac.Equal(checksum, hasher.Sum(nil)[:TruncatedHashLength]) {
+ return nil, ErrMediaIDChecksumMismatch
+ }
+ mid := &MediaID{}
+ err = mid.Read(bytes.NewReader(data))
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse media ID: %w", err)
+ }
+ return mid, nil
+}
+
+const TruncatedHashLength = 16
+
+func (mid *MediaID) SignedString(key [32]byte) string {
+ buf := bytes.NewBuffer(make([]byte, 0, mid.Size()))
+ mid.Write(buf)
+ hasher := hmac.New(sha256.New, key[:])
+ hasher.Write(buf.Bytes())
+ buf.Write(hasher.Sum(nil)[:TruncatedHashLength])
+ return base64.RawURLEncoding.EncodeToString(buf.Bytes())
+}
+
+func (mid *MediaID) Write(to io.Writer) {
+ _, _ = to.Write([]byte(MediaIDPrefix))
+ _ = binary.Write(to, binary.BigEndian, mid.Version)
+ _ = binary.Write(to, binary.BigEndian, mid.TypeClass)
+ mid.Data.Write(to)
+}
+
+func (mid *MediaID) Size() int {
+ return len(MediaIDPrefix) + 2 + mid.Data.Size() + TruncatedHashLength
+}
+
+var (
+ ErrInvalidMediaID = errors.New("invalid media ID")
+ ErrMediaIDChecksumMismatch = errors.New("invalid checksum in media ID")
+ ErrUnsupportedMediaID = errors.New("unsupported media ID")
+)
+
+func (mid *MediaID) Read(from io.Reader) error {
+ prefix := make([]byte, len(MediaIDPrefix))
+ _, err := io.ReadFull(from, prefix)
+ if err != nil || !bytes.Equal(prefix, []byte(MediaIDPrefix)) {
+ return fmt.Errorf("%w: prefix not found", ErrInvalidMediaID)
+ }
+ versionAndClass := make([]byte, 2)
+ _, err = io.ReadFull(from, versionAndClass)
+ if err != nil {
+ return fmt.Errorf("%w: version and class not found", ErrInvalidMediaID)
+ } else if versionAndClass[0] != MediaIDVersion {
+ return fmt.Errorf("%w: unknown version %d", ErrUnsupportedMediaID, versionAndClass[0])
+ }
+ switch MediaIDClass(versionAndClass[1]) {
+ case MediaIDClassAttachment:
+ mid.Data = &AttachmentMediaData{}
+ case MediaIDClassEmoji:
+ mid.Data = &EmojiMediaData{}
+ case MediaIDClassSticker:
+ mid.Data = &StickerMediaData{}
+ case MediaIDClassUserAvatar:
+ mid.Data = &UserAvatarMediaData{}
+ case MediaIDClassGuildMemberAvatar:
+ mid.Data = &GuildMemberAvatarMediaData{}
+ default:
+ return fmt.Errorf("%w: unrecognized type class %d", ErrUnsupportedMediaID, versionAndClass[1])
+ }
+ err = mid.Data.Read(from)
+ if err != nil {
+ return fmt.Errorf("failed to parse media ID data: %w", err)
+ }
+ return nil
+}
+
+type AttachmentMediaData struct {
+ ChannelID uint64
+ MessageID uint64
+ AttachmentID uint64
+}
+
+func (amd *AttachmentMediaData) Write(to io.Writer) {
+ _ = binary.Write(to, binary.BigEndian, amd)
+}
+
+func (amd *AttachmentMediaData) Read(from io.Reader) (err error) {
+ return binary.Read(from, binary.BigEndian, amd)
+}
+
+func (amd *AttachmentMediaData) Size() int {
+ return binary.Size(amd)
+}
+
+func (amd *AttachmentMediaData) Wrap() *MediaID {
+ return &MediaID{
+ Version: MediaIDVersion,
+ TypeClass: MediaIDClassAttachment,
+ Data: amd,
+ }
+}
+
+func (amd *AttachmentMediaData) CacheKey() AttachmentCacheKey {
+ return AttachmentCacheKey{
+ ChannelID: amd.ChannelID,
+ AttachmentID: amd.AttachmentID,
+ }
+}
+
+type StickerMediaData struct {
+ StickerID uint64
+ Format uint8
+}
+
+func (smd *StickerMediaData) Write(to io.Writer) {
+ _ = binary.Write(to, binary.BigEndian, smd)
+}
+
+func (smd *StickerMediaData) Read(from io.Reader) error {
+ return binary.Read(from, binary.BigEndian, smd)
+}
+
+func (smd *StickerMediaData) Size() int {
+ return binary.Size(smd)
+}
+
+func (smd *StickerMediaData) Wrap() *MediaID {
+ return &MediaID{
+ Version: MediaIDVersion,
+ TypeClass: MediaIDClassSticker,
+ Data: smd,
+ }
+}
+
+type EmojiMediaDataInner struct {
+ EmojiID uint64
+ Animated bool
+}
+
+type EmojiMediaData struct {
+ EmojiMediaDataInner
+ Name string
+}
+
+func (emd *EmojiMediaData) Write(to io.Writer) {
+ _ = binary.Write(to, binary.BigEndian, &emd.EmojiMediaDataInner)
+ _, _ = to.Write([]byte(emd.Name))
+}
+
+func (emd *EmojiMediaData) Read(from io.Reader) (err error) {
+ err = binary.Read(from, binary.BigEndian, &emd.EmojiMediaDataInner)
+ if err != nil {
+ return
+ }
+ name, err := io.ReadAll(from)
+ if err != nil {
+ return
+ }
+ emd.Name = string(name)
+ return
+}
+
+func (emd *EmojiMediaData) Size() int {
+ return binary.Size(&emd.EmojiMediaDataInner) + len(emd.Name)
+}
+
+func (emd *EmojiMediaData) Wrap() *MediaID {
+ return &MediaID{
+ Version: MediaIDVersion,
+ TypeClass: MediaIDClassEmoji,
+ Data: emd,
+ }
+}
+
+type UserAvatarMediaData struct {
+ UserID uint64
+ Animated bool
+ AvatarID [16]byte
+}
+
+func (uamd *UserAvatarMediaData) Write(to io.Writer) {
+ _ = binary.Write(to, binary.BigEndian, uamd)
+}
+
+func (uamd *UserAvatarMediaData) Read(from io.Reader) error {
+ return binary.Read(from, binary.BigEndian, uamd)
+}
+
+func (uamd *UserAvatarMediaData) Size() int {
+ return binary.Size(uamd)
+}
+
+func (uamd *UserAvatarMediaData) Wrap() *MediaID {
+ return &MediaID{
+ Version: MediaIDVersion,
+ TypeClass: MediaIDClassUserAvatar,
+ Data: uamd,
+ }
+}
+
+type GuildMemberAvatarMediaData struct {
+ GuildID uint64
+ UserID uint64
+ Animated bool
+ AvatarID [16]byte
+}
+
+func (guamd *GuildMemberAvatarMediaData) Write(to io.Writer) {
+ _ = binary.Write(to, binary.BigEndian, guamd)
+}
+
+func (guamd *GuildMemberAvatarMediaData) Read(from io.Reader) error {
+ return binary.Read(from, binary.BigEndian, guamd)
+}
+
+func (guamd *GuildMemberAvatarMediaData) Size() int {
+ return binary.Size(guamd)
+}
+
+func (guamd *GuildMemberAvatarMediaData) Wrap() *MediaID {
+ return &MediaID{
+ Version: MediaIDVersion,
+ TypeClass: MediaIDClassGuildMemberAvatar,
+ Data: guamd,
+ }
+}
diff --git a/example-config.yaml b/example-config.yaml
index 42b5c6c..9e2dc4c 100644
--- a/example-config.yaml
+++ b/example-config.yaml
@@ -168,20 +168,23 @@ bridge:
# This can be `never` to never cache, `unencrypted` to only cache unencrypted mxc uris, or `always` to cache everything.
# If you have a media repo that generates non-unique mxc uris, you should set this to never.
cache_media: unencrypted
- # Patterns for converting Discord media to custom mxc:// URIs instead of reuploading.
- # Each of the patterns can be set to null to disable custom URIs for that type of media.
+ # Settings for converting Discord media to custom mxc:// URIs instead of reuploading.
# More details can be found at https://docs.mau.fi/bridges/go/discord/direct-media.html
- media_patterns:
+ direct_media:
# Should custom mxc:// URIs be used instead of reuploading media?
enabled: false
- # Pattern for normal message attachments.
- attachments: mxc://discord-media.mau.dev/attachments|{{.ChannelID}}|{{.AttachmentID}}|{{.FileName}}
- # Pattern for custom emojis.
- emojis: mxc://discord-media.mau.dev/emojis|{{.ID}}.{{.Ext}}
- # Pattern for stickers. Note that animated lottie stickers will not be converted if this is enabled.
- stickers: mxc://discord-media.mau.dev/stickers|{{.ID}}.{{.Ext}}
- # Pattern for static user avatars.
- avatars: mxc://discord-media.mau.dev/avatars|{{.UserID}}|{{.AvatarID}}.{{.Ext}}
+ # The server name to use for the custom mxc:// URIs.
+ # This server name will effectively be a real Matrix server, it just won't implement anything other than media.
+ # You must either set up .well-known delegation from this domain to the bridge, or proxy the domain directly to the bridge.
+ server_name: discord-media.example.com
+ # Optionally a custom .well-known response. This defaults to `server_name:443`
+ well_known_response:
+ # The bridge supports MSC3860 media download redirects and will use them if the requester supports it.
+ # Optionally, you can force redirects and not allow proxying at all by setting this to false.
+ allow_proxy: true
+ # Matrix server signing key to make the federation tester pass, same format as synapse's .signing.key file.
+ # This key is also used to sign the mxc:// URIs to ensure only the bridge can generate them.
+ server_key: generate
# Settings for converting animated stickers.
animated_sticker:
# Format to which animated stickers should be converted.
diff --git a/go.mod b/go.mod
index 4e9f583..bf3770f 100644
--- a/go.mod
+++ b/go.mod
@@ -18,7 +18,7 @@ require (
golang.org/x/exp v0.0.0-20231219180239-dc181d75b848
golang.org/x/sync v0.5.0
maunium.net/go/maulogger/v2 v2.4.1
- maunium.net/go/mautrix v0.16.2
+ maunium.net/go/mautrix v0.16.3-0.20240218195727-4ceb1123b660
)
require (
diff --git a/go.sum b/go.sum
index 464f848..6a9df03 100644
--- a/go.sum
+++ b/go.sum
@@ -71,5 +71,5 @@ maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA=
maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8=
maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho=
-maunium.net/go/mautrix v0.16.2 h1:a6GUJXNWsTEOO8VE4dROBfCIfPp50mqaqzv7KPzChvg=
-maunium.net/go/mautrix v0.16.2/go.mod h1:YL4l4rZB46/vj/ifRMEjcibbvHjgxHftOF1SgmruLu4=
+maunium.net/go/mautrix v0.16.3-0.20240218195727-4ceb1123b660 h1:ZPg1i0wsyEs5ee7z6Gn4mSRsq9BtDV//rfYTGs82l8c=
+maunium.net/go/mautrix v0.16.3-0.20240218195727-4ceb1123b660/go.mod h1:YL4l4rZB46/vj/ifRMEjcibbvHjgxHftOF1SgmruLu4=
diff --git a/main.go b/main.go
index 7fd2501..f3b51d3 100644
--- a/main.go
+++ b/main.go
@@ -48,6 +48,7 @@ type DiscordBridge struct {
Config *config.Config
DB *database.Database
+ DMA *DirectMediaAPI
provisioning *ProvisioningAPI
usersByMXID map[id.UserID]*User
@@ -104,6 +105,7 @@ func (br *DiscordBridge) Start() {
if br.Config.Bridge.Provisioning.SharedSecret != "disable" {
br.provisioning = newProvisioningAPI(br)
}
+ br.DMA = newDirectMediaAPI(br)
br.WaitWebsocketConnected()
go br.startUsers()
}
diff --git a/portal.go b/portal.go
index f32ba28..e63c83f 100644
--- a/portal.go
+++ b/portal.go
@@ -1832,13 +1832,15 @@ func (portal *Portal) handleMatrixReaction(sender *User, evt *event.Event) {
emojiID := reaction.RelatesTo.Key
if strings.HasPrefix(emojiID, "mxc://") {
uri, _ := id.ParseContentURI(emojiID)
- emojiFile := portal.bridge.DB.File.GetEmojiByMXC(uri)
- if emojiFile == nil || emojiFile.ID == "" || emojiFile.EmojiName == "" {
+ emojiInfo := portal.bridge.DMA.GetEmojiInfo(uri)
+ if emojiInfo != nil {
+ emojiID = fmt.Sprintf("%s:%d", emojiInfo.Name, emojiInfo.EmojiID)
+ } else if emojiFile := portal.bridge.DB.File.GetEmojiByMXC(uri); emojiFile != nil && emojiFile.ID != "" && emojiFile.EmojiName != "" {
+ emojiID = fmt.Sprintf("%s:%s", emojiFile.EmojiName, emojiFile.ID)
+ } else {
go portal.sendMessageMetrics(evt, fmt.Errorf("%w %s", errUnknownEmoji, emojiID), "Ignoring")
return
}
-
- emojiID = fmt.Sprintf("%s:%s", emojiFile.EmojiName, emojiFile.ID)
} else {
emojiID = variationselector.FullyQualify(emojiID)
}
diff --git a/portal_convert.go b/portal_convert.go
index 3d34005..90f0100 100644
--- a/portal_convert.go
+++ b/portal_convert.go
@@ -27,12 +27,11 @@ import (
"github.com/bwmarrin/discordgo"
"github.com/rs/zerolog"
"golang.org/x/exp/slices"
- "maunium.net/go/mautrix/id"
-
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format"
+ "maunium.net/go/mautrix/id"
)
type ConvertedMessage struct {
@@ -103,20 +102,16 @@ func (portal *Portal) cleanupConvertedStickerInfo(content *event.MessageEventCon
}
func (portal *Portal) convertDiscordSticker(ctx context.Context, intent *appservice.IntentAPI, sticker *discordgo.Sticker) *ConvertedMessage {
- var mime, ext string
+ var mime string
switch sticker.FormatType {
case discordgo.StickerFormatTypePNG:
mime = "image/png"
- ext = "png"
case discordgo.StickerFormatTypeAPNG:
mime = "image/apng"
- ext = "png"
case discordgo.StickerFormatTypeLottie:
mime = "application/json"
- ext = "json"
case discordgo.StickerFormatTypeGIF:
mime = "image/gif"
- ext = "gif"
default:
zerolog.Ctx(ctx).Warn().
Int("sticker_format", int(sticker.FormatType)).
@@ -130,8 +125,9 @@ func (portal *Portal) convertDiscordSticker(ctx context.Context, intent *appserv
},
}
- mxc := portal.bridge.Config.Bridge.MediaPatterns.Sticker(sticker.ID, ext)
- if mxc.IsEmpty() {
+ mxc := portal.bridge.DMA.StickerMXC(sticker.ID, sticker.FormatType)
+ // TODO add config option to use direct media even for lottie stickers
+ if mxc.IsEmpty() && mime != "application/json" {
content = portal.convertDiscordFile(ctx, "sticker", intent, sticker.ID, sticker.URL(), content)
} else {
content.URL = mxc.CUString()
@@ -144,7 +140,7 @@ func (portal *Portal) convertDiscordSticker(ctx context.Context, intent *appserv
}
}
-func (portal *Portal) convertDiscordAttachment(ctx context.Context, intent *appservice.IntentAPI, att *discordgo.MessageAttachment) *ConvertedMessage {
+func (portal *Portal) convertDiscordAttachment(ctx context.Context, intent *appservice.IntentAPI, messageID string, att *discordgo.MessageAttachment) *ConvertedMessage {
content := &event.MessageEventContent{
Body: att.Filename,
Info: &event.FileInfo{
@@ -182,7 +178,7 @@ func (portal *Portal) convertDiscordAttachment(ctx context.Context, intent *apps
default:
content.MsgType = event.MsgFile
}
- mxc := portal.bridge.Config.Bridge.MediaPatterns.Attachment(portal.Key.ChannelID, att.ID, att.Filename)
+ mxc := portal.bridge.DMA.AttachmentMXC(portal.Key.ChannelID, messageID, att)
if mxc.IsEmpty() {
content = portal.convertDiscordFile(ctx, "attachment", intent, att.ID, att.URL, content)
} else {
@@ -287,7 +283,7 @@ func (portal *Portal) convertDiscordMessage(ctx context.Context, puppet *Puppet,
}
handledIDs[att.ID] = struct{}{}
log := log.With().Str("attachment_id", att.ID).Logger()
- if part := portal.convertDiscordAttachment(log.WithContext(ctx), intent, att); part != nil {
+ if part := portal.convertDiscordAttachment(log.WithContext(ctx), intent, msg.ID, att); part != nil {
parts = append(parts, part)
}
}
diff --git a/puppet.go b/puppet.go
index af47d8c..ca6489e 100644
--- a/puppet.go
+++ b/puppet.go
@@ -217,27 +217,23 @@ func (puppet *Puppet) UpdateName(info *discordgo.User) bool {
}
func (br *DiscordBridge) reuploadUserAvatar(intent *appservice.IntentAPI, guildID, userID, avatarID string) (id.ContentURI, string, error) {
- var downloadURL, ext string
+ var downloadURL string
if guildID == "" {
- downloadURL = discordgo.EndpointUserAvatar(userID, avatarID)
- ext = "png"
if strings.HasPrefix(avatarID, "a_") {
downloadURL = discordgo.EndpointUserAvatarAnimated(userID, avatarID)
- ext = "gif"
+ } else {
+ downloadURL = discordgo.EndpointUserAvatar(userID, avatarID)
}
} else {
- downloadURL = discordgo.EndpointGuildMemberAvatar(guildID, userID, avatarID)
- ext = "png"
if strings.HasPrefix(avatarID, "a_") {
downloadURL = discordgo.EndpointGuildMemberAvatarAnimated(guildID, userID, avatarID)
- ext = "gif"
+ } else {
+ downloadURL = discordgo.EndpointGuildMemberAvatar(guildID, userID, avatarID)
}
}
- if guildID == "" {
- url := br.Config.Bridge.MediaPatterns.Avatar(userID, avatarID, ext)
- if !url.IsEmpty() {
- return url, downloadURL, nil
- }
+ url := br.DMA.AvatarMXC(guildID, userID, avatarID)
+ if !url.IsEmpty() {
+ return url, downloadURL, nil
}
copied, err := br.copyAttachmentToMatrix(intent, downloadURL, false, AttachmentMeta{
AttachmentID: fmt.Sprintf("avatar/%s/%s/%s", guildID, userID, avatarID),
diff --git a/user.go b/user.go
index 174ed0d..1f08fba 100644
--- a/user.go
+++ b/user.go
@@ -195,6 +195,12 @@ func (br *DiscordBridge) GetCachedUserByID(id string) *User {
return br.usersByID[id]
}
+func (br *DiscordBridge) GetCachedUserByMXID(userID id.UserID) *User {
+ br.usersLock.Lock()
+ defer br.usersLock.Unlock()
+ return br.usersByMXID[userID]
+}
+
func (br *DiscordBridge) NewUser(dbUser *database.User) *User {
user := &User{
User: dbUser,