This commit is contained in:
Roman Isaev 2025-03-03 00:29:58 +00:00 committed by GitHub
commit a0788425eb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
64 changed files with 3604 additions and 530 deletions

View file

@ -28,7 +28,6 @@ run:
# the dependency descriptions in go.mod.
#modules-download-mode: (release|readonly|vendor)
# output configuration options
output:
# colored-line-number|line-number|json|tab|checkstyle|code-climate, default is "colored-line-number"
@ -41,7 +40,6 @@ output:
# print linter name in the end of issue text, default is true
print-linter-name: true
# all available settings of specific linters
linters-settings:
errcheck:
@ -72,22 +70,12 @@ linters-settings:
- (github.com/golangci/golangci-lint/pkg/logutils.Log).Warnf
- (github.com/golangci/golangci-lint/pkg/logutils.Log).Errorf
- (github.com/golangci/golangci-lint/pkg/logutils.Log).Fatalf
golint:
# minimal confidence for issues, default is 0.8
min-confidence: 0.8
gofmt:
# simplify code: gofmt with `-s` option, true by default
simplify: true
goimports:
# put imports beginning with prefix after 3rd-party packages;
# it's a comma-separated list of prefixes
#local-prefixes: github.com/org/project
gocyclo:
# minimal code complexity to report, 30 by default (but we recommend 10-20)
min-complexity: 25
maligned:
# print struct with more effective memory layout or not, false by default
suggest-new: true
dupl:
# tokens count to trigger issue, 150 by default
threshold: 100
@ -96,30 +84,17 @@ linters-settings:
min-len: 3
# minimal occurrences count to trigger, 3 by default
min-occurrences: 3
depguard:
list-type: blacklist
include-go-root: false
packages:
# - github.com/davecgh/go-spew/spew
misspell:
# Correct spellings using locale preferences for US or UK.
# Default is to use a neutral variety of English.
# Setting locale to US will correct the British spelling of 'colour' to 'color'.
locale: UK
ignore-words:
# - someword
lll:
# max line length, lines longer will be reported. Default is 120.
# '\t' is counted as 1 character by default, and can be changed with the tab-width option
line-length: 96
# tab width in spaces. Default to 1.
tab-width: 1
unused:
# treat code as a program (not a library) and report unused exported identifiers; default is false.
# XXX: if you enable this setting, unused will report a lot of false-positives in text editors:
# if it's called for subdir of a project it can't find funcs usages. All text editor integrations
# with golangci-lint call it on a directory with the changed file.
check-exported: false
unparam:
# Inspect exported functions, default is false. Set to true if no external program/library imports your code.
# XXX: if you enable this setting, unparam will report a lot of false-positives in text editors:
@ -167,7 +142,6 @@ linters:
- gosimple
- govet
- ineffassign
- megacheck
- misspell # Check code comments, whereas misspell in CI checks *.md files
- nakedret
- staticcheck
@ -182,19 +156,14 @@ linters:
- gochecknoinits
- gocritic
- gofmt
- golint
- gosec # Should turn back on soon
- interfacer
- lll
- maligned
- prealloc # Should turn back on soon
- scopelint
- stylecheck
- typecheck # Should turn back on soon
- unconvert # Should turn back on soon
- goconst # Slightly annoying, as it reports "issues" in SQL statements
disable-all: false
presets:
fast: false
@ -217,13 +186,6 @@ issues:
- bin
- docs
# List of regexps of issue texts to exclude, empty list by default.
# But independently from this option we use default exclude patterns,
# it can be disabled by `exclude-use-default: false`. To list all
# excluded by default patterns execute `golangci-lint run --help`
exclude:
# - abcdef
# Excluding configuration per-path, per-linter, per-text and per-source
exclude-rules:
# Exclude some linters from running on tests files.

View file

@ -3,7 +3,7 @@
#
# base installs required dependencies and runs go mod download to cache dependencies
#
FROM --platform=${BUILDPLATFORM} docker.io/golang:1.22-alpine AS base
FROM --platform=${BUILDPLATFORM} docker.io/golang:1.23-alpine AS base
RUN apk --update --no-cache add bash build-base curl git
#

View file

@ -53,6 +53,7 @@ func NewInternalAPI(
if err := generateAppServiceAccount(userAPI, appservice, cfg.Global.ServerName); err != nil {
logrus.WithFields(logrus.Fields{
"appservice": appservice.ID,
"as_token": appservice.ASToken,
}).WithError(err).Panicf("failed to generate bot account for appservice")
}
}
@ -92,12 +93,13 @@ func generateAppServiceAccount(
}
var devRes userapi.PerformDeviceCreationResponse
err = userAPI.PerformDeviceCreation(context.Background(), &userapi.PerformDeviceCreationRequest{
Localpart: as.SenderLocalpart,
ServerName: serverName,
AccessToken: as.ASToken,
DeviceID: &as.SenderLocalpart,
DeviceDisplayName: &as.SenderLocalpart,
NoDeviceListUpdate: true,
Localpart: as.SenderLocalpart,
ServerName: serverName,
AccessToken: as.ASToken,
DeviceID: &as.SenderLocalpart,
DeviceDisplayName: &as.SenderLocalpart,
NoDeviceListUpdate: true,
AccessTokenUniqueConstraintDisabled: false,
}, &devRes)
return err
}

View file

@ -15,6 +15,7 @@ import (
"time"
"github.com/element-hq/dendrite/clientapi"
"github.com/element-hq/dendrite/clientapi/auth"
"github.com/element-hq/dendrite/clientapi/auth/authtypes"
"github.com/element-hq/dendrite/federationapi/statistics"
"github.com/element-hq/dendrite/internal/httputil"
@ -138,7 +139,7 @@ func TestAppserviceInternalAPI(t *testing.T) {
as := &config.ApplicationService{
ID: "someID",
URL: srv.URL,
ASToken: "",
ASToken: util.RandomString(12),
HSToken: "",
SenderLocalpart: "senderLocalPart",
NamespaceMap: map[string][]config.ApplicationServiceNamespace{
@ -232,7 +233,7 @@ func TestAppserviceInternalAPI_UnixSocket_Simple(t *testing.T) {
as := &config.ApplicationService{
ID: "someID",
URL: fmt.Sprintf("unix://%s", socket),
ASToken: "",
ASToken: util.RandomString(8),
HSToken: "",
SenderLocalpart: "senderLocalPart",
NamespaceMap: map[string][]config.ApplicationServiceNamespace{
@ -376,7 +377,7 @@ func TestRoomserverConsumerOneInvite(t *testing.T) {
as := &config.ApplicationService{
ID: "someID",
URL: srv.URL,
ASToken: "",
ASToken: util.RandomString(8),
HSToken: "",
SenderLocalpart: "senderLocalPart",
NamespaceMap: map[string][]config.ApplicationServiceNamespace{
@ -446,7 +447,8 @@ func TestOutputAppserviceEvent(t *testing.T) {
}
usrAPI := userapi.NewInternalAPI(processCtx, cfg, cm, natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
clientapi.AddPublicRoutes(processCtx, routers, cfg, natsInstance, nil, rsAPI, nil, nil, nil, usrAPI, nil, nil, caching.DisableMetrics)
userVerifier := auth.DefaultUserVerifier{UserAPI: usrAPI}
clientapi.AddPublicRoutes(processCtx, routers, cfg, natsInstance, nil, rsAPI, nil, nil, nil, usrAPI, nil, nil, &userVerifier, caching.DisableMetrics)
createAccessTokens(t, accessTokens, usrAPI, processCtx.Context(), routers)
room := test.NewRoom(t, alice)
@ -508,7 +510,7 @@ func TestOutputAppserviceEvent(t *testing.T) {
as := &config.ApplicationService{
ID: "someID",
URL: srv.URL,
ASToken: "",
ASToken: util.RandomString(8),
HSToken: "",
SenderLocalpart: "senderLocalPart",
NamespaceMap: map[string][]config.ApplicationServiceNamespace{
@ -537,7 +539,7 @@ func TestOutputAppserviceEvent(t *testing.T) {
}
// Start the syncAPI to have `/joined_members` available
syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, natsInstance, usrAPI, rsAPI, caches, caching.DisableMetrics)
syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, natsInstance, usrAPI, rsAPI, caches, &userVerifier, caching.DisableMetrics)
// start the consumer
appservice.NewInternalAPI(processCtx, cfg, natsInstance, usrAPI, rsAPI)

View file

@ -1,6 +1,6 @@
#syntax=docker/dockerfile:1.2
FROM golang:1.22-bookworm as build
FROM golang:1.23-bookworm as build
RUN apt-get update && apt-get install -y sqlite3
WORKDIR /build

View file

@ -8,7 +8,7 @@
#
# Use these mounts to make use of this dockerfile:
# COMPLEMENT_HOST_MOUNTS='/your/local/dendrite:/dendrite:ro;/your/go/path:/go:ro'
FROM golang:1.22-bookworm
FROM golang:1.23-bookworm
RUN apt-get update && apt-get install -y sqlite3
ENV SERVER_NAME=localhost

View file

@ -1,6 +1,6 @@
#syntax=docker/dockerfile:1.2
FROM golang:1.22-bookworm as build
FROM golang:1.23-bookworm as build
RUN apt-get update && apt-get install -y postgresql
WORKDIR /build

View file

@ -1,3 +1,8 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
package clientapi
import (
@ -11,6 +16,8 @@ import (
"testing"
"time"
"github.com/element-hq/dendrite/userapi/types"
"github.com/element-hq/dendrite/federationapi"
"github.com/element-hq/dendrite/internal/caching"
"github.com/element-hq/dendrite/internal/httputil"
@ -27,6 +34,7 @@ import (
"github.com/tidwall/gjson"
capi "github.com/element-hq/dendrite/clientapi/api"
"github.com/element-hq/dendrite/clientapi/auth"
"github.com/element-hq/dendrite/test"
"github.com/element-hq/dendrite/test/testrig"
"github.com/element-hq/dendrite/userapi"
@ -48,7 +56,8 @@ func TestAdminCreateToken(t *testing.T) {
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
accessTokens := map[*test.User]userDevice{
aliceAdmin: {},
bob: {},
@ -199,7 +208,8 @@ func TestAdminListRegistrationTokens(t *testing.T) {
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
accessTokens := map[*test.User]userDevice{
aliceAdmin: {},
bob: {},
@ -317,7 +327,8 @@ func TestAdminGetRegistrationToken(t *testing.T) {
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
accessTokens := map[*test.User]userDevice{
aliceAdmin: {},
bob: {},
@ -418,7 +429,8 @@ func TestAdminDeleteRegistrationToken(t *testing.T) {
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
accessTokens := map[*test.User]userDevice{
aliceAdmin: {},
bob: {},
@ -512,7 +524,8 @@ func TestAdminUpdateRegistrationToken(t *testing.T) {
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
accessTokens := map[*test.User]userDevice{
aliceAdmin: {},
bob: {},
@ -697,7 +710,8 @@ func TestAdminResetPassword(t *testing.T) {
// Needed for changing the password/login
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
// We mostly need the userAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
// Create the users in the userapi and login
accessTokens := map[*test.User]userDevice{
@ -794,7 +808,8 @@ func TestPurgeRoom(t *testing.T) {
rsAPI.SetFederationAPI(fsAPI, nil)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, userAPI, rsAPI, caches, caching.DisableMetrics)
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, userAPI, rsAPI, caches, &userVerifier, caching.DisableMetrics)
// Create the room
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
@ -802,7 +817,7 @@ func TestPurgeRoom(t *testing.T) {
}
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
// Create the users in the userapi and login
accessTokens := map[*test.User]userDevice{
@ -872,8 +887,10 @@ func TestAdminEvacuateRoom(t *testing.T) {
t.Fatalf("failed to send events: %v", err)
}
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
// Create the users in the userapi and login
accessTokens := map[*test.User]userDevice{
@ -976,8 +993,10 @@ func TestAdminEvacuateUser(t *testing.T) {
t.Fatalf("failed to send events: %v", err)
}
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
// Create the users in the userapi and login
accessTokens := map[*test.User]userDevice{
@ -1059,8 +1078,10 @@ func TestAdminMarkAsStale(t *testing.T) {
rsAPI.SetFederationAPI(nil, nil)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
// Create the users in the userapi and login
accessTokens := map[*test.User]userDevice{
@ -1147,8 +1168,10 @@ func TestAdminQueryEventReports(t *testing.T) {
t.Fatalf("failed to send events: %v", err)
}
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
accessTokens := map[*test.User]userDevice{
alice: {},
@ -1376,8 +1399,10 @@ func TestEventReportsGetDelete(t *testing.T) {
t.Fatalf("failed to send events: %v", err)
}
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
accessTokens := map[*test.User]userDevice{
alice: {},
@ -1473,3 +1498,840 @@ func TestEventReportsGetDelete(t *testing.T) {
})
})
}
func TestAdminCheckUsernameAvailable(t *testing.T) {
alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
adminToken := "superSecretAdminToken"
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
defer close()
natsInstance := jetstream.NATSInstance{}
// add a vhost
cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{
SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"},
})
// There's no need to add a full config for msc3861 as we need only an admin token
cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"}
cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken}
routers := httputil.NewRouters()
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
// We mostly need the userAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics)
userRes := &uapi.PerformAccountCreationResponse{}
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
AccountType: alice.AccountType,
Localpart: alice.Localpart,
ServerName: cfg.Global.ServerName,
Password: "",
}, userRes); err != nil {
t.Errorf("failed to create account: %s", err)
}
testCases := []struct {
name string
accessToken string
userID string
wantOK bool
isAvailable bool
}{
{name: "Missing auth", accessToken: "", wantOK: false, userID: alice.Localpart, isAvailable: false},
{name: "Alice - user exists", accessToken: adminToken, wantOK: true, userID: alice.Localpart, isAvailable: false},
{name: "Bob - user does not exist", accessToken: adminToken, wantOK: true, userID: "bob", isAvailable: true},
}
for _, tc := range testCases {
tc := tc // ensure we don't accidentally only test the last test case
t.Run(tc.name, func(t *testing.T) {
req := test.NewRequest(t, http.MethodGet, "/_synapse/admin/v1/username_available?username="+tc.userID)
if tc.accessToken != "" {
req.Header.Set("Authorization", "Bearer "+tc.accessToken)
}
rec := httptest.NewRecorder()
routers.SynapseAdmin.ServeHTTP(rec, req)
t.Logf("%s", rec.Body.String())
if tc.wantOK && rec.Code != http.StatusOK || !tc.wantOK && rec.Code != http.StatusUnauthorized {
t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
}
// Nothing more to check, test is done.
if !tc.wantOK {
return
}
b := make(map[string]bool, 1)
_ = json.NewDecoder(rec.Body).Decode(&b)
available, ok := b["available"]
if !ok {
t.Fatal("'available' not found in body")
}
if available != tc.isAvailable {
t.Fatalf("expected 'available' to be %t, got %t instead", tc.isAvailable, available)
}
})
}
})
}
func TestAdminUserDeviceRetrieveCreate(t *testing.T) {
alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
adminToken := "superSecretAdminToken"
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
defer close()
natsInstance := jetstream.NATSInstance{}
// add a vhost
cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{
SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"},
})
// There's no need to add a full config for msc3861 as we need only an admin token
cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"}
cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken}
routers := httputil.NewRouters()
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
// We mostly need the userAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics)
for _, u := range []*test.User{alice, bob} {
userRes := &uapi.PerformAccountCreationResponse{}
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
AccountType: u.AccountType,
Localpart: u.Localpart,
ServerName: cfg.Global.ServerName,
Password: "",
}, userRes); err != nil {
t.Errorf("failed to create account: %s", err)
}
}
t.Run("Missing auth token", func(t *testing.T) {
req := test.NewRequest(t, http.MethodGet, "/_synapse/admin/v2/users/"+alice.ID+"/devices")
rec := httptest.NewRecorder()
routers.SynapseAdmin.ServeHTTP(rec, req)
t.Logf("%s", rec.Body.String())
if rec.Code != http.StatusUnauthorized {
t.Fatalf("expected http status %d, got %d: %s", http.StatusUnauthorized, rec.Code, rec.Body.String())
}
var b spec.MatrixError
_ = json.NewDecoder(rec.Body).Decode(&b)
if b.ErrCode != spec.ErrorMissingToken {
t.Fatalf("expected error code %s, got %s", spec.ErrorMissingToken, b.ErrCode)
}
})
t.Run("Retrieve device", func(t *testing.T) {
var deviceRes uapi.PerformDeviceCreationResponse
if err := userAPI.PerformDeviceCreation(ctx, &uapi.PerformDeviceCreationRequest{
Localpart: alice.Localpart,
ServerName: cfg.Global.ServerName,
AccessTokenUniqueConstraintDisabled: true,
}, &deviceRes); err != nil {
t.Errorf("failed to create account: %s", err)
}
req := test.NewRequest(t, http.MethodGet, "/_synapse/admin/v2/users/"+alice.ID+"/devices")
req.Header.Set("Authorization", "Bearer "+adminToken)
rec := httptest.NewRecorder()
routers.SynapseAdmin.ServeHTTP(rec, req)
t.Logf("%s", rec.Body.String())
var body struct {
Total int `json:"total"`
Devices []struct {
DeviceID string `json:"device_id"`
} `json:"devices"`
}
_ = json.NewDecoder(rec.Body).Decode(&body)
if body.Total != 1 {
t.Errorf("expected 1 device, got %d", body.Total)
}
if len(body.Devices) != 1 {
t.Errorf("expected 1 device, got %d", len(body.Devices))
}
})
t.Run("Create device", func(t *testing.T) {
reqBody := struct {
DeviceID string `json:"device_id"`
}{DeviceID: "devBob"}
req := test.NewRequest(t, http.MethodPost, "/_synapse/admin/v2/users/"+bob.ID+"/devices", test.WithJSONBody(t, reqBody))
req.Header.Set("Authorization", "Bearer "+adminToken)
rec := httptest.NewRecorder()
routers.SynapseAdmin.ServeHTTP(rec, req)
t.Logf("%s", rec.Body.String())
if rec.Code != http.StatusCreated {
t.Fatalf("expected HTTP status %d, got %d: %s", http.StatusCreated, rec.Code, rec.Body.String())
}
var res uapi.QueryDevicesResponse
_ = userAPI.QueryDevices(ctx, &uapi.QueryDevicesRequest{UserID: bob.ID}, &res)
if len(res.Devices) != 1 {
t.Errorf("expected 1 device, got %d", len(res.Devices))
}
if res.Devices[0].ID != "devBob" {
t.Errorf("expected device to be devBob, got %s", res.Devices[0].ID)
}
})
})
}
func TestAdminUserDeviceDelete(t *testing.T) {
alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
adminToken := "superSecretAdminToken"
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
defer close()
natsInstance := jetstream.NATSInstance{}
// add a vhost
cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{
SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"},
})
// There's no need to add a full config for msc3861 as we need only an admin token
cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"}
cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken}
routers := httputil.NewRouters()
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
// We mostly need the userAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics)
for _, u := range []*test.User{alice} {
userRes := &uapi.PerformAccountCreationResponse{}
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
AccountType: u.AccountType,
Localpart: u.Localpart,
ServerName: cfg.Global.ServerName,
Password: "",
}, userRes); err != nil {
t.Errorf("failed to create account: %s", err)
}
}
t.Run("Missing auth token", func(t *testing.T) {
req := test.NewRequest(t, http.MethodDelete, "/_synapse/admin/v2/users/"+alice.ID+"/devices/anything")
rec := httptest.NewRecorder()
routers.SynapseAdmin.ServeHTTP(rec, req)
t.Logf("%s", rec.Body.String())
if rec.Code != http.StatusUnauthorized {
t.Fatalf("expected http status %d, got %d: %s", http.StatusUnauthorized, rec.Code, rec.Body.String())
}
var b spec.MatrixError
_ = json.NewDecoder(rec.Body).Decode(&b)
if b.ErrCode != spec.ErrorMissingToken {
t.Fatalf("expected error code %s, got %s", spec.ErrorMissingToken, b.ErrCode)
}
})
t.Run("Delete existing device", func(t *testing.T) {
var deviceRes uapi.PerformDeviceCreationResponse
if err := userAPI.PerformDeviceCreation(ctx, &uapi.PerformDeviceCreationRequest{
Localpart: alice.Localpart,
ServerName: cfg.Global.ServerName,
AccessTokenUniqueConstraintDisabled: true,
}, &deviceRes); err != nil {
t.Errorf("failed to create account: %s", err)
}
req := test.NewRequest(t, http.MethodDelete, "/_synapse/admin/v2/users/"+alice.ID+"/devices/"+deviceRes.Device.ID)
req.Header.Set("Authorization", "Bearer "+adminToken)
rec := httptest.NewRecorder()
routers.SynapseAdmin.ServeHTTP(rec, req)
t.Logf("%s", rec.Body.String())
if rec.Code != http.StatusOK {
t.Fatalf("expected HTTP status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
}
var rs uapi.QueryDevicesResponse
_ = userAPI.QueryDevices(ctx, &uapi.QueryDevicesRequest{UserID: alice.ID}, &rs)
if len(rs.Devices) > 0 {
t.Errorf("expected 0 devices, got %d", len(rs.Devices))
}
})
t.Run("Delete non-existing user's devices", func(t *testing.T) {
req := test.NewRequest(t, http.MethodDelete, "/_synapse/admin/v2/users/"+bob.ID+"/devices/anything")
req.Header.Set("Authorization", "Bearer "+adminToken)
rec := httptest.NewRecorder()
routers.SynapseAdmin.ServeHTTP(rec, req)
t.Logf("%s", rec.Body.String())
if rec.Code != http.StatusOK {
t.Fatalf("expected HTTP status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
}
})
})
}
func TestAdminUserDevicesDelete(t *testing.T) {
alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
adminToken := "superSecretAdminToken"
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
defer close()
natsInstance := jetstream.NATSInstance{}
// add a vhost
cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{
SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"},
})
// There's no need to add a full config for msc3861 as we need only an admin token
cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"}
cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken}
routers := httputil.NewRouters()
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
// We mostly need the userAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics)
for _, u := range []*test.User{alice} {
userRes := &uapi.PerformAccountCreationResponse{}
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
AccountType: u.AccountType,
Localpart: u.Localpart,
ServerName: cfg.Global.ServerName,
Password: "",
}, userRes); err != nil {
t.Errorf("failed to create account: %s", err)
}
}
type payload struct {
Devices []string `json:"devices"`
}
t.Run("Missing auth token", func(t *testing.T) {
req := test.NewRequest(t, http.MethodPost, "/_synapse/admin/v2/users/"+alice.ID+"/delete_devices")
rec := httptest.NewRecorder()
routers.SynapseAdmin.ServeHTTP(rec, req)
t.Logf("%s", rec.Body.String())
if rec.Code != http.StatusUnauthorized {
t.Fatalf("expected http status %d, got %d: %s", http.StatusUnauthorized, rec.Code, rec.Body.String())
}
var b spec.MatrixError
_ = json.NewDecoder(rec.Body).Decode(&b)
if b.ErrCode != spec.ErrorMissingToken {
t.Fatalf("expected error code %s, got %s", spec.ErrorMissingToken, b.ErrCode)
}
})
t.Run("Delete existing user's devices", func(t *testing.T) {
var deviceRes uapi.PerformDeviceCreationResponse
if err := userAPI.PerformDeviceCreation(ctx, &uapi.PerformDeviceCreationRequest{
Localpart: alice.Localpart,
ServerName: cfg.Global.ServerName,
AccessTokenUniqueConstraintDisabled: true,
}, &deviceRes); err != nil {
t.Errorf("failed to create account: %s", err)
}
req := test.NewRequest(
t,
http.MethodPost,
"/_synapse/admin/v2/users/"+alice.ID+"/delete_devices",
test.WithJSONBody(t, payload{Devices: []string{deviceRes.Device.ID}}),
)
req.Header.Set("Authorization", "Bearer "+adminToken)
rec := httptest.NewRecorder()
routers.SynapseAdmin.ServeHTTP(rec, req)
t.Logf("%s", rec.Body.String())
if rec.Code != http.StatusOK {
t.Fatalf("expected HTTP status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
}
var rs uapi.QueryDevicesResponse
_ = userAPI.QueryDevices(ctx, &uapi.QueryDevicesRequest{UserID: alice.ID}, &rs)
if len(rs.Devices) > 0 {
t.Errorf("expected 0 devices, got %d", len(rs.Devices))
}
})
t.Run("Delete non-existing user's devices", func(t *testing.T) {
req := test.NewRequest(
t,
http.MethodPost,
"/_synapse/admin/v2/users/"+bob.ID+"/delete_devices",
test.WithJSONBody(t, payload{Devices: []string{"anyDevID"}}),
)
req.Header.Set("Authorization", "Bearer "+adminToken)
rec := httptest.NewRecorder()
routers.SynapseAdmin.ServeHTTP(rec, req)
t.Logf("%s", rec.Body.String())
if rec.Code != http.StatusOK {
t.Fatalf("expected HTTP status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
}
})
})
}
func TestAdminDeactivateAccount(t *testing.T) {
alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
adminToken := "superSecretAdminToken"
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
defer close()
natsInstance := jetstream.NATSInstance{}
// add a vhost
cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{
SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"},
})
// There's no need to add a full config for msc3861 as we need only an admin token
cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"}
cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken}
routers := httputil.NewRouters()
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
// We mostly need the userAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics)
for _, u := range []*test.User{alice} {
userRes := &uapi.PerformAccountCreationResponse{}
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
AccountType: u.AccountType,
Localpart: u.Localpart,
ServerName: cfg.Global.ServerName,
Password: "",
}, userRes); err != nil {
t.Errorf("failed to create account: %s", err)
}
}
t.Run("Missing auth token", func(t *testing.T) {
req := test.NewRequest(t, http.MethodPost, "/_synapse/admin/v1/deactivate/"+alice.ID)
rec := httptest.NewRecorder()
routers.SynapseAdmin.ServeHTTP(rec, req)
t.Logf("%s", rec.Body.String())
if rec.Code != http.StatusUnauthorized {
t.Fatalf("expected http status %d, got %d: %s", http.StatusUnauthorized, rec.Code, rec.Body.String())
}
var b spec.MatrixError
_ = json.NewDecoder(rec.Body).Decode(&b)
if b.ErrCode != spec.ErrorMissingToken {
t.Fatalf("expected error code %s, got %s", spec.ErrorMissingToken, b.ErrCode)
}
})
t.Run("Deactivate existing account", func(t *testing.T) {
req := test.NewRequest(t, http.MethodPost, "/_synapse/admin/v1/deactivate/"+alice.ID)
req.Header.Set("Authorization", "Bearer "+adminToken)
rec := httptest.NewRecorder()
routers.SynapseAdmin.ServeHTTP(rec, req)
t.Logf("%s", rec.Body.String())
if rec.Code != http.StatusOK {
t.Fatalf("expected HTTP status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
}
var rs uapi.QueryAccountByLocalpartResponse
_ = userAPI.QueryAccountByLocalpart(ctx, &uapi.QueryAccountByLocalpartRequest{Localpart: alice.Localpart, ServerName: cfg.Global.ServerName}, &rs)
if !rs.Account.Deactivated {
t.Fatalf("expected account is deactivated")
}
})
t.Run("Deactivate non-existing account", func(t *testing.T) {
req := test.NewRequest(t, http.MethodPost, "/_synapse/admin/v1/deactivate/"+bob.ID)
req.Header.Set("Authorization", "Bearer "+adminToken)
rec := httptest.NewRecorder()
routers.SynapseAdmin.ServeHTTP(rec, req)
t.Logf("%s", rec.Body.String())
if rec.Code != http.StatusOK {
t.Fatalf("expected HTTP status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
}
})
})
}
func TestAdminAllowCrossSigningReplacementWithoutUIA(t *testing.T) {
alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
adminToken := "superSecretAdminToken"
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
defer close()
natsInstance := jetstream.NATSInstance{}
// add a vhost
cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{
SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"},
})
// There's no need to add a full config for msc3861 as we need only an admin token
cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"}
cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken}
routers := httputil.NewRouters()
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
// We mostly need the userAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics)
t.Run("Missing auth token", func(t *testing.T) {
req := test.NewRequest(t, http.MethodPost, "/_synapse/admin/v1/users/"+alice.ID+"/_allow_cross_signing_replacement_without_uia")
rec := httptest.NewRecorder()
routers.SynapseAdmin.ServeHTTP(rec, req)
t.Logf("%s", rec.Body.String())
if rec.Code != http.StatusUnauthorized {
t.Fatalf("expected http status %d, got %d: %s", http.StatusUnauthorized, rec.Code, rec.Body.String())
}
var b spec.MatrixError
_ = json.NewDecoder(rec.Body).Decode(&b)
if b.ErrCode != spec.ErrorMissingToken {
t.Fatalf("expected error code %s, got %s", spec.ErrorMissingToken, b.ErrCode)
}
})
for _, u := range []*test.User{alice} {
var userRes uapi.PerformAccountCreationResponse
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
AccountType: u.AccountType,
Localpart: u.Localpart,
ServerName: cfg.Global.ServerName,
Password: "",
}, &userRes); err != nil {
t.Errorf("failed to create account: %s", err)
}
_ = userAPI.KeyDatabase.StoreCrossSigningKeysForUser(ctx, alice.ID, types.CrossSigningKeyMap{
fclient.CrossSigningKeyPurposeMaster: spec.Base64Bytes("Og7D7+RQS030dOsWEtS/juJLTOVojXk1DoNKadyXWyk"),
fclient.CrossSigningKeyPurposeSelfSigning: spec.Base64Bytes("Og7D7+RQS030dOsWEtS/juJLTOVojXk1DoNKadyXWyk"),
fclient.CrossSigningKeyPurposeUserSigning: spec.Base64Bytes("Og7D7+RQS030dOsWEtS/juJLTOVojXk1DoNKadyXWyk"),
})
}
testCases := []struct {
Name string
User *test.User
Code int
}{
{Name: "existing user", User: alice, Code: 200},
{Name: "non-existing user", User: bob, Code: 404},
}
now := time.Now()
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
req := test.NewRequest(t, http.MethodPost, "/_synapse/admin/v1/users/"+tc.User.ID+"/_allow_cross_signing_replacement_without_uia")
req.Header.Set("Authorization", "Bearer "+adminToken)
rec := httptest.NewRecorder()
routers.SynapseAdmin.ServeHTTP(rec, req)
t.Logf("%s", rec.Body.String())
if rec.Code != tc.Code {
t.Fatalf("expected HTTP status %d, got %d: %s", tc.Code, rec.Code, rec.Body.String())
}
if rec.Code == 200 {
buf := make(map[string]int64, 1)
_ = json.NewDecoder(rec.Body).Decode(&buf)
if ts := buf["updatable_without_uia_before_ms"]; ts <= now.UnixMilli() {
t.Fatalf("expected updatable_without_uia_before_ms is in future, got %d", ts)
}
}
})
}
})
}
func TestAdminCreateOrModifyAccount(t *testing.T) {
alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
adminToken := "superSecretAdminToken"
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
defer close()
natsInstance := jetstream.NATSInstance{}
// add a vhost
cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{
SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"},
})
// There's no need to add a full config for msc3861 as we need only an admin token
cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"}
cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken}
routers := httputil.NewRouters()
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
// We mostly need the userAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics)
for _, u := range []*test.User{alice} {
userRes := &uapi.PerformAccountCreationResponse{}
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
AccountType: u.AccountType,
Localpart: u.Localpart,
ServerName: cfg.Global.ServerName,
Password: "",
}, userRes); err != nil {
t.Errorf("failed to create account: %s", err)
}
}
type threePID struct {
Medium string `json:"medium"`
Address string `json:"address"`
}
type adminCreateOrModifyAccountRequest struct {
DisplayName string `json:"displayname"`
AvatarURL string `json:"avatar_url"`
ThreePIDs []threePID `json:"threepids"`
}
t.Run("Missing auth token", func(t *testing.T) {
req := test.NewRequest(t, http.MethodPut, "/_synapse/admin/v2/users/"+alice.ID)
rec := httptest.NewRecorder()
routers.SynapseAdmin.ServeHTTP(rec, req)
t.Logf("%s", rec.Body.String())
if rec.Code != http.StatusUnauthorized {
t.Fatalf("expected http status %d, got %d: %s", http.StatusUnauthorized, rec.Code, rec.Body.String())
}
var b spec.MatrixError
_ = json.NewDecoder(rec.Body).Decode(&b)
if b.ErrCode != spec.ErrorMissingToken {
t.Fatalf("expected error code %s, got %s", spec.ErrorMissingToken, b.ErrCode)
}
})
testCases := []struct {
Name string
User *test.User
Payload adminCreateOrModifyAccountRequest
Expected struct {
DisplayName,
AvatarURL string
ThreePIDs []string
}
Code int
}{
{
Name: fmt.Sprintf("Modify user %s", alice.ID),
User: alice,
Payload: adminCreateOrModifyAccountRequest{
DisplayName: "alice",
AvatarURL: "https://alice-avatar.example.com",
ThreePIDs: []threePID{
{
Medium: "email",
Address: "alice@example.com",
},
},
},
Expected: struct {
DisplayName, AvatarURL string
ThreePIDs []string
}{
// In order to avoid any confusion and undesired behaviour, we do not change display name and avatar url if account already exists
DisplayName: alice.Localpart,
AvatarURL: "",
ThreePIDs: []string{"alice@example.com"},
},
Code: http.StatusOK,
},
{
Name: fmt.Sprintf("Create user %s", bob.ID),
User: bob,
Payload: adminCreateOrModifyAccountRequest{
DisplayName: "bob",
AvatarURL: "https://bob-avatar.example.com",
ThreePIDs: []threePID{
{
Medium: "email",
Address: "bob@example.com",
},
},
},
Expected: struct {
DisplayName, AvatarURL string
ThreePIDs []string
}{
DisplayName: "bob",
AvatarURL: "https://bob-avatar.example.com",
ThreePIDs: []string{"bob@example.com"},
},
Code: http.StatusCreated,
},
}
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
req := test.NewRequest(
t,
http.MethodPut,
"/_synapse/admin/v2/users/"+tc.User.ID,
test.WithJSONBody(t, tc.Payload),
)
req.Header.Set("Authorization", "Bearer "+adminToken)
rec := httptest.NewRecorder()
routers.SynapseAdmin.ServeHTTP(rec, req)
t.Logf("%s", rec.Body.String())
if rec.Code != tc.Code {
t.Fatalf("expected HTTP status %d, got %d: %s", tc.Code, rec.Code, rec.Body.String())
}
p, _ := userAPI.QueryProfile(ctx, tc.User.ID)
if p.DisplayName != tc.Expected.DisplayName {
t.Fatalf("expected display name %s, got %s", tc.Expected.DisplayName, p.DisplayName)
}
if p.AvatarURL != tc.Expected.AvatarURL {
t.Fatalf("expected avatar_url %s, got %s", tc.Expected.AvatarURL, p.AvatarURL)
}
var threePidRs uapi.QueryThreePIDsForLocalpartResponse
_ = userAPI.QueryThreePIDsForLocalpart(
ctx,
&uapi.QueryThreePIDsForLocalpartRequest{Localpart: tc.User.Localpart, ServerName: cfg.Global.ServerName},
&threePidRs,
)
if len(threePidRs.ThreePIDs) != 1 {
t.Fatalf("expected 1 3pid got %d", len(threePidRs.ThreePIDs))
}
tp := threePidRs.ThreePIDs[0]
if tp.Medium != "email" {
t.Fatalf("expected 3pid medium email got %s", tp.Medium)
}
if tp.Address != tc.Payload.ThreePIDs[0].Address {
t.Fatalf("expected 3pid address %s got %s", tc.Expected.ThreePIDs[0], tp.Address)
}
})
}
})
}
func TestAdminRetrieveAccount(t *testing.T) {
alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
adminToken := "superSecretAdminToken"
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
defer close()
natsInstance := jetstream.NATSInstance{}
// add a vhost
cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{
SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"},
})
// There's no need to add a full config for msc3861 as we need only an admin token
cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"}
cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken}
routers := httputil.NewRouters()
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
// We mostly need the userAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics)
for _, u := range []*test.User{alice} {
userRes := &uapi.PerformAccountCreationResponse{}
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
AccountType: u.AccountType,
Localpart: u.Localpart,
ServerName: cfg.Global.ServerName,
Password: "",
}, userRes); err != nil {
t.Errorf("failed to create account: %s", err)
}
}
t.Run("Missing auth token", func(t *testing.T) {
req := test.NewRequest(t, http.MethodGet, "/_synapse/admin/v2/users/"+alice.ID)
rec := httptest.NewRecorder()
routers.SynapseAdmin.ServeHTTP(rec, req)
t.Logf("%s", rec.Body.String())
if rec.Code != http.StatusUnauthorized {
t.Fatalf("expected http status %d, got %d: %s", http.StatusUnauthorized, rec.Code, rec.Body.String())
}
var b spec.MatrixError
_ = json.NewDecoder(rec.Body).Decode(&b)
if b.ErrCode != spec.ErrorMissingToken {
t.Fatalf("expected error code %s, got %s", spec.ErrorMissingToken, b.ErrCode)
}
})
testCase := []struct {
Name string
User *test.User
Code int
Body string
}{
{
Name: "Retrieve existing account",
User: alice,
Code: http.StatusOK,
Body: fmt.Sprintf(`{"display_name":"%s","avatar_url":"","deactivated":false}`, alice.Localpart),
},
{
Name: "Retrieve non-existing account",
User: bob,
Code: http.StatusNotFound,
Body: "",
},
}
for _, tc := range testCase {
t.Run(tc.Name, func(t *testing.T) {
req := test.NewRequest(t, http.MethodGet, "/_synapse/admin/v2/users/"+tc.User.ID)
req.Header.Set("Authorization", "Bearer "+adminToken)
rec := httptest.NewRecorder()
routers.SynapseAdmin.ServeHTTP(rec, req)
t.Logf("%s", rec.Body.String())
if rec.Code != tc.Code {
t.Fatalf("expected HTTP status %d, got %d: %s", tc.Code, rec.Code, rec.Body.String())
}
if tc.Body != "" && tc.Body != rec.Body.String() {
t.Fatalf("expected body %s, got %s", tc.Body, rec.Body.String())
}
})
}
})
}

View file

@ -16,8 +16,6 @@ import (
"strings"
"github.com/element-hq/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util"
)
// OWASP recommends at least 128 bits of entropy for tokens: https://www.owasp.org/index.php/Insufficient_Session-ID_Length
@ -37,51 +35,6 @@ type AccountDatabase interface {
GetAccountByPassword(ctx context.Context, localpart, password string) (*api.Account, error)
}
// VerifyUserFromRequest authenticates the HTTP request,
// on success returns Device of the requester.
// Finds local user or an application service user.
// Note: For an AS user, AS dummy device is returned.
// On failure returns an JSON error response which can be sent to the client.
func VerifyUserFromRequest(
req *http.Request, userAPI api.QueryAcccessTokenAPI,
) (*api.Device, *util.JSONResponse) {
// Try to find the Application Service user
token, err := ExtractAccessToken(req)
if err != nil {
return nil, &util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: spec.MissingToken(err.Error()),
}
}
var res api.QueryAccessTokenResponse
err = userAPI.QueryAccessToken(req.Context(), &api.QueryAccessTokenRequest{
AccessToken: token,
AppServiceUserID: req.URL.Query().Get("user_id"),
}, &res)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccessToken failed")
return nil, &util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
if res.Err != "" {
if strings.HasPrefix(strings.ToLower(res.Err), "forbidden:") { // TODO: use actual error and no string comparison
return nil, &util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden(res.Err),
}
}
}
if res.Device == nil {
return nil, &util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: spec.UnknownToken("Unknown token"),
}
}
return res.Device, nil
}
// GenerateAccessToken creates a new access token. Returns an error if failed to generate
// random bytes.
func GenerateAccessToken() (string, error) {

View file

@ -0,0 +1,65 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
package auth
import (
"net/http"
"strings"
"github.com/element-hq/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util"
)
// DefaultUserVerifier implements UserVerifier interface
type DefaultUserVerifier struct {
UserAPI api.QueryAccessTokenAPI
}
// VerifyUserFromRequest authenticates the HTTP request,
// on success returns Device of the requester.
// Finds local user or an application service user.
// Note: For an AS user, AS dummy device is returned.
// On failure returns an JSON error response which can be sent to the client.
func (d *DefaultUserVerifier) VerifyUserFromRequest(req *http.Request) (*api.Device, *util.JSONResponse) {
ctx := req.Context()
util.GetLogger(ctx).Debug("Default VerifyUserFromRequest")
// Try to find the Application Service user
token, err := ExtractAccessToken(req)
if err != nil {
return nil, &util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: spec.MissingToken(err.Error()),
}
}
var res api.QueryAccessTokenResponse
err = d.UserAPI.QueryAccessToken(ctx, &api.QueryAccessTokenRequest{
AccessToken: token,
AppServiceUserID: req.URL.Query().Get("user_id"),
}, &res)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.QueryAccessToken failed")
return nil, &util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
if res.Err != "" {
if strings.HasPrefix(strings.ToLower(res.Err), "forbidden:") { // TODO: use actual error and no string comparison
return nil, &util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden(res.Err),
}
}
}
if res.Device == nil {
return nil, &util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: spec.UnknownToken("Unknown token"),
}
}
return res.Device, nil
}

View file

@ -36,7 +36,9 @@ func AddPublicRoutes(
fsAPI federationAPI.ClientFederationAPI,
userAPI userapi.ClientUserAPI,
userDirectoryProvider userapi.QuerySearchProfilesAPI,
extRoomsProvider api.ExtraPublicRoomsProvider, enableMetrics bool,
extRoomsProvider api.ExtraPublicRoomsProvider,
userVerifier httputil.UserVerifier,
enableMetrics bool,
) {
js, natsClient := natsInstance.Prepare(processContext, &cfg.Global.JetStream)
@ -55,6 +57,7 @@ func AddPublicRoutes(
cfg, rsAPI, asAPI,
userAPI, userDirectoryProvider, federation,
syncProducer, transactionsCache, fsAPI,
extRoomsProvider, natsClient, enableMetrics,
extRoomsProvider, natsClient,
userVerifier, enableMetrics,
)
}

View file

@ -14,6 +14,7 @@ import (
"time"
"github.com/element-hq/dendrite/appservice"
"github.com/element-hq/dendrite/clientapi/auth"
"github.com/element-hq/dendrite/clientapi/auth/authtypes"
"github.com/element-hq/dendrite/clientapi/routing"
"github.com/element-hq/dendrite/clientapi/threepid"
@ -127,9 +128,10 @@ func TestGetPutDevices(t *testing.T) {
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
accessTokens := map[*test.User]userDevice{
alice: {},
@ -176,9 +178,10 @@ func TestDeleteDevice(t *testing.T) {
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
// We mostly need the rsAPI/ for this test, so nil for other APIs/caches etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
accessTokens := map[*test.User]userDevice{
alice: {},
@ -281,9 +284,10 @@ func TestDeleteDevices(t *testing.T) {
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
// We mostly need the rsAPI/ for this test, so nil for other APIs/caches etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
accessTokens := map[*test.User]userDevice{
alice: {},
@ -449,8 +453,9 @@ func TestSetDisplayname(t *testing.T) {
rsAPI.SetFederationAPI(nil, nil)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
asPI := appservice.NewInternalAPI(processCtx, cfg, natsInstance, userAPI, rsAPI)
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
AddPublicRoutes(processCtx, routers, cfg, natsInstance, base.CreateFederationClient(cfg, nil), rsAPI, asPI, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
AddPublicRoutes(processCtx, routers, cfg, natsInstance, base.CreateFederationClient(cfg, nil), rsAPI, asPI, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
accessTokens := map[*test.User]userDevice{
alice: {},
@ -561,8 +566,9 @@ func TestSetAvatarURL(t *testing.T) {
rsAPI.SetFederationAPI(nil, nil)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
asPI := appservice.NewInternalAPI(processCtx, cfg, natsInstance, userAPI, rsAPI)
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
AddPublicRoutes(processCtx, routers, cfg, natsInstance, base.CreateFederationClient(cfg, nil), rsAPI, asPI, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
AddPublicRoutes(processCtx, routers, cfg, natsInstance, base.CreateFederationClient(cfg, nil), rsAPI, asPI, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
accessTokens := map[*test.User]userDevice{
alice: {},
@ -638,8 +644,9 @@ func TestTyping(t *testing.T) {
rsAPI.SetFederationAPI(nil, nil)
// Needed to create accounts
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
// We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
// Create the users in the userapi and login
accessTokens := map[*test.User]userDevice{
@ -723,8 +730,9 @@ func TestMembership(t *testing.T) {
// Needed to create accounts
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
rsAPI.SetUserAPI(userAPI)
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
// We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
// Create the users in the userapi and login
accessTokens := map[*test.User]userDevice{
@ -962,8 +970,9 @@ func TestCapabilities(t *testing.T) {
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
// We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
// Create the users in the userapi and login
accessTokens := map[*test.User]userDevice{
@ -1010,9 +1019,10 @@ func TestTurnserver(t *testing.T) {
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
//rsAPI.SetUserAPI(userAPI)
// We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
// Create the users in the userapi and login
accessTokens := map[*test.User]userDevice{
@ -1109,8 +1119,9 @@ func Test3PID(t *testing.T) {
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
// We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
// Create the users in the userapi and login
accessTokens := map[*test.User]userDevice{
@ -1285,9 +1296,10 @@ func TestPushRules(t *testing.T) {
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
accessTokens := map[*test.User]userDevice{
alice: {},
@ -1672,9 +1684,10 @@ func TestKeys(t *testing.T) {
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
accessTokens := map[*test.User]userDevice{
alice: {},
@ -2134,9 +2147,10 @@ func TestKeyBackup(t *testing.T) {
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
accessTokens := map[*test.User]userDevice{
alice: {},
@ -2238,9 +2252,10 @@ func TestGetMembership(t *testing.T) {
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
accessTokens := map[*test.User]userDevice{
alice: {},
@ -2301,9 +2316,10 @@ func TestCreateRoomInvite(t *testing.T) {
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
accessTokens := map[*test.User]userDevice{
alice: {},
@ -2376,9 +2392,10 @@ func TestReportEvent(t *testing.T) {
if err := api.SendEvents(context.Background(), rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
t.Fatalf("failed to send events: %v", err)
}
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
accessTokens := map[*test.User]userDevice{
alice: {},

View file

@ -1,7 +1,13 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
package routing
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
@ -20,16 +26,24 @@ import (
"github.com/sirupsen/logrus"
"golang.org/x/exp/constraints"
appserviceAPI "github.com/element-hq/dendrite/appservice/api"
clientapi "github.com/element-hq/dendrite/clientapi/api"
"github.com/element-hq/dendrite/clientapi/auth/authtypes"
clienthttputil "github.com/element-hq/dendrite/clientapi/httputil"
"github.com/element-hq/dendrite/clientapi/userutil"
"github.com/element-hq/dendrite/internal/httputil"
roomserverAPI "github.com/element-hq/dendrite/roomserver/api"
"github.com/element-hq/dendrite/setup/config"
"github.com/element-hq/dendrite/setup/jetstream"
"github.com/element-hq/dendrite/userapi/api"
userapi "github.com/element-hq/dendrite/userapi/api"
"github.com/element-hq/dendrite/userapi/storage/shared"
)
var validRegistrationTokenRegex = regexp.MustCompile("^[[:ascii:][:digit:]_]*$")
var (
validRegistrationTokenRegex = regexp.MustCompile("^[[:ascii:][:digit:]_]*$")
deviceDisplayName = "OIDC-native client"
)
func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
if !cfg.RegistrationRequiresToken {
@ -496,6 +510,485 @@ func AdminDownloadState(req *http.Request, device *api.Device, rsAPI roomserverA
}
}
func AdminCheckUsernameAvailable(
req *http.Request,
userAPI userapi.ClientUserAPI,
cfg *config.ClientAPI,
) util.JSONResponse {
username := req.URL.Query().Get("username")
if username == "" {
return util.MessageResponse(http.StatusBadRequest, "Query parameter 'username' is missing or empty")
}
rq := userapi.QueryAccountAvailabilityRequest{Localpart: username, ServerName: cfg.Matrix.ServerName}
rs := userapi.QueryAccountAvailabilityResponse{}
if err := userAPI.QueryAccountAvailability(req.Context(), &rq, &rs); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: map[string]bool{"available": rs.Available},
}
}
func AdminUserDeviceRetrieveCreate(
req *http.Request,
userAPI userapi.ClientUserAPI,
cfg *config.ClientAPI,
) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
userID := vars["userID"]
local, domain, err := userutil.ParseUsernameParam(userID, cfg.Matrix)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam(err.Error()),
}
}
logger := util.GetLogger(req.Context())
switch req.Method {
case http.MethodPost:
var payload struct {
DeviceID string `json:"device_id"`
}
if resErr := clienthttputil.UnmarshalJSONRequest(req, &payload); resErr != nil {
return *resErr
}
userDeviceExists := false
var rs api.QueryDevicesResponse
if err = userAPI.QueryDevices(req.Context(), &api.QueryDevicesRequest{UserID: userID}, &rs); err != nil {
logger.WithError(err).Error("QueryDevices")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
if !rs.UserExists {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: spec.NotFound("Given user ID does not exist"),
}
}
for i := range rs.Devices {
if d := rs.Devices[i]; d.ID == payload.DeviceID && d.UserID == userID {
userDeviceExists = true
break
}
}
if !userDeviceExists {
var rs userapi.PerformDeviceCreationResponse
if err = userAPI.PerformDeviceCreation(req.Context(), &userapi.PerformDeviceCreationRequest{
Localpart: local,
ServerName: domain,
DeviceID: &payload.DeviceID,
DeviceDisplayName: &deviceDisplayName,
IPAddr: "",
UserAgent: req.UserAgent(),
NoDeviceListUpdate: false,
FromRegistration: false,
AccessTokenUniqueConstraintDisabled: true,
}, &rs); err != nil {
logger.WithError(err).Error("PerformDeviceCreation")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
logger.WithError(err).Debug("PerformDeviceCreation succeeded")
}
return util.JSONResponse{
Code: http.StatusCreated,
JSON: struct{}{},
}
case http.MethodGet:
var res userapi.QueryDevicesResponse
if err := userAPI.QueryDevices(req.Context(), &userapi.QueryDevicesRequest{UserID: userID}, &res); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
jsonDevices := make([]deviceJSON, 0, len(res.Devices))
for i := range res.Devices {
d := &res.Devices[i]
jsonDevices = append(jsonDevices, deviceJSON{
DeviceID: d.ID,
DisplayName: d.DisplayName,
LastSeenIP: d.LastSeenIP,
LastSeenTS: d.LastSeenTS,
})
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct {
Devices []deviceJSON `json:"devices"`
Total int `json:"total"`
}{
Devices: jsonDevices,
Total: len(res.Devices),
},
}
default:
return util.JSONResponse{
Code: http.StatusMethodNotAllowed,
JSON: struct{}{},
}
}
}
func AdminUserDeviceDelete(
req *http.Request,
userAPI userapi.ClientUserAPI,
cfg *config.ClientAPI,
) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
userID := vars["userID"]
deviceID := vars["deviceID"]
logger := util.GetLogger(req.Context())
// XXX: we probably have to delete session from the sessions dict
// like we do in DeleteDeviceById. If so, we have to fi
var device *api.Device
{
var rs api.QueryDevicesResponse
if err := userAPI.QueryDevices(req.Context(), &api.QueryDevicesRequest{UserID: userID}, &rs); err != nil {
logger.WithError(err).Error("userAPI.QueryDevices failed")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
if !rs.UserExists {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: spec.NotFound("Given user ID does not exist"),
}
}
for i := range rs.Devices {
if d := rs.Devices[i]; d.ID == deviceID && d.UserID == userID {
device = &d
break
}
}
}
if device != nil {
// XXX: this response struct can completely removed everywhere as it doesn't
// have any functional purpose
var res api.PerformDeviceDeletionResponse
if err := userAPI.PerformDeviceDeletion(req.Context(), &api.PerformDeviceDeletionRequest{
UserID: device.UserID,
DeviceIDs: []string{device.ID},
}, &res); err != nil {
logger.WithError(err).Error("userAPI.PerformDeviceDeletion failed")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
func AdminUserDevicesDelete(
req *http.Request,
userAPI userapi.ClientUserAPI,
cfg *config.ClientAPI,
) util.JSONResponse {
logger := util.GetLogger(req.Context())
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
userID := vars["userID"]
if req.Body == nil {
return util.MessageResponse(http.StatusBadRequest, "body is required")
}
var payload struct {
Devices []string `json:"devices"`
}
if err = json.NewDecoder(req.Body).Decode(&payload); err != nil {
logger.WithError(err).Error("unable to decode device deletion request")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
defer req.Body.Close() // nolint: errcheck
{
// XXX: this response struct can completely removed everywhere as it doesn't
// have any functional purpose
var rs api.PerformDeviceDeletionResponse
if err := userAPI.PerformDeviceDeletion(req.Context(), &api.PerformDeviceDeletionRequest{
UserID: userID,
DeviceIDs: payload.Devices,
}, &rs); err != nil {
logger.WithError(err).Error("userAPI.PerformDeviceDeletion failed")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
func AdminDeactivateAccount(
req *http.Request,
userAPI userapi.ClientUserAPI,
cfg *config.ClientAPI,
) util.JSONResponse {
logger := util.GetLogger(req.Context())
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
userID := vars["userID"]
local, domain, err := userutil.ParseUsernameParam(userID, cfg.Matrix)
if err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
// TODO: "erase" field must also be processed here
// see https://github.com/element-hq/synapse/blob/develop/docs/admin_api/user_admin_api.md#deactivate-account
var rs api.PerformAccountDeactivationResponse
if err := userAPI.PerformAccountDeactivation(req.Context(), &api.PerformAccountDeactivationRequest{
Localpart: local, ServerName: domain,
}, &rs); err != nil {
logger.WithError(err).Error("userAPI.PerformDeviceDeletion failed")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
func AdminAllowCrossSigningReplacementWithoutUIA(
req *http.Request,
userAPI userapi.ClientUserAPI,
) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
userIDstr, ok := vars["userID"]
userID, err := spec.NewUserID(userIDstr, false)
if !ok || err != nil {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: spec.MissingParam("User not found."),
}
}
var rs api.QueryAccountByLocalpartResponse
err = userAPI.QueryAccountByLocalpart(req.Context(), &api.QueryAccountByLocalpartRequest{
Localpart: userID.Local(),
ServerName: userID.Domain(),
}, &rs)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccountByLocalpart")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.Unknown(err.Error()),
}
} else if errors.Is(err, sql.ErrNoRows) {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: spec.NotFound("User not found."),
}
}
switch req.Method {
case http.MethodPost:
ts := sessions.allowCrossSigningKeysReplacement(userID.String())
return util.JSONResponse{
Code: http.StatusOK,
JSON: map[string]int64{"updatable_without_uia_before_ms": ts},
}
default:
return util.JSONResponse{
Code: http.StatusMethodNotAllowed,
JSON: spec.Unknown("Method not allowed."),
}
}
}
type adminCreateOrModifyAccountRequest struct {
DisplayName string `json:"displayname"`
AvatarURL string `json:"avatar_url"`
ThreePIDs []struct {
Medium string `json:"medium"`
Address string `json:"address"`
} `json:"threepids"`
// TODO: the following fields are not used by dendrite, but they are used in Synapse.
// Password string `json:"password"`
// LogoutDevices bool `json:"logout_devices"`
// ExternalIDs []struct{
// AuthProvider string `json:"auth_provider"`
// ExternalID string `json:"external_id"`
// } `json:"external_ids"`
// Admin bool `json:"admin"`
// Deactivated bool `json:"deactivated"`
// Locked bool `json:"locked"`
}
func AdminCreateOrModifyAccount(req *http.Request, userAPI userapi.ClientUserAPI, cfg *config.ClientAPI) util.JSONResponse {
logger := util.GetLogger(req.Context())
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
userID := vars["userID"]
local, domain, err := userutil.ParseUsernameParam(userID, cfg.Matrix)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam(userID),
}
}
var r adminCreateOrModifyAccountRequest
if resErr := clienthttputil.UnmarshalJSONRequest(req, &r); resErr != nil {
logger.Debugf("UnmarshalJSONRequest failed: %+v", *resErr)
return *resErr
}
logger.Debugf("adminCreateOrModifyAccountRequest is: %#v", r)
statusCode := http.StatusOK
// TODO: Ideally, the following commands should be executed in one transaction.
// can we propagate the tx object and pass it in context?
var res userapi.PerformAccountCreationResponse
err = userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{
AccountType: userapi.AccountTypeUser,
Localpart: local,
ServerName: domain,
OnConflict: api.ConflictUpdate,
AvatarURL: r.AvatarURL,
DisplayName: r.DisplayName,
}, &res)
if err != nil {
logger.WithError(err).Error("userAPI.PerformAccountCreation")
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if res.AccountCreated {
statusCode = http.StatusCreated
}
if l := len(r.ThreePIDs); l > 0 {
logger.Debugf("Trying to bulk save 3PID associations: %+v", r.ThreePIDs)
threePIDs := make([]authtypes.ThreePID, 0, len(r.ThreePIDs))
for i := range r.ThreePIDs {
tpid := &r.ThreePIDs[i]
threePIDs = append(threePIDs, authtypes.ThreePID{Medium: tpid.Medium, Address: tpid.Address})
}
err = userAPI.PerformBulkSaveThreePIDAssociation(req.Context(), &userapi.PerformBulkSaveThreePIDAssociationRequest{
ThreePIDs: threePIDs,
Localpart: local,
ServerName: domain,
}, &struct{}{})
if err == shared.Err3PIDInUse {
return util.MessageResponse(http.StatusBadRequest, err.Error())
} else if err != nil {
logger.WithError(err).Error("userAPI.PerformSaveThreePIDAssociation")
return util.ErrorResponse(err)
}
}
return util.JSONResponse{
Code: statusCode,
JSON: nil,
}
}
func AdminRetrieveAccount(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
logger := util.GetLogger(req.Context())
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
userID, ok := vars["userID"]
if !ok {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.MissingParam("Expecting user ID."),
}
}
local, domain, err := userutil.ParseUsernameParam(userID, cfg.Matrix)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam(err.Error()),
}
}
body := struct {
DisplayName string `json:"display_name"`
AvatarURL string `json:"avatar_url"`
Deactivated bool `json:"deactivated"`
}{}
var rs api.QueryAccountByLocalpartResponse
err = userAPI.QueryAccountByLocalpart(req.Context(), &api.QueryAccountByLocalpartRequest{Localpart: local, ServerName: domain}, &rs)
if err == sql.ErrNoRows {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: spec.NotFound(fmt.Sprintf("User '%s' not found", userID)),
}
} else if err != nil {
logger.WithError(err).Error("userAPI.QueryAccountByLocalpart")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.Unknown(err.Error()),
}
}
body.Deactivated = rs.Account.Deactivated
profile, err := userAPI.QueryProfile(req.Context(), userID)
if err != nil {
if err == appserviceAPI.ErrProfileNotExists {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: spec.NotFound(err.Error()),
}
}
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.Unknown(err.Error()),
}
}
body.AvatarURL = profile.AvatarURL
body.DisplayName = profile.DisplayName
return util.JSONResponse{
Code: http.StatusOK,
JSON: body,
}
}
// GetEventReports returns reported events for a given user/room.
func GetEventReports(
req *http.Request,

View file

@ -9,20 +9,21 @@ package routing
import (
"context"
"net/http"
"strings"
"time"
"github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/sirupsen/logrus"
"github.com/element-hq/dendrite/clientapi/auth"
"github.com/element-hq/dendrite/clientapi/auth/authtypes"
"github.com/element-hq/dendrite/clientapi/httputil"
"github.com/element-hq/dendrite/setup/config"
"github.com/element-hq/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util"
)
const CrossSigningResetStage = "org.matrix.cross_signing_reset"
type crossSigningRequest struct {
api.PerformUploadDeviceKeysRequest
Auth newPasswordAuth `json:"auth"`
@ -38,6 +39,7 @@ func UploadCrossSigningDeviceKeys(
keyserverAPI UploadKeysAPI, device *api.Device,
accountAPI auth.GetAccountByPassword, cfg *config.ClientAPI,
) util.JSONResponse {
logger := util.GetLogger(req.Context())
uploadReq := &crossSigningRequest{}
uploadRes := &api.PerformUploadDeviceKeysResponse{}
@ -55,76 +57,92 @@ func UploadCrossSigningDeviceKeys(
}, &keyResp)
if keyResp.Error != nil {
logrus.WithError(keyResp.Error).Error("Failed to query keys")
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.Unknown(keyResp.Error.Error()),
}
logger.WithError(keyResp.Error).Error("Failed to query keys")
return convertKeyError(keyResp.Error)
}
existingMasterKey, hasMasterKey := keyResp.MasterKeys[device.UserID]
requireUIA := false
if hasMasterKey {
// If we have a master key, check if any of the existing keys differ. If they do,
// we need to re-authenticate the user.
requireUIA = keysDiffer(existingMasterKey, keyResp, uploadReq, device.UserID)
}
if requireUIA {
sessionID := uploadReq.Auth.Session
if sessionID == "" {
sessionID = util.RandomString(sessionIDLength)
}
if uploadReq.Auth.Type != authtypes.LoginTypePassword {
if hasMasterKey {
if !keysDiffer(existingMasterKey, keyResp, uploadReq, device.UserID) {
// If we have a master key, check if any of the existing keys differ. If they don't
// we return 200 as keys are still valid and there's nothing to do.
return util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: newUserInteractiveResponse(
sessionID,
[]authtypes.Flow{
{
Stages: []authtypes.LoginType{authtypes.LoginTypePassword},
},
},
nil,
),
Code: http.StatusOK,
JSON: struct{}{},
}
}
typePassword := auth.LoginTypePassword{
GetAccountByPassword: accountAPI,
Config: cfg,
// With MSC3861, UIA is not possible. Instead, the auth service has to explicitly mark the master key as replaceable.
if cfg.MSCs.MSC3861Enabled() {
requireUIA := !sessions.isCrossSigningKeysReplacementAllowed(device.UserID)
if requireUIA {
url := ""
if m := cfg.MSCs.MSC3861; m.AccountManagementURL != "" {
url = strings.Join([]string{m.AccountManagementURL, "?action=", CrossSigningResetStage}, "")
} else {
url = m.Issuer
}
return util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: newUserInteractiveResponse(
"dummy",
[]authtypes.Flow{
{
Stages: []authtypes.LoginType{CrossSigningResetStage},
},
},
map[string]interface{}{
CrossSigningResetStage: map[string]string{
"url": url,
},
},
strings.Join([]string{
"To reset your end-to-end encryption cross-signing identity, you first need to approve it at",
url,
"and then try again.",
}, " "),
),
}
}
sessions.restrictCrossSigningKeysReplacement(device.UserID)
} else {
sessionID := uploadReq.Auth.Session
if sessionID == "" {
sessionID = util.RandomString(sessionIDLength)
}
if uploadReq.Auth.Type != authtypes.LoginTypePassword {
return util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: newUserInteractiveResponse(
sessionID,
[]authtypes.Flow{
{
Stages: []authtypes.LoginType{authtypes.LoginTypePassword},
},
},
nil,
"",
),
}
}
typePassword := auth.LoginTypePassword{
GetAccountByPassword: accountAPI,
Config: cfg,
}
if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil {
return *authErr
}
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
}
if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil {
return *authErr
}
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
}
uploadReq.UserID = device.UserID
keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes)
if err := uploadRes.Error; err != nil {
switch {
case err.IsInvalidSignature:
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidSignature(err.Error()),
}
case err.IsMissingParam:
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.MissingParam(err.Error()),
}
case err.IsInvalidParam:
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam(err.Error()),
}
default:
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.Unknown(err.Error()),
}
}
return convertKeyError(err)
}
return util.JSONResponse{
@ -160,28 +178,7 @@ func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.Clie
keyserverAPI.PerformUploadDeviceSignatures(req.Context(), uploadReq, uploadRes)
if err := uploadRes.Error; err != nil {
switch {
case err.IsInvalidSignature:
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidSignature(err.Error()),
}
case err.IsMissingParam:
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.MissingParam(err.Error()),
}
case err.IsInvalidParam:
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam(err.Error()),
}
default:
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.Unknown(err.Error()),
}
}
return convertKeyError(err)
}
return util.JSONResponse{
@ -189,3 +186,28 @@ func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.Clie
JSON: struct{}{},
}
}
func convertKeyError(err *api.KeyError) util.JSONResponse {
switch {
case err.IsInvalidSignature:
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidSignature(err.Error()),
}
case err.IsMissingParam:
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.MissingParam(err.Error()),
}
case err.IsInvalidParam:
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam(err.Error()),
}
default:
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.Unknown(err.Error()),
}
}
}

View file

@ -19,15 +19,17 @@ import (
"github.com/matrix-org/gomatrixserverlib/spec"
)
// TODO: add more tests to cover cases related to MSC3861
type mockKeyAPI struct {
t *testing.T
userResponses map[string]api.QueryKeysResponse
queryKeysData map[string]api.QueryKeysResponse
}
func (m mockKeyAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {
res.MasterKeys = m.userResponses[req.UserID].MasterKeys
res.SelfSigningKeys = m.userResponses[req.UserID].SelfSigningKeys
res.UserSigningKeys = m.userResponses[req.UserID].UserSigningKeys
res.MasterKeys = m.queryKeysData[req.UserID].MasterKeys
res.SelfSigningKeys = m.queryKeysData[req.UserID].SelfSigningKeys
res.UserSigningKeys = m.queryKeysData[req.UserID].UserSigningKeys
if m.t != nil {
m.t.Logf("QueryKeys: %+v => %+v", req, res)
}
@ -53,13 +55,16 @@ func Test_UploadCrossSigningDeviceKeys_ValidRequest(t *testing.T) {
req.Header.Set("Content-Type", "application/json")
keyserverAPI := &mockKeyAPI{
userResponses: map[string]api.QueryKeysResponse{
queryKeysData: map[string]api.QueryKeysResponse{
"@user:example.com": {},
},
}
device := &api.Device{UserID: "@user:example.com", ID: "device"}
cfg := &config.ClientAPI{}
cfg := &config.ClientAPI{
MSCs: &config.MSCs{
MSCs: []string{},
},
}
res := UploadCrossSigningDeviceKeys(req, keyserverAPI, device, getAccountByPassword, cfg)
if res.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, res.Code)
@ -101,10 +106,13 @@ func Test_UploadCrossSigningDeviceKeys_Unauthorised(t *testing.T) {
keyserverAPI := &mockKeyAPI{
t: t,
userResponses: map[string]api.QueryKeysResponse{
queryKeysData: map[string]api.QueryKeysResponse{
"@user:example.com": {
MasterKeys: map[string]fclient.CrossSigningKey{
"@user:example.com": {UserID: "@user:example.com", Usage: []fclient.CrossSigningKeyPurpose{"master"}, Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{"ed25519:1": spec.Base64Bytes("key1")}},
"@user:example.com": {
UserID: "@user:example.com",
Usage: []fclient.CrossSigningKeyPurpose{fclient.CrossSigningKeyPurposeMaster},
Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{"ed25519:1": spec.Base64Bytes("key1")}},
},
SelfSigningKeys: nil,
UserSigningKeys: nil,
@ -112,7 +120,11 @@ func Test_UploadCrossSigningDeviceKeys_Unauthorised(t *testing.T) {
},
}
device := &api.Device{UserID: "@user:example.com", ID: "device"}
cfg := &config.ClientAPI{}
cfg := &config.ClientAPI{
MSCs: &config.MSCs{
MSCs: []string{},
},
}
res := UploadCrossSigningDeviceKeys(req, keyserverAPI, device, getAccountByPassword, cfg)
if res.Code != http.StatusUnauthorized {
@ -132,8 +144,11 @@ func Test_UploadCrossSigningDeviceKeys_InvalidJSON(t *testing.T) {
keyserverAPI := &mockKeyAPI{}
device := &api.Device{UserID: "@user:example.com", ID: "device"}
cfg := &config.ClientAPI{}
cfg := &config.ClientAPI{
MSCs: &config.MSCs{
MSCs: []string{},
},
}
res := UploadCrossSigningDeviceKeys(req, keyserverAPI, device, getAccountByPassword, cfg)
if res.Code != http.StatusBadRequest {
t.Fatalf("expected status %d, got %d", http.StatusBadRequest, res.Code)
@ -151,10 +166,14 @@ func Test_UploadCrossSigningDeviceKeys_ExistingKeysMismatch(t *testing.T) {
req.Header.Set("Content-Type", "application/json")
keyserverAPI := &mockKeyAPI{
userResponses: map[string]api.QueryKeysResponse{
queryKeysData: map[string]api.QueryKeysResponse{
"@user:example.com": {
MasterKeys: map[string]fclient.CrossSigningKey{
"@user:example.com": {UserID: "@user:example.com", Usage: []fclient.CrossSigningKeyPurpose{"master"}, Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{"ed25519:1": spec.Base64Bytes("different_key")}},
"@user:example.com": {
UserID: "@user:example.com",
Usage: []fclient.CrossSigningKeyPurpose{fclient.CrossSigningKeyPurposeMaster},
Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{"ed25519:1": spec.Base64Bytes("different_key")},
},
},
},
},

View file

@ -9,6 +9,7 @@ import (
"testing"
"time"
"github.com/element-hq/dendrite/clientapi/auth"
"github.com/element-hq/dendrite/clientapi/auth/authtypes"
"github.com/element-hq/dendrite/internal/caching"
"github.com/element-hq/dendrite/internal/httputil"
@ -50,9 +51,10 @@ func TestLogin(t *testing.T) {
rsAPI.SetFederationAPI(nil, nil)
// Needed for /login
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
// We mostly need the userAPI for this test, so nil for other APIs/caches etc.
Setup(routers, cfg, nil, nil, userAPI, nil, nil, nil, nil, nil, nil, nil, caching.DisableMetrics)
Setup(routers, cfg, nil, nil, userAPI, nil, nil, nil, nil, nil, nil, nil, &userVerifier, caching.DisableMetrics)
// Create password
password := util.RandomString(8)

View file

@ -67,6 +67,7 @@ func Password(
},
},
nil,
"",
),
}
}

View file

@ -172,24 +172,20 @@ func GetDisplayName(
// SetDisplayName implements PUT /profile/{userID}/displayname
func SetDisplayName(
req *http.Request, profileAPI userapi.ProfileAPI,
req *http.Request, userAPI userapi.ClientUserAPI,
device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.ClientRoomserverAPI,
) util.JSONResponse {
if userID != device.UserID {
if userID != device.UserID && device.AccountType != userapi.AccountTypeOIDCService {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID does not match the current user"),
}
}
var r eventutil.UserProfile
if resErr := httputil.UnmarshalJSONRequest(req, &r); resErr != nil {
return *resErr
}
logger := util.GetLogger(req.Context())
localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
logger.WithError(err).Error("gomatrixserverlib.SplitID failed")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
@ -203,6 +199,28 @@ func SetDisplayName(
}
}
if device.AccountType == userapi.AccountTypeOIDCService {
// When a request is made on behalf of an OIDC provider service, the original device object refers
// to the provider's pseudo-device and includes only the AccountTypeOIDCService flag. To continue,
// we need to replace the admin's device with the user's device
var rs userapi.QueryDevicesResponse
err = userAPI.QueryDevices(req.Context(), &userapi.QueryDevicesRequest{UserID: userID}, &rs)
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
if len(rs.Devices) > 0 {
device = &rs.Devices[0]
}
}
var r eventutil.UserProfile
if resErr := httputil.UnmarshalJSONRequest(req, &r); resErr != nil {
return *resErr
}
evTime, err := httputil.ParseTSParam(req)
if err != nil {
return util.JSONResponse{
@ -211,9 +229,9 @@ func SetDisplayName(
}
}
profile, changed, err := profileAPI.SetDisplayName(req.Context(), localpart, domain, r.DisplayName)
profile, changed, err := userAPI.SetDisplayName(req.Context(), localpart, domain, r.DisplayName)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetDisplayName failed")
logger.WithError(err).Error("profileAPI.SetDisplayName failed")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},

View file

@ -66,11 +66,17 @@ type sessionsDict struct {
// If a UIA session is started by trying to delete device1, and then UIA is completed by deleting device2,
// the delete request will fail for device2 since the UIA was initiated by trying to delete device1.
deleteSessionToDeviceID map[string]string
// crossSigningKeysReplacement is a collection of sessions that MAS has authorised for updating
// cross-signing keys without UIA.
crossSigningKeysReplacement map[string]*time.Timer
}
// defaultTimeout is the timeout used to clean up sessions
const defaultTimeOut = time.Minute * 5
// crossSigningKeysReplacementDuration is the timeout used for replacing cross signing keys without UIA
const crossSigningKeysReplacementDuration = time.Minute * 10
// getCompletedStages returns the completed stages for a session.
func (d *sessionsDict) getCompletedStages(sessionID string) []authtypes.LoginType {
d.RLock()
@ -119,13 +125,54 @@ func (d *sessionsDict) deleteSession(sessionID string) {
}
}
func (d *sessionsDict) allowCrossSigningKeysReplacement(userID string) int64 {
d.Lock()
defer d.Unlock()
allowedUntilTS := time.Now().Add(crossSigningKeysReplacementDuration).UnixMilli()
t, ok := d.crossSigningKeysReplacement[userID]
if ok {
t.Reset(crossSigningKeysReplacementDuration)
return allowedUntilTS
}
d.crossSigningKeysReplacement[userID] = time.AfterFunc(
crossSigningKeysReplacementDuration,
func() {
d.restrictCrossSigningKeysReplacement(userID)
},
)
return allowedUntilTS
}
func (d *sessionsDict) isCrossSigningKeysReplacementAllowed(userID string) bool {
d.RLock()
defer d.RUnlock()
_, ok := d.crossSigningKeysReplacement[userID]
return ok
}
func (d *sessionsDict) restrictCrossSigningKeysReplacement(userID string) {
d.Lock()
defer d.Unlock()
t, ok := d.crossSigningKeysReplacement[userID]
if ok {
if !t.Stop() {
select {
case <-t.C:
default:
}
}
delete(d.crossSigningKeysReplacement, userID)
}
}
func newSessionsDict() *sessionsDict {
return &sessionsDict{
sessions: make(map[string][]authtypes.LoginType),
sessionCompletedResult: make(map[string]registerResponse),
params: make(map[string]registerRequest),
timer: make(map[string]*time.Timer),
deleteSessionToDeviceID: make(map[string]string),
sessions: make(map[string][]authtypes.LoginType),
sessionCompletedResult: make(map[string]registerResponse),
params: make(map[string]registerRequest),
timer: make(map[string]*time.Timer),
deleteSessionToDeviceID: make(map[string]string),
crossSigningKeysReplacement: make(map[string]*time.Timer),
}
}
@ -234,6 +281,7 @@ type userInteractiveResponse struct {
Completed []authtypes.LoginType `json:"completed"`
Params map[string]interface{} `json:"params"`
Session string `json:"session"`
Msg string `json:"msg,omitempty"`
}
// newUserInteractiveResponse will return a struct to be sent back to the client
@ -242,9 +290,10 @@ func newUserInteractiveResponse(
sessionID string,
fs []authtypes.Flow,
params map[string]interface{},
msg string,
) userInteractiveResponse {
return userInteractiveResponse{
fs, sessions.getCompletedStages(sessionID), params, sessionID,
fs, sessions.getCompletedStages(sessionID), params, sessionID, msg,
}
}
@ -817,7 +866,7 @@ func checkAndCompleteFlow(
return util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: newUserInteractiveResponse(sessionID,
cfg.Derived.Registration.Flows, cfg.Derived.Registration.Params),
cfg.Derived.Registration.Flows, cfg.Derived.Registration.Params, ""),
}
}

View file

@ -669,3 +669,57 @@ func TestRegisterAdminUsingSharedSecret(t *testing.T) {
assert.Equal(t, expectedDisplayName, profile.DisplayName)
})
}
func TestCrossSigningKeysReplacement(t *testing.T) {
userID := "@user:example.com"
t.Run("Can add new session", func(t *testing.T) {
s := newSessionsDict()
assert.Empty(t, s.crossSigningKeysReplacement)
s.allowCrossSigningKeysReplacement(userID)
assert.Len(t, s.crossSigningKeysReplacement, 1)
assert.Contains(t, s.crossSigningKeysReplacement, userID)
})
t.Run("Can check if session exists or not", func(t *testing.T) {
s := newSessionsDict()
t.Run("exists", func(t *testing.T) {
s.allowCrossSigningKeysReplacement(userID)
assert.Len(t, s.crossSigningKeysReplacement, 1)
assert.True(t, s.isCrossSigningKeysReplacementAllowed(userID))
})
t.Run("not exists", func(t *testing.T) {
assert.False(t, s.isCrossSigningKeysReplacementAllowed("@random:test.com"))
})
})
t.Run("Can deactivate session", func(t *testing.T) {
s := newSessionsDict()
assert.Empty(t, s.crossSigningKeysReplacement)
t.Run("not exists", func(t *testing.T) {
s.restrictCrossSigningKeysReplacement("@random:test.com")
assert.Empty(t, s.crossSigningKeysReplacement)
})
t.Run("exists", func(t *testing.T) {
s.allowCrossSigningKeysReplacement(userID)
s.restrictCrossSigningKeysReplacement(userID)
assert.Empty(t, s.crossSigningKeysReplacement)
})
})
t.Run("Can erase expired sessions", func(t *testing.T) {
s := newSessionsDict()
s.allowCrossSigningKeysReplacement(userID)
assert.Len(t, s.crossSigningKeysReplacement, 1)
assert.True(t, s.isCrossSigningKeysReplacementAllowed(userID))
timer := s.crossSigningKeysReplacement[userID]
// pretending the timer is expired
timer.Reset(time.Millisecond)
time.Sleep(time.Millisecond * 500)
assert.Empty(t, s.crossSigningKeysReplacement)
})
}

File diff suppressed because it is too large Load diff

View file

@ -55,7 +55,7 @@ var latest, _ = semver.NewVersion("v6.6.6") // Dummy version, used as "HEAD"
// due to the error:
// When using COPY with more than one source file, the destination must be a directory and end with a /
// We need to run a postgres anyway, so use the dockerfile associated with Complement instead.
const DockerfilePostgreSQL = `FROM golang:1.22-bookworm as build
const DockerfilePostgreSQL = `FROM golang:1.23-bookworm as build
RUN apt-get update && apt-get install -y postgresql
WORKDIR /build
ARG BINARY
@ -99,7 +99,7 @@ ENV BINARY=dendrite
EXPOSE 8008 8448
CMD /build/run_dendrite.sh`
const DockerfileSQLite = `FROM golang:1.22-bookworm as build
const DockerfileSQLite = `FROM golang:1.23-bookworm as build
RUN apt-get update && apt-get install -y postgresql
WORKDIR /build
ARG BINARY

View file

@ -303,8 +303,27 @@ media_api:
# Configuration for enabling experimental MSCs on this homeserver.
mscs:
mscs:
# - msc3861 # (Next-gen auth, see https://github.com/matrix-org/matrix-doc/pull/3861)
# - msc2836 # (Threading, see https://github.com/matrix-org/matrix-doc/pull/2836)
# This block has no effect if the feature is not activated in the list above
# msc3861:
# # OIDC issuer advertised by the service.
# # See https://element-hq.github.io/matrix-authentication-service/reference/configuration.html#http
# issuer: "https://mas.example.com/"
# # Credentials used for authenticating requests coming from dendrite to auth service.
# # See https://element-hq.github.io/matrix-authentication-service/reference/configuration.html#clients
# client_id: 01JFNM9MCHKV6A7A0C0RBHMYC0
# client_secret: c85731184ac8f9aea76cf48146046b454473ca667a0cd1fd52a43034a0662eed
# # The service token used for authenticating requests coming from auth service to dendrite.
# # See https://element-hq.github.io/matrix-authentication-service/reference/configuration.html#matrix
# admin_token: ttJORW9oV4Wf4DJ63GdZEYekE2KElP4g
# # URL of the account page on the auth service side
# account_management_url: "https://mas.example.com/account"
# Configuration for the Sync API.
sync_api:
# This option controls which HTTP header to inspect to find the real remote IP

View file

@ -14,7 +14,7 @@ import (
"github.com/element-hq/dendrite/internal/caching"
"github.com/element-hq/dendrite/internal/sqlutil"
"github.com/element-hq/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
)
// NewDatabase opens a new database

2
go.mod
View file

@ -156,6 +156,6 @@ require (
nhooyr.io/websocket v1.8.7 // indirect
)
go 1.22
go 1.23
toolchain go1.23.2

View file

@ -58,17 +58,26 @@ func WithAuth() AuthAPIOption {
}
}
// UserVerifier verifies users by their access tokens. Currently, there are two interface implementations:
// DefaultUserVerifier and MSC3861UserVerifier. The first one checks if the token exists in the server's database,
// whereas the latter passes the token for verification to MAS and acts in accordance with MAS's response.
type UserVerifier interface {
// VerifyUserFromRequest authenticates the HTTP request,
// on success returns Device of the requester.
VerifyUserFromRequest(req *http.Request) (*userapi.Device, *util.JSONResponse)
}
// MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request.
func MakeAuthAPI(
metricsName string, userAPI userapi.QueryAcccessTokenAPI,
metricsName string, userVerifier UserVerifier,
f func(*http.Request, *userapi.Device) util.JSONResponse,
checks ...AuthAPIOption,
) http.Handler {
h := func(req *http.Request) util.JSONResponse {
logger := util.GetLogger(req.Context())
device, err := auth.VerifyUserFromRequest(req, userAPI)
device, err := userVerifier.VerifyUserFromRequest(req)
if err != nil {
logger.Debugf("VerifyUserFromRequest %s -> HTTP %d", req.RemoteAddr, err.Code)
logger.Debugf("VerifyUserFromRequest %s -> HTTP %d: JSON %+v", req.RemoteAddr, err.Code, err.JSON)
return *err
}
// add the user ID to the logger
@ -122,11 +131,11 @@ func MakeAuthAPI(
// MakeAdminAPI is a wrapper around MakeAuthAPI which enforces that the request can only be
// completed by a user that is a server administrator.
func MakeAdminAPI(
metricsName string, userAPI userapi.QueryAcccessTokenAPI,
metricsName string, userVerifier UserVerifier,
f func(*http.Request, *userapi.Device) util.JSONResponse,
) http.Handler {
return MakeAuthAPI(metricsName, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if device.AccountType != userapi.AccountTypeAdmin {
return MakeAuthAPI(metricsName, userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if device == nil || device.AccountType != userapi.AccountTypeAdmin {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("This API can only be used by admin users."),
@ -136,6 +145,38 @@ func MakeAdminAPI(
})
}
// MakeServiceAdminAPI is a wrapper around MakeExternalAPI which enforces that the request can only be
// completed by a trusted service e.g. Matrix Auth Service (MAS).
func MakeServiceAdminAPI(
metricsName, serviceToken string,
f func(*http.Request) util.JSONResponse,
) http.Handler {
h := func(req *http.Request) util.JSONResponse {
logger := util.GetLogger(req.Context())
token, err := auth.ExtractAccessToken(req)
if err != nil {
logger.Debugf("ExtractAccessToken %s -> HTTP %d", req.RemoteAddr, http.StatusUnauthorized)
return util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: spec.MissingToken(err.Error()),
}
}
if token != serviceToken {
logger.Debug("Invalid service token")
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.UnknownToken(token),
}
}
// add the service addr to the logger
logger = logger.WithField("service_useragent", req.UserAgent())
req = req.WithContext(util.ContextWithLogger(req.Context(), logger))
return f(req)
}
return MakeExternalAPI(metricsName, h)
}
// MakeExternalAPI turns a util.JSONRequestHandler function into an http.Handler.
// This is used for APIs that are called from the internet.
func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse) http.Handler {
@ -200,7 +241,7 @@ func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse
// MakeHTTPAPI adds Span metrics to the HTML Handler function
// This is used to serve HTML alongside JSON error messages
func MakeHTTPAPI(metricsName string, userAPI userapi.QueryAcccessTokenAPI, enableMetrics bool, f func(http.ResponseWriter, *http.Request), checks ...AuthAPIOption) http.Handler {
func MakeHTTPAPI(metricsName string, userVerifier UserVerifier, enableMetrics bool, f func(http.ResponseWriter, *http.Request), checks ...AuthAPIOption) http.Handler {
withSpan := func(w http.ResponseWriter, req *http.Request) {
if req.Method == http.MethodOptions {
util.SetCORSHeaders(w)
@ -220,7 +261,7 @@ func MakeHTTPAPI(metricsName string, userAPI userapi.QueryAcccessTokenAPI, enabl
if opts.WithAuth {
logger := util.GetLogger(req.Context())
_, jsonErr := auth.VerifyUserFromRequest(req, userAPI)
_, jsonErr := userVerifier.VerifyUserFromRequest(req)
if jsonErr != nil {
w.WriteHeader(jsonErr.Code)
if err := json.NewEncoder(w).Encode(jsonErr.JSON); err != nil {

View file

@ -10,6 +10,8 @@ import (
"net/http"
"net/http/httptest"
"testing"
"github.com/matrix-org/util"
)
func TestWrapHandlerInBasicAuth(t *testing.T) {
@ -99,3 +101,68 @@ func TestWrapHandlerInBasicAuth(t *testing.T) {
})
}
}
func TestMakeServiceAdminAPI(t *testing.T) {
serviceToken := "valid_secret_token"
type args struct {
f func(*http.Request) util.JSONResponse
serviceToken string
}
f := func(*http.Request) util.JSONResponse {
return util.JSONResponse{Code: http.StatusOK}
}
tests := []struct {
name string
args args
want int
reqAuth bool
}{
{
name: "service token valid",
args: args{
f: f,
serviceToken: serviceToken,
},
want: http.StatusOK,
reqAuth: true,
},
{
name: "service token invalid",
args: args{
f: f,
serviceToken: "invalid_service_token",
},
want: http.StatusForbidden,
reqAuth: true,
},
{
name: "service token is missing",
args: args{
f: f,
serviceToken: "",
},
want: http.StatusUnauthorized,
reqAuth: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handler := MakeServiceAdminAPI("metrics", serviceToken, tt.args.f)
req := httptest.NewRequest("GET", "http://localhost/admin/v1/username_available", nil)
if tt.reqAuth {
req.Header.Add("Authorization", "Bearer "+tt.args.serviceToken)
}
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
resp := w.Result()
if resp.StatusCode != tt.want {
t.Errorf("Expected status code %d, got %d", resp.StatusCode, tt.want)
}
})
}
}

View file

@ -14,7 +14,6 @@ import (
"github.com/element-hq/dendrite/mediaapi/routing"
"github.com/element-hq/dendrite/mediaapi/storage"
"github.com/element-hq/dendrite/setup/config"
userapi "github.com/element-hq/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/fclient"
)
@ -24,10 +23,10 @@ func AddPublicRoutes(
routers httputil.Routers,
cm *sqlutil.Connections,
cfg *config.Dendrite,
userAPI userapi.MediaUserAPI,
client *fclient.Client,
fedClient fclient.FederationClient,
keyRing gomatrixserverlib.JSONVerifier,
userVerifier httputil.UserVerifier,
) {
mediaDB, err := storage.NewMediaAPIDatasource(cm, &cfg.MediaAPI.Database)
if err != nil {
@ -35,6 +34,6 @@ func AddPublicRoutes(
}
routing.Setup(
routers, cfg, mediaDB, userAPI, client, fedClient, keyRing,
routers, cfg, mediaDB, client, fedClient, keyRing, userVerifier,
)
}

View file

@ -42,10 +42,10 @@ func Setup(
routers httputil.Routers,
cfg *config.Dendrite,
db storage.Database,
userAPI userapi.MediaUserAPI,
client *fclient.Client,
federationClient fclient.FederationClient,
keyRing gomatrixserverlib.JSONVerifier,
userVerifier httputil.UserVerifier,
) {
rateLimits := httputil.NewRateLimits(&cfg.ClientAPI.RateLimiting)
@ -58,7 +58,7 @@ func Setup(
}
uploadHandler := httputil.MakeAuthAPI(
"upload", userAPI,
"upload", userVerifier,
func(req *http.Request, dev *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req, dev); r != nil {
return *r
@ -67,7 +67,7 @@ func Setup(
},
)
configHandler := httputil.MakeAuthAPI("config", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
configHandler := httputil.MakeAuthAPI("config", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req, device); r != nil {
return *r
}
@ -97,13 +97,13 @@ func Setup(
).Methods(http.MethodGet, http.MethodOptions)
// v1 client endpoints requiring auth
downloadHandlerAuthed := httputil.MakeHTTPAPI("download", userAPI, cfg.Global.Metrics.Enabled, makeDownloadAPI("download_authed_client", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth())
downloadHandlerAuthed := httputil.MakeHTTPAPI("download", userVerifier, cfg.Global.Metrics.Enabled, makeDownloadAPI("download_authed_client", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth())
v1mux.Handle("/config", configHandler).Methods(http.MethodGet, http.MethodOptions)
v1mux.Handle("/download/{serverName}/{mediaId}", downloadHandlerAuthed).Methods(http.MethodGet, http.MethodOptions)
v1mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandlerAuthed).Methods(http.MethodGet, http.MethodOptions)
v1mux.Handle("/thumbnail/{serverName}/{mediaId}",
httputil.MakeHTTPAPI("thumbnail", userAPI, cfg.Global.Metrics.Enabled, makeDownloadAPI("thumbnail_authed_client", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth()),
httputil.MakeHTTPAPI("thumbnail", userVerifier, cfg.Global.Metrics.Enabled, makeDownloadAPI("thumbnail_authed_client", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth()),
).Methods(http.MethodGet, http.MethodOptions)
// same, but for federation

View file

@ -13,7 +13,7 @@ import (
"github.com/element-hq/dendrite/internal/sqlutil"
"github.com/element-hq/dendrite/relayapi/storage/sqlite3"
"github.com/element-hq/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
)
// NewDatabase opens a new database

View file

@ -7,6 +7,7 @@ import (
"testing"
"time"
"github.com/element-hq/dendrite/clientapi/auth"
"github.com/element-hq/dendrite/federationapi/statistics"
"github.com/element-hq/dendrite/internal/caching"
"github.com/element-hq/dendrite/internal/eventutil"
@ -267,7 +268,8 @@ func TestPurgeRoom(t *testing.T) {
rsAPI.SetFederationAPI(fsAPI, nil)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, fsAPI.IsBlacklistedOrBackingOff)
syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, userAPI, rsAPI, caches, caching.DisableMetrics)
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, userAPI, rsAPI, caches, &userVerifier, caching.DisableMetrics)
// Create the room
if err = api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {

View file

@ -74,34 +74,45 @@ func (c *ClientAPI) Defaults(opts DefaultOpts) {
func (c *ClientAPI) Verify(configErrs *ConfigErrors) {
c.TURN.Verify(configErrs)
c.RateLimiting.Verify(configErrs)
if c.RecaptchaEnabled {
if c.RecaptchaSiteVerifyAPI == "" {
c.RecaptchaSiteVerifyAPI = "https://www.google.com/recaptcha/api/siteverify"
if c.MSCs.MSC3861Enabled() {
if !c.RegistrationDisabled || c.RecaptchaEnabled {
configErrs.Add(
"You have enabled the experimental feature MSC3861 which implements the delegated authentication via OIDC. " +
"As a result, the feature conflicts with the standard Dendrite's registration and login flows and cannot be used if any of those is enabled. " +
"You need to disable registration (client_api.registration_disabled) and recapthca (client_api.enable_registration_captcha) options to proceed.",
)
}
if c.RecaptchaApiJsUrl == "" {
c.RecaptchaApiJsUrl = "https://www.google.com/recaptcha/api.js"
} else {
if c.RecaptchaEnabled {
if c.RecaptchaSiteVerifyAPI == "" {
c.RecaptchaSiteVerifyAPI = "https://www.google.com/recaptcha/api/siteverify"
}
if c.RecaptchaApiJsUrl == "" {
c.RecaptchaApiJsUrl = "https://www.google.com/recaptcha/api.js"
}
if c.RecaptchaFormField == "" {
c.RecaptchaFormField = "g-recaptcha-response"
}
if c.RecaptchaSitekeyClass == "" {
c.RecaptchaSitekeyClass = "g-recaptcha"
}
checkNotEmpty(configErrs, "client_api.recaptcha_public_key", c.RecaptchaPublicKey)
checkNotEmpty(configErrs, "client_api.recaptcha_private_key", c.RecaptchaPrivateKey)
checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", c.RecaptchaSiteVerifyAPI)
checkNotEmpty(configErrs, "client_api.recaptcha_sitekey_class", c.RecaptchaSitekeyClass)
}
if c.RecaptchaFormField == "" {
c.RecaptchaFormField = "g-recaptcha-response"
// Ensure there is any spam counter measure when enabling registration
if !c.RegistrationDisabled && !c.OpenRegistrationWithoutVerificationEnabled && !c.RecaptchaEnabled {
configErrs.Add(
"You have tried to enable open registration without any secondary verification methods " +
"(such as reCAPTCHA). By enabling open registration, you are SIGNIFICANTLY " +
"increasing the risk that your server will be used to send spam or abuse, and may result in " +
"your server being banned from some rooms. If you are ABSOLUTELY CERTAIN you want to do this, " +
"start Dendrite with the -really-enable-open-registration command line flag. Otherwise, you " +
"should set the registration_disabled option in your Dendrite config.",
)
}
if c.RecaptchaSitekeyClass == "" {
c.RecaptchaSitekeyClass = "g-recaptcha"
}
checkNotEmpty(configErrs, "client_api.recaptcha_public_key", c.RecaptchaPublicKey)
checkNotEmpty(configErrs, "client_api.recaptcha_private_key", c.RecaptchaPrivateKey)
checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", c.RecaptchaSiteVerifyAPI)
checkNotEmpty(configErrs, "client_api.recaptcha_sitekey_class", c.RecaptchaSitekeyClass)
}
// Ensure there is any spam counter measure when enabling registration
if !c.RegistrationDisabled && !c.OpenRegistrationWithoutVerificationEnabled && !c.RecaptchaEnabled {
configErrs.Add(
"You have tried to enable open registration without any secondary verification methods " +
"(such as reCAPTCHA). By enabling open registration, you are SIGNIFICANTLY " +
"increasing the risk that your server will be used to send spam or abuse, and may result in " +
"your server being banned from some rooms. If you are ABSOLUTELY CERTAIN you want to do this, " +
"start Dendrite with the -really-enable-open-registration command line flag. Otherwise, you " +
"should set the registration_disabled option in your Dendrite config.",
)
}
}

View file

@ -4,11 +4,16 @@ type MSCs struct {
Matrix *Global `yaml:"-"`
// The MSCs to enable. Supported MSCs include:
// 'msc3861': Delegate auth to an OIDC provider - https://github.com/matrix-org/matrix-spec-proposals/pull/3861
// 'msc2444': Peeking over federation - https://github.com/matrix-org/matrix-doc/pull/2444
// 'msc2753': Peeking via /sync - https://github.com/matrix-org/matrix-doc/pull/2753
// 'msc2836': Threading - https://github.com/matrix-org/matrix-doc/pull/2836
MSCs []string `yaml:"mscs"`
// MSC3861 contains config related to the experimental feature MSC3861.
// It takes effect only if 'msc3861' is included in 'MSCs' array.
MSC3861 *MSC3861 `yaml:"msc3861,omitempty"`
Database DatabaseOptions `yaml:"database,omitempty"`
}
@ -34,4 +39,27 @@ func (c *MSCs) Verify(configErrs *ConfigErrors) {
if c.Matrix.DatabaseOptions.ConnectionString == "" {
checkNotEmpty(configErrs, "mscs.database.connection_string", string(c.Database.ConnectionString))
}
if m := c.MSC3861; m != nil && c.MSC3861Enabled() {
m.Verify(configErrs)
}
}
func (c *MSCs) MSC3861Enabled() bool {
return c.Enabled("msc3861") && c.MSC3861 != nil
}
type MSC3861 struct {
Issuer string `yaml:"issuer"`
ClientID string `yaml:"client_id"`
ClientSecret string `yaml:"client_secret"`
AdminToken string `yaml:"admin_token"`
AccountManagementURL string `yaml:"account_management_url"`
}
func (m *MSC3861) Verify(configErrs *ConfigErrors) {
checkNotEmpty(configErrs, "mscs.msc3861.issuer", string(m.Issuer))
checkNotEmpty(configErrs, "mscs.msc3861.client_id", string(m.ClientID))
checkNotEmpty(configErrs, "mscs.msc3861.client_secret", string(m.ClientSecret))
checkNotEmpty(configErrs, "mscs.msc3861.admin_token", string(m.AdminToken))
checkNotEmpty(configErrs, "mscs.msc3861.account_management_url", string(m.AccountManagementURL))
}

View file

@ -7,9 +7,12 @@
package setup
import (
"net/http"
appserviceAPI "github.com/element-hq/dendrite/appservice/api"
"github.com/element-hq/dendrite/clientapi"
"github.com/element-hq/dendrite/clientapi/api"
"github.com/element-hq/dendrite/clientapi/auth"
"github.com/element-hq/dendrite/federationapi"
federationAPI "github.com/element-hq/dendrite/federationapi/api"
"github.com/element-hq/dendrite/internal/caching"
@ -27,6 +30,7 @@ import (
userapi "github.com/element-hq/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/util"
)
// Monolith represents an instantiation of all dependencies required to build
@ -46,6 +50,8 @@ type Monolith struct {
// Optional
ExtPublicRoomsProvider api.ExtraPublicRoomsProvider
ExtUserDirectoryProvider userapi.QuerySearchProfilesAPI
UserVerifierProvider httputil.UserVerifier
}
// AddAllPublicRoutes attaches all public paths to the given router
@ -58,6 +64,10 @@ func (m *Monolith) AddAllPublicRoutes(
caches *caching.Caches,
enableMetrics bool,
) {
if m.UserVerifierProvider == nil {
m.UserVerifierProvider = NewUserVerifierProvider(&auth.DefaultUserVerifier{UserAPI: m.UserAPI})
}
userDirectoryProvider := m.ExtUserDirectoryProvider
if userDirectoryProvider == nil {
userDirectoryProvider = m.UserAPI
@ -65,15 +75,29 @@ func (m *Monolith) AddAllPublicRoutes(
clientapi.AddPublicRoutes(
processCtx, routers, cfg, natsInstance, m.FedClient, m.RoomserverAPI, m.AppserviceAPI, transactions.New(),
m.FederationAPI, m.UserAPI, userDirectoryProvider,
m.ExtPublicRoomsProvider, enableMetrics,
m.ExtPublicRoomsProvider, m.UserVerifierProvider, enableMetrics,
)
federationapi.AddPublicRoutes(
processCtx, routers, cfg, natsInstance, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI, enableMetrics,
)
mediaapi.AddPublicRoutes(routers, cm, cfg, m.UserAPI, m.Client, m.FedClient, m.KeyRing)
syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, natsInstance, m.UserAPI, m.RoomserverAPI, caches, enableMetrics)
mediaapi.AddPublicRoutes(routers, cm, cfg, m.Client, m.FedClient, m.KeyRing, m.UserVerifierProvider)
syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, natsInstance, m.UserAPI, m.RoomserverAPI, caches, m.UserVerifierProvider, enableMetrics)
if m.RelayAPI != nil {
relayapi.AddPublicRoutes(routers, cfg, m.KeyRing, m.RelayAPI)
}
}
type UserVerifierProvider struct {
httputil.UserVerifier
}
func (u *UserVerifierProvider) VerifyUserFromRequest(req *http.Request) (*userapi.Device, *util.JSONResponse) {
return u.UserVerifier.VerifyUserFromRequest(req)
}
func NewUserVerifierProvider(userVerifier httputil.UserVerifier) *UserVerifierProvider {
return &UserVerifierProvider{
UserVerifier: userVerifier,
}
}

View file

@ -98,7 +98,7 @@ func toClientResponse(ctx context.Context, res *MSC2836EventRelationshipsRespons
// Enable this MSC
func Enable(
cfg *config.Dendrite, cm *sqlutil.Connections, routers httputil.Routers, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI,
userAPI userapi.UserInternalAPI, keyRing gomatrixserverlib.JSONVerifier,
userVerifier httputil.UserVerifier, keyRing gomatrixserverlib.JSONVerifier,
) error {
db, err := NewDatabase(cm, &cfg.MSCs.Database)
if err != nil {
@ -124,7 +124,7 @@ func Enable(
})
routers.Client.Handle("/unstable/event_relationships",
httputil.MakeAuthAPI("eventRelationships", userAPI, eventRelationshipHandler(db, rsAPI, fsAPI)),
httputil.MakeAuthAPI("eventRelationships", userVerifier, eventRelationshipHandler(db, rsAPI, fsAPI)),
).Methods(http.MethodPost, http.MethodOptions)
routers.Federation.Handle("/unstable/event_relationships", httputil.MakeExternalAPI(

View file

@ -14,6 +14,7 @@ import (
"testing"
"time"
"github.com/element-hq/dendrite/clientapi/auth"
"github.com/element-hq/dendrite/setup/process"
"github.com/element-hq/dendrite/syncapi/synctypes"
"github.com/gorilla/mux"
@ -571,7 +572,8 @@ func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserve
processCtx := process.NewProcessContext()
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
routers := httputil.NewRouters()
err := msc2836.Enable(cfg, cm, routers, rsAPI, nil, userAPI, nil)
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
err := msc2836.Enable(cfg, cm, routers, rsAPI, nil, &userVerifier, nil)
if err != nil {
t.Fatalf("failed to enable MSC2836: %s", err)
}

View file

@ -0,0 +1,38 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
package msc3861
import (
"errors"
"github.com/element-hq/dendrite/setup"
"github.com/matrix-org/gomatrixserverlib/fclient"
)
func Enable(m *setup.Monolith) error {
client := fclient.NewClient()
userVerifier, err := newMSC3861UserVerifier(
m.UserAPI, m.Config.Global.ServerName,
m.Config.MSCs.MSC3861, !m.Config.ClientAPI.GuestsDisabled,
client,
)
if err != nil {
return err
}
if m.UserVerifierProvider == nil {
return errors.New("msc3861: UserVerifierProvider is not initialised")
}
provider, ok := m.UserVerifierProvider.(*setup.UserVerifierProvider)
if !ok {
return errors.New("msc3861: the expected type of m.UserVerifierProvider is *setup.UserVerifierProvider")
}
provider.UserVerifier = userVerifier
return nil
}

View file

@ -0,0 +1,458 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
package msc3861
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"slices"
"strings"
"github.com/element-hq/dendrite/clientapi/auth"
"github.com/element-hq/dendrite/setup/config"
"github.com/element-hq/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util"
)
const externalAuthProvider string = "oauth-delegated"
// Scopes as defined by MSC2967
// https://github.com/matrix-org/matrix-spec-proposals/pull/2967
const (
scopeMatrixAPI string = "urn:matrix:org.matrix.msc2967.client:api:*"
scopeMatrixGuest string = "urn:matrix:org.matrix.msc2967.client:api:guest"
scopeMatrixDevicePrefix string = "urn:matrix:org.matrix.msc2967.client:device:"
)
type errCode string
const (
codeIntrospectionNot2xx errCode = "introspectionIsNot2xx"
codeInvalidClientToken errCode = "invalidClientToken"
codeAuthError errCode = "authError"
codeMxidError errCode = "mxidError"
codeOpenidConfigEndpointNon2xx errCode = "openidConfigEndpointNon2xx"
codeOpenidConfigDecodingFailed errCode = "openidConfigDecodingFailed"
)
// MSC3861UserVerifier implements UserVerifier interface
type MSC3861UserVerifier struct {
userAPI api.UserInternalAPI
serverName spec.ServerName
cfg *config.MSC3861
httpClient *fclient.Client
openIdConfig *OpenIDConfiguration
allowGuest bool
}
func newMSC3861UserVerifier(
userAPI api.UserInternalAPI,
serverName spec.ServerName,
cfg *config.MSC3861,
allowGuest bool,
client *fclient.Client,
) (*MSC3861UserVerifier, error) {
if cfg == nil {
return nil, errors.New("unable to create MSC3861UserVerifier object as 'cfg' param is nil")
}
if client == nil {
return nil, errors.New("unable to create MSC3861UserVerifier object as 'client' param is nil")
}
openIdConfig, err := fetchOpenIDConfiguration(client, cfg.Issuer)
if err != nil {
return nil, err
}
return &MSC3861UserVerifier{
userAPI: userAPI,
serverName: serverName,
cfg: cfg,
openIdConfig: openIdConfig,
allowGuest: allowGuest,
httpClient: client,
}, nil
}
type mscError struct {
Code errCode
Msg string
}
func (r *mscError) Error() string {
return fmt.Sprintf("%s: %s", r.Code, r.Msg)
}
// VerifyUserFromRequest authenticates the HTTP request, on success returns Device of the requester.
func (m *MSC3861UserVerifier) VerifyUserFromRequest(req *http.Request) (*api.Device, *util.JSONResponse) {
ctx := req.Context()
util.GetLogger(ctx).Debug("MSC3861.VerifyUserFromRequest")
// Try to find the Application Service user
token, err := auth.ExtractAccessToken(req)
if err != nil {
return nil, &util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: spec.MissingToken(err.Error()),
}
}
if appServiceUserID := req.URL.Query().Get("user_id"); appServiceUserID != "" {
var res api.QueryAccessTokenResponse
err = m.userAPI.QueryAccessToken(ctx, &api.QueryAccessTokenRequest{
AccessToken: token,
AppServiceUserID: appServiceUserID,
}, &res)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.QueryAccessToken failed")
return nil, &util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
}
userData, err := m.getUserByAccessToken(ctx, token)
if err != nil {
switch e := err.(type) {
case (*mscError):
switch e.Code {
case codeIntrospectionNot2xx, codeOpenidConfigDecodingFailed, codeOpenidConfigEndpointNon2xx:
return nil, &util.JSONResponse{
Code: http.StatusServiceUnavailable,
JSON: spec.Unknown(e.Error()),
}
case codeInvalidClientToken:
return nil, &util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: spec.UnknownToken(e.Error()),
}
case codeAuthError, codeMxidError:
return nil, &util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.Unknown(e.Error()),
}
default:
r := util.ErrorResponse(err)
return nil, &r
}
default:
r := util.ErrorResponse(err)
return nil, &r
}
}
// Do not record requests from MAS using the virtual `__oidc_admin` user.
if token != m.cfg.AdminToken {
// XXX: not sure which exact data we should record here. See the link for reference
// https://github.com/element-hq/synapse/blob/develop/synapse/api/auth/base.py#L365
}
if !m.allowGuest && userData.IsGuest {
return nil, &util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: spec.Forbidden(strings.Join([]string{"Insufficient scope: ", scopeMatrixAPI}, "")),
}
}
return userData.Device, nil
}
type requester struct {
Device *api.Device
UserID *spec.UserID
Scope []string
IsGuest bool
}
// nolint: gocyclo
func (m *MSC3861UserVerifier) getUserByAccessToken(ctx context.Context, token string) (*requester, error) {
var userID *spec.UserID
logger := util.GetLogger(ctx)
if adminToken := m.cfg.AdminToken; adminToken != "" && token == adminToken {
// XXX: This is a temporary solution so that the admin API can be called by
// the OIDC provider. This will be removed once we have OIDC client
// credentials grant support in matrix-authentication-service.
// XXX: that user doesn't exist and won't be provisioned.
adminUser, err := createUserID("__oidc_admin", m.serverName)
if err != nil {
return nil, err
}
return &requester{
UserID: adminUser,
Scope: []string{"urn:synapse:admin:*"},
Device: &api.Device{UserID: adminUser.Local(), AccountType: api.AccountTypeOIDCService},
}, nil
}
introspectionResult, err := m.introspectToken(ctx, token)
if err != nil {
logger.WithError(err).Error("MSC3861UserVerifier:introspectToken")
return nil, err
}
if !introspectionResult.Active {
return nil, &mscError{Code: codeInvalidClientToken, Msg: "Token is not active"}
}
scopes := introspectionResult.Scopes()
hasUserScope, hasGuestScope := slices.Contains(scopes, scopeMatrixAPI), slices.Contains(scopes, scopeMatrixGuest)
if !hasUserScope && !hasGuestScope {
return nil, &mscError{Code: codeInvalidClientToken, Msg: "No scope in token granting user rights"}
}
sub := introspectionResult.Sub
if sub == "" {
return nil, &mscError{Code: codeInvalidClientToken, Msg: "Invalid sub claim in the introspection result"}
}
localpart := ""
localpartExternalID, err := m.userAPI.QueryExternalUserIDByLocalpartAndProvider(ctx, sub, externalAuthProvider)
if err != nil && err != sql.ErrNoRows {
return nil, err
}
if localpartExternalID != nil {
localpart = localpartExternalID.Localpart
}
if localpart == "" {
// If we could not find a user via the external_id, it either does not exist,
// or the external_id was never recorded
username := introspectionResult.Username
if username == "" {
return nil, &mscError{Code: codeAuthError, Msg: "Invalid username claim in the introspection result"}
}
userID, err = createUserID(username, m.serverName)
if err != nil {
logger.WithError(err).Error("getUserByAccessToken:createUserID")
return nil, err
}
// First try to find a user from the username claim
var account *api.Account
{
var rs api.QueryAccountByLocalpartResponse
err = m.userAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{Localpart: userID.Local(), ServerName: userID.Domain()}, &rs)
if err != nil && err != sql.ErrNoRows {
logger.WithError(err).Error("QueryAccountByLocalpart")
return nil, err
}
account = rs.Account
}
if account == nil {
// If the user does not exist, we should create it on the fly
var rs api.PerformAccountCreationResponse
if err = m.userAPI.PerformAccountCreation(ctx, &api.PerformAccountCreationRequest{
AccountType: api.AccountTypeUser,
Localpart: userID.Local(),
ServerName: userID.Domain(),
}, &rs); err != nil {
logger.WithError(err).Error("PerformAccountCreation")
return nil, err
}
}
if err = m.userAPI.PerformLocalpartExternalUserIDCreation(ctx, userID.Local(), sub, externalAuthProvider); err != nil {
logger.WithError(err).Error("PerformLocalpartExternalUserIDCreation")
return nil, err
}
localpart = userID.Local()
}
if userID == nil {
userID, err = createUserID(localpart, m.serverName)
if err != nil {
logger.WithError(err).Error("getUserByAccessToken:createUserID")
return nil, err
}
}
deviceIDs := make([]string, 0, 1)
for i := range scopes {
if s := scopes[i]; strings.HasPrefix(s, scopeMatrixDevicePrefix) {
deviceIDs = append(deviceIDs, s[len(scopeMatrixDevicePrefix):])
}
}
if len(deviceIDs) != 1 {
logger.Errorf("Invalid device IDs in scope: %+v", deviceIDs)
return nil, &mscError{Code: codeAuthError, Msg: "Invalid device IDs in scope"}
}
var device *api.Device
deviceID := deviceIDs[0]
if len(deviceID) > 255 || len(deviceID) < 1 {
return nil, &mscError{
Code: codeAuthError,
Msg: strings.Join([]string{"Invalid device ID in scope: ", deviceID}, ""),
}
}
userDeviceExists := false
{
var rs api.QueryDevicesResponse
err := m.userAPI.QueryDevices(ctx, &api.QueryDevicesRequest{UserID: userID.String()}, &rs)
if err != nil && err != sql.ErrNoRows {
return nil, err
}
for i := range rs.Devices {
if d := &rs.Devices[i]; d.ID == deviceID {
userDeviceExists = true
device = d
break
}
}
}
if !userDeviceExists {
var rs api.PerformDeviceCreationResponse
deviceDisplayName := "OIDC-native client"
if err := m.userAPI.PerformDeviceCreation(ctx, &api.PerformDeviceCreationRequest{
Localpart: localpart,
ServerName: m.serverName,
AccessToken: "",
DeviceID: &deviceID,
DeviceDisplayName: &deviceDisplayName,
AccessTokenUniqueConstraintDisabled: true,
// TODO: Cannot add IPAddr and Useragent values here. Should we care about it here?
}, &rs); err != nil {
logger.WithError(err).Error("PerformDeviceCreation")
return nil, err
}
device = rs.Device
logger.Debugf("PerformDeviceCreationResponse is: %+v", rs)
}
return &requester{
Device: device,
UserID: userID,
Scope: scopes,
IsGuest: hasGuestScope && !hasUserScope,
}, nil
}
func createUserID(local string, serverName spec.ServerName) (*spec.UserID, error) {
userID, err := spec.NewUserID(strings.Join([]string{"@", local, ":", string(serverName)}, ""), false)
if err != nil {
return nil, &mscError{Code: codeMxidError, Msg: err.Error()}
}
return userID, nil
}
func (m *MSC3861UserVerifier) introspectToken(ctx context.Context, token string) (*introspectionResponse, error) {
formBody := url.Values{"token": []string{token}}
encoded := formBody.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, m.openIdConfig.IntrospectionEndpoint, strings.NewReader(encoded))
if err != nil {
return nil, err
}
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth(m.cfg.ClientID, m.cfg.ClientSecret)
resp, err := m.httpClient.DoHTTPRequest(ctx, req)
if err != nil {
return nil, err
}
defer resp.Body.Close() // nolint: errcheck
if c := resp.StatusCode; c/100 != 2 {
return nil, errors.New(strings.Join([]string{"The introspection endpoint returned a '", resp.Status, "' response"}, ""))
}
var ir introspectionResponse
if err := json.NewDecoder(resp.Body).Decode(&ir); err != nil {
return nil, err
}
return &ir, nil
}
type OpenIDConfiguration struct {
Issuer string `json:"issuer"`
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
JWKsURI string `json:"jwks_uri"`
RegistrationEndpoint string `json:"registration_endpoint"`
ScopesSupported []string `json:"scopes_supported"`
ResponseTypesSupported []string `json:"response_types_supported"`
ResponseModesSupported []string `json:"response_modes_supported"`
GrantTypesSupported []string `json:"grant_types_supported"`
TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported"`
TokenEndpointAuthSigningAlgCaluesSupported []string `json:"token_endpoint_auth_signing_alg_values_supported"`
RevocationEnpoint string `json:"revocation_endpoint"`
RevocationEndpointAuthMethodsSupported []string `json:"revocation_endpoint_auth_methods_supported"`
RevocationEndpointAuthSigningAlgValues []string `json:"revocation_endpoint_auth_signing_alg_values_supported"`
IntrospectionEndpoint string `json:"introspection_endpoint"`
IntrospectionEndpointAuthMethodsSupported []string `json:"introspection_endpoint_auth_methods_supported"`
IntrospectionEndpointAuthSigningAlgValues []string `json:"introspection_endpoint_auth_signing_alg_values_supported"`
CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"`
UserinfoEndpoint string `json:"userinfo_endpoint"`
SubjectTypesSupported []string `json:"subject_types_supported"`
IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"`
UserinfoSigningAlgValuesSupported []string `json:"userinfo_signing_alg_values_supported"`
DisplayValuesSupported []string `json:"display_values_supported"`
ClaimTypesSupported []string `json:"claim_types_supported"`
ClaimsSupported []string `json:"claims_supported"`
ClaimsParameterSupported bool `json:"claims_parameter_supported"`
RequestParameterSupported bool `json:"request_parameter_supported"`
RequestURIParameterSupported bool `json:"request_uri_parameter_supported"`
PromptValuesSupported []string `json:"prompt_values_supported"`
DeviceAuthorizaEndpoint string `json:"device_authorization_endpoint"`
AccountManagementURI string `json:"account_management_uri"`
AccountManagementActionsSupported []string `json:"account_management_actions_supported"`
}
func fetchOpenIDConfiguration(httpClient *fclient.Client, authHostURL string) (*OpenIDConfiguration, error) {
u, err := url.Parse(authHostURL)
if err != nil {
return nil, err
}
u = u.JoinPath(".well-known/openid-configuration")
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
if err != nil {
return nil, err
}
resp, err := httpClient.DoHTTPRequest(context.Background(), req)
if err != nil {
return nil, err
}
defer resp.Body.Close() // nolint: errcheck
if resp.StatusCode != http.StatusOK {
return nil, &mscError{Code: codeOpenidConfigEndpointNon2xx, Msg: ".well-known/openid-configuration endpoint returned non-200 response"}
}
var oic OpenIDConfiguration
if err := json.NewDecoder(resp.Body).Decode(&oic); err != nil {
return nil, &mscError{Code: codeOpenidConfigDecodingFailed, Msg: err.Error()}
}
return &oic, nil
}
// introspectionResponse as described in the RFC https://datatracker.ietf.org/doc/html/rfc7662#section-2.2
type introspectionResponse struct {
Active bool `json:"active"` // required
Scope string `json:"scope"` // optional
Username string `json:"username"` // optional
TokenType string `json:"token_type"` // optional
Exp *int64 `json:"exp"` // optional
Iat *int64 `json:"iat"` // optional
Nfb *int64 `json:"nfb"` // optional
Sub string `json:"sub"` // optional
Jti string `json:"jti"` // optional
Aud string `json:"aud"` // optional
Iss string `json:"iss"` // optional
}
func (i *introspectionResponse) Scopes() []string {
return strings.Split(i.Scope, " ")
}

View file

@ -0,0 +1,234 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
package msc3861
import (
"context"
"encoding/json"
"io"
"net/http"
"net/url"
"strings"
"testing"
"time"
"errors"
"github.com/element-hq/dendrite/federationapi/statistics"
"github.com/element-hq/dendrite/internal/caching"
"github.com/element-hq/dendrite/internal/sqlutil"
"github.com/element-hq/dendrite/roomserver"
"github.com/element-hq/dendrite/setup/config"
"github.com/element-hq/dendrite/setup/jetstream"
"github.com/element-hq/dendrite/test"
"github.com/element-hq/dendrite/test/testrig"
"github.com/element-hq/dendrite/userapi"
uapi "github.com/element-hq/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/gomatrixserverlib/spec"
)
var testIsBlacklistedOrBackingOff = func(s spec.ServerName) (*statistics.ServerStatistics, error) {
return &statistics.ServerStatistics{}, nil
}
type roundTripper struct {
roundTrip func(request *http.Request) (*http.Response, error)
}
func (rt *roundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
return rt.roundTrip(request)
}
func TestVerifyUserFromRequest(t *testing.T) {
aliceUser := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
bobUser := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
roundTrip := func(request *http.Request) (*http.Response, error) {
var (
respBody string
statusCode int
)
switch request.URL.String() {
case "https://mas.example.com/.well-known/openid-configuration":
respBody = `{"introspection_endpoint": "https://mas.example.com/oauth2/introspect"}`
statusCode = http.StatusOK
case "https://mas.example.com/oauth2/introspect":
_ = request.ParseForm()
switch request.Form.Get("token") {
case "validTokenUserExistsTokenActive":
statusCode = http.StatusOK
resp := introspectionResponse{
Active: true,
Scope: "urn:matrix:org.matrix.msc2967.client:device:devAlice urn:matrix:org.matrix.msc2967.client:api:*",
Sub: "111111111111111111",
Username: aliceUser.Localpart,
}
b, _ := json.Marshal(resp)
respBody = string(b)
case "validTokenUserDoesNotExistTokenActive":
statusCode = http.StatusOK
resp := introspectionResponse{
Active: true,
Scope: "urn:matrix:org.matrix.msc2967.client:device:devBob urn:matrix:org.matrix.msc2967.client:api:*",
Sub: "222222222222222222",
Username: bobUser.Localpart,
}
b, _ := json.Marshal(resp)
respBody = string(b)
case "validTokenUserExistsTokenInactive":
statusCode = http.StatusOK
resp := introspectionResponse{Active: false}
b, _ := json.Marshal(resp)
respBody = string(b)
default:
return nil, errors.New("Request URL not supported by stub")
}
}
respReader := io.NopCloser(strings.NewReader(respBody))
resp := http.Response{
StatusCode: statusCode,
Body: respReader,
ContentLength: int64(len(respBody)),
Header: map[string][]string{"Content-Type": {"application/json"}},
}
return &resp, nil
}
httpClient := fclient.NewClient(
fclient.WithTransport(&roundTripper{roundTrip: roundTrip}),
)
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
defer close()
cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{
Issuer: "https://mas.example.com",
}
cfg.ClientAPI.RateLimiting.Enabled = false
natsInstance := jetstream.NATSInstance{}
// add a vhost
cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{
SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"},
})
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
// Needed for /login
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
userVerifier, err := newMSC3861UserVerifier(
userAPI,
cfg.Global.ServerName,
cfg.MSCs.MSC3861,
false,
httpClient,
)
if err != nil {
t.Fatal(err.Error())
}
u, _ := url.Parse("https://example.com/something")
t.Run("existing user and active token", func(t *testing.T) {
localpart, serverName, _ := gomatrixserverlib.SplitID('@', aliceUser.ID)
userRes := &uapi.PerformAccountCreationResponse{}
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
AccountType: aliceUser.AccountType,
Localpart: localpart,
ServerName: serverName,
}, userRes); err != nil {
t.Errorf("failed to create account: %s", err)
}
if !userRes.AccountCreated {
t.Fatalf("account not created")
}
httpReq := http.Request{
URL: u,
Header: map[string][]string{
"Content-Type": {"application/json"},
"Authorization": {"Bearer validTokenUserExistsTokenActive"},
},
}
device, jsonResp := userVerifier.VerifyUserFromRequest(&httpReq)
if jsonResp != nil {
t.Fatalf("JSONResponse is not expected: %+v", jsonResp)
}
deviceRes := uapi.QueryDevicesResponse{}
if err := userAPI.QueryDevices(ctx, &uapi.QueryDevicesRequest{
UserID: aliceUser.ID,
}, &deviceRes); err != nil {
t.Errorf("failed to query user devices")
}
if !deviceRes.UserExists {
t.Fatalf("user does not exist")
}
if l := len(deviceRes.Devices); l != 1 {
t.Fatalf("Incorrect number of user devices. Got %d, want 1", l)
}
if device.ID != deviceRes.Devices[0].ID {
t.Fatalf("Device IDs do not match: %s != %s", device.ID, deviceRes.Devices[0].ID)
}
})
t.Run("inactive token", func(t *testing.T) {
httpReq := http.Request{
URL: u,
Header: map[string][]string{
"Content-Type": {"application/json"},
"Authorization": {"Bearer validTokenUserExistsTokenInactive"},
},
}
device, jsonResp := userVerifier.VerifyUserFromRequest(&httpReq)
if jsonResp == nil {
t.Fatal("JSONResponse is expected to be nil")
}
if device != nil {
t.Fatalf("Device is not nil: %+v", device)
}
if jsonResp.Code != http.StatusUnauthorized {
t.Fatalf("Incorrect status code: want=401, got=%d", jsonResp.Code)
}
mErr, _ := jsonResp.JSON.(spec.MatrixError)
if mErr.ErrCode != spec.ErrorUnknownToken {
t.Fatalf("Unexpected error code: want=%s, got=%s", spec.ErrorUnknownToken, mErr.ErrCode)
}
})
t.Run("non-existing user", func(t *testing.T) {
httpReq := http.Request{
URL: u,
Header: map[string][]string{
"Content-Type": {"application/json"},
"Authorization": {"Bearer validTokenUserDoesNotExistTokenActive"},
},
}
device, jsonResp := userVerifier.VerifyUserFromRequest(&httpReq)
if jsonResp != nil {
t.Fatalf("JSONResponse is not expected: %+v", jsonResp)
}
deviceRes := uapi.QueryDevicesResponse{}
if err := userAPI.QueryDevices(ctx, &uapi.QueryDevicesRequest{
UserID: bobUser.ID,
}, &deviceRes); err != nil {
t.Errorf("failed to query user devices")
}
if !deviceRes.UserExists {
t.Fatalf("user does not exist")
}
if l := len(deviceRes.Devices); l != 1 {
t.Fatalf("Incorrect number of user devices. Got %d, want 1", l)
}
if device.ID != deviceRes.Devices[0].ID {
t.Fatalf("Device IDs do not match: %s != %s", device.ID, deviceRes.Devices[0].ID)
}
})
})
}

View file

@ -16,6 +16,7 @@ import (
"github.com/element-hq/dendrite/setup"
"github.com/element-hq/dendrite/setup/config"
"github.com/element-hq/dendrite/setup/mscs/msc2836"
"github.com/element-hq/dendrite/setup/mscs/msc3861"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
)
@ -34,9 +35,11 @@ func Enable(cfg *config.Dendrite, cm *sqlutil.Connections, routers httputil.Rout
func EnableMSC(cfg *config.Dendrite, cm *sqlutil.Connections, routers httputil.Routers, monolith *setup.Monolith, msc string, caches *caching.Caches) error {
switch msc {
case "msc2836":
return msc2836.Enable(cfg, cm, routers, monolith.RoomserverAPI, monolith.FederationAPI, monolith.UserAPI, monolith.KeyRing)
return msc2836.Enable(cfg, cm, routers, monolith.RoomserverAPI, monolith.FederationAPI, monolith.UserVerifierProvider, monolith.KeyRing)
case "msc2444": // enabled inside federationapi
case "msc2753": // enabled inside clientapi
case "msc3861":
return msc3861.Enable(monolith)
default:
logrus.Warnf("EnableMSC: unknown MSC '%s', this MSC is either not supported or is natively supported by Dendrite", msc)
}

View file

@ -36,16 +36,17 @@ func Setup(
lazyLoadCache caching.LazyLoadCache,
fts fulltext.Indexer,
rateLimits *httputil.RateLimits,
userVerifier httputil.UserVerifier,
) {
v1unstablemux := csMux.PathPrefix("/{apiversion:(?:v1|unstable)}/").Subrouter()
v3mux := csMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter()
// TODO: Add AS support for all handlers below.
v3mux.Handle("/sync", httputil.MakeAuthAPI("sync", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
v3mux.Handle("/sync", httputil.MakeAuthAPI("sync", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return srp.OnIncomingSyncRequest(req, device)
}, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
v3mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse {
// not specced, but ensure we're rate limiting requests to this endpoint
if r := rateLimits.Limit(req, device); r != nil {
return *r
@ -58,7 +59,7 @@ func Setup(
}, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/event/{eventID}",
httputil.MakeAuthAPI("rooms_get_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
httputil.MakeAuthAPI("rooms_get_event", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
@ -68,7 +69,7 @@ func Setup(
).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/user/{userId}/filter",
httputil.MakeAuthAPI("put_filter", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
httputil.MakeAuthAPI("put_filter", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
@ -78,7 +79,7 @@ func Setup(
).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/user/{userId}/filter/{filterId}",
httputil.MakeAuthAPI("get_filter", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
httputil.MakeAuthAPI("get_filter", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
@ -87,12 +88,12 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/keys/changes", httputil.MakeAuthAPI("keys_changes", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
v3mux.Handle("/keys/changes", httputil.MakeAuthAPI("keys_changes", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return srp.OnIncomingKeyChangeRequest(req, device)
}, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/rooms/{roomId}/context/{eventId}",
httputil.MakeAuthAPI("context", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
httputil.MakeAuthAPI("context", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
@ -108,7 +109,7 @@ func Setup(
).Methods(http.MethodGet, http.MethodOptions)
v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}",
httputil.MakeAuthAPI("relations", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
httputil.MakeAuthAPI("relations", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
@ -122,7 +123,7 @@ func Setup(
).Methods(http.MethodGet, http.MethodOptions)
v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}/{relType}",
httputil.MakeAuthAPI("relation_type", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
httputil.MakeAuthAPI("relation_type", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
@ -136,7 +137,7 @@ func Setup(
).Methods(http.MethodGet, http.MethodOptions)
v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}",
httputil.MakeAuthAPI("relation_type_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
httputil.MakeAuthAPI("relation_type_event", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
@ -150,7 +151,7 @@ func Setup(
).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/search",
httputil.MakeAuthAPI("search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
httputil.MakeAuthAPI("search", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if !cfg.Fulltext.Enabled {
return util.JSONResponse{
Code: http.StatusNotImplemented,
@ -173,7 +174,7 @@ func Setup(
).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/members",
httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
httputil.MakeAuthAPI("rooms_members", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)

View file

@ -42,6 +42,7 @@ func AddPublicRoutes(
userAPI userapi.SyncUserAPI,
rsAPI api.SyncRoomserverAPI,
caches caching.LazyLoadCache,
userVerifier httputil.UserVerifier,
enableMetrics bool,
) {
js, natsClient := natsInstance.Prepare(processContext, &dendriteCfg.Global.JetStream)
@ -149,5 +150,6 @@ func AddPublicRoutes(
routers.Client, requestPool, syncDB, userAPI,
rsAPI, &dendriteCfg.SyncAPI, caches, fts,
rateLimits,
userVerifier,
)
}

View file

@ -18,6 +18,7 @@ import (
"github.com/gorilla/mux"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util"
"github.com/nats-io/nats.go"
"github.com/stretchr/testify/assert"
"github.com/tidwall/gjson"
@ -119,6 +120,20 @@ func (s *syncUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.Pe
return nil
}
type mockUserVerifier struct {
accessTokenToDeviceAndResponse map[string]struct {
Device *userapi.Device
Response *util.JSONResponse
}
}
func (u *mockUserVerifier) VerifyUserFromRequest(req *http.Request) (*userapi.Device, *util.JSONResponse) {
if pair, ok := u.accessTokenToDeviceAndResponse[req.URL.Query().Get("access_token")]; ok {
return pair.Device, pair.Response
}
return nil, nil
}
func TestSyncAPIAccessTokens(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
testSyncAccessTokens(t, dbType)
@ -146,12 +161,16 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) {
jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream)
msgs := toNATSMsgs(t, cfg, room.Events()...)
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, caching.DisableMetrics)
uv := &mockUserVerifier{}
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, uv, caching.DisableMetrics)
testrig.MustPublishMsgs(t, jsctx, msgs...)
testCases := []struct {
name string
req *http.Request
device *userapi.Device
response *util.JSONResponse
wantCode int
wantJoinedRooms []string
}{
@ -160,6 +179,11 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) {
req: test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{
"timeout": "0",
})),
device: nil,
response: &util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: spec.UnknownToken("Unknown token"),
},
wantCode: 401,
},
{
@ -168,6 +192,11 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) {
"access_token": "foo",
"timeout": "0",
})),
device: nil,
response: &util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: spec.UnknownToken("Unknown token"),
},
wantCode: 401,
},
{
@ -176,11 +205,25 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) {
"access_token": alice.AccessToken,
"timeout": "0",
})),
device: &alice,
response: nil,
wantCode: 200,
wantJoinedRooms: []string{room.ID},
},
}
uv.accessTokenToDeviceAndResponse = make(map[string]struct {
Device *userapi.Device
Response *util.JSONResponse
}, len(testCases))
for _, tc := range testCases {
uv.accessTokenToDeviceAndResponse[tc.req.URL.Query().Get("access_token")] = struct {
Device *userapi.Device
Response *util.JSONResponse
}{Device: tc.device, Response: tc.response}
}
syncUntil(t, routers, alice.AccessToken, false, func(syncBody string) bool {
// wait for the last sent eventID to come down sync
path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(event_id=="%s")`, room.ID, room.Events()[len(room.Events())-1].EventID())
@ -241,12 +284,20 @@ func testSyncEventFormatPowerLevels(t *testing.T, dbType test.DBType) {
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
natsInstance := jetstream.NATSInstance{}
uv := mockUserVerifier{
accessTokenToDeviceAndResponse: map[string]struct {
Device *userapi.Device
Response *util.JSONResponse
}{
alice.AccessToken: {Device: &alice, Response: nil},
},
}
defer close()
jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream)
msgs := toNATSMsgs(t, cfg, room.Events()...)
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, caching.DisableMetrics)
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, &uv, caching.DisableMetrics)
testrig.MustPublishMsgs(t, jsctx, msgs...)
testCases := []struct {
@ -399,7 +450,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
// m.room.history_visibility
msgs := toNATSMsgs(t, cfg, room.Events()...)
sinceTokens := make([]string, len(msgs))
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, caching.DisableMetrics)
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, nil, caching.DisableMetrics)
for i, msg := range msgs {
testrig.MustPublishMsgs(t, jsctx, msg)
time.Sleep(100 * time.Millisecond)
@ -487,7 +538,15 @@ func testSyncAPIUpdatePresenceImmediately(t *testing.T, dbType test.DBType) {
jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream)
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, caches, caching.DisableMetrics)
uv := mockUserVerifier{
accessTokenToDeviceAndResponse: map[string]struct {
Device *userapi.Device
Response *util.JSONResponse
}{
alice.AccessToken: {Device: &alice, Response: nil},
},
}
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, caches, &uv, caching.DisableMetrics)
w := httptest.NewRecorder()
routers.Client.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{
"access_token": alice.AccessToken,
@ -609,7 +668,16 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) {
// Use the actual internal roomserver API
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, caches, caching.DisableMetrics)
uv := mockUserVerifier{
accessTokenToDeviceAndResponse: map[string]struct {
Device *userapi.Device
Response *util.JSONResponse
}{
aliceDev.AccessToken: {Device: &aliceDev, Response: nil},
bobDev.AccessToken: {Device: &bobDev, Response: nil},
},
}
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, caches, &uv, caching.DisableMetrics)
for _, tc := range testCases {
testname := fmt.Sprintf("%s - %s", tc.historyVisibility, userType)
@ -878,8 +946,17 @@ func TestGetMembership(t *testing.T) {
// Use an actual roomserver for this
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
uv := mockUserVerifier{
accessTokenToDeviceAndResponse: map[string]struct {
Device *userapi.Device
Response *util.JSONResponse
}{
aliceDev.AccessToken: {Device: &aliceDev, Response: nil},
bobDev.AccessToken: {Device: &bobDev, Response: nil},
},
}
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, caches, caching.DisableMetrics)
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, caches, &uv, caching.DisableMetrics)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
@ -946,10 +1023,18 @@ func testSendToDevice(t *testing.T, dbType test.DBType) {
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
defer close()
natsInstance := jetstream.NATSInstance{}
uv := mockUserVerifier{
accessTokenToDeviceAndResponse: map[string]struct {
Device *userapi.Device
Response *util.JSONResponse
}{
alice.AccessToken: {Device: &alice, Response: nil},
},
}
jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream)
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, caches, caching.DisableMetrics)
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, caches, &uv, caching.DisableMetrics)
producer := producers.SyncAPIProducer{
TopicSendToDeviceEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent),
@ -1172,7 +1257,16 @@ func testContext(t *testing.T, dbType test.DBType) {
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, rsAPI, caches, caching.DisableMetrics)
uv := mockUserVerifier{
accessTokenToDeviceAndResponse: map[string]struct {
Device *userapi.Device
Response *util.JSONResponse
}{
alice.AccessToken: {Device: &alice, Response: nil},
},
}
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, rsAPI, caches, &uv, caching.DisableMetrics)
room := test.NewRoom(t, user)
@ -1351,9 +1445,17 @@ func TestRemoveEditedEventFromSearchIndex(t *testing.T) {
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
uv := mockUserVerifier{
accessTokenToDeviceAndResponse: map[string]struct {
Device *userapi.Device
Response *util.JSONResponse
}{
alice.AccessToken: {Device: &alice, Response: nil},
},
}
room := test.NewRoom(t, user)
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, caching.DisableMetrics)
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, &uv, caching.DisableMetrics)
if err := api.SendEvents(processCtx.Context(), rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
t.Fatalf("failed to send events: %v", err)
@ -1416,6 +1518,7 @@ func searchRequest(t *testing.T, router *mux.Router, accessToken, searchTerm str
assert.NoError(t, err)
return body
}
func syncUntil(t *testing.T,
routers httputil.Routers, accessToken string,
skip bool,

View file

@ -31,7 +31,8 @@ type UserInternalAPI interface {
FederationUserAPI
QuerySearchProfilesAPI // used by p2p demos
QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
QueryExternalUserIDByLocalpartAndProvider(ctx context.Context, externalID, authProvider string) (*LocalpartExternalID, error)
PerformLocalpartExternalUserIDCreation(ctx context.Context, localpart, externalID, authProvider string) error
}
// api functions required by the appservice api
@ -47,7 +48,7 @@ type RoomserverUserAPI interface {
// api functions required by the media api
type MediaUserAPI interface {
QueryAcccessTokenAPI
QueryAccessTokenAPI
}
// api functions required by the federation api
@ -64,7 +65,7 @@ type FederationUserAPI interface {
// api functions required by the sync api
type SyncUserAPI interface {
QueryAcccessTokenAPI
QueryAccessTokenAPI
SyncKeyAPI
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error
@ -75,7 +76,7 @@ type SyncUserAPI interface {
// api functions required by the client api
type ClientUserAPI interface {
QueryAcccessTokenAPI
QueryAccessTokenAPI
LoginTokenInternalAPI
UserLoginAPI
ClientKeyAPI
@ -87,6 +88,7 @@ type ClientUserAPI interface {
QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error
QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error)
QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error
QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
PerformAdminCreateRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error)
PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error)
PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error)
@ -109,6 +111,7 @@ type ClientUserAPI interface {
QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error
PerformForgetThreePID(ctx context.Context, req *PerformForgetThreePIDRequest, res *struct{}) error
PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error
PerformBulkSaveThreePIDAssociation(ctx context.Context, req *PerformBulkSaveThreePIDAssociationRequest, res *struct{}) error
}
type KeyBackupAPI interface {
@ -130,7 +133,7 @@ type QuerySearchProfilesAPI interface {
}
// common function for creating authenticated endpoints (used in client/media/sync api)
type QueryAcccessTokenAPI interface {
type QueryAccessTokenAPI interface {
QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error
}
@ -316,6 +319,9 @@ type PerformAccountCreationRequest struct {
Localpart string // Required: The localpart for this account. Ignored if account type is guest.
ServerName spec.ServerName // optional: if not specified, default server name used instead
DisplayName string // optional: this is populated only by MAS. In the legacy flow it's not used
AvatarURL string // optional: this is populated only by MAS. In the legacy flow it's not used
AppServiceID string // optional: the application service ID (not user ID) creating this account, if any.
Password string // optional: if missing then this account will be a passwordless account
OnConflict Conflict
@ -375,6 +381,11 @@ type PerformDeviceCreationRequest struct {
// FromRegistration determines if this request comes from registering a new account
// and is in most cases false.
FromRegistration bool
// AccessTokenUniqueConstraintDisabled determines if unique constraint is applicable for the AccessToken.
// It is false if an external auth service is in use (e.g. MAS) and server does not generate its own
// auth tokens. Otherwise, if traditional login is in use, the value is true. Default is false.
AccessTokenUniqueConstraintDisabled bool
}
// PerformDeviceCreationResponse is the response for PerformDeviceCreation
@ -455,6 +466,7 @@ type Account struct {
ServerName spec.ServerName
AppServiceID string
AccountType AccountType
Deactivated bool
// TODO: Associations (e.g. with application services)
}
@ -471,6 +483,14 @@ type OpenIDTokenAttributes struct {
ExpiresAtMS int64
}
// LocalpartExternalID represents a connection between Matrix account and OpenID Connect provider
type LocalpartExternalID struct {
Localpart string
ExternalID string
AuthProvider string
CreatedTS int64
}
// UserInfo is for returning information about the user an OpenID token was issued for
type UserInfo struct {
Sub string // The Matrix user's ID who generated the token
@ -514,6 +534,8 @@ const (
AccountTypeAdmin AccountType = 3
// AccountTypeAppService indicates this is an appservice account
AccountTypeAppService AccountType = 4
// AccountTypeOIDC indicates this is an account belonging to Matrix Authentication Service (MAS)
AccountTypeOIDCService AccountType = 5
)
type QueryPushersRequest struct {
@ -636,6 +658,12 @@ type PerformSaveThreePIDAssociationRequest struct {
Medium string
}
type PerformBulkSaveThreePIDAssociationRequest struct {
ThreePIDs []authtypes.ThreePID
Localpart string
ServerName spec.ServerName
}
type QueryAccountByLocalpartRequest struct {
Localpart string
ServerName spec.ServerName

View file

@ -455,10 +455,11 @@ func (a *UserInternalAPI) processOtherSignatures(
func (a *UserInternalAPI) crossSigningKeysFromDatabase(
ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse,
) {
logger := logrus.WithContext(ctx)
for targetUserID := range req.UserToDevices {
keys, err := a.KeyDatabase.CrossSigningKeysForUser(ctx, targetUserID)
if err != nil {
logrus.WithError(err).Errorf("Failed to get cross-signing keys for user %q", targetUserID)
logger.WithError(err).Errorf("Failed to get cross-signing keys for user %q", targetUserID)
continue
}
@ -471,7 +472,7 @@ func (a *UserInternalAPI) crossSigningKeysFromDatabase(
sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, keyID)
if err != nil && err != sql.ErrNoRows {
logrus.WithError(err).Errorf("Failed to get cross-signing signatures for user %q key %q", targetUserID, keyID)
logger.WithError(err).Errorf("Failed to get cross-signing signatures for user %q key %q", targetUserID, keyID)
continue
}

View file

@ -272,7 +272,7 @@ func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReque
DeviceIDs: dids,
}, &queryRes)
if err != nil {
util.GetLogger(ctx).Warnf("Failed to QueryDeviceInfos for device IDs, display names will be missing")
util.GetLogger(ctx).WithError(err).Warnf("Failed to QueryDeviceInfos for device IDs, display names will be missing")
}
if res.DeviceKeys[userID] == nil {

View file

@ -7,6 +7,7 @@
package internal
import (
"cmp"
"context"
"database/sql"
"encoding/json"
@ -247,10 +248,17 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
return nil
}
if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, serverName, req.Localpart); err != nil {
displayName := cmp.Or(req.DisplayName, req.Localpart)
if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, serverName, displayName); err != nil {
return fmt.Errorf("a.DB.SetDisplayName: %w", err)
}
if req.AvatarURL != "" {
if _, _, err := a.DB.SetAvatarURL(ctx, req.Localpart, serverName, req.AvatarURL); err != nil {
return fmt.Errorf("a.DB.SetAvatarURL: %w", err)
}
}
postRegisterJoinRooms(a.Config, acc, a.RSAPI)
res.AccountCreated = true
@ -298,6 +306,15 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
"device_id": req.DeviceID,
"display_name": req.DeviceDisplayName,
}).Info("PerformDeviceCreation")
if !req.AccessTokenUniqueConstraintDisabled {
dev, err := a.DB.GetDeviceByAccessToken(ctx, req.AccessToken)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
if dev.UserID != "" {
return errors.New("unique constraint violation. Access token is not unique" + dev.AccessToken)
}
}
dev, err := a.DB.CreateDevice(ctx, req.Localpart, serverName, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
if err != nil {
return err
@ -594,6 +611,14 @@ func (a *UserInternalAPI) QueryAccountByLocalpart(ctx context.Context, req *api.
return
}
func (a *UserInternalAPI) PerformLocalpartExternalUserIDCreation(ctx context.Context, localpart, externalID, authProvider string) (err error) {
return a.DB.CreateLocalpartExternalID(ctx, localpart, externalID, authProvider)
}
func (a *UserInternalAPI) QueryExternalUserIDByLocalpartAndProvider(ctx context.Context, externalID, authProvider string) (*api.LocalpartExternalID, error) {
return a.DB.GetLocalpartForExternalID(ctx, externalID, authProvider)
}
// Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem
// creating a 'device'.
func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appServiceUserID string) (*api.Device, error) {
@ -970,4 +995,8 @@ func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, re
return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.ServerName, req.Medium)
}
func (a *UserInternalAPI) PerformBulkSaveThreePIDAssociation(ctx context.Context, req *api.PerformBulkSaveThreePIDAssociationRequest, res *struct{}) error {
return a.DB.BulkSaveThreePIDAssociation(ctx, req.ThreePIDs, req.Localpart, req.ServerName)
}
const pushRulesAccountDataType = "m.push_rules"

View file

@ -119,6 +119,7 @@ type Pusher interface {
type ThreePID interface {
SaveThreePIDAssociation(ctx context.Context, threepid, localpart string, serverName spec.ServerName, medium string) (err error)
BulkSaveThreePIDAssociation(ctx context.Context, threePIDs []authtypes.ThreePID, localpart string, serverName spec.ServerName) (err error)
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, serverName spec.ServerName, err error)
GetThreePIDsForLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (threepids []authtypes.ThreePID, err error)
@ -134,6 +135,12 @@ type Notification interface {
DeleteOldNotifications(ctx context.Context) error
}
type LocalpartExternalID interface {
CreateLocalpartExternalID(ctx context.Context, localpart, externalID, authProvider string) error
GetLocalpartForExternalID(ctx context.Context, externalID, authProvider string) (*api.LocalpartExternalID, error)
DeleteLocalpartExternalID(ctx context.Context, externalID, authProvider string) error
}
type UserDatabase interface {
Account
AccountData
@ -147,6 +154,7 @@ type UserDatabase interface {
Statistics
ThreePID
RegistrationTokens
LocalpartExternalID
}
type KeyChangeDatabase interface {

View file

@ -55,7 +55,7 @@ const deactivateAccountSQL = "" +
"UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1 AND server_name = $2"
const selectAccountByLocalpartSQL = "" +
"SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2"
"SELECT localpart, server_name, appservice_id, account_type, is_deactivated FROM userapi_accounts WHERE localpart = $1 AND server_name = $2"
const selectPasswordHashSQL = "" +
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = FALSE"
@ -116,7 +116,7 @@ func (s *accountsStatements) InsertAccount(
localpart string, serverName spec.ServerName,
hash, appserviceID string, accountType api.AccountType,
) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
createdTimeMS := spec.AsTimestamp(time.Now())
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
var err error
@ -135,6 +135,7 @@ func (s *accountsStatements) InsertAccount(
ServerName: serverName,
AppServiceID: appserviceID,
AccountType: accountType,
Deactivated: false,
}, nil
}
@ -167,7 +168,7 @@ func (s *accountsStatements) SelectAccountByLocalpart(
var acc api.Account
stmt := s.selectAccountByLocalpartStmt
err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType)
err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType, &acc.Deactivated)
if err != nil {
if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve user from the db")

View file

@ -51,6 +51,11 @@ func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, erro
if err != nil {
return nil, err
}
m := sqlutil.NewMigrator(db)
err = m.Up(context.Background())
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL},
{&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL},

View file

@ -0,0 +1,30 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
package deltas
import (
"context"
"database/sql"
"fmt"
)
func UpDropPrimaryKeyConstraint(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `
ALTER TABLE userapi_devices DROP CONSTRAINT userapi_devices_pkey;`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownDropPrimaryKeyConstraint(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `
ALTER TABLE userapi_devices ADD CONSTRAINT userapi_devices_pkey PRIMARY KEY (access_token);`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -116,10 +116,16 @@ func NewPostgresDevicesTable(db *sql.DB, serverName spec.ServerName) (tables.Dev
return nil, err
}
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "userapi: add last_seen_ts",
Up: deltas.UpLastSeenTSIP,
})
m.AddMigrations(
sqlutil.Migration{
Version: "userapi: add last_seen_ts",
Up: deltas.UpLastSeenTSIP,
},
sqlutil.Migration{
Version: "userapi: drop primary key constraint",
Up: deltas.UpDropPrimaryKeyConstraint,
},
)
err = m.Up(context.Background())
if err != nil {
return nil, err

View file

@ -0,0 +1,102 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
package postgres
import (
"context"
"database/sql"
"time"
"github.com/element-hq/dendrite/internal/sqlutil"
"github.com/element-hq/dendrite/userapi/api"
"github.com/element-hq/dendrite/userapi/storage/tables"
log "github.com/sirupsen/logrus"
)
const localpartExternalIDsSchema = `
-- Stores data about connections between accounts and third-party auth providers
CREATE TABLE IF NOT EXISTS userapi_localpart_external_ids (
-- The Matrix user ID for this account
localpart TEXT NOT NULL,
-- The external ID
external_id TEXT NOT NULL,
-- Auth provider ID (see OIDCProvider.IDPID)
auth_provider TEXT NOT NULL,
-- When this connection was created, as a unix timestamp.
created_ts BIGINT NOT NULL,
CONSTRAINT userapi_localpart_external_ids_external_id_auth_provider_unique UNIQUE(external_id, auth_provider),
CONSTRAINT userapi_localpart_external_ids_localpart_external_id_auth_provider_unique UNIQUE(localpart, external_id, auth_provider)
);
-- This index allows efficient lookup of the local user by the external ID
CREATE INDEX IF NOT EXISTS userapi_external_id_auth_provider_idx ON userapi_localpart_external_ids(external_id, auth_provider);
`
const insertUserExternalIDSQL = "" +
"INSERT INTO userapi_localpart_external_ids(localpart, external_id, auth_provider, created_ts) VALUES ($1, $2, $3, $4)"
const selectUserExternalIDSQL = "" +
"SELECT localpart, created_ts FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2"
const deleteUserExternalIDSQL = "" +
"DELETE FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2"
type localpartExternalIDStatements struct {
db *sql.DB
insertUserExternalIDStmt *sql.Stmt
selectUserExternalIDStmt *sql.Stmt
deleteUserExternalIDStmt *sql.Stmt
}
func NewPostgresLocalpartExternalIDsTable(db *sql.DB) (tables.LocalpartExternalIDsTable, error) {
s := &localpartExternalIDStatements{
db: db,
}
_, err := db.Exec(localpartExternalIDsSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.insertUserExternalIDStmt, insertUserExternalIDSQL},
{&s.selectUserExternalIDStmt, selectUserExternalIDSQL},
{&s.deleteUserExternalIDStmt, deleteUserExternalIDSQL},
}.Prepare(db)
}
// SelectLocalExternalPartID selects an existing OpenID Connect connection from the database
func (u *localpartExternalIDStatements) SelectLocalExternalPartID(ctx context.Context, txn *sql.Tx, externalID, authProvider string) (*api.LocalpartExternalID, error) {
ret := api.LocalpartExternalID{
ExternalID: externalID,
AuthProvider: authProvider,
}
err := u.selectUserExternalIDStmt.QueryRowContext(ctx, externalID, authProvider).Scan(
&ret.Localpart, &ret.CreatedTS,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
log.WithError(err).Error("Unable to retrieve localpart from the db")
return nil, err
}
return &ret, nil
}
// InsertLocalExternalPartID creates a new record representing an OpenID Connect connection between Matrix and external accounts.
func (u *localpartExternalIDStatements) InsertLocalExternalPartID(ctx context.Context, txn *sql.Tx, localpart, externalID, authProvider string) error {
stmt := sqlutil.TxStmt(txn, u.insertUserExternalIDStmt)
_, err := stmt.ExecContext(ctx, localpart, externalID, authProvider, time.Now().Unix())
return err
}
// DeleteLocalExternalPartID deletes the existing OpenID Connect connection. After this method is called, the Matrix account will no longer be associated with the external account.
func (u *localpartExternalIDStatements) DeleteLocalExternalPartID(ctx context.Context, txn *sql.Tx, externalID, authProvider string) error {
stmt := sqlutil.TxStmt(txn, u.deleteUserExternalIDStmt)
_, err := stmt.ExecContext(ctx, externalID, authProvider)
return err
}

View file

@ -97,6 +97,10 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties
if err != nil {
return nil, fmt.Errorf("NewPostgresStatsTable: %w", err)
}
localpartExternalIDsTable, err := NewPostgresLocalpartExternalIDsTable(db)
if err != nil {
return nil, fmt.Errorf("NewSQLiteLocalpartExternalIDsTable: %w", err)
}
m = sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
@ -123,6 +127,7 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties
Notifications: notificationsTable,
RegistrationTokens: registationTokensTable,
Stats: statsTable,
LocalpartExternalIDs: localpartExternalIDsTable,
ServerName: serverName,
DB: db,
Writer: writer,

View file

@ -49,6 +49,7 @@ type Database struct {
Notifications tables.NotificationTable
Pushers tables.PusherTable
Stats tables.StatsTable
LocalpartExternalIDs tables.LocalpartExternalIDsTable
LoginTokenLifetime time.Duration
ServerName spec.ServerName
BcryptCost int
@ -352,6 +353,41 @@ func (d *Database) SaveThreePIDAssociation(
})
}
// BulkSaveThreePIDAssociation recreates 3PIDs for a user.
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
// Returns an error if there was a problem talking to the database.
func (d *Database) BulkSaveThreePIDAssociation(ctx context.Context, threePIDs []authtypes.ThreePID, localpart string, serverName spec.ServerName) (err error) {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
oldThreePIDs, err := d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart, serverName)
if err != nil {
return err
}
for _, t := range oldThreePIDs {
if err := d.ThreePIDs.DeleteThreePID(ctx, txn, t.Address, t.Medium); err != nil {
return err
}
}
for _, t := range threePIDs {
// if 3PID is associated with another user, return Err3PIDInUse
user, _, err := d.ThreePIDs.SelectLocalpartForThreePID(
ctx, txn, t.Address, t.Medium,
)
if err != nil {
return err
}
if len(user) > 0 && user != localpart {
return Err3PIDInUse
}
if err = d.ThreePIDs.InsertThreePID(ctx, txn, t.Address, t.Medium, localpart, serverName); err != nil {
return err
}
}
return nil
})
}
// RemoveThreePIDAssociation removes the association involving a given third-party
// identifier.
// If no association exists involving this third-party identifier, returns nothing.
@ -870,6 +906,18 @@ func (d *Database) UpsertPusher(
})
}
func (d *Database) CreateLocalpartExternalID(ctx context.Context, localpart, externalID, authProvider string) error {
return d.LocalpartExternalIDs.InsertLocalExternalPartID(ctx, nil, localpart, externalID, authProvider)
}
func (d *Database) GetLocalpartForExternalID(ctx context.Context, externalID, authProvider string) (*api.LocalpartExternalID, error) {
return d.LocalpartExternalIDs.SelectLocalExternalPartID(ctx, nil, externalID, authProvider)
}
func (d *Database) DeleteLocalpartExternalID(ctx context.Context, externalID, authProvider string) error {
return d.LocalpartExternalIDs.DeleteLocalExternalPartID(ctx, nil, externalID, authProvider)
}
// GetPushers returns the pushers matching the given localpart.
func (d *Database) GetPushers(
ctx context.Context, localpart string, serverName spec.ServerName,
@ -1132,8 +1180,8 @@ func (d *KeyDatabase) CrossSigningSigsForTarget(ctx context.Context, originUserI
// StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user.
func (d *KeyDatabase) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
for keyType, keyData := range keyMap {
if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil {
for keyType, key := range keyMap {
if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, key); err != nil {
return fmt.Errorf("d.CrossSigningKeysTable.InsertCrossSigningKeysForUser: %w", err)
}
}
@ -1141,7 +1189,7 @@ func (d *KeyDatabase) StoreCrossSigningKeysForUser(ctx context.Context, userID s
})
}
// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/dvice.
// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/device.
func (d *KeyDatabase) StoreCrossSigningSigsForTarget(
ctx context.Context,
originUserID string, originKeyID gomatrixserverlib.KeyID,

View file

@ -54,7 +54,7 @@ const deactivateAccountSQL = "" +
"UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1 AND server_name = $2"
const selectAccountByLocalpartSQL = "" +
"SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2"
"SELECT localpart, server_name, appservice_id, account_type, is_deactivated FROM userapi_accounts WHERE localpart = $1 AND server_name = $2"
const selectPasswordHashSQL = "" +
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = 0"
@ -116,7 +116,7 @@ func (s *accountsStatements) InsertAccount(
ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName,
hash, appserviceID string, accountType api.AccountType,
) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
createdTimeMS := spec.AsTimestamp(time.Now())
stmt := s.insertAccountStmt
var err error
@ -135,6 +135,7 @@ func (s *accountsStatements) InsertAccount(
ServerName: serverName,
AppServiceID: appserviceID,
AccountType: accountType,
Deactivated: false,
}, nil
}
@ -167,7 +168,7 @@ func (s *accountsStatements) SelectAccountByLocalpart(
var acc api.Account
stmt := s.selectAccountByLocalpartStmt
err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType)
err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType, &acc.Deactivated)
if err != nil {
if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve user from the db")

View file

@ -50,6 +50,11 @@ func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error)
if err != nil {
return nil, err
}
m := sqlutil.NewMigrator(db)
err = m.Up(context.Background())
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL},
{&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL},
@ -82,8 +87,7 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser(
}
func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser(
ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData spec.Base64Bytes,
) error {
ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData spec.Base64Bytes) error {
keyTypeInt, ok := types.KeyTypePurposeToInt[keyType]
if !ok {
return fmt.Errorf("unknown key purpose %q", keyType)

View file

@ -0,0 +1,70 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
package deltas
import (
"context"
"database/sql"
"fmt"
)
func UpDropPrimaryKeyConstraint(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `
ALTER TABLE userapi_devices RENAME TO userapi_devices_tmp;
CREATE TABLE userapi_devices (
access_token TEXT,
session_id INTEGER,
device_id TEXT ,
localpart TEXT ,
server_name TEXT NOT NULL,
created_ts BIGINT,
display_name TEXT,
last_seen_ts BIGINT,
ip TEXT,
user_agent TEXT,
UNIQUE (localpart, device_id)
);
INSERT
INTO userapi_devices (
access_token, session_id, device_id, localpart, server_name, created_ts, display_name, last_seen_ts, ip, user_agent
) SELECT
access_token, session_id, device_id, localpart, server_name, created_ts, display_name, created_ts, '', ''
FROM userapi_devices_tmp;
DROP TABLE userapi_devices_tmp;`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownDropPrimaryKeyConstraint(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `
ALTER TABLE userapi_devices RENAME TO userapi_devices_tmp;
CREATE TABLE userapi_devices (
access_token TEXT PRIMARY KEY,
session_id INTEGER,
device_id TEXT ,
localpart TEXT ,
server_name TEXT NOT NULL,
created_ts BIGINT,
display_name TEXT,
last_seen_ts BIGINT,
ip TEXT,
user_agent TEXT,
UNIQUE (localpart, device_id)
);
INSERT
INTO userapi_devices (
access_token, session_id, device_id, localpart, created_ts, display_name, last_seen_ts, ip, user_agent
) SELECT
access_token, session_id, device_id, localpart, created_ts, display_name, created_ts, '', ''
FROM userapi_devices_tmp;
DROP TABLE userapi_devices_tmp;`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -102,10 +102,16 @@ func NewSQLiteDevicesTable(db *sql.DB, serverName spec.ServerName) (tables.Devic
return nil, err
}
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "userapi: add last_seen_ts",
Up: deltas.UpLastSeenTSIP,
})
m.AddMigrations(
sqlutil.Migration{
Version: "userapi: add last_seen_ts",
Up: deltas.UpLastSeenTSIP,
},
sqlutil.Migration{
Version: "userapi: drop primary key constraint",
Up: deltas.UpDropPrimaryKeyConstraint,
},
)
if err = m.Up(context.Background()); err != nil {
return nil, err
}

View file

@ -0,0 +1,102 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
package sqlite3
import (
"context"
"database/sql"
"time"
"github.com/element-hq/dendrite/internal/sqlutil"
"github.com/element-hq/dendrite/userapi/api"
"github.com/element-hq/dendrite/userapi/storage/tables"
log "github.com/sirupsen/logrus"
)
const localpartExternalIDsSchema = `
-- Stores data about connections between accounts and third-party auth providers
CREATE TABLE IF NOT EXISTS userapi_localpart_external_ids (
-- The Matrix user ID for this account
localpart TEXT NOT NULL,
-- The external ID
external_id TEXT NOT NULL,
-- Auth provider ID (see OIDCProvider.IDPID)
auth_provider TEXT NOT NULL,
-- When this connection was created, as a unix timestamp.
created_ts BIGINT NOT NULL,
UNIQUE(external_id, auth_provider),
UNIQUE(localpart, external_id, auth_provider)
);
-- This index allows efficient lookup of the local user by the external ID
CREATE INDEX IF NOT EXISTS userapi_external_id_auth_provider_idx ON userapi_localpart_external_ids(external_id, auth_provider);
`
const insertLocalpartExternalIDSQL = "" +
"INSERT INTO userapi_localpart_external_ids(localpart, external_id, auth_provider, created_ts) VALUES ($1, $2, $3, $4)"
const selectLocalpartExternalIDSQL = "" +
"SELECT localpart, created_ts FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2"
const deleteLocalpartExternalIDSQL = "" +
"DELETE FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2"
type localpartExternalIDStatements struct {
db *sql.DB
insertUserExternalIDStmt *sql.Stmt
selectUserExternalIDStmt *sql.Stmt
deleteUserExternalIDStmt *sql.Stmt
}
func NewSQLiteLocalpartExternalIDsTable(db *sql.DB) (tables.LocalpartExternalIDsTable, error) {
s := &localpartExternalIDStatements{
db: db,
}
_, err := db.Exec(localpartExternalIDsSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.insertUserExternalIDStmt, insertLocalpartExternalIDSQL},
{&s.selectUserExternalIDStmt, selectLocalpartExternalIDSQL},
{&s.deleteUserExternalIDStmt, deleteLocalpartExternalIDSQL},
}.Prepare(db)
}
// SelectLocalExternalPartID selects an existing OpenID Connect connection from the database
func (u *localpartExternalIDStatements) SelectLocalExternalPartID(ctx context.Context, txn *sql.Tx, externalID, authProvider string) (*api.LocalpartExternalID, error) {
ret := api.LocalpartExternalID{
ExternalID: externalID,
AuthProvider: authProvider,
}
err := u.selectUserExternalIDStmt.QueryRowContext(ctx, externalID, authProvider).Scan(
&ret.Localpart, &ret.CreatedTS,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
log.WithError(err).Error("Unable to retrieve localpart from the db")
return nil, err
}
return &ret, nil
}
// InsertLocalExternalPartID creates a new record representing an OpenID Connect connection between Matrix and external accounts.
func (u *localpartExternalIDStatements) InsertLocalExternalPartID(ctx context.Context, txn *sql.Tx, localpart, externalID, authProvider string) error {
stmt := sqlutil.TxStmt(txn, u.insertUserExternalIDStmt)
_, err := stmt.ExecContext(ctx, localpart, externalID, authProvider, time.Now().Unix())
return err
}
// DeleteLocalExternalPartID deletes the existing OpenID Connect connection. After this method is called, the Matrix account will no longer be associated with the external account.
func (u *localpartExternalIDStatements) DeleteLocalExternalPartID(ctx context.Context, txn *sql.Tx, externalID, authProvider string) error {
stmt := sqlutil.TxStmt(txn, u.deleteUserExternalIDStmt)
_, err := stmt.ExecContext(ctx, externalID, authProvider)
return err
}

View file

@ -94,6 +94,10 @@ func NewUserDatabase(ctx context.Context, conMan *sqlutil.Connections, dbPropert
if err != nil {
return nil, fmt.Errorf("NewSQLiteStatsTable: %w", err)
}
localpartExternalIDsTable, err := NewSQLiteLocalpartExternalIDsTable(db)
if err != nil {
return nil, fmt.Errorf("NewSQLiteUserExternalIDsTable: %w", err)
}
m = sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
@ -119,6 +123,7 @@ func NewUserDatabase(ctx context.Context, conMan *sqlutil.Connections, dbPropert
Pushers: pusherTable,
Notifications: notificationsTable,
Stats: statsTable,
LocalpartExternalIDs: localpartExternalIDsTable,
ServerName: serverName,
DB: db,
Writer: writer,

View file

@ -14,7 +14,7 @@ import (
"github.com/element-hq/dendrite/internal/sqlutil"
"github.com/element-hq/dendrite/setup/config"
"github.com/element-hq/dendrite/userapi/storage/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
)
func NewUserDatabase(

View file

@ -127,6 +127,12 @@ type StatsTable interface {
UpsertDailyStats(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, stats types.MessageStats, activeRooms, activeE2EERooms int64) error
}
type LocalpartExternalIDsTable interface {
SelectLocalExternalPartID(ctx context.Context, txn *sql.Tx, externalID, authProvider string) (*api.LocalpartExternalID, error)
InsertLocalExternalPartID(ctx context.Context, txn *sql.Tx, localpart, externalID, authProvider string) error
DeleteLocalExternalPartID(ctx context.Context, txn *sql.Tx, externalID, authProvider string) error
}
type NotificationFilter uint32
const (

View file

@ -457,34 +457,42 @@ func TestDevices(t *testing.T) {
}{
{
name: "not a local user",
inputData: &api.PerformDeviceCreationRequest{Localpart: "test1", ServerName: "notlocal"},
inputData: &api.PerformDeviceCreationRequest{Localpart: "test1", ServerName: "notlocal", AccessTokenUniqueConstraintDisabled: true},
wantErr: true,
},
{
name: "implicit local user",
inputData: &api.PerformDeviceCreationRequest{Localpart: "test1", AccessToken: util.RandomString(8), NoDeviceListUpdate: true, DeviceDisplayName: &displayName},
inputData: &api.PerformDeviceCreationRequest{Localpart: "test1", AccessToken: util.RandomString(8), NoDeviceListUpdate: true, DeviceDisplayName: &displayName, AccessTokenUniqueConstraintDisabled: true},
},
{
name: "explicit local user",
inputData: &api.PerformDeviceCreationRequest{Localpart: "test2", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true},
},
{
name: "dupe token - ok",
inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true},
},
{
name: "dupe token - not ok",
inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true},
wantErr: true,
inputData: &api.PerformDeviceCreationRequest{Localpart: "test2", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true, AccessTokenUniqueConstraintDisabled: true},
},
{
name: "test3 second device", // used to test deletion later
inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true},
inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true, AccessTokenUniqueConstraintDisabled: true},
},
{
name: "test3 third device", // used to test deletion later
wantNewDevID: true,
inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true},
inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true, AccessTokenUniqueConstraintDisabled: true},
},
{
name: "dupe token - ok (unique constraint enabled)",
inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true, AccessTokenUniqueConstraintDisabled: false},
},
{
name: "dupe token - not ok (unique constraint enabled)",
inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true, AccessTokenUniqueConstraintDisabled: false},
wantErr: true,
},
{
name: "dupe token - ok (unique constraint disabled)",
inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true, AccessTokenUniqueConstraintDisabled: true},
},
{
name: "dupe token - not ok (unique constraint disabled)",
inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true, AccessTokenUniqueConstraintDisabled: true},
},
}
@ -629,7 +637,13 @@ func TestDeviceIDReuse(t *testing.T) {
res := api.PerformDeviceCreationResponse{}
// create a first device
deviceID := util.RandomString(8)
req := api.PerformDeviceCreationRequest{Localpart: "alice", ServerName: "test", DeviceID: &deviceID, NoDeviceListUpdate: true}
req := api.PerformDeviceCreationRequest{
Localpart: "alice",
ServerName: "test",
DeviceID: &deviceID,
NoDeviceListUpdate: true,
AccessTokenUniqueConstraintDisabled: true,
}
err := intAPI.PerformDeviceCreation(ctx, &req, &res)
if err != nil {
t.Fatal(err)