mirror of
https://github.com/element-hq/dendrite.git
synced 2025-03-14 14:15:35 +00:00
Support for fallback keys (#3451)
Backports support for fallback keys from Harmony, which should make E2EE more reliable in the face of OTK exhaustion. Signed-off-by: Neil Alexander <git@neilalexander.dev> Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com> [skip ci]
This commit is contained in:
parent
c3d7a34c15
commit
78dbf21c5f
13 changed files with 446 additions and 20 deletions
|
@ -35,6 +35,7 @@ func DeviceOTKCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID, deviceI
|
||||||
return queryRes.Error
|
return queryRes.Error
|
||||||
}
|
}
|
||||||
res.DeviceListsOTKCount = queryRes.Count.KeyCount
|
res.DeviceListsOTKCount = queryRes.Count.KeyCount
|
||||||
|
res.DeviceListsUnusedFallbackAlgorithms = queryRes.UnusedFallbackAlgorithms
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -350,13 +350,14 @@ type ToDeviceResponse struct {
|
||||||
|
|
||||||
// Response represents a /sync API response. See https://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-client-r0-sync
|
// Response represents a /sync API response. See https://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-client-r0-sync
|
||||||
type Response struct {
|
type Response struct {
|
||||||
NextBatch StreamingToken `json:"next_batch"`
|
NextBatch StreamingToken `json:"next_batch"`
|
||||||
AccountData *ClientEvents `json:"account_data,omitempty"`
|
AccountData *ClientEvents `json:"account_data,omitempty"`
|
||||||
Presence *ClientEvents `json:"presence,omitempty"`
|
Presence *ClientEvents `json:"presence,omitempty"`
|
||||||
Rooms *RoomsResponse `json:"rooms,omitempty"`
|
Rooms *RoomsResponse `json:"rooms,omitempty"`
|
||||||
ToDevice *ToDeviceResponse `json:"to_device,omitempty"`
|
ToDevice *ToDeviceResponse `json:"to_device,omitempty"`
|
||||||
DeviceLists *DeviceLists `json:"device_lists,omitempty"`
|
DeviceLists *DeviceLists `json:"device_lists,omitempty"`
|
||||||
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
|
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
|
||||||
|
DeviceListsUnusedFallbackAlgorithms []string `json:"device_unused_fallback_key_types"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r Response) MarshalJSON() ([]byte, error) {
|
func (r Response) MarshalJSON() ([]byte, error) {
|
||||||
|
@ -419,6 +420,7 @@ func NewResponse() *Response {
|
||||||
res.DeviceLists = &DeviceLists{}
|
res.DeviceLists = &DeviceLists{}
|
||||||
res.ToDevice = &ToDeviceResponse{}
|
res.ToDevice = &ToDeviceResponse{}
|
||||||
res.DeviceListsOTKCount = map[string]int{}
|
res.DeviceListsOTKCount = map[string]int{}
|
||||||
|
res.DeviceListsUnusedFallbackAlgorithms = []string{}
|
||||||
|
|
||||||
return &res
|
return &res
|
||||||
}
|
}
|
||||||
|
|
|
@ -791,4 +791,6 @@ remote user can join room with version 11
|
||||||
User can invite remote user to room with version 11
|
User can invite remote user to room with version 11
|
||||||
Remote user can backfill in a room with version 11
|
Remote user can backfill in a room with version 11
|
||||||
Can reject invites over federation for rooms with version 11
|
Can reject invites over federation for rooms with version 11
|
||||||
Can receive redactions from regular users over federation in room version 11
|
Can receive redactions from regular users over federation in room version 11
|
||||||
|
Can upload self-signing keys
|
||||||
|
uploading signed devices gets propagated over federation
|
||||||
|
|
|
@ -788,12 +788,30 @@ type OneTimeKeysCount struct {
|
||||||
KeyCount map[string]int
|
KeyCount map[string]int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FallbackKeys represents a set of fallback keys for a single device
|
||||||
|
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
|
||||||
|
type FallbackKeys struct {
|
||||||
|
// The user who owns this device
|
||||||
|
UserID string
|
||||||
|
// The device ID of this device
|
||||||
|
DeviceID string
|
||||||
|
// A map of algorithm:key_id => key JSON
|
||||||
|
KeyJSON map[string]json.RawMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split a key in KeyJSON into algorithm and key ID
|
||||||
|
func (k *FallbackKeys) Split(keyIDWithAlgo string) (algo string, keyID string) {
|
||||||
|
segments := strings.Split(keyIDWithAlgo, ":")
|
||||||
|
return segments[0], segments[1]
|
||||||
|
}
|
||||||
|
|
||||||
// PerformUploadKeysRequest is the request to PerformUploadKeys
|
// PerformUploadKeysRequest is the request to PerformUploadKeys
|
||||||
type PerformUploadKeysRequest struct {
|
type PerformUploadKeysRequest struct {
|
||||||
UserID string // Required - User performing the request
|
UserID string // Required - User performing the request
|
||||||
DeviceID string // Optional - Device performing the request, for fetching OTK count
|
DeviceID string // Optional - Device performing the request, for fetching OTK count
|
||||||
DeviceKeys []DeviceKeys
|
DeviceKeys []DeviceKeys
|
||||||
OneTimeKeys []OneTimeKeys
|
OneTimeKeys []OneTimeKeys
|
||||||
|
FallbackKeys []FallbackKeys
|
||||||
// OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
|
// OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
|
||||||
// the display name for their respective device, and NOT to modify the keys. The key
|
// the display name for their respective device, and NOT to modify the keys. The key
|
||||||
// itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths.
|
// itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths.
|
||||||
|
@ -810,8 +828,9 @@ type PerformUploadKeysResponse struct {
|
||||||
// A fatal error when processing e.g database failures
|
// A fatal error when processing e.g database failures
|
||||||
Error *KeyError
|
Error *KeyError
|
||||||
// A map of user_id -> device_id -> Error for tracking failures.
|
// A map of user_id -> device_id -> Error for tracking failures.
|
||||||
KeyErrors map[string]map[string]*KeyError
|
KeyErrors map[string]map[string]*KeyError
|
||||||
OneTimeKeyCounts []OneTimeKeysCount
|
OneTimeKeyCounts []OneTimeKeysCount
|
||||||
|
FallbackKeysUnusedAlgorithms []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// PerformDeleteKeysRequest asks the keyserver to forget about certain
|
// PerformDeleteKeysRequest asks the keyserver to forget about certain
|
||||||
|
@ -917,8 +936,9 @@ type QueryOneTimeKeysRequest struct {
|
||||||
|
|
||||||
type QueryOneTimeKeysResponse struct {
|
type QueryOneTimeKeysResponse struct {
|
||||||
// OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84
|
// OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84
|
||||||
Count OneTimeKeysCount
|
Count OneTimeKeysCount
|
||||||
Error *KeyError
|
UnusedFallbackAlgorithms []string
|
||||||
|
Error *KeyError
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryDeviceMessagesRequest struct {
|
type QueryDeviceMessagesRequest struct {
|
||||||
|
|
|
@ -44,14 +44,22 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor
|
||||||
if len(req.DeviceKeys) > 0 {
|
if len(req.DeviceKeys) > 0 {
|
||||||
a.uploadLocalDeviceKeys(ctx, req, res)
|
a.uploadLocalDeviceKeys(ctx, req, res)
|
||||||
}
|
}
|
||||||
if len(req.OneTimeKeys) > 0 {
|
if len(req.OneTimeKeys) > 0 || len(req.FallbackKeys) > 0 {
|
||||||
a.uploadOneTimeKeys(ctx, req, res)
|
a.uploadOneTimeAndFallbackKeys(ctx, req, res)
|
||||||
}
|
}
|
||||||
otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
|
otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
algos, err := a.KeyDatabase.UnusedFallbackKeyAlgorithms(ctx, req.UserID, req.DeviceID)
|
||||||
|
if err != nil {
|
||||||
|
res.Error = &api.KeyError{
|
||||||
|
Err: fmt.Sprintf("Failed to query unused fallback algorithms: %s", err),
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks}
|
res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks}
|
||||||
|
res.FallbackKeysUnusedAlgorithms = algos
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -169,7 +177,15 @@ func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOn
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
algos, err := a.KeyDatabase.UnusedFallbackKeyAlgorithms(ctx, req.UserID, req.DeviceID)
|
||||||
|
if err != nil {
|
||||||
|
res.Error = &api.KeyError{
|
||||||
|
Err: fmt.Sprintf("Failed to query unused fallback algorithms: %s", err),
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
res.Count = *count
|
res.Count = *count
|
||||||
|
res.UnusedFallbackAlgorithms = algos
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -507,6 +523,9 @@ func (a *UserInternalAPI) queryRemoteKeysOnServer(
|
||||||
for userID := range userIDsForAllDevices {
|
for userID := range userIDsForAllDevices {
|
||||||
err := a.Updater.ManualUpdate(context.Background(), spec.ServerName(serverName), userID)
|
err := a.Updater.ManualUpdate(context.Background(), spec.ServerName(serverName), userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, context.Canceled) {
|
||||||
|
return
|
||||||
|
}
|
||||||
logrus.WithFields(logrus.Fields{
|
logrus.WithFields(logrus.Fields{
|
||||||
logrus.ErrorKey: err,
|
logrus.ErrorKey: err,
|
||||||
"user_id": userID,
|
"user_id": userID,
|
||||||
|
@ -520,6 +539,9 @@ func (a *UserInternalAPI) queryRemoteKeysOnServer(
|
||||||
// user so the fact that we're populating all devices here isn't a problem so long as we have devices.
|
// user so the fact that we're populating all devices here isn't a problem so long as we have devices.
|
||||||
err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil)
|
err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, context.Canceled) {
|
||||||
|
return
|
||||||
|
}
|
||||||
logrus.WithFields(logrus.Fields{
|
logrus.WithFields(logrus.Fields{
|
||||||
logrus.ErrorKey: err,
|
logrus.ErrorKey: err,
|
||||||
"user_id": userID,
|
"user_id": userID,
|
||||||
|
@ -715,7 +737,7 @@ func (a *UserInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Pe
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
func (a *UserInternalAPI) uploadOneTimeAndFallbackKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
||||||
if req.UserID == "" {
|
if req.UserID == "" {
|
||||||
res.Error = &api.KeyError{
|
res.Error = &api.KeyError{
|
||||||
Err: "user ID missing",
|
Err: "user ID missing",
|
||||||
|
@ -768,7 +790,32 @@ func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perfor
|
||||||
// collect counts
|
// collect counts
|
||||||
res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts)
|
res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts)
|
||||||
}
|
}
|
||||||
|
if len(req.FallbackKeys) > 0 {
|
||||||
|
if err := a.KeyDatabase.DeleteFallbackKeys(ctx, req.UserID, req.DeviceID); err != nil {
|
||||||
|
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
||||||
|
Err: fmt.Sprintf("%s device %s : failed to clear fallback keys: %s", req.UserID, req.DeviceID, err.Error()),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, key := range req.FallbackKeys {
|
||||||
|
// grab existing keys based on (user/device/algorithm/key ID)
|
||||||
|
keyIDsWithAlgorithms := make([]string, len(key.KeyJSON))
|
||||||
|
i := 0
|
||||||
|
for keyIDWithAlgo := range key.KeyJSON {
|
||||||
|
keyIDsWithAlgorithms[i] = keyIDWithAlgo
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
unused, err := a.KeyDatabase.StoreFallbackKeys(ctx, key)
|
||||||
|
if err != nil {
|
||||||
|
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
||||||
|
Err: fmt.Sprintf("%s device %s : failed to store fallback keys: %s", req.UserID, req.DeviceID, err.Error()),
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// collect counts
|
||||||
|
res.FallbackKeysUnusedAlgorithms = unused
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error {
|
func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error {
|
||||||
|
|
|
@ -167,6 +167,15 @@ type KeyDatabase interface {
|
||||||
// OneTimeKeysCount returns a count of all OTKs for this device.
|
// OneTimeKeysCount returns a count of all OTKs for this device.
|
||||||
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
|
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
|
||||||
|
|
||||||
|
// StoreFallbackKeys persists the given fallback keys.
|
||||||
|
StoreFallbackKeys(ctx context.Context, keys api.FallbackKeys) ([]string, error)
|
||||||
|
|
||||||
|
// UnusedFallbackKeyAlgorithms returns unused fallback algorithms for this user/device.
|
||||||
|
UnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error)
|
||||||
|
|
||||||
|
// DeleteFallbackKeys deletes all fallback keys for the user.
|
||||||
|
DeleteFallbackKeys(ctx context.Context, userID, deviceID string) error
|
||||||
|
|
||||||
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
|
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
|
||||||
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||||
|
|
||||||
|
|
134
userapi/storage/postgres/fallback_keys_table.go
Normal file
134
userapi/storage/postgres/fallback_keys_table.go
Normal file
|
@ -0,0 +1,134 @@
|
||||||
|
// Copyright 2024 New Vector Ltd.
|
||||||
|
// Copyright 2017 Vector Creations Ltd
|
||||||
|
//
|
||||||
|
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
|
||||||
|
// Please see LICENSE files in the repository root for full details.
|
||||||
|
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/element-hq/dendrite/internal"
|
||||||
|
"github.com/element-hq/dendrite/internal/sqlutil"
|
||||||
|
"github.com/element-hq/dendrite/userapi/api"
|
||||||
|
"github.com/element-hq/dendrite/userapi/storage/tables"
|
||||||
|
)
|
||||||
|
|
||||||
|
var fallbackKeysSchema = `
|
||||||
|
-- Stores one-time public keys for users
|
||||||
|
CREATE TABLE IF NOT EXISTS keyserver_fallback_keys (
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
device_id TEXT NOT NULL,
|
||||||
|
key_id TEXT NOT NULL,
|
||||||
|
algorithm TEXT NOT NULL,
|
||||||
|
ts_added_secs BIGINT NOT NULL,
|
||||||
|
key_json TEXT NOT NULL,
|
||||||
|
used BOOLEAN NOT NULL,
|
||||||
|
-- Clobber based on tuple of user/device/algorithm.
|
||||||
|
CONSTRAINT keyserver_fallback_keys_unique UNIQUE (user_id, device_id, algorithm)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS keyserver_fallback_keys_idx ON keyserver_fallback_keys (user_id, device_id);
|
||||||
|
`
|
||||||
|
|
||||||
|
const upsertFallbackKeysSQL = "" +
|
||||||
|
"INSERT INTO keyserver_fallback_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json, used)" +
|
||||||
|
" VALUES ($1, $2, $3, $4, $5, $6, false)" +
|
||||||
|
" ON CONFLICT ON CONSTRAINT keyserver_fallback_keys_unique" +
|
||||||
|
" DO UPDATE SET key_id = $3, key_json = $6, used = false"
|
||||||
|
|
||||||
|
const selectFallbackUnusedAlgorithmsSQL = "" +
|
||||||
|
"SELECT algorithm FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2 AND used = false"
|
||||||
|
|
||||||
|
const selectFallbackKeysByAlgorithmSQL = "" +
|
||||||
|
"SELECT key_id, key_json FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 ORDER BY used ASC LIMIT 1"
|
||||||
|
|
||||||
|
const deleteFallbackKeysSQL = "" +
|
||||||
|
"DELETE FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2"
|
||||||
|
|
||||||
|
const updateFallbackKeyUsedSQL = "" +
|
||||||
|
"UPDATE keyserver_fallback_keys SET used=true WHERE user_id = $1 AND device_id = $2 AND key_id = $3 AND algorithm = $4"
|
||||||
|
|
||||||
|
type fallbackKeysStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
upsertKeysStmt *sql.Stmt
|
||||||
|
selectUnusedAlgorithmsStmt *sql.Stmt
|
||||||
|
selectKeyByAlgorithmStmt *sql.Stmt
|
||||||
|
deleteFallbackKeysStmt *sql.Stmt
|
||||||
|
updateFallbackKeyUsedStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPostgresFallbackKeysTable(db *sql.DB) (tables.FallbackKeys, error) {
|
||||||
|
s := &fallbackKeysStatements{
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
_, err := db.Exec(fallbackKeysSchema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s, sqlutil.StatementList{
|
||||||
|
{&s.upsertKeysStmt, upsertFallbackKeysSQL},
|
||||||
|
{&s.selectUnusedAlgorithmsStmt, selectFallbackUnusedAlgorithmsSQL},
|
||||||
|
{&s.selectKeyByAlgorithmStmt, selectFallbackKeysByAlgorithmSQL},
|
||||||
|
{&s.deleteFallbackKeysStmt, deleteFallbackKeysSQL},
|
||||||
|
{&s.updateFallbackKeyUsedStmt, updateFallbackKeyUsedSQL},
|
||||||
|
}.Prepare(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fallbackKeysStatements) SelectUnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error) {
|
||||||
|
rows, err := s.selectUnusedAlgorithmsStmt.QueryContext(ctx, userID, deviceID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
|
||||||
|
algos := []string{}
|
||||||
|
for rows.Next() {
|
||||||
|
var algorithm string
|
||||||
|
if err = rows.Scan(&algorithm); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
algos = append(algos, algorithm)
|
||||||
|
}
|
||||||
|
return algos, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fallbackKeysStatements) InsertFallbackKeys(ctx context.Context, txn *sql.Tx, keys api.FallbackKeys) ([]string, error) {
|
||||||
|
now := time.Now().Unix()
|
||||||
|
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
||||||
|
algo, keyID := keys.Split(keyIDWithAlgo)
|
||||||
|
_, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
|
||||||
|
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.SelectUnusedFallbackKeyAlgorithms(ctx, keys.UserID, keys.DeviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fallbackKeysStatements) DeleteFallbackKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
|
||||||
|
_, err := sqlutil.TxStmt(txn, s.deleteFallbackKeysStmt).ExecContext(ctx, userID, deviceID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fallbackKeysStatements) SelectAndUpdateFallbackKey(
|
||||||
|
ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
|
||||||
|
) (map[string]json.RawMessage, error) {
|
||||||
|
var keyID string
|
||||||
|
var keyJSON string
|
||||||
|
err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_, err = sqlutil.TxStmtContext(ctx, txn, s.updateFallbackKeyUsedStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
|
||||||
|
return map[string]json.RawMessage{
|
||||||
|
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
||||||
|
}, err
|
||||||
|
}
|
|
@ -141,6 +141,10 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
fk, err := NewPostgresFallbackKeysTable(db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
dk, err := NewPostgresDeviceKeysTable(db)
|
dk, err := NewPostgresDeviceKeysTable(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -164,6 +168,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
|
||||||
|
|
||||||
return &shared.KeyDatabase{
|
return &shared.KeyDatabase{
|
||||||
OneTimeKeysTable: otk,
|
OneTimeKeysTable: otk,
|
||||||
|
FallbackKeysTable: fk,
|
||||||
DeviceKeysTable: dk,
|
DeviceKeysTable: dk,
|
||||||
KeyChangesTable: kc,
|
KeyChangesTable: kc,
|
||||||
StaleDeviceListsTable: sdl,
|
StaleDeviceListsTable: sdl,
|
||||||
|
|
|
@ -57,6 +57,7 @@ type Database struct {
|
||||||
|
|
||||||
type KeyDatabase struct {
|
type KeyDatabase struct {
|
||||||
OneTimeKeysTable tables.OneTimeKeys
|
OneTimeKeysTable tables.OneTimeKeys
|
||||||
|
FallbackKeysTable tables.FallbackKeys
|
||||||
DeviceKeysTable tables.DeviceKeys
|
DeviceKeysTable tables.DeviceKeys
|
||||||
KeyChangesTable tables.KeyChanges
|
KeyChangesTable tables.KeyChanges
|
||||||
StaleDeviceListsTable tables.StaleDeviceLists
|
StaleDeviceListsTable tables.StaleDeviceLists
|
||||||
|
@ -937,6 +938,22 @@ func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID str
|
||||||
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
|
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *KeyDatabase) StoreFallbackKeys(ctx context.Context, keys api.FallbackKeys) (unused []string, err error) {
|
||||||
|
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
unused, err = d.FallbackKeysTable.InsertFallbackKeys(ctx, txn, keys)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *KeyDatabase) DeleteFallbackKeys(ctx context.Context, userID, deviceID string) error {
|
||||||
|
return d.FallbackKeysTable.DeleteFallbackKeys(ctx, nil, userID, deviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *KeyDatabase) UnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error) {
|
||||||
|
return d.FallbackKeysTable.SelectUnusedFallbackKeyAlgorithms(ctx, userID, deviceID)
|
||||||
|
}
|
||||||
|
|
||||||
func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||||
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
|
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
|
||||||
}
|
}
|
||||||
|
@ -999,6 +1016,12 @@ func (d *KeyDatabase) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if len(keyJSON) == 0 {
|
||||||
|
keyJSON, err = d.FallbackKeysTable.SelectAndUpdateFallbackKey(ctx, txn, userID, deviceID, algo)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
if keyJSON != nil {
|
if keyJSON != nil {
|
||||||
result = append(result, api.OneTimeKeys{
|
result = append(result, api.OneTimeKeys{
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
|
|
132
userapi/storage/sqlite3/fallback_keys_table.go
Normal file
132
userapi/storage/sqlite3/fallback_keys_table.go
Normal file
|
@ -0,0 +1,132 @@
|
||||||
|
// Copyright 2024 New Vector Ltd.
|
||||||
|
// Copyright 2017 Vector Creations Ltd
|
||||||
|
//
|
||||||
|
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
|
||||||
|
// Please see LICENSE files in the repository root for full details.
|
||||||
|
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/element-hq/dendrite/internal"
|
||||||
|
"github.com/element-hq/dendrite/internal/sqlutil"
|
||||||
|
"github.com/element-hq/dendrite/userapi/api"
|
||||||
|
"github.com/element-hq/dendrite/userapi/storage/tables"
|
||||||
|
)
|
||||||
|
|
||||||
|
var fallbackKeysSchema = `
|
||||||
|
-- Stores one-time public keys for users
|
||||||
|
CREATE TABLE IF NOT EXISTS keyserver_fallback_keys (
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
device_id TEXT NOT NULL,
|
||||||
|
key_id TEXT NOT NULL,
|
||||||
|
algorithm TEXT NOT NULL,
|
||||||
|
ts_added_secs BIGINT NOT NULL,
|
||||||
|
key_json TEXT NOT NULL,
|
||||||
|
used BOOLEAN NOT NULL
|
||||||
|
);
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS keyserver_fallback_keys_unique_idx ON keyserver_fallback_keys(user_id, device_id, algorithm);
|
||||||
|
CREATE INDEX IF NOT EXISTS keyserver_fallback_keys_idx ON keyserver_fallback_keys (user_id, device_id);
|
||||||
|
`
|
||||||
|
|
||||||
|
const upsertFallbackKeysSQL = "" +
|
||||||
|
"INSERT INTO keyserver_fallback_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json, used)" +
|
||||||
|
" VALUES ($1, $2, $3, $4, $5, $6, false)" +
|
||||||
|
" ON CONFLICT (user_id, device_id, algorithm)" +
|
||||||
|
" DO UPDATE SET key_id = $3, key_json = $6, used = false"
|
||||||
|
|
||||||
|
const selectFallbackUnusedAlgorithmsSQL = "" +
|
||||||
|
"SELECT algorithm FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2 AND used = false"
|
||||||
|
|
||||||
|
const selectFallbackKeysByAlgorithmSQL = "" +
|
||||||
|
"SELECT key_id, key_json FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 ORDER BY used ASC LIMIT 1"
|
||||||
|
|
||||||
|
const deleteFallbackKeysSQL = "" +
|
||||||
|
"DELETE FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2"
|
||||||
|
|
||||||
|
const updateFallbackKeyUsedSQL = "" +
|
||||||
|
"UPDATE keyserver_fallback_keys SET used=true WHERE user_id = $1 AND device_id = $2 AND key_id = $3 AND algorithm = $4"
|
||||||
|
|
||||||
|
type fallbackKeysStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
upsertKeysStmt *sql.Stmt
|
||||||
|
selectUnusedAlgorithmsStmt *sql.Stmt
|
||||||
|
selectKeyByAlgorithmStmt *sql.Stmt
|
||||||
|
deleteFallbackKeysStmt *sql.Stmt
|
||||||
|
updateFallbackKeyUsedStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSqliteFallbackKeysTable(db *sql.DB) (tables.FallbackKeys, error) {
|
||||||
|
s := &fallbackKeysStatements{
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
_, err := db.Exec(fallbackKeysSchema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s, sqlutil.StatementList{
|
||||||
|
{&s.upsertKeysStmt, upsertFallbackKeysSQL},
|
||||||
|
{&s.selectUnusedAlgorithmsStmt, selectFallbackUnusedAlgorithmsSQL},
|
||||||
|
{&s.selectKeyByAlgorithmStmt, selectFallbackKeysByAlgorithmSQL},
|
||||||
|
{&s.deleteFallbackKeysStmt, deleteFallbackKeysSQL},
|
||||||
|
{&s.updateFallbackKeyUsedStmt, updateFallbackKeyUsedSQL},
|
||||||
|
}.Prepare(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fallbackKeysStatements) SelectUnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error) {
|
||||||
|
rows, err := s.selectUnusedAlgorithmsStmt.QueryContext(ctx, userID, deviceID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
|
||||||
|
algos := []string{}
|
||||||
|
for rows.Next() {
|
||||||
|
var algorithm string
|
||||||
|
if err = rows.Scan(&algorithm); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
algos = append(algos, algorithm)
|
||||||
|
}
|
||||||
|
return algos, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fallbackKeysStatements) InsertFallbackKeys(ctx context.Context, txn *sql.Tx, keys api.FallbackKeys) ([]string, error) {
|
||||||
|
now := time.Now().Unix()
|
||||||
|
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
||||||
|
algo, keyID := keys.Split(keyIDWithAlgo)
|
||||||
|
_, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
|
||||||
|
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.SelectUnusedFallbackKeyAlgorithms(ctx, keys.UserID, keys.DeviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fallbackKeysStatements) DeleteFallbackKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
|
||||||
|
_, err := sqlutil.TxStmt(txn, s.deleteFallbackKeysStmt).ExecContext(ctx, userID, deviceID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fallbackKeysStatements) SelectAndUpdateFallbackKey(
|
||||||
|
ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
|
||||||
|
) (map[string]json.RawMessage, error) {
|
||||||
|
var keyID string
|
||||||
|
var keyJSON string
|
||||||
|
err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_, err = sqlutil.TxStmtContext(ctx, txn, s.updateFallbackKeyUsedStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
|
||||||
|
return map[string]json.RawMessage{
|
||||||
|
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
||||||
|
}, err
|
||||||
|
}
|
|
@ -138,6 +138,10 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
fk, err := NewSqliteFallbackKeysTable(db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
dk, err := NewSqliteDeviceKeysTable(db)
|
dk, err := NewSqliteDeviceKeysTable(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -161,6 +165,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
|
||||||
|
|
||||||
return &shared.KeyDatabase{
|
return &shared.KeyDatabase{
|
||||||
OneTimeKeysTable: otk,
|
OneTimeKeysTable: otk,
|
||||||
|
FallbackKeysTable: fk,
|
||||||
DeviceKeysTable: dk,
|
DeviceKeysTable: dk,
|
||||||
KeyChangesTable: kc,
|
KeyChangesTable: kc,
|
||||||
StaleDeviceListsTable: sdl,
|
StaleDeviceListsTable: sdl,
|
||||||
|
|
|
@ -809,3 +809,42 @@ func TestOneTimeKeys(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFallbackKeys(t *testing.T) {
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
db, clean := mustCreateKeyDatabase(t, dbType)
|
||||||
|
defer clean()
|
||||||
|
userID := "@alice:localhost"
|
||||||
|
deviceID := "alice_device"
|
||||||
|
fk := api.FallbackKeys{
|
||||||
|
UserID: userID,
|
||||||
|
DeviceID: deviceID,
|
||||||
|
KeyJSON: map[string]json.RawMessage{"curve25519:KEY1": []byte(`{"key":"v1"}`)},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := db.StoreFallbackKeys(ctx, fk)
|
||||||
|
MustNotError(t, err)
|
||||||
|
|
||||||
|
unused, err := db.UnusedFallbackKeyAlgorithms(ctx, userID, deviceID)
|
||||||
|
MustNotError(t, err)
|
||||||
|
if c := len(unused); c != 1 {
|
||||||
|
t.Fatalf("Expected 1 unused key algorithm, got %d", c)
|
||||||
|
}
|
||||||
|
if unused[0] != "curve25519" {
|
||||||
|
t.Fatalf("Expected unused key algorithm to be 'curve25519', got '%s'", unused[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
// No other one-time keys have been uploaded so we expect to get the fallback key instead.
|
||||||
|
claimed, err := db.ClaimKeys(ctx, map[string]map[string]string{userID: {deviceID: "curve25519"}})
|
||||||
|
MustNotError(t, err)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case claimed[0].UserID != fk.UserID:
|
||||||
|
t.Fatalf("Claimed user ID ID doesn't match, got %q, want %q", claimed[0].UserID, fk.DeviceID)
|
||||||
|
case claimed[0].DeviceID != fk.DeviceID:
|
||||||
|
t.Fatalf("Claimed device ID doesn't match, got %q, want %q", claimed[0].DeviceID, fk.DeviceID)
|
||||||
|
case claimed[0].KeyJSON["curve25519:KEY1"] == nil:
|
||||||
|
t.Fatalf("Claimed key JSON for curve25519:KEY1 not found")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -170,6 +170,13 @@ type DeviceKeys interface {
|
||||||
DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
|
DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type FallbackKeys interface {
|
||||||
|
SelectUnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error)
|
||||||
|
InsertFallbackKeys(ctx context.Context, txn *sql.Tx, keys api.FallbackKeys) ([]string, error)
|
||||||
|
DeleteFallbackKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
|
||||||
|
SelectAndUpdateFallbackKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error)
|
||||||
|
}
|
||||||
|
|
||||||
type KeyChanges interface {
|
type KeyChanges interface {
|
||||||
InsertKeyChange(ctx context.Context, userID string) (int64, error)
|
InsertKeyChange(ctx context.Context, userID string) (int64, error)
|
||||||
// SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets.
|
// SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets.
|
||||||
|
|
Loading…
Add table
Reference in a new issue