feat: expand mfa checks also to queries

Signed-off-by: Julian Koberg <julian.koberg@kiteworks.com>
This commit is contained in:
Julian Koberg
2025-09-23 13:49:16 +02:00
parent dda6104722
commit 04b829e8a7
6 changed files with 90 additions and 63 deletions

View File

@@ -22,29 +22,20 @@ var mfaKey = mfaKeyType{}
// This operation does not overwrite existing context values.
func EnhanceRequest(req *http.Request) *http.Request {
ctx := req.Context()
if Get(ctx) {
if Has(ctx) {
return req
}
return req.WithContext(Set(ctx, req.Header.Get(MFAHeader) == "true"))
}
// EnsureOrReject sets the MFA required header and status code 403 if not multi factor authenticated.
func EnsureOrReject(ctx context.Context, w http.ResponseWriter) (hasMFA bool) {
hasMFA = Get(ctx)
if !hasMFA {
SetRequiredStatus(w)
}
return
}
// SetRequiredStatus sets the MFA required header and the statuscode to 403
func SetRequiredStatus(w http.ResponseWriter) {
w.Header().Set(MFARequiredHeader, "true")
w.WriteHeader(http.StatusForbidden)
}
// Get gets the mfa status from the context.
func Get(ctx context.Context) bool {
// Has returns the mfa status from the context.
func Has(ctx context.Context) bool {
mfa, ok := ctx.Value(mfaKey).(bool)
if !ok {
return false

View File

@@ -19,13 +19,10 @@ func exampleUsage() http.HandlerFunc {
ctx := r.Context()
// now you can check if the user has MFA enabled
hasMFA := mfa.Get(ctx)
_ = hasMFA // use it as needed
// use convenience method to automatically set headers and response codes
if !mfa.EnsureOrReject(ctx, w) {
if !mfa.Has(ctx) {
// use this line to log access denied information
// mfa package will not log anything by itself
mfa.SetRequiredStatus(w)
return
}
// user has MFA enabled, you can now proceed with sensitive operation

View File

@@ -133,9 +133,10 @@ func (g Graph) GetAllDrives(version APIVersion) http.HandlerFunc {
// GetAllDrivesV1 attempts to retrieve the current users drives;
// it includes another user's drives, if the current user has the permission.
func (g Graph) GetAllDrivesV1(w http.ResponseWriter, r *http.Request) {
if !mfa.EnsureOrReject(r.Context(), w) {
if !mfa.Has(r.Context()) {
logger := g.logger.SubloggerWithRequestID(r.Context())
logger.Error().Str("path", r.URL.Path).Msg("MFA required but not satisfied")
mfa.SetRequiredStatus(w)
return
}
@@ -159,9 +160,10 @@ func (g Graph) GetAllDrivesV1(w http.ResponseWriter, r *http.Request) {
// it includes the grantedtoV2 property
// it uses unified roles instead of the cs3 representations
func (g Graph) GetAllDrivesV1Beta1(w http.ResponseWriter, r *http.Request) {
if !mfa.EnsureOrReject(r.Context(), w) {
if !mfa.Has(r.Context()) {
logger := g.logger.SubloggerWithRequestID(r.Context())
logger.Error().Str("path", r.URL.Path).Msg("MFA required but not satisfied")
mfa.SetRequiredStatus(w)
return
}

View File

@@ -8,6 +8,7 @@ import (
"path"
"strings"
"github.com/CiscoM31/godata"
gateway "github.com/cs3org/go-cs3apis/cs3/gateway/v1beta1"
storageprovider "github.com/cs3org/go-cs3apis/cs3/storage/provider/v1beta1"
"github.com/go-chi/chi/v5"
@@ -152,3 +153,39 @@ func parseIDParam(r *http.Request, param string) (storageprovider.ResourceId, er
}
return id, nil
}
// regular users can only search for terms with a minimum length
func hasAcceptableSearch(query *godata.GoDataQuery, minSearchLength int) bool {
if query == nil || query.Search == nil {
return false
}
if strings.HasPrefix(query.Search.RawValue, "\"") {
// if search starts with double quotes then it must finish with double quotes
// add +2 to the minimum search length in this case
minSearchLength += 2
}
return len(query.Search.RawValue) >= minSearchLength
}
// regular users can only filter by userType
func hasAcceptableFilter(query *godata.GoDataQuery) bool {
switch {
case query == nil || query.Filter == nil:
return true
case query.Filter.Tree.Token.Type != godata.ExpressionTokenLogical:
return false
case query.Filter.Tree.Token.Value != "eq":
return false
case query.Filter.Tree.Children[0].Token.Value != "userType":
return false
}
return true
}
// regular users can only use basic queries without any expansions, computes or applies
func hasAcceptableQuery(query *godata.GoDataQuery) bool {
return query != nil && query.Apply == nil && query.Expand == nil && query.Compute == nil
}

View File

@@ -33,35 +33,34 @@ func (g Graph) GetGroups(w http.ResponseWriter, r *http.Request) {
return
}
ctxHasFullPerms := g.contextUserHasFullAccountPerms(r.Context())
searchHasAcceptableLength := false
if odataReq.Query != nil && odataReq.Query.Search != nil {
minSearchLength := g.config.API.IdentitySearchMinLength
if strings.HasPrefix(odataReq.Query.Search.RawValue, "\"") {
// if search starts with double quotes then it must finish with double quotes
// add +2 to the minimum search length in this case
minSearchLength += 2
}
searchHasAcceptableLength = len(odataReq.Query.Search.RawValue) >= minSearchLength
}
if !searchHasAcceptableLength {
hasMFA := mfa.Has(r.Context())
if !hasAcceptableSearch(odataReq.Query, g.config.API.IdentitySearchMinLength) {
if !ctxHasFullPerms {
// for regular user the search term must have a minimum length
logger.Debug().Interface("query", r.URL.Query()).Msgf("search with less than %d chars for a regular user", g.config.API.IdentitySearchMinLength)
errorcode.AccessDenied.Render(w, r, http.StatusForbidden, "search term too short")
return
}
if !mfa.EnsureOrReject(r.Context(), w) {
if !hasMFA {
logger.Error().Str("path", r.URL.Path).Msg("MFA required but not satisfied")
mfa.SetRequiredStatus(w)
return
}
}
if !ctxHasFullPerms && (odataReq.Query.Filter != nil || odataReq.Query.Apply != nil || odataReq.Query.Expand != nil || odataReq.Query.Compute != nil) {
// regular users can't use filter, apply, expand and compute
logger.Debug().Interface("query", r.URL.Query()).Msg("forbidden query elements for a regular user")
errorcode.AccessDenied.Render(w, r, http.StatusForbidden, "query has forbidden elements for regular users")
return
if !hasAcceptableQuery(odataReq.Query) {
if !ctxHasFullPerms {
// regular users can't use filter, apply, expand and compute
logger.Debug().Interface("query", r.URL.Query()).Msg("forbidden query elements for a regular user")
errorcode.AccessDenied.Render(w, r, http.StatusForbidden, "query has forbidden elements for regular users")
return
}
if !hasMFA {
logger.Error().Str("path", r.URL.Path).Msg("MFA required but not satisfied")
mfa.SetRequiredStatus(w)
return
}
}
groups, err := g.identityBackend.GetGroups(r.Context(), odataReq)

View File

@@ -222,48 +222,49 @@ func (g Graph) GetUsers(w http.ResponseWriter, r *http.Request) {
}
ctxHasFullPerms := g.contextUserHasFullAccountPerms(r.Context())
searchHasAcceptableLength := false
if odataReq.Query != nil && odataReq.Query.Search != nil {
minSearchLength := g.config.API.IdentitySearchMinLength
if strings.HasPrefix(odataReq.Query.Search.RawValue, "\"") {
// if search starts with double quotes then it must finish with double quotes
// add +2 to the minimum search length in this case
minSearchLength += 2
}
searchHasAcceptableLength = len(odataReq.Query.Search.RawValue) >= minSearchLength
}
if !searchHasAcceptableLength {
hasMFA := mfa.Has(r.Context())
if !hasAcceptableSearch(odataReq.Query, g.config.API.IdentitySearchMinLength) {
if !ctxHasFullPerms {
// for regular user the search term must have a minimum length
logger.Debug().Interface("query", r.URL.Query()).Msgf("search with less than %d chars for a regular user", g.config.API.IdentitySearchMinLength)
errorcode.AccessDenied.Render(w, r, http.StatusForbidden, "search term too short")
return
}
if !mfa.EnsureOrReject(r.Context(), w) {
if !hasMFA {
logger.Error().Str("path", r.URL.Path).Msg("MFA required but not satisfied")
mfa.SetRequiredStatus(w)
return
}
}
if !ctxHasFullPerms && odataReq.Query.Filter != nil {
// regular users are allowed to filter only by userType
filter := odataReq.Query.Filter
switch {
case filter.Tree.Token.Type != godata.ExpressionTokenLogical:
fallthrough
case filter.Tree.Token.Value != "eq":
fallthrough
case filter.Tree.Children[0].Token.Value != "userType":
if !hasAcceptableFilter(odataReq.Query) {
if !ctxHasFullPerms {
// regular users are allowed to filter only by userType
logger.Debug().Interface("query", r.URL.Query()).Msg("forbidden filter for a regular user")
errorcode.AccessDenied.Render(w, r, http.StatusForbidden, "filter has forbidden elements for regular users")
return
}
if !hasMFA {
logger.Error().Str("path", r.URL.Path).Msg("MFA required but not satisfied")
mfa.SetRequiredStatus(w)
return
}
}
if !ctxHasFullPerms && (odataReq.Query.Apply != nil || odataReq.Query.Expand != nil || odataReq.Query.Compute != nil) {
// regular users can't use filter, apply, expand and compute
logger.Debug().Interface("query", r.URL.Query()).Msg("forbidden query elements for a regular user")
errorcode.AccessDenied.Render(w, r, http.StatusForbidden, "query has forbidden elements for regular users")
return
if !hasAcceptableQuery(odataReq.Query) {
if !ctxHasFullPerms {
// regular users can't use filter, apply, expand and compute
logger.Debug().Interface("query", r.URL.Query()).Msg("forbidden query elements for a regular user")
errorcode.AccessDenied.Render(w, r, http.StatusForbidden, "query has forbidden elements for regular users")
return
}
if !hasMFA {
logger.Error().Str("path", r.URL.Path).Msg("MFA required but not satisfied")
mfa.SetRequiredStatus(w)
return
}
}
logger.Debug().Interface("query", r.URL.Query()).Msg("calling get users on backend")