mirror of
https://github.com/mautrix/whatsapp.git
synced 2025-03-14 14:15:38 +00:00
Move a bunch of stuff to mautrix-go
See d578d1a610
Database upgrades from before v0.4.0 were squashed, users must update
to at least v0.4.0 before updating beyond this commit.
This commit is contained in:
parent
42a4839a4e
commit
a948ea0146
83 changed files with 627 additions and 2838 deletions
|
@ -114,18 +114,18 @@ func (pong *BridgeState) shouldDeduplicate(newPong *BridgeState) bool {
|
|||
return pong.Timestamp+int64(pong.TTL/5) > time.Now().Unix()
|
||||
}
|
||||
|
||||
func (bridge *Bridge) sendBridgeState(ctx context.Context, state *BridgeState) error {
|
||||
func (br *WABridge) sendBridgeState(ctx context.Context, state *BridgeState) error {
|
||||
var body bytes.Buffer
|
||||
if err := json.NewEncoder(&body).Encode(&state); err != nil {
|
||||
return fmt.Errorf("failed to encode bridge state JSON: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, bridge.Config.Homeserver.StatusEndpoint, &body)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, br.Config.Homeserver.StatusEndpoint, &body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to prepare request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+bridge.Config.AppService.ASToken)
|
||||
req.Header.Set("Authorization", "Bearer "+br.Config.AppService.ASToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
|
@ -143,17 +143,17 @@ func (bridge *Bridge) sendBridgeState(ctx context.Context, state *BridgeState) e
|
|||
return nil
|
||||
}
|
||||
|
||||
func (bridge *Bridge) sendGlobalBridgeState(state BridgeState) {
|
||||
if len(bridge.Config.Homeserver.StatusEndpoint) == 0 {
|
||||
func (br *WABridge) sendGlobalBridgeState(state BridgeState) {
|
||||
if len(br.Config.Homeserver.StatusEndpoint) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if err := bridge.sendBridgeState(ctx, &state); err != nil {
|
||||
bridge.Log.Warnln("Failed to update global bridge state:", err)
|
||||
if err := br.sendBridgeState(ctx, &state); err != nil {
|
||||
br.Log.Warnln("Failed to update global bridge state:", err)
|
||||
} else {
|
||||
bridge.Log.Debugfln("Sent new global bridge state %+v", state)
|
||||
br.Log.Debugfln("Sent new global bridge state %+v", state)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
19
commands.go
19
commands.go
|
@ -39,6 +39,7 @@ import (
|
|||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/appservice"
|
||||
"maunium.net/go/mautrix/bridge"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/format"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
@ -47,12 +48,12 @@ import (
|
|||
)
|
||||
|
||||
type CommandHandler struct {
|
||||
bridge *Bridge
|
||||
bridge *WABridge
|
||||
log maulogger.Logger
|
||||
}
|
||||
|
||||
// NewCommandHandler creates a CommandHandler
|
||||
func NewCommandHandler(bridge *Bridge) *CommandHandler {
|
||||
func NewCommandHandler(bridge *WABridge) *CommandHandler {
|
||||
return &CommandHandler{
|
||||
bridge: bridge,
|
||||
log: bridge.Log.Sub("Command handler"),
|
||||
|
@ -62,7 +63,7 @@ func NewCommandHandler(bridge *Bridge) *CommandHandler {
|
|||
// CommandEvent stores all data which might be used to handle commands
|
||||
type CommandEvent struct {
|
||||
Bot *appservice.IntentAPI
|
||||
Bridge *Bridge
|
||||
Bridge *WABridge
|
||||
Portal *Portal
|
||||
Handler *CommandHandler
|
||||
RoomID id.RoomID
|
||||
|
@ -251,13 +252,7 @@ func (handler *CommandHandler) CommandDevTest(_ *CommandEvent) {
|
|||
const cmdVersionHelp = `version - View the bridge version`
|
||||
|
||||
func (handler *CommandHandler) CommandVersion(ce *CommandEvent) {
|
||||
linkifiedVersion := fmt.Sprintf("v%s", Version)
|
||||
if Tag == Version {
|
||||
linkifiedVersion = fmt.Sprintf("[v%s](%s/releases/v%s)", Version, URL, Tag)
|
||||
} else if len(Commit) > 8 {
|
||||
linkifiedVersion = strings.Replace(linkifiedVersion, Commit[:8], fmt.Sprintf("[%s](%s/commit/%s)", Commit[:8], URL, Commit), 1)
|
||||
}
|
||||
ce.Reply(fmt.Sprintf("[%s](%s) %s (%s)", Name, URL, linkifiedVersion, BuildTime))
|
||||
ce.Reply(fmt.Sprintf("[%s](%s) %s (%s)", ce.Bridge.Name, ce.Bridge.URL, ce.Bridge.LinkifiedVersion, BuildTime))
|
||||
}
|
||||
|
||||
const cmdInviteLinkHelp = `invite-link [--reset] - Get an invite link to the current group chat, optionally regenerating the link and revoking the old link.`
|
||||
|
@ -331,7 +326,7 @@ func (handler *CommandHandler) CommandJoin(ce *CommandEvent) {
|
|||
ce.Reply("Successfully joined group `%s`, the portal should be created momentarily", jid)
|
||||
}
|
||||
|
||||
func tryDecryptEvent(crypto Crypto, evt *event.Event) (json.RawMessage, error) {
|
||||
func tryDecryptEvent(crypto bridge.Crypto, evt *event.Event) (json.RawMessage, error) {
|
||||
var data json.RawMessage
|
||||
if evt.Type != event.EventEncrypted {
|
||||
data = evt.Content.VeryRaw
|
||||
|
@ -903,7 +898,7 @@ func matchesQuery(str string, query string) bool {
|
|||
return strings.Contains(strings.ToLower(str), query)
|
||||
}
|
||||
|
||||
func formatContacts(bridge *Bridge, input map[types.JID]types.ContactInfo, query string) (result []string) {
|
||||
func formatContacts(bridge *WABridge, input map[types.JID]types.ContactInfo, query string) (result []string) {
|
||||
hasQuery := len(query) > 0
|
||||
for jid, contact := range input {
|
||||
if len(contact.FullName) == 0 {
|
||||
|
|
|
@ -24,6 +24,7 @@ import (
|
|||
|
||||
"go.mau.fi/whatsmeow/types"
|
||||
|
||||
"maunium.net/go/mautrix/bridge/bridgeconfig"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
@ -118,23 +119,23 @@ type BridgeConfig struct {
|
|||
AdditionalHelp string `yaml:"additional_help"`
|
||||
} `yaml:"management_room_text"`
|
||||
|
||||
Encryption struct {
|
||||
Allow bool `yaml:"allow"`
|
||||
Default bool `yaml:"default"`
|
||||
Encryption bridgeconfig.EncryptionConfig `yaml:"encryption"`
|
||||
|
||||
KeySharing struct {
|
||||
Allow bool `yaml:"allow"`
|
||||
RequireCrossSigning bool `yaml:"require_cross_signing"`
|
||||
RequireVerification bool `yaml:"require_verification"`
|
||||
} `yaml:"key_sharing"`
|
||||
} `yaml:"encryption"`
|
||||
Provisioning struct {
|
||||
Prefix string `yaml:"prefix"`
|
||||
SharedSecret string `yaml:"shared_secret"`
|
||||
} `yaml:"provisioning"`
|
||||
|
||||
Permissions PermissionConfig `yaml:"permissions"`
|
||||
|
||||
Relay RelaybotConfig `yaml:"relay"`
|
||||
|
||||
usernameTemplate *template.Template `yaml:"-"`
|
||||
displaynameTemplate *template.Template `yaml:"-"`
|
||||
ParsedUsernameTemplate *template.Template `yaml:"-"`
|
||||
displaynameTemplate *template.Template `yaml:"-"`
|
||||
}
|
||||
|
||||
func (bc BridgeConfig) GetEncryptionConfig() bridgeconfig.EncryptionConfig {
|
||||
return bc.Encryption
|
||||
}
|
||||
|
||||
type umBridgeConfig BridgeConfig
|
||||
|
@ -145,7 +146,7 @@ func (bc *BridgeConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
|||
return err
|
||||
}
|
||||
|
||||
bc.usernameTemplate, err = template.New("username").Parse(bc.UsernameTemplate)
|
||||
bc.ParsedUsernameTemplate, err = template.New("username").Parse(bc.UsernameTemplate)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if !strings.Contains(bc.FormatUsername("1234567890"), "1234567890") {
|
||||
|
@ -206,7 +207,7 @@ func (bc BridgeConfig) FormatDisplayname(jid types.JID, contact types.ContactInf
|
|||
|
||||
func (bc BridgeConfig) FormatUsername(username string) string {
|
||||
var buf strings.Builder
|
||||
_ = bc.usernameTemplate.Execute(&buf, username)
|
||||
_ = bc.ParsedUsernameTemplate.Execute(&buf, username)
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
|
|
|
@ -17,52 +17,12 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"maunium.net/go/mautrix/appservice"
|
||||
"maunium.net/go/mautrix/bridge/bridgeconfig"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
var ExampleConfig string
|
||||
|
||||
type Config struct {
|
||||
Homeserver struct {
|
||||
Address string `yaml:"address"`
|
||||
Domain string `yaml:"domain"`
|
||||
Asmux bool `yaml:"asmux"`
|
||||
StatusEndpoint string `yaml:"status_endpoint"`
|
||||
MessageSendCheckpointEndpoint string `yaml:"message_send_checkpoint_endpoint"`
|
||||
AsyncMedia bool `yaml:"async_media"`
|
||||
} `yaml:"homeserver"`
|
||||
|
||||
AppService struct {
|
||||
Address string `yaml:"address"`
|
||||
Hostname string `yaml:"hostname"`
|
||||
Port uint16 `yaml:"port"`
|
||||
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
|
||||
Provisioning struct {
|
||||
Prefix string `yaml:"prefix"`
|
||||
SharedSecret string `yaml:"shared_secret"`
|
||||
} `yaml:"provisioning"`
|
||||
|
||||
ID string `yaml:"id"`
|
||||
Bot struct {
|
||||
Username string `yaml:"username"`
|
||||
Displayname string `yaml:"displayname"`
|
||||
Avatar string `yaml:"avatar"`
|
||||
|
||||
ParsedAvatar id.ContentURI `yaml:"-"`
|
||||
} `yaml:"bot"`
|
||||
|
||||
EphemeralEvents bool `yaml:"ephemeral_events"`
|
||||
|
||||
ASToken string `yaml:"as_token"`
|
||||
HSToken string `yaml:"hs_token"`
|
||||
} `yaml:"appservice"`
|
||||
*bridgeconfig.BaseConfig `yaml:",inline"`
|
||||
|
||||
SegmentKey string `yaml:"segment_key"`
|
||||
|
||||
|
@ -77,8 +37,6 @@ type Config struct {
|
|||
} `yaml:"whatsapp"`
|
||||
|
||||
Bridge BridgeConfig `yaml:"bridge"`
|
||||
|
||||
Logging appservice.LogConfig `yaml:"logging"`
|
||||
}
|
||||
|
||||
func (config *Config) CanAutoDoublePuppet(userID id.UserID) bool {
|
||||
|
@ -98,44 +56,3 @@ func (config *Config) CanDoublePuppetBackfill(userID id.UserID) bool {
|
|||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func Load(data []byte, upgraded bool) (*Config, error) {
|
||||
var config = &Config{}
|
||||
if !upgraded {
|
||||
// Fallback: if config upgrading failed, load example config for base values
|
||||
err := yaml.Unmarshal([]byte(ExampleConfig), config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal example config: %w", err)
|
||||
}
|
||||
}
|
||||
err := yaml.Unmarshal(data, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return config, err
|
||||
}
|
||||
|
||||
func (config *Config) MakeAppService() (*appservice.AppService, error) {
|
||||
as := appservice.Create()
|
||||
as.HomeserverDomain = config.Homeserver.Domain
|
||||
as.HomeserverURL = config.Homeserver.Address
|
||||
as.Host.Hostname = config.AppService.Hostname
|
||||
as.Host.Port = config.AppService.Port
|
||||
as.MessageSendCheckpointEndpoint = config.Homeserver.MessageSendCheckpointEndpoint
|
||||
as.DefaultHTTPRetries = 4
|
||||
var err error
|
||||
as.Registration, err = config.GetRegistration()
|
||||
return as, err
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
Type string `yaml:"type"`
|
||||
URI string `yaml:"uri"`
|
||||
|
||||
MaxOpenConns int `yaml:"max_open_conns"`
|
||||
MaxIdleConns int `yaml:"max_idle_conns"`
|
||||
|
||||
ConnMaxIdleTime string `yaml:"conn_max_idle_time"`
|
||||
ConnMaxLifetime string `yaml:"conn_max_lifetime"`
|
||||
}
|
||||
|
|
|
@ -1,82 +0,0 @@
|
|||
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
|
||||
// Copyright (C) 2019 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"maunium.net/go/mautrix/appservice"
|
||||
)
|
||||
|
||||
func (config *Config) NewRegistration() (*appservice.Registration, error) {
|
||||
registration := appservice.CreateRegistration()
|
||||
|
||||
err := config.copyToRegistration(registration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config.AppService.ASToken = registration.AppToken
|
||||
config.AppService.HSToken = registration.ServerToken
|
||||
|
||||
// Workaround for https://github.com/matrix-org/synapse/pull/5758
|
||||
registration.SenderLocalpart = appservice.RandomString(32)
|
||||
botRegex := regexp.MustCompile(fmt.Sprintf("^@%s:%s$",
|
||||
regexp.QuoteMeta(config.AppService.Bot.Username),
|
||||
regexp.QuoteMeta(config.Homeserver.Domain)))
|
||||
registration.Namespaces.RegisterUserIDs(botRegex, true)
|
||||
|
||||
return registration, nil
|
||||
}
|
||||
|
||||
func (config *Config) GetRegistration() (*appservice.Registration, error) {
|
||||
registration := appservice.CreateRegistration()
|
||||
|
||||
err := config.copyToRegistration(registration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
registration.AppToken = config.AppService.ASToken
|
||||
registration.ServerToken = config.AppService.HSToken
|
||||
return registration, nil
|
||||
}
|
||||
|
||||
func (config *Config) copyToRegistration(registration *appservice.Registration) error {
|
||||
registration.ID = config.AppService.ID
|
||||
registration.URL = config.AppService.Address
|
||||
falseVal := false
|
||||
registration.RateLimited = &falseVal
|
||||
registration.SenderLocalpart = config.AppService.Bot.Username
|
||||
registration.EphemeralEvents = config.AppService.EphemeralEvents
|
||||
|
||||
usernamePlaceholder := appservice.RandomString(16)
|
||||
usernameTemplate := fmt.Sprintf("@%s:%s",
|
||||
config.Bridge.FormatUsername(usernamePlaceholder),
|
||||
config.Homeserver.Domain)
|
||||
usernameTemplate = regexp.QuoteMeta(usernameTemplate)
|
||||
usernameTemplate = strings.Replace(usernameTemplate, usernamePlaceholder, "[0-9]+", 1)
|
||||
usernameTemplate = fmt.Sprintf("^%s$", usernameTemplate)
|
||||
userIDRegex, err := regexp.Compile(usernameTemplate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
registration.Namespaces.RegisterUserIDs(userIDRegex, true)
|
||||
return nil
|
||||
}
|
|
@ -20,50 +20,12 @@ import (
|
|||
"strings"
|
||||
|
||||
"maunium.net/go/mautrix/appservice"
|
||||
"maunium.net/go/mautrix/bridge/bridgeconfig"
|
||||
up "maunium.net/go/mautrix/util/configupgrade"
|
||||
)
|
||||
|
||||
type waUpgrader struct{}
|
||||
|
||||
func (wau waUpgrader) GetBase() string {
|
||||
return ExampleConfig
|
||||
}
|
||||
|
||||
func (wau waUpgrader) DoUpgrade(helper *up.Helper) {
|
||||
helper.Copy(up.Str, "homeserver", "address")
|
||||
helper.Copy(up.Str, "homeserver", "domain")
|
||||
helper.Copy(up.Bool, "homeserver", "asmux")
|
||||
helper.Copy(up.Str|up.Null, "homeserver", "status_endpoint")
|
||||
helper.Copy(up.Str|up.Null, "homeserver", "message_send_checkpoint_endpoint")
|
||||
helper.Copy(up.Bool, "homeserver", "async_media")
|
||||
|
||||
helper.Copy(up.Str, "appservice", "address")
|
||||
helper.Copy(up.Str, "appservice", "hostname")
|
||||
helper.Copy(up.Int, "appservice", "port")
|
||||
helper.Copy(up.Str, "appservice", "database", "type")
|
||||
helper.Copy(up.Str, "appservice", "database", "uri")
|
||||
helper.Copy(up.Int, "appservice", "database", "max_open_conns")
|
||||
helper.Copy(up.Int, "appservice", "database", "max_idle_conns")
|
||||
helper.Copy(up.Str|up.Null, "appservice", "database", "max_conn_idle_time")
|
||||
helper.Copy(up.Str|up.Null, "appservice", "database", "max_conn_lifetime")
|
||||
if prefix, ok := helper.Get(up.Str, "appservice", "provisioning", "prefix"); ok && strings.HasSuffix(prefix, "/v1") {
|
||||
helper.Set(up.Str, strings.TrimSuffix(prefix, "/v1"), "appservice", "provisioning", "prefix")
|
||||
} else {
|
||||
helper.Copy(up.Str, "appservice", "provisioning", "prefix")
|
||||
}
|
||||
if secret, ok := helper.Get(up.Str, "appservice", "provisioning", "shared_secret"); !ok || secret == "generate" {
|
||||
sharedSecret := appservice.RandomString(64)
|
||||
helper.Set(up.Str, sharedSecret, "appservice", "provisioning", "shared_secret")
|
||||
} else {
|
||||
helper.Copy(up.Str, "appservice", "provisioning", "shared_secret")
|
||||
}
|
||||
helper.Copy(up.Str, "appservice", "id")
|
||||
helper.Copy(up.Str, "appservice", "bot", "username")
|
||||
helper.Copy(up.Str, "appservice", "bot", "displayname")
|
||||
helper.Copy(up.Str, "appservice", "bot", "avatar")
|
||||
helper.Copy(up.Bool, "appservice", "ephemeral_events")
|
||||
helper.Copy(up.Str, "appservice", "as_token")
|
||||
helper.Copy(up.Str, "appservice", "hs_token")
|
||||
func DoUpgrade(helper *up.Helper) {
|
||||
bridgeconfig.Upgrader.DoUpgrade(helper)
|
||||
|
||||
helper.Copy(up.Str|up.Null, "segment_key")
|
||||
|
||||
|
@ -134,46 +96,41 @@ func (wau waUpgrader) DoUpgrade(helper *up.Helper) {
|
|||
helper.Copy(up.Bool, "bridge", "encryption", "key_sharing", "allow")
|
||||
helper.Copy(up.Bool, "bridge", "encryption", "key_sharing", "require_cross_signing")
|
||||
helper.Copy(up.Bool, "bridge", "encryption", "key_sharing", "require_verification")
|
||||
if prefix, ok := helper.Get(up.Str, "appservice", "provisioning", "prefix"); ok {
|
||||
helper.Set(up.Str, strings.TrimSuffix(prefix, "/v1"), "bridge", "provisioning", "prefix")
|
||||
} else {
|
||||
helper.Copy(up.Str, "bridge", "provisioning", "prefix")
|
||||
}
|
||||
if secret, ok := helper.Get(up.Str, "appservice", "provisioning", "shared_secret"); ok && secret != "generate" {
|
||||
helper.Set(up.Str, secret, "bridge", "provisioning", "shared_secret")
|
||||
} else if secret, ok = helper.Get(up.Str, "bridge", "provisioning", "shared_secret"); !ok || secret == "generate" {
|
||||
sharedSecret := appservice.RandomString(64)
|
||||
helper.Set(up.Str, sharedSecret, "bridge", "provisioning", "shared_secret")
|
||||
} else {
|
||||
helper.Copy(up.Str, "bridge", "provisioning", "shared_secret")
|
||||
}
|
||||
helper.Copy(up.Map, "bridge", "permissions")
|
||||
helper.Copy(up.Bool, "bridge", "relay", "enabled")
|
||||
helper.Copy(up.Bool, "bridge", "relay", "admin_only")
|
||||
helper.Copy(up.Map, "bridge", "relay", "message_formats")
|
||||
|
||||
helper.Copy(up.Str, "logging", "directory")
|
||||
helper.Copy(up.Str|up.Null, "logging", "file_name_format")
|
||||
helper.Copy(up.Str|up.Timestamp, "logging", "file_date_format")
|
||||
helper.Copy(up.Int, "logging", "file_mode")
|
||||
helper.Copy(up.Str|up.Timestamp, "logging", "timestamp_format")
|
||||
helper.Copy(up.Str, "logging", "print_level")
|
||||
}
|
||||
|
||||
func (wau waUpgrader) SpacedBlocks() [][]string {
|
||||
return [][]string{
|
||||
{"homeserver", "asmux"},
|
||||
{"appservice"},
|
||||
{"appservice", "hostname"},
|
||||
{"appservice", "database"},
|
||||
{"appservice", "provisioning"},
|
||||
{"appservice", "id"},
|
||||
{"appservice", "as_token"},
|
||||
{"segment_key"},
|
||||
{"metrics"},
|
||||
{"whatsapp"},
|
||||
{"bridge"},
|
||||
{"bridge", "command_prefix"},
|
||||
{"bridge", "management_room_text"},
|
||||
{"bridge", "encryption"},
|
||||
{"bridge", "permissions"},
|
||||
{"bridge", "relay"},
|
||||
{"logging"},
|
||||
}
|
||||
}
|
||||
|
||||
func Mutate(path string, mutate func(helper *up.Helper)) error {
|
||||
_, _, err := up.Do(path, true, waUpgrader{}, up.SimpleUpgrader(mutate))
|
||||
return err
|
||||
}
|
||||
|
||||
func Upgrade(path string, save bool) ([]byte, bool, error) {
|
||||
return up.Do(path, save, waUpgrader{})
|
||||
var SpacedBlocks = [][]string{
|
||||
{"homeserver", "asmux"},
|
||||
{"appservice"},
|
||||
{"appservice", "hostname"},
|
||||
{"appservice", "database"},
|
||||
{"appservice", "id"},
|
||||
{"appservice", "as_token"},
|
||||
{"segment_key"},
|
||||
{"metrics"},
|
||||
{"whatsapp"},
|
||||
{"bridge"},
|
||||
{"bridge", "command_prefix"},
|
||||
{"bridge", "management_room_text"},
|
||||
{"bridge", "encryption"},
|
||||
{"bridge", "provisioning"},
|
||||
{"bridge", "permissions"},
|
||||
{"bridge", "relay"},
|
||||
{"logging"},
|
||||
}
|
||||
|
|
327
crypto.go
327
crypto.go
|
@ -1,327 +0,0 @@
|
|||
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
|
||||
// Copyright (C) 2020 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
//go:build cgo && !nocrypto
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
|
||||
"maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/crypto"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"maunium.net/go/mautrix-whatsapp/database"
|
||||
)
|
||||
|
||||
var NoSessionFound = crypto.NoSessionFound
|
||||
|
||||
var levelTrace = maulogger.Level{
|
||||
Name: "TRACE",
|
||||
Severity: -10,
|
||||
Color: -1,
|
||||
}
|
||||
|
||||
type CryptoHelper struct {
|
||||
bridge *Bridge
|
||||
client *mautrix.Client
|
||||
mach *crypto.OlmMachine
|
||||
store *database.SQLCryptoStore
|
||||
log maulogger.Logger
|
||||
baseLog maulogger.Logger
|
||||
}
|
||||
|
||||
func init() {
|
||||
crypto.PostgresArrayWrapper = pq.Array
|
||||
}
|
||||
|
||||
func NewCryptoHelper(bridge *Bridge) Crypto {
|
||||
if !bridge.Config.Bridge.Encryption.Allow {
|
||||
bridge.Log.Debugln("Bridge built with end-to-bridge encryption, but disabled in config")
|
||||
return nil
|
||||
}
|
||||
baseLog := bridge.Log.Sub("Crypto")
|
||||
return &CryptoHelper{
|
||||
bridge: bridge,
|
||||
log: baseLog.Sub("Helper"),
|
||||
baseLog: baseLog,
|
||||
}
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) Init() error {
|
||||
helper.log.Debugln("Initializing end-to-bridge encryption...")
|
||||
|
||||
helper.store = database.NewSQLCryptoStore(helper.bridge.DB, helper.bridge.AS.BotMXID(),
|
||||
fmt.Sprintf("@%s:%s", helper.bridge.Config.Bridge.FormatUsername("%"), helper.bridge.AS.HomeserverDomain))
|
||||
|
||||
var err error
|
||||
helper.client, err = helper.loginBot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
helper.log.Debugln("Logged in as bridge bot with device ID", helper.client.DeviceID)
|
||||
logger := &cryptoLogger{helper.baseLog}
|
||||
stateStore := &cryptoStateStore{helper.bridge}
|
||||
helper.mach = crypto.NewOlmMachine(helper.client, logger, helper.store, stateStore)
|
||||
helper.mach.AllowKeyShare = helper.allowKeyShare
|
||||
|
||||
helper.client.Syncer = &cryptoSyncer{helper.mach}
|
||||
helper.client.Store = &cryptoClientStore{helper.store}
|
||||
|
||||
return helper.mach.Load()
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) allowKeyShare(device *crypto.DeviceIdentity, info event.RequestedKeyInfo) *crypto.KeyShareRejection {
|
||||
cfg := helper.bridge.Config.Bridge.Encryption.KeySharing
|
||||
if !cfg.Allow {
|
||||
return &crypto.KeyShareRejectNoResponse
|
||||
} else if device.Trust == crypto.TrustStateBlacklisted {
|
||||
return &crypto.KeyShareRejectBlacklisted
|
||||
} else if device.Trust == crypto.TrustStateVerified || !cfg.RequireVerification {
|
||||
portal := helper.bridge.GetPortalByMXID(info.RoomID)
|
||||
if portal == nil {
|
||||
helper.log.Debugfln("Rejecting key request for %s from %s/%s: room is not a portal", info.SessionID, device.UserID, device.DeviceID)
|
||||
return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnavailable, Reason: "Requested room is not a portal room"}
|
||||
}
|
||||
user := helper.bridge.GetUserByMXID(device.UserID)
|
||||
// FIXME reimplement IsInPortal
|
||||
if !user.Admin /*&& !user.IsInPortal(portal.Key)*/ {
|
||||
helper.log.Debugfln("Rejecting key request for %s from %s/%s: user is not in portal", info.SessionID, device.UserID, device.DeviceID)
|
||||
return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnauthorized, Reason: "You're not in that portal"}
|
||||
}
|
||||
helper.log.Debugfln("Accepting key request for %s from %s/%s", info.SessionID, device.UserID, device.DeviceID)
|
||||
return nil
|
||||
} else {
|
||||
return &crypto.KeyShareRejectUnverified
|
||||
}
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) loginBot() (*mautrix.Client, error) {
|
||||
deviceID := helper.store.FindDeviceID()
|
||||
if len(deviceID) > 0 {
|
||||
helper.log.Debugln("Found existing device ID for bot in database:", deviceID)
|
||||
}
|
||||
client, err := mautrix.NewClient(helper.bridge.AS.HomeserverURL, "", "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize client: %w", err)
|
||||
}
|
||||
client.Logger = helper.baseLog.Sub("Bot")
|
||||
client.Client = helper.bridge.AS.HTTPClient
|
||||
client.DefaultHTTPRetries = helper.bridge.AS.DefaultHTTPRetries
|
||||
flows, err := client.GetLoginFlows()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get supported login flows: %w", err)
|
||||
} else if !flows.HasFlow(mautrix.AuthTypeAppservice) {
|
||||
return nil, fmt.Errorf("homeserver does not support appservice login")
|
||||
}
|
||||
// We set the API token to the AS token here to authenticate the appservice login
|
||||
// It'll get overridden after the login
|
||||
client.AccessToken = helper.bridge.AS.Registration.AppToken
|
||||
resp, err := client.Login(&mautrix.ReqLogin{
|
||||
Type: mautrix.AuthTypeAppservice,
|
||||
Identifier: mautrix.UserIdentifier{Type: mautrix.IdentifierTypeUser, User: string(helper.bridge.AS.BotMXID())},
|
||||
DeviceID: deviceID,
|
||||
InitialDeviceDisplayName: "WhatsApp Bridge",
|
||||
StoreCredentials: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to log in as bridge bot: %w", err)
|
||||
}
|
||||
helper.store.DeviceID = resp.DeviceID
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) Start() {
|
||||
helper.log.Debugln("Starting syncer for receiving to-device messages")
|
||||
err := helper.client.Sync()
|
||||
if err != nil {
|
||||
helper.log.Errorln("Fatal error syncing:", err)
|
||||
} else {
|
||||
helper.log.Infoln("Bridge bot to-device syncer stopped without error")
|
||||
}
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) Stop() {
|
||||
helper.log.Debugln("CryptoHelper.Stop() called, stopping bridge bot sync")
|
||||
helper.client.StopSync()
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) Decrypt(evt *event.Event) (*event.Event, error) {
|
||||
return helper.mach.DecryptMegolmEvent(evt)
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, content event.Content) (*event.EncryptedEventContent, error) {
|
||||
encrypted, err := helper.mach.EncryptMegolmEvent(roomID, evtType, &content)
|
||||
if err != nil {
|
||||
if err != crypto.SessionExpired && err != crypto.SessionNotShared && err != crypto.NoGroupSession {
|
||||
return nil, err
|
||||
}
|
||||
helper.log.Debugfln("Got %v while encrypting event for %s, sharing group session and trying again...", err, roomID)
|
||||
users, err := helper.store.GetRoomMembers(roomID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get room member list: %w", err)
|
||||
}
|
||||
err = helper.mach.ShareGroupSession(roomID, users)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to share group session: %w", err)
|
||||
}
|
||||
encrypted, err = helper.mach.EncryptMegolmEvent(roomID, evtType, &content)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt event after re-sharing group session: %w", err)
|
||||
}
|
||||
}
|
||||
return encrypted, nil
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
|
||||
return helper.mach.WaitForSession(roomID, senderKey, sessionID, timeout)
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) RequestSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) {
|
||||
err := helper.mach.SendRoomKeyRequest(roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{userID: {deviceID}})
|
||||
if err != nil {
|
||||
helper.log.Warnfln("Failed to send key request to %s/%s for %s in %s: %v", userID, deviceID, sessionID, roomID, err)
|
||||
} else {
|
||||
helper.log.Debugfln("Sent key request to %s/%s for %s in %s", userID, deviceID, sessionID, roomID)
|
||||
}
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) ResetSession(roomID id.RoomID) {
|
||||
err := helper.mach.CryptoStore.RemoveOutboundGroupSession(roomID)
|
||||
if err != nil {
|
||||
helper.log.Debugfln("Error manually removing outbound group session in %s: %v", roomID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) HandleMemberEvent(evt *event.Event) {
|
||||
helper.mach.HandleMemberEvent(evt)
|
||||
}
|
||||
|
||||
type cryptoSyncer struct {
|
||||
*crypto.OlmMachine
|
||||
}
|
||||
|
||||
func (syncer *cryptoSyncer) ProcessResponse(resp *mautrix.RespSync, since string) error {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
syncer.Log.Error("Processing sync response (%s) panicked: %v\n%s", since, err, debug.Stack())
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
syncer.Log.Trace("Starting sync response handling (%s)", since)
|
||||
syncer.ProcessSyncResponse(resp, since)
|
||||
syncer.Log.Trace("Successfully handled sync response (%s)", since)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(30 * time.Second):
|
||||
syncer.Log.Warn("Handling sync response (%s) is taking unusually long", since)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (syncer *cryptoSyncer) OnFailedSync(_ *mautrix.RespSync, err error) (time.Duration, error) {
|
||||
syncer.Log.Error("Error /syncing, waiting 10 seconds: %v", err)
|
||||
return 10 * time.Second, nil
|
||||
}
|
||||
|
||||
func (syncer *cryptoSyncer) GetFilterJSON(_ id.UserID) *mautrix.Filter {
|
||||
everything := []event.Type{{Type: "*"}}
|
||||
return &mautrix.Filter{
|
||||
Presence: mautrix.FilterPart{NotTypes: everything},
|
||||
AccountData: mautrix.FilterPart{NotTypes: everything},
|
||||
Room: mautrix.RoomFilter{
|
||||
IncludeLeave: false,
|
||||
Ephemeral: mautrix.FilterPart{NotTypes: everything},
|
||||
AccountData: mautrix.FilterPart{NotTypes: everything},
|
||||
State: mautrix.FilterPart{NotTypes: everything},
|
||||
Timeline: mautrix.FilterPart{NotTypes: everything},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type cryptoLogger struct {
|
||||
int maulogger.Logger
|
||||
}
|
||||
|
||||
func (c *cryptoLogger) Error(message string, args ...interface{}) {
|
||||
c.int.Errorfln(message, args...)
|
||||
}
|
||||
|
||||
func (c *cryptoLogger) Warn(message string, args ...interface{}) {
|
||||
c.int.Warnfln(message, args...)
|
||||
}
|
||||
|
||||
func (c *cryptoLogger) Debug(message string, args ...interface{}) {
|
||||
c.int.Debugfln(message, args...)
|
||||
}
|
||||
|
||||
func (c *cryptoLogger) Trace(message string, args ...interface{}) {
|
||||
c.int.Logfln(levelTrace, message, args...)
|
||||
}
|
||||
|
||||
type cryptoClientStore struct {
|
||||
int *database.SQLCryptoStore
|
||||
}
|
||||
|
||||
func (c cryptoClientStore) SaveFilterID(_ id.UserID, _ string) {}
|
||||
func (c cryptoClientStore) LoadFilterID(_ id.UserID) string { return "" }
|
||||
func (c cryptoClientStore) SaveRoom(_ *mautrix.Room) {}
|
||||
func (c cryptoClientStore) LoadRoom(_ id.RoomID) *mautrix.Room { return nil }
|
||||
|
||||
func (c cryptoClientStore) SaveNextBatch(_ id.UserID, nextBatchToken string) {
|
||||
c.int.PutNextBatch(nextBatchToken)
|
||||
}
|
||||
|
||||
func (c cryptoClientStore) LoadNextBatch(_ id.UserID) string {
|
||||
return c.int.GetNextBatch()
|
||||
}
|
||||
|
||||
var _ mautrix.Storer = (*cryptoClientStore)(nil)
|
||||
|
||||
type cryptoStateStore struct {
|
||||
bridge *Bridge
|
||||
}
|
||||
|
||||
var _ crypto.StateStore = (*cryptoStateStore)(nil)
|
||||
|
||||
func (c *cryptoStateStore) IsEncrypted(id id.RoomID) bool {
|
||||
portal := c.bridge.GetPortalByMXID(id)
|
||||
if portal != nil {
|
||||
return portal.Encrypted
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *cryptoStateStore) FindSharedRooms(id id.UserID) []id.RoomID {
|
||||
return c.bridge.StateStore.FindSharedRooms(id)
|
||||
}
|
||||
|
||||
func (c *cryptoStateStore) GetEncryptionEvent(id.RoomID) *event.EncryptionEventContent {
|
||||
// TODO implement
|
||||
return nil
|
||||
}
|
|
@ -75,8 +75,8 @@ func (puppet *Puppet) loginWithSharedSecret(mxid id.UserID) (string, error) {
|
|||
Type: mautrix.AuthTypePassword,
|
||||
Identifier: mautrix.UserIdentifier{Type: mautrix.IdentifierTypeUser, User: string(mxid)},
|
||||
Password: hex.EncodeToString(mac.Sum(nil)),
|
||||
DeviceID: "WhatsApp Bridge",
|
||||
InitialDeviceDisplayName: "WhatsApp Bridge",
|
||||
DeviceID: "WhatsApp bridge",
|
||||
InitialDeviceDisplayName: "WhatsApp bridge",
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
@ -84,22 +84,22 @@ func (puppet *Puppet) loginWithSharedSecret(mxid id.UserID) (string, error) {
|
|||
return resp.AccessToken, nil
|
||||
}
|
||||
|
||||
func (bridge *Bridge) newDoublePuppetClient(mxid id.UserID, accessToken string) (*mautrix.Client, error) {
|
||||
func (br *WABridge) newDoublePuppetClient(mxid id.UserID, accessToken string) (*mautrix.Client, error) {
|
||||
_, homeserver, err := mxid.Parse()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
homeserverURL, found := bridge.Config.Bridge.DoublePuppetServerMap[homeserver]
|
||||
homeserverURL, found := br.Config.Bridge.DoublePuppetServerMap[homeserver]
|
||||
if !found {
|
||||
if homeserver == bridge.AS.HomeserverDomain {
|
||||
homeserverURL = bridge.AS.HomeserverURL
|
||||
} else if bridge.Config.Bridge.DoublePuppetAllowDiscovery {
|
||||
if homeserver == br.AS.HomeserverDomain {
|
||||
homeserverURL = br.AS.HomeserverURL
|
||||
} else if br.Config.Bridge.DoublePuppetAllowDiscovery {
|
||||
resp, err := mautrix.DiscoverClientAPI(homeserver)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find homeserver URL for %s: %v", homeserver, err)
|
||||
}
|
||||
homeserverURL = resp.Homeserver.BaseURL
|
||||
bridge.Log.Debugfln("Discovered URL %s for %s to enable double puppeting for %s", homeserverURL, homeserver, mxid)
|
||||
br.Log.Debugfln("Discovered URL %s for %s to enable double puppeting for %s", homeserverURL, homeserver, mxid)
|
||||
} else {
|
||||
return nil, fmt.Errorf("double puppeting from %s is not allowed", homeserver)
|
||||
}
|
||||
|
@ -108,9 +108,9 @@ func (bridge *Bridge) newDoublePuppetClient(mxid id.UserID, accessToken string)
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client.Logger = bridge.AS.Log.Sub(mxid.String())
|
||||
client.Client = bridge.AS.HTTPClient
|
||||
client.DefaultHTTPRetries = bridge.AS.DefaultHTTPRetries
|
||||
client.Logger = br.AS.Log.Sub(mxid.String())
|
||||
client.Client = br.AS.HTTPClient
|
||||
client.DefaultHTTPRetries = br.AS.DefaultHTTPRetries
|
||||
return client, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -26,7 +26,9 @@ import (
|
|||
"time"
|
||||
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
"maunium.net/go/mautrix/util/dbutil"
|
||||
)
|
||||
|
||||
type BackfillType int
|
||||
|
@ -165,7 +167,7 @@ func (b *Backfill) String() string {
|
|||
)
|
||||
}
|
||||
|
||||
func (b *Backfill) Scan(row Scannable) *Backfill {
|
||||
func (b *Backfill) Scan(row dbutil.Scannable) *Backfill {
|
||||
err := row.Scan(&b.QueueID, &b.UserID, &b.BackfillType, &b.Priority, &b.Portal.JID, &b.Portal.Receiver, &b.TimeStart, &b.MaxBatchEvents, &b.MaxTotalEvents, &b.BatchDelay)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
|
@ -256,7 +258,7 @@ type BackfillState struct {
|
|||
FirstExpectedTimestamp uint64
|
||||
}
|
||||
|
||||
func (b *BackfillState) Scan(row Scannable) *BackfillState {
|
||||
func (b *BackfillState) Scan(row dbutil.Scannable) *BackfillState {
|
||||
err := row.Scan(&b.UserID, &b.Portal.JID, &b.Portal.Receiver, &b.ProcessingBatch, &b.BackfillComplete, &b.FirstExpectedTimestamp)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
|
|
|
@ -1,18 +1,8 @@
|
|||
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
|
||||
// Copyright (C) 2020 Tulir Asokan
|
||||
// Copyright (c) 2022 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
//go:build cgo && !nocrypto
|
||||
|
||||
|
@ -21,8 +11,6 @@ package database
|
|||
import (
|
||||
"database/sql"
|
||||
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix/crypto"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
@ -37,11 +25,9 @@ var _ crypto.Store = (*SQLCryptoStore)(nil)
|
|||
|
||||
func NewSQLCryptoStore(db *Database, userID id.UserID, ghostIDFormat string) *SQLCryptoStore {
|
||||
return &SQLCryptoStore{
|
||||
SQLCryptoStore: crypto.NewSQLCryptoStore(db.DB, db.dialect, "", "",
|
||||
[]byte("maunium.net/go/mautrix-whatsapp"),
|
||||
&cryptoLogger{db.log.Sub("CryptoStore")}),
|
||||
UserID: userID,
|
||||
GhostIDFormat: ghostIDFormat,
|
||||
SQLCryptoStore: crypto.NewSQLCryptoStore(db.Database, "", "", []byte("maunium.net/go/mautrix-whatsapp")),
|
||||
UserID: userID,
|
||||
GhostIDFormat: ghostIDFormat,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -76,30 +62,3 @@ func (store *SQLCryptoStore) GetRoomMembers(roomID id.RoomID) (members []id.User
|
|||
}
|
||||
return
|
||||
}
|
||||
|
||||
// TODO merge this with the one in the parent package
|
||||
type cryptoLogger struct {
|
||||
int log.Logger
|
||||
}
|
||||
|
||||
var levelTrace = log.Level{
|
||||
Name: "TRACE",
|
||||
Severity: -10,
|
||||
Color: -1,
|
||||
}
|
||||
|
||||
func (c *cryptoLogger) Error(message string, args ...interface{}) {
|
||||
c.int.Errorfln(message, args...)
|
||||
}
|
||||
|
||||
func (c *cryptoLogger) Warn(message string, args ...interface{}) {
|
||||
c.int.Warnfln(message, args...)
|
||||
}
|
||||
|
||||
func (c *cryptoLogger) Debug(message string, args ...interface{}) {
|
||||
c.int.Debugfln(message, args...)
|
||||
}
|
||||
|
||||
func (c *cryptoLogger) Trace(message string, args ...interface{}) {
|
||||
c.int.Logfln(levelTrace, message, args...)
|
||||
}
|
||||
|
|
|
@ -17,21 +17,17 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"go.mau.fi/whatsmeow/store"
|
||||
"go.mau.fi/whatsmeow/store/sqlstore"
|
||||
|
||||
"maunium.net/go/mautrix-whatsapp/config"
|
||||
"maunium.net/go/mautrix-whatsapp/database/upgrades"
|
||||
"maunium.net/go/mautrix/util/dbutil"
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
@ -39,9 +35,7 @@ func init() {
|
|||
}
|
||||
|
||||
type Database struct {
|
||||
*sql.DB
|
||||
log log.Logger
|
||||
dialect string
|
||||
*dbutil.Database
|
||||
|
||||
User *UserQuery
|
||||
Portal *PortalQuery
|
||||
|
@ -55,79 +49,46 @@ type Database struct {
|
|||
MediaBackfillRequest *MediaBackfillRequestQuery
|
||||
}
|
||||
|
||||
func New(cfg config.DatabaseConfig, baseLog log.Logger) (*Database, error) {
|
||||
conn, err := sql.Open(cfg.Type, cfg.URI)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db := &Database{
|
||||
DB: conn,
|
||||
log: baseLog.Sub("Database"),
|
||||
dialect: cfg.Type,
|
||||
}
|
||||
func New(baseDB *dbutil.Database) *Database {
|
||||
db := &Database{Database: baseDB}
|
||||
db.UpgradeTable = upgrades.Table
|
||||
db.User = &UserQuery{
|
||||
db: db,
|
||||
log: db.log.Sub("User"),
|
||||
log: db.Log.Sub("User"),
|
||||
}
|
||||
db.Portal = &PortalQuery{
|
||||
db: db,
|
||||
log: db.log.Sub("Portal"),
|
||||
log: db.Log.Sub("Portal"),
|
||||
}
|
||||
db.Puppet = &PuppetQuery{
|
||||
db: db,
|
||||
log: db.log.Sub("Puppet"),
|
||||
log: db.Log.Sub("Puppet"),
|
||||
}
|
||||
db.Message = &MessageQuery{
|
||||
db: db,
|
||||
log: db.log.Sub("Message"),
|
||||
log: db.Log.Sub("Message"),
|
||||
}
|
||||
db.Reaction = &ReactionQuery{
|
||||
db: db,
|
||||
log: db.log.Sub("Reaction"),
|
||||
log: db.Log.Sub("Reaction"),
|
||||
}
|
||||
db.DisappearingMessage = &DisappearingMessageQuery{
|
||||
db: db,
|
||||
log: db.log.Sub("DisappearingMessage"),
|
||||
log: db.Log.Sub("DisappearingMessage"),
|
||||
}
|
||||
db.Backfill = &BackfillQuery{
|
||||
db: db,
|
||||
log: db.log.Sub("Backfill"),
|
||||
log: db.Log.Sub("Backfill"),
|
||||
}
|
||||
db.HistorySync = &HistorySyncQuery{
|
||||
db: db,
|
||||
log: db.log.Sub("HistorySync"),
|
||||
log: db.Log.Sub("HistorySync"),
|
||||
}
|
||||
db.MediaBackfillRequest = &MediaBackfillRequestQuery{
|
||||
db: db,
|
||||
log: db.log.Sub("MediaBackfillRequest"),
|
||||
log: db.Log.Sub("MediaBackfillRequest"),
|
||||
}
|
||||
|
||||
db.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||
db.SetMaxIdleConns(cfg.MaxIdleConns)
|
||||
if len(cfg.ConnMaxIdleTime) > 0 {
|
||||
maxIdleTimeDuration, err := time.ParseDuration(cfg.ConnMaxIdleTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse max_conn_idle_time: %w", err)
|
||||
}
|
||||
db.SetConnMaxIdleTime(maxIdleTimeDuration)
|
||||
}
|
||||
if len(cfg.ConnMaxLifetime) > 0 {
|
||||
maxLifetimeDuration, err := time.ParseDuration(cfg.ConnMaxLifetime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse max_conn_idle_time: %w", err)
|
||||
}
|
||||
db.SetConnMaxLifetime(maxLifetimeDuration)
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func (db *Database) Init() error {
|
||||
return upgrades.Run(db.log.Sub("Upgrade"), db.dialect, db.DB)
|
||||
}
|
||||
|
||||
type Scannable interface {
|
||||
Scan(...interface{}) error
|
||||
return db
|
||||
}
|
||||
|
||||
func isRetryableError(err error) bool {
|
||||
|
@ -145,7 +106,7 @@ func isRetryableError(err error) bool {
|
|||
}
|
||||
|
||||
func (db *Database) HandleSignalStoreError(device *store.Device, action string, attemptIndex int, err error) (retry bool) {
|
||||
if db.dialect != "sqlite" && isRetryableError(err) {
|
||||
if db.Dialect != dbutil.SQLite && isRetryableError(err) {
|
||||
sleepTime := time.Duration(attemptIndex*2) * time.Second
|
||||
device.Log.Warnf("Failed to %s (attempt #%d): %v - retrying in %v", action, attemptIndex+1, err, sleepTime)
|
||||
time.Sleep(sleepTime)
|
||||
|
|
|
@ -24,6 +24,7 @@ import (
|
|||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
"maunium.net/go/mautrix/util/dbutil"
|
||||
)
|
||||
|
||||
type DisappearingMessageQuery struct {
|
||||
|
@ -94,7 +95,7 @@ type DisappearingMessage struct {
|
|||
ExpireAt time.Time
|
||||
}
|
||||
|
||||
func (msg *DisappearingMessage) Scan(row Scannable) *DisappearingMessage {
|
||||
func (msg *DisappearingMessage) Scan(row dbutil.Scannable) *DisappearingMessage {
|
||||
var expireIn int64
|
||||
var expireAt sql.NullInt64
|
||||
err := row.Scan(&msg.RoomID, &msg.EventID, &expireIn, &expireAt)
|
||||
|
|
|
@ -27,7 +27,9 @@ import (
|
|||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
"maunium.net/go/mautrix/util/dbutil"
|
||||
)
|
||||
|
||||
type HistorySyncQuery struct {
|
||||
|
@ -139,7 +141,7 @@ func (hsc *HistorySyncConversation) Upsert() {
|
|||
}
|
||||
}
|
||||
|
||||
func (hsc *HistorySyncConversation) Scan(row Scannable) *HistorySyncConversation {
|
||||
func (hsc *HistorySyncConversation) Scan(row dbutil.Scannable) *HistorySyncConversation {
|
||||
err := row.Scan(
|
||||
&hsc.UserID,
|
||||
&hsc.ConversationID,
|
||||
|
@ -166,7 +168,7 @@ func (hsc *HistorySyncConversation) Scan(row Scannable) *HistorySyncConversation
|
|||
func (hsq *HistorySyncQuery) GetNMostRecentConversations(userID id.UserID, n int) (conversations []*HistorySyncConversation) {
|
||||
nPtr := &n
|
||||
// Negative limit on SQLite means unlimited, but Postgres prefers a NULL limit.
|
||||
if n < 0 && hsq.db.dialect == "postgres" {
|
||||
if n < 0 && hsq.db.Dialect == dbutil.Postgres {
|
||||
nPtr = nil
|
||||
}
|
||||
rows, err := hsq.db.Query(getNMostRecentConversations, userID, nPtr)
|
||||
|
|
|
@ -22,7 +22,9 @@ import (
|
|||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
"maunium.net/go/mautrix/util/dbutil"
|
||||
)
|
||||
|
||||
type MediaBackfillRequestStatus int
|
||||
|
@ -100,7 +102,7 @@ func (mbr *MediaBackfillRequest) Upsert() {
|
|||
}
|
||||
}
|
||||
|
||||
func (mbr *MediaBackfillRequest) Scan(row Scannable) *MediaBackfillRequest {
|
||||
func (mbr *MediaBackfillRequest) Scan(row dbutil.Scannable) *MediaBackfillRequest {
|
||||
err := row.Scan(&mbr.UserID, &mbr.PortalKey.JID, &mbr.PortalKey.Receiver, &mbr.EventID, &mbr.MediaKey, &mbr.Status, &mbr.Error)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
"maunium.net/go/mautrix/util/dbutil"
|
||||
|
||||
"go.mau.fi/whatsmeow/types"
|
||||
)
|
||||
|
@ -163,7 +164,7 @@ func (msg *Message) IsFakeJID() bool {
|
|||
return strings.HasPrefix(msg.JID, "FAKE::") || msg.JID == string(msg.MXID)
|
||||
}
|
||||
|
||||
func (msg *Message) Scan(row Scannable) *Message {
|
||||
func (msg *Message) Scan(row dbutil.Scannable) *Message {
|
||||
var ts int64
|
||||
err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &ts, &msg.Sent, &msg.Type, &msg.Error, &msg.BroadcastListJID)
|
||||
if err != nil {
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
"maunium.net/go/mautrix/util/dbutil"
|
||||
|
||||
"go.mau.fi/whatsmeow/types"
|
||||
)
|
||||
|
@ -152,7 +153,7 @@ type Portal struct {
|
|||
ExpirationTime uint32
|
||||
}
|
||||
|
||||
func (portal *Portal) Scan(row Scannable) *Portal {
|
||||
func (portal *Portal) Scan(row dbutil.Scannable) *Portal {
|
||||
var mxid, avatarURL, firstEventID, nextBatchID, relayUserID sql.NullString
|
||||
err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.Topic, &portal.Avatar, &avatarURL, &portal.Encrypted, &firstEventID, &nextBatchID, &relayUserID, &portal.ExpirationTime)
|
||||
if err != nil {
|
||||
|
|
|
@ -20,7 +20,9 @@ import (
|
|||
"database/sql"
|
||||
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
"maunium.net/go/mautrix/util/dbutil"
|
||||
|
||||
"go.mau.fi/whatsmeow/types"
|
||||
)
|
||||
|
@ -97,7 +99,7 @@ type Puppet struct {
|
|||
EnableReceipts bool
|
||||
}
|
||||
|
||||
func (puppet *Puppet) Scan(row Scannable) *Puppet {
|
||||
func (puppet *Puppet) Scan(row dbutil.Scannable) *Puppet {
|
||||
var displayname, avatar, avatarURL, customMXID, accessToken, nextBatch sql.NullString
|
||||
var quality sql.NullInt64
|
||||
var enablePresence, enableReceipts sql.NullBool
|
||||
|
|
|
@ -23,6 +23,7 @@ import (
|
|||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
"maunium.net/go/mautrix/util/dbutil"
|
||||
|
||||
"go.mau.fi/whatsmeow/types"
|
||||
)
|
||||
|
@ -85,7 +86,7 @@ type Reaction struct {
|
|||
JID types.MessageID
|
||||
}
|
||||
|
||||
func (reaction *Reaction) Scan(row Scannable) *Reaction {
|
||||
func (reaction *Reaction) Scan(row dbutil.Scannable) *Reaction {
|
||||
err := row.Scan(&reaction.Chat.JID, &reaction.Chat.Receiver, &reaction.TargetJID, &reaction.Sender, &reaction.MXID, &reaction.JID)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
|
|
|
@ -1,282 +0,0 @@
|
|||
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
|
||||
// Copyright (C) 2022 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix/appservice"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type SQLStateStore struct {
|
||||
*appservice.TypingStateStore
|
||||
|
||||
db *Database
|
||||
log log.Logger
|
||||
|
||||
Typing map[id.RoomID]map[id.UserID]int64
|
||||
typingLock sync.RWMutex
|
||||
}
|
||||
|
||||
var _ appservice.StateStore = (*SQLStateStore)(nil)
|
||||
|
||||
func NewSQLStateStore(db *Database) *SQLStateStore {
|
||||
return &SQLStateStore{
|
||||
TypingStateStore: appservice.NewTypingStateStore(),
|
||||
db: db,
|
||||
log: db.log.Sub("StateStore"),
|
||||
}
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) IsRegistered(userID id.UserID) bool {
|
||||
var isRegistered bool
|
||||
err := store.db.
|
||||
QueryRow("SELECT EXISTS(SELECT 1 FROM mx_registrations WHERE user_id=$1)", userID).
|
||||
Scan(&isRegistered)
|
||||
if err != nil {
|
||||
store.log.Warnfln("Failed to scan registration existence for %s: %v", userID, err)
|
||||
}
|
||||
return isRegistered
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) MarkRegistered(userID id.UserID) {
|
||||
_, err := store.db.Exec("INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
|
||||
if err != nil {
|
||||
store.log.Warnfln("Failed to mark %s as registered: %v", userID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetRoomMembers(roomID id.RoomID) map[id.UserID]*event.MemberEventContent {
|
||||
members := make(map[id.UserID]*event.MemberEventContent)
|
||||
rows, err := store.db.Query("SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1", roomID)
|
||||
if err != nil {
|
||||
return members
|
||||
}
|
||||
var userID id.UserID
|
||||
var member event.MemberEventContent
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&userID, &member.Membership, &member.Displayname, &member.AvatarURL)
|
||||
if err != nil {
|
||||
store.log.Warnfln("Failed to scan member in %s: %v", roomID, err)
|
||||
} else {
|
||||
members[userID] = &member
|
||||
}
|
||||
}
|
||||
return members
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership {
|
||||
membership := event.MembershipLeave
|
||||
err := store.db.
|
||||
QueryRow("SELECT membership FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
|
||||
Scan(&membership)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
store.log.Warnfln("Failed to scan membership of %s in %s: %v", userID, roomID, err)
|
||||
}
|
||||
return membership
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent {
|
||||
member, ok := store.TryGetMember(roomID, userID)
|
||||
if !ok {
|
||||
member.Membership = event.MembershipLeave
|
||||
}
|
||||
return member
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, bool) {
|
||||
var member event.MemberEventContent
|
||||
err := store.db.
|
||||
QueryRow("SELECT membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
|
||||
Scan(&member.Membership, &member.Displayname, &member.AvatarURL)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
store.log.Warnfln("Failed to scan member info of %s in %s: %v", userID, roomID, err)
|
||||
}
|
||||
return &member, err == nil
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) FindSharedRooms(userID id.UserID) (rooms []id.RoomID) {
|
||||
rows, err := store.db.Query(`
|
||||
SELECT room_id FROM mx_user_profile
|
||||
LEFT JOIN portal ON portal.mxid=mx_user_profile.room_id
|
||||
WHERE user_id=$1 AND portal.encrypted=true
|
||||
`, userID)
|
||||
if err != nil {
|
||||
store.log.Warnfln("Failed to query shared rooms with %s: %v", userID, err)
|
||||
return
|
||||
}
|
||||
for rows.Next() {
|
||||
var roomID id.RoomID
|
||||
err = rows.Scan(&roomID)
|
||||
if err != nil {
|
||||
store.log.Warnfln("Failed to scan room ID: %v", err)
|
||||
} else {
|
||||
rooms = append(rooms, roomID)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) IsInRoom(roomID id.RoomID, userID id.UserID) bool {
|
||||
return store.IsMembership(roomID, userID, "join")
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) IsInvited(roomID id.RoomID, userID id.UserID) bool {
|
||||
return store.IsMembership(roomID, userID, "join", "invite")
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
|
||||
membership := store.GetMembership(roomID, userID)
|
||||
for _, allowedMembership := range allowedMemberships {
|
||||
if allowedMembership == membership {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) {
|
||||
_, err := store.db.Exec(`
|
||||
INSERT INTO mx_user_profile (room_id, user_id, membership) VALUES ($1, $2, $3)
|
||||
ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership
|
||||
`, roomID, userID, membership)
|
||||
if err != nil {
|
||||
store.log.Warnfln("Failed to set membership of %s in %s to %s: %v", userID, roomID, membership, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) {
|
||||
_, err := store.db.Exec(`
|
||||
INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership, displayname=excluded.displayname, avatar_url=excluded.avatar_url
|
||||
`, roomID, userID, member.Membership, member.Displayname, member.AvatarURL)
|
||||
if err != nil {
|
||||
store.log.Warnfln("Failed to set membership of %s in %s to %s: %v", userID, roomID, member, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) {
|
||||
levelsBytes, err := json.Marshal(levels)
|
||||
if err != nil {
|
||||
store.log.Errorfln("Failed to marshal power levels of %s: %v", roomID, err)
|
||||
return
|
||||
}
|
||||
_, err = store.db.Exec(`
|
||||
INSERT INTO mx_room_state (room_id, power_levels) VALUES ($1, $2)
|
||||
ON CONFLICT (room_id) DO UPDATE SET power_levels=excluded.power_levels
|
||||
`, roomID, levelsBytes)
|
||||
if err != nil {
|
||||
store.log.Warnfln("Failed to store power levels of %s: %v", roomID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetPowerLevels(roomID id.RoomID) (levels *event.PowerLevelsEventContent) {
|
||||
var data []byte
|
||||
err := store.db.
|
||||
QueryRow("SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID).
|
||||
Scan(&data)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
store.log.Errorfln("Failed to scan power levels of %s: %v", roomID, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
levels = &event.PowerLevelsEventContent{}
|
||||
err = json.Unmarshal(data, levels)
|
||||
if err != nil {
|
||||
store.log.Errorfln("Failed to parse power levels of %s: %v", roomID, err)
|
||||
return nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetPowerLevel(roomID id.RoomID, userID id.UserID) int {
|
||||
if store.db.dialect == "postgres" {
|
||||
var powerLevel int
|
||||
err := store.db.
|
||||
QueryRow(`
|
||||
SELECT COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
|
||||
FROM mx_room_state WHERE room_id=$1
|
||||
`, roomID, userID).
|
||||
Scan(&powerLevel)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
store.log.Errorfln("Failed to scan power level of %s in %s: %v", userID, roomID, err)
|
||||
}
|
||||
return powerLevel
|
||||
}
|
||||
return store.GetPowerLevels(roomID).GetUserLevel(userID)
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int {
|
||||
if store.db.dialect == "postgres" {
|
||||
defaultType := "events_default"
|
||||
defaultValue := 0
|
||||
if eventType.IsState() {
|
||||
defaultType = "state_default"
|
||||
defaultValue = 50
|
||||
}
|
||||
var powerLevel int
|
||||
err := store.db.
|
||||
QueryRow(`
|
||||
SELECT COALESCE((power_levels->'events'->$2)::int, (power_levels->'$3')::int, $4)
|
||||
FROM mx_room_state WHERE room_id=$1
|
||||
`, roomID, eventType.Type, defaultType, defaultValue).
|
||||
Scan(&powerLevel)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
store.log.Errorfln("Failed to scan power level for %s in %s: %v", eventType, roomID, err)
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
return powerLevel
|
||||
}
|
||||
return store.GetPowerLevels(roomID).GetEventLevel(eventType)
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool {
|
||||
if store.db.dialect == "postgres" {
|
||||
defaultType := "events_default"
|
||||
defaultValue := 0
|
||||
if eventType.IsState() {
|
||||
defaultType = "state_default"
|
||||
defaultValue = 50
|
||||
}
|
||||
var hasPower bool
|
||||
err := store.db.
|
||||
QueryRow(`SELECT
|
||||
COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
|
||||
>=
|
||||
COALESCE((power_levels->'events'->$3)::int, (power_levels->'$4')::int, $5)
|
||||
FROM mx_room_state WHERE room_id=$1`, roomID, userID, eventType.Type, defaultType, defaultValue).
|
||||
Scan(&hasPower)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
store.log.Errorfln("Failed to scan power level for %s in %s: %v", eventType, roomID, err)
|
||||
}
|
||||
return defaultValue == 0
|
||||
}
|
||||
return hasPower
|
||||
}
|
||||
return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType)
|
||||
}
|
181
database/upgrades/00-latest-revision.sql
Normal file
181
database/upgrades/00-latest-revision.sql
Normal file
|
@ -0,0 +1,181 @@
|
|||
-- v0 -> v48: Latest revision
|
||||
|
||||
CREATE TABLE "user" (
|
||||
mxid TEXT PRIMARY KEY,
|
||||
username TEXT UNIQUE,
|
||||
agent SMALLINT,
|
||||
device SMALLINT,
|
||||
|
||||
management_room TEXT,
|
||||
space_room TEXT,
|
||||
|
||||
phone_last_seen BIGINT,
|
||||
phone_last_pinged BIGINT,
|
||||
|
||||
timezone TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE portal (
|
||||
jid TEXT,
|
||||
receiver TEXT,
|
||||
mxid TEXT UNIQUE,
|
||||
name TEXT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
avatar TEXT NOT NULL,
|
||||
avatar_url TEXT,
|
||||
encrypted BOOLEAN NOT NULL DEFAULT false,
|
||||
|
||||
first_event_id TEXT,
|
||||
next_batch_id TEXT,
|
||||
relay_user_id TEXT,
|
||||
expiration_time BIGINT NOT NULL DEFAULT 0,
|
||||
|
||||
PRIMARY KEY (jid, receiver)
|
||||
);
|
||||
|
||||
CREATE TABLE puppet (
|
||||
username TEXT PRIMARY KEY,
|
||||
displayname TEXT,
|
||||
name_quality SMALLINT,
|
||||
avatar TEXT,
|
||||
avatar_url TEXT,
|
||||
|
||||
custom_mxid TEXT,
|
||||
access_token TEXT,
|
||||
next_batch TEXT,
|
||||
|
||||
enable_presence BOOLEAN NOT NULL DEFAULT true,
|
||||
enable_receipts BOOLEAN NOT NULL DEFAULT true
|
||||
);
|
||||
|
||||
-- only: postgres
|
||||
CREATE TYPE error_type AS ENUM ('', 'decryption_failed', 'media_not_found');
|
||||
|
||||
CREATE TABLE message (
|
||||
chat_jid TEXT,
|
||||
chat_receiver TEXT,
|
||||
jid TEXT,
|
||||
mxid TEXT UNIQUE,
|
||||
sender TEXT,
|
||||
timestamp BIGINT,
|
||||
sent BOOLEAN,
|
||||
error error_type,
|
||||
type TEXT,
|
||||
|
||||
broadcast_list_jid TEXT,
|
||||
|
||||
PRIMARY KEY (chat_jid, chat_receiver, jid),
|
||||
FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE reaction (
|
||||
chat_jid TEXT,
|
||||
chat_receiver TEXT,
|
||||
target_jid TEXT,
|
||||
sender TEXT,
|
||||
|
||||
mxid TEXT NOT NULL,
|
||||
jid TEXT NOT NULL,
|
||||
|
||||
PRIMARY KEY (chat_jid, chat_receiver, target_jid, sender),
|
||||
FOREIGN KEY (chat_jid, chat_receiver, target_jid) REFERENCES message(chat_jid, chat_receiver, jid)
|
||||
ON DELETE CASCADE ON UPDATE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE disappearing_message (
|
||||
room_id TEXT,
|
||||
event_id TEXT,
|
||||
expire_in BIGINT NOT NULL,
|
||||
expire_at BIGINT,
|
||||
PRIMARY KEY (room_id, event_id)
|
||||
);
|
||||
|
||||
CREATE TABLE user_portal (
|
||||
user_mxid TEXT,
|
||||
portal_jid TEXT,
|
||||
portal_receiver TEXT,
|
||||
last_read_ts BIGINT NOT NULL DEFAULT 0,
|
||||
in_space BOOLEAN NOT NULL DEFAULT false,
|
||||
PRIMARY KEY (user_mxid, portal_jid, portal_receiver),
|
||||
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON UPDATE CASCADE ON DELETE CASCADE,
|
||||
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON UPDATE CASCADE ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE backfill_queue (
|
||||
queue_id INTEGER PRIMARY KEY
|
||||
-- only: postgres
|
||||
GENERATED ALWAYS AS IDENTITY
|
||||
,
|
||||
user_mxid TEXT,
|
||||
type INTEGER NOT NULL,
|
||||
priority INTEGER NOT NULL,
|
||||
portal_jid TEXT,
|
||||
portal_receiver TEXT,
|
||||
time_start TIMESTAMP,
|
||||
dispatch_time TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
batch_delay INTEGER,
|
||||
max_batch_events INTEGER NOT NULL,
|
||||
max_total_events INTEGER,
|
||||
|
||||
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE ON UPDATE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE backfill_state (
|
||||
user_mxid TEXT,
|
||||
portal_jid TEXT,
|
||||
portal_receiver TEXT,
|
||||
processing_batch BOOLEAN,
|
||||
backfill_complete BOOLEAN,
|
||||
first_expected_ts TIMESTAMP,
|
||||
PRIMARY KEY (user_mxid, portal_jid, portal_receiver),
|
||||
FOREIGN KEY (user_mxid) REFERENCES "user" (mxid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal (jid, receiver) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE media_backfill_requests (
|
||||
user_mxid TEXT,
|
||||
portal_jid TEXT,
|
||||
portal_receiver TEXT,
|
||||
event_id TEXT,
|
||||
media_key bytea,
|
||||
status INTEGER,
|
||||
error TEXT,
|
||||
PRIMARY KEY (user_mxid, portal_jid, portal_receiver, event_id),
|
||||
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON UPDATE CASCADE ON DELETE CASCADE,
|
||||
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON UPDATE CASCADE ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE history_sync_conversation (
|
||||
user_mxid TEXT,
|
||||
conversation_id TEXT,
|
||||
portal_jid TEXT,
|
||||
portal_receiver TEXT,
|
||||
|
||||
last_message_timestamp TIMESTAMP,
|
||||
archived BOOLEAN,
|
||||
pinned INTEGER,
|
||||
mute_end_time TIMESTAMP,
|
||||
disappearing_mode INTEGER,
|
||||
end_of_history_transfer_type INTEGER,
|
||||
ephemeral_Expiration INTEGER,
|
||||
marked_as_unread BOOLEAN,
|
||||
unread_count INTEGER,
|
||||
|
||||
PRIMARY KEY (user_mxid, conversation_id),
|
||||
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON UPDATE CASCADE ON DELETE CASCADE,
|
||||
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON UPDATE CASCADE ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE history_sync_message (
|
||||
user_mxid TEXT,
|
||||
conversation_id TEXT,
|
||||
message_id TEXT,
|
||||
timestamp TIMESTAMP,
|
||||
data bytea,
|
||||
inserted_time TIMESTAMP,
|
||||
|
||||
PRIMARY KEY (user_mxid, conversation_id, message_id),
|
||||
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON UPDATE CASCADE ON DELETE CASCADE,
|
||||
FOREIGN KEY (user_mxid, conversation_id) REFERENCES history_sync_conversation(user_mxid, conversation_id) ON DELETE CASCADE
|
||||
);
|
|
@ -1,67 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[0] = upgrade{"Initial schema", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec(`CREATE TABLE IF NOT EXISTS portal (
|
||||
jid VARCHAR(255),
|
||||
receiver VARCHAR(255),
|
||||
mxid VARCHAR(255) UNIQUE,
|
||||
|
||||
name VARCHAR(255) NOT NULL,
|
||||
topic VARCHAR(255) NOT NULL,
|
||||
avatar VARCHAR(255) NOT NULL,
|
||||
|
||||
PRIMARY KEY (jid, receiver)
|
||||
)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Exec(`CREATE TABLE IF NOT EXISTS puppet (
|
||||
jid VARCHAR(255) PRIMARY KEY,
|
||||
avatar VARCHAR(255),
|
||||
displayname VARCHAR(255),
|
||||
name_quality SMALLINT
|
||||
)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Exec(`CREATE TABLE IF NOT EXISTS "user" (
|
||||
mxid VARCHAR(255) PRIMARY KEY,
|
||||
jid VARCHAR(255) UNIQUE,
|
||||
|
||||
management_room VARCHAR(255),
|
||||
|
||||
client_id VARCHAR(255),
|
||||
client_token VARCHAR(255),
|
||||
server_token VARCHAR(255),
|
||||
enc_key bytea,
|
||||
mac_key bytea
|
||||
)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Exec(`CREATE TABLE IF NOT EXISTS message (
|
||||
chat_jid VARCHAR(255),
|
||||
chat_receiver VARCHAR(255),
|
||||
jid VARCHAR(255),
|
||||
mxid VARCHAR(255) NOT NULL UNIQUE,
|
||||
sender VARCHAR(255) NOT NULL,
|
||||
content bytea NOT NULL,
|
||||
|
||||
PRIMARY KEY (chat_jid, chat_receiver, jid),
|
||||
FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
|
||||
)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}}
|
||||
}
|
|
@ -1,15 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[2] = upgrade{"Add timestamp column to messages", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec("ALTER TABLE message ADD COLUMN timestamp BIGINT NOT NULL DEFAULT 0")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}}
|
||||
}
|
|
@ -1,15 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[3] = upgrade{"Add last_connection column to users", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN last_connection BIGINT NOT NULL DEFAULT 0`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}}
|
||||
}
|
|
@ -1,23 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[5] = upgrade{"Add columns to store custom puppet info", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec(`ALTER TABLE puppet ADD COLUMN custom_mxid VARCHAR(255)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(`ALTER TABLE puppet ADD COLUMN access_token VARCHAR(1023)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(`ALTER TABLE puppet ADD COLUMN next_batch VARCHAR(255)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}}
|
||||
}
|
|
@ -1,19 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[6] = upgrade{"Add user-portal mapping table", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec(`CREATE TABLE user_portal (
|
||||
user_jid VARCHAR(255),
|
||||
portal_jid VARCHAR(255),
|
||||
portal_receiver VARCHAR(255),
|
||||
PRIMARY KEY (user_jid, portal_jid, portal_receiver),
|
||||
FOREIGN KEY (user_jid) REFERENCES "user"(jid) ON DELETE CASCADE,
|
||||
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
|
||||
)`)
|
||||
return err
|
||||
}}
|
||||
}
|
|
@ -1,19 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[7] = upgrade{"Add columns to store avatar MXC URIs", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec(`ALTER TABLE puppet ADD COLUMN avatar_url VARCHAR(255)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(`ALTER TABLE portal ADD COLUMN avatar_url VARCHAR(255)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}}
|
||||
}
|
|
@ -1,12 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[8] = upgrade{"Add columns to store portal in filtering community meta", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec(`ALTER TABLE user_portal ADD COLUMN in_community BOOLEAN NOT NULL DEFAULT FALSE`)
|
||||
return err
|
||||
}}
|
||||
}
|
|
@ -1,39 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func init() {
|
||||
userProfileTable := `CREATE TABLE mx_user_profile (
|
||||
room_id VARCHAR(255),
|
||||
user_id VARCHAR(255),
|
||||
membership VARCHAR(15) NOT NULL,
|
||||
PRIMARY KEY (room_id, user_id)
|
||||
)`
|
||||
|
||||
roomStateTable := `CREATE TABLE mx_room_state (
|
||||
room_id VARCHAR(255) PRIMARY KEY,
|
||||
power_levels TEXT
|
||||
)`
|
||||
|
||||
registrationsTable := `CREATE TABLE mx_registrations (
|
||||
user_id VARCHAR(255) PRIMARY KEY
|
||||
)`
|
||||
|
||||
upgrades[9] = upgrade{"Move state store to main DB", func(tx *sql.Tx, ctx context) error {
|
||||
if ctx.dialect == Postgres {
|
||||
roomStateTable = strings.Replace(roomStateTable, "TEXT", "JSONB", 1)
|
||||
}
|
||||
|
||||
if _, err := tx.Exec(userProfileTable); err != nil {
|
||||
return err
|
||||
} else if _, err = tx.Exec(roomStateTable); err != nil {
|
||||
return err
|
||||
} else if _, err = tx.Exec(registrationsTable); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}}
|
||||
}
|
|
@ -1,16 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[10] = upgrade{"Add columns to store full member info in state store", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec(`ALTER TABLE mx_user_profile ADD COLUMN displayname TEXT`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(`ALTER TABLE mx_user_profile ADD COLUMN avatar_url VARCHAR(255)`)
|
||||
return err
|
||||
}}
|
||||
}
|
|
@ -1,16 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[11] = upgrade{"Adjust the length of column topic in portal", func(tx *sql.Tx, ctx context) error {
|
||||
if ctx.dialect == SQLite {
|
||||
// SQLite doesn't support constraint updates, but it isn't that careful with constraints anyway.
|
||||
return nil
|
||||
}
|
||||
_, err := tx.Exec(`ALTER TABLE portal ALTER COLUMN topic TYPE VARCHAR(512)`)
|
||||
return err
|
||||
}}
|
||||
}
|
|
@ -1,12 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[12] = upgrade{"Add encryption status to portal table", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec(`ALTER TABLE portal ADD COLUMN encrypted BOOLEAN NOT NULL DEFAULT false`)
|
||||
return err
|
||||
}}
|
||||
}
|
|
@ -1,73 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[13] = upgrade{"Add crypto store to database", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec(`CREATE TABLE crypto_account (
|
||||
device_id VARCHAR(255) PRIMARY KEY,
|
||||
shared BOOLEAN NOT NULL,
|
||||
sync_token TEXT NOT NULL,
|
||||
account bytea NOT NULL
|
||||
)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(`CREATE TABLE crypto_message_index (
|
||||
sender_key CHAR(43),
|
||||
session_id CHAR(43),
|
||||
"index" INTEGER,
|
||||
event_id VARCHAR(255) NOT NULL,
|
||||
timestamp BIGINT NOT NULL,
|
||||
|
||||
PRIMARY KEY (sender_key, session_id, "index")
|
||||
)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(`CREATE TABLE crypto_tracked_user (
|
||||
user_id VARCHAR(255) PRIMARY KEY
|
||||
)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(`CREATE TABLE crypto_device (
|
||||
user_id VARCHAR(255),
|
||||
device_id VARCHAR(255),
|
||||
identity_key CHAR(43) NOT NULL,
|
||||
signing_key CHAR(43) NOT NULL,
|
||||
trust SMALLINT NOT NULL,
|
||||
deleted BOOLEAN NOT NULL,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
|
||||
PRIMARY KEY (user_id, device_id)
|
||||
)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(`CREATE TABLE crypto_olm_session (
|
||||
session_id CHAR(43) PRIMARY KEY,
|
||||
sender_key CHAR(43) NOT NULL,
|
||||
session bytea NOT NULL,
|
||||
created_at timestamp NOT NULL,
|
||||
last_used timestamp NOT NULL
|
||||
)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(`CREATE TABLE crypto_megolm_inbound_session (
|
||||
session_id CHAR(43) PRIMARY KEY,
|
||||
sender_key CHAR(43) NOT NULL,
|
||||
signing_key CHAR(43) NOT NULL,
|
||||
room_id VARCHAR(255) NOT NULL,
|
||||
session bytea NOT NULL,
|
||||
forwarding_chains bytea NOT NULL
|
||||
)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}}
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[14] = upgrade{"Add outbound group sessions to database", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec(`CREATE TABLE crypto_megolm_outbound_session (
|
||||
room_id VARCHAR(255) PRIMARY KEY,
|
||||
session_id CHAR(43) NOT NULL UNIQUE,
|
||||
session bytea NOT NULL,
|
||||
shared BOOLEAN NOT NULL,
|
||||
max_messages INTEGER NOT NULL,
|
||||
message_count INTEGER NOT NULL,
|
||||
max_age BIGINT NOT NULL,
|
||||
created_at timestamp NOT NULL,
|
||||
last_used timestamp NOT NULL
|
||||
)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}}
|
||||
}
|
|
@ -1,12 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[15] = upgrade{"Add enable_presence column for puppets", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec(`ALTER TABLE puppet ADD COLUMN enable_presence BOOLEAN NOT NULL DEFAULT true`)
|
||||
return err
|
||||
}}
|
||||
}
|
|
@ -1,13 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/sql_store_upgrade"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[16] = upgrade{"Add account_id to crypto store", func(tx *sql.Tx, c context) error {
|
||||
return sql_store_upgrade.Upgrades[1](tx, c.dialect.String())
|
||||
}}
|
||||
}
|
|
@ -1,12 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[17] = upgrade{"Add enable_receipts column for puppets", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec(`ALTER TABLE puppet ADD COLUMN enable_receipts BOOLEAN NOT NULL DEFAULT true`)
|
||||
return err
|
||||
}}
|
||||
}
|
|
@ -1,13 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/sql_store_upgrade"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[18] = upgrade{"Add megolm withheld data to crypto store", func(tx *sql.Tx, c context) error {
|
||||
return sql_store_upgrade.Upgrades[2](tx, c.dialect.String())
|
||||
}}
|
||||
}
|
|
@ -1,13 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/sql_store_upgrade"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[19] = upgrade{"Add cross-signing keys to crypto store", func(tx *sql.Tx, c context) error {
|
||||
return sql_store_upgrade.Upgrades[3](tx, c.dialect.String())
|
||||
}}
|
||||
}
|
|
@ -1,12 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[20] = upgrade{"Add sent column for messages", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec(`ALTER TABLE message ADD COLUMN sent BOOLEAN NOT NULL DEFAULT true`)
|
||||
return err
|
||||
}}
|
||||
}
|
|
@ -1,44 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[21] = upgrade{"Remove message content from local database", func(tx *sql.Tx, ctx context) error {
|
||||
if ctx.dialect == SQLite {
|
||||
_, err := tx.Exec("ALTER TABLE message RENAME TO old_message")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(`CREATE TABLE IF NOT EXISTS message (
|
||||
chat_jid TEXT,
|
||||
chat_receiver TEXT,
|
||||
jid TEXT,
|
||||
mxid TEXT NOT NULL UNIQUE,
|
||||
sender TEXT NOT NULL,
|
||||
timestamp BIGINT NOT NULL,
|
||||
sent BOOLEAN NOT NULL,
|
||||
|
||||
PRIMARY KEY (chat_jid, chat_receiver, jid),
|
||||
FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
|
||||
)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec("INSERT INTO message SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent FROM old_message")
|
||||
return err
|
||||
} else {
|
||||
_, err := tx.Exec(`ALTER TABLE message DROP COLUMN content`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(`ALTER TABLE message ALTER COLUMN timestamp DROP DEFAULT`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(`ALTER TABLE message ALTER COLUMN sent DROP DEFAULT`)
|
||||
return err
|
||||
}
|
||||
}}
|
||||
}
|
|
@ -1,13 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/sql_store_upgrade"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[23] = upgrade{"Replace VARCHAR(255) with TEXT in the crypto database", func(tx *sql.Tx, ctx context) error {
|
||||
return sql_store_upgrade.Upgrades[4](tx, ctx.dialect.String())
|
||||
}}
|
||||
}
|
|
@ -1,48 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[22] = upgrade{"Replace VARCHAR(255) with TEXT in the database", func(tx *sql.Tx, ctx context) error {
|
||||
if ctx.dialect == SQLite {
|
||||
// SQLite doesn't enforce varchar sizes anyway
|
||||
return nil
|
||||
}
|
||||
return execMany(tx,
|
||||
`ALTER TABLE message ALTER COLUMN chat_jid TYPE TEXT`,
|
||||
`ALTER TABLE message ALTER COLUMN chat_receiver TYPE TEXT`,
|
||||
`ALTER TABLE message ALTER COLUMN jid TYPE TEXT`,
|
||||
`ALTER TABLE message ALTER COLUMN mxid TYPE TEXT`,
|
||||
`ALTER TABLE message ALTER COLUMN sender TYPE TEXT`,
|
||||
|
||||
`ALTER TABLE portal ALTER COLUMN jid TYPE TEXT`,
|
||||
`ALTER TABLE portal ALTER COLUMN receiver TYPE TEXT`,
|
||||
`ALTER TABLE portal ALTER COLUMN mxid TYPE TEXT`,
|
||||
`ALTER TABLE portal ALTER COLUMN name TYPE TEXT`,
|
||||
`ALTER TABLE portal ALTER COLUMN topic TYPE TEXT`,
|
||||
`ALTER TABLE portal ALTER COLUMN avatar TYPE TEXT`,
|
||||
`ALTER TABLE portal ALTER COLUMN avatar_url TYPE TEXT`,
|
||||
|
||||
`ALTER TABLE puppet ALTER COLUMN jid TYPE TEXT`,
|
||||
`ALTER TABLE puppet ALTER COLUMN avatar TYPE TEXT`,
|
||||
`ALTER TABLE puppet ALTER COLUMN displayname TYPE TEXT`,
|
||||
`ALTER TABLE puppet ALTER COLUMN custom_mxid TYPE TEXT`,
|
||||
`ALTER TABLE puppet ALTER COLUMN access_token TYPE TEXT`,
|
||||
`ALTER TABLE puppet ALTER COLUMN next_batch TYPE TEXT`,
|
||||
`ALTER TABLE puppet ALTER COLUMN avatar_url TYPE TEXT`,
|
||||
|
||||
`ALTER TABLE "user" ALTER COLUMN mxid TYPE TEXT`,
|
||||
`ALTER TABLE "user" ALTER COLUMN jid TYPE TEXT`,
|
||||
`ALTER TABLE "user" ALTER COLUMN management_room TYPE TEXT`,
|
||||
`ALTER TABLE "user" ALTER COLUMN client_id TYPE TEXT`,
|
||||
`ALTER TABLE "user" ALTER COLUMN client_token TYPE TEXT`,
|
||||
`ALTER TABLE "user" ALTER COLUMN server_token TYPE TEXT`,
|
||||
|
||||
`ALTER TABLE user_portal ALTER COLUMN user_jid TYPE TEXT`,
|
||||
`ALTER TABLE user_portal ALTER COLUMN portal_jid TYPE TEXT`,
|
||||
`ALTER TABLE user_portal ALTER COLUMN portal_receiver TYPE TEXT`,
|
||||
)
|
||||
}}
|
||||
}
|
|
@ -1,13 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"go.mau.fi/whatsmeow/store/sqlstore"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[24] = upgrade{"Add whatsmeow state store", func(tx *sql.Tx, ctx context) error {
|
||||
return sqlstore.Upgrades[0](tx, sqlstore.NewWithDB(ctx.db, ctx.dialect.String(), nil))
|
||||
}}
|
||||
}
|
|
@ -1,93 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[25] = upgrade{"Update things for multidevice", func(tx *sql.Tx, ctx context) error {
|
||||
// This is probably not necessary
|
||||
_, err := tx.Exec("DROP TABLE user_portal")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove invalid puppet rows
|
||||
_, err = tx.Exec("DELETE FROM puppet WHERE jid LIKE '%@g.us' OR jid LIKE '%@broadcast'")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Remove the suffix from puppets since they'll all have the same suffix
|
||||
_, err = tx.Exec("UPDATE puppet SET jid=REPLACE(jid, '@s.whatsapp.net', '')")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Rename column to correctly represent the new content
|
||||
_, err = tx.Exec("ALTER TABLE puppet RENAME COLUMN jid TO username")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ctx.dialect == SQLite {
|
||||
// Message content was removed from the main message table earlier, but the backup table still exists for SQLite
|
||||
_, err = tx.Exec("DROP TABLE IF EXISTS old_message")
|
||||
|
||||
_, err = tx.Exec(`ALTER TABLE "user" RENAME TO old_user`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(`CREATE TABLE "user" (
|
||||
mxid TEXT PRIMARY KEY,
|
||||
username TEXT UNIQUE,
|
||||
agent SMALLINT,
|
||||
device SMALLINT,
|
||||
management_room TEXT
|
||||
)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// No need to copy auth data, users need to relogin anyway
|
||||
_, err = tx.Exec(`INSERT INTO "user" (mxid, management_room) SELECT mxid, management_room FROM old_user`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Exec("DROP TABLE old_user")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// The jid column never actually contained the full JID, so let's rename it.
|
||||
_, err = tx.Exec(`ALTER TABLE "user" RENAME COLUMN jid TO username`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// The auth data is now in the whatsmeow_device table.
|
||||
for _, column := range []string{"last_connection", "client_id", "client_token", "server_token", "enc_key", "mac_key"} {
|
||||
_, err = tx.Exec(`ALTER TABLE "user" DROP COLUMN ` + column)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// The whatsmeow_device table is keyed by the full JID, so we need to store the other parts of the JID here too.
|
||||
_, err = tx.Exec(`ALTER TABLE "user" ADD COLUMN agent SMALLINT`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(`ALTER TABLE "user" ADD COLUMN device SMALLINT`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Clear all usernames, the users need to relogin anyway.
|
||||
_, err = tx.Exec(`UPDATE "user" SET username=null`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}}
|
||||
}
|
|
@ -1,19 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[26] = upgrade{"Add columns to store infinite backfill pointers for portals", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec(`ALTER TABLE portal ADD COLUMN first_event_id TEXT NOT NULL DEFAULT ''`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(`ALTER TABLE portal ADD COLUMN next_batch_id TEXT NOT NULL DEFAULT ''`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}}
|
||||
}
|
|
@ -1,12 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[27] = upgrade{"Add marker for WhatsApp decryption errors in message table", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec(`ALTER TABLE message ADD COLUMN decryption_error BOOLEAN NOT NULL DEFAULT false`)
|
||||
return err
|
||||
}}
|
||||
}
|
|
@ -1,12 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[28] = upgrade{"Add relay user field to portal table", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec(`ALTER TABLE portal ADD COLUMN relay_user_id TEXT`)
|
||||
return err
|
||||
}}
|
||||
}
|
|
@ -1,22 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[29] = upgrade{"Replace VARCHAR(255) with TEXT in the Matrix state store", func(tx *sql.Tx, ctx context) error {
|
||||
if ctx.dialect == SQLite {
|
||||
// SQLite doesn't enforce varchar sizes anyway
|
||||
return nil
|
||||
}
|
||||
return execMany(tx,
|
||||
`ALTER TABLE mx_registrations ALTER COLUMN user_id TYPE TEXT`,
|
||||
`ALTER TABLE mx_room_state ALTER COLUMN room_id TYPE TEXT`,
|
||||
`ALTER TABLE mx_user_profile ALTER COLUMN room_id TYPE TEXT`,
|
||||
`ALTER TABLE mx_user_profile ALTER COLUMN user_id TYPE TEXT`,
|
||||
`ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE TEXT`,
|
||||
`ALTER TABLE mx_user_profile ALTER COLUMN avatar_url TYPE TEXT`,
|
||||
)
|
||||
}}
|
||||
}
|
|
@ -1,22 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[30] = upgrade{"Store last read message timestamp in database", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec(`CREATE TABLE user_portal (
|
||||
user_mxid TEXT,
|
||||
portal_jid TEXT,
|
||||
portal_receiver TEXT,
|
||||
|
||||
last_read_ts BIGINT NOT NULL DEFAULT 0,
|
||||
|
||||
PRIMARY KEY (user_mxid, portal_jid, portal_receiver),
|
||||
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE ON UPDATE CASCADE
|
||||
)`)
|
||||
return err
|
||||
}}
|
||||
}
|
|
@ -1,13 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/sql_store_upgrade"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[31] = upgrade{"Split last_used into last_encrypted and last_decrypted in crypto store", func(tx *sql.Tx, c context) error {
|
||||
return sql_store_upgrade.Upgrades[5](tx, c.dialect.String())
|
||||
}}
|
||||
}
|
|
@ -1,12 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[32] = upgrade{"Store source broadcast list in message table", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec(`ALTER TABLE message ADD COLUMN broadcast_list_jid TEXT`)
|
||||
return err
|
||||
}}
|
||||
}
|
|
@ -1,16 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[33] = upgrade{"Add personal filtering space info to user tables", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN space_room TEXT NOT NULL DEFAULT ''`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(`ALTER TABLE user_portal ADD COLUMN in_space BOOLEAN NOT NULL DEFAULT false`)
|
||||
return err
|
||||
}}
|
||||
}
|
|
@ -1,20 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import "database/sql"
|
||||
|
||||
func init() {
|
||||
upgrades[34] = upgrade{"Add support for disappearing messages", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec(`ALTER TABLE portal ADD COLUMN expiration_time BIGINT NOT NULL DEFAULT 0 CHECK (expiration_time >= 0 AND expiration_time < 4294967296)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(`CREATE TABLE disappearing_message (
|
||||
room_id TEXT,
|
||||
event_id TEXT,
|
||||
expire_in BIGINT NOT NULL,
|
||||
expire_at BIGINT,
|
||||
PRIMARY KEY (room_id, event_id)
|
||||
)`)
|
||||
return err
|
||||
}}
|
||||
}
|
|
@ -1,10 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import "database/sql"
|
||||
|
||||
func init() {
|
||||
upgrades[35] = upgrade{"Store approximate last seen timestamp of the main device", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN phone_last_seen BIGINT`)
|
||||
return err
|
||||
}}
|
||||
}
|
|
@ -1,30 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import "database/sql"
|
||||
|
||||
func init() {
|
||||
upgrades[36] = upgrade{"Store message error type as string", func(tx *sql.Tx, ctx context) error {
|
||||
if ctx.dialect == Postgres {
|
||||
_, err := tx.Exec("CREATE TYPE error_type AS ENUM ('', 'decryption_failed', 'media_not_found')")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err := tx.Exec("ALTER TABLE message ADD COLUMN error error_type NOT NULL DEFAULT ''")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec("UPDATE message SET error='decryption_failed' WHERE decryption_error=true")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ctx.dialect == Postgres {
|
||||
// TODO do this on sqlite at some point
|
||||
_, err = tx.Exec("ALTER TABLE message DROP COLUMN decryption_error")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}}
|
||||
}
|
|
@ -1,10 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import "database/sql"
|
||||
|
||||
func init() {
|
||||
upgrades[37] = upgrade{"Store timestamp for previous phone ping", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN phone_last_pinged BIGINT`)
|
||||
return err
|
||||
}}
|
||||
}
|
|
@ -1,39 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import "database/sql"
|
||||
|
||||
func init() {
|
||||
upgrades[38] = upgrade{"Add support for reactions", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec(`ALTER TABLE message ADD COLUMN type TEXT NOT NULL DEFAULT 'message'`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ctx.dialect == Postgres {
|
||||
_, err = tx.Exec("ALTER TABLE message ALTER COLUMN type DROP DEFAULT")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err = tx.Exec("UPDATE message SET type='' WHERE error='decryption_failed'")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec("UPDATE message SET type='fake' WHERE jid LIKE 'FAKE::%' OR mxid LIKE 'net.maunium.whatsapp.fake::%' OR jid=mxid")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(`CREATE TABLE reaction (
|
||||
chat_jid TEXT,
|
||||
chat_receiver TEXT,
|
||||
target_jid TEXT,
|
||||
sender TEXT,
|
||||
mxid TEXT NOT NULL,
|
||||
jid TEXT NOT NULL,
|
||||
PRIMARY KEY (chat_jid, chat_receiver, target_jid, sender),
|
||||
CONSTRAINT target_message_fkey FOREIGN KEY (chat_jid, chat_receiver, target_jid)
|
||||
REFERENCES message(chat_jid, chat_receiver, jid)
|
||||
ON DELETE CASCADE ON UPDATE CASCADE
|
||||
)`)
|
||||
return err
|
||||
}}
|
||||
}
|
|
@ -1,45 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[39] = upgrade{"Add backfill queue", func(tx *sql.Tx, ctx context) error {
|
||||
// The queue_id needs to auto-increment every insertion. For SQLite,
|
||||
// INTEGER PRIMARY KEY is an alias for the ROWID, so it will
|
||||
// auto-increment. See https://sqlite.org/lang_createtable.html#rowid
|
||||
// For Postgres, we need to add GENERATED ALWAYS AS IDENTITY for the
|
||||
// same functionality.
|
||||
queueIDColumnTypeModifier := ""
|
||||
if ctx.dialect == Postgres {
|
||||
queueIDColumnTypeModifier = "GENERATED ALWAYS AS IDENTITY"
|
||||
}
|
||||
|
||||
_, err := tx.Exec(fmt.Sprintf(`
|
||||
CREATE TABLE backfill_queue (
|
||||
queue_id INTEGER PRIMARY KEY %s,
|
||||
user_mxid TEXT,
|
||||
type INTEGER NOT NULL,
|
||||
priority INTEGER NOT NULL,
|
||||
portal_jid TEXT,
|
||||
portal_receiver TEXT,
|
||||
time_start TIMESTAMP,
|
||||
time_end TIMESTAMP,
|
||||
max_batch_events INTEGER NOT NULL,
|
||||
max_total_events INTEGER,
|
||||
batch_delay INTEGER,
|
||||
completed_at TIMESTAMP,
|
||||
|
||||
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
|
||||
)
|
||||
`, queueIDColumnTypeModifier))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
}}
|
||||
}
|
|
@ -1,52 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[40] = upgrade{"Store history syncs for later backfills", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec(`
|
||||
CREATE TABLE history_sync_conversation (
|
||||
user_mxid TEXT,
|
||||
conversation_id TEXT,
|
||||
portal_jid TEXT,
|
||||
portal_receiver TEXT,
|
||||
last_message_timestamp TIMESTAMP,
|
||||
archived BOOLEAN,
|
||||
pinned INTEGER,
|
||||
mute_end_time TIMESTAMP,
|
||||
disappearing_mode INTEGER,
|
||||
end_of_history_transfer_type INTEGER,
|
||||
ephemeral_expiration INTEGER,
|
||||
marked_as_unread BOOLEAN,
|
||||
unread_count INTEGER,
|
||||
|
||||
PRIMARY KEY (user_mxid, conversation_id),
|
||||
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE ON UPDATE CASCADE
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(`
|
||||
CREATE TABLE history_sync_message (
|
||||
user_mxid TEXT,
|
||||
conversation_id TEXT,
|
||||
message_id TEXT,
|
||||
timestamp TIMESTAMP,
|
||||
data BYTEA,
|
||||
|
||||
PRIMARY KEY (user_mxid, conversation_id, message_id),
|
||||
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
FOREIGN KEY (user_mxid, conversation_id) REFERENCES history_sync_conversation(user_mxid, conversation_id) ON DELETE CASCADE
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}}
|
||||
}
|
|
@ -1,20 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[41] = upgrade{"Update backfill queue tables to be sortable by priority", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec(`
|
||||
UPDATE backfill_queue
|
||||
SET type=CASE
|
||||
WHEN type=1 THEN 200
|
||||
WHEN type=2 THEN 300
|
||||
ELSE type
|
||||
END
|
||||
WHERE type=1 OR type=2
|
||||
`)
|
||||
return err
|
||||
}}
|
||||
}
|
|
@ -1,26 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[42] = upgrade{"Add table of media to request from the user's phone", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec(`
|
||||
CREATE TABLE media_backfill_requests (
|
||||
user_mxid TEXT,
|
||||
portal_jid TEXT,
|
||||
portal_receiver TEXT,
|
||||
event_id TEXT,
|
||||
media_key BYTEA,
|
||||
status INTEGER,
|
||||
error TEXT,
|
||||
|
||||
PRIMARY KEY (user_mxid, portal_jid, portal_receiver, event_id),
|
||||
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
|
||||
)
|
||||
`)
|
||||
return err
|
||||
}}
|
||||
}
|
|
@ -1,12 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[43] = upgrade{"Add timezone column to user table", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN timezone TEXT`)
|
||||
return err
|
||||
}}
|
||||
}
|
|
@ -1,34 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[44] = upgrade{"Add dispatch time to backfill queue", func(tx *sql.Tx, ctx context) error {
|
||||
// First, add dispatch_time TIMESTAMP column
|
||||
_, err := tx.Exec(`
|
||||
ALTER TABLE backfill_queue
|
||||
ADD COLUMN dispatch_time TIMESTAMP
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// For all previous jobs, set dispatch time to the completed time.
|
||||
_, err = tx.Exec(`
|
||||
UPDATE backfill_queue
|
||||
SET dispatch_time=completed_at
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove time_end from the backfill queue
|
||||
_, err = tx.Exec(`
|
||||
ALTER TABLE backfill_queue
|
||||
DROP COLUMN time_end
|
||||
`)
|
||||
return err
|
||||
}}
|
||||
}
|
|
@ -1,16 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[45] = upgrade{"Add inserted time to history sync message", func(tx *sql.Tx, ctx context) error {
|
||||
// Add the inserted time TIMESTAMP column to history_sync_message
|
||||
_, err := tx.Exec(`
|
||||
ALTER TABLE history_sync_message
|
||||
ADD COLUMN inserted_time TIMESTAMP
|
||||
`)
|
||||
return err
|
||||
}}
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[46] = upgrade{"Create the backfill state table", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec(`
|
||||
CREATE TABLE backfill_state (
|
||||
user_mxid TEXT,
|
||||
portal_jid TEXT,
|
||||
portal_receiver TEXT,
|
||||
processing_batch BOOLEAN,
|
||||
backfill_complete BOOLEAN,
|
||||
first_expected_ts INTEGER,
|
||||
|
||||
PRIMARY KEY (user_mxid, portal_jid, portal_receiver),
|
||||
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
|
||||
)
|
||||
`)
|
||||
return err
|
||||
}}
|
||||
}
|
5
database/upgrades/45-backfillqueue-dispatch-time.sql
Normal file
5
database/upgrades/45-backfillqueue-dispatch-time.sql
Normal file
|
@ -0,0 +1,5 @@
|
|||
-- v45: Add dispatch time to backfill queue
|
||||
|
||||
ALTER TABLE backfill_queue ADD COLUMN dispatch_time TIMESTAMP;
|
||||
UPDATE backfill_queue SET dispatch_time=completed_at;
|
||||
ALTER TABLE backfill_queue DROP COLUMN time_end;
|
|
@ -0,0 +1,3 @@
|
|||
-- v46: Add inserted time to history sync message
|
||||
|
||||
ALTER TABLE history_sync_message ADD COLUMN inserted_time TIMESTAMP;
|
13
database/upgrades/47-room-backfill-state.sql
Normal file
13
database/upgrades/47-room-backfill-state.sql
Normal file
|
@ -0,0 +1,13 @@
|
|||
-- v47: Add table for keeping track of backfill state
|
||||
|
||||
CREATE TABLE backfill_state (
|
||||
user_mxid TEXT,
|
||||
portal_jid TEXT,
|
||||
portal_receiver TEXT,
|
||||
processing_batch BOOLEAN,
|
||||
backfill_complete BOOLEAN,
|
||||
first_expected_ts TIMESTAMP,
|
||||
PRIMARY KEY (user_mxid, portal_jid, portal_receiver),
|
||||
FOREIGN KEY (user_mxid) REFERENCES "user" (mxid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal (jid, receiver) ON DELETE CASCADE
|
||||
);
|
7
database/upgrades/48-crypto-store-handling-split.sql
Normal file
7
database/upgrades/48-crypto-store-handling-split.sql
Normal file
|
@ -0,0 +1,7 @@
|
|||
-- v48: Move crypto/state/whatsmeow store upgrade handling to separate systems
|
||||
CREATE TABLE crypto_version (version INTEGER PRIMARY KEY);
|
||||
INSERT INTO crypto_version VALUES (6);
|
||||
CREATE TABLE whatsmeow_version (version INTEGER PRIMARY KEY);
|
||||
INSERT INTO whatsmeow_version VALUES (1);
|
||||
CREATE TABLE mx_version (version INTEGER PRIMARY KEY);
|
||||
INSERT INTO mx_version VALUES (1);
|
|
@ -1,180 +1,27 @@
|
|||
// Copyright (c) 2022 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
"maunium.net/go/mautrix/util/dbutil"
|
||||
)
|
||||
|
||||
type Dialect int
|
||||
var Table dbutil.UpgradeTable
|
||||
|
||||
const (
|
||||
Postgres Dialect = iota
|
||||
SQLite
|
||||
)
|
||||
//go:embed *.sql
|
||||
var rawUpgrades embed.FS
|
||||
|
||||
func (dialect Dialect) String() string {
|
||||
switch dialect {
|
||||
case Postgres:
|
||||
return "postgres"
|
||||
case SQLite:
|
||||
return "sqlite3"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
type upgradeFunc func(*sql.Tx, context) error
|
||||
|
||||
type context struct {
|
||||
dialect Dialect
|
||||
db *sql.DB
|
||||
log log.Logger
|
||||
}
|
||||
|
||||
type upgrade struct {
|
||||
message string
|
||||
fn upgradeFunc
|
||||
}
|
||||
|
||||
const NumberOfUpgrades = 47
|
||||
|
||||
var upgrades [NumberOfUpgrades]upgrade
|
||||
|
||||
var ErrUnsupportedDatabaseVersion = fmt.Errorf("unsupported database schema version")
|
||||
var ErrForeignTables = fmt.Errorf("the database contains foreign tables")
|
||||
var ErrNotOwned = fmt.Errorf("the database is owned by")
|
||||
var IgnoreForeignTables = false
|
||||
|
||||
const databaseOwner = "mautrix-whatsapp"
|
||||
|
||||
func GetVersion(db *sql.DB) (int, error) {
|
||||
_, err := db.Exec("CREATE TABLE IF NOT EXISTS version (version INTEGER)")
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
||||
version := 0
|
||||
err = db.QueryRow("SELECT version FROM version LIMIT 1").Scan(&version)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return -1, err
|
||||
}
|
||||
return version, nil
|
||||
}
|
||||
|
||||
const tableExistsPostgres = "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name=$1)"
|
||||
const tableExistsSQLite = "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND table_name=$1)"
|
||||
|
||||
func tableExists(dialect Dialect, db *sql.DB, table string) (exists bool) {
|
||||
if dialect == SQLite {
|
||||
_ = db.QueryRow(tableExistsSQLite, table).Scan(&exists)
|
||||
} else if dialect == Postgres {
|
||||
_ = db.QueryRow(tableExistsPostgres, table).Scan(&exists)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
const createOwnerTable = `
|
||||
CREATE TABLE IF NOT EXISTS database_owner (
|
||||
key INTEGER PRIMARY KEY DEFAULT 0,
|
||||
owner TEXT NOT NULL
|
||||
)
|
||||
`
|
||||
|
||||
func CheckDatabaseOwner(dialect Dialect, db *sql.DB) error {
|
||||
var owner string
|
||||
if !IgnoreForeignTables {
|
||||
if tableExists(dialect, db, "state_groups_state") {
|
||||
return fmt.Errorf("%w (found state_groups_state, likely belonging to Synapse)", ErrForeignTables)
|
||||
} else if tableExists(dialect, db, "goose_db_version") {
|
||||
return fmt.Errorf("%w (found goose_db_version, possibly belonging to Dendrite)", ErrForeignTables)
|
||||
}
|
||||
}
|
||||
if _, err := db.Exec(createOwnerTable); err != nil {
|
||||
return fmt.Errorf("failed to ensure database owner table exists: %w", err)
|
||||
} else if err = db.QueryRow("SELECT owner FROM database_owner WHERE key=0").Scan(&owner); errors.Is(err, sql.ErrNoRows) {
|
||||
_, err = db.Exec("INSERT INTO database_owner (owner) VALUES ($1)", databaseOwner)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert database owner: %w", err)
|
||||
}
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to check database owner: %w", err)
|
||||
} else if owner != databaseOwner {
|
||||
return fmt.Errorf("%w %s", ErrNotOwned, owner)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetVersion(tx *sql.Tx, version int) error {
|
||||
_, err := tx.Exec("DELETE FROM version")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec("INSERT INTO version (version) VALUES ($1)", version)
|
||||
return err
|
||||
}
|
||||
|
||||
func execMany(tx *sql.Tx, queries ...string) error {
|
||||
for _, query := range queries {
|
||||
_, err := tx.Exec(query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func Run(log log.Logger, dialectName string, db *sql.DB) error {
|
||||
var dialect Dialect
|
||||
switch strings.ToLower(dialectName) {
|
||||
case "postgres":
|
||||
dialect = Postgres
|
||||
case "sqlite3":
|
||||
dialect = SQLite
|
||||
default:
|
||||
return fmt.Errorf("unknown dialect %s", dialectName)
|
||||
}
|
||||
|
||||
err := CheckDatabaseOwner(dialect, db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
version, err := GetVersion(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if version > NumberOfUpgrades {
|
||||
return fmt.Errorf("%w: currently on v%d, latest known: v%d", ErrUnsupportedDatabaseVersion, version, NumberOfUpgrades)
|
||||
}
|
||||
|
||||
log.Infofln("Database currently on v%d, latest: v%d", version, NumberOfUpgrades)
|
||||
for i, upgradeItem := range upgrades[version:] {
|
||||
if upgradeItem.fn == nil {
|
||||
continue
|
||||
}
|
||||
log.Infofln("Upgrading database to v%d: %s", version+i+1, upgradeItem.message)
|
||||
var tx *sql.Tx
|
||||
tx, err = db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = upgradeItem.fn(tx, context{dialect, db, log})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = SetVersion(tx, version+i+1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
func init() {
|
||||
Table.Register(-1, 43, "Unsupported version", func(tx *sql.Tx, database *dbutil.Database) error {
|
||||
return errors.New("please upgrade to mautrix-whatsapp v0.4.0 before upgrading to a newer version")
|
||||
})
|
||||
Table.RegisterFS(rawUpgrades)
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ import (
|
|||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
"maunium.net/go/mautrix/util/dbutil"
|
||||
|
||||
"go.mau.fi/whatsmeow/types"
|
||||
)
|
||||
|
@ -89,7 +90,7 @@ type User struct {
|
|||
inSpaceCacheLock sync.Mutex
|
||||
}
|
||||
|
||||
func (user *User) Scan(row Scannable) *User {
|
||||
func (user *User) Scan(row dbutil.Scannable) *User {
|
||||
var username, timezone sql.NullString
|
||||
var device, agent sql.NullByte
|
||||
var phoneLastSeen, phoneLastPinged sql.NullInt64
|
||||
|
|
|
@ -50,9 +50,9 @@ func (portal *Portal) ScheduleDisappearing() {
|
|||
}
|
||||
}
|
||||
|
||||
func (bridge *Bridge) SleepAndDeleteUpcoming() {
|
||||
for _, msg := range bridge.DB.DisappearingMessage.GetUpcomingScheduled(1 * time.Hour) {
|
||||
portal := bridge.GetPortalByMXID(msg.RoomID)
|
||||
func (br *WABridge) SleepAndDeleteUpcoming() {
|
||||
for _, msg := range br.DB.DisappearingMessage.GetUpcomingScheduled(1 * time.Hour) {
|
||||
portal := br.GetPortalByMXID(msg.RoomID)
|
||||
if portal == nil {
|
||||
msg.Delete()
|
||||
} else {
|
||||
|
|
|
@ -43,14 +43,6 @@ appservice:
|
|||
max_conn_idle_time: null
|
||||
max_conn_lifetime: null
|
||||
|
||||
# Settings for provisioning API
|
||||
provisioning:
|
||||
# Prefix for the provisioning API paths.
|
||||
prefix: /_matrix/provision
|
||||
# Shared secret for authentication. If set to "generate", a random secret will be generated,
|
||||
# or if set to "disable", the provisioning API will be disabled.
|
||||
shared_secret: generate
|
||||
|
||||
# The unique ID of this appservice.
|
||||
id: whatsapp
|
||||
# Appservice bot details.
|
||||
|
@ -317,6 +309,14 @@ bridge:
|
|||
# Verification by the bridge is not yet implemented.
|
||||
require_verification: true
|
||||
|
||||
# Settings for provisioning API
|
||||
provisioning:
|
||||
# Prefix for the provisioning API paths.
|
||||
prefix: /_matrix/provision
|
||||
# Shared secret for authentication. If set to "generate", a random secret will be generated,
|
||||
# or if set to "disable", the provisioning API will be disabled.
|
||||
shared_secret: generate
|
||||
|
||||
# Permissions for using the bridge.
|
||||
# Permitted values:
|
||||
# relay - Talk through the relaybot (if enabled), no access otherwise
|
||||
|
|
|
@ -37,7 +37,7 @@ var codeBlockRegex = regexp.MustCompile("```(?:.|\n)+?```")
|
|||
const mentionedJIDsContextKey = "net.maunium.whatsapp.mentioned_jids"
|
||||
|
||||
type Formatter struct {
|
||||
bridge *Bridge
|
||||
bridge *WABridge
|
||||
|
||||
matrixHTMLParser *format.HTMLParser
|
||||
|
||||
|
@ -46,7 +46,7 @@ type Formatter struct {
|
|||
waReplFuncText map[*regexp.Regexp]func(string) string
|
||||
}
|
||||
|
||||
func NewFormatter(bridge *Bridge) *Formatter {
|
||||
func NewFormatter(bridge *WABridge) *Formatter {
|
||||
formatter := &Formatter{
|
||||
bridge: bridge,
|
||||
matrixHTMLParser: &format.HTMLParser{
|
||||
|
|
7
go.mod
7
go.mod
|
@ -14,10 +14,8 @@ require (
|
|||
golang.org/x/image v0.0.0-20220413100746-70e8d0d3baa9
|
||||
golang.org/x/net v0.0.0-20220513224357-95641704303c
|
||||
google.golang.org/protobuf v1.28.0
|
||||
gopkg.in/yaml.v3 v3.0.0-20220512140231-539c8e751b99
|
||||
maunium.net/go/mauflag v1.0.0
|
||||
maunium.net/go/maulogger/v2 v2.3.2
|
||||
maunium.net/go/mautrix v0.11.1-0.20220518174602-87d2cd49a4d1
|
||||
maunium.net/go/mautrix v0.11.1-0.20220521215033-d578d1a610d5
|
||||
)
|
||||
|
||||
require (
|
||||
|
@ -37,7 +35,8 @@ require (
|
|||
golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9 // indirect
|
||||
golang.org/x/sys v0.0.0-20220328115105-d36c6a25d886 // indirect
|
||||
golang.org/x/text v0.3.7 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.0 // indirect
|
||||
maunium.net/go/mauflag v1.0.0 // indirect
|
||||
)
|
||||
|
||||
// Exclude some things that cause go.sum to explode
|
||||
|
|
9
go.sum
9
go.sum
|
@ -99,14 +99,13 @@ google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqw
|
|||
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.0-20220512140231-539c8e751b99 h1:dbuHpmKjkDzSOMKAWl10QNlgaZUd3V1q99xc81tt2Kc=
|
||||
gopkg.in/yaml.v3 v3.0.0-20220512140231-539c8e751b99/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.0 h1:hjy8E9ON/egN1tAYqKb61G10WtihqetD4sz2H+8nIeA=
|
||||
gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
|
||||
maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA=
|
||||
maunium.net/go/maulogger/v2 v2.3.2 h1:1XmIYmMd3PoQfp9J+PaHhpt80zpfmMqaShzUTC7FwY0=
|
||||
maunium.net/go/maulogger/v2 v2.3.2/go.mod h1:TYWy7wKwz/tIXTpsx8G3mZseIRiC5DoMxSZazOHy68A=
|
||||
maunium.net/go/mautrix v0.11.1-0.20220518174602-87d2cd49a4d1 h1:+KEF+nSuBfHWsfQRz92YP/DdSLbComLoXCXgcrH6WRU=
|
||||
maunium.net/go/mautrix v0.11.1-0.20220518174602-87d2cd49a4d1/go.mod h1:K29EcHwsNg6r7fMfwvi0GHQ9o5wSjqB9+Q8RjCIQEjA=
|
||||
maunium.net/go/mautrix v0.11.1-0.20220521215033-d578d1a610d5 h1:7ZORg2h+lflc1HwjTKCXZnykauXD+wzbW+VDknbv6SU=
|
||||
maunium.net/go/mautrix v0.11.1-0.20220521215033-d578d1a610d5/go.mod h1:oma8o6Y/5jcViBlDbX7tp1ajP2XP+b78h8twdI+zKI0=
|
||||
|
|
515
main.go
515
main.go
|
@ -18,43 +18,26 @@ package main
|
|||
|
||||
import (
|
||||
_ "embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"go.mau.fi/whatsmeow"
|
||||
waProto "go.mau.fi/whatsmeow/binary/proto"
|
||||
"go.mau.fi/whatsmeow/store"
|
||||
"go.mau.fi/whatsmeow/store/sqlstore"
|
||||
"go.mau.fi/whatsmeow/types"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
flag "maunium.net/go/mauflag"
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/appservice"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/bridge"
|
||||
"maunium.net/go/mautrix/id"
|
||||
"maunium.net/go/mautrix/util/configupgrade"
|
||||
|
||||
"maunium.net/go/mautrix-whatsapp/config"
|
||||
"maunium.net/go/mautrix-whatsapp/database"
|
||||
"maunium.net/go/mautrix-whatsapp/database/upgrades"
|
||||
)
|
||||
|
||||
// The name and repo URL of the bridge.
|
||||
var (
|
||||
Name = "mautrix-whatsapp"
|
||||
URL = "https://github.com/mautrix/whatsapp"
|
||||
)
|
||||
|
||||
// Information to find out exactly which commit the bridge was built from.
|
||||
|
@ -65,120 +48,19 @@ var (
|
|||
BuildTime = "unknown"
|
||||
)
|
||||
|
||||
var (
|
||||
// Version is the version number of the bridge. Changed manually when making a release.
|
||||
Version = "0.4.0"
|
||||
// WAVersion is the version number exposed to WhatsApp. Filled in init()
|
||||
WAVersion = ""
|
||||
// VersionString is the bridge version, plus commit information. Filled in init() using the build-time values.
|
||||
VersionString = ""
|
||||
)
|
||||
|
||||
//go:embed example-config.yaml
|
||||
var ExampleConfig string
|
||||
|
||||
func init() {
|
||||
if len(Tag) > 0 && Tag[0] == 'v' {
|
||||
Tag = Tag[1:]
|
||||
}
|
||||
if Tag != Version {
|
||||
suffix := ""
|
||||
if !strings.HasSuffix(Version, "+dev") {
|
||||
suffix = "+dev"
|
||||
}
|
||||
if len(Commit) > 8 {
|
||||
Version = fmt.Sprintf("%s%s.%s", Version, suffix, Commit[:8])
|
||||
} else {
|
||||
Version = fmt.Sprintf("%s%s.unknown", Version, suffix)
|
||||
}
|
||||
}
|
||||
mautrix.DefaultUserAgent = fmt.Sprintf("mautrix-whatsapp/%s %s", Version, mautrix.DefaultUserAgent)
|
||||
WAVersion = strings.FieldsFunc(Version, func(r rune) bool { return r == '-' || r == '+' })[0]
|
||||
VersionString = fmt.Sprintf("%s %s (%s)", Name, Version, BuildTime)
|
||||
|
||||
config.ExampleConfig = ExampleConfig
|
||||
}
|
||||
|
||||
var configPath = flag.MakeFull("c", "config", "The path to your config file.", "config.yaml").String()
|
||||
var dontSaveConfig = flag.MakeFull("n", "no-update", "Don't save updated config to disk.", "false").Bool()
|
||||
var registrationPath = flag.MakeFull("r", "registration", "The path where to save the appservice registration.", "registration.yaml").String()
|
||||
var generateRegistration = flag.MakeFull("g", "generate-registration", "Generate registration and quit.", "false").Bool()
|
||||
var version = flag.MakeFull("v", "version", "View bridge version and quit.", "false").Bool()
|
||||
var ignoreUnsupportedDatabase = flag.Make().LongKey("ignore-unsupported-database").Usage("Run even if the database schema is too new").Default("false").Bool()
|
||||
var ignoreForeignTables = flag.Make().LongKey("ignore-foreign-tables").Usage("Run even if the database contains tables from other programs (like Synapse)").Default("false").Bool()
|
||||
var migrateFrom = flag.Make().LongKey("migrate-db").Usage("Source database type and URI to migrate from.").Bool()
|
||||
var wantHelp, _ = flag.MakeHelpFlag()
|
||||
|
||||
func (bridge *Bridge) GenerateRegistration() {
|
||||
if *dontSaveConfig {
|
||||
// We need to save the generated as_token and hs_token in the config
|
||||
_, _ = fmt.Fprintln(os.Stderr, "--no-update is not compatible with --generate-registration")
|
||||
os.Exit(5)
|
||||
}
|
||||
reg, err := bridge.Config.NewRegistration()
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintln(os.Stderr, "Failed to generate registration:", err)
|
||||
os.Exit(20)
|
||||
}
|
||||
|
||||
err = reg.Save(*registrationPath)
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintln(os.Stderr, "Failed to save registration:", err)
|
||||
os.Exit(21)
|
||||
}
|
||||
|
||||
err = config.Mutate(*configPath, func(helper *configupgrade.Helper) {
|
||||
helper.Set(configupgrade.Str, bridge.Config.AppService.ASToken, "appservice", "as_token")
|
||||
helper.Set(configupgrade.Str, bridge.Config.AppService.HSToken, "appservice", "hs_token")
|
||||
})
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintln(os.Stderr, "Failed to save config:", err)
|
||||
os.Exit(22)
|
||||
}
|
||||
fmt.Println("Registration generated. Add the path to the registration to your Synapse config, restart it, then start the bridge.")
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func (bridge *Bridge) MigrateDatabase() {
|
||||
oldDB, err := database.New(config.DatabaseConfig{Type: flag.Arg(0), URI: flag.Arg(1)}, log.DefaultLogger)
|
||||
if err != nil {
|
||||
fmt.Println("Failed to open old database:", err)
|
||||
os.Exit(30)
|
||||
}
|
||||
err = oldDB.Init()
|
||||
if err != nil {
|
||||
fmt.Println("Failed to upgrade old database:", err)
|
||||
os.Exit(31)
|
||||
}
|
||||
|
||||
newDB, err := database.New(bridge.Config.AppService.Database, log.DefaultLogger)
|
||||
if err != nil {
|
||||
fmt.Println("Failed to open new database:", err)
|
||||
os.Exit(32)
|
||||
}
|
||||
err = newDB.Init()
|
||||
if err != nil {
|
||||
fmt.Println("Failed to upgrade new database:", err)
|
||||
os.Exit(33)
|
||||
}
|
||||
|
||||
database.Migrate(oldDB, newDB)
|
||||
}
|
||||
|
||||
type Bridge struct {
|
||||
AS *appservice.AppService
|
||||
EventProcessor *appservice.EventProcessor
|
||||
MatrixHandler *MatrixHandler
|
||||
Config *config.Config
|
||||
DB *database.Database
|
||||
Log log.Logger
|
||||
StateStore *database.SQLStateStore
|
||||
Provisioning *ProvisioningAPI
|
||||
Bot *appservice.IntentAPI
|
||||
Formatter *Formatter
|
||||
Crypto Crypto
|
||||
Metrics *MetricsHandler
|
||||
WAContainer *sqlstore.Container
|
||||
type WABridge struct {
|
||||
bridge.Bridge
|
||||
MatrixHandler *MatrixHandler
|
||||
Config *config.Config
|
||||
DB *database.Database
|
||||
Provisioning *ProvisioningAPI
|
||||
Formatter *Formatter
|
||||
Metrics *MetricsHandler
|
||||
WAContainer *sqlstore.Container
|
||||
WAVersion string
|
||||
|
||||
usersByMXID map[id.UserID]*User
|
||||
usersByUsername map[string]*User
|
||||
|
@ -195,111 +77,32 @@ type Bridge struct {
|
|||
puppetsLock sync.Mutex
|
||||
}
|
||||
|
||||
type Crypto interface {
|
||||
HandleMemberEvent(*event.Event)
|
||||
Decrypt(*event.Event) (*event.Event, error)
|
||||
Encrypt(id.RoomID, event.Type, event.Content) (*event.EncryptedEventContent, error)
|
||||
WaitForSession(id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool
|
||||
RequestSession(id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID)
|
||||
ResetSession(id.RoomID)
|
||||
Init() error
|
||||
Start()
|
||||
Stop()
|
||||
}
|
||||
|
||||
func (bridge *Bridge) ensureConnection() {
|
||||
for {
|
||||
versions, err := bridge.Bot.Versions()
|
||||
if err != nil {
|
||||
bridge.Log.Errorfln("Failed to connect to homeserver: %v. Retrying in 10 seconds...", err)
|
||||
time.Sleep(10 * time.Second)
|
||||
continue
|
||||
}
|
||||
if !versions.ContainsGreaterOrEqual(mautrix.SpecV11) {
|
||||
bridge.Log.Warnfln("Server isn't advertising modern spec versions")
|
||||
}
|
||||
resp, err := bridge.Bot.Whoami()
|
||||
if err != nil {
|
||||
if errors.Is(err, mautrix.MUnknownToken) {
|
||||
bridge.Log.Fatalln("The as_token was not accepted. Is the registration file installed in your homeserver correctly?")
|
||||
os.Exit(16)
|
||||
} else if errors.Is(err, mautrix.MExclusive) {
|
||||
bridge.Log.Fatalln("The as_token was accepted, but the /register request was not. Are the homeserver domain and username template in the config correct, and do they match the values in the registration?")
|
||||
os.Exit(16)
|
||||
}
|
||||
bridge.Log.Errorfln("Failed to connect to homeserver: %v. Retrying in 10 seconds...", err)
|
||||
time.Sleep(10 * time.Second)
|
||||
} else if resp.UserID != bridge.Bot.UserID {
|
||||
bridge.Log.Fatalln("Unexpected user ID in whoami call: got %s, expected %s", resp.UserID, bridge.Bot.UserID)
|
||||
os.Exit(17)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (bridge *Bridge) Init() {
|
||||
var err error
|
||||
|
||||
bridge.AS, err = bridge.Config.MakeAppService()
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintln(os.Stderr, "Failed to initialize AppService:", err)
|
||||
os.Exit(11)
|
||||
}
|
||||
_, _ = bridge.AS.Init()
|
||||
|
||||
bridge.Log = log.Create()
|
||||
bridge.Config.Logging.Configure(bridge.Log)
|
||||
log.DefaultLogger = bridge.Log.(*log.BasicLogger)
|
||||
if len(bridge.Config.Logging.FileNameFormat) > 0 {
|
||||
err = log.OpenFile()
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintln(os.Stderr, "Failed to open log file:", err)
|
||||
os.Exit(12)
|
||||
}
|
||||
}
|
||||
bridge.AS.Log = log.Sub("Matrix")
|
||||
bridge.Bot = bridge.AS.BotIntent()
|
||||
bridge.Log.Infoln("Initializing", VersionString)
|
||||
|
||||
bridge.Log.Debugln("Initializing database connection")
|
||||
bridge.DB, err = database.New(bridge.Config.AppService.Database, bridge.Log)
|
||||
if err != nil {
|
||||
bridge.Log.Fatalln("Failed to initialize database connection:", err)
|
||||
os.Exit(14)
|
||||
}
|
||||
|
||||
bridge.Log.Debugln("Initializing state store")
|
||||
bridge.StateStore = database.NewSQLStateStore(bridge.DB)
|
||||
bridge.AS.StateStore = bridge.StateStore
|
||||
|
||||
Segment.log = bridge.Log.Sub("Segment")
|
||||
Segment.key = bridge.Config.SegmentKey
|
||||
func (br *WABridge) Init() {
|
||||
Segment.log = br.Log.Sub("Segment")
|
||||
Segment.key = br.Config.SegmentKey
|
||||
if Segment.IsEnabled() {
|
||||
Segment.log.Infoln("Segment metrics are enabled")
|
||||
}
|
||||
|
||||
bridge.WAContainer = sqlstore.NewWithDB(bridge.DB.DB, bridge.Config.AppService.Database.Type, nil)
|
||||
bridge.WAContainer.DatabaseErrorHandler = bridge.DB.HandleSignalStoreError
|
||||
br.DB = database.New(br.Bridge.DB)
|
||||
br.WAContainer = sqlstore.NewWithDB(br.DB.DB, br.DB.Dialect.String(), nil)
|
||||
br.WAContainer.DatabaseErrorHandler = br.DB.HandleSignalStoreError
|
||||
|
||||
ss := bridge.Config.AppService.Provisioning.SharedSecret
|
||||
ss := br.Config.Bridge.Provisioning.SharedSecret
|
||||
if len(ss) > 0 && ss != "disable" {
|
||||
bridge.Provisioning = &ProvisioningAPI{bridge: bridge}
|
||||
br.Provisioning = &ProvisioningAPI{bridge: br}
|
||||
}
|
||||
|
||||
bridge.Log.Debugln("Initializing Matrix event processor")
|
||||
bridge.EventProcessor = appservice.NewEventProcessor(bridge.AS)
|
||||
bridge.Log.Debugln("Initializing Matrix event handler")
|
||||
bridge.MatrixHandler = NewMatrixHandler(bridge)
|
||||
bridge.Formatter = NewFormatter(bridge)
|
||||
bridge.Crypto = NewCryptoHelper(bridge)
|
||||
bridge.Metrics = NewMetricsHandler(bridge.Config.Metrics.Listen, bridge.Log.Sub("Metrics"), bridge.DB)
|
||||
br.Log.Debugln("Initializing Matrix event handler")
|
||||
br.MatrixHandler = NewMatrixHandler(br)
|
||||
br.Formatter = NewFormatter(br)
|
||||
br.Metrics = NewMetricsHandler(br.Config.Metrics.Listen, br.Log.Sub("Metrics"), br.DB)
|
||||
|
||||
store.BaseClientPayload.UserAgent.OsVersion = proto.String(WAVersion)
|
||||
store.BaseClientPayload.UserAgent.OsBuildNumber = proto.String(WAVersion)
|
||||
store.CompanionProps.Os = proto.String(bridge.Config.WhatsApp.OSName)
|
||||
store.CompanionProps.RequireFullSync = proto.Bool(bridge.Config.Bridge.HistorySync.RequestFullSync)
|
||||
versionParts := strings.Split(WAVersion, ".")
|
||||
store.BaseClientPayload.UserAgent.OsVersion = proto.String(br.WAVersion)
|
||||
store.BaseClientPayload.UserAgent.OsBuildNumber = proto.String(br.WAVersion)
|
||||
store.CompanionProps.Os = proto.String(br.Config.WhatsApp.OSName)
|
||||
store.CompanionProps.RequireFullSync = proto.Bool(br.Config.Bridge.HistorySync.RequestFullSync)
|
||||
versionParts := strings.Split(br.WAVersion, ".")
|
||||
if len(versionParts) > 2 {
|
||||
primary, _ := strconv.Atoi(versionParts[0])
|
||||
secondary, _ := strconv.Atoi(versionParts[1])
|
||||
|
@ -308,161 +111,107 @@ func (bridge *Bridge) Init() {
|
|||
store.CompanionProps.Version.Secondary = proto.Uint32(uint32(secondary))
|
||||
store.CompanionProps.Version.Tertiary = proto.Uint32(uint32(tertiary))
|
||||
}
|
||||
platformID, ok := waProto.CompanionProps_CompanionPropsPlatformType_value[strings.ToUpper(bridge.Config.WhatsApp.BrowserName)]
|
||||
platformID, ok := waProto.CompanionProps_CompanionPropsPlatformType_value[strings.ToUpper(br.Config.WhatsApp.BrowserName)]
|
||||
if ok {
|
||||
store.CompanionProps.PlatformType = waProto.CompanionProps_CompanionPropsPlatformType(platformID).Enum()
|
||||
}
|
||||
}
|
||||
|
||||
func (bridge *Bridge) Start() {
|
||||
bridge.Log.Debugln("Running database upgrades")
|
||||
err := bridge.DB.Init()
|
||||
if err != nil && (!errors.Is(err, upgrades.ErrUnsupportedDatabaseVersion) || !*ignoreUnsupportedDatabase) {
|
||||
bridge.Log.Fatalln("Failed to initialize database:", err)
|
||||
if errors.Is(err, upgrades.ErrForeignTables) {
|
||||
bridge.Log.Infoln("You can use --ignore-foreign-tables to ignore this error")
|
||||
} else if errors.Is(err, upgrades.ErrNotOwned) {
|
||||
bridge.Log.Infoln("Sharing the same database with different programs is not supported")
|
||||
} else if errors.Is(err, upgrades.ErrUnsupportedDatabaseVersion) {
|
||||
bridge.Log.Infoln("Downgrading the bridge is not supported")
|
||||
}
|
||||
func (br *WABridge) Start() {
|
||||
err := br.WAContainer.Upgrade()
|
||||
if err != nil {
|
||||
br.Log.Fatalln("Failed to upgrade whatsmeow database: %v", err)
|
||||
os.Exit(15)
|
||||
}
|
||||
bridge.Log.Debugln("Checking connection to homeserver")
|
||||
bridge.ensureConnection()
|
||||
if bridge.Crypto != nil {
|
||||
err = bridge.Crypto.Init()
|
||||
if err != nil {
|
||||
bridge.Log.Fatalln("Error initializing end-to-bridge encryption:", err)
|
||||
os.Exit(19)
|
||||
}
|
||||
if br.Provisioning != nil {
|
||||
br.Log.Debugln("Initializing provisioning API")
|
||||
br.Provisioning.Init()
|
||||
}
|
||||
if bridge.Provisioning != nil {
|
||||
bridge.Log.Debugln("Initializing provisioning API")
|
||||
bridge.Provisioning.Init()
|
||||
}
|
||||
bridge.Log.Debugln("Starting application service HTTP server")
|
||||
go bridge.AS.Start()
|
||||
bridge.Log.Debugln("Starting event processor")
|
||||
go bridge.EventProcessor.Start()
|
||||
go bridge.CheckWhatsAppUpdate()
|
||||
go bridge.UpdateBotProfile()
|
||||
if bridge.Crypto != nil {
|
||||
go bridge.Crypto.Start()
|
||||
}
|
||||
go bridge.StartUsers()
|
||||
if bridge.Config.Metrics.Enabled {
|
||||
go bridge.Metrics.Start()
|
||||
go br.CheckWhatsAppUpdate()
|
||||
go br.StartUsers()
|
||||
if br.Config.Metrics.Enabled {
|
||||
go br.Metrics.Start()
|
||||
}
|
||||
|
||||
if bridge.Config.Bridge.ResendBridgeInfo {
|
||||
go bridge.ResendBridgeInfo()
|
||||
if br.Config.Bridge.ResendBridgeInfo {
|
||||
go br.ResendBridgeInfo()
|
||||
}
|
||||
go bridge.Loop()
|
||||
bridge.AS.Ready = true
|
||||
go br.Loop()
|
||||
}
|
||||
|
||||
func (bridge *Bridge) CheckWhatsAppUpdate() {
|
||||
bridge.Log.Debugfln("Checking for WhatsApp web update")
|
||||
func (br *WABridge) CheckWhatsAppUpdate() {
|
||||
br.Log.Debugfln("Checking for WhatsApp web update")
|
||||
resp, err := whatsmeow.CheckUpdate(http.DefaultClient)
|
||||
if err != nil {
|
||||
bridge.Log.Warnfln("Failed to check for WhatsApp web update: %v", err)
|
||||
br.Log.Warnfln("Failed to check for WhatsApp web update: %v", err)
|
||||
return
|
||||
}
|
||||
if store.GetWAVersion() == resp.ParsedVersion {
|
||||
bridge.Log.Debugfln("Bridge is using latest WhatsApp web protocol")
|
||||
br.Log.Debugfln("Bridge is using latest WhatsApp web protocol")
|
||||
} else if store.GetWAVersion().LessThan(resp.ParsedVersion) {
|
||||
if resp.IsBelowHard || resp.IsBroken {
|
||||
bridge.Log.Warnfln("Bridge is using outdated WhatsApp web protocol and probably doesn't work anymore (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion)
|
||||
br.Log.Warnfln("Bridge is using outdated WhatsApp web protocol and probably doesn't work anymore (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion)
|
||||
} else if resp.IsBelowSoft {
|
||||
bridge.Log.Infofln("Bridge is using outdated WhatsApp web protocol (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion)
|
||||
br.Log.Infofln("Bridge is using outdated WhatsApp web protocol (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion)
|
||||
} else {
|
||||
bridge.Log.Debugfln("Bridge is using outdated WhatsApp web protocol (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion)
|
||||
br.Log.Debugfln("Bridge is using outdated WhatsApp web protocol (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion)
|
||||
}
|
||||
} else {
|
||||
bridge.Log.Debugfln("Bridge is using newer than latest WhatsApp web protocol")
|
||||
br.Log.Debugfln("Bridge is using newer than latest WhatsApp web protocol")
|
||||
}
|
||||
}
|
||||
|
||||
func (bridge *Bridge) Loop() {
|
||||
func (br *WABridge) Loop() {
|
||||
for {
|
||||
bridge.SleepAndDeleteUpcoming()
|
||||
br.SleepAndDeleteUpcoming()
|
||||
time.Sleep(1 * time.Hour)
|
||||
bridge.WarnUsersAboutDisconnection()
|
||||
br.WarnUsersAboutDisconnection()
|
||||
}
|
||||
}
|
||||
|
||||
func (bridge *Bridge) WarnUsersAboutDisconnection() {
|
||||
bridge.usersLock.Lock()
|
||||
for _, user := range bridge.usersByUsername {
|
||||
func (br *WABridge) WarnUsersAboutDisconnection() {
|
||||
br.usersLock.Lock()
|
||||
for _, user := range br.usersByUsername {
|
||||
if user.IsConnected() && !user.PhoneRecentlySeen(true) {
|
||||
go user.sendPhoneOfflineWarning()
|
||||
}
|
||||
}
|
||||
bridge.usersLock.Unlock()
|
||||
br.usersLock.Unlock()
|
||||
}
|
||||
|
||||
func (bridge *Bridge) ResendBridgeInfo() {
|
||||
if *dontSaveConfig {
|
||||
bridge.Log.Warnln("Not setting resend_bridge_info to false in config due to --no-update flag")
|
||||
} else {
|
||||
err := config.Mutate(*configPath, func(helper *configupgrade.Helper) {
|
||||
helper.Set(configupgrade.Bool, "false", "bridge", "resend_bridge_info")
|
||||
})
|
||||
if err != nil {
|
||||
bridge.Log.Errorln("Failed to save config after setting resend_bridge_info to false:", err)
|
||||
}
|
||||
}
|
||||
bridge.Log.Infoln("Re-sending bridge info state event to all portals")
|
||||
for _, portal := range bridge.GetAllPortals() {
|
||||
portal.UpdateBridgeInfo()
|
||||
}
|
||||
bridge.Log.Infoln("Finished re-sending bridge info state events")
|
||||
func (br *WABridge) ResendBridgeInfo() {
|
||||
// FIXME
|
||||
//if *dontSaveConfig {
|
||||
// br.Log.Warnln("Not setting resend_bridge_info to false in config due to --no-update flag")
|
||||
//} else {
|
||||
// err := config.Mutate(*configPath, func(helper *configupgrade.Helper) {
|
||||
// helper.Set(configupgrade.Bool, "false", "bridge", "resend_bridge_info")
|
||||
// })
|
||||
// if err != nil {
|
||||
// br.Log.Errorln("Failed to save config after setting resend_bridge_info to false:", err)
|
||||
// }
|
||||
//}
|
||||
//br.Log.Infoln("Re-sending bridge info state event to all portals")
|
||||
//for _, portal := range br.GetAllPortals() {
|
||||
// portal.UpdateBridgeInfo()
|
||||
//}
|
||||
//br.Log.Infoln("Finished re-sending bridge info state events")
|
||||
}
|
||||
|
||||
func (bridge *Bridge) UpdateBotProfile() {
|
||||
bridge.Log.Debugln("Updating bot profile")
|
||||
botConfig := &bridge.Config.AppService.Bot
|
||||
|
||||
var err error
|
||||
var mxc id.ContentURI
|
||||
if botConfig.Avatar == "remove" {
|
||||
err = bridge.Bot.SetAvatarURL(mxc)
|
||||
} else if len(botConfig.Avatar) > 0 {
|
||||
mxc, err = id.ParseContentURI(botConfig.Avatar)
|
||||
if err == nil {
|
||||
err = bridge.Bot.SetAvatarURL(mxc)
|
||||
}
|
||||
botConfig.ParsedAvatar = mxc
|
||||
}
|
||||
if err != nil {
|
||||
bridge.Log.Warnln("Failed to update bot avatar:", err)
|
||||
}
|
||||
|
||||
if botConfig.Displayname == "remove" {
|
||||
err = bridge.Bot.SetDisplayName("")
|
||||
} else if len(botConfig.Displayname) > 0 {
|
||||
err = bridge.Bot.SetDisplayName(botConfig.Displayname)
|
||||
}
|
||||
if err != nil {
|
||||
bridge.Log.Warnln("Failed to update bot displayname:", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (bridge *Bridge) StartUsers() {
|
||||
bridge.Log.Debugln("Starting users")
|
||||
func (br *WABridge) StartUsers() {
|
||||
br.Log.Debugln("Starting users")
|
||||
foundAnySessions := false
|
||||
for _, user := range bridge.GetAllUsers() {
|
||||
for _, user := range br.GetAllUsers() {
|
||||
if !user.JID.IsEmpty() {
|
||||
foundAnySessions = true
|
||||
}
|
||||
go user.Connect()
|
||||
}
|
||||
if !foundAnySessions {
|
||||
bridge.sendGlobalBridgeState(BridgeState{StateEvent: StateUnconfigured}.fill(nil))
|
||||
br.sendGlobalBridgeState(BridgeState{StateEvent: StateUnconfigured}.fill(nil))
|
||||
}
|
||||
bridge.Log.Debugln("Starting custom puppets")
|
||||
for _, loopuppet := range bridge.GetAllPuppetsWithCustomMXID() {
|
||||
br.Log.Debugln("Starting custom puppets")
|
||||
for _, loopuppet := range br.GetAllPuppetsWithCustomMXID() {
|
||||
go func(puppet *Puppet) {
|
||||
puppet.log.Debugln("Starting custom puppet", puppet.CustomMXID)
|
||||
err := puppet.StartCustomMXID(true)
|
||||
|
@ -473,80 +222,37 @@ func (bridge *Bridge) StartUsers() {
|
|||
}
|
||||
}
|
||||
|
||||
func (bridge *Bridge) Stop() {
|
||||
if bridge.Crypto != nil {
|
||||
bridge.Crypto.Stop()
|
||||
func (br *WABridge) Stop() {
|
||||
if br.Crypto != nil {
|
||||
br.Crypto.Stop()
|
||||
}
|
||||
bridge.AS.Stop()
|
||||
bridge.Metrics.Stop()
|
||||
bridge.EventProcessor.Stop()
|
||||
for _, user := range bridge.usersByUsername {
|
||||
br.AS.Stop()
|
||||
br.Metrics.Stop()
|
||||
br.EventProcessor.Stop()
|
||||
for _, user := range br.usersByUsername {
|
||||
if user.Client == nil {
|
||||
continue
|
||||
}
|
||||
bridge.Log.Debugln("Disconnecting", user.MXID)
|
||||
br.Log.Debugln("Disconnecting", user.MXID)
|
||||
user.Client.Disconnect()
|
||||
close(user.historySyncs)
|
||||
}
|
||||
}
|
||||
|
||||
func (bridge *Bridge) Main() {
|
||||
configData, upgraded, err := config.Upgrade(*configPath, !*dontSaveConfig)
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintln(os.Stderr, "Error updating config:", err)
|
||||
if configData == nil {
|
||||
os.Exit(10)
|
||||
}
|
||||
func (br *WABridge) GetExampleConfig() string {
|
||||
return ExampleConfig
|
||||
}
|
||||
|
||||
func (br *WABridge) GetConfigPtr() interface{} {
|
||||
br.Config = &config.Config{
|
||||
BaseConfig: &br.Bridge.Config,
|
||||
}
|
||||
|
||||
bridge.Config, err = config.Load(configData, upgraded)
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintln(os.Stderr, "Failed to parse config:", err)
|
||||
os.Exit(10)
|
||||
}
|
||||
|
||||
if *generateRegistration {
|
||||
bridge.GenerateRegistration()
|
||||
return
|
||||
} else if *migrateFrom {
|
||||
bridge.MigrateDatabase()
|
||||
return
|
||||
}
|
||||
|
||||
bridge.Init()
|
||||
bridge.Log.Infoln("Bridge initialization complete, starting...")
|
||||
bridge.Start()
|
||||
bridge.Log.Infoln("Bridge started!")
|
||||
|
||||
c := make(chan os.Signal)
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||
<-c
|
||||
|
||||
bridge.Log.Infoln("Interrupt received, stopping...")
|
||||
bridge.Stop()
|
||||
bridge.Log.Infoln("Bridge stopped.")
|
||||
os.Exit(0)
|
||||
br.Config.BaseConfig.Bridge = &br.Config.Bridge
|
||||
return br.Config
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.SetHelpTitles(
|
||||
"mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.",
|
||||
"mautrix-whatsapp [-h] [-c <path>] [-r <path>] [-g] [--migrate-db <source type> <source uri>]")
|
||||
err := flag.Parse()
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintln(os.Stderr, err)
|
||||
flag.PrintHelp()
|
||||
os.Exit(1)
|
||||
} else if *wantHelp {
|
||||
flag.PrintHelp()
|
||||
os.Exit(0)
|
||||
} else if *version {
|
||||
fmt.Println(VersionString)
|
||||
return
|
||||
}
|
||||
upgrades.IgnoreForeignTables = *ignoreForeignTables
|
||||
|
||||
(&Bridge{
|
||||
br := &WABridge{
|
||||
usersByMXID: make(map[id.UserID]*User),
|
||||
usersByUsername: make(map[string]*User),
|
||||
spaceRooms: make(map[id.RoomID]*User),
|
||||
|
@ -555,5 +261,24 @@ func main() {
|
|||
portalsByJID: make(map[database.PortalKey]*Portal),
|
||||
puppets: make(map[types.JID]*Puppet),
|
||||
puppetsByCustomMXID: make(map[id.UserID]*Puppet),
|
||||
}).Main()
|
||||
}
|
||||
br.Bridge = bridge.Bridge{
|
||||
Name: "mautrix-whatsapp",
|
||||
URL: "https://github.com/mautrix/whatsapp",
|
||||
Description: "A Matrix-WhatsApp puppeting bridge.",
|
||||
Version: "0.4.0",
|
||||
ProtocolName: "WhatsApp",
|
||||
|
||||
ConfigUpgrader: &configupgrade.StructUpgrader{
|
||||
SimpleUpgrader: configupgrade.SimpleUpgrader(config.DoUpgrade),
|
||||
Blocks: config.SpacedBlocks,
|
||||
Base: ExampleConfig,
|
||||
},
|
||||
|
||||
Child: br,
|
||||
}
|
||||
br.InitVersion(Tag, Commit, BuildTime)
|
||||
br.WAVersion = strings.FieldsFunc(br.Version, func(r rune) bool { return r == '-' || r == '+' })[0]
|
||||
|
||||
br.Main()
|
||||
}
|
||||
|
|
|
@ -28,6 +28,7 @@ import (
|
|||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/appservice"
|
||||
"maunium.net/go/mautrix/bridge"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/format"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
@ -36,13 +37,13 @@ import (
|
|||
)
|
||||
|
||||
type MatrixHandler struct {
|
||||
bridge *Bridge
|
||||
bridge *WABridge
|
||||
as *appservice.AppService
|
||||
log maulogger.Logger
|
||||
cmd *CommandHandler
|
||||
}
|
||||
|
||||
func NewMatrixHandler(bridge *Bridge) *MatrixHandler {
|
||||
func NewMatrixHandler(bridge *WABridge) *MatrixHandler {
|
||||
handler := &MatrixHandler{
|
||||
bridge: bridge,
|
||||
as: bridge.AS,
|
||||
|
@ -362,7 +363,7 @@ func (mx *MatrixHandler) HandleEncrypted(evt *event.Event) {
|
|||
|
||||
decrypted, err := mx.bridge.Crypto.Decrypt(evt)
|
||||
decryptionRetryCount := 0
|
||||
if errors.Is(err, NoSessionFound) {
|
||||
if errors.Is(err, bridge.NoSessionFound) {
|
||||
content := evt.Content.AsEncrypted()
|
||||
mx.log.Debugfln("Couldn't find session %s trying to decrypt %s, waiting %d seconds...", content.SessionID, evt.ID, int(sessionWaitTimeout.Seconds()))
|
||||
mx.as.SendErrorMessageSendCheckpoint(evt, appservice.StepDecrypted, err, false, decryptionRetryCount)
|
||||
|
|
17
no-crypto.go
17
no-crypto.go
|
@ -1,17 +0,0 @@
|
|||
//go:build !cgo || nocrypto
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
func NewCryptoHelper(bridge *Bridge) Crypto {
|
||||
if !bridge.Config.Bridge.Encryption.Allow {
|
||||
bridge.Log.Warnln("Bridge built without end-to-bridge encryption, but encryption is enabled in config")
|
||||
}
|
||||
bridge.Log.Debugln("Bridge built without end-to-bridge encryption")
|
||||
return nil
|
||||
}
|
||||
|
||||
var NoSessionFound = errors.New("nil")
|
85
portal.go
85
portal.go
|
@ -45,6 +45,7 @@ import (
|
|||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/appservice"
|
||||
"maunium.net/go/mautrix/bridge"
|
||||
"maunium.net/go/mautrix/crypto/attachment"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/format"
|
||||
|
@ -69,68 +70,72 @@ const PrivateChatTopic = "WhatsApp private chat"
|
|||
|
||||
var ErrStatusBroadcastDisabled = errors.New("status bridging is disabled")
|
||||
|
||||
func (bridge *Bridge) GetPortalByMXID(mxid id.RoomID) *Portal {
|
||||
bridge.portalsLock.Lock()
|
||||
defer bridge.portalsLock.Unlock()
|
||||
portal, ok := bridge.portalsByMXID[mxid]
|
||||
func (br *WABridge) GetPortalByMXID(mxid id.RoomID) *Portal {
|
||||
br.portalsLock.Lock()
|
||||
defer br.portalsLock.Unlock()
|
||||
portal, ok := br.portalsByMXID[mxid]
|
||||
if !ok {
|
||||
return bridge.loadDBPortal(bridge.DB.Portal.GetByMXID(mxid), nil)
|
||||
return br.loadDBPortal(br.DB.Portal.GetByMXID(mxid), nil)
|
||||
}
|
||||
return portal
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetPortalByJID(key database.PortalKey) *Portal {
|
||||
bridge.portalsLock.Lock()
|
||||
defer bridge.portalsLock.Unlock()
|
||||
portal, ok := bridge.portalsByJID[key]
|
||||
func (br *WABridge) GetIPortalByMXID(mxid id.RoomID) bridge.Portal {
|
||||
return br.GetPortalByMXID(mxid)
|
||||
}
|
||||
|
||||
func (br *WABridge) GetPortalByJID(key database.PortalKey) *Portal {
|
||||
br.portalsLock.Lock()
|
||||
defer br.portalsLock.Unlock()
|
||||
portal, ok := br.portalsByJID[key]
|
||||
if !ok {
|
||||
return bridge.loadDBPortal(bridge.DB.Portal.GetByJID(key), &key)
|
||||
return br.loadDBPortal(br.DB.Portal.GetByJID(key), &key)
|
||||
}
|
||||
return portal
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetAllPortals() []*Portal {
|
||||
return bridge.dbPortalsToPortals(bridge.DB.Portal.GetAll())
|
||||
func (br *WABridge) GetAllPortals() []*Portal {
|
||||
return br.dbPortalsToPortals(br.DB.Portal.GetAll())
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetAllPortalsForUser(userID id.UserID) []*Portal {
|
||||
return bridge.dbPortalsToPortals(bridge.DB.Portal.GetAllForUser(userID))
|
||||
func (br *WABridge) GetAllPortalsForUser(userID id.UserID) []*Portal {
|
||||
return br.dbPortalsToPortals(br.DB.Portal.GetAllForUser(userID))
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetAllPortalsByJID(jid types.JID) []*Portal {
|
||||
return bridge.dbPortalsToPortals(bridge.DB.Portal.GetAllByJID(jid))
|
||||
func (br *WABridge) GetAllPortalsByJID(jid types.JID) []*Portal {
|
||||
return br.dbPortalsToPortals(br.DB.Portal.GetAllByJID(jid))
|
||||
}
|
||||
|
||||
func (bridge *Bridge) dbPortalsToPortals(dbPortals []*database.Portal) []*Portal {
|
||||
bridge.portalsLock.Lock()
|
||||
defer bridge.portalsLock.Unlock()
|
||||
func (br *WABridge) dbPortalsToPortals(dbPortals []*database.Portal) []*Portal {
|
||||
br.portalsLock.Lock()
|
||||
defer br.portalsLock.Unlock()
|
||||
output := make([]*Portal, len(dbPortals))
|
||||
for index, dbPortal := range dbPortals {
|
||||
if dbPortal == nil {
|
||||
continue
|
||||
}
|
||||
portal, ok := bridge.portalsByJID[dbPortal.Key]
|
||||
portal, ok := br.portalsByJID[dbPortal.Key]
|
||||
if !ok {
|
||||
portal = bridge.loadDBPortal(dbPortal, nil)
|
||||
portal = br.loadDBPortal(dbPortal, nil)
|
||||
}
|
||||
output[index] = portal
|
||||
}
|
||||
return output
|
||||
}
|
||||
|
||||
func (bridge *Bridge) loadDBPortal(dbPortal *database.Portal, key *database.PortalKey) *Portal {
|
||||
func (br *WABridge) loadDBPortal(dbPortal *database.Portal, key *database.PortalKey) *Portal {
|
||||
if dbPortal == nil {
|
||||
if key == nil {
|
||||
return nil
|
||||
}
|
||||
dbPortal = bridge.DB.Portal.New()
|
||||
dbPortal = br.DB.Portal.New()
|
||||
dbPortal.Key = *key
|
||||
dbPortal.Insert()
|
||||
}
|
||||
portal := bridge.NewPortal(dbPortal)
|
||||
bridge.portalsByJID[portal.Key] = portal
|
||||
portal := br.NewPortal(dbPortal)
|
||||
br.portalsByJID[portal.Key] = portal
|
||||
if len(portal.MXID) > 0 {
|
||||
bridge.portalsByMXID[portal.MXID] = portal
|
||||
br.portalsByMXID[portal.MXID] = portal
|
||||
}
|
||||
return portal
|
||||
}
|
||||
|
@ -139,14 +144,14 @@ func (portal *Portal) GetUsers() []*User {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (bridge *Bridge) newBlankPortal(key database.PortalKey) *Portal {
|
||||
func (br *WABridge) newBlankPortal(key database.PortalKey) *Portal {
|
||||
portal := &Portal{
|
||||
bridge: bridge,
|
||||
log: bridge.Log.Sub(fmt.Sprintf("Portal/%s", key)),
|
||||
bridge: br,
|
||||
log: br.Log.Sub(fmt.Sprintf("Portal/%s", key)),
|
||||
|
||||
messages: make(chan PortalMessage, bridge.Config.Bridge.PortalMessageBuffer),
|
||||
matrixMessages: make(chan PortalMatrixMessage, bridge.Config.Bridge.PortalMessageBuffer),
|
||||
mediaRetries: make(chan PortalMediaRetry, bridge.Config.Bridge.PortalMessageBuffer),
|
||||
messages: make(chan PortalMessage, br.Config.Bridge.PortalMessageBuffer),
|
||||
matrixMessages: make(chan PortalMatrixMessage, br.Config.Bridge.PortalMessageBuffer),
|
||||
mediaRetries: make(chan PortalMediaRetry, br.Config.Bridge.PortalMessageBuffer),
|
||||
|
||||
mediaErrorCache: make(map[types.MessageID]*FailedMediaMeta),
|
||||
}
|
||||
|
@ -154,15 +159,15 @@ func (bridge *Bridge) newBlankPortal(key database.PortalKey) *Portal {
|
|||
return portal
|
||||
}
|
||||
|
||||
func (bridge *Bridge) NewManualPortal(key database.PortalKey) *Portal {
|
||||
portal := bridge.newBlankPortal(key)
|
||||
portal.Portal = bridge.DB.Portal.New()
|
||||
func (br *WABridge) NewManualPortal(key database.PortalKey) *Portal {
|
||||
portal := br.newBlankPortal(key)
|
||||
portal.Portal = br.DB.Portal.New()
|
||||
portal.Key = key
|
||||
return portal
|
||||
}
|
||||
|
||||
func (bridge *Bridge) NewPortal(dbPortal *database.Portal) *Portal {
|
||||
portal := bridge.newBlankPortal(dbPortal.Key)
|
||||
func (br *WABridge) NewPortal(dbPortal *database.Portal) *Portal {
|
||||
portal := br.newBlankPortal(dbPortal.Key)
|
||||
portal.Portal = dbPortal
|
||||
return portal
|
||||
}
|
||||
|
@ -203,7 +208,7 @@ type recentlyHandledWrapper struct {
|
|||
type Portal struct {
|
||||
*database.Portal
|
||||
|
||||
bridge *Bridge
|
||||
bridge *WABridge
|
||||
log log.Logger
|
||||
|
||||
roomCreateLock sync.Mutex
|
||||
|
@ -229,6 +234,10 @@ type Portal struct {
|
|||
relayUser *User
|
||||
}
|
||||
|
||||
func (portal *Portal) IsEncrypted() bool {
|
||||
return portal.Encrypted
|
||||
}
|
||||
|
||||
func (portal *Portal) handleMessageLoopItem(msg PortalMessage) {
|
||||
if len(portal.MXID) == 0 {
|
||||
if msg.fake == nil && msg.undecryptable == nil && (msg.evt == nil || !containsSupportedMessage(msg.evt.Message)) {
|
||||
|
|
|
@ -43,15 +43,15 @@ import (
|
|||
)
|
||||
|
||||
type ProvisioningAPI struct {
|
||||
bridge *Bridge
|
||||
bridge *WABridge
|
||||
log log.Logger
|
||||
}
|
||||
|
||||
func (prov *ProvisioningAPI) Init() {
|
||||
prov.log = prov.bridge.Log.Sub("Provisioning")
|
||||
|
||||
prov.log.Debugln("Enabling provisioning API at", prov.bridge.Config.AppService.Provisioning.Prefix)
|
||||
r := prov.bridge.AS.Router.PathPrefix(prov.bridge.Config.AppService.Provisioning.Prefix).Subrouter()
|
||||
prov.log.Debugln("Enabling provisioning API at", prov.bridge.Config.Bridge.Provisioning.Prefix)
|
||||
r := prov.bridge.AS.Router.PathPrefix(prov.bridge.Config.Bridge.Provisioning.Prefix).Subrouter()
|
||||
r.Use(prov.AuthMiddleware)
|
||||
r.HandleFunc("/v1/ping", prov.Ping).Methods(http.MethodGet)
|
||||
r.HandleFunc("/v1/login", prov.Login).Methods(http.MethodGet)
|
||||
|
@ -109,7 +109,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
|
|||
} else if strings.HasPrefix(auth, "Bearer ") {
|
||||
auth = auth[len("Bearer "):]
|
||||
}
|
||||
if auth != prov.bridge.Config.AppService.Provisioning.SharedSecret {
|
||||
if auth != prov.bridge.Config.Bridge.Provisioning.SharedSecret {
|
||||
jsonResponse(w, http.StatusForbidden, map[string]interface{}{
|
||||
"error": "Invalid auth token",
|
||||
"errcode": "M_FORBIDDEN",
|
||||
|
|
84
puppet.go
84
puppet.go
|
@ -39,11 +39,11 @@ import (
|
|||
|
||||
var userIDRegex *regexp.Regexp
|
||||
|
||||
func (bridge *Bridge) ParsePuppetMXID(mxid id.UserID) (jid types.JID, ok bool) {
|
||||
func (br *WABridge) ParsePuppetMXID(mxid id.UserID) (jid types.JID, ok bool) {
|
||||
if userIDRegex == nil {
|
||||
userIDRegex = regexp.MustCompile(fmt.Sprintf("^@%s:%s$",
|
||||
bridge.Config.Bridge.FormatUsername("([0-9]+)"),
|
||||
bridge.Config.Homeserver.Domain))
|
||||
br.Config.Bridge.FormatUsername("([0-9]+)"),
|
||||
br.Config.Homeserver.Domain))
|
||||
}
|
||||
match := userIDRegex.FindStringSubmatch(string(mxid))
|
||||
if len(match) == 2 {
|
||||
|
@ -53,79 +53,79 @@ func (bridge *Bridge) ParsePuppetMXID(mxid id.UserID) (jid types.JID, ok bool) {
|
|||
return
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetPuppetByMXID(mxid id.UserID) *Puppet {
|
||||
jid, ok := bridge.ParsePuppetMXID(mxid)
|
||||
func (br *WABridge) GetPuppetByMXID(mxid id.UserID) *Puppet {
|
||||
jid, ok := br.ParsePuppetMXID(mxid)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return bridge.GetPuppetByJID(jid)
|
||||
return br.GetPuppetByJID(jid)
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetPuppetByJID(jid types.JID) *Puppet {
|
||||
func (br *WABridge) GetPuppetByJID(jid types.JID) *Puppet {
|
||||
jid = jid.ToNonAD()
|
||||
if jid.Server == types.LegacyUserServer {
|
||||
jid.Server = types.DefaultUserServer
|
||||
} else if jid.Server != types.DefaultUserServer {
|
||||
return nil
|
||||
}
|
||||
bridge.puppetsLock.Lock()
|
||||
defer bridge.puppetsLock.Unlock()
|
||||
puppet, ok := bridge.puppets[jid]
|
||||
br.puppetsLock.Lock()
|
||||
defer br.puppetsLock.Unlock()
|
||||
puppet, ok := br.puppets[jid]
|
||||
if !ok {
|
||||
dbPuppet := bridge.DB.Puppet.Get(jid)
|
||||
dbPuppet := br.DB.Puppet.Get(jid)
|
||||
if dbPuppet == nil {
|
||||
dbPuppet = bridge.DB.Puppet.New()
|
||||
dbPuppet = br.DB.Puppet.New()
|
||||
dbPuppet.JID = jid
|
||||
dbPuppet.Insert()
|
||||
}
|
||||
puppet = bridge.NewPuppet(dbPuppet)
|
||||
bridge.puppets[puppet.JID] = puppet
|
||||
puppet = br.NewPuppet(dbPuppet)
|
||||
br.puppets[puppet.JID] = puppet
|
||||
if len(puppet.CustomMXID) > 0 {
|
||||
bridge.puppetsByCustomMXID[puppet.CustomMXID] = puppet
|
||||
br.puppetsByCustomMXID[puppet.CustomMXID] = puppet
|
||||
}
|
||||
}
|
||||
return puppet
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetPuppetByCustomMXID(mxid id.UserID) *Puppet {
|
||||
bridge.puppetsLock.Lock()
|
||||
defer bridge.puppetsLock.Unlock()
|
||||
puppet, ok := bridge.puppetsByCustomMXID[mxid]
|
||||
func (br *WABridge) GetPuppetByCustomMXID(mxid id.UserID) *Puppet {
|
||||
br.puppetsLock.Lock()
|
||||
defer br.puppetsLock.Unlock()
|
||||
puppet, ok := br.puppetsByCustomMXID[mxid]
|
||||
if !ok {
|
||||
dbPuppet := bridge.DB.Puppet.GetByCustomMXID(mxid)
|
||||
dbPuppet := br.DB.Puppet.GetByCustomMXID(mxid)
|
||||
if dbPuppet == nil {
|
||||
return nil
|
||||
}
|
||||
puppet = bridge.NewPuppet(dbPuppet)
|
||||
bridge.puppets[puppet.JID] = puppet
|
||||
bridge.puppetsByCustomMXID[puppet.CustomMXID] = puppet
|
||||
puppet = br.NewPuppet(dbPuppet)
|
||||
br.puppets[puppet.JID] = puppet
|
||||
br.puppetsByCustomMXID[puppet.CustomMXID] = puppet
|
||||
}
|
||||
return puppet
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetAllPuppetsWithCustomMXID() []*Puppet {
|
||||
return bridge.dbPuppetsToPuppets(bridge.DB.Puppet.GetAllWithCustomMXID())
|
||||
func (br *WABridge) GetAllPuppetsWithCustomMXID() []*Puppet {
|
||||
return br.dbPuppetsToPuppets(br.DB.Puppet.GetAllWithCustomMXID())
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetAllPuppets() []*Puppet {
|
||||
return bridge.dbPuppetsToPuppets(bridge.DB.Puppet.GetAll())
|
||||
func (br *WABridge) GetAllPuppets() []*Puppet {
|
||||
return br.dbPuppetsToPuppets(br.DB.Puppet.GetAll())
|
||||
}
|
||||
|
||||
func (bridge *Bridge) dbPuppetsToPuppets(dbPuppets []*database.Puppet) []*Puppet {
|
||||
bridge.puppetsLock.Lock()
|
||||
defer bridge.puppetsLock.Unlock()
|
||||
func (br *WABridge) dbPuppetsToPuppets(dbPuppets []*database.Puppet) []*Puppet {
|
||||
br.puppetsLock.Lock()
|
||||
defer br.puppetsLock.Unlock()
|
||||
output := make([]*Puppet, len(dbPuppets))
|
||||
for index, dbPuppet := range dbPuppets {
|
||||
if dbPuppet == nil {
|
||||
continue
|
||||
}
|
||||
puppet, ok := bridge.puppets[dbPuppet.JID]
|
||||
puppet, ok := br.puppets[dbPuppet.JID]
|
||||
if !ok {
|
||||
puppet = bridge.NewPuppet(dbPuppet)
|
||||
bridge.puppets[dbPuppet.JID] = puppet
|
||||
puppet = br.NewPuppet(dbPuppet)
|
||||
br.puppets[dbPuppet.JID] = puppet
|
||||
if len(dbPuppet.CustomMXID) > 0 {
|
||||
bridge.puppetsByCustomMXID[dbPuppet.CustomMXID] = puppet
|
||||
br.puppetsByCustomMXID[dbPuppet.CustomMXID] = puppet
|
||||
}
|
||||
}
|
||||
output[index] = puppet
|
||||
|
@ -133,26 +133,26 @@ func (bridge *Bridge) dbPuppetsToPuppets(dbPuppets []*database.Puppet) []*Puppet
|
|||
return output
|
||||
}
|
||||
|
||||
func (bridge *Bridge) FormatPuppetMXID(jid types.JID) id.UserID {
|
||||
func (br *WABridge) FormatPuppetMXID(jid types.JID) id.UserID {
|
||||
return id.NewUserID(
|
||||
bridge.Config.Bridge.FormatUsername(jid.User),
|
||||
bridge.Config.Homeserver.Domain)
|
||||
br.Config.Bridge.FormatUsername(jid.User),
|
||||
br.Config.Homeserver.Domain)
|
||||
}
|
||||
|
||||
func (bridge *Bridge) NewPuppet(dbPuppet *database.Puppet) *Puppet {
|
||||
func (br *WABridge) NewPuppet(dbPuppet *database.Puppet) *Puppet {
|
||||
return &Puppet{
|
||||
Puppet: dbPuppet,
|
||||
bridge: bridge,
|
||||
log: bridge.Log.Sub(fmt.Sprintf("Puppet/%s", dbPuppet.JID)),
|
||||
bridge: br,
|
||||
log: br.Log.Sub(fmt.Sprintf("Puppet/%s", dbPuppet.JID)),
|
||||
|
||||
MXID: bridge.FormatPuppetMXID(dbPuppet.JID),
|
||||
MXID: br.FormatPuppetMXID(dbPuppet.JID),
|
||||
}
|
||||
}
|
||||
|
||||
type Puppet struct {
|
||||
*database.Puppet
|
||||
|
||||
bridge *Bridge
|
||||
bridge *WABridge
|
||||
log log.Logger
|
||||
|
||||
typingIn id.RoomID
|
||||
|
|
75
user.go
75
user.go
|
@ -35,6 +35,7 @@ import (
|
|||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/appservice"
|
||||
"maunium.net/go/mautrix/bridge"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/format"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
@ -56,7 +57,7 @@ type User struct {
|
|||
Client *whatsmeow.Client
|
||||
Session *store.Device
|
||||
|
||||
bridge *Bridge
|
||||
bridge *WABridge
|
||||
log log.Logger
|
||||
|
||||
Admin bool
|
||||
|
@ -84,38 +85,46 @@ type User struct {
|
|||
BackfillQueue *BackfillQueue
|
||||
}
|
||||
|
||||
func (bridge *Bridge) getUserByMXID(userID id.UserID, onlyIfExists bool) *User {
|
||||
_, isPuppet := bridge.ParsePuppetMXID(userID)
|
||||
if isPuppet || userID == bridge.Bot.UserID {
|
||||
func (br *WABridge) getUserByMXID(userID id.UserID, onlyIfExists bool) *User {
|
||||
_, isPuppet := br.ParsePuppetMXID(userID)
|
||||
if isPuppet || userID == br.Bot.UserID {
|
||||
return nil
|
||||
}
|
||||
bridge.usersLock.Lock()
|
||||
defer bridge.usersLock.Unlock()
|
||||
user, ok := bridge.usersByMXID[userID]
|
||||
br.usersLock.Lock()
|
||||
defer br.usersLock.Unlock()
|
||||
user, ok := br.usersByMXID[userID]
|
||||
if !ok {
|
||||
userIDPtr := &userID
|
||||
if onlyIfExists {
|
||||
userIDPtr = nil
|
||||
}
|
||||
return bridge.loadDBUser(bridge.DB.User.GetByMXID(userID), userIDPtr)
|
||||
return br.loadDBUser(br.DB.User.GetByMXID(userID), userIDPtr)
|
||||
}
|
||||
return user
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetUserByMXID(userID id.UserID) *User {
|
||||
return bridge.getUserByMXID(userID, false)
|
||||
func (br *WABridge) GetUserByMXID(userID id.UserID) *User {
|
||||
return br.getUserByMXID(userID, false)
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetUserByMXIDIfExists(userID id.UserID) *User {
|
||||
return bridge.getUserByMXID(userID, true)
|
||||
func (br *WABridge) GetIUserByMXID(userID id.UserID) bridge.User {
|
||||
return br.getUserByMXID(userID, false)
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetUserByJID(jid types.JID) *User {
|
||||
bridge.usersLock.Lock()
|
||||
defer bridge.usersLock.Unlock()
|
||||
user, ok := bridge.usersByUsername[jid.User]
|
||||
func (user *User) IsAdmin() bool {
|
||||
return user.Admin
|
||||
}
|
||||
|
||||
func (br *WABridge) GetUserByMXIDIfExists(userID id.UserID) *User {
|
||||
return br.getUserByMXID(userID, true)
|
||||
}
|
||||
|
||||
func (br *WABridge) GetUserByJID(jid types.JID) *User {
|
||||
br.usersLock.Lock()
|
||||
defer br.usersLock.Unlock()
|
||||
user, ok := br.usersByUsername[jid.User]
|
||||
if !ok {
|
||||
return bridge.loadDBUser(bridge.DB.User.GetByUsername(jid.User), nil)
|
||||
return br.loadDBUser(br.DB.User.GetByUsername(jid.User), nil)
|
||||
}
|
||||
return user
|
||||
}
|
||||
|
@ -137,35 +146,35 @@ func (user *User) removeFromJIDMap(state BridgeState) {
|
|||
user.sendBridgeState(state)
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetAllUsers() []*User {
|
||||
bridge.usersLock.Lock()
|
||||
defer bridge.usersLock.Unlock()
|
||||
dbUsers := bridge.DB.User.GetAll()
|
||||
func (br *WABridge) GetAllUsers() []*User {
|
||||
br.usersLock.Lock()
|
||||
defer br.usersLock.Unlock()
|
||||
dbUsers := br.DB.User.GetAll()
|
||||
output := make([]*User, len(dbUsers))
|
||||
for index, dbUser := range dbUsers {
|
||||
user, ok := bridge.usersByMXID[dbUser.MXID]
|
||||
user, ok := br.usersByMXID[dbUser.MXID]
|
||||
if !ok {
|
||||
user = bridge.loadDBUser(dbUser, nil)
|
||||
user = br.loadDBUser(dbUser, nil)
|
||||
}
|
||||
output[index] = user
|
||||
}
|
||||
return output
|
||||
}
|
||||
|
||||
func (bridge *Bridge) loadDBUser(dbUser *database.User, mxid *id.UserID) *User {
|
||||
func (br *WABridge) loadDBUser(dbUser *database.User, mxid *id.UserID) *User {
|
||||
if dbUser == nil {
|
||||
if mxid == nil {
|
||||
return nil
|
||||
}
|
||||
dbUser = bridge.DB.User.New()
|
||||
dbUser = br.DB.User.New()
|
||||
dbUser.MXID = *mxid
|
||||
dbUser.Insert()
|
||||
}
|
||||
user := bridge.NewUser(dbUser)
|
||||
bridge.usersByMXID[user.MXID] = user
|
||||
user := br.NewUser(dbUser)
|
||||
br.usersByMXID[user.MXID] = user
|
||||
if !user.JID.IsEmpty() {
|
||||
var err error
|
||||
user.Session, err = bridge.WAContainer.GetDevice(user.JID)
|
||||
user.Session, err = br.WAContainer.GetDevice(user.JID)
|
||||
if err != nil {
|
||||
user.log.Errorfln("Failed to load user's whatsapp session: %v", err)
|
||||
} else if user.Session == nil {
|
||||
|
@ -174,20 +183,20 @@ func (bridge *Bridge) loadDBUser(dbUser *database.User, mxid *id.UserID) *User {
|
|||
user.Update()
|
||||
} else {
|
||||
user.Session.Log = &waLogger{user.log.Sub("Session")}
|
||||
bridge.usersByUsername[user.JID.User] = user
|
||||
br.usersByUsername[user.JID.User] = user
|
||||
}
|
||||
}
|
||||
if len(user.ManagementRoom) > 0 {
|
||||
bridge.managementRooms[user.ManagementRoom] = user
|
||||
br.managementRooms[user.ManagementRoom] = user
|
||||
}
|
||||
return user
|
||||
}
|
||||
|
||||
func (bridge *Bridge) NewUser(dbUser *database.User) *User {
|
||||
func (br *WABridge) NewUser(dbUser *database.User) *User {
|
||||
user := &User{
|
||||
User: dbUser,
|
||||
bridge: bridge,
|
||||
log: bridge.Log.Sub("User").Sub(string(dbUser.MXID)),
|
||||
bridge: br,
|
||||
log: br.Log.Sub("User").Sub(string(dbUser.MXID)),
|
||||
|
||||
historySyncs: make(chan *events.HistorySync, 32),
|
||||
lastPresence: types.PresenceUnavailable,
|
||||
|
|
Loading…
Add table
Reference in a new issue