2022-02-06 19:08:30 -06:00
package database
import (
"database/sql"
"errors"
2022-06-27 10:53:49 +03:00
"fmt"
"strings"
2022-02-06 19:08:30 -06:00
"time"
2023-08-17 00:54:38 +03:00
"go.mau.fi/util/dbutil"
2022-02-06 19:08:30 -06:00
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
)
2022-05-27 15:58:09 +03:00
type MessageQuery struct {
2022-02-06 19:08:30 -06:00
db * Database
log log . Logger
2022-05-27 15:58:09 +03:00
}
2022-02-06 19:08:30 -06:00
2022-05-27 15:58:09 +03:00
const (
2023-05-24 13:18:23 +03:00
messageSelect = "SELECT dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid, sender_mxid FROM message"
2022-05-27 15:58:09 +03:00
)
2022-02-06 19:08:30 -06:00
2022-05-27 15:58:09 +03:00
func ( mq * MessageQuery ) New ( ) * Message {
return & Message {
db : mq . db ,
log : mq . log ,
}
}
2022-10-28 23:45:35 +03:00
func ( mq * MessageQuery ) scanAll ( rows dbutil . Rows , err error ) [ ] * Message {
2022-06-27 10:53:49 +03:00
if err != nil {
mq . log . Warnfln ( "Failed to query many messages: %v" , err )
panic ( err )
} else if rows == nil {
2022-05-27 15:58:09 +03:00
return nil
}
2022-02-06 19:08:30 -06:00
2022-05-27 15:58:09 +03:00
var messages [ ] * Message
for rows . Next ( ) {
messages = append ( messages , mq . New ( ) . Scan ( rows ) )
}
return messages
}
2022-06-27 10:53:49 +03:00
func ( mq * MessageQuery ) GetByDiscordID ( key PortalKey , discordID string ) [ ] * Message {
2023-05-06 22:10:00 +03:00
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 ORDER BY dc_attachment_id ASC"
2022-06-27 10:53:49 +03:00
return mq . scanAll ( mq . db . Query ( query , key . ChannelID , key . Receiver , discordID ) )
}
func ( mq * MessageQuery ) GetFirstByDiscordID ( key PortalKey , discordID string ) * Message {
2023-05-06 22:10:00 +03:00
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 ORDER BY dc_attachment_id ASC LIMIT 1"
2022-05-28 23:03:24 +03:00
return mq . New ( ) . Scan ( mq . db . QueryRow ( query , key . ChannelID , key . Receiver , discordID ) )
}
2022-05-27 15:58:09 +03:00
2022-07-08 11:54:09 +03:00
func ( mq * MessageQuery ) GetLastByDiscordID ( key PortalKey , discordID string ) * Message {
2023-05-06 22:10:00 +03:00
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 ORDER BY dc_attachment_id DESC LIMIT 1"
2022-07-08 11:54:09 +03:00
return mq . New ( ) . Scan ( mq . db . QueryRow ( query , key . ChannelID , key . Receiver , discordID ) )
}
2022-10-28 23:35:31 +03:00
func ( mq * MessageQuery ) GetClosestBefore ( key PortalKey , threadID string , ts time . Time ) * Message {
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 AND timestamp<=$4 ORDER BY timestamp DESC, dc_attachment_id DESC LIMIT 1"
return mq . New ( ) . Scan ( mq . db . QueryRow ( query , key . ChannelID , key . Receiver , threadID , ts . UnixMilli ( ) ) )
2022-07-08 11:54:09 +03:00
}
2022-05-28 23:03:24 +03:00
func ( mq * MessageQuery ) GetLastInThread ( key PortalKey , threadID string ) * Message {
2023-05-06 22:10:00 +03:00
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 ORDER BY timestamp DESC, dc_attachment_id DESC LIMIT 1"
2022-05-28 23:03:24 +03:00
return mq . New ( ) . Scan ( mq . db . QueryRow ( query , key . ChannelID , key . Receiver , threadID ) )
2022-05-27 15:58:09 +03:00
}
2023-04-14 18:48:35 +03:00
func ( mq * MessageQuery ) GetLast ( key PortalKey ) * Message {
2023-05-06 22:10:00 +03:00
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 ORDER BY timestamp DESC LIMIT 1"
2023-04-14 18:48:35 +03:00
return mq . New ( ) . Scan ( mq . db . QueryRow ( query , key . ChannelID , key . Receiver ) )
}
2022-07-09 16:51:43 +03:00
func ( mq * MessageQuery ) DeleteAll ( key PortalKey ) {
query := "DELETE FROM message WHERE dc_chan_id=$1 AND dc_chan_receiver=$2"
_ , err := mq . db . Exec ( query , key . ChannelID , key . Receiver )
if err != nil {
mq . log . Warnfln ( "Failed to delete messages of %s: %v" , key , err )
panic ( err )
}
}
2022-05-27 15:58:09 +03:00
func ( mq * MessageQuery ) GetByMXID ( key PortalKey , mxid id . EventID ) * Message {
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND mxid=$3"
row := mq . db . QueryRow ( query , key . ChannelID , key . Receiver , mxid )
if row == nil {
return nil
}
return mq . New ( ) . Scan ( row )
}
2023-04-16 15:06:02 +03:00
func ( mq * MessageQuery ) MassInsert ( key PortalKey , msgs [ ] Message ) {
if len ( msgs ) == 0 {
return
}
2023-05-24 13:18:23 +03:00
valueStringFormat := "($%d, $%d, $1, $2, $%d, $%d, $%d, $%d, $%d, $%d)"
2023-04-16 15:06:02 +03:00
if mq . db . Dialect == dbutil . SQLite {
valueStringFormat = strings . ReplaceAll ( valueStringFormat , "$" , "?" )
}
2023-05-24 13:18:23 +03:00
params := make ( [ ] interface { } , 2 + len ( msgs ) * 8 )
2023-04-16 15:06:02 +03:00
placeholders := make ( [ ] string , len ( msgs ) )
params [ 0 ] = key . ChannelID
params [ 1 ] = key . Receiver
for i , msg := range msgs {
2024-02-18 23:10:31 +02:00
baseIndex := 2 + i * 8
2023-04-16 15:06:02 +03:00
params [ baseIndex ] = msg . DiscordID
params [ baseIndex + 1 ] = msg . AttachmentID
2023-05-06 22:10:00 +03:00
params [ baseIndex + 2 ] = msg . SenderID
params [ baseIndex + 3 ] = msg . Timestamp . UnixMilli ( )
params [ baseIndex + 4 ] = msg . editTimestampVal ( )
2023-04-16 15:06:02 +03:00
params [ baseIndex + 5 ] = msg . ThreadID
params [ baseIndex + 6 ] = msg . MXID
2023-05-24 13:18:23 +03:00
params [ baseIndex + 7 ] = msg . SenderMXID . String ( )
placeholders [ i ] = fmt . Sprintf ( valueStringFormat , baseIndex + 1 , baseIndex + 2 , baseIndex + 3 , baseIndex + 4 , baseIndex + 5 , baseIndex + 6 , baseIndex + 7 , baseIndex + 8 )
2023-04-16 15:06:02 +03:00
}
_ , err := mq . db . Exec ( fmt . Sprintf ( messageMassInsertTemplate , strings . Join ( placeholders , ", " ) ) , params ... )
if err != nil {
mq . log . Warnfln ( "Failed to insert %d messages: %v" , len ( msgs ) , err )
panic ( err )
}
}
2022-05-27 15:58:09 +03:00
type Message struct {
db * Database
log log . Logger
2023-05-06 22:10:00 +03:00
DiscordID string
AttachmentID string
Channel PortalKey
SenderID string
Timestamp time . Time
EditTimestamp time . Time
ThreadID string
2022-05-27 15:58:09 +03:00
2023-05-24 13:18:23 +03:00
MXID id . EventID
SenderMXID id . UserID
2022-02-06 19:08:30 -06:00
}
2022-05-28 23:03:24 +03:00
func ( m * Message ) DiscordProtoChannelID ( ) string {
if m . ThreadID != "" {
return m . ThreadID
} else {
return m . Channel . ChannelID
}
}
2022-05-22 22:16:42 +03:00
func ( m * Message ) Scan ( row dbutil . Scannable ) * Message {
2023-05-06 22:10:00 +03:00
var ts , editTS int64
2022-02-06 19:08:30 -06:00
2023-05-24 13:18:23 +03:00
err := row . Scan ( & m . DiscordID , & m . AttachmentID , & m . Channel . ChannelID , & m . Channel . Receiver , & m . SenderID , & ts , & editTS , & m . ThreadID , & m . MXID , & m . SenderMXID )
2022-02-06 19:08:30 -06:00
if err != nil {
if ! errors . Is ( err , sql . ErrNoRows ) {
m . log . Errorln ( "Database scan failed:" , err )
2022-05-28 23:03:24 +03:00
panic ( err )
2022-02-06 19:08:30 -06:00
}
return nil
}
if ts != 0 {
2023-05-06 22:10:00 +03:00
m . Timestamp = time . UnixMilli ( ts ) . UTC ( )
}
if editTS != 0 {
m . EditTimestamp = time . Unix ( 0 , editTS ) . UTC ( )
2022-02-06 19:08:30 -06:00
}
return m
}
2022-06-27 10:53:49 +03:00
const messageInsertQuery = `
INSERT INTO message (
2023-05-24 13:18:23 +03:00
dcid , dc_attachment_id , dc_chan_id , dc_chan_receiver , dc_sender , timestamp , dc_edit_timestamp , dc_thread_id , mxid , sender_mxid
2022-06-27 10:53:49 +03:00
)
2023-05-24 13:18:23 +03:00
VALUES ( $ 1 , $ 2 , $ 3 , $ 4 , $ 5 , $ 6 , $ 7 , $ 8 , $ 9 , $ 10 )
2022-06-27 10:53:49 +03:00
`
2023-05-24 13:18:23 +03:00
var messageMassInsertTemplate = strings . Replace ( messageInsertQuery , "($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)" , "%s" , 1 )
2022-06-27 10:53:49 +03:00
type MessagePart struct {
AttachmentID string
MXID id . EventID
}
2022-02-06 19:08:30 -06:00
2023-05-06 22:10:00 +03:00
func ( m * Message ) editTimestampVal ( ) int64 {
if m . EditTimestamp . IsZero ( ) {
return 0
}
return m . EditTimestamp . UnixNano ( )
}
2023-04-16 15:06:02 +03:00
func ( m * Message ) MassInsertParts ( msgs [ ] MessagePart ) {
2022-07-02 23:18:49 +03:00
if len ( msgs ) == 0 {
return
}
2023-05-24 13:18:23 +03:00
valueStringFormat := "($1, $%d, $2, $3, $4, $5, $6, $7, $%d, $8)"
2022-06-27 10:53:49 +03:00
if m . db . Dialect == dbutil . SQLite {
valueStringFormat = strings . ReplaceAll ( valueStringFormat , "$" , "?" )
}
2023-05-24 13:18:23 +03:00
params := make ( [ ] interface { } , 8 + len ( msgs ) * 2 )
2022-06-27 10:53:49 +03:00
placeholders := make ( [ ] string , len ( msgs ) )
params [ 0 ] = m . DiscordID
2023-05-06 22:10:00 +03:00
params [ 1 ] = m . Channel . ChannelID
params [ 2 ] = m . Channel . Receiver
params [ 3 ] = m . SenderID
params [ 4 ] = m . Timestamp . UnixMilli ( )
params [ 5 ] = m . editTimestampVal ( )
2022-06-27 10:53:49 +03:00
params [ 6 ] = m . ThreadID
2023-05-24 13:18:23 +03:00
params [ 7 ] = m . SenderMXID . String ( )
2022-06-27 10:53:49 +03:00
for i , msg := range msgs {
2023-05-24 13:18:23 +03:00
params [ 8 + i * 2 ] = msg . AttachmentID
params [ 8 + i * 2 + 1 ] = msg . MXID
placeholders [ i ] = fmt . Sprintf ( valueStringFormat , 8 + i * 2 + 1 , 8 + i * 2 + 2 )
2022-06-27 10:53:49 +03:00
}
_ , err := m . db . Exec ( fmt . Sprintf ( messageMassInsertTemplate , strings . Join ( placeholders , ", " ) ) , params ... )
if err != nil {
m . log . Warnfln ( "Failed to insert %d parts of %s@%s: %v" , len ( msgs ) , m . DiscordID , m . Channel , err )
panic ( err )
}
}
func ( m * Message ) Insert ( ) {
_ , err := m . db . Exec ( messageInsertQuery ,
2023-05-06 22:10:00 +03:00
m . DiscordID , m . AttachmentID , m . Channel . ChannelID , m . Channel . Receiver , m . SenderID ,
2023-05-24 13:18:23 +03:00
m . Timestamp . UnixMilli ( ) , m . editTimestampVal ( ) , m . ThreadID , m . MXID , m . SenderMXID . String ( ) )
2022-02-06 19:08:30 -06:00
if err != nil {
2022-05-27 15:58:09 +03:00
m . log . Warnfln ( "Failed to insert %s@%s: %v" , m . DiscordID , m . Channel , err )
2022-05-28 23:03:24 +03:00
panic ( err )
2022-02-06 19:08:30 -06:00
}
}
2023-05-06 22:10:00 +03:00
const editUpdateQuery = `
UPDATE message
SET dc_edit_timestamp = $ 1
WHERE dcid = $ 2 AND dc_attachment_id = $ 3 AND dc_chan_id = $ 4 AND dc_chan_receiver = $ 5 AND dc_edit_timestamp < $ 1
`
func ( m * Message ) UpdateEditTimestamp ( ts time . Time ) {
_ , err := m . db . Exec ( editUpdateQuery , ts . UnixNano ( ) , m . DiscordID , m . AttachmentID , m . Channel . ChannelID , m . Channel . Receiver )
if err != nil {
m . log . Warnfln ( "Failed to update edit timestamp of %s@%s: %v" , m . DiscordID , m . Channel , err )
panic ( err )
}
}
2022-02-08 03:51:29 -06:00
func ( m * Message ) Delete ( ) {
2022-06-27 16:13:26 +03:00
query := "DELETE FROM message WHERE dcid=$1 AND dc_chan_id=$2 AND dc_chan_receiver=$3 AND dc_attachment_id=$4"
_ , err := m . db . Exec ( query , m . DiscordID , m . Channel . ChannelID , m . Channel . Receiver , m . AttachmentID )
2022-02-06 19:08:30 -06:00
if err != nil {
2022-06-27 16:13:26 +03:00
m . log . Warnfln ( "Failed to delete %q of %s@%s: %v" , m . AttachmentID , m . DiscordID , m . Channel , err )
2022-05-28 23:03:24 +03:00
panic ( err )
2022-02-06 19:08:30 -06:00
}
}