mirror of
https://github.com/element-hq/dendrite.git
synced 2025-03-14 14:15:35 +00:00
Support for fallback keys (#3451)
Backports support for fallback keys from Harmony, which should make E2EE more reliable in the face of OTK exhaustion. Signed-off-by: Neil Alexander <git@neilalexander.dev> Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com> [skip ci]
This commit is contained in:
parent
c3d7a34c15
commit
78dbf21c5f
13 changed files with 446 additions and 20 deletions
|
@ -35,6 +35,7 @@ func DeviceOTKCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID, deviceI
|
|||
return queryRes.Error
|
||||
}
|
||||
res.DeviceListsOTKCount = queryRes.Count.KeyCount
|
||||
res.DeviceListsUnusedFallbackAlgorithms = queryRes.UnusedFallbackAlgorithms
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -350,13 +350,14 @@ type ToDeviceResponse struct {
|
|||
|
||||
// Response represents a /sync API response. See https://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-client-r0-sync
|
||||
type Response struct {
|
||||
NextBatch StreamingToken `json:"next_batch"`
|
||||
AccountData *ClientEvents `json:"account_data,omitempty"`
|
||||
Presence *ClientEvents `json:"presence,omitempty"`
|
||||
Rooms *RoomsResponse `json:"rooms,omitempty"`
|
||||
ToDevice *ToDeviceResponse `json:"to_device,omitempty"`
|
||||
DeviceLists *DeviceLists `json:"device_lists,omitempty"`
|
||||
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
|
||||
NextBatch StreamingToken `json:"next_batch"`
|
||||
AccountData *ClientEvents `json:"account_data,omitempty"`
|
||||
Presence *ClientEvents `json:"presence,omitempty"`
|
||||
Rooms *RoomsResponse `json:"rooms,omitempty"`
|
||||
ToDevice *ToDeviceResponse `json:"to_device,omitempty"`
|
||||
DeviceLists *DeviceLists `json:"device_lists,omitempty"`
|
||||
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
|
||||
DeviceListsUnusedFallbackAlgorithms []string `json:"device_unused_fallback_key_types"`
|
||||
}
|
||||
|
||||
func (r Response) MarshalJSON() ([]byte, error) {
|
||||
|
@ -419,6 +420,7 @@ func NewResponse() *Response {
|
|||
res.DeviceLists = &DeviceLists{}
|
||||
res.ToDevice = &ToDeviceResponse{}
|
||||
res.DeviceListsOTKCount = map[string]int{}
|
||||
res.DeviceListsUnusedFallbackAlgorithms = []string{}
|
||||
|
||||
return &res
|
||||
}
|
||||
|
|
|
@ -791,4 +791,6 @@ remote user can join room with version 11
|
|||
User can invite remote user to room with version 11
|
||||
Remote user can backfill in a room with version 11
|
||||
Can reject invites over federation for rooms with version 11
|
||||
Can receive redactions from regular users over federation in room version 11
|
||||
Can receive redactions from regular users over federation in room version 11
|
||||
Can upload self-signing keys
|
||||
uploading signed devices gets propagated over federation
|
||||
|
|
|
@ -788,12 +788,30 @@ type OneTimeKeysCount struct {
|
|||
KeyCount map[string]int
|
||||
}
|
||||
|
||||
// FallbackKeys represents a set of fallback keys for a single device
|
||||
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
|
||||
type FallbackKeys struct {
|
||||
// The user who owns this device
|
||||
UserID string
|
||||
// The device ID of this device
|
||||
DeviceID string
|
||||
// A map of algorithm:key_id => key JSON
|
||||
KeyJSON map[string]json.RawMessage
|
||||
}
|
||||
|
||||
// Split a key in KeyJSON into algorithm and key ID
|
||||
func (k *FallbackKeys) Split(keyIDWithAlgo string) (algo string, keyID string) {
|
||||
segments := strings.Split(keyIDWithAlgo, ":")
|
||||
return segments[0], segments[1]
|
||||
}
|
||||
|
||||
// PerformUploadKeysRequest is the request to PerformUploadKeys
|
||||
type PerformUploadKeysRequest struct {
|
||||
UserID string // Required - User performing the request
|
||||
DeviceID string // Optional - Device performing the request, for fetching OTK count
|
||||
DeviceKeys []DeviceKeys
|
||||
OneTimeKeys []OneTimeKeys
|
||||
UserID string // Required - User performing the request
|
||||
DeviceID string // Optional - Device performing the request, for fetching OTK count
|
||||
DeviceKeys []DeviceKeys
|
||||
OneTimeKeys []OneTimeKeys
|
||||
FallbackKeys []FallbackKeys
|
||||
// OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
|
||||
// the display name for their respective device, and NOT to modify the keys. The key
|
||||
// itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths.
|
||||
|
@ -810,8 +828,9 @@ type PerformUploadKeysResponse struct {
|
|||
// A fatal error when processing e.g database failures
|
||||
Error *KeyError
|
||||
// A map of user_id -> device_id -> Error for tracking failures.
|
||||
KeyErrors map[string]map[string]*KeyError
|
||||
OneTimeKeyCounts []OneTimeKeysCount
|
||||
KeyErrors map[string]map[string]*KeyError
|
||||
OneTimeKeyCounts []OneTimeKeysCount
|
||||
FallbackKeysUnusedAlgorithms []string
|
||||
}
|
||||
|
||||
// PerformDeleteKeysRequest asks the keyserver to forget about certain
|
||||
|
@ -917,8 +936,9 @@ type QueryOneTimeKeysRequest struct {
|
|||
|
||||
type QueryOneTimeKeysResponse struct {
|
||||
// OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84
|
||||
Count OneTimeKeysCount
|
||||
Error *KeyError
|
||||
Count OneTimeKeysCount
|
||||
UnusedFallbackAlgorithms []string
|
||||
Error *KeyError
|
||||
}
|
||||
|
||||
type QueryDeviceMessagesRequest struct {
|
||||
|
|
|
@ -44,14 +44,22 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor
|
|||
if len(req.DeviceKeys) > 0 {
|
||||
a.uploadLocalDeviceKeys(ctx, req, res)
|
||||
}
|
||||
if len(req.OneTimeKeys) > 0 {
|
||||
a.uploadOneTimeKeys(ctx, req, res)
|
||||
if len(req.OneTimeKeys) > 0 || len(req.FallbackKeys) > 0 {
|
||||
a.uploadOneTimeAndFallbackKeys(ctx, req, res)
|
||||
}
|
||||
otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
algos, err := a.KeyDatabase.UnusedFallbackKeyAlgorithms(ctx, req.UserID, req.DeviceID)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("Failed to query unused fallback algorithms: %s", err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks}
|
||||
res.FallbackKeysUnusedAlgorithms = algos
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -169,7 +177,15 @@ func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOn
|
|||
}
|
||||
return nil
|
||||
}
|
||||
algos, err := a.KeyDatabase.UnusedFallbackKeyAlgorithms(ctx, req.UserID, req.DeviceID)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("Failed to query unused fallback algorithms: %s", err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
res.Count = *count
|
||||
res.UnusedFallbackAlgorithms = algos
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -507,6 +523,9 @@ func (a *UserInternalAPI) queryRemoteKeysOnServer(
|
|||
for userID := range userIDsForAllDevices {
|
||||
err := a.Updater.ManualUpdate(context.Background(), spec.ServerName(serverName), userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return
|
||||
}
|
||||
logrus.WithFields(logrus.Fields{
|
||||
logrus.ErrorKey: err,
|
||||
"user_id": userID,
|
||||
|
@ -520,6 +539,9 @@ func (a *UserInternalAPI) queryRemoteKeysOnServer(
|
|||
// user so the fact that we're populating all devices here isn't a problem so long as we have devices.
|
||||
err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return
|
||||
}
|
||||
logrus.WithFields(logrus.Fields{
|
||||
logrus.ErrorKey: err,
|
||||
"user_id": userID,
|
||||
|
@ -715,7 +737,7 @@ func (a *UserInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Pe
|
|||
}
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
||||
func (a *UserInternalAPI) uploadOneTimeAndFallbackKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
||||
if req.UserID == "" {
|
||||
res.Error = &api.KeyError{
|
||||
Err: "user ID missing",
|
||||
|
@ -768,7 +790,32 @@ func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perfor
|
|||
// collect counts
|
||||
res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts)
|
||||
}
|
||||
|
||||
if len(req.FallbackKeys) > 0 {
|
||||
if err := a.KeyDatabase.DeleteFallbackKeys(ctx, req.UserID, req.DeviceID); err != nil {
|
||||
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
||||
Err: fmt.Sprintf("%s device %s : failed to clear fallback keys: %s", req.UserID, req.DeviceID, err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
for _, key := range req.FallbackKeys {
|
||||
// grab existing keys based on (user/device/algorithm/key ID)
|
||||
keyIDsWithAlgorithms := make([]string, len(key.KeyJSON))
|
||||
i := 0
|
||||
for keyIDWithAlgo := range key.KeyJSON {
|
||||
keyIDsWithAlgorithms[i] = keyIDWithAlgo
|
||||
i++
|
||||
}
|
||||
unused, err := a.KeyDatabase.StoreFallbackKeys(ctx, key)
|
||||
if err != nil {
|
||||
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
||||
Err: fmt.Sprintf("%s device %s : failed to store fallback keys: %s", req.UserID, req.DeviceID, err.Error()),
|
||||
})
|
||||
continue
|
||||
}
|
||||
// collect counts
|
||||
res.FallbackKeysUnusedAlgorithms = unused
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error {
|
||||
|
|
|
@ -167,6 +167,15 @@ type KeyDatabase interface {
|
|||
// OneTimeKeysCount returns a count of all OTKs for this device.
|
||||
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
|
||||
|
||||
// StoreFallbackKeys persists the given fallback keys.
|
||||
StoreFallbackKeys(ctx context.Context, keys api.FallbackKeys) ([]string, error)
|
||||
|
||||
// UnusedFallbackKeyAlgorithms returns unused fallback algorithms for this user/device.
|
||||
UnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error)
|
||||
|
||||
// DeleteFallbackKeys deletes all fallback keys for the user.
|
||||
DeleteFallbackKeys(ctx context.Context, userID, deviceID string) error
|
||||
|
||||
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
|
||||
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||
|
||||
|
|
134
userapi/storage/postgres/fallback_keys_table.go
Normal file
134
userapi/storage/postgres/fallback_keys_table.go
Normal file
|
@ -0,0 +1,134 @@
|
|||
// Copyright 2024 New Vector Ltd.
|
||||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
|
||||
// Please see LICENSE files in the repository root for full details.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/element-hq/dendrite/internal"
|
||||
"github.com/element-hq/dendrite/internal/sqlutil"
|
||||
"github.com/element-hq/dendrite/userapi/api"
|
||||
"github.com/element-hq/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
var fallbackKeysSchema = `
|
||||
-- Stores one-time public keys for users
|
||||
CREATE TABLE IF NOT EXISTS keyserver_fallback_keys (
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
key_id TEXT NOT NULL,
|
||||
algorithm TEXT NOT NULL,
|
||||
ts_added_secs BIGINT NOT NULL,
|
||||
key_json TEXT NOT NULL,
|
||||
used BOOLEAN NOT NULL,
|
||||
-- Clobber based on tuple of user/device/algorithm.
|
||||
CONSTRAINT keyserver_fallback_keys_unique UNIQUE (user_id, device_id, algorithm)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keyserver_fallback_keys_idx ON keyserver_fallback_keys (user_id, device_id);
|
||||
`
|
||||
|
||||
const upsertFallbackKeysSQL = "" +
|
||||
"INSERT INTO keyserver_fallback_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json, used)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6, false)" +
|
||||
" ON CONFLICT ON CONSTRAINT keyserver_fallback_keys_unique" +
|
||||
" DO UPDATE SET key_id = $3, key_json = $6, used = false"
|
||||
|
||||
const selectFallbackUnusedAlgorithmsSQL = "" +
|
||||
"SELECT algorithm FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2 AND used = false"
|
||||
|
||||
const selectFallbackKeysByAlgorithmSQL = "" +
|
||||
"SELECT key_id, key_json FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 ORDER BY used ASC LIMIT 1"
|
||||
|
||||
const deleteFallbackKeysSQL = "" +
|
||||
"DELETE FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2"
|
||||
|
||||
const updateFallbackKeyUsedSQL = "" +
|
||||
"UPDATE keyserver_fallback_keys SET used=true WHERE user_id = $1 AND device_id = $2 AND key_id = $3 AND algorithm = $4"
|
||||
|
||||
type fallbackKeysStatements struct {
|
||||
db *sql.DB
|
||||
upsertKeysStmt *sql.Stmt
|
||||
selectUnusedAlgorithmsStmt *sql.Stmt
|
||||
selectKeyByAlgorithmStmt *sql.Stmt
|
||||
deleteFallbackKeysStmt *sql.Stmt
|
||||
updateFallbackKeyUsedStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresFallbackKeysTable(db *sql.DB) (tables.FallbackKeys, error) {
|
||||
s := &fallbackKeysStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(fallbackKeysSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertKeysStmt, upsertFallbackKeysSQL},
|
||||
{&s.selectUnusedAlgorithmsStmt, selectFallbackUnusedAlgorithmsSQL},
|
||||
{&s.selectKeyByAlgorithmStmt, selectFallbackKeysByAlgorithmSQL},
|
||||
{&s.deleteFallbackKeysStmt, deleteFallbackKeysSQL},
|
||||
{&s.updateFallbackKeyUsedStmt, updateFallbackKeyUsedSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *fallbackKeysStatements) SelectUnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error) {
|
||||
rows, err := s.selectUnusedAlgorithmsStmt.QueryContext(ctx, userID, deviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
|
||||
algos := []string{}
|
||||
for rows.Next() {
|
||||
var algorithm string
|
||||
if err = rows.Scan(&algorithm); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
algos = append(algos, algorithm)
|
||||
}
|
||||
return algos, rows.Err()
|
||||
}
|
||||
|
||||
func (s *fallbackKeysStatements) InsertFallbackKeys(ctx context.Context, txn *sql.Tx, keys api.FallbackKeys) ([]string, error) {
|
||||
now := time.Now().Unix()
|
||||
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
||||
algo, keyID := keys.Split(keyIDWithAlgo)
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
|
||||
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return s.SelectUnusedFallbackKeyAlgorithms(ctx, keys.UserID, keys.DeviceID)
|
||||
}
|
||||
|
||||
func (s *fallbackKeysStatements) DeleteFallbackKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deleteFallbackKeysStmt).ExecContext(ctx, userID, deviceID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *fallbackKeysStatements) SelectAndUpdateFallbackKey(
|
||||
ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
|
||||
) (map[string]json.RawMessage, error) {
|
||||
var keyID string
|
||||
var keyJSON string
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
_, err = sqlutil.TxStmtContext(ctx, txn, s.updateFallbackKeyUsedStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
|
||||
return map[string]json.RawMessage{
|
||||
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
||||
}, err
|
||||
}
|
|
@ -141,6 +141,10 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fk, err := NewPostgresFallbackKeysTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dk, err := NewPostgresDeviceKeysTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -164,6 +168,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
|
|||
|
||||
return &shared.KeyDatabase{
|
||||
OneTimeKeysTable: otk,
|
||||
FallbackKeysTable: fk,
|
||||
DeviceKeysTable: dk,
|
||||
KeyChangesTable: kc,
|
||||
StaleDeviceListsTable: sdl,
|
||||
|
|
|
@ -57,6 +57,7 @@ type Database struct {
|
|||
|
||||
type KeyDatabase struct {
|
||||
OneTimeKeysTable tables.OneTimeKeys
|
||||
FallbackKeysTable tables.FallbackKeys
|
||||
DeviceKeysTable tables.DeviceKeys
|
||||
KeyChangesTable tables.KeyChanges
|
||||
StaleDeviceListsTable tables.StaleDeviceLists
|
||||
|
@ -937,6 +938,22 @@ func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID str
|
|||
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) StoreFallbackKeys(ctx context.Context, keys api.FallbackKeys) (unused []string, err error) {
|
||||
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
unused, err = d.FallbackKeysTable.InsertFallbackKeys(ctx, txn, keys)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) DeleteFallbackKeys(ctx context.Context, userID, deviceID string) error {
|
||||
return d.FallbackKeysTable.DeleteFallbackKeys(ctx, nil, userID, deviceID)
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) UnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error) {
|
||||
return d.FallbackKeysTable.SelectUnusedFallbackKeyAlgorithms(ctx, userID, deviceID)
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
|
||||
}
|
||||
|
@ -999,6 +1016,12 @@ func (d *KeyDatabase) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(keyJSON) == 0 {
|
||||
keyJSON, err = d.FallbackKeysTable.SelectAndUpdateFallbackKey(ctx, txn, userID, deviceID, algo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if keyJSON != nil {
|
||||
result = append(result, api.OneTimeKeys{
|
||||
UserID: userID,
|
||||
|
|
132
userapi/storage/sqlite3/fallback_keys_table.go
Normal file
132
userapi/storage/sqlite3/fallback_keys_table.go
Normal file
|
@ -0,0 +1,132 @@
|
|||
// Copyright 2024 New Vector Ltd.
|
||||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
|
||||
// Please see LICENSE files in the repository root for full details.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/element-hq/dendrite/internal"
|
||||
"github.com/element-hq/dendrite/internal/sqlutil"
|
||||
"github.com/element-hq/dendrite/userapi/api"
|
||||
"github.com/element-hq/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
var fallbackKeysSchema = `
|
||||
-- Stores one-time public keys for users
|
||||
CREATE TABLE IF NOT EXISTS keyserver_fallback_keys (
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
key_id TEXT NOT NULL,
|
||||
algorithm TEXT NOT NULL,
|
||||
ts_added_secs BIGINT NOT NULL,
|
||||
key_json TEXT NOT NULL,
|
||||
used BOOLEAN NOT NULL
|
||||
);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS keyserver_fallback_keys_unique_idx ON keyserver_fallback_keys(user_id, device_id, algorithm);
|
||||
CREATE INDEX IF NOT EXISTS keyserver_fallback_keys_idx ON keyserver_fallback_keys (user_id, device_id);
|
||||
`
|
||||
|
||||
const upsertFallbackKeysSQL = "" +
|
||||
"INSERT INTO keyserver_fallback_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json, used)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6, false)" +
|
||||
" ON CONFLICT (user_id, device_id, algorithm)" +
|
||||
" DO UPDATE SET key_id = $3, key_json = $6, used = false"
|
||||
|
||||
const selectFallbackUnusedAlgorithmsSQL = "" +
|
||||
"SELECT algorithm FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2 AND used = false"
|
||||
|
||||
const selectFallbackKeysByAlgorithmSQL = "" +
|
||||
"SELECT key_id, key_json FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 ORDER BY used ASC LIMIT 1"
|
||||
|
||||
const deleteFallbackKeysSQL = "" +
|
||||
"DELETE FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2"
|
||||
|
||||
const updateFallbackKeyUsedSQL = "" +
|
||||
"UPDATE keyserver_fallback_keys SET used=true WHERE user_id = $1 AND device_id = $2 AND key_id = $3 AND algorithm = $4"
|
||||
|
||||
type fallbackKeysStatements struct {
|
||||
db *sql.DB
|
||||
upsertKeysStmt *sql.Stmt
|
||||
selectUnusedAlgorithmsStmt *sql.Stmt
|
||||
selectKeyByAlgorithmStmt *sql.Stmt
|
||||
deleteFallbackKeysStmt *sql.Stmt
|
||||
updateFallbackKeyUsedStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteFallbackKeysTable(db *sql.DB) (tables.FallbackKeys, error) {
|
||||
s := &fallbackKeysStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(fallbackKeysSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertKeysStmt, upsertFallbackKeysSQL},
|
||||
{&s.selectUnusedAlgorithmsStmt, selectFallbackUnusedAlgorithmsSQL},
|
||||
{&s.selectKeyByAlgorithmStmt, selectFallbackKeysByAlgorithmSQL},
|
||||
{&s.deleteFallbackKeysStmt, deleteFallbackKeysSQL},
|
||||
{&s.updateFallbackKeyUsedStmt, updateFallbackKeyUsedSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *fallbackKeysStatements) SelectUnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error) {
|
||||
rows, err := s.selectUnusedAlgorithmsStmt.QueryContext(ctx, userID, deviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
|
||||
algos := []string{}
|
||||
for rows.Next() {
|
||||
var algorithm string
|
||||
if err = rows.Scan(&algorithm); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
algos = append(algos, algorithm)
|
||||
}
|
||||
return algos, rows.Err()
|
||||
}
|
||||
|
||||
func (s *fallbackKeysStatements) InsertFallbackKeys(ctx context.Context, txn *sql.Tx, keys api.FallbackKeys) ([]string, error) {
|
||||
now := time.Now().Unix()
|
||||
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
||||
algo, keyID := keys.Split(keyIDWithAlgo)
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
|
||||
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return s.SelectUnusedFallbackKeyAlgorithms(ctx, keys.UserID, keys.DeviceID)
|
||||
}
|
||||
|
||||
func (s *fallbackKeysStatements) DeleteFallbackKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deleteFallbackKeysStmt).ExecContext(ctx, userID, deviceID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *fallbackKeysStatements) SelectAndUpdateFallbackKey(
|
||||
ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
|
||||
) (map[string]json.RawMessage, error) {
|
||||
var keyID string
|
||||
var keyJSON string
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
_, err = sqlutil.TxStmtContext(ctx, txn, s.updateFallbackKeyUsedStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
|
||||
return map[string]json.RawMessage{
|
||||
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
||||
}, err
|
||||
}
|
|
@ -138,6 +138,10 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fk, err := NewSqliteFallbackKeysTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dk, err := NewSqliteDeviceKeysTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -161,6 +165,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
|
|||
|
||||
return &shared.KeyDatabase{
|
||||
OneTimeKeysTable: otk,
|
||||
FallbackKeysTable: fk,
|
||||
DeviceKeysTable: dk,
|
||||
KeyChangesTable: kc,
|
||||
StaleDeviceListsTable: sdl,
|
||||
|
|
|
@ -809,3 +809,42 @@ func TestOneTimeKeys(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFallbackKeys(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, clean := mustCreateKeyDatabase(t, dbType)
|
||||
defer clean()
|
||||
userID := "@alice:localhost"
|
||||
deviceID := "alice_device"
|
||||
fk := api.FallbackKeys{
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
KeyJSON: map[string]json.RawMessage{"curve25519:KEY1": []byte(`{"key":"v1"}`)},
|
||||
}
|
||||
|
||||
_, err := db.StoreFallbackKeys(ctx, fk)
|
||||
MustNotError(t, err)
|
||||
|
||||
unused, err := db.UnusedFallbackKeyAlgorithms(ctx, userID, deviceID)
|
||||
MustNotError(t, err)
|
||||
if c := len(unused); c != 1 {
|
||||
t.Fatalf("Expected 1 unused key algorithm, got %d", c)
|
||||
}
|
||||
if unused[0] != "curve25519" {
|
||||
t.Fatalf("Expected unused key algorithm to be 'curve25519', got '%s'", unused[0])
|
||||
}
|
||||
|
||||
// No other one-time keys have been uploaded so we expect to get the fallback key instead.
|
||||
claimed, err := db.ClaimKeys(ctx, map[string]map[string]string{userID: {deviceID: "curve25519"}})
|
||||
MustNotError(t, err)
|
||||
|
||||
switch {
|
||||
case claimed[0].UserID != fk.UserID:
|
||||
t.Fatalf("Claimed user ID ID doesn't match, got %q, want %q", claimed[0].UserID, fk.DeviceID)
|
||||
case claimed[0].DeviceID != fk.DeviceID:
|
||||
t.Fatalf("Claimed device ID doesn't match, got %q, want %q", claimed[0].DeviceID, fk.DeviceID)
|
||||
case claimed[0].KeyJSON["curve25519:KEY1"] == nil:
|
||||
t.Fatalf("Claimed key JSON for curve25519:KEY1 not found")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -170,6 +170,13 @@ type DeviceKeys interface {
|
|||
DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
|
||||
}
|
||||
|
||||
type FallbackKeys interface {
|
||||
SelectUnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error)
|
||||
InsertFallbackKeys(ctx context.Context, txn *sql.Tx, keys api.FallbackKeys) ([]string, error)
|
||||
DeleteFallbackKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
|
||||
SelectAndUpdateFallbackKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error)
|
||||
}
|
||||
|
||||
type KeyChanges interface {
|
||||
InsertKeyChange(ctx context.Context, userID string) (int64, error)
|
||||
// SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets.
|
||||
|
|
Loading…
Add table
Reference in a new issue