mirror of
https://github.com/mautrix/signal.git
synced 2025-03-14 14:15:36 +00:00
Add bpnsclient and connect first user to BPNS as PoC
This commit is contained in:
parent
60263e7529
commit
01048dc165
3 changed files with 337 additions and 2 deletions
309
bpnsclient/bpnsclient.go
Normal file
309
bpnsclient/bpnsclient.go
Normal file
|
@ -0,0 +1,309 @@
|
|||
package bpnsclient
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrBPNSUnauthorized = errors.New("imat unauthorized")
|
||||
ErrBPNSUnknownDevice = errors.New("device not registered on bpns")
|
||||
ErrBPNSConflict = errors.New("registration conflict on bpns")
|
||||
ErrBPNSUnexpectedResponse = errors.New("unexpected error from bpns")
|
||||
)
|
||||
|
||||
const bpnsMinKeepalive = time.Second * 1
|
||||
|
||||
type BPNSClient struct {
|
||||
log *zerolog.Logger
|
||||
url string
|
||||
imaToken string
|
||||
cfg DeviceConfig
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
type FCMConfig struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
type SignalConfig struct {
|
||||
AciUUID string `json:"aci_uuid"`
|
||||
DeviceID int `json:"device_id"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type DeviceConfig struct {
|
||||
FCM FCMConfig `json:"fcm"`
|
||||
Signal SignalConfig `json:"signal,omitempty"`
|
||||
}
|
||||
|
||||
type JSONDurationMs time.Duration
|
||||
|
||||
func (t *JSONDurationMs) UnmarshalJSON(data []byte) error {
|
||||
var value int64
|
||||
|
||||
if err := json.Unmarshal(data, &value); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*t = JSONDurationMs(value * int64(time.Millisecond))
|
||||
return nil
|
||||
}
|
||||
|
||||
type KeepaliveResponse struct {
|
||||
Rate JSONDurationMs `json:"rate_ms"`
|
||||
}
|
||||
|
||||
func NewBPNSClient(log *zerolog.Logger, url, imaToken, pushToken string, signalACI uuid.UUID, signalPassword string, deviceID int) *BPNSClient {
|
||||
deviceConfig := DeviceConfig{
|
||||
FCM: FCMConfig{
|
||||
Token: pushToken,
|
||||
},
|
||||
Signal: SignalConfig{
|
||||
AciUUID: signalACI.String(),
|
||||
Password: signalPassword,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
}
|
||||
|
||||
return &BPNSClient{
|
||||
log: log,
|
||||
url: url,
|
||||
imaToken: imaToken,
|
||||
cfg: deviceConfig,
|
||||
}
|
||||
}
|
||||
|
||||
func (bpns *BPNSClient) request(ctx context.Context, method, endpoint string, body any) (*http.Response, error) {
|
||||
var bodyReader io.Reader
|
||||
|
||||
if body != nil {
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bodyReader = bytes.NewReader(data)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, fmt.Sprintf("%s/v1/signal/%s", bpns.url, endpoint), bodyReader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+bpns.imaToken)
|
||||
req.Header.Set("User-Agent", "signal")
|
||||
|
||||
normalTransport := &http.Transport{
|
||||
DialContext: (&net.Dialer{Timeout: 10 * time.Second}).DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ResponseHeaderTimeout: 5 * time.Second,
|
||||
}
|
||||
normalClient := &http.Client{
|
||||
Timeout: 1 * time.Minute,
|
||||
Transport: normalTransport,
|
||||
}
|
||||
resp, err := normalClient.Do(req)
|
||||
if err != nil {
|
||||
bpns.log.Trace().Err(err).Stringer("url", req.URL).Str("method", method).Msg("bpns Request Error")
|
||||
} else {
|
||||
bpns.log.Trace().Stringer("url", req.URL).Str("method", method).Str("status", resp.Status).Msg("bpns Request")
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (bpns *BPNSClient) register(ctx context.Context) (time.Duration, error) {
|
||||
resp, err := bpns.request(ctx, http.MethodPut, "device", &bpns.cfg)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
switch resp.StatusCode {
|
||||
case http.StatusCreated:
|
||||
return decodeKeepaliveRate(resp)
|
||||
case http.StatusConflict:
|
||||
return 0, ErrBPNSConflict
|
||||
case http.StatusUnauthorized:
|
||||
return 0, ErrBPNSUnauthorized
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: register returned unexpected status %d", ErrBPNSUnexpectedResponse, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func decodeKeepaliveRate(resp *http.Response) (rate time.Duration, err error) {
|
||||
if resp.Body == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var keepaliveResp KeepaliveResponse
|
||||
if err = json.NewDecoder(resp.Body).Decode(&keepaliveResp); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
rate = time.Duration(keepaliveResp.Rate)
|
||||
return
|
||||
}
|
||||
|
||||
func (bpns *BPNSClient) keepalive(ctx context.Context) (time.Duration, error) {
|
||||
resp, err := bpns.request(ctx, http.MethodPost, "keepalive", nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
switch resp.StatusCode {
|
||||
case http.StatusOK:
|
||||
return decodeKeepaliveRate(resp)
|
||||
case http.StatusNoContent:
|
||||
return 0, nil
|
||||
case http.StatusUnauthorized:
|
||||
return 0, ErrBPNSUnauthorized
|
||||
case http.StatusNotFound:
|
||||
return 0, ErrBPNSUnknownDevice
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: keepalive returned unexpected status %d", ErrBPNSUnexpectedResponse, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func (bpns *BPNSClient) yield(ctx context.Context) error {
|
||||
resp, err := bpns.request(ctx, http.MethodPost, "yield", nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch resp.StatusCode {
|
||||
case http.StatusNoContent:
|
||||
return nil
|
||||
case http.StatusUnauthorized:
|
||||
return ErrBPNSUnauthorized
|
||||
case http.StatusNotFound:
|
||||
return ErrBPNSUnknownDevice
|
||||
default:
|
||||
return fmt.Errorf("%w: yield returned unexpected status %d", ErrBPNSUnexpectedResponse, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func (bpns *BPNSClient) keepaliveOrRegister(ctx context.Context) (time.Duration, error) {
|
||||
rate, err := bpns.keepalive(ctx)
|
||||
|
||||
if err == ErrBPNSUnknownDevice {
|
||||
rate, err = bpns.register(ctx)
|
||||
}
|
||||
|
||||
return rate, err
|
||||
}
|
||||
|
||||
func (bpns *BPNSClient) Start(ctx context.Context) {
|
||||
if bpns.cancel != nil {
|
||||
bpns.log.Debug().Msg("Start() called when already running")
|
||||
return
|
||||
}
|
||||
|
||||
if bpns.url == "" {
|
||||
bpns.log.Warn().Msg("No URL set, bpns disabled")
|
||||
return
|
||||
}
|
||||
|
||||
// we need to run once with start context to make bpns disconnect from remote network immediately
|
||||
rate, err := bpns.keepaliveOrRegister(ctx)
|
||||
if err != nil {
|
||||
bpns.log.Warn().Err(err).Msg("Unexpected bpns error during start")
|
||||
// first successful keepalive will reset the ticker
|
||||
rate = time.Second * 30
|
||||
} else if rate == 0 {
|
||||
bpns.log.Warn().Msg("Server requested we don't do keepalives, not starting loop")
|
||||
return
|
||||
} else if rate < bpnsMinKeepalive {
|
||||
bpns.log.Warn().Dur("rate_req", rate).Dur("rate_min", bpnsMinKeepalive).Msg("Server requested keepalive time less than allowed, using ours")
|
||||
rate = bpnsMinKeepalive
|
||||
}
|
||||
|
||||
bpns.log.Info().Str("url", bpns.url).Dur("rate", rate).Msg("Starting bpns loop")
|
||||
|
||||
ctx, bpns.cancel = context.WithCancel(context.Background())
|
||||
|
||||
bpns.wg.Add(1)
|
||||
go func() {
|
||||
defer func() {
|
||||
bpns.wg.Done()
|
||||
bpns.log.Trace().Msg("Leaving bpns loop")
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(rate)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
newRate, err := bpns.keepaliveOrRegister(ctx)
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return
|
||||
} else if err != nil {
|
||||
bpns.log.Warn().Err(err).Msg("Unexpected bpns error")
|
||||
} else if newRate == 0 {
|
||||
bpns.log.Warn().Msg("Keepalive returned zero rate, stopping loop")
|
||||
return
|
||||
} else if newRate > bpnsMinKeepalive && newRate != rate {
|
||||
bpns.log.Info().Dur("rate_new", newRate).Dur("rate_old", rate).Msg("Keepalive rate changed by bpns")
|
||||
rate = newRate
|
||||
ticker.Reset(rate)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (bpns *BPNSClient) Stop(ctx context.Context) {
|
||||
if bpns.cancel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
bpns.log.Info().Msg("Stopping bpns loop")
|
||||
|
||||
bpns.Close()
|
||||
|
||||
// send yield before returning
|
||||
bpns.log.Debug().Msg("Sending final yield")
|
||||
err := bpns.yield(ctx)
|
||||
if err != nil {
|
||||
bpns.log.Err(err).Msg("Final yield failed")
|
||||
} else {
|
||||
bpns.log.Info().Msg("Yielded bpns connection")
|
||||
}
|
||||
}
|
||||
|
||||
func (bpns *BPNSClient) Close() {
|
||||
if bpns.cancel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
bpns.cancel()
|
||||
bpns.cancel = nil
|
||||
bpns.wg.Wait()
|
||||
}
|
||||
|
||||
func (bpns *BPNSClient) Deregister(ctx context.Context) error {
|
||||
_, err := bpns.request(ctx, http.MethodDelete, "device", &bpns.cfg)
|
||||
return err
|
||||
}
|
||||
|
||||
func (bpns *BPNSClient) IsUsingSignalACI(aci uuid.UUID) bool {
|
||||
return uuid.MustParse(bpns.cfg.Signal.AciUUID) == aci
|
||||
}
|
||||
|
||||
func (bpns *BPNSClient) IsUsingFCMPushToken(token string) bool {
|
||||
return bpns.cfg.FCM.Token == token
|
||||
}
|
25
main.go
25
main.go
|
@ -34,6 +34,7 @@ import (
|
|||
"maunium.net/go/mautrix/format"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"go.mau.fi/mautrix-signal/bpnsclient"
|
||||
"go.mau.fi/mautrix-signal/config"
|
||||
"go.mau.fi/mautrix-signal/database"
|
||||
"go.mau.fi/mautrix-signal/msgconv/matrixfmt"
|
||||
|
@ -79,6 +80,8 @@ type SignalBridge struct {
|
|||
puppetsLock sync.Mutex
|
||||
|
||||
disappearingMessagesManager *DisappearingMessagesManager
|
||||
|
||||
bpnsClient *bpnsclient.BPNSClient
|
||||
}
|
||||
|
||||
var _ bridge.ChildOverride = (*SignalBridge)(nil)
|
||||
|
@ -187,7 +190,24 @@ func (br *SignalBridge) Start() {
|
|||
br.Log.Debugln("Initializing provisioning API")
|
||||
br.provisioning.Init()
|
||||
}
|
||||
go br.StartUsers()
|
||||
go func() {
|
||||
startingUsers := br.StartUsers()
|
||||
|
||||
logger := br.ZLog.With().Str("component", "bpnsclient").Logger()
|
||||
br.bpnsClient = bpnsclient.NewBPNSClient(
|
||||
&logger,
|
||||
"http://localhost:8000",
|
||||
"myimaToken",
|
||||
"mypushToken",
|
||||
startingUsers[0].SignalID,
|
||||
startingUsers[0].Client.Store.Password,
|
||||
startingUsers[0].Client.Store.DeviceID,
|
||||
)
|
||||
if br.bpnsClient != nil {
|
||||
br.Log.Debugln("Starting BPNS client")
|
||||
br.bpnsClient.Start(context.TODO())
|
||||
}
|
||||
}()
|
||||
if br.Config.Metrics.Enabled {
|
||||
go br.Metrics.Start()
|
||||
}
|
||||
|
@ -200,6 +220,9 @@ func (br *SignalBridge) Stop() {
|
|||
br.Log.Debugln("Disconnecting", user.MXID)
|
||||
user.Disconnect()
|
||||
}
|
||||
if br.bpnsClient != nil {
|
||||
br.bpnsClient.Stop(context.TODO())
|
||||
}
|
||||
}
|
||||
|
||||
func (br *SignalBridge) GetIPortal(mxid id.RoomID) bridge.Portal {
|
||||
|
|
5
user.go
5
user.go
|
@ -517,8 +517,9 @@ func (user *User) clearKeysAndDisconnect() {
|
|||
}
|
||||
}
|
||||
|
||||
func (br *SignalBridge) StartUsers() {
|
||||
func (br *SignalBridge) StartUsers() []*User {
|
||||
br.ZLog.Debug().Msg("Starting users")
|
||||
var startingUsers []*User
|
||||
|
||||
usersWithToken := br.GetAllLoggedInUsers()
|
||||
for _, u := range usersWithToken {
|
||||
|
@ -528,6 +529,7 @@ func (br *SignalBridge) StartUsers() {
|
|||
u.BridgeState.Send(status.BridgeState{StateEvent: status.StateBadCredentials, Message: "You have been logged out of Signal, please reconnect"})
|
||||
continue
|
||||
}
|
||||
startingUsers = append(startingUsers, u)
|
||||
go u.Connect()
|
||||
}
|
||||
if len(usersWithToken) == 0 {
|
||||
|
@ -544,6 +546,7 @@ func (br *SignalBridge) StartUsers() {
|
|||
}
|
||||
}(customPuppet)
|
||||
}
|
||||
return startingUsers
|
||||
}
|
||||
|
||||
func (user *User) Login() (<-chan signalmeow.ProvisioningResponse, error) {
|
||||
|
|
Loading…
Add table
Reference in a new issue