code review fixes

This commit is contained in:
Roman Isaev 2025-01-23 01:25:04 +00:00
parent 17b7677071
commit 8a05a66cd7
No known key found for this signature in database
GPG key ID: 7BE2B6A6C89AEC7F
9 changed files with 24 additions and 24 deletions

View file

@ -376,7 +376,7 @@ func Setup(
})).Methods(http.MethodPost)
} else {
// If msc3861 is enabled, these endpoints are either redundant or replaced by Matrix Auth Service (MAS)
// Once we migrate to MAS completely, these edndpoints should be removed
// Once we migrate to MAS completely, these endpoints should be removed
v3mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
if r := rateLimits.Limit(req, nil); r != nil {

View file

@ -76,10 +76,10 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors) {
c.RateLimiting.Verify(configErrs)
if c.MSCs.MSC3861Enabled() {
if c.RecaptchaEnabled || !c.RegistrationDisabled {
if !c.RegistrationDisabled || c.RecaptchaEnabled {
configErrs.Add(
"You have enabled the experimental feature MSC3861 which implements the delegated authentication via OIDC." +
"As a result, the feature conflicts with the standard Dendrite's registration and login flows and cannot be used if any of those is enabled." +
"You have enabled the experimental feature MSC3861 which implements the delegated authentication via OIDC. " +
"As a result, the feature conflicts with the standard Dendrite's registration and login flows and cannot be used if any of those is enabled. " +
"You need to disable registration (client_api.registration_disabled) and recapthca (client_api.enable_registration_captcha) options to proceed.",
)
}

View file

@ -10,7 +10,8 @@ type MSCs struct {
// 'msc2836': Threading - https://github.com/matrix-org/matrix-doc/pull/2836
MSCs []string `yaml:"mscs"`
// MSC3861 contains config related to the experimental feature MSC3861. It takes effect only if 'msc3861' is included in 'MSCs' array
// MSC3861 contains config related to the experimental feature MSC3861.
// It takes effect only if 'msc3861' is included in 'MSCs' array.
MSC3861 *MSC3861 `yaml:"msc3861,omitempty"`
Database DatabaseOptions `yaml:"database,omitempty"`

View file

@ -121,14 +121,14 @@ func (s *syncUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.Pe
}
type userVerifier struct {
m map[string]struct {
accessTokenToDeviceAndResponse map[string]struct {
Device *userapi.Device
Response *util.JSONResponse
}
}
func (u *userVerifier) VerifyUserFromRequest(req *http.Request) (*userapi.Device, *util.JSONResponse) {
if pair, ok := u.m[req.URL.Query().Get("access_token")]; ok {
if pair, ok := u.accessTokenToDeviceAndResponse[req.URL.Query().Get("access_token")]; ok {
return pair.Device, pair.Response
}
return nil, nil
@ -212,13 +212,13 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) {
},
}
uv.m = make(map[string]struct {
uv.accessTokenToDeviceAndResponse = make(map[string]struct {
Device *userapi.Device
Response *util.JSONResponse
}, len(testCases))
for _, tc := range testCases {
uv.m[tc.req.URL.Query().Get("access_token")] = struct {
uv.accessTokenToDeviceAndResponse[tc.req.URL.Query().Get("access_token")] = struct {
Device *userapi.Device
Response *util.JSONResponse
}{Device: tc.device, Response: tc.response}
@ -285,7 +285,7 @@ func testSyncEventFormatPowerLevels(t *testing.T, dbType test.DBType) {
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
natsInstance := jetstream.NATSInstance{}
uv := userVerifier{
m: map[string]struct {
accessTokenToDeviceAndResponse: map[string]struct {
Device *userapi.Device
Response *util.JSONResponse
}{
@ -539,7 +539,7 @@ func testSyncAPIUpdatePresenceImmediately(t *testing.T, dbType test.DBType) {
jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream)
uv := userVerifier{
m: map[string]struct {
accessTokenToDeviceAndResponse: map[string]struct {
Device *userapi.Device
Response *util.JSONResponse
}{
@ -669,7 +669,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) {
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
uv := userVerifier{
m: map[string]struct {
accessTokenToDeviceAndResponse: map[string]struct {
Device *userapi.Device
Response *util.JSONResponse
}{
@ -947,7 +947,7 @@ func TestGetMembership(t *testing.T) {
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
uv := userVerifier{
m: map[string]struct {
accessTokenToDeviceAndResponse: map[string]struct {
Device *userapi.Device
Response *util.JSONResponse
}{
@ -1024,7 +1024,7 @@ func testSendToDevice(t *testing.T, dbType test.DBType) {
defer close()
natsInstance := jetstream.NATSInstance{}
uv := userVerifier{
m: map[string]struct {
accessTokenToDeviceAndResponse: map[string]struct {
Device *userapi.Device
Response *util.JSONResponse
}{
@ -1258,7 +1258,7 @@ func testContext(t *testing.T, dbType test.DBType) {
rsAPI.SetFederationAPI(nil, nil)
uv := userVerifier{
m: map[string]struct {
accessTokenToDeviceAndResponse: map[string]struct {
Device *userapi.Device
Response *util.JSONResponse
}{
@ -1446,7 +1446,7 @@ func TestRemoveEditedEventFromSearchIndex(t *testing.T) {
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
rsAPI.SetFederationAPI(nil, nil)
uv := userVerifier{
m: map[string]struct {
accessTokenToDeviceAndResponse: map[string]struct {
Device *userapi.Device
Response *util.JSONResponse
}{

View file

@ -483,7 +483,7 @@ type LocalpartExternalID struct {
Localpart string
ExternalID string
AuthProvider string
CreatedTs int64
CreatedTS int64
}
// UserInfo is for returning information about the user an OpenID token was issued for

View file

@ -116,8 +116,7 @@ func (s *accountsStatements) InsertAccount(
localpart string, serverName spec.ServerName,
hash, appserviceID string, accountType api.AccountType,
) (*api.Account, error) {
// TODO: can we replace "UnixNano() / 1M" with "UnixMilli()"?
createdTimeMS := time.Now().UnixNano() / 1000000
createdTimeMS := spec.AsTimestamp(time.Now())
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
var err error

View file

@ -38,7 +38,7 @@ const selectUserExternalIDSQL = "" +
"SELECT localpart, created_ts FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2"
const deleteUserExternalIDSQL = "" +
"SELECT localpart, external_id, auth_provider, created_ts FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2"
"DELETE FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2"
type localpartExternalIDStatements struct {
db *sql.DB
@ -69,7 +69,7 @@ func (u *localpartExternalIDStatements) Select(ctx context.Context, txn *sql.Tx,
AuthProvider: authProvider,
}
err := u.selectUserExternalIDStmt.QueryRowContext(ctx, externalID, authProvider).Scan(
&ret.Localpart, &ret.CreatedTs,
&ret.Localpart, &ret.CreatedTS,
)
if err != nil {
if err == sql.ErrNoRows {

View file

@ -116,7 +116,7 @@ func (s *accountsStatements) InsertAccount(
ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName,
hash, appserviceID string, accountType api.AccountType,
) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
createdTimeMS := spec.AsTimestamp(time.Now())
stmt := s.insertAccountStmt
var err error

View file

@ -38,7 +38,7 @@ const selectLocalpartExternalIDSQL = "" +
"SELECT localpart, created_ts FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2"
const deleteLocalpartExternalIDSQL = "" +
"SELECT localpart, external_id, auth_provider, created_ts FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2"
"DELETE FROM userapi_localpart_external_ids WHERE external_id = $1 AND auth_provider = $2"
type localpartExternalIDStatements struct {
db *sql.DB
@ -69,7 +69,7 @@ func (u *localpartExternalIDStatements) Select(ctx context.Context, txn *sql.Tx,
AuthProvider: authProvider,
}
err := u.selectUserExternalIDStmt.QueryRowContext(ctx, externalID, authProvider).Scan(
&ret.Localpart, &ret.CreatedTs,
&ret.Localpart, &ret.CreatedTS,
)
if err != nil {
if err == sql.ErrNoRows {