mirror of
https://github.com/mautrix/signal.git
synced 2025-03-14 14:15:36 +00:00
signalmeow: block until stop completes
Currently the disconnect/stop bridge call will complete before all the loops have returned. This switches them all to use a shared cancelable context and wait group to block on stop until all loops exit.
This commit is contained in:
parent
272f38297f
commit
4713ddfcd1
3 changed files with 69 additions and 60 deletions
|
@ -48,9 +48,11 @@ type Client struct {
|
|||
|
||||
AuthedWS *web.SignalWebsocket
|
||||
UnauthedWS *web.SignalWebsocket
|
||||
WSCancel context.CancelFunc
|
||||
lastConnectionStatus SignalConnectionStatus
|
||||
|
||||
loopCancel context.CancelFunc
|
||||
loopWg sync.WaitGroup
|
||||
|
||||
EventHandler func(events.SignalEvent)
|
||||
|
||||
storageAuthLock sync.Mutex
|
||||
|
|
|
@ -599,49 +599,48 @@ func (cli *Client) CheckAndUploadNewPreKeys(ctx context.Context, pks store.PreKe
|
|||
return nil
|
||||
}
|
||||
|
||||
func (cli *Client) StartKeyCheckLoop(ctx context.Context) {
|
||||
func (cli *Client) keyCheckLoop(ctx context.Context) {
|
||||
log := zerolog.Ctx(ctx).With().Str("action", "start key check loop").Logger()
|
||||
go func() {
|
||||
// Do the initial check in 5-10 minutes after starting the loop
|
||||
window_start := 0
|
||||
window_size := 1
|
||||
for {
|
||||
random_minutes_in_window := rand.Intn(window_size) + window_start
|
||||
check_time := time.Duration(random_minutes_in_window) * time.Minute
|
||||
log.Debug().Dur("check_time", check_time).Msg("Waiting to check for new prekeys")
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(check_time):
|
||||
err := cli.CheckAndUploadNewPreKeys(ctx, cli.Store.ACIPreKeyStore)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Error checking and uploading new prekeys for ACI identity")
|
||||
// Retry within half an hour
|
||||
window_start = 5
|
||||
window_size = 25
|
||||
continue
|
||||
}
|
||||
err = cli.CheckAndUploadNewPreKeys(ctx, cli.Store.PNIPreKeyStore)
|
||||
if err != nil {
|
||||
if errors.Is(err, errPrekeyUpload422) {
|
||||
log.Err(err).Msg("Got 422 error while uploading PNI prekeys, deleting session")
|
||||
disconnectErr := cli.ClearKeysAndDisconnect(ctx)
|
||||
if disconnectErr != nil {
|
||||
log.Err(disconnectErr).Msg("ClearKeysAndDisconnect error")
|
||||
}
|
||||
return
|
||||
}
|
||||
log.Err(err).Msg("Error checking and uploading new prekeys for PNI identity")
|
||||
// Retry within half an hour
|
||||
window_start = 5
|
||||
window_size = 25
|
||||
continue
|
||||
}
|
||||
// After a successful check, check again in 36 to 60 hours
|
||||
window_start = 36 * 60
|
||||
window_size = 24 * 60
|
||||
// Do the initial check in 5-10 minutes after starting the loop
|
||||
window_start := 0
|
||||
window_size := 1
|
||||
for {
|
||||
random_minutes_in_window := rand.Intn(window_size) + window_start
|
||||
check_time := time.Duration(random_minutes_in_window) * time.Minute
|
||||
log.Debug().Dur("check_time", check_time).Msg("Waiting to check for new prekeys")
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(check_time):
|
||||
err := cli.CheckAndUploadNewPreKeys(ctx, cli.Store.ACIPreKeyStore)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Error checking and uploading new prekeys for ACI identity")
|
||||
// Retry within half an hour
|
||||
window_start = 5
|
||||
window_size = 25
|
||||
continue
|
||||
}
|
||||
err = cli.CheckAndUploadNewPreKeys(ctx, cli.Store.PNIPreKeyStore)
|
||||
if err != nil {
|
||||
if errors.Is(err, errPrekeyUpload422) {
|
||||
log.Err(err).Msg("Got 422 error while uploading PNI prekeys, deleting session")
|
||||
disconnectErr := cli.ClearKeysAndDisconnect(ctx)
|
||||
if disconnectErr != nil {
|
||||
log.Err(disconnectErr).Msg("ClearKeysAndDisconnect error")
|
||||
}
|
||||
return
|
||||
}
|
||||
log.Err(err).Msg("Error checking and uploading new prekeys for PNI identity")
|
||||
// Retry within half an hour
|
||||
window_start = 5
|
||||
window_size = 25
|
||||
continue
|
||||
}
|
||||
// After a successful check, check again in 36 to 60 hours
|
||||
window_start = 36 * 60
|
||||
window_size = 24 * 60
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -77,31 +77,30 @@ func (cli *Client) startWebsocketsInternal(
|
|||
ctx context.Context,
|
||||
) (
|
||||
authChan, unauthChan chan web.SignalWebsocketConnectionStatus,
|
||||
cancelCtx context.Context,
|
||||
cancelFunc context.CancelFunc,
|
||||
loopCtx context.Context, loopCancel context.CancelFunc,
|
||||
err error,
|
||||
) {
|
||||
cancelCtx, cancelFunc = context.WithCancel(ctx)
|
||||
cli.WSCancel = cancelFunc
|
||||
unauthChan, err = cli.connectUnauthedWS(cancelCtx)
|
||||
loopCtx, loopCancel = context.WithCancel(ctx)
|
||||
unauthChan, err = cli.connectUnauthedWS(loopCtx)
|
||||
if err != nil {
|
||||
cancelFunc()
|
||||
loopCancel()
|
||||
return
|
||||
}
|
||||
zerolog.Ctx(ctx).Info().Msg("Unauthed websocket connecting")
|
||||
authChan, err = cli.connectAuthedWS(cancelCtx, cli.incomingRequestHandler)
|
||||
authChan, err = cli.connectAuthedWS(loopCtx, cli.incomingRequestHandler)
|
||||
if err != nil {
|
||||
cancelFunc()
|
||||
loopCancel()
|
||||
return
|
||||
}
|
||||
zerolog.Ctx(ctx).Info().Msg("Authed websocket connecting")
|
||||
cli.loopCancel = loopCancel
|
||||
return
|
||||
}
|
||||
|
||||
func (cli *Client) StartReceiveLoops(ctx context.Context) (chan SignalConnectionStatus, error) {
|
||||
log := zerolog.Ctx(ctx).With().Str("action", "start receive loops").Logger()
|
||||
ctx = log.WithContext(ctx)
|
||||
authChan, unauthChan, ctx, cancel, err := cli.startWebsocketsInternal(log.WithContext(ctx))
|
||||
|
||||
authChan, unauthChan, loopCtx, loopCancel, err := cli.startWebsocketsInternal(log.WithContext(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -110,13 +109,15 @@ func (cli *Client) StartReceiveLoops(ctx context.Context) (chan SignalConnection
|
|||
initialConnectChan := make(chan struct{})
|
||||
|
||||
// Combine both websocket status channels into a single, more generic "Signal" connection status channel
|
||||
cli.loopWg.Add(1)
|
||||
go func() {
|
||||
defer cli.loopWg.Done()
|
||||
defer close(statusChan)
|
||||
defer cancel()
|
||||
defer loopCancel()
|
||||
var currentStatus, lastAuthStatus, lastUnauthStatus web.SignalWebsocketConnectionStatus
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-loopCtx.Done():
|
||||
log.Info().Msg("Context done, exiting websocket status loop")
|
||||
return
|
||||
case status := <-authChan:
|
||||
|
@ -201,19 +202,21 @@ func (cli *Client) StartReceiveLoops(ctx context.Context) (chan SignalConnection
|
|||
}()
|
||||
|
||||
// Send sync message once both websockets are connected
|
||||
cli.loopWg.Add(1)
|
||||
go func() {
|
||||
defer cli.loopWg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-loopCtx.Done():
|
||||
return
|
||||
case <-initialConnectChan:
|
||||
log.Info().Msg("Both websockets connected, sending contacts sync request")
|
||||
// TODO hacky
|
||||
if cli.SyncContactsOnConnect {
|
||||
cli.SendContactSyncRequest(ctx)
|
||||
cli.SendContactSyncRequest(loopCtx)
|
||||
}
|
||||
if cli.Store.MasterKey == nil {
|
||||
cli.SendStorageMasterKeyRequest(ctx)
|
||||
cli.SendStorageMasterKeyRequest(loopCtx)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -221,7 +224,11 @@ func (cli *Client) StartReceiveLoops(ctx context.Context) (chan SignalConnection
|
|||
}()
|
||||
|
||||
// Start loop to check for and upload more prekeys
|
||||
cli.StartKeyCheckLoop(ctx)
|
||||
cli.loopWg.Add(1)
|
||||
go func() {
|
||||
defer cli.loopWg.Done()
|
||||
cli.keyCheckLoop(loopCtx)
|
||||
}()
|
||||
|
||||
return statusChan, nil
|
||||
}
|
||||
|
@ -233,8 +240,9 @@ func (cli *Client) StopReceiveLoops() error {
|
|||
}()
|
||||
authErr := cli.AuthedWS.Close()
|
||||
unauthErr := cli.UnauthedWS.Close()
|
||||
if cli.WSCancel != nil {
|
||||
cli.WSCancel()
|
||||
if cli.loopCancel != nil {
|
||||
cli.loopCancel()
|
||||
cli.loopWg.Wait()
|
||||
}
|
||||
if authErr != nil {
|
||||
return authErr
|
||||
|
|
Loading…
Add table
Reference in a new issue