mirror of
https://github.com/element-hq/dendrite.git
synced 2025-03-14 14:15:35 +00:00
Merge 20b3917084
into c15dee80f2
This commit is contained in:
commit
a0788425eb
64 changed files with 3604 additions and 530 deletions
|
@ -28,7 +28,6 @@ run:
|
|||
# the dependency descriptions in go.mod.
|
||||
#modules-download-mode: (release|readonly|vendor)
|
||||
|
||||
|
||||
# output configuration options
|
||||
output:
|
||||
# colored-line-number|line-number|json|tab|checkstyle|code-climate, default is "colored-line-number"
|
||||
|
@ -41,7 +40,6 @@ output:
|
|||
# print linter name in the end of issue text, default is true
|
||||
print-linter-name: true
|
||||
|
||||
|
||||
# all available settings of specific linters
|
||||
linters-settings:
|
||||
errcheck:
|
||||
|
@ -72,22 +70,12 @@ linters-settings:
|
|||
- (github.com/golangci/golangci-lint/pkg/logutils.Log).Warnf
|
||||
- (github.com/golangci/golangci-lint/pkg/logutils.Log).Errorf
|
||||
- (github.com/golangci/golangci-lint/pkg/logutils.Log).Fatalf
|
||||
golint:
|
||||
# minimal confidence for issues, default is 0.8
|
||||
min-confidence: 0.8
|
||||
gofmt:
|
||||
# simplify code: gofmt with `-s` option, true by default
|
||||
simplify: true
|
||||
goimports:
|
||||
# put imports beginning with prefix after 3rd-party packages;
|
||||
# it's a comma-separated list of prefixes
|
||||
#local-prefixes: github.com/org/project
|
||||
gocyclo:
|
||||
# minimal code complexity to report, 30 by default (but we recommend 10-20)
|
||||
min-complexity: 25
|
||||
maligned:
|
||||
# print struct with more effective memory layout or not, false by default
|
||||
suggest-new: true
|
||||
dupl:
|
||||
# tokens count to trigger issue, 150 by default
|
||||
threshold: 100
|
||||
|
@ -96,30 +84,17 @@ linters-settings:
|
|||
min-len: 3
|
||||
# minimal occurrences count to trigger, 3 by default
|
||||
min-occurrences: 3
|
||||
depguard:
|
||||
list-type: blacklist
|
||||
include-go-root: false
|
||||
packages:
|
||||
# - github.com/davecgh/go-spew/spew
|
||||
misspell:
|
||||
# Correct spellings using locale preferences for US or UK.
|
||||
# Default is to use a neutral variety of English.
|
||||
# Setting locale to US will correct the British spelling of 'colour' to 'color'.
|
||||
locale: UK
|
||||
ignore-words:
|
||||
# - someword
|
||||
lll:
|
||||
# max line length, lines longer will be reported. Default is 120.
|
||||
# '\t' is counted as 1 character by default, and can be changed with the tab-width option
|
||||
line-length: 96
|
||||
# tab width in spaces. Default to 1.
|
||||
tab-width: 1
|
||||
unused:
|
||||
# treat code as a program (not a library) and report unused exported identifiers; default is false.
|
||||
# XXX: if you enable this setting, unused will report a lot of false-positives in text editors:
|
||||
# if it's called for subdir of a project it can't find funcs usages. All text editor integrations
|
||||
# with golangci-lint call it on a directory with the changed file.
|
||||
check-exported: false
|
||||
unparam:
|
||||
# Inspect exported functions, default is false. Set to true if no external program/library imports your code.
|
||||
# XXX: if you enable this setting, unparam will report a lot of false-positives in text editors:
|
||||
|
@ -167,7 +142,6 @@ linters:
|
|||
- gosimple
|
||||
- govet
|
||||
- ineffassign
|
||||
- megacheck
|
||||
- misspell # Check code comments, whereas misspell in CI checks *.md files
|
||||
- nakedret
|
||||
- staticcheck
|
||||
|
@ -182,19 +156,14 @@ linters:
|
|||
- gochecknoinits
|
||||
- gocritic
|
||||
- gofmt
|
||||
- golint
|
||||
- gosec # Should turn back on soon
|
||||
- interfacer
|
||||
- lll
|
||||
- maligned
|
||||
- prealloc # Should turn back on soon
|
||||
- scopelint
|
||||
- stylecheck
|
||||
- typecheck # Should turn back on soon
|
||||
- unconvert # Should turn back on soon
|
||||
- goconst # Slightly annoying, as it reports "issues" in SQL statements
|
||||
disable-all: false
|
||||
presets:
|
||||
fast: false
|
||||
|
||||
|
||||
|
@ -217,13 +186,6 @@ issues:
|
|||
- bin
|
||||
- docs
|
||||
|
||||
# List of regexps of issue texts to exclude, empty list by default.
|
||||
# But independently from this option we use default exclude patterns,
|
||||
# it can be disabled by `exclude-use-default: false`. To list all
|
||||
# excluded by default patterns execute `golangci-lint run --help`
|
||||
exclude:
|
||||
# - abcdef
|
||||
|
||||
# Excluding configuration per-path, per-linter, per-text and per-source
|
||||
exclude-rules:
|
||||
# Exclude some linters from running on tests files.
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# base installs required dependencies and runs go mod download to cache dependencies
|
||||
#
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/golang:1.22-alpine AS base
|
||||
FROM --platform=${BUILDPLATFORM} docker.io/golang:1.23-alpine AS base
|
||||
RUN apk --update --no-cache add bash build-base curl git
|
||||
|
||||
#
|
||||
|
|
|
@ -53,6 +53,7 @@ func NewInternalAPI(
|
|||
if err := generateAppServiceAccount(userAPI, appservice, cfg.Global.ServerName); err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"appservice": appservice.ID,
|
||||
"as_token": appservice.ASToken,
|
||||
}).WithError(err).Panicf("failed to generate bot account for appservice")
|
||||
}
|
||||
}
|
||||
|
@ -92,12 +93,13 @@ func generateAppServiceAccount(
|
|||
}
|
||||
var devRes userapi.PerformDeviceCreationResponse
|
||||
err = userAPI.PerformDeviceCreation(context.Background(), &userapi.PerformDeviceCreationRequest{
|
||||
Localpart: as.SenderLocalpart,
|
||||
ServerName: serverName,
|
||||
AccessToken: as.ASToken,
|
||||
DeviceID: &as.SenderLocalpart,
|
||||
DeviceDisplayName: &as.SenderLocalpart,
|
||||
NoDeviceListUpdate: true,
|
||||
Localpart: as.SenderLocalpart,
|
||||
ServerName: serverName,
|
||||
AccessToken: as.ASToken,
|
||||
DeviceID: &as.SenderLocalpart,
|
||||
DeviceDisplayName: &as.SenderLocalpart,
|
||||
NoDeviceListUpdate: true,
|
||||
AccessTokenUniqueConstraintDisabled: false,
|
||||
}, &devRes)
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/element-hq/dendrite/clientapi"
|
||||
"github.com/element-hq/dendrite/clientapi/auth"
|
||||
"github.com/element-hq/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/element-hq/dendrite/federationapi/statistics"
|
||||
"github.com/element-hq/dendrite/internal/httputil"
|
||||
|
@ -138,7 +139,7 @@ func TestAppserviceInternalAPI(t *testing.T) {
|
|||
as := &config.ApplicationService{
|
||||
ID: "someID",
|
||||
URL: srv.URL,
|
||||
ASToken: "",
|
||||
ASToken: util.RandomString(12),
|
||||
HSToken: "",
|
||||
SenderLocalpart: "senderLocalPart",
|
||||
NamespaceMap: map[string][]config.ApplicationServiceNamespace{
|
||||
|
@ -232,7 +233,7 @@ func TestAppserviceInternalAPI_UnixSocket_Simple(t *testing.T) {
|
|||
as := &config.ApplicationService{
|
||||
ID: "someID",
|
||||
URL: fmt.Sprintf("unix://%s", socket),
|
||||
ASToken: "",
|
||||
ASToken: util.RandomString(8),
|
||||
HSToken: "",
|
||||
SenderLocalpart: "senderLocalPart",
|
||||
NamespaceMap: map[string][]config.ApplicationServiceNamespace{
|
||||
|
@ -376,7 +377,7 @@ func TestRoomserverConsumerOneInvite(t *testing.T) {
|
|||
as := &config.ApplicationService{
|
||||
ID: "someID",
|
||||
URL: srv.URL,
|
||||
ASToken: "",
|
||||
ASToken: util.RandomString(8),
|
||||
HSToken: "",
|
||||
SenderLocalpart: "senderLocalPart",
|
||||
NamespaceMap: map[string][]config.ApplicationServiceNamespace{
|
||||
|
@ -446,7 +447,8 @@ func TestOutputAppserviceEvent(t *testing.T) {
|
|||
}
|
||||
|
||||
usrAPI := userapi.NewInternalAPI(processCtx, cfg, cm, natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
clientapi.AddPublicRoutes(processCtx, routers, cfg, natsInstance, nil, rsAPI, nil, nil, nil, usrAPI, nil, nil, caching.DisableMetrics)
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: usrAPI}
|
||||
clientapi.AddPublicRoutes(processCtx, routers, cfg, natsInstance, nil, rsAPI, nil, nil, nil, usrAPI, nil, nil, &userVerifier, caching.DisableMetrics)
|
||||
createAccessTokens(t, accessTokens, usrAPI, processCtx.Context(), routers)
|
||||
|
||||
room := test.NewRoom(t, alice)
|
||||
|
@ -508,7 +510,7 @@ func TestOutputAppserviceEvent(t *testing.T) {
|
|||
as := &config.ApplicationService{
|
||||
ID: "someID",
|
||||
URL: srv.URL,
|
||||
ASToken: "",
|
||||
ASToken: util.RandomString(8),
|
||||
HSToken: "",
|
||||
SenderLocalpart: "senderLocalPart",
|
||||
NamespaceMap: map[string][]config.ApplicationServiceNamespace{
|
||||
|
@ -537,7 +539,7 @@ func TestOutputAppserviceEvent(t *testing.T) {
|
|||
}
|
||||
|
||||
// Start the syncAPI to have `/joined_members` available
|
||||
syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, natsInstance, usrAPI, rsAPI, caches, caching.DisableMetrics)
|
||||
syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, natsInstance, usrAPI, rsAPI, caches, &userVerifier, caching.DisableMetrics)
|
||||
|
||||
// start the consumer
|
||||
appservice.NewInternalAPI(processCtx, cfg, natsInstance, usrAPI, rsAPI)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#syntax=docker/dockerfile:1.2
|
||||
|
||||
FROM golang:1.22-bookworm as build
|
||||
FROM golang:1.23-bookworm as build
|
||||
RUN apt-get update && apt-get install -y sqlite3
|
||||
WORKDIR /build
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
#
|
||||
# Use these mounts to make use of this dockerfile:
|
||||
# COMPLEMENT_HOST_MOUNTS='/your/local/dendrite:/dendrite:ro;/your/go/path:/go:ro'
|
||||
FROM golang:1.22-bookworm
|
||||
FROM golang:1.23-bookworm
|
||||
RUN apt-get update && apt-get install -y sqlite3
|
||||
|
||||
ENV SERVER_NAME=localhost
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#syntax=docker/dockerfile:1.2
|
||||
|
||||
FROM golang:1.22-bookworm as build
|
||||
FROM golang:1.23-bookworm as build
|
||||
RUN apt-get update && apt-get install -y postgresql
|
||||
WORKDIR /build
|
||||
|
||||
|
|
|
@ -1,3 +1,8 @@
|
|||
// Copyright 2025 New Vector Ltd.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
|
||||
// Please see LICENSE files in the repository root for full details.
|
||||
|
||||
package clientapi
|
||||
|
||||
import (
|
||||
|
@ -11,6 +16,8 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/element-hq/dendrite/userapi/types"
|
||||
|
||||
"github.com/element-hq/dendrite/federationapi"
|
||||
"github.com/element-hq/dendrite/internal/caching"
|
||||
"github.com/element-hq/dendrite/internal/httputil"
|
||||
|
@ -27,6 +34,7 @@ import (
|
|||
"github.com/tidwall/gjson"
|
||||
|
||||
capi "github.com/element-hq/dendrite/clientapi/api"
|
||||
"github.com/element-hq/dendrite/clientapi/auth"
|
||||
"github.com/element-hq/dendrite/test"
|
||||
"github.com/element-hq/dendrite/test/testrig"
|
||||
"github.com/element-hq/dendrite/userapi"
|
||||
|
@ -48,7 +56,8 @@ func TestAdminCreateToken(t *testing.T) {
|
|||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
|
||||
accessTokens := map[*test.User]userDevice{
|
||||
aliceAdmin: {},
|
||||
bob: {},
|
||||
|
@ -199,7 +208,8 @@ func TestAdminListRegistrationTokens(t *testing.T) {
|
|||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
|
||||
accessTokens := map[*test.User]userDevice{
|
||||
aliceAdmin: {},
|
||||
bob: {},
|
||||
|
@ -317,7 +327,8 @@ func TestAdminGetRegistrationToken(t *testing.T) {
|
|||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
|
||||
accessTokens := map[*test.User]userDevice{
|
||||
aliceAdmin: {},
|
||||
bob: {},
|
||||
|
@ -418,7 +429,8 @@ func TestAdminDeleteRegistrationToken(t *testing.T) {
|
|||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
|
||||
accessTokens := map[*test.User]userDevice{
|
||||
aliceAdmin: {},
|
||||
bob: {},
|
||||
|
@ -512,7 +524,8 @@ func TestAdminUpdateRegistrationToken(t *testing.T) {
|
|||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
|
||||
accessTokens := map[*test.User]userDevice{
|
||||
aliceAdmin: {},
|
||||
bob: {},
|
||||
|
@ -697,7 +710,8 @@ func TestAdminResetPassword(t *testing.T) {
|
|||
// Needed for changing the password/login
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
// We mostly need the userAPI for this test, so nil for other APIs/caches etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
|
||||
|
||||
// Create the users in the userapi and login
|
||||
accessTokens := map[*test.User]userDevice{
|
||||
|
@ -794,7 +808,8 @@ func TestPurgeRoom(t *testing.T) {
|
|||
rsAPI.SetFederationAPI(fsAPI, nil)
|
||||
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, userAPI, rsAPI, caches, caching.DisableMetrics)
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, userAPI, rsAPI, caches, &userVerifier, caching.DisableMetrics)
|
||||
|
||||
// Create the room
|
||||
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
|
||||
|
@ -802,7 +817,7 @@ func TestPurgeRoom(t *testing.T) {
|
|||
}
|
||||
|
||||
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
|
||||
|
||||
// Create the users in the userapi and login
|
||||
accessTokens := map[*test.User]userDevice{
|
||||
|
@ -872,8 +887,10 @@ func TestAdminEvacuateRoom(t *testing.T) {
|
|||
t.Fatalf("failed to send events: %v", err)
|
||||
}
|
||||
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
|
||||
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
|
||||
|
||||
// Create the users in the userapi and login
|
||||
accessTokens := map[*test.User]userDevice{
|
||||
|
@ -976,8 +993,10 @@ func TestAdminEvacuateUser(t *testing.T) {
|
|||
t.Fatalf("failed to send events: %v", err)
|
||||
}
|
||||
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
|
||||
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
|
||||
|
||||
// Create the users in the userapi and login
|
||||
accessTokens := map[*test.User]userDevice{
|
||||
|
@ -1059,8 +1078,10 @@ func TestAdminMarkAsStale(t *testing.T) {
|
|||
rsAPI.SetFederationAPI(nil, nil)
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
|
||||
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
|
||||
|
||||
// Create the users in the userapi and login
|
||||
accessTokens := map[*test.User]userDevice{
|
||||
|
@ -1147,8 +1168,10 @@ func TestAdminQueryEventReports(t *testing.T) {
|
|||
t.Fatalf("failed to send events: %v", err)
|
||||
}
|
||||
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
|
||||
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
|
||||
|
||||
accessTokens := map[*test.User]userDevice{
|
||||
alice: {},
|
||||
|
@ -1376,8 +1399,10 @@ func TestEventReportsGetDelete(t *testing.T) {
|
|||
t.Fatalf("failed to send events: %v", err)
|
||||
}
|
||||
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
|
||||
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
|
||||
|
||||
accessTokens := map[*test.User]userDevice{
|
||||
alice: {},
|
||||
|
@ -1473,3 +1498,840 @@ func TestEventReportsGetDelete(t *testing.T) {
|
|||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminCheckUsernameAvailable(t *testing.T) {
|
||||
alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
|
||||
adminToken := "superSecretAdminToken"
|
||||
ctx := context.Background()
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
||||
defer close()
|
||||
natsInstance := jetstream.NATSInstance{}
|
||||
// add a vhost
|
||||
cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{
|
||||
SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"},
|
||||
})
|
||||
// There's no need to add a full config for msc3861 as we need only an admin token
|
||||
cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"}
|
||||
cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken}
|
||||
|
||||
routers := httputil.NewRouters()
|
||||
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
||||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
// We mostly need the userAPI for this test, so nil for other APIs/caches etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics)
|
||||
userRes := &uapi.PerformAccountCreationResponse{}
|
||||
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
|
||||
AccountType: alice.AccountType,
|
||||
Localpart: alice.Localpart,
|
||||
ServerName: cfg.Global.ServerName,
|
||||
Password: "",
|
||||
}, userRes); err != nil {
|
||||
t.Errorf("failed to create account: %s", err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
accessToken string
|
||||
userID string
|
||||
wantOK bool
|
||||
isAvailable bool
|
||||
}{
|
||||
{name: "Missing auth", accessToken: "", wantOK: false, userID: alice.Localpart, isAvailable: false},
|
||||
{name: "Alice - user exists", accessToken: adminToken, wantOK: true, userID: alice.Localpart, isAvailable: false},
|
||||
{name: "Bob - user does not exist", accessToken: adminToken, wantOK: true, userID: "bob", isAvailable: true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc // ensure we don't accidentally only test the last test case
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := test.NewRequest(t, http.MethodGet, "/_synapse/admin/v1/username_available?username="+tc.userID)
|
||||
if tc.accessToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+tc.accessToken)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
routers.SynapseAdmin.ServeHTTP(rec, req)
|
||||
t.Logf("%s", rec.Body.String())
|
||||
if tc.wantOK && rec.Code != http.StatusOK || !tc.wantOK && rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
// Nothing more to check, test is done.
|
||||
if !tc.wantOK {
|
||||
return
|
||||
}
|
||||
|
||||
b := make(map[string]bool, 1)
|
||||
_ = json.NewDecoder(rec.Body).Decode(&b)
|
||||
available, ok := b["available"]
|
||||
if !ok {
|
||||
t.Fatal("'available' not found in body")
|
||||
}
|
||||
if available != tc.isAvailable {
|
||||
t.Fatalf("expected 'available' to be %t, got %t instead", tc.isAvailable, available)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminUserDeviceRetrieveCreate(t *testing.T) {
|
||||
alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
|
||||
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
|
||||
adminToken := "superSecretAdminToken"
|
||||
ctx := context.Background()
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
||||
defer close()
|
||||
natsInstance := jetstream.NATSInstance{}
|
||||
// add a vhost
|
||||
cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{
|
||||
SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"},
|
||||
})
|
||||
// There's no need to add a full config for msc3861 as we need only an admin token
|
||||
cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"}
|
||||
cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken}
|
||||
|
||||
routers := httputil.NewRouters()
|
||||
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
||||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
// We mostly need the userAPI for this test, so nil for other APIs/caches etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics)
|
||||
|
||||
for _, u := range []*test.User{alice, bob} {
|
||||
userRes := &uapi.PerformAccountCreationResponse{}
|
||||
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
|
||||
AccountType: u.AccountType,
|
||||
Localpart: u.Localpart,
|
||||
ServerName: cfg.Global.ServerName,
|
||||
Password: "",
|
||||
}, userRes); err != nil {
|
||||
t.Errorf("failed to create account: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("Missing auth token", func(t *testing.T) {
|
||||
req := test.NewRequest(t, http.MethodGet, "/_synapse/admin/v2/users/"+alice.ID+"/devices")
|
||||
rec := httptest.NewRecorder()
|
||||
routers.SynapseAdmin.ServeHTTP(rec, req)
|
||||
t.Logf("%s", rec.Body.String())
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected http status %d, got %d: %s", http.StatusUnauthorized, rec.Code, rec.Body.String())
|
||||
}
|
||||
var b spec.MatrixError
|
||||
_ = json.NewDecoder(rec.Body).Decode(&b)
|
||||
if b.ErrCode != spec.ErrorMissingToken {
|
||||
t.Fatalf("expected error code %s, got %s", spec.ErrorMissingToken, b.ErrCode)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Retrieve device", func(t *testing.T) {
|
||||
var deviceRes uapi.PerformDeviceCreationResponse
|
||||
if err := userAPI.PerformDeviceCreation(ctx, &uapi.PerformDeviceCreationRequest{
|
||||
Localpart: alice.Localpart,
|
||||
ServerName: cfg.Global.ServerName,
|
||||
AccessTokenUniqueConstraintDisabled: true,
|
||||
}, &deviceRes); err != nil {
|
||||
t.Errorf("failed to create account: %s", err)
|
||||
}
|
||||
req := test.NewRequest(t, http.MethodGet, "/_synapse/admin/v2/users/"+alice.ID+"/devices")
|
||||
req.Header.Set("Authorization", "Bearer "+adminToken)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
routers.SynapseAdmin.ServeHTTP(rec, req)
|
||||
t.Logf("%s", rec.Body.String())
|
||||
var body struct {
|
||||
Total int `json:"total"`
|
||||
Devices []struct {
|
||||
DeviceID string `json:"device_id"`
|
||||
} `json:"devices"`
|
||||
}
|
||||
_ = json.NewDecoder(rec.Body).Decode(&body)
|
||||
if body.Total != 1 {
|
||||
t.Errorf("expected 1 device, got %d", body.Total)
|
||||
}
|
||||
if len(body.Devices) != 1 {
|
||||
t.Errorf("expected 1 device, got %d", len(body.Devices))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Create device", func(t *testing.T) {
|
||||
reqBody := struct {
|
||||
DeviceID string `json:"device_id"`
|
||||
}{DeviceID: "devBob"}
|
||||
req := test.NewRequest(t, http.MethodPost, "/_synapse/admin/v2/users/"+bob.ID+"/devices", test.WithJSONBody(t, reqBody))
|
||||
req.Header.Set("Authorization", "Bearer "+adminToken)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
routers.SynapseAdmin.ServeHTTP(rec, req)
|
||||
t.Logf("%s", rec.Body.String())
|
||||
if rec.Code != http.StatusCreated {
|
||||
t.Fatalf("expected HTTP status %d, got %d: %s", http.StatusCreated, rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
var res uapi.QueryDevicesResponse
|
||||
_ = userAPI.QueryDevices(ctx, &uapi.QueryDevicesRequest{UserID: bob.ID}, &res)
|
||||
if len(res.Devices) != 1 {
|
||||
t.Errorf("expected 1 device, got %d", len(res.Devices))
|
||||
}
|
||||
if res.Devices[0].ID != "devBob" {
|
||||
t.Errorf("expected device to be devBob, got %s", res.Devices[0].ID)
|
||||
}
|
||||
})
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminUserDeviceDelete(t *testing.T) {
|
||||
alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
|
||||
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
|
||||
adminToken := "superSecretAdminToken"
|
||||
ctx := context.Background()
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
||||
defer close()
|
||||
natsInstance := jetstream.NATSInstance{}
|
||||
// add a vhost
|
||||
cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{
|
||||
SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"},
|
||||
})
|
||||
// There's no need to add a full config for msc3861 as we need only an admin token
|
||||
cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"}
|
||||
cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken}
|
||||
|
||||
routers := httputil.NewRouters()
|
||||
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
||||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
// We mostly need the userAPI for this test, so nil for other APIs/caches etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics)
|
||||
|
||||
for _, u := range []*test.User{alice} {
|
||||
userRes := &uapi.PerformAccountCreationResponse{}
|
||||
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
|
||||
AccountType: u.AccountType,
|
||||
Localpart: u.Localpart,
|
||||
ServerName: cfg.Global.ServerName,
|
||||
Password: "",
|
||||
}, userRes); err != nil {
|
||||
t.Errorf("failed to create account: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("Missing auth token", func(t *testing.T) {
|
||||
req := test.NewRequest(t, http.MethodDelete, "/_synapse/admin/v2/users/"+alice.ID+"/devices/anything")
|
||||
rec := httptest.NewRecorder()
|
||||
routers.SynapseAdmin.ServeHTTP(rec, req)
|
||||
t.Logf("%s", rec.Body.String())
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected http status %d, got %d: %s", http.StatusUnauthorized, rec.Code, rec.Body.String())
|
||||
}
|
||||
var b spec.MatrixError
|
||||
_ = json.NewDecoder(rec.Body).Decode(&b)
|
||||
if b.ErrCode != spec.ErrorMissingToken {
|
||||
t.Fatalf("expected error code %s, got %s", spec.ErrorMissingToken, b.ErrCode)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Delete existing device", func(t *testing.T) {
|
||||
var deviceRes uapi.PerformDeviceCreationResponse
|
||||
if err := userAPI.PerformDeviceCreation(ctx, &uapi.PerformDeviceCreationRequest{
|
||||
Localpart: alice.Localpart,
|
||||
ServerName: cfg.Global.ServerName,
|
||||
AccessTokenUniqueConstraintDisabled: true,
|
||||
}, &deviceRes); err != nil {
|
||||
t.Errorf("failed to create account: %s", err)
|
||||
}
|
||||
req := test.NewRequest(t, http.MethodDelete, "/_synapse/admin/v2/users/"+alice.ID+"/devices/"+deviceRes.Device.ID)
|
||||
req.Header.Set("Authorization", "Bearer "+adminToken)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
routers.SynapseAdmin.ServeHTTP(rec, req)
|
||||
t.Logf("%s", rec.Body.String())
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected HTTP status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
var rs uapi.QueryDevicesResponse
|
||||
_ = userAPI.QueryDevices(ctx, &uapi.QueryDevicesRequest{UserID: alice.ID}, &rs)
|
||||
if len(rs.Devices) > 0 {
|
||||
t.Errorf("expected 0 devices, got %d", len(rs.Devices))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Delete non-existing user's devices", func(t *testing.T) {
|
||||
req := test.NewRequest(t, http.MethodDelete, "/_synapse/admin/v2/users/"+bob.ID+"/devices/anything")
|
||||
req.Header.Set("Authorization", "Bearer "+adminToken)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
routers.SynapseAdmin.ServeHTTP(rec, req)
|
||||
t.Logf("%s", rec.Body.String())
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected HTTP status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminUserDevicesDelete(t *testing.T) {
|
||||
alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
|
||||
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
|
||||
adminToken := "superSecretAdminToken"
|
||||
ctx := context.Background()
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
||||
defer close()
|
||||
natsInstance := jetstream.NATSInstance{}
|
||||
// add a vhost
|
||||
cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{
|
||||
SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"},
|
||||
})
|
||||
// There's no need to add a full config for msc3861 as we need only an admin token
|
||||
cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"}
|
||||
cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken}
|
||||
|
||||
routers := httputil.NewRouters()
|
||||
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
||||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
// We mostly need the userAPI for this test, so nil for other APIs/caches etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics)
|
||||
|
||||
for _, u := range []*test.User{alice} {
|
||||
userRes := &uapi.PerformAccountCreationResponse{}
|
||||
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
|
||||
AccountType: u.AccountType,
|
||||
Localpart: u.Localpart,
|
||||
ServerName: cfg.Global.ServerName,
|
||||
Password: "",
|
||||
}, userRes); err != nil {
|
||||
t.Errorf("failed to create account: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
type payload struct {
|
||||
Devices []string `json:"devices"`
|
||||
}
|
||||
|
||||
t.Run("Missing auth token", func(t *testing.T) {
|
||||
req := test.NewRequest(t, http.MethodPost, "/_synapse/admin/v2/users/"+alice.ID+"/delete_devices")
|
||||
rec := httptest.NewRecorder()
|
||||
routers.SynapseAdmin.ServeHTTP(rec, req)
|
||||
t.Logf("%s", rec.Body.String())
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected http status %d, got %d: %s", http.StatusUnauthorized, rec.Code, rec.Body.String())
|
||||
}
|
||||
var b spec.MatrixError
|
||||
_ = json.NewDecoder(rec.Body).Decode(&b)
|
||||
if b.ErrCode != spec.ErrorMissingToken {
|
||||
t.Fatalf("expected error code %s, got %s", spec.ErrorMissingToken, b.ErrCode)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Delete existing user's devices", func(t *testing.T) {
|
||||
var deviceRes uapi.PerformDeviceCreationResponse
|
||||
if err := userAPI.PerformDeviceCreation(ctx, &uapi.PerformDeviceCreationRequest{
|
||||
Localpart: alice.Localpart,
|
||||
ServerName: cfg.Global.ServerName,
|
||||
AccessTokenUniqueConstraintDisabled: true,
|
||||
}, &deviceRes); err != nil {
|
||||
t.Errorf("failed to create account: %s", err)
|
||||
}
|
||||
req := test.NewRequest(
|
||||
t,
|
||||
http.MethodPost,
|
||||
"/_synapse/admin/v2/users/"+alice.ID+"/delete_devices",
|
||||
test.WithJSONBody(t, payload{Devices: []string{deviceRes.Device.ID}}),
|
||||
)
|
||||
req.Header.Set("Authorization", "Bearer "+adminToken)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
routers.SynapseAdmin.ServeHTTP(rec, req)
|
||||
t.Logf("%s", rec.Body.String())
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected HTTP status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
var rs uapi.QueryDevicesResponse
|
||||
_ = userAPI.QueryDevices(ctx, &uapi.QueryDevicesRequest{UserID: alice.ID}, &rs)
|
||||
if len(rs.Devices) > 0 {
|
||||
t.Errorf("expected 0 devices, got %d", len(rs.Devices))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Delete non-existing user's devices", func(t *testing.T) {
|
||||
req := test.NewRequest(
|
||||
t,
|
||||
http.MethodPost,
|
||||
"/_synapse/admin/v2/users/"+bob.ID+"/delete_devices",
|
||||
test.WithJSONBody(t, payload{Devices: []string{"anyDevID"}}),
|
||||
)
|
||||
req.Header.Set("Authorization", "Bearer "+adminToken)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
routers.SynapseAdmin.ServeHTTP(rec, req)
|
||||
t.Logf("%s", rec.Body.String())
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected HTTP status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminDeactivateAccount(t *testing.T) {
|
||||
alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
|
||||
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
|
||||
adminToken := "superSecretAdminToken"
|
||||
ctx := context.Background()
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
||||
defer close()
|
||||
natsInstance := jetstream.NATSInstance{}
|
||||
// add a vhost
|
||||
cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{
|
||||
SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"},
|
||||
})
|
||||
// There's no need to add a full config for msc3861 as we need only an admin token
|
||||
cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"}
|
||||
cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken}
|
||||
|
||||
routers := httputil.NewRouters()
|
||||
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
||||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
// We mostly need the userAPI for this test, so nil for other APIs/caches etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics)
|
||||
|
||||
for _, u := range []*test.User{alice} {
|
||||
userRes := &uapi.PerformAccountCreationResponse{}
|
||||
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
|
||||
AccountType: u.AccountType,
|
||||
Localpart: u.Localpart,
|
||||
ServerName: cfg.Global.ServerName,
|
||||
Password: "",
|
||||
}, userRes); err != nil {
|
||||
t.Errorf("failed to create account: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("Missing auth token", func(t *testing.T) {
|
||||
req := test.NewRequest(t, http.MethodPost, "/_synapse/admin/v1/deactivate/"+alice.ID)
|
||||
rec := httptest.NewRecorder()
|
||||
routers.SynapseAdmin.ServeHTTP(rec, req)
|
||||
t.Logf("%s", rec.Body.String())
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected http status %d, got %d: %s", http.StatusUnauthorized, rec.Code, rec.Body.String())
|
||||
}
|
||||
var b spec.MatrixError
|
||||
_ = json.NewDecoder(rec.Body).Decode(&b)
|
||||
if b.ErrCode != spec.ErrorMissingToken {
|
||||
t.Fatalf("expected error code %s, got %s", spec.ErrorMissingToken, b.ErrCode)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Deactivate existing account", func(t *testing.T) {
|
||||
req := test.NewRequest(t, http.MethodPost, "/_synapse/admin/v1/deactivate/"+alice.ID)
|
||||
req.Header.Set("Authorization", "Bearer "+adminToken)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
routers.SynapseAdmin.ServeHTTP(rec, req)
|
||||
t.Logf("%s", rec.Body.String())
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected HTTP status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
var rs uapi.QueryAccountByLocalpartResponse
|
||||
_ = userAPI.QueryAccountByLocalpart(ctx, &uapi.QueryAccountByLocalpartRequest{Localpart: alice.Localpart, ServerName: cfg.Global.ServerName}, &rs)
|
||||
if !rs.Account.Deactivated {
|
||||
t.Fatalf("expected account is deactivated")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Deactivate non-existing account", func(t *testing.T) {
|
||||
req := test.NewRequest(t, http.MethodPost, "/_synapse/admin/v1/deactivate/"+bob.ID)
|
||||
req.Header.Set("Authorization", "Bearer "+adminToken)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
routers.SynapseAdmin.ServeHTTP(rec, req)
|
||||
t.Logf("%s", rec.Body.String())
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected HTTP status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminAllowCrossSigningReplacementWithoutUIA(t *testing.T) {
|
||||
alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
|
||||
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
|
||||
adminToken := "superSecretAdminToken"
|
||||
ctx := context.Background()
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
||||
defer close()
|
||||
natsInstance := jetstream.NATSInstance{}
|
||||
// add a vhost
|
||||
cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{
|
||||
SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"},
|
||||
})
|
||||
// There's no need to add a full config for msc3861 as we need only an admin token
|
||||
cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"}
|
||||
cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken}
|
||||
|
||||
routers := httputil.NewRouters()
|
||||
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
||||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
// We mostly need the userAPI for this test, so nil for other APIs/caches etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics)
|
||||
|
||||
t.Run("Missing auth token", func(t *testing.T) {
|
||||
req := test.NewRequest(t, http.MethodPost, "/_synapse/admin/v1/users/"+alice.ID+"/_allow_cross_signing_replacement_without_uia")
|
||||
rec := httptest.NewRecorder()
|
||||
routers.SynapseAdmin.ServeHTTP(rec, req)
|
||||
t.Logf("%s", rec.Body.String())
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected http status %d, got %d: %s", http.StatusUnauthorized, rec.Code, rec.Body.String())
|
||||
}
|
||||
var b spec.MatrixError
|
||||
_ = json.NewDecoder(rec.Body).Decode(&b)
|
||||
if b.ErrCode != spec.ErrorMissingToken {
|
||||
t.Fatalf("expected error code %s, got %s", spec.ErrorMissingToken, b.ErrCode)
|
||||
}
|
||||
})
|
||||
|
||||
for _, u := range []*test.User{alice} {
|
||||
var userRes uapi.PerformAccountCreationResponse
|
||||
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
|
||||
AccountType: u.AccountType,
|
||||
Localpart: u.Localpart,
|
||||
ServerName: cfg.Global.ServerName,
|
||||
Password: "",
|
||||
}, &userRes); err != nil {
|
||||
t.Errorf("failed to create account: %s", err)
|
||||
}
|
||||
_ = userAPI.KeyDatabase.StoreCrossSigningKeysForUser(ctx, alice.ID, types.CrossSigningKeyMap{
|
||||
fclient.CrossSigningKeyPurposeMaster: spec.Base64Bytes("Og7D7+RQS030dOsWEtS/juJLTOVojXk1DoNKadyXWyk"),
|
||||
fclient.CrossSigningKeyPurposeSelfSigning: spec.Base64Bytes("Og7D7+RQS030dOsWEtS/juJLTOVojXk1DoNKadyXWyk"),
|
||||
fclient.CrossSigningKeyPurposeUserSigning: spec.Base64Bytes("Og7D7+RQS030dOsWEtS/juJLTOVojXk1DoNKadyXWyk"),
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
Name string
|
||||
User *test.User
|
||||
Code int
|
||||
}{
|
||||
{Name: "existing user", User: alice, Code: 200},
|
||||
{Name: "non-existing user", User: bob, Code: 404},
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
req := test.NewRequest(t, http.MethodPost, "/_synapse/admin/v1/users/"+tc.User.ID+"/_allow_cross_signing_replacement_without_uia")
|
||||
req.Header.Set("Authorization", "Bearer "+adminToken)
|
||||
rec := httptest.NewRecorder()
|
||||
routers.SynapseAdmin.ServeHTTP(rec, req)
|
||||
t.Logf("%s", rec.Body.String())
|
||||
|
||||
if rec.Code != tc.Code {
|
||||
t.Fatalf("expected HTTP status %d, got %d: %s", tc.Code, rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
if rec.Code == 200 {
|
||||
buf := make(map[string]int64, 1)
|
||||
_ = json.NewDecoder(rec.Body).Decode(&buf)
|
||||
if ts := buf["updatable_without_uia_before_ms"]; ts <= now.UnixMilli() {
|
||||
t.Fatalf("expected updatable_without_uia_before_ms is in future, got %d", ts)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminCreateOrModifyAccount(t *testing.T) {
|
||||
alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
|
||||
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
|
||||
adminToken := "superSecretAdminToken"
|
||||
ctx := context.Background()
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
||||
defer close()
|
||||
natsInstance := jetstream.NATSInstance{}
|
||||
// add a vhost
|
||||
cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{
|
||||
SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"},
|
||||
})
|
||||
// There's no need to add a full config for msc3861 as we need only an admin token
|
||||
cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"}
|
||||
cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken}
|
||||
|
||||
routers := httputil.NewRouters()
|
||||
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
||||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
// We mostly need the userAPI for this test, so nil for other APIs/caches etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics)
|
||||
|
||||
for _, u := range []*test.User{alice} {
|
||||
userRes := &uapi.PerformAccountCreationResponse{}
|
||||
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
|
||||
AccountType: u.AccountType,
|
||||
Localpart: u.Localpart,
|
||||
ServerName: cfg.Global.ServerName,
|
||||
Password: "",
|
||||
}, userRes); err != nil {
|
||||
t.Errorf("failed to create account: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
type threePID struct {
|
||||
Medium string `json:"medium"`
|
||||
Address string `json:"address"`
|
||||
}
|
||||
type adminCreateOrModifyAccountRequest struct {
|
||||
DisplayName string `json:"displayname"`
|
||||
AvatarURL string `json:"avatar_url"`
|
||||
ThreePIDs []threePID `json:"threepids"`
|
||||
}
|
||||
|
||||
t.Run("Missing auth token", func(t *testing.T) {
|
||||
req := test.NewRequest(t, http.MethodPut, "/_synapse/admin/v2/users/"+alice.ID)
|
||||
rec := httptest.NewRecorder()
|
||||
routers.SynapseAdmin.ServeHTTP(rec, req)
|
||||
t.Logf("%s", rec.Body.String())
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected http status %d, got %d: %s", http.StatusUnauthorized, rec.Code, rec.Body.String())
|
||||
}
|
||||
var b spec.MatrixError
|
||||
_ = json.NewDecoder(rec.Body).Decode(&b)
|
||||
if b.ErrCode != spec.ErrorMissingToken {
|
||||
t.Fatalf("expected error code %s, got %s", spec.ErrorMissingToken, b.ErrCode)
|
||||
}
|
||||
})
|
||||
|
||||
testCases := []struct {
|
||||
Name string
|
||||
User *test.User
|
||||
Payload adminCreateOrModifyAccountRequest
|
||||
Expected struct {
|
||||
DisplayName,
|
||||
AvatarURL string
|
||||
ThreePIDs []string
|
||||
}
|
||||
Code int
|
||||
}{
|
||||
{
|
||||
Name: fmt.Sprintf("Modify user %s", alice.ID),
|
||||
User: alice,
|
||||
Payload: adminCreateOrModifyAccountRequest{
|
||||
DisplayName: "alice",
|
||||
AvatarURL: "https://alice-avatar.example.com",
|
||||
ThreePIDs: []threePID{
|
||||
{
|
||||
Medium: "email",
|
||||
Address: "alice@example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
Expected: struct {
|
||||
DisplayName, AvatarURL string
|
||||
ThreePIDs []string
|
||||
}{
|
||||
// In order to avoid any confusion and undesired behaviour, we do not change display name and avatar url if account already exists
|
||||
DisplayName: alice.Localpart,
|
||||
AvatarURL: "",
|
||||
ThreePIDs: []string{"alice@example.com"},
|
||||
},
|
||||
Code: http.StatusOK,
|
||||
},
|
||||
{
|
||||
Name: fmt.Sprintf("Create user %s", bob.ID),
|
||||
User: bob,
|
||||
Payload: adminCreateOrModifyAccountRequest{
|
||||
DisplayName: "bob",
|
||||
AvatarURL: "https://bob-avatar.example.com",
|
||||
ThreePIDs: []threePID{
|
||||
{
|
||||
Medium: "email",
|
||||
Address: "bob@example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
Expected: struct {
|
||||
DisplayName, AvatarURL string
|
||||
ThreePIDs []string
|
||||
}{
|
||||
DisplayName: "bob",
|
||||
AvatarURL: "https://bob-avatar.example.com",
|
||||
ThreePIDs: []string{"bob@example.com"},
|
||||
},
|
||||
Code: http.StatusCreated,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
req := test.NewRequest(
|
||||
t,
|
||||
http.MethodPut,
|
||||
"/_synapse/admin/v2/users/"+tc.User.ID,
|
||||
test.WithJSONBody(t, tc.Payload),
|
||||
)
|
||||
req.Header.Set("Authorization", "Bearer "+adminToken)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
routers.SynapseAdmin.ServeHTTP(rec, req)
|
||||
t.Logf("%s", rec.Body.String())
|
||||
if rec.Code != tc.Code {
|
||||
t.Fatalf("expected HTTP status %d, got %d: %s", tc.Code, rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
p, _ := userAPI.QueryProfile(ctx, tc.User.ID)
|
||||
if p.DisplayName != tc.Expected.DisplayName {
|
||||
t.Fatalf("expected display name %s, got %s", tc.Expected.DisplayName, p.DisplayName)
|
||||
}
|
||||
if p.AvatarURL != tc.Expected.AvatarURL {
|
||||
t.Fatalf("expected avatar_url %s, got %s", tc.Expected.AvatarURL, p.AvatarURL)
|
||||
}
|
||||
var threePidRs uapi.QueryThreePIDsForLocalpartResponse
|
||||
_ = userAPI.QueryThreePIDsForLocalpart(
|
||||
ctx,
|
||||
&uapi.QueryThreePIDsForLocalpartRequest{Localpart: tc.User.Localpart, ServerName: cfg.Global.ServerName},
|
||||
&threePidRs,
|
||||
)
|
||||
if len(threePidRs.ThreePIDs) != 1 {
|
||||
t.Fatalf("expected 1 3pid got %d", len(threePidRs.ThreePIDs))
|
||||
}
|
||||
tp := threePidRs.ThreePIDs[0]
|
||||
if tp.Medium != "email" {
|
||||
t.Fatalf("expected 3pid medium email got %s", tp.Medium)
|
||||
}
|
||||
if tp.Address != tc.Payload.ThreePIDs[0].Address {
|
||||
t.Fatalf("expected 3pid address %s got %s", tc.Expected.ThreePIDs[0], tp.Address)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminRetrieveAccount(t *testing.T) {
|
||||
alice := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
|
||||
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
|
||||
adminToken := "superSecretAdminToken"
|
||||
ctx := context.Background()
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
||||
defer close()
|
||||
natsInstance := jetstream.NATSInstance{}
|
||||
// add a vhost
|
||||
cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{
|
||||
SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"},
|
||||
})
|
||||
// There's no need to add a full config for msc3861 as we need only an admin token
|
||||
cfg.ClientAPI.MSCs.MSCs = []string{"msc3861"}
|
||||
cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{AdminToken: adminToken}
|
||||
|
||||
routers := httputil.NewRouters()
|
||||
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
||||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
// We mostly need the userAPI for this test, so nil for other APIs/caches etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil, caching.DisableMetrics)
|
||||
|
||||
for _, u := range []*test.User{alice} {
|
||||
userRes := &uapi.PerformAccountCreationResponse{}
|
||||
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
|
||||
AccountType: u.AccountType,
|
||||
Localpart: u.Localpart,
|
||||
ServerName: cfg.Global.ServerName,
|
||||
Password: "",
|
||||
}, userRes); err != nil {
|
||||
t.Errorf("failed to create account: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("Missing auth token", func(t *testing.T) {
|
||||
req := test.NewRequest(t, http.MethodGet, "/_synapse/admin/v2/users/"+alice.ID)
|
||||
rec := httptest.NewRecorder()
|
||||
routers.SynapseAdmin.ServeHTTP(rec, req)
|
||||
t.Logf("%s", rec.Body.String())
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected http status %d, got %d: %s", http.StatusUnauthorized, rec.Code, rec.Body.String())
|
||||
}
|
||||
var b spec.MatrixError
|
||||
_ = json.NewDecoder(rec.Body).Decode(&b)
|
||||
if b.ErrCode != spec.ErrorMissingToken {
|
||||
t.Fatalf("expected error code %s, got %s", spec.ErrorMissingToken, b.ErrCode)
|
||||
}
|
||||
})
|
||||
|
||||
testCase := []struct {
|
||||
Name string
|
||||
User *test.User
|
||||
Code int
|
||||
Body string
|
||||
}{
|
||||
{
|
||||
Name: "Retrieve existing account",
|
||||
User: alice,
|
||||
Code: http.StatusOK,
|
||||
Body: fmt.Sprintf(`{"display_name":"%s","avatar_url":"","deactivated":false}`, alice.Localpart),
|
||||
},
|
||||
{
|
||||
Name: "Retrieve non-existing account",
|
||||
User: bob,
|
||||
Code: http.StatusNotFound,
|
||||
Body: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCase {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
req := test.NewRequest(t, http.MethodGet, "/_synapse/admin/v2/users/"+tc.User.ID)
|
||||
req.Header.Set("Authorization", "Bearer "+adminToken)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
routers.SynapseAdmin.ServeHTTP(rec, req)
|
||||
t.Logf("%s", rec.Body.String())
|
||||
if rec.Code != tc.Code {
|
||||
t.Fatalf("expected HTTP status %d, got %d: %s", tc.Code, rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
if tc.Body != "" && tc.Body != rec.Body.String() {
|
||||
t.Fatalf("expected body %s, got %s", tc.Body, rec.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -16,8 +16,6 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/element-hq/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
// OWASP recommends at least 128 bits of entropy for tokens: https://www.owasp.org/index.php/Insufficient_Session-ID_Length
|
||||
|
@ -37,51 +35,6 @@ type AccountDatabase interface {
|
|||
GetAccountByPassword(ctx context.Context, localpart, password string) (*api.Account, error)
|
||||
}
|
||||
|
||||
// VerifyUserFromRequest authenticates the HTTP request,
|
||||
// on success returns Device of the requester.
|
||||
// Finds local user or an application service user.
|
||||
// Note: For an AS user, AS dummy device is returned.
|
||||
// On failure returns an JSON error response which can be sent to the client.
|
||||
func VerifyUserFromRequest(
|
||||
req *http.Request, userAPI api.QueryAcccessTokenAPI,
|
||||
) (*api.Device, *util.JSONResponse) {
|
||||
// Try to find the Application Service user
|
||||
token, err := ExtractAccessToken(req)
|
||||
if err != nil {
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusUnauthorized,
|
||||
JSON: spec.MissingToken(err.Error()),
|
||||
}
|
||||
}
|
||||
var res api.QueryAccessTokenResponse
|
||||
err = userAPI.QueryAccessToken(req.Context(), &api.QueryAccessTokenRequest{
|
||||
AccessToken: token,
|
||||
AppServiceUserID: req.URL.Query().Get("user_id"),
|
||||
}, &res)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccessToken failed")
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{},
|
||||
}
|
||||
}
|
||||
if res.Err != "" {
|
||||
if strings.HasPrefix(strings.ToLower(res.Err), "forbidden:") { // TODO: use actual error and no string comparison
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: spec.Forbidden(res.Err),
|
||||
}
|
||||
}
|
||||
}
|
||||
if res.Device == nil {
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusUnauthorized,
|
||||
JSON: spec.UnknownToken("Unknown token"),
|
||||
}
|
||||
}
|
||||
return res.Device, nil
|
||||
}
|
||||
|
||||
// GenerateAccessToken creates a new access token. Returns an error if failed to generate
|
||||
// random bytes.
|
||||
func GenerateAccessToken() (string, error) {
|
||||
|
|
65
clientapi/auth/default_user_verifier.go
Normal file
65
clientapi/auth/default_user_verifier.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
// Copyright 2025 New Vector Ltd.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
|
||||
// Please see LICENSE files in the repository root for full details.
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/element-hq/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
// DefaultUserVerifier implements UserVerifier interface
|
||||
type DefaultUserVerifier struct {
|
||||
UserAPI api.QueryAccessTokenAPI
|
||||
}
|
||||
|
||||
// VerifyUserFromRequest authenticates the HTTP request,
|
||||
// on success returns Device of the requester.
|
||||
// Finds local user or an application service user.
|
||||
// Note: For an AS user, AS dummy device is returned.
|
||||
// On failure returns an JSON error response which can be sent to the client.
|
||||
func (d *DefaultUserVerifier) VerifyUserFromRequest(req *http.Request) (*api.Device, *util.JSONResponse) {
|
||||
ctx := req.Context()
|
||||
util.GetLogger(ctx).Debug("Default VerifyUserFromRequest")
|
||||
// Try to find the Application Service user
|
||||
token, err := ExtractAccessToken(req)
|
||||
if err != nil {
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusUnauthorized,
|
||||
JSON: spec.MissingToken(err.Error()),
|
||||
}
|
||||
}
|
||||
var res api.QueryAccessTokenResponse
|
||||
err = d.UserAPI.QueryAccessToken(ctx, &api.QueryAccessTokenRequest{
|
||||
AccessToken: token,
|
||||
AppServiceUserID: req.URL.Query().Get("user_id"),
|
||||
}, &res)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("userAPI.QueryAccessToken failed")
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{},
|
||||
}
|
||||
}
|
||||
if res.Err != "" {
|
||||
if strings.HasPrefix(strings.ToLower(res.Err), "forbidden:") { // TODO: use actual error and no string comparison
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: spec.Forbidden(res.Err),
|
||||
}
|
||||
}
|
||||
}
|
||||
if res.Device == nil {
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusUnauthorized,
|
||||
JSON: spec.UnknownToken("Unknown token"),
|
||||
}
|
||||
}
|
||||
return res.Device, nil
|
||||
}
|
|
@ -36,7 +36,9 @@ func AddPublicRoutes(
|
|||
fsAPI federationAPI.ClientFederationAPI,
|
||||
userAPI userapi.ClientUserAPI,
|
||||
userDirectoryProvider userapi.QuerySearchProfilesAPI,
|
||||
extRoomsProvider api.ExtraPublicRoomsProvider, enableMetrics bool,
|
||||
extRoomsProvider api.ExtraPublicRoomsProvider,
|
||||
userVerifier httputil.UserVerifier,
|
||||
enableMetrics bool,
|
||||
) {
|
||||
js, natsClient := natsInstance.Prepare(processContext, &cfg.Global.JetStream)
|
||||
|
||||
|
@ -55,6 +57,7 @@ func AddPublicRoutes(
|
|||
cfg, rsAPI, asAPI,
|
||||
userAPI, userDirectoryProvider, federation,
|
||||
syncProducer, transactionsCache, fsAPI,
|
||||
extRoomsProvider, natsClient, enableMetrics,
|
||||
extRoomsProvider, natsClient,
|
||||
userVerifier, enableMetrics,
|
||||
)
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/element-hq/dendrite/appservice"
|
||||
"github.com/element-hq/dendrite/clientapi/auth"
|
||||
"github.com/element-hq/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/element-hq/dendrite/clientapi/routing"
|
||||
"github.com/element-hq/dendrite/clientapi/threepid"
|
||||
|
@ -127,9 +128,10 @@ func TestGetPutDevices(t *testing.T) {
|
|||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
|
||||
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
|
||||
|
||||
accessTokens := map[*test.User]userDevice{
|
||||
alice: {},
|
||||
|
@ -176,9 +178,10 @@ func TestDeleteDevice(t *testing.T) {
|
|||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
|
||||
// We mostly need the rsAPI/ for this test, so nil for other APIs/caches etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
|
||||
|
||||
accessTokens := map[*test.User]userDevice{
|
||||
alice: {},
|
||||
|
@ -281,9 +284,10 @@ func TestDeleteDevices(t *testing.T) {
|
|||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
|
||||
// We mostly need the rsAPI/ for this test, so nil for other APIs/caches etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
|
||||
|
||||
accessTokens := map[*test.User]userDevice{
|
||||
alice: {},
|
||||
|
@ -449,8 +453,9 @@ func TestSetDisplayname(t *testing.T) {
|
|||
rsAPI.SetFederationAPI(nil, nil)
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
asPI := appservice.NewInternalAPI(processCtx, cfg, natsInstance, userAPI, rsAPI)
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
|
||||
AddPublicRoutes(processCtx, routers, cfg, natsInstance, base.CreateFederationClient(cfg, nil), rsAPI, asPI, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||
AddPublicRoutes(processCtx, routers, cfg, natsInstance, base.CreateFederationClient(cfg, nil), rsAPI, asPI, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
|
||||
|
||||
accessTokens := map[*test.User]userDevice{
|
||||
alice: {},
|
||||
|
@ -561,8 +566,9 @@ func TestSetAvatarURL(t *testing.T) {
|
|||
rsAPI.SetFederationAPI(nil, nil)
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
asPI := appservice.NewInternalAPI(processCtx, cfg, natsInstance, userAPI, rsAPI)
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
|
||||
AddPublicRoutes(processCtx, routers, cfg, natsInstance, base.CreateFederationClient(cfg, nil), rsAPI, asPI, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||
AddPublicRoutes(processCtx, routers, cfg, natsInstance, base.CreateFederationClient(cfg, nil), rsAPI, asPI, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
|
||||
|
||||
accessTokens := map[*test.User]userDevice{
|
||||
alice: {},
|
||||
|
@ -638,8 +644,9 @@ func TestTyping(t *testing.T) {
|
|||
rsAPI.SetFederationAPI(nil, nil)
|
||||
// Needed to create accounts
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
// We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
|
||||
|
||||
// Create the users in the userapi and login
|
||||
accessTokens := map[*test.User]userDevice{
|
||||
|
@ -723,8 +730,9 @@ func TestMembership(t *testing.T) {
|
|||
// Needed to create accounts
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
rsAPI.SetUserAPI(userAPI)
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
// We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
|
||||
|
||||
// Create the users in the userapi and login
|
||||
accessTokens := map[*test.User]userDevice{
|
||||
|
@ -962,8 +970,9 @@ func TestCapabilities(t *testing.T) {
|
|||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
// We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
|
||||
|
||||
// Create the users in the userapi and login
|
||||
accessTokens := map[*test.User]userDevice{
|
||||
|
@ -1010,9 +1019,10 @@ func TestTurnserver(t *testing.T) {
|
|||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
//rsAPI.SetUserAPI(userAPI)
|
||||
// We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
|
||||
|
||||
// Create the users in the userapi and login
|
||||
accessTokens := map[*test.User]userDevice{
|
||||
|
@ -1109,8 +1119,9 @@ func Test3PID(t *testing.T) {
|
|||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
// We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
|
||||
|
||||
// Create the users in the userapi and login
|
||||
accessTokens := map[*test.User]userDevice{
|
||||
|
@ -1285,9 +1296,10 @@ func TestPushRules(t *testing.T) {
|
|||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
|
||||
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
|
||||
|
||||
accessTokens := map[*test.User]userDevice{
|
||||
alice: {},
|
||||
|
@ -1672,9 +1684,10 @@ func TestKeys(t *testing.T) {
|
|||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
|
||||
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
|
||||
|
||||
accessTokens := map[*test.User]userDevice{
|
||||
alice: {},
|
||||
|
@ -2134,9 +2147,10 @@ func TestKeyBackup(t *testing.T) {
|
|||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
|
||||
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
|
||||
|
||||
accessTokens := map[*test.User]userDevice{
|
||||
alice: {},
|
||||
|
@ -2238,9 +2252,10 @@ func TestGetMembership(t *testing.T) {
|
|||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
|
||||
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
|
||||
|
||||
accessTokens := map[*test.User]userDevice{
|
||||
alice: {},
|
||||
|
@ -2301,9 +2316,10 @@ func TestCreateRoomInvite(t *testing.T) {
|
|||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
|
||||
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
|
||||
|
||||
accessTokens := map[*test.User]userDevice{
|
||||
alice: {},
|
||||
|
@ -2376,9 +2392,10 @@ func TestReportEvent(t *testing.T) {
|
|||
if err := api.SendEvents(context.Background(), rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
|
||||
t.Fatalf("failed to send events: %v", err)
|
||||
}
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
|
||||
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
||||
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, &userVerifier, caching.DisableMetrics)
|
||||
|
||||
accessTokens := map[*test.User]userDevice{
|
||||
alice: {},
|
||||
|
|
|
@ -1,7 +1,13 @@
|
|||
// Copyright 2025 New Vector Ltd.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
|
||||
// Please see LICENSE files in the repository root for full details.
|
||||
|
||||
package routing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -20,16 +26,24 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/constraints"
|
||||
|
||||
appserviceAPI "github.com/element-hq/dendrite/appservice/api"
|
||||
clientapi "github.com/element-hq/dendrite/clientapi/api"
|
||||
"github.com/element-hq/dendrite/clientapi/auth/authtypes"
|
||||
clienthttputil "github.com/element-hq/dendrite/clientapi/httputil"
|
||||
"github.com/element-hq/dendrite/clientapi/userutil"
|
||||
"github.com/element-hq/dendrite/internal/httputil"
|
||||
roomserverAPI "github.com/element-hq/dendrite/roomserver/api"
|
||||
"github.com/element-hq/dendrite/setup/config"
|
||||
"github.com/element-hq/dendrite/setup/jetstream"
|
||||
"github.com/element-hq/dendrite/userapi/api"
|
||||
userapi "github.com/element-hq/dendrite/userapi/api"
|
||||
"github.com/element-hq/dendrite/userapi/storage/shared"
|
||||
)
|
||||
|
||||
var validRegistrationTokenRegex = regexp.MustCompile("^[[:ascii:][:digit:]_]*$")
|
||||
var (
|
||||
validRegistrationTokenRegex = regexp.MustCompile("^[[:ascii:][:digit:]_]*$")
|
||||
deviceDisplayName = "OIDC-native client"
|
||||
)
|
||||
|
||||
func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
||||
if !cfg.RegistrationRequiresToken {
|
||||
|
@ -496,6 +510,485 @@ func AdminDownloadState(req *http.Request, device *api.Device, rsAPI roomserverA
|
|||
}
|
||||
}
|
||||
|
||||
func AdminCheckUsernameAvailable(
|
||||
req *http.Request,
|
||||
userAPI userapi.ClientUserAPI,
|
||||
cfg *config.ClientAPI,
|
||||
) util.JSONResponse {
|
||||
username := req.URL.Query().Get("username")
|
||||
if username == "" {
|
||||
return util.MessageResponse(http.StatusBadRequest, "Query parameter 'username' is missing or empty")
|
||||
}
|
||||
rq := userapi.QueryAccountAvailabilityRequest{Localpart: username, ServerName: cfg.Matrix.ServerName}
|
||||
rs := userapi.QueryAccountAvailabilityResponse{}
|
||||
if err := userAPI.QueryAccountAvailability(req.Context(), &rq, &rs); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: map[string]bool{"available": rs.Available},
|
||||
}
|
||||
}
|
||||
|
||||
func AdminUserDeviceRetrieveCreate(
|
||||
req *http.Request,
|
||||
userAPI userapi.ClientUserAPI,
|
||||
cfg *config.ClientAPI,
|
||||
) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
userID := vars["userID"]
|
||||
local, domain, err := userutil.ParseUsernameParam(userID, cfg.Matrix)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.InvalidParam(err.Error()),
|
||||
}
|
||||
}
|
||||
logger := util.GetLogger(req.Context())
|
||||
|
||||
switch req.Method {
|
||||
case http.MethodPost:
|
||||
var payload struct {
|
||||
DeviceID string `json:"device_id"`
|
||||
}
|
||||
if resErr := clienthttputil.UnmarshalJSONRequest(req, &payload); resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
|
||||
userDeviceExists := false
|
||||
var rs api.QueryDevicesResponse
|
||||
if err = userAPI.QueryDevices(req.Context(), &api.QueryDevicesRequest{UserID: userID}, &rs); err != nil {
|
||||
logger.WithError(err).Error("QueryDevices")
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{},
|
||||
}
|
||||
}
|
||||
if !rs.UserExists {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusNotFound,
|
||||
JSON: spec.NotFound("Given user ID does not exist"),
|
||||
}
|
||||
}
|
||||
for i := range rs.Devices {
|
||||
if d := rs.Devices[i]; d.ID == payload.DeviceID && d.UserID == userID {
|
||||
userDeviceExists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !userDeviceExists {
|
||||
var rs userapi.PerformDeviceCreationResponse
|
||||
if err = userAPI.PerformDeviceCreation(req.Context(), &userapi.PerformDeviceCreationRequest{
|
||||
Localpart: local,
|
||||
ServerName: domain,
|
||||
DeviceID: &payload.DeviceID,
|
||||
DeviceDisplayName: &deviceDisplayName,
|
||||
IPAddr: "",
|
||||
UserAgent: req.UserAgent(),
|
||||
NoDeviceListUpdate: false,
|
||||
FromRegistration: false,
|
||||
AccessTokenUniqueConstraintDisabled: true,
|
||||
}, &rs); err != nil {
|
||||
logger.WithError(err).Error("PerformDeviceCreation")
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{},
|
||||
}
|
||||
}
|
||||
logger.WithError(err).Debug("PerformDeviceCreation succeeded")
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusCreated,
|
||||
JSON: struct{}{},
|
||||
}
|
||||
case http.MethodGet:
|
||||
var res userapi.QueryDevicesResponse
|
||||
if err := userAPI.QueryDevices(req.Context(), &userapi.QueryDevicesRequest{UserID: userID}, &res); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
|
||||
jsonDevices := make([]deviceJSON, 0, len(res.Devices))
|
||||
for i := range res.Devices {
|
||||
d := &res.Devices[i]
|
||||
jsonDevices = append(jsonDevices, deviceJSON{
|
||||
DeviceID: d.ID,
|
||||
DisplayName: d.DisplayName,
|
||||
LastSeenIP: d.LastSeenIP,
|
||||
LastSeenTS: d.LastSeenTS,
|
||||
})
|
||||
}
|
||||
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: struct {
|
||||
Devices []deviceJSON `json:"devices"`
|
||||
Total int `json:"total"`
|
||||
}{
|
||||
Devices: jsonDevices,
|
||||
Total: len(res.Devices),
|
||||
},
|
||||
}
|
||||
default:
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusMethodNotAllowed,
|
||||
JSON: struct{}{},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func AdminUserDeviceDelete(
|
||||
req *http.Request,
|
||||
userAPI userapi.ClientUserAPI,
|
||||
cfg *config.ClientAPI,
|
||||
) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
userID := vars["userID"]
|
||||
deviceID := vars["deviceID"]
|
||||
logger := util.GetLogger(req.Context())
|
||||
|
||||
// XXX: we probably have to delete session from the sessions dict
|
||||
// like we do in DeleteDeviceById. If so, we have to fi
|
||||
var device *api.Device
|
||||
{
|
||||
var rs api.QueryDevicesResponse
|
||||
if err := userAPI.QueryDevices(req.Context(), &api.QueryDevicesRequest{UserID: userID}, &rs); err != nil {
|
||||
logger.WithError(err).Error("userAPI.QueryDevices failed")
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{},
|
||||
}
|
||||
}
|
||||
if !rs.UserExists {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusNotFound,
|
||||
JSON: spec.NotFound("Given user ID does not exist"),
|
||||
}
|
||||
}
|
||||
for i := range rs.Devices {
|
||||
if d := rs.Devices[i]; d.ID == deviceID && d.UserID == userID {
|
||||
device = &d
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if device != nil {
|
||||
// XXX: this response struct can completely removed everywhere as it doesn't
|
||||
// have any functional purpose
|
||||
var res api.PerformDeviceDeletionResponse
|
||||
if err := userAPI.PerformDeviceDeletion(req.Context(), &api.PerformDeviceDeletionRequest{
|
||||
UserID: device.UserID,
|
||||
DeviceIDs: []string{device.ID},
|
||||
}, &res); err != nil {
|
||||
logger.WithError(err).Error("userAPI.PerformDeviceDeletion failed")
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: struct{}{},
|
||||
}
|
||||
}
|
||||
|
||||
func AdminUserDevicesDelete(
|
||||
req *http.Request,
|
||||
userAPI userapi.ClientUserAPI,
|
||||
cfg *config.ClientAPI,
|
||||
) util.JSONResponse {
|
||||
logger := util.GetLogger(req.Context())
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
userID := vars["userID"]
|
||||
|
||||
if req.Body == nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, "body is required")
|
||||
}
|
||||
var payload struct {
|
||||
Devices []string `json:"devices"`
|
||||
}
|
||||
|
||||
if err = json.NewDecoder(req.Body).Decode(&payload); err != nil {
|
||||
logger.WithError(err).Error("unable to decode device deletion request")
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{},
|
||||
}
|
||||
}
|
||||
defer req.Body.Close() // nolint: errcheck
|
||||
|
||||
{
|
||||
// XXX: this response struct can completely removed everywhere as it doesn't
|
||||
// have any functional purpose
|
||||
var rs api.PerformDeviceDeletionResponse
|
||||
if err := userAPI.PerformDeviceDeletion(req.Context(), &api.PerformDeviceDeletionRequest{
|
||||
UserID: userID,
|
||||
DeviceIDs: payload.Devices,
|
||||
}, &rs); err != nil {
|
||||
logger.WithError(err).Error("userAPI.PerformDeviceDeletion failed")
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: struct{}{},
|
||||
}
|
||||
}
|
||||
|
||||
func AdminDeactivateAccount(
|
||||
req *http.Request,
|
||||
userAPI userapi.ClientUserAPI,
|
||||
cfg *config.ClientAPI,
|
||||
) util.JSONResponse {
|
||||
logger := util.GetLogger(req.Context())
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
userID := vars["userID"]
|
||||
local, domain, err := userutil.ParseUsernameParam(userID, cfg.Matrix)
|
||||
if err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
|
||||
// TODO: "erase" field must also be processed here
|
||||
// see https://github.com/element-hq/synapse/blob/develop/docs/admin_api/user_admin_api.md#deactivate-account
|
||||
|
||||
var rs api.PerformAccountDeactivationResponse
|
||||
if err := userAPI.PerformAccountDeactivation(req.Context(), &api.PerformAccountDeactivationRequest{
|
||||
Localpart: local, ServerName: domain,
|
||||
}, &rs); err != nil {
|
||||
logger.WithError(err).Error("userAPI.PerformDeviceDeletion failed")
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{},
|
||||
}
|
||||
}
|
||||
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: struct{}{},
|
||||
}
|
||||
}
|
||||
|
||||
func AdminAllowCrossSigningReplacementWithoutUIA(
|
||||
req *http.Request,
|
||||
userAPI userapi.ClientUserAPI,
|
||||
) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
userIDstr, ok := vars["userID"]
|
||||
userID, err := spec.NewUserID(userIDstr, false)
|
||||
if !ok || err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusNotFound,
|
||||
JSON: spec.MissingParam("User not found."),
|
||||
}
|
||||
}
|
||||
|
||||
var rs api.QueryAccountByLocalpartResponse
|
||||
err = userAPI.QueryAccountByLocalpart(req.Context(), &api.QueryAccountByLocalpartRequest{
|
||||
Localpart: userID.Local(),
|
||||
ServerName: userID.Domain(),
|
||||
}, &rs)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccountByLocalpart")
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.Unknown(err.Error()),
|
||||
}
|
||||
} else if errors.Is(err, sql.ErrNoRows) {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusNotFound,
|
||||
JSON: spec.NotFound("User not found."),
|
||||
}
|
||||
}
|
||||
switch req.Method {
|
||||
case http.MethodPost:
|
||||
ts := sessions.allowCrossSigningKeysReplacement(userID.String())
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: map[string]int64{"updatable_without_uia_before_ms": ts},
|
||||
}
|
||||
default:
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusMethodNotAllowed,
|
||||
JSON: spec.Unknown("Method not allowed."),
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
type adminCreateOrModifyAccountRequest struct {
|
||||
DisplayName string `json:"displayname"`
|
||||
AvatarURL string `json:"avatar_url"`
|
||||
ThreePIDs []struct {
|
||||
Medium string `json:"medium"`
|
||||
Address string `json:"address"`
|
||||
} `json:"threepids"`
|
||||
// TODO: the following fields are not used by dendrite, but they are used in Synapse.
|
||||
// Password string `json:"password"`
|
||||
// LogoutDevices bool `json:"logout_devices"`
|
||||
// ExternalIDs []struct{
|
||||
// AuthProvider string `json:"auth_provider"`
|
||||
// ExternalID string `json:"external_id"`
|
||||
// } `json:"external_ids"`
|
||||
// Admin bool `json:"admin"`
|
||||
// Deactivated bool `json:"deactivated"`
|
||||
// Locked bool `json:"locked"`
|
||||
}
|
||||
|
||||
func AdminCreateOrModifyAccount(req *http.Request, userAPI userapi.ClientUserAPI, cfg *config.ClientAPI) util.JSONResponse {
|
||||
logger := util.GetLogger(req.Context())
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
userID := vars["userID"]
|
||||
local, domain, err := userutil.ParseUsernameParam(userID, cfg.Matrix)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.InvalidParam(userID),
|
||||
}
|
||||
}
|
||||
var r adminCreateOrModifyAccountRequest
|
||||
if resErr := clienthttputil.UnmarshalJSONRequest(req, &r); resErr != nil {
|
||||
logger.Debugf("UnmarshalJSONRequest failed: %+v", *resErr)
|
||||
return *resErr
|
||||
}
|
||||
logger.Debugf("adminCreateOrModifyAccountRequest is: %#v", r)
|
||||
statusCode := http.StatusOK
|
||||
|
||||
// TODO: Ideally, the following commands should be executed in one transaction.
|
||||
// can we propagate the tx object and pass it in context?
|
||||
var res userapi.PerformAccountCreationResponse
|
||||
err = userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{
|
||||
AccountType: userapi.AccountTypeUser,
|
||||
Localpart: local,
|
||||
ServerName: domain,
|
||||
OnConflict: api.ConflictUpdate,
|
||||
AvatarURL: r.AvatarURL,
|
||||
DisplayName: r.DisplayName,
|
||||
}, &res)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("userAPI.PerformAccountCreation")
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if res.AccountCreated {
|
||||
statusCode = http.StatusCreated
|
||||
}
|
||||
|
||||
if l := len(r.ThreePIDs); l > 0 {
|
||||
logger.Debugf("Trying to bulk save 3PID associations: %+v", r.ThreePIDs)
|
||||
threePIDs := make([]authtypes.ThreePID, 0, len(r.ThreePIDs))
|
||||
for i := range r.ThreePIDs {
|
||||
tpid := &r.ThreePIDs[i]
|
||||
threePIDs = append(threePIDs, authtypes.ThreePID{Medium: tpid.Medium, Address: tpid.Address})
|
||||
}
|
||||
err = userAPI.PerformBulkSaveThreePIDAssociation(req.Context(), &userapi.PerformBulkSaveThreePIDAssociationRequest{
|
||||
ThreePIDs: threePIDs,
|
||||
Localpart: local,
|
||||
ServerName: domain,
|
||||
}, &struct{}{})
|
||||
if err == shared.Err3PIDInUse {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
} else if err != nil {
|
||||
logger.WithError(err).Error("userAPI.PerformSaveThreePIDAssociation")
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
}
|
||||
|
||||
return util.JSONResponse{
|
||||
Code: statusCode,
|
||||
JSON: nil,
|
||||
}
|
||||
}
|
||||
|
||||
func AdminRetrieveAccount(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
||||
logger := util.GetLogger(req.Context())
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
userID, ok := vars["userID"]
|
||||
if !ok {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.MissingParam("Expecting user ID."),
|
||||
}
|
||||
}
|
||||
local, domain, err := userutil.ParseUsernameParam(userID, cfg.Matrix)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.InvalidParam(err.Error()),
|
||||
}
|
||||
}
|
||||
|
||||
body := struct {
|
||||
DisplayName string `json:"display_name"`
|
||||
AvatarURL string `json:"avatar_url"`
|
||||
Deactivated bool `json:"deactivated"`
|
||||
}{}
|
||||
|
||||
var rs api.QueryAccountByLocalpartResponse
|
||||
err = userAPI.QueryAccountByLocalpart(req.Context(), &api.QueryAccountByLocalpartRequest{Localpart: local, ServerName: domain}, &rs)
|
||||
if err == sql.ErrNoRows {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusNotFound,
|
||||
JSON: spec.NotFound(fmt.Sprintf("User '%s' not found", userID)),
|
||||
}
|
||||
} else if err != nil {
|
||||
logger.WithError(err).Error("userAPI.QueryAccountByLocalpart")
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.Unknown(err.Error()),
|
||||
}
|
||||
}
|
||||
body.Deactivated = rs.Account.Deactivated
|
||||
|
||||
profile, err := userAPI.QueryProfile(req.Context(), userID)
|
||||
if err != nil {
|
||||
if err == appserviceAPI.ErrProfileNotExists {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusNotFound,
|
||||
JSON: spec.NotFound(err.Error()),
|
||||
}
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.Unknown(err.Error()),
|
||||
}
|
||||
}
|
||||
body.AvatarURL = profile.AvatarURL
|
||||
body.DisplayName = profile.DisplayName
|
||||
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: body,
|
||||
}
|
||||
}
|
||||
|
||||
// GetEventReports returns reported events for a given user/room.
|
||||
func GetEventReports(
|
||||
req *http.Request,
|
||||
|
|
|
@ -9,20 +9,21 @@ package routing
|
|||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib/fclient"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/element-hq/dendrite/clientapi/auth"
|
||||
"github.com/element-hq/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/element-hq/dendrite/clientapi/httputil"
|
||||
"github.com/element-hq/dendrite/setup/config"
|
||||
"github.com/element-hq/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib/fclient"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
const CrossSigningResetStage = "org.matrix.cross_signing_reset"
|
||||
|
||||
type crossSigningRequest struct {
|
||||
api.PerformUploadDeviceKeysRequest
|
||||
Auth newPasswordAuth `json:"auth"`
|
||||
|
@ -38,6 +39,7 @@ func UploadCrossSigningDeviceKeys(
|
|||
keyserverAPI UploadKeysAPI, device *api.Device,
|
||||
accountAPI auth.GetAccountByPassword, cfg *config.ClientAPI,
|
||||
) util.JSONResponse {
|
||||
logger := util.GetLogger(req.Context())
|
||||
uploadReq := &crossSigningRequest{}
|
||||
uploadRes := &api.PerformUploadDeviceKeysResponse{}
|
||||
|
||||
|
@ -55,76 +57,92 @@ func UploadCrossSigningDeviceKeys(
|
|||
}, &keyResp)
|
||||
|
||||
if keyResp.Error != nil {
|
||||
logrus.WithError(keyResp.Error).Error("Failed to query keys")
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.Unknown(keyResp.Error.Error()),
|
||||
}
|
||||
logger.WithError(keyResp.Error).Error("Failed to query keys")
|
||||
return convertKeyError(keyResp.Error)
|
||||
}
|
||||
|
||||
existingMasterKey, hasMasterKey := keyResp.MasterKeys[device.UserID]
|
||||
requireUIA := false
|
||||
if hasMasterKey {
|
||||
// If we have a master key, check if any of the existing keys differ. If they do,
|
||||
// we need to re-authenticate the user.
|
||||
requireUIA = keysDiffer(existingMasterKey, keyResp, uploadReq, device.UserID)
|
||||
}
|
||||
|
||||
if requireUIA {
|
||||
sessionID := uploadReq.Auth.Session
|
||||
if sessionID == "" {
|
||||
sessionID = util.RandomString(sessionIDLength)
|
||||
}
|
||||
if uploadReq.Auth.Type != authtypes.LoginTypePassword {
|
||||
if hasMasterKey {
|
||||
if !keysDiffer(existingMasterKey, keyResp, uploadReq, device.UserID) {
|
||||
// If we have a master key, check if any of the existing keys differ. If they don't
|
||||
// we return 200 as keys are still valid and there's nothing to do.
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusUnauthorized,
|
||||
JSON: newUserInteractiveResponse(
|
||||
sessionID,
|
||||
[]authtypes.Flow{
|
||||
{
|
||||
Stages: []authtypes.LoginType{authtypes.LoginTypePassword},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
),
|
||||
Code: http.StatusOK,
|
||||
JSON: struct{}{},
|
||||
}
|
||||
}
|
||||
typePassword := auth.LoginTypePassword{
|
||||
GetAccountByPassword: accountAPI,
|
||||
Config: cfg,
|
||||
|
||||
// With MSC3861, UIA is not possible. Instead, the auth service has to explicitly mark the master key as replaceable.
|
||||
if cfg.MSCs.MSC3861Enabled() {
|
||||
requireUIA := !sessions.isCrossSigningKeysReplacementAllowed(device.UserID)
|
||||
if requireUIA {
|
||||
url := ""
|
||||
if m := cfg.MSCs.MSC3861; m.AccountManagementURL != "" {
|
||||
url = strings.Join([]string{m.AccountManagementURL, "?action=", CrossSigningResetStage}, "")
|
||||
} else {
|
||||
url = m.Issuer
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusUnauthorized,
|
||||
JSON: newUserInteractiveResponse(
|
||||
"dummy",
|
||||
[]authtypes.Flow{
|
||||
{
|
||||
Stages: []authtypes.LoginType{CrossSigningResetStage},
|
||||
},
|
||||
},
|
||||
map[string]interface{}{
|
||||
CrossSigningResetStage: map[string]string{
|
||||
"url": url,
|
||||
},
|
||||
},
|
||||
strings.Join([]string{
|
||||
"To reset your end-to-end encryption cross-signing identity, you first need to approve it at",
|
||||
url,
|
||||
"and then try again.",
|
||||
}, " "),
|
||||
),
|
||||
}
|
||||
}
|
||||
sessions.restrictCrossSigningKeysReplacement(device.UserID)
|
||||
} else {
|
||||
sessionID := uploadReq.Auth.Session
|
||||
if sessionID == "" {
|
||||
sessionID = util.RandomString(sessionIDLength)
|
||||
}
|
||||
|
||||
if uploadReq.Auth.Type != authtypes.LoginTypePassword {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusUnauthorized,
|
||||
JSON: newUserInteractiveResponse(
|
||||
sessionID,
|
||||
[]authtypes.Flow{
|
||||
{
|
||||
Stages: []authtypes.LoginType{authtypes.LoginTypePassword},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
"",
|
||||
),
|
||||
}
|
||||
}
|
||||
typePassword := auth.LoginTypePassword{
|
||||
GetAccountByPassword: accountAPI,
|
||||
Config: cfg,
|
||||
}
|
||||
if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil {
|
||||
return *authErr
|
||||
}
|
||||
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
|
||||
}
|
||||
if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil {
|
||||
return *authErr
|
||||
}
|
||||
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
|
||||
}
|
||||
|
||||
uploadReq.UserID = device.UserID
|
||||
keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes)
|
||||
|
||||
if err := uploadRes.Error; err != nil {
|
||||
switch {
|
||||
case err.IsInvalidSignature:
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.InvalidSignature(err.Error()),
|
||||
}
|
||||
case err.IsMissingParam:
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.MissingParam(err.Error()),
|
||||
}
|
||||
case err.IsInvalidParam:
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.InvalidParam(err.Error()),
|
||||
}
|
||||
default:
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.Unknown(err.Error()),
|
||||
}
|
||||
}
|
||||
return convertKeyError(err)
|
||||
}
|
||||
|
||||
return util.JSONResponse{
|
||||
|
@ -160,28 +178,7 @@ func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.Clie
|
|||
keyserverAPI.PerformUploadDeviceSignatures(req.Context(), uploadReq, uploadRes)
|
||||
|
||||
if err := uploadRes.Error; err != nil {
|
||||
switch {
|
||||
case err.IsInvalidSignature:
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.InvalidSignature(err.Error()),
|
||||
}
|
||||
case err.IsMissingParam:
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.MissingParam(err.Error()),
|
||||
}
|
||||
case err.IsInvalidParam:
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.InvalidParam(err.Error()),
|
||||
}
|
||||
default:
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.Unknown(err.Error()),
|
||||
}
|
||||
}
|
||||
return convertKeyError(err)
|
||||
}
|
||||
|
||||
return util.JSONResponse{
|
||||
|
@ -189,3 +186,28 @@ func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.Clie
|
|||
JSON: struct{}{},
|
||||
}
|
||||
}
|
||||
|
||||
func convertKeyError(err *api.KeyError) util.JSONResponse {
|
||||
switch {
|
||||
case err.IsInvalidSignature:
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.InvalidSignature(err.Error()),
|
||||
}
|
||||
case err.IsMissingParam:
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.MissingParam(err.Error()),
|
||||
}
|
||||
case err.IsInvalidParam:
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.InvalidParam(err.Error()),
|
||||
}
|
||||
default:
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.Unknown(err.Error()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,15 +19,17 @@ import (
|
|||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
)
|
||||
|
||||
// TODO: add more tests to cover cases related to MSC3861
|
||||
|
||||
type mockKeyAPI struct {
|
||||
t *testing.T
|
||||
userResponses map[string]api.QueryKeysResponse
|
||||
queryKeysData map[string]api.QueryKeysResponse
|
||||
}
|
||||
|
||||
func (m mockKeyAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {
|
||||
res.MasterKeys = m.userResponses[req.UserID].MasterKeys
|
||||
res.SelfSigningKeys = m.userResponses[req.UserID].SelfSigningKeys
|
||||
res.UserSigningKeys = m.userResponses[req.UserID].UserSigningKeys
|
||||
res.MasterKeys = m.queryKeysData[req.UserID].MasterKeys
|
||||
res.SelfSigningKeys = m.queryKeysData[req.UserID].SelfSigningKeys
|
||||
res.UserSigningKeys = m.queryKeysData[req.UserID].UserSigningKeys
|
||||
if m.t != nil {
|
||||
m.t.Logf("QueryKeys: %+v => %+v", req, res)
|
||||
}
|
||||
|
@ -53,13 +55,16 @@ func Test_UploadCrossSigningDeviceKeys_ValidRequest(t *testing.T) {
|
|||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
keyserverAPI := &mockKeyAPI{
|
||||
userResponses: map[string]api.QueryKeysResponse{
|
||||
queryKeysData: map[string]api.QueryKeysResponse{
|
||||
"@user:example.com": {},
|
||||
},
|
||||
}
|
||||
device := &api.Device{UserID: "@user:example.com", ID: "device"}
|
||||
cfg := &config.ClientAPI{}
|
||||
|
||||
cfg := &config.ClientAPI{
|
||||
MSCs: &config.MSCs{
|
||||
MSCs: []string{},
|
||||
},
|
||||
}
|
||||
res := UploadCrossSigningDeviceKeys(req, keyserverAPI, device, getAccountByPassword, cfg)
|
||||
if res.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, res.Code)
|
||||
|
@ -101,10 +106,13 @@ func Test_UploadCrossSigningDeviceKeys_Unauthorised(t *testing.T) {
|
|||
|
||||
keyserverAPI := &mockKeyAPI{
|
||||
t: t,
|
||||
userResponses: map[string]api.QueryKeysResponse{
|
||||
queryKeysData: map[string]api.QueryKeysResponse{
|
||||
"@user:example.com": {
|
||||
MasterKeys: map[string]fclient.CrossSigningKey{
|
||||
"@user:example.com": {UserID: "@user:example.com", Usage: []fclient.CrossSigningKeyPurpose{"master"}, Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{"ed25519:1": spec.Base64Bytes("key1")}},
|
||||
"@user:example.com": {
|
||||
UserID: "@user:example.com",
|
||||
Usage: []fclient.CrossSigningKeyPurpose{fclient.CrossSigningKeyPurposeMaster},
|
||||
Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{"ed25519:1": spec.Base64Bytes("key1")}},
|
||||
},
|
||||
SelfSigningKeys: nil,
|
||||
UserSigningKeys: nil,
|
||||
|
@ -112,7 +120,11 @@ func Test_UploadCrossSigningDeviceKeys_Unauthorised(t *testing.T) {
|
|||
},
|
||||
}
|
||||
device := &api.Device{UserID: "@user:example.com", ID: "device"}
|
||||
cfg := &config.ClientAPI{}
|
||||
cfg := &config.ClientAPI{
|
||||
MSCs: &config.MSCs{
|
||||
MSCs: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
res := UploadCrossSigningDeviceKeys(req, keyserverAPI, device, getAccountByPassword, cfg)
|
||||
if res.Code != http.StatusUnauthorized {
|
||||
|
@ -132,8 +144,11 @@ func Test_UploadCrossSigningDeviceKeys_InvalidJSON(t *testing.T) {
|
|||
|
||||
keyserverAPI := &mockKeyAPI{}
|
||||
device := &api.Device{UserID: "@user:example.com", ID: "device"}
|
||||
cfg := &config.ClientAPI{}
|
||||
|
||||
cfg := &config.ClientAPI{
|
||||
MSCs: &config.MSCs{
|
||||
MSCs: []string{},
|
||||
},
|
||||
}
|
||||
res := UploadCrossSigningDeviceKeys(req, keyserverAPI, device, getAccountByPassword, cfg)
|
||||
if res.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusBadRequest, res.Code)
|
||||
|
@ -151,10 +166,14 @@ func Test_UploadCrossSigningDeviceKeys_ExistingKeysMismatch(t *testing.T) {
|
|||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
keyserverAPI := &mockKeyAPI{
|
||||
userResponses: map[string]api.QueryKeysResponse{
|
||||
queryKeysData: map[string]api.QueryKeysResponse{
|
||||
"@user:example.com": {
|
||||
MasterKeys: map[string]fclient.CrossSigningKey{
|
||||
"@user:example.com": {UserID: "@user:example.com", Usage: []fclient.CrossSigningKeyPurpose{"master"}, Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{"ed25519:1": spec.Base64Bytes("different_key")}},
|
||||
"@user:example.com": {
|
||||
UserID: "@user:example.com",
|
||||
Usage: []fclient.CrossSigningKeyPurpose{fclient.CrossSigningKeyPurposeMaster},
|
||||
Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{"ed25519:1": spec.Base64Bytes("different_key")},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/element-hq/dendrite/clientapi/auth"
|
||||
"github.com/element-hq/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/element-hq/dendrite/internal/caching"
|
||||
"github.com/element-hq/dendrite/internal/httputil"
|
||||
|
@ -50,9 +51,10 @@ func TestLogin(t *testing.T) {
|
|||
rsAPI.SetFederationAPI(nil, nil)
|
||||
// Needed for /login
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
|
||||
// We mostly need the userAPI for this test, so nil for other APIs/caches etc.
|
||||
Setup(routers, cfg, nil, nil, userAPI, nil, nil, nil, nil, nil, nil, nil, caching.DisableMetrics)
|
||||
Setup(routers, cfg, nil, nil, userAPI, nil, nil, nil, nil, nil, nil, nil, &userVerifier, caching.DisableMetrics)
|
||||
|
||||
// Create password
|
||||
password := util.RandomString(8)
|
||||
|
|
|
@ -67,6 +67,7 @@ func Password(
|
|||
},
|
||||
},
|
||||
nil,
|
||||
"",
|
||||
),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -172,24 +172,20 @@ func GetDisplayName(
|
|||
|
||||
// SetDisplayName implements PUT /profile/{userID}/displayname
|
||||
func SetDisplayName(
|
||||
req *http.Request, profileAPI userapi.ProfileAPI,
|
||||
req *http.Request, userAPI userapi.ClientUserAPI,
|
||||
device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.ClientRoomserverAPI,
|
||||
) util.JSONResponse {
|
||||
if userID != device.UserID {
|
||||
if userID != device.UserID && device.AccountType != userapi.AccountTypeOIDCService {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: spec.Forbidden("userID does not match the current user"),
|
||||
}
|
||||
}
|
||||
|
||||
var r eventutil.UserProfile
|
||||
if resErr := httputil.UnmarshalJSONRequest(req, &r); resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
|
||||
logger := util.GetLogger(req.Context())
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||
logger.WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{},
|
||||
|
@ -203,6 +199,28 @@ func SetDisplayName(
|
|||
}
|
||||
}
|
||||
|
||||
if device.AccountType == userapi.AccountTypeOIDCService {
|
||||
// When a request is made on behalf of an OIDC provider service, the original device object refers
|
||||
// to the provider's pseudo-device and includes only the AccountTypeOIDCService flag. To continue,
|
||||
// we need to replace the admin's device with the user's device
|
||||
var rs userapi.QueryDevicesResponse
|
||||
err = userAPI.QueryDevices(req.Context(), &userapi.QueryDevicesRequest{UserID: userID}, &rs)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{},
|
||||
}
|
||||
}
|
||||
if len(rs.Devices) > 0 {
|
||||
device = &rs.Devices[0]
|
||||
}
|
||||
}
|
||||
|
||||
var r eventutil.UserProfile
|
||||
if resErr := httputil.UnmarshalJSONRequest(req, &r); resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
|
||||
evTime, err := httputil.ParseTSParam(req)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
|
@ -211,9 +229,9 @@ func SetDisplayName(
|
|||
}
|
||||
}
|
||||
|
||||
profile, changed, err := profileAPI.SetDisplayName(req.Context(), localpart, domain, r.DisplayName)
|
||||
profile, changed, err := userAPI.SetDisplayName(req.Context(), localpart, domain, r.DisplayName)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetDisplayName failed")
|
||||
logger.WithError(err).Error("profileAPI.SetDisplayName failed")
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{},
|
||||
|
|
|
@ -66,11 +66,17 @@ type sessionsDict struct {
|
|||
// If a UIA session is started by trying to delete device1, and then UIA is completed by deleting device2,
|
||||
// the delete request will fail for device2 since the UIA was initiated by trying to delete device1.
|
||||
deleteSessionToDeviceID map[string]string
|
||||
// crossSigningKeysReplacement is a collection of sessions that MAS has authorised for updating
|
||||
// cross-signing keys without UIA.
|
||||
crossSigningKeysReplacement map[string]*time.Timer
|
||||
}
|
||||
|
||||
// defaultTimeout is the timeout used to clean up sessions
|
||||
const defaultTimeOut = time.Minute * 5
|
||||
|
||||
// crossSigningKeysReplacementDuration is the timeout used for replacing cross signing keys without UIA
|
||||
const crossSigningKeysReplacementDuration = time.Minute * 10
|
||||
|
||||
// getCompletedStages returns the completed stages for a session.
|
||||
func (d *sessionsDict) getCompletedStages(sessionID string) []authtypes.LoginType {
|
||||
d.RLock()
|
||||
|
@ -119,13 +125,54 @@ func (d *sessionsDict) deleteSession(sessionID string) {
|
|||
}
|
||||
}
|
||||
|
||||
func (d *sessionsDict) allowCrossSigningKeysReplacement(userID string) int64 {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
allowedUntilTS := time.Now().Add(crossSigningKeysReplacementDuration).UnixMilli()
|
||||
t, ok := d.crossSigningKeysReplacement[userID]
|
||||
if ok {
|
||||
t.Reset(crossSigningKeysReplacementDuration)
|
||||
return allowedUntilTS
|
||||
}
|
||||
d.crossSigningKeysReplacement[userID] = time.AfterFunc(
|
||||
crossSigningKeysReplacementDuration,
|
||||
func() {
|
||||
d.restrictCrossSigningKeysReplacement(userID)
|
||||
},
|
||||
)
|
||||
return allowedUntilTS
|
||||
}
|
||||
|
||||
func (d *sessionsDict) isCrossSigningKeysReplacementAllowed(userID string) bool {
|
||||
d.RLock()
|
||||
defer d.RUnlock()
|
||||
_, ok := d.crossSigningKeysReplacement[userID]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (d *sessionsDict) restrictCrossSigningKeysReplacement(userID string) {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
t, ok := d.crossSigningKeysReplacement[userID]
|
||||
if ok {
|
||||
if !t.Stop() {
|
||||
select {
|
||||
case <-t.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
delete(d.crossSigningKeysReplacement, userID)
|
||||
}
|
||||
}
|
||||
|
||||
func newSessionsDict() *sessionsDict {
|
||||
return &sessionsDict{
|
||||
sessions: make(map[string][]authtypes.LoginType),
|
||||
sessionCompletedResult: make(map[string]registerResponse),
|
||||
params: make(map[string]registerRequest),
|
||||
timer: make(map[string]*time.Timer),
|
||||
deleteSessionToDeviceID: make(map[string]string),
|
||||
sessions: make(map[string][]authtypes.LoginType),
|
||||
sessionCompletedResult: make(map[string]registerResponse),
|
||||
params: make(map[string]registerRequest),
|
||||
timer: make(map[string]*time.Timer),
|
||||
deleteSessionToDeviceID: make(map[string]string),
|
||||
crossSigningKeysReplacement: make(map[string]*time.Timer),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -234,6 +281,7 @@ type userInteractiveResponse struct {
|
|||
Completed []authtypes.LoginType `json:"completed"`
|
||||
Params map[string]interface{} `json:"params"`
|
||||
Session string `json:"session"`
|
||||
Msg string `json:"msg,omitempty"`
|
||||
}
|
||||
|
||||
// newUserInteractiveResponse will return a struct to be sent back to the client
|
||||
|
@ -242,9 +290,10 @@ func newUserInteractiveResponse(
|
|||
sessionID string,
|
||||
fs []authtypes.Flow,
|
||||
params map[string]interface{},
|
||||
msg string,
|
||||
) userInteractiveResponse {
|
||||
return userInteractiveResponse{
|
||||
fs, sessions.getCompletedStages(sessionID), params, sessionID,
|
||||
fs, sessions.getCompletedStages(sessionID), params, sessionID, msg,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -817,7 +866,7 @@ func checkAndCompleteFlow(
|
|||
return util.JSONResponse{
|
||||
Code: http.StatusUnauthorized,
|
||||
JSON: newUserInteractiveResponse(sessionID,
|
||||
cfg.Derived.Registration.Flows, cfg.Derived.Registration.Params),
|
||||
cfg.Derived.Registration.Flows, cfg.Derived.Registration.Params, ""),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -669,3 +669,57 @@ func TestRegisterAdminUsingSharedSecret(t *testing.T) {
|
|||
assert.Equal(t, expectedDisplayName, profile.DisplayName)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCrossSigningKeysReplacement(t *testing.T) {
|
||||
userID := "@user:example.com"
|
||||
|
||||
t.Run("Can add new session", func(t *testing.T) {
|
||||
s := newSessionsDict()
|
||||
assert.Empty(t, s.crossSigningKeysReplacement)
|
||||
s.allowCrossSigningKeysReplacement(userID)
|
||||
assert.Len(t, s.crossSigningKeysReplacement, 1)
|
||||
assert.Contains(t, s.crossSigningKeysReplacement, userID)
|
||||
})
|
||||
|
||||
t.Run("Can check if session exists or not", func(t *testing.T) {
|
||||
s := newSessionsDict()
|
||||
t.Run("exists", func(t *testing.T) {
|
||||
s.allowCrossSigningKeysReplacement(userID)
|
||||
assert.Len(t, s.crossSigningKeysReplacement, 1)
|
||||
assert.True(t, s.isCrossSigningKeysReplacementAllowed(userID))
|
||||
})
|
||||
|
||||
t.Run("not exists", func(t *testing.T) {
|
||||
assert.False(t, s.isCrossSigningKeysReplacementAllowed("@random:test.com"))
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Can deactivate session", func(t *testing.T) {
|
||||
s := newSessionsDict()
|
||||
assert.Empty(t, s.crossSigningKeysReplacement)
|
||||
t.Run("not exists", func(t *testing.T) {
|
||||
s.restrictCrossSigningKeysReplacement("@random:test.com")
|
||||
assert.Empty(t, s.crossSigningKeysReplacement)
|
||||
})
|
||||
|
||||
t.Run("exists", func(t *testing.T) {
|
||||
s.allowCrossSigningKeysReplacement(userID)
|
||||
s.restrictCrossSigningKeysReplacement(userID)
|
||||
assert.Empty(t, s.crossSigningKeysReplacement)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Can erase expired sessions", func(t *testing.T) {
|
||||
s := newSessionsDict()
|
||||
s.allowCrossSigningKeysReplacement(userID)
|
||||
assert.Len(t, s.crossSigningKeysReplacement, 1)
|
||||
assert.True(t, s.isCrossSigningKeysReplacementAllowed(userID))
|
||||
timer := s.crossSigningKeysReplacement[userID]
|
||||
|
||||
// pretending the timer is expired
|
||||
timer.Reset(time.Millisecond)
|
||||
time.Sleep(time.Millisecond * 500)
|
||||
|
||||
assert.Empty(t, s.crossSigningKeysReplacement)
|
||||
})
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -55,7 +55,7 @@ var latest, _ = semver.NewVersion("v6.6.6") // Dummy version, used as "HEAD"
|
|||
// due to the error:
|
||||
// When using COPY with more than one source file, the destination must be a directory and end with a /
|
||||
// We need to run a postgres anyway, so use the dockerfile associated with Complement instead.
|
||||
const DockerfilePostgreSQL = `FROM golang:1.22-bookworm as build
|
||||
const DockerfilePostgreSQL = `FROM golang:1.23-bookworm as build
|
||||
RUN apt-get update && apt-get install -y postgresql
|
||||
WORKDIR /build
|
||||
ARG BINARY
|
||||
|
@ -99,7 +99,7 @@ ENV BINARY=dendrite
|
|||
EXPOSE 8008 8448
|
||||
CMD /build/run_dendrite.sh`
|
||||
|
||||
const DockerfileSQLite = `FROM golang:1.22-bookworm as build
|
||||
const DockerfileSQLite = `FROM golang:1.23-bookworm as build
|
||||
RUN apt-get update && apt-get install -y postgresql
|
||||
WORKDIR /build
|
||||
ARG BINARY
|
||||
|
|
|
@ -303,8 +303,27 @@ media_api:
|
|||
# Configuration for enabling experimental MSCs on this homeserver.
|
||||
mscs:
|
||||
mscs:
|
||||
# - msc3861 # (Next-gen auth, see https://github.com/matrix-org/matrix-doc/pull/3861)
|
||||
# - msc2836 # (Threading, see https://github.com/matrix-org/matrix-doc/pull/2836)
|
||||
|
||||
# This block has no effect if the feature is not activated in the list above
|
||||
# msc3861:
|
||||
# # OIDC issuer advertised by the service.
|
||||
# # See https://element-hq.github.io/matrix-authentication-service/reference/configuration.html#http
|
||||
# issuer: "https://mas.example.com/"
|
||||
|
||||
# # Credentials used for authenticating requests coming from dendrite to auth service.
|
||||
# # See https://element-hq.github.io/matrix-authentication-service/reference/configuration.html#clients
|
||||
# client_id: 01JFNM9MCHKV6A7A0C0RBHMYC0
|
||||
# client_secret: c85731184ac8f9aea76cf48146046b454473ca667a0cd1fd52a43034a0662eed
|
||||
|
||||
# # The service token used for authenticating requests coming from auth service to dendrite.
|
||||
# # See https://element-hq.github.io/matrix-authentication-service/reference/configuration.html#matrix
|
||||
# admin_token: ttJORW9oV4Wf4DJ63GdZEYekE2KElP4g
|
||||
|
||||
# # URL of the account page on the auth service side
|
||||
# account_management_url: "https://mas.example.com/account"
|
||||
|
||||
# Configuration for the Sync API.
|
||||
sync_api:
|
||||
# This option controls which HTTP header to inspect to find the real remote IP
|
||||
|
|
|
@ -14,7 +14,7 @@ import (
|
|||
"github.com/element-hq/dendrite/internal/caching"
|
||||
"github.com/element-hq/dendrite/internal/sqlutil"
|
||||
"github.com/element-hq/dendrite/setup/config"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
)
|
||||
|
||||
// NewDatabase opens a new database
|
||||
|
|
2
go.mod
2
go.mod
|
@ -156,6 +156,6 @@ require (
|
|||
nhooyr.io/websocket v1.8.7 // indirect
|
||||
)
|
||||
|
||||
go 1.22
|
||||
go 1.23
|
||||
|
||||
toolchain go1.23.2
|
||||
|
|
|
@ -58,17 +58,26 @@ func WithAuth() AuthAPIOption {
|
|||
}
|
||||
}
|
||||
|
||||
// UserVerifier verifies users by their access tokens. Currently, there are two interface implementations:
|
||||
// DefaultUserVerifier and MSC3861UserVerifier. The first one checks if the token exists in the server's database,
|
||||
// whereas the latter passes the token for verification to MAS and acts in accordance with MAS's response.
|
||||
type UserVerifier interface {
|
||||
// VerifyUserFromRequest authenticates the HTTP request,
|
||||
// on success returns Device of the requester.
|
||||
VerifyUserFromRequest(req *http.Request) (*userapi.Device, *util.JSONResponse)
|
||||
}
|
||||
|
||||
// MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request.
|
||||
func MakeAuthAPI(
|
||||
metricsName string, userAPI userapi.QueryAcccessTokenAPI,
|
||||
metricsName string, userVerifier UserVerifier,
|
||||
f func(*http.Request, *userapi.Device) util.JSONResponse,
|
||||
checks ...AuthAPIOption,
|
||||
) http.Handler {
|
||||
h := func(req *http.Request) util.JSONResponse {
|
||||
logger := util.GetLogger(req.Context())
|
||||
device, err := auth.VerifyUserFromRequest(req, userAPI)
|
||||
device, err := userVerifier.VerifyUserFromRequest(req)
|
||||
if err != nil {
|
||||
logger.Debugf("VerifyUserFromRequest %s -> HTTP %d", req.RemoteAddr, err.Code)
|
||||
logger.Debugf("VerifyUserFromRequest %s -> HTTP %d: JSON %+v", req.RemoteAddr, err.Code, err.JSON)
|
||||
return *err
|
||||
}
|
||||
// add the user ID to the logger
|
||||
|
@ -122,11 +131,11 @@ func MakeAuthAPI(
|
|||
// MakeAdminAPI is a wrapper around MakeAuthAPI which enforces that the request can only be
|
||||
// completed by a user that is a server administrator.
|
||||
func MakeAdminAPI(
|
||||
metricsName string, userAPI userapi.QueryAcccessTokenAPI,
|
||||
metricsName string, userVerifier UserVerifier,
|
||||
f func(*http.Request, *userapi.Device) util.JSONResponse,
|
||||
) http.Handler {
|
||||
return MakeAuthAPI(metricsName, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if device.AccountType != userapi.AccountTypeAdmin {
|
||||
return MakeAuthAPI(metricsName, userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if device == nil || device.AccountType != userapi.AccountTypeAdmin {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: spec.Forbidden("This API can only be used by admin users."),
|
||||
|
@ -136,6 +145,38 @@ func MakeAdminAPI(
|
|||
})
|
||||
}
|
||||
|
||||
// MakeServiceAdminAPI is a wrapper around MakeExternalAPI which enforces that the request can only be
|
||||
// completed by a trusted service e.g. Matrix Auth Service (MAS).
|
||||
func MakeServiceAdminAPI(
|
||||
metricsName, serviceToken string,
|
||||
f func(*http.Request) util.JSONResponse,
|
||||
) http.Handler {
|
||||
h := func(req *http.Request) util.JSONResponse {
|
||||
logger := util.GetLogger(req.Context())
|
||||
token, err := auth.ExtractAccessToken(req)
|
||||
|
||||
if err != nil {
|
||||
logger.Debugf("ExtractAccessToken %s -> HTTP %d", req.RemoteAddr, http.StatusUnauthorized)
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusUnauthorized,
|
||||
JSON: spec.MissingToken(err.Error()),
|
||||
}
|
||||
}
|
||||
if token != serviceToken {
|
||||
logger.Debug("Invalid service token")
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: spec.UnknownToken(token),
|
||||
}
|
||||
}
|
||||
// add the service addr to the logger
|
||||
logger = logger.WithField("service_useragent", req.UserAgent())
|
||||
req = req.WithContext(util.ContextWithLogger(req.Context(), logger))
|
||||
return f(req)
|
||||
}
|
||||
return MakeExternalAPI(metricsName, h)
|
||||
}
|
||||
|
||||
// MakeExternalAPI turns a util.JSONRequestHandler function into an http.Handler.
|
||||
// This is used for APIs that are called from the internet.
|
||||
func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse) http.Handler {
|
||||
|
@ -200,7 +241,7 @@ func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse
|
|||
|
||||
// MakeHTTPAPI adds Span metrics to the HTML Handler function
|
||||
// This is used to serve HTML alongside JSON error messages
|
||||
func MakeHTTPAPI(metricsName string, userAPI userapi.QueryAcccessTokenAPI, enableMetrics bool, f func(http.ResponseWriter, *http.Request), checks ...AuthAPIOption) http.Handler {
|
||||
func MakeHTTPAPI(metricsName string, userVerifier UserVerifier, enableMetrics bool, f func(http.ResponseWriter, *http.Request), checks ...AuthAPIOption) http.Handler {
|
||||
withSpan := func(w http.ResponseWriter, req *http.Request) {
|
||||
if req.Method == http.MethodOptions {
|
||||
util.SetCORSHeaders(w)
|
||||
|
@ -220,7 +261,7 @@ func MakeHTTPAPI(metricsName string, userAPI userapi.QueryAcccessTokenAPI, enabl
|
|||
|
||||
if opts.WithAuth {
|
||||
logger := util.GetLogger(req.Context())
|
||||
_, jsonErr := auth.VerifyUserFromRequest(req, userAPI)
|
||||
_, jsonErr := userVerifier.VerifyUserFromRequest(req)
|
||||
if jsonErr != nil {
|
||||
w.WriteHeader(jsonErr.Code)
|
||||
if err := json.NewEncoder(w).Encode(jsonErr.JSON); err != nil {
|
||||
|
|
|
@ -10,6 +10,8 @@ import (
|
|||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
func TestWrapHandlerInBasicAuth(t *testing.T) {
|
||||
|
@ -99,3 +101,68 @@ func TestWrapHandlerInBasicAuth(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakeServiceAdminAPI(t *testing.T) {
|
||||
serviceToken := "valid_secret_token"
|
||||
type args struct {
|
||||
f func(*http.Request) util.JSONResponse
|
||||
serviceToken string
|
||||
}
|
||||
|
||||
f := func(*http.Request) util.JSONResponse {
|
||||
return util.JSONResponse{Code: http.StatusOK}
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want int
|
||||
reqAuth bool
|
||||
}{
|
||||
{
|
||||
name: "service token valid",
|
||||
args: args{
|
||||
f: f,
|
||||
serviceToken: serviceToken,
|
||||
},
|
||||
want: http.StatusOK,
|
||||
reqAuth: true,
|
||||
},
|
||||
{
|
||||
name: "service token invalid",
|
||||
args: args{
|
||||
f: f,
|
||||
serviceToken: "invalid_service_token",
|
||||
},
|
||||
want: http.StatusForbidden,
|
||||
reqAuth: true,
|
||||
},
|
||||
{
|
||||
name: "service token is missing",
|
||||
args: args{
|
||||
f: f,
|
||||
serviceToken: "",
|
||||
},
|
||||
want: http.StatusUnauthorized,
|
||||
reqAuth: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
handler := MakeServiceAdminAPI("metrics", serviceToken, tt.args.f)
|
||||
|
||||
req := httptest.NewRequest("GET", "http://localhost/admin/v1/username_available", nil)
|
||||
if tt.reqAuth {
|
||||
req.Header.Add("Authorization", "Bearer "+tt.args.serviceToken)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
resp := w.Result()
|
||||
|
||||
if resp.StatusCode != tt.want {
|
||||
t.Errorf("Expected status code %d, got %d", resp.StatusCode, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,7 +14,6 @@ import (
|
|||
"github.com/element-hq/dendrite/mediaapi/routing"
|
||||
"github.com/element-hq/dendrite/mediaapi/storage"
|
||||
"github.com/element-hq/dendrite/setup/config"
|
||||
userapi "github.com/element-hq/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/gomatrixserverlib/fclient"
|
||||
)
|
||||
|
@ -24,10 +23,10 @@ func AddPublicRoutes(
|
|||
routers httputil.Routers,
|
||||
cm *sqlutil.Connections,
|
||||
cfg *config.Dendrite,
|
||||
userAPI userapi.MediaUserAPI,
|
||||
client *fclient.Client,
|
||||
fedClient fclient.FederationClient,
|
||||
keyRing gomatrixserverlib.JSONVerifier,
|
||||
userVerifier httputil.UserVerifier,
|
||||
) {
|
||||
mediaDB, err := storage.NewMediaAPIDatasource(cm, &cfg.MediaAPI.Database)
|
||||
if err != nil {
|
||||
|
@ -35,6 +34,6 @@ func AddPublicRoutes(
|
|||
}
|
||||
|
||||
routing.Setup(
|
||||
routers, cfg, mediaDB, userAPI, client, fedClient, keyRing,
|
||||
routers, cfg, mediaDB, client, fedClient, keyRing, userVerifier,
|
||||
)
|
||||
}
|
||||
|
|
|
@ -42,10 +42,10 @@ func Setup(
|
|||
routers httputil.Routers,
|
||||
cfg *config.Dendrite,
|
||||
db storage.Database,
|
||||
userAPI userapi.MediaUserAPI,
|
||||
client *fclient.Client,
|
||||
federationClient fclient.FederationClient,
|
||||
keyRing gomatrixserverlib.JSONVerifier,
|
||||
userVerifier httputil.UserVerifier,
|
||||
) {
|
||||
rateLimits := httputil.NewRateLimits(&cfg.ClientAPI.RateLimiting)
|
||||
|
||||
|
@ -58,7 +58,7 @@ func Setup(
|
|||
}
|
||||
|
||||
uploadHandler := httputil.MakeAuthAPI(
|
||||
"upload", userAPI,
|
||||
"upload", userVerifier,
|
||||
func(req *http.Request, dev *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req, dev); r != nil {
|
||||
return *r
|
||||
|
@ -67,7 +67,7 @@ func Setup(
|
|||
},
|
||||
)
|
||||
|
||||
configHandler := httputil.MakeAuthAPI("config", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
configHandler := httputil.MakeAuthAPI("config", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req, device); r != nil {
|
||||
return *r
|
||||
}
|
||||
|
@ -97,13 +97,13 @@ func Setup(
|
|||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
// v1 client endpoints requiring auth
|
||||
downloadHandlerAuthed := httputil.MakeHTTPAPI("download", userAPI, cfg.Global.Metrics.Enabled, makeDownloadAPI("download_authed_client", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth())
|
||||
downloadHandlerAuthed := httputil.MakeHTTPAPI("download", userVerifier, cfg.Global.Metrics.Enabled, makeDownloadAPI("download_authed_client", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth())
|
||||
v1mux.Handle("/config", configHandler).Methods(http.MethodGet, http.MethodOptions)
|
||||
v1mux.Handle("/download/{serverName}/{mediaId}", downloadHandlerAuthed).Methods(http.MethodGet, http.MethodOptions)
|
||||
v1mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandlerAuthed).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v1mux.Handle("/thumbnail/{serverName}/{mediaId}",
|
||||
httputil.MakeHTTPAPI("thumbnail", userAPI, cfg.Global.Metrics.Enabled, makeDownloadAPI("thumbnail_authed_client", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth()),
|
||||
httputil.MakeHTTPAPI("thumbnail", userVerifier, cfg.Global.Metrics.Enabled, makeDownloadAPI("thumbnail_authed_client", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth()),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
// same, but for federation
|
||||
|
|
|
@ -13,7 +13,7 @@ import (
|
|||
"github.com/element-hq/dendrite/internal/sqlutil"
|
||||
"github.com/element-hq/dendrite/relayapi/storage/sqlite3"
|
||||
"github.com/element-hq/dendrite/setup/config"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
)
|
||||
|
||||
// NewDatabase opens a new database
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/element-hq/dendrite/clientapi/auth"
|
||||
"github.com/element-hq/dendrite/federationapi/statistics"
|
||||
"github.com/element-hq/dendrite/internal/caching"
|
||||
"github.com/element-hq/dendrite/internal/eventutil"
|
||||
|
@ -267,7 +268,8 @@ func TestPurgeRoom(t *testing.T) {
|
|||
rsAPI.SetFederationAPI(fsAPI, nil)
|
||||
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, fsAPI.IsBlacklistedOrBackingOff)
|
||||
syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, userAPI, rsAPI, caches, caching.DisableMetrics)
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, userAPI, rsAPI, caches, &userVerifier, caching.DisableMetrics)
|
||||
|
||||
// Create the room
|
||||
if err = api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
|
||||
|
|
|
@ -74,34 +74,45 @@ func (c *ClientAPI) Defaults(opts DefaultOpts) {
|
|||
func (c *ClientAPI) Verify(configErrs *ConfigErrors) {
|
||||
c.TURN.Verify(configErrs)
|
||||
c.RateLimiting.Verify(configErrs)
|
||||
if c.RecaptchaEnabled {
|
||||
if c.RecaptchaSiteVerifyAPI == "" {
|
||||
c.RecaptchaSiteVerifyAPI = "https://www.google.com/recaptcha/api/siteverify"
|
||||
|
||||
if c.MSCs.MSC3861Enabled() {
|
||||
if !c.RegistrationDisabled || c.RecaptchaEnabled {
|
||||
configErrs.Add(
|
||||
"You have enabled the experimental feature MSC3861 which implements the delegated authentication via OIDC. " +
|
||||
"As a result, the feature conflicts with the standard Dendrite's registration and login flows and cannot be used if any of those is enabled. " +
|
||||
"You need to disable registration (client_api.registration_disabled) and recapthca (client_api.enable_registration_captcha) options to proceed.",
|
||||
)
|
||||
}
|
||||
if c.RecaptchaApiJsUrl == "" {
|
||||
c.RecaptchaApiJsUrl = "https://www.google.com/recaptcha/api.js"
|
||||
} else {
|
||||
if c.RecaptchaEnabled {
|
||||
if c.RecaptchaSiteVerifyAPI == "" {
|
||||
c.RecaptchaSiteVerifyAPI = "https://www.google.com/recaptcha/api/siteverify"
|
||||
}
|
||||
if c.RecaptchaApiJsUrl == "" {
|
||||
c.RecaptchaApiJsUrl = "https://www.google.com/recaptcha/api.js"
|
||||
}
|
||||
if c.RecaptchaFormField == "" {
|
||||
c.RecaptchaFormField = "g-recaptcha-response"
|
||||
}
|
||||
if c.RecaptchaSitekeyClass == "" {
|
||||
c.RecaptchaSitekeyClass = "g-recaptcha"
|
||||
}
|
||||
checkNotEmpty(configErrs, "client_api.recaptcha_public_key", c.RecaptchaPublicKey)
|
||||
checkNotEmpty(configErrs, "client_api.recaptcha_private_key", c.RecaptchaPrivateKey)
|
||||
checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", c.RecaptchaSiteVerifyAPI)
|
||||
checkNotEmpty(configErrs, "client_api.recaptcha_sitekey_class", c.RecaptchaSitekeyClass)
|
||||
}
|
||||
if c.RecaptchaFormField == "" {
|
||||
c.RecaptchaFormField = "g-recaptcha-response"
|
||||
// Ensure there is any spam counter measure when enabling registration
|
||||
if !c.RegistrationDisabled && !c.OpenRegistrationWithoutVerificationEnabled && !c.RecaptchaEnabled {
|
||||
configErrs.Add(
|
||||
"You have tried to enable open registration without any secondary verification methods " +
|
||||
"(such as reCAPTCHA). By enabling open registration, you are SIGNIFICANTLY " +
|
||||
"increasing the risk that your server will be used to send spam or abuse, and may result in " +
|
||||
"your server being banned from some rooms. If you are ABSOLUTELY CERTAIN you want to do this, " +
|
||||
"start Dendrite with the -really-enable-open-registration command line flag. Otherwise, you " +
|
||||
"should set the registration_disabled option in your Dendrite config.",
|
||||
)
|
||||
}
|
||||
if c.RecaptchaSitekeyClass == "" {
|
||||
c.RecaptchaSitekeyClass = "g-recaptcha"
|
||||
}
|
||||
checkNotEmpty(configErrs, "client_api.recaptcha_public_key", c.RecaptchaPublicKey)
|
||||
checkNotEmpty(configErrs, "client_api.recaptcha_private_key", c.RecaptchaPrivateKey)
|
||||
checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", c.RecaptchaSiteVerifyAPI)
|
||||
checkNotEmpty(configErrs, "client_api.recaptcha_sitekey_class", c.RecaptchaSitekeyClass)
|
||||
}
|
||||
// Ensure there is any spam counter measure when enabling registration
|
||||
if !c.RegistrationDisabled && !c.OpenRegistrationWithoutVerificationEnabled && !c.RecaptchaEnabled {
|
||||
configErrs.Add(
|
||||
"You have tried to enable open registration without any secondary verification methods " +
|
||||
"(such as reCAPTCHA). By enabling open registration, you are SIGNIFICANTLY " +
|
||||
"increasing the risk that your server will be used to send spam or abuse, and may result in " +
|
||||
"your server being banned from some rooms. If you are ABSOLUTELY CERTAIN you want to do this, " +
|
||||
"start Dendrite with the -really-enable-open-registration command line flag. Otherwise, you " +
|
||||
"should set the registration_disabled option in your Dendrite config.",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -4,11 +4,16 @@ type MSCs struct {
|
|||
Matrix *Global `yaml:"-"`
|
||||
|
||||
// The MSCs to enable. Supported MSCs include:
|
||||
// 'msc3861': Delegate auth to an OIDC provider - https://github.com/matrix-org/matrix-spec-proposals/pull/3861
|
||||
// 'msc2444': Peeking over federation - https://github.com/matrix-org/matrix-doc/pull/2444
|
||||
// 'msc2753': Peeking via /sync - https://github.com/matrix-org/matrix-doc/pull/2753
|
||||
// 'msc2836': Threading - https://github.com/matrix-org/matrix-doc/pull/2836
|
||||
MSCs []string `yaml:"mscs"`
|
||||
|
||||
// MSC3861 contains config related to the experimental feature MSC3861.
|
||||
// It takes effect only if 'msc3861' is included in 'MSCs' array.
|
||||
MSC3861 *MSC3861 `yaml:"msc3861,omitempty"`
|
||||
|
||||
Database DatabaseOptions `yaml:"database,omitempty"`
|
||||
}
|
||||
|
||||
|
@ -34,4 +39,27 @@ func (c *MSCs) Verify(configErrs *ConfigErrors) {
|
|||
if c.Matrix.DatabaseOptions.ConnectionString == "" {
|
||||
checkNotEmpty(configErrs, "mscs.database.connection_string", string(c.Database.ConnectionString))
|
||||
}
|
||||
if m := c.MSC3861; m != nil && c.MSC3861Enabled() {
|
||||
m.Verify(configErrs)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MSCs) MSC3861Enabled() bool {
|
||||
return c.Enabled("msc3861") && c.MSC3861 != nil
|
||||
}
|
||||
|
||||
type MSC3861 struct {
|
||||
Issuer string `yaml:"issuer"`
|
||||
ClientID string `yaml:"client_id"`
|
||||
ClientSecret string `yaml:"client_secret"`
|
||||
AdminToken string `yaml:"admin_token"`
|
||||
AccountManagementURL string `yaml:"account_management_url"`
|
||||
}
|
||||
|
||||
func (m *MSC3861) Verify(configErrs *ConfigErrors) {
|
||||
checkNotEmpty(configErrs, "mscs.msc3861.issuer", string(m.Issuer))
|
||||
checkNotEmpty(configErrs, "mscs.msc3861.client_id", string(m.ClientID))
|
||||
checkNotEmpty(configErrs, "mscs.msc3861.client_secret", string(m.ClientSecret))
|
||||
checkNotEmpty(configErrs, "mscs.msc3861.admin_token", string(m.AdminToken))
|
||||
checkNotEmpty(configErrs, "mscs.msc3861.account_management_url", string(m.AccountManagementURL))
|
||||
}
|
||||
|
|
|
@ -7,9 +7,12 @@
|
|||
package setup
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
appserviceAPI "github.com/element-hq/dendrite/appservice/api"
|
||||
"github.com/element-hq/dendrite/clientapi"
|
||||
"github.com/element-hq/dendrite/clientapi/api"
|
||||
"github.com/element-hq/dendrite/clientapi/auth"
|
||||
"github.com/element-hq/dendrite/federationapi"
|
||||
federationAPI "github.com/element-hq/dendrite/federationapi/api"
|
||||
"github.com/element-hq/dendrite/internal/caching"
|
||||
|
@ -27,6 +30,7 @@ import (
|
|||
userapi "github.com/element-hq/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/gomatrixserverlib/fclient"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
// Monolith represents an instantiation of all dependencies required to build
|
||||
|
@ -46,6 +50,8 @@ type Monolith struct {
|
|||
// Optional
|
||||
ExtPublicRoomsProvider api.ExtraPublicRoomsProvider
|
||||
ExtUserDirectoryProvider userapi.QuerySearchProfilesAPI
|
||||
|
||||
UserVerifierProvider httputil.UserVerifier
|
||||
}
|
||||
|
||||
// AddAllPublicRoutes attaches all public paths to the given router
|
||||
|
@ -58,6 +64,10 @@ func (m *Monolith) AddAllPublicRoutes(
|
|||
caches *caching.Caches,
|
||||
enableMetrics bool,
|
||||
) {
|
||||
if m.UserVerifierProvider == nil {
|
||||
m.UserVerifierProvider = NewUserVerifierProvider(&auth.DefaultUserVerifier{UserAPI: m.UserAPI})
|
||||
}
|
||||
|
||||
userDirectoryProvider := m.ExtUserDirectoryProvider
|
||||
if userDirectoryProvider == nil {
|
||||
userDirectoryProvider = m.UserAPI
|
||||
|
@ -65,15 +75,29 @@ func (m *Monolith) AddAllPublicRoutes(
|
|||
clientapi.AddPublicRoutes(
|
||||
processCtx, routers, cfg, natsInstance, m.FedClient, m.RoomserverAPI, m.AppserviceAPI, transactions.New(),
|
||||
m.FederationAPI, m.UserAPI, userDirectoryProvider,
|
||||
m.ExtPublicRoomsProvider, enableMetrics,
|
||||
m.ExtPublicRoomsProvider, m.UserVerifierProvider, enableMetrics,
|
||||
)
|
||||
federationapi.AddPublicRoutes(
|
||||
processCtx, routers, cfg, natsInstance, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI, enableMetrics,
|
||||
)
|
||||
mediaapi.AddPublicRoutes(routers, cm, cfg, m.UserAPI, m.Client, m.FedClient, m.KeyRing)
|
||||
syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, natsInstance, m.UserAPI, m.RoomserverAPI, caches, enableMetrics)
|
||||
mediaapi.AddPublicRoutes(routers, cm, cfg, m.Client, m.FedClient, m.KeyRing, m.UserVerifierProvider)
|
||||
syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, natsInstance, m.UserAPI, m.RoomserverAPI, caches, m.UserVerifierProvider, enableMetrics)
|
||||
|
||||
if m.RelayAPI != nil {
|
||||
relayapi.AddPublicRoutes(routers, cfg, m.KeyRing, m.RelayAPI)
|
||||
}
|
||||
}
|
||||
|
||||
type UserVerifierProvider struct {
|
||||
httputil.UserVerifier
|
||||
}
|
||||
|
||||
func (u *UserVerifierProvider) VerifyUserFromRequest(req *http.Request) (*userapi.Device, *util.JSONResponse) {
|
||||
return u.UserVerifier.VerifyUserFromRequest(req)
|
||||
}
|
||||
|
||||
func NewUserVerifierProvider(userVerifier httputil.UserVerifier) *UserVerifierProvider {
|
||||
return &UserVerifierProvider{
|
||||
UserVerifier: userVerifier,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -98,7 +98,7 @@ func toClientResponse(ctx context.Context, res *MSC2836EventRelationshipsRespons
|
|||
// Enable this MSC
|
||||
func Enable(
|
||||
cfg *config.Dendrite, cm *sqlutil.Connections, routers httputil.Routers, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI,
|
||||
userAPI userapi.UserInternalAPI, keyRing gomatrixserverlib.JSONVerifier,
|
||||
userVerifier httputil.UserVerifier, keyRing gomatrixserverlib.JSONVerifier,
|
||||
) error {
|
||||
db, err := NewDatabase(cm, &cfg.MSCs.Database)
|
||||
if err != nil {
|
||||
|
@ -124,7 +124,7 @@ func Enable(
|
|||
})
|
||||
|
||||
routers.Client.Handle("/unstable/event_relationships",
|
||||
httputil.MakeAuthAPI("eventRelationships", userAPI, eventRelationshipHandler(db, rsAPI, fsAPI)),
|
||||
httputil.MakeAuthAPI("eventRelationships", userVerifier, eventRelationshipHandler(db, rsAPI, fsAPI)),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
routers.Federation.Handle("/unstable/event_relationships", httputil.MakeExternalAPI(
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/element-hq/dendrite/clientapi/auth"
|
||||
"github.com/element-hq/dendrite/setup/process"
|
||||
"github.com/element-hq/dendrite/syncapi/synctypes"
|
||||
"github.com/gorilla/mux"
|
||||
|
@ -571,7 +572,8 @@ func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserve
|
|||
processCtx := process.NewProcessContext()
|
||||
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||
routers := httputil.NewRouters()
|
||||
err := msc2836.Enable(cfg, cm, routers, rsAPI, nil, userAPI, nil)
|
||||
userVerifier := auth.DefaultUserVerifier{UserAPI: userAPI}
|
||||
err := msc2836.Enable(cfg, cm, routers, rsAPI, nil, &userVerifier, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to enable MSC2836: %s", err)
|
||||
}
|
||||
|
|
38
setup/mscs/msc3861/msc3861.go
Normal file
38
setup/mscs/msc3861/msc3861.go
Normal file
|
@ -0,0 +1,38 @@
|
|||
// Copyright 2025 New Vector Ltd.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
|
||||
// Please see LICENSE files in the repository root for full details.
|
||||
|
||||
package msc3861
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/element-hq/dendrite/setup"
|
||||
"github.com/matrix-org/gomatrixserverlib/fclient"
|
||||
)
|
||||
|
||||
func Enable(m *setup.Monolith) error {
|
||||
client := fclient.NewClient()
|
||||
userVerifier, err := newMSC3861UserVerifier(
|
||||
m.UserAPI, m.Config.Global.ServerName,
|
||||
m.Config.MSCs.MSC3861, !m.Config.ClientAPI.GuestsDisabled,
|
||||
client,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.UserVerifierProvider == nil {
|
||||
return errors.New("msc3861: UserVerifierProvider is not initialised")
|
||||
}
|
||||
|
||||
provider, ok := m.UserVerifierProvider.(*setup.UserVerifierProvider)
|
||||
if !ok {
|
||||
return errors.New("msc3861: the expected type of m.UserVerifierProvider is *setup.UserVerifierProvider")
|
||||
}
|
||||
|
||||
provider.UserVerifier = userVerifier
|
||||
|
||||
return nil
|
||||
}
|
458
setup/mscs/msc3861/msc3861_user_verifier.go
Normal file
458
setup/mscs/msc3861/msc3861_user_verifier.go
Normal file
|
@ -0,0 +1,458 @@
|
|||
// Copyright 2025 New Vector Ltd.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
|
||||
// Please see LICENSE files in the repository root for full details.
|
||||
|
||||
package msc3861
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/element-hq/dendrite/clientapi/auth"
|
||||
"github.com/element-hq/dendrite/setup/config"
|
||||
"github.com/element-hq/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib/fclient"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
const externalAuthProvider string = "oauth-delegated"
|
||||
|
||||
// Scopes as defined by MSC2967
|
||||
// https://github.com/matrix-org/matrix-spec-proposals/pull/2967
|
||||
const (
|
||||
scopeMatrixAPI string = "urn:matrix:org.matrix.msc2967.client:api:*"
|
||||
scopeMatrixGuest string = "urn:matrix:org.matrix.msc2967.client:api:guest"
|
||||
scopeMatrixDevicePrefix string = "urn:matrix:org.matrix.msc2967.client:device:"
|
||||
)
|
||||
|
||||
type errCode string
|
||||
|
||||
const (
|
||||
codeIntrospectionNot2xx errCode = "introspectionIsNot2xx"
|
||||
codeInvalidClientToken errCode = "invalidClientToken"
|
||||
codeAuthError errCode = "authError"
|
||||
codeMxidError errCode = "mxidError"
|
||||
codeOpenidConfigEndpointNon2xx errCode = "openidConfigEndpointNon2xx"
|
||||
codeOpenidConfigDecodingFailed errCode = "openidConfigDecodingFailed"
|
||||
)
|
||||
|
||||
// MSC3861UserVerifier implements UserVerifier interface
|
||||
type MSC3861UserVerifier struct {
|
||||
userAPI api.UserInternalAPI
|
||||
serverName spec.ServerName
|
||||
cfg *config.MSC3861
|
||||
httpClient *fclient.Client
|
||||
openIdConfig *OpenIDConfiguration
|
||||
allowGuest bool
|
||||
}
|
||||
|
||||
func newMSC3861UserVerifier(
|
||||
userAPI api.UserInternalAPI,
|
||||
serverName spec.ServerName,
|
||||
cfg *config.MSC3861,
|
||||
allowGuest bool,
|
||||
client *fclient.Client,
|
||||
) (*MSC3861UserVerifier, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("unable to create MSC3861UserVerifier object as 'cfg' param is nil")
|
||||
}
|
||||
|
||||
if client == nil {
|
||||
return nil, errors.New("unable to create MSC3861UserVerifier object as 'client' param is nil")
|
||||
}
|
||||
|
||||
openIdConfig, err := fetchOpenIDConfiguration(client, cfg.Issuer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &MSC3861UserVerifier{
|
||||
userAPI: userAPI,
|
||||
serverName: serverName,
|
||||
cfg: cfg,
|
||||
openIdConfig: openIdConfig,
|
||||
allowGuest: allowGuest,
|
||||
httpClient: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type mscError struct {
|
||||
Code errCode
|
||||
Msg string
|
||||
}
|
||||
|
||||
func (r *mscError) Error() string {
|
||||
return fmt.Sprintf("%s: %s", r.Code, r.Msg)
|
||||
}
|
||||
|
||||
// VerifyUserFromRequest authenticates the HTTP request, on success returns Device of the requester.
|
||||
func (m *MSC3861UserVerifier) VerifyUserFromRequest(req *http.Request) (*api.Device, *util.JSONResponse) {
|
||||
ctx := req.Context()
|
||||
util.GetLogger(ctx).Debug("MSC3861.VerifyUserFromRequest")
|
||||
// Try to find the Application Service user
|
||||
token, err := auth.ExtractAccessToken(req)
|
||||
if err != nil {
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusUnauthorized,
|
||||
JSON: spec.MissingToken(err.Error()),
|
||||
}
|
||||
}
|
||||
if appServiceUserID := req.URL.Query().Get("user_id"); appServiceUserID != "" {
|
||||
var res api.QueryAccessTokenResponse
|
||||
err = m.userAPI.QueryAccessToken(ctx, &api.QueryAccessTokenRequest{
|
||||
AccessToken: token,
|
||||
AppServiceUserID: appServiceUserID,
|
||||
}, &res)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("userAPI.QueryAccessToken failed")
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
userData, err := m.getUserByAccessToken(ctx, token)
|
||||
if err != nil {
|
||||
switch e := err.(type) {
|
||||
case (*mscError):
|
||||
switch e.Code {
|
||||
case codeIntrospectionNot2xx, codeOpenidConfigDecodingFailed, codeOpenidConfigEndpointNon2xx:
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusServiceUnavailable,
|
||||
JSON: spec.Unknown(e.Error()),
|
||||
}
|
||||
case codeInvalidClientToken:
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusUnauthorized,
|
||||
JSON: spec.UnknownToken(e.Error()),
|
||||
}
|
||||
case codeAuthError, codeMxidError:
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.Unknown(e.Error()),
|
||||
}
|
||||
default:
|
||||
r := util.ErrorResponse(err)
|
||||
return nil, &r
|
||||
}
|
||||
default:
|
||||
r := util.ErrorResponse(err)
|
||||
return nil, &r
|
||||
}
|
||||
}
|
||||
|
||||
// Do not record requests from MAS using the virtual `__oidc_admin` user.
|
||||
if token != m.cfg.AdminToken {
|
||||
// XXX: not sure which exact data we should record here. See the link for reference
|
||||
// https://github.com/element-hq/synapse/blob/develop/synapse/api/auth/base.py#L365
|
||||
}
|
||||
|
||||
if !m.allowGuest && userData.IsGuest {
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusUnauthorized,
|
||||
JSON: spec.Forbidden(strings.Join([]string{"Insufficient scope: ", scopeMatrixAPI}, "")),
|
||||
}
|
||||
}
|
||||
|
||||
return userData.Device, nil
|
||||
}
|
||||
|
||||
type requester struct {
|
||||
Device *api.Device
|
||||
UserID *spec.UserID
|
||||
Scope []string
|
||||
IsGuest bool
|
||||
}
|
||||
|
||||
// nolint: gocyclo
|
||||
func (m *MSC3861UserVerifier) getUserByAccessToken(ctx context.Context, token string) (*requester, error) {
|
||||
var userID *spec.UserID
|
||||
logger := util.GetLogger(ctx)
|
||||
|
||||
if adminToken := m.cfg.AdminToken; adminToken != "" && token == adminToken {
|
||||
// XXX: This is a temporary solution so that the admin API can be called by
|
||||
// the OIDC provider. This will be removed once we have OIDC client
|
||||
// credentials grant support in matrix-authentication-service.
|
||||
// XXX: that user doesn't exist and won't be provisioned.
|
||||
adminUser, err := createUserID("__oidc_admin", m.serverName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &requester{
|
||||
UserID: adminUser,
|
||||
Scope: []string{"urn:synapse:admin:*"},
|
||||
Device: &api.Device{UserID: adminUser.Local(), AccountType: api.AccountTypeOIDCService},
|
||||
}, nil
|
||||
}
|
||||
|
||||
introspectionResult, err := m.introspectToken(ctx, token)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("MSC3861UserVerifier:introspectToken")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !introspectionResult.Active {
|
||||
return nil, &mscError{Code: codeInvalidClientToken, Msg: "Token is not active"}
|
||||
}
|
||||
|
||||
scopes := introspectionResult.Scopes()
|
||||
hasUserScope, hasGuestScope := slices.Contains(scopes, scopeMatrixAPI), slices.Contains(scopes, scopeMatrixGuest)
|
||||
if !hasUserScope && !hasGuestScope {
|
||||
return nil, &mscError{Code: codeInvalidClientToken, Msg: "No scope in token granting user rights"}
|
||||
}
|
||||
|
||||
sub := introspectionResult.Sub
|
||||
if sub == "" {
|
||||
return nil, &mscError{Code: codeInvalidClientToken, Msg: "Invalid sub claim in the introspection result"}
|
||||
}
|
||||
|
||||
localpart := ""
|
||||
localpartExternalID, err := m.userAPI.QueryExternalUserIDByLocalpartAndProvider(ctx, sub, externalAuthProvider)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return nil, err
|
||||
}
|
||||
if localpartExternalID != nil {
|
||||
localpart = localpartExternalID.Localpart
|
||||
}
|
||||
|
||||
if localpart == "" {
|
||||
// If we could not find a user via the external_id, it either does not exist,
|
||||
// or the external_id was never recorded
|
||||
username := introspectionResult.Username
|
||||
if username == "" {
|
||||
return nil, &mscError{Code: codeAuthError, Msg: "Invalid username claim in the introspection result"}
|
||||
}
|
||||
userID, err = createUserID(username, m.serverName)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("getUserByAccessToken:createUserID")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// First try to find a user from the username claim
|
||||
var account *api.Account
|
||||
{
|
||||
var rs api.QueryAccountByLocalpartResponse
|
||||
err = m.userAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{Localpart: userID.Local(), ServerName: userID.Domain()}, &rs)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
logger.WithError(err).Error("QueryAccountByLocalpart")
|
||||
return nil, err
|
||||
}
|
||||
account = rs.Account
|
||||
}
|
||||
|
||||
if account == nil {
|
||||
// If the user does not exist, we should create it on the fly
|
||||
var rs api.PerformAccountCreationResponse
|
||||
if err = m.userAPI.PerformAccountCreation(ctx, &api.PerformAccountCreationRequest{
|
||||
AccountType: api.AccountTypeUser,
|
||||
Localpart: userID.Local(),
|
||||
ServerName: userID.Domain(),
|
||||
}, &rs); err != nil {
|
||||
logger.WithError(err).Error("PerformAccountCreation")
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err = m.userAPI.PerformLocalpartExternalUserIDCreation(ctx, userID.Local(), sub, externalAuthProvider); err != nil {
|
||||
logger.WithError(err).Error("PerformLocalpartExternalUserIDCreation")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
localpart = userID.Local()
|
||||
}
|
||||
|
||||
if userID == nil {
|
||||
userID, err = createUserID(localpart, m.serverName)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("getUserByAccessToken:createUserID")
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
deviceIDs := make([]string, 0, 1)
|
||||
for i := range scopes {
|
||||
if s := scopes[i]; strings.HasPrefix(s, scopeMatrixDevicePrefix) {
|
||||
deviceIDs = append(deviceIDs, s[len(scopeMatrixDevicePrefix):])
|
||||
}
|
||||
}
|
||||
|
||||
if len(deviceIDs) != 1 {
|
||||
logger.Errorf("Invalid device IDs in scope: %+v", deviceIDs)
|
||||
return nil, &mscError{Code: codeAuthError, Msg: "Invalid device IDs in scope"}
|
||||
}
|
||||
|
||||
var device *api.Device
|
||||
|
||||
deviceID := deviceIDs[0]
|
||||
if len(deviceID) > 255 || len(deviceID) < 1 {
|
||||
return nil, &mscError{
|
||||
Code: codeAuthError,
|
||||
Msg: strings.Join([]string{"Invalid device ID in scope: ", deviceID}, ""),
|
||||
}
|
||||
}
|
||||
|
||||
userDeviceExists := false
|
||||
{
|
||||
var rs api.QueryDevicesResponse
|
||||
err := m.userAPI.QueryDevices(ctx, &api.QueryDevicesRequest{UserID: userID.String()}, &rs)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range rs.Devices {
|
||||
if d := &rs.Devices[i]; d.ID == deviceID {
|
||||
userDeviceExists = true
|
||||
device = d
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !userDeviceExists {
|
||||
var rs api.PerformDeviceCreationResponse
|
||||
deviceDisplayName := "OIDC-native client"
|
||||
if err := m.userAPI.PerformDeviceCreation(ctx, &api.PerformDeviceCreationRequest{
|
||||
Localpart: localpart,
|
||||
ServerName: m.serverName,
|
||||
AccessToken: "",
|
||||
DeviceID: &deviceID,
|
||||
DeviceDisplayName: &deviceDisplayName,
|
||||
AccessTokenUniqueConstraintDisabled: true,
|
||||
// TODO: Cannot add IPAddr and Useragent values here. Should we care about it here?
|
||||
}, &rs); err != nil {
|
||||
logger.WithError(err).Error("PerformDeviceCreation")
|
||||
return nil, err
|
||||
}
|
||||
device = rs.Device
|
||||
logger.Debugf("PerformDeviceCreationResponse is: %+v", rs)
|
||||
}
|
||||
|
||||
return &requester{
|
||||
Device: device,
|
||||
UserID: userID,
|
||||
Scope: scopes,
|
||||
IsGuest: hasGuestScope && !hasUserScope,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func createUserID(local string, serverName spec.ServerName) (*spec.UserID, error) {
|
||||
userID, err := spec.NewUserID(strings.Join([]string{"@", local, ":", string(serverName)}, ""), false)
|
||||
if err != nil {
|
||||
return nil, &mscError{Code: codeMxidError, Msg: err.Error()}
|
||||
}
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
func (m *MSC3861UserVerifier) introspectToken(ctx context.Context, token string) (*introspectionResponse, error) {
|
||||
formBody := url.Values{"token": []string{token}}
|
||||
encoded := formBody.Encode()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, m.openIdConfig.IntrospectionEndpoint, strings.NewReader(encoded))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.SetBasicAuth(m.cfg.ClientID, m.cfg.ClientSecret)
|
||||
|
||||
resp, err := m.httpClient.DoHTTPRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close() // nolint: errcheck
|
||||
|
||||
if c := resp.StatusCode; c/100 != 2 {
|
||||
return nil, errors.New(strings.Join([]string{"The introspection endpoint returned a '", resp.Status, "' response"}, ""))
|
||||
}
|
||||
var ir introspectionResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&ir); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ir, nil
|
||||
}
|
||||
|
||||
type OpenIDConfiguration struct {
|
||||
Issuer string `json:"issuer"`
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||
TokenEndpoint string `json:"token_endpoint"`
|
||||
JWKsURI string `json:"jwks_uri"`
|
||||
RegistrationEndpoint string `json:"registration_endpoint"`
|
||||
ScopesSupported []string `json:"scopes_supported"`
|
||||
ResponseTypesSupported []string `json:"response_types_supported"`
|
||||
ResponseModesSupported []string `json:"response_modes_supported"`
|
||||
GrantTypesSupported []string `json:"grant_types_supported"`
|
||||
TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported"`
|
||||
TokenEndpointAuthSigningAlgCaluesSupported []string `json:"token_endpoint_auth_signing_alg_values_supported"`
|
||||
RevocationEnpoint string `json:"revocation_endpoint"`
|
||||
RevocationEndpointAuthMethodsSupported []string `json:"revocation_endpoint_auth_methods_supported"`
|
||||
RevocationEndpointAuthSigningAlgValues []string `json:"revocation_endpoint_auth_signing_alg_values_supported"`
|
||||
IntrospectionEndpoint string `json:"introspection_endpoint"`
|
||||
IntrospectionEndpointAuthMethodsSupported []string `json:"introspection_endpoint_auth_methods_supported"`
|
||||
IntrospectionEndpointAuthSigningAlgValues []string `json:"introspection_endpoint_auth_signing_alg_values_supported"`
|
||||
CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"`
|
||||
UserinfoEndpoint string `json:"userinfo_endpoint"`
|
||||
SubjectTypesSupported []string `json:"subject_types_supported"`
|
||||
IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"`
|
||||
UserinfoSigningAlgValuesSupported []string `json:"userinfo_signing_alg_values_supported"`
|
||||
DisplayValuesSupported []string `json:"display_values_supported"`
|
||||
ClaimTypesSupported []string `json:"claim_types_supported"`
|
||||
ClaimsSupported []string `json:"claims_supported"`
|
||||
ClaimsParameterSupported bool `json:"claims_parameter_supported"`
|
||||
RequestParameterSupported bool `json:"request_parameter_supported"`
|
||||
RequestURIParameterSupported bool `json:"request_uri_parameter_supported"`
|
||||
PromptValuesSupported []string `json:"prompt_values_supported"`
|
||||
DeviceAuthorizaEndpoint string `json:"device_authorization_endpoint"`
|
||||
AccountManagementURI string `json:"account_management_uri"`
|
||||
AccountManagementActionsSupported []string `json:"account_management_actions_supported"`
|
||||
}
|
||||
|
||||
func fetchOpenIDConfiguration(httpClient *fclient.Client, authHostURL string) (*OpenIDConfiguration, error) {
|
||||
u, err := url.Parse(authHostURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u = u.JoinPath(".well-known/openid-configuration")
|
||||
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := httpClient.DoHTTPRequest(context.Background(), req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close() // nolint: errcheck
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, &mscError{Code: codeOpenidConfigEndpointNon2xx, Msg: ".well-known/openid-configuration endpoint returned non-200 response"}
|
||||
}
|
||||
var oic OpenIDConfiguration
|
||||
if err := json.NewDecoder(resp.Body).Decode(&oic); err != nil {
|
||||
return nil, &mscError{Code: codeOpenidConfigDecodingFailed, Msg: err.Error()}
|
||||
}
|
||||
return &oic, nil
|
||||
}
|
||||
|
||||
// introspectionResponse as described in the RFC https://datatracker.ietf.org/doc/html/rfc7662#section-2.2
|
||||
type introspectionResponse struct {
|
||||
Active bool `json:"active"` // required
|
||||
Scope string `json:"scope"` // optional
|
||||
Username string `json:"username"` // optional
|
||||
TokenType string `json:"token_type"` // optional
|
||||
Exp *int64 `json:"exp"` // optional
|
||||
Iat *int64 `json:"iat"` // optional
|
||||
Nfb *int64 `json:"nfb"` // optional
|
||||
Sub string `json:"sub"` // optional
|
||||
Jti string `json:"jti"` // optional
|
||||
Aud string `json:"aud"` // optional
|
||||
Iss string `json:"iss"` // optional
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) Scopes() []string {
|
||||
return strings.Split(i.Scope, " ")
|
||||
}
|
234
setup/mscs/msc3861/msc3861_user_verifier_test.go
Normal file
234
setup/mscs/msc3861/msc3861_user_verifier_test.go
Normal file
|
@ -0,0 +1,234 @@
|
|||
// Copyright 2025 New Vector Ltd.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
|
||||
// Please see LICENSE files in the repository root for full details.
|
||||
|
||||
package msc3861
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"errors"
|
||||
|
||||
"github.com/element-hq/dendrite/federationapi/statistics"
|
||||
"github.com/element-hq/dendrite/internal/caching"
|
||||
"github.com/element-hq/dendrite/internal/sqlutil"
|
||||
"github.com/element-hq/dendrite/roomserver"
|
||||
"github.com/element-hq/dendrite/setup/config"
|
||||
"github.com/element-hq/dendrite/setup/jetstream"
|
||||
"github.com/element-hq/dendrite/test"
|
||||
"github.com/element-hq/dendrite/test/testrig"
|
||||
"github.com/element-hq/dendrite/userapi"
|
||||
uapi "github.com/element-hq/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/gomatrixserverlib/fclient"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
)
|
||||
|
||||
var testIsBlacklistedOrBackingOff = func(s spec.ServerName) (*statistics.ServerStatistics, error) {
|
||||
return &statistics.ServerStatistics{}, nil
|
||||
}
|
||||
|
||||
type roundTripper struct {
|
||||
roundTrip func(request *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
func (rt *roundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||
return rt.roundTrip(request)
|
||||
}
|
||||
|
||||
func TestVerifyUserFromRequest(t *testing.T) {
|
||||
aliceUser := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
|
||||
bobUser := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
|
||||
|
||||
roundTrip := func(request *http.Request) (*http.Response, error) {
|
||||
var (
|
||||
respBody string
|
||||
statusCode int
|
||||
)
|
||||
|
||||
switch request.URL.String() {
|
||||
case "https://mas.example.com/.well-known/openid-configuration":
|
||||
respBody = `{"introspection_endpoint": "https://mas.example.com/oauth2/introspect"}`
|
||||
statusCode = http.StatusOK
|
||||
case "https://mas.example.com/oauth2/introspect":
|
||||
_ = request.ParseForm()
|
||||
|
||||
switch request.Form.Get("token") {
|
||||
case "validTokenUserExistsTokenActive":
|
||||
statusCode = http.StatusOK
|
||||
resp := introspectionResponse{
|
||||
Active: true,
|
||||
Scope: "urn:matrix:org.matrix.msc2967.client:device:devAlice urn:matrix:org.matrix.msc2967.client:api:*",
|
||||
Sub: "111111111111111111",
|
||||
Username: aliceUser.Localpart,
|
||||
}
|
||||
b, _ := json.Marshal(resp)
|
||||
respBody = string(b)
|
||||
case "validTokenUserDoesNotExistTokenActive":
|
||||
statusCode = http.StatusOK
|
||||
resp := introspectionResponse{
|
||||
Active: true,
|
||||
Scope: "urn:matrix:org.matrix.msc2967.client:device:devBob urn:matrix:org.matrix.msc2967.client:api:*",
|
||||
Sub: "222222222222222222",
|
||||
Username: bobUser.Localpart,
|
||||
}
|
||||
b, _ := json.Marshal(resp)
|
||||
respBody = string(b)
|
||||
case "validTokenUserExistsTokenInactive":
|
||||
statusCode = http.StatusOK
|
||||
resp := introspectionResponse{Active: false}
|
||||
b, _ := json.Marshal(resp)
|
||||
respBody = string(b)
|
||||
default:
|
||||
return nil, errors.New("Request URL not supported by stub")
|
||||
}
|
||||
}
|
||||
|
||||
respReader := io.NopCloser(strings.NewReader(respBody))
|
||||
resp := http.Response{
|
||||
StatusCode: statusCode,
|
||||
Body: respReader,
|
||||
ContentLength: int64(len(respBody)),
|
||||
Header: map[string][]string{"Content-Type": {"application/json"}},
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
httpClient := fclient.NewClient(
|
||||
fclient.WithTransport(&roundTripper{roundTrip: roundTrip}),
|
||||
)
|
||||
|
||||
ctx := context.Background()
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
||||
defer close()
|
||||
cfg.ClientAPI.MSCs.MSC3861 = &config.MSC3861{
|
||||
Issuer: "https://mas.example.com",
|
||||
}
|
||||
cfg.ClientAPI.RateLimiting.Enabled = false
|
||||
natsInstance := jetstream.NATSInstance{}
|
||||
// add a vhost
|
||||
cfg.Global.VirtualHosts = append(cfg.Global.VirtualHosts, &config.VirtualHost{
|
||||
SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"},
|
||||
})
|
||||
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
||||
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
// Needed for /login
|
||||
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
|
||||
userVerifier, err := newMSC3861UserVerifier(
|
||||
userAPI,
|
||||
cfg.Global.ServerName,
|
||||
cfg.MSCs.MSC3861,
|
||||
false,
|
||||
httpClient,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
u, _ := url.Parse("https://example.com/something")
|
||||
|
||||
t.Run("existing user and active token", func(t *testing.T) {
|
||||
localpart, serverName, _ := gomatrixserverlib.SplitID('@', aliceUser.ID)
|
||||
userRes := &uapi.PerformAccountCreationResponse{}
|
||||
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
|
||||
AccountType: aliceUser.AccountType,
|
||||
Localpart: localpart,
|
||||
ServerName: serverName,
|
||||
}, userRes); err != nil {
|
||||
t.Errorf("failed to create account: %s", err)
|
||||
}
|
||||
if !userRes.AccountCreated {
|
||||
t.Fatalf("account not created")
|
||||
}
|
||||
httpReq := http.Request{
|
||||
URL: u,
|
||||
Header: map[string][]string{
|
||||
"Content-Type": {"application/json"},
|
||||
"Authorization": {"Bearer validTokenUserExistsTokenActive"},
|
||||
},
|
||||
}
|
||||
device, jsonResp := userVerifier.VerifyUserFromRequest(&httpReq)
|
||||
if jsonResp != nil {
|
||||
t.Fatalf("JSONResponse is not expected: %+v", jsonResp)
|
||||
}
|
||||
deviceRes := uapi.QueryDevicesResponse{}
|
||||
if err := userAPI.QueryDevices(ctx, &uapi.QueryDevicesRequest{
|
||||
UserID: aliceUser.ID,
|
||||
}, &deviceRes); err != nil {
|
||||
t.Errorf("failed to query user devices")
|
||||
}
|
||||
if !deviceRes.UserExists {
|
||||
t.Fatalf("user does not exist")
|
||||
}
|
||||
if l := len(deviceRes.Devices); l != 1 {
|
||||
t.Fatalf("Incorrect number of user devices. Got %d, want 1", l)
|
||||
}
|
||||
if device.ID != deviceRes.Devices[0].ID {
|
||||
t.Fatalf("Device IDs do not match: %s != %s", device.ID, deviceRes.Devices[0].ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("inactive token", func(t *testing.T) {
|
||||
httpReq := http.Request{
|
||||
URL: u,
|
||||
Header: map[string][]string{
|
||||
"Content-Type": {"application/json"},
|
||||
"Authorization": {"Bearer validTokenUserExistsTokenInactive"},
|
||||
},
|
||||
}
|
||||
device, jsonResp := userVerifier.VerifyUserFromRequest(&httpReq)
|
||||
if jsonResp == nil {
|
||||
t.Fatal("JSONResponse is expected to be nil")
|
||||
}
|
||||
if device != nil {
|
||||
t.Fatalf("Device is not nil: %+v", device)
|
||||
}
|
||||
if jsonResp.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("Incorrect status code: want=401, got=%d", jsonResp.Code)
|
||||
}
|
||||
mErr, _ := jsonResp.JSON.(spec.MatrixError)
|
||||
if mErr.ErrCode != spec.ErrorUnknownToken {
|
||||
t.Fatalf("Unexpected error code: want=%s, got=%s", spec.ErrorUnknownToken, mErr.ErrCode)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-existing user", func(t *testing.T) {
|
||||
httpReq := http.Request{
|
||||
URL: u,
|
||||
Header: map[string][]string{
|
||||
"Content-Type": {"application/json"},
|
||||
"Authorization": {"Bearer validTokenUserDoesNotExistTokenActive"},
|
||||
},
|
||||
}
|
||||
device, jsonResp := userVerifier.VerifyUserFromRequest(&httpReq)
|
||||
if jsonResp != nil {
|
||||
t.Fatalf("JSONResponse is not expected: %+v", jsonResp)
|
||||
}
|
||||
deviceRes := uapi.QueryDevicesResponse{}
|
||||
if err := userAPI.QueryDevices(ctx, &uapi.QueryDevicesRequest{
|
||||
UserID: bobUser.ID,
|
||||
}, &deviceRes); err != nil {
|
||||
t.Errorf("failed to query user devices")
|
||||
}
|
||||
if !deviceRes.UserExists {
|
||||
t.Fatalf("user does not exist")
|
||||
}
|
||||
if l := len(deviceRes.Devices); l != 1 {
|
||||
t.Fatalf("Incorrect number of user devices. Got %d, want 1", l)
|
||||
}
|
||||
if device.ID != deviceRes.Devices[0].ID {
|
||||
t.Fatalf("Device IDs do not match: %s != %s", device.ID, deviceRes.Devices[0].ID)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
|
@ -16,6 +16,7 @@ import (
|
|||
"github.com/element-hq/dendrite/setup"
|
||||
"github.com/element-hq/dendrite/setup/config"
|
||||
"github.com/element-hq/dendrite/setup/mscs/msc2836"
|
||||
"github.com/element-hq/dendrite/setup/mscs/msc3861"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
@ -34,9 +35,11 @@ func Enable(cfg *config.Dendrite, cm *sqlutil.Connections, routers httputil.Rout
|
|||
func EnableMSC(cfg *config.Dendrite, cm *sqlutil.Connections, routers httputil.Routers, monolith *setup.Monolith, msc string, caches *caching.Caches) error {
|
||||
switch msc {
|
||||
case "msc2836":
|
||||
return msc2836.Enable(cfg, cm, routers, monolith.RoomserverAPI, monolith.FederationAPI, monolith.UserAPI, monolith.KeyRing)
|
||||
return msc2836.Enable(cfg, cm, routers, monolith.RoomserverAPI, monolith.FederationAPI, monolith.UserVerifierProvider, monolith.KeyRing)
|
||||
case "msc2444": // enabled inside federationapi
|
||||
case "msc2753": // enabled inside clientapi
|
||||
case "msc3861":
|
||||
return msc3861.Enable(monolith)
|
||||
default:
|
||||
logrus.Warnf("EnableMSC: unknown MSC '%s', this MSC is either not supported or is natively supported by Dendrite", msc)
|
||||
}
|
||||
|
|
|
@ -36,16 +36,17 @@ func Setup(
|
|||
lazyLoadCache caching.LazyLoadCache,
|
||||
fts fulltext.Indexer,
|
||||
rateLimits *httputil.RateLimits,
|
||||
userVerifier httputil.UserVerifier,
|
||||
) {
|
||||
v1unstablemux := csMux.PathPrefix("/{apiversion:(?:v1|unstable)}/").Subrouter()
|
||||
v3mux := csMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter()
|
||||
|
||||
// TODO: Add AS support for all handlers below.
|
||||
v3mux.Handle("/sync", httputil.MakeAuthAPI("sync", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
v3mux.Handle("/sync", httputil.MakeAuthAPI("sync", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return srp.OnIncomingSyncRequest(req, device)
|
||||
}, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
v3mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
// not specced, but ensure we're rate limiting requests to this endpoint
|
||||
if r := rateLimits.Limit(req, device); r != nil {
|
||||
return *r
|
||||
|
@ -58,7 +59,7 @@ func Setup(
|
|||
}, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/rooms/{roomID}/event/{eventID}",
|
||||
httputil.MakeAuthAPI("rooms_get_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
httputil.MakeAuthAPI("rooms_get_event", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
|
@ -68,7 +69,7 @@ func Setup(
|
|||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/user/{userId}/filter",
|
||||
httputil.MakeAuthAPI("put_filter", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
httputil.MakeAuthAPI("put_filter", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
|
@ -78,7 +79,7 @@ func Setup(
|
|||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/user/{userId}/filter/{filterId}",
|
||||
httputil.MakeAuthAPI("get_filter", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
httputil.MakeAuthAPI("get_filter", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
|
@ -87,12 +88,12 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/keys/changes", httputil.MakeAuthAPI("keys_changes", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
v3mux.Handle("/keys/changes", httputil.MakeAuthAPI("keys_changes", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return srp.OnIncomingKeyChangeRequest(req, device)
|
||||
}, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/rooms/{roomId}/context/{eventId}",
|
||||
httputil.MakeAuthAPI("context", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
httputil.MakeAuthAPI("context", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
|
@ -108,7 +109,7 @@ func Setup(
|
|||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}",
|
||||
httputil.MakeAuthAPI("relations", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
httputil.MakeAuthAPI("relations", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
|
@ -122,7 +123,7 @@ func Setup(
|
|||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}/{relType}",
|
||||
httputil.MakeAuthAPI("relation_type", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
httputil.MakeAuthAPI("relation_type", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
|
@ -136,7 +137,7 @@ func Setup(
|
|||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}",
|
||||
httputil.MakeAuthAPI("relation_type_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
httputil.MakeAuthAPI("relation_type_event", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
|
@ -150,7 +151,7 @@ func Setup(
|
|||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/search",
|
||||
httputil.MakeAuthAPI("search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
httputil.MakeAuthAPI("search", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if !cfg.Fulltext.Enabled {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusNotImplemented,
|
||||
|
@ -173,7 +174,7 @@ func Setup(
|
|||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/rooms/{roomID}/members",
|
||||
httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
httputil.MakeAuthAPI("rooms_members", userVerifier, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
|
|
|
@ -42,6 +42,7 @@ func AddPublicRoutes(
|
|||
userAPI userapi.SyncUserAPI,
|
||||
rsAPI api.SyncRoomserverAPI,
|
||||
caches caching.LazyLoadCache,
|
||||
userVerifier httputil.UserVerifier,
|
||||
enableMetrics bool,
|
||||
) {
|
||||
js, natsClient := natsInstance.Prepare(processContext, &dendriteCfg.Global.JetStream)
|
||||
|
@ -149,5 +150,6 @@ func AddPublicRoutes(
|
|||
routers.Client, requestPool, syncDB, userAPI,
|
||||
rsAPI, &dendriteCfg.SyncAPI, caches, fts,
|
||||
rateLimits,
|
||||
userVerifier,
|
||||
)
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ import (
|
|||
"github.com/gorilla/mux"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tidwall/gjson"
|
||||
|
@ -119,6 +120,20 @@ func (s *syncUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.Pe
|
|||
return nil
|
||||
}
|
||||
|
||||
type mockUserVerifier struct {
|
||||
accessTokenToDeviceAndResponse map[string]struct {
|
||||
Device *userapi.Device
|
||||
Response *util.JSONResponse
|
||||
}
|
||||
}
|
||||
|
||||
func (u *mockUserVerifier) VerifyUserFromRequest(req *http.Request) (*userapi.Device, *util.JSONResponse) {
|
||||
if pair, ok := u.accessTokenToDeviceAndResponse[req.URL.Query().Get("access_token")]; ok {
|
||||
return pair.Device, pair.Response
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestSyncAPIAccessTokens(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
testSyncAccessTokens(t, dbType)
|
||||
|
@ -146,12 +161,16 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) {
|
|||
jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream)
|
||||
defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream)
|
||||
msgs := toNATSMsgs(t, cfg, room.Events()...)
|
||||
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, caching.DisableMetrics)
|
||||
uv := &mockUserVerifier{}
|
||||
|
||||
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, uv, caching.DisableMetrics)
|
||||
testrig.MustPublishMsgs(t, jsctx, msgs...)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
req *http.Request
|
||||
device *userapi.Device
|
||||
response *util.JSONResponse
|
||||
wantCode int
|
||||
wantJoinedRooms []string
|
||||
}{
|
||||
|
@ -160,6 +179,11 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) {
|
|||
req: test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{
|
||||
"timeout": "0",
|
||||
})),
|
||||
device: nil,
|
||||
response: &util.JSONResponse{
|
||||
Code: http.StatusUnauthorized,
|
||||
JSON: spec.UnknownToken("Unknown token"),
|
||||
},
|
||||
wantCode: 401,
|
||||
},
|
||||
{
|
||||
|
@ -168,6 +192,11 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) {
|
|||
"access_token": "foo",
|
||||
"timeout": "0",
|
||||
})),
|
||||
device: nil,
|
||||
response: &util.JSONResponse{
|
||||
Code: http.StatusUnauthorized,
|
||||
JSON: spec.UnknownToken("Unknown token"),
|
||||
},
|
||||
wantCode: 401,
|
||||
},
|
||||
{
|
||||
|
@ -176,11 +205,25 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) {
|
|||
"access_token": alice.AccessToken,
|
||||
"timeout": "0",
|
||||
})),
|
||||
device: &alice,
|
||||
response: nil,
|
||||
wantCode: 200,
|
||||
wantJoinedRooms: []string{room.ID},
|
||||
},
|
||||
}
|
||||
|
||||
uv.accessTokenToDeviceAndResponse = make(map[string]struct {
|
||||
Device *userapi.Device
|
||||
Response *util.JSONResponse
|
||||
}, len(testCases))
|
||||
for _, tc := range testCases {
|
||||
|
||||
uv.accessTokenToDeviceAndResponse[tc.req.URL.Query().Get("access_token")] = struct {
|
||||
Device *userapi.Device
|
||||
Response *util.JSONResponse
|
||||
}{Device: tc.device, Response: tc.response}
|
||||
}
|
||||
|
||||
syncUntil(t, routers, alice.AccessToken, false, func(syncBody string) bool {
|
||||
// wait for the last sent eventID to come down sync
|
||||
path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(event_id=="%s")`, room.ID, room.Events()[len(room.Events())-1].EventID())
|
||||
|
@ -241,12 +284,20 @@ func testSyncEventFormatPowerLevels(t *testing.T, dbType test.DBType) {
|
|||
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
||||
natsInstance := jetstream.NATSInstance{}
|
||||
uv := mockUserVerifier{
|
||||
accessTokenToDeviceAndResponse: map[string]struct {
|
||||
Device *userapi.Device
|
||||
Response *util.JSONResponse
|
||||
}{
|
||||
alice.AccessToken: {Device: &alice, Response: nil},
|
||||
},
|
||||
}
|
||||
defer close()
|
||||
|
||||
jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream)
|
||||
defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream)
|
||||
msgs := toNATSMsgs(t, cfg, room.Events()...)
|
||||
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, caching.DisableMetrics)
|
||||
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, &uv, caching.DisableMetrics)
|
||||
testrig.MustPublishMsgs(t, jsctx, msgs...)
|
||||
|
||||
testCases := []struct {
|
||||
|
@ -399,7 +450,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
|
|||
// m.room.history_visibility
|
||||
msgs := toNATSMsgs(t, cfg, room.Events()...)
|
||||
sinceTokens := make([]string, len(msgs))
|
||||
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, caching.DisableMetrics)
|
||||
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, nil, caching.DisableMetrics)
|
||||
for i, msg := range msgs {
|
||||
testrig.MustPublishMsgs(t, jsctx, msg)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
@ -487,7 +538,15 @@ func testSyncAPIUpdatePresenceImmediately(t *testing.T, dbType test.DBType) {
|
|||
|
||||
jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream)
|
||||
defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream)
|
||||
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, caches, caching.DisableMetrics)
|
||||
uv := mockUserVerifier{
|
||||
accessTokenToDeviceAndResponse: map[string]struct {
|
||||
Device *userapi.Device
|
||||
Response *util.JSONResponse
|
||||
}{
|
||||
alice.AccessToken: {Device: &alice, Response: nil},
|
||||
},
|
||||
}
|
||||
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, caches, &uv, caching.DisableMetrics)
|
||||
w := httptest.NewRecorder()
|
||||
routers.Client.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{
|
||||
"access_token": alice.AccessToken,
|
||||
|
@ -609,7 +668,16 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) {
|
|||
// Use the actual internal roomserver API
|
||||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, caches, caching.DisableMetrics)
|
||||
uv := mockUserVerifier{
|
||||
accessTokenToDeviceAndResponse: map[string]struct {
|
||||
Device *userapi.Device
|
||||
Response *util.JSONResponse
|
||||
}{
|
||||
aliceDev.AccessToken: {Device: &aliceDev, Response: nil},
|
||||
bobDev.AccessToken: {Device: &bobDev, Response: nil},
|
||||
},
|
||||
}
|
||||
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, caches, &uv, caching.DisableMetrics)
|
||||
|
||||
for _, tc := range testCases {
|
||||
testname := fmt.Sprintf("%s - %s", tc.historyVisibility, userType)
|
||||
|
@ -878,8 +946,17 @@ func TestGetMembership(t *testing.T) {
|
|||
// Use an actual roomserver for this
|
||||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
uv := mockUserVerifier{
|
||||
accessTokenToDeviceAndResponse: map[string]struct {
|
||||
Device *userapi.Device
|
||||
Response *util.JSONResponse
|
||||
}{
|
||||
aliceDev.AccessToken: {Device: &aliceDev, Response: nil},
|
||||
bobDev.AccessToken: {Device: &bobDev, Response: nil},
|
||||
},
|
||||
}
|
||||
|
||||
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, caches, caching.DisableMetrics)
|
||||
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, caches, &uv, caching.DisableMetrics)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
@ -946,10 +1023,18 @@ func testSendToDevice(t *testing.T, dbType test.DBType) {
|
|||
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
||||
defer close()
|
||||
natsInstance := jetstream.NATSInstance{}
|
||||
uv := mockUserVerifier{
|
||||
accessTokenToDeviceAndResponse: map[string]struct {
|
||||
Device *userapi.Device
|
||||
Response *util.JSONResponse
|
||||
}{
|
||||
alice.AccessToken: {Device: &alice, Response: nil},
|
||||
},
|
||||
}
|
||||
|
||||
jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream)
|
||||
defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream)
|
||||
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, caches, caching.DisableMetrics)
|
||||
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, caches, &uv, caching.DisableMetrics)
|
||||
|
||||
producer := producers.SyncAPIProducer{
|
||||
TopicSendToDeviceEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent),
|
||||
|
@ -1172,7 +1257,16 @@ func testContext(t *testing.T, dbType test.DBType) {
|
|||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
|
||||
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, rsAPI, caches, caching.DisableMetrics)
|
||||
uv := mockUserVerifier{
|
||||
accessTokenToDeviceAndResponse: map[string]struct {
|
||||
Device *userapi.Device
|
||||
Response *util.JSONResponse
|
||||
}{
|
||||
alice.AccessToken: {Device: &alice, Response: nil},
|
||||
},
|
||||
}
|
||||
|
||||
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, rsAPI, caches, &uv, caching.DisableMetrics)
|
||||
|
||||
room := test.NewRoom(t, user)
|
||||
|
||||
|
@ -1351,9 +1445,17 @@ func TestRemoveEditedEventFromSearchIndex(t *testing.T) {
|
|||
|
||||
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
uv := mockUserVerifier{
|
||||
accessTokenToDeviceAndResponse: map[string]struct {
|
||||
Device *userapi.Device
|
||||
Response *util.JSONResponse
|
||||
}{
|
||||
alice.AccessToken: {Device: &alice, Response: nil},
|
||||
},
|
||||
}
|
||||
|
||||
room := test.NewRoom(t, user)
|
||||
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, caching.DisableMetrics)
|
||||
AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, caches, &uv, caching.DisableMetrics)
|
||||
|
||||
if err := api.SendEvents(processCtx.Context(), rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
|
||||
t.Fatalf("failed to send events: %v", err)
|
||||
|
@ -1416,6 +1518,7 @@ func searchRequest(t *testing.T, router *mux.Router, accessToken, searchTerm str
|
|||
assert.NoError(t, err)
|
||||
return body
|
||||
}
|
||||
|
||||
func syncUntil(t *testing.T,
|
||||
routers httputil.Routers, accessToken string,
|
||||
skip bool,
|
||||
|
|
|
@ -31,7 +31,8 @@ type UserInternalAPI interface {
|
|||
FederationUserAPI
|
||||
|
||||
QuerySearchProfilesAPI // used by p2p demos
|
||||
QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
|
||||
QueryExternalUserIDByLocalpartAndProvider(ctx context.Context, externalID, authProvider string) (*LocalpartExternalID, error)
|
||||
PerformLocalpartExternalUserIDCreation(ctx context.Context, localpart, externalID, authProvider string) error
|
||||
}
|
||||
|
||||
// api functions required by the appservice api
|
||||
|
@ -47,7 +48,7 @@ type RoomserverUserAPI interface {
|
|||
|
||||
// api functions required by the media api
|
||||
type MediaUserAPI interface {
|
||||
QueryAcccessTokenAPI
|
||||
QueryAccessTokenAPI
|
||||
}
|
||||
|
||||
// api functions required by the federation api
|
||||
|
@ -64,7 +65,7 @@ type FederationUserAPI interface {
|
|||
|
||||
// api functions required by the sync api
|
||||
type SyncUserAPI interface {
|
||||
QueryAcccessTokenAPI
|
||||
QueryAccessTokenAPI
|
||||
SyncKeyAPI
|
||||
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
||||
PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error
|
||||
|
@ -75,7 +76,7 @@ type SyncUserAPI interface {
|
|||
|
||||
// api functions required by the client api
|
||||
type ClientUserAPI interface {
|
||||
QueryAcccessTokenAPI
|
||||
QueryAccessTokenAPI
|
||||
LoginTokenInternalAPI
|
||||
UserLoginAPI
|
||||
ClientKeyAPI
|
||||
|
@ -87,6 +88,7 @@ type ClientUserAPI interface {
|
|||
QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error
|
||||
QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error)
|
||||
QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error
|
||||
QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
|
||||
PerformAdminCreateRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error)
|
||||
PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error)
|
||||
PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error)
|
||||
|
@ -109,6 +111,7 @@ type ClientUserAPI interface {
|
|||
QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error
|
||||
PerformForgetThreePID(ctx context.Context, req *PerformForgetThreePIDRequest, res *struct{}) error
|
||||
PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error
|
||||
PerformBulkSaveThreePIDAssociation(ctx context.Context, req *PerformBulkSaveThreePIDAssociationRequest, res *struct{}) error
|
||||
}
|
||||
|
||||
type KeyBackupAPI interface {
|
||||
|
@ -130,7 +133,7 @@ type QuerySearchProfilesAPI interface {
|
|||
}
|
||||
|
||||
// common function for creating authenticated endpoints (used in client/media/sync api)
|
||||
type QueryAcccessTokenAPI interface {
|
||||
type QueryAccessTokenAPI interface {
|
||||
QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error
|
||||
}
|
||||
|
||||
|
@ -316,6 +319,9 @@ type PerformAccountCreationRequest struct {
|
|||
Localpart string // Required: The localpart for this account. Ignored if account type is guest.
|
||||
ServerName spec.ServerName // optional: if not specified, default server name used instead
|
||||
|
||||
DisplayName string // optional: this is populated only by MAS. In the legacy flow it's not used
|
||||
AvatarURL string // optional: this is populated only by MAS. In the legacy flow it's not used
|
||||
|
||||
AppServiceID string // optional: the application service ID (not user ID) creating this account, if any.
|
||||
Password string // optional: if missing then this account will be a passwordless account
|
||||
OnConflict Conflict
|
||||
|
@ -375,6 +381,11 @@ type PerformDeviceCreationRequest struct {
|
|||
// FromRegistration determines if this request comes from registering a new account
|
||||
// and is in most cases false.
|
||||
FromRegistration bool
|
||||
|
||||
// AccessTokenUniqueConstraintDisabled determines if unique constraint is applicable for the AccessToken.
|
||||
// It is false if an external auth service is in use (e.g. MAS) and server does not generate its own
|
||||
// auth tokens. Otherwise, if traditional login is in use, the value is true. Default is false.
|
||||
AccessTokenUniqueConstraintDisabled bool
|
||||
}
|
||||
|
||||
// PerformDeviceCreationResponse is the response for PerformDeviceCreation
|
||||
|
@ -455,6 +466,7 @@ type Account struct {
|
|||
ServerName spec.ServerName
|
||||
AppServiceID string
|
||||
AccountType AccountType
|
||||
Deactivated bool
|
||||
// TODO: Associations (e.g. with application services)
|
||||
}
|
||||
|
||||
|
@ -471,6 +483,14 @@ type OpenIDTokenAttributes struct {
|
|||
ExpiresAtMS int64
|
||||
}
|
||||
|
||||
// LocalpartExternalID represents a connection between Matrix account and OpenID Connect provider
|
||||
type LocalpartExternalID struct {
|
||||
Localpart string
|
||||
ExternalID string
|
||||
AuthProvider string
|
||||
CreatedTS int64
|
||||
}
|
||||
|
||||
// UserInfo is for returning information about the user an OpenID token was issued for
|
||||
type UserInfo struct {
|
||||
Sub string // The Matrix user's ID who generated the token
|
||||
|
@ -514,6 +534,8 @@ const (
|
|||
AccountTypeAdmin AccountType = 3
|
||||
// AccountTypeAppService indicates this is an appservice account
|
||||
AccountTypeAppService AccountType = 4
|
||||
// AccountTypeOIDC indicates this is an account belonging to Matrix Authentication Service (MAS)
|
||||
AccountTypeOIDCService AccountType = 5
|
||||
)
|
||||
|
||||
type QueryPushersRequest struct {
|
||||
|
@ -636,6 +658,12 @@ type PerformSaveThreePIDAssociationRequest struct {
|
|||
Medium string
|
||||
}
|
||||
|
||||
type PerformBulkSaveThreePIDAssociationRequest struct {
|
||||
ThreePIDs []authtypes.ThreePID
|
||||
Localpart string
|
||||
ServerName spec.ServerName
|
||||
}
|
||||
|
||||
type QueryAccountByLocalpartRequest struct {
|
||||
Localpart string
|
||||
ServerName spec.ServerName
|
||||
|
|
|
@ -455,10 +455,11 @@ func (a *UserInternalAPI) processOtherSignatures(
|
|||
func (a *UserInternalAPI) crossSigningKeysFromDatabase(
|
||||
ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse,
|
||||
) {
|
||||
logger := logrus.WithContext(ctx)
|
||||
for targetUserID := range req.UserToDevices {
|
||||
keys, err := a.KeyDatabase.CrossSigningKeysForUser(ctx, targetUserID)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Errorf("Failed to get cross-signing keys for user %q", targetUserID)
|
||||
logger.WithError(err).Errorf("Failed to get cross-signing keys for user %q", targetUserID)
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -471,7 +472,7 @@ func (a *UserInternalAPI) crossSigningKeysFromDatabase(
|
|||
|
||||
sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, keyID)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
logrus.WithError(err).Errorf("Failed to get cross-signing signatures for user %q key %q", targetUserID, keyID)
|
||||
logger.WithError(err).Errorf("Failed to get cross-signing signatures for user %q key %q", targetUserID, keyID)
|
||||
continue
|
||||
}
|
||||
|
||||
|
|
|
@ -272,7 +272,7 @@ func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReque
|
|||
DeviceIDs: dids,
|
||||
}, &queryRes)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).Warnf("Failed to QueryDeviceInfos for device IDs, display names will be missing")
|
||||
util.GetLogger(ctx).WithError(err).Warnf("Failed to QueryDeviceInfos for device IDs, display names will be missing")
|
||||
}
|
||||
|
||||
if res.DeviceKeys[userID] == nil {
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
package internal
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
|
@ -247,10 +248,17 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
|
|||
return nil
|
||||
}
|
||||
|
||||
if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, serverName, req.Localpart); err != nil {
|
||||
displayName := cmp.Or(req.DisplayName, req.Localpart)
|
||||
if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, serverName, displayName); err != nil {
|
||||
return fmt.Errorf("a.DB.SetDisplayName: %w", err)
|
||||
}
|
||||
|
||||
if req.AvatarURL != "" {
|
||||
if _, _, err := a.DB.SetAvatarURL(ctx, req.Localpart, serverName, req.AvatarURL); err != nil {
|
||||
return fmt.Errorf("a.DB.SetAvatarURL: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
postRegisterJoinRooms(a.Config, acc, a.RSAPI)
|
||||
|
||||
res.AccountCreated = true
|
||||
|
@ -298,6 +306,15 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
|
|||
"device_id": req.DeviceID,
|
||||
"display_name": req.DeviceDisplayName,
|
||||
}).Info("PerformDeviceCreation")
|
||||
if !req.AccessTokenUniqueConstraintDisabled {
|
||||
dev, err := a.DB.GetDeviceByAccessToken(ctx, req.AccessToken)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return err
|
||||
}
|
||||
if dev.UserID != "" {
|
||||
return errors.New("unique constraint violation. Access token is not unique" + dev.AccessToken)
|
||||
}
|
||||
}
|
||||
dev, err := a.DB.CreateDevice(ctx, req.Localpart, serverName, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -594,6 +611,14 @@ func (a *UserInternalAPI) QueryAccountByLocalpart(ctx context.Context, req *api.
|
|||
return
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformLocalpartExternalUserIDCreation(ctx context.Context, localpart, externalID, authProvider string) (err error) {
|
||||
return a.DB.CreateLocalpartExternalID(ctx, localpart, externalID, authProvider)
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryExternalUserIDByLocalpartAndProvider(ctx context.Context, externalID, authProvider string) (*api.LocalpartExternalID, error) {
|
||||
return a.DB.GetLocalpartForExternalID(ctx, externalID, authProvider)
|
||||
}
|
||||
|
||||
// Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem
|
||||
// creating a 'device'.
|
||||
func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appServiceUserID string) (*api.Device, error) {
|
||||
|
@ -970,4 +995,8 @@ func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, re
|
|||
return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.ServerName, req.Medium)
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformBulkSaveThreePIDAssociation(ctx context.Context, req *api.PerformBulkSaveThreePIDAssociationRequest, res *struct{}) error {
|
||||
return a.DB.BulkSaveThreePIDAssociation(ctx, req.ThreePIDs, req.Localpart, req.ServerName)
|
||||
}
|
||||
|
||||
const pushRulesAccountDataType = "m.push_rules"
|
||||
|
|
|
@ -119,6 +119,7 @@ type Pusher interface {
|
|||
|
||||
type ThreePID interface {
|
||||
SaveThreePIDAssociation(ctx context.Context, threepid, localpart string, serverName spec.ServerName, medium string) (err error)
|
||||
BulkSaveThreePIDAssociation(ctx context.Context, threePIDs []authtypes.ThreePID, localpart string, serverName spec.ServerName) (err error)
|
||||
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
|
||||
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, serverName spec.ServerName, err error)
|
||||
GetThreePIDsForLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (threepids []authtypes.ThreePID, err error)
|
||||
|
@ -134,6 +135,12 @@ type Notification interface {
|
|||
DeleteOldNotifications(ctx context.Context) error
|
||||
}
|
||||
|
||||
type LocalpartExternalID interface {
|
||||
CreateLocalpartExternalID(ctx context.Context, localpart, externalID, authProvider string) error
|
||||
GetLocalpartForExternalID(ctx context.Context, externalID, authProvider string) (*api.LocalpartExternalID, error)
|
||||
DeleteLocalpartExternalID(ctx context.Context, externalID, authProvider string) error
|
||||
}
|
||||
|
||||
type UserDatabase interface {
|
||||
Account
|
||||
AccountData
|
||||
|
@ -147,6 +154,7 @@ type UserDatabase interface {
|
|||
Statistics
|
||||
ThreePID
|
||||
RegistrationTokens
|
||||
LocalpartExternalID
|
||||
}
|
||||
|
||||
type KeyChangeDatabase interface {
|
||||
|
|
|
@ -55,7 +55,7 @@ const deactivateAccountSQL = "" +
|
|||
"UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const selectAccountByLocalpartSQL = "" +
|
||||
"SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2"
|
||||
"SELECT localpart, server_name, appservice_id, account_type, is_deactivated FROM userapi_accounts WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const selectPasswordHashSQL = "" +
|
||||
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = FALSE"
|
||||
|
@ -116,7 +116,7 @@ func (s *accountsStatements) InsertAccount(
|
|||
localpart string, serverName spec.ServerName,
|
||||
hash, appserviceID string, accountType api.AccountType,
|
||||
) (*api.Account, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
createdTimeMS := spec.AsTimestamp(time.Now())
|
||||
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
|
||||
|
||||
var err error
|
||||
|
@ -135,6 +135,7 @@ func (s *accountsStatements) InsertAccount(
|
|||
ServerName: serverName,
|
||||
AppServiceID: appserviceID,
|
||||
AccountType: accountType,
|
||||
Deactivated: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -167,7 +168,7 @@ func (s *accountsStatements) SelectAccountByLocalpart(
|
|||
var acc api.Account
|
||||
|
||||
stmt := s.selectAccountByLocalpartStmt
|
||||
err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType)
|
||||
err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType, &acc.Deactivated)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
log.WithError(err).Error("Unable to retrieve user from the db")
|
||||
|
|
|
@ -51,6 +51,11 @@ func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, erro
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := sqlutil.NewMigrator(db)
|
||||
err = m.Up(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL},
|
||||
{&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL},
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
// Copyright 2025 New Vector Ltd.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
|
||||
// Please see LICENSE files in the repository root for full details.
|
||||
|
||||
package deltas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func UpDropPrimaryKeyConstraint(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
ALTER TABLE userapi_devices DROP CONSTRAINT userapi_devices_pkey;`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DownDropPrimaryKeyConstraint(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
ALTER TABLE userapi_devices ADD CONSTRAINT userapi_devices_pkey PRIMARY KEY (access_token);`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -116,10 +116,16 @@ func NewPostgresDevicesTable(db *sql.DB, serverName spec.ServerName) (tables.Dev
|
|||
return nil, err
|
||||
}
|
||||
m := sqlutil.NewMigrator(db)
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: "userapi: add last_seen_ts",
|
||||
Up: deltas.UpLastSeenTSIP,
|
||||
})
|
||||
m.AddMigrations(
|
||||
sqlutil.Migration{
|
||||
Version: "userapi: add last_seen_ts",
|
||||
Up: deltas.UpLastSeenTSIP,
|
||||
},
|
||||
sqlutil.Migration{
|
||||
Version: "userapi: drop primary key constraint",
|
||||
Up: deltas.UpDropPrimaryKeyConstraint,
|
||||
},
|
||||
)
|
||||
err = m.Up(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
102
userapi/storage/postgres/localpart_external_ids_table.go
Normal file
102
userapi/storage/postgres/localpart_external_ids_table.go
Normal file
|
@ -0,0 +1,102 @@
|
|||
// Copyright 2025 New Vector Ltd.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
|
||||
// Please see LICENSE files in the repository root for full details.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/element-hq/dendrite/internal/sqlutil"
|
||||
"github.com/element-hq/dendrite/userapi/api"
|
||||
"github.com/element-hq/dendrite/userapi/storage/tables"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const localpartExternalIDsSchema = `
|
||||
-- Stores data about connections between accounts and third-party auth providers
|
||||
CREATE TABLE IF NOT EXISTS userapi_localpart_external_ids (
|
||||
-- The Matrix user ID for this account
|
||||
localpart TEXT NOT NULL,
|
||||
-- The external ID
|
||||
external_id TEXT NOT NULL,
|
||||
-- Auth provider ID (see OIDCProvider.IDPID)
|
||||
auth_provider TEXT NOT NULL,
|
||||
-- When this connection was created, as a unix timestamp.
|
||||
created_ts BIGINT NOT NULL,
|
||||
|
||||
CONSTRAINT userapi_localpart_external_ids_external_id_auth_provider_unique UNIQUE(external_id, auth_provider),
|
||||
CONSTRAINT userapi_localpart_external_ids_localpart_external_id_auth_provider_unique UNIQUE(localpart, external_id, auth_provider)
|
||||
);
|
||||
|
||||
-- This index allows efficient lookup of the local user by the external ID
|
||||
CREATE INDEX IF NOT EXISTS userapi_external_id_auth_provider_idx ON userapi_localpart_external_ids(external_id, auth_provider);
|
||||
`
|
||||
|
||||
const insertUserExternalIDSQL = "" +
|
||||
"INSERT INTO userapi_localpart_external_ids(localpart, external_id, auth_provider, created_ts) VALUES ($1, $2, $3, $4)"
|
||||
|
||||
const selectUserExternalIDSQL = "" +
|
||||
"SELECT localpart, created_ts FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2"
|
||||
|
||||
const deleteUserExternalIDSQL = "" +
|
||||
"DELETE FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2"
|
||||
|
||||
type localpartExternalIDStatements struct {
|
||||
db *sql.DB
|
||||
insertUserExternalIDStmt *sql.Stmt
|
||||
selectUserExternalIDStmt *sql.Stmt
|
||||
deleteUserExternalIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresLocalpartExternalIDsTable(db *sql.DB) (tables.LocalpartExternalIDsTable, error) {
|
||||
s := &localpartExternalIDStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(localpartExternalIDsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertUserExternalIDStmt, insertUserExternalIDSQL},
|
||||
{&s.selectUserExternalIDStmt, selectUserExternalIDSQL},
|
||||
{&s.deleteUserExternalIDStmt, deleteUserExternalIDSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
// SelectLocalExternalPartID selects an existing OpenID Connect connection from the database
|
||||
func (u *localpartExternalIDStatements) SelectLocalExternalPartID(ctx context.Context, txn *sql.Tx, externalID, authProvider string) (*api.LocalpartExternalID, error) {
|
||||
ret := api.LocalpartExternalID{
|
||||
ExternalID: externalID,
|
||||
AuthProvider: authProvider,
|
||||
}
|
||||
err := u.selectUserExternalIDStmt.QueryRowContext(ctx, externalID, authProvider).Scan(
|
||||
&ret.Localpart, &ret.CreatedTS,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
log.WithError(err).Error("Unable to retrieve localpart from the db")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ret, nil
|
||||
}
|
||||
|
||||
// InsertLocalExternalPartID creates a new record representing an OpenID Connect connection between Matrix and external accounts.
|
||||
func (u *localpartExternalIDStatements) InsertLocalExternalPartID(ctx context.Context, txn *sql.Tx, localpart, externalID, authProvider string) error {
|
||||
stmt := sqlutil.TxStmt(txn, u.insertUserExternalIDStmt)
|
||||
_, err := stmt.ExecContext(ctx, localpart, externalID, authProvider, time.Now().Unix())
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteLocalExternalPartID deletes the existing OpenID Connect connection. After this method is called, the Matrix account will no longer be associated with the external account.
|
||||
func (u *localpartExternalIDStatements) DeleteLocalExternalPartID(ctx context.Context, txn *sql.Tx, externalID, authProvider string) error {
|
||||
stmt := sqlutil.TxStmt(txn, u.deleteUserExternalIDStmt)
|
||||
_, err := stmt.ExecContext(ctx, externalID, authProvider)
|
||||
return err
|
||||
}
|
|
@ -97,6 +97,10 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresStatsTable: %w", err)
|
||||
}
|
||||
localpartExternalIDsTable, err := NewPostgresLocalpartExternalIDsTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteLocalpartExternalIDsTable: %w", err)
|
||||
}
|
||||
|
||||
m = sqlutil.NewMigrator(db)
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
|
@ -123,6 +127,7 @@ func NewDatabase(ctx context.Context, conMan *sqlutil.Connections, dbProperties
|
|||
Notifications: notificationsTable,
|
||||
RegistrationTokens: registationTokensTable,
|
||||
Stats: statsTable,
|
||||
LocalpartExternalIDs: localpartExternalIDsTable,
|
||||
ServerName: serverName,
|
||||
DB: db,
|
||||
Writer: writer,
|
||||
|
|
|
@ -49,6 +49,7 @@ type Database struct {
|
|||
Notifications tables.NotificationTable
|
||||
Pushers tables.PusherTable
|
||||
Stats tables.StatsTable
|
||||
LocalpartExternalIDs tables.LocalpartExternalIDsTable
|
||||
LoginTokenLifetime time.Duration
|
||||
ServerName spec.ServerName
|
||||
BcryptCost int
|
||||
|
@ -352,6 +353,41 @@ func (d *Database) SaveThreePIDAssociation(
|
|||
})
|
||||
}
|
||||
|
||||
// BulkSaveThreePIDAssociation recreates 3PIDs for a user.
|
||||
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
|
||||
// Returns an error if there was a problem talking to the database.
|
||||
func (d *Database) BulkSaveThreePIDAssociation(ctx context.Context, threePIDs []authtypes.ThreePID, localpart string, serverName spec.ServerName) (err error) {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
oldThreePIDs, err := d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart, serverName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, t := range oldThreePIDs {
|
||||
if err := d.ThreePIDs.DeleteThreePID(ctx, txn, t.Address, t.Medium); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for _, t := range threePIDs {
|
||||
// if 3PID is associated with another user, return Err3PIDInUse
|
||||
user, _, err := d.ThreePIDs.SelectLocalpartForThreePID(
|
||||
ctx, txn, t.Address, t.Medium,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(user) > 0 && user != localpart {
|
||||
return Err3PIDInUse
|
||||
}
|
||||
|
||||
if err = d.ThreePIDs.InsertThreePID(ctx, txn, t.Address, t.Medium, localpart, serverName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveThreePIDAssociation removes the association involving a given third-party
|
||||
// identifier.
|
||||
// If no association exists involving this third-party identifier, returns nothing.
|
||||
|
@ -870,6 +906,18 @@ func (d *Database) UpsertPusher(
|
|||
})
|
||||
}
|
||||
|
||||
func (d *Database) CreateLocalpartExternalID(ctx context.Context, localpart, externalID, authProvider string) error {
|
||||
return d.LocalpartExternalIDs.InsertLocalExternalPartID(ctx, nil, localpart, externalID, authProvider)
|
||||
}
|
||||
|
||||
func (d *Database) GetLocalpartForExternalID(ctx context.Context, externalID, authProvider string) (*api.LocalpartExternalID, error) {
|
||||
return d.LocalpartExternalIDs.SelectLocalExternalPartID(ctx, nil, externalID, authProvider)
|
||||
}
|
||||
|
||||
func (d *Database) DeleteLocalpartExternalID(ctx context.Context, externalID, authProvider string) error {
|
||||
return d.LocalpartExternalIDs.DeleteLocalExternalPartID(ctx, nil, externalID, authProvider)
|
||||
}
|
||||
|
||||
// GetPushers returns the pushers matching the given localpart.
|
||||
func (d *Database) GetPushers(
|
||||
ctx context.Context, localpart string, serverName spec.ServerName,
|
||||
|
@ -1132,8 +1180,8 @@ func (d *KeyDatabase) CrossSigningSigsForTarget(ctx context.Context, originUserI
|
|||
// StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user.
|
||||
func (d *KeyDatabase) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
for keyType, keyData := range keyMap {
|
||||
if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil {
|
||||
for keyType, key := range keyMap {
|
||||
if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, key); err != nil {
|
||||
return fmt.Errorf("d.CrossSigningKeysTable.InsertCrossSigningKeysForUser: %w", err)
|
||||
}
|
||||
}
|
||||
|
@ -1141,7 +1189,7 @@ func (d *KeyDatabase) StoreCrossSigningKeysForUser(ctx context.Context, userID s
|
|||
})
|
||||
}
|
||||
|
||||
// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/dvice.
|
||||
// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/device.
|
||||
func (d *KeyDatabase) StoreCrossSigningSigsForTarget(
|
||||
ctx context.Context,
|
||||
originUserID string, originKeyID gomatrixserverlib.KeyID,
|
||||
|
|
|
@ -54,7 +54,7 @@ const deactivateAccountSQL = "" +
|
|||
"UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const selectAccountByLocalpartSQL = "" +
|
||||
"SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2"
|
||||
"SELECT localpart, server_name, appservice_id, account_type, is_deactivated FROM userapi_accounts WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const selectPasswordHashSQL = "" +
|
||||
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = 0"
|
||||
|
@ -116,7 +116,7 @@ func (s *accountsStatements) InsertAccount(
|
|||
ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName,
|
||||
hash, appserviceID string, accountType api.AccountType,
|
||||
) (*api.Account, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
createdTimeMS := spec.AsTimestamp(time.Now())
|
||||
stmt := s.insertAccountStmt
|
||||
|
||||
var err error
|
||||
|
@ -135,6 +135,7 @@ func (s *accountsStatements) InsertAccount(
|
|||
ServerName: serverName,
|
||||
AppServiceID: appserviceID,
|
||||
AccountType: accountType,
|
||||
Deactivated: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -167,7 +168,7 @@ func (s *accountsStatements) SelectAccountByLocalpart(
|
|||
var acc api.Account
|
||||
|
||||
stmt := s.selectAccountByLocalpartStmt
|
||||
err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType)
|
||||
err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType, &acc.Deactivated)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
log.WithError(err).Error("Unable to retrieve user from the db")
|
||||
|
|
|
@ -50,6 +50,11 @@ func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error)
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := sqlutil.NewMigrator(db)
|
||||
err = m.Up(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL},
|
||||
{&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL},
|
||||
|
@ -82,8 +87,7 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser(
|
|||
}
|
||||
|
||||
func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser(
|
||||
ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData spec.Base64Bytes,
|
||||
) error {
|
||||
ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData spec.Base64Bytes) error {
|
||||
keyTypeInt, ok := types.KeyTypePurposeToInt[keyType]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown key purpose %q", keyType)
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
// Copyright 2025 New Vector Ltd.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
|
||||
// Please see LICENSE files in the repository root for full details.
|
||||
|
||||
package deltas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func UpDropPrimaryKeyConstraint(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
ALTER TABLE userapi_devices RENAME TO userapi_devices_tmp;
|
||||
CREATE TABLE userapi_devices (
|
||||
access_token TEXT,
|
||||
session_id INTEGER,
|
||||
device_id TEXT ,
|
||||
localpart TEXT ,
|
||||
server_name TEXT NOT NULL,
|
||||
created_ts BIGINT,
|
||||
display_name TEXT,
|
||||
last_seen_ts BIGINT,
|
||||
ip TEXT,
|
||||
user_agent TEXT,
|
||||
UNIQUE (localpart, device_id)
|
||||
);
|
||||
INSERT
|
||||
INTO userapi_devices (
|
||||
access_token, session_id, device_id, localpart, server_name, created_ts, display_name, last_seen_ts, ip, user_agent
|
||||
) SELECT
|
||||
access_token, session_id, device_id, localpart, server_name, created_ts, display_name, created_ts, '', ''
|
||||
FROM userapi_devices_tmp;
|
||||
DROP TABLE userapi_devices_tmp;`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DownDropPrimaryKeyConstraint(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
ALTER TABLE userapi_devices RENAME TO userapi_devices_tmp;
|
||||
CREATE TABLE userapi_devices (
|
||||
access_token TEXT PRIMARY KEY,
|
||||
session_id INTEGER,
|
||||
device_id TEXT ,
|
||||
localpart TEXT ,
|
||||
server_name TEXT NOT NULL,
|
||||
created_ts BIGINT,
|
||||
display_name TEXT,
|
||||
last_seen_ts BIGINT,
|
||||
ip TEXT,
|
||||
user_agent TEXT,
|
||||
UNIQUE (localpart, device_id)
|
||||
);
|
||||
INSERT
|
||||
INTO userapi_devices (
|
||||
access_token, session_id, device_id, localpart, created_ts, display_name, last_seen_ts, ip, user_agent
|
||||
) SELECT
|
||||
access_token, session_id, device_id, localpart, created_ts, display_name, created_ts, '', ''
|
||||
FROM userapi_devices_tmp;
|
||||
DROP TABLE userapi_devices_tmp;`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -102,10 +102,16 @@ func NewSQLiteDevicesTable(db *sql.DB, serverName spec.ServerName) (tables.Devic
|
|||
return nil, err
|
||||
}
|
||||
m := sqlutil.NewMigrator(db)
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: "userapi: add last_seen_ts",
|
||||
Up: deltas.UpLastSeenTSIP,
|
||||
})
|
||||
m.AddMigrations(
|
||||
sqlutil.Migration{
|
||||
Version: "userapi: add last_seen_ts",
|
||||
Up: deltas.UpLastSeenTSIP,
|
||||
},
|
||||
sqlutil.Migration{
|
||||
Version: "userapi: drop primary key constraint",
|
||||
Up: deltas.UpDropPrimaryKeyConstraint,
|
||||
},
|
||||
)
|
||||
if err = m.Up(context.Background()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
102
userapi/storage/sqlite3/localpart_external_ids_table.go
Normal file
102
userapi/storage/sqlite3/localpart_external_ids_table.go
Normal file
|
@ -0,0 +1,102 @@
|
|||
// Copyright 2025 New Vector Ltd.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
|
||||
// Please see LICENSE files in the repository root for full details.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/element-hq/dendrite/internal/sqlutil"
|
||||
"github.com/element-hq/dendrite/userapi/api"
|
||||
"github.com/element-hq/dendrite/userapi/storage/tables"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const localpartExternalIDsSchema = `
|
||||
-- Stores data about connections between accounts and third-party auth providers
|
||||
CREATE TABLE IF NOT EXISTS userapi_localpart_external_ids (
|
||||
-- The Matrix user ID for this account
|
||||
localpart TEXT NOT NULL,
|
||||
-- The external ID
|
||||
external_id TEXT NOT NULL,
|
||||
-- Auth provider ID (see OIDCProvider.IDPID)
|
||||
auth_provider TEXT NOT NULL,
|
||||
-- When this connection was created, as a unix timestamp.
|
||||
created_ts BIGINT NOT NULL,
|
||||
|
||||
UNIQUE(external_id, auth_provider),
|
||||
UNIQUE(localpart, external_id, auth_provider)
|
||||
);
|
||||
|
||||
-- This index allows efficient lookup of the local user by the external ID
|
||||
CREATE INDEX IF NOT EXISTS userapi_external_id_auth_provider_idx ON userapi_localpart_external_ids(external_id, auth_provider);
|
||||
`
|
||||
|
||||
const insertLocalpartExternalIDSQL = "" +
|
||||
"INSERT INTO userapi_localpart_external_ids(localpart, external_id, auth_provider, created_ts) VALUES ($1, $2, $3, $4)"
|
||||
|
||||
const selectLocalpartExternalIDSQL = "" +
|
||||
"SELECT localpart, created_ts FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2"
|
||||
|
||||
const deleteLocalpartExternalIDSQL = "" +
|
||||
"DELETE FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2"
|
||||
|
||||
type localpartExternalIDStatements struct {
|
||||
db *sql.DB
|
||||
insertUserExternalIDStmt *sql.Stmt
|
||||
selectUserExternalIDStmt *sql.Stmt
|
||||
deleteUserExternalIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSQLiteLocalpartExternalIDsTable(db *sql.DB) (tables.LocalpartExternalIDsTable, error) {
|
||||
s := &localpartExternalIDStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(localpartExternalIDsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertUserExternalIDStmt, insertLocalpartExternalIDSQL},
|
||||
{&s.selectUserExternalIDStmt, selectLocalpartExternalIDSQL},
|
||||
{&s.deleteUserExternalIDStmt, deleteLocalpartExternalIDSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
// SelectLocalExternalPartID selects an existing OpenID Connect connection from the database
|
||||
func (u *localpartExternalIDStatements) SelectLocalExternalPartID(ctx context.Context, txn *sql.Tx, externalID, authProvider string) (*api.LocalpartExternalID, error) {
|
||||
ret := api.LocalpartExternalID{
|
||||
ExternalID: externalID,
|
||||
AuthProvider: authProvider,
|
||||
}
|
||||
err := u.selectUserExternalIDStmt.QueryRowContext(ctx, externalID, authProvider).Scan(
|
||||
&ret.Localpart, &ret.CreatedTS,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
log.WithError(err).Error("Unable to retrieve localpart from the db")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ret, nil
|
||||
}
|
||||
|
||||
// InsertLocalExternalPartID creates a new record representing an OpenID Connect connection between Matrix and external accounts.
|
||||
func (u *localpartExternalIDStatements) InsertLocalExternalPartID(ctx context.Context, txn *sql.Tx, localpart, externalID, authProvider string) error {
|
||||
stmt := sqlutil.TxStmt(txn, u.insertUserExternalIDStmt)
|
||||
_, err := stmt.ExecContext(ctx, localpart, externalID, authProvider, time.Now().Unix())
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteLocalExternalPartID deletes the existing OpenID Connect connection. After this method is called, the Matrix account will no longer be associated with the external account.
|
||||
func (u *localpartExternalIDStatements) DeleteLocalExternalPartID(ctx context.Context, txn *sql.Tx, externalID, authProvider string) error {
|
||||
stmt := sqlutil.TxStmt(txn, u.deleteUserExternalIDStmt)
|
||||
_, err := stmt.ExecContext(ctx, externalID, authProvider)
|
||||
return err
|
||||
}
|
|
@ -94,6 +94,10 @@ func NewUserDatabase(ctx context.Context, conMan *sqlutil.Connections, dbPropert
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteStatsTable: %w", err)
|
||||
}
|
||||
localpartExternalIDsTable, err := NewSQLiteLocalpartExternalIDsTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteUserExternalIDsTable: %w", err)
|
||||
}
|
||||
|
||||
m = sqlutil.NewMigrator(db)
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
|
@ -119,6 +123,7 @@ func NewUserDatabase(ctx context.Context, conMan *sqlutil.Connections, dbPropert
|
|||
Pushers: pusherTable,
|
||||
Notifications: notificationsTable,
|
||||
Stats: statsTable,
|
||||
LocalpartExternalIDs: localpartExternalIDsTable,
|
||||
ServerName: serverName,
|
||||
DB: db,
|
||||
Writer: writer,
|
||||
|
|
|
@ -14,7 +14,7 @@ import (
|
|||
"github.com/element-hq/dendrite/internal/sqlutil"
|
||||
"github.com/element-hq/dendrite/setup/config"
|
||||
"github.com/element-hq/dendrite/userapi/storage/sqlite3"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
)
|
||||
|
||||
func NewUserDatabase(
|
||||
|
|
|
@ -127,6 +127,12 @@ type StatsTable interface {
|
|||
UpsertDailyStats(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, stats types.MessageStats, activeRooms, activeE2EERooms int64) error
|
||||
}
|
||||
|
||||
type LocalpartExternalIDsTable interface {
|
||||
SelectLocalExternalPartID(ctx context.Context, txn *sql.Tx, externalID, authProvider string) (*api.LocalpartExternalID, error)
|
||||
InsertLocalExternalPartID(ctx context.Context, txn *sql.Tx, localpart, externalID, authProvider string) error
|
||||
DeleteLocalExternalPartID(ctx context.Context, txn *sql.Tx, externalID, authProvider string) error
|
||||
}
|
||||
|
||||
type NotificationFilter uint32
|
||||
|
||||
const (
|
||||
|
|
|
@ -457,34 +457,42 @@ func TestDevices(t *testing.T) {
|
|||
}{
|
||||
{
|
||||
name: "not a local user",
|
||||
inputData: &api.PerformDeviceCreationRequest{Localpart: "test1", ServerName: "notlocal"},
|
||||
inputData: &api.PerformDeviceCreationRequest{Localpart: "test1", ServerName: "notlocal", AccessTokenUniqueConstraintDisabled: true},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "implicit local user",
|
||||
inputData: &api.PerformDeviceCreationRequest{Localpart: "test1", AccessToken: util.RandomString(8), NoDeviceListUpdate: true, DeviceDisplayName: &displayName},
|
||||
inputData: &api.PerformDeviceCreationRequest{Localpart: "test1", AccessToken: util.RandomString(8), NoDeviceListUpdate: true, DeviceDisplayName: &displayName, AccessTokenUniqueConstraintDisabled: true},
|
||||
},
|
||||
{
|
||||
name: "explicit local user",
|
||||
inputData: &api.PerformDeviceCreationRequest{Localpart: "test2", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true},
|
||||
},
|
||||
{
|
||||
name: "dupe token - ok",
|
||||
inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true},
|
||||
},
|
||||
{
|
||||
name: "dupe token - not ok",
|
||||
inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true},
|
||||
wantErr: true,
|
||||
inputData: &api.PerformDeviceCreationRequest{Localpart: "test2", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true, AccessTokenUniqueConstraintDisabled: true},
|
||||
},
|
||||
{
|
||||
name: "test3 second device", // used to test deletion later
|
||||
inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true},
|
||||
inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true, AccessTokenUniqueConstraintDisabled: true},
|
||||
},
|
||||
{
|
||||
name: "test3 third device", // used to test deletion later
|
||||
wantNewDevID: true,
|
||||
inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true},
|
||||
inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true, AccessTokenUniqueConstraintDisabled: true},
|
||||
},
|
||||
{
|
||||
name: "dupe token - ok (unique constraint enabled)",
|
||||
inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true, AccessTokenUniqueConstraintDisabled: false},
|
||||
},
|
||||
{
|
||||
name: "dupe token - not ok (unique constraint enabled)",
|
||||
inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true, AccessTokenUniqueConstraintDisabled: false},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "dupe token - ok (unique constraint disabled)",
|
||||
inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true, AccessTokenUniqueConstraintDisabled: true},
|
||||
},
|
||||
{
|
||||
name: "dupe token - not ok (unique constraint disabled)",
|
||||
inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true, AccessTokenUniqueConstraintDisabled: true},
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -629,7 +637,13 @@ func TestDeviceIDReuse(t *testing.T) {
|
|||
res := api.PerformDeviceCreationResponse{}
|
||||
// create a first device
|
||||
deviceID := util.RandomString(8)
|
||||
req := api.PerformDeviceCreationRequest{Localpart: "alice", ServerName: "test", DeviceID: &deviceID, NoDeviceListUpdate: true}
|
||||
req := api.PerformDeviceCreationRequest{
|
||||
Localpart: "alice",
|
||||
ServerName: "test",
|
||||
DeviceID: &deviceID,
|
||||
NoDeviceListUpdate: true,
|
||||
AccessTokenUniqueConstraintDisabled: true,
|
||||
}
|
||||
err := intAPI.PerformDeviceCreation(ctx, &req, &res)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
Loading…
Add table
Reference in a new issue