Support for fallback keys (#3451)

Backports support for fallback keys from Harmony, which should make E2EE
more reliable in the face of OTK exhaustion.

Signed-off-by: Neil Alexander <git@neilalexander.dev>

Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>

[skip ci]
This commit is contained in:
Neil 2024-12-17 19:19:15 +01:00 committed by GitHub
parent c3d7a34c15
commit 78dbf21c5f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 446 additions and 20 deletions

View file

@ -35,6 +35,7 @@ func DeviceOTKCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID, deviceI
return queryRes.Error
}
res.DeviceListsOTKCount = queryRes.Count.KeyCount
res.DeviceListsUnusedFallbackAlgorithms = queryRes.UnusedFallbackAlgorithms
return nil
}

View file

@ -350,13 +350,14 @@ type ToDeviceResponse struct {
// Response represents a /sync API response. See https://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-client-r0-sync
type Response struct {
NextBatch StreamingToken `json:"next_batch"`
AccountData *ClientEvents `json:"account_data,omitempty"`
Presence *ClientEvents `json:"presence,omitempty"`
Rooms *RoomsResponse `json:"rooms,omitempty"`
ToDevice *ToDeviceResponse `json:"to_device,omitempty"`
DeviceLists *DeviceLists `json:"device_lists,omitempty"`
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
NextBatch StreamingToken `json:"next_batch"`
AccountData *ClientEvents `json:"account_data,omitempty"`
Presence *ClientEvents `json:"presence,omitempty"`
Rooms *RoomsResponse `json:"rooms,omitempty"`
ToDevice *ToDeviceResponse `json:"to_device,omitempty"`
DeviceLists *DeviceLists `json:"device_lists,omitempty"`
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
DeviceListsUnusedFallbackAlgorithms []string `json:"device_unused_fallback_key_types"`
}
func (r Response) MarshalJSON() ([]byte, error) {
@ -419,6 +420,7 @@ func NewResponse() *Response {
res.DeviceLists = &DeviceLists{}
res.ToDevice = &ToDeviceResponse{}
res.DeviceListsOTKCount = map[string]int{}
res.DeviceListsUnusedFallbackAlgorithms = []string{}
return &res
}

View file

@ -792,3 +792,5 @@ User can invite remote user to room with version 11
Remote user can backfill in a room with version 11
Can reject invites over federation for rooms with version 11
Can receive redactions from regular users over federation in room version 11
Can upload self-signing keys
uploading signed devices gets propagated over federation

View file

@ -788,12 +788,30 @@ type OneTimeKeysCount struct {
KeyCount map[string]int
}
// FallbackKeys represents a set of fallback keys for a single device
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
type FallbackKeys struct {
// The user who owns this device
UserID string
// The device ID of this device
DeviceID string
// A map of algorithm:key_id => key JSON
KeyJSON map[string]json.RawMessage
}
// Split a key in KeyJSON into algorithm and key ID
func (k *FallbackKeys) Split(keyIDWithAlgo string) (algo string, keyID string) {
segments := strings.Split(keyIDWithAlgo, ":")
return segments[0], segments[1]
}
// PerformUploadKeysRequest is the request to PerformUploadKeys
type PerformUploadKeysRequest struct {
UserID string // Required - User performing the request
DeviceID string // Optional - Device performing the request, for fetching OTK count
DeviceKeys []DeviceKeys
OneTimeKeys []OneTimeKeys
UserID string // Required - User performing the request
DeviceID string // Optional - Device performing the request, for fetching OTK count
DeviceKeys []DeviceKeys
OneTimeKeys []OneTimeKeys
FallbackKeys []FallbackKeys
// OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
// the display name for their respective device, and NOT to modify the keys. The key
// itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths.
@ -810,8 +828,9 @@ type PerformUploadKeysResponse struct {
// A fatal error when processing e.g database failures
Error *KeyError
// A map of user_id -> device_id -> Error for tracking failures.
KeyErrors map[string]map[string]*KeyError
OneTimeKeyCounts []OneTimeKeysCount
KeyErrors map[string]map[string]*KeyError
OneTimeKeyCounts []OneTimeKeysCount
FallbackKeysUnusedAlgorithms []string
}
// PerformDeleteKeysRequest asks the keyserver to forget about certain
@ -917,8 +936,9 @@ type QueryOneTimeKeysRequest struct {
type QueryOneTimeKeysResponse struct {
// OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84
Count OneTimeKeysCount
Error *KeyError
Count OneTimeKeysCount
UnusedFallbackAlgorithms []string
Error *KeyError
}
type QueryDeviceMessagesRequest struct {

View file

@ -44,14 +44,22 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor
if len(req.DeviceKeys) > 0 {
a.uploadLocalDeviceKeys(ctx, req, res)
}
if len(req.OneTimeKeys) > 0 {
a.uploadOneTimeKeys(ctx, req, res)
if len(req.OneTimeKeys) > 0 || len(req.FallbackKeys) > 0 {
a.uploadOneTimeAndFallbackKeys(ctx, req, res)
}
otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
if err != nil {
return err
}
algos, err := a.KeyDatabase.UnusedFallbackKeyAlgorithms(ctx, req.UserID, req.DeviceID)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("Failed to query unused fallback algorithms: %s", err),
}
return nil
}
res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks}
res.FallbackKeysUnusedAlgorithms = algos
return nil
}
@ -169,7 +177,15 @@ func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOn
}
return nil
}
algos, err := a.KeyDatabase.UnusedFallbackKeyAlgorithms(ctx, req.UserID, req.DeviceID)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("Failed to query unused fallback algorithms: %s", err),
}
return nil
}
res.Count = *count
res.UnusedFallbackAlgorithms = algos
return nil
}
@ -507,6 +523,9 @@ func (a *UserInternalAPI) queryRemoteKeysOnServer(
for userID := range userIDsForAllDevices {
err := a.Updater.ManualUpdate(context.Background(), spec.ServerName(serverName), userID)
if err != nil {
if errors.Is(err, context.Canceled) {
return
}
logrus.WithFields(logrus.Fields{
logrus.ErrorKey: err,
"user_id": userID,
@ -520,6 +539,9 @@ func (a *UserInternalAPI) queryRemoteKeysOnServer(
// user so the fact that we're populating all devices here isn't a problem so long as we have devices.
err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil)
if err != nil {
if errors.Is(err, context.Canceled) {
return
}
logrus.WithFields(logrus.Fields{
logrus.ErrorKey: err,
"user_id": userID,
@ -715,7 +737,7 @@ func (a *UserInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Pe
}
}
func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
func (a *UserInternalAPI) uploadOneTimeAndFallbackKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
if req.UserID == "" {
res.Error = &api.KeyError{
Err: "user ID missing",
@ -768,7 +790,32 @@ func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perfor
// collect counts
res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts)
}
if len(req.FallbackKeys) > 0 {
if err := a.KeyDatabase.DeleteFallbackKeys(ctx, req.UserID, req.DeviceID); err != nil {
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: fmt.Sprintf("%s device %s : failed to clear fallback keys: %s", req.UserID, req.DeviceID, err.Error()),
})
return
}
for _, key := range req.FallbackKeys {
// grab existing keys based on (user/device/algorithm/key ID)
keyIDsWithAlgorithms := make([]string, len(key.KeyJSON))
i := 0
for keyIDWithAlgo := range key.KeyJSON {
keyIDsWithAlgorithms[i] = keyIDWithAlgo
i++
}
unused, err := a.KeyDatabase.StoreFallbackKeys(ctx, key)
if err != nil {
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: fmt.Sprintf("%s device %s : failed to store fallback keys: %s", req.UserID, req.DeviceID, err.Error()),
})
continue
}
// collect counts
res.FallbackKeysUnusedAlgorithms = unused
}
}
}
func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error {

View file

@ -167,6 +167,15 @@ type KeyDatabase interface {
// OneTimeKeysCount returns a count of all OTKs for this device.
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
// StoreFallbackKeys persists the given fallback keys.
StoreFallbackKeys(ctx context.Context, keys api.FallbackKeys) ([]string, error)
// UnusedFallbackKeyAlgorithms returns unused fallback algorithms for this user/device.
UnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error)
// DeleteFallbackKeys deletes all fallback keys for the user.
DeleteFallbackKeys(ctx context.Context, userID, deviceID string) error
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error

View file

@ -0,0 +1,134 @@
// Copyright 2024 New Vector Ltd.
// Copyright 2017 Vector Creations Ltd
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
package postgres
import (
"context"
"database/sql"
"encoding/json"
"time"
"github.com/element-hq/dendrite/internal"
"github.com/element-hq/dendrite/internal/sqlutil"
"github.com/element-hq/dendrite/userapi/api"
"github.com/element-hq/dendrite/userapi/storage/tables"
)
var fallbackKeysSchema = `
-- Stores one-time public keys for users
CREATE TABLE IF NOT EXISTS keyserver_fallback_keys (
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
key_id TEXT NOT NULL,
algorithm TEXT NOT NULL,
ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL,
used BOOLEAN NOT NULL,
-- Clobber based on tuple of user/device/algorithm.
CONSTRAINT keyserver_fallback_keys_unique UNIQUE (user_id, device_id, algorithm)
);
CREATE INDEX IF NOT EXISTS keyserver_fallback_keys_idx ON keyserver_fallback_keys (user_id, device_id);
`
const upsertFallbackKeysSQL = "" +
"INSERT INTO keyserver_fallback_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json, used)" +
" VALUES ($1, $2, $3, $4, $5, $6, false)" +
" ON CONFLICT ON CONSTRAINT keyserver_fallback_keys_unique" +
" DO UPDATE SET key_id = $3, key_json = $6, used = false"
const selectFallbackUnusedAlgorithmsSQL = "" +
"SELECT algorithm FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2 AND used = false"
const selectFallbackKeysByAlgorithmSQL = "" +
"SELECT key_id, key_json FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 ORDER BY used ASC LIMIT 1"
const deleteFallbackKeysSQL = "" +
"DELETE FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2"
const updateFallbackKeyUsedSQL = "" +
"UPDATE keyserver_fallback_keys SET used=true WHERE user_id = $1 AND device_id = $2 AND key_id = $3 AND algorithm = $4"
type fallbackKeysStatements struct {
db *sql.DB
upsertKeysStmt *sql.Stmt
selectUnusedAlgorithmsStmt *sql.Stmt
selectKeyByAlgorithmStmt *sql.Stmt
deleteFallbackKeysStmt *sql.Stmt
updateFallbackKeyUsedStmt *sql.Stmt
}
func NewPostgresFallbackKeysTable(db *sql.DB) (tables.FallbackKeys, error) {
s := &fallbackKeysStatements{
db: db,
}
_, err := db.Exec(fallbackKeysSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.upsertKeysStmt, upsertFallbackKeysSQL},
{&s.selectUnusedAlgorithmsStmt, selectFallbackUnusedAlgorithmsSQL},
{&s.selectKeyByAlgorithmStmt, selectFallbackKeysByAlgorithmSQL},
{&s.deleteFallbackKeysStmt, deleteFallbackKeysSQL},
{&s.updateFallbackKeyUsedStmt, updateFallbackKeyUsedSQL},
}.Prepare(db)
}
func (s *fallbackKeysStatements) SelectUnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error) {
rows, err := s.selectUnusedAlgorithmsStmt.QueryContext(ctx, userID, deviceID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
algos := []string{}
for rows.Next() {
var algorithm string
if err = rows.Scan(&algorithm); err != nil {
return nil, err
}
algos = append(algos, algorithm)
}
return algos, rows.Err()
}
func (s *fallbackKeysStatements) InsertFallbackKeys(ctx context.Context, txn *sql.Tx, keys api.FallbackKeys) ([]string, error) {
now := time.Now().Unix()
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
algo, keyID := keys.Split(keyIDWithAlgo)
_, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
)
if err != nil {
return nil, err
}
}
return s.SelectUnusedFallbackKeyAlgorithms(ctx, keys.UserID, keys.DeviceID)
}
func (s *fallbackKeysStatements) DeleteFallbackKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
_, err := sqlutil.TxStmt(txn, s.deleteFallbackKeysStmt).ExecContext(ctx, userID, deviceID)
return err
}
func (s *fallbackKeysStatements) SelectAndUpdateFallbackKey(
ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
) (map[string]json.RawMessage, error) {
var keyID string
var keyJSON string
err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
_, err = sqlutil.TxStmtContext(ctx, txn, s.updateFallbackKeyUsedStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
return map[string]json.RawMessage{
algorithm + ":" + keyID: json.RawMessage(keyJSON),
}, err
}

View file

@ -141,6 +141,10 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
if err != nil {
return nil, err
}
fk, err := NewPostgresFallbackKeysTable(db)
if err != nil {
return nil, err
}
dk, err := NewPostgresDeviceKeysTable(db)
if err != nil {
return nil, err
@ -164,6 +168,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
return &shared.KeyDatabase{
OneTimeKeysTable: otk,
FallbackKeysTable: fk,
DeviceKeysTable: dk,
KeyChangesTable: kc,
StaleDeviceListsTable: sdl,

View file

@ -57,6 +57,7 @@ type Database struct {
type KeyDatabase struct {
OneTimeKeysTable tables.OneTimeKeys
FallbackKeysTable tables.FallbackKeys
DeviceKeysTable tables.DeviceKeys
KeyChangesTable tables.KeyChanges
StaleDeviceListsTable tables.StaleDeviceLists
@ -937,6 +938,22 @@ func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID str
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
}
func (d *KeyDatabase) StoreFallbackKeys(ctx context.Context, keys api.FallbackKeys) (unused []string, err error) {
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
unused, err = d.FallbackKeysTable.InsertFallbackKeys(ctx, txn, keys)
return err
})
return
}
func (d *KeyDatabase) DeleteFallbackKeys(ctx context.Context, userID, deviceID string) error {
return d.FallbackKeysTable.DeleteFallbackKeys(ctx, nil, userID, deviceID)
}
func (d *KeyDatabase) UnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error) {
return d.FallbackKeysTable.SelectUnusedFallbackKeyAlgorithms(ctx, userID, deviceID)
}
func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
}
@ -999,6 +1016,12 @@ func (d *KeyDatabase) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map
if err != nil {
return err
}
if len(keyJSON) == 0 {
keyJSON, err = d.FallbackKeysTable.SelectAndUpdateFallbackKey(ctx, txn, userID, deviceID, algo)
if err != nil {
return err
}
}
if keyJSON != nil {
result = append(result, api.OneTimeKeys{
UserID: userID,

View file

@ -0,0 +1,132 @@
// Copyright 2024 New Vector Ltd.
// Copyright 2017 Vector Creations Ltd
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
package sqlite3
import (
"context"
"database/sql"
"encoding/json"
"time"
"github.com/element-hq/dendrite/internal"
"github.com/element-hq/dendrite/internal/sqlutil"
"github.com/element-hq/dendrite/userapi/api"
"github.com/element-hq/dendrite/userapi/storage/tables"
)
var fallbackKeysSchema = `
-- Stores one-time public keys for users
CREATE TABLE IF NOT EXISTS keyserver_fallback_keys (
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
key_id TEXT NOT NULL,
algorithm TEXT NOT NULL,
ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL,
used BOOLEAN NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS keyserver_fallback_keys_unique_idx ON keyserver_fallback_keys(user_id, device_id, algorithm);
CREATE INDEX IF NOT EXISTS keyserver_fallback_keys_idx ON keyserver_fallback_keys (user_id, device_id);
`
const upsertFallbackKeysSQL = "" +
"INSERT INTO keyserver_fallback_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json, used)" +
" VALUES ($1, $2, $3, $4, $5, $6, false)" +
" ON CONFLICT (user_id, device_id, algorithm)" +
" DO UPDATE SET key_id = $3, key_json = $6, used = false"
const selectFallbackUnusedAlgorithmsSQL = "" +
"SELECT algorithm FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2 AND used = false"
const selectFallbackKeysByAlgorithmSQL = "" +
"SELECT key_id, key_json FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 ORDER BY used ASC LIMIT 1"
const deleteFallbackKeysSQL = "" +
"DELETE FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2"
const updateFallbackKeyUsedSQL = "" +
"UPDATE keyserver_fallback_keys SET used=true WHERE user_id = $1 AND device_id = $2 AND key_id = $3 AND algorithm = $4"
type fallbackKeysStatements struct {
db *sql.DB
upsertKeysStmt *sql.Stmt
selectUnusedAlgorithmsStmt *sql.Stmt
selectKeyByAlgorithmStmt *sql.Stmt
deleteFallbackKeysStmt *sql.Stmt
updateFallbackKeyUsedStmt *sql.Stmt
}
func NewSqliteFallbackKeysTable(db *sql.DB) (tables.FallbackKeys, error) {
s := &fallbackKeysStatements{
db: db,
}
_, err := db.Exec(fallbackKeysSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.upsertKeysStmt, upsertFallbackKeysSQL},
{&s.selectUnusedAlgorithmsStmt, selectFallbackUnusedAlgorithmsSQL},
{&s.selectKeyByAlgorithmStmt, selectFallbackKeysByAlgorithmSQL},
{&s.deleteFallbackKeysStmt, deleteFallbackKeysSQL},
{&s.updateFallbackKeyUsedStmt, updateFallbackKeyUsedSQL},
}.Prepare(db)
}
func (s *fallbackKeysStatements) SelectUnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error) {
rows, err := s.selectUnusedAlgorithmsStmt.QueryContext(ctx, userID, deviceID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
algos := []string{}
for rows.Next() {
var algorithm string
if err = rows.Scan(&algorithm); err != nil {
return nil, err
}
algos = append(algos, algorithm)
}
return algos, rows.Err()
}
func (s *fallbackKeysStatements) InsertFallbackKeys(ctx context.Context, txn *sql.Tx, keys api.FallbackKeys) ([]string, error) {
now := time.Now().Unix()
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
algo, keyID := keys.Split(keyIDWithAlgo)
_, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
)
if err != nil {
return nil, err
}
}
return s.SelectUnusedFallbackKeyAlgorithms(ctx, keys.UserID, keys.DeviceID)
}
func (s *fallbackKeysStatements) DeleteFallbackKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
_, err := sqlutil.TxStmt(txn, s.deleteFallbackKeysStmt).ExecContext(ctx, userID, deviceID)
return err
}
func (s *fallbackKeysStatements) SelectAndUpdateFallbackKey(
ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
) (map[string]json.RawMessage, error) {
var keyID string
var keyJSON string
err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
_, err = sqlutil.TxStmtContext(ctx, txn, s.updateFallbackKeyUsedStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
return map[string]json.RawMessage{
algorithm + ":" + keyID: json.RawMessage(keyJSON),
}, err
}

View file

@ -138,6 +138,10 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
if err != nil {
return nil, err
}
fk, err := NewSqliteFallbackKeysTable(db)
if err != nil {
return nil, err
}
dk, err := NewSqliteDeviceKeysTable(db)
if err != nil {
return nil, err
@ -161,6 +165,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
return &shared.KeyDatabase{
OneTimeKeysTable: otk,
FallbackKeysTable: fk,
DeviceKeysTable: dk,
KeyChangesTable: kc,
StaleDeviceListsTable: sdl,

View file

@ -809,3 +809,42 @@ func TestOneTimeKeys(t *testing.T) {
}
})
}
func TestFallbackKeys(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, clean := mustCreateKeyDatabase(t, dbType)
defer clean()
userID := "@alice:localhost"
deviceID := "alice_device"
fk := api.FallbackKeys{
UserID: userID,
DeviceID: deviceID,
KeyJSON: map[string]json.RawMessage{"curve25519:KEY1": []byte(`{"key":"v1"}`)},
}
_, err := db.StoreFallbackKeys(ctx, fk)
MustNotError(t, err)
unused, err := db.UnusedFallbackKeyAlgorithms(ctx, userID, deviceID)
MustNotError(t, err)
if c := len(unused); c != 1 {
t.Fatalf("Expected 1 unused key algorithm, got %d", c)
}
if unused[0] != "curve25519" {
t.Fatalf("Expected unused key algorithm to be 'curve25519', got '%s'", unused[0])
}
// No other one-time keys have been uploaded so we expect to get the fallback key instead.
claimed, err := db.ClaimKeys(ctx, map[string]map[string]string{userID: {deviceID: "curve25519"}})
MustNotError(t, err)
switch {
case claimed[0].UserID != fk.UserID:
t.Fatalf("Claimed user ID ID doesn't match, got %q, want %q", claimed[0].UserID, fk.DeviceID)
case claimed[0].DeviceID != fk.DeviceID:
t.Fatalf("Claimed device ID doesn't match, got %q, want %q", claimed[0].DeviceID, fk.DeviceID)
case claimed[0].KeyJSON["curve25519:KEY1"] == nil:
t.Fatalf("Claimed key JSON for curve25519:KEY1 not found")
}
})
}

View file

@ -170,6 +170,13 @@ type DeviceKeys interface {
DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
}
type FallbackKeys interface {
SelectUnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error)
InsertFallbackKeys(ctx context.Context, txn *sql.Tx, keys api.FallbackKeys) ([]string, error)
DeleteFallbackKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
SelectAndUpdateFallbackKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error)
}
type KeyChanges interface {
InsertKeyChange(ctx context.Context, userID string) (int64, error)
// SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets.