mautrix-signal/pkg/signalmeow/profile.go
2025-01-19 14:24:50 +02:00

399 lines
13 KiB
Go

// mautrix-signal - A Matrix-signal puppeting bridge.
// Copyright (C) 2023 Scott Weber
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package signalmeow
import (
"bytes"
"context"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/google/uuid"
"github.com/rs/zerolog"
"go.mau.fi/mautrix-signal/pkg/libsignalgo"
"go.mau.fi/mautrix-signal/pkg/signalmeow/types"
"go.mau.fi/mautrix-signal/pkg/signalmeow/web"
)
type Capabilities struct {
SenderKey bool `json:"senderKey"`
AnnouncementGroup bool `json:"announcementGroup"`
ChangeNumber bool `json:"changeNumber"`
Stories bool `json:"stories"`
GiftBadges bool `json:"giftBadges"`
PaymentActivation bool `json:"paymentActivation"`
PNI bool `json:"pni"`
Gv1Migration bool `json:"gv1-migration"`
}
type ProfileResponse struct {
UUID uuid.UUID `json:"uuid"`
Name []byte `json:"name"`
About []byte `json:"about"`
AboutEmoji []byte `json:"aboutEmoji"`
Avatar string `json:"avatar"`
Capabilities Capabilities `json:"capabilities"`
Credential []byte `json:"credential"`
IdentityKey []byte `json:"identityKey"`
UnidentifiedAccess []byte `json:"unidentifiedAccess"`
UnrestrictedUnidentifiedAccess bool `json:"UnrestrictedUnidentifiedAccess"`
//Badges []any `json:"badges"`
//PhoneNumberSharing []byte `json:"phoneNumberSharing"`
//PaymentAddress []byte `json:"paymentAddress"`
}
type ProfileCache struct {
profiles map[string]*types.Profile
errors map[string]*error
lastFetched map[string]time.Time
}
func (cli *Client) ProfileKeyCredentialRequest(ctx context.Context, signalACI uuid.UUID) ([]byte, error) {
profileKey, err := cli.ProfileKeyForSignalID(ctx, signalACI)
if err != nil {
return nil, fmt.Errorf("error getting profile key for ACI: %w", err)
}
requestContext, err := libsignalgo.CreateProfileKeyCredentialRequestContext(
prodServerPublicParams,
signalACI,
*profileKey,
)
if err != nil {
return nil, fmt.Errorf("error creating profile key credential request context: %w", err)
}
request, err := requestContext.ProfileKeyCredentialRequestContextGetRequest()
if err != nil {
return nil, fmt.Errorf("error getting profile key credential request: %w", err)
}
// convert request bytes to hexidecimal representation
hexRequest := hex.EncodeToString(request[:])
return []byte(hexRequest), nil
}
func (cli *Client) ProfileKeyForSignalID(ctx context.Context, signalACI uuid.UUID) (*libsignalgo.ProfileKey, error) {
profileKey, err := cli.Store.RecipientStore.LoadProfileKey(ctx, signalACI)
if err != nil {
return nil, fmt.Errorf("error getting profile key: %w", err)
}
return profileKey, nil
}
var errProfileKeyNotFound = errors.New("profile key not found")
func (cli *Client) RetrieveProfileByID(ctx context.Context, signalID uuid.UUID) (*types.Profile, error) {
if cli.ProfileCache == nil {
cli.ProfileCache = &ProfileCache{
profiles: make(map[string]*types.Profile),
errors: make(map[string]*error),
lastFetched: make(map[string]time.Time),
}
}
// Check if we have a cached profile that is less than an hour old
// or if we have a cached error that is less than an hour old
lastFetched, ok := cli.ProfileCache.lastFetched[signalID.String()]
if ok && time.Since(lastFetched) < 1*time.Hour {
profile, ok := cli.ProfileCache.profiles[signalID.String()]
if ok {
return profile, nil
}
err, ok := cli.ProfileCache.errors[signalID.String()]
if ok {
return nil, fmt.Errorf("%w: %w", ErrCachedError, *err)
}
}
// If we get here, we don't have a cached profile, so fetch it
profile, err := cli.fetchProfileByID(ctx, signalID)
if err != nil {
// If we get a 401 or 5xx error, we should not retry until the cache expires
if errors.Is(err, ErrProfileUnauthorized) || errors.Is(err, ErrProfileInternalError) {
cli.ProfileCache.errors[signalID.String()] = &err
cli.ProfileCache.lastFetched[signalID.String()] = time.Now()
}
return nil, err
}
// If we get here, we have a valid profile, so cache it
cli.ProfileCache.profiles[signalID.String()] = profile
cli.ProfileCache.lastFetched[signalID.String()] = time.Now()
return profile, nil
}
func (cli *Client) fetchProfileByID(ctx context.Context, signalID uuid.UUID) (*types.Profile, error) {
profileKey, err := cli.ProfileKeyForSignalID(ctx, signalID)
if err != nil {
return nil, fmt.Errorf("error getting profile key: %w", err)
} else if profileKey == nil {
return nil, errProfileKeyNotFound
}
credentialRequest, err := cli.ProfileKeyCredentialRequest(ctx, signalID)
if err != nil {
return nil, fmt.Errorf("error getting profile key credential request: %w", err)
}
return cli.fetchProfileWithRequestAndKey(ctx, signalID, credentialRequest, profileKey)
}
var (
ErrCachedError = errors.New("cached error")
ErrProfileUnauthorized = errors.New("profile get returned 401")
ErrProfileNotFound = errors.New("profile get returned 404")
ErrProfileInternalError = errors.New("profile get returned")
)
func (cli *Client) fetchProfileWithRequestAndKey(ctx context.Context, signalID uuid.UUID, credentialRequest []byte, profileKey *libsignalgo.ProfileKey) (*types.Profile, error) {
log := zerolog.Ctx(ctx)
profileKeyVersion, err := profileKey.GetProfileKeyVersion(signalID)
if err != nil {
return nil, fmt.Errorf("error getting profile key version: %w", err)
}
accessKey, err := profileKey.DeriveAccessKey()
if err != nil {
return nil, fmt.Errorf("error deriving access key: %w", err)
}
base64AccessKey := base64.StdEncoding.EncodeToString(accessKey[:])
path := "/v1/profile/" + signalID.String()
useUnidentified := profileKeyVersion != nil && accessKey != nil
if useUnidentified {
log.Trace().
Hex("profile_key_version", profileKeyVersion[:]).
Msg("Using unidentified profile request")
// Assuming we can just make the version bytes into a string
path += "/" + profileKeyVersion.String()
}
if credentialRequest != nil {
path += "/" + string(credentialRequest)
path += "?credentialType=expiringProfileKey"
}
profileRequest := web.CreateWSRequest(http.MethodGet, path, nil, nil, nil)
if useUnidentified {
profileRequest.Headers = append(profileRequest.Headers, "unidentified-access-key:"+base64AccessKey)
profileRequest.Headers = append(profileRequest.Headers, "accept-language:en-CA")
}
resp, err := cli.UnauthedWS.SendRequest(ctx, profileRequest)
if err != nil {
return nil, fmt.Errorf("error sending request: %w", err)
}
var profile types.Profile
profile.FetchedAt = time.Now()
logEvt := log.Trace().Uint32("status_code", resp.GetStatus())
if logEvt.Enabled() {
if json.Valid(resp.Body) {
logEvt.RawJSON("response_data", resp.Body)
} else {
logEvt.Str("invalid_response_data", base64.StdEncoding.EncodeToString(resp.Body))
}
}
logEvt.Msg("Got profile response")
if *resp.Status < 200 || *resp.Status >= 300 {
if *resp.Status == 401 {
return nil, ErrProfileUnauthorized
} else if *resp.Status == 404 {
return nil, ErrProfileNotFound
} else if *resp.Status >= 500 {
return nil, fmt.Errorf("%w %d", ErrProfileInternalError, *resp.Status)
}
return nil, fmt.Errorf("unexpected status code %d", *resp.Status)
}
var profileResponse ProfileResponse
err = json.Unmarshal(resp.Body, &profileResponse)
if err != nil {
return nil, fmt.Errorf("error unmarshalling profile response: %w", err)
}
if len(profileResponse.Name) > 0 {
profile.Name, err = decryptString(profileKey, profileResponse.Name)
if err != nil {
return nil, fmt.Errorf("error decrypting profile name: %w", err)
}
// TODO store first and last name separately instead of removing the separator
profile.Name = strings.ReplaceAll(profile.Name, "\x00", " ")
}
if len(profileResponse.About) > 0 {
profile.About, err = decryptString(profileKey, profileResponse.About)
if err != nil {
return nil, fmt.Errorf("error decrypting profile about: %w", err)
}
}
if len(profileResponse.AboutEmoji) > 0 {
profile.AboutEmoji, err = decryptString(profileKey, profileResponse.AboutEmoji)
if err != nil {
return nil, fmt.Errorf("error decrypting profile aboutEmoji: %w", err)
}
}
// TODO store other metadata fields?
if profileResponse.Avatar == "" {
profile.AvatarPath = "clear"
} else {
profile.AvatarPath = profileResponse.Avatar
}
profile.Credential = profileResponse.Credential
profile.Key = *profileKey
return &profile, nil
}
func (cli *Client) DownloadUserAvatar(ctx context.Context, avatarPath string, profileKey libsignalgo.ProfileKey) ([]byte, error) {
username, password := cli.Store.BasicAuthCreds()
opts := &web.HTTPReqOpt{
Host: web.CDN1Hostname,
Username: &username,
Password: &password,
}
resp, err := web.SendHTTPRequest(ctx, http.MethodGet, avatarPath, opts)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("unexpected response status %d", resp.StatusCode)
}
encryptedAvatar, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
avatar, err := decryptBytes(profileKey[:], encryptedAvatar)
if err != nil {
return nil, fmt.Errorf("failed to decrypt response: %w", err)
}
return avatar, nil
}
func decryptBytes(key []byte, encryptedText []byte) ([]byte, error) {
if len(encryptedText) < NONCE_LENGTH+16+1 {
return nil, errors.New("invalid encryptedBytes length")
}
nonce := encryptedText[:NONCE_LENGTH]
ciphertext := encryptedText[NONCE_LENGTH:]
decrypted, err := AesgcmDecrypt(key, nonce, ciphertext, []byte{})
if err != nil {
return nil, err
}
return decrypted, nil
}
func decryptString(key *libsignalgo.ProfileKey, encryptedText []byte) (string, error) {
data, err := decryptBytes(key[:], encryptedText)
if err != nil {
return "", err
}
return string(bytes.TrimRight(data, "\x00")), nil
}
func encryptString(key libsignalgo.ProfileKey, plaintext string, paddedLength int) ([]byte, error) {
inputLength := len(plaintext)
if inputLength > paddedLength {
return nil, errors.New("plaintext longer than paddedLength")
}
padded := append([]byte(plaintext), make([]byte, paddedLength-inputLength)...)
nonce := make([]byte, NONCE_LENGTH)
rand.Read(nonce)
keyBytes := key[:]
ciphertext, err := AesgcmEncrypt(keyBytes, nonce, padded)
if err != nil {
return nil, err
}
return append(nonce, ciphertext...), nil
}
const NONCE_LENGTH = 12
const TAG_LENGTH_BYTES = 16
func AesgcmDecrypt(key, nonce, data, mac []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aesgcm, err := cipher.NewGCMWithTagSize(block, TAG_LENGTH_BYTES)
if err != nil {
return nil, err
}
ciphertext := append(data, mac...)
return aesgcm.Open(nil, nonce, ciphertext, nil)
}
func AesgcmEncrypt(key, nonce, plaintext []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
return aesgcm.Seal(nil, nonce, plaintext, nil), nil
}
func (cli *Client) FetchExpiringProfileKeyCredentialById(ctx context.Context, signalACI uuid.UUID) (*libsignalgo.ExpiringProfileKeyCredential, error) {
profileKey, err := cli.ProfileKeyForSignalID(ctx, signalACI)
if err != nil {
return nil, fmt.Errorf("error getting profile key for ACI: %w", err)
}
requestContext, err := libsignalgo.CreateProfileKeyCredentialRequestContext(
prodServerPublicParams,
signalACI,
*profileKey,
)
if err != nil {
return nil, fmt.Errorf("error creating profile key credential request context: %w", err)
}
request, err := requestContext.ProfileKeyCredentialRequestContextGetRequest()
if err != nil {
return nil, fmt.Errorf("error getting profile key credential request: %w", err)
}
// convert request bytes to hexidecimal representation
hexRequest := hex.EncodeToString(request[:])
credentialRequest := []byte(hexRequest)
profile, err := cli.fetchProfileWithRequestAndKey(ctx, signalACI, credentialRequest, profileKey)
if err != nil {
return nil, fmt.Errorf("failed to fetch profile: %w", err)
}
response, err := libsignalgo.NewExpiringProfileKeyCredentialResponse(profile.Credential)
if err != nil {
return nil, fmt.Errorf("failed to get expiring profile key credential response: %w", err)
}
epkc, err := libsignalgo.ReceiveExpiringProfileKeyCredential(prodServerPublicParams, requestContext, response, uint64(time.Now().Unix()))
if err != nil {
return nil, fmt.Errorf("failed to receive expiring profile key credential: %w", err)
}
return epkc, nil
}