mas: store crossSigngingKeysReplacement period in sessionsDict struct instead of db

This commit is contained in:
Roman Isaev 2025-01-24 23:22:02 +00:00
parent b5f34dfe47
commit 453445695c
No known key found for this signature in database
GPG key ID: 7BE2B6A6C89AEC7F
3 changed files with 61 additions and 39 deletions

View file

@ -35,10 +35,6 @@ import (
"github.com/element-hq/dendrite/userapi/storage/shared"
)
const (
replacementPeriod time.Duration = 10 * time.Minute
)
var (
validRegistrationTokenRegex = regexp.MustCompile("^[[:ascii:][:digit:]_]*$")
deviceDisplayName = "OIDC-native client"
@ -807,27 +803,10 @@ func AdminAllowCrossSigningReplacementWithoutUIA(
switch req.Method {
case http.MethodPost:
rq := userapi.PerformAllowingMasterCrossSigningKeyReplacementWithoutUIARequest{
UserID: userID.String(),
Duration: replacementPeriod,
}
var rs userapi.PerformAllowingMasterCrossSigningKeyReplacementWithoutUIAResponse
err = userAPI.PerformAllowingMasterCrossSigningKeyReplacementWithoutUIA(req.Context(), &rq, &rs)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
util.GetLogger(req.Context()).WithError(err).Error("userAPI.PerformAllowingMasterCrossSigningKeyReplacementWithoutUIA")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.Unknown(err.Error()),
}
} else if errors.Is(err, sql.ErrNoRows) {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: spec.NotFound("User not found."),
}
}
ts := sessions.allowCrossSigningKeysReplacement(userID.String())
return util.JSONResponse{
Code: http.StatusOK,
JSON: map[string]int64{"updatable_without_uia_before_ms": rs.Timestamp},
JSON: map[string]int64{"updatable_without_uia_before_ms": ts},
}
default:
return util.JSONResponse{

View file

@ -49,11 +49,6 @@ func UploadCrossSigningDeviceKeys(
return *resErr
}
sessionID := uploadReq.Auth.Session
if sessionID == "" {
sessionID = util.RandomString(sessionIDLength)
}
// Query existing keys to determine if UIA is required
keyResp := api.QueryKeysResponse{}
keyserverAPI.QueryKeys(req.Context(), &api.QueryKeysRequest{
@ -68,7 +63,6 @@ func UploadCrossSigningDeviceKeys(
}
existingMasterKey, hasMasterKey := keyResp.MasterKeys[device.UserID]
requireUIA := true
if hasMasterKey {
if !keysDiffer(existingMasterKey, keyResp, uploadReq, device.UserID) {
@ -89,10 +83,8 @@ func UploadCrossSigningDeviceKeys(
logger.WithError(masterKeyResp.Error).Error("Failed to query master key")
return convertKeyError(masterKeyResp.Error)
}
if k := masterKeyResp.Key; k != nil && k.UpdatableWithoutUIABeforeMs != nil {
requireUIA = !(time.Now().UnixMilli() < *k.UpdatableWithoutUIABeforeMs)
}
requireUIA := !sessions.isCrossSigningKeysReplacementAllowed(device.UserID) && masterKeyResp.Key != nil
if requireUIA {
url := ""
if m := cfg.MSCs.MSC3861; m.AccountManagementURL != "" {
@ -122,9 +114,13 @@ func UploadCrossSigningDeviceKeys(
),
}
}
// XXX: is it necessary?
sessions.addCompletedSessionStage(sessionID, CrossSigningResetStage)
sessions.restrictCrossSigningKeysReplacement(device.UserID)
} else {
sessionID := uploadReq.Auth.Session
if sessionID == "" {
sessionID = util.RandomString(sessionIDLength)
}
if uploadReq.Auth.Type != authtypes.LoginTypePassword {
return util.JSONResponse{
Code: http.StatusUnauthorized,

View file

@ -66,11 +66,17 @@ type sessionsDict struct {
// If a UIA session is started by trying to delete device1, and then UIA is completed by deleting device2,
// the delete request will fail for device2 since the UIA was initiated by trying to delete device1.
deleteSessionToDeviceID map[string]string
// allowedForCrossSigningKeysReplacement is a collection of sessions that MAS has authorised for updating
// cross-signing keys without UIA.
allowedForCrossSigningKeysReplacement map[string]*time.Timer
}
// defaultTimeout is the timeout used to clean up sessions
const defaultTimeOut = time.Minute * 5
// allowedForCrossSigningKeysReplacementDuration is the timeout used for replacing cross signing keys without UIA
const allowedForCrossSigningKeysReplacementDuration = time.Minute * 10
// getCompletedStages returns the completed stages for a session.
func (d *sessionsDict) getCompletedStages(sessionID string) []authtypes.LoginType {
d.RLock()
@ -119,13 +125,54 @@ func (d *sessionsDict) deleteSession(sessionID string) {
}
}
func (d *sessionsDict) allowCrossSigningKeysReplacement(userID string) int64 {
d.Lock()
defer d.Unlock()
ts := time.Now().Add(allowedForCrossSigningKeysReplacementDuration).UnixMilli()
t, ok := d.allowedForCrossSigningKeysReplacement[userID]
if ok {
t.Reset(allowedForCrossSigningKeysReplacementDuration)
return ts
}
d.allowedForCrossSigningKeysReplacement[userID] = time.AfterFunc(
allowedForCrossSigningKeysReplacementDuration,
func() {
d.restrictCrossSigningKeysReplacement(userID)
},
)
return ts
}
func (d *sessionsDict) isCrossSigningKeysReplacementAllowed(userID string) bool {
d.RLock()
defer d.RUnlock()
_, ok := d.allowedForCrossSigningKeysReplacement[userID]
return ok
}
func (d *sessionsDict) restrictCrossSigningKeysReplacement(userID string) {
d.Lock()
defer d.Unlock()
t, ok := d.allowedForCrossSigningKeysReplacement[userID]
if ok {
if !t.Stop() {
select {
case <-t.C:
default:
}
}
delete(d.allowedForCrossSigningKeysReplacement, userID)
}
}
func newSessionsDict() *sessionsDict {
return &sessionsDict{
sessions: make(map[string][]authtypes.LoginType),
sessionCompletedResult: make(map[string]registerResponse),
params: make(map[string]registerRequest),
timer: make(map[string]*time.Timer),
deleteSessionToDeviceID: make(map[string]string),
sessions: make(map[string][]authtypes.LoginType),
sessionCompletedResult: make(map[string]registerResponse),
params: make(map[string]registerRequest),
timer: make(map[string]*time.Timer),
deleteSessionToDeviceID: make(map[string]string),
allowedForCrossSigningKeysReplacement: make(map[string]*time.Timer),
}
}