diff --git a/ocis-pkg/mfa/mfa.go b/ocis-pkg/mfa/mfa.go index f1b1bc67e5b..e75a46002a7 100644 --- a/ocis-pkg/mfa/mfa.go +++ b/ocis-pkg/mfa/mfa.go @@ -22,29 +22,20 @@ var mfaKey = mfaKeyType{} // This operation does not overwrite existing context values. func EnhanceRequest(req *http.Request) *http.Request { ctx := req.Context() - if Get(ctx) { + if Has(ctx) { return req } return req.WithContext(Set(ctx, req.Header.Get(MFAHeader) == "true")) } -// EnsureOrReject sets the MFA required header and status code 403 if not multi factor authenticated. -func EnsureOrReject(ctx context.Context, w http.ResponseWriter) (hasMFA bool) { - hasMFA = Get(ctx) - if !hasMFA { - SetRequiredStatus(w) - } - return -} - // SetRequiredStatus sets the MFA required header and the statuscode to 403 func SetRequiredStatus(w http.ResponseWriter) { w.Header().Set(MFARequiredHeader, "true") w.WriteHeader(http.StatusForbidden) } -// Get gets the mfa status from the context. -func Get(ctx context.Context) bool { +// Has returns the mfa status from the context. +func Has(ctx context.Context) bool { mfa, ok := ctx.Value(mfaKey).(bool) if !ok { return false diff --git a/ocis-pkg/mfa/mfa_test.go b/ocis-pkg/mfa/mfa_test.go index f43815e737d..50e8261a70d 100644 --- a/ocis-pkg/mfa/mfa_test.go +++ b/ocis-pkg/mfa/mfa_test.go @@ -19,13 +19,10 @@ func exampleUsage() http.HandlerFunc { ctx := r.Context() // now you can check if the user has MFA enabled - hasMFA := mfa.Get(ctx) - _ = hasMFA // use it as needed - - // use convenience method to automatically set headers and response codes - if !mfa.EnsureOrReject(ctx, w) { + if !mfa.Has(ctx) { // use this line to log access denied information // mfa package will not log anything by itself + mfa.SetRequiredStatus(w) return } // user has MFA enabled, you can now proceed with sensitive operation diff --git a/services/graph/pkg/service/v0/drives.go b/services/graph/pkg/service/v0/drives.go index e5aa3e77c35..a0eaffe8d69 100644 --- a/services/graph/pkg/service/v0/drives.go +++ b/services/graph/pkg/service/v0/drives.go @@ -133,9 +133,10 @@ func (g Graph) GetAllDrives(version APIVersion) http.HandlerFunc { // GetAllDrivesV1 attempts to retrieve the current users drives; // it includes another user's drives, if the current user has the permission. func (g Graph) GetAllDrivesV1(w http.ResponseWriter, r *http.Request) { - if !mfa.EnsureOrReject(r.Context(), w) { + if !mfa.Has(r.Context()) { logger := g.logger.SubloggerWithRequestID(r.Context()) logger.Error().Str("path", r.URL.Path).Msg("MFA required but not satisfied") + mfa.SetRequiredStatus(w) return } @@ -159,9 +160,10 @@ func (g Graph) GetAllDrivesV1(w http.ResponseWriter, r *http.Request) { // it includes the grantedtoV2 property // it uses unified roles instead of the cs3 representations func (g Graph) GetAllDrivesV1Beta1(w http.ResponseWriter, r *http.Request) { - if !mfa.EnsureOrReject(r.Context(), w) { + if !mfa.Has(r.Context()) { logger := g.logger.SubloggerWithRequestID(r.Context()) logger.Error().Str("path", r.URL.Path).Msg("MFA required but not satisfied") + mfa.SetRequiredStatus(w) return } diff --git a/services/graph/pkg/service/v0/graph.go b/services/graph/pkg/service/v0/graph.go index 6d87e9c087d..6ceebe3a962 100644 --- a/services/graph/pkg/service/v0/graph.go +++ b/services/graph/pkg/service/v0/graph.go @@ -8,6 +8,7 @@ import ( "path" "strings" + "github.com/CiscoM31/godata" gateway "github.com/cs3org/go-cs3apis/cs3/gateway/v1beta1" storageprovider "github.com/cs3org/go-cs3apis/cs3/storage/provider/v1beta1" "github.com/go-chi/chi/v5" @@ -152,3 +153,39 @@ func parseIDParam(r *http.Request, param string) (storageprovider.ResourceId, er } return id, nil } + +// regular users can only search for terms with a minimum length +func hasAcceptableSearch(query *godata.GoDataQuery, minSearchLength int) bool { + if query == nil || query.Search == nil { + return false + } + + if strings.HasPrefix(query.Search.RawValue, "\"") { + // if search starts with double quotes then it must finish with double quotes + // add +2 to the minimum search length in this case + minSearchLength += 2 + } + + return len(query.Search.RawValue) >= minSearchLength +} + +// regular users can only filter by userType +func hasAcceptableFilter(query *godata.GoDataQuery) bool { + switch { + case query == nil || query.Filter == nil: + return true + case query.Filter.Tree.Token.Type != godata.ExpressionTokenLogical: + return false + case query.Filter.Tree.Token.Value != "eq": + return false + case query.Filter.Tree.Children[0].Token.Value != "userType": + return false + } + + return true +} + +// regular users can only use basic queries without any expansions, computes or applies +func hasAcceptableQuery(query *godata.GoDataQuery) bool { + return query != nil && query.Apply == nil && query.Expand == nil && query.Compute == nil +} diff --git a/services/graph/pkg/service/v0/groups.go b/services/graph/pkg/service/v0/groups.go index a81faa37c85..8f386eefa84 100644 --- a/services/graph/pkg/service/v0/groups.go +++ b/services/graph/pkg/service/v0/groups.go @@ -33,35 +33,34 @@ func (g Graph) GetGroups(w http.ResponseWriter, r *http.Request) { return } ctxHasFullPerms := g.contextUserHasFullAccountPerms(r.Context()) - searchHasAcceptableLength := false - if odataReq.Query != nil && odataReq.Query.Search != nil { - minSearchLength := g.config.API.IdentitySearchMinLength - if strings.HasPrefix(odataReq.Query.Search.RawValue, "\"") { - // if search starts with double quotes then it must finish with double quotes - // add +2 to the minimum search length in this case - minSearchLength += 2 - } - searchHasAcceptableLength = len(odataReq.Query.Search.RawValue) >= minSearchLength - } - if !searchHasAcceptableLength { + hasMFA := mfa.Has(r.Context()) + if !hasAcceptableSearch(odataReq.Query, g.config.API.IdentitySearchMinLength) { if !ctxHasFullPerms { // for regular user the search term must have a minimum length logger.Debug().Interface("query", r.URL.Query()).Msgf("search with less than %d chars for a regular user", g.config.API.IdentitySearchMinLength) errorcode.AccessDenied.Render(w, r, http.StatusForbidden, "search term too short") return } - - if !mfa.EnsureOrReject(r.Context(), w) { + if !hasMFA { logger.Error().Str("path", r.URL.Path).Msg("MFA required but not satisfied") + mfa.SetRequiredStatus(w) return } } - if !ctxHasFullPerms && (odataReq.Query.Filter != nil || odataReq.Query.Apply != nil || odataReq.Query.Expand != nil || odataReq.Query.Compute != nil) { - // regular users can't use filter, apply, expand and compute - logger.Debug().Interface("query", r.URL.Query()).Msg("forbidden query elements for a regular user") - errorcode.AccessDenied.Render(w, r, http.StatusForbidden, "query has forbidden elements for regular users") - return + if !hasAcceptableQuery(odataReq.Query) { + if !ctxHasFullPerms { + // regular users can't use filter, apply, expand and compute + logger.Debug().Interface("query", r.URL.Query()).Msg("forbidden query elements for a regular user") + errorcode.AccessDenied.Render(w, r, http.StatusForbidden, "query has forbidden elements for regular users") + return + } + + if !hasMFA { + logger.Error().Str("path", r.URL.Path).Msg("MFA required but not satisfied") + mfa.SetRequiredStatus(w) + return + } } groups, err := g.identityBackend.GetGroups(r.Context(), odataReq) diff --git a/services/graph/pkg/service/v0/users.go b/services/graph/pkg/service/v0/users.go index 5dad76424cd..567c2bd982f 100644 --- a/services/graph/pkg/service/v0/users.go +++ b/services/graph/pkg/service/v0/users.go @@ -222,48 +222,49 @@ func (g Graph) GetUsers(w http.ResponseWriter, r *http.Request) { } ctxHasFullPerms := g.contextUserHasFullAccountPerms(r.Context()) - searchHasAcceptableLength := false - if odataReq.Query != nil && odataReq.Query.Search != nil { - minSearchLength := g.config.API.IdentitySearchMinLength - if strings.HasPrefix(odataReq.Query.Search.RawValue, "\"") { - // if search starts with double quotes then it must finish with double quotes - // add +2 to the minimum search length in this case - minSearchLength += 2 - } - searchHasAcceptableLength = len(odataReq.Query.Search.RawValue) >= minSearchLength - } - if !searchHasAcceptableLength { + hasMFA := mfa.Has(r.Context()) + if !hasAcceptableSearch(odataReq.Query, g.config.API.IdentitySearchMinLength) { if !ctxHasFullPerms { // for regular user the search term must have a minimum length logger.Debug().Interface("query", r.URL.Query()).Msgf("search with less than %d chars for a regular user", g.config.API.IdentitySearchMinLength) errorcode.AccessDenied.Render(w, r, http.StatusForbidden, "search term too short") return } - if !mfa.EnsureOrReject(r.Context(), w) { + if !hasMFA { logger.Error().Str("path", r.URL.Path).Msg("MFA required but not satisfied") + mfa.SetRequiredStatus(w) return } } - if !ctxHasFullPerms && odataReq.Query.Filter != nil { - // regular users are allowed to filter only by userType - filter := odataReq.Query.Filter - switch { - case filter.Tree.Token.Type != godata.ExpressionTokenLogical: - fallthrough - case filter.Tree.Token.Value != "eq": - fallthrough - case filter.Tree.Children[0].Token.Value != "userType": + if !hasAcceptableFilter(odataReq.Query) { + if !ctxHasFullPerms { + // regular users are allowed to filter only by userType logger.Debug().Interface("query", r.URL.Query()).Msg("forbidden filter for a regular user") errorcode.AccessDenied.Render(w, r, http.StatusForbidden, "filter has forbidden elements for regular users") return } + + if !hasMFA { + logger.Error().Str("path", r.URL.Path).Msg("MFA required but not satisfied") + mfa.SetRequiredStatus(w) + return + } } - if !ctxHasFullPerms && (odataReq.Query.Apply != nil || odataReq.Query.Expand != nil || odataReq.Query.Compute != nil) { - // regular users can't use filter, apply, expand and compute - logger.Debug().Interface("query", r.URL.Query()).Msg("forbidden query elements for a regular user") - errorcode.AccessDenied.Render(w, r, http.StatusForbidden, "query has forbidden elements for regular users") - return + + if !hasAcceptableQuery(odataReq.Query) { + if !ctxHasFullPerms { + // regular users can't use filter, apply, expand and compute + logger.Debug().Interface("query", r.URL.Query()).Msg("forbidden query elements for a regular user") + errorcode.AccessDenied.Render(w, r, http.StatusForbidden, "query has forbidden elements for regular users") + return + } + + if !hasMFA { + logger.Error().Str("path", r.URL.Path).Msg("MFA required but not satisfied") + mfa.SetRequiredStatus(w) + return + } } logger.Debug().Interface("query", r.URL.Query()).Msg("calling get users on backend")