mirror of
https://github.com/unraid/api.git
synced 2026-01-06 08:39:54 -06:00
Compare commits
24 Commits
4.15.1-bui
...
4.18.1-bui
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9b750d2a48 | ||
|
|
09b0051c73 | ||
|
|
bc15bd3d70 | ||
|
|
7c3aee8f3f | ||
|
|
c7c3bb57ea | ||
|
|
99dbad57d5 | ||
|
|
c42f79d406 | ||
|
|
4d8588b173 | ||
|
|
0d1d27064e | ||
|
|
0fe2c2c1c8 | ||
|
|
a8e4119270 | ||
|
|
372a4ebb42 | ||
|
|
4e945f5f56 | ||
|
|
6356f9c41d | ||
|
|
a1ee915ca5 | ||
|
|
c147a6b507 | ||
|
|
9d42b36f74 | ||
|
|
26a95af953 | ||
|
|
0ead267838 | ||
|
|
163763f9e5 | ||
|
|
6469d002b7 | ||
|
|
ab11e7ff7f | ||
|
|
7316dc753f | ||
|
|
1bf74e9d6c |
2
.github/workflows/build-plugin.yml
vendored
2
.github/workflows/build-plugin.yml
vendored
@@ -152,7 +152,7 @@ jobs:
|
||||
with:
|
||||
workflow: release-production.yml
|
||||
inputs: '{ "version": "${{ steps.vars.outputs.API_VERSION }}" }'
|
||||
token: ${{ secrets.WORKFLOW_TRIGGER_PAT }}
|
||||
token: ${{ secrets.UNRAID_BOT_GITHUB_ADMIN_TOKEN }}
|
||||
|
||||
- name: Upload to Cloudflare
|
||||
if: inputs.RELEASE_CREATED == 'false'
|
||||
|
||||
2
.github/workflows/deploy-storybook.yml
vendored
2
.github/workflows/deploy-storybook.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '22.18.0'
|
||||
node-version: '22.19.0'
|
||||
|
||||
- uses: pnpm/action-setup@v4
|
||||
name: Install pnpm
|
||||
|
||||
34
.github/workflows/main.yml
vendored
34
.github/workflows/main.yml
vendored
@@ -117,42 +117,62 @@ jobs:
|
||||
# Verify libvirt is running using sudo to bypass group membership delays
|
||||
sudo virsh list --all || true
|
||||
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
with:
|
||||
bun-version: latest
|
||||
- name: Build UI Package First
|
||||
run: |
|
||||
echo "🔧 Building UI package for web tests dependency..."
|
||||
cd ../unraid-ui && pnpm run build
|
||||
|
||||
- name: Run Tests Concurrently
|
||||
run: |
|
||||
set -e
|
||||
|
||||
# Run all tests in parallel with labeled output
|
||||
# Run all tests in parallel with labeled output and coverage generation
|
||||
echo "🚀 Starting API coverage tests..."
|
||||
pnpm run coverage > api-test.log 2>&1 &
|
||||
API_PID=$!
|
||||
|
||||
echo "🚀 Starting Connect plugin tests..."
|
||||
(cd ../packages/unraid-api-plugin-connect && pnpm test) > connect-test.log 2>&1 &
|
||||
(cd ../packages/unraid-api-plugin-connect && pnpm test --coverage 2>/dev/null || pnpm test) > connect-test.log 2>&1 &
|
||||
CONNECT_PID=$!
|
||||
|
||||
echo "🚀 Starting Shared package tests..."
|
||||
(cd ../packages/unraid-shared && pnpm test) > shared-test.log 2>&1 &
|
||||
(cd ../packages/unraid-shared && pnpm test --coverage 2>/dev/null || pnpm test) > shared-test.log 2>&1 &
|
||||
SHARED_PID=$!
|
||||
|
||||
echo "🚀 Starting Web package coverage tests..."
|
||||
(cd ../web && (pnpm test --coverage || pnpm test)) > web-test.log 2>&1 &
|
||||
WEB_PID=$!
|
||||
|
||||
echo "🚀 Starting UI package coverage tests..."
|
||||
(cd ../unraid-ui && pnpm test --coverage 2>/dev/null || pnpm test) > ui-test.log 2>&1 &
|
||||
UI_PID=$!
|
||||
|
||||
# Wait for all processes and capture exit codes
|
||||
wait $API_PID && echo "✅ API tests completed" || { echo "❌ API tests failed"; API_EXIT=1; }
|
||||
wait $CONNECT_PID && echo "✅ Connect tests completed" || { echo "❌ Connect tests failed"; CONNECT_EXIT=1; }
|
||||
wait $SHARED_PID && echo "✅ Shared tests completed" || { echo "❌ Shared tests failed"; SHARED_EXIT=1; }
|
||||
wait $WEB_PID && echo "✅ Web tests completed" || { echo "❌ Web tests failed"; WEB_EXIT=1; }
|
||||
wait $UI_PID && echo "✅ UI tests completed" || { echo "❌ UI tests failed"; UI_EXIT=1; }
|
||||
|
||||
# Display all outputs
|
||||
echo "📋 API Test Results:" && cat api-test.log
|
||||
echo "📋 Connect Plugin Test Results:" && cat connect-test.log
|
||||
echo "📋 Shared Package Test Results:" && cat shared-test.log
|
||||
echo "📋 Web Package Test Results:" && cat web-test.log
|
||||
echo "📋 UI Package Test Results:" && cat ui-test.log
|
||||
|
||||
# Exit with error if any test failed
|
||||
if [[ ${API_EXIT:-0} -eq 1 || ${CONNECT_EXIT:-0} -eq 1 || ${SHARED_EXIT:-0} -eq 1 ]]; then
|
||||
if [[ ${API_EXIT:-0} -eq 1 || ${CONNECT_EXIT:-0} -eq 1 || ${SHARED_EXIT:-0} -eq 1 || ${WEB_EXIT:-0} -eq 1 || ${UI_EXIT:-0} -eq 1 ]]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Upload all coverage reports to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
files: ./coverage/coverage-final.json,../web/coverage/coverage-final.json,../unraid-ui/coverage/coverage-final.json,../packages/unraid-api-plugin-connect/coverage/coverage-final.json,../packages/unraid-shared/coverage/coverage-final.json
|
||||
fail_ci_if_error: false
|
||||
|
||||
build-api:
|
||||
name: Build API
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
2
.github/workflows/release-production.yml
vendored
2
.github/workflows/release-production.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
||||
prerelease: false
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '22.18.0'
|
||||
node-version: '22.19.0'
|
||||
- run: |
|
||||
cat << 'EOF' > release-notes.txt
|
||||
${{ steps.release-info.outputs.body }}
|
||||
|
||||
4
.github/workflows/test-libvirt.yml
vendored
4
.github/workflows/test-libvirt.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.13.6"
|
||||
python-version: "3.13.7"
|
||||
|
||||
- name: Cache APT Packages
|
||||
uses: awalsh128/cache-apt-pkgs-action@v1.5.3
|
||||
@@ -44,7 +44,7 @@ jobs:
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 10.14.0
|
||||
version: 10.15.1
|
||||
run_install: false
|
||||
|
||||
- name: Get pnpm store directory
|
||||
|
||||
@@ -1 +1 @@
|
||||
{".":"4.15.1"}
|
||||
{".":"4.18.1"}
|
||||
|
||||
@@ -233,8 +233,8 @@
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
padding: 0;
|
||||
color: var(--gray12);
|
||||
border: 1px solid var(--gray4);
|
||||
color: hsl(var(--foreground));
|
||||
border: 1px solid hsl(var(--border));
|
||||
transform: var(--toast-close-button-transform);
|
||||
border-radius: 50%;
|
||||
cursor: pointer;
|
||||
@@ -243,7 +243,7 @@
|
||||
}
|
||||
|
||||
[data-sonner-toast] [data-close-button] {
|
||||
background: var(--gray1);
|
||||
background: hsl(var(--background));
|
||||
}
|
||||
|
||||
:where([data-sonner-toast]) :where([data-close-button]):focus-visible {
|
||||
@@ -255,8 +255,8 @@
|
||||
}
|
||||
|
||||
[data-sonner-toast]:hover [data-close-button]:hover {
|
||||
background: var(--gray2);
|
||||
border-color: var(--gray5);
|
||||
background: hsl(var(--muted));
|
||||
border-color: hsl(var(--border));
|
||||
}
|
||||
|
||||
/* Leave a ghost div to avoid setting hover to false when swiping out */
|
||||
@@ -414,10 +414,27 @@
|
||||
}
|
||||
|
||||
[data-sonner-toaster][data-theme='light'] {
|
||||
--normal-bg: #fff;
|
||||
--normal-border: var(--gray4);
|
||||
--normal-text: var(--gray12);
|
||||
--normal-bg: hsl(var(--background));
|
||||
--normal-border: hsl(var(--border));
|
||||
--normal-text: hsl(var(--foreground));
|
||||
|
||||
--success-bg: hsl(var(--background));
|
||||
--success-border: hsl(var(--border));
|
||||
--success-text: hsl(140, 100%, 27%);
|
||||
|
||||
--info-bg: hsl(var(--background));
|
||||
--info-border: hsl(var(--border));
|
||||
--info-text: hsl(210, 92%, 45%);
|
||||
|
||||
--warning-bg: hsl(var(--background));
|
||||
--warning-border: hsl(var(--border));
|
||||
--warning-text: hsl(31, 92%, 45%);
|
||||
|
||||
--error-bg: hsl(var(--background));
|
||||
--error-border: hsl(var(--border));
|
||||
--error-text: hsl(360, 100%, 45%);
|
||||
|
||||
/* Old colors, preserved for reference
|
||||
--success-bg: hsl(143, 85%, 96%);
|
||||
--success-border: hsl(145, 92%, 91%);
|
||||
--success-text: hsl(140, 100%, 27%);
|
||||
@@ -432,26 +449,43 @@
|
||||
|
||||
--error-bg: hsl(359, 100%, 97%);
|
||||
--error-border: hsl(359, 100%, 94%);
|
||||
--error-text: hsl(360, 100%, 45%);
|
||||
--error-text: hsl(360, 100%, 45%); */
|
||||
}
|
||||
|
||||
[data-sonner-toaster][data-theme='light'] [data-sonner-toast][data-invert='true'] {
|
||||
--normal-bg: #000;
|
||||
--normal-border: hsl(0, 0%, 20%);
|
||||
--normal-text: var(--gray1);
|
||||
--normal-bg: hsl(0 0% 3.9%);
|
||||
--normal-border: hsl(0 0% 14.9%);
|
||||
--normal-text: hsl(0 0% 98%);
|
||||
}
|
||||
|
||||
[data-sonner-toaster][data-theme='dark'] [data-sonner-toast][data-invert='true'] {
|
||||
--normal-bg: #fff;
|
||||
--normal-border: var(--gray3);
|
||||
--normal-text: var(--gray12);
|
||||
--normal-bg: hsl(0 0% 100%);
|
||||
--normal-border: hsl(0 0% 89.8%);
|
||||
--normal-text: hsl(0 0% 3.9%);
|
||||
}
|
||||
|
||||
[data-sonner-toaster][data-theme='dark'] {
|
||||
--normal-bg: #000;
|
||||
--normal-border: hsl(0, 0%, 20%);
|
||||
--normal-text: var(--gray1);
|
||||
--normal-bg: hsl(var(--background));
|
||||
--normal-border: hsl(var(--border));
|
||||
--normal-text: hsl(var(--foreground));
|
||||
|
||||
--success-bg: hsl(var(--background));
|
||||
--success-border: hsl(var(--border));
|
||||
--success-text: hsl(150, 86%, 65%);
|
||||
|
||||
--info-bg: hsl(var(--background));
|
||||
--info-border: hsl(var(--border));
|
||||
--info-text: hsl(216, 87%, 65%);
|
||||
|
||||
--warning-bg: hsl(var(--background));
|
||||
--warning-border: hsl(var(--border));
|
||||
--warning-text: hsl(46, 87%, 65%);
|
||||
|
||||
--error-bg: hsl(var(--background));
|
||||
--error-border: hsl(var(--border));
|
||||
--error-text: hsl(358, 100%, 81%);
|
||||
|
||||
/* Old colors, preserved for reference
|
||||
--success-bg: hsl(150, 100%, 6%);
|
||||
--success-border: hsl(147, 100%, 12%);
|
||||
--success-text: hsl(150, 86%, 65%);
|
||||
@@ -466,7 +500,7 @@
|
||||
|
||||
--error-bg: hsl(358, 76%, 10%);
|
||||
--error-border: hsl(357, 89%, 16%);
|
||||
--error-text: hsl(358, 100%, 81%);
|
||||
--error-text: hsl(358, 100%, 81%); */
|
||||
}
|
||||
|
||||
[data-rich-colors='true'][data-sonner-toast][data-type='success'] {
|
||||
@@ -541,7 +575,7 @@
|
||||
|
||||
.sonner-loading-bar {
|
||||
animation: sonner-spin 1.2s linear infinite;
|
||||
background: var(--gray11);
|
||||
background: hsl(var(--muted-foreground));
|
||||
border-radius: 6px;
|
||||
height: 8%;
|
||||
left: -10%;
|
||||
|
||||
@@ -157,4 +157,7 @@ Enables GraphQL playground at `http://tower.local/graphql`
|
||||
|
||||
- We are using tailwind v4 we do not need a tailwind config anymore
|
||||
- always search the internet for tailwind v4 documentation when making tailwind related style changes
|
||||
- never run or restart the API server or web server. I will handle the lifecylce, simply wait and ask me to do this for you
|
||||
- never run or restart the API server or web server. I will handle the lifecycle, simply wait and ask me to do this for you
|
||||
- Never use the `any` type. Always prefer proper typing
|
||||
- Avoid using casting whenever possible, prefer proper typing from the start
|
||||
- **IMPORTANT:** cache-manager v7 expects TTL values in **milliseconds**, not seconds. Always use milliseconds when setting cache TTL (e.g., 600000 for 10 minutes, not 600)
|
||||
|
||||
@@ -1,5 +1,58 @@
|
||||
# Changelog
|
||||
|
||||
## [4.18.1](https://github.com/unraid/api/compare/v4.18.0...v4.18.1) (2025-09-03)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* OIDC and API Key management issues ([#1642](https://github.com/unraid/api/issues/1642)) ([0fe2c2c](https://github.com/unraid/api/commit/0fe2c2c1c85dcc547e4b1217a3b5636d7dd6d4b4))
|
||||
* rm redundant emission to `$HOME/.pm2/logs` ([#1640](https://github.com/unraid/api/issues/1640)) ([a8e4119](https://github.com/unraid/api/commit/a8e4119270868a1dabccd405853a7340f8dcd8a5))
|
||||
|
||||
## [4.18.0](https://github.com/unraid/api/compare/v4.17.0...v4.18.0) (2025-09-02)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* **api:** enhance OIDC redirect URI handling in service and tests ([#1618](https://github.com/unraid/api/issues/1618)) ([4e945f5](https://github.com/unraid/api/commit/4e945f5f56ce059eb275a9576caf3194a5df8a90))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* api key creation cli ([#1637](https://github.com/unraid/api/issues/1637)) ([c147a6b](https://github.com/unraid/api/commit/c147a6b5075969e77798210c4a5cfd1fa5b96ae3))
|
||||
* **cli:** support `--log-level` for `start` and `restart` cmds ([#1623](https://github.com/unraid/api/issues/1623)) ([a1ee915](https://github.com/unraid/api/commit/a1ee915ca52e5a063eccf8facbada911a63f37f6))
|
||||
* confusing server -> status query ([#1635](https://github.com/unraid/api/issues/1635)) ([9d42b36](https://github.com/unraid/api/commit/9d42b36f74274cad72490da5152fdb98fdc5b89b))
|
||||
* use unraid css variables in sonner ([#1634](https://github.com/unraid/api/issues/1634)) ([26a95af](https://github.com/unraid/api/commit/26a95af9539d05a837112d62dc6b7dd46761c83f))
|
||||
|
||||
## [4.17.0](https://github.com/unraid/api/compare/v4.16.0...v4.17.0) (2025-08-27)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* add tailwind class sort plugin ([#1562](https://github.com/unraid/api/issues/1562)) ([ab11e7f](https://github.com/unraid/api/commit/ab11e7ff7ff74da1f1cd5e49938459d00bfc846b))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* cleanup obsoleted legacy api keys on api startup (cli / connect) ([#1630](https://github.com/unraid/api/issues/1630)) ([6469d00](https://github.com/unraid/api/commit/6469d002b7b18e49c77ee650a4255974ab43e790))
|
||||
|
||||
## [4.16.0](https://github.com/unraid/api/compare/v4.15.1...v4.16.0) (2025-08-27)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* add `parityCheckStatus` field to `array` query ([#1611](https://github.com/unraid/api/issues/1611)) ([c508366](https://github.com/unraid/api/commit/c508366702b9fa20d9ed05559fe73da282116aa6))
|
||||
* generated UI API key management + OAuth-like API Key Flows ([#1609](https://github.com/unraid/api/issues/1609)) ([674323f](https://github.com/unraid/api/commit/674323fd87bbcc55932e6b28f6433a2de79b7ab0))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* **connect:** clear `wanport` upon disabling remote access ([#1624](https://github.com/unraid/api/issues/1624)) ([9df6a3f](https://github.com/unraid/api/commit/9df6a3f5ebb0319aa7e3fe3be6159d39ec6f587f))
|
||||
* **connect:** valid LAN FQDN while remote access is enabled ([#1625](https://github.com/unraid/api/issues/1625)) ([aa58888](https://github.com/unraid/api/commit/aa588883cc2e2fe4aa4aea1d035236c888638f5b))
|
||||
* correctly parse periods in share names from ini file ([#1629](https://github.com/unraid/api/issues/1629)) ([7d67a40](https://github.com/unraid/api/commit/7d67a404333a38d6e1ba5c3febf02be8b1b71901))
|
||||
* **rc.unraid-api:** remove profile sourcing ([#1622](https://github.com/unraid/api/issues/1622)) ([6947b5d](https://github.com/unraid/api/commit/6947b5d4aff70319116eb65cf4c639444f3749e9))
|
||||
* remove unused api key calls ([#1628](https://github.com/unraid/api/issues/1628)) ([9cd0d6a](https://github.com/unraid/api/commit/9cd0d6ac658475efa25683ef6e3f2e1d68f7e903))
|
||||
* retry VMs init for up to 2 min ([#1612](https://github.com/unraid/api/issues/1612)) ([b2e7801](https://github.com/unraid/api/commit/b2e78012384e6b3f2630341281fc811026be23b9))
|
||||
|
||||
## [4.15.1](https://github.com/unraid/api/compare/v4.15.0...v4.15.1) (2025-08-20)
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
###########################################################
|
||||
# Development/Build Image
|
||||
###########################################################
|
||||
FROM node:22.18.0-bookworm-slim AS development
|
||||
FROM node:22.19.0-bookworm-slim AS development
|
||||
|
||||
# Install build tools and dependencies
|
||||
RUN apt-get update -y && apt-get install -y \
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
{
|
||||
"version": "4.15.1",
|
||||
"version": "4.18.1",
|
||||
"extraOrigins": [],
|
||||
"sandbox": true,
|
||||
"ssoSubIds": [],
|
||||
|
||||
@@ -17,5 +17,6 @@
|
||||
],
|
||||
"buttonText": "Login With Unraid.net"
|
||||
}
|
||||
]
|
||||
],
|
||||
"defaultAllowedOrigins": []
|
||||
}
|
||||
@@ -21,7 +21,14 @@ unraid-api start [--log-level <level>]
|
||||
Starts the Unraid API service.
|
||||
|
||||
Options:
|
||||
- `--log-level`: Set logging level (trace|debug|info|warn|error)
|
||||
|
||||
- `--log-level`: Set logging level (trace|debug|info|warn|error|fatal)
|
||||
|
||||
Alternative: You can also set the log level using the `LOG_LEVEL` environment variable:
|
||||
|
||||
```bash
|
||||
LOG_LEVEL=trace unraid-api start
|
||||
```
|
||||
|
||||
### Stop
|
||||
|
||||
@@ -36,11 +43,21 @@ Stops the Unraid API service.
|
||||
### Restart
|
||||
|
||||
```bash
|
||||
unraid-api restart
|
||||
unraid-api restart [--log-level <level>]
|
||||
```
|
||||
|
||||
Restarts the Unraid API service.
|
||||
|
||||
Options:
|
||||
|
||||
- `--log-level`: Set logging level (trace|debug|info|warn|error|fatal)
|
||||
|
||||
Alternative: You can also set the log level using the `LOG_LEVEL` environment variable:
|
||||
|
||||
```bash
|
||||
LOG_LEVEL=trace unraid-api restart
|
||||
```
|
||||
|
||||
### Logs
|
||||
|
||||
```bash
|
||||
|
||||
252
api/docs/public/programmatic-api-key-management.md
Normal file
252
api/docs/public/programmatic-api-key-management.md
Normal file
@@ -0,0 +1,252 @@
|
||||
---
|
||||
title: Programmatic API Key Management
|
||||
description: Create, use, and delete API keys programmatically for automated workflows
|
||||
sidebar_position: 4
|
||||
---
|
||||
|
||||
# Programmatic API Key Management
|
||||
|
||||
This guide explains how to create, use, and delete API keys programmatically using the Unraid API CLI, enabling automated workflows and scripts.
|
||||
|
||||
## Overview
|
||||
|
||||
The `unraid-api apikey` command supports both interactive and non-interactive modes, making it suitable for:
|
||||
|
||||
- Automated deployment scripts
|
||||
- CI/CD pipelines
|
||||
- Temporary access provisioning
|
||||
- Infrastructure as code workflows
|
||||
|
||||
:::tip[Quick Start]
|
||||
Jump to the [Complete Workflow Example](#complete-workflow-example) to see everything in action.
|
||||
:::
|
||||
|
||||
## Creating API Keys Programmatically
|
||||
|
||||
### Basic Creation with JSON Output
|
||||
|
||||
Use the `--json` flag to get machine-readable output:
|
||||
|
||||
```bash
|
||||
unraid-api apikey --create --name "workflow key" --roles ADMIN --json
|
||||
```
|
||||
|
||||
**Output:**
|
||||
|
||||
```json
|
||||
{
|
||||
"key": "your-generated-api-key-here",
|
||||
"name": "workflow key",
|
||||
"id": "generated-uuid"
|
||||
}
|
||||
```
|
||||
|
||||
### Advanced Creation with Permissions
|
||||
|
||||
```bash
|
||||
unraid-api apikey --create \
|
||||
--name "limited access key" \
|
||||
--permissions "DOCKER:READ_ANY,ARRAY:READ_ANY" \
|
||||
--description "Read-only access for monitoring" \
|
||||
--json
|
||||
```
|
||||
|
||||
### Handling Existing Keys
|
||||
|
||||
If a key with the same name exists, use `--overwrite`:
|
||||
|
||||
```bash
|
||||
unraid-api apikey --create --name "existing key" --roles ADMIN --overwrite --json
|
||||
```
|
||||
|
||||
:::warning[Key Replacement]
|
||||
The `--overwrite` flag will permanently replace the existing key. The old key will be immediately invalidated.
|
||||
:::
|
||||
|
||||
## Deleting API Keys Programmatically
|
||||
|
||||
### Non-Interactive Deletion
|
||||
|
||||
Delete a key by name without prompts:
|
||||
|
||||
```bash
|
||||
unraid-api apikey --delete --name "workflow key"
|
||||
```
|
||||
|
||||
**Output:**
|
||||
|
||||
```
|
||||
Successfully deleted 1 API key
|
||||
```
|
||||
|
||||
### JSON Output for Deletion
|
||||
|
||||
Use `--json` flag for machine-readable delete confirmation:
|
||||
|
||||
```bash
|
||||
unraid-api apikey --delete --name "workflow key" --json
|
||||
```
|
||||
|
||||
**Success Output:**
|
||||
|
||||
```json
|
||||
{
|
||||
"deleted": 1,
|
||||
"keys": [
|
||||
{
|
||||
"id": "generated-uuid",
|
||||
"name": "workflow key"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Error Output:**
|
||||
|
||||
```json
|
||||
{
|
||||
"deleted": 0,
|
||||
"error": "No API key found with name: nonexistent key"
|
||||
}
|
||||
```
|
||||
|
||||
### Error Handling
|
||||
|
||||
When the specified key doesn't exist:
|
||||
|
||||
```bash
|
||||
unraid-api apikey --delete --name "nonexistent key"
|
||||
# Output: No API keys found to delete
|
||||
```
|
||||
|
||||
**JSON Error Output:**
|
||||
|
||||
```json
|
||||
{
|
||||
"deleted": 0,
|
||||
"message": "No API keys found to delete"
|
||||
}
|
||||
```
|
||||
|
||||
## Complete Workflow Example
|
||||
|
||||
Here's a complete example for temporary access provisioning:
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# 1. Create temporary API key
|
||||
echo "Creating temporary API key..."
|
||||
KEY_DATA=$(unraid-api apikey --create \
|
||||
--name "temp deployment key" \
|
||||
--roles ADMIN \
|
||||
--description "Temporary key for deployment $(date)" \
|
||||
--json)
|
||||
|
||||
# 2. Extract the API key
|
||||
API_KEY=$(echo "$KEY_DATA" | jq -r '.key')
|
||||
echo "API key created successfully"
|
||||
|
||||
# 3. Use the key for operations
|
||||
echo "Configuring services..."
|
||||
curl -H "Authorization: Bearer $API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"provider": "azure", "clientId": "your-client-id"}' \
|
||||
http://localhost:3001/graphql
|
||||
|
||||
# 4. Clean up (always runs, even on error)
|
||||
trap 'echo "Cleaning up..."; unraid-api apikey --delete --name "temp deployment key"' EXIT
|
||||
|
||||
echo "Deployment completed successfully"
|
||||
```
|
||||
|
||||
## Command Reference
|
||||
|
||||
### Create Command Options
|
||||
|
||||
| Flag | Description | Example |
|
||||
| ----------------------- | ----------------------- | --------------------------------- |
|
||||
| `--name <name>` | Key name (required) | `--name "my key"` |
|
||||
| `--roles <roles>` | Comma-separated roles | `--roles ADMIN,VIEWER` |
|
||||
| `--permissions <perms>` | Resource:action pairs | `--permissions "DOCKER:READ_ANY"` |
|
||||
| `--description <desc>` | Key description | `--description "CI/CD key"` |
|
||||
| `--overwrite` | Replace existing key | `--overwrite` |
|
||||
| `--json` | Machine-readable output | `--json` |
|
||||
|
||||
### Available Roles
|
||||
|
||||
- `ADMIN` - Full system access
|
||||
- `CONNECT` - Unraid Connect features
|
||||
- `VIEWER` - Read-only access
|
||||
- `GUEST` - Limited access
|
||||
|
||||
### Available Resources and Actions
|
||||
|
||||
**Resources:** `ACTIVATION_CODE`, `API_KEY`, `ARRAY`, `CLOUD`, `CONFIG`, `CONNECT`, `CONNECT__REMOTE_ACCESS`, `CUSTOMIZATIONS`, `DASHBOARD`, `DISK`, `DISPLAY`, `DOCKER`, `FLASH`, `INFO`, `LOGS`, `ME`, `NETWORK`, `NOTIFICATIONS`, `ONLINE`, `OS`, `OWNER`, `PERMISSION`, `REGISTRATION`, `SERVERS`, `SERVICES`, `SHARE`, `VARS`, `VMS`, `WELCOME`
|
||||
|
||||
**Actions:** `CREATE_ANY`, `CREATE_OWN`, `READ_ANY`, `READ_OWN`, `UPDATE_ANY`, `UPDATE_OWN`, `DELETE_ANY`, `DELETE_OWN`
|
||||
|
||||
### Delete Command Options
|
||||
|
||||
| Flag | Description | Example |
|
||||
| --------------- | ------------------------ | ----------------- |
|
||||
| `--delete` | Enable delete mode | `--delete` |
|
||||
| `--name <name>` | Key to delete (optional) | `--name "my key"` |
|
||||
|
||||
**Note:** If `--name` is omitted, the command runs interactively.
|
||||
|
||||
## Best Practices
|
||||
|
||||
:::info[Security Best Practices]
|
||||
**Minimal Permissions**
|
||||
|
||||
- Use specific permissions instead of ADMIN role when possible
|
||||
- Example: `--permissions "DOCKER:READ_ANY"` instead of `--roles ADMIN`
|
||||
|
||||
**Key Lifecycle Management**
|
||||
|
||||
- Always clean up temporary keys after use
|
||||
- Store API keys securely (environment variables, secrets management)
|
||||
- Use descriptive names and descriptions for audit trails
|
||||
:::
|
||||
|
||||
### Error Handling
|
||||
|
||||
- Check exit codes (`$?`) after each command
|
||||
- Use `set -e` in bash scripts to fail fast
|
||||
- Implement proper cleanup with `trap`
|
||||
|
||||
### Key Naming
|
||||
|
||||
- Use descriptive names that include purpose and date
|
||||
- Names must contain only letters, numbers, and spaces
|
||||
- Unicode letters are supported
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
:::note[Common Error Messages]
|
||||
|
||||
**"API key name must contain only letters, numbers, and spaces"**
|
||||
|
||||
- **Solution:** Remove special characters like hyphens, underscores, or symbols
|
||||
|
||||
**"API key with name 'x' already exists"**
|
||||
|
||||
- **Solution:** Use `--overwrite` flag or choose a different name
|
||||
|
||||
**"Please add at least one role or permission to the key"**
|
||||
|
||||
- **Solution:** Specify either `--roles` or `--permissions` (or both)
|
||||
|
||||
:::
|
||||
|
||||
### Debug Mode
|
||||
|
||||
For troubleshooting, run with debug logging:
|
||||
|
||||
```bash
|
||||
LOG_LEVEL=debug unraid-api apikey --create --name "debug key" --roles ADMIN
|
||||
```
|
||||
@@ -13,7 +13,9 @@
|
||||
"watch": false,
|
||||
"interpreter": "/usr/local/bin/node",
|
||||
"ignore_watch": ["node_modules", "src", ".env.*", "myservers.cfg"],
|
||||
"log_file": "/var/log/graphql-api.log",
|
||||
"out_file": "/var/log/graphql-api.log",
|
||||
"error_file": "/var/log/graphql-api.log",
|
||||
"merge_logs": true,
|
||||
"kill_timeout": 10000
|
||||
}
|
||||
]
|
||||
|
||||
@@ -1798,6 +1798,8 @@ type Server implements Node {
|
||||
guid: String!
|
||||
apikey: String!
|
||||
name: String!
|
||||
|
||||
"""Whether this server is online or offline"""
|
||||
status: ServerStatus!
|
||||
wanip: String!
|
||||
lanip: String!
|
||||
@@ -1854,7 +1856,7 @@ type OidcProvider {
|
||||
"""
|
||||
OIDC issuer URL (e.g., https://accounts.google.com). Required for auto-discovery via /.well-known/openid-configuration
|
||||
"""
|
||||
issuer: String!
|
||||
issuer: String
|
||||
|
||||
"""
|
||||
OAuth2 authorization endpoint URL. If omitted, will be auto-discovered from issuer/.well-known/openid-configuration
|
||||
@@ -1907,6 +1909,16 @@ enum AuthorizationRuleMode {
|
||||
AND
|
||||
}
|
||||
|
||||
type OidcConfiguration {
|
||||
"""List of configured OIDC providers"""
|
||||
providers: [OidcProvider!]!
|
||||
|
||||
"""
|
||||
Default allowed redirect origins that apply to all OIDC providers (e.g., Tailscale domains)
|
||||
"""
|
||||
defaultAllowedOrigins: [String!]
|
||||
}
|
||||
|
||||
type OidcSessionValidation {
|
||||
valid: Boolean!
|
||||
username: String
|
||||
@@ -2307,8 +2319,6 @@ type Query {
|
||||
getApiKeyCreationFormSchema: ApiKeyFormSettings!
|
||||
config: Config!
|
||||
flash: Flash!
|
||||
logFiles: [LogFile!]!
|
||||
logFile(path: String!, lines: Int, startLine: Int): LogFileContent!
|
||||
me: UserAccount!
|
||||
|
||||
"""Get all notifications"""
|
||||
@@ -2335,6 +2345,8 @@ type Query {
|
||||
disk(id: PrefixedID!): Disk!
|
||||
rclone: RCloneBackupSettings!
|
||||
info: Info!
|
||||
logFiles: [LogFile!]!
|
||||
logFile(path: String!, lines: Int, startLine: Int): LogFileContent!
|
||||
settings: Settings!
|
||||
isSSOEnabled: Boolean!
|
||||
|
||||
@@ -2347,6 +2359,9 @@ type Query {
|
||||
"""Get a specific OIDC provider by ID"""
|
||||
oidcProvider(id: PrefixedID!): OidcProvider
|
||||
|
||||
"""Get the full OIDC configuration (admin only)"""
|
||||
oidcConfiguration: OidcConfiguration!
|
||||
|
||||
"""Validate an OIDC session token (internal use for CLI validation)"""
|
||||
validateOidcSession(token: String!): OidcSessionValidation!
|
||||
metrics: Metrics!
|
||||
@@ -2590,13 +2605,13 @@ input AccessUrlInput {
|
||||
}
|
||||
|
||||
type Subscription {
|
||||
logFile(path: String!): LogFileContent!
|
||||
notificationAdded: Notification!
|
||||
notificationsOverview: NotificationOverview!
|
||||
ownerSubscription: Owner!
|
||||
serversSubscription: Server!
|
||||
parityHistorySubscription: ParityCheck!
|
||||
arraySubscription: UnraidArray!
|
||||
logFile(path: String!): LogFileContent!
|
||||
systemMetricsCpu: CpuUtilization!
|
||||
systemMetricsMemory: MemoryUtilization!
|
||||
upsUpdates: UPSDevice!
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@unraid/api",
|
||||
"version": "4.15.1",
|
||||
"version": "4.18.1",
|
||||
"main": "src/cli/index.ts",
|
||||
"type": "module",
|
||||
"corepack": {
|
||||
@@ -10,7 +10,7 @@
|
||||
"author": "Lime Technology, Inc. <unraid.net>",
|
||||
"license": "GPL-2.0-or-later",
|
||||
"engines": {
|
||||
"pnpm": "10.14.0"
|
||||
"pnpm": "10.15.1"
|
||||
},
|
||||
"scripts": {
|
||||
"// Development": "",
|
||||
@@ -51,7 +51,7 @@
|
||||
"unraid-api": "dist/cli.js"
|
||||
},
|
||||
"dependencies": {
|
||||
"@apollo/client": "3.13.9",
|
||||
"@apollo/client": "3.14.0",
|
||||
"@apollo/server": "4.12.2",
|
||||
"@as-integrations/fastify": "2.1.1",
|
||||
"@fastify/cookie": "11.0.2",
|
||||
@@ -73,7 +73,7 @@
|
||||
"@nestjs/platform-fastify": "11.1.6",
|
||||
"@nestjs/schedule": "6.0.0",
|
||||
"@nestjs/throttler": "6.4.0",
|
||||
"@reduxjs/toolkit": "2.8.2",
|
||||
"@reduxjs/toolkit": "2.9.0",
|
||||
"@runonflux/nat-upnp": "1.0.2",
|
||||
"@types/diff": "8.0.0",
|
||||
"@unraid/libvirt": "2.1.0",
|
||||
@@ -82,7 +82,7 @@
|
||||
"atomically": "2.0.3",
|
||||
"bycontract": "2.0.11",
|
||||
"bytes": "3.1.2",
|
||||
"cache-manager": "7.1.1",
|
||||
"cache-manager": "7.2.0",
|
||||
"cacheable-lookup": "7.0.0",
|
||||
"camelcase-keys": "9.1.3",
|
||||
"casbin": "5.38.0",
|
||||
@@ -98,7 +98,8 @@
|
||||
"cross-fetch": "4.1.0",
|
||||
"diff": "8.0.2",
|
||||
"dockerode": "4.0.7",
|
||||
"dotenv": "17.2.1",
|
||||
"dotenv": "17.2.2",
|
||||
"escape-html": "1.0.3",
|
||||
"execa": "9.6.0",
|
||||
"exit-hook": "4.0.0",
|
||||
"fastify": "5.5.0",
|
||||
@@ -106,7 +107,7 @@
|
||||
"fs-extra": "11.3.1",
|
||||
"glob": "11.0.3",
|
||||
"global-agent": "3.0.0",
|
||||
"got": "14.4.7",
|
||||
"got": "14.4.8",
|
||||
"graphql": "16.11.0",
|
||||
"graphql-fields": "2.0.3",
|
||||
"graphql-scalars": "1.24.2",
|
||||
@@ -115,31 +116,31 @@
|
||||
"graphql-ws": "6.0.6",
|
||||
"ini": "5.0.0",
|
||||
"ip": "2.0.1",
|
||||
"jose": "6.0.12",
|
||||
"jose": "6.1.0",
|
||||
"json-bigint-patch": "0.0.8",
|
||||
"lodash-es": "4.17.21",
|
||||
"multi-ini": "2.3.2",
|
||||
"mustache": "4.2.0",
|
||||
"nest-authz": "2.17.0",
|
||||
"nest-commander": "3.18.0",
|
||||
"nest-commander": "3.19.0",
|
||||
"nestjs-pino": "4.4.0",
|
||||
"node-cache": "5.1.2",
|
||||
"node-window-polyfill": "1.0.4",
|
||||
"openid-client": "6.6.2",
|
||||
"openid-client": "6.7.1",
|
||||
"p-retry": "6.2.1",
|
||||
"passport-custom": "1.1.1",
|
||||
"passport-http-header-strategy": "1.1.0",
|
||||
"path-type": "6.0.0",
|
||||
"pino": "9.8.0",
|
||||
"pino": "9.9.1",
|
||||
"pino-http": "10.5.0",
|
||||
"pino-pretty": "13.1.1",
|
||||
"pm2": "6.0.8",
|
||||
"pm2": "6.0.10",
|
||||
"reflect-metadata": "^0.1.14",
|
||||
"rxjs": "7.8.2",
|
||||
"semver": "7.7.2",
|
||||
"strftime": "0.10.3",
|
||||
"systeminformation": "5.27.7",
|
||||
"undici": "7.13.0",
|
||||
"systeminformation": "5.27.8",
|
||||
"undici": "7.15.0",
|
||||
"uuid": "11.1.0",
|
||||
"ws": "8.18.3",
|
||||
"zen-observable-ts": "1.1.0",
|
||||
@@ -154,7 +155,7 @@
|
||||
}
|
||||
},
|
||||
"devDependencies": {
|
||||
"@eslint/js": "9.33.0",
|
||||
"@eslint/js": "9.34.0",
|
||||
"@graphql-codegen/add": "5.0.3",
|
||||
"@graphql-codegen/cli": "5.0.7",
|
||||
"@graphql-codegen/fragment-matcher": "5.1.0",
|
||||
@@ -162,19 +163,19 @@
|
||||
"@graphql-codegen/typed-document-node": "5.1.2",
|
||||
"@graphql-codegen/typescript": "4.1.6",
|
||||
"@graphql-codegen/typescript-operations": "4.6.1",
|
||||
"@graphql-codegen/typescript-resolvers": "4.5.1",
|
||||
"@graphql-codegen/typescript-resolvers": "4.5.2",
|
||||
"@graphql-typed-document-node/core": "3.2.0",
|
||||
"@ianvs/prettier-plugin-sort-imports": "4.6.1",
|
||||
"@ianvs/prettier-plugin-sort-imports": "4.7.0",
|
||||
"@nestjs/testing": "11.1.6",
|
||||
"@originjs/vite-plugin-commonjs": "1.0.3",
|
||||
"@rollup/plugin-node-resolve": "16.0.1",
|
||||
"@swc/core": "1.13.3",
|
||||
"@swc/core": "1.13.5",
|
||||
"@types/async-exit-hook": "2.0.2",
|
||||
"@types/bytes": "3.1.5",
|
||||
"@types/cli-table": "0.3.4",
|
||||
"@types/command-exists": "1.2.3",
|
||||
"@types/cors": "2.8.19",
|
||||
"@types/dockerode": "3.3.42",
|
||||
"@types/dockerode": "3.3.43",
|
||||
"@types/graphql-fields": "1.3.9",
|
||||
"@types/graphql-type-uuid": "0.2.6",
|
||||
"@types/ini": "4.1.1",
|
||||
@@ -182,37 +183,37 @@
|
||||
"@types/lodash": "4.17.20",
|
||||
"@types/lodash-es": "4.17.12",
|
||||
"@types/mustache": "4.2.6",
|
||||
"@types/node": "22.17.1",
|
||||
"@types/node": "22.18.0",
|
||||
"@types/pify": "6.1.0",
|
||||
"@types/semver": "7.7.0",
|
||||
"@types/semver": "7.7.1",
|
||||
"@types/sendmail": "1.4.7",
|
||||
"@types/stoppable": "1.1.3",
|
||||
"@types/strftime": "0.9.8",
|
||||
"@types/supertest": "6.0.3",
|
||||
"@types/uuid": "10.0.0",
|
||||
"@types/ws": "8.18.1",
|
||||
"@types/wtfnode": "0.7.3",
|
||||
"@types/wtfnode": "0.10.0",
|
||||
"@vitest/coverage-v8": "3.2.4",
|
||||
"@vitest/ui": "3.2.4",
|
||||
"eslint": "9.33.0",
|
||||
"eslint": "9.34.0",
|
||||
"eslint-plugin-import": "2.32.0",
|
||||
"eslint-plugin-no-relative-import-paths": "1.6.1",
|
||||
"eslint-plugin-prettier": "5.5.4",
|
||||
"jiti": "2.5.1",
|
||||
"nodemon": "3.1.10",
|
||||
"prettier": "3.6.2",
|
||||
"rollup-plugin-node-externals": "8.0.1",
|
||||
"rollup-plugin-node-externals": "8.1.1",
|
||||
"supertest": "7.1.4",
|
||||
"tsx": "4.20.3",
|
||||
"tsx": "4.20.5",
|
||||
"type-fest": "4.41.0",
|
||||
"typescript": "5.9.2",
|
||||
"typescript-eslint": "8.39.1",
|
||||
"unplugin-swc": "1.5.5",
|
||||
"vite": "7.1.1",
|
||||
"typescript-eslint": "8.42.0",
|
||||
"unplugin-swc": "1.5.7",
|
||||
"vite": "7.1.4",
|
||||
"vite-plugin-node": "7.0.0",
|
||||
"vite-tsconfig-paths": "5.1.4",
|
||||
"vitest": "3.2.4",
|
||||
"zx": "8.8.0"
|
||||
"zx": "8.8.1"
|
||||
},
|
||||
"overrides": {
|
||||
"eslint": {
|
||||
@@ -227,5 +228,5 @@
|
||||
}
|
||||
},
|
||||
"private": true,
|
||||
"packageManager": "pnpm@10.14.0"
|
||||
"packageManager": "pnpm@10.15.1"
|
||||
}
|
||||
|
||||
@@ -29,8 +29,24 @@ const stream = SUPPRESS_LOGS
|
||||
singleLine: true,
|
||||
hideObject: false,
|
||||
colorize: true,
|
||||
colorizeObjects: true,
|
||||
levelFirst: false,
|
||||
ignore: 'hostname,pid',
|
||||
destination: logDestination,
|
||||
translateTime: 'HH:mm:ss',
|
||||
customPrettifiers: {
|
||||
time: (timestamp: string | object) => `[${timestamp}`,
|
||||
level: (logLevel: string | object, key: string, log: any, extras: any) => {
|
||||
// Use labelColorized which preserves the colors
|
||||
const { labelColorized } = extras;
|
||||
const context = log.context || log.logger || 'app';
|
||||
return `${labelColorized} ${context}]`;
|
||||
},
|
||||
},
|
||||
messageFormat: (log: any, messageKey: string) => {
|
||||
const msg = log[messageKey] || log.msg || '';
|
||||
return msg;
|
||||
},
|
||||
})
|
||||
: logDestination;
|
||||
|
||||
|
||||
@@ -13,10 +13,11 @@ export const pubsub = new PubSub({ eventEmitter });
|
||||
|
||||
/**
|
||||
* Create a pubsub subscription.
|
||||
* @param channel The pubsub channel to subscribe to.
|
||||
* @param channel The pubsub channel to subscribe to. Can be either a predefined GRAPHQL_PUBSUB_CHANNEL
|
||||
* or a dynamic string for runtime-generated topics (e.g., log file paths like "LOG_FILE:/var/log/test.log")
|
||||
*/
|
||||
export const createSubscription = <T = any>(
|
||||
channel: GRAPHQL_PUBSUB_CHANNEL
|
||||
channel: GRAPHQL_PUBSUB_CHANNEL | string
|
||||
): AsyncIterableIterator<T> => {
|
||||
return pubsub.asyncIterableIterator<T>(channel);
|
||||
};
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { CacheModule } from '@nestjs/cache-manager';
|
||||
import { Test } from '@nestjs/testing';
|
||||
|
||||
import { describe, expect, it } from 'vitest';
|
||||
@@ -9,7 +10,7 @@ describe('Module Dependencies Integration', () => {
|
||||
let module;
|
||||
try {
|
||||
module = await Test.createTestingModule({
|
||||
imports: [RestModule],
|
||||
imports: [CacheModule.register({ isGlobal: true }), RestModule],
|
||||
}).compile();
|
||||
|
||||
expect(module).toBeDefined();
|
||||
|
||||
@@ -34,6 +34,15 @@ import { UnraidFileModifierModule } from '@app/unraid-api/unraid-file-modifier/u
|
||||
req: () => undefined,
|
||||
res: () => undefined,
|
||||
},
|
||||
formatters: {
|
||||
log: (obj) => {
|
||||
// Map NestJS context to Pino context field for pino-pretty
|
||||
if (obj.context && !obj.logger) {
|
||||
return { ...obj, logger: obj.context };
|
||||
}
|
||||
return obj;
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
AuthModule,
|
||||
|
||||
@@ -681,4 +681,104 @@ describe('ApiKeyService', () => {
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('convertRolesStringArrayToRoles', () => {
|
||||
beforeEach(async () => {
|
||||
vi.mocked(getters.paths).mockReturnValue({
|
||||
'auth-keys': mockBasePath,
|
||||
} as ReturnType<typeof getters.paths>);
|
||||
|
||||
// Create a fresh mock logger for each test
|
||||
mockLogger = {
|
||||
log: vi.fn(),
|
||||
error: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
debug: vi.fn(),
|
||||
verbose: vi.fn(),
|
||||
};
|
||||
|
||||
apiKeyService = new ApiKeyService();
|
||||
// Replace the logger with our mock
|
||||
(apiKeyService as any).logger = mockLogger;
|
||||
});
|
||||
|
||||
it('should convert uppercase role strings to Role enum values', () => {
|
||||
const roles = ['ADMIN', 'CONNECT', 'VIEWER'];
|
||||
const result = apiKeyService.convertRolesStringArrayToRoles(roles);
|
||||
|
||||
expect(result).toEqual([Role.ADMIN, Role.CONNECT, Role.VIEWER]);
|
||||
});
|
||||
|
||||
it('should convert lowercase role strings to Role enum values', () => {
|
||||
const roles = ['admin', 'connect', 'guest'];
|
||||
const result = apiKeyService.convertRolesStringArrayToRoles(roles);
|
||||
|
||||
expect(result).toEqual([Role.ADMIN, Role.CONNECT, Role.GUEST]);
|
||||
});
|
||||
|
||||
it('should convert mixed case role strings to Role enum values', () => {
|
||||
const roles = ['Admin', 'CoNnEcT', 'ViEwEr'];
|
||||
const result = apiKeyService.convertRolesStringArrayToRoles(roles);
|
||||
|
||||
expect(result).toEqual([Role.ADMIN, Role.CONNECT, Role.VIEWER]);
|
||||
});
|
||||
|
||||
it('should handle roles with whitespace', () => {
|
||||
const roles = [' ADMIN ', ' CONNECT ', 'VIEWER '];
|
||||
const result = apiKeyService.convertRolesStringArrayToRoles(roles);
|
||||
|
||||
expect(result).toEqual([Role.ADMIN, Role.CONNECT, Role.VIEWER]);
|
||||
});
|
||||
|
||||
it('should filter out invalid roles and warn', () => {
|
||||
const roles = ['ADMIN', 'INVALID_ROLE', 'VIEWER', 'ANOTHER_INVALID'];
|
||||
const result = apiKeyService.convertRolesStringArrayToRoles(roles);
|
||||
|
||||
expect(result).toEqual([Role.ADMIN, Role.VIEWER]);
|
||||
expect(mockLogger.warn).toHaveBeenCalledWith(
|
||||
'Ignoring invalid roles: INVALID_ROLE, ANOTHER_INVALID'
|
||||
);
|
||||
});
|
||||
|
||||
it('should return empty array when all roles are invalid', () => {
|
||||
const roles = ['INVALID1', 'INVALID2', 'INVALID3'];
|
||||
const result = apiKeyService.convertRolesStringArrayToRoles(roles);
|
||||
|
||||
expect(result).toEqual([]);
|
||||
expect(mockLogger.warn).toHaveBeenCalledWith(
|
||||
'Ignoring invalid roles: INVALID1, INVALID2, INVALID3'
|
||||
);
|
||||
});
|
||||
|
||||
it('should return empty array for empty input', () => {
|
||||
const result = apiKeyService.convertRolesStringArrayToRoles([]);
|
||||
|
||||
expect(result).toEqual([]);
|
||||
expect(mockLogger.warn).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle all valid Role enum values', () => {
|
||||
const roles = Object.values(Role);
|
||||
const result = apiKeyService.convertRolesStringArrayToRoles(roles);
|
||||
|
||||
expect(result).toEqual(Object.values(Role));
|
||||
expect(mockLogger.warn).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should deduplicate roles', () => {
|
||||
const roles = ['ADMIN', 'admin', 'ADMIN', 'VIEWER', 'viewer'];
|
||||
const result = apiKeyService.convertRolesStringArrayToRoles(roles);
|
||||
|
||||
// Note: Current implementation doesn't deduplicate, but this test documents the behavior
|
||||
expect(result).toEqual([Role.ADMIN, Role.ADMIN, Role.ADMIN, Role.VIEWER, Role.VIEWER]);
|
||||
});
|
||||
|
||||
it('should handle mixed valid and invalid roles correctly', () => {
|
||||
const roles = ['ADMIN', 'invalid', 'CONNECT', 'bad_role', 'GUEST', 'VIEWER'];
|
||||
const result = apiKeyService.convertRolesStringArrayToRoles(roles);
|
||||
|
||||
expect(result).toEqual([Role.ADMIN, Role.CONNECT, Role.GUEST, Role.VIEWER]);
|
||||
expect(mockLogger.warn).toHaveBeenCalledWith('Ignoring invalid roles: invalid, bad_role');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -35,11 +35,29 @@ export class ApiKeyService implements OnModuleInit {
|
||||
|
||||
async onModuleInit() {
|
||||
this.memoryApiKeys = await this.loadAllFromDisk();
|
||||
await this.cleanupLegacyInternalKeys();
|
||||
if (environment.IS_MAIN_PROCESS) {
|
||||
this.setupWatch();
|
||||
}
|
||||
}
|
||||
|
||||
private async cleanupLegacyInternalKeys() {
|
||||
const legacyNames = ['CliInternal', 'ConnectInternal'];
|
||||
const keysToDelete = this.memoryApiKeys.filter((key) => legacyNames.includes(key.name));
|
||||
|
||||
if (keysToDelete.length > 0) {
|
||||
try {
|
||||
await this.deleteApiKeys(keysToDelete.map((key) => key.id));
|
||||
this.logger.log(`Cleaned up ${keysToDelete.length} legacy internal keys`);
|
||||
} catch (error) {
|
||||
this.logger.debug(
|
||||
error,
|
||||
`Failed to delete legacy internal keys: ${keysToDelete.map((key) => key.name).join(', ')}`
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public async findAll(): Promise<ApiKey[]> {
|
||||
return this.memoryApiKeys;
|
||||
}
|
||||
@@ -92,9 +110,25 @@ export class ApiKeyService implements OnModuleInit {
|
||||
}
|
||||
|
||||
public convertRolesStringArrayToRoles(roles: string[]): Role[] {
|
||||
return roles
|
||||
.map((roleStr) => Role[roleStr.trim().toUpperCase() as keyof typeof Role])
|
||||
.filter(Boolean);
|
||||
const validRoles: Role[] = [];
|
||||
const invalidRoles: string[] = [];
|
||||
|
||||
for (const roleStr of roles) {
|
||||
const upperRole = roleStr.trim().toUpperCase();
|
||||
const role = Role[upperRole as keyof typeof Role];
|
||||
|
||||
if (role && ApiKeyService.validRoles.has(role)) {
|
||||
validRoles.push(role);
|
||||
} else {
|
||||
invalidRoles.push(roleStr);
|
||||
}
|
||||
}
|
||||
|
||||
if (invalidRoles.length > 0) {
|
||||
this.logger.warn(`Ignoring invalid roles: ${invalidRoles.join(', ')}`);
|
||||
}
|
||||
|
||||
return validRoles;
|
||||
}
|
||||
|
||||
async create({
|
||||
|
||||
192
api/src/unraid-api/cli/__test__/api-key.command.test.ts
Normal file
192
api/src/unraid-api/cli/__test__/api-key.command.test.ts
Normal file
@@ -0,0 +1,192 @@
|
||||
import { Test, TestingModule } from '@nestjs/testing';
|
||||
|
||||
import { InquirerService } from 'nest-commander';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { ApiKeyService } from '@app/unraid-api/auth/api-key.service.js';
|
||||
import { AddApiKeyQuestionSet } from '@app/unraid-api/cli/apikey/add-api-key.questions.js';
|
||||
import { ApiKeyCommand } from '@app/unraid-api/cli/apikey/api-key.command.js';
|
||||
import { LogService } from '@app/unraid-api/cli/log.service.js';
|
||||
|
||||
describe('ApiKeyCommand', () => {
|
||||
let command: ApiKeyCommand;
|
||||
let apiKeyService: ApiKeyService;
|
||||
let logService: LogService;
|
||||
let inquirerService: InquirerService;
|
||||
let questionSet: AddApiKeyQuestionSet;
|
||||
|
||||
beforeEach(async () => {
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
providers: [
|
||||
ApiKeyCommand,
|
||||
AddApiKeyQuestionSet,
|
||||
{
|
||||
provide: ApiKeyService,
|
||||
useValue: {
|
||||
findByField: vi.fn(),
|
||||
create: vi.fn(),
|
||||
findAll: vi.fn(),
|
||||
deleteApiKeys: vi.fn(),
|
||||
convertRolesStringArrayToRoles: vi.fn((roles) => roles),
|
||||
convertPermissionsStringArrayToPermissions: vi.fn((perms) => perms),
|
||||
getAllValidPermissions: vi.fn(() => []),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: LogService,
|
||||
useValue: {
|
||||
log: vi.fn(),
|
||||
error: vi.fn(),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: InquirerService,
|
||||
useValue: {
|
||||
prompt: vi.fn(),
|
||||
},
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
command = module.get<ApiKeyCommand>(ApiKeyCommand);
|
||||
apiKeyService = module.get<ApiKeyService>(ApiKeyService);
|
||||
logService = module.get<LogService>(LogService);
|
||||
inquirerService = module.get<InquirerService>(InquirerService);
|
||||
questionSet = module.get<AddApiKeyQuestionSet>(AddApiKeyQuestionSet);
|
||||
});
|
||||
|
||||
describe('AddApiKeyQuestionSet', () => {
|
||||
describe('shouldAskOverwrite', () => {
|
||||
it('should return true when an API key with the given name exists', () => {
|
||||
vi.mocked(apiKeyService.findByField).mockReturnValue({
|
||||
key: 'existing-key',
|
||||
name: 'test-key',
|
||||
description: 'Test key',
|
||||
roles: [],
|
||||
permissions: [],
|
||||
} as any);
|
||||
|
||||
const result = questionSet.shouldAskOverwrite({ name: 'test-key' });
|
||||
|
||||
expect(result).toBe(true);
|
||||
expect(apiKeyService.findByField).toHaveBeenCalledWith('name', 'test-key');
|
||||
});
|
||||
|
||||
it('should return false when no API key with the given name exists', () => {
|
||||
vi.mocked(apiKeyService.findByField).mockReturnValue(null);
|
||||
|
||||
const result = questionSet.shouldAskOverwrite({ name: 'non-existent-key' });
|
||||
|
||||
expect(result).toBe(false);
|
||||
expect(apiKeyService.findByField).toHaveBeenCalledWith('name', 'non-existent-key');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('run', () => {
|
||||
it('should find and return existing key when not creating', async () => {
|
||||
const mockKey = { key: 'test-api-key-123', name: 'test-key' };
|
||||
vi.mocked(apiKeyService.findByField).mockReturnValue(mockKey as any);
|
||||
|
||||
await command.run([], { name: 'test-key', create: false });
|
||||
|
||||
expect(apiKeyService.findByField).toHaveBeenCalledWith('name', 'test-key');
|
||||
expect(logService.log).toHaveBeenCalledWith('test-api-key-123');
|
||||
});
|
||||
|
||||
it('should create new key when key does not exist and create flag is set', async () => {
|
||||
vi.mocked(apiKeyService.findByField).mockReturnValue(null);
|
||||
vi.mocked(apiKeyService.create).mockResolvedValue({ key: 'new-api-key-456' } as any);
|
||||
|
||||
await command.run([], {
|
||||
name: 'new-key',
|
||||
create: true,
|
||||
roles: ['ADMIN'] as any,
|
||||
description: 'Test description',
|
||||
});
|
||||
|
||||
expect(apiKeyService.create).toHaveBeenCalledWith({
|
||||
name: 'new-key',
|
||||
description: 'Test description',
|
||||
roles: ['ADMIN'],
|
||||
permissions: undefined,
|
||||
overwrite: false,
|
||||
});
|
||||
expect(logService.log).toHaveBeenCalledWith('new-api-key-456');
|
||||
});
|
||||
|
||||
it('should error when key exists and overwrite is not set in non-interactive mode', async () => {
|
||||
const mockKey = { key: 'existing-key', name: 'test-key' };
|
||||
vi.mocked(apiKeyService.findByField)
|
||||
.mockReturnValueOnce(null) // First call in line 131
|
||||
.mockReturnValueOnce(mockKey as any); // Second call in non-interactive check
|
||||
const exitSpy = vi.spyOn(process, 'exit').mockImplementation(() => {
|
||||
throw new Error('process.exit');
|
||||
});
|
||||
|
||||
await expect(
|
||||
command.run([], {
|
||||
name: 'test-key',
|
||||
create: true,
|
||||
roles: ['ADMIN'] as any,
|
||||
})
|
||||
).rejects.toThrow();
|
||||
|
||||
expect(logService.error).toHaveBeenCalledWith(
|
||||
"API key with name 'test-key' already exists. Use --overwrite to replace it."
|
||||
);
|
||||
expect(exitSpy).toHaveBeenCalledWith(1);
|
||||
exitSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should create key with overwrite when key exists and overwrite is set', async () => {
|
||||
const mockKey = { key: 'existing-key', name: 'test-key' };
|
||||
vi.mocked(apiKeyService.findByField)
|
||||
.mockReturnValueOnce(null) // First call in line 131
|
||||
.mockReturnValueOnce(mockKey as any); // Second call in non-interactive check
|
||||
vi.mocked(apiKeyService.create).mockResolvedValue({ key: 'overwritten-key' } as any);
|
||||
|
||||
await command.run([], {
|
||||
name: 'test-key',
|
||||
create: true,
|
||||
roles: ['ADMIN'] as any,
|
||||
overwrite: true,
|
||||
});
|
||||
|
||||
expect(apiKeyService.create).toHaveBeenCalledWith({
|
||||
name: 'test-key',
|
||||
description: 'CLI generated key: test-key',
|
||||
roles: ['ADMIN'],
|
||||
permissions: undefined,
|
||||
overwrite: true,
|
||||
});
|
||||
expect(logService.log).toHaveBeenCalledWith('overwritten-key');
|
||||
});
|
||||
|
||||
it('should prompt for missing fields when creating without sufficient info', async () => {
|
||||
vi.mocked(apiKeyService.findByField).mockReturnValue(null);
|
||||
vi.mocked(inquirerService.prompt).mockResolvedValue({
|
||||
name: 'prompted-key',
|
||||
roles: ['USER'],
|
||||
permissions: [],
|
||||
description: 'Prompted description',
|
||||
overwrite: false,
|
||||
} as any);
|
||||
vi.mocked(apiKeyService.create).mockResolvedValue({ key: 'prompted-api-key' } as any);
|
||||
|
||||
await command.run([], { name: '', create: true });
|
||||
|
||||
expect(inquirerService.prompt).toHaveBeenCalledWith('add-api-key', {
|
||||
name: '',
|
||||
create: true,
|
||||
});
|
||||
expect(apiKeyService.create).toHaveBeenCalledWith({
|
||||
name: 'prompted-key',
|
||||
description: 'Prompted description',
|
||||
roles: ['USER'],
|
||||
permissions: [],
|
||||
overwrite: false,
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -39,6 +39,12 @@ export class AddApiKeyQuestionSet {
|
||||
return this.apiKeyService.convertRolesStringArrayToRoles(val);
|
||||
}
|
||||
|
||||
@WhenFor({ name: 'roles' })
|
||||
shouldAskRoles(options: { roles?: Role[]; permissions?: Permission[] }): boolean {
|
||||
// Ask for roles if they weren't provided or are empty
|
||||
return !options.roles || options.roles.length === 0;
|
||||
}
|
||||
|
||||
@ChoicesFor({ name: 'roles' })
|
||||
async getRoles() {
|
||||
return Object.values(Role);
|
||||
@@ -53,6 +59,12 @@ export class AddApiKeyQuestionSet {
|
||||
return this.apiKeyService.convertPermissionsStringArrayToPermissions(val);
|
||||
}
|
||||
|
||||
@WhenFor({ name: 'permissions' })
|
||||
shouldAskPermissions(options: { roles?: Role[]; permissions?: Permission[] }): boolean {
|
||||
// Ask for permissions if they weren't provided or are empty
|
||||
return !options.permissions || options.permissions.length === 0;
|
||||
}
|
||||
|
||||
@ChoicesFor({ name: 'permissions' })
|
||||
async getPermissions() {
|
||||
return this.apiKeyService
|
||||
@@ -72,6 +84,6 @@ export class AddApiKeyQuestionSet {
|
||||
|
||||
@WhenFor({ name: 'overwrite' })
|
||||
shouldAskOverwrite(options: { name: string }): boolean {
|
||||
return Boolean(this.apiKeyService.findByKey(options.name));
|
||||
return Boolean(this.apiKeyService.findByField('name', options.name));
|
||||
}
|
||||
}
|
||||
|
||||
434
api/src/unraid-api/cli/apikey/api-key.command.spec.ts
Normal file
434
api/src/unraid-api/cli/apikey/api-key.command.spec.ts
Normal file
@@ -0,0 +1,434 @@
|
||||
import { Test, TestingModule } from '@nestjs/testing';
|
||||
|
||||
import { AuthAction, Resource, Role } from '@unraid/shared/graphql.model.js';
|
||||
import { InquirerService } from 'nest-commander';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { ApiKeyService } from '@app/unraid-api/auth/api-key.service.js';
|
||||
import { ApiKeyCommand } from '@app/unraid-api/cli/apikey/api-key.command.js';
|
||||
import { LogService } from '@app/unraid-api/cli/log.service.js';
|
||||
|
||||
describe('ApiKeyCommand', () => {
|
||||
let command: ApiKeyCommand;
|
||||
let apiKeyService: ApiKeyService;
|
||||
let logService: LogService;
|
||||
|
||||
beforeEach(async () => {
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
providers: [
|
||||
ApiKeyCommand,
|
||||
{
|
||||
provide: ApiKeyService,
|
||||
useValue: {
|
||||
findByField: vi.fn(),
|
||||
create: vi.fn(),
|
||||
convertRolesStringArrayToRoles: vi.fn(),
|
||||
convertPermissionsStringArrayToPermissions: vi.fn(),
|
||||
findAll: vi.fn(),
|
||||
deleteApiKeys: vi.fn(),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: LogService,
|
||||
useValue: {
|
||||
log: vi.fn(),
|
||||
error: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: InquirerService,
|
||||
useValue: {
|
||||
prompt: vi.fn(),
|
||||
},
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
command = module.get<ApiKeyCommand>(ApiKeyCommand);
|
||||
apiKeyService = module.get<ApiKeyService>(ApiKeyService);
|
||||
logService = module.get<LogService>(LogService);
|
||||
});
|
||||
|
||||
describe('parseRoles', () => {
|
||||
it('should parse valid roles correctly', () => {
|
||||
const mockConvert = vi
|
||||
.spyOn(apiKeyService, 'convertRolesStringArrayToRoles')
|
||||
.mockReturnValue([Role.ADMIN, Role.CONNECT]);
|
||||
|
||||
const result = command.parseRoles('ADMIN,CONNECT');
|
||||
|
||||
expect(mockConvert).toHaveBeenCalledWith(['ADMIN', 'CONNECT']);
|
||||
expect(result).toEqual([Role.ADMIN, Role.CONNECT]);
|
||||
});
|
||||
|
||||
it('should return GUEST role when no roles provided', () => {
|
||||
const result = command.parseRoles('');
|
||||
|
||||
expect(result).toEqual([Role.GUEST]);
|
||||
});
|
||||
|
||||
it('should handle roles with spaces', () => {
|
||||
const mockConvert = vi
|
||||
.spyOn(apiKeyService, 'convertRolesStringArrayToRoles')
|
||||
.mockReturnValue([Role.ADMIN, Role.VIEWER]);
|
||||
|
||||
const result = command.parseRoles('ADMIN, VIEWER');
|
||||
|
||||
expect(mockConvert).toHaveBeenCalledWith(['ADMIN', ' VIEWER']);
|
||||
expect(result).toEqual([Role.ADMIN, Role.VIEWER]);
|
||||
});
|
||||
|
||||
it('should throw error when no valid roles found', () => {
|
||||
vi.spyOn(apiKeyService, 'convertRolesStringArrayToRoles').mockReturnValue([]);
|
||||
|
||||
expect(() => command.parseRoles('INVALID_ROLE')).toThrow(
|
||||
`Invalid roles. Valid options are: ${Object.values(Role).join(', ')}`
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle mixed valid and invalid roles with warning', () => {
|
||||
const mockConvert = vi
|
||||
.spyOn(apiKeyService, 'convertRolesStringArrayToRoles')
|
||||
.mockImplementation((roles) => {
|
||||
const validRoles: Role[] = [];
|
||||
const invalidRoles: string[] = [];
|
||||
|
||||
for (const roleStr of roles) {
|
||||
const upperRole = roleStr.trim().toUpperCase();
|
||||
const role = Role[upperRole as keyof typeof Role];
|
||||
|
||||
if (role) {
|
||||
validRoles.push(role);
|
||||
} else {
|
||||
invalidRoles.push(roleStr);
|
||||
}
|
||||
}
|
||||
|
||||
if (invalidRoles.length > 0) {
|
||||
logService.warn(`Ignoring invalid roles: ${invalidRoles.join(', ')}`);
|
||||
}
|
||||
|
||||
return validRoles;
|
||||
});
|
||||
|
||||
const result = command.parseRoles('ADMIN,INVALID,VIEWER');
|
||||
|
||||
expect(mockConvert).toHaveBeenCalledWith(['ADMIN', 'INVALID', 'VIEWER']);
|
||||
expect(logService.warn).toHaveBeenCalledWith('Ignoring invalid roles: INVALID');
|
||||
expect(result).toEqual([Role.ADMIN, Role.VIEWER]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('run', () => {
|
||||
it('should create API key with roles without prompting', async () => {
|
||||
const mockKey = {
|
||||
id: 'test-id',
|
||||
key: 'test-key-123',
|
||||
name: 'TEST',
|
||||
roles: [Role.ADMIN],
|
||||
createdAt: new Date().toISOString(),
|
||||
permissions: [],
|
||||
};
|
||||
vi.spyOn(apiKeyService, 'findByField').mockReturnValue(null);
|
||||
vi.spyOn(apiKeyService, 'create').mockResolvedValue(mockKey);
|
||||
|
||||
await command.run([], {
|
||||
name: 'TEST',
|
||||
create: true,
|
||||
roles: [Role.ADMIN],
|
||||
permissions: undefined,
|
||||
description: 'Test description',
|
||||
});
|
||||
|
||||
expect(apiKeyService.create).toHaveBeenCalledWith({
|
||||
name: 'TEST',
|
||||
description: 'Test description',
|
||||
roles: [Role.ADMIN],
|
||||
permissions: undefined,
|
||||
overwrite: false,
|
||||
});
|
||||
expect(logService.log).toHaveBeenCalledWith('test-key-123');
|
||||
});
|
||||
|
||||
it('should create API key with permissions only without prompting', async () => {
|
||||
const mockKey = {
|
||||
id: 'test-id',
|
||||
key: 'test-key-456',
|
||||
name: 'TEST_PERMS',
|
||||
roles: [],
|
||||
createdAt: new Date().toISOString(),
|
||||
permissions: [],
|
||||
};
|
||||
const mockPermissions = [
|
||||
{
|
||||
resource: Resource.DOCKER,
|
||||
actions: [AuthAction.READ_ANY],
|
||||
},
|
||||
];
|
||||
|
||||
vi.spyOn(apiKeyService, 'findByField').mockReturnValue(null);
|
||||
vi.spyOn(apiKeyService, 'create').mockResolvedValue(mockKey);
|
||||
|
||||
await command.run([], {
|
||||
name: 'TEST_PERMS',
|
||||
create: true,
|
||||
roles: undefined,
|
||||
permissions: mockPermissions,
|
||||
description: 'Test with permissions',
|
||||
});
|
||||
|
||||
expect(apiKeyService.create).toHaveBeenCalledWith({
|
||||
name: 'TEST_PERMS',
|
||||
description: 'Test with permissions',
|
||||
roles: undefined,
|
||||
permissions: mockPermissions,
|
||||
overwrite: false,
|
||||
});
|
||||
expect(logService.log).toHaveBeenCalledWith('test-key-456');
|
||||
});
|
||||
|
||||
it('should use default description when not provided', async () => {
|
||||
const mockKey = {
|
||||
id: 'test-id',
|
||||
key: 'test-key-789',
|
||||
name: 'NO_DESC',
|
||||
roles: [Role.VIEWER],
|
||||
createdAt: new Date().toISOString(),
|
||||
permissions: [],
|
||||
};
|
||||
vi.spyOn(apiKeyService, 'findByField').mockReturnValue(null);
|
||||
vi.spyOn(apiKeyService, 'create').mockResolvedValue(mockKey);
|
||||
|
||||
await command.run([], {
|
||||
name: 'NO_DESC',
|
||||
create: true,
|
||||
roles: [Role.VIEWER],
|
||||
permissions: undefined,
|
||||
});
|
||||
|
||||
expect(apiKeyService.create).toHaveBeenCalledWith({
|
||||
name: 'NO_DESC',
|
||||
description: 'CLI generated key: NO_DESC',
|
||||
roles: [Role.VIEWER],
|
||||
permissions: undefined,
|
||||
overwrite: false,
|
||||
});
|
||||
});
|
||||
|
||||
it('should return existing key when found', async () => {
|
||||
const existingKey = {
|
||||
id: 'existing-id',
|
||||
key: 'existing-key-123',
|
||||
name: 'EXISTING',
|
||||
roles: [Role.ADMIN],
|
||||
createdAt: new Date().toISOString(),
|
||||
permissions: [],
|
||||
};
|
||||
vi.spyOn(apiKeyService, 'findByField').mockReturnValue(existingKey);
|
||||
|
||||
await command.run([], {
|
||||
name: 'EXISTING',
|
||||
create: false,
|
||||
});
|
||||
|
||||
expect(apiKeyService.findByField).toHaveBeenCalledWith('name', 'EXISTING');
|
||||
expect(logService.log).toHaveBeenCalledWith('existing-key-123');
|
||||
expect(apiKeyService.create).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle uppercase role conversion', () => {
|
||||
const mockConvert = vi
|
||||
.spyOn(apiKeyService, 'convertRolesStringArrayToRoles')
|
||||
.mockImplementation((roles) => {
|
||||
return roles
|
||||
.map((roleStr) => Role[roleStr.trim().toUpperCase() as keyof typeof Role])
|
||||
.filter(Boolean);
|
||||
});
|
||||
|
||||
const result = command.parseRoles('admin,connect');
|
||||
|
||||
expect(mockConvert).toHaveBeenCalledWith(['admin', 'connect']);
|
||||
expect(result).toEqual([Role.ADMIN, Role.CONNECT]);
|
||||
});
|
||||
|
||||
it('should handle lowercase role conversion', () => {
|
||||
const mockConvert = vi
|
||||
.spyOn(apiKeyService, 'convertRolesStringArrayToRoles')
|
||||
.mockImplementation((roles) => {
|
||||
return roles
|
||||
.map((roleStr) => Role[roleStr.trim().toUpperCase() as keyof typeof Role])
|
||||
.filter(Boolean);
|
||||
});
|
||||
|
||||
const result = command.parseRoles('viewer');
|
||||
|
||||
expect(mockConvert).toHaveBeenCalledWith(['viewer']);
|
||||
expect(result).toEqual([Role.VIEWER]);
|
||||
});
|
||||
|
||||
it('should handle mixed case role conversion', () => {
|
||||
const mockConvert = vi
|
||||
.spyOn(apiKeyService, 'convertRolesStringArrayToRoles')
|
||||
.mockImplementation((roles) => {
|
||||
return roles
|
||||
.map((roleStr) => Role[roleStr.trim().toUpperCase() as keyof typeof Role])
|
||||
.filter(Boolean);
|
||||
});
|
||||
|
||||
const result = command.parseRoles('Admin,CoNnEcT');
|
||||
|
||||
expect(mockConvert).toHaveBeenCalledWith(['Admin', 'CoNnEcT']);
|
||||
expect(result).toEqual([Role.ADMIN, Role.CONNECT]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('JSON output functionality', () => {
|
||||
let consoleSpy: ReturnType<typeof vi.spyOn>;
|
||||
|
||||
beforeEach(() => {
|
||||
consoleSpy = vi.spyOn(console, 'log').mockImplementation(() => {});
|
||||
});
|
||||
|
||||
it('should output JSON when creating key with --json flag', async () => {
|
||||
const mockKey = {
|
||||
id: 'test-id-123',
|
||||
key: 'test-key-456',
|
||||
name: 'JSON_TEST',
|
||||
roles: [Role.ADMIN],
|
||||
createdAt: new Date().toISOString(),
|
||||
permissions: [],
|
||||
};
|
||||
vi.spyOn(apiKeyService, 'findByField').mockReturnValue(null);
|
||||
vi.spyOn(apiKeyService, 'create').mockResolvedValue(mockKey);
|
||||
|
||||
await command.run([], {
|
||||
name: 'JSON_TEST',
|
||||
create: true,
|
||||
roles: [Role.ADMIN],
|
||||
json: true,
|
||||
});
|
||||
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
JSON.stringify({ key: 'test-key-456', name: 'JSON_TEST', id: 'test-id-123' })
|
||||
);
|
||||
expect(logService.log).not.toHaveBeenCalledWith('test-key-456');
|
||||
});
|
||||
|
||||
it('should output JSON when fetching existing key with --json flag', async () => {
|
||||
const existingKey = {
|
||||
id: 'existing-id-456',
|
||||
key: 'existing-key-789',
|
||||
name: 'EXISTING_JSON',
|
||||
roles: [Role.VIEWER],
|
||||
createdAt: new Date().toISOString(),
|
||||
permissions: [],
|
||||
};
|
||||
vi.spyOn(apiKeyService, 'findByField').mockReturnValue(existingKey);
|
||||
|
||||
await command.run([], {
|
||||
name: 'EXISTING_JSON',
|
||||
create: false,
|
||||
json: true,
|
||||
});
|
||||
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
JSON.stringify({ key: 'existing-key-789', name: 'EXISTING_JSON', id: 'existing-id-456' })
|
||||
);
|
||||
expect(logService.log).not.toHaveBeenCalledWith('existing-key-789');
|
||||
});
|
||||
|
||||
it('should output JSON when deleting key with --json flag', async () => {
|
||||
const existingKeys = [
|
||||
{
|
||||
id: 'delete-id-123',
|
||||
name: 'DELETE_JSON',
|
||||
key: 'delete-key-456',
|
||||
roles: [Role.GUEST],
|
||||
createdAt: new Date().toISOString(),
|
||||
permissions: [],
|
||||
},
|
||||
];
|
||||
vi.spyOn(apiKeyService, 'findAll').mockResolvedValue(existingKeys);
|
||||
vi.spyOn(apiKeyService, 'deleteApiKeys').mockResolvedValue();
|
||||
|
||||
await command.run([], {
|
||||
name: 'DELETE_JSON',
|
||||
delete: true,
|
||||
json: true,
|
||||
});
|
||||
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
JSON.stringify({
|
||||
deleted: 1,
|
||||
keys: [{ id: 'delete-id-123', name: 'DELETE_JSON' }],
|
||||
})
|
||||
);
|
||||
expect(logService.log).not.toHaveBeenCalledWith('Successfully deleted 1 API key');
|
||||
});
|
||||
|
||||
it('should output JSON error when deleting non-existent key with --json flag', async () => {
|
||||
vi.spyOn(apiKeyService, 'findAll').mockResolvedValue([]);
|
||||
|
||||
await command.run([], {
|
||||
name: 'NONEXISTENT',
|
||||
delete: true,
|
||||
json: true,
|
||||
});
|
||||
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
JSON.stringify({ deleted: 0, message: 'No API keys found to delete' })
|
||||
);
|
||||
expect(logService.log).not.toHaveBeenCalledWith('No API keys found to delete');
|
||||
});
|
||||
|
||||
it('should not suppress creation message when not using JSON', async () => {
|
||||
const mockKey = {
|
||||
id: 'test-id',
|
||||
key: 'test-key',
|
||||
name: 'NO_JSON_TEST',
|
||||
roles: [Role.ADMIN],
|
||||
createdAt: new Date().toISOString(),
|
||||
permissions: [],
|
||||
};
|
||||
vi.spyOn(apiKeyService, 'findByField').mockReturnValue(null);
|
||||
vi.spyOn(apiKeyService, 'create').mockResolvedValue(mockKey);
|
||||
|
||||
await command.run([], {
|
||||
name: 'NO_JSON_TEST',
|
||||
create: true,
|
||||
roles: [Role.ADMIN],
|
||||
json: false,
|
||||
});
|
||||
|
||||
expect(logService.log).toHaveBeenCalledWith('Creating API Key...');
|
||||
expect(logService.log).toHaveBeenCalledWith('test-key');
|
||||
expect(consoleSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should suppress creation message when using JSON', async () => {
|
||||
const mockKey = {
|
||||
id: 'test-id',
|
||||
key: 'test-key',
|
||||
name: 'JSON_SUPPRESS_TEST',
|
||||
roles: [Role.ADMIN],
|
||||
createdAt: new Date().toISOString(),
|
||||
permissions: [],
|
||||
};
|
||||
vi.spyOn(apiKeyService, 'findByField').mockReturnValue(null);
|
||||
vi.spyOn(apiKeyService, 'create').mockResolvedValue(mockKey);
|
||||
|
||||
await command.run([], {
|
||||
name: 'JSON_SUPPRESS_TEST',
|
||||
create: true,
|
||||
roles: [Role.ADMIN],
|
||||
json: true,
|
||||
});
|
||||
|
||||
expect(logService.log).not.toHaveBeenCalledWith('Creating API Key...');
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
JSON.stringify({ key: 'test-key', name: 'JSON_SUPPRESS_TEST', id: 'test-id' })
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -10,11 +10,13 @@ import { Permission } from '@app/unraid-api/graph/resolvers/api-key/api-key.mode
|
||||
|
||||
interface KeyOptions {
|
||||
name: string;
|
||||
create: boolean;
|
||||
create?: boolean;
|
||||
delete?: boolean;
|
||||
description?: string;
|
||||
roles?: Role[];
|
||||
permissions?: Permission[];
|
||||
overwrite?: boolean;
|
||||
json?: boolean;
|
||||
}
|
||||
|
||||
@Command({
|
||||
@@ -52,22 +54,15 @@ export class ApiKeyCommand extends CommandRunner {
|
||||
})
|
||||
parseRoles(roles: string): Role[] {
|
||||
if (!roles) return [Role.GUEST];
|
||||
const validRoles: Set<Role> = new Set(Object.values(Role));
|
||||
|
||||
const requestedRoles = roles.split(',').map((role) => role.trim().toLocaleLowerCase() as Role);
|
||||
const validRequestedRoles = requestedRoles.filter((role) => validRoles.has(role));
|
||||
const roleArray = roles.split(',').filter(Boolean);
|
||||
const validRoles = this.apiKeyService.convertRolesStringArrayToRoles(roleArray);
|
||||
|
||||
if (validRequestedRoles.length === 0) {
|
||||
throw new Error(`Invalid roles. Valid options are: ${Array.from(validRoles).join(', ')}`);
|
||||
if (validRoles.length === 0) {
|
||||
throw new Error(`Invalid roles. Valid options are: ${Object.values(Role).join(', ')}`);
|
||||
}
|
||||
|
||||
const invalidRoles = requestedRoles.filter((role) => !validRoles.has(role));
|
||||
|
||||
if (invalidRoles.length > 0) {
|
||||
this.logger.warn(`Ignoring invalid roles: ${invalidRoles.join(', ')}`);
|
||||
}
|
||||
|
||||
return validRequestedRoles;
|
||||
return validRoles;
|
||||
}
|
||||
|
||||
@Option({
|
||||
@@ -98,48 +93,137 @@ ACTIONS: ${Object.values(AuthAction).join(', ')}`,
|
||||
return true;
|
||||
}
|
||||
|
||||
/** Prompt the user to select API keys to delete. Then, delete the selected keys. */
|
||||
private async deleteKeys() {
|
||||
@Option({
|
||||
flags: '--overwrite',
|
||||
description: 'Overwrite existing API key if it exists',
|
||||
})
|
||||
parseOverwrite(): boolean {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Option({
|
||||
flags: '--json',
|
||||
description: 'Output machine-readable JSON format',
|
||||
})
|
||||
parseJson(): boolean {
|
||||
return true;
|
||||
}
|
||||
|
||||
/** Helper to output either JSON or regular log message */
|
||||
private output(message: string, jsonData?: object, jsonOutput?: boolean): void {
|
||||
if (jsonOutput && jsonData) {
|
||||
console.log(JSON.stringify(jsonData));
|
||||
} else {
|
||||
this.logger.log(message);
|
||||
}
|
||||
}
|
||||
|
||||
/** Helper to output either JSON or regular error message */
|
||||
private outputError(message: string, jsonData?: object, jsonOutput?: boolean): void {
|
||||
if (jsonOutput && jsonData) {
|
||||
console.log(JSON.stringify(jsonData));
|
||||
} else {
|
||||
this.logger.error(message);
|
||||
}
|
||||
}
|
||||
|
||||
/** Delete API keys either by name (non-interactive) or by prompting user selection (interactive). */
|
||||
private async deleteKeys(name?: string, jsonOutput?: boolean) {
|
||||
const allKeys = await this.apiKeyService.findAll();
|
||||
if (allKeys.length === 0) {
|
||||
this.logger.log('No API keys found to delete');
|
||||
this.output(
|
||||
'No API keys found to delete',
|
||||
{ deleted: 0, message: 'No API keys found to delete' },
|
||||
jsonOutput
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const answers = await this.inquirerService.prompt<DeleteApiKeyAnswers>(
|
||||
DeleteApiKeyQuestionSet.name,
|
||||
{}
|
||||
);
|
||||
if (!answers.selectedKeys || answers.selectedKeys.length === 0) {
|
||||
this.logger.log('No keys selected for deletion');
|
||||
return;
|
||||
let selectedKeyIds: string[];
|
||||
let deletedKeys: { id: string; name: string }[] = [];
|
||||
|
||||
if (name) {
|
||||
// Non-interactive mode: delete by name
|
||||
const keyToDelete = allKeys.find((key) => key.name === name);
|
||||
if (!keyToDelete) {
|
||||
this.outputError(
|
||||
`No API key found with name: ${name}`,
|
||||
{ deleted: 0, error: `No API key found with name: ${name}` },
|
||||
jsonOutput
|
||||
);
|
||||
process.exit(1);
|
||||
}
|
||||
selectedKeyIds = [keyToDelete.id];
|
||||
deletedKeys = [{ id: keyToDelete.id, name: keyToDelete.name }];
|
||||
} else {
|
||||
// Interactive mode: prompt user to select keys
|
||||
const answers = await this.inquirerService.prompt<DeleteApiKeyAnswers>(
|
||||
DeleteApiKeyQuestionSet.name,
|
||||
{}
|
||||
);
|
||||
if (!answers.selectedKeys || answers.selectedKeys.length === 0) {
|
||||
this.output(
|
||||
'No keys selected for deletion',
|
||||
{ deleted: 0, message: 'No keys selected for deletion' },
|
||||
jsonOutput
|
||||
);
|
||||
return;
|
||||
}
|
||||
selectedKeyIds = answers.selectedKeys;
|
||||
deletedKeys = allKeys
|
||||
.filter((key) => selectedKeyIds.includes(key.id))
|
||||
.map((key) => ({ id: key.id, name: key.name }));
|
||||
}
|
||||
|
||||
try {
|
||||
await this.apiKeyService.deleteApiKeys(answers.selectedKeys);
|
||||
this.logger.log(`Successfully deleted ${answers.selectedKeys.length} API keys`);
|
||||
await this.apiKeyService.deleteApiKeys(selectedKeyIds);
|
||||
const message = `Successfully deleted ${selectedKeyIds.length} API key${selectedKeyIds.length === 1 ? '' : 's'}`;
|
||||
this.output(message, { deleted: selectedKeyIds.length, keys: deletedKeys }, jsonOutput);
|
||||
} catch (error) {
|
||||
this.logger.error(error as any);
|
||||
const errorMessage = error instanceof Error ? error.message : String(error);
|
||||
this.outputError(errorMessage, { deleted: 0, error: errorMessage }, jsonOutput);
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
async run(
|
||||
_: string[],
|
||||
options: KeyOptions = { create: false, name: '', delete: false }
|
||||
): Promise<void> {
|
||||
async run(_: string[], options: KeyOptions = { name: '', delete: false }): Promise<void> {
|
||||
try {
|
||||
if (options.delete) {
|
||||
await this.deleteKeys();
|
||||
await this.deleteKeys(options.name, options.json);
|
||||
return;
|
||||
}
|
||||
|
||||
const key = this.apiKeyService.findByField('name', options.name);
|
||||
if (key) {
|
||||
this.logger.log(key.key);
|
||||
} else if (options.create) {
|
||||
options = await this.inquirerService.prompt(AddApiKeyQuestionSet.name, options);
|
||||
this.logger.log('Creating API Key...' + JSON.stringify(options));
|
||||
this.output(key.key, { key: key.key, name: key.name, id: key.id }, options.json);
|
||||
} else if (options.create === true) {
|
||||
// Check if we have minimum required info from flags (name + at least one role or permission)
|
||||
const hasMinimumInfo =
|
||||
options.name &&
|
||||
((options.roles && options.roles.length > 0) ||
|
||||
(options.permissions && options.permissions.length > 0));
|
||||
|
||||
if (!hasMinimumInfo) {
|
||||
// Interactive mode - prompt for missing fields
|
||||
options = await this.inquirerService.prompt(AddApiKeyQuestionSet.name, options);
|
||||
} else {
|
||||
// Non-interactive mode - check if key exists and handle overwrite
|
||||
const existingKey = this.apiKeyService.findByField('name', options.name);
|
||||
if (existingKey && !options.overwrite) {
|
||||
this.outputError(
|
||||
`API key with name '${options.name}' already exists. Use --overwrite to replace it.`,
|
||||
{
|
||||
error: `API key with name '${options.name}' already exists. Use --overwrite to replace it.`,
|
||||
},
|
||||
options.json
|
||||
);
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
if (!options.json) {
|
||||
this.logger.log('Creating API Key...');
|
||||
}
|
||||
|
||||
if (!options.roles && !options.permissions) {
|
||||
this.logger.error('Please add at least one role or permission to the key.');
|
||||
@@ -154,10 +238,10 @@ ACTIONS: ${Object.values(AuthAction).join(', ')}`,
|
||||
description: options.description || `CLI generated key: ${options.name}`,
|
||||
roles: options.roles,
|
||||
permissions: options.permissions,
|
||||
overwrite: true,
|
||||
overwrite: options.overwrite ?? false,
|
||||
});
|
||||
|
||||
this.logger.log(key.key);
|
||||
this.output(key.key, { key: key.key, name: key.name, id: key.id }, options.json);
|
||||
} else {
|
||||
this.logger.log('No Key Found');
|
||||
process.exit(1);
|
||||
|
||||
@@ -448,6 +448,20 @@ export enum ConfigErrorState {
|
||||
WITHDRAWN = 'WITHDRAWN'
|
||||
}
|
||||
|
||||
export type ConfigFile = {
|
||||
__typename?: 'ConfigFile';
|
||||
content: Scalars['String']['output'];
|
||||
name: Scalars['String']['output'];
|
||||
path: Scalars['String']['output'];
|
||||
/** Human-readable file size (e.g., "1.5 KB", "2.3 MB") */
|
||||
sizeReadable: Scalars['String']['output'];
|
||||
};
|
||||
|
||||
export type ConfigFilesResponse = {
|
||||
__typename?: 'ConfigFilesResponse';
|
||||
files: Array<ConfigFile>;
|
||||
};
|
||||
|
||||
export type Connect = Node & {
|
||||
__typename?: 'Connect';
|
||||
/** The status of dynamic remote access */
|
||||
@@ -1432,6 +1446,14 @@ export type OidcAuthorizationRule = {
|
||||
value: Array<Scalars['String']['output']>;
|
||||
};
|
||||
|
||||
export type OidcConfiguration = {
|
||||
__typename?: 'OidcConfiguration';
|
||||
/** Default allowed redirect origins that apply to all OIDC providers (e.g., Tailscale domains) */
|
||||
defaultAllowedOrigins?: Maybe<Array<Scalars['String']['output']>>;
|
||||
/** List of configured OIDC providers */
|
||||
providers: Array<OidcProvider>;
|
||||
};
|
||||
|
||||
export type OidcProvider = {
|
||||
__typename?: 'OidcProvider';
|
||||
/** OAuth2 authorization endpoint URL. If omitted, will be auto-discovered from issuer/.well-known/openid-configuration */
|
||||
@@ -1455,7 +1477,7 @@ export type OidcProvider = {
|
||||
/** The unique identifier for the OIDC provider */
|
||||
id: Scalars['PrefixedID']['output'];
|
||||
/** OIDC issuer URL (e.g., https://accounts.google.com). Required for auto-discovery via /.well-known/openid-configuration */
|
||||
issuer: Scalars['String']['output'];
|
||||
issuer?: Maybe<Scalars['String']['output']>;
|
||||
/** JSON Web Key Set URI for token validation. If omitted, will be auto-discovered from issuer/.well-known/openid-configuration */
|
||||
jwksUri?: Maybe<Scalars['String']['output']>;
|
||||
/** Display name of the OIDC provider */
|
||||
@@ -1623,6 +1645,7 @@ export type PublicPartnerInfo = {
|
||||
|
||||
export type Query = {
|
||||
__typename?: 'Query';
|
||||
allConfigFiles: ConfigFilesResponse;
|
||||
apiKey?: Maybe<ApiKey>;
|
||||
/** All possible permissions for API keys */
|
||||
apiKeyPossiblePermissions: Array<Permission>;
|
||||
@@ -1632,6 +1655,7 @@ export type Query = {
|
||||
array: UnraidArray;
|
||||
cloud: Cloud;
|
||||
config: Config;
|
||||
configFile?: Maybe<ConfigFile>;
|
||||
connect: Connect;
|
||||
customization?: Maybe<Customization>;
|
||||
disk: Disk;
|
||||
@@ -1654,6 +1678,8 @@ export type Query = {
|
||||
network: Network;
|
||||
/** Get all notifications */
|
||||
notifications: Notifications;
|
||||
/** Get the full OIDC configuration (admin only) */
|
||||
oidcConfiguration: OidcConfiguration;
|
||||
/** Get a specific OIDC provider by ID */
|
||||
oidcProvider?: Maybe<OidcProvider>;
|
||||
/** Get all configured OIDC providers (admin only) */
|
||||
@@ -1693,6 +1719,11 @@ export type QueryApiKeyArgs = {
|
||||
};
|
||||
|
||||
|
||||
export type QueryConfigFileArgs = {
|
||||
name: Scalars['String']['input'];
|
||||
};
|
||||
|
||||
|
||||
export type QueryDiskArgs = {
|
||||
id: Scalars['PrefixedID']['input'];
|
||||
};
|
||||
@@ -1933,6 +1964,7 @@ export type Server = Node & {
|
||||
name: Scalars['String']['output'];
|
||||
owner: ProfileModel;
|
||||
remoteurl: Scalars['String']['output'];
|
||||
/** Whether this server is online or offline */
|
||||
status: ServerStatus;
|
||||
wanip: Scalars['String']['output'];
|
||||
};
|
||||
|
||||
@@ -1,9 +1,23 @@
|
||||
import { Command, CommandRunner } from 'nest-commander';
|
||||
import { Command, CommandRunner, Option } from 'nest-commander';
|
||||
|
||||
import { ECOSYSTEM_PATH } from '@app/environment.js';
|
||||
import type { LogLevel } from '@app/core/log.js';
|
||||
import { levels } from '@app/core/log.js';
|
||||
import { ECOSYSTEM_PATH, LOG_LEVEL } from '@app/environment.js';
|
||||
import { LogService } from '@app/unraid-api/cli/log.service.js';
|
||||
import { PM2Service } from '@app/unraid-api/cli/pm2.service.js';
|
||||
|
||||
export interface LogLevelOptions {
|
||||
logLevel?: LogLevel;
|
||||
}
|
||||
|
||||
export function parseLogLevelOption(val: string, allowedLevels: string[] = [...levels]): LogLevel {
|
||||
const normalized = val.toLowerCase() as LogLevel;
|
||||
if (allowedLevels.includes(normalized)) {
|
||||
return normalized;
|
||||
}
|
||||
throw new Error(`Invalid --log-level "${val}". Allowed: ${allowedLevels.join(', ')}`);
|
||||
}
|
||||
|
||||
@Command({ name: 'restart', description: 'Restart the Unraid API' })
|
||||
export class RestartCommand extends CommandRunner {
|
||||
constructor(
|
||||
@@ -13,11 +27,12 @@ export class RestartCommand extends CommandRunner {
|
||||
super();
|
||||
}
|
||||
|
||||
async run(): Promise<void> {
|
||||
async run(_?: string[], options: LogLevelOptions = {}): Promise<void> {
|
||||
try {
|
||||
this.logger.info('Restarting the Unraid API...');
|
||||
const env = { LOG_LEVEL: options.logLevel };
|
||||
const { stderr, stdout } = await this.pm2.run(
|
||||
{ tag: 'PM2 Restart', raw: true },
|
||||
{ tag: 'PM2 Restart', raw: true, extendEnv: true, env },
|
||||
'restart',
|
||||
ECOSYSTEM_PATH,
|
||||
'--update-env'
|
||||
@@ -40,4 +55,13 @@ export class RestartCommand extends CommandRunner {
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
@Option({
|
||||
flags: `--log-level <${levels.join('|')}>`,
|
||||
description: 'log level to use',
|
||||
defaultValue: LOG_LEVEL.toLowerCase(),
|
||||
})
|
||||
parseLogLevel(val: string): LogLevel {
|
||||
return parseLogLevelOption(val);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
import { Command, CommandRunner, Option } from 'nest-commander';
|
||||
|
||||
import type { LogLevel } from '@app/core/log.js';
|
||||
import type { LogLevelOptions } from '@app/unraid-api/cli/restart.command.js';
|
||||
import { levels } from '@app/core/log.js';
|
||||
import { ECOSYSTEM_PATH } from '@app/environment.js';
|
||||
import { ECOSYSTEM_PATH, LOG_LEVEL } from '@app/environment.js';
|
||||
import { LogService } from '@app/unraid-api/cli/log.service.js';
|
||||
import { PM2Service } from '@app/unraid-api/cli/pm2.service.js';
|
||||
|
||||
interface StartCommandOptions {
|
||||
'log-level'?: string;
|
||||
}
|
||||
import { parseLogLevelOption } from '@app/unraid-api/cli/restart.command.js';
|
||||
|
||||
@Command({ name: 'start', description: 'Start the Unraid API' })
|
||||
export class StartCommand extends CommandRunner {
|
||||
@@ -27,17 +25,12 @@ export class StartCommand extends CommandRunner {
|
||||
await this.pm2.run({ tag: 'PM2 Delete' }, 'delete', ECOSYSTEM_PATH);
|
||||
}
|
||||
|
||||
async run(_: string[], options: StartCommandOptions): Promise<void> {
|
||||
async run(_: string[], options: LogLevelOptions): Promise<void> {
|
||||
this.logger.info('Starting the Unraid API');
|
||||
await this.cleanupPM2State();
|
||||
|
||||
const env: Record<string, string> = {};
|
||||
if (options['log-level']) {
|
||||
env.LOG_LEVEL = options['log-level'];
|
||||
}
|
||||
|
||||
const env = { LOG_LEVEL: options.logLevel };
|
||||
const { stderr, stdout } = await this.pm2.run(
|
||||
{ tag: 'PM2 Start', env, raw: true },
|
||||
{ tag: 'PM2 Start', raw: true, extendEnv: true, env },
|
||||
'start',
|
||||
ECOSYSTEM_PATH,
|
||||
'--update-env'
|
||||
@@ -54,9 +47,9 @@ export class StartCommand extends CommandRunner {
|
||||
@Option({
|
||||
flags: `--log-level <${levels.join('|')}>`,
|
||||
description: 'log level to use',
|
||||
defaultValue: 'info',
|
||||
defaultValue: LOG_LEVEL.toLowerCase(),
|
||||
})
|
||||
parseLogLevel(val: string): LogLevel {
|
||||
return levels.includes(val as LogLevel) ? (val as LogLevel) : 'info';
|
||||
return parseLogLevelOption(val);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,6 +27,16 @@ export class CpuLoad {
|
||||
description: 'The percentage of time the CPU spent servicing hardware interrupts.',
|
||||
})
|
||||
percentIrq!: number;
|
||||
|
||||
@Field(() => Float, {
|
||||
description: 'The percentage of time the CPU spent running virtual machines (guest).',
|
||||
})
|
||||
percentGuest!: number;
|
||||
|
||||
@Field(() => Float, {
|
||||
description: 'The percentage of CPU time stolen by the hypervisor.',
|
||||
})
|
||||
percentSteal!: number;
|
||||
}
|
||||
|
||||
@ObjectType({ implements: () => Node })
|
||||
|
||||
246
api/src/unraid-api/graph/resolvers/info/cpu/cpu.service.spec.ts
Normal file
246
api/src/unraid-api/graph/resolvers/info/cpu/cpu.service.spec.ts
Normal file
@@ -0,0 +1,246 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { CpuService } from '@app/unraid-api/graph/resolvers/info/cpu/cpu.service.js';
|
||||
|
||||
vi.mock('systeminformation', () => ({
|
||||
cpu: vi.fn().mockResolvedValue({
|
||||
manufacturer: 'Intel',
|
||||
brand: 'Core i7-9700K',
|
||||
vendor: 'Intel',
|
||||
family: '6',
|
||||
model: '158',
|
||||
stepping: '12',
|
||||
revision: '',
|
||||
voltage: '1.2V',
|
||||
speed: 3.6,
|
||||
speedMin: 800,
|
||||
speedMax: 4900,
|
||||
cores: 16,
|
||||
physicalCores: 8,
|
||||
processors: 1,
|
||||
socket: 'LGA1151',
|
||||
cache: {
|
||||
l1d: 32768,
|
||||
l1i: 32768,
|
||||
l2: 262144,
|
||||
l3: 12582912,
|
||||
},
|
||||
}),
|
||||
cpuFlags: vi.fn().mockResolvedValue('fpu vme de pse tsc msr pae mce cx8'),
|
||||
currentLoad: vi.fn().mockResolvedValue({
|
||||
avgLoad: 2.5,
|
||||
currentLoad: 25.5,
|
||||
currentLoadUser: 15.0,
|
||||
currentLoadSystem: 8.0,
|
||||
currentLoadNice: 0.5,
|
||||
currentLoadIdle: 74.5,
|
||||
currentLoadIrq: 1.0,
|
||||
currentLoadSteal: 0.2,
|
||||
currentLoadGuest: 0.3,
|
||||
rawCurrentLoad: 25500,
|
||||
rawCurrentLoadUser: 15000,
|
||||
rawCurrentLoadSystem: 8000,
|
||||
rawCurrentLoadNice: 500,
|
||||
rawCurrentLoadIdle: 74500,
|
||||
rawCurrentLoadIrq: 1000,
|
||||
rawCurrentLoadSteal: 200,
|
||||
rawCurrentLoadGuest: 300,
|
||||
cpus: [
|
||||
{
|
||||
load: 30.0,
|
||||
loadUser: 20.0,
|
||||
loadSystem: 10.0,
|
||||
loadNice: 0,
|
||||
loadIdle: 70.0,
|
||||
loadIrq: 0,
|
||||
loadSteal: 0,
|
||||
loadGuest: 0,
|
||||
rawLoad: 30000,
|
||||
rawLoadUser: 20000,
|
||||
rawLoadSystem: 10000,
|
||||
rawLoadNice: 0,
|
||||
rawLoadIdle: 70000,
|
||||
rawLoadIrq: 0,
|
||||
rawLoadSteal: 0,
|
||||
rawLoadGuest: 0,
|
||||
},
|
||||
{
|
||||
load: 21.0,
|
||||
loadUser: 15.0,
|
||||
loadSystem: 6.0,
|
||||
loadNice: 0,
|
||||
loadIdle: 79.0,
|
||||
loadIrq: 0,
|
||||
loadSteal: 0,
|
||||
loadGuest: 0,
|
||||
rawLoad: 21000,
|
||||
rawLoadUser: 15000,
|
||||
rawLoadSystem: 6000,
|
||||
rawLoadNice: 0,
|
||||
rawLoadIdle: 79000,
|
||||
rawLoadIrq: 0,
|
||||
rawLoadSteal: 0,
|
||||
rawLoadGuest: 0,
|
||||
},
|
||||
],
|
||||
}),
|
||||
}));
|
||||
|
||||
describe('CpuService', () => {
|
||||
let service: CpuService;
|
||||
|
||||
beforeEach(() => {
|
||||
service = new CpuService();
|
||||
});
|
||||
|
||||
describe('generateCpu', () => {
|
||||
it('should return CPU information with correct structure', async () => {
|
||||
const result = await service.generateCpu();
|
||||
|
||||
expect(result).toEqual({
|
||||
id: 'info/cpu',
|
||||
manufacturer: 'Intel',
|
||||
brand: 'Core i7-9700K',
|
||||
vendor: 'Intel',
|
||||
family: '6',
|
||||
model: '158',
|
||||
stepping: 12,
|
||||
revision: '',
|
||||
voltage: '1.2V',
|
||||
speed: 3.6,
|
||||
speedmin: 800,
|
||||
speedmax: 4900,
|
||||
cores: 8,
|
||||
threads: 16,
|
||||
processors: 1,
|
||||
socket: 'LGA1151',
|
||||
cache: {
|
||||
l1d: 32768,
|
||||
l1i: 32768,
|
||||
l2: 262144,
|
||||
l3: 12582912,
|
||||
},
|
||||
flags: ['fpu', 'vme', 'de', 'pse', 'tsc', 'msr', 'pae', 'mce', 'cx8'],
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle missing speed values', async () => {
|
||||
const { cpu } = await import('systeminformation');
|
||||
vi.mocked(cpu).mockResolvedValueOnce({
|
||||
manufacturer: 'Intel',
|
||||
brand: 'Core i7-9700K',
|
||||
vendor: 'Intel',
|
||||
family: '6',
|
||||
model: '158',
|
||||
stepping: '12',
|
||||
revision: '',
|
||||
voltage: '1.2V',
|
||||
speed: 3.6,
|
||||
cores: 16,
|
||||
physicalCores: 8,
|
||||
processors: 1,
|
||||
socket: 'LGA1151',
|
||||
cache: { l1d: 32768, l1i: 32768, l2: 262144, l3: 12582912 },
|
||||
} as any);
|
||||
|
||||
const result = await service.generateCpu();
|
||||
|
||||
expect(result.speedmin).toBe(-1);
|
||||
expect(result.speedmax).toBe(-1);
|
||||
});
|
||||
|
||||
it('should handle cpuFlags error gracefully', async () => {
|
||||
const { cpuFlags } = await import('systeminformation');
|
||||
vi.mocked(cpuFlags).mockRejectedValueOnce(new Error('flags error'));
|
||||
|
||||
const result = await service.generateCpu();
|
||||
|
||||
expect(result.flags).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('generateCpuLoad', () => {
|
||||
it('should return CPU utilization with all load metrics', async () => {
|
||||
const result = await service.generateCpuLoad();
|
||||
|
||||
expect(result).toEqual({
|
||||
id: 'info/cpu-load',
|
||||
percentTotal: 25.5,
|
||||
cpus: [
|
||||
{
|
||||
percentTotal: 30.0,
|
||||
percentUser: 20.0,
|
||||
percentSystem: 10.0,
|
||||
percentNice: 0,
|
||||
percentIdle: 70.0,
|
||||
percentIrq: 0,
|
||||
percentGuest: 0,
|
||||
percentSteal: 0,
|
||||
},
|
||||
{
|
||||
percentTotal: 21.0,
|
||||
percentUser: 15.0,
|
||||
percentSystem: 6.0,
|
||||
percentNice: 0,
|
||||
percentIdle: 79.0,
|
||||
percentIrq: 0,
|
||||
percentGuest: 0,
|
||||
percentSteal: 0,
|
||||
},
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
it('should include guest and steal metrics when present', async () => {
|
||||
const { currentLoad } = await import('systeminformation');
|
||||
vi.mocked(currentLoad).mockResolvedValueOnce({
|
||||
avgLoad: 2.5,
|
||||
currentLoad: 25.5,
|
||||
currentLoadUser: 15.0,
|
||||
currentLoadSystem: 8.0,
|
||||
currentLoadNice: 0.5,
|
||||
currentLoadIdle: 74.5,
|
||||
currentLoadIrq: 1.0,
|
||||
currentLoadSteal: 0.2,
|
||||
currentLoadGuest: 0.3,
|
||||
rawCurrentLoad: 25500,
|
||||
rawCurrentLoadUser: 15000,
|
||||
rawCurrentLoadSystem: 8000,
|
||||
rawCurrentLoadNice: 500,
|
||||
rawCurrentLoadIdle: 74500,
|
||||
rawCurrentLoadIrq: 1000,
|
||||
rawCurrentLoadSteal: 200,
|
||||
rawCurrentLoadGuest: 300,
|
||||
cpus: [
|
||||
{
|
||||
load: 30.0,
|
||||
loadUser: 20.0,
|
||||
loadSystem: 10.0,
|
||||
loadNice: 0,
|
||||
loadIdle: 70.0,
|
||||
loadIrq: 0,
|
||||
loadGuest: 2.5,
|
||||
loadSteal: 1.2,
|
||||
rawLoad: 30000,
|
||||
rawLoadUser: 20000,
|
||||
rawLoadSystem: 10000,
|
||||
rawLoadNice: 0,
|
||||
rawLoadIdle: 70000,
|
||||
rawLoadIrq: 0,
|
||||
rawLoadGuest: 2500,
|
||||
rawLoadSteal: 1200,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const result = await service.generateCpuLoad();
|
||||
|
||||
expect(result.cpus[0]).toEqual(
|
||||
expect.objectContaining({
|
||||
percentGuest: 2.5,
|
||||
percentSteal: 1.2,
|
||||
})
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -37,6 +37,8 @@ export class CpuService {
|
||||
percentNice: cpu.loadNice,
|
||||
percentIdle: cpu.loadIdle,
|
||||
percentIrq: cpu.loadIrq,
|
||||
percentGuest: cpu.loadGuest || 0,
|
||||
percentSteal: cpu.loadSteal || 0,
|
||||
})),
|
||||
};
|
||||
}
|
||||
|
||||
@@ -0,0 +1,213 @@
|
||||
import { Test, TestingModule } from '@nestjs/testing';
|
||||
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import {
|
||||
LogWatcherManager,
|
||||
WatcherState,
|
||||
} from '@app/unraid-api/graph/resolvers/logs/log-watcher-manager.service.js';
|
||||
|
||||
describe('LogWatcherManager', () => {
|
||||
let manager: LogWatcherManager;
|
||||
let mockWatcher: any;
|
||||
|
||||
beforeEach(async () => {
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
providers: [LogWatcherManager],
|
||||
}).compile();
|
||||
|
||||
manager = module.get<LogWatcherManager>(LogWatcherManager);
|
||||
|
||||
mockWatcher = {
|
||||
close: vi.fn(),
|
||||
on: vi.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
describe('state management', () => {
|
||||
it('should set watcher as initializing', () => {
|
||||
manager.setInitializing('test-key');
|
||||
const entry = manager.getEntry('test-key');
|
||||
expect(entry).toBeDefined();
|
||||
expect(entry?.state).toBe(WatcherState.INITIALIZING);
|
||||
});
|
||||
|
||||
it('should set watcher as active with position', () => {
|
||||
manager.setActive('test-key', mockWatcher as any, 1000);
|
||||
const entry = manager.getEntry('test-key');
|
||||
expect(entry).toBeDefined();
|
||||
expect(entry?.state).toBe(WatcherState.ACTIVE);
|
||||
if (manager.isActive(entry)) {
|
||||
expect(entry.watcher).toBe(mockWatcher);
|
||||
expect(entry.position).toBe(1000);
|
||||
}
|
||||
});
|
||||
|
||||
it('should set watcher as stopping', () => {
|
||||
manager.setStopping('test-key');
|
||||
const entry = manager.getEntry('test-key');
|
||||
expect(entry).toBeDefined();
|
||||
expect(entry?.state).toBe(WatcherState.STOPPING);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isWatchingOrInitializing', () => {
|
||||
it('should return true for initializing watcher', () => {
|
||||
manager.setInitializing('test-key');
|
||||
expect(manager.isWatchingOrInitializing('test-key')).toBe(true);
|
||||
});
|
||||
|
||||
it('should return true for active watcher', () => {
|
||||
manager.setActive('test-key', mockWatcher as any, 0);
|
||||
expect(manager.isWatchingOrInitializing('test-key')).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false for stopping watcher', () => {
|
||||
manager.setStopping('test-key');
|
||||
expect(manager.isWatchingOrInitializing('test-key')).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false for non-existent watcher', () => {
|
||||
expect(manager.isWatchingOrInitializing('test-key')).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('handlePostInitialization', () => {
|
||||
it('should activate watcher when not stopped', () => {
|
||||
manager.setInitializing('test-key');
|
||||
const result = manager.handlePostInitialization('test-key', mockWatcher as any, 500);
|
||||
|
||||
expect(result).toBe(true);
|
||||
expect(mockWatcher.close).not.toHaveBeenCalled();
|
||||
|
||||
const entry = manager.getEntry('test-key');
|
||||
expect(entry?.state).toBe(WatcherState.ACTIVE);
|
||||
if (manager.isActive(entry)) {
|
||||
expect(entry.position).toBe(500);
|
||||
}
|
||||
});
|
||||
|
||||
it('should cleanup watcher when marked as stopping', () => {
|
||||
manager.setStopping('test-key');
|
||||
const result = manager.handlePostInitialization('test-key', mockWatcher as any, 500);
|
||||
|
||||
expect(result).toBe(false);
|
||||
expect(mockWatcher.close).toHaveBeenCalled();
|
||||
expect(manager.getEntry('test-key')).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should cleanup watcher when entry is missing', () => {
|
||||
const result = manager.handlePostInitialization('test-key', mockWatcher as any, 500);
|
||||
|
||||
expect(result).toBe(false);
|
||||
expect(mockWatcher.close).toHaveBeenCalled();
|
||||
expect(manager.getEntry('test-key')).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('stopWatcher', () => {
|
||||
it('should mark initializing watcher as stopping', () => {
|
||||
manager.setInitializing('test-key');
|
||||
manager.stopWatcher('test-key');
|
||||
|
||||
const entry = manager.getEntry('test-key');
|
||||
expect(entry?.state).toBe(WatcherState.STOPPING);
|
||||
});
|
||||
|
||||
it('should close and remove active watcher', () => {
|
||||
manager.setActive('test-key', mockWatcher as any, 0);
|
||||
manager.stopWatcher('test-key');
|
||||
|
||||
expect(mockWatcher.close).toHaveBeenCalled();
|
||||
expect(manager.getEntry('test-key')).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should do nothing for non-existent watcher', () => {
|
||||
manager.stopWatcher('test-key');
|
||||
expect(mockWatcher.close).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('position management', () => {
|
||||
it('should update position for active watcher', () => {
|
||||
manager.setActive('test-key', mockWatcher as any, 100);
|
||||
manager.updatePosition('test-key', 200);
|
||||
|
||||
const position = manager.getPosition('test-key');
|
||||
expect(position).toBe(200);
|
||||
});
|
||||
|
||||
it('should not update position for non-active watcher', () => {
|
||||
manager.setInitializing('test-key');
|
||||
manager.updatePosition('test-key', 200);
|
||||
|
||||
const position = manager.getPosition('test-key');
|
||||
expect(position).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should get position for active watcher', () => {
|
||||
manager.setActive('test-key', mockWatcher as any, 300);
|
||||
expect(manager.getPosition('test-key')).toBe(300);
|
||||
});
|
||||
|
||||
it('should return undefined for non-active watcher', () => {
|
||||
manager.setStopping('test-key');
|
||||
expect(manager.getPosition('test-key')).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('stopAllWatchers', () => {
|
||||
it('should close all active watchers and clear map', () => {
|
||||
const mockWatcher1 = { close: vi.fn() };
|
||||
const mockWatcher2 = { close: vi.fn() };
|
||||
const mockWatcher3 = { close: vi.fn() };
|
||||
|
||||
manager.setActive('key1', mockWatcher1 as any, 0);
|
||||
manager.setInitializing('key2');
|
||||
manager.setActive('key3', mockWatcher2 as any, 0);
|
||||
manager.setStopping('key4');
|
||||
manager.setActive('key5', mockWatcher3 as any, 0);
|
||||
|
||||
manager.stopAllWatchers();
|
||||
|
||||
expect(mockWatcher1.close).toHaveBeenCalled();
|
||||
expect(mockWatcher2.close).toHaveBeenCalled();
|
||||
expect(mockWatcher3.close).toHaveBeenCalled();
|
||||
|
||||
expect(manager.getEntry('key1')).toBeUndefined();
|
||||
expect(manager.getEntry('key2')).toBeUndefined();
|
||||
expect(manager.getEntry('key3')).toBeUndefined();
|
||||
expect(manager.getEntry('key4')).toBeUndefined();
|
||||
expect(manager.getEntry('key5')).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('in-flight processing', () => {
|
||||
it('should prevent concurrent processing', () => {
|
||||
manager.setActive('test-key', mockWatcher as any, 0);
|
||||
|
||||
// First call should succeed
|
||||
expect(manager.startProcessing('test-key')).toBe(true);
|
||||
|
||||
// Second call should fail (already in flight)
|
||||
expect(manager.startProcessing('test-key')).toBe(false);
|
||||
|
||||
// After finishing, should be able to start again
|
||||
manager.finishProcessing('test-key');
|
||||
expect(manager.startProcessing('test-key')).toBe(true);
|
||||
});
|
||||
|
||||
it('should not start processing for non-active watcher', () => {
|
||||
manager.setInitializing('test-key');
|
||||
expect(manager.startProcessing('test-key')).toBe(false);
|
||||
|
||||
manager.setStopping('test-key');
|
||||
expect(manager.startProcessing('test-key')).toBe(false);
|
||||
});
|
||||
|
||||
it('should handle finish processing for non-existent watcher', () => {
|
||||
// Should not throw
|
||||
expect(() => manager.finishProcessing('non-existent')).not.toThrow();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,183 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
|
||||
import * as chokidar from 'chokidar';
|
||||
|
||||
export enum WatcherState {
|
||||
INITIALIZING = 'initializing',
|
||||
ACTIVE = 'active',
|
||||
STOPPING = 'stopping',
|
||||
}
|
||||
|
||||
export type WatcherEntry =
|
||||
| { state: WatcherState.INITIALIZING }
|
||||
| { state: WatcherState.ACTIVE; watcher: chokidar.FSWatcher; position: number; inFlight: boolean }
|
||||
| { state: WatcherState.STOPPING };
|
||||
|
||||
/**
|
||||
* Service responsible for managing log file watchers and their lifecycle.
|
||||
* Handles race conditions during watcher initialization and cleanup.
|
||||
*/
|
||||
@Injectable()
|
||||
export class LogWatcherManager {
|
||||
private readonly logger = new Logger(LogWatcherManager.name);
|
||||
private readonly watchers = new Map<string, WatcherEntry>();
|
||||
|
||||
/**
|
||||
* Set a watcher as initializing
|
||||
*/
|
||||
setInitializing(key: string): void {
|
||||
this.watchers.set(key, { state: WatcherState.INITIALIZING });
|
||||
}
|
||||
|
||||
/**
|
||||
* Set a watcher as active with its FSWatcher and position
|
||||
*/
|
||||
setActive(key: string, watcher: chokidar.FSWatcher, position: number): void {
|
||||
this.watchers.set(key, { state: WatcherState.ACTIVE, watcher, position, inFlight: false });
|
||||
}
|
||||
|
||||
/**
|
||||
* Mark a watcher as stopping (used during initialization race conditions)
|
||||
*/
|
||||
setStopping(key: string): void {
|
||||
this.watchers.set(key, { state: WatcherState.STOPPING });
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a watcher entry by key
|
||||
*/
|
||||
getEntry(key: string): WatcherEntry | undefined {
|
||||
return this.watchers.get(key);
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a watcher entry
|
||||
*/
|
||||
removeEntry(key: string): void {
|
||||
this.watchers.delete(key);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a watcher is active and return typed entry
|
||||
*/
|
||||
isActive(entry: WatcherEntry | undefined): entry is {
|
||||
state: WatcherState.ACTIVE;
|
||||
watcher: chokidar.FSWatcher;
|
||||
position: number;
|
||||
inFlight: boolean;
|
||||
} {
|
||||
return entry?.state === WatcherState.ACTIVE;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a watcher exists and is either initializing or active
|
||||
*/
|
||||
isWatchingOrInitializing(key: string): boolean {
|
||||
const entry = this.getEntry(key);
|
||||
return (
|
||||
entry !== undefined &&
|
||||
(entry.state === WatcherState.ACTIVE || entry.state === WatcherState.INITIALIZING)
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle cleanup after initialization completes.
|
||||
* Returns true if the watcher should continue, false if it should be cleaned up.
|
||||
*/
|
||||
handlePostInitialization(key: string, watcher: chokidar.FSWatcher, position: number): boolean {
|
||||
const currentEntry = this.getEntry(key);
|
||||
|
||||
if (!currentEntry || currentEntry.state === WatcherState.STOPPING) {
|
||||
// We were stopped during initialization, clean up immediately
|
||||
this.logger.debug(`Watcher for ${key} was stopped during initialization, cleaning up`);
|
||||
watcher.close();
|
||||
this.removeEntry(key);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Store the active watcher and position
|
||||
this.setActive(key, watcher, position);
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop a watcher, handling all possible states
|
||||
*/
|
||||
stopWatcher(key: string): void {
|
||||
const entry = this.getEntry(key);
|
||||
|
||||
if (!entry) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (entry.state === WatcherState.INITIALIZING) {
|
||||
// Mark as stopping so the initialization will clean up
|
||||
this.setStopping(key);
|
||||
this.logger.debug(`Marked watcher as stopping during initialization: ${key}`);
|
||||
} else if (entry.state === WatcherState.ACTIVE) {
|
||||
// Close the active watcher
|
||||
entry.watcher.close();
|
||||
this.removeEntry(key);
|
||||
this.logger.debug(`Stopped active watcher: ${key}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update the position for an active watcher
|
||||
*/
|
||||
updatePosition(key: string, newPosition: number): void {
|
||||
const entry = this.getEntry(key);
|
||||
if (this.isActive(entry)) {
|
||||
entry.position = newPosition;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Start processing a change event (set inFlight to true)
|
||||
* Returns true if processing can proceed, false if already in flight
|
||||
*/
|
||||
startProcessing(key: string): boolean {
|
||||
const entry = this.getEntry(key);
|
||||
if (this.isActive(entry)) {
|
||||
if (entry.inFlight) {
|
||||
return false; // Already processing
|
||||
}
|
||||
entry.inFlight = true;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Finish processing a change event (set inFlight to false)
|
||||
*/
|
||||
finishProcessing(key: string): void {
|
||||
const entry = this.getEntry(key);
|
||||
if (this.isActive(entry)) {
|
||||
entry.inFlight = false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the position for an active watcher
|
||||
*/
|
||||
getPosition(key: string): number | undefined {
|
||||
const entry = this.getEntry(key);
|
||||
if (this.isActive(entry)) {
|
||||
return entry.position;
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up all watchers (useful for module cleanup)
|
||||
*/
|
||||
stopAllWatchers(): void {
|
||||
for (const entry of this.watchers.values()) {
|
||||
if (this.isActive(entry)) {
|
||||
entry.watcher.close();
|
||||
}
|
||||
}
|
||||
this.watchers.clear();
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,13 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
|
||||
import { LogWatcherManager } from '@app/unraid-api/graph/resolvers/logs/log-watcher-manager.service.js';
|
||||
import { LogsResolver } from '@app/unraid-api/graph/resolvers/logs/logs.resolver.js';
|
||||
import { LogsService } from '@app/unraid-api/graph/resolvers/logs/logs.service.js';
|
||||
import { ServicesModule } from '@app/unraid-api/graph/services/services.module.js';
|
||||
|
||||
@Module({
|
||||
providers: [LogsResolver, LogsService],
|
||||
exports: [LogsService],
|
||||
imports: [ServicesModule],
|
||||
providers: [LogsResolver, LogsService, LogWatcherManager],
|
||||
exports: [LogsService, LogWatcherManager],
|
||||
})
|
||||
export class LogsModule {}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import { Test, TestingModule } from '@nestjs/testing';
|
||||
|
||||
import { beforeEach, describe, expect, it } from 'vitest';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { LogsResolver } from '@app/unraid-api/graph/resolvers/logs/logs.resolver.js';
|
||||
import { LogsService } from '@app/unraid-api/graph/resolvers/logs/logs.service.js';
|
||||
import { SubscriptionHelperService } from '@app/unraid-api/graph/services/subscription-helper.service.js';
|
||||
|
||||
describe('LogsResolver', () => {
|
||||
let resolver: LogsResolver;
|
||||
@@ -18,6 +19,13 @@ describe('LogsResolver', () => {
|
||||
// Add mock implementations for service methods used by resolver
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: SubscriptionHelperService,
|
||||
useValue: {
|
||||
// Add mock implementations for subscription helper methods
|
||||
createTrackedSubscription: vi.fn(),
|
||||
},
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
resolver = module.get<LogsResolver>(LogsResolver);
|
||||
|
||||
@@ -3,13 +3,16 @@ import { Args, Int, Query, Resolver, Subscription } from '@nestjs/graphql';
|
||||
import { AuthAction, Resource } from '@unraid/shared/graphql.model.js';
|
||||
import { UsePermissions } from '@unraid/shared/use-permissions.directive.js';
|
||||
|
||||
import { createSubscription, PUBSUB_CHANNEL } from '@app/core/pubsub.js';
|
||||
import { LogFile, LogFileContent } from '@app/unraid-api/graph/resolvers/logs/logs.model.js';
|
||||
import { LogsService } from '@app/unraid-api/graph/resolvers/logs/logs.service.js';
|
||||
import { SubscriptionHelperService } from '@app/unraid-api/graph/services/subscription-helper.service.js';
|
||||
|
||||
@Resolver(() => LogFile)
|
||||
export class LogsResolver {
|
||||
constructor(private readonly logsService: LogsService) {}
|
||||
constructor(
|
||||
private readonly logsService: LogsService,
|
||||
private readonly subscriptionHelper: SubscriptionHelperService
|
||||
) {}
|
||||
|
||||
@Query(() => [LogFile])
|
||||
@UsePermissions({
|
||||
@@ -38,27 +41,12 @@ export class LogsResolver {
|
||||
action: AuthAction.READ_ANY,
|
||||
resource: Resource.LOGS,
|
||||
})
|
||||
async logFileSubscription(@Args('path') path: string) {
|
||||
// Start watching the file
|
||||
this.logsService.getLogFileSubscriptionChannel(path);
|
||||
logFileSubscription(@Args('path') path: string) {
|
||||
// Register the topic and get the key
|
||||
const topicKey = this.logsService.registerLogFileSubscription(path);
|
||||
|
||||
// Create the async iterator
|
||||
const asyncIterator = createSubscription(PUBSUB_CHANNEL.LOG_FILE);
|
||||
|
||||
// Store the original return method to wrap it
|
||||
const originalReturn = asyncIterator.return;
|
||||
|
||||
// Override the return method to clean up resources
|
||||
asyncIterator.return = async () => {
|
||||
// Stop watching the file when subscription ends
|
||||
this.logsService.stopWatchingLogFile(path);
|
||||
|
||||
// Call the original return method
|
||||
return originalReturn
|
||||
? originalReturn.call(asyncIterator)
|
||||
: Promise.resolve({ value: undefined, done: true });
|
||||
};
|
||||
|
||||
return asyncIterator;
|
||||
// Use the helper service to create a tracked subscription
|
||||
// This automatically handles subscribe/unsubscribe with reference counting
|
||||
return this.subscriptionHelper.createTrackedSubscription(topicKey);
|
||||
}
|
||||
}
|
||||
|
||||
201
api/src/unraid-api/graph/resolvers/logs/logs.service.spec.ts
Normal file
201
api/src/unraid-api/graph/resolvers/logs/logs.service.spec.ts
Normal file
@@ -0,0 +1,201 @@
|
||||
import { Test, TestingModule } from '@nestjs/testing';
|
||||
import * as fs from 'node:fs/promises';
|
||||
|
||||
import * as chokidar from 'chokidar';
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { LogWatcherManager } from '@app/unraid-api/graph/resolvers/logs/log-watcher-manager.service.js';
|
||||
import { LogsService } from '@app/unraid-api/graph/resolvers/logs/logs.service.js';
|
||||
import { SubscriptionTrackerService } from '@app/unraid-api/graph/services/subscription-tracker.service.js';
|
||||
|
||||
vi.mock('node:fs/promises');
|
||||
vi.mock('chokidar');
|
||||
vi.mock('@app/store/index.js', () => ({
|
||||
getters: {
|
||||
paths: () => ({
|
||||
'unraid-log-base': '/var/log',
|
||||
}),
|
||||
},
|
||||
}));
|
||||
vi.mock('@app/core/pubsub.js', () => ({
|
||||
pubsub: {
|
||||
publish: vi.fn(),
|
||||
},
|
||||
PUBSUB_CHANNEL: {},
|
||||
}));
|
||||
|
||||
describe('LogsService', () => {
|
||||
let service: LogsService;
|
||||
let mockWatcher: any;
|
||||
let subscriptionTracker: any;
|
||||
|
||||
beforeEach(async () => {
|
||||
// Create a mock watcher
|
||||
mockWatcher = {
|
||||
on: vi.fn(),
|
||||
close: vi.fn(),
|
||||
};
|
||||
|
||||
// Mock chokidar.watch to return our mock watcher
|
||||
vi.mocked(chokidar.watch).mockReturnValue(mockWatcher as any);
|
||||
|
||||
// Mock fs.stat to return a file size
|
||||
vi.mocked(fs.stat).mockResolvedValue({ size: 1000 } as any);
|
||||
|
||||
subscriptionTracker = {
|
||||
getSubscriberCount: vi.fn().mockReturnValue(0),
|
||||
registerTopic: vi.fn(),
|
||||
};
|
||||
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
providers: [
|
||||
LogsService,
|
||||
LogWatcherManager,
|
||||
{
|
||||
provide: SubscriptionTrackerService,
|
||||
useValue: subscriptionTracker,
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
service = module.get<LogsService>(LogsService);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should be defined', () => {
|
||||
expect(service).toBeDefined();
|
||||
});
|
||||
|
||||
it('should handle race condition when stopping watcher during initialization', async () => {
|
||||
// Setup: Register the subscription which will trigger registerTopic
|
||||
service.registerLogFileSubscription('test.log');
|
||||
|
||||
// Get the onStart callback that was registered
|
||||
const registerTopicCall = subscriptionTracker.registerTopic.mock.calls[0];
|
||||
const onStartCallback = registerTopicCall[1];
|
||||
const onStopCallback = registerTopicCall[2];
|
||||
|
||||
// Create a promise to control when stat resolves
|
||||
let statResolve: any;
|
||||
const statPromise = new Promise((resolve) => {
|
||||
statResolve = resolve;
|
||||
});
|
||||
vi.mocked(fs.stat).mockReturnValue(statPromise as any);
|
||||
|
||||
// Start the watcher (this will call startWatchingLogFile internally)
|
||||
onStartCallback();
|
||||
|
||||
// At this point, the watcher should be marked as 'initializing'
|
||||
// Now call stop before the stat promise resolves
|
||||
onStopCallback();
|
||||
|
||||
// Now resolve the stat promise to complete initialization
|
||||
statResolve({ size: 1000 });
|
||||
|
||||
// Wait for any async operations to complete
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
|
||||
// The watcher should have been closed due to the race condition check
|
||||
expect(mockWatcher.close).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should not leak watcher if stopped multiple times during initialization', async () => {
|
||||
// Setup: Register the subscription
|
||||
service.registerLogFileSubscription('test.log');
|
||||
|
||||
const registerTopicCall = subscriptionTracker.registerTopic.mock.calls[0];
|
||||
const onStartCallback = registerTopicCall[1];
|
||||
const onStopCallback = registerTopicCall[2];
|
||||
|
||||
// Create controlled stat promise
|
||||
let statResolve: any;
|
||||
const statPromise = new Promise((resolve) => {
|
||||
statResolve = resolve;
|
||||
});
|
||||
vi.mocked(fs.stat).mockReturnValue(statPromise as any);
|
||||
|
||||
// Start the watcher
|
||||
onStartCallback();
|
||||
|
||||
// Call stop multiple times during initialization
|
||||
onStopCallback();
|
||||
onStopCallback();
|
||||
onStopCallback();
|
||||
|
||||
// Complete initialization
|
||||
statResolve({ size: 1000 });
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
|
||||
// Should only close once
|
||||
expect(mockWatcher.close).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should properly handle normal start and stop without race condition', async () => {
|
||||
// Setup: Register the subscription
|
||||
service.registerLogFileSubscription('test.log');
|
||||
|
||||
const registerTopicCall = subscriptionTracker.registerTopic.mock.calls[0];
|
||||
const onStartCallback = registerTopicCall[1];
|
||||
const onStopCallback = registerTopicCall[2];
|
||||
|
||||
// Make stat resolve immediately
|
||||
vi.mocked(fs.stat).mockResolvedValue({ size: 1000 } as any);
|
||||
|
||||
// Start the watcher and let it complete initialization
|
||||
onStartCallback();
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
|
||||
// Watcher should be created but not closed
|
||||
expect(chokidar.watch).toHaveBeenCalled();
|
||||
expect(mockWatcher.close).not.toHaveBeenCalled();
|
||||
|
||||
// Now stop it normally
|
||||
onStopCallback();
|
||||
|
||||
// Watcher should be closed
|
||||
expect(mockWatcher.close).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should handle error during initialization without leaking watchers', async () => {
|
||||
// Setup: Register the subscription
|
||||
service.registerLogFileSubscription('test.log');
|
||||
|
||||
const registerTopicCall = subscriptionTracker.registerTopic.mock.calls[0];
|
||||
const onStartCallback = registerTopicCall[1];
|
||||
|
||||
// Make stat reject with an error
|
||||
vi.mocked(fs.stat).mockRejectedValue(new Error('File not found'));
|
||||
|
||||
// Start the watcher (should fail during initialization)
|
||||
onStartCallback();
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
|
||||
// Watcher should never be created due to stat error
|
||||
expect(chokidar.watch).not.toHaveBeenCalled();
|
||||
expect(mockWatcher.close).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should not create duplicate watchers when started multiple times', async () => {
|
||||
// Setup: Register the subscription
|
||||
service.registerLogFileSubscription('test.log');
|
||||
|
||||
const registerTopicCall = subscriptionTracker.registerTopic.mock.calls[0];
|
||||
const onStartCallback = registerTopicCall[1];
|
||||
|
||||
// Make stat resolve immediately
|
||||
vi.mocked(fs.stat).mockResolvedValue({ size: 1000 } as any);
|
||||
|
||||
// Start the watcher multiple times
|
||||
onStartCallback();
|
||||
onStartCallback();
|
||||
onStartCallback();
|
||||
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
|
||||
// Should only create one watcher
|
||||
expect(chokidar.watch).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
@@ -1,13 +1,15 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { createReadStream } from 'node:fs';
|
||||
import { readdir, readFile, stat } from 'node:fs/promises';
|
||||
import { readdir, stat } from 'node:fs/promises';
|
||||
import { basename, join } from 'node:path';
|
||||
import { createInterface } from 'node:readline';
|
||||
|
||||
import * as chokidar from 'chokidar';
|
||||
|
||||
import { pubsub, PUBSUB_CHANNEL } from '@app/core/pubsub.js';
|
||||
import { pubsub } from '@app/core/pubsub.js';
|
||||
import { getters } from '@app/store/index.js';
|
||||
import { LogWatcherManager } from '@app/unraid-api/graph/resolvers/logs/log-watcher-manager.service.js';
|
||||
import { SubscriptionTrackerService } from '@app/unraid-api/graph/services/subscription-tracker.service.js';
|
||||
|
||||
interface LogFile {
|
||||
name: string;
|
||||
@@ -26,12 +28,13 @@ interface LogFileContent {
|
||||
@Injectable()
|
||||
export class LogsService {
|
||||
private readonly logger = new Logger(LogsService.name);
|
||||
private readonly logWatchers = new Map<
|
||||
string,
|
||||
{ watcher: chokidar.FSWatcher; position: number; subscriptionCount: number }
|
||||
>();
|
||||
private readonly DEFAULT_LINES = 100;
|
||||
|
||||
constructor(
|
||||
private readonly subscriptionTracker: SubscriptionTrackerService,
|
||||
private readonly watcherManager: LogWatcherManager
|
||||
) {}
|
||||
|
||||
/**
|
||||
* Get the base path for log files
|
||||
*/
|
||||
@@ -111,135 +114,208 @@ export class LogsService {
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the subscription channel for a log file
|
||||
* Register and get the topic key for a log file subscription
|
||||
* @param path Path to the log file
|
||||
* @returns The subscription topic key
|
||||
*/
|
||||
getLogFileSubscriptionChannel(path: string): PUBSUB_CHANNEL {
|
||||
registerLogFileSubscription(path: string): string {
|
||||
const normalizedPath = join(this.logBasePath, basename(path));
|
||||
const topicKey = this.getTopicKey(normalizedPath);
|
||||
|
||||
// Start watching the file if not already watching
|
||||
if (!this.logWatchers.has(normalizedPath)) {
|
||||
this.startWatchingLogFile(normalizedPath);
|
||||
} else {
|
||||
// Increment subscription count for existing watcher
|
||||
const watcher = this.logWatchers.get(normalizedPath);
|
||||
if (watcher) {
|
||||
watcher.subscriptionCount++;
|
||||
this.logger.debug(
|
||||
`Incremented subscription count for ${normalizedPath} to ${watcher.subscriptionCount}`
|
||||
);
|
||||
}
|
||||
// Register the topic if not already registered
|
||||
if (!this.subscriptionTracker.getSubscriberCount(topicKey)) {
|
||||
this.logger.debug(`Registering log file subscription topic: ${topicKey}`);
|
||||
|
||||
this.subscriptionTracker.registerTopic(
|
||||
topicKey,
|
||||
// onStart handler
|
||||
() => {
|
||||
this.logger.debug(`Starting log file watcher for topic: ${topicKey}`);
|
||||
this.startWatchingLogFile(normalizedPath);
|
||||
},
|
||||
// onStop handler
|
||||
() => {
|
||||
this.logger.debug(`Stopping log file watcher for topic: ${topicKey}`);
|
||||
this.stopWatchingLogFile(normalizedPath);
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
return PUBSUB_CHANNEL.LOG_FILE;
|
||||
return topicKey;
|
||||
}
|
||||
|
||||
/**
|
||||
* Start watching a log file for changes using chokidar
|
||||
* @param path Path to the log file
|
||||
*/
|
||||
private async startWatchingLogFile(path: string): Promise<void> {
|
||||
try {
|
||||
// Get initial file size
|
||||
const stats = await stat(path);
|
||||
let position = stats.size;
|
||||
private startWatchingLogFile(path: string): void {
|
||||
const watcherKey = path;
|
||||
|
||||
// Create a watcher for the file using chokidar
|
||||
const watcher = chokidar.watch(path, {
|
||||
persistent: true,
|
||||
awaitWriteFinish: {
|
||||
stabilityThreshold: 300,
|
||||
pollInterval: 100,
|
||||
},
|
||||
});
|
||||
// Check if already watching or initializing
|
||||
if (this.watcherManager.isWatchingOrInitializing(watcherKey)) {
|
||||
this.logger.debug(`Already watching or initializing log file: ${watcherKey}`);
|
||||
return;
|
||||
}
|
||||
|
||||
watcher.on('change', async () => {
|
||||
try {
|
||||
const newStats = await stat(path);
|
||||
// Mark as initializing immediately to prevent race conditions
|
||||
this.watcherManager.setInitializing(watcherKey);
|
||||
|
||||
// If the file has grown
|
||||
if (newStats.size > position) {
|
||||
// Read only the new content
|
||||
const stream = createReadStream(path, {
|
||||
start: position,
|
||||
end: newStats.size - 1,
|
||||
});
|
||||
// Get initial file size and set up watcher
|
||||
stat(path)
|
||||
.then((stats) => {
|
||||
const position = stats.size;
|
||||
|
||||
let newContent = '';
|
||||
stream.on('data', (chunk) => {
|
||||
newContent += chunk.toString();
|
||||
});
|
||||
// Create a watcher for the file using chokidar
|
||||
const watcher = chokidar.watch(path, {
|
||||
persistent: true,
|
||||
awaitWriteFinish: {
|
||||
stabilityThreshold: 300,
|
||||
pollInterval: 100,
|
||||
},
|
||||
});
|
||||
|
||||
stream.on('end', () => {
|
||||
if (newContent) {
|
||||
pubsub.publish(PUBSUB_CHANNEL.LOG_FILE, {
|
||||
watcher.on('change', async () => {
|
||||
// Check if we're already processing a change event for this file
|
||||
if (!this.watcherManager.startProcessing(watcherKey)) {
|
||||
// Already processing, ignore this event
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const newStats = await stat(path);
|
||||
|
||||
// Get the current position
|
||||
const currentPosition = this.watcherManager.getPosition(watcherKey);
|
||||
if (currentPosition === undefined) {
|
||||
// Watcher was stopped or not active, ignore the event
|
||||
return;
|
||||
}
|
||||
|
||||
// If the file has grown
|
||||
if (newStats.size > currentPosition) {
|
||||
// Read only the new content
|
||||
const stream = createReadStream(path, {
|
||||
start: currentPosition,
|
||||
end: newStats.size - 1,
|
||||
});
|
||||
|
||||
let newContent = '';
|
||||
stream.on('data', (chunk) => {
|
||||
newContent += chunk.toString();
|
||||
});
|
||||
|
||||
stream.on('end', () => {
|
||||
try {
|
||||
if (newContent) {
|
||||
// Use topic-specific channel
|
||||
const topicKey = this.getTopicKey(path);
|
||||
pubsub.publish(topicKey, {
|
||||
logFile: {
|
||||
path,
|
||||
content: newContent,
|
||||
totalLines: 0, // We don't need to count lines for updates
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Update position for next read (while still holding the guard)
|
||||
this.watcherManager.updatePosition(watcherKey, newStats.size);
|
||||
} finally {
|
||||
// Clear the in-flight flag
|
||||
this.watcherManager.finishProcessing(watcherKey);
|
||||
}
|
||||
});
|
||||
|
||||
stream.on('error', (error) => {
|
||||
this.logger.error(`Error reading stream for ${path}: ${error}`);
|
||||
// Clear the in-flight flag on error
|
||||
this.watcherManager.finishProcessing(watcherKey);
|
||||
});
|
||||
} else if (newStats.size < currentPosition) {
|
||||
// File was truncated, reset position and read from beginning
|
||||
this.logger.debug(`File ${path} was truncated, resetting position`);
|
||||
|
||||
try {
|
||||
// Read the entire file content
|
||||
const content = await this.getLogFileContent(
|
||||
path,
|
||||
this.DEFAULT_LINES,
|
||||
undefined
|
||||
);
|
||||
|
||||
// Use topic-specific channel
|
||||
const topicKey = this.getTopicKey(path);
|
||||
pubsub.publish(topicKey, {
|
||||
logFile: {
|
||||
path,
|
||||
content: newContent,
|
||||
totalLines: 0, // We don't need to count lines for updates
|
||||
...content,
|
||||
},
|
||||
});
|
||||
|
||||
// Update position (while still holding the guard)
|
||||
this.watcherManager.updatePosition(watcherKey, newStats.size);
|
||||
} finally {
|
||||
// Clear the in-flight flag
|
||||
this.watcherManager.finishProcessing(watcherKey);
|
||||
}
|
||||
|
||||
// Update position for next read
|
||||
position = newStats.size;
|
||||
});
|
||||
} else if (newStats.size < position) {
|
||||
// File was truncated, reset position and read from beginning
|
||||
position = 0;
|
||||
this.logger.debug(`File ${path} was truncated, resetting position`);
|
||||
|
||||
// Read the entire file content
|
||||
const content = await this.getLogFileContent(path);
|
||||
|
||||
pubsub.publish(PUBSUB_CHANNEL.LOG_FILE, {
|
||||
logFile: content,
|
||||
});
|
||||
|
||||
position = newStats.size;
|
||||
} else {
|
||||
// File size unchanged, clear the in-flight flag
|
||||
this.watcherManager.finishProcessing(watcherKey);
|
||||
}
|
||||
} catch (error: unknown) {
|
||||
this.logger.error(`Error processing file change for ${path}: ${error}`);
|
||||
// Clear the in-flight flag on error
|
||||
this.watcherManager.finishProcessing(watcherKey);
|
||||
}
|
||||
} catch (error: unknown) {
|
||||
this.logger.error(`Error processing file change for ${path}: ${error}`);
|
||||
});
|
||||
|
||||
watcher.on('error', (error) => {
|
||||
this.logger.error(`Chokidar watcher error for ${path}: ${error}`);
|
||||
});
|
||||
|
||||
// Check if we were stopped during initialization and handle cleanup
|
||||
if (!this.watcherManager.handlePostInitialization(watcherKey, watcher, position)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Publish initial snapshot
|
||||
this.getLogFileContent(path, this.DEFAULT_LINES, undefined)
|
||||
.then((content) => {
|
||||
const topicKey = this.getTopicKey(path);
|
||||
pubsub.publish(topicKey, {
|
||||
logFile: {
|
||||
...content,
|
||||
},
|
||||
});
|
||||
})
|
||||
.catch((error) => {
|
||||
this.logger.error(`Error publishing initial log content for ${path}: ${error}`);
|
||||
});
|
||||
|
||||
this.logger.debug(`Started watching log file with chokidar: ${path}`);
|
||||
})
|
||||
.catch((error) => {
|
||||
this.logger.error(`Error setting up file watcher for ${path}: ${error}`);
|
||||
// Clean up the initializing entry on error
|
||||
this.watcherManager.removeEntry(watcherKey);
|
||||
});
|
||||
}
|
||||
|
||||
watcher.on('error', (error) => {
|
||||
this.logger.error(`Chokidar watcher error for ${path}: ${error}`);
|
||||
});
|
||||
|
||||
// Store the watcher and current position with initial subscription count of 1
|
||||
this.logWatchers.set(path, { watcher, position, subscriptionCount: 1 });
|
||||
|
||||
this.logger.debug(
|
||||
`Started watching log file with chokidar: ${path} (subscription count: 1)`
|
||||
);
|
||||
} catch (error: unknown) {
|
||||
this.logger.error(`Error setting up chokidar file watcher for ${path}: ${error}`);
|
||||
}
|
||||
/**
|
||||
* Get the topic key for a log file subscription
|
||||
* @param path Path to the log file (should already be normalized)
|
||||
* @returns The topic key
|
||||
*/
|
||||
private getTopicKey(path: string): string {
|
||||
// Assume path is already normalized (full path)
|
||||
return `LOG_FILE:${path}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop watching a log file
|
||||
* @param path Path to the log file
|
||||
*/
|
||||
public stopWatchingLogFile(path: string): void {
|
||||
const normalizedPath = join(this.logBasePath, basename(path));
|
||||
const watcher = this.logWatchers.get(normalizedPath);
|
||||
|
||||
if (watcher) {
|
||||
// Decrement subscription count
|
||||
watcher.subscriptionCount--;
|
||||
this.logger.debug(
|
||||
`Decremented subscription count for ${normalizedPath} to ${watcher.subscriptionCount}`
|
||||
);
|
||||
|
||||
// Only close the watcher when subscription count reaches 0
|
||||
if (watcher.subscriptionCount <= 0) {
|
||||
watcher.watcher.close();
|
||||
this.logWatchers.delete(normalizedPath);
|
||||
this.logger.debug(`Stopped watching log file: ${normalizedPath} (no more subscribers)`);
|
||||
}
|
||||
}
|
||||
private stopWatchingLogFile(path: string): void {
|
||||
this.watcherManager.stopWatcher(path);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -9,7 +9,7 @@ import { CpuService } from '@app/unraid-api/graph/resolvers/info/cpu/cpu.service
|
||||
import { MemoryService } from '@app/unraid-api/graph/resolvers/info/memory/memory.service.js';
|
||||
import { MetricsResolver } from '@app/unraid-api/graph/resolvers/metrics/metrics.resolver.js';
|
||||
import { SubscriptionHelperService } from '@app/unraid-api/graph/services/subscription-helper.service.js';
|
||||
import { SubscriptionPollingService } from '@app/unraid-api/graph/services/subscription-polling.service.js';
|
||||
import { SubscriptionManagerService } from '@app/unraid-api/graph/services/subscription-manager.service.js';
|
||||
import { SubscriptionTrackerService } from '@app/unraid-api/graph/services/subscription-tracker.service.js';
|
||||
|
||||
describe('MetricsResolver Integration Tests', () => {
|
||||
@@ -25,7 +25,7 @@ describe('MetricsResolver Integration Tests', () => {
|
||||
MemoryService,
|
||||
SubscriptionTrackerService,
|
||||
SubscriptionHelperService,
|
||||
SubscriptionPollingService,
|
||||
SubscriptionManagerService,
|
||||
],
|
||||
}).compile();
|
||||
|
||||
@@ -36,8 +36,8 @@ describe('MetricsResolver Integration Tests', () => {
|
||||
|
||||
afterEach(async () => {
|
||||
// Clean up polling service
|
||||
const pollingService = module.get<SubscriptionPollingService>(SubscriptionPollingService);
|
||||
pollingService.stopAll();
|
||||
const subscriptionManager = module.get<SubscriptionManagerService>(SubscriptionManagerService);
|
||||
subscriptionManager.stopAll();
|
||||
await module.close();
|
||||
});
|
||||
|
||||
@@ -202,10 +202,13 @@ describe('MetricsResolver Integration Tests', () => {
|
||||
it('should handle errors in CPU polling gracefully', async () => {
|
||||
const service = module.get<CpuService>(CpuService);
|
||||
const trackerService = module.get<SubscriptionTrackerService>(SubscriptionTrackerService);
|
||||
const pollingService = module.get<SubscriptionPollingService>(SubscriptionPollingService);
|
||||
const subscriptionManager =
|
||||
module.get<SubscriptionManagerService>(SubscriptionManagerService);
|
||||
|
||||
// Mock logger to capture error logs
|
||||
const loggerSpy = vi.spyOn(pollingService['logger'], 'error').mockImplementation(() => {});
|
||||
const loggerSpy = vi
|
||||
.spyOn(subscriptionManager['logger'], 'error')
|
||||
.mockImplementation(() => {});
|
||||
vi.spyOn(service, 'generateCpuLoad').mockRejectedValueOnce(new Error('CPU error'));
|
||||
|
||||
// Trigger polling
|
||||
@@ -215,7 +218,7 @@ describe('MetricsResolver Integration Tests', () => {
|
||||
await new Promise((resolve) => setTimeout(resolve, 1100));
|
||||
|
||||
expect(loggerSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Error in polling task'),
|
||||
expect.stringContaining('Error in subscription callback'),
|
||||
expect.any(Error)
|
||||
);
|
||||
|
||||
@@ -226,10 +229,13 @@ describe('MetricsResolver Integration Tests', () => {
|
||||
it('should handle errors in memory polling gracefully', async () => {
|
||||
const service = module.get<MemoryService>(MemoryService);
|
||||
const trackerService = module.get<SubscriptionTrackerService>(SubscriptionTrackerService);
|
||||
const pollingService = module.get<SubscriptionPollingService>(SubscriptionPollingService);
|
||||
const subscriptionManager =
|
||||
module.get<SubscriptionManagerService>(SubscriptionManagerService);
|
||||
|
||||
// Mock logger to capture error logs
|
||||
const loggerSpy = vi.spyOn(pollingService['logger'], 'error').mockImplementation(() => {});
|
||||
const loggerSpy = vi
|
||||
.spyOn(subscriptionManager['logger'], 'error')
|
||||
.mockImplementation(() => {});
|
||||
vi.spyOn(service, 'generateMemoryLoad').mockRejectedValueOnce(new Error('Memory error'));
|
||||
|
||||
// Trigger polling
|
||||
@@ -239,7 +245,7 @@ describe('MetricsResolver Integration Tests', () => {
|
||||
await new Promise((resolve) => setTimeout(resolve, 2100));
|
||||
|
||||
expect(loggerSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Error in polling task'),
|
||||
expect.stringContaining('Error in subscription callback'),
|
||||
expect.any(Error)
|
||||
);
|
||||
|
||||
@@ -251,22 +257,30 @@ describe('MetricsResolver Integration Tests', () => {
|
||||
describe('Polling cleanup on module destroy', () => {
|
||||
it('should clean up timers when module is destroyed', async () => {
|
||||
const trackerService = module.get<SubscriptionTrackerService>(SubscriptionTrackerService);
|
||||
const pollingService = module.get<SubscriptionPollingService>(SubscriptionPollingService);
|
||||
const subscriptionManager =
|
||||
module.get<SubscriptionManagerService>(SubscriptionManagerService);
|
||||
|
||||
// Start polling
|
||||
trackerService.subscribe(PUBSUB_CHANNEL.CPU_UTILIZATION);
|
||||
trackerService.subscribe(PUBSUB_CHANNEL.MEMORY_UTILIZATION);
|
||||
|
||||
// Verify polling is active
|
||||
expect(pollingService.isPolling(PUBSUB_CHANNEL.CPU_UTILIZATION)).toBe(true);
|
||||
expect(pollingService.isPolling(PUBSUB_CHANNEL.MEMORY_UTILIZATION)).toBe(true);
|
||||
// Wait a bit for subscriptions to be fully set up
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
|
||||
// Verify subscriptions are active
|
||||
expect(subscriptionManager.isSubscriptionActive(PUBSUB_CHANNEL.CPU_UTILIZATION)).toBe(true);
|
||||
expect(subscriptionManager.isSubscriptionActive(PUBSUB_CHANNEL.MEMORY_UTILIZATION)).toBe(
|
||||
true
|
||||
);
|
||||
|
||||
// Clean up the module
|
||||
await module.close();
|
||||
|
||||
// Timers should be cleaned up
|
||||
expect(pollingService.isPolling(PUBSUB_CHANNEL.CPU_UTILIZATION)).toBe(false);
|
||||
expect(pollingService.isPolling(PUBSUB_CHANNEL.MEMORY_UTILIZATION)).toBe(false);
|
||||
// Subscriptions should be cleaned up
|
||||
expect(subscriptionManager.isSubscriptionActive(PUBSUB_CHANNEL.CPU_UTILIZATION)).toBe(false);
|
||||
expect(subscriptionManager.isSubscriptionActive(PUBSUB_CHANNEL.MEMORY_UTILIZATION)).toBe(
|
||||
false
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -32,6 +32,8 @@ describe('MetricsResolver', () => {
|
||||
loadNice: 0,
|
||||
loadIdle: 70.0,
|
||||
loadIrq: 0,
|
||||
loadGuest: 0,
|
||||
loadSteal: 0,
|
||||
},
|
||||
{
|
||||
load: 21.0,
|
||||
@@ -40,6 +42,8 @@ describe('MetricsResolver', () => {
|
||||
loadNice: 0,
|
||||
loadIdle: 79.0,
|
||||
loadIrq: 0,
|
||||
loadGuest: 0,
|
||||
loadSteal: 0,
|
||||
},
|
||||
],
|
||||
}),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
|
||||
import { AuthModule } from '@app/unraid-api/auth/auth.module.js';
|
||||
import { ApiConfigModule } from '@app/unraid-api/config/api-config.module.js';
|
||||
import { ApiKeyModule } from '@app/unraid-api/graph/resolvers/api-key/api-key.module.js';
|
||||
import { ApiKeyResolver } from '@app/unraid-api/graph/resolvers/api-key/api-key.resolver.js';
|
||||
import { ArrayModule } from '@app/unraid-api/graph/resolvers/array/array.module.js';
|
||||
@@ -11,8 +12,7 @@ import { DockerModule } from '@app/unraid-api/graph/resolvers/docker/docker.modu
|
||||
import { FlashBackupModule } from '@app/unraid-api/graph/resolvers/flash-backup/flash-backup.module.js';
|
||||
import { FlashResolver } from '@app/unraid-api/graph/resolvers/flash/flash.resolver.js';
|
||||
import { InfoModule } from '@app/unraid-api/graph/resolvers/info/info.module.js';
|
||||
import { LogsResolver } from '@app/unraid-api/graph/resolvers/logs/logs.resolver.js';
|
||||
import { LogsService } from '@app/unraid-api/graph/resolvers/logs/logs.service.js';
|
||||
import { LogsModule } from '@app/unraid-api/graph/resolvers/logs/logs.module.js';
|
||||
import { MetricsModule } from '@app/unraid-api/graph/resolvers/metrics/metrics.module.js';
|
||||
import { RootMutationsResolver } from '@app/unraid-api/graph/resolvers/mutation/mutation.resolver.js';
|
||||
import { NotificationsResolver } from '@app/unraid-api/graph/resolvers/notifications/notifications.resolver.js';
|
||||
@@ -39,12 +39,14 @@ import { MeResolver } from '@app/unraid-api/graph/user/user.resolver.js';
|
||||
ServicesModule,
|
||||
ArrayModule,
|
||||
ApiKeyModule,
|
||||
ApiConfigModule,
|
||||
AuthModule,
|
||||
CustomizationModule,
|
||||
DockerModule,
|
||||
DisksModule,
|
||||
FlashBackupModule,
|
||||
InfoModule,
|
||||
LogsModule,
|
||||
RCloneModule,
|
||||
SettingsModule,
|
||||
SsoModule,
|
||||
@@ -54,8 +56,6 @@ import { MeResolver } from '@app/unraid-api/graph/user/user.resolver.js';
|
||||
providers: [
|
||||
ConfigResolver,
|
||||
FlashResolver,
|
||||
LogsResolver,
|
||||
LogsService,
|
||||
MeResolver,
|
||||
NotificationsResolver,
|
||||
NotificationsService,
|
||||
|
||||
@@ -38,7 +38,9 @@ export class Server extends Node {
|
||||
@Field()
|
||||
name!: string;
|
||||
|
||||
@Field(() => ServerStatus)
|
||||
@Field(() => ServerStatus, {
|
||||
description: 'Whether this server is online or offline',
|
||||
})
|
||||
status!: ServerStatus;
|
||||
|
||||
@Field()
|
||||
|
||||
@@ -24,7 +24,7 @@ export class ServerResolver {
|
||||
resource: Resource.SERVERS,
|
||||
})
|
||||
public async server(): Promise<ServerModel | null> {
|
||||
return this.getLocalServer()[0] || null;
|
||||
return this.getLocalServer() || null;
|
||||
}
|
||||
|
||||
@Query(() => [ServerModel])
|
||||
@@ -33,7 +33,7 @@ export class ServerResolver {
|
||||
resource: Resource.SERVERS,
|
||||
})
|
||||
public async servers(): Promise<ServerModel[]> {
|
||||
return this.getLocalServer();
|
||||
return [this.getLocalServer()];
|
||||
}
|
||||
|
||||
@Subscription(() => ServerModel)
|
||||
@@ -45,7 +45,7 @@ export class ServerResolver {
|
||||
return createSubscription(PUBSUB_CHANNEL.SERVERS);
|
||||
}
|
||||
|
||||
private getLocalServer(): ServerModel[] {
|
||||
private getLocalServer(): ServerModel {
|
||||
const emhttp = getters.emhttp();
|
||||
const connectConfig = this.configService.get('connect');
|
||||
|
||||
@@ -64,22 +64,17 @@ export class ServerResolver {
|
||||
avatar: '',
|
||||
};
|
||||
|
||||
return [
|
||||
{
|
||||
id: 'local',
|
||||
owner,
|
||||
guid: guid || '',
|
||||
apikey: connectConfig?.config?.apikey ?? '',
|
||||
name: name ?? 'Local Server',
|
||||
status:
|
||||
connectConfig?.mothership?.status === MinigraphStatus.CONNECTED
|
||||
? ServerStatus.ONLINE
|
||||
: ServerStatus.OFFLINE,
|
||||
wanip,
|
||||
lanip,
|
||||
localurl,
|
||||
remoteurl,
|
||||
},
|
||||
];
|
||||
return {
|
||||
id: 'local',
|
||||
owner,
|
||||
guid: guid || '',
|
||||
apikey: connectConfig?.config?.apikey ?? '',
|
||||
name: name ?? 'Local Server',
|
||||
status: ServerStatus.ONLINE,
|
||||
wanip,
|
||||
lanip,
|
||||
localurl,
|
||||
remoteurl,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,8 +16,8 @@ import {
|
||||
} from '@app/unraid-api/graph/resolvers/settings/settings.model.js';
|
||||
import { ApiSettings } from '@app/unraid-api/graph/resolvers/settings/settings.service.js';
|
||||
import { SsoSettings } from '@app/unraid-api/graph/resolvers/settings/sso-settings.model.js';
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/oidc-config.service.js';
|
||||
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/oidc-provider.model.js';
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
|
||||
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
|
||||
|
||||
@Resolver(() => Settings)
|
||||
export class SettingsResolver {
|
||||
|
||||
@@ -7,7 +7,7 @@ import { type ApiConfig } from '@unraid/shared/services/api-config.js';
|
||||
import { UserSettingsService } from '@unraid/shared/services/user-settings.js';
|
||||
import { execa } from 'execa';
|
||||
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/oidc-config.service.js';
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
|
||||
import { createLabeledControl } from '@app/unraid-api/graph/utils/form-utils.js';
|
||||
import { SettingSlice } from '@app/unraid-api/types/json-forms.js';
|
||||
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
|
||||
import { OidcAuthorizationService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-authorization.service.js';
|
||||
import { OidcClaimsService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-claims.service.js';
|
||||
import { OidcTokenExchangeService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-token-exchange.service.js';
|
||||
|
||||
@Module({
|
||||
providers: [OidcAuthorizationService, OidcTokenExchangeService, OidcClaimsService],
|
||||
exports: [OidcAuthorizationService, OidcTokenExchangeService, OidcClaimsService],
|
||||
})
|
||||
export class OidcAuthModule {}
|
||||
@@ -1,70 +1,26 @@
|
||||
import { UnauthorizedException } from '@nestjs/common';
|
||||
import { ConfigService } from '@nestjs/config';
|
||||
import { Test, TestingModule } from '@nestjs/testing';
|
||||
|
||||
import * as client from 'openid-client';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { OidcAuthService } from '@app/unraid-api/graph/resolvers/sso/oidc-auth.service.js';
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/oidc-config.service.js';
|
||||
import { OidcAuthorizationService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-authorization.service.js';
|
||||
import {
|
||||
AuthorizationOperator,
|
||||
AuthorizationRuleMode,
|
||||
OidcAuthorizationRule,
|
||||
OidcProvider,
|
||||
} from '@app/unraid-api/graph/resolvers/sso/oidc-provider.model.js';
|
||||
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/oidc-session.service.js';
|
||||
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/oidc-state.service.js';
|
||||
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/oidc-validation.service.js';
|
||||
} from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
|
||||
|
||||
describe('OidcAuthService', () => {
|
||||
let service: OidcAuthService;
|
||||
let oidcConfig: any;
|
||||
let sessionService: any;
|
||||
let configService: any;
|
||||
let stateService: any;
|
||||
let validationService: any;
|
||||
describe('OidcAuthorizationService', () => {
|
||||
let service: OidcAuthorizationService;
|
||||
let module: TestingModule;
|
||||
|
||||
beforeEach(async () => {
|
||||
module = await Test.createTestingModule({
|
||||
providers: [
|
||||
OidcAuthService,
|
||||
{
|
||||
provide: ConfigService,
|
||||
useValue: {
|
||||
get: vi.fn(),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcConfigPersistence,
|
||||
useValue: {
|
||||
getProvider: vi.fn(),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcSessionService,
|
||||
useValue: {
|
||||
createSession: vi.fn(),
|
||||
},
|
||||
},
|
||||
OidcStateService,
|
||||
{
|
||||
provide: OidcValidationService,
|
||||
useValue: {
|
||||
validateProvider: vi.fn(),
|
||||
performDiscovery: vi.fn(),
|
||||
},
|
||||
},
|
||||
],
|
||||
providers: [OidcAuthorizationService],
|
||||
}).compile();
|
||||
|
||||
service = module.get<OidcAuthService>(OidcAuthService);
|
||||
oidcConfig = module.get(OidcConfigPersistence);
|
||||
sessionService = module.get(OidcSessionService);
|
||||
configService = module.get(ConfigService);
|
||||
stateService = module.get(OidcStateService);
|
||||
validationService = module.get<OidcValidationService>(OidcValidationService);
|
||||
service = module.get<OidcAuthorizationService>(OidcAuthorizationService);
|
||||
});
|
||||
|
||||
describe('Authorization Rule Evaluation', () => {
|
||||
@@ -1189,467 +1145,4 @@ describe('OidcAuthService', () => {
|
||||
).resolves.toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Manual Configuration (No Discovery)', () => {
|
||||
it('should create manual configuration when discovery fails but manual endpoints are provided', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'manual-provider',
|
||||
name: 'Manual Provider',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
issuer: 'https://manual.example.com',
|
||||
authorizationEndpoint: 'https://manual.example.com/auth',
|
||||
tokenEndpoint: 'https://manual.example.com/token',
|
||||
jwksUri: 'https://manual.example.com/jwks',
|
||||
scopes: ['openid', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
|
||||
// Mock discovery to fail
|
||||
validationService.performDiscovery = vi
|
||||
.fn()
|
||||
.mockRejectedValue(new Error('Discovery failed'));
|
||||
|
||||
// Access the private method
|
||||
const getOrCreateConfig = async (provider: OidcProvider) => {
|
||||
return (service as any).getOrCreateConfig(provider);
|
||||
};
|
||||
|
||||
const config = await getOrCreateConfig(provider);
|
||||
|
||||
// Verify the configuration was created with the correct endpoints
|
||||
expect(config).toBeDefined();
|
||||
expect(config.serverMetadata().authorization_endpoint).toBe(
|
||||
'https://manual.example.com/auth'
|
||||
);
|
||||
expect(config.serverMetadata().token_endpoint).toBe('https://manual.example.com/token');
|
||||
expect(config.serverMetadata().jwks_uri).toBe('https://manual.example.com/jwks');
|
||||
expect(config.serverMetadata().issuer).toBe('https://manual.example.com');
|
||||
});
|
||||
|
||||
it('should create manual configuration with fallback issuer when not provided', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'manual-provider-no-issuer',
|
||||
name: 'Manual Provider No Issuer',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
issuer: '', // Empty issuer should skip discovery and use manual endpoints
|
||||
authorizationEndpoint: 'https://manual.example.com/auth',
|
||||
tokenEndpoint: 'https://manual.example.com/token',
|
||||
scopes: ['openid', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
|
||||
// No need to mock discovery since it won't be called with empty issuer
|
||||
|
||||
// Access the private method
|
||||
const getOrCreateConfig = async (provider: OidcProvider) => {
|
||||
return (service as any).getOrCreateConfig(provider);
|
||||
};
|
||||
|
||||
const config = await getOrCreateConfig(provider);
|
||||
|
||||
// Verify the configuration was created with fallback issuer
|
||||
expect(config).toBeDefined();
|
||||
expect(config.serverMetadata().issuer).toBe('manual-manual-provider-no-issuer');
|
||||
expect(config.serverMetadata().authorization_endpoint).toBe(
|
||||
'https://manual.example.com/auth'
|
||||
);
|
||||
expect(config.serverMetadata().token_endpoint).toBe('https://manual.example.com/token');
|
||||
});
|
||||
|
||||
it('should handle manual configuration with client secret properly', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'manual-with-secret',
|
||||
name: 'Manual With Secret',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'secret-123',
|
||||
issuer: 'https://manual.example.com',
|
||||
authorizationEndpoint: 'https://manual.example.com/auth',
|
||||
tokenEndpoint: 'https://manual.example.com/token',
|
||||
scopes: ['openid', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
|
||||
// Mock discovery to fail
|
||||
validationService.performDiscovery = vi
|
||||
.fn()
|
||||
.mockRejectedValue(new Error('Discovery failed'));
|
||||
|
||||
// Access the private method
|
||||
const getOrCreateConfig = async (provider: OidcProvider) => {
|
||||
return (service as any).getOrCreateConfig(provider);
|
||||
};
|
||||
|
||||
const config = await getOrCreateConfig(provider);
|
||||
|
||||
// Verify configuration was created successfully
|
||||
expect(config).toBeDefined();
|
||||
expect(config.clientMetadata().client_secret).toBe('secret-123');
|
||||
});
|
||||
|
||||
it('should handle manual configuration without client secret (public client)', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'manual-public-client',
|
||||
name: 'Manual Public Client',
|
||||
clientId: 'public-client-id',
|
||||
// No client secret
|
||||
issuer: 'https://manual.example.com',
|
||||
authorizationEndpoint: 'https://manual.example.com/auth',
|
||||
tokenEndpoint: 'https://manual.example.com/token',
|
||||
scopes: ['openid', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
|
||||
// Mock discovery to fail
|
||||
validationService.performDiscovery = vi
|
||||
.fn()
|
||||
.mockRejectedValue(new Error('Discovery failed'));
|
||||
|
||||
// Access the private method
|
||||
const getOrCreateConfig = async (provider: OidcProvider) => {
|
||||
return (service as any).getOrCreateConfig(provider);
|
||||
};
|
||||
|
||||
const config = await getOrCreateConfig(provider);
|
||||
|
||||
// Verify configuration was created successfully for public client
|
||||
expect(config).toBeDefined();
|
||||
expect(config.clientMetadata().client_secret).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should throw error when discovery fails and no manual endpoints provided', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'no-manual-endpoints',
|
||||
name: 'No Manual Endpoints',
|
||||
clientId: 'test-client-id',
|
||||
issuer: 'https://broken.example.com',
|
||||
// Missing authorizationEndpoint and tokenEndpoint
|
||||
scopes: ['openid', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
|
||||
// Mock discovery to fail
|
||||
validationService.performDiscovery = vi
|
||||
.fn()
|
||||
.mockRejectedValue(new Error('Discovery failed'));
|
||||
|
||||
// Access the private method
|
||||
const getOrCreateConfig = async (provider: OidcProvider) => {
|
||||
return (service as any).getOrCreateConfig(provider);
|
||||
};
|
||||
|
||||
await expect(getOrCreateConfig(provider)).rejects.toThrow(UnauthorizedException);
|
||||
});
|
||||
|
||||
it('should throw error when only authorization endpoint is provided', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'partial-manual-endpoints',
|
||||
name: 'Partial Manual Endpoints',
|
||||
clientId: 'test-client-id',
|
||||
issuer: 'https://broken.example.com',
|
||||
authorizationEndpoint: 'https://manual.example.com/auth',
|
||||
// Missing tokenEndpoint
|
||||
scopes: ['openid', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
|
||||
// Mock discovery to fail
|
||||
validationService.performDiscovery = vi
|
||||
.fn()
|
||||
.mockRejectedValue(new Error('Discovery failed'));
|
||||
|
||||
// Access the private method
|
||||
const getOrCreateConfig = async (provider: OidcProvider) => {
|
||||
return (service as any).getOrCreateConfig(provider);
|
||||
};
|
||||
|
||||
await expect(getOrCreateConfig(provider)).rejects.toThrow(UnauthorizedException);
|
||||
});
|
||||
|
||||
it('should cache manual configuration properly', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'cache-test',
|
||||
name: 'Cache Test',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-secret',
|
||||
issuer: 'https://manual.example.com',
|
||||
authorizationEndpoint: 'https://manual.example.com/auth',
|
||||
tokenEndpoint: 'https://manual.example.com/token',
|
||||
scopes: ['openid', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
|
||||
// Mock discovery to fail
|
||||
validationService.performDiscovery = vi
|
||||
.fn()
|
||||
.mockRejectedValue(new Error('Discovery failed'));
|
||||
|
||||
// Access the private method
|
||||
const getOrCreateConfig = async (provider: OidcProvider) => {
|
||||
return (service as any).getOrCreateConfig(provider);
|
||||
};
|
||||
|
||||
// First call should create configuration
|
||||
const config1 = await getOrCreateConfig(provider);
|
||||
|
||||
// Second call should return cached configuration
|
||||
const config2 = await getOrCreateConfig(provider);
|
||||
|
||||
expect(config1).toBe(config2); // Should be the exact same instance
|
||||
expect(validationService.performDiscovery).toHaveBeenCalledTimes(1); // Only called once due to caching
|
||||
});
|
||||
|
||||
it('should handle HTTP endpoints with allowInsecureRequests', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'http-endpoints',
|
||||
name: 'HTTP Endpoints',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-secret',
|
||||
issuer: 'http://manual.example.com', // HTTP instead of HTTPS
|
||||
authorizationEndpoint: 'http://manual.example.com/auth',
|
||||
tokenEndpoint: 'http://manual.example.com/token',
|
||||
scopes: ['openid', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
|
||||
// Mock discovery to fail
|
||||
validationService.performDiscovery = vi
|
||||
.fn()
|
||||
.mockRejectedValue(new Error('Discovery failed'));
|
||||
|
||||
// Access the private method
|
||||
const getOrCreateConfig = async (provider: OidcProvider) => {
|
||||
return (service as any).getOrCreateConfig(provider);
|
||||
};
|
||||
|
||||
const config = await getOrCreateConfig(provider);
|
||||
|
||||
// Verify configuration was created successfully even with HTTP
|
||||
expect(config).toBeDefined();
|
||||
expect(config.serverMetadata().token_endpoint).toBe('http://manual.example.com/token');
|
||||
expect(config.serverMetadata().authorization_endpoint).toBe(
|
||||
'http://manual.example.com/auth'
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getAuthorizationUrl', () => {
|
||||
it('should generate authorization URL with custom authorization endpoint', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'test-provider',
|
||||
name: 'Test Provider',
|
||||
clientId: 'test-client-id',
|
||||
issuer: 'https://example.com',
|
||||
authorizationEndpoint: 'https://custom.example.com/auth',
|
||||
scopes: ['openid', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
|
||||
const authUrl = await service.getAuthorizationUrl(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'localhost:3001'
|
||||
);
|
||||
|
||||
expect(authUrl).toContain('https://custom.example.com/auth');
|
||||
expect(authUrl).toContain('client_id=test-client-id');
|
||||
expect(authUrl).toContain('response_type=code');
|
||||
expect(authUrl).toContain('scope=openid+profile');
|
||||
// State should start with provider ID followed by secure state token
|
||||
expect(authUrl).toMatch(/state=test-provider%3A[a-f0-9]+\.[0-9]+\.[a-f0-9]+/);
|
||||
expect(authUrl).toContain('redirect_uri=');
|
||||
});
|
||||
|
||||
it('should encode provider ID in state parameter', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'encode-test-provider',
|
||||
name: 'Encode Test Provider',
|
||||
clientId: 'test-client-id',
|
||||
issuer: 'https://example.com',
|
||||
authorizationEndpoint: 'https://example.com/auth',
|
||||
scopes: ['openid', 'email'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
|
||||
const authUrl = await service.getAuthorizationUrl('encode-test-provider', 'original-state');
|
||||
|
||||
// Verify that the state parameter includes provider ID at the start
|
||||
expect(authUrl).toMatch(/state=encode-test-provider%3A[a-f0-9]+\.[0-9]+\.[a-f0-9]+/);
|
||||
});
|
||||
|
||||
it('should throw error when provider not found', async () => {
|
||||
oidcConfig.getProvider.mockResolvedValue(null);
|
||||
|
||||
await expect(
|
||||
service.getAuthorizationUrl('nonexistent-provider', 'test-state')
|
||||
).rejects.toThrow('Provider nonexistent-provider not found');
|
||||
});
|
||||
|
||||
it('should handle custom scopes properly', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'custom-scopes-provider',
|
||||
name: 'Custom Scopes Provider',
|
||||
clientId: 'test-client-id',
|
||||
issuer: 'https://example.com',
|
||||
authorizationEndpoint: 'https://example.com/auth',
|
||||
scopes: ['openid', 'profile', 'groups', 'custom:scope'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
|
||||
const authUrl = await service.getAuthorizationUrl('custom-scopes-provider', 'test-state');
|
||||
|
||||
expect(authUrl).toContain('scope=openid+profile+groups+custom%3Ascope');
|
||||
});
|
||||
});
|
||||
|
||||
describe('handleCallback', () => {
|
||||
it('should throw error when provider not found in callback', async () => {
|
||||
oidcConfig.getProvider.mockResolvedValue(null);
|
||||
|
||||
await expect(
|
||||
service.handleCallback('nonexistent-provider', 'code', 'redirect-uri')
|
||||
).rejects.toThrow('Provider nonexistent-provider not found');
|
||||
});
|
||||
|
||||
it('should handle malformed state parameter', async () => {
|
||||
await expect(
|
||||
service.handleCallback('invalid-state', 'code', 'redirect-uri')
|
||||
).rejects.toThrow(UnauthorizedException);
|
||||
});
|
||||
|
||||
it('should call getProvider with the provided provider ID', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'test-provider',
|
||||
name: 'Test Provider',
|
||||
clientId: 'test-client-id',
|
||||
issuer: 'https://example.com',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
|
||||
// This will fail during token exchange, but we're testing the provider lookup logic
|
||||
await expect(
|
||||
service.handleCallback('test-provider', 'code', 'redirect-uri')
|
||||
).rejects.toThrow(UnauthorizedException);
|
||||
|
||||
// Verify the provider was looked up with the correct ID
|
||||
expect(oidcConfig.getProvider).toHaveBeenCalledWith('test-provider');
|
||||
});
|
||||
});
|
||||
|
||||
describe('validateProvider', () => {
|
||||
it('should delegate to validation service and return result', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'validate-provider',
|
||||
name: 'Validate Provider',
|
||||
clientId: 'test-client-id',
|
||||
issuer: 'https://example.com',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
const expectedResult = {
|
||||
isValid: true,
|
||||
authorizationEndpoint: 'https://example.com/auth',
|
||||
tokenEndpoint: 'https://example.com/token',
|
||||
};
|
||||
|
||||
validationService.validateProvider.mockResolvedValue(expectedResult);
|
||||
|
||||
const result = await service.validateProvider(provider);
|
||||
|
||||
expect(result).toEqual(expectedResult);
|
||||
expect(validationService.validateProvider).toHaveBeenCalledWith(provider);
|
||||
});
|
||||
|
||||
it('should clear config cache before validation', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'cache-clear-provider',
|
||||
name: 'Cache Clear Provider',
|
||||
clientId: 'test-client-id',
|
||||
issuer: 'https://example.com',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
const expectedResult = {
|
||||
isValid: false,
|
||||
error: 'Validation failed',
|
||||
};
|
||||
|
||||
validationService.validateProvider.mockResolvedValue(expectedResult);
|
||||
|
||||
const result = await service.validateProvider(provider);
|
||||
|
||||
expect(result).toEqual(expectedResult);
|
||||
// Verify the cache was cleared by checking the method was called
|
||||
expect(validationService.validateProvider).toHaveBeenCalledWith(provider);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getRedirectUri (private method)', () => {
|
||||
it('should generate correct redirect URI with localhost (development)', () => {
|
||||
const getRedirectUri = (service as any).getRedirectUri.bind(service);
|
||||
const redirectUri = getRedirectUri('http://localhost:3000');
|
||||
|
||||
expect(redirectUri).toBe('http://localhost:3000/graphql/api/auth/oidc/callback');
|
||||
});
|
||||
|
||||
it('should generate correct redirect URI with non-localhost host', () => {
|
||||
const getRedirectUri = (service as any).getRedirectUri.bind(service);
|
||||
const redirectUri = getRedirectUri('https://example.com');
|
||||
|
||||
expect(redirectUri).toBe('https://example.com/graphql/api/auth/oidc/callback');
|
||||
});
|
||||
|
||||
it('should handle HTTP protocol for non-localhost hosts', () => {
|
||||
const getRedirectUri = (service as any).getRedirectUri.bind(service);
|
||||
const redirectUri = getRedirectUri('http://tower.local');
|
||||
|
||||
expect(redirectUri).toBe('http://tower.local/graphql/api/auth/oidc/callback');
|
||||
});
|
||||
|
||||
it('should handle non-standard ports correctly', () => {
|
||||
const getRedirectUri = (service as any).getRedirectUri.bind(service);
|
||||
const redirectUri = getRedirectUri('http://example.com:8080');
|
||||
|
||||
expect(redirectUri).toBe('http://example.com:8080/graphql/api/auth/oidc/callback');
|
||||
});
|
||||
|
||||
it('should use default redirect URI when no request host provided', () => {
|
||||
const getRedirectUri = (service as any).getRedirectUri.bind(service);
|
||||
|
||||
// Mock the ConfigService to return a default value
|
||||
configService.get.mockReturnValue('http://tower.local');
|
||||
|
||||
const redirectUri = getRedirectUri();
|
||||
|
||||
expect(redirectUri).toBe('http://tower.local/graphql/api/auth/oidc/callback');
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,170 @@
|
||||
import { Injectable, Logger, UnauthorizedException } from '@nestjs/common';
|
||||
|
||||
import {
|
||||
AuthorizationOperator,
|
||||
AuthorizationRuleMode,
|
||||
OidcAuthorizationRule,
|
||||
OidcProvider,
|
||||
} from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
|
||||
|
||||
interface JwtClaims {
|
||||
sub?: string;
|
||||
email?: string;
|
||||
name?: string;
|
||||
hd?: string; // Google hosted domain
|
||||
[claim: string]: unknown;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class OidcAuthorizationService {
|
||||
private readonly logger = new Logger(OidcAuthorizationService.name);
|
||||
|
||||
/**
|
||||
* Check authorization based on rules
|
||||
* This will throw a helpful error if misconfigured or unauthorized
|
||||
*/
|
||||
async checkAuthorization(provider: OidcProvider, claims: JwtClaims): Promise<void> {
|
||||
this.logger.debug(
|
||||
`Checking authorization for provider ${provider.id} with ${provider.authorizationRules?.length || 0} rules`
|
||||
);
|
||||
this.logger.debug(`Available claims: ${Object.keys(claims).join(', ')}`);
|
||||
this.logger.debug(
|
||||
`Authorization rule mode: ${provider.authorizationRuleMode || AuthorizationRuleMode.OR}`
|
||||
);
|
||||
|
||||
// If no authorization rules are specified, throw a helpful error
|
||||
if (!provider.authorizationRules || provider.authorizationRules.length === 0) {
|
||||
throw new UnauthorizedException(
|
||||
`Login failed: The ${provider.name} provider has no authorization rules configured. ` +
|
||||
`Please configure authorization rules.`
|
||||
);
|
||||
}
|
||||
|
||||
this.logger.debug('Authorization rules to evaluate: %o', provider.authorizationRules);
|
||||
|
||||
// Evaluate the rules
|
||||
const ruleMode = provider.authorizationRuleMode || AuthorizationRuleMode.OR;
|
||||
const isAuthorized = this.evaluateAuthorizationRules(
|
||||
provider.authorizationRules,
|
||||
claims,
|
||||
ruleMode
|
||||
);
|
||||
|
||||
this.logger.debug(`Authorization result: ${isAuthorized}`);
|
||||
|
||||
if (!isAuthorized) {
|
||||
// Log authorization failure with safe claim representation (no PII)
|
||||
const availableClaimKeys = Object.keys(claims).join(', ');
|
||||
this.logger.warn(
|
||||
`Authorization failed for provider ${provider.name}, user ${claims.sub}, available claim keys: [${availableClaimKeys}]`
|
||||
);
|
||||
throw new UnauthorizedException(
|
||||
`Access denied: Your account does not meet the authorization requirements for ${provider.name}.`
|
||||
);
|
||||
}
|
||||
|
||||
this.logger.debug(`Authorization successful for user ${claims.sub}`);
|
||||
}
|
||||
|
||||
private evaluateAuthorizationRules(
|
||||
rules: OidcAuthorizationRule[],
|
||||
claims: JwtClaims,
|
||||
mode: AuthorizationRuleMode = AuthorizationRuleMode.OR
|
||||
): boolean {
|
||||
// No rules means no authorization
|
||||
if (rules.length === 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (mode === AuthorizationRuleMode.AND) {
|
||||
// All rules must pass (AND logic)
|
||||
return rules.every((rule) => this.evaluateRule(rule, claims));
|
||||
} else {
|
||||
// Any rule can pass (OR logic) - default behavior
|
||||
// Multiple rules act as alternative authorization paths
|
||||
return rules.some((rule) => this.evaluateRule(rule, claims));
|
||||
}
|
||||
}
|
||||
|
||||
private evaluateRule(rule: OidcAuthorizationRule, claims: JwtClaims): boolean {
|
||||
const claimValue = claims[rule.claim];
|
||||
|
||||
this.logger.verbose(
|
||||
`Evaluating rule for claim ${rule.claim}: { claimType: ${typeof claimValue}, isArray: ${Array.isArray(claimValue)}, ruleOperator: ${rule.operator}, ruleValuesCount: ${rule.value.length} }`
|
||||
);
|
||||
|
||||
if (claimValue === undefined || claimValue === null) {
|
||||
this.logger.verbose(`Claim ${rule.claim} not found in token`);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Handle non-array, non-string objects
|
||||
if (typeof claimValue === 'object' && claimValue !== null && !Array.isArray(claimValue)) {
|
||||
this.logger.warn(
|
||||
`unexpected JWT claim value encountered - claim ${rule.claim} has unsupported object type (keys: [${Object.keys(claimValue as Record<string, unknown>).join(', ')}])`
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Handle array claims - evaluate rule against each array element
|
||||
if (Array.isArray(claimValue)) {
|
||||
this.logger.verbose(
|
||||
`Processing array claim ${rule.claim} with ${claimValue.length} elements`
|
||||
);
|
||||
|
||||
// For array claims, check if ANY element in the array matches the rule
|
||||
const arrayResult = claimValue.some((element) => {
|
||||
// Skip non-string elements
|
||||
if (
|
||||
typeof element !== 'string' &&
|
||||
typeof element !== 'number' &&
|
||||
typeof element !== 'boolean'
|
||||
) {
|
||||
this.logger.verbose(`Skipping non-primitive element in array: ${typeof element}`);
|
||||
return false;
|
||||
}
|
||||
|
||||
const elementValue = String(element);
|
||||
return this.evaluateSingleValue(elementValue, rule);
|
||||
});
|
||||
|
||||
this.logger.verbose(`Array evaluation result for claim ${rule.claim}: ${arrayResult}`);
|
||||
return arrayResult;
|
||||
}
|
||||
|
||||
// Handle single value claims (string, number, boolean)
|
||||
const value = String(claimValue);
|
||||
this.logger.verbose(`Processing single value claim ${rule.claim}`);
|
||||
|
||||
return this.evaluateSingleValue(value, rule);
|
||||
}
|
||||
|
||||
private evaluateSingleValue(value: string, rule: OidcAuthorizationRule): boolean {
|
||||
let result: boolean;
|
||||
switch (rule.operator) {
|
||||
case AuthorizationOperator.EQUALS:
|
||||
result = rule.value.some((v) => value === v);
|
||||
this.logger.verbose(`EQUALS check: evaluated for claim ${rule.claim}: ${result}`);
|
||||
return result;
|
||||
|
||||
case AuthorizationOperator.CONTAINS:
|
||||
result = rule.value.some((v) => value.includes(v));
|
||||
this.logger.verbose(`CONTAINS check: evaluated for claim ${rule.claim}: ${result}`);
|
||||
return result;
|
||||
|
||||
case AuthorizationOperator.STARTS_WITH:
|
||||
result = rule.value.some((v) => value.startsWith(v));
|
||||
this.logger.verbose(`STARTS_WITH check: evaluated for claim ${rule.claim}: ${result}`);
|
||||
return result;
|
||||
|
||||
case AuthorizationOperator.ENDS_WITH:
|
||||
result = rule.value.some((v) => value.endsWith(v));
|
||||
this.logger.verbose(`ENDS_WITH check: evaluated for claim ${rule.claim}: ${result}`);
|
||||
return result;
|
||||
|
||||
default:
|
||||
this.logger.error(`Unknown authorization operator: ${rule.operator}`);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
import { UnauthorizedException } from '@nestjs/common';
|
||||
import { Test, TestingModule } from '@nestjs/testing';
|
||||
|
||||
import { decodeJwt } from 'jose';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import {
|
||||
JwtClaims,
|
||||
OidcClaimsService,
|
||||
} from '@app/unraid-api/graph/resolvers/sso/auth/oidc-claims.service.js';
|
||||
|
||||
// Mock jose
|
||||
vi.mock('jose', () => ({
|
||||
decodeJwt: vi.fn(),
|
||||
}));
|
||||
|
||||
describe('OidcClaimsService', () => {
|
||||
let service: OidcClaimsService;
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
providers: [OidcClaimsService],
|
||||
}).compile();
|
||||
|
||||
service = module.get<OidcClaimsService>(OidcClaimsService);
|
||||
});
|
||||
|
||||
describe('parseIdToken', () => {
|
||||
it('should parse valid ID token', () => {
|
||||
const mockClaims: JwtClaims = {
|
||||
sub: 'user123',
|
||||
email: 'user@example.com',
|
||||
name: 'Test User',
|
||||
iat: 1234567890,
|
||||
exp: 1234567890,
|
||||
};
|
||||
|
||||
(decodeJwt as any).mockReturnValue(mockClaims);
|
||||
|
||||
const result = service.parseIdToken('valid.jwt.token');
|
||||
|
||||
expect(result).toEqual(mockClaims);
|
||||
expect(decodeJwt).toHaveBeenCalledWith('valid.jwt.token');
|
||||
});
|
||||
|
||||
it('should return null when no token provided', () => {
|
||||
const result = service.parseIdToken(undefined);
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should return null when token parsing fails', () => {
|
||||
(decodeJwt as any).mockImplementation(() => {
|
||||
throw new Error('Invalid token');
|
||||
});
|
||||
|
||||
const result = service.parseIdToken('invalid.token');
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should handle claims with array values', () => {
|
||||
const mockClaims: JwtClaims = {
|
||||
sub: 'user123',
|
||||
groups: ['admin', 'user'],
|
||||
roles: ['role1', 'role2', 'role3'],
|
||||
};
|
||||
|
||||
(decodeJwt as any).mockReturnValue(mockClaims);
|
||||
|
||||
const result = service.parseIdToken('token.with.arrays');
|
||||
|
||||
expect(result).toEqual(mockClaims);
|
||||
});
|
||||
|
||||
it('should log warning for complex object claims', () => {
|
||||
const loggerSpy = vi.spyOn(service['logger'], 'warn');
|
||||
|
||||
const mockClaims: JwtClaims = {
|
||||
sub: 'user123',
|
||||
complexClaim: {
|
||||
nested: 'value',
|
||||
another: 'field',
|
||||
},
|
||||
};
|
||||
|
||||
(decodeJwt as any).mockReturnValue(mockClaims);
|
||||
|
||||
service.parseIdToken('token.with.complex');
|
||||
|
||||
expect(loggerSpy).toHaveBeenCalledWith(expect.stringContaining('complex object structure'));
|
||||
});
|
||||
|
||||
it('should handle Google-specific claims', () => {
|
||||
const mockClaims: JwtClaims = {
|
||||
sub: 'google-user-id',
|
||||
email: 'user@company.com',
|
||||
name: 'Google User',
|
||||
hd: 'company.com', // Google hosted domain
|
||||
};
|
||||
|
||||
(decodeJwt as any).mockReturnValue(mockClaims);
|
||||
|
||||
const result = service.parseIdToken('google.jwt.token');
|
||||
|
||||
expect(result).toEqual(mockClaims);
|
||||
expect(result?.hd).toBe('company.com');
|
||||
});
|
||||
});
|
||||
|
||||
describe('validateClaims', () => {
|
||||
it('should return user sub when claims are valid', () => {
|
||||
const claims: JwtClaims = {
|
||||
sub: 'user123',
|
||||
email: 'user@example.com',
|
||||
};
|
||||
|
||||
const result = service.validateClaims(claims);
|
||||
expect(result).toBe('user123');
|
||||
});
|
||||
|
||||
it('should throw UnauthorizedException when claims are null', () => {
|
||||
expect(() => service.validateClaims(null)).toThrow(UnauthorizedException);
|
||||
});
|
||||
|
||||
it('should throw UnauthorizedException when sub is missing', () => {
|
||||
const claims: JwtClaims = {
|
||||
email: 'user@example.com',
|
||||
name: 'User',
|
||||
};
|
||||
|
||||
expect(() => service.validateClaims(claims)).toThrow(UnauthorizedException);
|
||||
});
|
||||
|
||||
it('should throw UnauthorizedException when sub is empty', () => {
|
||||
const claims: JwtClaims = {
|
||||
sub: '',
|
||||
email: 'user@example.com',
|
||||
};
|
||||
|
||||
expect(() => service.validateClaims(claims)).toThrow(UnauthorizedException);
|
||||
});
|
||||
});
|
||||
|
||||
describe('extractUserInfo', () => {
|
||||
it('should extract basic user information', () => {
|
||||
const claims: JwtClaims = {
|
||||
sub: 'user123',
|
||||
email: 'user@example.com',
|
||||
name: 'Test User',
|
||||
};
|
||||
|
||||
const result = service.extractUserInfo(claims);
|
||||
|
||||
expect(result).toEqual({
|
||||
sub: 'user123',
|
||||
email: 'user@example.com',
|
||||
name: 'Test User',
|
||||
domain: undefined,
|
||||
});
|
||||
});
|
||||
|
||||
it('should extract Google hosted domain', () => {
|
||||
const claims: JwtClaims = {
|
||||
sub: 'google-user',
|
||||
email: 'user@company.com',
|
||||
name: 'Google User',
|
||||
hd: 'company.com',
|
||||
};
|
||||
|
||||
const result = service.extractUserInfo(claims);
|
||||
|
||||
expect(result).toEqual({
|
||||
sub: 'google-user',
|
||||
email: 'user@company.com',
|
||||
name: 'Google User',
|
||||
domain: 'company.com',
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle missing optional fields', () => {
|
||||
const claims: JwtClaims = {
|
||||
sub: 'user123',
|
||||
};
|
||||
|
||||
const result = service.extractUserInfo(claims);
|
||||
|
||||
expect(result).toEqual({
|
||||
sub: 'user123',
|
||||
email: undefined,
|
||||
name: undefined,
|
||||
domain: undefined,
|
||||
});
|
||||
});
|
||||
|
||||
it('should ignore extra claims', () => {
|
||||
const claims: JwtClaims = {
|
||||
sub: 'user123',
|
||||
email: 'user@example.com',
|
||||
name: 'Test User',
|
||||
extra: 'claim',
|
||||
another: 'field',
|
||||
groups: ['admin'],
|
||||
};
|
||||
|
||||
const result = service.extractUserInfo(claims);
|
||||
|
||||
expect(result).toEqual({
|
||||
sub: 'user123',
|
||||
email: 'user@example.com',
|
||||
name: 'Test User',
|
||||
domain: undefined,
|
||||
});
|
||||
expect(result).not.toHaveProperty('extra');
|
||||
expect(result).not.toHaveProperty('groups');
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,80 @@
|
||||
import { Injectable, Logger, UnauthorizedException } from '@nestjs/common';
|
||||
|
||||
import { decodeJwt } from 'jose';
|
||||
|
||||
export interface JwtClaims {
|
||||
sub?: string;
|
||||
email?: string;
|
||||
name?: string;
|
||||
hd?: string; // Google hosted domain
|
||||
[claim: string]: unknown;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class OidcClaimsService {
|
||||
private readonly logger = new Logger(OidcClaimsService.name);
|
||||
|
||||
parseIdToken(idToken: string | undefined): JwtClaims | null {
|
||||
if (!idToken) {
|
||||
this.logger.error('No ID token received from provider');
|
||||
return null;
|
||||
}
|
||||
|
||||
try {
|
||||
// Use jose to properly decode the JWT
|
||||
const claims = decodeJwt(idToken) as JwtClaims;
|
||||
|
||||
// Log claims safely without PII - only structure, not values
|
||||
if (claims) {
|
||||
const claimKeys = Object.keys(claims).join(', ');
|
||||
this.logger.debug(`ID token decoded successfully. Available claims: [${claimKeys}]`);
|
||||
|
||||
// Log claim types without exposing sensitive values
|
||||
for (const [key, value] of Object.entries(claims)) {
|
||||
const valueType = Array.isArray(value) ? `array[${value.length}]` : typeof value;
|
||||
|
||||
// Only log structure, not actual values (avoid PII)
|
||||
this.logger.debug(`Claim '${key}': type=${valueType}`);
|
||||
|
||||
// Check for unexpected claim types
|
||||
if (valueType === 'object' && value !== null && !Array.isArray(value)) {
|
||||
this.logger.warn(`Claim '${key}' contains complex object structure`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return claims;
|
||||
} catch (e) {
|
||||
this.logger.warn(`Failed to parse ID token: ${e}`);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
validateClaims(claims: JwtClaims | null): string {
|
||||
if (!claims?.sub) {
|
||||
this.logger.error(
|
||||
'No subject in token - claims available: ' +
|
||||
(claims ? Object.keys(claims).join(', ') : 'none')
|
||||
);
|
||||
throw new UnauthorizedException('No subject in token');
|
||||
}
|
||||
|
||||
const userSub = claims.sub;
|
||||
this.logger.debug(`Processing authentication for user: ${userSub}`);
|
||||
return userSub;
|
||||
}
|
||||
|
||||
extractUserInfo(claims: JwtClaims): {
|
||||
sub: string;
|
||||
email?: string;
|
||||
name?: string;
|
||||
domain?: string;
|
||||
} {
|
||||
return {
|
||||
sub: claims.sub!,
|
||||
email: claims.email,
|
||||
name: claims.name,
|
||||
domain: claims.hd,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,224 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
|
||||
import * as client from 'openid-client';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { OidcTokenExchangeService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-token-exchange.service.js';
|
||||
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
|
||||
|
||||
vi.mock('openid-client', () => ({
|
||||
authorizationCodeGrant: vi.fn(),
|
||||
allowInsecureRequests: vi.fn(),
|
||||
}));
|
||||
|
||||
describe('OidcTokenExchangeService', () => {
|
||||
let service: OidcTokenExchangeService;
|
||||
let mockConfig: client.Configuration;
|
||||
let mockProvider: OidcProvider;
|
||||
|
||||
beforeEach(() => {
|
||||
service = new OidcTokenExchangeService();
|
||||
|
||||
mockConfig = {
|
||||
serverMetadata: vi.fn().mockReturnValue({
|
||||
issuer: 'https://example.com',
|
||||
token_endpoint: 'https://example.com/token',
|
||||
response_types_supported: ['code'],
|
||||
grant_types_supported: ['authorization_code'],
|
||||
token_endpoint_auth_methods_supported: ['client_secret_post'],
|
||||
}),
|
||||
} as unknown as client.Configuration;
|
||||
|
||||
mockProvider = {
|
||||
id: 'test-provider',
|
||||
issuer: 'https://example.com',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
} as OidcProvider;
|
||||
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('exchangeCodeForTokens', () => {
|
||||
it('should handle malformed fullCallbackUrl gracefully', async () => {
|
||||
const code = 'test-code';
|
||||
const state = 'test-state';
|
||||
const redirectUri = 'https://example.com/callback';
|
||||
const malformedUrl = 'not://a valid url';
|
||||
|
||||
const mockTokens = {
|
||||
access_token: 'test-access-token',
|
||||
id_token: 'test-id-token',
|
||||
};
|
||||
|
||||
vi.mocked(client.authorizationCodeGrant).mockResolvedValue(mockTokens as any);
|
||||
|
||||
const loggerWarnSpy = vi.spyOn(Logger.prototype, 'warn').mockImplementation(() => {});
|
||||
const loggerDebugSpy = vi.spyOn(Logger.prototype, 'debug').mockImplementation(() => {});
|
||||
|
||||
const result = await service.exchangeCodeForTokens(
|
||||
mockConfig,
|
||||
mockProvider,
|
||||
code,
|
||||
state,
|
||||
redirectUri,
|
||||
malformedUrl
|
||||
);
|
||||
|
||||
expect(result).toEqual(mockTokens);
|
||||
expect(loggerWarnSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Failed to parse fullCallbackUrl'),
|
||||
expect.any(Error)
|
||||
);
|
||||
expect(client.authorizationCodeGrant).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle empty fullCallbackUrl without throwing', async () => {
|
||||
const code = 'test-code';
|
||||
const state = 'test-state';
|
||||
const redirectUri = 'https://example.com/callback';
|
||||
|
||||
const mockTokens = {
|
||||
access_token: 'test-access-token',
|
||||
id_token: 'test-id-token',
|
||||
};
|
||||
|
||||
vi.mocked(client.authorizationCodeGrant).mockResolvedValue(mockTokens as any);
|
||||
|
||||
const loggerWarnSpy = vi.spyOn(Logger.prototype, 'warn').mockImplementation(() => {});
|
||||
|
||||
const result = await service.exchangeCodeForTokens(
|
||||
mockConfig,
|
||||
mockProvider,
|
||||
code,
|
||||
state,
|
||||
redirectUri,
|
||||
''
|
||||
);
|
||||
|
||||
expect(result).toEqual(mockTokens);
|
||||
expect(loggerWarnSpy).not.toHaveBeenCalled();
|
||||
expect(client.authorizationCodeGrant).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle whitespace-only fullCallbackUrl without throwing', async () => {
|
||||
const code = 'test-code';
|
||||
const state = 'test-state';
|
||||
const redirectUri = 'https://example.com/callback';
|
||||
|
||||
const mockTokens = {
|
||||
access_token: 'test-access-token',
|
||||
id_token: 'test-id-token',
|
||||
};
|
||||
|
||||
vi.mocked(client.authorizationCodeGrant).mockResolvedValue(mockTokens as any);
|
||||
|
||||
const loggerWarnSpy = vi.spyOn(Logger.prototype, 'warn').mockImplementation(() => {});
|
||||
|
||||
const result = await service.exchangeCodeForTokens(
|
||||
mockConfig,
|
||||
mockProvider,
|
||||
code,
|
||||
state,
|
||||
redirectUri,
|
||||
' '
|
||||
);
|
||||
|
||||
expect(result).toEqual(mockTokens);
|
||||
expect(loggerWarnSpy).not.toHaveBeenCalled();
|
||||
expect(client.authorizationCodeGrant).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should copy parameters from valid fullCallbackUrl', async () => {
|
||||
const code = 'test-code';
|
||||
const state = 'test-state';
|
||||
const redirectUri = 'https://example.com/callback';
|
||||
const fullCallbackUrl =
|
||||
'https://example.com/callback?code=test-code&state=test-state&scope=openid&authuser=0';
|
||||
|
||||
const mockTokens = {
|
||||
access_token: 'test-access-token',
|
||||
id_token: 'test-id-token',
|
||||
};
|
||||
|
||||
vi.mocked(client.authorizationCodeGrant).mockResolvedValue(mockTokens as any);
|
||||
|
||||
const loggerWarnSpy = vi.spyOn(Logger.prototype, 'warn').mockImplementation(() => {});
|
||||
const loggerDebugSpy = vi.spyOn(Logger.prototype, 'debug').mockImplementation(() => {});
|
||||
|
||||
const result = await service.exchangeCodeForTokens(
|
||||
mockConfig,
|
||||
mockProvider,
|
||||
code,
|
||||
state,
|
||||
redirectUri,
|
||||
fullCallbackUrl
|
||||
);
|
||||
|
||||
expect(result).toEqual(mockTokens);
|
||||
expect(loggerWarnSpy).not.toHaveBeenCalled();
|
||||
|
||||
const authCodeGrantCall = vi.mocked(client.authorizationCodeGrant).mock.calls[0];
|
||||
const cleanUrl = authCodeGrantCall[1] as URL;
|
||||
|
||||
expect(cleanUrl.searchParams.get('scope')).toBe('openid');
|
||||
expect(cleanUrl.searchParams.get('authuser')).toBe('0');
|
||||
});
|
||||
|
||||
it('should handle undefined fullCallbackUrl', async () => {
|
||||
const code = 'test-code';
|
||||
const state = 'test-state';
|
||||
const redirectUri = 'https://example.com/callback';
|
||||
|
||||
const mockTokens = {
|
||||
access_token: 'test-access-token',
|
||||
id_token: 'test-id-token',
|
||||
};
|
||||
|
||||
vi.mocked(client.authorizationCodeGrant).mockResolvedValue(mockTokens as any);
|
||||
|
||||
const loggerWarnSpy = vi.spyOn(Logger.prototype, 'warn').mockImplementation(() => {});
|
||||
|
||||
const result = await service.exchangeCodeForTokens(
|
||||
mockConfig,
|
||||
mockProvider,
|
||||
code,
|
||||
state,
|
||||
redirectUri,
|
||||
undefined
|
||||
);
|
||||
|
||||
expect(result).toEqual(mockTokens);
|
||||
expect(loggerWarnSpy).not.toHaveBeenCalled();
|
||||
expect(client.authorizationCodeGrant).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle non-string fullCallbackUrl types gracefully', async () => {
|
||||
const code = 'test-code';
|
||||
const state = 'test-state';
|
||||
const redirectUri = 'https://example.com/callback';
|
||||
|
||||
const mockTokens = {
|
||||
access_token: 'test-access-token',
|
||||
id_token: 'test-id-token',
|
||||
};
|
||||
|
||||
vi.mocked(client.authorizationCodeGrant).mockResolvedValue(mockTokens as any);
|
||||
|
||||
const loggerWarnSpy = vi.spyOn(Logger.prototype, 'warn').mockImplementation(() => {});
|
||||
|
||||
const result = await service.exchangeCodeForTokens(
|
||||
mockConfig,
|
||||
mockProvider,
|
||||
code,
|
||||
state,
|
||||
redirectUri,
|
||||
123 as any
|
||||
);
|
||||
|
||||
expect(result).toEqual(mockTokens);
|
||||
expect(loggerWarnSpy).not.toHaveBeenCalled();
|
||||
expect(client.authorizationCodeGrant).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,174 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
|
||||
import * as client from 'openid-client';
|
||||
|
||||
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
|
||||
import { ErrorExtractor } from '@app/unraid-api/utils/error-extractor.util.js';
|
||||
|
||||
// Extended type for our internal use - openid-client v6 doesn't directly expose
|
||||
// skip options for aud/iss checks, so we'll handle validation errors differently
|
||||
type ExtendedGrantChecks = client.AuthorizationCodeGrantChecks;
|
||||
|
||||
@Injectable()
|
||||
export class OidcTokenExchangeService {
|
||||
private readonly logger = new Logger(OidcTokenExchangeService.name);
|
||||
|
||||
async exchangeCodeForTokens(
|
||||
config: client.Configuration,
|
||||
provider: OidcProvider,
|
||||
code: string,
|
||||
state: string,
|
||||
redirectUri: string,
|
||||
fullCallbackUrl?: string
|
||||
): Promise<client.TokenEndpointResponse> {
|
||||
this.logger.debug(`Provider ${provider.id} config loaded`);
|
||||
this.logger.debug(`Redirect URI: ${redirectUri}`);
|
||||
|
||||
// Build current URL for token exchange
|
||||
// CRITICAL: The URL used here MUST match the redirect_uri that was sent to the authorization endpoint
|
||||
// Google expects the exact same redirect_uri during token exchange
|
||||
const currentUrl = new URL(redirectUri);
|
||||
currentUrl.searchParams.set('code', code);
|
||||
currentUrl.searchParams.set('state', state);
|
||||
|
||||
// Copy additional parameters from the actual callback if provided
|
||||
if (fullCallbackUrl && typeof fullCallbackUrl === 'string' && fullCallbackUrl.trim()) {
|
||||
try {
|
||||
const actualUrl = new URL(fullCallbackUrl);
|
||||
// Copy over additional params that Google might have added (scope, authuser, prompt, etc)
|
||||
// but DO NOT change the base URL or path
|
||||
['scope', 'authuser', 'prompt', 'hd', 'session_state', 'iss'].forEach((param) => {
|
||||
const value = actualUrl.searchParams.get(param);
|
||||
if (value && !currentUrl.searchParams.has(param)) {
|
||||
currentUrl.searchParams.set(param, value);
|
||||
}
|
||||
});
|
||||
} catch (urlError) {
|
||||
this.logger.warn(`Failed to parse fullCallbackUrl: ${fullCallbackUrl}`, urlError);
|
||||
// Continue with the existing currentUrl flow without additional params
|
||||
}
|
||||
}
|
||||
|
||||
// Google returns iss in the response, openid-client v6 expects it
|
||||
// If not present, add it based on the provider's issuer
|
||||
if (!currentUrl.searchParams.has('iss') && provider.issuer) {
|
||||
currentUrl.searchParams.set('iss', provider.issuer);
|
||||
}
|
||||
|
||||
this.logger.debug(`Token exchange URL (matches redirect_uri): ${currentUrl.href}`);
|
||||
|
||||
// For openid-client v6, we need to prepare the authorization response
|
||||
const authorizationResponse = new URLSearchParams(currentUrl.search);
|
||||
|
||||
// Set the original client state for openid-client
|
||||
authorizationResponse.set('state', state);
|
||||
|
||||
// Create a new URL with the cleaned parameters
|
||||
const cleanUrl = new URL(redirectUri);
|
||||
cleanUrl.search = authorizationResponse.toString();
|
||||
|
||||
this.logger.debug(`Clean URL for token exchange: ${cleanUrl.href}`);
|
||||
|
||||
try {
|
||||
this.logger.debug(`Starting token exchange with openid-client`);
|
||||
this.logger.debug(`Config issuer: ${config.serverMetadata().issuer}`);
|
||||
this.logger.debug(`Config token endpoint: ${config.serverMetadata().token_endpoint}`);
|
||||
|
||||
// Log the complete token exchange request details
|
||||
const tokenEndpoint = config.serverMetadata().token_endpoint;
|
||||
this.logger.debug(`Full token endpoint URL: ${tokenEndpoint}`);
|
||||
this.logger.debug(`Authorization code: ${code.substring(0, 10)}...`);
|
||||
this.logger.debug(`Redirect URI in token request: ${redirectUri}`);
|
||||
this.logger.debug(`Client ID: ${provider.clientId}`);
|
||||
this.logger.debug(`Client secret configured: ${provider.clientSecret ? 'Yes' : 'No'}`);
|
||||
this.logger.debug(`Expected state value: ${state}`);
|
||||
|
||||
// Log the server metadata to check for any configuration issues
|
||||
const metadata = config.serverMetadata();
|
||||
this.logger.debug(
|
||||
`Server supports response types: ${metadata.response_types_supported?.join(', ') || 'not specified'}`
|
||||
);
|
||||
this.logger.debug(
|
||||
`Server grant types: ${metadata.grant_types_supported?.join(', ') || 'not specified'}`
|
||||
);
|
||||
this.logger.debug(
|
||||
`Token endpoint auth methods: ${metadata.token_endpoint_auth_methods_supported?.join(', ') || 'not specified'}`
|
||||
);
|
||||
|
||||
// For HTTP endpoints, we need to call allowInsecureRequests on the config
|
||||
if (provider.issuer) {
|
||||
try {
|
||||
const serverUrl = new URL(provider.issuer);
|
||||
if (serverUrl.protocol === 'http:') {
|
||||
this.logger.debug(
|
||||
`Allowing insecure requests for HTTP endpoint: ${provider.id}`
|
||||
);
|
||||
// allowInsecureRequests is deprecated but still needed for HTTP endpoints
|
||||
client.allowInsecureRequests(config);
|
||||
}
|
||||
} catch (error) {
|
||||
this.logger.warn(
|
||||
`Invalid issuer URL for provider ${provider.id}: ${provider.issuer}`
|
||||
);
|
||||
// Continue without special HTTP options
|
||||
}
|
||||
}
|
||||
|
||||
const requestChecks: ExtendedGrantChecks = {
|
||||
expectedState: state,
|
||||
};
|
||||
|
||||
// Log what we're about to send
|
||||
this.logger.debug(`Executing authorizationCodeGrant with:`);
|
||||
this.logger.debug(`- Clean URL: ${cleanUrl.href}`);
|
||||
this.logger.debug(`- Expected state: ${state}`);
|
||||
this.logger.debug(`- Grant type: authorization_code`);
|
||||
|
||||
const tokens = await client.authorizationCodeGrant(config, cleanUrl, requestChecks);
|
||||
|
||||
this.logger.debug(
|
||||
`Token exchange successful, received tokens: ${Object.keys(tokens).join(', ')}`
|
||||
);
|
||||
|
||||
return tokens;
|
||||
} catch (tokenError) {
|
||||
// Extract and log error details using the utility
|
||||
const extracted = ErrorExtractor.extract(tokenError);
|
||||
this.logger.error('Token exchange failed');
|
||||
ErrorExtractor.formatForLogging(extracted, this.logger);
|
||||
|
||||
// Special handling for content-type and parsing errors
|
||||
if (ErrorExtractor.isOAuthResponseError(extracted)) {
|
||||
this.logger.error('Token endpoint returned invalid or non-JSON response.');
|
||||
this.logger.error('This typically means:');
|
||||
this.logger.error(
|
||||
'1. The token endpoint URL is incorrect (check for typos or wrong paths)'
|
||||
);
|
||||
this.logger.error('2. The server returned an HTML error page instead of JSON');
|
||||
this.logger.error('3. Authentication failed (invalid client_id or client_secret)');
|
||||
this.logger.error('4. A proxy/firewall is intercepting the request');
|
||||
this.logger.error('5. The OAuth server returned malformed JSON');
|
||||
this.logger.error(
|
||||
`Configured token endpoint: ${config.serverMetadata().token_endpoint}`
|
||||
);
|
||||
this.logger.error('Please verify your OIDC provider configuration.');
|
||||
}
|
||||
|
||||
// Check if error message contains the "unexpected JWT claim" text
|
||||
if (ErrorExtractor.isJwtClaimError(extracted)) {
|
||||
this.logger.error(
|
||||
`unexpected JWT claim value encountered during token validation by openid-client`
|
||||
);
|
||||
this.logger.error(
|
||||
`This error typically means the 'iss' claim in the JWT doesn't match the expected issuer`
|
||||
);
|
||||
this.logger.error(`Check that your provider's issuer URL is configured correctly`);
|
||||
this.logger.error(`Expected issuer: ${config.serverMetadata().issuer}`);
|
||||
this.logger.error(`Provider configured issuer: ${provider.issuer}`);
|
||||
}
|
||||
|
||||
// Re-throw the original error with all its properties intact
|
||||
throw tokenError;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,267 @@
|
||||
import { Test, TestingModule } from '@nestjs/testing';
|
||||
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { OidcClientConfigService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-client-config.service.js';
|
||||
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/core/oidc-validation.service.js';
|
||||
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
|
||||
|
||||
describe('OidcClientConfigService', () => {
|
||||
let service: OidcClientConfigService;
|
||||
let validationService: any;
|
||||
|
||||
beforeEach(async () => {
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
providers: [
|
||||
OidcClientConfigService,
|
||||
{
|
||||
provide: OidcValidationService,
|
||||
useValue: {
|
||||
performDiscovery: vi.fn(),
|
||||
},
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
service = module.get<OidcClientConfigService>(OidcClientConfigService);
|
||||
validationService = module.get(OidcValidationService);
|
||||
});
|
||||
|
||||
describe('Manual Configuration', () => {
|
||||
it('should create manual configuration when discovery fails but manual endpoints are provided', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'manual-provider',
|
||||
name: 'Manual Provider',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
issuer: 'https://manual.example.com',
|
||||
authorizationEndpoint: 'https://manual.example.com/auth',
|
||||
tokenEndpoint: 'https://manual.example.com/token',
|
||||
jwksUri: 'https://manual.example.com/jwks',
|
||||
scopes: ['openid', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
// Mock discovery to fail
|
||||
validationService.performDiscovery.mockRejectedValue(new Error('Discovery failed'));
|
||||
|
||||
const config = await service.getOrCreateConfig(provider);
|
||||
|
||||
// Verify the configuration was created with the correct endpoints
|
||||
expect(config).toBeDefined();
|
||||
expect(config.serverMetadata().authorization_endpoint).toBe(
|
||||
'https://manual.example.com/auth'
|
||||
);
|
||||
expect(config.serverMetadata().token_endpoint).toBe('https://manual.example.com/token');
|
||||
expect(config.serverMetadata().jwks_uri).toBe('https://manual.example.com/jwks');
|
||||
expect(config.serverMetadata().issuer).toBe('https://manual.example.com');
|
||||
});
|
||||
|
||||
it('should create manual configuration with fallback issuer when not provided', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'manual-provider-no-issuer',
|
||||
name: 'Manual Provider No Issuer',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
issuer: '', // Empty issuer should skip discovery and use manual endpoints
|
||||
authorizationEndpoint: 'https://manual.example.com/auth',
|
||||
tokenEndpoint: 'https://manual.example.com/token',
|
||||
scopes: ['openid', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
const config = await service.getOrCreateConfig(provider);
|
||||
|
||||
// Verify the configuration was created with inferred issuer from endpoints
|
||||
expect(config).toBeDefined();
|
||||
expect(config.serverMetadata().issuer).toBe('https://manual.example.com');
|
||||
expect(config.serverMetadata().authorization_endpoint).toBe(
|
||||
'https://manual.example.com/auth'
|
||||
);
|
||||
expect(config.serverMetadata().token_endpoint).toBe('https://manual.example.com/token');
|
||||
});
|
||||
|
||||
it('should handle manual configuration with client secret properly', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'manual-with-secret',
|
||||
name: 'Manual With Secret',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'secret-123',
|
||||
issuer: 'https://manual.example.com',
|
||||
authorizationEndpoint: 'https://manual.example.com/auth',
|
||||
tokenEndpoint: 'https://manual.example.com/token',
|
||||
scopes: ['openid', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
// Mock discovery to fail
|
||||
validationService.performDiscovery.mockRejectedValue(new Error('Discovery failed'));
|
||||
|
||||
const config = await service.getOrCreateConfig(provider);
|
||||
|
||||
// Verify configuration was created successfully
|
||||
expect(config).toBeDefined();
|
||||
expect(config.clientMetadata().client_secret).toBe('secret-123');
|
||||
});
|
||||
|
||||
it('should handle manual configuration without client secret (public client)', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'manual-public-client',
|
||||
name: 'Manual Public Client',
|
||||
clientId: 'public-client-id',
|
||||
// No client secret
|
||||
issuer: 'https://manual.example.com',
|
||||
authorizationEndpoint: 'https://manual.example.com/auth',
|
||||
tokenEndpoint: 'https://manual.example.com/token',
|
||||
scopes: ['openid', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
// Mock discovery to fail
|
||||
validationService.performDiscovery.mockRejectedValue(new Error('Discovery failed'));
|
||||
|
||||
const config = await service.getOrCreateConfig(provider);
|
||||
|
||||
// Verify configuration was created successfully
|
||||
expect(config).toBeDefined();
|
||||
expect(config.clientMetadata().client_secret).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should cache configurations', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'cached-provider',
|
||||
name: 'Cached Provider',
|
||||
clientId: 'test-client-id',
|
||||
issuer: '',
|
||||
authorizationEndpoint: 'https://cached.example.com/auth',
|
||||
tokenEndpoint: 'https://cached.example.com/token',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
// First call
|
||||
const config1 = await service.getOrCreateConfig(provider);
|
||||
|
||||
// Second call - should return cached value
|
||||
const config2 = await service.getOrCreateConfig(provider);
|
||||
|
||||
// Should be the exact same object
|
||||
expect(config1).toBe(config2);
|
||||
expect(service.getCacheSize()).toBe(1);
|
||||
});
|
||||
|
||||
it('should clear cache for specific provider', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'provider-to-clear',
|
||||
name: 'Provider to Clear',
|
||||
clientId: 'test-client-id',
|
||||
issuer: '',
|
||||
authorizationEndpoint: 'https://clear.example.com/auth',
|
||||
tokenEndpoint: 'https://clear.example.com/token',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
await service.getOrCreateConfig(provider);
|
||||
expect(service.getCacheSize()).toBe(1);
|
||||
|
||||
service.clearCache('provider-to-clear');
|
||||
expect(service.getCacheSize()).toBe(0);
|
||||
});
|
||||
|
||||
it('should clear entire cache', async () => {
|
||||
const provider1: OidcProvider = {
|
||||
id: 'provider1',
|
||||
name: 'Provider 1',
|
||||
clientId: 'client1',
|
||||
issuer: '',
|
||||
authorizationEndpoint: 'https://p1.example.com/auth',
|
||||
tokenEndpoint: 'https://p1.example.com/token',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
const provider2: OidcProvider = {
|
||||
id: 'provider2',
|
||||
name: 'Provider 2',
|
||||
clientId: 'client2',
|
||||
issuer: '',
|
||||
authorizationEndpoint: 'https://p2.example.com/auth',
|
||||
tokenEndpoint: 'https://p2.example.com/token',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
await service.getOrCreateConfig(provider1);
|
||||
await service.getOrCreateConfig(provider2);
|
||||
expect(service.getCacheSize()).toBe(2);
|
||||
|
||||
service.clearCache();
|
||||
expect(service.getCacheSize()).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Discovery Configuration', () => {
|
||||
it('should use discovery when issuer is provided', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'discovery-provider',
|
||||
name: 'Discovery Provider',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-secret',
|
||||
issuer: 'https://discovery.example.com',
|
||||
scopes: ['openid', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
const mockConfig = {
|
||||
serverMetadata: vi.fn().mockReturnValue({
|
||||
issuer: 'https://discovery.example.com',
|
||||
authorization_endpoint: 'https://discovery.example.com/authorize',
|
||||
token_endpoint: 'https://discovery.example.com/token',
|
||||
jwks_uri: 'https://discovery.example.com/.well-known/jwks.json',
|
||||
userinfo_endpoint: 'https://discovery.example.com/userinfo',
|
||||
}),
|
||||
clientMetadata: vi.fn().mockReturnValue({}),
|
||||
};
|
||||
|
||||
validationService.performDiscovery.mockResolvedValue(mockConfig);
|
||||
|
||||
const config = await service.getOrCreateConfig(provider);
|
||||
|
||||
expect(validationService.performDiscovery).toHaveBeenCalledWith(provider, undefined);
|
||||
expect(config).toBe(mockConfig);
|
||||
});
|
||||
|
||||
it('should allow HTTP for discovery when issuer uses HTTP', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'http-discovery-provider',
|
||||
name: 'HTTP Discovery Provider',
|
||||
clientId: 'test-client-id',
|
||||
issuer: 'http://discovery.example.com',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
const mockConfig = {
|
||||
serverMetadata: vi.fn().mockReturnValue({
|
||||
issuer: 'http://discovery.example.com',
|
||||
authorization_endpoint: 'http://discovery.example.com/authorize',
|
||||
token_endpoint: 'http://discovery.example.com/token',
|
||||
}),
|
||||
clientMetadata: vi.fn().mockReturnValue({}),
|
||||
};
|
||||
|
||||
validationService.performDiscovery.mockResolvedValue(mockConfig);
|
||||
|
||||
const config = await service.getOrCreateConfig(provider);
|
||||
|
||||
expect(validationService.performDiscovery).toHaveBeenCalledWith(
|
||||
provider,
|
||||
expect.objectContaining({
|
||||
execute: expect.any(Array),
|
||||
})
|
||||
);
|
||||
expect(config).toBe(mockConfig);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,168 @@
|
||||
import { Injectable, Logger, UnauthorizedException } from '@nestjs/common';
|
||||
|
||||
import * as client from 'openid-client';
|
||||
|
||||
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/core/oidc-validation.service.js';
|
||||
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
|
||||
import { ErrorExtractor } from '@app/unraid-api/utils/error-extractor.util.js';
|
||||
|
||||
@Injectable()
|
||||
export class OidcClientConfigService {
|
||||
private readonly logger = new Logger(OidcClientConfigService.name);
|
||||
private readonly configCache = new Map<string, client.Configuration>();
|
||||
|
||||
constructor(private readonly validationService: OidcValidationService) {}
|
||||
|
||||
async getOrCreateConfig(provider: OidcProvider): Promise<client.Configuration> {
|
||||
const cacheKey = provider.id;
|
||||
|
||||
if (this.configCache.has(cacheKey)) {
|
||||
return this.configCache.get(cacheKey)!;
|
||||
}
|
||||
|
||||
try {
|
||||
// Use the validation service to perform discovery with HTTP support
|
||||
if (provider.issuer) {
|
||||
this.logger.debug(`Attempting discovery for ${provider.id} at ${provider.issuer}`);
|
||||
|
||||
// Create client options with HTTP support if needed
|
||||
const serverUrl = new URL(provider.issuer);
|
||||
let clientOptions: client.DiscoveryRequestOptions | undefined;
|
||||
if (serverUrl.protocol === 'http:') {
|
||||
this.logger.debug(`Allowing HTTP for ${provider.id} as specified by user`);
|
||||
clientOptions = {
|
||||
execute: [client.allowInsecureRequests],
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const config = await this.validationService.performDiscovery(
|
||||
provider,
|
||||
clientOptions
|
||||
);
|
||||
this.logger.debug(`Discovery successful for ${provider.id}`);
|
||||
this.logger.debug(
|
||||
`Authorization endpoint: ${config.serverMetadata().authorization_endpoint}`
|
||||
);
|
||||
this.logger.debug(`Token endpoint: ${config.serverMetadata().token_endpoint}`);
|
||||
this.logger.debug(`JWKS URI: ${config.serverMetadata().jwks_uri || 'Not provided'}`);
|
||||
this.logger.debug(
|
||||
`Userinfo endpoint: ${config.serverMetadata().userinfo_endpoint || 'Not provided'}`
|
||||
);
|
||||
this.configCache.set(cacheKey, config);
|
||||
return config;
|
||||
} catch (discoveryError) {
|
||||
const extracted = ErrorExtractor.extract(discoveryError);
|
||||
this.logger.warn(`Discovery failed for ${provider.id}: ${extracted.message}`);
|
||||
|
||||
// Log more details about the discovery error
|
||||
const discoveryUrl = `${provider.issuer}/.well-known/openid-configuration`;
|
||||
this.logger.debug(`Discovery URL attempted: ${discoveryUrl}`);
|
||||
|
||||
// Use error extractor for consistent logging
|
||||
ErrorExtractor.formatForLogging(extracted, this.logger);
|
||||
|
||||
// If discovery fails but we have manual endpoints, use them
|
||||
if (provider.authorizationEndpoint && provider.tokenEndpoint) {
|
||||
this.logger.log(`Using manual endpoints for ${provider.id}`);
|
||||
return this.createManualConfiguration(provider, cacheKey);
|
||||
} else {
|
||||
throw new Error(
|
||||
`OIDC discovery failed and no manual endpoints provided for ${provider.id}`
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Manual configuration when no issuer is provided
|
||||
if (provider.authorizationEndpoint && provider.tokenEndpoint) {
|
||||
this.logger.log(`Using manual endpoints for ${provider.id} (no issuer provided)`);
|
||||
return this.createManualConfiguration(provider, cacheKey);
|
||||
}
|
||||
|
||||
// If we reach here, neither discovery nor manual endpoints are available
|
||||
throw new Error(
|
||||
`No configuration method available for ${provider.id}: requires either valid issuer for discovery or manual endpoints`
|
||||
);
|
||||
} catch (error) {
|
||||
const extracted = ErrorExtractor.extract(error);
|
||||
this.logger.error(
|
||||
`Failed to create OIDC configuration for ${provider.id}: ${extracted.message}`
|
||||
);
|
||||
|
||||
// Log more details in debug mode
|
||||
if (extracted.stack) {
|
||||
this.logger.debug(`Stack trace: ${extracted.stack}`);
|
||||
}
|
||||
|
||||
throw new UnauthorizedException('Provider configuration error');
|
||||
}
|
||||
}
|
||||
|
||||
private createManualConfiguration(provider: OidcProvider, cacheKey: string): client.Configuration {
|
||||
// Create manual configuration with a valid issuer URL
|
||||
const inferredIssuer =
|
||||
provider.issuer && provider.issuer.trim() !== ''
|
||||
? provider.issuer
|
||||
: new URL(provider.authorizationEndpoint ?? provider.tokenEndpoint!).origin;
|
||||
const serverMetadata: client.ServerMetadata = {
|
||||
issuer: inferredIssuer,
|
||||
authorization_endpoint: provider.authorizationEndpoint!,
|
||||
token_endpoint: provider.tokenEndpoint!,
|
||||
jwks_uri: provider.jwksUri,
|
||||
};
|
||||
|
||||
const clientMetadata: Partial<client.ClientMetadata> = {
|
||||
client_secret: provider.clientSecret,
|
||||
};
|
||||
|
||||
// Configure client auth method
|
||||
const clientAuth = provider.clientSecret
|
||||
? client.ClientSecretPost(provider.clientSecret)
|
||||
: client.None();
|
||||
|
||||
try {
|
||||
const config = new client.Configuration(
|
||||
serverMetadata,
|
||||
provider.clientId,
|
||||
clientMetadata,
|
||||
clientAuth
|
||||
);
|
||||
|
||||
// Allow HTTP if any configured endpoint uses http
|
||||
const endpoints = [
|
||||
serverMetadata.authorization_endpoint,
|
||||
serverMetadata.token_endpoint,
|
||||
].filter(Boolean) as string[];
|
||||
const hasHttp = endpoints.some((e) => new URL(e).protocol === 'http:');
|
||||
if (hasHttp) {
|
||||
this.logger.debug(`Allowing HTTP for manual endpoints on ${provider.id}`);
|
||||
// allowInsecureRequests is deprecated but still needed for HTTP endpoints
|
||||
client.allowInsecureRequests(config);
|
||||
}
|
||||
|
||||
this.logger.debug(`Manual configuration created for ${provider.id}`);
|
||||
this.logger.debug(`Authorization endpoint: ${serverMetadata.authorization_endpoint}`);
|
||||
this.logger.debug(`Token endpoint: ${serverMetadata.token_endpoint}`);
|
||||
|
||||
this.configCache.set(cacheKey, config);
|
||||
return config;
|
||||
} catch (manualConfigError) {
|
||||
const extracted = ErrorExtractor.extract(manualConfigError);
|
||||
this.logger.error(`Failed to create manual configuration: ${extracted.message}`);
|
||||
throw new Error(`Manual configuration failed for ${provider.id}`);
|
||||
}
|
||||
}
|
||||
|
||||
clearCache(providerId?: string): void {
|
||||
if (providerId) {
|
||||
this.configCache.delete(providerId);
|
||||
} else {
|
||||
this.configCache.clear();
|
||||
}
|
||||
}
|
||||
|
||||
getCacheSize(): number {
|
||||
return this.configCache.size;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
|
||||
import { OidcClientConfigService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-client-config.service.js';
|
||||
import { OidcRedirectUriService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-redirect-uri.service.js';
|
||||
import { OidcBaseModule } from '@app/unraid-api/graph/resolvers/sso/core/oidc-base.module.js';
|
||||
|
||||
@Module({
|
||||
imports: [OidcBaseModule],
|
||||
providers: [OidcClientConfigService, OidcRedirectUriService],
|
||||
exports: [OidcClientConfigService, OidcRedirectUriService],
|
||||
})
|
||||
export class OidcClientModule {}
|
||||
@@ -0,0 +1,222 @@
|
||||
import { UnauthorizedException } from '@nestjs/common';
|
||||
import { Test, TestingModule } from '@nestjs/testing';
|
||||
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { OidcRedirectUriService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-redirect-uri.service.js';
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
|
||||
import { validateRedirectUri } from '@app/unraid-api/utils/redirect-uri-validator.js';
|
||||
|
||||
// Mock the redirect URI validator
|
||||
vi.mock('@app/unraid-api/utils/redirect-uri-validator.js', () => ({
|
||||
validateRedirectUri: vi.fn(),
|
||||
}));
|
||||
|
||||
describe('OidcRedirectUriService', () => {
|
||||
let service: OidcRedirectUriService;
|
||||
let oidcConfig: any;
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
providers: [
|
||||
OidcRedirectUriService,
|
||||
{
|
||||
provide: OidcConfigPersistence,
|
||||
useValue: {
|
||||
getConfig: vi.fn().mockResolvedValue({
|
||||
providers: [],
|
||||
defaultAllowedOrigins: ['https://allowed.example.com'],
|
||||
}),
|
||||
},
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
service = module.get<OidcRedirectUriService>(OidcRedirectUriService);
|
||||
oidcConfig = module.get(OidcConfigPersistence);
|
||||
});
|
||||
|
||||
describe('getRedirectUri', () => {
|
||||
it('should return valid redirect URI when validation passes', async () => {
|
||||
const requestOrigin = 'https://example.com';
|
||||
const requestHeaders = {
|
||||
'x-forwarded-proto': 'https',
|
||||
'x-forwarded-host': 'example.com',
|
||||
};
|
||||
|
||||
(validateRedirectUri as any).mockReturnValue({
|
||||
isValid: true,
|
||||
validatedUri: 'https://example.com',
|
||||
});
|
||||
|
||||
const result = await service.getRedirectUri(requestOrigin, requestHeaders);
|
||||
|
||||
expect(result).toBe('https://example.com/graphql/api/auth/oidc/callback');
|
||||
expect(validateRedirectUri).toHaveBeenCalledWith(
|
||||
'https://example.com',
|
||||
'https',
|
||||
'example.com',
|
||||
expect.anything(),
|
||||
['https://allowed.example.com']
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw UnauthorizedException when validation fails', async () => {
|
||||
const requestOrigin = 'https://evil.com';
|
||||
const requestHeaders = {
|
||||
'x-forwarded-proto': 'https',
|
||||
'x-forwarded-host': 'example.com',
|
||||
};
|
||||
|
||||
(validateRedirectUri as any).mockReturnValue({
|
||||
isValid: false,
|
||||
reason: 'Origin not allowed',
|
||||
});
|
||||
|
||||
await expect(service.getRedirectUri(requestOrigin, requestHeaders)).rejects.toThrow(
|
||||
UnauthorizedException
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle missing allowed origins', async () => {
|
||||
oidcConfig.getConfig.mockResolvedValue({
|
||||
providers: [],
|
||||
defaultAllowedOrigins: undefined,
|
||||
});
|
||||
|
||||
const requestOrigin = 'https://example.com';
|
||||
const requestHeaders = {
|
||||
'x-forwarded-proto': 'https',
|
||||
'x-forwarded-host': 'example.com',
|
||||
};
|
||||
|
||||
(validateRedirectUri as any).mockReturnValue({
|
||||
isValid: true,
|
||||
validatedUri: 'https://example.com',
|
||||
});
|
||||
|
||||
const result = await service.getRedirectUri(requestOrigin, requestHeaders);
|
||||
|
||||
expect(result).toBe('https://example.com/graphql/api/auth/oidc/callback');
|
||||
expect(validateRedirectUri).toHaveBeenCalledWith(
|
||||
'https://example.com',
|
||||
'https',
|
||||
'example.com',
|
||||
expect.anything(),
|
||||
undefined
|
||||
);
|
||||
});
|
||||
|
||||
it('should extract protocol from headers correctly', async () => {
|
||||
const requestOrigin = 'https://example.com';
|
||||
const requestHeaders = {
|
||||
'x-forwarded-proto': ['https', 'http'],
|
||||
host: 'example.com',
|
||||
};
|
||||
|
||||
(validateRedirectUri as any).mockReturnValue({
|
||||
isValid: true,
|
||||
validatedUri: 'https://example.com',
|
||||
});
|
||||
|
||||
const result = await service.getRedirectUri(requestOrigin, requestHeaders);
|
||||
|
||||
expect(result).toBe('https://example.com/graphql/api/auth/oidc/callback');
|
||||
expect(validateRedirectUri).toHaveBeenCalledWith(
|
||||
'https://example.com',
|
||||
'https', // Should use first value from array
|
||||
'example.com',
|
||||
expect.anything(),
|
||||
expect.anything()
|
||||
);
|
||||
});
|
||||
|
||||
it('should use host header as fallback', async () => {
|
||||
const requestOrigin = 'https://example.com';
|
||||
const requestHeaders = {
|
||||
host: 'example.com',
|
||||
};
|
||||
|
||||
(validateRedirectUri as any).mockReturnValue({
|
||||
isValid: true,
|
||||
validatedUri: 'https://example.com',
|
||||
});
|
||||
|
||||
const result = await service.getRedirectUri(requestOrigin, requestHeaders);
|
||||
|
||||
expect(result).toBe('https://example.com/graphql/api/auth/oidc/callback');
|
||||
expect(validateRedirectUri).toHaveBeenCalledWith(
|
||||
'https://example.com',
|
||||
'https', // Inferred from requestOrigin when x-forwarded-proto not present
|
||||
'example.com',
|
||||
expect.anything(),
|
||||
expect.anything()
|
||||
);
|
||||
});
|
||||
|
||||
it('should prefer x-forwarded-host over host header', async () => {
|
||||
const requestOrigin = 'https://example.com';
|
||||
const requestHeaders = {
|
||||
'x-forwarded-host': 'forwarded.example.com',
|
||||
host: 'original.example.com',
|
||||
};
|
||||
|
||||
(validateRedirectUri as any).mockReturnValue({
|
||||
isValid: true,
|
||||
validatedUri: 'https://example.com',
|
||||
});
|
||||
|
||||
const result = await service.getRedirectUri(requestOrigin, requestHeaders);
|
||||
|
||||
expect(result).toBe('https://example.com/graphql/api/auth/oidc/callback');
|
||||
expect(validateRedirectUri).toHaveBeenCalledWith(
|
||||
'https://example.com',
|
||||
'https', // Inferred from requestOrigin when x-forwarded-proto not present
|
||||
'forwarded.example.com', // Should use x-forwarded-host
|
||||
expect.anything(),
|
||||
expect.anything()
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw when URL construction fails', async () => {
|
||||
const requestOrigin = 'https://example.com';
|
||||
const requestHeaders = {};
|
||||
|
||||
(validateRedirectUri as any).mockReturnValue({
|
||||
isValid: true,
|
||||
validatedUri: 'invalid-url', // Invalid URL
|
||||
});
|
||||
|
||||
await expect(service.getRedirectUri(requestOrigin, requestHeaders)).rejects.toThrow(
|
||||
UnauthorizedException
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle array values in headers correctly', async () => {
|
||||
const requestOrigin = 'https://example.com';
|
||||
const requestHeaders = {
|
||||
'x-forwarded-proto': ['https'],
|
||||
'x-forwarded-host': ['forwarded.example.com', 'another.example.com'],
|
||||
host: ['original.example.com'],
|
||||
};
|
||||
|
||||
(validateRedirectUri as any).mockReturnValue({
|
||||
isValid: true,
|
||||
validatedUri: 'https://example.com',
|
||||
});
|
||||
|
||||
const result = await service.getRedirectUri(requestOrigin, requestHeaders);
|
||||
|
||||
expect(result).toBe('https://example.com/graphql/api/auth/oidc/callback');
|
||||
expect(validateRedirectUri).toHaveBeenCalledWith(
|
||||
'https://example.com',
|
||||
'https',
|
||||
'forwarded.example.com', // Should use first value from array
|
||||
expect.anything(),
|
||||
expect.anything()
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,97 @@
|
||||
import { Injectable, Logger, UnauthorizedException } from '@nestjs/common';
|
||||
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
|
||||
import { validateRedirectUri } from '@app/unraid-api/utils/redirect-uri-validator.js';
|
||||
|
||||
@Injectable()
|
||||
export class OidcRedirectUriService {
|
||||
private readonly logger = new Logger(OidcRedirectUriService.name);
|
||||
private readonly CALLBACK_PATH = '/graphql/api/auth/oidc/callback';
|
||||
|
||||
constructor(private readonly oidcConfig: OidcConfigPersistence) {}
|
||||
|
||||
async getRedirectUri(
|
||||
requestOrigin: string,
|
||||
requestHeaders: Record<string, string | string[] | undefined>
|
||||
): Promise<string> {
|
||||
// Extract protocol and host from headers for validation
|
||||
const { protocol, host } = this.getRequestOriginInfo(requestHeaders, requestOrigin);
|
||||
|
||||
// Get the global allowed origins from OIDC config
|
||||
const config = await this.oidcConfig.getConfig();
|
||||
const allowedOrigins = config?.defaultAllowedOrigins;
|
||||
|
||||
// Debug logging to trace the issue
|
||||
this.logger.debug(
|
||||
`OIDC Config loaded: ${JSON.stringify(config ? { hasConfig: true, allowedOrigins } : { hasConfig: false })}`
|
||||
);
|
||||
this.logger.debug(
|
||||
`Validating redirect URI: ${requestOrigin} against host: ${protocol}://${host}`
|
||||
);
|
||||
this.logger.debug(`Allowed origins from config: ${JSON.stringify(allowedOrigins || [])}`);
|
||||
|
||||
// Validate the provided requestOrigin using centralized validator
|
||||
// Pass the global allowed origins if available
|
||||
const validation = validateRedirectUri(
|
||||
requestOrigin,
|
||||
protocol,
|
||||
host,
|
||||
this.logger,
|
||||
allowedOrigins
|
||||
);
|
||||
|
||||
if (!validation.isValid) {
|
||||
this.logger.warn(`Invalid redirect_uri in GraphQL OIDC flow: ${validation.reason}`);
|
||||
throw new UnauthorizedException(
|
||||
`Invalid redirect_uri: ${requestOrigin}. Please add this callback URI to Settings → Management Access → Allowed Redirect URIs`
|
||||
);
|
||||
}
|
||||
|
||||
// Ensure the validated URI has the correct callback path
|
||||
try {
|
||||
const url = new URL(validation.validatedUri);
|
||||
// Only use origin to prevent path manipulation
|
||||
const redirectUri = `${url.origin}${this.CALLBACK_PATH}`;
|
||||
this.logger.debug(`Using validated redirect URI: ${redirectUri}`);
|
||||
return redirectUri;
|
||||
} catch (e) {
|
||||
this.logger.error(
|
||||
`Failed to construct redirect URI from validated URI: ${validation.validatedUri}`
|
||||
);
|
||||
throw new UnauthorizedException('Invalid redirect_uri');
|
||||
}
|
||||
}
|
||||
|
||||
private getRequestOriginInfo(
|
||||
requestHeaders: Record<string, string | string[] | undefined>,
|
||||
requestOrigin?: string
|
||||
): {
|
||||
protocol: string;
|
||||
host: string | undefined;
|
||||
} {
|
||||
// Extract protocol from x-forwarded-proto or infer from requestOrigin, default to http
|
||||
const forwardedProto = requestHeaders['x-forwarded-proto'];
|
||||
const protocol = forwardedProto
|
||||
? Array.isArray(forwardedProto)
|
||||
? forwardedProto[0]
|
||||
: forwardedProto
|
||||
: requestOrigin?.startsWith('https')
|
||||
? 'https'
|
||||
: 'http';
|
||||
|
||||
// Extract host from x-forwarded-host or host header
|
||||
const forwardedHost = requestHeaders['x-forwarded-host'];
|
||||
const hostHeader = requestHeaders['host'];
|
||||
const host = forwardedHost
|
||||
? Array.isArray(forwardedHost)
|
||||
? forwardedHost[0]
|
||||
: forwardedHost
|
||||
: hostHeader
|
||||
? Array.isArray(hostHeader)
|
||||
? hostHeader[0]
|
||||
: hostHeader
|
||||
: undefined;
|
||||
|
||||
return { protocol, host };
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
|
||||
import { UserSettingsModule } from '@unraid/shared/services/user-settings.js';
|
||||
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
|
||||
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/core/oidc-validation.service.js';
|
||||
|
||||
@Module({
|
||||
imports: [UserSettingsModule],
|
||||
providers: [OidcConfigPersistence, OidcValidationService],
|
||||
exports: [OidcConfigPersistence, OidcValidationService],
|
||||
})
|
||||
export class OidcBaseModule {}
|
||||
@@ -0,0 +1,87 @@
|
||||
import { ConfigService } from '@nestjs/config';
|
||||
import { Test, TestingModule } from '@nestjs/testing';
|
||||
|
||||
import { UserSettingsService } from '@unraid/shared/services/user-settings.js';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
|
||||
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/core/oidc-validation.service.js';
|
||||
import { OidcUrlPatterns } from '@app/unraid-api/graph/resolvers/sso/utils/oidc-url-patterns.util.js';
|
||||
|
||||
describe('OidcConfigPersistence', () => {
|
||||
let service: OidcConfigPersistence;
|
||||
let mockConfigService: ConfigService;
|
||||
let mockUserSettingsService: UserSettingsService;
|
||||
let mockValidationService: OidcValidationService;
|
||||
|
||||
beforeEach(async () => {
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
providers: [
|
||||
OidcConfigPersistence,
|
||||
{
|
||||
provide: ConfigService,
|
||||
useValue: {
|
||||
get: vi.fn(),
|
||||
set: vi.fn(),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: UserSettingsService,
|
||||
useValue: {
|
||||
register: vi.fn(),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcValidationService,
|
||||
useValue: {
|
||||
validateProvider: vi.fn(),
|
||||
},
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
service = module.get<OidcConfigPersistence>(OidcConfigPersistence);
|
||||
mockConfigService = module.get<ConfigService>(ConfigService);
|
||||
mockUserSettingsService = module.get<UserSettingsService>(UserSettingsService);
|
||||
mockValidationService = module.get<OidcValidationService>(OidcValidationService);
|
||||
|
||||
// Mock persist method to avoid file system operations
|
||||
vi.spyOn(service, 'persist').mockResolvedValue(true);
|
||||
});
|
||||
|
||||
describe('URL validation integration', () => {
|
||||
it('should validate issuer URLs using the shared utility', () => {
|
||||
// Test that our shared utility correctly validates URLs
|
||||
// This ensures the pattern we use in the form schema works correctly
|
||||
const examples = OidcUrlPatterns.getExamples();
|
||||
|
||||
// Test valid URLs
|
||||
examples.valid.forEach((url) => {
|
||||
expect(OidcUrlPatterns.isValidIssuerUrl(url)).toBe(true);
|
||||
});
|
||||
|
||||
// Test invalid URLs
|
||||
examples.invalid.forEach((url) => {
|
||||
expect(OidcUrlPatterns.isValidIssuerUrl(url)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
it('should validate the pattern constant matches the regex', () => {
|
||||
// Ensure the pattern string can be compiled into a valid regex
|
||||
expect(() => new RegExp(OidcUrlPatterns.ISSUER_URL_PATTERN)).not.toThrow();
|
||||
|
||||
// Ensure the static regex matches the pattern
|
||||
const manualRegex = new RegExp(OidcUrlPatterns.ISSUER_URL_PATTERN);
|
||||
expect(OidcUrlPatterns.ISSUER_URL_REGEX.source).toBe(manualRegex.source);
|
||||
});
|
||||
|
||||
it('should reject the specific URL from the bug report', () => {
|
||||
// Test the exact scenario that caused the original bug
|
||||
const problematicUrl = 'https://accounts.google.com/';
|
||||
const correctUrl = 'https://accounts.google.com';
|
||||
|
||||
expect(OidcUrlPatterns.isValidIssuerUrl(problematicUrl)).toBe(false);
|
||||
expect(OidcUrlPatterns.isValidIssuerUrl(correctUrl)).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import { ConfigService } from '@nestjs/config';
|
||||
|
||||
import { RuleEffect } from '@jsonforms/core';
|
||||
@@ -6,12 +6,13 @@ import { mergeSettingSlices } from '@unraid/shared/jsonforms/settings.js';
|
||||
import { ConfigFilePersister } from '@unraid/shared/services/config-file.js';
|
||||
import { UserSettingsService } from '@unraid/shared/services/user-settings.js';
|
||||
|
||||
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/core/oidc-validation.service.js';
|
||||
import {
|
||||
AuthorizationOperator,
|
||||
OidcAuthorizationRule,
|
||||
OidcProvider,
|
||||
} from '@app/unraid-api/graph/resolvers/sso/oidc-provider.model.js';
|
||||
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/oidc-validation.service.js';
|
||||
} from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
|
||||
import { OidcUrlPatterns } from '@app/unraid-api/graph/resolvers/sso/utils/oidc-url-patterns.util.js';
|
||||
import {
|
||||
createAccordionLayout,
|
||||
createLabeledControl,
|
||||
@@ -21,6 +22,7 @@ import { SettingSlice } from '@app/unraid-api/types/json-forms.js';
|
||||
|
||||
export interface OidcConfig {
|
||||
providers: OidcProvider[];
|
||||
defaultAllowedOrigins?: string[];
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
@@ -52,6 +54,7 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
|
||||
defaultConfig(): OidcConfig {
|
||||
return {
|
||||
providers: [this.getUnraidNetSsoProvider()],
|
||||
defaultAllowedOrigins: [],
|
||||
};
|
||||
}
|
||||
|
||||
@@ -93,6 +96,7 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
|
||||
|
||||
return {
|
||||
providers: [unraidNetSsoProvider],
|
||||
defaultAllowedOrigins: [],
|
||||
};
|
||||
}
|
||||
|
||||
@@ -119,6 +123,42 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
|
||||
provider.authorizationRules || currentDefaults.authorizationRules,
|
||||
};
|
||||
}
|
||||
|
||||
// Fix dangerous authorization rules for non-unraid.net providers
|
||||
if (provider.authorizationRules && provider.authorizationRules.length > 0) {
|
||||
// Filter out dangerous rules that would allow all emails
|
||||
const safeRules = provider.authorizationRules.filter((rule) => {
|
||||
// Remove rules that have "email endsWith @" which allows all emails
|
||||
if (
|
||||
rule.claim === 'email' &&
|
||||
rule.operator === AuthorizationOperator.ENDS_WITH &&
|
||||
rule.value &&
|
||||
rule.value.length === 1 &&
|
||||
rule.value[0] === '@'
|
||||
) {
|
||||
this.logger.warn(
|
||||
`Removing dangerous authorization rule from provider "${provider.name}": email endsWith "@" allows all emails`
|
||||
);
|
||||
return false;
|
||||
}
|
||||
// Remove rules with empty or invalid values
|
||||
if (
|
||||
!rule.value ||
|
||||
rule.value.length === 0 ||
|
||||
rule.value.every((v) => !v || !v.trim())
|
||||
) {
|
||||
this.logger.warn(
|
||||
`Removing invalid authorization rule from provider "${provider.name}": empty values`
|
||||
);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
// Update provider with safe rules
|
||||
provider.authorizationRules = safeRules;
|
||||
}
|
||||
|
||||
return provider;
|
||||
});
|
||||
|
||||
@@ -155,6 +195,34 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
|
||||
provider.authorizationRules = rules;
|
||||
}
|
||||
|
||||
// Skip providers without authorization rules (they will be ignored)
|
||||
if (!provider.authorizationRules || provider.authorizationRules.length === 0) {
|
||||
this.logger.warn(
|
||||
`Provider "${provider.name}" has no authorization rules and will be ignored. Configure authorization rules to enable this provider.`
|
||||
);
|
||||
}
|
||||
|
||||
// Validate each rule has valid values (only if rules exist)
|
||||
if (provider.authorizationRules && provider.authorizationRules.length > 0) {
|
||||
for (const rule of provider.authorizationRules) {
|
||||
if (!rule.claim || !rule.claim.trim()) {
|
||||
throw new Error(
|
||||
`Provider "${provider.name}": Authorization rule claim cannot be empty`
|
||||
);
|
||||
}
|
||||
if (!rule.operator) {
|
||||
throw new Error(
|
||||
`Provider "${provider.name}": Authorization rule operator is required`
|
||||
);
|
||||
}
|
||||
if (!rule.value || rule.value.length === 0 || rule.value.every((v) => !v || !v.trim())) {
|
||||
throw new Error(
|
||||
`Provider "${provider.name}": Authorization rule for claim "${rule.claim}" must have at least one non-empty value`
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up the provider object - remove UI-only fields
|
||||
const cleanedProvider: OidcProvider = {
|
||||
id: provider.id,
|
||||
@@ -191,46 +259,52 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
|
||||
allowedDomains?: string[];
|
||||
allowedEmails?: string[];
|
||||
allowedUserIds?: string[];
|
||||
googleWorkspaceDomain?: string;
|
||||
}): OidcAuthorizationRule[] {
|
||||
const rules: OidcAuthorizationRule[] = [];
|
||||
|
||||
// Convert email domains to endsWith rules
|
||||
// Only add if domains are provided AND not empty AND have non-empty values
|
||||
if (simpleAuth?.allowedDomains && simpleAuth.allowedDomains.length > 0) {
|
||||
rules.push({
|
||||
claim: 'email',
|
||||
operator: AuthorizationOperator.ENDS_WITH,
|
||||
value: simpleAuth.allowedDomains.map((domain: string) =>
|
||||
domain.startsWith('@') ? domain : `@${domain}`
|
||||
),
|
||||
});
|
||||
const validDomains = simpleAuth.allowedDomains.filter(
|
||||
(domain: string) => domain && domain.trim()
|
||||
);
|
||||
if (validDomains.length > 0) {
|
||||
rules.push({
|
||||
claim: 'email',
|
||||
operator: AuthorizationOperator.ENDS_WITH,
|
||||
value: validDomains.map((domain: string) =>
|
||||
domain.startsWith('@') ? domain : `@${domain}`
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Convert specific emails to equals rules
|
||||
// Only add if emails are provided AND not empty AND have non-empty values
|
||||
if (simpleAuth?.allowedEmails && simpleAuth.allowedEmails.length > 0) {
|
||||
rules.push({
|
||||
claim: 'email',
|
||||
operator: AuthorizationOperator.EQUALS,
|
||||
value: simpleAuth.allowedEmails,
|
||||
});
|
||||
const validEmails = simpleAuth.allowedEmails.filter(
|
||||
(email: string) => email && email.trim()
|
||||
);
|
||||
if (validEmails.length > 0) {
|
||||
rules.push({
|
||||
claim: 'email',
|
||||
operator: AuthorizationOperator.EQUALS,
|
||||
value: validEmails,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Convert user IDs to sub equals rules
|
||||
// Only add if user IDs are provided AND not empty AND have non-empty values
|
||||
if (simpleAuth?.allowedUserIds && simpleAuth.allowedUserIds.length > 0) {
|
||||
rules.push({
|
||||
claim: 'sub',
|
||||
operator: AuthorizationOperator.EQUALS,
|
||||
value: simpleAuth.allowedUserIds,
|
||||
});
|
||||
}
|
||||
|
||||
// Google Workspace domain (hd claim)
|
||||
if (simpleAuth?.googleWorkspaceDomain) {
|
||||
rules.push({
|
||||
claim: 'hd',
|
||||
operator: AuthorizationOperator.EQUALS,
|
||||
value: [simpleAuth.googleWorkspaceDomain],
|
||||
});
|
||||
const validUserIds = simpleAuth.allowedUserIds.filter((id: string) => id && id.trim());
|
||||
if (validUserIds.length > 0) {
|
||||
rules.push({
|
||||
claim: 'sub',
|
||||
operator: AuthorizationOperator.EQUALS,
|
||||
value: validUserIds,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return rules;
|
||||
@@ -286,7 +360,6 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
|
||||
allowedDomains?: string[];
|
||||
allowedEmails?: string[];
|
||||
allowedUserIds?: string[];
|
||||
googleWorkspaceDomain?: string;
|
||||
}
|
||||
);
|
||||
// Return provider with generated rules, removing UI-only fields
|
||||
@@ -304,6 +377,39 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
|
||||
}),
|
||||
};
|
||||
|
||||
// Validate authorization rules for providers that have them
|
||||
for (const provider of processedConfig.providers) {
|
||||
if (!provider.authorizationRules || provider.authorizationRules.length === 0) {
|
||||
this.logger.warn(
|
||||
`Provider "${provider.name}" has no authorization rules and will be ignored. Configure authorization rules to enable this provider.`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Validate each rule has valid values
|
||||
for (const rule of provider.authorizationRules) {
|
||||
if (!rule.claim || !rule.claim.trim()) {
|
||||
throw new Error(
|
||||
`Provider "${provider.name}": Authorization rule claim cannot be empty`
|
||||
);
|
||||
}
|
||||
if (!rule.operator) {
|
||||
throw new Error(
|
||||
`Provider "${provider.name}": Authorization rule operator is required`
|
||||
);
|
||||
}
|
||||
if (
|
||||
!rule.value ||
|
||||
rule.value.length === 0 ||
|
||||
rule.value.every((v) => !v || !v.trim())
|
||||
) {
|
||||
throw new Error(
|
||||
`Provider "${provider.name}": Authorization rule for claim "${rule.claim}" must have at least one non-empty value`
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate OIDC discovery for all providers with issuer URLs
|
||||
const validationErrors: string[] = [];
|
||||
for (const provider of processedConfig.providers) {
|
||||
@@ -419,10 +525,6 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
|
||||
if (rule.claim === 'sub' && rule.operator === AuthorizationOperator.EQUALS) {
|
||||
return true;
|
||||
}
|
||||
// Google Workspace domain
|
||||
if (rule.claim === 'hd' && rule.operator === AuthorizationOperator.EQUALS) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
}
|
||||
@@ -431,13 +533,11 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
|
||||
allowedDomains: string[];
|
||||
allowedEmails: string[];
|
||||
allowedUserIds: string[];
|
||||
googleWorkspaceDomain?: string;
|
||||
} {
|
||||
const simpleAuth = {
|
||||
allowedDomains: [] as string[],
|
||||
allowedEmails: [] as string[],
|
||||
allowedUserIds: [] as string[],
|
||||
googleWorkspaceDomain: undefined as string | undefined,
|
||||
};
|
||||
|
||||
rules.forEach((rule) => {
|
||||
@@ -449,12 +549,6 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
|
||||
simpleAuth.allowedEmails = rule.value;
|
||||
} else if (rule.claim === 'sub' && rule.operator === AuthorizationOperator.EQUALS) {
|
||||
simpleAuth.allowedUserIds = rule.value;
|
||||
} else if (
|
||||
rule.claim === 'hd' &&
|
||||
rule.operator === AuthorizationOperator.EQUALS &&
|
||||
rule.value.length > 0
|
||||
) {
|
||||
simpleAuth.googleWorkspaceDomain = rule.value[0];
|
||||
}
|
||||
});
|
||||
|
||||
@@ -462,7 +556,36 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
|
||||
}
|
||||
|
||||
private buildSlice(): SettingSlice {
|
||||
return mergeSettingSlices([this.oidcProvidersSlice()], { as: 'sso' });
|
||||
const providersSlice = this.oidcProvidersSlice();
|
||||
|
||||
// Add defaultAllowedOrigins to the properties
|
||||
providersSlice.properties.defaultAllowedOrigins = {
|
||||
type: 'array',
|
||||
items: { type: 'string' },
|
||||
title: 'Default Allowed Redirect Origins',
|
||||
default: [],
|
||||
description:
|
||||
'Additional trusted redirect origins to allow redirects from custom ports, reverse proxies, Tailscale, etc.',
|
||||
};
|
||||
|
||||
// Add the control for defaultAllowedOrigins before the providers control using UnraidSettingsLayout
|
||||
if (providersSlice.elements?.[0]?.elements) {
|
||||
providersSlice.elements[0].elements.unshift(
|
||||
createLabeledControl({
|
||||
scope: '#/properties/sso/properties/defaultAllowedOrigins',
|
||||
label: 'Allowed OIDC Redirect Origins',
|
||||
description:
|
||||
'Add trusted origins for OIDC redirection. These are URLs that the OIDC provider can redirect to after authentication when accessing Unraid through custom ports, reverse proxies, or Tailscale. Each origin should include the protocol and optionally a port (e.g., https://unraid.local:8443)',
|
||||
controlOptions: {
|
||||
format: 'array',
|
||||
inputType: 'text',
|
||||
placeholder: 'https://unraid.local:8443',
|
||||
},
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
return mergeSettingSlices([providersSlice], { as: 'sso' });
|
||||
}
|
||||
|
||||
private oidcProvidersSlice(): SettingSlice {
|
||||
@@ -498,7 +621,22 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
|
||||
type: 'string',
|
||||
title: 'Issuer URL',
|
||||
format: 'uri',
|
||||
description: 'OIDC issuer URL (e.g., https://accounts.google.com)',
|
||||
allOf: [
|
||||
{
|
||||
pattern: OidcUrlPatterns.ISSUER_URL_PATTERN,
|
||||
errorMessage:
|
||||
'Must be a valid HTTP or HTTPS URL without trailing slashes or whitespace',
|
||||
},
|
||||
{
|
||||
not: {
|
||||
pattern: '\\.well-known',
|
||||
},
|
||||
errorMessage:
|
||||
'Cannot contain /.well-known/ paths. Use the base issuer URL instead (e.g., https://accounts.google.com instead of https://accounts.google.com/.well-known/openid-configuration)',
|
||||
},
|
||||
],
|
||||
description:
|
||||
'OIDC issuer URL (e.g., https://accounts.google.com). Cannot contain /.well-known/ paths - use the base issuer URL instead of the full discovery endpoint. Must not end with a trailing slash.',
|
||||
},
|
||||
authorizationEndpoint: {
|
||||
anyOf: [
|
||||
@@ -999,7 +1137,7 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
|
||||
scope: '#/properties/claim',
|
||||
label: 'JWT Claim:',
|
||||
description:
|
||||
'JWT claim to check (e.g., email, sub, groups, hd for Google hosted domain)',
|
||||
'JWT claim to check (e.g., email, sub, groups)',
|
||||
controlOptions: {
|
||||
inputType: 'text',
|
||||
placeholder: 'email',
|
||||
@@ -0,0 +1,14 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
|
||||
import { OidcAuthModule } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-auth.module.js';
|
||||
import { OidcClientModule } from '@app/unraid-api/graph/resolvers/sso/client/oidc-client.module.js';
|
||||
import { OidcBaseModule } from '@app/unraid-api/graph/resolvers/sso/core/oidc-base.module.js';
|
||||
import { OidcService } from '@app/unraid-api/graph/resolvers/sso/core/oidc.service.js';
|
||||
import { OidcSessionModule } from '@app/unraid-api/graph/resolvers/sso/session/oidc-session.module.js';
|
||||
|
||||
@Module({
|
||||
imports: [OidcBaseModule, OidcSessionModule, OidcAuthModule, OidcClientModule],
|
||||
providers: [OidcService],
|
||||
exports: [OidcService, OidcBaseModule, OidcSessionModule, OidcAuthModule, OidcClientModule],
|
||||
})
|
||||
export class OidcCoreModule {}
|
||||
@@ -0,0 +1,160 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { ConfigService } from '@nestjs/config';
|
||||
|
||||
import * as client from 'openid-client';
|
||||
|
||||
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
|
||||
import { OidcErrorHelper } from '@app/unraid-api/graph/resolvers/sso/utils/oidc-error.helper.js';
|
||||
|
||||
@Injectable()
|
||||
export class OidcValidationService {
|
||||
private readonly logger = new Logger(OidcValidationService.name);
|
||||
|
||||
constructor(private readonly configService: ConfigService) {}
|
||||
|
||||
/**
|
||||
* Validate OIDC provider configuration by attempting discovery
|
||||
* Returns validation result with helpful error messages for debugging
|
||||
*/
|
||||
async validateProvider(
|
||||
provider: OidcProvider
|
||||
): Promise<{ isValid: boolean; error?: string; details?: unknown }> {
|
||||
try {
|
||||
// Validate issuer URL is present
|
||||
if (!provider.issuer) {
|
||||
return {
|
||||
isValid: false,
|
||||
error: 'No issuer URL provided. Please specify the OIDC provider issuer URL.',
|
||||
details: { type: 'MISSING_ISSUER' },
|
||||
};
|
||||
}
|
||||
|
||||
// Validate issuer URL is valid
|
||||
let serverUrl: URL;
|
||||
try {
|
||||
serverUrl = new URL(provider.issuer);
|
||||
} catch (urlError) {
|
||||
return {
|
||||
isValid: false,
|
||||
error: `Invalid issuer URL format: '${provider.issuer}'. Please provide a valid URL.`,
|
||||
details: {
|
||||
type: 'INVALID_URL',
|
||||
originalError: urlError instanceof Error ? urlError.message : String(urlError),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Configure client options for HTTP if needed
|
||||
let clientOptions: any = undefined;
|
||||
if (serverUrl.protocol === 'http:') {
|
||||
this.logger.warn(
|
||||
`HTTP issuer URL detected for provider ${provider.id}: ${provider.issuer} - This is insecure`
|
||||
);
|
||||
clientOptions = {
|
||||
execute: [client.allowInsecureRequests],
|
||||
};
|
||||
}
|
||||
|
||||
// Attempt OIDC discovery
|
||||
await this.performDiscovery(provider, clientOptions);
|
||||
return { isValid: true };
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
|
||||
|
||||
// Log the raw error for debugging
|
||||
this.logger.log(`Raw discovery error for ${provider.id}: ${errorMessage}`);
|
||||
|
||||
// Use the helper to parse the error
|
||||
const { userFriendlyError, details } = OidcErrorHelper.parseDiscoveryError(
|
||||
error,
|
||||
provider.issuer
|
||||
);
|
||||
|
||||
this.logger.error(`Validation failed for provider ${provider.id}: ${errorMessage}`);
|
||||
|
||||
// Add debug logging for HTTP status errors
|
||||
if (errorMessage.includes('unexpected HTTP response status code')) {
|
||||
const baseUrl = provider.issuer?.endsWith('/.well-known/openid-configuration')
|
||||
? provider.issuer.replace('/.well-known/openid-configuration', '')
|
||||
: provider.issuer;
|
||||
this.logger.log(`Attempted to fetch: ${baseUrl}/.well-known/openid-configuration`);
|
||||
this.logger.error(`Full error details: ${errorMessage}`);
|
||||
}
|
||||
|
||||
return {
|
||||
isValid: false,
|
||||
error: userFriendlyError,
|
||||
details,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
async performDiscovery(provider: OidcProvider, clientOptions?: any): Promise<client.Configuration> {
|
||||
if (!provider.issuer) {
|
||||
throw new Error('No issuer URL provided');
|
||||
}
|
||||
|
||||
// Configure client auth method
|
||||
const clientAuth = provider.clientSecret
|
||||
? client.ClientSecretPost(provider.clientSecret)
|
||||
: undefined;
|
||||
|
||||
const serverUrl = new URL(provider.issuer);
|
||||
const discoveryUrl = `${provider.issuer}/.well-known/openid-configuration`;
|
||||
|
||||
this.logger.log(`Starting discovery for provider ${provider.id}`);
|
||||
this.logger.log(`Discovery URL: ${discoveryUrl}`);
|
||||
this.logger.log(`Client ID: ${provider.clientId}`);
|
||||
this.logger.log(`Client secret configured: ${provider.clientSecret ? 'Yes' : 'No'}`);
|
||||
|
||||
// Use provided client options or create default options with HTTP support if needed
|
||||
if (!clientOptions && serverUrl.protocol === 'http:') {
|
||||
this.logger.warn(
|
||||
`Allowing HTTP for ${provider.id} - This is insecure and should only be used for testing`
|
||||
);
|
||||
// For openid-client v6, use allowInsecureRequests in the execute array
|
||||
// This is deprecated but needed for local development with HTTP endpoints
|
||||
clientOptions = {
|
||||
execute: [client.allowInsecureRequests],
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const config = await client.discovery(
|
||||
serverUrl,
|
||||
provider.clientId,
|
||||
undefined, // client metadata
|
||||
clientAuth,
|
||||
clientOptions
|
||||
);
|
||||
|
||||
this.logger.log(`Discovery successful for ${provider.id}`);
|
||||
this.logger.log(`Discovery response metadata:`);
|
||||
this.logger.log(` - issuer: ${config.serverMetadata().issuer}`);
|
||||
this.logger.log(
|
||||
` - authorization_endpoint: ${config.serverMetadata().authorization_endpoint}`
|
||||
);
|
||||
this.logger.log(` - token_endpoint: ${config.serverMetadata().token_endpoint}`);
|
||||
this.logger.log(
|
||||
` - userinfo_endpoint: ${config.serverMetadata().userinfo_endpoint || 'not provided'}`
|
||||
);
|
||||
this.logger.log(` - jwks_uri: ${config.serverMetadata().jwks_uri || 'not provided'}`);
|
||||
this.logger.log(
|
||||
` - response_types_supported: ${config.serverMetadata().response_types_supported?.join(', ') || 'not provided'}`
|
||||
);
|
||||
this.logger.log(
|
||||
` - scopes_supported: ${config.serverMetadata().scopes_supported?.join(', ') || 'not provided'}`
|
||||
);
|
||||
|
||||
return config;
|
||||
} catch (discoveryError) {
|
||||
this.logger.error(`Discovery failed for ${provider.id} at ${discoveryUrl}`);
|
||||
|
||||
if (discoveryError instanceof Error) {
|
||||
this.logger.error('Discovery error: %o', discoveryError);
|
||||
}
|
||||
|
||||
throw discoveryError;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,485 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
import { ConfigModule, ConfigService } from '@nestjs/config';
|
||||
import { Test, TestingModule } from '@nestjs/testing';
|
||||
|
||||
import * as client from 'openid-client';
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { OidcAuthorizationService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-authorization.service.js';
|
||||
import { OidcClaimsService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-claims.service.js';
|
||||
import { OidcTokenExchangeService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-token-exchange.service.js';
|
||||
import { OidcClientConfigService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-client-config.service.js';
|
||||
import { OidcRedirectUriService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-redirect-uri.service.js';
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
|
||||
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/core/oidc-validation.service.js';
|
||||
import { OidcService } from '@app/unraid-api/graph/resolvers/sso/core/oidc.service.js';
|
||||
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
|
||||
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-session.service.js';
|
||||
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
|
||||
|
||||
describe('OidcService Integration Tests - Enhanced Logging', () => {
|
||||
let service: OidcService;
|
||||
let configPersistence: OidcConfigPersistence;
|
||||
let loggerSpy: any;
|
||||
let debugLogs: string[] = [];
|
||||
let errorLogs: string[] = [];
|
||||
let warnLogs: string[] = [];
|
||||
let logLogs: string[] = [];
|
||||
|
||||
beforeEach(async () => {
|
||||
// Clear log arrays
|
||||
debugLogs = [];
|
||||
errorLogs = [];
|
||||
warnLogs = [];
|
||||
logLogs = [];
|
||||
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
imports: [
|
||||
ConfigModule.forRoot({
|
||||
isGlobal: true,
|
||||
load: [() => ({ BASE_URL: 'http://test.local' })],
|
||||
}),
|
||||
],
|
||||
providers: [
|
||||
OidcService,
|
||||
OidcValidationService,
|
||||
OidcClientConfigService,
|
||||
OidcTokenExchangeService,
|
||||
{
|
||||
provide: OidcAuthorizationService,
|
||||
useValue: {
|
||||
checkAuthorization: vi.fn(),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcConfigPersistence,
|
||||
useValue: {
|
||||
getProvider: vi.fn(),
|
||||
saveProvider: vi.fn(),
|
||||
getConfig: vi.fn().mockReturnValue({
|
||||
providers: [],
|
||||
defaultAllowedOrigins: [],
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcSessionService,
|
||||
useValue: {
|
||||
createSession: vi.fn().mockResolvedValue('mock-token'),
|
||||
validateSession: vi.fn(),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcStateService,
|
||||
useValue: {
|
||||
generateSecureState: vi.fn().mockResolvedValue('secure-state'),
|
||||
validateSecureState: vi.fn().mockResolvedValue({
|
||||
isValid: true,
|
||||
clientState: 'test-state',
|
||||
redirectUri: 'https://myapp.example.com/graphql/api/auth/oidc/callback',
|
||||
}),
|
||||
extractProviderFromState: vi.fn().mockReturnValue('test-provider'),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcRedirectUriService,
|
||||
useValue: {
|
||||
getRedirectUri: vi
|
||||
.fn()
|
||||
.mockResolvedValue(
|
||||
'https://myapp.example.com/graphql/api/auth/oidc/callback'
|
||||
),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcClaimsService,
|
||||
useValue: {
|
||||
parseIdToken: vi.fn().mockReturnValue({
|
||||
sub: 'user123',
|
||||
email: 'user@example.com',
|
||||
}),
|
||||
validateClaims: vi.fn().mockReturnValue('user123'),
|
||||
},
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
service = module.get<OidcService>(OidcService);
|
||||
configPersistence = module.get<OidcConfigPersistence>(OidcConfigPersistence);
|
||||
|
||||
// Spy on logger methods to capture logs
|
||||
loggerSpy = {
|
||||
debug: vi
|
||||
.spyOn(Logger.prototype, 'debug')
|
||||
.mockImplementation((message: string, ...args: any[]) => {
|
||||
debugLogs.push(message);
|
||||
}),
|
||||
error: vi
|
||||
.spyOn(Logger.prototype, 'error')
|
||||
.mockImplementation((message: string, ...args: any[]) => {
|
||||
errorLogs.push(message);
|
||||
}),
|
||||
warn: vi
|
||||
.spyOn(Logger.prototype, 'warn')
|
||||
.mockImplementation((message: string, ...args: any[]) => {
|
||||
warnLogs.push(message);
|
||||
}),
|
||||
log: vi
|
||||
.spyOn(Logger.prototype, 'log')
|
||||
.mockImplementation((message: string, ...args: any[]) => {
|
||||
logLogs.push(message);
|
||||
}),
|
||||
verbose: vi.spyOn(Logger.prototype, 'verbose').mockImplementation(() => {}),
|
||||
};
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('Token Exchange Error Logging', () => {
|
||||
it('should log detailed error information when token exchange fails with Google (trailing slash issue)', async () => {
|
||||
// This simulates the issue from #1616 where a trailing slash causes failure
|
||||
const provider: OidcProvider = {
|
||||
id: 'google-test',
|
||||
name: 'Google Test',
|
||||
issuer: 'https://accounts.google.com/', // Trailing slash will cause issue
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
scopes: ['openid', 'email', 'profile'],
|
||||
authorizationRules: [
|
||||
{
|
||||
claim: 'email',
|
||||
operator: 'ENDS_WITH' as any,
|
||||
value: ['@example.com'],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
vi.mocked(configPersistence.getProvider).mockResolvedValue(provider);
|
||||
|
||||
try {
|
||||
await service.handleCallback({
|
||||
providerId: 'google-test',
|
||||
code: 'test-code',
|
||||
state: 'test-state',
|
||||
requestOrigin: 'http://test.local',
|
||||
fullCallbackUrl:
|
||||
'http://test.local/graphql/api/auth/oidc/callback?code=test-code&state=test-state',
|
||||
requestHeaders: { host: 'test.local' },
|
||||
});
|
||||
} catch (error) {
|
||||
// We expect this to fail
|
||||
}
|
||||
|
||||
// Verify that the service attempted to handle the callback
|
||||
// Note: Detailed token exchange logging now happens in OidcTokenExchangeService
|
||||
expect(errorLogs.length).toBeGreaterThan(0);
|
||||
// Changed logging format to use error extractor
|
||||
expect(errorLogs.some((log) => log.includes('Token exchange failed'))).toBe(true);
|
||||
});
|
||||
|
||||
it('should log discovery failure details with invalid issuer URL', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'invalid-issuer',
|
||||
name: 'Invalid Issuer Test',
|
||||
issuer: 'https://invalid-oidc-provider.example.com', // Non-existent domain
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
scopes: ['openid', 'email'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
const validationService = new OidcValidationService(new ConfigService());
|
||||
const result = await validationService.validateProvider(provider);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
// Should now have more specific error message
|
||||
expect(result.error).toBeDefined();
|
||||
// The error should mention the domain cannot be resolved or connection failed
|
||||
expect(result.error).toMatch(
|
||||
/Cannot resolve domain name|Failed to connect to OIDC provider/
|
||||
);
|
||||
expect(result.details).toBeDefined();
|
||||
expect(result.details).toHaveProperty('type');
|
||||
// Should be either DNS_ERROR or FETCH_ERROR depending on the cause
|
||||
expect(['DNS_ERROR', 'FETCH_ERROR']).toContain((result.details as any).type);
|
||||
});
|
||||
|
||||
it('should log detailed HTTP error responses from discovery', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'http-error-test',
|
||||
name: 'HTTP Error Test',
|
||||
issuer: 'https://httpstat.us/500', // Returns 500 error
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
vi.mocked(configPersistence.getProvider).mockResolvedValue(provider);
|
||||
|
||||
try {
|
||||
await service.validateProvider(provider);
|
||||
} catch (error) {
|
||||
// Expected to fail
|
||||
}
|
||||
|
||||
// Check that HTTP status details are logged (now in log level)
|
||||
expect(logLogs.some((log) => log.includes('Discovery URL:'))).toBe(true);
|
||||
expect(logLogs.some((log) => log.includes('Client ID:'))).toBe(true);
|
||||
});
|
||||
|
||||
it('should log authorization URL building details', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'auth-url-test',
|
||||
name: 'Auth URL Test',
|
||||
issuer: 'https://accounts.google.com',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
scopes: ['openid', 'email', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
vi.mocked(configPersistence.getProvider).mockResolvedValue(provider);
|
||||
|
||||
try {
|
||||
await service.getAuthorizationUrl({
|
||||
providerId: 'auth-url-test',
|
||||
state: 'test-state',
|
||||
requestOrigin: 'http://test.local',
|
||||
requestHeaders: { host: 'test.local' },
|
||||
});
|
||||
|
||||
// Verify URL building logs
|
||||
expect(logLogs.some((log) => log.includes('Built authorization URL'))).toBe(true);
|
||||
expect(logLogs.some((log) => log.includes('Authorization parameters:'))).toBe(true);
|
||||
} catch (error) {
|
||||
// May fail due to real discovery, but we're interested in the logs
|
||||
}
|
||||
});
|
||||
|
||||
it('should log detailed information for manual endpoint configuration', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'manual-endpoints',
|
||||
name: 'Manual Endpoints Test',
|
||||
issuer: undefined,
|
||||
authorizationEndpoint: 'https://auth.example.com/authorize',
|
||||
tokenEndpoint: 'https://auth.example.com/token',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
vi.mocked(configPersistence.getProvider).mockResolvedValue(provider);
|
||||
|
||||
const authUrl = await service.getAuthorizationUrl({
|
||||
providerId: 'manual-endpoints',
|
||||
state: 'test-state',
|
||||
requestOrigin: 'http://test.local',
|
||||
requestHeaders: {
|
||||
'x-forwarded-host': 'test.local',
|
||||
'x-forwarded-proto': 'http',
|
||||
},
|
||||
});
|
||||
|
||||
// Verify manual endpoint logs
|
||||
expect(debugLogs.some((log) => log.includes('Built authorization URL'))).toBe(true);
|
||||
expect(debugLogs.some((log) => log.includes('client_id=test-client-id'))).toBe(true);
|
||||
expect(authUrl).toContain('https://auth.example.com/authorize');
|
||||
});
|
||||
|
||||
it('should log JWT claim validation failures with detailed context', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'jwt-validation-test',
|
||||
name: 'JWT Validation Test',
|
||||
issuer: 'https://accounts.google.com',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
scopes: ['openid', 'email'],
|
||||
authorizationRules: [
|
||||
{
|
||||
claim: 'email',
|
||||
operator: 'ENDS_WITH' as any,
|
||||
value: ['@restricted.com'],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
vi.mocked(configPersistence.getProvider).mockResolvedValue(provider);
|
||||
|
||||
// Mock a scenario where JWT validation fails
|
||||
try {
|
||||
await service.handleCallback({
|
||||
providerId: 'jwt-validation-test',
|
||||
code: 'test-code',
|
||||
state: 'test-state',
|
||||
requestOrigin: 'http://test.local',
|
||||
fullCallbackUrl:
|
||||
'http://test.local/graphql/api/auth/oidc/callback?code=test-code&state=test-state',
|
||||
requestHeaders: { host: 'test.local' },
|
||||
});
|
||||
} catch (error) {
|
||||
// Expected to fail
|
||||
}
|
||||
|
||||
// The JWT error handling is now in OidcTokenExchangeService
|
||||
// We should see some error logged
|
||||
expect(errorLogs.length).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Discovery Endpoint Logging', () => {
|
||||
it('should log all discovery metadata when successful', async () => {
|
||||
// Use a real OIDC provider that works
|
||||
const provider: OidcProvider = {
|
||||
id: 'microsoft',
|
||||
name: 'Microsoft',
|
||||
issuer: 'https://login.microsoftonline.com/common/v2.0',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
scopes: ['openid', 'email', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
const validationService = new OidcValidationService(new ConfigService());
|
||||
|
||||
try {
|
||||
await validationService.performDiscovery(provider);
|
||||
} catch (error) {
|
||||
// May fail due to network, but we're checking logs
|
||||
}
|
||||
|
||||
// Verify discovery logging (now in log level)
|
||||
expect(logLogs.some((log) => log.includes('Starting discovery'))).toBe(true);
|
||||
expect(logLogs.some((log) => log.includes('Discovery URL:'))).toBe(true);
|
||||
});
|
||||
|
||||
it('should log discovery failures with malformed JSON response', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'malformed-json',
|
||||
name: 'Malformed JSON Test',
|
||||
issuer: 'https://example.com/malformed',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
// Mock global fetch to return HTML instead of JSON
|
||||
const originalFetch = global.fetch;
|
||||
global.fetch = vi.fn().mockImplementation(() =>
|
||||
Promise.resolve(
|
||||
new Response('<html><body>Not JSON</body></html>', {
|
||||
status: 200,
|
||||
headers: { 'content-type': 'text/html' },
|
||||
})
|
||||
)
|
||||
);
|
||||
|
||||
const validationService = new OidcValidationService(new ConfigService());
|
||||
const result = await validationService.validateProvider(provider);
|
||||
|
||||
// Restore original fetch
|
||||
global.fetch = originalFetch;
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toBeDefined();
|
||||
// The openid-client library will fail when it gets HTML instead of JSON
|
||||
// It returns "unexpected response content-type" error
|
||||
expect(result.error).toMatch(
|
||||
/Invalid OIDC discovery|malformed|doesn't conform|unexpected|content-type/i
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle and log HTTP vs HTTPS protocol differences', async () => {
|
||||
const httpProvider: OidcProvider = {
|
||||
id: 'http-local',
|
||||
name: 'HTTP Local Test',
|
||||
issuer: 'http://localhost:8080', // HTTP endpoint
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
// Create a validation service and spy on its logger
|
||||
const validationService = new OidcValidationService(new ConfigService());
|
||||
|
||||
try {
|
||||
await validationService.validateProvider(httpProvider);
|
||||
} catch (error) {
|
||||
// Expected to fail if localhost:8080 isn't running
|
||||
}
|
||||
|
||||
// The HTTP logging happens in the validation service
|
||||
// We should check that HTTP issuers are detected
|
||||
expect(httpProvider.issuer).toMatch(/^http:/);
|
||||
// Verify that we're testing an HTTP endpoint
|
||||
expect(httpProvider.issuer).toBe('http://localhost:8080');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Request/Response Detail Logging', () => {
|
||||
it('should log complete request parameters for token exchange', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'token-params-test',
|
||||
name: 'Token Params Test',
|
||||
issuer: 'https://accounts.google.com',
|
||||
clientId: 'detailed-client-id',
|
||||
clientSecret: 'detailed-client-secret',
|
||||
scopes: ['openid', 'email', 'profile', 'offline_access'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
vi.mocked(configPersistence.getProvider).mockResolvedValue(provider);
|
||||
|
||||
try {
|
||||
await service.handleCallback({
|
||||
providerId: 'token-params-test',
|
||||
code: 'authorization-code-12345',
|
||||
state: 'state-with-signature',
|
||||
requestOrigin: 'https://myapp.example.com',
|
||||
fullCallbackUrl:
|
||||
'https://myapp.example.com/graphql/api/auth/oidc/callback?code=authorization-code-12345&state=state-with-signature&scope=openid+email+profile',
|
||||
requestHeaders: { host: 'myapp.example.com' },
|
||||
});
|
||||
} catch (error) {
|
||||
// Expected to fail
|
||||
}
|
||||
|
||||
// Verify that we attempted the operation
|
||||
// Detailed parameter logging is now in OidcTokenExchangeService
|
||||
expect(debugLogs.length).toBeGreaterThan(0);
|
||||
expect(debugLogs.some((log) => log.includes('Client ID: detailed-client-id'))).toBe(true);
|
||||
expect(debugLogs.some((log) => log.includes('Client secret configured: Yes'))).toBe(true);
|
||||
});
|
||||
|
||||
it('should capture and log all error properties from openid-client', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'error-properties-test',
|
||||
name: 'Error Properties Test',
|
||||
issuer: 'https://expired-cert.badssl.com/', // SSL cert error
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
const validationService = new OidcValidationService(new ConfigService());
|
||||
const result = await validationService.validateProvider(provider);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toBeDefined();
|
||||
// Should detect SSL/certificate issues or connection failure
|
||||
expect(result.error).toMatch(
|
||||
/SSL\/TLS certificate error|Failed to connect to OIDC provider|certificate/
|
||||
);
|
||||
expect(result.details).toBeDefined();
|
||||
expect(result.details).toHaveProperty('type');
|
||||
// Should be either SSL_ERROR or FETCH_ERROR
|
||||
expect(['SSL_ERROR', 'FETCH_ERROR']).toContain((result.details as any).type);
|
||||
});
|
||||
});
|
||||
});
|
||||
381
api/src/unraid-api/graph/resolvers/sso/core/oidc.service.test.ts
Normal file
381
api/src/unraid-api/graph/resolvers/sso/core/oidc.service.test.ts
Normal file
@@ -0,0 +1,381 @@
|
||||
import { CacheModule } from '@nestjs/cache-manager';
|
||||
import { UnauthorizedException } from '@nestjs/common';
|
||||
import { Test, TestingModule } from '@nestjs/testing';
|
||||
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { OidcAuthorizationService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-authorization.service.js';
|
||||
import { OidcClaimsService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-claims.service.js';
|
||||
import { OidcTokenExchangeService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-token-exchange.service.js';
|
||||
import { OidcClientConfigService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-client-config.service.js';
|
||||
import { OidcRedirectUriService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-redirect-uri.service.js';
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
|
||||
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/core/oidc-validation.service.js';
|
||||
import { OidcService } from '@app/unraid-api/graph/resolvers/sso/core/oidc.service.js';
|
||||
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
|
||||
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-session.service.js';
|
||||
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
|
||||
|
||||
// Mock openid-client
|
||||
vi.mock('openid-client', () => ({
|
||||
buildAuthorizationUrl: vi.fn((config, params) => {
|
||||
const url = new URL(config.serverMetadata().authorization_endpoint);
|
||||
Object.entries(params).forEach(([key, value]) => {
|
||||
if (value !== undefined) {
|
||||
url.searchParams.set(key, String(value));
|
||||
}
|
||||
});
|
||||
return url;
|
||||
}),
|
||||
allowInsecureRequests: vi.fn(),
|
||||
}));
|
||||
|
||||
describe('OidcService Integration', () => {
|
||||
let service: OidcService;
|
||||
let oidcConfig: any;
|
||||
let sessionService: any;
|
||||
let stateService: OidcStateService;
|
||||
let redirectUriService: any;
|
||||
let clientConfigService: any;
|
||||
let tokenExchangeService: any;
|
||||
let claimsService: any;
|
||||
let authorizationService: any;
|
||||
|
||||
beforeEach(async () => {
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
imports: [CacheModule.register()],
|
||||
providers: [
|
||||
OidcService,
|
||||
{
|
||||
provide: OidcConfigPersistence,
|
||||
useValue: {
|
||||
getProvider: vi.fn(),
|
||||
getConfig: vi.fn().mockResolvedValue({
|
||||
providers: [],
|
||||
defaultAllowedOrigins: ['https://example.com'],
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcSessionService,
|
||||
useValue: {
|
||||
createSession: vi.fn().mockResolvedValue('padded-token-123'),
|
||||
},
|
||||
},
|
||||
OidcStateService,
|
||||
{
|
||||
provide: OidcValidationService,
|
||||
useValue: {
|
||||
validateProvider: vi.fn().mockResolvedValue({ isValid: true }),
|
||||
performDiscovery: vi.fn(),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcAuthorizationService,
|
||||
useValue: {
|
||||
checkAuthorization: vi.fn(),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcRedirectUriService,
|
||||
useValue: {
|
||||
getRedirectUri: vi.fn().mockResolvedValue('https://example.com/callback'),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcClientConfigService,
|
||||
useValue: {
|
||||
getOrCreateConfig: vi.fn(),
|
||||
clearCache: vi.fn(),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcTokenExchangeService,
|
||||
useValue: {
|
||||
exchangeCodeForTokens: vi.fn(),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcClaimsService,
|
||||
useValue: {
|
||||
parseIdToken: vi.fn(),
|
||||
validateClaims: vi.fn(),
|
||||
},
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
service = module.get<OidcService>(OidcService);
|
||||
oidcConfig = module.get(OidcConfigPersistence);
|
||||
sessionService = module.get(OidcSessionService);
|
||||
stateService = module.get<OidcStateService>(OidcStateService);
|
||||
redirectUriService = module.get(OidcRedirectUriService);
|
||||
clientConfigService = module.get(OidcClientConfigService);
|
||||
tokenExchangeService = module.get(OidcTokenExchangeService);
|
||||
claimsService = module.get(OidcClaimsService);
|
||||
authorizationService = module.get(OidcAuthorizationService);
|
||||
});
|
||||
|
||||
describe('getAuthorizationUrl', () => {
|
||||
it('should generate authorization URL with custom endpoints', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'custom-provider',
|
||||
name: 'Custom Provider',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-secret',
|
||||
authorizationEndpoint: 'https://custom.example.com/auth',
|
||||
scopes: ['openid', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
|
||||
const params = {
|
||||
providerId: 'custom-provider',
|
||||
state: 'client-state-123',
|
||||
requestOrigin: 'https://example.com',
|
||||
requestHeaders: { host: 'example.com' },
|
||||
};
|
||||
|
||||
const url = await service.getAuthorizationUrl(params);
|
||||
|
||||
expect(redirectUriService.getRedirectUri).toHaveBeenCalledWith('https://example.com', {
|
||||
host: 'example.com',
|
||||
});
|
||||
|
||||
const urlObj = new URL(url);
|
||||
expect(urlObj.origin).toBe('https://custom.example.com');
|
||||
expect(urlObj.pathname).toBe('/auth');
|
||||
expect(urlObj.searchParams.get('client_id')).toBe('test-client-id');
|
||||
expect(urlObj.searchParams.get('redirect_uri')).toBe('https://example.com/callback');
|
||||
expect(urlObj.searchParams.get('scope')).toBe('openid profile');
|
||||
expect(urlObj.searchParams.get('response_type')).toBe('code');
|
||||
expect(urlObj.searchParams.has('state')).toBe(true);
|
||||
});
|
||||
|
||||
it('should use OIDC discovery when no custom authorization endpoint', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'discovery-provider',
|
||||
name: 'Discovery Provider',
|
||||
clientId: 'test-client-id',
|
||||
issuer: 'https://discovery.example.com',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
// Create a mock configuration object
|
||||
const mockConfig = {
|
||||
serverMetadata: vi.fn().mockReturnValue({
|
||||
authorization_endpoint: 'https://discovery.example.com/authorize',
|
||||
}),
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
clientConfigService.getOrCreateConfig.mockResolvedValue(mockConfig);
|
||||
|
||||
const params = {
|
||||
providerId: 'discovery-provider',
|
||||
state: 'client-state-123',
|
||||
requestOrigin: 'https://example.com',
|
||||
requestHeaders: {},
|
||||
};
|
||||
|
||||
const url = await service.getAuthorizationUrl(params);
|
||||
|
||||
expect(clientConfigService.getOrCreateConfig).toHaveBeenCalledWith(provider);
|
||||
expect(url).toContain('https://discovery.example.com/authorize');
|
||||
});
|
||||
|
||||
it('should throw when provider not found', async () => {
|
||||
oidcConfig.getProvider.mockResolvedValue(null);
|
||||
|
||||
const params = {
|
||||
providerId: 'non-existent',
|
||||
state: 'state',
|
||||
requestOrigin: 'https://example.com',
|
||||
requestHeaders: {},
|
||||
};
|
||||
|
||||
await expect(service.getAuthorizationUrl(params)).rejects.toThrow(UnauthorizedException);
|
||||
});
|
||||
});
|
||||
|
||||
describe('handleCallback', () => {
|
||||
it('should handle successful callback flow', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'test-provider',
|
||||
name: 'Test Provider',
|
||||
clientId: 'test-client-id',
|
||||
issuer: 'https://test.example.com',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
const mockConfig = {
|
||||
serverMetadata: vi.fn().mockReturnValue({
|
||||
issuer: 'https://test.example.com',
|
||||
token_endpoint: 'https://test.example.com/token',
|
||||
}),
|
||||
};
|
||||
|
||||
const mockTokens = {
|
||||
id_token: 'id.token.here',
|
||||
access_token: 'access.token.here',
|
||||
};
|
||||
|
||||
const mockClaims = {
|
||||
sub: 'user123',
|
||||
email: 'user@example.com',
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
clientConfigService.getOrCreateConfig.mockResolvedValue(mockConfig);
|
||||
tokenExchangeService.exchangeCodeForTokens.mockResolvedValue(mockTokens);
|
||||
claimsService.parseIdToken.mockReturnValue(mockClaims);
|
||||
claimsService.validateClaims.mockReturnValue('user123');
|
||||
|
||||
// Mock the OidcStateExtractor's static method
|
||||
const OidcStateExtractor = await import(
|
||||
'@app/unraid-api/graph/resolvers/sso/session/oidc-state-extractor.util.js'
|
||||
);
|
||||
vi.spyOn(OidcStateExtractor.OidcStateExtractor, 'extractAndValidateState').mockResolvedValue(
|
||||
{
|
||||
providerId: 'test-provider',
|
||||
originalState: 'original-state',
|
||||
clientState: 'original-state',
|
||||
redirectUri: 'https://example.com/callback',
|
||||
}
|
||||
);
|
||||
|
||||
const params = {
|
||||
providerId: 'test-provider',
|
||||
code: 'auth-code-123',
|
||||
state: 'secure-state',
|
||||
requestOrigin: 'https://example.com',
|
||||
fullCallbackUrl: 'https://example.com/callback?code=auth-code-123&state=secure-state',
|
||||
requestHeaders: {},
|
||||
};
|
||||
|
||||
const token = await service.handleCallback(params);
|
||||
|
||||
expect(token).toBe('padded-token-123');
|
||||
expect(tokenExchangeService.exchangeCodeForTokens).toHaveBeenCalled();
|
||||
expect(claimsService.parseIdToken).toHaveBeenCalledWith('id.token.here');
|
||||
expect(claimsService.validateClaims).toHaveBeenCalledWith(mockClaims);
|
||||
expect(authorizationService.checkAuthorization).toHaveBeenCalledWith(provider, mockClaims);
|
||||
expect(sessionService.createSession).toHaveBeenCalledWith('test-provider', 'user123');
|
||||
});
|
||||
|
||||
it('should throw when provider not found', async () => {
|
||||
oidcConfig.getProvider.mockResolvedValue(null);
|
||||
|
||||
const params = {
|
||||
providerId: 'non-existent',
|
||||
code: 'code',
|
||||
state: 'state',
|
||||
requestOrigin: 'https://example.com',
|
||||
fullCallbackUrl: 'https://example.com/callback',
|
||||
requestHeaders: {},
|
||||
};
|
||||
|
||||
await expect(service.handleCallback(params)).rejects.toThrow(UnauthorizedException);
|
||||
});
|
||||
|
||||
it('should handle authorization rejection', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'test-provider',
|
||||
name: 'Test Provider',
|
||||
clientId: 'test-client-id',
|
||||
issuer: 'https://test.example.com',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
const mockConfig = {
|
||||
serverMetadata: vi.fn().mockReturnValue({
|
||||
issuer: 'https://test.example.com',
|
||||
token_endpoint: 'https://test.example.com/token',
|
||||
}),
|
||||
};
|
||||
|
||||
const mockTokens = {
|
||||
id_token: 'id.token.here',
|
||||
};
|
||||
|
||||
const mockClaims = {
|
||||
sub: 'user123',
|
||||
email: 'user@example.com',
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
clientConfigService.getOrCreateConfig.mockResolvedValue(mockConfig);
|
||||
tokenExchangeService.exchangeCodeForTokens.mockResolvedValue(mockTokens);
|
||||
claimsService.parseIdToken.mockReturnValue(mockClaims);
|
||||
claimsService.validateClaims.mockReturnValue('user123');
|
||||
authorizationService.checkAuthorization.mockRejectedValue(
|
||||
new UnauthorizedException('Not authorized')
|
||||
);
|
||||
|
||||
// Mock the OidcStateExtractor's static method
|
||||
const OidcStateExtractor = await import(
|
||||
'@app/unraid-api/graph/resolvers/sso/session/oidc-state-extractor.util.js'
|
||||
);
|
||||
vi.spyOn(OidcStateExtractor.OidcStateExtractor, 'extractAndValidateState').mockResolvedValue(
|
||||
{
|
||||
providerId: 'test-provider',
|
||||
originalState: 'original-state',
|
||||
clientState: 'original-state',
|
||||
redirectUri: 'https://example.com/callback',
|
||||
}
|
||||
);
|
||||
|
||||
const params = {
|
||||
providerId: 'test-provider',
|
||||
code: 'auth-code-123',
|
||||
state: 'secure-state',
|
||||
requestOrigin: 'https://example.com',
|
||||
fullCallbackUrl: 'https://example.com/callback',
|
||||
requestHeaders: {},
|
||||
};
|
||||
|
||||
await expect(service.handleCallback(params)).rejects.toThrow(UnauthorizedException);
|
||||
});
|
||||
});
|
||||
|
||||
describe('validateProvider', () => {
|
||||
it('should clear cache and validate provider', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'test-provider',
|
||||
name: 'Test Provider',
|
||||
clientId: 'test-client-id',
|
||||
issuer: 'https://test.example.com',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
const result = await service.validateProvider(provider);
|
||||
|
||||
expect(clientConfigService.clearCache).toHaveBeenCalledWith('test-provider');
|
||||
// The validation service mock already returns { isValid: true }
|
||||
expect(result).toEqual({ isValid: true });
|
||||
});
|
||||
});
|
||||
|
||||
describe('extractProviderFromState', () => {
|
||||
it('should extract provider from state', () => {
|
||||
const state = 'provider-id:original-state';
|
||||
|
||||
const result = service.extractProviderFromState(state);
|
||||
|
||||
expect(result.providerId).toBeDefined();
|
||||
expect(result.originalState).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getStateService', () => {
|
||||
it('should return state service', () => {
|
||||
const result = service.getStateService();
|
||||
expect(result).toBe(stateService);
|
||||
});
|
||||
});
|
||||
});
|
||||
243
api/src/unraid-api/graph/resolvers/sso/core/oidc.service.ts
Normal file
243
api/src/unraid-api/graph/resolvers/sso/core/oidc.service.ts
Normal file
@@ -0,0 +1,243 @@
|
||||
import { Injectable, Logger, UnauthorizedException } from '@nestjs/common';
|
||||
|
||||
import * as client from 'openid-client';
|
||||
|
||||
import { OidcAuthorizationService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-authorization.service.js';
|
||||
import { OidcClaimsService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-claims.service.js';
|
||||
import { OidcTokenExchangeService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-token-exchange.service.js';
|
||||
import { OidcClientConfigService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-client-config.service.js';
|
||||
import { OidcRedirectUriService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-redirect-uri.service.js';
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
|
||||
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/core/oidc-validation.service.js';
|
||||
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
|
||||
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-session.service.js';
|
||||
import { OidcStateExtractor } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state-extractor.util.js';
|
||||
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
|
||||
import { ErrorExtractor } from '@app/unraid-api/utils/error-extractor.util.js';
|
||||
|
||||
export interface GetAuthorizationUrlParams {
|
||||
providerId: string;
|
||||
state: string;
|
||||
requestOrigin: string;
|
||||
requestHeaders: Record<string, string | string[] | undefined>;
|
||||
}
|
||||
|
||||
export interface HandleCallbackParams {
|
||||
providerId: string;
|
||||
code: string;
|
||||
state: string;
|
||||
requestOrigin: string;
|
||||
fullCallbackUrl: string;
|
||||
requestHeaders: Record<string, string | string[] | undefined>;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class OidcService {
|
||||
private readonly logger = new Logger(OidcService.name);
|
||||
|
||||
constructor(
|
||||
private readonly oidcConfig: OidcConfigPersistence,
|
||||
private readonly sessionService: OidcSessionService,
|
||||
private readonly stateService: OidcStateService,
|
||||
private readonly validationService: OidcValidationService,
|
||||
private readonly authorizationService: OidcAuthorizationService,
|
||||
private readonly redirectUriService: OidcRedirectUriService,
|
||||
private readonly clientConfigService: OidcClientConfigService,
|
||||
private readonly tokenExchangeService: OidcTokenExchangeService,
|
||||
private readonly claimsService: OidcClaimsService
|
||||
) {}
|
||||
|
||||
async getAuthorizationUrl(params: GetAuthorizationUrlParams): Promise<string> {
|
||||
const { providerId, state, requestOrigin, requestHeaders } = params;
|
||||
|
||||
const provider = await this.oidcConfig.getProvider(providerId);
|
||||
if (!provider) {
|
||||
throw new UnauthorizedException(`Provider ${providerId} not found`);
|
||||
}
|
||||
|
||||
// Use requestOrigin with validation
|
||||
const redirectUri = await this.redirectUriService.getRedirectUri(requestOrigin, requestHeaders);
|
||||
|
||||
this.logger.debug(`Using redirect URI for authorization: ${redirectUri}`);
|
||||
this.logger.debug(`Request origin was: ${requestOrigin}`);
|
||||
|
||||
// Generate secure state with cryptographic signature, including redirect URI
|
||||
const secureState = await this.stateService.generateSecureState(providerId, state, redirectUri);
|
||||
|
||||
// Build authorization URL
|
||||
if (provider.authorizationEndpoint) {
|
||||
// Use custom authorization endpoint
|
||||
const authUrl = new URL(provider.authorizationEndpoint);
|
||||
|
||||
// Standard OAuth2 parameters
|
||||
authUrl.searchParams.set('client_id', provider.clientId);
|
||||
authUrl.searchParams.set('redirect_uri', redirectUri);
|
||||
authUrl.searchParams.set('scope', provider.scopes.join(' '));
|
||||
authUrl.searchParams.set('state', secureState);
|
||||
authUrl.searchParams.set('response_type', 'code');
|
||||
|
||||
this.logger.debug(`Built authorization URL for provider ${provider.id}`);
|
||||
this.logger.debug(
|
||||
`Authorization parameters: client_id=${provider.clientId}, redirect_uri=${redirectUri}, scope=${provider.scopes.join(' ')}, response_type=code`
|
||||
);
|
||||
|
||||
return authUrl.href;
|
||||
}
|
||||
|
||||
// Use OIDC discovery for providers without custom endpoints
|
||||
const config = await this.clientConfigService.getOrCreateConfig(provider);
|
||||
const parameters: Record<string, string> = {
|
||||
redirect_uri: redirectUri,
|
||||
scope: provider.scopes.join(' '),
|
||||
state: secureState,
|
||||
response_type: 'code',
|
||||
};
|
||||
|
||||
// For HTTP endpoints, we need to call allowInsecureRequests on the config
|
||||
if (provider.issuer) {
|
||||
try {
|
||||
const serverUrl = new URL(provider.issuer);
|
||||
if (serverUrl.protocol === 'http:') {
|
||||
this.logger.debug(`Allowing insecure requests for HTTP endpoint: ${provider.id}`);
|
||||
// allowInsecureRequests is deprecated but still needed for HTTP endpoints
|
||||
client.allowInsecureRequests(config);
|
||||
}
|
||||
} catch (error) {
|
||||
this.logger.warn(`Invalid issuer URL for provider ${provider.id}: ${provider.issuer}`);
|
||||
// Continue without special HTTP options
|
||||
}
|
||||
}
|
||||
|
||||
const authUrl = client.buildAuthorizationUrl(config, parameters);
|
||||
|
||||
this.logger.log(`Built authorization URL via discovery for provider ${provider.id}`);
|
||||
this.logger.log(`Authorization parameters: ${JSON.stringify(parameters)}`);
|
||||
|
||||
return authUrl.href;
|
||||
}
|
||||
|
||||
extractProviderFromState(state: string): { providerId: string; originalState: string } {
|
||||
return OidcStateExtractor.extractProviderFromState(state, this.stateService);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the state service for external utilities
|
||||
*/
|
||||
getStateService(): OidcStateService {
|
||||
return this.stateService;
|
||||
}
|
||||
|
||||
async handleCallback(params: HandleCallbackParams): Promise<string> {
|
||||
const { providerId, code, state, fullCallbackUrl } = params;
|
||||
|
||||
const provider = await this.oidcConfig.getProvider(providerId);
|
||||
if (!provider) {
|
||||
throw new UnauthorizedException(`Provider ${providerId} not found`);
|
||||
}
|
||||
|
||||
// Extract and validate state, including the stored redirect URI
|
||||
const stateInfo = await OidcStateExtractor.extractAndValidateState(state, this.stateService);
|
||||
if (!stateInfo.redirectUri) {
|
||||
throw new UnauthorizedException('Missing redirect URI in state');
|
||||
}
|
||||
|
||||
// Use the redirect URI that was stored during authorization
|
||||
const redirectUri = stateInfo.redirectUri;
|
||||
this.logger.debug(`Using stored redirect URI from state: ${redirectUri}`);
|
||||
|
||||
try {
|
||||
// Always use openid-client for consistency
|
||||
const config = await this.clientConfigService.getOrCreateConfig(provider);
|
||||
|
||||
// Log configuration details
|
||||
this.logger.debug(`Provider ${providerId} config loaded`);
|
||||
this.logger.debug(`Redirect URI: ${redirectUri}`);
|
||||
|
||||
// Build current URL for token exchange
|
||||
// CRITICAL: The URL used here MUST match the redirect_uri that was sent to the authorization endpoint
|
||||
// Google expects the exact same redirect_uri during token exchange
|
||||
const currentUrl = new URL(redirectUri);
|
||||
currentUrl.searchParams.set('code', code);
|
||||
currentUrl.searchParams.set('state', state);
|
||||
|
||||
// Copy additional parameters from the actual callback if provided
|
||||
if (fullCallbackUrl) {
|
||||
const actualUrl = new URL(fullCallbackUrl);
|
||||
// Copy over additional params that Google might have added (scope, authuser, prompt, etc)
|
||||
// but DO NOT change the base URL or path
|
||||
['scope', 'authuser', 'prompt', 'hd', 'session_state', 'iss'].forEach((param) => {
|
||||
const value = actualUrl.searchParams.get(param);
|
||||
if (value && !currentUrl.searchParams.has(param)) {
|
||||
currentUrl.searchParams.set(param, value);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Google returns iss in the response, openid-client v6 expects it
|
||||
// If not present, add it based on the provider's issuer
|
||||
if (!currentUrl.searchParams.has('iss') && provider.issuer) {
|
||||
currentUrl.searchParams.set('iss', provider.issuer);
|
||||
}
|
||||
|
||||
this.logger.debug(`Token exchange URL (matches redirect_uri): ${currentUrl.href}`);
|
||||
|
||||
// State was already validated in extractAndValidateState above, use that result
|
||||
// The clientState should be present after successful validation, but handle the edge case
|
||||
if (!stateInfo.clientState) {
|
||||
this.logger.warn('Client state missing after successful validation');
|
||||
throw new UnauthorizedException('Invalid state: missing client state');
|
||||
}
|
||||
const originalState = stateInfo.clientState;
|
||||
this.logger.debug(`Exchanging code for tokens with provider ${providerId}`);
|
||||
this.logger.debug(`Client state extracted: ${originalState}`);
|
||||
|
||||
// Use the token exchange service
|
||||
const tokens = await this.tokenExchangeService.exchangeCodeForTokens(
|
||||
config,
|
||||
provider,
|
||||
code,
|
||||
originalState,
|
||||
redirectUri,
|
||||
fullCallbackUrl
|
||||
);
|
||||
|
||||
// Parse ID token to get user info
|
||||
const claims = this.claimsService.parseIdToken(tokens.id_token);
|
||||
const userSub = this.claimsService.validateClaims(claims);
|
||||
|
||||
// Check authorization based on rules
|
||||
// This will throw a helpful error if misconfigured or unauthorized
|
||||
await this.authorizationService.checkAuthorization(provider, claims!);
|
||||
|
||||
// Create session and return padded token
|
||||
const paddedToken = await this.sessionService.createSession(providerId, userSub);
|
||||
|
||||
this.logger.log(`Successfully authenticated user ${userSub} via provider ${providerId}`);
|
||||
|
||||
return paddedToken;
|
||||
} catch (error) {
|
||||
const extracted = ErrorExtractor.extract(error);
|
||||
this.logger.error(`OAuth callback error: ${extracted.message}`);
|
||||
// Re-throw the original error if it's already an UnauthorizedException
|
||||
if (error instanceof UnauthorizedException) {
|
||||
throw error;
|
||||
}
|
||||
// Otherwise throw a generic error
|
||||
throw new UnauthorizedException('Authentication failed');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate OIDC provider configuration by attempting discovery
|
||||
* Returns validation result with helpful error messages for debugging
|
||||
*/
|
||||
async validateProvider(
|
||||
provider: OidcProvider
|
||||
): Promise<{ isValid: boolean; error?: string; details?: unknown }> {
|
||||
// Clear any cached config for this provider to force fresh validation
|
||||
this.clientConfigService.clearCache(provider.id);
|
||||
|
||||
// Delegate to the validation service
|
||||
return this.validationService.validateProvider(provider);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
import { Field, ObjectType } from '@nestjs/graphql';
|
||||
|
||||
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
|
||||
|
||||
@ObjectType()
|
||||
export class OidcConfiguration {
|
||||
@Field(() => [OidcProvider], { description: 'List of configured OIDC providers' })
|
||||
providers!: OidcProvider[];
|
||||
|
||||
@Field(() => [String], {
|
||||
nullable: true,
|
||||
description:
|
||||
'Default allowed redirect origins that apply to all OIDC providers (e.g., Tailscale domains)',
|
||||
})
|
||||
defaultAllowedOrigins?: string[];
|
||||
}
|
||||
@@ -80,9 +80,11 @@ export class OidcProvider {
|
||||
@Field(() => String, {
|
||||
description:
|
||||
'OIDC issuer URL (e.g., https://accounts.google.com). Required for auto-discovery via /.well-known/openid-configuration',
|
||||
nullable: true,
|
||||
})
|
||||
@IsUrl()
|
||||
issuer!: string;
|
||||
@IsOptional()
|
||||
issuer?: string;
|
||||
|
||||
@Field(() => String, {
|
||||
nullable: true,
|
||||
@@ -1,4 +1,4 @@
|
||||
import type { OidcConfig } from '@app/unraid-api/graph/resolvers/sso/oidc-config.service.js';
|
||||
import type { OidcConfig } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
|
||||
|
||||
declare module '@unraid/shared/services/user-settings.js' {
|
||||
interface UserSettings {
|
||||
@@ -1,701 +0,0 @@
|
||||
import { Injectable, Logger, UnauthorizedException } from '@nestjs/common';
|
||||
import { ConfigService } from '@nestjs/config';
|
||||
|
||||
import { decodeJwt } from 'jose';
|
||||
import * as client from 'openid-client';
|
||||
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/oidc-config.service.js';
|
||||
import {
|
||||
AuthorizationOperator,
|
||||
AuthorizationRuleMode,
|
||||
OidcAuthorizationRule,
|
||||
OidcProvider,
|
||||
} from '@app/unraid-api/graph/resolvers/sso/oidc-provider.model.js';
|
||||
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/oidc-session.service.js';
|
||||
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/oidc-state.service.js';
|
||||
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/oidc-validation.service.js';
|
||||
|
||||
interface JwtClaims {
|
||||
sub?: string;
|
||||
email?: string;
|
||||
name?: string;
|
||||
hd?: string; // Google hosted domain
|
||||
[claim: string]: unknown;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class OidcAuthService {
|
||||
private readonly logger = new Logger(OidcAuthService.name);
|
||||
private readonly configCache = new Map<string, client.Configuration>();
|
||||
|
||||
constructor(
|
||||
private readonly configService: ConfigService,
|
||||
private readonly oidcConfig: OidcConfigPersistence,
|
||||
private readonly sessionService: OidcSessionService,
|
||||
private readonly stateService: OidcStateService,
|
||||
private readonly validationService: OidcValidationService
|
||||
) {}
|
||||
|
||||
async getAuthorizationUrl(
|
||||
providerId: string,
|
||||
state: string,
|
||||
requestOrigin?: string
|
||||
): Promise<string> {
|
||||
const provider = await this.oidcConfig.getProvider(providerId);
|
||||
if (!provider) {
|
||||
throw new UnauthorizedException(`Provider ${providerId} not found`);
|
||||
}
|
||||
|
||||
const redirectUri = this.getRedirectUri(requestOrigin);
|
||||
|
||||
// Generate secure state with cryptographic signature
|
||||
const secureState = this.stateService.generateSecureState(providerId, state);
|
||||
|
||||
// Build authorization URL
|
||||
if (provider.authorizationEndpoint) {
|
||||
// Use custom authorization endpoint
|
||||
const authUrl = new URL(provider.authorizationEndpoint);
|
||||
|
||||
// Standard OAuth2 parameters
|
||||
authUrl.searchParams.set('client_id', provider.clientId);
|
||||
authUrl.searchParams.set('redirect_uri', redirectUri);
|
||||
authUrl.searchParams.set('scope', provider.scopes.join(' '));
|
||||
authUrl.searchParams.set('state', secureState);
|
||||
authUrl.searchParams.set('response_type', 'code');
|
||||
|
||||
return authUrl.href;
|
||||
}
|
||||
|
||||
// Use OIDC discovery for providers without custom endpoints
|
||||
const config = await this.getOrCreateConfig(provider);
|
||||
const parameters: Record<string, string> = {
|
||||
redirect_uri: redirectUri,
|
||||
scope: provider.scopes.join(' '),
|
||||
state: secureState,
|
||||
response_type: 'code',
|
||||
};
|
||||
|
||||
// For HTTP endpoints, we need to pass the allowInsecureRequests option
|
||||
const serverUrl = new URL(provider.issuer || '');
|
||||
let clientOptions: any = undefined;
|
||||
if (serverUrl.protocol === 'http:') {
|
||||
this.logger.debug(
|
||||
`Building authorization URL with allowInsecureRequests for ${provider.id}`
|
||||
);
|
||||
clientOptions = {
|
||||
execute: [client.allowInsecureRequests],
|
||||
};
|
||||
}
|
||||
|
||||
const authUrl = client.buildAuthorizationUrl(config, parameters);
|
||||
|
||||
return authUrl.href;
|
||||
}
|
||||
|
||||
extractProviderFromState(state: string): { providerId: string; originalState: string } {
|
||||
// Extract provider from state prefix (no decryption needed)
|
||||
const providerId = this.stateService.extractProviderFromState(state);
|
||||
|
||||
if (providerId) {
|
||||
return {
|
||||
providerId,
|
||||
originalState: state,
|
||||
};
|
||||
}
|
||||
|
||||
// Fallback for unknown formats
|
||||
return {
|
||||
providerId: '',
|
||||
originalState: state,
|
||||
};
|
||||
}
|
||||
|
||||
async handleCallback(
|
||||
providerId: string,
|
||||
code: string,
|
||||
state: string,
|
||||
requestOrigin?: string,
|
||||
fullCallbackUrl?: string
|
||||
): Promise<string> {
|
||||
const provider = await this.oidcConfig.getProvider(providerId);
|
||||
if (!provider) {
|
||||
throw new UnauthorizedException(`Provider ${providerId} not found`);
|
||||
}
|
||||
|
||||
try {
|
||||
const redirectUri = this.getRedirectUri(requestOrigin);
|
||||
|
||||
// Always use openid-client for consistency
|
||||
const config = await this.getOrCreateConfig(provider);
|
||||
|
||||
// Log configuration details
|
||||
this.logger.debug(`Provider ${providerId} config loaded`);
|
||||
this.logger.debug(`Redirect URI: ${redirectUri}`);
|
||||
|
||||
// Build current URL for token exchange
|
||||
// CRITICAL: The URL used here MUST match the redirect_uri that was sent to the authorization endpoint
|
||||
// Google expects the exact same redirect_uri during token exchange
|
||||
const currentUrl = new URL(redirectUri);
|
||||
currentUrl.searchParams.set('code', code);
|
||||
currentUrl.searchParams.set('state', state);
|
||||
|
||||
// Copy additional parameters from the actual callback if provided
|
||||
if (fullCallbackUrl) {
|
||||
const actualUrl = new URL(fullCallbackUrl);
|
||||
// Copy over additional params that Google might have added (scope, authuser, prompt, etc)
|
||||
// but DO NOT change the base URL or path
|
||||
['scope', 'authuser', 'prompt', 'hd', 'session_state', 'iss'].forEach((param) => {
|
||||
const value = actualUrl.searchParams.get(param);
|
||||
if (value && !currentUrl.searchParams.has(param)) {
|
||||
currentUrl.searchParams.set(param, value);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Google returns iss in the response, openid-client v6 expects it
|
||||
// If not present, add it based on the provider's issuer
|
||||
if (!currentUrl.searchParams.has('iss') && provider.issuer) {
|
||||
currentUrl.searchParams.set('iss', provider.issuer);
|
||||
}
|
||||
|
||||
this.logger.debug(`Token exchange URL (matches redirect_uri): ${currentUrl.href}`);
|
||||
|
||||
// Validate secure state
|
||||
const stateValidation = this.stateService.validateSecureState(state, providerId);
|
||||
if (!stateValidation.isValid) {
|
||||
this.logger.error(`State validation failed: ${stateValidation.error}`);
|
||||
throw new UnauthorizedException(stateValidation.error || 'Invalid state parameter');
|
||||
}
|
||||
|
||||
const originalState = stateValidation.clientState!;
|
||||
this.logger.debug(`Exchanging code for tokens with provider ${providerId}`);
|
||||
this.logger.debug(`Client state extracted: ${originalState}`);
|
||||
|
||||
// For openid-client v6, we need to prepare the authorization response
|
||||
const authorizationResponse = new URLSearchParams(currentUrl.search);
|
||||
|
||||
// Set the original client state for openid-client
|
||||
authorizationResponse.set('state', originalState);
|
||||
|
||||
// Create a new URL with the cleaned parameters
|
||||
const cleanUrl = new URL(redirectUri);
|
||||
cleanUrl.search = authorizationResponse.toString();
|
||||
|
||||
this.logger.debug(`Clean URL for token exchange: ${cleanUrl.href}`);
|
||||
|
||||
let tokens;
|
||||
try {
|
||||
this.logger.debug(`Starting token exchange with openid-client`);
|
||||
this.logger.debug(`Config issuer: ${config.serverMetadata().issuer}`);
|
||||
this.logger.debug(`Config token endpoint: ${config.serverMetadata().token_endpoint}`);
|
||||
|
||||
// For HTTP endpoints, we need to pass the allowInsecureRequests option
|
||||
const serverUrl = new URL(provider.issuer || '');
|
||||
let clientOptions: any = undefined;
|
||||
if (serverUrl.protocol === 'http:') {
|
||||
this.logger.debug(`Token exchange with allowInsecureRequests for ${provider.id}`);
|
||||
clientOptions = {
|
||||
execute: [client.allowInsecureRequests],
|
||||
};
|
||||
}
|
||||
|
||||
tokens = await client.authorizationCodeGrant(
|
||||
config,
|
||||
cleanUrl,
|
||||
{
|
||||
expectedState: originalState,
|
||||
},
|
||||
clientOptions
|
||||
);
|
||||
this.logger.debug(
|
||||
`Token exchange successful, received tokens: ${Object.keys(tokens).join(', ')}`
|
||||
);
|
||||
} catch (tokenError) {
|
||||
const errorMessage =
|
||||
tokenError instanceof Error ? tokenError.message : String(tokenError);
|
||||
this.logger.error(`Token exchange failed: ${errorMessage}`);
|
||||
|
||||
// Check if error message contains the "unexpected JWT claim" text
|
||||
if (errorMessage.includes('unexpected JWT claim value encountered')) {
|
||||
this.logger.error(
|
||||
`unexpected JWT claim value encountered during token validation by openid-client`
|
||||
);
|
||||
this.logger.debug(
|
||||
`Token exchange error details: ${JSON.stringify(tokenError, null, 2)}`
|
||||
);
|
||||
|
||||
// Log the actual vs expected issuer
|
||||
this.logger.error(
|
||||
`This error typically means the 'iss' claim in the JWT doesn't match the expected issuer`
|
||||
);
|
||||
this.logger.error(`Check that your provider's issuer URL is configured correctly`);
|
||||
}
|
||||
|
||||
throw tokenError;
|
||||
}
|
||||
|
||||
// Parse ID token to get user info
|
||||
let claims: JwtClaims | null = null;
|
||||
if (tokens.id_token) {
|
||||
try {
|
||||
// Use jose to properly decode the JWT
|
||||
claims = decodeJwt(tokens.id_token) as JwtClaims;
|
||||
|
||||
// Log claims safely without PII - only structure, not values
|
||||
if (claims) {
|
||||
const claimKeys = Object.keys(claims).join(', ');
|
||||
this.logger.debug(
|
||||
`ID token decoded successfully. Available claims: [${claimKeys}]`
|
||||
);
|
||||
|
||||
// Log claim types without exposing sensitive values
|
||||
for (const [key, value] of Object.entries(claims)) {
|
||||
const valueType = Array.isArray(value)
|
||||
? `array[${value.length}]`
|
||||
: typeof value;
|
||||
|
||||
// Only log structure, not actual values (avoid PII)
|
||||
this.logger.debug(`Claim '${key}': type=${valueType}`);
|
||||
|
||||
// Check for unexpected claim types
|
||||
if (valueType === 'object' && value !== null && !Array.isArray(value)) {
|
||||
this.logger.warn(`Claim '${key}' contains complex object structure`);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
this.logger.warn(`Failed to parse ID token: ${e}`);
|
||||
}
|
||||
} else {
|
||||
this.logger.error('No ID token received from provider');
|
||||
}
|
||||
|
||||
if (!claims?.sub) {
|
||||
this.logger.error(
|
||||
'No subject in token - claims available: ' +
|
||||
(claims ? Object.keys(claims).join(', ') : 'none')
|
||||
);
|
||||
throw new UnauthorizedException('No subject in token');
|
||||
}
|
||||
|
||||
const userSub = claims.sub;
|
||||
this.logger.debug(`Processing authentication for user: ${userSub}`);
|
||||
|
||||
// Check authorization based on rules
|
||||
// This will throw a helpful error if misconfigured or unauthorized
|
||||
await this.checkAuthorization(provider, claims);
|
||||
|
||||
// Create session and return padded token
|
||||
const paddedToken = await this.sessionService.createSession(providerId, userSub);
|
||||
|
||||
this.logger.log(`Successfully authenticated user ${userSub} via provider ${providerId}`);
|
||||
|
||||
return paddedToken;
|
||||
} catch (error) {
|
||||
this.logger.error(
|
||||
`OAuth callback error: ${error instanceof Error ? error.message : 'Unknown error'}`
|
||||
);
|
||||
// Re-throw the original error if it's already an UnauthorizedException
|
||||
if (error instanceof UnauthorizedException) {
|
||||
throw error;
|
||||
}
|
||||
// Otherwise throw a generic error
|
||||
throw new UnauthorizedException('Authentication failed');
|
||||
}
|
||||
}
|
||||
|
||||
private async getOrCreateConfig(provider: OidcProvider): Promise<client.Configuration> {
|
||||
const cacheKey = provider.id;
|
||||
|
||||
if (this.configCache.has(cacheKey)) {
|
||||
return this.configCache.get(cacheKey)!;
|
||||
}
|
||||
|
||||
try {
|
||||
// Use the validation service to perform discovery with HTTP support
|
||||
if (provider.issuer) {
|
||||
this.logger.debug(`Attempting discovery for ${provider.id} at ${provider.issuer}`);
|
||||
|
||||
// Create client options with HTTP support if needed
|
||||
const serverUrl = new URL(provider.issuer);
|
||||
let clientOptions: any = undefined;
|
||||
if (serverUrl.protocol === 'http:') {
|
||||
this.logger.debug(`Allowing HTTP for ${provider.id} as specified by user`);
|
||||
clientOptions = {
|
||||
execute: [client.allowInsecureRequests],
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const config = await this.validationService.performDiscovery(
|
||||
provider,
|
||||
clientOptions
|
||||
);
|
||||
this.logger.debug(`Discovery successful for ${provider.id}`);
|
||||
this.logger.debug(
|
||||
`Authorization endpoint: ${config.serverMetadata().authorization_endpoint}`
|
||||
);
|
||||
this.logger.debug(`Token endpoint: ${config.serverMetadata().token_endpoint}`);
|
||||
this.configCache.set(cacheKey, config);
|
||||
return config;
|
||||
} catch (discoveryError) {
|
||||
const errorMessage =
|
||||
discoveryError instanceof Error ? discoveryError.message : 'Unknown error';
|
||||
this.logger.warn(`Discovery failed for ${provider.id}: ${errorMessage}`);
|
||||
|
||||
// Log more details about the discovery error
|
||||
this.logger.debug(
|
||||
`Discovery URL attempted: ${provider.issuer}/.well-known/openid-configuration`
|
||||
);
|
||||
this.logger.debug(
|
||||
`Full discovery error: ${JSON.stringify(discoveryError, null, 2)}`
|
||||
);
|
||||
|
||||
// Log stack trace for better debugging
|
||||
if (discoveryError instanceof Error && discoveryError.stack) {
|
||||
this.logger.debug(`Stack trace: ${discoveryError.stack}`);
|
||||
}
|
||||
|
||||
// If discovery fails but we have manual endpoints, use them
|
||||
if (provider.authorizationEndpoint && provider.tokenEndpoint) {
|
||||
this.logger.log(`Using manual endpoints for ${provider.id}`);
|
||||
|
||||
// Create manual configuration
|
||||
const serverMetadata: client.ServerMetadata = {
|
||||
issuer: provider.issuer || `manual-${provider.id}`,
|
||||
authorization_endpoint: provider.authorizationEndpoint,
|
||||
token_endpoint: provider.tokenEndpoint,
|
||||
jwks_uri: provider.jwksUri,
|
||||
};
|
||||
|
||||
const clientMetadata: Partial<client.ClientMetadata> = {
|
||||
client_secret: provider.clientSecret,
|
||||
};
|
||||
|
||||
// Configure client auth method
|
||||
const clientAuth = provider.clientSecret
|
||||
? client.ClientSecretPost(provider.clientSecret)
|
||||
: client.None();
|
||||
|
||||
try {
|
||||
const config = new client.Configuration(
|
||||
serverMetadata,
|
||||
provider.clientId,
|
||||
clientMetadata,
|
||||
clientAuth
|
||||
);
|
||||
|
||||
// Use manual configuration with HTTP support if needed
|
||||
const serverUrl = new URL(provider.tokenEndpoint);
|
||||
if (serverUrl.protocol === 'http:') {
|
||||
this.logger.debug(
|
||||
`Allowing HTTP for manual endpoints on ${provider.id}`
|
||||
);
|
||||
client.allowInsecureRequests(config);
|
||||
}
|
||||
|
||||
this.logger.debug(`Manual configuration created for ${provider.id}`);
|
||||
this.logger.debug(
|
||||
`Authorization endpoint: ${serverMetadata.authorization_endpoint}`
|
||||
);
|
||||
this.logger.debug(`Token endpoint: ${serverMetadata.token_endpoint}`);
|
||||
|
||||
this.configCache.set(cacheKey, config);
|
||||
return config;
|
||||
} catch (manualConfigError) {
|
||||
this.logger.error(
|
||||
`Failed to create manual configuration: ${manualConfigError instanceof Error ? manualConfigError.message : 'Unknown error'}`
|
||||
);
|
||||
throw new Error(`Manual configuration failed for ${provider.id}`);
|
||||
}
|
||||
} else {
|
||||
throw new Error(
|
||||
`OIDC discovery failed and no manual endpoints provided for ${provider.id}`
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Manual configuration when no issuer is provided
|
||||
if (provider.authorizationEndpoint && provider.tokenEndpoint) {
|
||||
this.logger.log(`Using manual endpoints for ${provider.id} (no issuer provided)`);
|
||||
|
||||
// Create manual configuration
|
||||
const serverMetadata: client.ServerMetadata = {
|
||||
issuer: provider.issuer || `manual-${provider.id}`,
|
||||
authorization_endpoint: provider.authorizationEndpoint,
|
||||
token_endpoint: provider.tokenEndpoint,
|
||||
jwks_uri: provider.jwksUri,
|
||||
};
|
||||
|
||||
const clientMetadata: Partial<client.ClientMetadata> = {
|
||||
client_secret: provider.clientSecret,
|
||||
};
|
||||
|
||||
// Configure client auth method
|
||||
const clientAuth = provider.clientSecret
|
||||
? client.ClientSecretPost(provider.clientSecret)
|
||||
: client.None();
|
||||
|
||||
try {
|
||||
const config = new client.Configuration(
|
||||
serverMetadata,
|
||||
provider.clientId,
|
||||
clientMetadata,
|
||||
clientAuth
|
||||
);
|
||||
|
||||
// Use manual configuration with HTTP support if needed
|
||||
const serverUrl = new URL(provider.tokenEndpoint);
|
||||
if (serverUrl.protocol === 'http:') {
|
||||
this.logger.debug(`Allowing HTTP for manual endpoints on ${provider.id}`);
|
||||
client.allowInsecureRequests(config);
|
||||
}
|
||||
|
||||
this.logger.debug(`Manual configuration created for ${provider.id}`);
|
||||
this.logger.debug(
|
||||
`Authorization endpoint: ${serverMetadata.authorization_endpoint}`
|
||||
);
|
||||
this.logger.debug(`Token endpoint: ${serverMetadata.token_endpoint}`);
|
||||
|
||||
this.configCache.set(cacheKey, config);
|
||||
return config;
|
||||
} catch (manualConfigError) {
|
||||
this.logger.error(
|
||||
`Failed to create manual configuration: ${manualConfigError instanceof Error ? manualConfigError.message : 'Unknown error'}`
|
||||
);
|
||||
throw new Error(`Manual configuration failed for ${provider.id}`);
|
||||
}
|
||||
}
|
||||
|
||||
// If we reach here, neither discovery nor manual endpoints are available
|
||||
throw new Error(
|
||||
`No configuration method available for ${provider.id}: requires either valid issuer for discovery or manual endpoints`
|
||||
);
|
||||
} catch (error) {
|
||||
this.logger.error(
|
||||
`Failed to create OIDC configuration for ${provider.id}: ${
|
||||
error instanceof Error ? error.message : 'Unknown error'
|
||||
}`
|
||||
);
|
||||
|
||||
// Log more details in debug mode
|
||||
if (error instanceof Error && error.stack) {
|
||||
this.logger.debug(`Stack trace: ${error.stack}`);
|
||||
}
|
||||
|
||||
throw new UnauthorizedException('Provider configuration error');
|
||||
}
|
||||
}
|
||||
|
||||
private async checkAuthorization(provider: OidcProvider, claims: JwtClaims): Promise<void> {
|
||||
this.logger.debug(
|
||||
`Checking authorization for provider ${provider.id} with ${provider.authorizationRules?.length || 0} rules`
|
||||
);
|
||||
this.logger.debug(`Available claims: ${Object.keys(claims).join(', ')}`);
|
||||
this.logger.debug(
|
||||
`Authorization rule mode: ${provider.authorizationRuleMode || AuthorizationRuleMode.OR}`
|
||||
);
|
||||
|
||||
// If no authorization rules are specified, throw a helpful error
|
||||
if (!provider.authorizationRules || provider.authorizationRules.length === 0) {
|
||||
throw new UnauthorizedException(
|
||||
`Login failed: The ${provider.name} provider has no authorization rules configured. ` +
|
||||
`Please configure authorization rules.`
|
||||
);
|
||||
}
|
||||
|
||||
this.logger.debug(
|
||||
`Authorization rules to evaluate: ${JSON.stringify(provider.authorizationRules, null, 2)}`
|
||||
);
|
||||
|
||||
// Evaluate the rules
|
||||
const ruleMode = provider.authorizationRuleMode || AuthorizationRuleMode.OR;
|
||||
const isAuthorized = this.evaluateAuthorizationRules(
|
||||
provider.authorizationRules,
|
||||
claims,
|
||||
ruleMode
|
||||
);
|
||||
|
||||
this.logger.debug(`Authorization result: ${isAuthorized}`);
|
||||
|
||||
if (!isAuthorized) {
|
||||
// Log authorization failure with safe claim representation (no PII)
|
||||
const availableClaimKeys = Object.keys(claims).join(', ');
|
||||
this.logger.warn(
|
||||
`Authorization failed for provider ${provider.name}, user ${claims.sub}, available claim keys: [${availableClaimKeys}]`
|
||||
);
|
||||
throw new UnauthorizedException(
|
||||
`Access denied: Your account does not meet the authorization requirements for ${provider.name}.`
|
||||
);
|
||||
}
|
||||
|
||||
this.logger.debug(`Authorization successful for user ${claims.sub}`);
|
||||
}
|
||||
|
||||
private evaluateAuthorizationRules(
|
||||
rules: OidcAuthorizationRule[],
|
||||
claims: JwtClaims,
|
||||
mode: AuthorizationRuleMode = AuthorizationRuleMode.OR
|
||||
): boolean {
|
||||
// No rules means no authorization
|
||||
if (rules.length === 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (mode === AuthorizationRuleMode.AND) {
|
||||
// All rules must pass (AND logic)
|
||||
return rules.every((rule) => this.evaluateRule(rule, claims));
|
||||
} else {
|
||||
// Any rule can pass (OR logic) - default behavior
|
||||
// Multiple rules act as alternative authorization paths
|
||||
return rules.some((rule) => this.evaluateRule(rule, claims));
|
||||
}
|
||||
}
|
||||
|
||||
private evaluateRule(rule: OidcAuthorizationRule, claims: JwtClaims): boolean {
|
||||
const claimValue = claims[rule.claim];
|
||||
|
||||
this.logger.verbose(
|
||||
`Evaluating rule for claim ${rule.claim}: ${JSON.stringify({
|
||||
claimValue,
|
||||
claimType: typeof claimValue,
|
||||
isArray: Array.isArray(claimValue),
|
||||
ruleOperator: rule.operator,
|
||||
ruleValues: rule.value,
|
||||
})}`
|
||||
);
|
||||
|
||||
if (claimValue === undefined || claimValue === null) {
|
||||
this.logger.verbose(`Claim ${rule.claim} not found in token`);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Handle non-array, non-string objects
|
||||
if (typeof claimValue === 'object' && claimValue !== null && !Array.isArray(claimValue)) {
|
||||
this.logger.warn(
|
||||
`unexpected JWT claim value encountered - claim ${rule.claim} has unsupported object type (keys: [${Object.keys(claimValue as Record<string, unknown>).join(', ')}])`
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Handle array claims - evaluate rule against each array element
|
||||
if (Array.isArray(claimValue)) {
|
||||
this.logger.verbose(
|
||||
`Processing array claim ${rule.claim} with ${claimValue.length} elements`
|
||||
);
|
||||
|
||||
// For array claims, check if ANY element in the array matches the rule
|
||||
const arrayResult = claimValue.some((element) => {
|
||||
// Skip non-string elements
|
||||
if (
|
||||
typeof element !== 'string' &&
|
||||
typeof element !== 'number' &&
|
||||
typeof element !== 'boolean'
|
||||
) {
|
||||
this.logger.verbose(`Skipping non-primitive element in array: ${typeof element}`);
|
||||
return false;
|
||||
}
|
||||
|
||||
const elementValue = String(element);
|
||||
return this.evaluateSingleValue(elementValue, rule);
|
||||
});
|
||||
|
||||
this.logger.verbose(`Array evaluation result for claim ${rule.claim}: ${arrayResult}`);
|
||||
return arrayResult;
|
||||
}
|
||||
|
||||
// Handle single value claims (string, number, boolean)
|
||||
const value = String(claimValue);
|
||||
this.logger.verbose(`Processing single value claim ${rule.claim} with value: "${value}"`);
|
||||
|
||||
return this.evaluateSingleValue(value, rule);
|
||||
}
|
||||
|
||||
private evaluateSingleValue(value: string, rule: OidcAuthorizationRule): boolean {
|
||||
let result: boolean;
|
||||
switch (rule.operator) {
|
||||
case AuthorizationOperator.EQUALS:
|
||||
result = rule.value.some((v) => value === v);
|
||||
this.logger.verbose(
|
||||
`EQUALS check: "${value}" matches any of [${rule.value.join(', ')}]: ${result}`
|
||||
);
|
||||
return result;
|
||||
|
||||
case AuthorizationOperator.CONTAINS:
|
||||
result = rule.value.some((v) => value.includes(v));
|
||||
this.logger.verbose(
|
||||
`CONTAINS check: "${value}" contains any of [${rule.value.join(', ')}]: ${result}`
|
||||
);
|
||||
return result;
|
||||
|
||||
case AuthorizationOperator.STARTS_WITH:
|
||||
result = rule.value.some((v) => value.startsWith(v));
|
||||
this.logger.verbose(
|
||||
`STARTS_WITH check: "${value}" starts with any of [${rule.value.join(', ')}]: ${result}`
|
||||
);
|
||||
return result;
|
||||
|
||||
case AuthorizationOperator.ENDS_WITH:
|
||||
result = rule.value.some((v) => value.endsWith(v));
|
||||
this.logger.verbose(
|
||||
`ENDS_WITH check: "${value}" ends with any of [${rule.value.join(', ')}]: ${result}`
|
||||
);
|
||||
return result;
|
||||
|
||||
default:
|
||||
this.logger.error(`Unknown authorization operator: ${rule.operator}`);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate OIDC provider configuration by attempting discovery
|
||||
* Returns validation result with helpful error messages for debugging
|
||||
*/
|
||||
async validateProvider(
|
||||
provider: OidcProvider
|
||||
): Promise<{ isValid: boolean; error?: string; details?: unknown }> {
|
||||
// Clear any cached config for this provider to force fresh validation
|
||||
this.configCache.delete(provider.id);
|
||||
|
||||
// Delegate to the validation service
|
||||
return this.validationService.validateProvider(provider);
|
||||
}
|
||||
|
||||
private getRedirectUri(requestOrigin?: string): string {
|
||||
// If we have the full origin (protocol://host), use it directly
|
||||
if (requestOrigin) {
|
||||
// Parse the origin to extract protocol and host
|
||||
try {
|
||||
const url = new URL(requestOrigin);
|
||||
const { protocol, hostname, port } = url;
|
||||
|
||||
// Reconstruct the URL, removing default ports
|
||||
let cleanOrigin = `${protocol}//${hostname}`;
|
||||
|
||||
// Add port if it's not the default for the protocol
|
||||
if (
|
||||
port &&
|
||||
!(protocol === 'https:' && port === '443') &&
|
||||
!(protocol === 'http:' && port === '80')
|
||||
) {
|
||||
cleanOrigin += `:${port}`;
|
||||
}
|
||||
|
||||
// Special handling for localhost development with Nuxt proxy
|
||||
if (hostname === 'localhost' && port === '3000') {
|
||||
return `${cleanOrigin}/graphql/api/auth/oidc/callback`;
|
||||
}
|
||||
|
||||
return `${cleanOrigin}/graphql/api/auth/oidc/callback`;
|
||||
} catch (e) {
|
||||
this.logger.warn(`Failed to parse request origin: ${requestOrigin}, error: ${e}`);
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to configured BASE_URL or default
|
||||
const baseUrl = this.configService.get('BASE_URL', 'http://tower.local');
|
||||
return `${baseUrl}/graphql/api/auth/oidc/callback`;
|
||||
}
|
||||
}
|
||||
@@ -1,204 +0,0 @@
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/oidc-state.service.js';
|
||||
|
||||
describe('OidcStateService', () => {
|
||||
let service: OidcStateService;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
vi.useFakeTimers();
|
||||
// Create a single instance for all tests in a describe block
|
||||
service = new OidcStateService();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
describe('generateSecureState', () => {
|
||||
it('should generate a state with provider prefix and signed token', () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state = service.generateSecureState(providerId, clientState);
|
||||
|
||||
expect(state).toBeTruthy();
|
||||
expect(typeof state).toBe('string');
|
||||
expect(state.startsWith(`${providerId}:`)).toBe(true);
|
||||
|
||||
// Extract signed portion and verify format (nonce.timestamp.signature)
|
||||
const signed = state.substring(providerId.length + 1);
|
||||
expect(signed.split('.').length).toBe(3);
|
||||
});
|
||||
|
||||
it('should generate unique states for each call', () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state1 = service.generateSecureState(providerId, clientState);
|
||||
const state2 = service.generateSecureState(providerId, clientState);
|
||||
|
||||
expect(state1).not.toBe(state2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('validateSecureState', () => {
|
||||
it('should validate a valid state token', () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state = service.generateSecureState(providerId, clientState);
|
||||
const result = service.validateSecureState(state, providerId);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.clientState).toBe(clientState);
|
||||
expect(result.error).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should reject state with wrong provider ID', () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state = service.generateSecureState(providerId, clientState);
|
||||
const result = service.validateSecureState(state, 'wrong-provider');
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toBe('Provider ID mismatch in state');
|
||||
});
|
||||
|
||||
it('should reject expired state tokens', () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state = service.generateSecureState(providerId, clientState);
|
||||
|
||||
// Fast forward time beyond expiration (11 minutes)
|
||||
vi.advanceTimersByTime(11 * 60 * 1000);
|
||||
|
||||
const result = service.validateSecureState(state, providerId);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toBe('State token has expired');
|
||||
});
|
||||
|
||||
it('should reject reused state tokens', () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state = service.generateSecureState(providerId, clientState);
|
||||
|
||||
// First validation should succeed
|
||||
const result1 = service.validateSecureState(state, providerId);
|
||||
expect(result1.isValid).toBe(true);
|
||||
|
||||
// Second validation should fail (replay attack prevention)
|
||||
const result2 = service.validateSecureState(state, providerId);
|
||||
expect(result2.isValid).toBe(false);
|
||||
expect(result2.error).toBe('State token not found or already used');
|
||||
});
|
||||
|
||||
it('should reject invalid state tokens', () => {
|
||||
const result = service.validateSecureState('invalid.state.token', 'test-provider');
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toBe('Invalid state format');
|
||||
});
|
||||
|
||||
it('should reject tampered state tokens', () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state = service.generateSecureState(providerId, clientState);
|
||||
|
||||
// Tamper with the signature
|
||||
const parts = state.split('.');
|
||||
parts[2] = parts[2].slice(0, -4) + 'XXXX';
|
||||
const tamperedState = parts.join('.');
|
||||
|
||||
const result = service.validateSecureState(tamperedState, providerId);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toBe('Invalid state signature');
|
||||
});
|
||||
});
|
||||
|
||||
describe('extractProviderFromState', () => {
|
||||
it('should extract provider from state prefix', () => {
|
||||
const state = 'provider-id:eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.payload.signature';
|
||||
const result = service.extractProviderFromState(state);
|
||||
|
||||
expect(result).toBe('provider-id');
|
||||
});
|
||||
|
||||
it('should handle states with multiple colons', () => {
|
||||
const state = 'provider-id:jwt:with:colons';
|
||||
const result = service.extractProviderFromState(state);
|
||||
|
||||
expect(result).toBe('provider-id');
|
||||
});
|
||||
|
||||
it('should return null for invalid format', () => {
|
||||
const result = service.extractProviderFromState('invalid-state');
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('extractProviderFromLegacyState', () => {
|
||||
it('should extract provider from legacy colon-separated format', () => {
|
||||
const result = service.extractProviderFromLegacyState('provider-id:client-state');
|
||||
|
||||
expect(result.providerId).toBe('provider-id');
|
||||
expect(result.originalState).toBe('client-state');
|
||||
});
|
||||
|
||||
it('should handle multiple colons in legacy format', () => {
|
||||
const result = service.extractProviderFromLegacyState(
|
||||
'provider-id:client:state:with:colons'
|
||||
);
|
||||
|
||||
expect(result.providerId).toBe('provider-id');
|
||||
expect(result.originalState).toBe('client:state:with:colons');
|
||||
});
|
||||
|
||||
it('should return empty provider for JWT format', () => {
|
||||
const jwtState = 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.payload.signature';
|
||||
const result = service.extractProviderFromLegacyState(jwtState);
|
||||
|
||||
expect(result.providerId).toBe('');
|
||||
expect(result.originalState).toBe(jwtState);
|
||||
});
|
||||
|
||||
it('should return empty provider for unknown format', () => {
|
||||
const result = service.extractProviderFromLegacyState('some-random-state');
|
||||
|
||||
expect(result.providerId).toBe('');
|
||||
expect(result.originalState).toBe('some-random-state');
|
||||
});
|
||||
});
|
||||
|
||||
describe('cleanupExpiredStates', () => {
|
||||
it('should clean up expired states periodically', () => {
|
||||
const providerId = 'test-provider';
|
||||
|
||||
// Generate multiple states
|
||||
service.generateSecureState(providerId, 'state1');
|
||||
service.generateSecureState(providerId, 'state2');
|
||||
service.generateSecureState(providerId, 'state3');
|
||||
|
||||
// Fast forward past expiration
|
||||
vi.advanceTimersByTime(11 * 60 * 1000);
|
||||
|
||||
// Generate a new state that shouldn't be cleaned
|
||||
const validState = service.generateSecureState(providerId, 'state4');
|
||||
|
||||
// Trigger cleanup (happens every minute)
|
||||
vi.advanceTimersByTime(60 * 1000);
|
||||
|
||||
// The new state should still be valid
|
||||
const result = service.validateSecureState(validState, providerId);
|
||||
expect(result.isValid).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,164 +0,0 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { ConfigService } from '@nestjs/config';
|
||||
|
||||
import * as client from 'openid-client';
|
||||
|
||||
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/oidc-provider.model.js';
|
||||
|
||||
@Injectable()
|
||||
export class OidcValidationService {
|
||||
private readonly logger = new Logger(OidcValidationService.name);
|
||||
|
||||
constructor(private readonly configService: ConfigService) {}
|
||||
|
||||
/**
|
||||
* Validate OIDC provider configuration by attempting discovery
|
||||
* Returns validation result with helpful error messages for debugging
|
||||
*/
|
||||
async validateProvider(
|
||||
provider: OidcProvider
|
||||
): Promise<{ isValid: boolean; error?: string; details?: unknown }> {
|
||||
try {
|
||||
// Validate issuer URL is present
|
||||
if (!provider.issuer) {
|
||||
return {
|
||||
isValid: false,
|
||||
error: 'No issuer URL provided. Please specify the OIDC provider issuer URL.',
|
||||
details: { type: 'MISSING_ISSUER' },
|
||||
};
|
||||
}
|
||||
|
||||
// Validate issuer URL is valid
|
||||
let serverUrl: URL;
|
||||
try {
|
||||
serverUrl = new URL(provider.issuer);
|
||||
} catch (urlError) {
|
||||
return {
|
||||
isValid: false,
|
||||
error: `Invalid issuer URL format: '${provider.issuer}'. Please provide a valid URL.`,
|
||||
details: {
|
||||
type: 'INVALID_URL',
|
||||
originalError: urlError instanceof Error ? urlError.message : String(urlError),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Configure client options for HTTP if needed
|
||||
let clientOptions: any = undefined;
|
||||
if (serverUrl.protocol === 'http:') {
|
||||
this.logger.debug(
|
||||
`HTTP issuer URL detected for provider ${provider.id}: ${provider.issuer}`
|
||||
);
|
||||
clientOptions = {
|
||||
execute: [client.allowInsecureRequests],
|
||||
};
|
||||
}
|
||||
|
||||
// Attempt OIDC discovery
|
||||
await this.performDiscovery(provider, clientOptions);
|
||||
return { isValid: true };
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
|
||||
|
||||
// Log the raw error for debugging
|
||||
this.logger.debug(`Raw discovery error for ${provider.id}: ${errorMessage}`);
|
||||
|
||||
// Provide specific error messages for common issues
|
||||
let userFriendlyError = errorMessage;
|
||||
let details: Record<string, unknown> = {};
|
||||
|
||||
if (errorMessage.includes('getaddrinfo ENOTFOUND')) {
|
||||
userFriendlyError = `Cannot resolve domain name. Please check that '${provider.issuer}' is accessible and spelled correctly.`;
|
||||
details = { type: 'DNS_ERROR', originalError: errorMessage };
|
||||
} else if (errorMessage.includes('ECONNREFUSED')) {
|
||||
userFriendlyError = `Connection refused. The server at '${provider.issuer}' is not accepting connections.`;
|
||||
details = { type: 'CONNECTION_ERROR', originalError: errorMessage };
|
||||
} else if (errorMessage.includes('ECONNRESET') || errorMessage.includes('ETIMEDOUT')) {
|
||||
userFriendlyError = `Connection timeout. The server at '${provider.issuer}' is not responding.`;
|
||||
details = { type: 'TIMEOUT_ERROR', originalError: errorMessage };
|
||||
} else if (errorMessage.includes('404') || errorMessage.includes('Not Found')) {
|
||||
const baseUrl = provider.issuer?.endsWith('/.well-known/openid-configuration')
|
||||
? provider.issuer.replace('/.well-known/openid-configuration', '')
|
||||
: provider.issuer;
|
||||
userFriendlyError = `OIDC discovery endpoint not found. Please verify that '${baseUrl}/.well-known/openid-configuration' exists.`;
|
||||
details = { type: 'DISCOVERY_NOT_FOUND', originalError: errorMessage };
|
||||
} else if (errorMessage.includes('401') || errorMessage.includes('403')) {
|
||||
userFriendlyError = `Access denied to discovery endpoint. Please check the issuer URL and any authentication requirements.`;
|
||||
details = { type: 'AUTHENTICATION_ERROR', originalError: errorMessage };
|
||||
} else if (errorMessage.includes('unexpected HTTP response status code')) {
|
||||
// Extract status code if possible
|
||||
const statusMatch = errorMessage.match(/status code (\d+)/);
|
||||
const statusCode = statusMatch ? statusMatch[1] : 'unknown';
|
||||
const baseUrl = provider.issuer?.endsWith('/.well-known/openid-configuration')
|
||||
? provider.issuer.replace('/.well-known/openid-configuration', '')
|
||||
: provider.issuer;
|
||||
userFriendlyError = `HTTP ${statusCode} error from discovery endpoint. Please check that '${baseUrl}/.well-known/openid-configuration' returns a valid OIDC discovery document.`;
|
||||
details = { type: 'HTTP_STATUS_ERROR', statusCode, originalError: errorMessage };
|
||||
} else if (
|
||||
errorMessage.includes('certificate') ||
|
||||
errorMessage.includes('SSL') ||
|
||||
errorMessage.includes('TLS')
|
||||
) {
|
||||
userFriendlyError = `SSL/TLS certificate error. The server certificate may be invalid or expired.`;
|
||||
details = { type: 'SSL_ERROR', originalError: errorMessage };
|
||||
} else if (errorMessage.includes('JSON') || errorMessage.includes('parse')) {
|
||||
userFriendlyError = `Invalid OIDC discovery response. The server returned malformed JSON.`;
|
||||
details = { type: 'INVALID_JSON', originalError: errorMessage };
|
||||
} else if (error && (error as any).code === 'OAUTH_RESPONSE_IS_NOT_CONFORM') {
|
||||
const baseUrl = provider.issuer?.endsWith('/.well-known/openid-configuration')
|
||||
? provider.issuer.replace('/.well-known/openid-configuration', '')
|
||||
: provider.issuer;
|
||||
userFriendlyError = `Invalid OIDC discovery document. The server at '${baseUrl}/.well-known/openid-configuration' returned a response that doesn't conform to the OpenID Connect Discovery specification. Please verify the endpoint returns valid OIDC metadata.`;
|
||||
details = { type: 'INVALID_OIDC_DOCUMENT', originalError: errorMessage };
|
||||
}
|
||||
|
||||
this.logger.warn(`OIDC validation failed for provider ${provider.id}: ${errorMessage}`);
|
||||
|
||||
// Add debug logging for HTTP status errors
|
||||
if (errorMessage.includes('unexpected HTTP response status code')) {
|
||||
const baseUrl = provider.issuer?.endsWith('/.well-known/openid-configuration')
|
||||
? provider.issuer.replace('/.well-known/openid-configuration', '')
|
||||
: provider.issuer;
|
||||
this.logger.debug(`Attempted to fetch: ${baseUrl}/.well-known/openid-configuration`);
|
||||
this.logger.debug(`Full error details: ${errorMessage}`);
|
||||
}
|
||||
|
||||
return {
|
||||
isValid: false,
|
||||
error: userFriendlyError,
|
||||
details,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
async performDiscovery(provider: OidcProvider, clientOptions?: any): Promise<client.Configuration> {
|
||||
if (!provider.issuer) {
|
||||
throw new Error('No issuer URL provided');
|
||||
}
|
||||
|
||||
// Configure client auth method
|
||||
const clientAuth = provider.clientSecret
|
||||
? client.ClientSecretPost(provider.clientSecret)
|
||||
: undefined;
|
||||
|
||||
const serverUrl = new URL(provider.issuer);
|
||||
|
||||
// Use provided client options or create default options with HTTP support if needed
|
||||
if (!clientOptions && serverUrl.protocol === 'http:') {
|
||||
this.logger.debug(`Allowing HTTP for ${provider.id} as specified by user`);
|
||||
// For openid-client v6, use allowInsecureRequests in the execute array
|
||||
// This is deprecated but needed for local development with HTTP endpoints
|
||||
clientOptions = {
|
||||
execute: [client.allowInsecureRequests],
|
||||
};
|
||||
}
|
||||
|
||||
return client.discovery(
|
||||
serverUrl,
|
||||
provider.clientId,
|
||||
undefined, // client metadata
|
||||
clientAuth,
|
||||
clientOptions
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
|
||||
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-session.service.js';
|
||||
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
|
||||
|
||||
@Module({
|
||||
providers: [OidcSessionService, OidcStateService],
|
||||
exports: [OidcSessionService, OidcStateService],
|
||||
})
|
||||
export class OidcSessionModule {}
|
||||
@@ -4,7 +4,7 @@ import { Test } from '@nestjs/testing';
|
||||
import type { Cache } from 'cache-manager';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/oidc-session.service.js';
|
||||
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-session.service.js';
|
||||
|
||||
describe('OidcSessionService', () => {
|
||||
let service: OidcSessionService;
|
||||
@@ -15,7 +15,7 @@ export interface OidcSession {
|
||||
@Injectable()
|
||||
export class OidcSessionService {
|
||||
private readonly logger = new Logger(OidcSessionService.name);
|
||||
private readonly SESSION_TTL_SECONDS = 2 * 60; // 2 minutes for one-time token security
|
||||
private readonly SESSION_TTL_MS = 2 * 60 * 1000; // 2 minutes in milliseconds (cache-manager v7 expects milliseconds)
|
||||
|
||||
constructor(@Inject(CACHE_MANAGER) private readonly cacheManager: Cache) {}
|
||||
|
||||
@@ -28,12 +28,21 @@ export class OidcSessionService {
|
||||
providerId,
|
||||
providerUserId,
|
||||
createdAt: now,
|
||||
expiresAt: new Date(now.getTime() + this.SESSION_TTL_SECONDS * 1000),
|
||||
expiresAt: new Date(now.getTime() + this.SESSION_TTL_MS),
|
||||
};
|
||||
|
||||
// Store in cache with TTL
|
||||
await this.cacheManager.set(sessionId, session, this.SESSION_TTL_SECONDS * 1000);
|
||||
this.logger.log(`Created OIDC session ${sessionId} for provider ${providerId}`);
|
||||
// Store in cache with TTL (in milliseconds for cache-manager v7)
|
||||
await this.cacheManager.set(sessionId, session, this.SESSION_TTL_MS);
|
||||
|
||||
// Verify it was stored
|
||||
const verifyStored = await this.cacheManager.get(sessionId);
|
||||
if (verifyStored) {
|
||||
this.logger.debug(`Session successfully stored and verified with ID: ${sessionId}`);
|
||||
} else {
|
||||
this.logger.error(`CRITICAL: Session was NOT stored in cache for ID: ${sessionId}`);
|
||||
}
|
||||
|
||||
this.logger.log(`Created OIDC session for provider ${providerId}`);
|
||||
|
||||
return this.createPaddedToken(sessionId);
|
||||
}
|
||||
@@ -44,15 +53,16 @@ export class OidcSessionService {
|
||||
return { valid: false };
|
||||
}
|
||||
|
||||
this.logger.debug(`Looking for session with ID: ${sessionId}`);
|
||||
const session = await this.cacheManager.get<OidcSession>(sessionId);
|
||||
if (!session) {
|
||||
this.logger.debug(`Session ${sessionId} not found`);
|
||||
this.logger.debug(`Session not found for ID: ${sessionId}`);
|
||||
return { valid: false };
|
||||
}
|
||||
|
||||
const now = new Date();
|
||||
if (now > new Date(session.expiresAt)) {
|
||||
this.logger.debug(`Session ${sessionId} expired`);
|
||||
this.logger.debug(`Session expired`);
|
||||
await this.cacheManager.del(sessionId);
|
||||
return { valid: false };
|
||||
}
|
||||
@@ -62,7 +72,7 @@ export class OidcSessionService {
|
||||
await this.cacheManager.del(sessionId);
|
||||
|
||||
this.logger.log(
|
||||
`Validated and invalidated session ${sessionId} for provider ${session.providerId} (one-time use)`
|
||||
`Validated and invalidated session for provider ${session.providerId} (one-time use)`
|
||||
);
|
||||
return { valid: true, username: 'root' };
|
||||
}
|
||||
@@ -0,0 +1,155 @@
|
||||
import { CacheModule } from '@nestjs/cache-manager';
|
||||
import { UnauthorizedException } from '@nestjs/common';
|
||||
import { Test } from '@nestjs/testing';
|
||||
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { OidcStateExtractor } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state-extractor.util.js';
|
||||
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
|
||||
|
||||
describe('OidcStateExtractor', () => {
|
||||
let stateService: OidcStateService;
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
const module = await Test.createTestingModule({
|
||||
imports: [CacheModule.register()],
|
||||
providers: [OidcStateService],
|
||||
}).compile();
|
||||
|
||||
stateService = module.get<OidcStateService>(OidcStateService);
|
||||
});
|
||||
|
||||
describe('extractProviderFromState', () => {
|
||||
it('should extract provider ID from valid state', () => {
|
||||
const state = 'provider123:nonce.timestamp.signature';
|
||||
const result = OidcStateExtractor.extractProviderFromState(state, stateService);
|
||||
|
||||
expect(result.providerId).toBe('provider123');
|
||||
expect(result.originalState).toBe(state);
|
||||
});
|
||||
|
||||
it('should handle state without provider prefix', () => {
|
||||
const state = 'invalid-state-format';
|
||||
const result = OidcStateExtractor.extractProviderFromState(state, stateService);
|
||||
|
||||
expect(result.providerId).toBe('');
|
||||
expect(result.originalState).toBe(state);
|
||||
});
|
||||
});
|
||||
|
||||
describe('extractAndValidateState', () => {
|
||||
it('should extract and validate a valid state with redirectUri', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
const redirectUri = 'https://example.com/callback';
|
||||
|
||||
// Generate a valid state
|
||||
const state = await stateService.generateSecureState(providerId, clientState, redirectUri);
|
||||
|
||||
// Extract and validate
|
||||
const result = await OidcStateExtractor.extractAndValidateState(state, stateService);
|
||||
|
||||
expect(result.providerId).toBe(providerId);
|
||||
expect(result.originalState).toBe(state);
|
||||
expect(result.clientState).toBe(clientState);
|
||||
expect(result.redirectUri).toBe(redirectUri);
|
||||
});
|
||||
|
||||
it('should extract and validate a valid state without redirectUri', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
// Generate a valid state without redirectUri
|
||||
const state = await stateService.generateSecureState(providerId, clientState);
|
||||
|
||||
// Extract and validate
|
||||
const result = await OidcStateExtractor.extractAndValidateState(state, stateService);
|
||||
|
||||
expect(result.providerId).toBe(providerId);
|
||||
expect(result.originalState).toBe(state);
|
||||
expect(result.clientState).toBe(clientState);
|
||||
expect(result.redirectUri).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should throw UnauthorizedException for invalid state format', async () => {
|
||||
const invalidState = 'invalid-format';
|
||||
|
||||
await expect(async () => {
|
||||
await OidcStateExtractor.extractAndValidateState(invalidState, stateService);
|
||||
}).rejects.toThrow(UnauthorizedException);
|
||||
});
|
||||
|
||||
it('should throw UnauthorizedException for expired state', async () => {
|
||||
vi.useFakeTimers();
|
||||
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
const redirectUri = 'https://example.com/callback';
|
||||
|
||||
// Generate a valid state
|
||||
const state = await stateService.generateSecureState(providerId, clientState, redirectUri);
|
||||
|
||||
// Fast forward time beyond expiration (11 minutes)
|
||||
vi.advanceTimersByTime(11 * 60 * 1000);
|
||||
|
||||
await expect(async () => {
|
||||
await OidcStateExtractor.extractAndValidateState(state, stateService);
|
||||
}).rejects.toThrow(UnauthorizedException);
|
||||
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
it('should throw UnauthorizedException for wrong provider ID', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const wrongProviderId = 'wrong-provider';
|
||||
const clientState = 'client-state-123';
|
||||
const redirectUri = 'https://example.com/callback';
|
||||
|
||||
// Generate a valid state
|
||||
const state = await stateService.generateSecureState(providerId, clientState, redirectUri);
|
||||
|
||||
// Create a fake state with wrong provider prefix
|
||||
const tamperedState = state.replace(providerId, wrongProviderId);
|
||||
|
||||
await expect(async () => {
|
||||
await OidcStateExtractor.extractAndValidateState(tamperedState, stateService);
|
||||
}).rejects.toThrow(UnauthorizedException);
|
||||
});
|
||||
|
||||
it('should throw UnauthorizedException for tampered state', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
const redirectUri = 'https://example.com/callback';
|
||||
|
||||
// Generate a valid state
|
||||
const state = await stateService.generateSecureState(providerId, clientState, redirectUri);
|
||||
|
||||
// Tamper with the signature
|
||||
const tamperedState = state.slice(0, -5) + 'xxxxx';
|
||||
|
||||
await expect(async () => {
|
||||
await OidcStateExtractor.extractAndValidateState(tamperedState, stateService);
|
||||
}).rejects.toThrow(UnauthorizedException);
|
||||
});
|
||||
|
||||
it('should throw UnauthorizedException for reused state (replay attack)', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
const redirectUri = 'https://example.com/callback';
|
||||
|
||||
// Generate a valid state
|
||||
const state = await stateService.generateSecureState(providerId, clientState, redirectUri);
|
||||
|
||||
// First validation should succeed
|
||||
const result1 = await OidcStateExtractor.extractAndValidateState(state, stateService);
|
||||
expect(result1.providerId).toBe(providerId);
|
||||
|
||||
// Second validation should fail (replay attack)
|
||||
await expect(async () => {
|
||||
await OidcStateExtractor.extractAndValidateState(state, stateService);
|
||||
}).rejects.toThrow(UnauthorizedException);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,60 @@
|
||||
import { UnauthorizedException } from '@nestjs/common';
|
||||
|
||||
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
|
||||
|
||||
export interface StateExtractionResult {
|
||||
providerId: string;
|
||||
originalState: string;
|
||||
clientState?: string;
|
||||
redirectUri?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Utility to extract and validate OIDC state information consistently
|
||||
* across authorize and callback endpoints
|
||||
*/
|
||||
export class OidcStateExtractor {
|
||||
/**
|
||||
* Extract provider ID from state without validation (for routing purposes)
|
||||
*/
|
||||
static extractProviderFromState(
|
||||
state: string,
|
||||
stateService: OidcStateService
|
||||
): { providerId: string; originalState: string } {
|
||||
// Use the state service's extraction method
|
||||
const providerId = stateService.extractProviderFromState(state);
|
||||
|
||||
return {
|
||||
providerId: providerId || '',
|
||||
originalState: state,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract provider ID and validate the full encrypted state
|
||||
*/
|
||||
static async extractAndValidateState(
|
||||
state: string,
|
||||
stateService: OidcStateService
|
||||
): Promise<StateExtractionResult> {
|
||||
// First extract provider ID for routing
|
||||
const { providerId } = this.extractProviderFromState(state, stateService);
|
||||
|
||||
if (!providerId) {
|
||||
throw new UnauthorizedException('Invalid state format: missing provider ID');
|
||||
}
|
||||
|
||||
// Then validate the full encrypted state
|
||||
const stateValidation = await stateService.validateSecureState(state, providerId);
|
||||
if (!stateValidation.isValid) {
|
||||
throw new UnauthorizedException(`Invalid state: ${stateValidation.error}`);
|
||||
}
|
||||
|
||||
return {
|
||||
providerId,
|
||||
originalState: state,
|
||||
clientState: stateValidation.clientState,
|
||||
redirectUri: stateValidation.redirectUri,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,238 @@
|
||||
import { CacheModule } from '@nestjs/cache-manager';
|
||||
import { Test, TestingModule } from '@nestjs/testing';
|
||||
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
|
||||
|
||||
describe('OidcStateService', () => {
|
||||
let service: OidcStateService;
|
||||
let module: TestingModule;
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.clearAllMocks();
|
||||
vi.useFakeTimers();
|
||||
// Set a deterministic system time for consistent testing
|
||||
vi.setSystemTime(new Date('2024-01-01T00:00:00Z'));
|
||||
|
||||
module = await Test.createTestingModule({
|
||||
imports: [CacheModule.register()],
|
||||
providers: [OidcStateService],
|
||||
}).compile();
|
||||
|
||||
service = module.get<OidcStateService>(OidcStateService);
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
vi.useRealTimers();
|
||||
// Close the testing module to prevent handle leaks
|
||||
if (module) {
|
||||
await module.close();
|
||||
}
|
||||
});
|
||||
|
||||
describe('generateSecureState', () => {
|
||||
it('should generate a state with provider prefix and signed token', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
const redirectUri = 'https://example.com/callback';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState, redirectUri);
|
||||
|
||||
expect(state).toBeTruthy();
|
||||
expect(typeof state).toBe('string');
|
||||
expect(state.startsWith(`${providerId}:`)).toBe(true);
|
||||
|
||||
// Extract signed portion and verify format (nonce.timestamp.signature)
|
||||
const signed = state.substring(providerId.length + 1);
|
||||
expect(signed.split('.').length).toBe(3);
|
||||
});
|
||||
|
||||
it('should generate unique states for each call', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
const redirectUri = 'https://example.com/callback';
|
||||
|
||||
const state1 = await service.generateSecureState(providerId, clientState, redirectUri);
|
||||
const state2 = await service.generateSecureState(providerId, clientState, redirectUri);
|
||||
|
||||
expect(state1).not.toBe(state2);
|
||||
});
|
||||
|
||||
it('should work without redirectUri parameter (backwards compatibility)', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState);
|
||||
|
||||
expect(state).toBeTruthy();
|
||||
expect(state.startsWith(`${providerId}:`)).toBe(true);
|
||||
});
|
||||
|
||||
it('should store state data in cache and retrieve it', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
const redirectUri = 'https://example.com/callback';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState, redirectUri);
|
||||
const validation = await service.validateSecureState(state, providerId);
|
||||
|
||||
expect(validation.isValid).toBe(true);
|
||||
expect(validation.clientState).toBe(clientState);
|
||||
expect(validation.redirectUri).toBe(redirectUri);
|
||||
});
|
||||
});
|
||||
|
||||
describe('validateSecureState', () => {
|
||||
it('should validate a valid state token', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState);
|
||||
const result = await service.validateSecureState(state, providerId);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.clientState).toBe(clientState);
|
||||
expect(result.error).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should validate a state token with redirectUri', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
const redirectUri = 'https://example.com/callback';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState, redirectUri);
|
||||
const result = await service.validateSecureState(state, providerId);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.clientState).toBe(clientState);
|
||||
expect(result.redirectUri).toBe(redirectUri);
|
||||
expect(result.error).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should reject state with wrong provider ID', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState);
|
||||
const result = await service.validateSecureState(state, 'different-provider');
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toContain('Provider ID mismatch');
|
||||
});
|
||||
|
||||
it('should reject expired state tokens', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState);
|
||||
|
||||
// Advance time by 11 minutes (past the 10-minute TTL)
|
||||
vi.advanceTimersByTime(11 * 60 * 1000);
|
||||
|
||||
const result = await service.validateSecureState(state, providerId);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toContain('expired');
|
||||
});
|
||||
|
||||
it('should reject reused state tokens', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState);
|
||||
|
||||
// First validation should succeed
|
||||
const result1 = await service.validateSecureState(state, providerId);
|
||||
expect(result1.isValid).toBe(true);
|
||||
|
||||
// Second validation should fail (replay attack prevention)
|
||||
const result2 = await service.validateSecureState(state, providerId);
|
||||
expect(result2.isValid).toBe(false);
|
||||
expect(result2.error).toContain('not found or already used');
|
||||
});
|
||||
|
||||
it('should reject invalid state tokens', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const invalidState = `${providerId}:invalid-format`;
|
||||
|
||||
const result = await service.validateSecureState(invalidState, providerId);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toBeTruthy();
|
||||
});
|
||||
|
||||
it('should reject tampered state tokens', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState);
|
||||
// Tamper with the signature
|
||||
const tamperedState = state.substring(0, state.length - 5) + 'xxxxx';
|
||||
|
||||
const result = await service.validateSecureState(tamperedState, providerId);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toContain('signature');
|
||||
});
|
||||
});
|
||||
|
||||
describe('extractProviderFromState', () => {
|
||||
it('should extract provider ID from state', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState);
|
||||
const extracted = service.extractProviderFromState(state);
|
||||
|
||||
expect(extracted).toBe(providerId);
|
||||
});
|
||||
|
||||
it('should return null for invalid state format', () => {
|
||||
const invalidState = 'invalid-state-without-colon';
|
||||
const extracted = service.extractProviderFromState(invalidState);
|
||||
|
||||
expect(extracted).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('extractProviderFromLegacyState', () => {
|
||||
it('should handle legacy state format', () => {
|
||||
const legacyState = 'provider-id:client-state-value';
|
||||
const result = service.extractProviderFromLegacyState(legacyState);
|
||||
|
||||
expect(result.providerId).toBe('provider-id');
|
||||
expect(result.originalState).toBe('client-state-value');
|
||||
});
|
||||
|
||||
it('should handle new signed state format', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState);
|
||||
const result = service.extractProviderFromLegacyState(state);
|
||||
|
||||
// New format should not be recognized as legacy
|
||||
expect(result.providerId).toBe('');
|
||||
expect(result.originalState).toBe(state);
|
||||
});
|
||||
});
|
||||
|
||||
describe('cache TTL', () => {
|
||||
it('should remove state from cache after successful validation', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState);
|
||||
|
||||
// First validation should succeed
|
||||
const result1 = await service.validateSecureState(state, providerId);
|
||||
expect(result1.isValid).toBe(true);
|
||||
|
||||
// Second validation should fail (state was removed after first use)
|
||||
const result2 = await service.validateSecureState(state, providerId);
|
||||
expect(result2.isValid).toBe(false);
|
||||
expect(result2.error).toContain('not found or already used');
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,241 @@
|
||||
import { CacheModule } from '@nestjs/cache-manager';
|
||||
import { Test } from '@nestjs/testing';
|
||||
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
|
||||
|
||||
describe('OidcStateService', () => {
|
||||
let service: OidcStateService;
|
||||
|
||||
beforeEach(async () => {
|
||||
const module = await Test.createTestingModule({
|
||||
imports: [CacheModule.register()],
|
||||
providers: [OidcStateService],
|
||||
}).compile();
|
||||
|
||||
service = module.get<OidcStateService>(OidcStateService);
|
||||
});
|
||||
|
||||
describe('state generation and validation flow', () => {
|
||||
it('should generate state with redirect URI and validate it successfully', async () => {
|
||||
const providerId = 'unraid.net';
|
||||
const clientState = 'client-state-123';
|
||||
const redirectUri = 'http://devgen-dev1.local/graphql/api/auth/oidc/callback';
|
||||
|
||||
// Generate state
|
||||
const state = await service.generateSecureState(providerId, clientState, redirectUri);
|
||||
|
||||
// Verify state format: providerId:nonce.timestamp.signature
|
||||
expect(state).toMatch(/^unraid\.net:[a-f0-9]+\.\d+\.[a-f0-9]+$/);
|
||||
|
||||
// Extract and verify parts
|
||||
const [extractedProviderId, signedState] = state.split(':');
|
||||
expect(extractedProviderId).toBe(providerId);
|
||||
|
||||
// Parse the signed state components
|
||||
const [nonce, timestamp, signature] = signedState.split('.');
|
||||
|
||||
// Verify nonce is a 32-character hex string (16 bytes)
|
||||
expect(nonce).toMatch(/^[a-f0-9]{32}$/);
|
||||
|
||||
// Verify timestamp is a valid number and recent
|
||||
const timestampNum = parseInt(timestamp, 10);
|
||||
expect(timestampNum).toBeGreaterThan(Date.now() - 1000); // Generated within last second
|
||||
expect(timestampNum).toBeLessThanOrEqual(Date.now());
|
||||
|
||||
// Verify signature is a 64-character hex string (SHA256 output)
|
||||
expect(signature).toMatch(/^[a-f0-9]{64}$/);
|
||||
|
||||
// Validate the state
|
||||
const validation = await service.validateSecureState(state, providerId);
|
||||
|
||||
expect(validation.isValid).toBe(true);
|
||||
expect(validation.clientState).toBe(clientState);
|
||||
expect(validation.redirectUri).toBe(redirectUri);
|
||||
});
|
||||
|
||||
it('should verify signed state integrity with HMAC', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'test-state';
|
||||
const redirectUri = 'http://localhost:3000/callback';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState, redirectUri);
|
||||
|
||||
// Tamper with the signature
|
||||
const [provider, signedState] = state.split(':');
|
||||
const [nonce, timestamp] = signedState.split('.');
|
||||
const tamperedSignature = 'a'.repeat(64); // Invalid signature
|
||||
const tamperedState = `${provider}:${nonce}.${timestamp}.${tamperedSignature}`;
|
||||
|
||||
const validation = await service.validateSecureState(tamperedState, providerId);
|
||||
|
||||
expect(validation.isValid).toBe(false);
|
||||
expect(validation.error).toContain('Invalid state signature');
|
||||
});
|
||||
|
||||
it('should fail validation when nonce is not in cache', async () => {
|
||||
const providerId = 'unraid.net';
|
||||
// Create a fake state that looks valid but has unknown nonce
|
||||
const fakeState = `unraid.net:fakenonce123.${Date.now()}.fakesignature456`;
|
||||
|
||||
const validation = await service.validateSecureState(fakeState, providerId);
|
||||
|
||||
expect(validation.isValid).toBe(false);
|
||||
expect(validation.error).toContain('Invalid state signature');
|
||||
});
|
||||
|
||||
it('should prevent replay attacks by removing nonce after validation', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'test-state';
|
||||
const redirectUri = 'http://localhost:3000/callback';
|
||||
|
||||
// Generate and validate state once
|
||||
const state = await service.generateSecureState(providerId, clientState, redirectUri);
|
||||
const firstValidation = await service.validateSecureState(state, providerId);
|
||||
expect(firstValidation.isValid).toBe(true);
|
||||
|
||||
// Try to validate the same state again (replay attack)
|
||||
const secondValidation = await service.validateSecureState(state, providerId);
|
||||
expect(secondValidation.isValid).toBe(false);
|
||||
expect(secondValidation.error).toContain('State token not found or already used');
|
||||
});
|
||||
|
||||
it('should handle state with missing redirect URI', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'test-state';
|
||||
// No redirect URI provided
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState);
|
||||
const validation = await service.validateSecureState(state, providerId);
|
||||
|
||||
expect(validation.isValid).toBe(true);
|
||||
expect(validation.clientState).toBe(clientState);
|
||||
expect(validation.redirectUri).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should reject state with wrong provider ID', async () => {
|
||||
const providerId = 'provider-a';
|
||||
const wrongProviderId = 'provider-b';
|
||||
const clientState = 'test-state';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState);
|
||||
const validation = await service.validateSecureState(state, wrongProviderId);
|
||||
|
||||
expect(validation.isValid).toBe(false);
|
||||
expect(validation.error).toContain('Provider ID mismatch');
|
||||
});
|
||||
|
||||
it('should extract provider from state correctly', async () => {
|
||||
const providerId = 'unraid.net';
|
||||
const state = await service.generateSecureState(providerId, 'test', 'http://example.com');
|
||||
|
||||
const extracted = service.extractProviderFromState(state);
|
||||
expect(extracted).toBe(providerId);
|
||||
});
|
||||
|
||||
it('should handle state expiration', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'test-state';
|
||||
|
||||
// Generate state
|
||||
const state = await service.generateSecureState(providerId, clientState);
|
||||
|
||||
// Mock timestamp to simulate expired state
|
||||
const parts = state.split(':')[1].split('.');
|
||||
const nonce = parts[0];
|
||||
const expiredTimestamp = Date.now() - 700000; // 11+ minutes ago
|
||||
const fakeState = `${providerId}:${nonce}.${expiredTimestamp}.fakesignature`;
|
||||
|
||||
const validation = await service.validateSecureState(fakeState, providerId);
|
||||
expect(validation.isValid).toBe(false);
|
||||
expect(validation.error).toContain('Invalid state signature'); // Will fail on signature first
|
||||
});
|
||||
});
|
||||
|
||||
describe('redirect URI extraction from state', () => {
|
||||
it('should store and retrieve redirect URI from state token', async () => {
|
||||
const providerId = 'unraid.net';
|
||||
const clientState = 'original-client-state';
|
||||
const redirectUri = 'http://devgen-dev1.local/graphql/api/auth/oidc/callback';
|
||||
|
||||
// This simulates the authorize flow
|
||||
const stateToken = await service.generateSecureState(providerId, clientState, redirectUri);
|
||||
|
||||
// Log the generated state for debugging
|
||||
console.log('Generated state token:', stateToken);
|
||||
|
||||
// This simulates the callback flow
|
||||
const validation = await service.validateSecureState(stateToken, providerId);
|
||||
|
||||
expect(validation.isValid).toBe(true);
|
||||
expect(validation.redirectUri).toBe(redirectUri);
|
||||
expect(validation.clientState).toBe(clientState);
|
||||
});
|
||||
|
||||
it('should handle dynamic redirect URIs for different origins', async () => {
|
||||
const providerId = 'google';
|
||||
const clientState = 'state123';
|
||||
|
||||
// Test with different origins
|
||||
const origins = [
|
||||
'http://localhost:3000/graphql/api/auth/oidc/callback',
|
||||
'https://myserver.local/graphql/api/auth/oidc/callback',
|
||||
'http://192.168.1.100/graphql/api/auth/oidc/callback',
|
||||
];
|
||||
|
||||
for (const redirectUri of origins) {
|
||||
const state = await service.generateSecureState(providerId, clientState, redirectUri);
|
||||
const validation = await service.validateSecureState(state, providerId);
|
||||
|
||||
expect(validation.isValid).toBe(true);
|
||||
expect(validation.redirectUri).toBe(redirectUri);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('cache management', () => {
|
||||
it('should handle TTL expiration correctly', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'test-state';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState);
|
||||
|
||||
// First validation should succeed
|
||||
const validation1 = await service.validateSecureState(state, providerId);
|
||||
expect(validation1.isValid).toBe(true);
|
||||
|
||||
// State should be removed after first use (replay protection)
|
||||
const validation2 = await service.validateSecureState(state, providerId);
|
||||
expect(validation2.isValid).toBe(false);
|
||||
});
|
||||
|
||||
it('should store complete state data in cache with redirect URI', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-123';
|
||||
const redirectUri = 'http://example.com/callback';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState, redirectUri);
|
||||
|
||||
// Extract nonce from the generated state
|
||||
const [, signedState] = state.split(':');
|
||||
const [nonce] = signedState.split('.');
|
||||
|
||||
// Access the cache directly to verify stored data
|
||||
const cacheKey = `oidc_state:${nonce}`;
|
||||
const cachedData = await service['cacheManager'].get(cacheKey);
|
||||
|
||||
expect(cachedData).toBeDefined();
|
||||
expect(cachedData).toMatchObject({
|
||||
nonce,
|
||||
clientState,
|
||||
providerId,
|
||||
redirectUri,
|
||||
});
|
||||
// @ts-expect-error - cachedData is of type StateData
|
||||
expect(cachedData.timestamp).toBeGreaterThan(Date.now() - 1000);
|
||||
// @ts-expect-error - cachedData is of type StateData
|
||||
expect(cachedData.timestamp).toBeLessThanOrEqual(Date.now());
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,4 +1,5 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { Cache, CACHE_MANAGER } from '@nestjs/cache-manager';
|
||||
import { Inject, Injectable, Logger } from '@nestjs/common';
|
||||
import crypto from 'crypto';
|
||||
|
||||
interface StateData {
|
||||
@@ -6,26 +7,34 @@ interface StateData {
|
||||
clientState: string;
|
||||
timestamp: number;
|
||||
providerId: string;
|
||||
redirectUri?: string;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class OidcStateService {
|
||||
private static instanceCount = 0;
|
||||
private readonly instanceId: number;
|
||||
private readonly logger = new Logger(OidcStateService.name);
|
||||
private readonly stateCache = new Map<string, StateData>();
|
||||
private readonly hmacSecret: string;
|
||||
private readonly STATE_TTL_SECONDS = 600; // 10 minutes
|
||||
private readonly STATE_TTL_MS = 600000; // 10 minutes in milliseconds (cache-manager v7+ expects milliseconds, not seconds)
|
||||
private readonly STATE_CACHE_PREFIX = 'oidc_state:';
|
||||
|
||||
constructor(@Inject(CACHE_MANAGER) private cacheManager: Cache) {
|
||||
// Track instance creation
|
||||
this.instanceId = ++OidcStateService.instanceCount;
|
||||
|
||||
constructor() {
|
||||
// Always generate a new secret on API restart for security
|
||||
// This ensures state tokens cannot be reused across restarts
|
||||
this.hmacSecret = crypto.randomBytes(32).toString('hex');
|
||||
this.logger.debug('Generated new OIDC state secret for this session');
|
||||
|
||||
// Clean up expired states periodically
|
||||
setInterval(() => this.cleanupExpiredStates(), 60000); // Every minute
|
||||
this.logger.warn(`OidcStateService instance #${this.instanceId} created with new HMAC secret`);
|
||||
this.logger.debug(`HMAC secret first 8 chars: ${this.hmacSecret.substring(0, 8)}`);
|
||||
}
|
||||
|
||||
generateSecureState(providerId: string, clientState: string): string {
|
||||
async generateSecureState(
|
||||
providerId: string,
|
||||
clientState: string,
|
||||
redirectUri?: string
|
||||
): Promise<string> {
|
||||
const nonce = crypto.randomBytes(16).toString('hex');
|
||||
const timestamp = Date.now();
|
||||
|
||||
@@ -35,8 +44,21 @@ export class OidcStateService {
|
||||
clientState,
|
||||
timestamp,
|
||||
providerId,
|
||||
redirectUri,
|
||||
};
|
||||
this.stateCache.set(nonce, stateData);
|
||||
|
||||
// Store in cache with TTL (in milliseconds for cache-manager v7)
|
||||
const cacheKey = `${this.STATE_CACHE_PREFIX}${nonce}`;
|
||||
this.logger.debug(`Storing state with key: ${cacheKey}, TTL: ${this.STATE_TTL_MS}ms`);
|
||||
await this.cacheManager.set(cacheKey, stateData, this.STATE_TTL_MS);
|
||||
|
||||
// Verify it was stored
|
||||
const verifyStored = await this.cacheManager.get(cacheKey);
|
||||
if (verifyStored) {
|
||||
this.logger.debug(`State successfully stored and verified for key: ${cacheKey}`);
|
||||
} else {
|
||||
this.logger.error(`CRITICAL: State was NOT stored in cache for key: ${cacheKey}`);
|
||||
}
|
||||
|
||||
// Create signed state: nonce.timestamp.signature
|
||||
const dataToSign = `${nonce}.${timestamp}`;
|
||||
@@ -45,14 +67,18 @@ export class OidcStateService {
|
||||
const signedState = `${dataToSign}.${signature}`;
|
||||
|
||||
this.logger.debug(`Generated secure state for provider ${providerId} with nonce ${nonce}`);
|
||||
this.logger.debug(
|
||||
`Instance #${this.instanceId}, HMAC secret first 8 chars: ${this.hmacSecret.substring(0, 8)}`
|
||||
);
|
||||
this.logger.debug(`Stored redirectUri: ${redirectUri}`);
|
||||
// Return state with provider ID prefix (unencrypted) for routing
|
||||
return `${providerId}:${signedState}`;
|
||||
}
|
||||
|
||||
validateSecureState(
|
||||
async validateSecureState(
|
||||
state: string,
|
||||
expectedProviderId: string
|
||||
): { isValid: boolean; clientState?: string; error?: string } {
|
||||
): Promise<{ isValid: boolean; clientState?: string; redirectUri?: string; error?: string }> {
|
||||
try {
|
||||
// Extract provider ID and signed state
|
||||
const parts = state.split(':');
|
||||
@@ -107,7 +133,7 @@ export class OidcStateService {
|
||||
// Check timestamp expiration
|
||||
const now = Date.now();
|
||||
const age = now - timestamp;
|
||||
if (age > this.STATE_TTL_SECONDS * 1000) {
|
||||
if (age > this.STATE_TTL_MS) {
|
||||
this.logger.warn(`State validation failed: token expired (age: ${age}ms)`);
|
||||
return {
|
||||
isValid: false,
|
||||
@@ -116,11 +142,21 @@ export class OidcStateService {
|
||||
}
|
||||
|
||||
// Check if state exists in cache (prevents replay attacks)
|
||||
const cachedState = this.stateCache.get(nonce);
|
||||
const cacheKey = `${this.STATE_CACHE_PREFIX}${nonce}`;
|
||||
this.logger.debug(`Looking for nonce ${nonce} in cache with key: ${cacheKey}`);
|
||||
this.logger.debug(
|
||||
`Instance #${this.instanceId}, HMAC secret first 8 chars: ${this.hmacSecret.substring(0, 8)}`
|
||||
);
|
||||
this.logger.debug(`Cache manager type: ${this.cacheManager.constructor.name}`);
|
||||
|
||||
const cachedState = await this.cacheManager.get<StateData>(cacheKey);
|
||||
|
||||
if (!cachedState) {
|
||||
this.logger.warn(
|
||||
`State validation failed: nonce ${nonce} not found in cache (possible replay attack)`
|
||||
);
|
||||
this.logger.warn(`Cache key checked: ${cacheKey}`);
|
||||
|
||||
return {
|
||||
isValid: false,
|
||||
error: 'State token not found or already used',
|
||||
@@ -137,12 +173,13 @@ export class OidcStateService {
|
||||
}
|
||||
|
||||
// Remove from cache to prevent reuse
|
||||
this.stateCache.delete(nonce);
|
||||
await this.cacheManager.del(cacheKey);
|
||||
|
||||
this.logger.debug(`State validation successful for provider ${expectedProviderId}`);
|
||||
return {
|
||||
isValid: true,
|
||||
clientState: cachedState.clientState,
|
||||
redirectUri: cachedState.redirectUri,
|
||||
};
|
||||
} catch (error) {
|
||||
this.logger.error(
|
||||
@@ -182,20 +219,5 @@ export class OidcStateService {
|
||||
return null;
|
||||
}
|
||||
|
||||
private cleanupExpiredStates(): void {
|
||||
const now = Date.now();
|
||||
let cleaned = 0;
|
||||
|
||||
for (const [nonce, stateData] of this.stateCache.entries()) {
|
||||
const age = now - stateData.timestamp;
|
||||
if (age > this.STATE_TTL_SECONDS * 1000) {
|
||||
this.stateCache.delete(nonce);
|
||||
cleaned++;
|
||||
}
|
||||
}
|
||||
|
||||
if (cleaned > 0) {
|
||||
this.logger.debug(`Cleaned up ${cleaned} expired state entries`);
|
||||
}
|
||||
}
|
||||
// Cleanup is now handled by cache TTL
|
||||
}
|
||||
@@ -1,33 +1,13 @@
|
||||
import { CacheModule } from '@nestjs/cache-manager';
|
||||
import { Module } from '@nestjs/common';
|
||||
|
||||
import { UserSettingsModule } from '@unraid/shared/services/user-settings.js';
|
||||
|
||||
import { OidcAuthService } from '@app/unraid-api/graph/resolvers/sso/oidc-auth.service.js';
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/oidc-config.service.js';
|
||||
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/oidc-session.service.js';
|
||||
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/oidc-state.service.js';
|
||||
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/oidc-validation.service.js';
|
||||
import { OidcCoreModule } from '@app/unraid-api/graph/resolvers/sso/core/oidc-core.module.js';
|
||||
import { SsoResolver } from '@app/unraid-api/graph/resolvers/sso/sso.resolver.js';
|
||||
|
||||
import '@app/unraid-api/graph/resolvers/sso/sso-settings.types.js';
|
||||
import '@app/unraid-api/graph/resolvers/sso/models/sso-settings.types.js';
|
||||
|
||||
@Module({
|
||||
imports: [UserSettingsModule, CacheModule.register()],
|
||||
providers: [
|
||||
SsoResolver,
|
||||
OidcConfigPersistence,
|
||||
OidcSessionService,
|
||||
OidcStateService,
|
||||
OidcAuthService,
|
||||
OidcValidationService,
|
||||
],
|
||||
exports: [
|
||||
OidcConfigPersistence,
|
||||
OidcSessionService,
|
||||
OidcStateService,
|
||||
OidcAuthService,
|
||||
OidcValidationService,
|
||||
],
|
||||
imports: [OidcCoreModule],
|
||||
providers: [SsoResolver],
|
||||
exports: [OidcCoreModule],
|
||||
})
|
||||
export class SsoModule {}
|
||||
|
||||
@@ -6,11 +6,12 @@ import { PrefixedID } from '@unraid/shared/prefixed-id-scalar.js';
|
||||
import { UsePermissions } from '@unraid/shared/use-permissions.directive.js';
|
||||
|
||||
import { Public } from '@app/unraid-api/auth/public.decorator.js';
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/oidc-config.service.js';
|
||||
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/oidc-provider.model.js';
|
||||
import { OidcSessionValidation } from '@app/unraid-api/graph/resolvers/sso/oidc-session-validation.model.js';
|
||||
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/oidc-session.service.js';
|
||||
import { PublicOidcProvider } from '@app/unraid-api/graph/resolvers/sso/public-oidc-provider.model.js';
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
|
||||
import { OidcConfiguration } from '@app/unraid-api/graph/resolvers/sso/models/oidc-configuration.model.js';
|
||||
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
|
||||
import { OidcSessionValidation } from '@app/unraid-api/graph/resolvers/sso/models/oidc-session-validation.model.js';
|
||||
import { PublicOidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/public-oidc-provider.model.js';
|
||||
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-session.service.js';
|
||||
|
||||
@Resolver()
|
||||
export class SsoResolver {
|
||||
@@ -88,6 +89,19 @@ export class SsoResolver {
|
||||
return this.oidcConfig.getProvider(id);
|
||||
}
|
||||
|
||||
@Query(() => OidcConfiguration, { description: 'Get the full OIDC configuration (admin only)' })
|
||||
@UsePermissions({
|
||||
action: AuthAction.READ_ANY,
|
||||
resource: Resource.CONFIG,
|
||||
})
|
||||
public async oidcConfiguration(): Promise<OidcConfiguration> {
|
||||
const config = await this.oidcConfig.getConfig();
|
||||
return {
|
||||
providers: config?.providers || [],
|
||||
defaultAllowedOrigins: config?.defaultAllowedOrigins || [],
|
||||
};
|
||||
}
|
||||
|
||||
@Query(() => OidcSessionValidation, {
|
||||
description: 'Validate an OIDC session token (internal use for CLI validation)',
|
||||
})
|
||||
|
||||
@@ -0,0 +1,234 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
|
||||
export interface OidcErrorDetails {
|
||||
userFriendlyError: string;
|
||||
details: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export class OidcErrorHelper {
|
||||
private static readonly logger = new Logger(OidcErrorHelper.name);
|
||||
|
||||
/**
|
||||
* Parse fetch errors and return user-friendly error messages
|
||||
*/
|
||||
static parseFetchError(error: unknown, issuerUrl?: string): OidcErrorDetails {
|
||||
const errorMessage = error instanceof Error ? error.message : String(error);
|
||||
let userFriendlyError = errorMessage;
|
||||
let details: Record<string, unknown> = { originalError: errorMessage };
|
||||
|
||||
// Extract cause information if available
|
||||
if (error instanceof Error && 'cause' in error) {
|
||||
const cause = (error as any).cause;
|
||||
if (cause) {
|
||||
this.logger.log('Fetch error cause: %o', cause);
|
||||
|
||||
const errorCode = cause.code || '';
|
||||
const causeMessage = cause.message || '';
|
||||
|
||||
// Map error codes to user-friendly messages
|
||||
switch (errorCode) {
|
||||
case 'ENOTFOUND':
|
||||
userFriendlyError = `Cannot resolve domain name. Please check that '${issuerUrl}' is accessible and spelled correctly.`;
|
||||
details = {
|
||||
type: 'DNS_ERROR',
|
||||
originalError: errorMessage,
|
||||
cause: causeMessage || errorCode,
|
||||
};
|
||||
break;
|
||||
|
||||
case 'ECONNREFUSED':
|
||||
userFriendlyError = `Connection refused. The server at '${issuerUrl}' is not accepting connections.`;
|
||||
details = {
|
||||
type: 'CONNECTION_ERROR',
|
||||
originalError: errorMessage,
|
||||
cause: causeMessage || errorCode,
|
||||
};
|
||||
break;
|
||||
|
||||
case 'CERT_HAS_EXPIRED':
|
||||
userFriendlyError = `SSL/TLS certificate error. The server certificate may be invalid or expired.`;
|
||||
details = {
|
||||
type: 'SSL_ERROR',
|
||||
originalError: errorMessage,
|
||||
cause: causeMessage || errorCode,
|
||||
};
|
||||
break;
|
||||
|
||||
case 'ETIMEDOUT':
|
||||
userFriendlyError = `Connection timeout. The server at '${issuerUrl}' is not responding.`;
|
||||
details = {
|
||||
type: 'TIMEOUT_ERROR',
|
||||
originalError: errorMessage,
|
||||
cause: causeMessage || errorCode,
|
||||
};
|
||||
break;
|
||||
|
||||
default:
|
||||
// Check message patterns if code doesn't match
|
||||
if (causeMessage.includes('ENOTFOUND')) {
|
||||
userFriendlyError = `Cannot resolve domain name. Please check that '${issuerUrl}' is accessible and spelled correctly.`;
|
||||
details = {
|
||||
type: 'DNS_ERROR',
|
||||
originalError: errorMessage,
|
||||
cause: causeMessage,
|
||||
};
|
||||
} else if (causeMessage.includes('ECONNREFUSED')) {
|
||||
userFriendlyError = `Connection refused. The server at '${issuerUrl}' is not accepting connections.`;
|
||||
details = {
|
||||
type: 'CONNECTION_ERROR',
|
||||
originalError: errorMessage,
|
||||
cause: causeMessage,
|
||||
};
|
||||
} else if (
|
||||
causeMessage.includes('certificate') ||
|
||||
causeMessage.includes('SSL') ||
|
||||
causeMessage.includes('TLS')
|
||||
) {
|
||||
userFriendlyError = `SSL/TLS certificate error. The server certificate may be invalid or expired.`;
|
||||
details = {
|
||||
type: 'SSL_ERROR',
|
||||
originalError: errorMessage,
|
||||
cause: causeMessage,
|
||||
};
|
||||
} else if (causeMessage.includes('ETIMEDOUT')) {
|
||||
userFriendlyError = `Connection timeout. The server at '${issuerUrl}' is not responding.`;
|
||||
details = {
|
||||
type: 'TIMEOUT_ERROR',
|
||||
originalError: errorMessage,
|
||||
cause: causeMessage,
|
||||
};
|
||||
} else {
|
||||
userFriendlyError = `Failed to connect to OIDC provider at '${issuerUrl}'. ${causeMessage || errorCode || 'Unknown network error'}`;
|
||||
details = {
|
||||
type: 'FETCH_ERROR',
|
||||
originalError: errorMessage,
|
||||
cause: causeMessage || errorCode,
|
||||
};
|
||||
}
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
// Generic fetch failed without cause
|
||||
userFriendlyError = `Failed to connect to OIDC provider at '${issuerUrl}'. Please verify the URL is correct and accessible.`;
|
||||
details = { type: 'FETCH_ERROR', originalError: errorMessage };
|
||||
}
|
||||
} else if (errorMessage.includes('fetch failed')) {
|
||||
// Fetch failed but no cause information
|
||||
userFriendlyError = `Failed to connect to OIDC provider at '${issuerUrl}'. Please verify the URL is correct and accessible.`;
|
||||
details = { type: 'FETCH_ERROR', originalError: errorMessage };
|
||||
}
|
||||
|
||||
return { userFriendlyError, details };
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse HTTP status errors and return user-friendly error messages
|
||||
*/
|
||||
static parseHttpError(errorMessage: string, issuerUrl?: string): OidcErrorDetails {
|
||||
let userFriendlyError = errorMessage;
|
||||
let details: Record<string, unknown> = { originalError: errorMessage };
|
||||
|
||||
if (errorMessage.includes('404') || errorMessage.includes('Not Found')) {
|
||||
const baseUrl = issuerUrl?.endsWith('/.well-known/openid-configuration')
|
||||
? issuerUrl.replace('/.well-known/openid-configuration', '')
|
||||
: issuerUrl;
|
||||
userFriendlyError = `OIDC discovery endpoint not found. Please verify that '${baseUrl}/.well-known/openid-configuration' exists.`;
|
||||
details = { type: 'DISCOVERY_NOT_FOUND', originalError: errorMessage };
|
||||
} else if (errorMessage.includes('401') || errorMessage.includes('403')) {
|
||||
userFriendlyError = `Access denied to discovery endpoint. Please check the issuer URL and any authentication requirements.`;
|
||||
details = { type: 'AUTHENTICATION_ERROR', originalError: errorMessage };
|
||||
} else if (errorMessage.includes('unexpected HTTP response status code')) {
|
||||
// Extract status code if possible
|
||||
const statusMatch = errorMessage.match(/status code (\d+)/);
|
||||
const statusCode = statusMatch ? statusMatch[1] : 'unknown';
|
||||
const baseUrl = issuerUrl?.endsWith('/.well-known/openid-configuration')
|
||||
? issuerUrl.replace('/.well-known/openid-configuration', '')
|
||||
: issuerUrl;
|
||||
userFriendlyError = `HTTP ${statusCode} error from discovery endpoint. Please check that '${baseUrl}/.well-known/openid-configuration' returns a valid OIDC discovery document.`;
|
||||
details = { type: 'HTTP_STATUS_ERROR', statusCode, originalError: errorMessage };
|
||||
}
|
||||
|
||||
return { userFriendlyError, details };
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse generic OIDC errors and return user-friendly error messages
|
||||
*/
|
||||
static parseGenericError(error: unknown, issuerUrl?: string): OidcErrorDetails {
|
||||
const errorMessage = error instanceof Error ? error.message : String(error);
|
||||
let userFriendlyError = errorMessage;
|
||||
let details: Record<string, unknown> = { originalError: errorMessage };
|
||||
|
||||
// Check for specific error patterns
|
||||
if (errorMessage.includes('getaddrinfo ENOTFOUND')) {
|
||||
userFriendlyError = `Cannot resolve domain name. Please check that '${issuerUrl}' is accessible and spelled correctly.`;
|
||||
details = { type: 'DNS_ERROR', originalError: errorMessage };
|
||||
} else if (errorMessage.includes('ECONNREFUSED')) {
|
||||
userFriendlyError = `Connection refused. The server at '${issuerUrl}' is not accepting connections.`;
|
||||
details = { type: 'CONNECTION_ERROR', originalError: errorMessage };
|
||||
} else if (errorMessage.includes('ECONNRESET') || errorMessage.includes('ETIMEDOUT')) {
|
||||
userFriendlyError = `Connection timeout. The server at '${issuerUrl}' is not responding.`;
|
||||
details = { type: 'TIMEOUT_ERROR', originalError: errorMessage };
|
||||
} else if (
|
||||
errorMessage.includes('certificate') ||
|
||||
errorMessage.includes('SSL') ||
|
||||
errorMessage.includes('TLS')
|
||||
) {
|
||||
userFriendlyError = `SSL/TLS certificate error. The server certificate may be invalid or expired.`;
|
||||
details = { type: 'SSL_ERROR', originalError: errorMessage };
|
||||
} else if (errorMessage.includes('JSON') || errorMessage.includes('parse')) {
|
||||
userFriendlyError = `Invalid OIDC discovery response. The server returned malformed JSON.`;
|
||||
details = { type: 'INVALID_JSON', originalError: errorMessage };
|
||||
} else if (error && (error as any).code === 'OAUTH_RESPONSE_IS_NOT_CONFORM') {
|
||||
const baseUrl = issuerUrl?.endsWith('/.well-known/openid-configuration')
|
||||
? issuerUrl.replace('/.well-known/openid-configuration', '')
|
||||
: issuerUrl;
|
||||
userFriendlyError = `Invalid OIDC discovery document. The server at '${baseUrl}/.well-known/openid-configuration' returned a response that doesn't conform to the OpenID Connect Discovery specification. Please verify the endpoint returns valid OIDC metadata.`;
|
||||
details = { type: 'INVALID_OIDC_DOCUMENT', originalError: errorMessage };
|
||||
}
|
||||
|
||||
return { userFriendlyError, details };
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse OIDC discovery errors and return user-friendly error messages
|
||||
*/
|
||||
static parseDiscoveryError(error: unknown, issuerUrl?: string): OidcErrorDetails {
|
||||
const errorMessage = error instanceof Error ? error.message : String(error);
|
||||
|
||||
// Log additional error details for debugging
|
||||
if (error instanceof Error) {
|
||||
this.logger.log(`Error type: ${error.constructor.name}`);
|
||||
if ('stack' in error && error.stack) {
|
||||
this.logger.debug(`Stack trace: ${error.stack}`);
|
||||
}
|
||||
if ('response' in error) {
|
||||
const response = (error as any).response;
|
||||
if (response) {
|
||||
this.logger.log(`Response status: ${response.status}`);
|
||||
this.logger.log(`Response body: ${response.body}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for fetch-specific errors first
|
||||
if (errorMessage.includes('fetch failed')) {
|
||||
return this.parseFetchError(error, issuerUrl);
|
||||
}
|
||||
|
||||
// Check for HTTP status errors
|
||||
const httpError = this.parseHttpError(errorMessage, issuerUrl);
|
||||
// Proper type-narrowing guard for accessing details.type
|
||||
if (
|
||||
httpError.details &&
|
||||
typeof httpError.details === 'object' &&
|
||||
'type' in httpError.details &&
|
||||
httpError.details.type !== undefined
|
||||
) {
|
||||
return httpError;
|
||||
}
|
||||
|
||||
// Fall back to generic error parsing
|
||||
return this.parseGenericError(error, issuerUrl);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,228 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import type { FastifyRequest } from '@app/unraid-api/types/fastify.js';
|
||||
import { OidcRequestHandler } from '@app/unraid-api/graph/resolvers/sso/utils/oidc-request-handler.util.js';
|
||||
|
||||
describe('OidcRequestHandler', () => {
|
||||
let mockLogger: Logger;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockLogger = {
|
||||
debug: vi.fn(),
|
||||
log: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
} as any;
|
||||
});
|
||||
|
||||
describe('extractRequestInfo', () => {
|
||||
it('should extract request info from headers', () => {
|
||||
const mockReq = {
|
||||
headers: {
|
||||
'x-forwarded-proto': 'https',
|
||||
'x-forwarded-host': 'example.com:8443',
|
||||
},
|
||||
protocol: 'http',
|
||||
url: '/callback?code=123&state=456',
|
||||
} as unknown as FastifyRequest;
|
||||
|
||||
const result = OidcRequestHandler.extractRequestInfo(mockReq);
|
||||
|
||||
expect(result.protocol).toBe('https');
|
||||
expect(result.host).toBe('example.com:8443');
|
||||
expect(result.fullUrl).toBe('https://example.com:8443/callback?code=123&state=456');
|
||||
expect(result.baseUrl).toBe('https://example.com:8443');
|
||||
});
|
||||
|
||||
it('should fall back to request properties when headers are missing', () => {
|
||||
const mockReq = {
|
||||
headers: {
|
||||
host: 'localhost:3000',
|
||||
},
|
||||
protocol: 'http',
|
||||
url: '/callback?code=123&state=456',
|
||||
} as FastifyRequest;
|
||||
|
||||
const result = OidcRequestHandler.extractRequestInfo(mockReq);
|
||||
|
||||
expect(result.protocol).toBe('http');
|
||||
expect(result.host).toBe('localhost:3000');
|
||||
expect(result.fullUrl).toBe('http://localhost:3000/callback?code=123&state=456');
|
||||
expect(result.baseUrl).toBe('http://localhost:3000');
|
||||
});
|
||||
|
||||
it('should use defaults when all headers are missing', () => {
|
||||
const mockReq = {
|
||||
headers: {},
|
||||
url: '/callback?code=123&state=456',
|
||||
} as FastifyRequest;
|
||||
|
||||
const result = OidcRequestHandler.extractRequestInfo(mockReq);
|
||||
|
||||
expect(result.protocol).toBe('http');
|
||||
expect(result.host).toBe('localhost:3000');
|
||||
expect(result.fullUrl).toBe('http://localhost:3000/callback?code=123&state=456');
|
||||
expect(result.baseUrl).toBe('http://localhost:3000');
|
||||
});
|
||||
});
|
||||
|
||||
describe('validateAuthorizeParams', () => {
|
||||
it('should validate valid parameters', () => {
|
||||
const result = OidcRequestHandler.validateAuthorizeParams(
|
||||
'provider123',
|
||||
'state456',
|
||||
'https://example.com/callback'
|
||||
);
|
||||
|
||||
expect(result.providerId).toBe('provider123');
|
||||
expect(result.state).toBe('state456');
|
||||
expect(result.redirectUri).toBe('https://example.com/callback');
|
||||
});
|
||||
|
||||
it('should throw error for missing provider ID', () => {
|
||||
expect(() => {
|
||||
OidcRequestHandler.validateAuthorizeParams(
|
||||
undefined,
|
||||
'state456',
|
||||
'https://example.com/callback'
|
||||
);
|
||||
}).toThrow('Provider ID is required');
|
||||
});
|
||||
|
||||
it('should throw error for missing state', () => {
|
||||
expect(() => {
|
||||
OidcRequestHandler.validateAuthorizeParams(
|
||||
'provider123',
|
||||
undefined,
|
||||
'https://example.com/callback'
|
||||
);
|
||||
}).toThrow('State parameter is required');
|
||||
});
|
||||
|
||||
it('should throw error for missing redirect URI', () => {
|
||||
expect(() => {
|
||||
OidcRequestHandler.validateAuthorizeParams('provider123', 'state456', undefined);
|
||||
}).toThrow('Redirect URI is required');
|
||||
});
|
||||
});
|
||||
|
||||
describe('validateCallbackParams', () => {
|
||||
it('should validate valid parameters', () => {
|
||||
const result = OidcRequestHandler.validateCallbackParams('code123', 'state456');
|
||||
|
||||
expect(result.code).toBe('code123');
|
||||
expect(result.state).toBe('state456');
|
||||
});
|
||||
|
||||
it('should throw error for missing code', () => {
|
||||
expect(() => {
|
||||
OidcRequestHandler.validateCallbackParams(undefined, 'state456');
|
||||
}).toThrow('Missing required parameters');
|
||||
});
|
||||
|
||||
it('should throw error for missing state', () => {
|
||||
expect(() => {
|
||||
OidcRequestHandler.validateCallbackParams('code123', undefined);
|
||||
}).toThrow('Missing required parameters');
|
||||
});
|
||||
|
||||
it('should throw error for empty code', () => {
|
||||
expect(() => {
|
||||
OidcRequestHandler.validateCallbackParams('', 'state456');
|
||||
}).toThrow('Missing required parameters');
|
||||
});
|
||||
|
||||
it('should throw error for empty state', () => {
|
||||
expect(() => {
|
||||
OidcRequestHandler.validateCallbackParams('code123', '');
|
||||
}).toThrow('Missing required parameters');
|
||||
});
|
||||
});
|
||||
|
||||
describe('handleAuthorize', () => {
|
||||
it('should handle authorization flow', async () => {
|
||||
const mockAuthService = {
|
||||
getAuthorizationUrl: vi
|
||||
.fn()
|
||||
.mockResolvedValue('https://provider.com/auth?client_id=123'),
|
||||
};
|
||||
|
||||
const mockReq = {
|
||||
headers: { 'x-forwarded-proto': 'https', 'x-forwarded-host': 'example.com' },
|
||||
url: '/authorize',
|
||||
} as unknown as FastifyRequest;
|
||||
|
||||
const authUrl = await OidcRequestHandler.handleAuthorize(
|
||||
'provider123',
|
||||
'state456',
|
||||
'https://example.com/callback',
|
||||
mockReq,
|
||||
mockAuthService as any,
|
||||
mockLogger
|
||||
);
|
||||
|
||||
expect(authUrl).toBe('https://provider.com/auth?client_id=123');
|
||||
expect(mockAuthService.getAuthorizationUrl).toHaveBeenCalledWith({
|
||||
providerId: 'provider123',
|
||||
state: 'state456',
|
||||
requestOrigin: 'https://example.com/callback',
|
||||
requestHeaders: {
|
||||
'x-forwarded-proto': 'https',
|
||||
'x-forwarded-host': 'example.com',
|
||||
},
|
||||
});
|
||||
expect(mockLogger.debug).toHaveBeenCalledWith(
|
||||
'Authorization request - Provider: provider123'
|
||||
);
|
||||
expect(mockLogger.log).toHaveBeenCalledWith(
|
||||
'Redirecting to OIDC provider: https://provider.com/auth?client_id=123'
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('handleCallback', () => {
|
||||
it('should handle callback flow', async () => {
|
||||
const mockStateService = {
|
||||
extractProviderFromState: vi.fn().mockReturnValue('provider123'),
|
||||
};
|
||||
|
||||
const mockAuthService = {
|
||||
getStateService: vi.fn().mockReturnValue(mockStateService),
|
||||
handleCallback: vi.fn().mockResolvedValue('paddedToken123'),
|
||||
};
|
||||
|
||||
const mockReq: Pick<FastifyRequest, 'id' | 'headers' | 'url'> = {
|
||||
id: '123',
|
||||
headers: { 'x-forwarded-proto': 'https', 'x-forwarded-host': 'example.com' },
|
||||
url: '/callback?code=123&state=456',
|
||||
};
|
||||
|
||||
const result = await OidcRequestHandler.handleCallback(
|
||||
'code123',
|
||||
'state456',
|
||||
mockReq as unknown as FastifyRequest,
|
||||
mockAuthService as any,
|
||||
mockLogger
|
||||
);
|
||||
|
||||
expect(result.providerId).toBe('provider123');
|
||||
expect(result.paddedToken).toBe('paddedToken123');
|
||||
expect(result.requestInfo.fullUrl).toBe('https://example.com/callback?code=123&state=456');
|
||||
expect(mockAuthService.handleCallback).toHaveBeenCalledWith({
|
||||
providerId: 'provider123',
|
||||
code: 'code123',
|
||||
state: 'state456',
|
||||
requestOrigin: 'https://example.com',
|
||||
fullCallbackUrl: 'https://example.com/callback?code=123&state=456',
|
||||
requestHeaders: {
|
||||
'x-forwarded-proto': 'https',
|
||||
'x-forwarded-host': 'example.com',
|
||||
},
|
||||
});
|
||||
expect(mockLogger.debug).toHaveBeenCalledWith('Callback request - Provider: provider123');
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,155 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
|
||||
import type { FastifyRequest } from '@app/unraid-api/types/fastify.js';
|
||||
import { OidcService } from '@app/unraid-api/graph/resolvers/sso/core/oidc.service.js';
|
||||
import { OidcStateExtractor } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state-extractor.util.js';
|
||||
|
||||
export interface RequestInfo {
|
||||
protocol: string;
|
||||
host: string;
|
||||
fullUrl: string;
|
||||
baseUrl: string;
|
||||
}
|
||||
|
||||
export interface OidcFlowResult {
|
||||
providerId: string;
|
||||
requestInfo: RequestInfo;
|
||||
}
|
||||
|
||||
export interface OidcCallbackResult extends OidcFlowResult {
|
||||
paddedToken: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Utility class to handle common OIDC request processing logic
|
||||
* between authorize and callback endpoints
|
||||
*/
|
||||
export class OidcRequestHandler {
|
||||
/**
|
||||
* Extract request information from Fastify request headers
|
||||
*/
|
||||
static extractRequestInfo(req: FastifyRequest): RequestInfo {
|
||||
// Handle potentially comma-separated forwarded headers (take first value)
|
||||
const forwardedProto = String(req.headers['x-forwarded-proto'] || '')
|
||||
.split(',')[0]
|
||||
?.trim();
|
||||
const forwardedHost = String(req.headers['x-forwarded-host'] || '')
|
||||
.split(',')[0]
|
||||
?.trim();
|
||||
|
||||
const protocol = forwardedProto || req.protocol || 'http';
|
||||
const host = forwardedHost || req.headers.host || 'localhost:3000';
|
||||
const fullUrl = `${protocol}://${host}${req.url}`;
|
||||
const baseUrl = `${protocol}://${host}`;
|
||||
|
||||
return {
|
||||
protocol,
|
||||
host,
|
||||
fullUrl,
|
||||
baseUrl,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle OIDC authorization flow
|
||||
*/
|
||||
static async handleAuthorize(
|
||||
providerId: string,
|
||||
state: string,
|
||||
redirectUri: string,
|
||||
req: FastifyRequest,
|
||||
oidcService: OidcService,
|
||||
logger: Logger
|
||||
): Promise<string> {
|
||||
const requestInfo = this.extractRequestInfo(req);
|
||||
|
||||
logger.debug(`Authorization request - Provider: ${providerId}`);
|
||||
logger.debug(`Authorization request - Full URL: ${requestInfo.fullUrl}`);
|
||||
logger.debug(`Authorization request - Redirect URI: ${redirectUri}`);
|
||||
|
||||
// Get authorization URL using the validated redirect URI and request headers
|
||||
const authUrl = await oidcService.getAuthorizationUrl({
|
||||
providerId,
|
||||
state,
|
||||
requestOrigin: redirectUri,
|
||||
requestHeaders: req.headers as Record<string, string | string[] | undefined>,
|
||||
});
|
||||
|
||||
logger.log(`Redirecting to OIDC provider: ${authUrl}`);
|
||||
return authUrl;
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle OIDC callback flow
|
||||
*/
|
||||
static async handleCallback(
|
||||
code: string,
|
||||
state: string,
|
||||
req: FastifyRequest,
|
||||
oidcService: OidcService,
|
||||
logger: Logger
|
||||
): Promise<OidcCallbackResult> {
|
||||
// Extract provider ID from state for routing
|
||||
const { providerId } = OidcStateExtractor.extractProviderFromState(
|
||||
state,
|
||||
oidcService.getStateService()
|
||||
);
|
||||
|
||||
const requestInfo = this.extractRequestInfo(req);
|
||||
|
||||
logger.debug(`Callback request - Provider: ${providerId}`);
|
||||
logger.debug(`Callback request - Full URL: ${requestInfo.fullUrl}`);
|
||||
logger.debug(`Redirect URI will be retrieved from encrypted state`);
|
||||
|
||||
// Handle the callback using stored redirect URI from state and request headers
|
||||
const paddedToken = await oidcService.handleCallback({
|
||||
providerId,
|
||||
code,
|
||||
state,
|
||||
requestOrigin: requestInfo.baseUrl,
|
||||
fullCallbackUrl: requestInfo.fullUrl,
|
||||
requestHeaders: req.headers as Record<string, string | string[] | undefined>,
|
||||
});
|
||||
|
||||
return {
|
||||
providerId,
|
||||
requestInfo,
|
||||
paddedToken,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate required parameters for authorization flow
|
||||
*/
|
||||
static validateAuthorizeParams(
|
||||
providerId: string | undefined,
|
||||
state: string | undefined,
|
||||
redirectUri: string | undefined
|
||||
): { providerId: string; state: string; redirectUri: string } {
|
||||
if (!providerId) {
|
||||
throw new Error('Provider ID is required');
|
||||
}
|
||||
if (!state) {
|
||||
throw new Error('State parameter is required');
|
||||
}
|
||||
if (!redirectUri) {
|
||||
throw new Error('Redirect URI is required');
|
||||
}
|
||||
|
||||
return { providerId, state, redirectUri };
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate required parameters for callback flow
|
||||
*/
|
||||
static validateCallbackParams(
|
||||
code: string | undefined,
|
||||
state: string | undefined
|
||||
): { code: string; state: string } {
|
||||
if (!code || !state) {
|
||||
throw new Error('Missing required parameters');
|
||||
}
|
||||
|
||||
return { code, state };
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,205 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { OidcUrlPatterns } from '@app/unraid-api/graph/resolvers/sso/utils/oidc-url-patterns.util.js';
|
||||
|
||||
describe('OidcUrlPatterns', () => {
|
||||
describe('ISSUER_URL_PATTERN', () => {
|
||||
it('should be defined as a string', () => {
|
||||
expect(typeof OidcUrlPatterns.ISSUER_URL_PATTERN).toBe('string');
|
||||
expect(OidcUrlPatterns.ISSUER_URL_PATTERN).toBe('^https?://[^/\\s]+(?:/[^/\\s]*)*[^/\\s]$');
|
||||
});
|
||||
});
|
||||
|
||||
describe('ISSUER_URL_REGEX', () => {
|
||||
it('should be a RegExp instance', () => {
|
||||
expect(OidcUrlPatterns.ISSUER_URL_REGEX).toBeInstanceOf(RegExp);
|
||||
});
|
||||
|
||||
it('should match the pattern string', () => {
|
||||
const regex = new RegExp(OidcUrlPatterns.ISSUER_URL_PATTERN);
|
||||
expect(OidcUrlPatterns.ISSUER_URL_REGEX.source).toBe(regex.source);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isValidIssuerUrl', () => {
|
||||
it('should accept valid URLs without trailing slash', () => {
|
||||
const validUrls = [
|
||||
'https://accounts.google.com',
|
||||
'https://auth.example.com/oidc',
|
||||
'https://auth.example.com/realms/master',
|
||||
'http://localhost:8080',
|
||||
'http://localhost:8080/auth',
|
||||
'https://login.microsoftonline.com/common/v2.0',
|
||||
];
|
||||
|
||||
validUrls.forEach((url) => {
|
||||
expect(OidcUrlPatterns.isValidIssuerUrl(url)).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
it('should reject URLs with trailing slashes', () => {
|
||||
const invalidUrls = [
|
||||
'https://accounts.google.com/',
|
||||
'https://auth.example.com/oidc/',
|
||||
'https://auth.example.com/realms/master/',
|
||||
'http://localhost:8080/',
|
||||
'http://localhost:8080/auth/',
|
||||
'https://login.microsoftonline.com/common/v2.0/',
|
||||
];
|
||||
|
||||
invalidUrls.forEach((url) => {
|
||||
expect(OidcUrlPatterns.isValidIssuerUrl(url)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
it('should reject URLs with whitespace', () => {
|
||||
const invalidUrls = [
|
||||
'https://accounts.google.com ',
|
||||
' https://accounts.google.com',
|
||||
'https://accounts. google.com',
|
||||
'https://accounts.google.com\t',
|
||||
'https://accounts.google.com\n',
|
||||
];
|
||||
|
||||
invalidUrls.forEach((url) => {
|
||||
expect(OidcUrlPatterns.isValidIssuerUrl(url)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
it('should accept both HTTP and HTTPS protocols', () => {
|
||||
expect(OidcUrlPatterns.isValidIssuerUrl('https://example.com')).toBe(true);
|
||||
expect(OidcUrlPatterns.isValidIssuerUrl('http://example.com')).toBe(true);
|
||||
});
|
||||
|
||||
it('should reject other protocols', () => {
|
||||
expect(OidcUrlPatterns.isValidIssuerUrl('ftp://example.com')).toBe(false);
|
||||
expect(OidcUrlPatterns.isValidIssuerUrl('ws://example.com')).toBe(false);
|
||||
expect(OidcUrlPatterns.isValidIssuerUrl('file://example.com')).toBe(false);
|
||||
});
|
||||
|
||||
it('should accept .well-known URLs without trailing slashes', () => {
|
||||
const wellKnownUrls = [
|
||||
'https://example.com/.well-known/openid-configuration',
|
||||
'https://auth.example.com/path/.well-known/openid-configuration',
|
||||
'https://example.com/.well-known/jwks.json',
|
||||
'https://keycloak.example.com/realms/master/.well-known/openid-configuration',
|
||||
];
|
||||
|
||||
wellKnownUrls.forEach((url) => {
|
||||
expect(OidcUrlPatterns.isValidIssuerUrl(url)).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
it('should reject .well-known URLs with trailing slashes', () => {
|
||||
const invalidWellKnownUrls = [
|
||||
'https://example.com/.well-known/openid-configuration/',
|
||||
'https://auth.example.com/path/.well-known/openid-configuration/',
|
||||
'https://example.com/.well-known/jwks.json/',
|
||||
'https://keycloak.example.com/realms/master/.well-known/openid-configuration/',
|
||||
];
|
||||
|
||||
invalidWellKnownUrls.forEach((url) => {
|
||||
expect(OidcUrlPatterns.isValidIssuerUrl(url)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle complex real-world scenarios', () => {
|
||||
// Google
|
||||
expect(OidcUrlPatterns.isValidIssuerUrl('https://accounts.google.com')).toBe(true);
|
||||
expect(OidcUrlPatterns.isValidIssuerUrl('https://accounts.google.com/')).toBe(false);
|
||||
|
||||
// Microsoft
|
||||
expect(
|
||||
OidcUrlPatterns.isValidIssuerUrl('https://login.microsoftonline.com/tenant-id/v2.0')
|
||||
).toBe(true);
|
||||
expect(
|
||||
OidcUrlPatterns.isValidIssuerUrl('https://login.microsoftonline.com/tenant-id/v2.0/')
|
||||
).toBe(false);
|
||||
|
||||
// Auth0
|
||||
expect(OidcUrlPatterns.isValidIssuerUrl('https://tenant.auth0.com')).toBe(true);
|
||||
expect(OidcUrlPatterns.isValidIssuerUrl('https://tenant.auth0.com/')).toBe(false);
|
||||
|
||||
// Keycloak
|
||||
expect(OidcUrlPatterns.isValidIssuerUrl('https://keycloak.example.com/realms/master')).toBe(
|
||||
true
|
||||
);
|
||||
expect(OidcUrlPatterns.isValidIssuerUrl('https://keycloak.example.com/realms/master/')).toBe(
|
||||
false
|
||||
);
|
||||
|
||||
// AWS Cognito
|
||||
expect(
|
||||
OidcUrlPatterns.isValidIssuerUrl(
|
||||
'https://cognito-idp.us-west-2.amazonaws.com/us-west-2_example'
|
||||
)
|
||||
).toBe(true);
|
||||
expect(
|
||||
OidcUrlPatterns.isValidIssuerUrl(
|
||||
'https://cognito-idp.us-west-2.amazonaws.com/us-west-2_example/'
|
||||
)
|
||||
).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getExamples', () => {
|
||||
it('should return valid and invalid URL examples', () => {
|
||||
const examples = OidcUrlPatterns.getExamples();
|
||||
|
||||
expect(examples).toHaveProperty('valid');
|
||||
expect(examples).toHaveProperty('invalid');
|
||||
expect(Array.isArray(examples.valid)).toBe(true);
|
||||
expect(Array.isArray(examples.invalid)).toBe(true);
|
||||
expect(examples.valid.length).toBeGreaterThan(0);
|
||||
expect(examples.invalid.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('should have all valid examples pass validation', () => {
|
||||
const examples = OidcUrlPatterns.getExamples();
|
||||
|
||||
examples.valid.forEach((url) => {
|
||||
expect(OidcUrlPatterns.isValidIssuerUrl(url)).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
it('should have all invalid examples fail validation', () => {
|
||||
const examples = OidcUrlPatterns.getExamples();
|
||||
|
||||
examples.invalid.forEach((url) => {
|
||||
expect(OidcUrlPatterns.isValidIssuerUrl(url)).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('integration with the bug report scenario', () => {
|
||||
it('should specifically catch the Google trailing slash issue from the bug report', () => {
|
||||
// The exact scenario from the bug report
|
||||
const problematicUrl = 'https://accounts.google.com/';
|
||||
const correctUrl = 'https://accounts.google.com';
|
||||
|
||||
expect(OidcUrlPatterns.isValidIssuerUrl(problematicUrl)).toBe(false);
|
||||
expect(OidcUrlPatterns.isValidIssuerUrl(correctUrl)).toBe(true);
|
||||
});
|
||||
|
||||
it('should prevent the double slash in discovery URL construction', () => {
|
||||
// Simulate what would happen in discovery URL construction
|
||||
const issuerWithSlash = 'https://accounts.google.com/';
|
||||
const issuerWithoutSlash = 'https://accounts.google.com';
|
||||
|
||||
// This is what would happen in the discovery process
|
||||
const discoveryWithSlash = `${issuerWithSlash}/.well-known/openid-configuration`;
|
||||
const discoveryWithoutSlash = `${issuerWithoutSlash}/.well-known/openid-configuration`;
|
||||
|
||||
expect(discoveryWithSlash).toBe(
|
||||
'https://accounts.google.com//.well-known/openid-configuration'
|
||||
); // Double slash - bad
|
||||
expect(discoveryWithoutSlash).toBe(
|
||||
'https://accounts.google.com/.well-known/openid-configuration'
|
||||
); // Single slash - good
|
||||
|
||||
// Our validation should prevent the first scenario
|
||||
expect(OidcUrlPatterns.isValidIssuerUrl(issuerWithSlash)).toBe(false);
|
||||
expect(OidcUrlPatterns.isValidIssuerUrl(issuerWithoutSlash)).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,59 @@
|
||||
/**
|
||||
* Utility for OIDC URL validation patterns
|
||||
*/
|
||||
export class OidcUrlPatterns {
|
||||
/**
|
||||
* Regex pattern for validating OIDC issuer URLs
|
||||
* - Allows HTTP and HTTPS protocols
|
||||
* - Prevents trailing slashes
|
||||
* - Prevents whitespace
|
||||
* - Allows paths but not ending with slash
|
||||
*/
|
||||
static readonly ISSUER_URL_PATTERN = '^https?://[^/\\s]+(?:/[^/\\s]*)*[^/\\s]$';
|
||||
|
||||
/**
|
||||
* Compiled regex for issuer URL validation
|
||||
*/
|
||||
static readonly ISSUER_URL_REGEX = new RegExp(OidcUrlPatterns.ISSUER_URL_PATTERN);
|
||||
|
||||
/**
|
||||
* Validate an issuer URL against the pattern
|
||||
* @param url The URL to validate
|
||||
* @returns True if the URL is valid, false otherwise
|
||||
*/
|
||||
static isValidIssuerUrl(url: string): boolean {
|
||||
return this.ISSUER_URL_REGEX.test(url);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get examples of valid and invalid issuer URLs for documentation/testing
|
||||
*/
|
||||
static getExamples() {
|
||||
return {
|
||||
valid: [
|
||||
// Standard issuer URLs (most common)
|
||||
'https://accounts.google.com',
|
||||
'https://auth.example.com/oidc',
|
||||
'https://auth.example.com/realms/master',
|
||||
'http://localhost:8080',
|
||||
'http://localhost:8080/auth',
|
||||
'https://login.microsoftonline.com/common/v2.0',
|
||||
'https://cognito-idp.us-west-2.amazonaws.com/us-west-2_example',
|
||||
// Well-known URLs are valid at the URL pattern level (schema-level validation handles rejection)
|
||||
'https://example.com/.well-known/openid-configuration',
|
||||
'https://auth.example.com/path/.well-known/openid-configuration',
|
||||
'https://example.com/.well-known/jwks.json',
|
||||
],
|
||||
invalid: [
|
||||
'https://accounts.google.com/', // Trailing slash
|
||||
'https://auth.example.com/oidc/', // Trailing slash
|
||||
'https://auth.example.com/realms/master/', // Trailing slash
|
||||
'http://localhost:8080/', // Trailing slash
|
||||
'https://accounts.google.com ', // Trailing whitespace
|
||||
' https://accounts.google.com', // Leading whitespace
|
||||
'https://accounts. google.com', // Internal whitespace
|
||||
'ftp://example.com', // Invalid protocol
|
||||
],
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -2,12 +2,12 @@ import { Module } from '@nestjs/common';
|
||||
import { ScheduleModule } from '@nestjs/schedule';
|
||||
|
||||
import { SubscriptionHelperService } from '@app/unraid-api/graph/services/subscription-helper.service.js';
|
||||
import { SubscriptionPollingService } from '@app/unraid-api/graph/services/subscription-polling.service.js';
|
||||
import { SubscriptionManagerService } from '@app/unraid-api/graph/services/subscription-manager.service.js';
|
||||
import { SubscriptionTrackerService } from '@app/unraid-api/graph/services/subscription-tracker.service.js';
|
||||
|
||||
@Module({
|
||||
imports: [],
|
||||
providers: [SubscriptionTrackerService, SubscriptionHelperService, SubscriptionPollingService],
|
||||
exports: [SubscriptionTrackerService, SubscriptionHelperService, SubscriptionPollingService],
|
||||
providers: [SubscriptionTrackerService, SubscriptionHelperService, SubscriptionManagerService],
|
||||
exports: [SubscriptionTrackerService, SubscriptionHelperService], // SubscriptionManagerService is internal
|
||||
})
|
||||
export class ServicesModule {}
|
||||
|
||||
@@ -4,7 +4,25 @@ import { createSubscription, PUBSUB_CHANNEL } from '@app/core/pubsub.js';
|
||||
import { SubscriptionTrackerService } from '@app/unraid-api/graph/services/subscription-tracker.service.js';
|
||||
|
||||
/**
|
||||
* Helper service for creating tracked GraphQL subscriptions with automatic cleanup
|
||||
* High-level helper service for creating GraphQL subscriptions with automatic cleanup.
|
||||
*
|
||||
* This service provides a convenient way to create GraphQL subscriptions that:
|
||||
* - Automatically track subscriber count via SubscriptionTrackerService
|
||||
* - Properly clean up resources when subscriptions end
|
||||
* - Handle errors gracefully
|
||||
*
|
||||
* **When to use this service:**
|
||||
* - In GraphQL resolvers when implementing subscriptions
|
||||
* - When you need automatic reference counting for shared resources
|
||||
* - When you want to ensure proper cleanup on subscription termination
|
||||
*
|
||||
* @example
|
||||
* // In a GraphQL resolver
|
||||
* \@Subscription(() => MetricsUpdate)
|
||||
* async metricsSubscription() {
|
||||
* // Topic must be registered first via SubscriptionTrackerService
|
||||
* return this.subscriptionHelper.createTrackedSubscription(PUBSUB_CHANNEL.METRICS);
|
||||
* }
|
||||
*/
|
||||
@Injectable()
|
||||
export class SubscriptionHelperService {
|
||||
@@ -15,7 +33,7 @@ export class SubscriptionHelperService {
|
||||
* @param topic The subscription topic/channel to subscribe to
|
||||
* @returns A proxy async iterator with automatic cleanup
|
||||
*/
|
||||
public createTrackedSubscription<T = any>(topic: PUBSUB_CHANNEL): AsyncIterableIterator<T> {
|
||||
public createTrackedSubscription<T = any>(topic: PUBSUB_CHANNEL | string): AsyncIterableIterator<T> {
|
||||
const innerIterator = createSubscription<T>(topic);
|
||||
|
||||
// Subscribe when the subscription starts
|
||||
|
||||
@@ -0,0 +1,196 @@
|
||||
import { Injectable, Logger, OnModuleDestroy } from '@nestjs/common';
|
||||
import { SchedulerRegistry } from '@nestjs/schedule';
|
||||
|
||||
/**
|
||||
* Configuration for managed subscriptions
|
||||
*/
|
||||
export interface SubscriptionConfig {
|
||||
/** Unique identifier for the subscription */
|
||||
name: string;
|
||||
|
||||
/**
|
||||
* Polling interval in milliseconds.
|
||||
* - If set to a number, the callback will be called at that interval
|
||||
* - If null/undefined, the subscription is event-based (no polling)
|
||||
*/
|
||||
intervalMs?: number | null;
|
||||
|
||||
/** Function to call periodically (for polling) or once (for setup) */
|
||||
callback: () => Promise<void>;
|
||||
|
||||
/** Optional function called when the subscription starts */
|
||||
onStart?: () => Promise<void>;
|
||||
|
||||
/** Optional function called when the subscription stops */
|
||||
onStop?: () => Promise<void>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Low-level service for managing both polling and event-based subscriptions.
|
||||
*
|
||||
* ⚠️ **IMPORTANT**: This is an internal service. Do not use directly in resolvers or business logic.
|
||||
* Instead, use one of the higher-level services:
|
||||
* - **SubscriptionTrackerService**: For subscriptions that need reference counting
|
||||
* - **SubscriptionHelperService**: For GraphQL subscriptions with automatic cleanup
|
||||
*
|
||||
* This service provides the underlying implementation for:
|
||||
* - **Polling subscriptions**: Execute a callback at regular intervals
|
||||
* - **Event-based subscriptions**: Set up event listeners or watchers that persist until stopped
|
||||
*
|
||||
* @internal
|
||||
*/
|
||||
@Injectable()
|
||||
export class SubscriptionManagerService implements OnModuleDestroy {
|
||||
private readonly logger = new Logger(SubscriptionManagerService.name);
|
||||
private readonly activeSubscriptions = new Map<
|
||||
string,
|
||||
{ isPolling: boolean; config?: SubscriptionConfig }
|
||||
>();
|
||||
|
||||
constructor(private readonly schedulerRegistry: SchedulerRegistry) {}
|
||||
|
||||
async onModuleDestroy() {
|
||||
await this.stopAll();
|
||||
}
|
||||
|
||||
/**
|
||||
* Start a managed subscription (polling or event-based).
|
||||
*
|
||||
* @param config - The subscription configuration
|
||||
* @throws Will throw an error if the onStart callback fails
|
||||
*/
|
||||
async startSubscription(config: SubscriptionConfig): Promise<void> {
|
||||
const { name, intervalMs, callback, onStart } = config;
|
||||
|
||||
// Clean up any existing subscription with the same name
|
||||
await this.stopSubscription(name);
|
||||
|
||||
// Initialize subscription state with config
|
||||
this.activeSubscriptions.set(name, { isPolling: false, config });
|
||||
|
||||
// Call onStart callback if provided
|
||||
if (onStart) {
|
||||
try {
|
||||
await onStart();
|
||||
this.logger.debug(`Called onStart for '${name}'`);
|
||||
} catch (error) {
|
||||
this.logger.error(`Error in onStart for '${name}'`, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
// If intervalMs is null, this is a continuous/event-based subscription
|
||||
if (intervalMs === null || intervalMs === undefined) {
|
||||
this.logger.debug(`Started continuous subscription for '${name}' (no polling)`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Create the polling function with guard against overlapping executions
|
||||
const pollFunction = async () => {
|
||||
const subscription = this.activeSubscriptions.get(name);
|
||||
if (!subscription || subscription.isPolling) {
|
||||
return;
|
||||
}
|
||||
|
||||
subscription.isPolling = true;
|
||||
try {
|
||||
await callback();
|
||||
} catch (error) {
|
||||
this.logger.error(`Error in subscription callback '${name}'`, error);
|
||||
} finally {
|
||||
if (subscription) {
|
||||
subscription.isPolling = false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Create and register the interval
|
||||
const interval = setInterval(pollFunction, intervalMs);
|
||||
this.schedulerRegistry.addInterval(name, interval);
|
||||
|
||||
this.logger.debug(`Started polling for '${name}' every ${intervalMs}ms`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop a managed subscription.
|
||||
*
|
||||
* This will:
|
||||
* 1. Stop any active polling interval
|
||||
* 2. Call the onStop callback if provided
|
||||
* 3. Clean up internal state
|
||||
*
|
||||
* @param name - The unique identifier of the subscription to stop
|
||||
*/
|
||||
async stopSubscription(name: string): Promise<void> {
|
||||
// Get the config before deleting
|
||||
const subscription = this.activeSubscriptions.get(name);
|
||||
const onStop = subscription?.config?.onStop;
|
||||
|
||||
try {
|
||||
if (this.schedulerRegistry.doesExist('interval', name)) {
|
||||
this.schedulerRegistry.deleteInterval(name);
|
||||
this.logger.debug(`Stopped polling interval for '${name}'`);
|
||||
}
|
||||
} catch (error) {
|
||||
// Interval doesn't exist, which is fine
|
||||
}
|
||||
|
||||
// Call onStop callback if provided
|
||||
if (onStop) {
|
||||
try {
|
||||
await onStop();
|
||||
this.logger.debug(`Called onStop for '${name}'`);
|
||||
} catch (error) {
|
||||
this.logger.error(`Error in onStop for '${name}'`, error);
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up subscription state
|
||||
this.activeSubscriptions.delete(name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop all active subscriptions.
|
||||
*
|
||||
* This is automatically called when the module is destroyed.
|
||||
*/
|
||||
async stopAll(): Promise<void> {
|
||||
// Get all active subscription keys (both polling and event-based)
|
||||
const activeKeys = Array.from(this.activeSubscriptions.keys());
|
||||
|
||||
// Stop each subscription and await cleanup
|
||||
await Promise.all(activeKeys.map((key) => this.stopSubscription(key)));
|
||||
|
||||
// Clear the map after all subscriptions are stopped
|
||||
this.activeSubscriptions.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a subscription is active.
|
||||
*
|
||||
* @param name - The unique identifier of the subscription
|
||||
* @returns true if the subscription exists (either polling or event-based)
|
||||
*/
|
||||
isSubscriptionActive(name: string): boolean {
|
||||
// Check both for polling intervals and event-based subscriptions
|
||||
return this.activeSubscriptions.has(name) || this.schedulerRegistry.doesExist('interval', name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the total number of active subscriptions.
|
||||
*
|
||||
* @returns The count of all active subscriptions (polling and event-based)
|
||||
*/
|
||||
getActiveSubscriptionCount(): number {
|
||||
return this.activeSubscriptions.size;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a list of all active subscription names.
|
||||
*
|
||||
* @returns Array of subscription identifiers
|
||||
*/
|
||||
getActiveSubscriptionNames(): string[] {
|
||||
return Array.from(this.activeSubscriptions.keys());
|
||||
}
|
||||
}
|
||||
@@ -1,91 +0,0 @@
|
||||
import { Injectable, Logger, OnModuleDestroy } from '@nestjs/common';
|
||||
import { SchedulerRegistry } from '@nestjs/schedule';
|
||||
|
||||
export interface PollingConfig {
|
||||
name: string;
|
||||
intervalMs: number;
|
||||
callback: () => Promise<void>;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class SubscriptionPollingService implements OnModuleDestroy {
|
||||
private readonly logger = new Logger(SubscriptionPollingService.name);
|
||||
private readonly activePollers = new Map<string, { isPolling: boolean }>();
|
||||
|
||||
constructor(private readonly schedulerRegistry: SchedulerRegistry) {}
|
||||
|
||||
onModuleDestroy() {
|
||||
this.stopAll();
|
||||
}
|
||||
|
||||
/**
|
||||
* Start polling for a specific subscription topic
|
||||
*/
|
||||
startPolling(config: PollingConfig): void {
|
||||
const { name, intervalMs, callback } = config;
|
||||
|
||||
// Clean up any existing interval
|
||||
this.stopPolling(name);
|
||||
|
||||
// Initialize polling state
|
||||
this.activePollers.set(name, { isPolling: false });
|
||||
|
||||
// Create the polling function with guard against overlapping executions
|
||||
const pollFunction = async () => {
|
||||
const poller = this.activePollers.get(name);
|
||||
if (!poller || poller.isPolling) {
|
||||
return;
|
||||
}
|
||||
|
||||
poller.isPolling = true;
|
||||
try {
|
||||
await callback();
|
||||
} catch (error) {
|
||||
this.logger.error(`Error in polling task '${name}'`, error);
|
||||
} finally {
|
||||
if (poller) {
|
||||
poller.isPolling = false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Create and register the interval
|
||||
const interval = setInterval(pollFunction, intervalMs);
|
||||
this.schedulerRegistry.addInterval(name, interval);
|
||||
|
||||
this.logger.debug(`Started polling for '${name}' every ${intervalMs}ms`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop polling for a specific subscription topic
|
||||
*/
|
||||
stopPolling(name: string): void {
|
||||
try {
|
||||
if (this.schedulerRegistry.doesExist('interval', name)) {
|
||||
this.schedulerRegistry.deleteInterval(name);
|
||||
this.logger.debug(`Stopped polling for '${name}'`);
|
||||
}
|
||||
} catch (error) {
|
||||
// Interval doesn't exist, which is fine
|
||||
}
|
||||
|
||||
// Clean up polling state
|
||||
this.activePollers.delete(name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop all active polling tasks
|
||||
*/
|
||||
stopAll(): void {
|
||||
const intervals = this.schedulerRegistry.getIntervals();
|
||||
intervals.forEach((key) => this.stopPolling(key));
|
||||
this.activePollers.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if polling is active for a given name
|
||||
*/
|
||||
isPolling(name: string): boolean {
|
||||
return this.schedulerRegistry.doesExist('interval', name);
|
||||
}
|
||||
}
|
||||
@@ -1,14 +1,44 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
|
||||
import { SubscriptionPollingService } from '@app/unraid-api/graph/services/subscription-polling.service.js';
|
||||
import { SubscriptionManagerService } from '@app/unraid-api/graph/services/subscription-manager.service.js';
|
||||
|
||||
/**
|
||||
* Service for managing subscriptions with automatic reference counting.
|
||||
*
|
||||
* This service tracks the number of active subscribers for each topic and automatically
|
||||
* starts/stops the underlying subscription based on subscriber count.
|
||||
*
|
||||
* **When to use this service:**
|
||||
* - When you have multiple GraphQL subscriptions that share the same data source
|
||||
* - When you need to start a resource (polling, file watcher, etc.) only when there are active subscribers
|
||||
* - When you need automatic cleanup when the last subscriber disconnects
|
||||
*
|
||||
* @example
|
||||
* // Register a polling subscription
|
||||
* subscriptionTracker.registerTopic(
|
||||
* 'metrics-update',
|
||||
* async () => {
|
||||
* const metrics = await fetchMetrics();
|
||||
* pubsub.publish('metrics-update', { metrics });
|
||||
* },
|
||||
* 5000 // Poll every 5 seconds
|
||||
* );
|
||||
*
|
||||
* @example
|
||||
* // Register an event-based subscription (e.g., file watching)
|
||||
* subscriptionTracker.registerTopic(
|
||||
* 'log-file-updates',
|
||||
* () => startFileWatcher('/var/log/app.log'), // onStart
|
||||
* () => stopFileWatcher('/var/log/app.log') // onStop
|
||||
* );
|
||||
*/
|
||||
@Injectable()
|
||||
export class SubscriptionTrackerService {
|
||||
private readonly logger = new Logger(SubscriptionTrackerService.name);
|
||||
private subscriberCounts = new Map<string, number>();
|
||||
private topicHandlers = new Map<string, { onStart: () => void; onStop: () => void }>();
|
||||
|
||||
constructor(private readonly pollingService: SubscriptionPollingService) {}
|
||||
constructor(private readonly subscriptionManager: SubscriptionManagerService) {}
|
||||
|
||||
/**
|
||||
* Register a topic with optional polling support
|
||||
@@ -29,8 +59,8 @@ export class SubscriptionTrackerService {
|
||||
callback: async () => callbackOrOnStart(),
|
||||
};
|
||||
this.topicHandlers.set(topic, {
|
||||
onStart: () => this.pollingService.startPolling(pollingConfig),
|
||||
onStop: () => this.pollingService.stopPolling(topic),
|
||||
onStart: () => this.subscriptionManager.startSubscription(pollingConfig),
|
||||
onStop: () => this.subscriptionManager.stopSubscription(topic),
|
||||
});
|
||||
} else {
|
||||
// Legacy API: onStart and onStop handlers
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { CacheModule } from '@nestjs/cache-manager';
|
||||
import { Test } from '@nestjs/testing';
|
||||
|
||||
import { CANONICAL_INTERNAL_CLIENT_TOKEN } from '@unraid/shared';
|
||||
@@ -60,7 +61,7 @@ vi.mock('execa', () => ({
|
||||
describe('RestModule Integration', () => {
|
||||
it('should compile with RestService having access to ApiReportService', async () => {
|
||||
const module = await Test.createTestingModule({
|
||||
imports: [RestModule],
|
||||
imports: [CacheModule.register({ isGlobal: true }), RestModule],
|
||||
})
|
||||
// Override services that have complex dependencies for testing
|
||||
.overrideProvider(CANONICAL_INTERNAL_CLIENT_TOKEN)
|
||||
|
||||
487
api/src/unraid-api/rest/rest.controller.test.ts
Normal file
487
api/src/unraid-api/rest/rest.controller.test.ts
Normal file
@@ -0,0 +1,487 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
import { ConfigService } from '@nestjs/config';
|
||||
import { Test, TestingModule } from '@nestjs/testing';
|
||||
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import type { FastifyReply, FastifyRequest } from '@app/unraid-api/types/fastify.js';
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
|
||||
import { OidcService } from '@app/unraid-api/graph/resolvers/sso/core/oidc.service.js';
|
||||
import { OidcRequestHandler } from '@app/unraid-api/graph/resolvers/sso/utils/oidc-request-handler.util.js';
|
||||
import { RestController } from '@app/unraid-api/rest/rest.controller.js';
|
||||
import { RestService } from '@app/unraid-api/rest/rest.service.js';
|
||||
|
||||
describe('RestController', () => {
|
||||
let controller: RestController;
|
||||
let oidcService: OidcService;
|
||||
let oidcConfig: OidcConfigPersistence;
|
||||
let mockReply: Partial<FastifyReply>;
|
||||
|
||||
// Helper function to create a mock request with the desired hostname
|
||||
const createMockRequest = (hostname?: string, headers: Record<string, any> = {}): FastifyRequest => {
|
||||
return {
|
||||
headers,
|
||||
hostname,
|
||||
url: '/test',
|
||||
protocol: 'https',
|
||||
} as FastifyRequest;
|
||||
};
|
||||
|
||||
beforeEach(async () => {
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
controllers: [RestController],
|
||||
providers: [
|
||||
{
|
||||
provide: RestService,
|
||||
useValue: {
|
||||
getLogs: vi.fn(),
|
||||
getCustomizationStream: vi.fn(),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcService,
|
||||
useValue: {
|
||||
getAuthorizationUrl: vi.fn(),
|
||||
handleCallback: vi.fn(),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcConfigPersistence,
|
||||
useValue: {
|
||||
getConfig: vi.fn().mockResolvedValue({
|
||||
defaultAllowedOrigins: [],
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: ConfigService,
|
||||
useValue: {
|
||||
get: vi.fn(),
|
||||
},
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
controller = module.get<RestController>(RestController);
|
||||
oidcService = module.get<OidcService>(OidcService);
|
||||
oidcConfig = module.get<OidcConfigPersistence>(OidcConfigPersistence);
|
||||
|
||||
mockReply = {
|
||||
status: vi.fn().mockReturnThis(),
|
||||
header: vi.fn().mockReturnThis(),
|
||||
send: vi.fn().mockReturnThis(),
|
||||
type: vi.fn().mockReturnThis(),
|
||||
};
|
||||
});
|
||||
|
||||
describe('oidcAuthorize', () => {
|
||||
describe('redirect URI validation', () => {
|
||||
beforeEach(() => {
|
||||
// Mock OidcRequestHandler.handleAuthorize to return a valid auth URL
|
||||
vi.spyOn(OidcRequestHandler, 'handleAuthorize').mockResolvedValue(
|
||||
'https://provider.com/authorize?client_id=test&redirect_uri=...'
|
||||
);
|
||||
});
|
||||
|
||||
it('should accept redirect_uri with same hostname but different port', async () => {
|
||||
const mockRequest = createMockRequest('unraid.mytailnet.ts.net');
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'https://unraid.mytailnet.ts.net:1443/graphql/api/auth/oidc/callback',
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(302);
|
||||
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalledWith(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'https://unraid.mytailnet.ts.net:1443/graphql/api/auth/oidc/callback',
|
||||
mockRequest,
|
||||
oidcService,
|
||||
expect.any(Logger)
|
||||
);
|
||||
});
|
||||
|
||||
it('should accept redirect_uri with same hostname and standard HTTPS port', async () => {
|
||||
const mockRequest = createMockRequest('unraid.mytailnet.ts.net');
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'https://unraid.mytailnet.ts.net/graphql/api/auth/oidc/callback',
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(302);
|
||||
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should accept redirect_uri with same hostname and explicit port 443', async () => {
|
||||
const mockRequest = createMockRequest('unraid.mytailnet.ts.net');
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'https://unraid.mytailnet.ts.net:443/graphql/api/auth/oidc/callback',
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(302);
|
||||
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should reject redirect_uri with different hostname', async () => {
|
||||
const mockRequest = createMockRequest('unraid.mytailnet.ts.net');
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'https://evil.com/graphql/api/auth/oidc/callback',
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(400);
|
||||
expect(mockReply.send).toHaveBeenCalledWith(
|
||||
expect.stringContaining(
|
||||
'Invalid redirect_uri: https://evil.com/graphql/api/auth/oidc/callback'
|
||||
)
|
||||
);
|
||||
expect(OidcRequestHandler.handleAuthorize).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should reject redirect_uri with subdomain difference', async () => {
|
||||
const mockRequest = createMockRequest('unraid.mytailnet.ts.net');
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'https://evil.unraid.mytailnet.ts.net/graphql/api/auth/oidc/callback',
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(400);
|
||||
expect(mockReply.send).toHaveBeenCalledWith(
|
||||
expect.stringContaining(
|
||||
'Invalid redirect_uri: https://evil.unraid.mytailnet.ts.net/graphql/api/auth/oidc/callback'
|
||||
)
|
||||
);
|
||||
expect(OidcRequestHandler.handleAuthorize).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle hostname from host header when hostname is not available', async () => {
|
||||
const mockRequest = createMockRequest(undefined, {
|
||||
host: 'unraid.mytailnet.ts.net:8080',
|
||||
});
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'https://unraid.mytailnet.ts.net:1443/graphql/api/auth/oidc/callback',
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(302);
|
||||
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should reject malformed redirect_uri', async () => {
|
||||
const mockRequest = createMockRequest('unraid.mytailnet.ts.net');
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'not-a-valid-url',
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(400);
|
||||
expect(mockReply.send).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Invalid redirect_uri: not-a-valid-url')
|
||||
);
|
||||
expect(OidcRequestHandler.handleAuthorize).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle case-insensitive hostname comparison', async () => {
|
||||
const mockRequest = createMockRequest('UnRaid.MyTailnet.TS.net');
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'https://unraid.mytailnet.ts.net:1443/graphql/api/auth/oidc/callback',
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(302);
|
||||
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should preserve exact redirect_uri including custom port in call to handleAuthorize', async () => {
|
||||
const mockRequest = createMockRequest('unraid.mytailnet.ts.net');
|
||||
const customRedirectUri =
|
||||
'https://unraid.mytailnet.ts.net:1443/graphql/api/auth/oidc/callback';
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
customRedirectUri,
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
// Verify the exact redirect URI with port is passed through
|
||||
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalledWith(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
customRedirectUri, // Should be exactly as provided, with :1443
|
||||
mockRequest,
|
||||
oidcService,
|
||||
expect.any(Logger)
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow localhost with different ports', async () => {
|
||||
const mockRequest = createMockRequest('localhost');
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'http://localhost:3000/graphql/api/auth/oidc/callback',
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(302);
|
||||
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalledWith(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'http://localhost:3000/graphql/api/auth/oidc/callback',
|
||||
mockRequest,
|
||||
oidcService,
|
||||
expect.any(Logger)
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow IP addresses with different ports', async () => {
|
||||
const mockRequest = createMockRequest('192.168.1.100');
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'http://192.168.1.100:8080/graphql/api/auth/oidc/callback',
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(302);
|
||||
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should accept redirect_uri with different hostname if in allowed origins', async () => {
|
||||
const mockRequest = createMockRequest('devgen-dev1.local');
|
||||
|
||||
// Mock the config to include the allowed origin
|
||||
vi.mocked(oidcConfig.getConfig).mockResolvedValueOnce({
|
||||
defaultAllowedOrigins: ['https://devgen-bad-dev1.local'],
|
||||
} as any);
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'https://devgen-bad-dev1.local/graphql/api/auth/oidc/callback',
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(302);
|
||||
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalledWith(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'https://devgen-bad-dev1.local/graphql/api/auth/oidc/callback',
|
||||
mockRequest,
|
||||
oidcService,
|
||||
expect.any(Logger)
|
||||
);
|
||||
});
|
||||
|
||||
describe('integration with centralized validator', () => {
|
||||
it('should use the same validation logic as validateRedirectUri function', async () => {
|
||||
const testCases = [
|
||||
{
|
||||
name: 'accepts HTTPS upgrade from allowed origins',
|
||||
requestHost: 'devgen-dev1.local',
|
||||
redirectUri: 'https://allowed-host.local/graphql/api/auth/oidc/callback',
|
||||
allowedOrigins: ['http://allowed-host.local'],
|
||||
expectedStatus: 302,
|
||||
shouldSucceed: true,
|
||||
},
|
||||
{
|
||||
name: 'rejects hostname not in allowed origins',
|
||||
requestHost: 'devgen-dev1.local',
|
||||
redirectUri: 'https://evil.com/graphql/api/auth/oidc/callback',
|
||||
allowedOrigins: ['https://good-host.local'],
|
||||
expectedStatus: 400,
|
||||
shouldSucceed: false,
|
||||
},
|
||||
{
|
||||
name: 'accepts multiple allowed origins',
|
||||
requestHost: 'devgen-dev1.local',
|
||||
redirectUri: 'https://second.local/graphql/api/auth/oidc/callback',
|
||||
allowedOrigins: [
|
||||
'https://first.local',
|
||||
'https://second.local',
|
||||
'https://third.local',
|
||||
],
|
||||
expectedStatus: 302,
|
||||
shouldSucceed: true,
|
||||
},
|
||||
{
|
||||
name: 'respects protocol and hostname from headers',
|
||||
requestHost: undefined,
|
||||
headers: {
|
||||
'x-forwarded-proto': 'https',
|
||||
'x-forwarded-host': 'proxy.local',
|
||||
},
|
||||
redirectUri: 'https://proxy.local/graphql/api/auth/oidc/callback',
|
||||
allowedOrigins: [],
|
||||
expectedStatus: 302,
|
||||
shouldSucceed: true,
|
||||
},
|
||||
];
|
||||
|
||||
for (const testCase of testCases) {
|
||||
// Reset mocks for each test case
|
||||
vi.clearAllMocks();
|
||||
|
||||
const mockRequest = createMockRequest(
|
||||
testCase.requestHost,
|
||||
testCase.headers || {}
|
||||
);
|
||||
|
||||
vi.mocked(oidcConfig.getConfig).mockResolvedValueOnce({
|
||||
defaultAllowedOrigins: testCase.allowedOrigins,
|
||||
} as any);
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
testCase.redirectUri,
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(testCase.expectedStatus);
|
||||
|
||||
if (testCase.shouldSucceed) {
|
||||
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalled();
|
||||
} else {
|
||||
expect(mockReply.send).toHaveBeenCalledWith(
|
||||
expect.stringContaining(testCase.redirectUri)
|
||||
);
|
||||
expect(OidcRequestHandler.handleAuthorize).not.toHaveBeenCalled();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
it('should handle edge cases consistently with centralized validator', async () => {
|
||||
// Test with empty allowed origins
|
||||
vi.mocked(oidcConfig.getConfig).mockResolvedValueOnce({
|
||||
defaultAllowedOrigins: [],
|
||||
} as any);
|
||||
|
||||
const mockRequest = createMockRequest('host.local');
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'https://different.local/graphql/api/auth/oidc/callback',
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(400);
|
||||
expect(mockReply.send).toHaveBeenCalledWith(
|
||||
expect.stringContaining('https://different.local/graphql/api/auth/oidc/callback')
|
||||
);
|
||||
});
|
||||
|
||||
it('should validate that error messages guide users to settings', async () => {
|
||||
vi.mocked(oidcConfig.getConfig).mockResolvedValueOnce({
|
||||
defaultAllowedOrigins: [],
|
||||
} as any);
|
||||
|
||||
const mockRequest = createMockRequest('host.local');
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'https://different.local/graphql/api/auth/oidc/callback',
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.send).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Settings → Management Access → Allowed Redirect URIs')
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('parameter validation', () => {
|
||||
it('should return 400 if redirect_uri is missing', async () => {
|
||||
const mockRequest = createMockRequest('unraid.local');
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
undefined as any,
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(400);
|
||||
// The controller catches validation errors and returns a generic message
|
||||
expect(mockReply.send).toHaveBeenCalledWith('Invalid provider or configuration');
|
||||
});
|
||||
|
||||
it('should return 400 if providerId is missing', async () => {
|
||||
const mockRequest = createMockRequest('unraid.local');
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
undefined as any,
|
||||
'test-state',
|
||||
'https://unraid.local/callback',
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(400);
|
||||
expect(mockReply.send).toHaveBeenCalledWith('Invalid provider or configuration');
|
||||
});
|
||||
|
||||
it('should return 400 if state is missing', async () => {
|
||||
const mockRequest = createMockRequest('unraid.local');
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
undefined as any,
|
||||
'https://unraid.local/callback',
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(400);
|
||||
expect(mockReply.send).toHaveBeenCalledWith('Invalid provider or configuration');
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -2,18 +2,25 @@ import { Controller, Get, Logger, Param, Query, Req, Res, UnauthorizedException
|
||||
|
||||
import { AuthAction, Resource } from '@unraid/shared/graphql.model.js';
|
||||
import { UsePermissions } from '@unraid/shared/use-permissions.directive.js';
|
||||
import escapeHtml from 'escape-html';
|
||||
|
||||
import type { FastifyReply, FastifyRequest } from '@app/unraid-api/types/fastify.js';
|
||||
import { Public } from '@app/unraid-api/auth/public.decorator.js';
|
||||
import { OidcAuthService } from '@app/unraid-api/graph/resolvers/sso/oidc-auth.service.js';
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
|
||||
import { OidcService } from '@app/unraid-api/graph/resolvers/sso/core/oidc.service.js';
|
||||
import { OidcRequestHandler } from '@app/unraid-api/graph/resolvers/sso/utils/oidc-request-handler.util.js';
|
||||
import { RestService } from '@app/unraid-api/rest/rest.service.js';
|
||||
import { validateRedirectUri } from '@app/unraid-api/utils/redirect-uri-validator.js';
|
||||
|
||||
@Controller()
|
||||
export class RestController {
|
||||
protected logger = new Logger(RestController.name);
|
||||
protected oidcLogger = new Logger('OidcRestController');
|
||||
|
||||
constructor(
|
||||
private readonly restService: RestService,
|
||||
private readonly oidcAuthService: OidcAuthService
|
||||
private readonly oidcService: OidcService,
|
||||
private readonly oidcConfig: OidcConfigPersistence
|
||||
) {}
|
||||
|
||||
@Get('/')
|
||||
@@ -65,38 +72,69 @@ export class RestController {
|
||||
async oidcAuthorize(
|
||||
@Param('providerId') providerId: string,
|
||||
@Query('state') state: string,
|
||||
@Query('redirect_uri') redirectUri: string,
|
||||
@Req() req: FastifyRequest,
|
||||
@Res() res: FastifyReply
|
||||
) {
|
||||
try {
|
||||
if (!state) {
|
||||
return res.status(400).send('State parameter is required');
|
||||
// Validate required parameters
|
||||
const params = OidcRequestHandler.validateAuthorizeParams(providerId, state, redirectUri);
|
||||
|
||||
// IMPORTANT: Use the redirect_uri from query params directly
|
||||
// Do NOT parse headers or try to build/validate against headers
|
||||
// The frontend provides the complete redirect_uri
|
||||
if (!params.redirectUri) {
|
||||
return res.status(400).send('redirect_uri parameter is required');
|
||||
}
|
||||
|
||||
// Get the host and protocol from the request headers
|
||||
const protocol = (req.headers['x-forwarded-proto'] as string) || req.protocol || 'http';
|
||||
const host = (req.headers['x-forwarded-host'] as string) || req.headers.host || undefined;
|
||||
const requestInfo = host ? `${protocol}://${host}` : undefined;
|
||||
// Security validation: validate redirect_uri with support for allowed origins
|
||||
const protocol = (req.headers['x-forwarded-proto'] as string) || 'http';
|
||||
const host = (req.headers['x-forwarded-host'] as string) || req.headers.host || req.hostname;
|
||||
|
||||
const authUrl = await this.oidcAuthService.getAuthorizationUrl(
|
||||
providerId,
|
||||
state,
|
||||
requestInfo
|
||||
// Get allowed origins from OIDC config
|
||||
const config = await this.oidcConfig.getConfig();
|
||||
const allowedOrigins = config?.defaultAllowedOrigins;
|
||||
|
||||
// Validate the redirect URI using the centralized validator
|
||||
const validation = validateRedirectUri(
|
||||
params.redirectUri,
|
||||
protocol,
|
||||
host,
|
||||
this.oidcLogger,
|
||||
allowedOrigins
|
||||
);
|
||||
|
||||
if (!validation.isValid) {
|
||||
this.oidcLogger.warn(`Invalid redirect_uri: ${validation.reason}`);
|
||||
return res
|
||||
.status(400)
|
||||
.send(
|
||||
`Invalid redirect_uri: ${escapeHtml(params.redirectUri)}. ${escapeHtml(validation.reason || 'Unknown validation error')}. Please add this callback URI to Settings → Management Access → Allowed Redirect URIs`
|
||||
);
|
||||
}
|
||||
|
||||
// Handle authorization flow using the exact redirect_uri from query params
|
||||
const authUrl = await OidcRequestHandler.handleAuthorize(
|
||||
params.providerId,
|
||||
params.state,
|
||||
params.redirectUri,
|
||||
req,
|
||||
this.oidcService,
|
||||
this.oidcLogger
|
||||
);
|
||||
this.logger.log(`Redirecting to OIDC provider: ${authUrl}`);
|
||||
|
||||
// Manually set redirect headers for better proxy compatibility
|
||||
res.status(302);
|
||||
res.header('Location', authUrl);
|
||||
return res.send();
|
||||
} catch (error: unknown) {
|
||||
this.logger.error(`OIDC authorize error for provider ${providerId}:`, error);
|
||||
this.oidcLogger.error(`OIDC authorize error for provider ${providerId}:`, error);
|
||||
|
||||
// Log more details about the error
|
||||
if (error instanceof Error) {
|
||||
this.logger.error(`Error message: ${error.message}`);
|
||||
this.oidcLogger.error(`Error message: ${error.message}`);
|
||||
if (error.stack) {
|
||||
this.logger.debug(`Stack trace: ${error.stack}`);
|
||||
this.oidcLogger.debug(`Stack trace: ${error.stack}`);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,32 +155,20 @@ export class RestController {
|
||||
@Res() res: FastifyReply
|
||||
) {
|
||||
try {
|
||||
if (!code || !state) {
|
||||
return res.status(400).send('Missing required parameters');
|
||||
}
|
||||
// Validate required parameters
|
||||
const params = OidcRequestHandler.validateCallbackParams(code, state);
|
||||
|
||||
// Extract provider ID from state
|
||||
const { providerId } = this.oidcAuthService.extractProviderFromState(state);
|
||||
|
||||
// Get the full callback URL as received, respecting reverse proxy headers
|
||||
const protocol = (req.headers['x-forwarded-proto'] as string) || req.protocol || 'http';
|
||||
const host =
|
||||
(req.headers['x-forwarded-host'] as string) || req.headers.host || 'localhost:3000';
|
||||
const fullUrl = `${protocol}://${host}${req.url}`;
|
||||
const requestInfo = `${protocol}://${host}`;
|
||||
|
||||
this.logger.debug(`Full callback URL from request: ${fullUrl}`);
|
||||
|
||||
const paddedToken = await this.oidcAuthService.handleCallback(
|
||||
providerId,
|
||||
code,
|
||||
state,
|
||||
requestInfo,
|
||||
fullUrl
|
||||
// Handle callback flow
|
||||
const result = await OidcRequestHandler.handleCallback(
|
||||
params.code,
|
||||
params.state,
|
||||
req,
|
||||
this.oidcService,
|
||||
this.oidcLogger
|
||||
);
|
||||
|
||||
// Redirect to login page with the token in hash to keep it out of server logs
|
||||
const loginUrl = `/login#token=${encodeURIComponent(paddedToken)}`;
|
||||
const loginUrl = `/login#token=${encodeURIComponent(result.paddedToken)}`;
|
||||
|
||||
// Manually set redirect headers for better proxy compatibility
|
||||
res.header('Cache-Control', 'no-store');
|
||||
@@ -152,16 +178,16 @@ export class RestController {
|
||||
res.header('Location', loginUrl);
|
||||
return res.send();
|
||||
} catch (error: unknown) {
|
||||
this.logger.error(`OIDC callback error: ${error}`);
|
||||
this.oidcLogger.error(`OIDC callback error: ${error}`);
|
||||
|
||||
// Use a generic error message to avoid leaking sensitive information
|
||||
const errorMessage = 'Authentication failed';
|
||||
|
||||
// Log detailed error for debugging but don't expose to user
|
||||
if (error instanceof UnauthorizedException) {
|
||||
this.logger.debug(`UnauthorizedException occurred during OIDC callback`);
|
||||
this.oidcLogger.debug(`UnauthorizedException occurred during OIDC callback`);
|
||||
} else if (error instanceof Error) {
|
||||
this.logger.debug(`Error during OIDC callback: ${error.message}`);
|
||||
this.oidcLogger.debug(`Error during OIDC callback: ${error.message}`);
|
||||
}
|
||||
|
||||
const loginUrl = `/login#error=${encodeURIComponent(errorMessage)}`;
|
||||
|
||||
406
api/src/unraid-api/utils/error-extractor.util.test.ts
Normal file
406
api/src/unraid-api/utils/error-extractor.util.test.ts
Normal file
@@ -0,0 +1,406 @@
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { ErrorExtractor } from '@app/unraid-api/utils/error-extractor.util.js';
|
||||
|
||||
describe('ErrorExtractor', () => {
|
||||
describe('extract', () => {
|
||||
it('should handle null and undefined errors', () => {
|
||||
const nullResult = ErrorExtractor.extract(null);
|
||||
expect(nullResult.message).toBe('Unknown error');
|
||||
expect(nullResult.type).toBe('Unknown');
|
||||
|
||||
const undefinedResult = ErrorExtractor.extract(undefined);
|
||||
expect(undefinedResult.message).toBe('Unknown error');
|
||||
expect(undefinedResult.type).toBe('Unknown');
|
||||
});
|
||||
|
||||
it('should extract basic Error properties', () => {
|
||||
const error = new Error('Test error message');
|
||||
const result = ErrorExtractor.extract(error);
|
||||
|
||||
expect(result.message).toBe('Test error message');
|
||||
expect(result.type).toBe('Error');
|
||||
expect(result.stack).toBeDefined();
|
||||
});
|
||||
|
||||
it('should extract custom error types', () => {
|
||||
class CustomError extends Error {}
|
||||
const error = new CustomError('Custom error');
|
||||
const result = ErrorExtractor.extract(error);
|
||||
|
||||
expect(result.type).toBe('CustomError');
|
||||
});
|
||||
|
||||
it('should extract error code', () => {
|
||||
const error: any = new Error('Error with code');
|
||||
error.code = 'ERR_CODE';
|
||||
const result = ErrorExtractor.extract(error);
|
||||
|
||||
expect(result.code).toBe('ERR_CODE');
|
||||
});
|
||||
|
||||
it('should extract HTTP response details', () => {
|
||||
const error: any = new Error('HTTP error');
|
||||
error.response = {
|
||||
status: 404,
|
||||
statusText: 'Not Found',
|
||||
body: { error: 'Resource not found' },
|
||||
headers: { 'content-type': 'application/json' },
|
||||
};
|
||||
|
||||
const result = ErrorExtractor.extract(error);
|
||||
|
||||
expect(result.status).toBe(404);
|
||||
expect(result.statusText).toBe('Not Found');
|
||||
expect(result.responseBody).toEqual({ error: 'Resource not found' });
|
||||
expect(result.responseHeaders).toEqual({ 'content-type': 'application/json' });
|
||||
});
|
||||
|
||||
it('should truncate long response body strings', () => {
|
||||
const error: any = new Error('Error with long body');
|
||||
const longString = 'x'.repeat(2000);
|
||||
error.body = longString;
|
||||
|
||||
const result = ErrorExtractor.extract(error);
|
||||
|
||||
expect(result.responseBody).toBe('x'.repeat(1000) + '... (truncated)');
|
||||
});
|
||||
|
||||
it('should extract OAuth error details', () => {
|
||||
const error: any = new Error('OAuth error');
|
||||
error.error = 'invalid_grant';
|
||||
error.error_description = 'The provided authorization code is invalid';
|
||||
|
||||
const result = ErrorExtractor.extract(error);
|
||||
|
||||
expect(result.oauthError).toBe('invalid_grant');
|
||||
expect(result.oauthErrorDescription).toBe('The provided authorization code is invalid');
|
||||
});
|
||||
|
||||
it('should extract cause chain', () => {
|
||||
const rootCause = new Error('Root cause');
|
||||
const middleCause: any = new Error('Middle cause');
|
||||
middleCause.cause = rootCause;
|
||||
const topError: any = new Error('Top error');
|
||||
topError.cause = middleCause;
|
||||
|
||||
const result = ErrorExtractor.extract(topError);
|
||||
|
||||
expect(result.causeChain).toHaveLength(2);
|
||||
expect(result.causeChain![0]).toEqual({
|
||||
depth: 1,
|
||||
type: 'Error',
|
||||
message: 'Middle cause',
|
||||
});
|
||||
expect(result.causeChain![1]).toEqual({
|
||||
depth: 2,
|
||||
type: 'Error',
|
||||
message: 'Root cause',
|
||||
});
|
||||
});
|
||||
|
||||
it('should limit cause chain depth', () => {
|
||||
// Create a deep nested error chain
|
||||
let deepestError: any = new Error('Level 10');
|
||||
|
||||
for (let i = 9; i >= 0; i--) {
|
||||
const error: any = new Error(`Level ${i}`);
|
||||
error.cause = deepestError;
|
||||
deepestError = error;
|
||||
}
|
||||
|
||||
const topError: any = new Error('Top');
|
||||
topError.cause = deepestError;
|
||||
|
||||
const result = ErrorExtractor.extract(topError);
|
||||
|
||||
expect(result.causeChain).toHaveLength(5); // MAX_CAUSE_DEPTH
|
||||
});
|
||||
|
||||
it('should extract cause with code', () => {
|
||||
const cause: any = new Error('Cause with code');
|
||||
cause.code = 'ECONNREFUSED';
|
||||
const error: any = new Error('Main error');
|
||||
error.cause = cause;
|
||||
|
||||
const result = ErrorExtractor.extract(error);
|
||||
|
||||
expect(result.causeChain![0].code).toBe('ECONNREFUSED');
|
||||
});
|
||||
|
||||
it('should extract additional properties', () => {
|
||||
const error: any = new Error('Error with extras');
|
||||
error.customProp1 = 'value1';
|
||||
error.customProp2 = 123;
|
||||
|
||||
const result = ErrorExtractor.extract(error);
|
||||
|
||||
expect(result.additionalProperties).toEqual({
|
||||
customProp1: 'value1',
|
||||
customProp2: 123,
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle string errors', () => {
|
||||
const result = ErrorExtractor.extract('String error message');
|
||||
|
||||
expect(result.message).toBe('String error message');
|
||||
expect(result.type).toBe('String');
|
||||
});
|
||||
|
||||
it('should handle object errors', () => {
|
||||
const error = { code: 'ERROR', message: 'Object error' };
|
||||
const result = ErrorExtractor.extract(error);
|
||||
|
||||
expect(result.message).toBe(JSON.stringify(error));
|
||||
expect(result.type).toBe('Object');
|
||||
});
|
||||
|
||||
it('should handle primitive errors', () => {
|
||||
const result = ErrorExtractor.extract(42);
|
||||
|
||||
expect(result.message).toBe('42');
|
||||
expect(result.type).toBe('number');
|
||||
});
|
||||
|
||||
it('should handle openid-client error structure', () => {
|
||||
const error: any = new Error('unexpected response content-type');
|
||||
error.code = 'OAUTH_RESPONSE_IS_NOT_JSON';
|
||||
error.response = {
|
||||
status: 200,
|
||||
headers: { 'content-type': 'text/html' },
|
||||
body: '<html>Error page</html>',
|
||||
};
|
||||
|
||||
const result = ErrorExtractor.extract(error);
|
||||
|
||||
expect(result.code).toBe('OAUTH_RESPONSE_IS_NOT_JSON');
|
||||
expect(result.responseHeaders!['content-type']).toBe('text/html');
|
||||
expect(result.responseBody).toContain('<html>');
|
||||
});
|
||||
});
|
||||
|
||||
describe('isOAuthResponseError', () => {
|
||||
it('should identify OAuth response errors by code', () => {
|
||||
const extracted = {
|
||||
message: 'Some error',
|
||||
type: 'Error',
|
||||
code: 'OAUTH_RESPONSE_IS_NOT_JSON',
|
||||
};
|
||||
|
||||
expect(ErrorExtractor.isOAuthResponseError(extracted)).toBe(true);
|
||||
});
|
||||
|
||||
it('should identify OAuth response errors by message', () => {
|
||||
const extracted = {
|
||||
message: 'unexpected response content-type from server',
|
||||
type: 'Error',
|
||||
};
|
||||
|
||||
expect(ErrorExtractor.isOAuthResponseError(extracted)).toBe(true);
|
||||
});
|
||||
|
||||
it('should identify parsing errors', () => {
|
||||
const extracted = {
|
||||
message: 'JSON parsing error occurred',
|
||||
type: 'Error',
|
||||
};
|
||||
|
||||
expect(ErrorExtractor.isOAuthResponseError(extracted)).toBe(true);
|
||||
});
|
||||
|
||||
it('should not identify non-OAuth errors', () => {
|
||||
const extracted = {
|
||||
message: 'Some other error',
|
||||
type: 'Error',
|
||||
code: 'OTHER_ERROR',
|
||||
};
|
||||
|
||||
expect(ErrorExtractor.isOAuthResponseError(extracted)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isJwtClaimError', () => {
|
||||
it('should identify JWT claim errors', () => {
|
||||
const extracted = {
|
||||
message: 'unexpected JWT claim value encountered',
|
||||
type: 'Error',
|
||||
};
|
||||
|
||||
expect(ErrorExtractor.isJwtClaimError(extracted)).toBe(true);
|
||||
});
|
||||
|
||||
it('should not identify non-JWT errors', () => {
|
||||
const extracted = {
|
||||
message: 'Some other error',
|
||||
type: 'Error',
|
||||
};
|
||||
|
||||
expect(ErrorExtractor.isJwtClaimError(extracted)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isNetworkError', () => {
|
||||
it('should identify network errors by code', () => {
|
||||
const codes = ['ECONNREFUSED', 'ENOTFOUND', 'ETIMEDOUT', 'ECONNRESET'];
|
||||
|
||||
for (const code of codes) {
|
||||
const extracted = {
|
||||
message: 'Error',
|
||||
type: 'Error',
|
||||
code,
|
||||
};
|
||||
|
||||
expect(ErrorExtractor.isNetworkError(extracted)).toBe(true);
|
||||
}
|
||||
});
|
||||
|
||||
it('should identify network errors by message', () => {
|
||||
const messages = ['network timeout occurred', 'failed to connect to server'];
|
||||
|
||||
for (const message of messages) {
|
||||
const extracted = {
|
||||
message,
|
||||
type: 'Error',
|
||||
};
|
||||
|
||||
expect(ErrorExtractor.isNetworkError(extracted)).toBe(true);
|
||||
}
|
||||
});
|
||||
|
||||
it('should not identify non-network errors', () => {
|
||||
const extracted = {
|
||||
message: 'Invalid credentials',
|
||||
type: 'Error',
|
||||
code: 'AUTH_ERROR',
|
||||
};
|
||||
|
||||
expect(ErrorExtractor.isNetworkError(extracted)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('formatForLogging', () => {
|
||||
it('should log basic error information', () => {
|
||||
const logger = {
|
||||
error: vi.fn(),
|
||||
debug: vi.fn(),
|
||||
};
|
||||
|
||||
const extracted = {
|
||||
message: 'Test error',
|
||||
type: 'CustomError',
|
||||
code: 'ERR_TEST',
|
||||
};
|
||||
|
||||
ErrorExtractor.formatForLogging(extracted, logger);
|
||||
|
||||
expect(logger.error).toHaveBeenCalledWith('Error type: CustomError');
|
||||
expect(logger.error).toHaveBeenCalledWith('Error message: Test error');
|
||||
expect(logger.error).toHaveBeenCalledWith('Error code: ERR_TEST');
|
||||
});
|
||||
|
||||
it('should log HTTP response details', () => {
|
||||
const logger = {
|
||||
error: vi.fn(),
|
||||
debug: vi.fn(),
|
||||
};
|
||||
|
||||
const extracted = {
|
||||
message: 'HTTP error',
|
||||
type: 'Error',
|
||||
status: 500,
|
||||
statusText: 'Internal Server Error',
|
||||
responseBody: { error: 'Server error' },
|
||||
responseHeaders: { 'content-type': 'application/json' },
|
||||
};
|
||||
|
||||
ErrorExtractor.formatForLogging(extracted, logger);
|
||||
|
||||
expect(logger.error).toHaveBeenCalledWith('HTTP Status: 500 Internal Server Error');
|
||||
expect(logger.error).toHaveBeenCalledWith('Response body: %o', { error: 'Server error' });
|
||||
expect(logger.error).toHaveBeenCalledWith('Response Content-Type: application/json');
|
||||
});
|
||||
|
||||
it('should log OAuth error details', () => {
|
||||
const logger = {
|
||||
error: vi.fn(),
|
||||
debug: vi.fn(),
|
||||
};
|
||||
|
||||
const extracted = {
|
||||
message: 'OAuth error',
|
||||
type: 'Error',
|
||||
oauthError: 'invalid_client',
|
||||
oauthErrorDescription: 'Client authentication failed',
|
||||
};
|
||||
|
||||
ErrorExtractor.formatForLogging(extracted, logger);
|
||||
|
||||
expect(logger.error).toHaveBeenCalledWith('OAuth error: invalid_client');
|
||||
expect(logger.error).toHaveBeenCalledWith(
|
||||
'OAuth error description: Client authentication failed'
|
||||
);
|
||||
});
|
||||
|
||||
it('should log cause chain', () => {
|
||||
const logger = {
|
||||
error: vi.fn(),
|
||||
debug: vi.fn(),
|
||||
};
|
||||
|
||||
const extracted = {
|
||||
message: 'Top error',
|
||||
type: 'Error',
|
||||
causeChain: [
|
||||
{ depth: 1, type: 'Error', message: 'Cause 1', code: 'CODE1' },
|
||||
{ depth: 2, type: 'Error', message: 'Cause 2' },
|
||||
],
|
||||
};
|
||||
|
||||
ErrorExtractor.formatForLogging(extracted, logger);
|
||||
|
||||
expect(logger.error).toHaveBeenCalledWith('Error cause chain:');
|
||||
expect(logger.error).toHaveBeenCalledWith(' [Cause 1] Error: Cause 1');
|
||||
expect(logger.error).toHaveBeenCalledWith(' [Cause 1] Code: CODE1');
|
||||
expect(logger.error).toHaveBeenCalledWith(' [Cause 2] Error: Cause 2');
|
||||
});
|
||||
|
||||
it('should log additional properties and stack in debug', () => {
|
||||
const logger = {
|
||||
error: vi.fn(),
|
||||
debug: vi.fn(),
|
||||
};
|
||||
|
||||
const extracted = {
|
||||
message: 'Error',
|
||||
type: 'Error',
|
||||
additionalProperties: { custom: 'value' },
|
||||
stack: 'Stack trace here',
|
||||
};
|
||||
|
||||
ErrorExtractor.formatForLogging(extracted, logger);
|
||||
|
||||
expect(logger.debug).toHaveBeenCalledWith('Additional error properties: %o', {
|
||||
custom: 'value',
|
||||
});
|
||||
expect(logger.debug).toHaveBeenCalledWith('Stack trace: Stack trace here');
|
||||
});
|
||||
|
||||
it('should handle case-insensitive Content-Type header', () => {
|
||||
const logger = {
|
||||
error: vi.fn(),
|
||||
debug: vi.fn(),
|
||||
};
|
||||
|
||||
const extracted = {
|
||||
message: 'Error',
|
||||
type: 'Error',
|
||||
responseHeaders: { 'Content-Type': 'text/html' },
|
||||
};
|
||||
|
||||
ErrorExtractor.formatForLogging(extracted, logger);
|
||||
|
||||
expect(logger.error).toHaveBeenCalledWith('Response Content-Type: text/html');
|
||||
});
|
||||
});
|
||||
});
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user