signalmeow/store: fix locking recipient store

This commit is contained in:
Tulir Asokan 2025-02-04 11:48:31 +02:00
parent f07723070d
commit 3781461b28
6 changed files with 47 additions and 25 deletions

View file

@ -93,7 +93,7 @@ func (cli *Client) FetchAndProcessTransfer(ctx context.Context, meta *TransferAr
if err != nil {
return fmt.Errorf("failed to seek to start of file: %w", err)
}
err = cli.Store.DoTxn(ctx, func(ctx context.Context) error {
err = cli.Store.DoContactTxn(ctx, func(ctx context.Context) error {
err = cli.Store.BackupStore.ClearBackup(ctx)
if err != nil {
return fmt.Errorf("failed to clear backup: %w", err)

View file

@ -805,24 +805,30 @@ func (cli *Client) handleDecryptedResult(
}
log.Debug().Int("contact_count", len(contacts)).Msg("Contacts Sync received contacts")
convertedContacts := make([]*types.Recipient, 0, len(contacts))
for i, signalContact := range contacts {
if signalContact.Aci == nil || *signalContact.Aci == "" {
// TODO lookup PNI via CDSI and store that when ACI is missing?
log.Info().
Any("contact", signalContact).
Msg("Signal Contact UUID is nil, skipping")
continue
err = cli.Store.DoContactTxn(ctx, func(ctx context.Context) error {
for i, signalContact := range contacts {
if signalContact.Aci == nil || *signalContact.Aci == "" {
// TODO lookup PNI via CDSI and store that when ACI is missing?
log.Info().
Any("contact", signalContact).
Msg("Signal Contact UUID is nil, skipping")
continue
}
contact, err := cli.StoreContactDetailsAsContact(ctx, signalContact, &avatars[i])
if err != nil {
return err
}
convertedContacts = append(convertedContacts, contact)
}
contact, err := cli.StoreContactDetailsAsContact(ctx, signalContact, &avatars[i])
if err != nil {
log.Err(err).Msg("StoreContactDetailsAsContact error")
continue
}
convertedContacts = append(convertedContacts, contact)
}
cli.handleEvent(&events.ContactList{
Contacts: convertedContacts,
return nil
})
if err != nil {
log.Err(err).Msg("Error storing contacts")
} else {
cli.handleEvent(&events.ContactList{
Contacts: convertedContacts,
})
}
}
}
if content.SyncMessage.Read != nil {

View file

@ -49,7 +49,7 @@ func (cli *Client) SyncStorage(ctx context.Context) {
log.Err(err).Msg("Failed to fetch storage")
return
}
err = cli.Store.DoTxn(ctx, func(ctx context.Context) error {
err = cli.Store.DoContactTxn(ctx, func(ctx context.Context) error {
return cli.processStorageInTxn(ctx, update)
})
if err != nil {

View file

@ -108,10 +108,8 @@ func (c *Container) scanDevice(row dbutil.Scannable) (*Device, error) {
device.RecipientStore = baseStore
device.DeviceStore = baseStore
device.BackupStore = baseStore
device.DoTxn = func(ctx context.Context, fn func(context.Context) error) error {
return c.db.DoTxn(context.WithValue(ctx, dbutil.ContextKeyDoTxnCallerSkip, 1), nil, fn)
}
device.sqlStore = baseStore
device.db = c.db
return &device, nil
}

View file

@ -7,6 +7,7 @@ import (
"github.com/google/uuid"
"github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
"go.mau.fi/mautrix-signal/pkg/libsignalgo"
signalpb "go.mau.fi/mautrix-signal/pkg/signalmeow/protobuf"
@ -79,7 +80,22 @@ type Device struct {
DeviceStore DeviceStore
BackupStore BackupStore
DoTxn func(context.Context, func(context.Context) error) error
sqlStore *sqlStore
db *dbutil.Database
}
type contextKey int64
const (
contextKeyContactLock contextKey = 1
)
func (d *Device) DoContactTxn(ctx context.Context, fn func(context.Context) error) error {
d.sqlStore.contactLock.Lock()
defer d.sqlStore.contactLock.Unlock()
ctx = context.WithValue(ctx, dbutil.ContextKeyDoTxnCallerSkip, 1)
ctx = context.WithValue(ctx, contextKeyContactLock, true)
return d.db.DoTxn(ctx, nil, fn)
}
func (d *Device) ClearDeviceKeys(ctx context.Context) error {

View file

@ -204,8 +204,10 @@ func (s *sqlStore) LoadAndUpdateRecipient(ctx context.Context, aci, pni uuid.UUI
return false, nil
}
}
s.contactLock.Lock()
defer s.contactLock.Unlock()
if ctx.Value(contextKeyContactLock) == nil {
s.contactLock.Lock()
defer s.contactLock.Unlock()
}
outErr = s.db.DoTxn(ctx, nil, func(ctx context.Context) error {
var entries []*types.Recipient
var err error