mautrix-discord/database/message.go

259 lines
7.9 KiB
Go
Raw Permalink Normal View History

2022-02-06 19:08:30 -06:00
package database
import (
"database/sql"
"errors"
"fmt"
"strings"
2022-02-06 19:08:30 -06:00
"time"
2023-08-17 00:54:38 +03:00
"go.mau.fi/util/dbutil"
2022-02-06 19:08:30 -06:00
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
)
2022-05-27 15:58:09 +03:00
type MessageQuery struct {
2022-02-06 19:08:30 -06:00
db *Database
log log.Logger
2022-05-27 15:58:09 +03:00
}
2022-02-06 19:08:30 -06:00
2022-05-27 15:58:09 +03:00
const (
2023-05-24 13:18:23 +03:00
messageSelect = "SELECT dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid, sender_mxid FROM message"
2022-05-27 15:58:09 +03:00
)
2022-02-06 19:08:30 -06:00
2022-05-27 15:58:09 +03:00
func (mq *MessageQuery) New() *Message {
return &Message{
db: mq.db,
log: mq.log,
}
}
func (mq *MessageQuery) scanAll(rows dbutil.Rows, err error) []*Message {
if err != nil {
mq.log.Warnfln("Failed to query many messages: %v", err)
panic(err)
} else if rows == nil {
2022-05-27 15:58:09 +03:00
return nil
}
2022-02-06 19:08:30 -06:00
2022-05-27 15:58:09 +03:00
var messages []*Message
for rows.Next() {
messages = append(messages, mq.New().Scan(rows))
}
return messages
}
func (mq *MessageQuery) GetByDiscordID(key PortalKey, discordID string) []*Message {
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 ORDER BY dc_attachment_id ASC"
return mq.scanAll(mq.db.Query(query, key.ChannelID, key.Receiver, discordID))
}
func (mq *MessageQuery) GetFirstByDiscordID(key PortalKey, discordID string) *Message {
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 ORDER BY dc_attachment_id ASC LIMIT 1"
2022-05-28 23:03:24 +03:00
return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID))
}
2022-05-27 15:58:09 +03:00
func (mq *MessageQuery) GetLastByDiscordID(key PortalKey, discordID string) *Message {
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 ORDER BY dc_attachment_id DESC LIMIT 1"
return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID))
}
func (mq *MessageQuery) GetClosestBefore(key PortalKey, threadID string, ts time.Time) *Message {
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 AND timestamp<=$4 ORDER BY timestamp DESC, dc_attachment_id DESC LIMIT 1"
return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, threadID, ts.UnixMilli()))
}
2022-05-28 23:03:24 +03:00
func (mq *MessageQuery) GetLastInThread(key PortalKey, threadID string) *Message {
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 ORDER BY timestamp DESC, dc_attachment_id DESC LIMIT 1"
2022-05-28 23:03:24 +03:00
return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, threadID))
2022-05-27 15:58:09 +03:00
}
func (mq *MessageQuery) GetLast(key PortalKey) *Message {
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 ORDER BY timestamp DESC LIMIT 1"
return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver))
}
2022-07-09 16:51:43 +03:00
func (mq *MessageQuery) DeleteAll(key PortalKey) {
query := "DELETE FROM message WHERE dc_chan_id=$1 AND dc_chan_receiver=$2"
_, err := mq.db.Exec(query, key.ChannelID, key.Receiver)
if err != nil {
mq.log.Warnfln("Failed to delete messages of %s: %v", key, err)
panic(err)
}
}
2022-05-27 15:58:09 +03:00
func (mq *MessageQuery) GetByMXID(key PortalKey, mxid id.EventID) *Message {
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND mxid=$3"
row := mq.db.QueryRow(query, key.ChannelID, key.Receiver, mxid)
if row == nil {
return nil
}
return mq.New().Scan(row)
}
2023-04-16 15:06:02 +03:00
func (mq *MessageQuery) MassInsert(key PortalKey, msgs []Message) {
if len(msgs) == 0 {
return
}
2023-05-24 13:18:23 +03:00
valueStringFormat := "($%d, $%d, $1, $2, $%d, $%d, $%d, $%d, $%d, $%d)"
2023-04-16 15:06:02 +03:00
if mq.db.Dialect == dbutil.SQLite {
valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?")
}
2023-05-24 13:18:23 +03:00
params := make([]interface{}, 2+len(msgs)*8)
2023-04-16 15:06:02 +03:00
placeholders := make([]string, len(msgs))
params[0] = key.ChannelID
params[1] = key.Receiver
for i, msg := range msgs {
2024-02-18 23:10:31 +02:00
baseIndex := 2 + i*8
2023-04-16 15:06:02 +03:00
params[baseIndex] = msg.DiscordID
params[baseIndex+1] = msg.AttachmentID
params[baseIndex+2] = msg.SenderID
params[baseIndex+3] = msg.Timestamp.UnixMilli()
params[baseIndex+4] = msg.editTimestampVal()
2023-04-16 15:06:02 +03:00
params[baseIndex+5] = msg.ThreadID
params[baseIndex+6] = msg.MXID
2023-05-24 13:18:23 +03:00
params[baseIndex+7] = msg.SenderMXID.String()
placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5, baseIndex+6, baseIndex+7, baseIndex+8)
2023-04-16 15:06:02 +03:00
}
_, err := mq.db.Exec(fmt.Sprintf(messageMassInsertTemplate, strings.Join(placeholders, ", ")), params...)
if err != nil {
mq.log.Warnfln("Failed to insert %d messages: %v", len(msgs), err)
panic(err)
}
}
2022-05-27 15:58:09 +03:00
type Message struct {
db *Database
log log.Logger
DiscordID string
AttachmentID string
Channel PortalKey
SenderID string
Timestamp time.Time
EditTimestamp time.Time
ThreadID string
2022-05-27 15:58:09 +03:00
2023-05-24 13:18:23 +03:00
MXID id.EventID
SenderMXID id.UserID
2022-02-06 19:08:30 -06:00
}
2022-05-28 23:03:24 +03:00
func (m *Message) DiscordProtoChannelID() string {
if m.ThreadID != "" {
return m.ThreadID
} else {
return m.Channel.ChannelID
}
}
func (m *Message) Scan(row dbutil.Scannable) *Message {
var ts, editTS int64
2022-02-06 19:08:30 -06:00
2023-05-24 13:18:23 +03:00
err := row.Scan(&m.DiscordID, &m.AttachmentID, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &editTS, &m.ThreadID, &m.MXID, &m.SenderMXID)
2022-02-06 19:08:30 -06:00
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
m.log.Errorln("Database scan failed:", err)
2022-05-28 23:03:24 +03:00
panic(err)
2022-02-06 19:08:30 -06:00
}
return nil
}
if ts != 0 {
m.Timestamp = time.UnixMilli(ts).UTC()
}
if editTS != 0 {
m.EditTimestamp = time.Unix(0, editTS).UTC()
2022-02-06 19:08:30 -06:00
}
return m
}
const messageInsertQuery = `
INSERT INTO message (
2023-05-24 13:18:23 +03:00
dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid, sender_mxid
)
2023-05-24 13:18:23 +03:00
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
`
2023-05-24 13:18:23 +03:00
var messageMassInsertTemplate = strings.Replace(messageInsertQuery, "($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)", "%s", 1)
type MessagePart struct {
AttachmentID string
MXID id.EventID
}
2022-02-06 19:08:30 -06:00
func (m *Message) editTimestampVal() int64 {
if m.EditTimestamp.IsZero() {
return 0
}
return m.EditTimestamp.UnixNano()
}
2023-04-16 15:06:02 +03:00
func (m *Message) MassInsertParts(msgs []MessagePart) {
if len(msgs) == 0 {
return
}
2023-05-24 13:18:23 +03:00
valueStringFormat := "($1, $%d, $2, $3, $4, $5, $6, $7, $%d, $8)"
if m.db.Dialect == dbutil.SQLite {
valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?")
}
2023-05-24 13:18:23 +03:00
params := make([]interface{}, 8+len(msgs)*2)
placeholders := make([]string, len(msgs))
params[0] = m.DiscordID
params[1] = m.Channel.ChannelID
params[2] = m.Channel.Receiver
params[3] = m.SenderID
params[4] = m.Timestamp.UnixMilli()
params[5] = m.editTimestampVal()
params[6] = m.ThreadID
2023-05-24 13:18:23 +03:00
params[7] = m.SenderMXID.String()
for i, msg := range msgs {
2023-05-24 13:18:23 +03:00
params[8+i*2] = msg.AttachmentID
params[8+i*2+1] = msg.MXID
placeholders[i] = fmt.Sprintf(valueStringFormat, 8+i*2+1, 8+i*2+2)
}
_, err := m.db.Exec(fmt.Sprintf(messageMassInsertTemplate, strings.Join(placeholders, ", ")), params...)
if err != nil {
m.log.Warnfln("Failed to insert %d parts of %s@%s: %v", len(msgs), m.DiscordID, m.Channel, err)
panic(err)
}
}
func (m *Message) Insert() {
_, err := m.db.Exec(messageInsertQuery,
m.DiscordID, m.AttachmentID, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID,
2023-05-24 13:18:23 +03:00
m.Timestamp.UnixMilli(), m.editTimestampVal(), m.ThreadID, m.MXID, m.SenderMXID.String())
2022-02-06 19:08:30 -06:00
if err != nil {
2022-05-27 15:58:09 +03:00
m.log.Warnfln("Failed to insert %s@%s: %v", m.DiscordID, m.Channel, err)
2022-05-28 23:03:24 +03:00
panic(err)
2022-02-06 19:08:30 -06:00
}
}
const editUpdateQuery = `
UPDATE message
SET dc_edit_timestamp=$1
WHERE dcid=$2 AND dc_attachment_id=$3 AND dc_chan_id=$4 AND dc_chan_receiver=$5 AND dc_edit_timestamp<$1
`
func (m *Message) UpdateEditTimestamp(ts time.Time) {
_, err := m.db.Exec(editUpdateQuery, ts.UnixNano(), m.DiscordID, m.AttachmentID, m.Channel.ChannelID, m.Channel.Receiver)
if err != nil {
m.log.Warnfln("Failed to update edit timestamp of %s@%s: %v", m.DiscordID, m.Channel, err)
panic(err)
}
}
func (m *Message) Delete() {
query := "DELETE FROM message WHERE dcid=$1 AND dc_chan_id=$2 AND dc_chan_receiver=$3 AND dc_attachment_id=$4"
_, err := m.db.Exec(query, m.DiscordID, m.Channel.ChannelID, m.Channel.Receiver, m.AttachmentID)
2022-02-06 19:08:30 -06:00
if err != nil {
m.log.Warnfln("Failed to delete %q of %s@%s: %v", m.AttachmentID, m.DiscordID, m.Channel, err)
2022-05-28 23:03:24 +03:00
panic(err)
2022-02-06 19:08:30 -06:00
}
}