mirror of
https://github.com/goauthentik/authentik
synced 2026-05-09 00:22:24 +02:00
Compare commits
15 Commits
router-tid
...
docs-event
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b5e38dcacc | ||
|
|
aca8c883db | ||
|
|
47a54fedd0 | ||
|
|
bc00e7284b | ||
|
|
207d3557e6 | ||
|
|
1beea91bbf | ||
|
|
02bee093b7 | ||
|
|
17c957b94d | ||
|
|
fb2450169c | ||
|
|
d8eb2bd016 | ||
|
|
e8f56df048 | ||
|
|
5a6c13e991 | ||
|
|
b28af354a2 | ||
|
|
be9572b12b | ||
|
|
b6d1c055cb |
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 2025.6.4
|
||||
current_version = 2025.6.3
|
||||
tag = True
|
||||
commit = True
|
||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))?
|
||||
|
||||
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
@@ -31,4 +31,4 @@ If changes to the frontend have been made
|
||||
If applicable
|
||||
|
||||
- [ ] The documentation has been updated
|
||||
- [ ] The documentation has been formatted (`make docs`)
|
||||
- [ ] The documentation has been formatted (`make website`)
|
||||
|
||||
4
.github/workflows/api-ts-publish.yml
vendored
4
.github/workflows/api-ts-publish.yml
vendored
@@ -27,8 +27,8 @@ jobs:
|
||||
- name: Publish package
|
||||
working-directory: gen-ts-api/
|
||||
run: |
|
||||
npm i
|
||||
npm publish --tag generated
|
||||
npm ci
|
||||
npm publish
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_PUBLISH_TOKEN }}
|
||||
- name: Upgrade /web
|
||||
|
||||
94
.github/workflows/ci-api-docs.yml
vendored
94
.github/workflows/ci-api-docs.yml
vendored
@@ -1,94 +0,0 @@
|
||||
name: authentik-ci-api-docs
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- next
|
||||
- version-*
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- version-*
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
command:
|
||||
- prettier-check
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Install Dependencies
|
||||
working-directory: website/
|
||||
run: npm ci
|
||||
- name: Lint
|
||||
working-directory: website/
|
||||
run: npm run ${{ matrix.command }}
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version-file: website/package.json
|
||||
cache: "npm"
|
||||
cache-dependency-path: website/package-lock.json
|
||||
- working-directory: website/
|
||||
name: Install Dependencies
|
||||
run: npm ci
|
||||
- uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
${{ github.workspace }}/website/api/.docusaurus
|
||||
${{ github.workspace }}/website/api/**/.cache
|
||||
key: |
|
||||
${{ runner.os }}-docusaurus-${{ hashFiles('**/package-lock.json') }}-${{ hashFiles('**.[jt]s', '**.[jt]sx') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-docusaurus-${{ hashFiles('**/package-lock.json') }}
|
||||
- name: Build API Docs via Docusaurus
|
||||
working-directory: website
|
||||
env:
|
||||
NODE_ENV: production
|
||||
run: npm run build -w api
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: api-docs
|
||||
path: website/api/build
|
||||
retention-days: 7
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
- lint
|
||||
- build
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: api-docs
|
||||
path: website/api/build
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version-file: website/package.json
|
||||
cache: "npm"
|
||||
cache-dependency-path: website/package-lock.json
|
||||
- name: Deploy Netlify (Production)
|
||||
working-directory: website/api
|
||||
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
||||
env:
|
||||
NETLIFY_SITE_ID: authentik-api-docs.netlify.app
|
||||
NETLIFY_AUTH_TOKEN: ${{ secrets.NETLIFY_AUTH_TOKEN }}
|
||||
run: npx netlify deploy --no-build --prod
|
||||
- name: Deploy Netlify (Preview)
|
||||
if: github.event_name == 'pull_request' || github.ref != 'refs/heads/main'
|
||||
working-directory: website/api
|
||||
env:
|
||||
NETLIFY_SITE_ID: authentik-api-docs.netlify.app
|
||||
NETLIFY_AUTH_TOKEN: ${{ secrets.NETLIFY_AUTH_TOKEN }}
|
||||
run: |
|
||||
if [ -n "${VAR}" ]; then
|
||||
npx netlify deploy --no-build --alias=deploy-preview-${{ github.event.number }}
|
||||
fi
|
||||
@@ -1,4 +1,4 @@
|
||||
name: authentik-ci-docs
|
||||
name: authentik-ci-website
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -18,18 +18,17 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
command:
|
||||
- lint:lockfile
|
||||
- prettier-check
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Install dependencies
|
||||
working-directory: website/
|
||||
- working-directory: website/
|
||||
run: npm ci
|
||||
- name: Lint
|
||||
working-directory: website/
|
||||
run: npm run ${{ matrix.command }}
|
||||
build-docs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-node@v4
|
||||
@@ -38,14 +37,19 @@ jobs:
|
||||
cache: "npm"
|
||||
cache-dependency-path: website/package-lock.json
|
||||
- working-directory: website/
|
||||
name: Install Dependencies
|
||||
run: npm ci
|
||||
- name: Build Documentation via Docusaurus
|
||||
- name: test
|
||||
working-directory: website/
|
||||
run: npm run build
|
||||
build-integrations:
|
||||
run: npm test
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
name: ${{ matrix.job }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
job:
|
||||
- build
|
||||
- build:integrations
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-node@v4
|
||||
@@ -54,11 +58,10 @@ jobs:
|
||||
cache: "npm"
|
||||
cache-dependency-path: website/package-lock.json
|
||||
- working-directory: website/
|
||||
name: Install Dependencies
|
||||
run: npm ci
|
||||
- name: Build Integrations via Docusaurus
|
||||
- name: build
|
||||
working-directory: website/
|
||||
run: npm run build -w integrations
|
||||
run: npm run ${{ matrix.job }}
|
||||
build-container:
|
||||
if: ${{ github.repository != 'goauthentik/authentik-internal' }}
|
||||
runs-on: ubuntu-latest
|
||||
@@ -112,8 +115,8 @@ jobs:
|
||||
if: always()
|
||||
needs:
|
||||
- lint
|
||||
- build-docs
|
||||
- build-integrations
|
||||
- test
|
||||
- build
|
||||
- build-container
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
@@ -10,8 +10,7 @@ coverage
|
||||
dist
|
||||
out
|
||||
.docusaurus
|
||||
# TODO Replace after moving website to docs
|
||||
website/api/reference
|
||||
website/docs/developer-docs/api/**/*
|
||||
|
||||
## Environment
|
||||
*.env
|
||||
|
||||
11
.vscode/settings.json
vendored
11
.vscode/settings.json
vendored
@@ -7,10 +7,7 @@
|
||||
"!Enumerate sequence",
|
||||
"!Env scalar",
|
||||
"!Env sequence",
|
||||
"!File scalar",
|
||||
"!File sequence",
|
||||
"!Find sequence",
|
||||
"!FindObject sequence",
|
||||
"!Format sequence",
|
||||
"!If sequence",
|
||||
"!Index scalar",
|
||||
@@ -34,10 +31,6 @@
|
||||
"ignoreCase": false
|
||||
}
|
||||
],
|
||||
"go.testFlags": [
|
||||
"-count=1"
|
||||
],
|
||||
"github-actions.workflows.pinned.workflows": [
|
||||
".github/workflows/ci-main.yml"
|
||||
]
|
||||
"go.testFlags": ["-count=1"],
|
||||
"github-actions.workflows.pinned.workflows": [".github/workflows/ci-main.yml"]
|
||||
}
|
||||
|
||||
44
.vscode/tasks.json
vendored
44
.vscode/tasks.json
vendored
@@ -4,7 +4,12 @@
|
||||
{
|
||||
"label": "authentik/core: make",
|
||||
"command": "uv",
|
||||
"args": ["run", "make", "lint-fix", "lint"],
|
||||
"args": [
|
||||
"run",
|
||||
"make",
|
||||
"lint-fix",
|
||||
"lint"
|
||||
],
|
||||
"presentation": {
|
||||
"panel": "new"
|
||||
},
|
||||
@@ -13,7 +18,11 @@
|
||||
{
|
||||
"label": "authentik/core: run",
|
||||
"command": "uv",
|
||||
"args": ["run", "ak", "server"],
|
||||
"args": [
|
||||
"run",
|
||||
"ak",
|
||||
"server"
|
||||
],
|
||||
"group": "build",
|
||||
"presentation": {
|
||||
"panel": "dedicated",
|
||||
@@ -23,13 +32,17 @@
|
||||
{
|
||||
"label": "authentik/web: make",
|
||||
"command": "make",
|
||||
"args": ["web"],
|
||||
"args": [
|
||||
"web"
|
||||
],
|
||||
"group": "build"
|
||||
},
|
||||
{
|
||||
"label": "authentik/web: watch",
|
||||
"command": "make",
|
||||
"args": ["web-watch"],
|
||||
"args": [
|
||||
"web-watch"
|
||||
],
|
||||
"group": "build",
|
||||
"presentation": {
|
||||
"panel": "dedicated",
|
||||
@@ -39,19 +52,26 @@
|
||||
{
|
||||
"label": "authentik: install",
|
||||
"command": "make",
|
||||
"args": ["install", "-j4"],
|
||||
"args": [
|
||||
"install",
|
||||
"-j4"
|
||||
],
|
||||
"group": "build"
|
||||
},
|
||||
{
|
||||
"label": "authentik/docs: make",
|
||||
"label": "authentik/website: make",
|
||||
"command": "make",
|
||||
"args": ["docs"],
|
||||
"args": [
|
||||
"website"
|
||||
],
|
||||
"group": "build"
|
||||
},
|
||||
{
|
||||
"label": "authentik/docs: watch",
|
||||
"label": "authentik/website: watch",
|
||||
"command": "make",
|
||||
"args": ["docs-watch"],
|
||||
"args": [
|
||||
"website-watch"
|
||||
],
|
||||
"group": "build",
|
||||
"presentation": {
|
||||
"panel": "dedicated",
|
||||
@@ -61,7 +81,11 @@
|
||||
{
|
||||
"label": "authentik/api: generate",
|
||||
"command": "uv",
|
||||
"args": ["run", "make", "gen"],
|
||||
"args": [
|
||||
"run",
|
||||
"make",
|
||||
"gen"
|
||||
],
|
||||
"group": "build"
|
||||
}
|
||||
]
|
||||
|
||||
@@ -32,12 +32,8 @@ tests/wdio/ @goauthentik/frontend
|
||||
locale/ @goauthentik/backend @goauthentik/frontend
|
||||
web/xliff/ @goauthentik/backend @goauthentik/frontend
|
||||
# Docs & Website
|
||||
docs/ @goauthentik/docs
|
||||
# TODO Remove after moving website to docs
|
||||
website/ @goauthentik/docs
|
||||
CODE_OF_CONDUCT.md @goauthentik/docs
|
||||
# Security
|
||||
SECURITY.md @goauthentik/security @goauthentik/docs
|
||||
# TODO Remove after moving website to docs
|
||||
website/security/ @goauthentik/security @goauthentik/docs
|
||||
docs/security/ @goauthentik/security @goauthentik/docs
|
||||
website/docs/security/ @goauthentik/security @goauthentik/docs
|
||||
|
||||
@@ -14,11 +14,10 @@ RUN --mount=type=bind,target=/work/web/package.json,src=./web/package.json \
|
||||
--mount=type=bind,target=/work/web/packages/sfe/package.json,src=./web/packages/sfe/package.json \
|
||||
--mount=type=bind,target=/work/web/scripts,src=./web/scripts \
|
||||
--mount=type=cache,id=npm-ak,sharing=shared,target=/root/.npm \
|
||||
npm ci
|
||||
npm ci --include=dev
|
||||
|
||||
COPY ./package.json /work
|
||||
COPY ./web /work/web/
|
||||
# TODO: Update this after moving website to docs
|
||||
COPY ./website /work/website/
|
||||
COPY ./gen-ts-api /work/web/node_modules/@goauthentik/api
|
||||
|
||||
@@ -63,7 +62,7 @@ RUN --mount=type=cache,sharing=locked,target=/go/pkg/mod \
|
||||
go build -o /go/authentik ./cmd/server
|
||||
|
||||
# Stage 3: MaxMind GeoIP
|
||||
FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v7.1.1 AS geoip
|
||||
FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v7.1.0 AS geoip
|
||||
|
||||
ENV GEOIPUPDATE_EDITION_IDS="GeoLite2-City GeoLite2-ASN"
|
||||
ENV GEOIPUPDATE_VERBOSE="1"
|
||||
@@ -76,7 +75,7 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \
|
||||
/bin/sh -c "GEOIPUPDATE_LICENSE_KEY_FILE=/run/secrets/GEOIPUPDATE_LICENSE_KEY /usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0"
|
||||
|
||||
# Stage 4: Download uv
|
||||
FROM ghcr.io/astral-sh/uv:0.8.2 AS uv
|
||||
FROM ghcr.io/astral-sh/uv:0.7.18 AS uv
|
||||
# Stage 5: Base python image
|
||||
FROM ghcr.io/goauthentik/fips-python:3.13.5-slim-bookworm-fips AS python-base
|
||||
|
||||
|
||||
57
Makefile
57
Makefile
@@ -1,4 +1,4 @@
|
||||
.PHONY: gen dev-reset all clean test web docs
|
||||
.PHONY: gen dev-reset all clean test web website
|
||||
|
||||
SHELL := /usr/bin/env bash
|
||||
.SHELLFLAGS += ${SHELLFLAGS} -e -o pipefail
|
||||
@@ -73,7 +73,7 @@ core-i18n-extract:
|
||||
--ignore website \
|
||||
-l en
|
||||
|
||||
install: node-install docs-install core-install ## Install all requires dependencies for `node`, `docs` and `core`
|
||||
install: web-install website-install core-install ## Install all requires dependencies for `web`, `website` and `core`
|
||||
|
||||
dev-drop-db:
|
||||
dropdb -U ${pg_user} -h ${pg_host} ${pg_name}
|
||||
@@ -121,7 +121,7 @@ gen-diff: ## (Release) generate the changelog diff between the current schema a
|
||||
sed -i 's/}/}/g' diff.md
|
||||
npx prettier --write diff.md
|
||||
|
||||
gen-clean-ts: ## Remove generated API client for TypeScript
|
||||
gen-clean-ts: ## Remove generated API client for Typescript
|
||||
rm -rf ${PWD}/${GEN_API_TS}/
|
||||
rm -rf ${PWD}/web/node_modules/@goauthentik/api/
|
||||
|
||||
@@ -183,23 +183,18 @@ gen-dev-config: ## Generate a local development config file
|
||||
|
||||
gen: gen-build gen-client-ts
|
||||
|
||||
#########################
|
||||
## Node.js
|
||||
#########################
|
||||
|
||||
node-install: ## Install the necessary libraries to build Node.js packages
|
||||
npm ci
|
||||
npm ci --prefix web
|
||||
|
||||
#########################
|
||||
## Web
|
||||
#########################
|
||||
|
||||
web-build: node-install ## Build the Authentik UI
|
||||
web-build: web-install ## Build the Authentik UI
|
||||
cd web && npm run build
|
||||
|
||||
web: web-lint-fix web-lint web-check-compile ## Automatically fix formatting issues in the Authentik UI source code, lint the code, and compile it
|
||||
|
||||
web-install: ## Install the necessary libraries to build the Authentik UI
|
||||
cd web && npm ci
|
||||
|
||||
web-test: ## Run tests for the Authentik UI
|
||||
cd web && npm run test
|
||||
|
||||
@@ -226,40 +221,22 @@ web-i18n-extract:
|
||||
cd web && npm run extract-locales
|
||||
|
||||
#########################
|
||||
## Docs
|
||||
## Website
|
||||
#########################
|
||||
|
||||
docs: docs-lint-fix docs-build ## Automatically fix formatting issues in the Authentik docs source code, lint the code, and compile it
|
||||
website: website-lint-fix website-build ## Automatically fix formatting issues in the Authentik website/docs source code, lint the code, and compile it
|
||||
|
||||
docs-install:
|
||||
npm ci --prefix website
|
||||
website-install:
|
||||
cd website && npm ci
|
||||
|
||||
docs-lint-fix: lint-codespell
|
||||
npm run prettier --prefix website
|
||||
website-lint-fix: lint-codespell
|
||||
cd website && npm run prettier
|
||||
|
||||
docs-build:
|
||||
npm run build --prefix website
|
||||
website-build:
|
||||
cd website && npm run build
|
||||
|
||||
docs-watch: ## Build and watch the topics documentation
|
||||
npm run start --prefix website
|
||||
|
||||
integrations: docs-lint-fix integrations-build ## Fix formatting issues in the integrations source code, lint the code, and compile it
|
||||
|
||||
integrations-build:
|
||||
npm run build --prefix website -w integrations
|
||||
|
||||
integrations-watch: ## Build and watch the Integrations documentation
|
||||
npm run start --prefix website -w integrations
|
||||
|
||||
docs-api-build:
|
||||
npm run build --prefix website -w api
|
||||
|
||||
docs-api-watch: ## Build and watch the API documentation
|
||||
npm run build:api --prefix website -w api
|
||||
npm run start --prefix website -w api
|
||||
|
||||
docs-api-clean: ## Clean generated API documentation
|
||||
npm run build:api:clean --prefix website -w api
|
||||
website-watch: ## Build and watch the documentation website, updating automatically
|
||||
cd website && npm run watch
|
||||
|
||||
#########################
|
||||
## Docker
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from os import environ
|
||||
|
||||
__version__ = "2025.6.4"
|
||||
__version__ = "2025.6.3"
|
||||
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"
|
||||
|
||||
|
||||
|
||||
44
authentik/blueprints/tests/fixtures/tags.yaml
vendored
44
authentik/blueprints/tests/fixtures/tags.yaml
vendored
@@ -12,8 +12,8 @@ context:
|
||||
context1: context-nested-value
|
||||
context2: !Context context1
|
||||
entries:
|
||||
- model: !Format ["%%s", authentik_sources_oauth.oauthsource]
|
||||
state: !Format ["%%s", present]
|
||||
- model: !Format ["%s", authentik_sources_oauth.oauthsource]
|
||||
state: !Format ["%s", present]
|
||||
identifiers:
|
||||
slug: test
|
||||
attrs:
|
||||
@@ -27,23 +27,20 @@ entries:
|
||||
[slug, default-source-authentication],
|
||||
]
|
||||
enrollment_flow:
|
||||
!Find [!Format ["%%s", authentik_flows.Flow], [slug, default-source-enrollment]]
|
||||
!Find [!Format ["%s", authentik_flows.Flow], [slug, default-source-enrollment]]
|
||||
- attrs:
|
||||
expression: return True
|
||||
identifiers:
|
||||
name: !Format [foo-%%s-%%s-%%s, !Context foo, !Context bar, qux]
|
||||
name: !Format [foo-%s-%s-%s, !Context foo, !Context bar, qux]
|
||||
id: policy
|
||||
model: authentik_policies_expression.expressionpolicy
|
||||
- attrs:
|
||||
attributes:
|
||||
env_null: !Env [bar-baz, null]
|
||||
file_content: !File '%(file_name)s'
|
||||
file_default: !File ['%(file_default_name)s', 'default']
|
||||
file_non_existent: !File '/does-not-exist'
|
||||
json_parse: !ParseJSON '{"foo": "bar"}'
|
||||
policy_pk1:
|
||||
!Format [
|
||||
"%%s-%%s",
|
||||
"%s-%s",
|
||||
!Find [
|
||||
authentik_policies_expression.expressionpolicy,
|
||||
[
|
||||
@@ -54,29 +51,29 @@ entries:
|
||||
],
|
||||
suffix,
|
||||
]
|
||||
policy_pk2: !Format ["%%s-%%s", !KeyOf policy, suffix]
|
||||
policy_pk2: !Format ["%s-%s", !KeyOf policy, suffix]
|
||||
boolAnd:
|
||||
!Condition [AND, !Context foo, !Format ["%%s", "a_string"], 1]
|
||||
!Condition [AND, !Context foo, !Format ["%s", "a_string"], 1]
|
||||
boolNand:
|
||||
!Condition [NAND, !Context foo, !Format ["%%s", "a_string"], 1]
|
||||
!Condition [NAND, !Context foo, !Format ["%s", "a_string"], 1]
|
||||
boolOr:
|
||||
!Condition [
|
||||
OR,
|
||||
!Context foo,
|
||||
!Format ["%%s", "a_string"],
|
||||
!Format ["%s", "a_string"],
|
||||
null,
|
||||
]
|
||||
boolNor:
|
||||
!Condition [
|
||||
NOR,
|
||||
!Context foo,
|
||||
!Format ["%%s", "a_string"],
|
||||
!Format ["%s", "a_string"],
|
||||
null,
|
||||
]
|
||||
boolXor:
|
||||
!Condition [XOR, !Context foo, !Format ["%%s", "a_string"], 1]
|
||||
!Condition [XOR, !Context foo, !Format ["%s", "a_string"], 1]
|
||||
boolXnor:
|
||||
!Condition [XNOR, !Context foo, !Format ["%%s", "a_string"], 1]
|
||||
!Condition [XNOR, !Context foo, !Format ["%s", "a_string"], 1]
|
||||
boolComplex:
|
||||
!Condition [
|
||||
XNOR,
|
||||
@@ -92,7 +89,7 @@ entries:
|
||||
{
|
||||
with: { keys: "and_values" },
|
||||
and_nested_custom_tags:
|
||||
!Format ["foo-%%s", !Context foo],
|
||||
!Format ["foo-%s", !Context foo],
|
||||
},
|
||||
},
|
||||
null,
|
||||
@@ -101,7 +98,7 @@ entries:
|
||||
!If [
|
||||
!Condition [AND, false],
|
||||
null,
|
||||
[list, with, items, !Format ["foo-%%s", !Context foo]],
|
||||
[list, with, items, !Format ["foo-%s", !Context foo]],
|
||||
]
|
||||
if_true_simple: !If [!Context foo, true, text]
|
||||
if_short: !If [!Context foo]
|
||||
@@ -109,22 +106,22 @@ entries:
|
||||
enumerate_mapping_to_mapping: !Enumerate [
|
||||
!Context mapping,
|
||||
MAP,
|
||||
[!Format ["prefix-%%s", !Index 0], !Format ["other-prefix-%%s", !Value 0]]
|
||||
[!Format ["prefix-%s", !Index 0], !Format ["other-prefix-%s", !Value 0]]
|
||||
]
|
||||
enumerate_mapping_to_sequence: !Enumerate [
|
||||
!Context mapping,
|
||||
SEQ,
|
||||
!Format ["prefixed-pair-%%s-%%s", !Index 0, !Value 0]
|
||||
!Format ["prefixed-pair-%s-%s", !Index 0, !Value 0]
|
||||
]
|
||||
enumerate_sequence_to_sequence: !Enumerate [
|
||||
!Context sequence,
|
||||
SEQ,
|
||||
!Format ["prefixed-items-%%s-%%s", !Index 0, !Value 0]
|
||||
!Format ["prefixed-items-%s-%s", !Index 0, !Value 0]
|
||||
]
|
||||
enumerate_sequence_to_mapping: !Enumerate [
|
||||
!Context sequence,
|
||||
MAP,
|
||||
[!Format ["index: %%d", !Index 0], !Value 0]
|
||||
[!Format ["index: %d", !Index 0], !Value 0]
|
||||
]
|
||||
nested_complex_enumeration: !Enumerate [
|
||||
!Context sequence,
|
||||
@@ -135,9 +132,9 @@ entries:
|
||||
!Context mapping,
|
||||
MAP,
|
||||
[
|
||||
!Format ["%%s", !Index 0],
|
||||
!Format ["%s", !Index 0],
|
||||
[
|
||||
!Enumerate [!Value 2, SEQ, !Format ["prefixed-%%s", !Value 0]],
|
||||
!Enumerate [!Value 2, SEQ, !Format ["prefixed-%s", !Value 0]],
|
||||
{
|
||||
outer_value: !Value 1,
|
||||
outer_index: !Index 1,
|
||||
@@ -154,7 +151,6 @@ entries:
|
||||
at_index_sequence_default: !AtIndex [!Context sequence, 100, "non existent"]
|
||||
at_index_mapping: !AtIndex [!Context mapping, "key2"]
|
||||
at_index_mapping_default: !AtIndex [!Context mapping, "invalid", "non existent"]
|
||||
find_object: !AtIndex [!FindObject [authentik_providers_oauth2.scopemapping, [scope_name, openid]], managed]
|
||||
identifiers:
|
||||
name: test
|
||||
conditions:
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
"""Test blueprints v1"""
|
||||
|
||||
from os import chmod, environ, unlink, write
|
||||
from tempfile import mkstemp
|
||||
from os import environ
|
||||
|
||||
from django.test import TransactionTestCase
|
||||
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.blueprints.v1.exporter import FlowExporter
|
||||
from authentik.blueprints.v1.importer import Importer, transaction_rollback
|
||||
from authentik.core.models import Group
|
||||
@@ -128,119 +126,102 @@ class TestBlueprintsV1(TransactionTestCase):
|
||||
|
||||
self.assertEqual(Prompt.objects.filter(field_key="username").count(), count_before)
|
||||
|
||||
@apply_blueprint("system/providers-oauth2.yaml")
|
||||
def test_import_yaml_tags(self):
|
||||
"""Test some yaml tags"""
|
||||
ExpressionPolicy.objects.filter(name="foo-bar-baz-qux").delete()
|
||||
Group.objects.filter(name="test").delete()
|
||||
environ["foo"] = generate_id()
|
||||
file, file_name = mkstemp()
|
||||
write(file, b"foo")
|
||||
_, file_default_name = mkstemp()
|
||||
chmod(file_default_name, 0o000) # Remove all permissions so we can't read the file
|
||||
importer = Importer.from_string(
|
||||
load_fixture(
|
||||
"fixtures/tags.yaml",
|
||||
file_name=file_name,
|
||||
file_default_name=file_default_name,
|
||||
),
|
||||
{"bar": "baz"},
|
||||
)
|
||||
importer = Importer.from_string(load_fixture("fixtures/tags.yaml"), {"bar": "baz"})
|
||||
self.assertTrue(importer.validate()[0])
|
||||
self.assertTrue(importer.apply())
|
||||
policy = ExpressionPolicy.objects.filter(name="foo-bar-baz-qux").first()
|
||||
self.assertTrue(policy)
|
||||
group = Group.objects.filter(name="test").first()
|
||||
self.assertIsNotNone(group)
|
||||
self.assertEqual(
|
||||
group.attributes,
|
||||
{
|
||||
"policy_pk1": str(policy.pk) + "-suffix",
|
||||
"policy_pk2": str(policy.pk) + "-suffix",
|
||||
"boolAnd": True,
|
||||
"boolNand": False,
|
||||
"boolOr": True,
|
||||
"boolNor": False,
|
||||
"boolXor": True,
|
||||
"boolXnor": False,
|
||||
"boolComplex": True,
|
||||
"if_true_complex": {
|
||||
"dictionary": {
|
||||
"with": {"keys": "and_values"},
|
||||
"and_nested_custom_tags": "foo-bar",
|
||||
}
|
||||
},
|
||||
"if_false_complex": ["list", "with", "items", "foo-bar"],
|
||||
"if_true_simple": True,
|
||||
"if_short": True,
|
||||
"if_false_simple": 2,
|
||||
"enumerate_mapping_to_mapping": {
|
||||
"prefix-key1": "other-prefix-value",
|
||||
"prefix-key2": "other-prefix-2",
|
||||
},
|
||||
"enumerate_mapping_to_sequence": [
|
||||
"prefixed-pair-key1-value",
|
||||
"prefixed-pair-key2-2",
|
||||
],
|
||||
"enumerate_sequence_to_sequence": [
|
||||
"prefixed-items-0-foo",
|
||||
"prefixed-items-1-bar",
|
||||
],
|
||||
"enumerate_sequence_to_mapping": {"index: 0": "foo", "index: 1": "bar"},
|
||||
"nested_complex_enumeration": {
|
||||
"0": {
|
||||
"key1": [
|
||||
["prefixed-f", "prefixed-o", "prefixed-o"],
|
||||
{
|
||||
"outer_value": "foo",
|
||||
"outer_index": 0,
|
||||
"middle_value": "value",
|
||||
"middle_index": "key1",
|
||||
},
|
||||
],
|
||||
"key2": [
|
||||
["prefixed-f", "prefixed-o", "prefixed-o"],
|
||||
{
|
||||
"outer_value": "foo",
|
||||
"outer_index": 0,
|
||||
"middle_value": 2,
|
||||
"middle_index": "key2",
|
||||
},
|
||||
],
|
||||
self.assertTrue(
|
||||
Group.objects.filter(
|
||||
attributes={
|
||||
"policy_pk1": str(policy.pk) + "-suffix",
|
||||
"policy_pk2": str(policy.pk) + "-suffix",
|
||||
"boolAnd": True,
|
||||
"boolNand": False,
|
||||
"boolOr": True,
|
||||
"boolNor": False,
|
||||
"boolXor": True,
|
||||
"boolXnor": False,
|
||||
"boolComplex": True,
|
||||
"if_true_complex": {
|
||||
"dictionary": {
|
||||
"with": {"keys": "and_values"},
|
||||
"and_nested_custom_tags": "foo-bar",
|
||||
}
|
||||
},
|
||||
"1": {
|
||||
"key1": [
|
||||
["prefixed-b", "prefixed-a", "prefixed-r"],
|
||||
{
|
||||
"outer_value": "bar",
|
||||
"outer_index": 1,
|
||||
"middle_value": "value",
|
||||
"middle_index": "key1",
|
||||
},
|
||||
],
|
||||
"key2": [
|
||||
["prefixed-b", "prefixed-a", "prefixed-r"],
|
||||
{
|
||||
"outer_value": "bar",
|
||||
"outer_index": 1,
|
||||
"middle_value": 2,
|
||||
"middle_index": "key2",
|
||||
},
|
||||
],
|
||||
"if_false_complex": ["list", "with", "items", "foo-bar"],
|
||||
"if_true_simple": True,
|
||||
"if_short": True,
|
||||
"if_false_simple": 2,
|
||||
"enumerate_mapping_to_mapping": {
|
||||
"prefix-key1": "other-prefix-value",
|
||||
"prefix-key2": "other-prefix-2",
|
||||
},
|
||||
},
|
||||
"nested_context": "context-nested-value",
|
||||
"env_null": None,
|
||||
"file_content": "foo",
|
||||
"file_default": "default",
|
||||
"file_non_existent": None,
|
||||
"json_parse": {"foo": "bar"},
|
||||
"at_index_sequence": "foo",
|
||||
"at_index_sequence_default": "non existent",
|
||||
"at_index_mapping": 2,
|
||||
"at_index_mapping_default": "non existent",
|
||||
"find_object": "goauthentik.io/providers/oauth2/scope-openid",
|
||||
},
|
||||
"enumerate_mapping_to_sequence": [
|
||||
"prefixed-pair-key1-value",
|
||||
"prefixed-pair-key2-2",
|
||||
],
|
||||
"enumerate_sequence_to_sequence": [
|
||||
"prefixed-items-0-foo",
|
||||
"prefixed-items-1-bar",
|
||||
],
|
||||
"enumerate_sequence_to_mapping": {"index: 0": "foo", "index: 1": "bar"},
|
||||
"nested_complex_enumeration": {
|
||||
"0": {
|
||||
"key1": [
|
||||
["prefixed-f", "prefixed-o", "prefixed-o"],
|
||||
{
|
||||
"outer_value": "foo",
|
||||
"outer_index": 0,
|
||||
"middle_value": "value",
|
||||
"middle_index": "key1",
|
||||
},
|
||||
],
|
||||
"key2": [
|
||||
["prefixed-f", "prefixed-o", "prefixed-o"],
|
||||
{
|
||||
"outer_value": "foo",
|
||||
"outer_index": 0,
|
||||
"middle_value": 2,
|
||||
"middle_index": "key2",
|
||||
},
|
||||
],
|
||||
},
|
||||
"1": {
|
||||
"key1": [
|
||||
["prefixed-b", "prefixed-a", "prefixed-r"],
|
||||
{
|
||||
"outer_value": "bar",
|
||||
"outer_index": 1,
|
||||
"middle_value": "value",
|
||||
"middle_index": "key1",
|
||||
},
|
||||
],
|
||||
"key2": [
|
||||
["prefixed-b", "prefixed-a", "prefixed-r"],
|
||||
{
|
||||
"outer_value": "bar",
|
||||
"outer_index": 1,
|
||||
"middle_value": 2,
|
||||
"middle_index": "key2",
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
"nested_context": "context-nested-value",
|
||||
"env_null": None,
|
||||
"json_parse": {"foo": "bar"},
|
||||
"at_index_sequence": "foo",
|
||||
"at_index_sequence_default": "non existent",
|
||||
"at_index_mapping": 2,
|
||||
"at_index_mapping_default": "non existent",
|
||||
}
|
||||
).exists()
|
||||
)
|
||||
self.assertTrue(
|
||||
OAuthSource.objects.filter(
|
||||
@@ -248,8 +229,6 @@ class TestBlueprintsV1(TransactionTestCase):
|
||||
consumer_key=environ["foo"],
|
||||
)
|
||||
)
|
||||
unlink(file_name)
|
||||
unlink(file_default_name)
|
||||
|
||||
def test_export_validate_import_policies(self):
|
||||
"""Test export and validate it"""
|
||||
|
||||
@@ -18,15 +18,12 @@ from django.db.models import Model, Q
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.fields import Field
|
||||
from rest_framework.serializers import Serializer
|
||||
from structlog.stdlib import get_logger
|
||||
from yaml import SafeDumper, SafeLoader, ScalarNode, SequenceNode
|
||||
|
||||
from authentik.lib.models import SerializerModel
|
||||
from authentik.lib.sentry import SentryIgnoredException
|
||||
from authentik.policies.models import PolicyBindingModel
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
class UNSET:
|
||||
"""Used to test whether a key has not been set."""
|
||||
@@ -271,34 +268,6 @@ class Env(YAMLTag):
|
||||
return getenv(self.key) or self.default
|
||||
|
||||
|
||||
class File(YAMLTag):
|
||||
"""Lookup file with optional default"""
|
||||
|
||||
path: str
|
||||
default: Any | None
|
||||
|
||||
def __init__(self, loader: "BlueprintLoader", node: ScalarNode | SequenceNode) -> None:
|
||||
super().__init__()
|
||||
self.default = None
|
||||
if isinstance(node, ScalarNode):
|
||||
self.path = node.value
|
||||
if isinstance(node, SequenceNode):
|
||||
self.path = loader.construct_object(node.value[0])
|
||||
self.default = loader.construct_object(node.value[1])
|
||||
|
||||
def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
|
||||
try:
|
||||
with open(self.path, encoding="utf8") as _file:
|
||||
return _file.read().strip()
|
||||
except OSError as exc:
|
||||
LOGGER.warning(
|
||||
"Failed to read file. Falling back to default value",
|
||||
path=self.path,
|
||||
exc=exc,
|
||||
)
|
||||
return self.default
|
||||
|
||||
|
||||
class Context(YAMLTag):
|
||||
"""Lookup key from instance context"""
|
||||
|
||||
@@ -367,7 +336,7 @@ class Format(YAMLTag):
|
||||
|
||||
|
||||
class Find(YAMLTag):
|
||||
"""Find any object primary key"""
|
||||
"""Find any object"""
|
||||
|
||||
model_name: str | YAMLTag
|
||||
conditions: list[list]
|
||||
@@ -382,7 +351,7 @@ class Find(YAMLTag):
|
||||
values.append(loader.construct_object(node_values))
|
||||
self.conditions.append(values)
|
||||
|
||||
def _get_instance(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
|
||||
def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
|
||||
if isinstance(self.model_name, YAMLTag):
|
||||
model_name = self.model_name.resolve(entry, blueprint)
|
||||
else:
|
||||
@@ -404,29 +373,12 @@ class Find(YAMLTag):
|
||||
else:
|
||||
query_value = cond[1]
|
||||
query &= Q(**{query_key: query_value})
|
||||
return model_class.objects.filter(query).first()
|
||||
|
||||
def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
|
||||
instance = self._get_instance(entry, blueprint)
|
||||
instance = model_class.objects.filter(query).first()
|
||||
if instance:
|
||||
return instance.pk
|
||||
return None
|
||||
|
||||
|
||||
class FindObject(Find):
|
||||
"""Find any object"""
|
||||
|
||||
def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
|
||||
instance = self._get_instance(entry, blueprint)
|
||||
if not instance:
|
||||
return None
|
||||
if not isinstance(instance, SerializerModel):
|
||||
raise EntryInvalidError.from_entry(
|
||||
f"Model {self.model_name} is not resolvable through FindObject", entry
|
||||
)
|
||||
return instance.serializer(instance=instance).data
|
||||
|
||||
|
||||
class Condition(YAMLTag):
|
||||
"""Convert all values to a single boolean"""
|
||||
|
||||
@@ -722,13 +674,11 @@ class BlueprintLoader(SafeLoader):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.add_constructor("!KeyOf", KeyOf)
|
||||
self.add_constructor("!Find", Find)
|
||||
self.add_constructor("!FindObject", FindObject)
|
||||
self.add_constructor("!Context", Context)
|
||||
self.add_constructor("!Format", Format)
|
||||
self.add_constructor("!Condition", Condition)
|
||||
self.add_constructor("!If", If)
|
||||
self.add_constructor("!Env", Env)
|
||||
self.add_constructor("!File", File)
|
||||
self.add_constructor("!Enumerate", Enumerate)
|
||||
self.add_constructor("!Value", Value)
|
||||
self.add_constructor("!Index", Index)
|
||||
|
||||
@@ -52,27 +52,6 @@ class TestBrands(APITestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def test_brand_subdomain_same_suffix(self):
|
||||
"""Test Current brand API"""
|
||||
Brand.objects.all().delete()
|
||||
Brand.objects.create(domain="bar.baz", branding_title="custom")
|
||||
Brand.objects.create(domain="foo.bar.baz", branding_title="custom")
|
||||
self.assertJSONEqual(
|
||||
self.client.get(
|
||||
reverse("authentik_api:brand-current"), HTTP_HOST="foo.bar.baz"
|
||||
).content.decode(),
|
||||
{
|
||||
"branding_logo": "/static/dist/assets/icons/icon_left_brand.svg",
|
||||
"branding_favicon": "/static/dist/assets/icons/icon.png",
|
||||
"branding_title": "custom",
|
||||
"branding_custom_css": "",
|
||||
"matched_domain": "foo.bar.baz",
|
||||
"ui_footer_links": [],
|
||||
"ui_theme": Themes.AUTOMATIC,
|
||||
"default_locale": "",
|
||||
},
|
||||
)
|
||||
|
||||
def test_fallback(self):
|
||||
"""Test fallback brand"""
|
||||
Brand.objects.all().delete()
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Any
|
||||
|
||||
from django.db.models import F, Q
|
||||
from django.db.models import Value as V
|
||||
from django.db.models.functions import Length
|
||||
from django.http.request import HttpRequest
|
||||
from django.utils.html import _json_script_escapes
|
||||
from django.utils.safestring import mark_safe
|
||||
@@ -21,9 +20,9 @@ DEFAULT_BRAND = Brand(domain="fallback")
|
||||
def get_brand_for_request(request: HttpRequest) -> Brand:
|
||||
"""Get brand object for current request"""
|
||||
db_brands = (
|
||||
Brand.objects.annotate(host_domain=V(request.get_host()), match_length=Length("domain"))
|
||||
Brand.objects.annotate(host_domain=V(request.get_host()))
|
||||
.filter(Q(host_domain__iendswith=F("domain")) | _q_default)
|
||||
.order_by("-match_length", "default")
|
||||
.order_by("default")
|
||||
)
|
||||
brands = list(db_brands.all())
|
||||
if len(brands) < 1:
|
||||
@@ -43,6 +42,6 @@ def context_processor(request: HttpRequest) -> dict[str, Any]:
|
||||
"brand": brand,
|
||||
"brand_css": brand_css,
|
||||
"footer_links": tenant.footer_links,
|
||||
"html_meta": get_http_meta(),
|
||||
"html_meta": {**get_http_meta()},
|
||||
"version": get_full_version(),
|
||||
}
|
||||
|
||||
@@ -149,10 +149,10 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
|
||||
return applications
|
||||
|
||||
def _filter_applications_with_launch_url(
|
||||
self, paginated_apps: Iterator[Application]
|
||||
self, pagined_apps: Iterator[Application]
|
||||
) -> list[Application]:
|
||||
applications = []
|
||||
for app in paginated_apps:
|
||||
for app in pagined_apps:
|
||||
if app.get_launch_url():
|
||||
applications.append(app)
|
||||
return applications
|
||||
|
||||
@@ -11,6 +11,7 @@ from authentik.core.expression.exceptions import SkipObjectException
|
||||
from authentik.core.models import User
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.expression.evaluator import BaseEvaluator
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.policies.types import PolicyRequest
|
||||
|
||||
PROPERTY_MAPPING_TIME = Histogram(
|
||||
@@ -68,11 +69,12 @@ class PropertyMappingEvaluator(BaseEvaluator):
|
||||
# For dry-run requests we don't save exceptions
|
||||
if self.dry_run:
|
||||
return
|
||||
error_string = exception_to_string(exc)
|
||||
event = Event.new(
|
||||
EventAction.PROPERTY_MAPPING_EXCEPTION,
|
||||
expression=expression_source,
|
||||
message="Failed to execute property mapping",
|
||||
).with_exception(exc)
|
||||
message=error_string,
|
||||
)
|
||||
if "request" in self._context:
|
||||
req: PolicyRequest = self._context["request"]
|
||||
if req.http_request:
|
||||
|
||||
@@ -5,7 +5,6 @@ from contextvars import ContextVar
|
||||
from functools import partial
|
||||
from uuid import uuid4
|
||||
|
||||
from django.contrib.auth import logout
|
||||
from django.contrib.auth.models import AnonymousUser
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
@@ -59,11 +58,6 @@ class AuthenticationMiddleware(MiddlewareMixin):
|
||||
request.user = SimpleLazyObject(lambda: get_user(request))
|
||||
request.auser = partial(aget_user, request)
|
||||
|
||||
user = request.user
|
||||
if user and user.is_authenticated and not user.is_active:
|
||||
logout(request)
|
||||
raise AssertionError()
|
||||
|
||||
|
||||
class ImpersonateMiddleware:
|
||||
"""Middleware to impersonate users"""
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
# Generated by Django 5.1.11 on 2025-07-03 13:08
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_core", "0048_delete_oldauthenticatedsession_content_type"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterModelOptions(
|
||||
name="token",
|
||||
options={
|
||||
"permissions": [
|
||||
("view_token_key", "View token's key"),
|
||||
("set_token_key", "Set a token's key"),
|
||||
],
|
||||
"verbose_name": "Token",
|
||||
"verbose_name_plural": "Tokens",
|
||||
},
|
||||
),
|
||||
]
|
||||
@@ -953,10 +953,7 @@ class Token(SerializerModel, ManagedModel, ExpiringModel):
|
||||
models.Index(fields=["identifier"]),
|
||||
models.Index(fields=["key"]),
|
||||
]
|
||||
permissions = [
|
||||
("view_token_key", _("View token's key")),
|
||||
("set_token_key", _("Set a token's key")),
|
||||
]
|
||||
permissions = [("view_token_key", _("View token's key"))]
|
||||
|
||||
def __str__(self):
|
||||
description = f"{self.identifier}"
|
||||
|
||||
@@ -79,8 +79,8 @@ class SourceFlowManager:
|
||||
|
||||
identifier: str
|
||||
|
||||
user_connection_type: type[UserSourceConnection]
|
||||
group_connection_type: type[GroupSourceConnection]
|
||||
user_connection_type: type[UserSourceConnection] = UserSourceConnection
|
||||
group_connection_type: type[GroupSourceConnection] = GroupSourceConnection
|
||||
|
||||
user_info: dict[str, Any]
|
||||
policy_context: dict[str, Any]
|
||||
|
||||
@@ -10,7 +10,6 @@
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1">
|
||||
{# Darkreader breaks the site regardless of theme as its not compatible with webcomponents, and we default to a dark theme based on preferred colour-scheme #}
|
||||
<meta name="darkreader-lock">
|
||||
<base href="{{ base_url_rel }}" />
|
||||
<title>{% block title %}{% trans title|default:brand.branding_title %}{% endblock %}</title>
|
||||
<link rel="icon" href="{{ brand.branding_favicon_url }}">
|
||||
<link rel="shortcut icon" href="{{ brand.branding_favicon_url }}">
|
||||
|
||||
@@ -12,7 +12,6 @@ from rest_framework.fields import CharField, IntegerField
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.validators import UniqueValidator
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
@@ -54,7 +53,6 @@ class LicenseSerializer(ModelSerializer):
|
||||
"external_users",
|
||||
]
|
||||
extra_kwargs = {
|
||||
"key": {"validators": [UniqueValidator(queryset=License.objects.all())]},
|
||||
"name": {"read_only": True},
|
||||
"expiry": {"read_only": True},
|
||||
"internal_users": {"read_only": True},
|
||||
|
||||
@@ -65,17 +65,13 @@ class EnterpriseAuditMiddleware(AuditMiddleware):
|
||||
data[field.name] = deepcopy(field_value)
|
||||
return cleanse_dict(data)
|
||||
|
||||
def diff(self, before: dict, after: dict, update_fields: list[str] | None = None) -> dict:
|
||||
def diff(self, before: dict, after: dict) -> dict:
|
||||
"""Generate diff between dicts"""
|
||||
diff = {}
|
||||
for key, value in before.items():
|
||||
if update_fields and key not in update_fields:
|
||||
continue
|
||||
if after.get(key) != value:
|
||||
diff[key] = {"previous_value": value, "new_value": after.get(key)}
|
||||
for key, value in after.items():
|
||||
if update_fields and key not in update_fields:
|
||||
continue
|
||||
if key not in before and key not in diff and before.get(key) != value:
|
||||
diff[key] = {"previous_value": before.get(key), "new_value": value}
|
||||
return sanitize_item(diff)
|
||||
@@ -99,7 +95,6 @@ class EnterpriseAuditMiddleware(AuditMiddleware):
|
||||
instance: Model,
|
||||
created: bool,
|
||||
thread_kwargs: dict | None = None,
|
||||
update_fields: list[str] | None = None,
|
||||
**_,
|
||||
):
|
||||
if not self.enabled:
|
||||
@@ -113,7 +108,7 @@ class EnterpriseAuditMiddleware(AuditMiddleware):
|
||||
prev_state = {}
|
||||
# Get current state
|
||||
new_state = self.serialize_simple(instance)
|
||||
diff = self.diff(prev_state, new_state, update_fields)
|
||||
diff = self.diff(prev_state, new_state)
|
||||
thread_kwargs["diff"] = diff
|
||||
return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_)
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.models import Group, User
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.enterprise.audit.middleware import EnterpriseAuditMiddleware
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.events.utils import sanitize_item
|
||||
from authentik.lib.generators import generate_id
|
||||
@@ -209,23 +208,3 @@ class TestEnterpriseAudit(APITestCase):
|
||||
diff,
|
||||
{"users": {"remove": [user.pk]}},
|
||||
)
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled",
|
||||
PropertyMock(return_value=True),
|
||||
)
|
||||
def test_diff_update_fields(self):
|
||||
"""Test update audit log"""
|
||||
self.client.force_login(self.user)
|
||||
diff = EnterpriseAuditMiddleware(None).diff(
|
||||
{
|
||||
"foo": "bar",
|
||||
"is_active": False,
|
||||
},
|
||||
{
|
||||
"foo": "baz",
|
||||
"is_active": True,
|
||||
},
|
||||
update_fields=["is_active"],
|
||||
)
|
||||
self.assertEqual(diff, {"is_active": {"new_value": True, "previous_value": False}})
|
||||
|
||||
@@ -6,7 +6,7 @@ from djangoql.ast import Name
|
||||
from djangoql.exceptions import DjangoQLError
|
||||
from djangoql.queryset import apply_search
|
||||
from djangoql.schema import DjangoQLSchema
|
||||
from rest_framework.filters import SearchFilter
|
||||
from rest_framework.filters import BaseFilterBackend, SearchFilter
|
||||
from rest_framework.request import Request
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
@@ -39,7 +39,7 @@ class BaseSchema(DjangoQLSchema):
|
||||
return super().resolve_name(name)
|
||||
|
||||
|
||||
class QLSearch(SearchFilter):
|
||||
class QLSearch(BaseFilterBackend):
|
||||
"""rest_framework search filter which uses DjangoQL"""
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -16,7 +16,7 @@ from authentik.stages.authenticator.models import Device
|
||||
|
||||
|
||||
class AuthenticatorEndpointGDTCStage(ConfigurableStage, FriendlyNamedStage, Stage):
|
||||
"""Setup Google Chrome Device Trust connection"""
|
||||
"""Setup Google Chrome Device-trust connection"""
|
||||
|
||||
credentials = models.JSONField()
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@ from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import (
|
||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
|
||||
from authentik.flows.views.executor import SESSION_KEY_PLAN
|
||||
from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS
|
||||
from authentik.stages.user_login.stage import PLAN_CONTEXT_METHOD_ARGS_KNOWN_DEVICE
|
||||
|
||||
# Header we get from chrome that initiates verified access
|
||||
HEADER_DEVICE_TRUST = "X-Device-Trust"
|
||||
@@ -28,8 +27,6 @@ HEADER_ACCESS_CHALLENGE_RESPONSE = "X-Verified-Access-Challenge-Response"
|
||||
# Header value for x-device-trust that initiates the flow
|
||||
DEVICE_TRUST_VERIFIED_ACCESS = "VerifiedAccess"
|
||||
|
||||
PLAN_CONTEXT_METHOD_ARGS_ENDPOINTS = "endpoints"
|
||||
|
||||
|
||||
@method_decorator(xframe_options_sameorigin, name="dispatch")
|
||||
class GoogleChromeDeviceTrustConnector(View):
|
||||
@@ -84,14 +81,7 @@ class GoogleChromeDeviceTrustConnector(View):
|
||||
)
|
||||
flow_plan.context.setdefault(PLAN_CONTEXT_METHOD, "trusted_endpoint")
|
||||
flow_plan.context.setdefault(PLAN_CONTEXT_METHOD_ARGS, {})
|
||||
flow_plan.context[PLAN_CONTEXT_METHOD_ARGS].setdefault(
|
||||
PLAN_CONTEXT_METHOD_ARGS_ENDPOINTS, []
|
||||
)
|
||||
flow_plan.context[PLAN_CONTEXT_METHOD_ARGS][PLAN_CONTEXT_METHOD_ARGS_ENDPOINTS].append(
|
||||
response
|
||||
)
|
||||
flow_plan.context[PLAN_CONTEXT_METHOD_ARGS].setdefault(
|
||||
PLAN_CONTEXT_METHOD_ARGS_KNOWN_DEVICE, True
|
||||
)
|
||||
flow_plan.context[PLAN_CONTEXT_METHOD_ARGS].setdefault("endpoints", [])
|
||||
flow_plan.context[PLAN_CONTEXT_METHOD_ARGS]["endpoints"].append(response)
|
||||
request.session[SESSION_KEY_PLAN] = flow_plan
|
||||
return TemplateResponse(request, "stages/authenticator_endpoint/google_chrome_dtc.html")
|
||||
|
||||
@@ -20,7 +20,7 @@ from authentik.core.models import Group, User
|
||||
from authentik.events.models import Event, EventAction, Notification
|
||||
from authentik.events.utils import model_to_dict
|
||||
from authentik.lib.sentry import should_ignore_exception
|
||||
from authentik.lib.utils.errors import exception_to_dict
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.stages.authenticator_static.models import StaticToken
|
||||
|
||||
IGNORED_MODELS = tuple(
|
||||
@@ -170,16 +170,14 @@ class AuditMiddleware:
|
||||
thread = EventNewThread(
|
||||
EventAction.SUSPICIOUS_REQUEST,
|
||||
request,
|
||||
message=str(exception),
|
||||
exception=exception_to_dict(exception),
|
||||
message=exception_to_string(exception),
|
||||
)
|
||||
thread.run()
|
||||
elif not should_ignore_exception(exception):
|
||||
thread = EventNewThread(
|
||||
EventAction.SYSTEM_EXCEPTION,
|
||||
request,
|
||||
message=str(exception),
|
||||
exception=exception_to_dict(exception),
|
||||
message=exception_to_string(exception),
|
||||
)
|
||||
thread.run()
|
||||
|
||||
|
||||
@@ -38,7 +38,6 @@ from authentik.events.utils import (
|
||||
)
|
||||
from authentik.lib.models import DomainlessURLValidator, SerializerModel
|
||||
from authentik.lib.sentry import SentryIgnoredException
|
||||
from authentik.lib.utils.errors import exception_to_dict
|
||||
from authentik.lib.utils.http import get_http_session
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.policies.models import PolicyBindingModel
|
||||
@@ -164,12 +163,6 @@ class Event(SerializerModel, ExpiringModel):
|
||||
event = Event(action=action, app=app, context=cleaned_kwargs)
|
||||
return event
|
||||
|
||||
def with_exception(self, exc: Exception) -> "Event":
|
||||
"""Add data from 'exc' to the event in a database-saveable format"""
|
||||
self.context.setdefault("message", str(exc))
|
||||
self.context["exception"] = exception_to_dict(exc)
|
||||
return self
|
||||
|
||||
def set_user(self, user: User) -> "Event":
|
||||
"""Set `.user` based on user, ensuring the correct attributes are copied.
|
||||
This should only be used when self.from_http is *not* used."""
|
||||
|
||||
@@ -127,8 +127,8 @@ class SystemTask(TenantTask):
|
||||
)
|
||||
Event.new(
|
||||
EventAction.SYSTEM_TASK_EXCEPTION,
|
||||
message=f"Task {self.__name__} encountered an error",
|
||||
).with_exception(exc).save()
|
||||
message=f"Task {self.__name__} encountered an error: {exception_to_string(exc)}",
|
||||
).save()
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -62,7 +62,6 @@ def event_trigger_handler(event_uuid: str, trigger_name: str):
|
||||
policy_engine.mode = PolicyEngineMode.MODE_ANY
|
||||
policy_engine.empty_result = False
|
||||
policy_engine.use_cache = False
|
||||
policy_engine.request.obj = event
|
||||
policy_engine.request.context["event"] = event
|
||||
policy_engine.build()
|
||||
result = policy_engine.result
|
||||
|
||||
@@ -56,6 +56,7 @@ from authentik.flows.planner import (
|
||||
)
|
||||
from authentik.flows.stage import AccessDeniedStage, StageView
|
||||
from authentik.lib.sentry import SentryIgnoredException, should_ignore_exception
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.lib.utils.reflection import all_subclasses, class_to_path
|
||||
from authentik.lib.utils.urls import is_url_absolute, redirect_with_qs
|
||||
from authentik.policies.engine import PolicyEngine
|
||||
@@ -238,8 +239,8 @@ class FlowExecutorView(APIView):
|
||||
capture_exception(exc)
|
||||
Event.new(
|
||||
action=EventAction.SYSTEM_EXCEPTION,
|
||||
message="System exception during flow execution.",
|
||||
).with_exception(exc).from_http(self.request)
|
||||
message=exception_to_string(exc),
|
||||
).from_http(self.request)
|
||||
challenge = FlowErrorChallenge(self.request, exc)
|
||||
challenge.is_valid(raise_exception=True)
|
||||
return to_stage_response(self.request, HttpChallengeResponse(challenge))
|
||||
|
||||
@@ -8,9 +8,9 @@
|
||||
# make gen-dev-config
|
||||
# ```
|
||||
#
|
||||
# You may edit the generated file to override the configuration below.
|
||||
# You may edit the generated file to override the configuration below.
|
||||
#
|
||||
# When making modifying the default configuration file,
|
||||
# When making modifying the default configuration file,
|
||||
# ensure that the corresponding documentation is updated to match.
|
||||
#
|
||||
# @see {@link ../../website/docs/install-config/configuration/configuration.mdx Configuration documentation} for more information.
|
||||
|
||||
@@ -14,6 +14,7 @@ from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.expression.exceptions import ControlFlowException
|
||||
from authentik.lib.sync.mapper import PropertyMappingManager
|
||||
from authentik.lib.sync.outgoing.exceptions import NotFoundSyncException, StopSync
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from django.db.models import Model
|
||||
@@ -105,9 +106,9 @@ class BaseOutgoingSyncClient[
|
||||
# Value error can be raised when assigning invalid data to an attribute
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message="Failed to evaluate property-mapping",
|
||||
message=f"Failed to evaluate property-mapping {exception_to_string(exc)}",
|
||||
mapping=exc.mapping,
|
||||
).with_exception(exc).save()
|
||||
).save()
|
||||
raise StopSync(exc, obj, exc.mapping) from exc
|
||||
if not raw_final_object:
|
||||
raise StopSync(ValueError("No mappings configured"), obj)
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
|
||||
from traceback import extract_tb
|
||||
|
||||
from structlog.tracebacks import ExceptionDictTransformer
|
||||
|
||||
from authentik.lib.utils.reflection import class_to_path
|
||||
|
||||
TRACEBACK_HEADER = "Traceback (most recent call last):"
|
||||
@@ -19,8 +17,3 @@ def exception_to_string(exc: Exception) -> str:
|
||||
f"{class_to_path(exc.__class__)}: {str(exc)}",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def exception_to_dict(exc: Exception) -> dict:
|
||||
"""Format exception as a dictionary"""
|
||||
return ExceptionDictTransformer()((type(exc), exc, exc.__traceback__))
|
||||
|
||||
@@ -35,6 +35,7 @@ from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.models import InheritanceForeignKey, SerializerModel
|
||||
from authentik.lib.sentry import SentryIgnoredException
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.outposts.controllers.k8s.utils import get_namespace
|
||||
|
||||
OUR_VERSION = parse(__version__)
|
||||
@@ -325,8 +326,9 @@ class Outpost(SerializerModel, ManagedModel):
|
||||
"While setting the permissions for the service-account, a "
|
||||
"permission was not found: Check "
|
||||
"https://goauthentik.io/docs/troubleshooting/missing_permission"
|
||||
),
|
||||
).with_exception(exc).set_user(user).save()
|
||||
)
|
||||
+ exception_to_string(exc),
|
||||
).set_user(user).save()
|
||||
else:
|
||||
app_label, perm = model_or_perm.split(".")
|
||||
permission = Permission.objects.filter(
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""authentik policy engine"""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Iterator
|
||||
from multiprocessing import Pipe, current_process
|
||||
from multiprocessing.connection import Connection
|
||||
from time import perf_counter
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.db.models import Count, Q, QuerySet
|
||||
from django.http import HttpRequest
|
||||
from sentry_sdk import start_span
|
||||
from sentry_sdk.tracing import Span
|
||||
@@ -67,11 +67,14 @@ class PolicyEngine:
|
||||
self.__processes: list[PolicyProcessInfo] = []
|
||||
self.use_cache = True
|
||||
self.__expected_result_count = 0
|
||||
self.__static_result: PolicyResult | None = None
|
||||
|
||||
def bindings(self) -> QuerySet[PolicyBinding] | Iterable[PolicyBinding]:
|
||||
def iterate_bindings(self) -> Iterator[PolicyBinding]:
|
||||
"""Make sure all Policies are their respective classes"""
|
||||
return PolicyBinding.objects.filter(target=self.__pbm, enabled=True).order_by("order")
|
||||
return (
|
||||
PolicyBinding.objects.filter(target=self.__pbm, enabled=True)
|
||||
.order_by("order")
|
||||
.iterator()
|
||||
)
|
||||
|
||||
def _check_policy_type(self, binding: PolicyBinding):
|
||||
"""Check policy type, make sure it's not the root class as that has no logic implemented"""
|
||||
@@ -81,66 +84,30 @@ class PolicyEngine:
|
||||
def _check_cache(self, binding: PolicyBinding):
|
||||
if not self.use_cache:
|
||||
return False
|
||||
# It's a bit silly to time this, but
|
||||
with HIST_POLICIES_EXECUTION_TIME.labels(
|
||||
binding_order=binding.order,
|
||||
binding_target_type=binding.target_type,
|
||||
binding_target_name=binding.target_name,
|
||||
object_pk=str(self.request.obj.pk),
|
||||
object_type=class_to_path(self.request.obj.__class__),
|
||||
mode="cache_retrieve",
|
||||
).time():
|
||||
key = cache_key(binding, self.request)
|
||||
cached_policy = cache.get(key, None)
|
||||
if not cached_policy:
|
||||
return False
|
||||
before = perf_counter()
|
||||
key = cache_key(binding, self.request)
|
||||
cached_policy = cache.get(key, None)
|
||||
duration = max(perf_counter() - before, 0)
|
||||
if not cached_policy:
|
||||
return False
|
||||
self.logger.debug(
|
||||
"P_ENG: Taking result from cache",
|
||||
binding=binding,
|
||||
cache_key=key,
|
||||
request=self.request,
|
||||
)
|
||||
HIST_POLICIES_EXECUTION_TIME.labels(
|
||||
binding_order=binding.order,
|
||||
binding_target_type=binding.target_type,
|
||||
binding_target_name=binding.target_name,
|
||||
object_pk=str(self.request.obj.pk),
|
||||
object_type=class_to_path(self.request.obj.__class__),
|
||||
mode="cache_retrieve",
|
||||
).observe(duration)
|
||||
# It's a bit silly to time this, but
|
||||
self.__cached_policies.append(cached_policy)
|
||||
return True
|
||||
|
||||
def compute_static_bindings(self, bindings: QuerySet[PolicyBinding]):
|
||||
"""Check static bindings if possible"""
|
||||
aggrs = {
|
||||
"total": Count(
|
||||
"pk", filter=Q(Q(group__isnull=False) | Q(user__isnull=False), policy=None)
|
||||
),
|
||||
}
|
||||
if self.request.user.pk:
|
||||
all_groups = self.request.user.all_groups()
|
||||
aggrs["passing"] = Count(
|
||||
"pk",
|
||||
filter=Q(
|
||||
Q(
|
||||
Q(user=self.request.user) | Q(group__in=all_groups),
|
||||
negate=False,
|
||||
)
|
||||
| Q(
|
||||
Q(~Q(user=self.request.user), user__isnull=False)
|
||||
| Q(~Q(group__in=all_groups), group__isnull=False),
|
||||
negate=True,
|
||||
),
|
||||
enabled=True,
|
||||
),
|
||||
)
|
||||
matched_bindings = bindings.aggregate(**aggrs)
|
||||
passing = False
|
||||
if matched_bindings["total"] == 0 and matched_bindings.get("passing", 0) == 0:
|
||||
# If we didn't find any static bindings, do nothing
|
||||
return
|
||||
self.logger.debug("P_ENG: Found static bindings", **matched_bindings)
|
||||
if matched_bindings.get("passing", 0) > 0:
|
||||
# Any passing static binding -> passing
|
||||
passing = True
|
||||
elif matched_bindings["total"] > 0 and matched_bindings.get("passing", 0) < 1:
|
||||
# No matching static bindings but at least one is configured -> not passing
|
||||
passing = False
|
||||
self.__static_result = PolicyResult(passing)
|
||||
|
||||
def build(self) -> "PolicyEngine":
|
||||
"""Build wrapper which monitors performance"""
|
||||
with (
|
||||
@@ -156,12 +123,7 @@ class PolicyEngine:
|
||||
span: Span
|
||||
span.set_data("pbm", self.__pbm)
|
||||
span.set_data("request", self.request)
|
||||
bindings = self.bindings()
|
||||
policy_bindings = bindings
|
||||
if isinstance(bindings, QuerySet):
|
||||
self.compute_static_bindings(bindings)
|
||||
policy_bindings = [x for x in bindings if x.policy]
|
||||
for binding in policy_bindings:
|
||||
for binding in self.iterate_bindings():
|
||||
self.__expected_result_count += 1
|
||||
|
||||
self._check_policy_type(binding)
|
||||
@@ -191,13 +153,10 @@ class PolicyEngine:
|
||||
@property
|
||||
def result(self) -> PolicyResult:
|
||||
"""Get policy-checking result"""
|
||||
self.__processes.sort(key=lambda x: x.binding.order)
|
||||
process_results: list[PolicyResult] = [x.result for x in self.__processes if x.result]
|
||||
all_results = list(process_results + self.__cached_policies)
|
||||
if len(all_results) < self.__expected_result_count: # pragma: no cover
|
||||
raise AssertionError("Got less results than polices")
|
||||
if self.__static_result:
|
||||
all_results.append(self.__static_result)
|
||||
# No results, no policies attached -> passing
|
||||
if len(all_results) == 0:
|
||||
return PolicyResult(self.empty_result)
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Optional
|
||||
from django.http import HttpRequest
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.events.models import Event
|
||||
from authentik.flows.planner import PLAN_CONTEXT_SSO
|
||||
from authentik.lib.expression.evaluator import BaseEvaluator
|
||||
from authentik.policies.exceptions import PolicyException
|
||||
@@ -46,10 +45,6 @@ class PolicyEvaluator(BaseEvaluator):
|
||||
self.set_http_request(request.http_request)
|
||||
self._context["request"] = request
|
||||
self._context["context"] = request.context
|
||||
if request.obj and isinstance(request.obj, Event):
|
||||
self._context["ak_client_ip"] = ip_address(
|
||||
request.obj.client_ip or ClientIPMiddleware.default_ip
|
||||
)
|
||||
|
||||
def set_http_request(self, request: HttpRequest):
|
||||
"""Update context based on http request"""
|
||||
|
||||
@@ -10,7 +10,7 @@ from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.utils.errors import exception_to_dict
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.lib.utils.reflection import class_to_path
|
||||
from authentik.policies.apps import HIST_POLICIES_EXECUTION_TIME
|
||||
from authentik.policies.exceptions import PolicyException
|
||||
@@ -95,13 +95,10 @@ class PolicyProcess(PROCESS_CLASS):
|
||||
except PolicyException as exc:
|
||||
# Either use passed original exception or whatever we have
|
||||
src_exc = exc.src_exc if exc.src_exc else exc
|
||||
error_string = exception_to_string(src_exc)
|
||||
# Create policy exception event, only when we're not debugging
|
||||
if not self.request.debug:
|
||||
self.create_event(
|
||||
EventAction.POLICY_EXCEPTION,
|
||||
message="Policy failed to execute",
|
||||
exception=exception_to_dict(src_exc),
|
||||
)
|
||||
self.create_event(EventAction.POLICY_EXCEPTION, message=error_string)
|
||||
LOGGER.debug("P_ENG(proc): error, using failure result", exc=src_exc)
|
||||
policy_result = PolicyResult(self.binding.failure_result, str(src_exc))
|
||||
policy_result.source_binding = self.binding
|
||||
@@ -146,5 +143,5 @@ class PolicyProcess(PROCESS_CLASS):
|
||||
try:
|
||||
self.connection.send(self.profiling_wrapper())
|
||||
except Exception as exc:
|
||||
LOGGER.warning("Policy failed to run", exc=exc)
|
||||
LOGGER.warning("Policy failed to run", exc=exception_to_string(exc))
|
||||
self.connection.send(PolicyResult(False, str(exc)))
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
"""policy engine tests"""
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.db import connections
|
||||
from django.test import TestCase
|
||||
from django.test.utils import CaptureQueriesContext
|
||||
|
||||
from authentik.core.models import Group
|
||||
from authentik.core.tests.utils import create_test_user
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.policies.dummy.models import DummyPolicy
|
||||
from authentik.policies.engine import PolicyEngine
|
||||
@@ -22,7 +19,7 @@ class TestPolicyEngine(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
clear_policy_cache()
|
||||
self.user = create_test_user()
|
||||
self.user = create_test_admin_user()
|
||||
self.policy_false = DummyPolicy.objects.create(
|
||||
name=generate_id(), result=False, wait_min=0, wait_max=1
|
||||
)
|
||||
@@ -130,58 +127,3 @@ class TestPolicyEngine(TestCase):
|
||||
self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1)
|
||||
self.assertEqual(engine.build().passing, False)
|
||||
self.assertEqual(len(cache.keys(f"{CACHE_PREFIX}{binding.policy_binding_uuid.hex}*")), 1)
|
||||
|
||||
def test_engine_static_bindings(self):
|
||||
"""Test static bindings"""
|
||||
group_a = Group.objects.create(name=generate_id())
|
||||
group_b = Group.objects.create(name=generate_id())
|
||||
group_b.users.add(self.user)
|
||||
user = create_test_user()
|
||||
|
||||
for case in [
|
||||
{
|
||||
"message": "Group, not member",
|
||||
"binding_args": {"group": group_a},
|
||||
"passing": False,
|
||||
},
|
||||
{
|
||||
"message": "Group, member",
|
||||
"binding_args": {"group": group_b},
|
||||
"passing": True,
|
||||
},
|
||||
{
|
||||
"message": "User, other",
|
||||
"binding_args": {"user": user},
|
||||
"passing": False,
|
||||
},
|
||||
{
|
||||
"message": "User, same",
|
||||
"binding_args": {"user": self.user},
|
||||
"passing": True,
|
||||
},
|
||||
]:
|
||||
with self.subTest():
|
||||
pbm = PolicyBindingModel.objects.create()
|
||||
for x in range(1000):
|
||||
PolicyBinding.objects.create(target=pbm, order=x, **case["binding_args"])
|
||||
engine = PolicyEngine(pbm, self.user)
|
||||
engine.use_cache = False
|
||||
with CaptureQueriesContext(connections["default"]) as ctx:
|
||||
engine.build()
|
||||
self.assertLess(ctx.final_queries, 1000)
|
||||
self.assertEqual(engine.result.passing, case["passing"])
|
||||
|
||||
def test_engine_group_complex(self):
|
||||
"""Test more complex group setups"""
|
||||
group_a = Group.objects.create(name=generate_id())
|
||||
group_b = Group.objects.create(name=generate_id(), parent=group_a)
|
||||
user = create_test_user()
|
||||
group_b.users.add(user)
|
||||
pbm = PolicyBindingModel.objects.create()
|
||||
PolicyBinding.objects.create(target=pbm, order=0, group=group_a)
|
||||
engine = PolicyEngine(pbm, user)
|
||||
engine.use_cache = False
|
||||
with CaptureQueriesContext(connections["default"]) as ctx:
|
||||
engine.build()
|
||||
self.assertLess(ctx.final_queries, 1000)
|
||||
self.assertTrue(engine.result.passing)
|
||||
|
||||
@@ -29,12 +29,13 @@ class TestPolicyProcess(TestCase):
|
||||
def setUp(self):
|
||||
clear_policy_cache()
|
||||
self.factory = RequestFactory()
|
||||
self.user = User.objects.create_user(username=generate_id())
|
||||
self.user = User.objects.create_user(username="policyuser")
|
||||
|
||||
def test_group_passing(self):
|
||||
"""Test binding to group"""
|
||||
group = Group.objects.create(name=generate_id())
|
||||
group = Group.objects.create(name="test-group")
|
||||
group.users.add(self.user)
|
||||
group.save()
|
||||
binding = PolicyBinding(group=group)
|
||||
|
||||
request = PolicyRequest(self.user)
|
||||
@@ -43,7 +44,8 @@ class TestPolicyProcess(TestCase):
|
||||
|
||||
def test_group_negative(self):
|
||||
"""Test binding to group"""
|
||||
group = Group.objects.create(name=generate_id())
|
||||
group = Group.objects.create(name="test-group")
|
||||
group.save()
|
||||
binding = PolicyBinding(group=group)
|
||||
|
||||
request = PolicyRequest(self.user)
|
||||
@@ -113,10 +115,8 @@ class TestPolicyProcess(TestCase):
|
||||
|
||||
def test_exception(self):
|
||||
"""Test policy execution"""
|
||||
policy = Policy.objects.create(name=generate_id())
|
||||
binding = PolicyBinding(
|
||||
policy=policy, target=Application.objects.create(name=generate_id())
|
||||
)
|
||||
policy = Policy.objects.create(name="test-execution")
|
||||
binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test"))
|
||||
|
||||
request = PolicyRequest(self.user)
|
||||
response = PolicyProcess(binding, request, None).execute()
|
||||
@@ -125,15 +125,13 @@ class TestPolicyProcess(TestCase):
|
||||
def test_execution_logging(self):
|
||||
"""Test policy execution creates event"""
|
||||
policy = DummyPolicy.objects.create(
|
||||
name=generate_id(),
|
||||
name="test-execution-logging",
|
||||
result=False,
|
||||
wait_min=0,
|
||||
wait_max=1,
|
||||
execution_logging=True,
|
||||
)
|
||||
binding = PolicyBinding(
|
||||
policy=policy, target=Application.objects.create(name=generate_id())
|
||||
)
|
||||
binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test"))
|
||||
|
||||
http_request = self.factory.get(reverse("authentik_api:user-impersonate-end"))
|
||||
http_request.user = self.user
|
||||
@@ -188,15 +186,13 @@ class TestPolicyProcess(TestCase):
|
||||
def test_execution_logging_anonymous(self):
|
||||
"""Test policy execution creates event with anonymous user"""
|
||||
policy = DummyPolicy.objects.create(
|
||||
name=generate_id(),
|
||||
name="test-execution-logging-anon",
|
||||
result=False,
|
||||
wait_min=0,
|
||||
wait_max=1,
|
||||
execution_logging=True,
|
||||
)
|
||||
binding = PolicyBinding(
|
||||
policy=policy, target=Application.objects.create(name=generate_id())
|
||||
)
|
||||
binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test"))
|
||||
|
||||
user = AnonymousUser()
|
||||
|
||||
@@ -223,9 +219,9 @@ class TestPolicyProcess(TestCase):
|
||||
|
||||
def test_raises(self):
|
||||
"""Test policy that raises error"""
|
||||
policy_raises = ExpressionPolicy.objects.create(name=generate_id(), expression="{{ 0/0 }}")
|
||||
policy_raises = ExpressionPolicy.objects.create(name="raises", expression="{{ 0/0 }}")
|
||||
binding = PolicyBinding(
|
||||
policy=policy_raises, target=Application.objects.create(name=generate_id())
|
||||
policy=policy_raises, target=Application.objects.create(name="test")
|
||||
)
|
||||
|
||||
request = PolicyRequest(self.user)
|
||||
@@ -241,4 +237,4 @@ class TestPolicyProcess(TestCase):
|
||||
self.assertEqual(len(events), 1)
|
||||
event = events.first()
|
||||
self.assertEqual(event.user["username"], self.user.username)
|
||||
self.assertIn("Policy failed to execute", event.context["message"])
|
||||
self.assertIn("division by zero", event.context["message"])
|
||||
|
||||
@@ -15,14 +15,12 @@ class OAuth2Error(SentryIgnoredException):
|
||||
|
||||
error: str
|
||||
description: str
|
||||
cause: str | None = None
|
||||
|
||||
def create_dict(self, request: HttpRequest):
|
||||
def create_dict(self):
|
||||
"""Return error as dict for JSON Rendering"""
|
||||
return {
|
||||
"error": self.error,
|
||||
"error_description": self.description,
|
||||
"request_id": request.request_id,
|
||||
}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@@ -33,15 +31,9 @@ class OAuth2Error(SentryIgnoredException):
|
||||
return Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message=message or self.description,
|
||||
cause=self.cause,
|
||||
error=self.error,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def with_cause(self, cause: str):
|
||||
self.cause = cause
|
||||
return self
|
||||
|
||||
|
||||
class RedirectUriError(OAuth2Error):
|
||||
"""The request fails due to a missing, invalid, or mismatching
|
||||
@@ -251,14 +243,13 @@ class TokenRevocationError(OAuth2Error):
|
||||
self.description = self.errors[error]
|
||||
|
||||
|
||||
class DeviceCodeError(TokenError):
|
||||
class DeviceCodeError(OAuth2Error):
|
||||
"""
|
||||
Device-code flow errors
|
||||
See https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
|
||||
Can also use codes form TokenError
|
||||
"""
|
||||
|
||||
errors = TokenError.errors | {
|
||||
errors = {
|
||||
"authorization_pending": (
|
||||
"The authorization request is still pending as the end user hasn't "
|
||||
"yet completed the user-interaction steps"
|
||||
@@ -270,15 +261,10 @@ class DeviceCodeError(TokenError):
|
||||
"authorization request but SHOULD wait for user interaction before "
|
||||
"restarting to avoid unnecessary polling."
|
||||
),
|
||||
"slow_down": (
|
||||
'A variant of "authorization_pending", the authorization request is'
|
||||
"still pending and polling should continue, but the interval MUST"
|
||||
"be increased by 5 seconds for this and all subsequent requests."
|
||||
),
|
||||
}
|
||||
|
||||
def __init__(self, error: str):
|
||||
super().__init__(error)
|
||||
super().__init__()
|
||||
self.error = error
|
||||
self.description = self.errors[error]
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.providers.oauth2.constants import SCOPE_OFFLINE_ACCESS, SCOPE_OPENID, TOKEN_TYPE
|
||||
from authentik.providers.oauth2.constants import TOKEN_TYPE
|
||||
from authentik.providers.oauth2.errors import AuthorizeError, ClientIdError, RedirectUriError
|
||||
from authentik.providers.oauth2.models import (
|
||||
AccessToken,
|
||||
@@ -43,7 +43,7 @@ class TestAuthorize(OAuthTestCase):
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
|
||||
)
|
||||
with self.assertRaises(AuthorizeError) as cm:
|
||||
with self.assertRaises(AuthorizeError):
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
@@ -53,7 +53,6 @@ class TestAuthorize(OAuthTestCase):
|
||||
},
|
||||
)
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
self.assertEqual(cm.exception.error, "unsupported_response_type")
|
||||
|
||||
def test_invalid_client_id(self):
|
||||
"""Test invalid client ID"""
|
||||
@@ -69,7 +68,7 @@ class TestAuthorize(OAuthTestCase):
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
|
||||
)
|
||||
with self.assertRaises(AuthorizeError) as cm:
|
||||
with self.assertRaises(AuthorizeError):
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
@@ -80,30 +79,19 @@ class TestAuthorize(OAuthTestCase):
|
||||
},
|
||||
)
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
self.assertEqual(cm.exception.error, "request_not_supported")
|
||||
|
||||
def test_invalid_redirect_uri_missing(self):
|
||||
"""test missing redirect URI"""
|
||||
OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
|
||||
)
|
||||
with self.assertRaises(RedirectUriError) as cm:
|
||||
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
self.assertEqual(cm.exception.cause, "redirect_uri_missing")
|
||||
|
||||
def test_invalid_redirect_uri(self):
|
||||
"""test invalid redirect URI"""
|
||||
"""test missing/invalid redirect URI"""
|
||||
OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
|
||||
)
|
||||
with self.assertRaises(RedirectUriError) as cm:
|
||||
with self.assertRaises(RedirectUriError):
|
||||
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
with self.assertRaises(RedirectUriError):
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
@@ -113,7 +101,6 @@ class TestAuthorize(OAuthTestCase):
|
||||
},
|
||||
)
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
self.assertEqual(cm.exception.cause, "redirect_uri_no_match")
|
||||
|
||||
def test_blocked_redirect_uri(self):
|
||||
"""test missing/invalid redirect URI"""
|
||||
@@ -121,9 +108,9 @@ class TestAuthorize(OAuthTestCase):
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "data:localhost")],
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "data:local.invalid")],
|
||||
)
|
||||
with self.assertRaises(RedirectUriError) as cm:
|
||||
with self.assertRaises(RedirectUriError):
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
@@ -133,7 +120,6 @@ class TestAuthorize(OAuthTestCase):
|
||||
},
|
||||
)
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
self.assertEqual(cm.exception.cause, "redirect_uri_forbidden_scheme")
|
||||
|
||||
def test_invalid_redirect_uri_empty(self):
|
||||
"""test missing/invalid redirect URI"""
|
||||
@@ -143,6 +129,9 @@ class TestAuthorize(OAuthTestCase):
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[],
|
||||
)
|
||||
with self.assertRaises(RedirectUriError):
|
||||
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
@@ -161,9 +150,12 @@ class TestAuthorize(OAuthTestCase):
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, "http://local.invalid?")],
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid?")],
|
||||
)
|
||||
with self.assertRaises(RedirectUriError) as cm:
|
||||
with self.assertRaises(RedirectUriError):
|
||||
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
with self.assertRaises(RedirectUriError):
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
@@ -173,7 +165,6 @@ class TestAuthorize(OAuthTestCase):
|
||||
},
|
||||
)
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
self.assertEqual(cm.exception.cause, "redirect_uri_no_match")
|
||||
|
||||
def test_redirect_uri_invalid_regex(self):
|
||||
"""test missing/invalid redirect URI (invalid regex)"""
|
||||
@@ -181,9 +172,12 @@ class TestAuthorize(OAuthTestCase):
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, "+")],
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "+")],
|
||||
)
|
||||
with self.assertRaises(RedirectUriError) as cm:
|
||||
with self.assertRaises(RedirectUriError):
|
||||
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
with self.assertRaises(RedirectUriError):
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
@@ -193,22 +187,23 @@ class TestAuthorize(OAuthTestCase):
|
||||
},
|
||||
)
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
self.assertEqual(cm.exception.cause, "redirect_uri_no_match")
|
||||
|
||||
def test_redirect_uri_regex(self):
|
||||
"""test valid redirect URI (regex)"""
|
||||
def test_empty_redirect_uri(self):
|
||||
"""test empty redirect URI (configure in provider)"""
|
||||
OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, ".+")],
|
||||
)
|
||||
with self.assertRaises(RedirectUriError):
|
||||
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
"response_type": "code",
|
||||
"client_id": "test",
|
||||
"redirect_uri": "http://foo.bar.baz",
|
||||
"redirect_uri": "http://localhost",
|
||||
},
|
||||
)
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
@@ -263,7 +258,7 @@ class TestAuthorize(OAuthTestCase):
|
||||
GrantTypes.IMPLICIT,
|
||||
)
|
||||
# Implicit without openid scope
|
||||
with self.assertRaises(AuthorizeError) as cm:
|
||||
with self.assertRaises(AuthorizeError):
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
@@ -290,7 +285,7 @@ class TestAuthorize(OAuthTestCase):
|
||||
self.assertEqual(
|
||||
OAuthAuthorizationParams.from_request(request).grant_type, GrantTypes.HYBRID
|
||||
)
|
||||
with self.assertRaises(AuthorizeError) as cm:
|
||||
with self.assertRaises(AuthorizeError):
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
@@ -300,7 +295,6 @@ class TestAuthorize(OAuthTestCase):
|
||||
},
|
||||
)
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
self.assertEqual(cm.exception.error, "unsupported_response_type")
|
||||
|
||||
def test_full_code(self):
|
||||
"""Test full authorization"""
|
||||
@@ -619,54 +613,3 @@ class TestAuthorize(OAuthTestCase):
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def test_openid_missing_invalid(self):
|
||||
"""test request requiring an OpenID scope to be set"""
|
||||
OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
|
||||
)
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
"response_type": "id_token",
|
||||
"client_id": "test",
|
||||
"redirect_uri": "http://localhost",
|
||||
"scope": "",
|
||||
},
|
||||
)
|
||||
with self.assertRaises(AuthorizeError) as cm:
|
||||
OAuthAuthorizationParams.from_request(request)
|
||||
self.assertEqual(cm.exception.cause, "scope_openid_missing")
|
||||
|
||||
@apply_blueprint("system/providers-oauth2.yaml")
|
||||
def test_offline_access_invalid(self):
|
||||
"""test request for offline_access with invalid response type"""
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
client_id="test",
|
||||
authorization_flow=create_test_flow(),
|
||||
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
|
||||
)
|
||||
provider.property_mappings.set(
|
||||
ScopeMapping.objects.filter(
|
||||
managed__in=[
|
||||
"goauthentik.io/providers/oauth2/scope-openid",
|
||||
"goauthentik.io/providers/oauth2/scope-offline_access",
|
||||
]
|
||||
)
|
||||
)
|
||||
request = self.factory.get(
|
||||
"/",
|
||||
data={
|
||||
"response_type": "id_token",
|
||||
"client_id": "test",
|
||||
"redirect_uri": "http://localhost",
|
||||
"scope": f"{SCOPE_OPENID} {SCOPE_OFFLINE_ACCESS}",
|
||||
"nonce": generate_id(),
|
||||
},
|
||||
)
|
||||
parsed = OAuthAuthorizationParams.from_request(request)
|
||||
self.assertNotIn(SCOPE_OFFLINE_ACCESS, parsed.scope)
|
||||
|
||||
@@ -68,11 +68,7 @@ class TestTokenClientCredentialsStandard(OAuthTestCase):
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{
|
||||
"error": "invalid_grant",
|
||||
"error_description": TokenError.errors["invalid_grant"],
|
||||
"request_id": response.headers["X-authentik-id"],
|
||||
},
|
||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
||||
)
|
||||
|
||||
def test_no_provider(self):
|
||||
@@ -91,11 +87,7 @@ class TestTokenClientCredentialsStandard(OAuthTestCase):
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{
|
||||
"error": "invalid_grant",
|
||||
"error_description": TokenError.errors["invalid_grant"],
|
||||
"request_id": response.headers["X-authentik-id"],
|
||||
},
|
||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
||||
)
|
||||
|
||||
def test_permission_denied(self):
|
||||
@@ -118,11 +110,7 @@ class TestTokenClientCredentialsStandard(OAuthTestCase):
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{
|
||||
"error": "invalid_grant",
|
||||
"error_description": TokenError.errors["invalid_grant"],
|
||||
"request_id": response.headers["X-authentik-id"],
|
||||
},
|
||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
||||
)
|
||||
|
||||
def test_incorrect_scopes(self):
|
||||
|
||||
@@ -68,11 +68,7 @@ class TestTokenClientCredentialsStandardCompat(OAuthTestCase):
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{
|
||||
"error": "invalid_grant",
|
||||
"error_description": TokenError.errors["invalid_grant"],
|
||||
"request_id": response.headers["X-authentik-id"],
|
||||
},
|
||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
||||
)
|
||||
|
||||
def test_wrong_token(self):
|
||||
@@ -89,11 +85,7 @@ class TestTokenClientCredentialsStandardCompat(OAuthTestCase):
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{
|
||||
"error": "invalid_grant",
|
||||
"error_description": TokenError.errors["invalid_grant"],
|
||||
"request_id": response.headers["X-authentik-id"],
|
||||
},
|
||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
||||
)
|
||||
|
||||
def test_no_provider(self):
|
||||
@@ -112,11 +104,7 @@ class TestTokenClientCredentialsStandardCompat(OAuthTestCase):
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{
|
||||
"error": "invalid_grant",
|
||||
"error_description": TokenError.errors["invalid_grant"],
|
||||
"request_id": response.headers["X-authentik-id"],
|
||||
},
|
||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
||||
)
|
||||
|
||||
def test_permission_denied(self):
|
||||
@@ -139,11 +127,7 @@ class TestTokenClientCredentialsStandardCompat(OAuthTestCase):
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{
|
||||
"error": "invalid_grant",
|
||||
"error_description": TokenError.errors["invalid_grant"],
|
||||
"request_id": response.headers["X-authentik-id"],
|
||||
},
|
||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
||||
)
|
||||
|
||||
def test_successful(self):
|
||||
|
||||
@@ -68,11 +68,7 @@ class TestTokenClientCredentialsUserNamePassword(OAuthTestCase):
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{
|
||||
"error": "invalid_grant",
|
||||
"error_description": TokenError.errors["invalid_grant"],
|
||||
"request_id": response.headers["X-authentik-id"],
|
||||
},
|
||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
||||
)
|
||||
|
||||
def test_wrong_token(self):
|
||||
@@ -90,11 +86,7 @@ class TestTokenClientCredentialsUserNamePassword(OAuthTestCase):
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{
|
||||
"error": "invalid_grant",
|
||||
"error_description": TokenError.errors["invalid_grant"],
|
||||
"request_id": response.headers["X-authentik-id"],
|
||||
},
|
||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
||||
)
|
||||
|
||||
def test_no_provider(self):
|
||||
@@ -114,11 +106,7 @@ class TestTokenClientCredentialsUserNamePassword(OAuthTestCase):
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{
|
||||
"error": "invalid_grant",
|
||||
"error_description": TokenError.errors["invalid_grant"],
|
||||
"request_id": response.headers["X-authentik-id"],
|
||||
},
|
||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
||||
)
|
||||
|
||||
def test_permission_denied(self):
|
||||
@@ -142,11 +130,7 @@ class TestTokenClientCredentialsUserNamePassword(OAuthTestCase):
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{
|
||||
"error": "invalid_grant",
|
||||
"error_description": TokenError.errors["invalid_grant"],
|
||||
"request_id": response.headers["X-authentik-id"],
|
||||
},
|
||||
{"error": "invalid_grant", "error_description": TokenError.errors["invalid_grant"]},
|
||||
)
|
||||
|
||||
def test_successful(self):
|
||||
|
||||
@@ -80,7 +80,6 @@ class TestTokenPKCE(OAuthTestCase):
|
||||
"revoked, does not match the redirection URI used in the authorization "
|
||||
"request, or was issued to another client"
|
||||
),
|
||||
"request_id": response.headers["X-authentik-id"],
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
@@ -137,7 +136,6 @@ class TestTokenPKCE(OAuthTestCase):
|
||||
"revoked, does not match the redirection URI used in the authorization "
|
||||
"request, or was issued to another client"
|
||||
),
|
||||
"request_id": response.headers["X-authentik-id"],
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
|
||||
@@ -6,7 +6,6 @@ from django.urls import include, path
|
||||
from authentik.providers.oauth2.views.authorize import AuthorizationFlowInitView
|
||||
from authentik.providers.oauth2.views.device_init import DeviceEntryView
|
||||
from authentik.providers.oauth2.views.github import GitHubUserTeamsView, GitHubUserView
|
||||
from authentik.providers.oauth2.views.provider import ProviderInfoView
|
||||
from authentik.providers.oauth2.views.token import TokenView
|
||||
|
||||
github_urlpatterns = [
|
||||
@@ -41,9 +40,4 @@ urlpatterns = [
|
||||
),
|
||||
name="device-login",
|
||||
),
|
||||
path(
|
||||
".well-known/oauth-authorization-server/application/o/<slug:application_slug>/",
|
||||
ProviderInfoView.as_view(),
|
||||
name="providers-oauth2-authorization-server-metadata",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -190,7 +190,7 @@ class OAuthAuthorizationParams:
|
||||
allowed_redirect_urls = self.provider.redirect_uris
|
||||
if not self.redirect_uri:
|
||||
LOGGER.warning("Missing redirect uri.")
|
||||
raise RedirectUriError("", allowed_redirect_urls).with_cause("redirect_uri_missing")
|
||||
raise RedirectUriError("", allowed_redirect_urls)
|
||||
|
||||
if len(allowed_redirect_urls) < 1:
|
||||
LOGGER.info("Setting redirect for blank redirect_uris", redirect=self.redirect_uri)
|
||||
@@ -219,14 +219,10 @@ class OAuthAuthorizationParams:
|
||||
provider=self.provider,
|
||||
)
|
||||
if not match_found:
|
||||
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls).with_cause(
|
||||
"redirect_uri_no_match"
|
||||
)
|
||||
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
|
||||
# Check against forbidden schemes
|
||||
if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES:
|
||||
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls).with_cause(
|
||||
"redirect_uri_forbidden_scheme"
|
||||
)
|
||||
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
|
||||
|
||||
def check_scope(self, github_compat=False):
|
||||
"""Ensure openid scope is set in Hybrid flows, or when requesting an id_token"""
|
||||
@@ -255,9 +251,7 @@ class OAuthAuthorizationParams:
|
||||
or self.response_type in [ResponseTypes.ID_TOKEN, ResponseTypes.ID_TOKEN_TOKEN]
|
||||
):
|
||||
LOGGER.warning("Missing 'openid' scope.")
|
||||
raise AuthorizeError(
|
||||
self.redirect_uri, "invalid_scope", self.grant_type, self.state
|
||||
).with_cause("scope_openid_missing")
|
||||
raise AuthorizeError(self.redirect_uri, "invalid_scope", self.grant_type, self.state)
|
||||
if SCOPE_OFFLINE_ACCESS in self.scope:
|
||||
# https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess
|
||||
# Don't explicitly request consent with offline_access, as the spec allows for
|
||||
@@ -292,9 +286,7 @@ class OAuthAuthorizationParams:
|
||||
return
|
||||
if not self.nonce:
|
||||
LOGGER.warning("Missing nonce for OpenID Request")
|
||||
raise AuthorizeError(
|
||||
self.redirect_uri, "invalid_request", self.grant_type, self.state
|
||||
).with_cause("nonce_missing")
|
||||
raise AuthorizeError(self.redirect_uri, "invalid_request", self.grant_type, self.state)
|
||||
|
||||
def check_code_challenge(self):
|
||||
"""PKCE validation of the transformation method."""
|
||||
@@ -353,10 +345,10 @@ class AuthorizationFlowInitView(PolicyAccessView):
|
||||
self.request, github_compat=self.github_compat
|
||||
)
|
||||
except AuthorizeError as error:
|
||||
LOGGER.warning(error.description, redirect_uri=error.redirect_uri, cause=error.cause)
|
||||
LOGGER.warning(error.description, redirect_uri=error.redirect_uri)
|
||||
raise RequestValidationError(error.get_response(self.request)) from None
|
||||
except OAuth2Error as error:
|
||||
LOGGER.warning(error.description, cause=error.cause)
|
||||
LOGGER.warning(error.description)
|
||||
raise RequestValidationError(
|
||||
bad_request_message(self.request, error.description, title=error.error)
|
||||
) from None
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
from django.http import HttpRequest, HttpResponse, HttpResponseBadRequest, JsonResponse
|
||||
from django.urls import reverse
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.utils.timezone import now
|
||||
@@ -14,9 +14,7 @@ from structlog.stdlib import get_logger
|
||||
from authentik.core.models import Application
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.providers.oauth2.errors import DeviceCodeError
|
||||
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider
|
||||
from authentik.providers.oauth2.utils import TokenResponse
|
||||
from authentik.providers.oauth2.views.device_init import QS_KEY_CODE
|
||||
|
||||
LOGGER = get_logger()
|
||||
@@ -30,36 +28,38 @@ class DeviceView(View):
|
||||
provider: OAuth2Provider
|
||||
scopes: list[str] = []
|
||||
|
||||
def parse_request(self):
|
||||
def parse_request(self) -> HttpResponse | None:
|
||||
"""Parse incoming request"""
|
||||
client_id = self.request.POST.get("client_id", None)
|
||||
if not client_id:
|
||||
raise DeviceCodeError("invalid_client")
|
||||
provider = OAuth2Provider.objects.filter(client_id=client_id).first()
|
||||
return HttpResponseBadRequest()
|
||||
provider = OAuth2Provider.objects.filter(
|
||||
client_id=client_id,
|
||||
).first()
|
||||
if not provider:
|
||||
raise DeviceCodeError("invalid_client")
|
||||
return HttpResponseBadRequest()
|
||||
try:
|
||||
_ = provider.application
|
||||
except Application.DoesNotExist:
|
||||
raise DeviceCodeError("invalid_client") from None
|
||||
return HttpResponseBadRequest()
|
||||
self.provider = provider
|
||||
self.client_id = client_id
|
||||
self.scopes = self.request.POST.get("scope", "").split(" ")
|
||||
return None
|
||||
|
||||
def dispatch(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
||||
throttle = AnonRateThrottle()
|
||||
throttle.rate = CONFIG.get("throttle.providers.oauth2.device", "20/hour")
|
||||
throttle.num_requests, throttle.duration = throttle.parse_rate(throttle.rate)
|
||||
if not throttle.allow_request(request, self):
|
||||
return TokenResponse(DeviceCodeError("slow_down").create_dict(request), status=429)
|
||||
return HttpResponse(status=429)
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
|
||||
def post(self, request: HttpRequest) -> HttpResponse:
|
||||
"""Generate device token"""
|
||||
try:
|
||||
self.parse_request()
|
||||
except DeviceCodeError as exc:
|
||||
return TokenResponse(exc.create_dict(request), status=400)
|
||||
resp = self.parse_request()
|
||||
if resp:
|
||||
return resp
|
||||
until = timedelta_from_string(self.provider.access_code_validity)
|
||||
token: DeviceToken = DeviceToken.objects.create(
|
||||
expires=now() + until, provider=self.provider, _scope=" ".join(self.scopes)
|
||||
@@ -67,7 +67,7 @@ class DeviceView(View):
|
||||
device_url = self.request.build_absolute_uri(
|
||||
reverse("authentik_providers_oauth2_root:device-login")
|
||||
)
|
||||
return TokenResponse(
|
||||
return JsonResponse(
|
||||
{
|
||||
"device_code": token.device_code,
|
||||
"verification_uri": device_url,
|
||||
|
||||
@@ -598,9 +598,9 @@ class TokenView(View):
|
||||
return TokenResponse(self.create_device_code_response())
|
||||
raise TokenError("unsupported_grant_type")
|
||||
except (TokenError, DeviceCodeError) as error:
|
||||
return TokenResponse(error.create_dict(request), status=400)
|
||||
return TokenResponse(error.create_dict(), status=400)
|
||||
except UserAuthError as error:
|
||||
return TokenResponse(error.create_dict(request), status=403)
|
||||
return TokenResponse(error.create_dict(), status=403)
|
||||
|
||||
def create_code_response(self) -> dict[str, Any]:
|
||||
"""See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1"""
|
||||
|
||||
@@ -65,7 +65,7 @@ class TokenRevokeView(View):
|
||||
|
||||
return TokenResponse(data={}, status=200)
|
||||
except TokenRevocationError as exc:
|
||||
return TokenResponse(exc.create_dict(request), status=401)
|
||||
return TokenResponse(exc.create_dict(), status=401)
|
||||
except Http404:
|
||||
# Token not found should return a HTTP 200
|
||||
# https://datatracker.ietf.org/doc/html/rfc7009#section-2.2
|
||||
|
||||
@@ -102,7 +102,6 @@ class IngressReconciler(KubernetesObjectReconciler[V1Ingress]):
|
||||
# Buffer sizes for large headers with JWTs
|
||||
"nginx.ingress.kubernetes.io/proxy-buffers-number": "4",
|
||||
"nginx.ingress.kubernetes.io/proxy-buffer-size": "16k",
|
||||
"nginx.ingress.kubernetes.io/proxy-busy-buffers-size": "32k",
|
||||
# Enable TLS in traefik
|
||||
"traefik.ingress.kubernetes.io/router.tls": "true",
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ from authentik.core.models import Application
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.expression.exceptions import ControlFlowException
|
||||
from authentik.lib.sync.mapper import PropertyMappingManager
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.policies.api.exec import PolicyTestResultSerializer
|
||||
from authentik.policies.engine import PolicyEngine
|
||||
from authentik.policies.types import PolicyResult
|
||||
@@ -141,9 +142,9 @@ class RadiusOutpostConfigViewSet(ListModelMixin, GenericViewSet):
|
||||
# Value error can be raised when assigning invalid data to an attribute
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message="Failed to evaluate property-mapping",
|
||||
message=f"Failed to evaluate property-mapping {exception_to_string(exc)}",
|
||||
mapping=exc.mapping,
|
||||
).with_exception(exc).save()
|
||||
).save()
|
||||
return None
|
||||
return b64encode(packet.RequestPacket()).decode()
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import AnyUrl, BaseModel, ConfigDict, Field
|
||||
from pydantic import Field
|
||||
from pydanticscim.group import Group as BaseGroup
|
||||
from pydanticscim.responses import PatchOperation as BasePatchOperation
|
||||
from pydanticscim.responses import PatchRequest as BasePatchRequest
|
||||
@@ -12,95 +12,19 @@ from pydanticscim.service_provider import ChangePassword, Filter, Patch, Sort
|
||||
from pydanticscim.service_provider import (
|
||||
ServiceProviderConfiguration as BaseServiceProviderConfiguration,
|
||||
)
|
||||
from pydanticscim.user import AddressKind
|
||||
from pydanticscim.user import User as BaseUser
|
||||
|
||||
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
|
||||
SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group"
|
||||
|
||||
|
||||
class Address(BaseModel):
|
||||
formatted: str | None = Field(
|
||||
None,
|
||||
description="The full mailing address, formatted for display "
|
||||
"or use with a mailing label. This attribute MAY contain newlines.",
|
||||
)
|
||||
streetAddress: str | None = Field(
|
||||
None,
|
||||
description="The full street address component, which may "
|
||||
"include house number, street name, P.O. box, and multi-line "
|
||||
"extended street address information. This attribute MAY contain newlines.",
|
||||
)
|
||||
locality: str | None = Field(None, description="The city or locality component.")
|
||||
region: str | None = Field(None, description="The state or region component.")
|
||||
postalCode: str | None = Field(None, description="The zip code or postal code component.")
|
||||
country: str | None = Field(None, description="The country name component.")
|
||||
type: AddressKind | None = Field(
|
||||
None,
|
||||
description="A label indicating the attribute's function, e.g., 'work' or 'home'.",
|
||||
)
|
||||
primary: bool | None = None
|
||||
|
||||
|
||||
class Manager(BaseModel):
|
||||
value: str | None = Field(
|
||||
None,
|
||||
description="The id of the SCIM resource representingthe User's manager. REQUIRED.",
|
||||
)
|
||||
ref: AnyUrl | None = Field(
|
||||
None,
|
||||
alias="$ref",
|
||||
description="The URI of the SCIM resource representing the User's manager. REQUIRED.",
|
||||
)
|
||||
displayName: str | None = Field(
|
||||
None,
|
||||
description="The displayName of the User's manager. OPTIONAL and READ-ONLY.",
|
||||
)
|
||||
|
||||
|
||||
class EnterpriseUser(BaseModel):
|
||||
employeeNumber: str | None = Field(
|
||||
None,
|
||||
description="Numeric or alphanumeric identifier assigned to a person, "
|
||||
"typically based on order of hire or association with anorganization.",
|
||||
)
|
||||
costCenter: str | None = Field(None, description="Identifies the name of a cost center.")
|
||||
organization: str | None = Field(None, description="Identifies the name of an organization.")
|
||||
division: str | None = Field(None, description="Identifies the name of a division.")
|
||||
department: str | None = Field(
|
||||
None,
|
||||
description="Numeric or alphanumeric identifier assigned to a person,"
|
||||
" typically based on order of hire or association with anorganization.",
|
||||
)
|
||||
manager: Manager | None = Field(
|
||||
None,
|
||||
description="The User's manager. A complex type that optionally allows "
|
||||
"service providers to represent organizational hierarchy by referencing"
|
||||
" the 'id' attribute of another User.",
|
||||
)
|
||||
|
||||
|
||||
class User(BaseUser):
|
||||
"""Modified User schema with added externalId field"""
|
||||
|
||||
model_config = ConfigDict(serialize_by_alias=True)
|
||||
|
||||
id: str | int | None = None
|
||||
schemas: list[str] = [SCIM_USER_SCHEMA]
|
||||
externalId: str | None = None
|
||||
meta: dict | None = None
|
||||
addresses: list[Address] | None = Field(
|
||||
None,
|
||||
description=(
|
||||
"A physical mailing address for this User. Canonical type "
|
||||
"values of 'work', 'home', and 'other'."
|
||||
),
|
||||
)
|
||||
enterprise_user: EnterpriseUser | None = Field(
|
||||
default=None,
|
||||
alias="urn:ietf:params:scim:schemas:extension:enterprise:2.0:User",
|
||||
serialization_alias="urn:ietf:params:scim:schemas:extension:enterprise:2.0:User",
|
||||
)
|
||||
|
||||
|
||||
class Group(BaseGroup):
|
||||
@@ -168,7 +92,7 @@ class PatchOperation(BasePatchOperation):
|
||||
"""PatchOperation with optional path"""
|
||||
|
||||
op: PatchOp
|
||||
path: str | None = None
|
||||
path: str | None
|
||||
|
||||
|
||||
class SCIMError(BaseSCIMError):
|
||||
|
||||
@@ -28,6 +28,7 @@ from tenant_schemas_celery.app import CeleryApp as TenantAwareCeleryApp
|
||||
|
||||
from authentik import get_full_version
|
||||
from authentik.lib.sentry import should_ignore_exception
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
|
||||
# set the default Django settings module for the 'celery' program.
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings")
|
||||
@@ -82,8 +83,8 @@ def task_error_hook(task_id: str, exception: Exception, traceback, *args, **kwar
|
||||
CTX_TASK_ID.set(...)
|
||||
if not should_ignore_exception(exception):
|
||||
Event.new(
|
||||
EventAction.SYSTEM_EXCEPTION, message="Failed to execute task", task_id=task_id
|
||||
).with_exception(exception).save()
|
||||
EventAction.SYSTEM_EXCEPTION, message=exception_to_string(exception), task_id=task_id
|
||||
).save()
|
||||
|
||||
|
||||
def _get_startup_tasks_default_tenant() -> list[Callable]:
|
||||
|
||||
@@ -49,8 +49,6 @@ class ReadyView(View):
|
||||
def dispatch(self, request: HttpRequest) -> HttpResponse:
|
||||
try:
|
||||
for db_conn in connections.all():
|
||||
# Force connection reload
|
||||
db_conn.connect()
|
||||
_ = db_conn.cursor()
|
||||
except OperationalError: # pragma: no cover
|
||||
return HttpResponse(status=503)
|
||||
|
||||
@@ -156,17 +156,16 @@ SPECTACULAR_SETTINGS = {
|
||||
},
|
||||
"ENUM_NAME_OVERRIDES": {
|
||||
"CountryCodeEnum": "django_countries.countries",
|
||||
"DeviceClassesEnum": "authentik.stages.authenticator_validate.models.DeviceClasses",
|
||||
"EventActions": "authentik.events.models.EventAction",
|
||||
"FlowDesignationEnum": "authentik.flows.models.FlowDesignation",
|
||||
"FlowLayoutEnum": "authentik.flows.models.FlowLayout",
|
||||
"LDAPAPIAccessMode": "authentik.providers.ldap.models.APIAccessMode",
|
||||
"OutgoingSyncDeleteAction": "authentik.lib.sync.outgoing.models.OutgoingSyncDeleteAction",
|
||||
"PolicyEngineMode": "authentik.policies.models.PolicyEngineMode",
|
||||
"PromptTypeEnum": "authentik.stages.prompt.models.FieldTypes",
|
||||
"ProxyMode": "authentik.providers.proxy.models.ProxyMode",
|
||||
"UserTypeEnum": "authentik.core.models.UserTypes",
|
||||
"PromptTypeEnum": "authentik.stages.prompt.models.FieldTypes",
|
||||
"LDAPAPIAccessMode": "authentik.providers.ldap.models.APIAccessMode",
|
||||
"UserVerificationEnum": "authentik.stages.authenticator_webauthn.models.UserVerification",
|
||||
"UserTypeEnum": "authentik.core.models.UserTypes",
|
||||
"OutgoingSyncDeleteAction": "authentik.lib.sync.outgoing.models.OutgoingSyncDeleteAction",
|
||||
},
|
||||
"ENUM_ADD_EXPLICIT_BLANK_NULL_CHOICE": False,
|
||||
"ENUM_GENERATE_CHOICE_DESCRIPTION": False,
|
||||
|
||||
@@ -4,11 +4,11 @@ from pathlib import Path
|
||||
from secrets import token_urlsafe
|
||||
from tempfile import gettempdir
|
||||
|
||||
from django.test import TransactionTestCase
|
||||
from django.test import TestCase
|
||||
from django.urls import reverse
|
||||
|
||||
|
||||
class TestRoot(TransactionTestCase):
|
||||
class TestRoot(TestCase):
|
||||
"""Test root application"""
|
||||
|
||||
def setUp(self):
|
||||
|
||||
@@ -8,6 +8,7 @@ from authentik.events.models import TaskStatus
|
||||
from authentik.events.system_tasks import SystemTask
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.sync.outgoing.exceptions import StopSync
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.root.celery import CELERY_APP
|
||||
from authentik.sources.kerberos.models import KerberosSource
|
||||
from authentik.sources.kerberos.sync import KerberosSync
|
||||
@@ -63,5 +64,5 @@ def kerberos_sync_single(self, source_pk: str):
|
||||
syncer.sync()
|
||||
self.set_status(TaskStatus.SUCCESSFUL, *syncer.messages)
|
||||
except StopSync as exc:
|
||||
LOGGER.warning("Error syncing kerberos", exc=exc, source=source)
|
||||
LOGGER.warning(exception_to_string(exc))
|
||||
self.set_error(exc)
|
||||
|
||||
@@ -12,6 +12,7 @@ from authentik.events.models import TaskStatus
|
||||
from authentik.events.system_tasks import SystemTask
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.sync.outgoing.exceptions import StopSync
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.lib.utils.reflection import class_to_path, path_to_class
|
||||
from authentik.root.celery import CELERY_APP
|
||||
from authentik.sources.ldap.models import LDAPSource
|
||||
@@ -148,5 +149,5 @@ def ldap_sync(self: SystemTask, source_pk: str, sync_class: str, page_cache_key:
|
||||
cache.delete(page_cache_key)
|
||||
except (LDAPException, StopSync) as exc:
|
||||
# No explicit event is created here as .set_status with an error will do that
|
||||
LOGGER.warning("Failed to sync LDAP", exc=exc, source=source)
|
||||
LOGGER.warning(exception_to_string(exc))
|
||||
self.set_error(exc)
|
||||
|
||||
@@ -10,7 +10,6 @@ AUTHENTIK_SOURCES_OAUTH_TYPES = [
|
||||
"authentik.sources.oauth.types.apple",
|
||||
"authentik.sources.oauth.types.azure_ad",
|
||||
"authentik.sources.oauth.types.discord",
|
||||
"authentik.sources.oauth.types.entra_id",
|
||||
"authentik.sources.oauth.types.facebook",
|
||||
"authentik.sources.oauth.types.github",
|
||||
"authentik.sources.oauth.types.gitlab",
|
||||
|
||||
@@ -232,7 +232,7 @@ class GoogleOAuthSource(CreatableType, OAuthSource):
|
||||
|
||||
|
||||
class AzureADOAuthSource(CreatableType, OAuthSource):
|
||||
"""(Deprecated) Social Login using Azure AD."""
|
||||
"""Social Login using Azure AD."""
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
@@ -240,17 +240,6 @@ class AzureADOAuthSource(CreatableType, OAuthSource):
|
||||
verbose_name_plural = _("Azure AD OAuth Sources")
|
||||
|
||||
|
||||
# TODO: When removing this, add a migration for OAuthSource that sets
|
||||
# provider_type to `entraid` if it is currently `azuread`
|
||||
class EntraIDOAuthSource(CreatableType, OAuthSource):
|
||||
"""Social Login using Entra ID."""
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
verbose_name = _("Entra ID OAuth Source")
|
||||
verbose_name_plural = _("Entra ID OAuth Sources")
|
||||
|
||||
|
||||
class OpenIDConnectOAuthSource(CreatableType, OAuthSource):
|
||||
"""Login using a Generic OpenID-Connect compliant provider."""
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
"""Entra ID Type tests"""
|
||||
"""azure ad Type tests"""
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
from authentik.sources.oauth.models import OAuthSource
|
||||
from authentik.sources.oauth.types.entra_id import EntraIDOAuthCallback, EntraIDType
|
||||
from authentik.sources.oauth.types.azure_ad import AzureADOAuthCallback, AzureADType
|
||||
|
||||
# https://docs.microsoft.com/en-us/graph/api/user-get?view=graph-rest-1.0&tabs=http#response-2
|
||||
EID_USER = {
|
||||
AAD_USER = {
|
||||
"@odata.context": "https://graph.microsoft.com/v1.0/$metadata#users/$entity",
|
||||
"@odata.id": (
|
||||
"https://graph.microsoft.com/v2/7ce9b89e-646a-41d2-9fa6-8371c6a8423d/"
|
||||
@@ -41,11 +41,11 @@ class TestTypeAzureAD(TestCase):
|
||||
|
||||
def test_enroll_context(self):
|
||||
"""Test azure_ad Enrollment context"""
|
||||
ak_context = EntraIDType().get_base_user_properties(source=self.source, info=EID_USER)
|
||||
self.assertEqual(ak_context["username"], EID_USER["userPrincipalName"])
|
||||
self.assertEqual(ak_context["email"], EID_USER["mail"])
|
||||
self.assertEqual(ak_context["name"], EID_USER["displayName"])
|
||||
ak_context = AzureADType().get_base_user_properties(source=self.source, info=AAD_USER)
|
||||
self.assertEqual(ak_context["username"], AAD_USER["userPrincipalName"])
|
||||
self.assertEqual(ak_context["email"], AAD_USER["mail"])
|
||||
self.assertEqual(ak_context["name"], AAD_USER["displayName"])
|
||||
|
||||
def test_user_id(self):
|
||||
"""Test Entra ID user ID"""
|
||||
self.assertEqual(EntraIDOAuthCallback().get_user_id(EID_USER), EID_USER["id"])
|
||||
"""Test azure AD user ID"""
|
||||
self.assertEqual(AzureADOAuthCallback().get_user_id(AAD_USER), AAD_USER["id"])
|
||||
@@ -1,17 +1,105 @@
|
||||
"""AzureAD OAuth2 Views"""
|
||||
|
||||
from authentik.sources.oauth.types.entra_id import EntraIDType
|
||||
from authentik.sources.oauth.types.registry import registry
|
||||
from typing import Any
|
||||
|
||||
# TODO: When removing this, add a migration for OAuthSource that sets
|
||||
# provider_type to `entraid` if it is currently `azuread`
|
||||
from requests import RequestException
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
|
||||
from authentik.sources.oauth.models import AuthorizationCodeAuthMethod
|
||||
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
|
||||
from authentik.sources.oauth.types.registry import SourceType, registry
|
||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
class AzureADOAuthRedirect(OAuthRedirect):
|
||||
"""Azure AD OAuth2 Redirect"""
|
||||
|
||||
def get_additional_parameters(self, source): # pragma: no cover
|
||||
return {
|
||||
"scope": ["openid", "https://graph.microsoft.com/User.Read"],
|
||||
}
|
||||
|
||||
|
||||
class AzureADClient(UserprofileHeaderAuthClient):
|
||||
"""Fetch AzureAD group information"""
|
||||
|
||||
def get_profile_info(self, token):
|
||||
profile_data = super().get_profile_info(token)
|
||||
if "https://graph.microsoft.com/GroupMember.Read.All" not in self.source.additional_scopes:
|
||||
return profile_data
|
||||
group_response = self.session.request(
|
||||
"get",
|
||||
"https://graph.microsoft.com/v1.0/me/memberOf",
|
||||
headers={"Authorization": f"{token['token_type']} {token['access_token']}"},
|
||||
)
|
||||
try:
|
||||
group_response.raise_for_status()
|
||||
except RequestException as exc:
|
||||
LOGGER.warning(
|
||||
"Unable to fetch user profile",
|
||||
exc=exc,
|
||||
response=exc.response.text if exc.response else str(exc),
|
||||
)
|
||||
return None
|
||||
profile_data["raw_groups"] = group_response.json()
|
||||
return profile_data
|
||||
|
||||
|
||||
class AzureADOAuthCallback(OpenIDConnectOAuth2Callback):
|
||||
"""AzureAD OAuth2 Callback"""
|
||||
|
||||
client_class = AzureADClient
|
||||
|
||||
def get_user_id(self, info: dict[str, str]) -> str:
|
||||
# Default try to get `id` for the Graph API endpoint
|
||||
# fallback to OpenID logic in case the profile URL was changed
|
||||
return info.get("id", super().get_user_id(info))
|
||||
|
||||
|
||||
@registry.register()
|
||||
class AzureADType(EntraIDType):
|
||||
class AzureADType(SourceType):
|
||||
"""Azure AD Type definition"""
|
||||
|
||||
callback_view = AzureADOAuthCallback
|
||||
redirect_view = AzureADOAuthRedirect
|
||||
verbose_name = "Azure AD"
|
||||
name = "azuread"
|
||||
|
||||
urls_customizable = True
|
||||
|
||||
authorization_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
||||
access_token_url = "https://login.microsoftonline.com/common/oauth2/v2.0/token" # nosec
|
||||
profile_url = "https://graph.microsoft.com/v1.0/me"
|
||||
oidc_well_known_url = (
|
||||
"https://login.microsoftonline.com/common/.well-known/openid-configuration"
|
||||
)
|
||||
oidc_jwks_url = "https://login.microsoftonline.com/common/discovery/keys"
|
||||
|
||||
authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY
|
||||
|
||||
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
|
||||
mail = info.get("mail", None) or info.get("otherMails", [None])[0]
|
||||
# Format group info
|
||||
groups = []
|
||||
group_id_dict = {}
|
||||
for group in info.get("raw_groups", {}).get("value", []):
|
||||
if group["@odata.type"] != "#microsoft.graph.group":
|
||||
continue
|
||||
groups.append(group["id"])
|
||||
group_id_dict[group["id"]] = group
|
||||
info["raw_groups"] = group_id_dict
|
||||
return {
|
||||
"username": info.get("userPrincipalName"),
|
||||
"email": mail,
|
||||
"name": info.get("displayName"),
|
||||
"groups": groups,
|
||||
}
|
||||
|
||||
def get_base_group_properties(self, source, group_id, **kwargs):
|
||||
raw_group = kwargs["info"]["raw_groups"][group_id]
|
||||
return {
|
||||
"name": raw_group["displayName"],
|
||||
}
|
||||
|
||||
@@ -1,102 +0,0 @@
|
||||
"""EntraID OAuth2 Views"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from requests import RequestException
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
|
||||
from authentik.sources.oauth.models import AuthorizationCodeAuthMethod
|
||||
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
|
||||
from authentik.sources.oauth.types.registry import SourceType, registry
|
||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
class EntraIDOAuthRedirect(OAuthRedirect):
|
||||
"""Entra ID OAuth2 Redirect"""
|
||||
|
||||
def get_additional_parameters(self, source): # pragma: no cover
|
||||
return {
|
||||
"scope": ["openid", "https://graph.microsoft.com/User.Read"],
|
||||
}
|
||||
|
||||
|
||||
class EntraIDClient(UserprofileHeaderAuthClient):
|
||||
"""Fetch EntraID group information"""
|
||||
|
||||
def get_profile_info(self, token):
|
||||
profile_data = super().get_profile_info(token)
|
||||
if "https://graph.microsoft.com/GroupMember.Read.All" not in self.source.additional_scopes:
|
||||
return profile_data
|
||||
group_response = self.session.request(
|
||||
"get",
|
||||
"https://graph.microsoft.com/v1.0/me/memberOf",
|
||||
headers={"Authorization": f"{token['token_type']} {token['access_token']}"},
|
||||
)
|
||||
try:
|
||||
group_response.raise_for_status()
|
||||
except RequestException as exc:
|
||||
LOGGER.warning(
|
||||
"Unable to fetch user profile",
|
||||
exc=exc,
|
||||
response=exc.response.text if exc.response else str(exc),
|
||||
)
|
||||
return None
|
||||
profile_data["raw_groups"] = group_response.json()
|
||||
return profile_data
|
||||
|
||||
|
||||
class EntraIDOAuthCallback(OpenIDConnectOAuth2Callback):
|
||||
"""EntraID OAuth2 Callback"""
|
||||
|
||||
client_class = EntraIDClient
|
||||
|
||||
def get_user_id(self, info: dict[str, str]) -> str:
|
||||
# Default try to get `id` for the Graph API endpoint
|
||||
# fallback to OpenID logic in case the profile URL was changed
|
||||
return info.get("id", super().get_user_id(info))
|
||||
|
||||
|
||||
@registry.register()
|
||||
class EntraIDType(SourceType):
|
||||
"""Entra ID Type definition"""
|
||||
|
||||
callback_view = EntraIDOAuthCallback
|
||||
redirect_view = EntraIDOAuthRedirect
|
||||
verbose_name = "Entra ID"
|
||||
name = "entraid"
|
||||
|
||||
urls_customizable = True
|
||||
|
||||
authorization_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
||||
access_token_url = "https://login.microsoftonline.com/common/oauth2/v2.0/token" # nosec
|
||||
profile_url = "https://graph.microsoft.com/v1.0/me"
|
||||
oidc_jwks_url = "https://login.microsoftonline.com/common/discovery/keys"
|
||||
|
||||
authorization_code_auth_method = AuthorizationCodeAuthMethod.POST_BODY
|
||||
|
||||
def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]:
|
||||
mail = info.get("mail", None) or info.get("otherMails", [None])[0]
|
||||
# Format group info
|
||||
groups = []
|
||||
group_id_dict = {}
|
||||
for group in info.get("raw_groups", {}).get("value", []):
|
||||
if group["@odata.type"] != "#microsoft.graph.group":
|
||||
continue
|
||||
groups.append(group["id"])
|
||||
group_id_dict[group["id"]] = group
|
||||
info["raw_groups"] = group_id_dict
|
||||
return {
|
||||
"username": info.get("userPrincipalName"),
|
||||
"email": mail,
|
||||
"name": info.get("displayName"),
|
||||
"groups": groups,
|
||||
}
|
||||
|
||||
def get_base_group_properties(self, source, group_id, **kwargs):
|
||||
raw_group = kwargs["info"]["raw_groups"][group_id]
|
||||
return {
|
||||
"name": raw_group["displayName"],
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
def create_missing_groupplexsourceconnection(apps, schema_editor):
|
||||
db_alias = schema_editor.connection.alias
|
||||
|
||||
GroupSourceConnection = apps.get_model("authentik_core", "GroupSourceConnection")
|
||||
PlexSource = apps.get_model("authentik_sources_plex", "PlexSource")
|
||||
GroupPlexSourceConnection = apps.get_model(
|
||||
"authentik_sources_plex", "GroupPlexSourceConnection"
|
||||
)
|
||||
|
||||
for source in PlexSource.objects.using(db_alias).all():
|
||||
for gsc in GroupSourceConnection.objects.using(db_alias).filter(source=source):
|
||||
if GroupPlexSourceConnection.objects.using(db_alias).filter(pk=gsc.pk).exists():
|
||||
continue
|
||||
gpsc = GroupPlexSourceConnection(pk=gsc.pk)
|
||||
gpsc.save(using=db_alias)
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
(
|
||||
"authentik_sources_plex",
|
||||
"0005_migrate_userplexsourceconnection_identifier",
|
||||
),
|
||||
("authentik_core", "0044_usersourceconnection_new_identifier"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RunPython(
|
||||
code=create_missing_groupplexsourceconnection, reverse_code=migrations.RunPython.noop
|
||||
),
|
||||
]
|
||||
@@ -9,11 +9,7 @@ from structlog.stdlib import get_logger
|
||||
from authentik import __version__
|
||||
from authentik.core.sources.flow_manager import SourceFlowManager
|
||||
from authentik.lib.utils.http import get_http_session
|
||||
from authentik.sources.plex.models import (
|
||||
GroupPlexSourceConnection,
|
||||
PlexSource,
|
||||
UserPlexSourceConnection,
|
||||
)
|
||||
from authentik.sources.plex.models import PlexSource, UserPlexSourceConnection
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
@@ -114,7 +110,6 @@ class PlexSourceFlowManager(SourceFlowManager):
|
||||
"""Flow manager for plex sources"""
|
||||
|
||||
user_connection_type = UserPlexSourceConnection
|
||||
group_connection_type = GroupPlexSourceConnection
|
||||
|
||||
def update_user_connection(
|
||||
self, connection: UserPlexSourceConnection, **kwargs
|
||||
|
||||
@@ -18,7 +18,6 @@ class SCIMSourceGroupSerializer(SourceSerializer):
|
||||
model = SCIMSourceGroup
|
||||
fields = [
|
||||
"id",
|
||||
"external_id",
|
||||
"group",
|
||||
"group_obj",
|
||||
"source",
|
||||
@@ -32,5 +31,5 @@ class SCIMSourceGroupViewSet(UsedByMixin, ModelViewSet):
|
||||
queryset = SCIMSourceGroup.objects.all().select_related("group")
|
||||
serializer_class = SCIMSourceGroupSerializer
|
||||
filterset_fields = ["source__slug", "group__name", "group__group_uuid"]
|
||||
search_fields = ["source__slug", "group__name", "attributes", "external_id"]
|
||||
search_fields = ["source__slug", "group__name", "attributes"]
|
||||
ordering = ["group__name"]
|
||||
|
||||
@@ -18,7 +18,6 @@ class SCIMSourceUserSerializer(SourceSerializer):
|
||||
model = SCIMSourceUser
|
||||
fields = [
|
||||
"id",
|
||||
"external_id",
|
||||
"user",
|
||||
"user_obj",
|
||||
"source",
|
||||
@@ -32,5 +31,5 @@ class SCIMSourceUserViewSet(UsedByMixin, ModelViewSet):
|
||||
queryset = SCIMSourceUser.objects.all().select_related("user")
|
||||
serializer_class = SCIMSourceUserSerializer
|
||||
filterset_fields = ["source__slug", "user__username", "user__id"]
|
||||
search_fields = ["source__slug", "user__username", "attributes", "user__uuid", "external_id"]
|
||||
search_fields = ["source__slug", "user__username", "attributes"]
|
||||
ordering = ["user__username"]
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
SCIM_URN_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Schema"
|
||||
SCIM_URN_GROUP = "urn:ietf:params:scim:schemas:core:2.0:Group"
|
||||
SCIM_URN_USER = "urn:ietf:params:scim:schemas:core:2.0:User"
|
||||
SCIM_URN_USER_ENTERPRISE = "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
|
||||
8
authentik/sources/scim/errors.py
Normal file
8
authentik/sources/scim/errors.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""SCIM Errors"""
|
||||
|
||||
from authentik.lib.sentry import SentryIgnoredException
|
||||
|
||||
|
||||
class PatchError(SentryIgnoredException):
|
||||
"""Error raised within an atomic block when an error happened
|
||||
so nothing is saved"""
|
||||
@@ -1,98 +0,0 @@
|
||||
# Generated by Django 5.1.11 on 2025-07-13 01:07
|
||||
|
||||
import uuid
|
||||
from django.db import migrations, models
|
||||
from django.apps.registry import Apps
|
||||
|
||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||
|
||||
|
||||
def migrate_ext_id(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
|
||||
SCIMSourceUser = apps.get_model("authentik_sources_scim", "SCIMSourceUser")
|
||||
SCIMSourceGroup = apps.get_model("authentik_sources_scim", "SCIMSourceGroup")
|
||||
db_alias = schema_editor.connection.alias
|
||||
for user in SCIMSourceUser.objects.using(db_alias).all():
|
||||
user.external_id = user.id
|
||||
user.save(update_fields=["external_id"])
|
||||
for group in SCIMSourceGroup.objects.using(db_alias).all():
|
||||
group.external_id = group.id
|
||||
group.save(update_fields=["external_id"])
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_sources_scim", "0002_scimsourcepropertymapping"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterUniqueTogether(
|
||||
name="scimsourcegroup",
|
||||
unique_together=set(),
|
||||
),
|
||||
migrations.AlterUniqueTogether(
|
||||
name="scimsourceuser",
|
||||
unique_together=set(),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="scimsourcegroup",
|
||||
name="external_id",
|
||||
field=models.TextField(default=None, null=True),
|
||||
preserve_default=False,
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="scimsourceuser",
|
||||
name="external_id",
|
||||
field=models.TextField(default=None, null=True),
|
||||
preserve_default=False,
|
||||
),
|
||||
migrations.AlterUniqueTogether(
|
||||
name="scimsourcegroup",
|
||||
unique_together={("external_id", "source")},
|
||||
),
|
||||
migrations.AlterUniqueTogether(
|
||||
name="scimsourceuser",
|
||||
unique_together={("external_id", "source")},
|
||||
),
|
||||
migrations.RunPython(migrate_ext_id, migrations.RunPython.noop),
|
||||
migrations.AlterField(
|
||||
model_name="scimsourcegroup",
|
||||
name="external_id",
|
||||
field=models.TextField(),
|
||||
preserve_default=False,
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="scimsourceuser",
|
||||
name="external_id",
|
||||
field=models.TextField(),
|
||||
preserve_default=False,
|
||||
),
|
||||
migrations.AddIndex(
|
||||
model_name="scimsourcegroup",
|
||||
index=models.Index(fields=["external_id"], name="authentik_s_externa_05e346_idx"),
|
||||
),
|
||||
migrations.AddIndex(
|
||||
model_name="scimsourceuser",
|
||||
index=models.Index(fields=["external_id"], name="authentik_s_externa_4bd760_idx"),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="scimsourcegroup",
|
||||
name="id",
|
||||
field=models.TextField(default=uuid.uuid4, primary_key=True, serialize=False),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="scimsourceuser",
|
||||
name="id",
|
||||
field=models.TextField(default=uuid.uuid4, primary_key=True, serialize=False),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="scimsourcegroup",
|
||||
name="last_update",
|
||||
field=models.DateTimeField(auto_now=True),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="scimsourceuser",
|
||||
name="last_update",
|
||||
field=models.DateTimeField(auto_now=True),
|
||||
),
|
||||
]
|
||||
@@ -1,7 +1,6 @@
|
||||
"""SCIM Source"""
|
||||
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from django.db import models
|
||||
from django.templatetags.static import static
|
||||
@@ -104,12 +103,10 @@ class SCIMSourcePropertyMapping(PropertyMapping):
|
||||
class SCIMSourceUser(SerializerModel):
|
||||
"""Mapping of a user and source to a SCIM user ID"""
|
||||
|
||||
id = models.TextField(primary_key=True, default=uuid4)
|
||||
external_id = models.TextField()
|
||||
id = models.TextField(primary_key=True)
|
||||
user = models.ForeignKey(User, on_delete=models.CASCADE)
|
||||
source = models.ForeignKey(SCIMSource, on_delete=models.CASCADE)
|
||||
attributes = models.JSONField(default=dict)
|
||||
last_update = models.DateTimeField(auto_now=True)
|
||||
|
||||
@property
|
||||
def serializer(self) -> BaseSerializer:
|
||||
@@ -118,10 +115,7 @@ class SCIMSourceUser(SerializerModel):
|
||||
return SCIMSourceUserSerializer
|
||||
|
||||
class Meta:
|
||||
unique_together = (("external_id", "source"),)
|
||||
indexes = [
|
||||
models.Index(fields=["external_id"]),
|
||||
]
|
||||
unique_together = (("id", "user", "source"),)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"SCIM User {self.user_id} to {self.source_id}"
|
||||
@@ -130,12 +124,10 @@ class SCIMSourceUser(SerializerModel):
|
||||
class SCIMSourceGroup(SerializerModel):
|
||||
"""Mapping of a group and source to a SCIM user ID"""
|
||||
|
||||
id = models.TextField(primary_key=True, default=uuid4)
|
||||
external_id = models.TextField()
|
||||
id = models.TextField(primary_key=True)
|
||||
group = models.ForeignKey(Group, on_delete=models.CASCADE)
|
||||
source = models.ForeignKey(SCIMSource, on_delete=models.CASCADE)
|
||||
attributes = models.JSONField(default=dict)
|
||||
last_update = models.DateTimeField(auto_now=True)
|
||||
|
||||
@property
|
||||
def serializer(self) -> BaseSerializer:
|
||||
@@ -144,10 +136,7 @@ class SCIMSourceGroup(SerializerModel):
|
||||
return SCIMSourceGroupSerializer
|
||||
|
||||
class Meta:
|
||||
unique_together = (("external_id", "source"),)
|
||||
indexes = [
|
||||
models.Index(fields=["external_id"]),
|
||||
]
|
||||
unique_together = (("id", "group", "source"),)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"SCIM Group {self.group_id} to {self.source_id}"
|
||||
|
||||
@@ -1,180 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from authentik.sources.scim.constants import (
|
||||
SCIM_URN_GROUP,
|
||||
SCIM_URN_SCHEMA,
|
||||
SCIM_URN_USER,
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
)
|
||||
|
||||
|
||||
# Token types for SCIM path parsing
|
||||
class TokenType(Enum):
|
||||
ATTRIBUTE = "ATTRIBUTE"
|
||||
DOT = "DOT"
|
||||
LBRACKET = "LBRACKET"
|
||||
RBRACKET = "RBRACKET"
|
||||
LPAREN = "LPAREN"
|
||||
RPAREN = "RPAREN"
|
||||
STRING = "STRING"
|
||||
NUMBER = "NUMBER"
|
||||
BOOLEAN = "BOOLEAN"
|
||||
NULL = "NULL"
|
||||
OPERATOR = "OPERATOR"
|
||||
AND = "AND"
|
||||
OR = "OR"
|
||||
NOT = "NOT"
|
||||
EOF = "EOF"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Token:
|
||||
type: TokenType
|
||||
value: str
|
||||
position: int = 0
|
||||
|
||||
|
||||
class SCIMPathLexer:
|
||||
"""Lexer for SCIM paths and filter expressions"""
|
||||
|
||||
OPERATORS = ["eq", "ne", "co", "sw", "ew", "gt", "lt", "ge", "le", "pr"]
|
||||
|
||||
def __init__(self, text: str):
|
||||
self.schema_urns = [
|
||||
SCIM_URN_SCHEMA,
|
||||
SCIM_URN_GROUP,
|
||||
SCIM_URN_USER,
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
]
|
||||
self.text = text
|
||||
self.pos = 0
|
||||
self.current_char = self.text[self.pos] if self.pos < len(self.text) else None
|
||||
|
||||
def advance(self):
|
||||
"""Move to next character"""
|
||||
self.pos += 1
|
||||
self.current_char = self.text[self.pos] if self.pos < len(self.text) else None
|
||||
|
||||
def skip_whitespace(self):
|
||||
"""Skip whitespace characters"""
|
||||
while self.current_char and self.current_char.isspace():
|
||||
self.advance()
|
||||
|
||||
def read_string(self, quote_char):
|
||||
"""Read a quoted string"""
|
||||
value = ""
|
||||
self.advance() # Skip opening quote
|
||||
|
||||
while self.current_char and self.current_char != quote_char:
|
||||
if self.current_char == "\\":
|
||||
self.advance()
|
||||
if self.current_char:
|
||||
value += self.current_char
|
||||
self.advance()
|
||||
else:
|
||||
value += self.current_char
|
||||
self.advance()
|
||||
|
||||
if self.current_char == quote_char:
|
||||
self.advance() # Skip closing quote
|
||||
|
||||
return value
|
||||
|
||||
def read_number(self):
|
||||
"""Read a number (integer or float)"""
|
||||
value = ""
|
||||
while self.current_char and (self.current_char.isdigit() or self.current_char == "."):
|
||||
value += self.current_char
|
||||
self.advance()
|
||||
return value
|
||||
|
||||
def read_identifier(self):
|
||||
"""Read an identifier (attribute name or operator) - supports URN format"""
|
||||
value = ""
|
||||
while self.current_char and (self.current_char.isalnum() or self.current_char in "_-:"):
|
||||
value += self.current_char
|
||||
self.advance()
|
||||
# If the identifier value so far is a schema URN, take that as the identifier and
|
||||
# treat the next part as a sub_attribute
|
||||
if value in self.schema_urns:
|
||||
self.current_char = "."
|
||||
return value
|
||||
|
||||
# Handle dots within URN identifiers (like "2.0")
|
||||
# A dot is part of the identifier if it's followed by a digit
|
||||
if (
|
||||
self.current_char == "."
|
||||
and self.pos + 1 < len(self.text)
|
||||
and self.text[self.pos + 1].isdigit()
|
||||
):
|
||||
value += self.current_char
|
||||
self.advance()
|
||||
# Continue reading digits after the dot
|
||||
while self.current_char and self.current_char.isdigit():
|
||||
value += self.current_char
|
||||
self.advance()
|
||||
|
||||
return value
|
||||
|
||||
def get_next_token(self) -> Token: # noqa PLR0911
|
||||
"""Get the next token from the input"""
|
||||
while self.current_char:
|
||||
if self.current_char.isspace():
|
||||
self.skip_whitespace()
|
||||
continue
|
||||
|
||||
if self.current_char == ".":
|
||||
self.advance()
|
||||
return Token(TokenType.DOT, ".")
|
||||
|
||||
if self.current_char == "[":
|
||||
self.advance()
|
||||
return Token(TokenType.LBRACKET, "[")
|
||||
|
||||
if self.current_char == "]":
|
||||
self.advance()
|
||||
return Token(TokenType.RBRACKET, "]")
|
||||
|
||||
if self.current_char == "(":
|
||||
self.advance()
|
||||
return Token(TokenType.LPAREN, "(")
|
||||
|
||||
if self.current_char == ")":
|
||||
self.advance()
|
||||
return Token(TokenType.RPAREN, ")")
|
||||
|
||||
if self.current_char in "\"'":
|
||||
quote_char = self.current_char
|
||||
value = self.read_string(quote_char)
|
||||
return Token(TokenType.STRING, value)
|
||||
|
||||
if self.current_char.isdigit():
|
||||
value = self.read_number()
|
||||
return Token(TokenType.NUMBER, value)
|
||||
|
||||
if self.current_char.isalpha() or self.current_char == "_":
|
||||
value = self.read_identifier()
|
||||
|
||||
# Check for special keywords
|
||||
if value.lower() == "true":
|
||||
return Token(TokenType.BOOLEAN, True)
|
||||
elif value.lower() == "false":
|
||||
return Token(TokenType.BOOLEAN, False)
|
||||
elif value.lower() == "null":
|
||||
return Token(TokenType.NULL, None)
|
||||
elif value.lower() == "and":
|
||||
return Token(TokenType.AND, "and")
|
||||
elif value.lower() == "or":
|
||||
return Token(TokenType.OR, "or")
|
||||
elif value.lower() == "not":
|
||||
return Token(TokenType.NOT, "not")
|
||||
elif value.lower() in self.OPERATORS:
|
||||
return Token(TokenType.OPERATOR, value.lower())
|
||||
else:
|
||||
return Token(TokenType.ATTRIBUTE, value)
|
||||
|
||||
# Skip unknown characters
|
||||
self.advance()
|
||||
|
||||
return Token(TokenType.EOF, "")
|
||||
@@ -1,131 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from authentik.sources.scim.patch.lexer import SCIMPathLexer, TokenType
|
||||
|
||||
|
||||
class SCIMPathParser:
|
||||
"""Parser for SCIM paths including filter expressions"""
|
||||
|
||||
def __init__(self):
|
||||
self.lexer = None
|
||||
self.current_token = None
|
||||
|
||||
def parse_path(self, path: str | None) -> list[dict[str, Any]]:
|
||||
"""Parse a SCIM path into components"""
|
||||
self.lexer = SCIMPathLexer(path)
|
||||
self.current_token = self.lexer.get_next_token()
|
||||
|
||||
components = []
|
||||
|
||||
while self.current_token.type != TokenType.EOF:
|
||||
component = self._parse_path_component()
|
||||
if component:
|
||||
components.append(component)
|
||||
|
||||
return components
|
||||
|
||||
def _parse_path_component(self) -> dict[str, Any] | None:
|
||||
"""Parse a single path component"""
|
||||
if self.current_token.type != TokenType.ATTRIBUTE:
|
||||
return None
|
||||
|
||||
attribute = self.current_token.value
|
||||
self._consume(TokenType.ATTRIBUTE)
|
||||
|
||||
filter_expr = None
|
||||
sub_attribute = None
|
||||
|
||||
# Check for filter expression
|
||||
if self.current_token.type == TokenType.LBRACKET:
|
||||
self._consume(TokenType.LBRACKET)
|
||||
filter_expr = self._parse_filter_expression()
|
||||
self._consume(TokenType.RBRACKET)
|
||||
|
||||
# Check for sub-attribute
|
||||
if self.current_token.type == TokenType.DOT:
|
||||
self._consume(TokenType.DOT)
|
||||
if self.current_token.type == TokenType.ATTRIBUTE:
|
||||
sub_attribute = self.current_token.value
|
||||
self._consume(TokenType.ATTRIBUTE)
|
||||
|
||||
return {"attribute": attribute, "filter": filter_expr, "sub_attribute": sub_attribute}
|
||||
|
||||
def _parse_filter_expression(self) -> dict[str, Any] | None:
|
||||
"""Parse a filter expression like 'primary eq true' or
|
||||
'type eq "work" and primary eq true'"""
|
||||
return self._parse_or_expression()
|
||||
|
||||
def _parse_or_expression(self) -> dict[str, Any] | None:
|
||||
"""Parse OR expressions"""
|
||||
left = self._parse_and_expression()
|
||||
|
||||
while self.current_token.type == TokenType.OR:
|
||||
self._consume(TokenType.OR)
|
||||
right = self._parse_and_expression()
|
||||
left = {"type": "logical", "operator": "or", "left": left, "right": right}
|
||||
|
||||
return left
|
||||
|
||||
def _parse_and_expression(self) -> dict[str, Any] | None:
|
||||
"""Parse AND expressions"""
|
||||
left = self._parse_primary_expression()
|
||||
|
||||
while self.current_token.type == TokenType.AND:
|
||||
self._consume(TokenType.AND)
|
||||
right = self._parse_primary_expression()
|
||||
left = {"type": "logical", "operator": "and", "left": left, "right": right}
|
||||
|
||||
return left
|
||||
|
||||
def _parse_primary_expression(self) -> dict[str, Any] | None:
|
||||
"""Parse primary expressions (attribute operator value)"""
|
||||
if self.current_token.type == TokenType.LPAREN:
|
||||
self._consume(TokenType.LPAREN)
|
||||
expr = self._parse_or_expression()
|
||||
self._consume(TokenType.RPAREN)
|
||||
return expr
|
||||
|
||||
if self.current_token.type == TokenType.NOT:
|
||||
self._consume(TokenType.NOT)
|
||||
expr = self._parse_primary_expression()
|
||||
return {"type": "logical", "operator": "not", "operand": expr}
|
||||
|
||||
if self.current_token.type != TokenType.ATTRIBUTE:
|
||||
return None
|
||||
|
||||
attribute = self.current_token.value
|
||||
self._consume(TokenType.ATTRIBUTE)
|
||||
|
||||
if self.current_token.type != TokenType.OPERATOR:
|
||||
return None
|
||||
|
||||
operator = self.current_token.value
|
||||
self._consume(TokenType.OPERATOR)
|
||||
|
||||
# Parse value
|
||||
value = None
|
||||
if self.current_token.type == TokenType.STRING:
|
||||
value = self.current_token.value
|
||||
self._consume(TokenType.STRING)
|
||||
elif self.current_token.type == TokenType.NUMBER:
|
||||
value = (
|
||||
float(self.current_token.value)
|
||||
if "." in self.current_token.value
|
||||
else int(self.current_token.value)
|
||||
)
|
||||
self._consume(TokenType.NUMBER)
|
||||
elif self.current_token.type == TokenType.BOOLEAN:
|
||||
value = self.current_token.value
|
||||
self._consume(TokenType.BOOLEAN)
|
||||
elif self.current_token.type == TokenType.NULL:
|
||||
value = None
|
||||
self._consume(TokenType.NULL)
|
||||
|
||||
return {"type": "comparison", "attribute": attribute, "operator": operator, "value": value}
|
||||
|
||||
def _consume(self, expected_type: TokenType):
|
||||
"""Consume a token of the expected type"""
|
||||
if self.current_token.type == expected_type:
|
||||
self.current_token = self.lexer.get_next_token()
|
||||
else:
|
||||
raise ValueError(f"Expected {expected_type}, got {self.current_token.type}")
|
||||
@@ -1,246 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from authentik.providers.scim.clients.schema import PatchOp, PatchOperation
|
||||
from authentik.sources.scim.constants import SCIM_URN_USER_ENTERPRISE
|
||||
from authentik.sources.scim.patch.parser import SCIMPathParser
|
||||
|
||||
|
||||
class SCIMPatchProcessor:
|
||||
"""Processes SCIM patch operations on Python dictionaries"""
|
||||
|
||||
def __init__(self):
|
||||
self.parser = SCIMPathParser()
|
||||
|
||||
def apply_patches(self, data: dict[str, Any], patches: list[PatchOperation]) -> dict[str, Any]:
|
||||
"""Apply a list of patch operations to the data"""
|
||||
result = data.copy()
|
||||
|
||||
for _patch in patches:
|
||||
patch = PatchOperation.model_validate(_patch)
|
||||
if patch.path is None:
|
||||
# Handle operations with no path - value contains attribute paths as keys
|
||||
self._apply_bulk_operation(result, patch.op, patch.value)
|
||||
elif patch.op == PatchOp.add:
|
||||
self._apply_add(result, patch.path, patch.value)
|
||||
elif patch.op == PatchOp.remove:
|
||||
self._apply_remove(result, patch.path)
|
||||
elif patch.op == PatchOp.replace:
|
||||
self._apply_replace(result, patch.path, patch.value)
|
||||
|
||||
return result
|
||||
|
||||
def _apply_bulk_operation(
|
||||
self, data: dict[str, Any], operation: PatchOp, value: dict[str, Any]
|
||||
):
|
||||
"""Apply bulk operations when path is None"""
|
||||
if not isinstance(value, dict):
|
||||
return
|
||||
for path, val in value.items():
|
||||
if operation == PatchOp.add:
|
||||
self._apply_add(data, path, val)
|
||||
elif operation == PatchOp.remove:
|
||||
self._apply_remove(data, path)
|
||||
elif operation == PatchOp.replace:
|
||||
self._apply_replace(data, path, val)
|
||||
|
||||
def _apply_add(self, data: dict[str, Any], path: str, value: Any):
|
||||
"""Apply ADD operation"""
|
||||
components = self.parser.parse_path(path)
|
||||
|
||||
if len(components) == 1 and not components[0]["filter"]:
|
||||
# Simple path
|
||||
attr = components[0]["attribute"]
|
||||
if components[0]["sub_attribute"]:
|
||||
if attr not in data:
|
||||
data[attr] = {}
|
||||
# Somewhat hacky workaround for the manager attribute of the enterprise schema
|
||||
# ideally we'd do this based on the schema
|
||||
if attr == SCIM_URN_USER_ENTERPRISE and components[0]["sub_attribute"] == "manager":
|
||||
data[attr][components[0]["sub_attribute"]] = {"value": value}
|
||||
else:
|
||||
data[attr][components[0]["sub_attribute"]] = value
|
||||
elif attr in data:
|
||||
data[attr].append(value)
|
||||
else:
|
||||
data[attr] = value
|
||||
else:
|
||||
# Complex path with filters
|
||||
self._navigate_and_modify(data, components, value, "add")
|
||||
|
||||
def _apply_remove(self, data: dict[str, Any], path: str):
|
||||
"""Apply REMOVE operation"""
|
||||
components = self.parser.parse_path(path)
|
||||
|
||||
if len(components) == 1 and not components[0]["filter"]:
|
||||
# Simple path
|
||||
attr = components[0]["attribute"]
|
||||
if components[0]["sub_attribute"]:
|
||||
if attr in data and isinstance(data[attr], dict):
|
||||
data[attr].pop(components[0]["sub_attribute"], None)
|
||||
else:
|
||||
data.pop(attr, None)
|
||||
else:
|
||||
# Complex path with filters
|
||||
self._navigate_and_modify(data, components, None, "remove")
|
||||
|
||||
def _apply_replace(self, data: dict[str, Any], path: str, value: Any):
|
||||
"""Apply REPLACE operation"""
|
||||
components = self.parser.parse_path(path)
|
||||
|
||||
if len(components) == 1 and not components[0]["filter"]:
|
||||
# Simple path
|
||||
attr = components[0]["attribute"]
|
||||
if components[0]["sub_attribute"]:
|
||||
if attr not in data:
|
||||
data[attr] = {}
|
||||
# Somewhat hacky workaround for the manager attribute of the enterprise schema
|
||||
# ideally we'd do this based on the schema
|
||||
if attr == SCIM_URN_USER_ENTERPRISE and components[0]["sub_attribute"] == "manager":
|
||||
data[attr][components[0]["sub_attribute"]] = {"value": value}
|
||||
else:
|
||||
data[attr][components[0]["sub_attribute"]] = value
|
||||
else:
|
||||
data[attr] = value
|
||||
else:
|
||||
# Complex path with filters
|
||||
self._navigate_and_modify(data, components, value, "replace")
|
||||
|
||||
def _navigate_and_modify( # noqa PLR0912
|
||||
self, data: dict[str, Any], components: list[dict[str, Any]], value: Any, operation: str
|
||||
):
|
||||
"""Navigate through complex paths and apply modifications"""
|
||||
current = data
|
||||
|
||||
for i, component in enumerate(components):
|
||||
attr = component["attribute"]
|
||||
filter_expr = component["filter"]
|
||||
sub_attr = component["sub_attribute"]
|
||||
|
||||
if filter_expr:
|
||||
# Handle array with filter
|
||||
if attr not in current:
|
||||
if operation == "add":
|
||||
current[attr] = []
|
||||
else:
|
||||
return
|
||||
|
||||
if not isinstance(current[attr], list):
|
||||
return
|
||||
|
||||
# Find matching items
|
||||
matching_items = []
|
||||
for item in current[attr]:
|
||||
if self._matches_filter(item, filter_expr):
|
||||
matching_items.append(item)
|
||||
|
||||
if not matching_items and operation == "add":
|
||||
# Create new item if none match (only for simple comparison filters)
|
||||
if filter_expr.get("type", "comparison") == "comparison":
|
||||
new_item = {filter_expr["attribute"]: filter_expr["value"]}
|
||||
current[attr].append(new_item)
|
||||
matching_items = [new_item]
|
||||
|
||||
# Apply operation to matching items
|
||||
for item in matching_items:
|
||||
if sub_attr:
|
||||
if operation in {"add", "replace"}:
|
||||
item[sub_attr] = value
|
||||
elif operation == "remove":
|
||||
item.pop(sub_attr, None)
|
||||
elif operation in {"add", "replace"}:
|
||||
if isinstance(value, dict):
|
||||
item.update(value)
|
||||
else:
|
||||
# If value is not a dict, we can't merge it
|
||||
pass
|
||||
elif operation == "remove":
|
||||
# Remove the entire item
|
||||
if item in current[attr]:
|
||||
current[attr].remove(item)
|
||||
# Handle simple attribute
|
||||
elif i == len(components) - 1:
|
||||
# Last component
|
||||
if sub_attr:
|
||||
if attr not in current:
|
||||
current[attr] = {}
|
||||
if operation in {"add", "replace"}:
|
||||
current[attr][sub_attr] = value
|
||||
elif operation == "remove":
|
||||
current[attr].pop(sub_attr, None)
|
||||
elif operation in {"add", "replace"}:
|
||||
current[attr] = value
|
||||
elif operation == "remove":
|
||||
current.pop(attr, None)
|
||||
else:
|
||||
# Navigate deeper
|
||||
if attr not in current:
|
||||
current[attr] = {}
|
||||
current = current[attr]
|
||||
|
||||
def _matches_filter(self, item: dict[str, Any], filter_expr: dict[str, Any]) -> bool:
|
||||
"""Check if an item matches the filter expression"""
|
||||
if not filter_expr:
|
||||
return True
|
||||
|
||||
filter_type = filter_expr.get("type", "comparison")
|
||||
|
||||
if filter_type == "comparison":
|
||||
return self._matches_comparison(item, filter_expr)
|
||||
elif filter_type == "logical":
|
||||
return self._matches_logical(item, filter_expr)
|
||||
|
||||
return False
|
||||
|
||||
def _matches_comparison( # noqa PLR0912
|
||||
self, item: dict[str, Any], filter_expr: dict[str, Any]
|
||||
) -> bool:
|
||||
"""Check if an item matches a comparison filter"""
|
||||
attr = filter_expr["attribute"]
|
||||
operator = filter_expr["operator"]
|
||||
expected_value = filter_expr["value"]
|
||||
|
||||
if attr not in item:
|
||||
return False
|
||||
|
||||
actual_value = item[attr]
|
||||
|
||||
if operator == "eq":
|
||||
return actual_value == expected_value
|
||||
elif operator == "ne":
|
||||
return actual_value != expected_value
|
||||
elif operator == "co":
|
||||
return str(expected_value) in str(actual_value)
|
||||
elif operator == "sw":
|
||||
return str(actual_value).startswith(str(expected_value))
|
||||
elif operator == "ew":
|
||||
return str(actual_value).endswith(str(expected_value))
|
||||
elif operator == "gt":
|
||||
return actual_value > expected_value
|
||||
elif operator == "lt":
|
||||
return actual_value < expected_value
|
||||
elif operator == "ge":
|
||||
return actual_value >= expected_value
|
||||
elif operator == "le":
|
||||
return actual_value <= expected_value
|
||||
elif operator == "pr":
|
||||
return actual_value is not None
|
||||
|
||||
return False
|
||||
|
||||
def _matches_logical(self, item: dict[str, Any], filter_expr: dict[str, Any]) -> bool:
|
||||
"""Check if an item matches a logical filter expression"""
|
||||
operator = filter_expr["operator"]
|
||||
|
||||
if operator == "and":
|
||||
left_result = self._matches_filter(item, filter_expr["left"])
|
||||
right_result = self._matches_filter(item, filter_expr["right"])
|
||||
return left_result and right_result
|
||||
elif operator == "or":
|
||||
left_result = self._matches_filter(item, filter_expr["left"])
|
||||
right_result = self._matches_filter(item, filter_expr["right"])
|
||||
return left_result or right_result
|
||||
elif operator == "not":
|
||||
operand_result = self._matches_filter(item, filter_expr["operand"])
|
||||
return not operand_result
|
||||
|
||||
return False
|
||||
@@ -1101,6 +1101,17 @@
|
||||
"returned": "default",
|
||||
"uniqueness": "none"
|
||||
},
|
||||
{
|
||||
"name": "password",
|
||||
"type": "string",
|
||||
"multiValued": false,
|
||||
"description": "The User's cleartext password. This attribute is intended to be used as a means to specify an initial\npassword when creating a new User or to reset an existing User's password.",
|
||||
"required": false,
|
||||
"caseExact": false,
|
||||
"mutability": "writeOnly",
|
||||
"returned": "never",
|
||||
"uniqueness": "none"
|
||||
},
|
||||
{
|
||||
"name": "emails",
|
||||
"type": "complex",
|
||||
|
||||
@@ -75,9 +75,7 @@ class TestSCIMGroups(APITestCase):
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 201)
|
||||
self.assertTrue(
|
||||
SCIMSourceGroup.objects.filter(source=self.source, external_id=ext_id).exists()
|
||||
)
|
||||
self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists())
|
||||
self.assertTrue(
|
||||
Event.objects.filter(
|
||||
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
|
||||
@@ -88,7 +86,6 @@ class TestSCIMGroups(APITestCase):
|
||||
"""Test group create"""
|
||||
user = create_test_user()
|
||||
ext_id = generate_id()
|
||||
name = generate_id()
|
||||
response = self.client.post(
|
||||
reverse(
|
||||
"authentik_sources_scim:v2-groups",
|
||||
@@ -98,7 +95,7 @@ class TestSCIMGroups(APITestCase):
|
||||
),
|
||||
data=dumps(
|
||||
{
|
||||
"displayName": name,
|
||||
"displayName": generate_id(),
|
||||
"externalId": ext_id,
|
||||
"members": [{"value": str(user.uuid)}],
|
||||
}
|
||||
@@ -107,22 +104,12 @@ class TestSCIMGroups(APITestCase):
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 201)
|
||||
connection = SCIMSourceGroup.objects.filter(source=self.source, external_id=ext_id).first()
|
||||
self.assertIsNotNone(connection)
|
||||
self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists())
|
||||
self.assertTrue(
|
||||
Event.objects.filter(
|
||||
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
|
||||
).exists()
|
||||
)
|
||||
connection.refresh_from_db()
|
||||
self.assertEqual(
|
||||
connection.attributes,
|
||||
{
|
||||
"displayName": name,
|
||||
"externalId": ext_id,
|
||||
"members": [{"value": str(user.uuid)}],
|
||||
},
|
||||
)
|
||||
|
||||
def test_group_create_members_empty(self):
|
||||
"""Test group create"""
|
||||
@@ -139,9 +126,7 @@ class TestSCIMGroups(APITestCase):
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 201)
|
||||
self.assertTrue(
|
||||
SCIMSourceGroup.objects.filter(source=self.source, external_id=ext_id).exists()
|
||||
)
|
||||
self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists())
|
||||
self.assertTrue(
|
||||
Event.objects.filter(
|
||||
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
|
||||
@@ -151,9 +136,7 @@ class TestSCIMGroups(APITestCase):
|
||||
def test_group_create_duplicate(self):
|
||||
"""Test group create (duplicate)"""
|
||||
group = Group.objects.create(name=generate_id())
|
||||
existing = SCIMSourceGroup.objects.create(
|
||||
source=self.source, group=group, external_id=uuid4()
|
||||
)
|
||||
existing = SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
|
||||
ext_id = generate_id()
|
||||
response = self.client.post(
|
||||
reverse(
|
||||
@@ -182,9 +165,7 @@ class TestSCIMGroups(APITestCase):
|
||||
def test_group_update(self):
|
||||
"""Test group update"""
|
||||
group = Group.objects.create(name=generate_id())
|
||||
existing = SCIMSourceGroup.objects.create(
|
||||
source=self.source, group=group, external_id=uuid4()
|
||||
)
|
||||
existing = SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
|
||||
ext_id = generate_id()
|
||||
response = self.client.put(
|
||||
reverse(
|
||||
@@ -224,49 +205,12 @@ class TestSCIMGroups(APITestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def test_group_patch_modify(self):
|
||||
"""Test group patch"""
|
||||
group = Group.objects.create(name=generate_id())
|
||||
connection = SCIMSourceGroup.objects.create(
|
||||
source=self.source,
|
||||
group=group,
|
||||
external_id=uuid4(),
|
||||
attributes={"displayName": group.name, "members": []},
|
||||
)
|
||||
response = self.client.patch(
|
||||
reverse(
|
||||
"authentik_sources_scim:v2-groups",
|
||||
kwargs={"source_slug": self.source.slug, "group_id": group.pk},
|
||||
),
|
||||
data=dumps(
|
||||
{
|
||||
"Operations": [
|
||||
{
|
||||
"op": "Add",
|
||||
"value": {"externalId": "d85051cb-0557-4aa1-98ca-51eabcee4d40"},
|
||||
}
|
||||
]
|
||||
}
|
||||
),
|
||||
content_type=SCIM_CONTENT_TYPE,
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200, response.content)
|
||||
connection = SCIMSourceGroup.objects.filter(id="d85051cb-0557-4aa1-98ca-51eabcee4d40")
|
||||
self.assertIsNotNone(connection)
|
||||
|
||||
def test_group_patch_member_add(self):
|
||||
def test_group_patch_add(self):
|
||||
"""Test group patch"""
|
||||
user = create_test_user()
|
||||
other_user = create_test_user()
|
||||
|
||||
group = Group.objects.create(name=generate_id())
|
||||
group.users.add(other_user)
|
||||
connection = SCIMSourceGroup.objects.create(
|
||||
source=self.source,
|
||||
group=group,
|
||||
external_id=uuid4(),
|
||||
attributes={"displayName": group.name, "members": [{"value": str(other_user.uuid)}]},
|
||||
)
|
||||
SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
|
||||
response = self.client.patch(
|
||||
reverse(
|
||||
"authentik_sources_scim:v2-groups",
|
||||
@@ -278,7 +222,7 @@ class TestSCIMGroups(APITestCase):
|
||||
{
|
||||
"op": "Add",
|
||||
"path": "members",
|
||||
"value": [{"value": str(user.uuid)}],
|
||||
"value": {"value": str(user.uuid)},
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -286,33 +230,16 @@ class TestSCIMGroups(APITestCase):
|
||||
content_type=SCIM_CONTENT_TYPE,
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200, response.content)
|
||||
self.assertEqual(response.status_code, second=200)
|
||||
self.assertTrue(group.users.filter(pk=user.pk).exists())
|
||||
self.assertTrue(group.users.filter(pk=other_user.pk).exists())
|
||||
connection.refresh_from_db()
|
||||
self.assertEqual(
|
||||
connection.attributes,
|
||||
{
|
||||
"displayName": group.name,
|
||||
"members": sorted(
|
||||
[{"value": str(other_user.uuid)}, {"value": str(user.uuid)}],
|
||||
key=lambda u: u["value"],
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
def test_group_patch_member_remove(self):
|
||||
def test_group_patch_remove(self):
|
||||
"""Test group patch"""
|
||||
user = create_test_user()
|
||||
|
||||
group = Group.objects.create(name=generate_id())
|
||||
group.users.add(user)
|
||||
connection = SCIMSourceGroup.objects.create(
|
||||
source=self.source,
|
||||
group=group,
|
||||
external_id=uuid4(),
|
||||
attributes={"displayName": group.name, "members": []},
|
||||
)
|
||||
SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
|
||||
response = self.client.patch(
|
||||
reverse(
|
||||
"authentik_sources_scim:v2-groups",
|
||||
@@ -324,7 +251,7 @@ class TestSCIMGroups(APITestCase):
|
||||
{
|
||||
"op": "remove",
|
||||
"path": "members",
|
||||
"value": [{"value": str(user.uuid)}],
|
||||
"value": {"value": str(user.uuid)},
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -332,21 +259,13 @@ class TestSCIMGroups(APITestCase):
|
||||
content_type=SCIM_CONTENT_TYPE,
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200, response.content)
|
||||
self.assertEqual(response.status_code, second=200)
|
||||
self.assertFalse(group.users.filter(pk=user.pk).exists())
|
||||
connection.refresh_from_db()
|
||||
self.assertEqual(
|
||||
connection.attributes,
|
||||
{
|
||||
"displayName": group.name,
|
||||
"members": [],
|
||||
},
|
||||
)
|
||||
|
||||
def test_group_delete(self):
|
||||
"""Test group delete"""
|
||||
group = Group.objects.create(name=generate_id())
|
||||
SCIMSourceGroup.objects.create(source=self.source, group=group, external_id=uuid4())
|
||||
SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
|
||||
response = self.client.delete(
|
||||
reverse(
|
||||
"authentik_sources_scim:v2-groups",
|
||||
|
||||
@@ -1,510 +0,0 @@
|
||||
from unittest import TestCase
|
||||
|
||||
from authentik.sources.scim.constants import (
|
||||
SCIM_URN_GROUP,
|
||||
SCIM_URN_SCHEMA,
|
||||
SCIM_URN_USER,
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
)
|
||||
from authentik.sources.scim.patch.lexer import SCIMPathLexer, Token, TokenType
|
||||
|
||||
|
||||
class TestTokenType(TestCase):
|
||||
"""Test TokenType enum"""
|
||||
|
||||
def test_token_type_values(self):
|
||||
"""Test that all token types have correct values"""
|
||||
self.assertEqual(TokenType.ATTRIBUTE.value, "ATTRIBUTE")
|
||||
self.assertEqual(TokenType.DOT.value, "DOT")
|
||||
self.assertEqual(TokenType.LBRACKET.value, "LBRACKET")
|
||||
self.assertEqual(TokenType.RBRACKET.value, "RBRACKET")
|
||||
self.assertEqual(TokenType.LPAREN.value, "LPAREN")
|
||||
self.assertEqual(TokenType.RPAREN.value, "RPAREN")
|
||||
self.assertEqual(TokenType.STRING.value, "STRING")
|
||||
self.assertEqual(TokenType.NUMBER.value, "NUMBER")
|
||||
self.assertEqual(TokenType.BOOLEAN.value, "BOOLEAN")
|
||||
self.assertEqual(TokenType.NULL.value, "NULL")
|
||||
self.assertEqual(TokenType.OPERATOR.value, "OPERATOR")
|
||||
self.assertEqual(TokenType.AND.value, "AND")
|
||||
self.assertEqual(TokenType.OR.value, "OR")
|
||||
self.assertEqual(TokenType.NOT.value, "NOT")
|
||||
self.assertEqual(TokenType.EOF.value, "EOF")
|
||||
|
||||
|
||||
class TestToken(TestCase):
|
||||
"""Test Token dataclass"""
|
||||
|
||||
def test_token_creation(self):
|
||||
"""Test token creation with all parameters"""
|
||||
token = Token(TokenType.ATTRIBUTE, "userName", 5)
|
||||
self.assertEqual(token.type, TokenType.ATTRIBUTE)
|
||||
self.assertEqual(token.value, "userName")
|
||||
self.assertEqual(token.position, 5)
|
||||
|
||||
def test_token_creation_default_position(self):
|
||||
"""Test token creation with default position"""
|
||||
token = Token(TokenType.DOT, ".")
|
||||
self.assertEqual(token.type, TokenType.DOT)
|
||||
self.assertEqual(token.value, ".")
|
||||
self.assertEqual(token.position, 0)
|
||||
|
||||
|
||||
class TestSCIMPathLexer(TestCase):
|
||||
"""Test SCIMPathLexer class"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures"""
|
||||
self.simple_lexer = SCIMPathLexer("userName")
|
||||
|
||||
def test_init(self):
|
||||
"""Test lexer initialization"""
|
||||
lexer = SCIMPathLexer("test")
|
||||
self.assertEqual(lexer.text, "test")
|
||||
self.assertEqual(lexer.pos, 0)
|
||||
self.assertEqual(lexer.current_char, "t")
|
||||
self.assertIn(SCIM_URN_SCHEMA, lexer.schema_urns)
|
||||
self.assertIn(SCIM_URN_GROUP, lexer.schema_urns)
|
||||
self.assertIn(SCIM_URN_USER, lexer.schema_urns)
|
||||
self.assertIn(SCIM_URN_USER_ENTERPRISE, lexer.schema_urns)
|
||||
self.assertEqual(
|
||||
lexer.OPERATORS, ["eq", "ne", "co", "sw", "ew", "gt", "lt", "ge", "le", "pr"]
|
||||
)
|
||||
|
||||
def test_init_empty_string(self):
|
||||
"""Test lexer initialization with empty string"""
|
||||
lexer = SCIMPathLexer("")
|
||||
self.assertEqual(lexer.text, "")
|
||||
self.assertEqual(lexer.pos, 0)
|
||||
self.assertIsNone(lexer.current_char)
|
||||
|
||||
def test_advance(self):
|
||||
"""Test advance method"""
|
||||
lexer = SCIMPathLexer("abc")
|
||||
self.assertEqual(lexer.current_char, "a")
|
||||
|
||||
lexer.advance()
|
||||
self.assertEqual(lexer.pos, 1)
|
||||
self.assertEqual(lexer.current_char, "b")
|
||||
|
||||
lexer.advance()
|
||||
self.assertEqual(lexer.pos, 2)
|
||||
self.assertEqual(lexer.current_char, "c")
|
||||
|
||||
lexer.advance()
|
||||
self.assertEqual(lexer.pos, 3)
|
||||
self.assertIsNone(lexer.current_char)
|
||||
|
||||
def test_skip_whitespace(self):
|
||||
"""Test skip_whitespace method"""
|
||||
lexer = SCIMPathLexer(" \t\n abc")
|
||||
lexer.skip_whitespace()
|
||||
self.assertEqual(lexer.current_char, "a")
|
||||
|
||||
def test_skip_whitespace_only_whitespace(self):
|
||||
"""Test skip_whitespace with only whitespace"""
|
||||
lexer = SCIMPathLexer(" \t\n ")
|
||||
lexer.skip_whitespace()
|
||||
self.assertIsNone(lexer.current_char)
|
||||
|
||||
def test_skip_whitespace_no_whitespace(self):
|
||||
"""Test skip_whitespace with no leading whitespace"""
|
||||
lexer = SCIMPathLexer("abc")
|
||||
original_pos = lexer.pos
|
||||
lexer.skip_whitespace()
|
||||
self.assertEqual(lexer.pos, original_pos)
|
||||
self.assertEqual(lexer.current_char, "a")
|
||||
|
||||
def test_read_string_double_quotes(self):
|
||||
"""Test reading double-quoted string"""
|
||||
lexer = SCIMPathLexer('"hello world"')
|
||||
result = lexer.read_string('"')
|
||||
self.assertEqual(result, "hello world")
|
||||
self.assertIsNone(lexer.current_char) # Should be at end
|
||||
|
||||
def test_read_string_single_quotes(self):
|
||||
"""Test reading single-quoted string"""
|
||||
lexer = SCIMPathLexer("'hello world'")
|
||||
result = lexer.read_string("'")
|
||||
self.assertEqual(result, "hello world")
|
||||
self.assertIsNone(lexer.current_char)
|
||||
|
||||
def test_read_string_with_escapes(self):
|
||||
"""Test reading string with escape characters"""
|
||||
lexer = SCIMPathLexer('"hello \\"world\\""')
|
||||
result = lexer.read_string('"')
|
||||
self.assertEqual(result, 'hello "world"')
|
||||
|
||||
def test_read_string_with_backslash_at_end(self):
|
||||
"""Test reading string with backslash at end"""
|
||||
lexer = SCIMPathLexer('"hello\\"')
|
||||
result = lexer.read_string('"')
|
||||
self.assertEqual(result, 'hello"')
|
||||
|
||||
def test_read_string_unclosed(self):
|
||||
"""Test reading unclosed string"""
|
||||
lexer = SCIMPathLexer('"hello world')
|
||||
result = lexer.read_string('"')
|
||||
self.assertEqual(result, "hello world")
|
||||
self.assertIsNone(lexer.current_char)
|
||||
|
||||
def test_read_string_empty(self):
|
||||
"""Test reading empty string"""
|
||||
lexer = SCIMPathLexer('""')
|
||||
result = lexer.read_string('"')
|
||||
self.assertEqual(result, "")
|
||||
|
||||
def test_read_number_integer(self):
|
||||
"""Test reading integer number"""
|
||||
lexer = SCIMPathLexer("123")
|
||||
result = lexer.read_number()
|
||||
self.assertEqual(result, "123")
|
||||
self.assertIsNone(lexer.current_char)
|
||||
|
||||
def test_read_number_float(self):
|
||||
"""Test reading float number"""
|
||||
lexer = SCIMPathLexer("123.456")
|
||||
result = lexer.read_number()
|
||||
self.assertEqual(result, "123.456")
|
||||
self.assertIsNone(lexer.current_char)
|
||||
|
||||
def test_read_number_with_multiple_dots(self):
|
||||
"""Test reading number with multiple dots (invalid but handled)"""
|
||||
lexer = SCIMPathLexer("123.456.789")
|
||||
result = lexer.read_number()
|
||||
self.assertEqual(result, "123.456.789")
|
||||
self.assertIsNone(lexer.current_char)
|
||||
|
||||
def test_read_number_starting_with_dot(self):
|
||||
"""Test reading number starting with dot"""
|
||||
lexer = SCIMPathLexer(".123")
|
||||
result = lexer.read_number()
|
||||
self.assertEqual(result, ".123")
|
||||
|
||||
def test_read_identifier_simple(self):
|
||||
"""Test reading simple identifier"""
|
||||
lexer = SCIMPathLexer("userName")
|
||||
result = lexer.read_identifier()
|
||||
self.assertEqual(result, "userName")
|
||||
self.assertIsNone(lexer.current_char)
|
||||
|
||||
def test_read_identifier_with_underscore(self):
|
||||
"""Test reading identifier with underscore"""
|
||||
lexer = SCIMPathLexer("user_name")
|
||||
result = lexer.read_identifier()
|
||||
self.assertEqual(result, "user_name")
|
||||
|
||||
def test_read_identifier_with_hyphen(self):
|
||||
"""Test reading identifier with hyphen"""
|
||||
lexer = SCIMPathLexer("user-name")
|
||||
result = lexer.read_identifier()
|
||||
self.assertEqual(result, "user-name")
|
||||
|
||||
def test_read_identifier_with_colon(self):
|
||||
"""Test reading identifier with colon (URN format)"""
|
||||
lexer = SCIMPathLexer("urn:ietf:params:scim:schemas:core:2.0:User")
|
||||
result = lexer.read_identifier()
|
||||
self.assertEqual(result, "urn:ietf:params:scim:schemas:core:2.0:User")
|
||||
|
||||
def test_read_identifier_schema_urn(self):
|
||||
"""Test reading schema URN identifier"""
|
||||
lexer = SCIMPathLexer(f"{SCIM_URN_USER}.userName")
|
||||
result = lexer.read_identifier()
|
||||
self.assertEqual(result, SCIM_URN_USER)
|
||||
self.assertEqual(lexer.current_char, ".") # Should stop at dot and set current_char to dot
|
||||
|
||||
def test_read_identifier_with_version_number(self):
|
||||
"""Test reading identifier with version number (dots followed by digits)"""
|
||||
lexer = SCIMPathLexer("urn:ietf:params:scim:schemas:core:2.0:User")
|
||||
result = lexer.read_identifier()
|
||||
self.assertEqual(result, "urn:ietf:params:scim:schemas:core:2.0:User")
|
||||
|
||||
def test_read_identifier_partial_urn_match(self):
|
||||
"""Test reading identifier that partially matches URN"""
|
||||
lexer = SCIMPathLexer("urn:ietf:params:scim:schemas:core:2.0:CustomUser")
|
||||
result = lexer.read_identifier()
|
||||
self.assertEqual(result, "urn:ietf:params:scim:schemas:core:2.0:CustomUser")
|
||||
|
||||
# Test get_next_token method
|
||||
def test_get_next_token_dot(self):
|
||||
"""Test tokenizing dot"""
|
||||
lexer = SCIMPathLexer(".")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.DOT)
|
||||
self.assertEqual(token.value, ".")
|
||||
|
||||
def test_get_next_token_lbracket(self):
|
||||
"""Test tokenizing left bracket"""
|
||||
lexer = SCIMPathLexer("[")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.LBRACKET)
|
||||
self.assertEqual(token.value, "[")
|
||||
|
||||
def test_get_next_token_rbracket(self):
|
||||
"""Test tokenizing right bracket"""
|
||||
lexer = SCIMPathLexer("]")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.RBRACKET)
|
||||
self.assertEqual(token.value, "]")
|
||||
|
||||
def test_get_next_token_lparen(self):
|
||||
"""Test tokenizing left parenthesis"""
|
||||
lexer = SCIMPathLexer("(")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.LPAREN)
|
||||
self.assertEqual(token.value, "(")
|
||||
|
||||
def test_get_next_token_rparen(self):
|
||||
"""Test tokenizing right parenthesis"""
|
||||
lexer = SCIMPathLexer(")")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.RPAREN)
|
||||
self.assertEqual(token.value, ")")
|
||||
|
||||
def test_get_next_token_string_double_quotes(self):
|
||||
"""Test tokenizing double-quoted string"""
|
||||
lexer = SCIMPathLexer('"test string"')
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.STRING)
|
||||
self.assertEqual(token.value, "test string")
|
||||
|
||||
def test_get_next_token_string_single_quotes(self):
|
||||
"""Test tokenizing single-quoted string"""
|
||||
lexer = SCIMPathLexer("'test string'")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.STRING)
|
||||
self.assertEqual(token.value, "test string")
|
||||
|
||||
def test_get_next_token_number_integer(self):
|
||||
"""Test tokenizing integer"""
|
||||
lexer = SCIMPathLexer("123")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.NUMBER)
|
||||
self.assertEqual(token.value, "123")
|
||||
|
||||
def test_get_next_token_number_float(self):
|
||||
"""Test tokenizing float"""
|
||||
lexer = SCIMPathLexer("123.45")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.NUMBER)
|
||||
self.assertEqual(token.value, "123.45")
|
||||
|
||||
def test_get_next_token_boolean_true(self):
|
||||
"""Test tokenizing boolean true"""
|
||||
lexer = SCIMPathLexer("true")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.BOOLEAN)
|
||||
self.assertTrue(token.value)
|
||||
|
||||
def test_get_next_token_boolean_false(self):
|
||||
"""Test tokenizing boolean false"""
|
||||
lexer = SCIMPathLexer("false")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.BOOLEAN)
|
||||
self.assertFalse(token.value)
|
||||
|
||||
def test_get_next_token_boolean_case_insensitive(self):
|
||||
"""Test tokenizing boolean with different cases"""
|
||||
for value in ["TRUE", "True", "FALSE", "False"]:
|
||||
with self.subTest(value=value):
|
||||
lexer = SCIMPathLexer(value)
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.BOOLEAN)
|
||||
|
||||
def test_get_next_token_null(self):
|
||||
"""Test tokenizing null"""
|
||||
lexer = SCIMPathLexer("null")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.NULL)
|
||||
self.assertIsNone(token.value)
|
||||
|
||||
def test_get_next_token_null_case_insensitive(self):
|
||||
"""Test tokenizing null with different cases"""
|
||||
for value in ["NULL", "Null"]:
|
||||
with self.subTest(value=value):
|
||||
lexer = SCIMPathLexer(value)
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.NULL)
|
||||
|
||||
def test_get_next_token_and(self):
|
||||
"""Test tokenizing AND operator"""
|
||||
lexer = SCIMPathLexer("and")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.AND)
|
||||
self.assertEqual(token.value, "and")
|
||||
|
||||
def test_get_next_token_or(self):
|
||||
"""Test tokenizing OR operator"""
|
||||
lexer = SCIMPathLexer("or")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.OR)
|
||||
self.assertEqual(token.value, "or")
|
||||
|
||||
def test_get_next_token_not(self):
|
||||
"""Test tokenizing NOT operator"""
|
||||
lexer = SCIMPathLexer("not")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.NOT)
|
||||
self.assertEqual(token.value, "not")
|
||||
|
||||
def test_get_next_token_operators(self):
|
||||
"""Test tokenizing all comparison operators"""
|
||||
operators = ["eq", "ne", "co", "sw", "ew", "gt", "lt", "ge", "le", "pr"]
|
||||
for op in operators:
|
||||
with self.subTest(operator=op):
|
||||
lexer = SCIMPathLexer(op)
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.OPERATOR)
|
||||
self.assertEqual(token.value, op)
|
||||
|
||||
def test_get_next_token_operators_case_insensitive(self):
|
||||
"""Test tokenizing operators with different cases"""
|
||||
for op in ["EQ", "Eq", "NE", "Ne"]:
|
||||
with self.subTest(operator=op):
|
||||
lexer = SCIMPathLexer(op)
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.OPERATOR)
|
||||
self.assertEqual(token.value, op.lower())
|
||||
|
||||
def test_get_next_token_attribute(self):
|
||||
"""Test tokenizing attribute name"""
|
||||
lexer = SCIMPathLexer("userName")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.ATTRIBUTE)
|
||||
self.assertEqual(token.value, "userName")
|
||||
|
||||
def test_get_next_token_attribute_with_underscore(self):
|
||||
"""Test tokenizing attribute name with underscore"""
|
||||
lexer = SCIMPathLexer("_userName")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.ATTRIBUTE)
|
||||
self.assertEqual(token.value, "_userName")
|
||||
|
||||
def test_get_next_token_eof(self):
|
||||
"""Test tokenizing end of file"""
|
||||
lexer = SCIMPathLexer("")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.EOF)
|
||||
self.assertEqual(token.value, "")
|
||||
|
||||
def test_get_next_token_with_whitespace(self):
|
||||
"""Test tokenizing with leading whitespace"""
|
||||
lexer = SCIMPathLexer(" userName")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.ATTRIBUTE)
|
||||
self.assertEqual(token.value, "userName")
|
||||
|
||||
def test_get_next_token_skip_unknown_characters(self):
|
||||
"""Test that unknown characters are skipped"""
|
||||
lexer = SCIMPathLexer("@#$userName")
|
||||
token = lexer.get_next_token()
|
||||
self.assertEqual(token.type, TokenType.ATTRIBUTE)
|
||||
self.assertEqual(token.value, "userName")
|
||||
|
||||
def test_get_next_token_multiple_tokens(self):
|
||||
"""Test tokenizing multiple tokens in sequence"""
|
||||
lexer = SCIMPathLexer("userName.givenName")
|
||||
|
||||
token1 = lexer.get_next_token()
|
||||
self.assertEqual(token1.type, TokenType.ATTRIBUTE)
|
||||
self.assertEqual(token1.value, "userName")
|
||||
|
||||
token2 = lexer.get_next_token()
|
||||
self.assertEqual(token2.type, TokenType.DOT)
|
||||
self.assertEqual(token2.value, ".")
|
||||
|
||||
token3 = lexer.get_next_token()
|
||||
self.assertEqual(token3.type, TokenType.ATTRIBUTE)
|
||||
self.assertEqual(token3.value, "givenName")
|
||||
|
||||
token4 = lexer.get_next_token()
|
||||
self.assertEqual(token4.type, TokenType.EOF)
|
||||
|
||||
def test_get_next_token_complex_filter(self):
|
||||
"""Test tokenizing complex filter expression"""
|
||||
lexer = SCIMPathLexer('emails[type eq "work" and primary eq true]')
|
||||
|
||||
tokens = []
|
||||
while True:
|
||||
token = lexer.get_next_token()
|
||||
tokens.append(token)
|
||||
if token.type == TokenType.EOF:
|
||||
break
|
||||
|
||||
expected_types = [
|
||||
TokenType.ATTRIBUTE, # emails
|
||||
TokenType.LBRACKET, # [
|
||||
TokenType.ATTRIBUTE, # type
|
||||
TokenType.OPERATOR, # eq
|
||||
TokenType.STRING, # "work"
|
||||
TokenType.AND, # and
|
||||
TokenType.ATTRIBUTE, # primary
|
||||
TokenType.OPERATOR, # eq
|
||||
TokenType.BOOLEAN, # true
|
||||
TokenType.RBRACKET, # ]
|
||||
TokenType.EOF,
|
||||
]
|
||||
|
||||
self.assertEqual(len(tokens), len(expected_types))
|
||||
for token, expected_type in zip(tokens, expected_types, strict=False):
|
||||
self.assertEqual(token.type, expected_type)
|
||||
|
||||
def test_get_next_token_urn_attribute(self):
|
||||
"""Test tokenizing URN-based attribute"""
|
||||
lexer = SCIMPathLexer(f"{SCIM_URN_USER}.userName")
|
||||
|
||||
token1 = lexer.get_next_token()
|
||||
self.assertEqual(token1.type, TokenType.ATTRIBUTE)
|
||||
self.assertEqual(token1.value, SCIM_URN_USER)
|
||||
|
||||
token2 = lexer.get_next_token()
|
||||
self.assertEqual(token2.type, TokenType.DOT)
|
||||
|
||||
token3 = lexer.get_next_token()
|
||||
self.assertEqual(token3.type, TokenType.ATTRIBUTE)
|
||||
self.assertEqual(token3.value, "userName")
|
||||
|
||||
def test_get_next_token_enterprise_urn(self):
|
||||
"""Test tokenizing enterprise URN"""
|
||||
lexer = SCIMPathLexer(f"{SCIM_URN_USER_ENTERPRISE}.manager")
|
||||
|
||||
token1 = lexer.get_next_token()
|
||||
self.assertEqual(token1.type, TokenType.ATTRIBUTE)
|
||||
self.assertEqual(token1.value, SCIM_URN_USER_ENTERPRISE)
|
||||
|
||||
token2 = lexer.get_next_token()
|
||||
self.assertEqual(token2.type, TokenType.DOT)
|
||||
|
||||
def test_lexer_state_after_eof(self):
|
||||
"""Test lexer state after reaching EOF"""
|
||||
lexer = SCIMPathLexer("a")
|
||||
|
||||
# Get first token
|
||||
token1 = lexer.get_next_token()
|
||||
self.assertEqual(token1.type, TokenType.ATTRIBUTE)
|
||||
|
||||
# Get EOF token
|
||||
token2 = lexer.get_next_token()
|
||||
self.assertEqual(token2.type, TokenType.EOF)
|
||||
|
||||
# Should continue returning EOF
|
||||
token3 = lexer.get_next_token()
|
||||
self.assertEqual(token3.type, TokenType.EOF)
|
||||
|
||||
def test_read_identifier_edge_cases(self):
|
||||
"""Test read_identifier with edge cases"""
|
||||
# Test identifier ending with colon
|
||||
lexer = SCIMPathLexer("test:")
|
||||
result = lexer.read_identifier()
|
||||
self.assertEqual(result, "test:")
|
||||
|
||||
# Test identifier with numbers
|
||||
lexer = SCIMPathLexer("test123")
|
||||
result = lexer.read_identifier()
|
||||
self.assertEqual(result, "test123")
|
||||
|
||||
def test_complex_urn_parsing(self):
|
||||
"""Test parsing complex URN with version numbers"""
|
||||
urn = "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
|
||||
lexer = SCIMPathLexer(urn)
|
||||
result = lexer.read_identifier()
|
||||
self.assertEqual(result, urn)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -10,7 +10,6 @@ from authentik.core.tests.utils import create_test_user
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.providers.scim.clients.schema import User as SCIMUserSchema
|
||||
from authentik.sources.scim.constants import SCIM_URN_USER_ENTERPRISE
|
||||
from authentik.sources.scim.models import SCIMSource, SCIMSourcePropertyMapping, SCIMSourceUser
|
||||
from authentik.sources.scim.views.v2.base import SCIM_CONTENT_TYPE
|
||||
|
||||
@@ -82,9 +81,7 @@ class TestSCIMUsers(APITestCase):
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 201)
|
||||
self.assertTrue(
|
||||
SCIMSourceUser.objects.filter(source=self.source, external_id=ext_id).exists()
|
||||
)
|
||||
self.assertTrue(SCIMSourceUser.objects.filter(source=self.source, id=ext_id).exists())
|
||||
self.assertTrue(
|
||||
Event.objects.filter(
|
||||
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
|
||||
@@ -177,16 +174,14 @@ class TestSCIMUsers(APITestCase):
|
||||
)
|
||||
self.assertEqual(response.status_code, 201)
|
||||
self.assertEqual(
|
||||
SCIMSourceUser.objects.get(source=self.source, external_id=ext_id).user.attributes[
|
||||
"phone"
|
||||
],
|
||||
SCIMSourceUser.objects.get(source=self.source, id=ext_id).user.attributes["phone"],
|
||||
"0123456789",
|
||||
)
|
||||
|
||||
def test_user_update(self):
|
||||
"""Test user update"""
|
||||
user = create_test_user()
|
||||
existing = SCIMSourceUser.objects.create(source=self.source, user=user, external_id=uuid4())
|
||||
existing = SCIMSourceUser.objects.create(source=self.source, user=user, id=uuid4())
|
||||
ext_id = generate_id()
|
||||
response = self.client.put(
|
||||
reverse(
|
||||
@@ -214,51 +209,10 @@ class TestSCIMUsers(APITestCase):
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_user_update_patch(self):
|
||||
"""Test user update (patch)"""
|
||||
user = create_test_user()
|
||||
existing = SCIMSourceUser.objects.create(
|
||||
source=self.source,
|
||||
user=user,
|
||||
external_id=uuid4(),
|
||||
attributes={
|
||||
"userName": generate_id(),
|
||||
},
|
||||
)
|
||||
response = self.client.patch(
|
||||
reverse(
|
||||
"authentik_sources_scim:v2-users",
|
||||
kwargs={
|
||||
"source_slug": self.source.slug,
|
||||
"user_id": str(user.uuid),
|
||||
},
|
||||
),
|
||||
data=dumps(
|
||||
{
|
||||
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
|
||||
"Operations": [
|
||||
{
|
||||
"op": "Add",
|
||||
"path": f"{SCIM_URN_USER_ENTERPRISE}:manager",
|
||||
"value": "86b2ed3e-30cd-4881-bb58-c4e910821339",
|
||||
}
|
||||
],
|
||||
}
|
||||
),
|
||||
content_type=SCIM_CONTENT_TYPE,
|
||||
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
existing.refresh_from_db()
|
||||
self.assertEqual(
|
||||
existing.attributes[SCIM_URN_USER_ENTERPRISE],
|
||||
{"manager": {"value": "86b2ed3e-30cd-4881-bb58-c4e910821339"}},
|
||||
)
|
||||
|
||||
def test_user_delete(self):
|
||||
"""Test user delete"""
|
||||
user = create_test_user()
|
||||
SCIMSourceUser.objects.create(source=self.source, user=user, external_id=uuid4())
|
||||
SCIMSourceUser.objects.create(source=self.source, user=user, id=uuid4())
|
||||
response = self.client.delete(
|
||||
reverse(
|
||||
"authentik_sources_scim:v2-users",
|
||||
|
||||
@@ -1,488 +0,0 @@
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.tests.utils import create_test_user
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.sources.scim.constants import SCIM_URN_USER_ENTERPRISE
|
||||
from authentik.sources.scim.models import SCIMSource, SCIMSourceUser
|
||||
from authentik.sources.scim.patch.processor import SCIMPatchProcessor
|
||||
|
||||
|
||||
class TestSCIMUsersPatch(APITestCase):
|
||||
"""Test SCIM User Patch"""
|
||||
|
||||
def test_add(self):
|
||||
req = {
|
||||
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
|
||||
"Operations": [
|
||||
{"op": "Add", "path": "name.givenName", "value": "aqwer"},
|
||||
{"op": "Add", "path": "name.familyName", "value": "qwerqqqq"},
|
||||
{"op": "Add", "path": "name.formatted", "value": "aqwer qwerqqqq"},
|
||||
],
|
||||
}
|
||||
user = create_test_user()
|
||||
source = SCIMSource.objects.create(slug=generate_id())
|
||||
connection = SCIMSourceUser.objects.create(
|
||||
user=user,
|
||||
id=generate_id(),
|
||||
source=source,
|
||||
attributes={
|
||||
"meta": {"resourceType": "User"},
|
||||
"active": True,
|
||||
"schemas": [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
],
|
||||
"userName": "test@t.goauthentik.io",
|
||||
"externalId": "test",
|
||||
"displayName": "Test MS",
|
||||
},
|
||||
)
|
||||
updated = SCIMPatchProcessor().apply_patches(connection.attributes, req["Operations"])
|
||||
self.assertEqual(
|
||||
updated,
|
||||
{
|
||||
"meta": {"resourceType": "User"},
|
||||
"active": True,
|
||||
"name": {
|
||||
"givenName": "aqwer",
|
||||
"familyName": "qwerqqqq",
|
||||
"formatted": "aqwer qwerqqqq",
|
||||
},
|
||||
"schemas": [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
],
|
||||
"userName": "test@t.goauthentik.io",
|
||||
"externalId": "test",
|
||||
"displayName": "Test MS",
|
||||
},
|
||||
)
|
||||
|
||||
def test_add_no_path(self):
|
||||
"""Test add patch with no path set"""
|
||||
req = {
|
||||
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
|
||||
"Operations": [
|
||||
{"op": "Add", "value": {"externalId": "aqwer"}},
|
||||
],
|
||||
}
|
||||
user = create_test_user()
|
||||
source = SCIMSource.objects.create(slug=generate_id())
|
||||
connection = SCIMSourceUser.objects.create(
|
||||
user=user,
|
||||
id=generate_id(),
|
||||
source=source,
|
||||
attributes={
|
||||
"meta": {"resourceType": "User"},
|
||||
"active": True,
|
||||
"schemas": [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
],
|
||||
"userName": "test@t.goauthentik.io",
|
||||
"displayName": "Test MS",
|
||||
},
|
||||
)
|
||||
updated = SCIMPatchProcessor().apply_patches(connection.attributes, req["Operations"])
|
||||
self.assertEqual(
|
||||
updated,
|
||||
{
|
||||
"meta": {"resourceType": "User"},
|
||||
"active": True,
|
||||
"schemas": [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
],
|
||||
"userName": "test@t.goauthentik.io",
|
||||
"externalId": "aqwer",
|
||||
"displayName": "Test MS",
|
||||
},
|
||||
)
|
||||
|
||||
def test_replace(self):
|
||||
req = {
|
||||
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
|
||||
"Operations": [
|
||||
{"op": "Replace", "path": "name", "value": {"givenName": "aqwer"}},
|
||||
],
|
||||
}
|
||||
user = create_test_user()
|
||||
source = SCIMSource.objects.create(slug=generate_id())
|
||||
connection = SCIMSourceUser.objects.create(
|
||||
user=user,
|
||||
id=generate_id(),
|
||||
source=source,
|
||||
attributes={
|
||||
"meta": {"resourceType": "User"},
|
||||
"active": True,
|
||||
"schemas": [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
],
|
||||
"userName": "test@t.goauthentik.io",
|
||||
"externalId": "test",
|
||||
"displayName": "Test MS",
|
||||
},
|
||||
)
|
||||
updated = SCIMPatchProcessor().apply_patches(connection.attributes, req["Operations"])
|
||||
self.assertEqual(
|
||||
updated,
|
||||
{
|
||||
"meta": {"resourceType": "User"},
|
||||
"active": True,
|
||||
"name": {
|
||||
"givenName": "aqwer",
|
||||
},
|
||||
"schemas": [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
],
|
||||
"userName": "test@t.goauthentik.io",
|
||||
"externalId": "test",
|
||||
"displayName": "Test MS",
|
||||
},
|
||||
)
|
||||
|
||||
def test_replace_no_path(self):
|
||||
"""Test value replace with no path"""
|
||||
req = {
|
||||
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
|
||||
"Operations": [
|
||||
{"op": "Replace", "value": {"externalId": "aqwer"}},
|
||||
],
|
||||
}
|
||||
user = create_test_user()
|
||||
source = SCIMSource.objects.create(slug=generate_id())
|
||||
connection = SCIMSourceUser.objects.create(
|
||||
user=user,
|
||||
id=generate_id(),
|
||||
source=source,
|
||||
attributes={
|
||||
"meta": {"resourceType": "User"},
|
||||
"active": True,
|
||||
"schemas": [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
],
|
||||
"userName": "test@t.goauthentik.io",
|
||||
"externalId": "test",
|
||||
"displayName": "Test MS",
|
||||
},
|
||||
)
|
||||
updated = SCIMPatchProcessor().apply_patches(connection.attributes, req["Operations"])
|
||||
self.assertEqual(
|
||||
updated,
|
||||
{
|
||||
"meta": {"resourceType": "User"},
|
||||
"active": True,
|
||||
"schemas": [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
],
|
||||
"userName": "test@t.goauthentik.io",
|
||||
"externalId": "aqwer",
|
||||
"displayName": "Test MS",
|
||||
},
|
||||
)
|
||||
|
||||
def test_remove(self):
|
||||
req = {
|
||||
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
|
||||
"Operations": [
|
||||
{"op": "Remove", "path": "name", "value": {"givenName": "aqwer"}},
|
||||
],
|
||||
}
|
||||
user = create_test_user()
|
||||
source = SCIMSource.objects.create(slug=generate_id())
|
||||
connection = SCIMSourceUser.objects.create(
|
||||
user=user,
|
||||
id=generate_id(),
|
||||
source=source,
|
||||
attributes={
|
||||
"meta": {"resourceType": "User"},
|
||||
"active": True,
|
||||
"name": {
|
||||
"givenName": "aqwer",
|
||||
},
|
||||
"schemas": [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
],
|
||||
"userName": "test@t.goauthentik.io",
|
||||
"externalId": "test",
|
||||
"displayName": "Test MS",
|
||||
},
|
||||
)
|
||||
updated = SCIMPatchProcessor().apply_patches(connection.attributes, req["Operations"])
|
||||
self.assertEqual(
|
||||
updated,
|
||||
{
|
||||
"meta": {"resourceType": "User"},
|
||||
"active": True,
|
||||
"schemas": [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
],
|
||||
"userName": "test@t.goauthentik.io",
|
||||
"externalId": "test",
|
||||
"displayName": "Test MS",
|
||||
},
|
||||
)
|
||||
|
||||
def test_large(self):
|
||||
"""Large amount of patch operations"""
|
||||
req = {
|
||||
"Operations": [
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "emails[primary eq true].value",
|
||||
"value": "dandre_kling@wintheiser.info",
|
||||
},
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "phoneNumbers[primary eq true].value",
|
||||
"value": "72-634-1548",
|
||||
},
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "phoneNumbers[primary eq true].display",
|
||||
"value": "72-634-1548",
|
||||
},
|
||||
{"op": "replace", "path": "ims[primary eq true].value", "value": "GXSGJKWGHVVS"},
|
||||
{"op": "replace", "path": "ims[primary eq true].display", "value": "IMCHDKUQIPYB"},
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "photos[primary eq true].display",
|
||||
"value": "TWAWLHHSUNIV",
|
||||
},
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "addresses[primary eq true].formatted",
|
||||
"value": "TMINZQAJQDCL",
|
||||
},
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "addresses[primary eq true].streetAddress",
|
||||
"value": "081 Wisoky Key",
|
||||
},
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "addresses[primary eq true].locality",
|
||||
"value": "DPFASBZRPMDP",
|
||||
},
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "addresses[primary eq true].region",
|
||||
"value": "WHSTJSPIPTCF",
|
||||
},
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "addresses[primary eq true].postalCode",
|
||||
"value": "ko28 1qa",
|
||||
},
|
||||
{"op": "replace", "path": "addresses[primary eq true].country", "value": "Taiwan"},
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "entitlements[primary eq true].value",
|
||||
"value": "NGBJMUYZVVBX",
|
||||
},
|
||||
{"op": "replace", "path": "roles[primary eq true].value", "value": "XEELVFMMWCVM"},
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "x509Certificates[primary eq true].value",
|
||||
"value": "UYISMEDOXUZY",
|
||||
},
|
||||
{
|
||||
"op": "replace",
|
||||
"value": {
|
||||
"externalId": "7faaefb0-0774-4d8e-8f6d-863c361bc72c",
|
||||
"name.formatted": "Dell",
|
||||
"name.familyName": "Gay",
|
||||
"name.givenName": "Kyler",
|
||||
"name.middleName": "Hannah",
|
||||
"name.honorificPrefix": "Cassie",
|
||||
"name.honorificSuffix": "Yolanda",
|
||||
"displayName": "DPRLIJSFQMTL",
|
||||
"nickName": "BKSPMIRMFBTI",
|
||||
"title": "NBZCOAXVYJUY",
|
||||
"userType": "ZGJMYZRUORZE",
|
||||
"preferredLanguage": "as-IN",
|
||||
"locale": "JLOJHLPWZODG",
|
||||
"timezone": "America/Argentina/Rio_Gallegos",
|
||||
"active": True,
|
||||
f"{SCIM_URN_USER_ENTERPRISE}:employeeNumber": "PDFWRRZBQOHB",
|
||||
f"{SCIM_URN_USER_ENTERPRISE}:costCenter": "HACMZWSEDOTQ",
|
||||
f"{SCIM_URN_USER_ENTERPRISE}:organization": "LXVHJUOLNCLS",
|
||||
f"{SCIM_URN_USER_ENTERPRISE}:division": "JASVTPKPBPMG",
|
||||
f"{SCIM_URN_USER_ENTERPRISE}:department": "GMSBFLMNPABY",
|
||||
},
|
||||
},
|
||||
],
|
||||
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
|
||||
}
|
||||
user = create_test_user()
|
||||
source = SCIMSource.objects.create(slug=generate_id())
|
||||
connection = SCIMSourceUser.objects.create(
|
||||
user=user,
|
||||
id=generate_id(),
|
||||
source=source,
|
||||
attributes={
|
||||
"active": True,
|
||||
"addresses": [
|
||||
{
|
||||
"primary": "true",
|
||||
"formatted": "BLJMCNXHYLZK",
|
||||
"streetAddress": "7801 Jacobs Fork",
|
||||
"locality": "HZJBJWFAKXDD",
|
||||
"region": "GJXCXPMIIKWK",
|
||||
"postalCode": "pv82 8ua",
|
||||
"country": "India",
|
||||
}
|
||||
],
|
||||
"displayName": "KEFXCHKHAFOT",
|
||||
"emails": [{"primary": "true", "value": "scot@zemlak.uk"}],
|
||||
"entitlements": [{"primary": "true", "value": "FTTUXWYDAAQC"}],
|
||||
"externalId": "448d2786-7bf6-4e03-a4ef-64cbaf162fa7",
|
||||
"ims": [{"primary": "true", "value": "IGWZUUMCMKXS", "display": "PJVGMMKYYHRU"}],
|
||||
"locale": "PJNYJHWJILTI",
|
||||
"name": {
|
||||
"formatted": "Ladarius",
|
||||
"familyName": "Manley",
|
||||
"givenName": "Mazie",
|
||||
"middleName": "Vernon",
|
||||
"honorificPrefix": "Melyssa",
|
||||
"honorificSuffix": "Demarcus",
|
||||
},
|
||||
"nickName": "HTPKOXMWZKHL",
|
||||
"phoneNumbers": [
|
||||
{"primary": "true", "value": "50-608-7660", "display": "50-608-7660"}
|
||||
],
|
||||
"photos": [{"primary": "true", "display": "KCONLNLSYTBP"}],
|
||||
"preferredLanguage": "wae",
|
||||
"profileUrl": "HPSEOIPXMGOH",
|
||||
"roles": [{"primary": "true", "value": "TLGYITOIZGKP"}],
|
||||
"schemas": [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
],
|
||||
"timezone": "America/Indiana/Petersburg",
|
||||
"title": "EJWFXLHNHMCD",
|
||||
SCIM_URN_USER_ENTERPRISE: {
|
||||
"employeeNumber": "XHDMEJUURJNR",
|
||||
"costCenter": "RXUYBXOTRCZH",
|
||||
"organization": "CEXWXMBRYAHN",
|
||||
"division": "XMPFMDCLRKCW",
|
||||
"department": "BKMNJVMCJUYS",
|
||||
"manager": "PNGSGXLYVWMV",
|
||||
},
|
||||
"userName": "imelda.auer@kshlerin.co.uk",
|
||||
"userType": "PZFXORVSUAPU",
|
||||
"x509Certificates": [{"primary": "true", "value": "KOVKWGIVVEHH"}],
|
||||
},
|
||||
)
|
||||
updated = SCIMPatchProcessor().apply_patches(connection.attributes, req["Operations"])
|
||||
self.assertEqual(
|
||||
updated,
|
||||
{
|
||||
"active": True,
|
||||
"addresses": [
|
||||
{
|
||||
"primary": "true",
|
||||
"formatted": "BLJMCNXHYLZK",
|
||||
"streetAddress": "7801 Jacobs Fork",
|
||||
"locality": "HZJBJWFAKXDD",
|
||||
"region": "GJXCXPMIIKWK",
|
||||
"postalCode": "pv82 8ua",
|
||||
"country": "India",
|
||||
}
|
||||
],
|
||||
"displayName": "DPRLIJSFQMTL",
|
||||
"emails": [{"primary": "true", "value": "scot@zemlak.uk"}],
|
||||
"entitlements": [{"primary": "true", "value": "FTTUXWYDAAQC"}],
|
||||
"externalId": "7faaefb0-0774-4d8e-8f6d-863c361bc72c",
|
||||
"ims": [{"primary": "true", "value": "IGWZUUMCMKXS", "display": "PJVGMMKYYHRU"}],
|
||||
"locale": "JLOJHLPWZODG",
|
||||
"name": {
|
||||
"formatted": "Dell",
|
||||
"familyName": "Gay",
|
||||
"givenName": "Kyler",
|
||||
"middleName": "Hannah",
|
||||
"honorificPrefix": "Cassie",
|
||||
"honorificSuffix": "Yolanda",
|
||||
},
|
||||
"nickName": "BKSPMIRMFBTI",
|
||||
"phoneNumbers": [
|
||||
{"primary": "true", "value": "50-608-7660", "display": "50-608-7660"}
|
||||
],
|
||||
"photos": [{"primary": "true", "display": "KCONLNLSYTBP"}],
|
||||
"preferredLanguage": "as-IN",
|
||||
"profileUrl": "HPSEOIPXMGOH",
|
||||
"roles": [{"primary": "true", "value": "TLGYITOIZGKP"}],
|
||||
"schemas": [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
],
|
||||
"timezone": "America/Argentina/Rio_Gallegos",
|
||||
"title": "NBZCOAXVYJUY",
|
||||
SCIM_URN_USER_ENTERPRISE: {
|
||||
"employeeNumber": "PDFWRRZBQOHB",
|
||||
"costCenter": "HACMZWSEDOTQ",
|
||||
"organization": "LXVHJUOLNCLS",
|
||||
"division": "JASVTPKPBPMG",
|
||||
"department": "GMSBFLMNPABY",
|
||||
"manager": "PNGSGXLYVWMV",
|
||||
},
|
||||
"userName": "imelda.auer@kshlerin.co.uk",
|
||||
"userType": "ZGJMYZRUORZE",
|
||||
"x509Certificates": [{"primary": "true", "value": "KOVKWGIVVEHH"}],
|
||||
},
|
||||
)
|
||||
|
||||
def test_schema_urn_manager(self):
|
||||
req = {
|
||||
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
|
||||
"Operations": [
|
||||
{
|
||||
"op": "Add",
|
||||
"value": {
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User:manager": "foo"
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
user = create_test_user()
|
||||
source = SCIMSource.objects.create(slug=generate_id())
|
||||
connection = SCIMSourceUser.objects.create(
|
||||
user=user,
|
||||
id=generate_id(),
|
||||
source=source,
|
||||
attributes={
|
||||
"meta": {"resourceType": "User"},
|
||||
"active": True,
|
||||
"schemas": [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
],
|
||||
"userName": "test@t.goauthentik.io",
|
||||
"externalId": "test",
|
||||
"displayName": "Test MS",
|
||||
},
|
||||
)
|
||||
updated = SCIMPatchProcessor().apply_patches(connection.attributes, req["Operations"])
|
||||
self.assertEqual(
|
||||
updated,
|
||||
{
|
||||
"meta": {"resourceType": "User"},
|
||||
"active": True,
|
||||
"schemas": [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
SCIM_URN_USER_ENTERPRISE,
|
||||
],
|
||||
"userName": "test@t.goauthentik.io",
|
||||
"externalId": "test",
|
||||
"displayName": "Test MS",
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User": {
|
||||
"manager": {"value": "foo"}
|
||||
},
|
||||
},
|
||||
)
|
||||
@@ -1,7 +1,6 @@
|
||||
"""SCIM Utils"""
|
||||
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.paginator import Page, Paginator
|
||||
@@ -22,7 +21,6 @@ from authentik.core.sources.mapper import SourceMapper
|
||||
from authentik.lib.sync.mapper import PropertyMappingManager
|
||||
from authentik.sources.scim.models import SCIMSource
|
||||
from authentik.sources.scim.views.v2.auth import SCIMTokenAuth
|
||||
from authentik.sources.scim.views.v2.exceptions import SCIMNotFoundError
|
||||
|
||||
SCIM_CONTENT_TYPE = "application/scim+json"
|
||||
|
||||
@@ -56,13 +54,6 @@ class SCIMView(APIView):
|
||||
def get_authenticators(self):
|
||||
return [SCIMTokenAuth(self)]
|
||||
|
||||
def remove_excluded_attributes(self, data: dict):
|
||||
"""Remove attributes specified in excludedAttributes"""
|
||||
excluded: str = self.request.query_params.get("excludedAttributes", "")
|
||||
for key in excluded.split(","):
|
||||
data.pop(key.strip(), None)
|
||||
return data
|
||||
|
||||
def filter_parse(self, request: Request):
|
||||
"""Parse the path of a Patch Operation"""
|
||||
path = request.query_params.get("filter")
|
||||
@@ -112,12 +103,6 @@ class SCIMObjectView(SCIMView):
|
||||
# a source attribute before
|
||||
self.mapper = SourceMapper(self.source)
|
||||
self.manager = self.mapper.get_manager(self.model, ["data"])
|
||||
for key, value in kwargs.items():
|
||||
if key.endswith("_id"):
|
||||
try:
|
||||
UUID(value)
|
||||
except ValueError:
|
||||
raise SCIMNotFoundError("Invalid ID") from None
|
||||
|
||||
def build_object_properties(self, data: dict[str, Any]) -> dict[str, Any | dict[str, Any]]:
|
||||
return self.mapper.build_object_properties(
|
||||
|
||||
@@ -17,7 +17,6 @@ from authentik.core.models import Group, User
|
||||
from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOp, PatchOperation
|
||||
from authentik.providers.scim.clients.schema import Group as SCIMGroupModel
|
||||
from authentik.sources.scim.models import SCIMSourceGroup
|
||||
from authentik.sources.scim.patch.processor import SCIMPatchProcessor
|
||||
from authentik.sources.scim.views.v2.base import SCIMObjectView
|
||||
from authentik.sources.scim.views.v2.exceptions import (
|
||||
SCIMConflictError,
|
||||
@@ -36,12 +35,11 @@ class GroupsView(SCIMObjectView):
|
||||
payload = SCIMGroupModel(
|
||||
schemas=[SCIM_GROUP_SCHEMA],
|
||||
id=str(scim_group.group.pk),
|
||||
externalId=scim_group.external_id,
|
||||
externalId=scim_group.id,
|
||||
displayName=scim_group.group.name,
|
||||
members=[],
|
||||
meta={
|
||||
"resourceType": "Group",
|
||||
"lastModified": scim_group.last_update,
|
||||
"location": self.request.build_absolute_uri(
|
||||
reverse(
|
||||
"authentik_sources_scim:v2-groups",
|
||||
@@ -56,11 +54,7 @@ class GroupsView(SCIMObjectView):
|
||||
for member in scim_group.group.users.order_by("pk"):
|
||||
member: User
|
||||
payload.members.append(GroupMember(value=str(member.uuid)))
|
||||
final_payload = payload.model_dump(mode="json", exclude_unset=True)
|
||||
final_payload.update(scim_group.attributes)
|
||||
return self.remove_excluded_attributes(
|
||||
SCIMGroupModel.model_validate(final_payload).model_dump(mode="json", exclude_unset=True)
|
||||
)
|
||||
return payload.model_dump(mode="json", exclude_unset=True)
|
||||
|
||||
def get(self, request: Request, group_id: str | None = None, **kwargs) -> Response:
|
||||
"""List Group handler"""
|
||||
@@ -87,7 +81,7 @@ class GroupsView(SCIMObjectView):
|
||||
)
|
||||
|
||||
@atomic
|
||||
def update_group(self, connection: SCIMSourceGroup | None, data: QueryDict, apply_members=True):
|
||||
def update_group(self, connection: SCIMSourceGroup | None, data: QueryDict):
|
||||
"""Partial update a group"""
|
||||
properties = self.build_object_properties(data)
|
||||
|
||||
@@ -100,7 +94,7 @@ class GroupsView(SCIMObjectView):
|
||||
|
||||
group.update_attributes(properties)
|
||||
|
||||
if "members" in data and apply_members:
|
||||
if "members" in data:
|
||||
query = Q()
|
||||
for _member in data.get("members", []):
|
||||
try:
|
||||
@@ -111,18 +105,14 @@ class GroupsView(SCIMObjectView):
|
||||
query |= Q(uuid=member.value)
|
||||
if query:
|
||||
group.users.set(User.objects.filter(query))
|
||||
data["members"] = self._convert_members(group)
|
||||
if not connection:
|
||||
connection, _ = SCIMSourceGroup.objects.update_or_create(
|
||||
external_id=data.get("externalId") or str(uuid4()),
|
||||
connection, _ = SCIMSourceGroup.objects.get_or_create(
|
||||
source=self.source,
|
||||
group=group,
|
||||
defaults={
|
||||
"attributes": data,
|
||||
},
|
||||
attributes=data,
|
||||
id=data.get("externalId") or str(uuid4()),
|
||||
)
|
||||
else:
|
||||
connection.external_id = data.get("externalId", connection.external_id)
|
||||
connection.attributes = data
|
||||
connection.save()
|
||||
return connection
|
||||
@@ -149,12 +139,6 @@ class GroupsView(SCIMObjectView):
|
||||
connection = self.update_group(connection, request.data)
|
||||
return Response(self.group_to_scim(connection), status=200)
|
||||
|
||||
def _convert_members(self, group: Group):
|
||||
users = []
|
||||
for user in group.users.all().order_by("uuid"):
|
||||
users.append({"value": str(user.uuid)})
|
||||
return sorted(users, key=lambda u: u["value"])
|
||||
|
||||
@atomic
|
||||
def patch(self, request: Request, group_id: str, **kwargs) -> Response:
|
||||
"""Patch group handler"""
|
||||
@@ -187,13 +171,6 @@ class GroupsView(SCIMObjectView):
|
||||
query |= Q(uuid=member["value"])
|
||||
if query:
|
||||
connection.group.users.remove(*User.objects.filter(query))
|
||||
patcher = SCIMPatchProcessor()
|
||||
patched_data = patcher.apply_patches(
|
||||
connection.attributes, request.data.get("Operations", [])
|
||||
)
|
||||
patched_data["members"] = self._convert_members(connection.group)
|
||||
if patched_data != connection.attributes:
|
||||
self.update_group(connection, patched_data, apply_members=False)
|
||||
return Response(self.group_to_scim(connection), status=200)
|
||||
|
||||
@atomic
|
||||
|
||||
@@ -33,7 +33,9 @@ class ServiceProviderConfigView(SCIMView):
|
||||
{
|
||||
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"],
|
||||
"authenticationSchemes": auth_schemas,
|
||||
"patch": {"supported": True},
|
||||
# We only support patch for groups currently, so don't broadly advertise it.
|
||||
# Implementations that require Group patch will use it regardless of this flag.
|
||||
"patch": {"supported": False},
|
||||
"bulk": {"supported": False, "maxOperations": 0, "maxPayloadSize": 0},
|
||||
"filter": {
|
||||
"supported": True,
|
||||
|
||||
@@ -15,7 +15,6 @@ from authentik.core.models import User
|
||||
from authentik.providers.scim.clients.schema import SCIM_USER_SCHEMA
|
||||
from authentik.providers.scim.clients.schema import User as SCIMUserModel
|
||||
from authentik.sources.scim.models import SCIMSourceUser
|
||||
from authentik.sources.scim.patch.processor import SCIMPatchProcessor
|
||||
from authentik.sources.scim.views.v2.base import SCIMObjectView
|
||||
from authentik.sources.scim.views.v2.exceptions import SCIMConflictError, SCIMNotFoundError
|
||||
|
||||
@@ -30,7 +29,7 @@ class UsersView(SCIMObjectView):
|
||||
payload = SCIMUserModel(
|
||||
schemas=[SCIM_USER_SCHEMA],
|
||||
id=str(scim_user.user.uuid),
|
||||
externalId=scim_user.external_id,
|
||||
externalId=scim_user.id,
|
||||
userName=scim_user.user.username,
|
||||
name=Name(
|
||||
formatted=scim_user.user.name,
|
||||
@@ -45,7 +44,8 @@ class UsersView(SCIMObjectView):
|
||||
meta={
|
||||
"resourceType": "User",
|
||||
"created": scim_user.user.date_joined,
|
||||
"lastModified": scim_user.last_update,
|
||||
# TODO: use events to find last edit?
|
||||
"lastModified": scim_user.user.date_joined,
|
||||
"location": self.request.build_absolute_uri(
|
||||
reverse(
|
||||
"authentik_sources_scim:v2-users",
|
||||
@@ -59,9 +59,7 @@ class UsersView(SCIMObjectView):
|
||||
)
|
||||
final_payload = payload.model_dump(mode="json", exclude_unset=True)
|
||||
final_payload.update(scim_user.attributes)
|
||||
return self.remove_excluded_attributes(
|
||||
SCIMUserModel.model_validate(final_payload).model_dump(mode="json", exclude_unset=True)
|
||||
)
|
||||
return final_payload
|
||||
|
||||
def get(self, request: Request, user_id: str | None = None, **kwargs) -> Response:
|
||||
"""List User handler"""
|
||||
@@ -103,16 +101,13 @@ class UsersView(SCIMObjectView):
|
||||
user.update_attributes(properties)
|
||||
|
||||
if not connection:
|
||||
connection, _ = SCIMSourceUser.objects.update_or_create(
|
||||
external_id=data.get("externalId") or str(uuid4()),
|
||||
connection, _ = SCIMSourceUser.objects.get_or_create(
|
||||
source=self.source,
|
||||
user=user,
|
||||
defaults={
|
||||
"attributes": data,
|
||||
},
|
||||
attributes=data,
|
||||
id=data.get("externalId") or str(uuid4()),
|
||||
)
|
||||
else:
|
||||
connection.external_id = data.get("externalId", connection.external_id)
|
||||
connection.attributes = data
|
||||
connection.save()
|
||||
return connection
|
||||
@@ -132,18 +127,6 @@ class UsersView(SCIMObjectView):
|
||||
connection = self.update_user(None, request.data)
|
||||
return Response(self.user_to_scim(connection), status=201)
|
||||
|
||||
def patch(self, request: Request, user_id: str, **kwargs):
|
||||
connection = SCIMSourceUser.objects.filter(source=self.source, user__uuid=user_id).first()
|
||||
if not connection:
|
||||
raise SCIMNotFoundError("User not found.")
|
||||
patcher = SCIMPatchProcessor()
|
||||
patched_data = patcher.apply_patches(
|
||||
connection.attributes, request.data.get("Operations", [])
|
||||
)
|
||||
if patched_data != connection.attributes:
|
||||
self.update_user(connection, patched_data)
|
||||
return Response(self.user_to_scim(connection), status=200)
|
||||
|
||||
def put(self, request: Request, user_id: str, **kwargs) -> Response:
|
||||
"""Update user handler"""
|
||||
connection = SCIMSourceUser.objects.filter(source=self.source, user__uuid=user_id).first()
|
||||
|
||||
@@ -13,6 +13,7 @@ from authentik.flows.exceptions import StageInvalidException
|
||||
from authentik.flows.models import ConfigurableStage, FriendlyNamedStage, Stage
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.models import SerializerModel
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.lib.utils.time import timedelta_string_validator
|
||||
from authentik.stages.authenticator.models import SideChannelDevice
|
||||
from authentik.stages.email.utils import TemplateEmailMessage
|
||||
@@ -159,8 +160,9 @@ class EmailDevice(SerializerModel, SideChannelDevice):
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message=_("Exception occurred while rendering E-mail template"),
|
||||
error=exception_to_string(exc),
|
||||
template=stage.template,
|
||||
).with_exception(exc).from_http(self.request)
|
||||
).from_http(self.request)
|
||||
raise StageInvalidException from exc
|
||||
|
||||
def __str__(self):
|
||||
|
||||
@@ -17,6 +17,7 @@ from authentik.flows.challenge import (
|
||||
from authentik.flows.exceptions import StageInvalidException
|
||||
from authentik.flows.stage import ChallengeStageView
|
||||
from authentik.lib.utils.email import mask_email
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.stages.authenticator_email.models import (
|
||||
AuthenticatorEmailStage,
|
||||
@@ -99,8 +100,9 @@ class AuthenticatorEmailStageView(ChallengeStageView):
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message=_("Exception occurred while rendering E-mail template"),
|
||||
error=exception_to_string(exc),
|
||||
template=stage.template,
|
||||
).with_exception(exc).from_http(self.request)
|
||||
).from_http(self.request)
|
||||
raise StageInvalidException from exc
|
||||
|
||||
def _has_email(self) -> str | None:
|
||||
|
||||
@@ -4,7 +4,7 @@ from hashlib import sha256
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.db import models
|
||||
from django.http import HttpRequest, HttpResponseBadRequest
|
||||
from django.http import HttpResponseBadRequest
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.views import View
|
||||
from requests.exceptions import RequestException
|
||||
@@ -19,6 +19,7 @@ from authentik.events.models import Event, EventAction, NotificationWebhookMappi
|
||||
from authentik.events.utils import sanitize_item
|
||||
from authentik.flows.models import ConfigurableStage, FriendlyNamedStage, Stage
|
||||
from authentik.lib.models import SerializerModel
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
from authentik.lib.utils.http import get_http_session
|
||||
from authentik.stages.authenticator.models import SideChannelDevice
|
||||
|
||||
@@ -68,44 +69,32 @@ class AuthenticatorSMSStage(ConfigurableStage, FriendlyNamedStage, Stage):
|
||||
help_text=_("Optionally modify the payload being sent to custom providers."),
|
||||
)
|
||||
|
||||
def send(self, request: HttpRequest, token: str, device: "SMSDevice"):
|
||||
def send(self, token: str, device: "SMSDevice"):
|
||||
"""Send message via selected provider"""
|
||||
if self.provider == SMSProviders.TWILIO:
|
||||
return self.send_twilio(request, token, device)
|
||||
return self.send_twilio(token, device)
|
||||
if self.provider == SMSProviders.GENERIC:
|
||||
return self.send_generic(request, token, device)
|
||||
return self.send_generic(token, device)
|
||||
raise ValueError(f"invalid provider {self.provider}")
|
||||
|
||||
def get_message(self, token: str) -> str:
|
||||
"""Get SMS message"""
|
||||
return _("Use this code to authenticate in authentik: {token}".format_map({"token": token}))
|
||||
|
||||
def send_twilio(self, request: HttpRequest, token: str, device: "SMSDevice"):
|
||||
def send_twilio(self, token: str, device: "SMSDevice"):
|
||||
"""send sms via twilio provider"""
|
||||
client = Client(self.account_sid, self.auth)
|
||||
message_body = str(self.get_message(token))
|
||||
if self.mapping:
|
||||
payload = sanitize_item(
|
||||
self.mapping.evaluate(
|
||||
user=device.user,
|
||||
request=request,
|
||||
device=device,
|
||||
token=token,
|
||||
stage=self,
|
||||
)
|
||||
)
|
||||
message_body = payload.get("message", message_body)
|
||||
|
||||
try:
|
||||
message = client.messages.create(
|
||||
to=device.phone_number, from_=self.from_number, body=message_body
|
||||
to=device.phone_number, from_=self.from_number, body=str(self.get_message(token))
|
||||
)
|
||||
LOGGER.debug("Sent SMS", to=device, message=message.sid)
|
||||
except TwilioRestException as exc:
|
||||
LOGGER.warning("Error sending token by Twilio SMS", exc=exc, msg=exc.msg)
|
||||
raise ValidationError(exc.msg) from None
|
||||
|
||||
def send_generic(self, request: HttpRequest, token: str, device: "SMSDevice"):
|
||||
def send_generic(self, token: str, device: "SMSDevice"):
|
||||
"""Send SMS via outside API"""
|
||||
payload = {
|
||||
"From": self.from_number,
|
||||
@@ -118,7 +107,7 @@ class AuthenticatorSMSStage(ConfigurableStage, FriendlyNamedStage, Stage):
|
||||
payload = sanitize_item(
|
||||
self.mapping.evaluate(
|
||||
user=device.user,
|
||||
request=request,
|
||||
request=None,
|
||||
device=device,
|
||||
token=token,
|
||||
stage=self,
|
||||
@@ -153,9 +142,10 @@ class AuthenticatorSMSStage(ConfigurableStage, FriendlyNamedStage, Stage):
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message="Error sending SMS",
|
||||
exc=exception_to_string(exc),
|
||||
status_code=response.status_code,
|
||||
body=response.text,
|
||||
).with_exception(exc).set_user(device.user).save()
|
||||
).set_user(device.user).save()
|
||||
if response.status_code >= HttpResponseBadRequest.status_code:
|
||||
raise ValidationError(response.text) from None
|
||||
raise
|
||||
|
||||
@@ -71,7 +71,7 @@ class AuthenticatorSMSStageView(ChallengeStageView):
|
||||
raise ValidationError(_("Invalid phone number"))
|
||||
# No code yet, but we have a phone number, so send a verification message
|
||||
device: SMSDevice = self.request.session[SESSION_KEY_SMS_DEVICE]
|
||||
stage.send(self.request, device.token, device)
|
||||
stage.send(device.token, device)
|
||||
|
||||
def _has_phone_number(self) -> str | None:
|
||||
context = self.executor.plan.context
|
||||
|
||||
@@ -9,7 +9,7 @@ from django.http.response import Http404
|
||||
from django.shortcuts import get_object_or_404
|
||||
from django.utils.translation import gettext as __
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework.fields import CharField, ChoiceField, DateTimeField
|
||||
from rest_framework.fields import CharField, DateTimeField
|
||||
from rest_framework.serializers import ValidationError
|
||||
from structlog.stdlib import get_logger
|
||||
from webauthn import options_to_json
|
||||
@@ -18,7 +18,7 @@ from webauthn.authentication.verify_authentication_response import verify_authen
|
||||
from webauthn.helpers import parse_authentication_credential_json
|
||||
from webauthn.helpers.base64url_to_bytes import base64url_to_bytes
|
||||
from webauthn.helpers.exceptions import InvalidAuthenticationResponse, InvalidJSONStructure
|
||||
from webauthn.helpers.structs import PublicKeyCredentialType, UserVerificationRequirement
|
||||
from webauthn.helpers.structs import UserVerificationRequirement
|
||||
|
||||
from authentik.core.api.utils import JSONDictField, PassiveSerializer
|
||||
from authentik.core.models import Application, User
|
||||
@@ -48,7 +48,7 @@ if TYPE_CHECKING:
|
||||
class DeviceChallenge(PassiveSerializer):
|
||||
"""Single device challenge"""
|
||||
|
||||
device_class = ChoiceField(choices=DeviceClasses.choices)
|
||||
device_class = CharField()
|
||||
device_uid = CharField()
|
||||
challenge = JSONDictField()
|
||||
last_used = DateTimeField(allow_null=True)
|
||||
@@ -124,7 +124,7 @@ def select_challenge(request: HttpRequest, device: Device):
|
||||
def select_challenge_sms(request: HttpRequest, device: SMSDevice):
|
||||
"""Send SMS"""
|
||||
device.generate_token()
|
||||
device.stage.send(request, device.token, device)
|
||||
device.stage.send(device.token, device)
|
||||
|
||||
|
||||
def select_challenge_email(request: HttpRequest, device: EmailDevice):
|
||||
@@ -157,12 +157,6 @@ def validate_challenge_webauthn(data: dict, stage_view: StageView, user: User) -
|
||||
request = stage_view.request
|
||||
challenge = stage_view.executor.plan.context.get(PLAN_CONTEXT_WEBAUTHN_CHALLENGE)
|
||||
stage: AuthenticatorValidateStage = stage_view.executor.current_stage
|
||||
if "MinuteMaid" in request.META.get("HTTP_USER_AGENT", ""):
|
||||
# Workaround for Android sign-in, when signing into Google Workspace on android while
|
||||
# adding the account to the system (not in Chrome), for some reason `type` is not set
|
||||
# so in that case we fall back to `public-key`
|
||||
# since that's the only option we support anyways
|
||||
data.setdefault("type", PublicKeyCredentialType.PUBLIC_KEY)
|
||||
try:
|
||||
credential = parse_authentication_credential_json(data)
|
||||
except InvalidJSONStructure as exc:
|
||||
|
||||
@@ -173,7 +173,6 @@ class AuthenticatorValidateStageDuoTests(FlowTestCase):
|
||||
{
|
||||
"auth_method": "auth_mfa",
|
||||
"auth_method_args": {
|
||||
"known_device": False,
|
||||
"mfa_devices": [
|
||||
{
|
||||
"app": "authentik_stages_authenticator_duo",
|
||||
@@ -181,7 +180,7 @@ class AuthenticatorValidateStageDuoTests(FlowTestCase):
|
||||
"name": "",
|
||||
"pk": duo_device.pk,
|
||||
}
|
||||
],
|
||||
]
|
||||
},
|
||||
"http_request": {
|
||||
"args": {},
|
||||
|
||||
@@ -153,13 +153,13 @@ class AuthenticatorValidateStageTests(FlowTestCase):
|
||||
plan.append_stage(stage)
|
||||
plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = [
|
||||
{
|
||||
"device_class": DeviceClasses.STATIC,
|
||||
"device_class": "static",
|
||||
"device_uid": "1",
|
||||
"challenge": {},
|
||||
"last_used": now(),
|
||||
},
|
||||
{
|
||||
"device_class": DeviceClasses.TOTP,
|
||||
"device_class": "totp",
|
||||
"device_uid": "2",
|
||||
"challenge": {},
|
||||
"last_used": now(),
|
||||
@@ -172,7 +172,7 @@ class AuthenticatorValidateStageTests(FlowTestCase):
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
|
||||
data={
|
||||
"selected_challenge": {
|
||||
"device_class": DeviceClasses.WEBAUTHN,
|
||||
"device_class": "baz",
|
||||
"device_uid": "quox",
|
||||
"challenge": {},
|
||||
"last_used": None,
|
||||
|
||||
@@ -162,7 +162,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
|
||||
session = self.client.session
|
||||
plan = FlowPlan(flow_pk=flow.pk.hex)
|
||||
plan.append_stage(stage)
|
||||
plan.append_stage(UserLoginStage.objects.create(name=generate_id()))
|
||||
plan.append_stage(UserLoginStage(name=generate_id()))
|
||||
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
|
||||
session[SESSION_KEY_PLAN] = plan
|
||||
session.save()
|
||||
@@ -282,7 +282,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
|
||||
session = self.client.session
|
||||
plan = FlowPlan(flow_pk=flow.pk.hex)
|
||||
plan.append_stage(stage)
|
||||
plan.append_stage(UserLoginStage.objects.create(name=generate_id()))
|
||||
plan.append_stage(UserLoginStage(name=generate_id()))
|
||||
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
|
||||
plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = [
|
||||
{
|
||||
@@ -359,7 +359,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
|
||||
session = self.client.session
|
||||
plan = FlowPlan(flow_pk=flow.pk.hex)
|
||||
plan.append_stage(stage)
|
||||
plan.append_stage(UserLoginStage.objects.create(name=generate_id()))
|
||||
plan.append_stage(UserLoginStage(name=generate_id()))
|
||||
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
|
||||
plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = [
|
||||
{
|
||||
@@ -441,7 +441,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
|
||||
session = self.client.session
|
||||
plan = FlowPlan(flow_pk=flow.pk.hex)
|
||||
plan.append_stage(stage)
|
||||
plan.append_stage(UserLoginStage.objects.create(name=generate_id()))
|
||||
plan.append_stage(UserLoginStage(name=generate_id()))
|
||||
plan.context[PLAN_CONTEXT_DEVICE_CHALLENGES] = [
|
||||
{
|
||||
"device_class": device.__class__.__name__.lower().replace("device", ""),
|
||||
|
||||
File diff suppressed because one or more lines are too long
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user