feat(api): enhance OIDC redirect URI handling in service and tests (#1618)

- Updated `getRedirectUri` method in `OidcAuthService` to handle various
edge cases for redirect URIs, including full URIs, malformed URLs, and
default ports.
- Added comprehensive tests for `OidcAuthService` to validate redirect
URI construction and error handling.
- Modified `RestController` to utilize `redirect_uri` query parameter
for authorization requests.
- Updated frontend components to include `redirect_uri` in authorization
URLs, ensuring correct handling of different protocols and ports.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Stronger OIDC redirect_uri validation and an admin GraphQL endpoint to
view full OIDC configuration.
* OIDC Debug Logs UI (panel, button, modal), enhanced log viewer with
presets/filters, ANSI-colored rendering, and a File Viewer component.
* New GraphQL queries to list and fetch config files; API Config
Download page.

* **Refactor**
* Centralized, modular OIDC flows and safer redirect handling;
topic-based log subscriptions with a watcher manager for scalable live
logs.

* **Documentation**
  * Cache TTL guidance clarified to use milliseconds.

* **Chores**
* Added ansi_up and escape-html deps; improved log formatting; added
root codegen script.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
This commit is contained in:
Eli Bosley
2025-09-02 10:40:20 -04:00
committed by GitHub
parent 6356f9c41d
commit 4e945f5f56
100 changed files with 9543 additions and 3396 deletions

View File

@@ -157,4 +157,7 @@ Enables GraphQL playground at `http://tower.local/graphql`
- We are using tailwind v4 we do not need a tailwind config anymore - We are using tailwind v4 we do not need a tailwind config anymore
- always search the internet for tailwind v4 documentation when making tailwind related style changes - always search the internet for tailwind v4 documentation when making tailwind related style changes
- never run or restart the API server or web server. I will handle the lifecylce, simply wait and ask me to do this for you - never run or restart the API server or web server. I will handle the lifecycle, simply wait and ask me to do this for you
- Never use the `any` type. Always prefer proper typing
- Avoid using casting whenever possible, prefer proper typing from the start
- **IMPORTANT:** cache-manager v7 expects TTL values in **milliseconds**, not seconds. Always use milliseconds when setting cache TTL (e.g., 600000 for 10 minutes, not 600)

View File

@@ -17,5 +17,6 @@
], ],
"buttonText": "Login With Unraid.net" "buttonText": "Login With Unraid.net"
} }
] ],
"defaultAllowedOrigins": []
} }

View File

@@ -1798,6 +1798,8 @@ type Server implements Node {
guid: String! guid: String!
apikey: String! apikey: String!
name: String! name: String!
"""Whether this server is online or offline"""
status: ServerStatus! status: ServerStatus!
wanip: String! wanip: String!
lanip: String! lanip: String!
@@ -1854,7 +1856,7 @@ type OidcProvider {
""" """
OIDC issuer URL (e.g., https://accounts.google.com). Required for auto-discovery via /.well-known/openid-configuration OIDC issuer URL (e.g., https://accounts.google.com). Required for auto-discovery via /.well-known/openid-configuration
""" """
issuer: String! issuer: String
""" """
OAuth2 authorization endpoint URL. If omitted, will be auto-discovered from issuer/.well-known/openid-configuration OAuth2 authorization endpoint URL. If omitted, will be auto-discovered from issuer/.well-known/openid-configuration
@@ -1907,6 +1909,16 @@ enum AuthorizationRuleMode {
AND AND
} }
type OidcConfiguration {
"""List of configured OIDC providers"""
providers: [OidcProvider!]!
"""
Default allowed redirect origins that apply to all OIDC providers (e.g., Tailscale domains)
"""
defaultAllowedOrigins: [String!]
}
type OidcSessionValidation { type OidcSessionValidation {
valid: Boolean! valid: Boolean!
username: String username: String
@@ -2307,8 +2319,6 @@ type Query {
getApiKeyCreationFormSchema: ApiKeyFormSettings! getApiKeyCreationFormSchema: ApiKeyFormSettings!
config: Config! config: Config!
flash: Flash! flash: Flash!
logFiles: [LogFile!]!
logFile(path: String!, lines: Int, startLine: Int): LogFileContent!
me: UserAccount! me: UserAccount!
"""Get all notifications""" """Get all notifications"""
@@ -2335,6 +2345,8 @@ type Query {
disk(id: PrefixedID!): Disk! disk(id: PrefixedID!): Disk!
rclone: RCloneBackupSettings! rclone: RCloneBackupSettings!
info: Info! info: Info!
logFiles: [LogFile!]!
logFile(path: String!, lines: Int, startLine: Int): LogFileContent!
settings: Settings! settings: Settings!
isSSOEnabled: Boolean! isSSOEnabled: Boolean!
@@ -2347,6 +2359,9 @@ type Query {
"""Get a specific OIDC provider by ID""" """Get a specific OIDC provider by ID"""
oidcProvider(id: PrefixedID!): OidcProvider oidcProvider(id: PrefixedID!): OidcProvider
"""Get the full OIDC configuration (admin only)"""
oidcConfiguration: OidcConfiguration!
"""Validate an OIDC session token (internal use for CLI validation)""" """Validate an OIDC session token (internal use for CLI validation)"""
validateOidcSession(token: String!): OidcSessionValidation! validateOidcSession(token: String!): OidcSessionValidation!
metrics: Metrics! metrics: Metrics!
@@ -2590,13 +2605,13 @@ input AccessUrlInput {
} }
type Subscription { type Subscription {
logFile(path: String!): LogFileContent!
notificationAdded: Notification! notificationAdded: Notification!
notificationsOverview: NotificationOverview! notificationsOverview: NotificationOverview!
ownerSubscription: Owner! ownerSubscription: Owner!
serversSubscription: Server! serversSubscription: Server!
parityHistorySubscription: ParityCheck! parityHistorySubscription: ParityCheck!
arraySubscription: UnraidArray! arraySubscription: UnraidArray!
logFile(path: String!): LogFileContent!
systemMetricsCpu: CpuUtilization! systemMetricsCpu: CpuUtilization!
systemMetricsMemory: MemoryUtilization! systemMetricsMemory: MemoryUtilization!
upsUpdates: UPSDevice! upsUpdates: UPSDevice!

View File

@@ -99,6 +99,7 @@
"diff": "8.0.2", "diff": "8.0.2",
"dockerode": "4.0.7", "dockerode": "4.0.7",
"dotenv": "17.2.1", "dotenv": "17.2.1",
"escape-html": "1.0.3",
"execa": "9.6.0", "execa": "9.6.0",
"exit-hook": "4.0.0", "exit-hook": "4.0.0",
"fastify": "5.5.0", "fastify": "5.5.0",

View File

@@ -29,8 +29,24 @@ const stream = SUPPRESS_LOGS
singleLine: true, singleLine: true,
hideObject: false, hideObject: false,
colorize: true, colorize: true,
colorizeObjects: true,
levelFirst: false,
ignore: 'hostname,pid', ignore: 'hostname,pid',
destination: logDestination, destination: logDestination,
translateTime: 'HH:mm:ss',
customPrettifiers: {
time: (timestamp: string | object) => `[${timestamp}`,
level: (logLevel: string | object, key: string, log: any, extras: any) => {
// Use labelColorized which preserves the colors
const { labelColorized } = extras;
const context = log.context || log.logger || 'app';
return `${labelColorized} ${context}]`;
},
},
messageFormat: (log: any, messageKey: string) => {
const msg = log[messageKey] || log.msg || '';
return msg;
},
}) })
: logDestination; : logDestination;

View File

@@ -13,10 +13,11 @@ export const pubsub = new PubSub({ eventEmitter });
/** /**
* Create a pubsub subscription. * Create a pubsub subscription.
* @param channel The pubsub channel to subscribe to. * @param channel The pubsub channel to subscribe to. Can be either a predefined GRAPHQL_PUBSUB_CHANNEL
* or a dynamic string for runtime-generated topics (e.g., log file paths like "LOG_FILE:/var/log/test.log")
*/ */
export const createSubscription = <T = any>( export const createSubscription = <T = any>(
channel: GRAPHQL_PUBSUB_CHANNEL channel: GRAPHQL_PUBSUB_CHANNEL | string
): AsyncIterableIterator<T> => { ): AsyncIterableIterator<T> => {
return pubsub.asyncIterableIterator<T>(channel); return pubsub.asyncIterableIterator<T>(channel);
}; };

View File

@@ -1,3 +1,4 @@
import { CacheModule } from '@nestjs/cache-manager';
import { Test } from '@nestjs/testing'; import { Test } from '@nestjs/testing';
import { describe, expect, it } from 'vitest'; import { describe, expect, it } from 'vitest';
@@ -9,7 +10,7 @@ describe('Module Dependencies Integration', () => {
let module; let module;
try { try {
module = await Test.createTestingModule({ module = await Test.createTestingModule({
imports: [RestModule], imports: [CacheModule.register({ isGlobal: true }), RestModule],
}).compile(); }).compile();
expect(module).toBeDefined(); expect(module).toBeDefined();

View File

@@ -34,6 +34,15 @@ import { UnraidFileModifierModule } from '@app/unraid-api/unraid-file-modifier/u
req: () => undefined, req: () => undefined,
res: () => undefined, res: () => undefined,
}, },
formatters: {
log: (obj) => {
// Map NestJS context to Pino context field for pino-pretty
if (obj.context && !obj.logger) {
return { ...obj, logger: obj.context };
}
return obj;
},
},
}, },
}), }),
AuthModule, AuthModule,

View File

@@ -448,6 +448,20 @@ export enum ConfigErrorState {
WITHDRAWN = 'WITHDRAWN' WITHDRAWN = 'WITHDRAWN'
} }
export type ConfigFile = {
__typename?: 'ConfigFile';
content: Scalars['String']['output'];
name: Scalars['String']['output'];
path: Scalars['String']['output'];
/** Human-readable file size (e.g., "1.5 KB", "2.3 MB") */
sizeReadable: Scalars['String']['output'];
};
export type ConfigFilesResponse = {
__typename?: 'ConfigFilesResponse';
files: Array<ConfigFile>;
};
export type Connect = Node & { export type Connect = Node & {
__typename?: 'Connect'; __typename?: 'Connect';
/** The status of dynamic remote access */ /** The status of dynamic remote access */
@@ -1432,6 +1446,14 @@ export type OidcAuthorizationRule = {
value: Array<Scalars['String']['output']>; value: Array<Scalars['String']['output']>;
}; };
export type OidcConfiguration = {
__typename?: 'OidcConfiguration';
/** Default allowed redirect origins that apply to all OIDC providers (e.g., Tailscale domains) */
defaultAllowedOrigins?: Maybe<Array<Scalars['String']['output']>>;
/** List of configured OIDC providers */
providers: Array<OidcProvider>;
};
export type OidcProvider = { export type OidcProvider = {
__typename?: 'OidcProvider'; __typename?: 'OidcProvider';
/** OAuth2 authorization endpoint URL. If omitted, will be auto-discovered from issuer/.well-known/openid-configuration */ /** OAuth2 authorization endpoint URL. If omitted, will be auto-discovered from issuer/.well-known/openid-configuration */
@@ -1455,7 +1477,7 @@ export type OidcProvider = {
/** The unique identifier for the OIDC provider */ /** The unique identifier for the OIDC provider */
id: Scalars['PrefixedID']['output']; id: Scalars['PrefixedID']['output'];
/** OIDC issuer URL (e.g., https://accounts.google.com). Required for auto-discovery via /.well-known/openid-configuration */ /** OIDC issuer URL (e.g., https://accounts.google.com). Required for auto-discovery via /.well-known/openid-configuration */
issuer: Scalars['String']['output']; issuer?: Maybe<Scalars['String']['output']>;
/** JSON Web Key Set URI for token validation. If omitted, will be auto-discovered from issuer/.well-known/openid-configuration */ /** JSON Web Key Set URI for token validation. If omitted, will be auto-discovered from issuer/.well-known/openid-configuration */
jwksUri?: Maybe<Scalars['String']['output']>; jwksUri?: Maybe<Scalars['String']['output']>;
/** Display name of the OIDC provider */ /** Display name of the OIDC provider */
@@ -1623,6 +1645,7 @@ export type PublicPartnerInfo = {
export type Query = { export type Query = {
__typename?: 'Query'; __typename?: 'Query';
allConfigFiles: ConfigFilesResponse;
apiKey?: Maybe<ApiKey>; apiKey?: Maybe<ApiKey>;
/** All possible permissions for API keys */ /** All possible permissions for API keys */
apiKeyPossiblePermissions: Array<Permission>; apiKeyPossiblePermissions: Array<Permission>;
@@ -1632,6 +1655,7 @@ export type Query = {
array: UnraidArray; array: UnraidArray;
cloud: Cloud; cloud: Cloud;
config: Config; config: Config;
configFile?: Maybe<ConfigFile>;
connect: Connect; connect: Connect;
customization?: Maybe<Customization>; customization?: Maybe<Customization>;
disk: Disk; disk: Disk;
@@ -1654,6 +1678,8 @@ export type Query = {
network: Network; network: Network;
/** Get all notifications */ /** Get all notifications */
notifications: Notifications; notifications: Notifications;
/** Get the full OIDC configuration (admin only) */
oidcConfiguration: OidcConfiguration;
/** Get a specific OIDC provider by ID */ /** Get a specific OIDC provider by ID */
oidcProvider?: Maybe<OidcProvider>; oidcProvider?: Maybe<OidcProvider>;
/** Get all configured OIDC providers (admin only) */ /** Get all configured OIDC providers (admin only) */
@@ -1693,6 +1719,11 @@ export type QueryApiKeyArgs = {
}; };
export type QueryConfigFileArgs = {
name: Scalars['String']['input'];
};
export type QueryDiskArgs = { export type QueryDiskArgs = {
id: Scalars['PrefixedID']['input']; id: Scalars['PrefixedID']['input'];
}; };
@@ -1933,6 +1964,7 @@ export type Server = Node & {
name: Scalars['String']['output']; name: Scalars['String']['output'];
owner: ProfileModel; owner: ProfileModel;
remoteurl: Scalars['String']['output']; remoteurl: Scalars['String']['output'];
/** Whether this server is online or offline */
status: ServerStatus; status: ServerStatus;
wanip: Scalars['String']['output']; wanip: Scalars['String']['output'];
}; };

View File

@@ -0,0 +1,213 @@
import { Test, TestingModule } from '@nestjs/testing';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import {
LogWatcherManager,
WatcherState,
} from '@app/unraid-api/graph/resolvers/logs/log-watcher-manager.service.js';
describe('LogWatcherManager', () => {
let manager: LogWatcherManager;
let mockWatcher: any;
beforeEach(async () => {
const module: TestingModule = await Test.createTestingModule({
providers: [LogWatcherManager],
}).compile();
manager = module.get<LogWatcherManager>(LogWatcherManager);
mockWatcher = {
close: vi.fn(),
on: vi.fn(),
};
});
describe('state management', () => {
it('should set watcher as initializing', () => {
manager.setInitializing('test-key');
const entry = manager.getEntry('test-key');
expect(entry).toBeDefined();
expect(entry?.state).toBe(WatcherState.INITIALIZING);
});
it('should set watcher as active with position', () => {
manager.setActive('test-key', mockWatcher as any, 1000);
const entry = manager.getEntry('test-key');
expect(entry).toBeDefined();
expect(entry?.state).toBe(WatcherState.ACTIVE);
if (manager.isActive(entry)) {
expect(entry.watcher).toBe(mockWatcher);
expect(entry.position).toBe(1000);
}
});
it('should set watcher as stopping', () => {
manager.setStopping('test-key');
const entry = manager.getEntry('test-key');
expect(entry).toBeDefined();
expect(entry?.state).toBe(WatcherState.STOPPING);
});
});
describe('isWatchingOrInitializing', () => {
it('should return true for initializing watcher', () => {
manager.setInitializing('test-key');
expect(manager.isWatchingOrInitializing('test-key')).toBe(true);
});
it('should return true for active watcher', () => {
manager.setActive('test-key', mockWatcher as any, 0);
expect(manager.isWatchingOrInitializing('test-key')).toBe(true);
});
it('should return false for stopping watcher', () => {
manager.setStopping('test-key');
expect(manager.isWatchingOrInitializing('test-key')).toBe(false);
});
it('should return false for non-existent watcher', () => {
expect(manager.isWatchingOrInitializing('test-key')).toBe(false);
});
});
describe('handlePostInitialization', () => {
it('should activate watcher when not stopped', () => {
manager.setInitializing('test-key');
const result = manager.handlePostInitialization('test-key', mockWatcher as any, 500);
expect(result).toBe(true);
expect(mockWatcher.close).not.toHaveBeenCalled();
const entry = manager.getEntry('test-key');
expect(entry?.state).toBe(WatcherState.ACTIVE);
if (manager.isActive(entry)) {
expect(entry.position).toBe(500);
}
});
it('should cleanup watcher when marked as stopping', () => {
manager.setStopping('test-key');
const result = manager.handlePostInitialization('test-key', mockWatcher as any, 500);
expect(result).toBe(false);
expect(mockWatcher.close).toHaveBeenCalled();
expect(manager.getEntry('test-key')).toBeUndefined();
});
it('should cleanup watcher when entry is missing', () => {
const result = manager.handlePostInitialization('test-key', mockWatcher as any, 500);
expect(result).toBe(false);
expect(mockWatcher.close).toHaveBeenCalled();
expect(manager.getEntry('test-key')).toBeUndefined();
});
});
describe('stopWatcher', () => {
it('should mark initializing watcher as stopping', () => {
manager.setInitializing('test-key');
manager.stopWatcher('test-key');
const entry = manager.getEntry('test-key');
expect(entry?.state).toBe(WatcherState.STOPPING);
});
it('should close and remove active watcher', () => {
manager.setActive('test-key', mockWatcher as any, 0);
manager.stopWatcher('test-key');
expect(mockWatcher.close).toHaveBeenCalled();
expect(manager.getEntry('test-key')).toBeUndefined();
});
it('should do nothing for non-existent watcher', () => {
manager.stopWatcher('test-key');
expect(mockWatcher.close).not.toHaveBeenCalled();
});
});
describe('position management', () => {
it('should update position for active watcher', () => {
manager.setActive('test-key', mockWatcher as any, 100);
manager.updatePosition('test-key', 200);
const position = manager.getPosition('test-key');
expect(position).toBe(200);
});
it('should not update position for non-active watcher', () => {
manager.setInitializing('test-key');
manager.updatePosition('test-key', 200);
const position = manager.getPosition('test-key');
expect(position).toBeUndefined();
});
it('should get position for active watcher', () => {
manager.setActive('test-key', mockWatcher as any, 300);
expect(manager.getPosition('test-key')).toBe(300);
});
it('should return undefined for non-active watcher', () => {
manager.setStopping('test-key');
expect(manager.getPosition('test-key')).toBeUndefined();
});
});
describe('stopAllWatchers', () => {
it('should close all active watchers and clear map', () => {
const mockWatcher1 = { close: vi.fn() };
const mockWatcher2 = { close: vi.fn() };
const mockWatcher3 = { close: vi.fn() };
manager.setActive('key1', mockWatcher1 as any, 0);
manager.setInitializing('key2');
manager.setActive('key3', mockWatcher2 as any, 0);
manager.setStopping('key4');
manager.setActive('key5', mockWatcher3 as any, 0);
manager.stopAllWatchers();
expect(mockWatcher1.close).toHaveBeenCalled();
expect(mockWatcher2.close).toHaveBeenCalled();
expect(mockWatcher3.close).toHaveBeenCalled();
expect(manager.getEntry('key1')).toBeUndefined();
expect(manager.getEntry('key2')).toBeUndefined();
expect(manager.getEntry('key3')).toBeUndefined();
expect(manager.getEntry('key4')).toBeUndefined();
expect(manager.getEntry('key5')).toBeUndefined();
});
});
describe('in-flight processing', () => {
it('should prevent concurrent processing', () => {
manager.setActive('test-key', mockWatcher as any, 0);
// First call should succeed
expect(manager.startProcessing('test-key')).toBe(true);
// Second call should fail (already in flight)
expect(manager.startProcessing('test-key')).toBe(false);
// After finishing, should be able to start again
manager.finishProcessing('test-key');
expect(manager.startProcessing('test-key')).toBe(true);
});
it('should not start processing for non-active watcher', () => {
manager.setInitializing('test-key');
expect(manager.startProcessing('test-key')).toBe(false);
manager.setStopping('test-key');
expect(manager.startProcessing('test-key')).toBe(false);
});
it('should handle finish processing for non-existent watcher', () => {
// Should not throw
expect(() => manager.finishProcessing('non-existent')).not.toThrow();
});
});
});

View File

@@ -0,0 +1,183 @@
import { Injectable, Logger } from '@nestjs/common';
import * as chokidar from 'chokidar';
export enum WatcherState {
INITIALIZING = 'initializing',
ACTIVE = 'active',
STOPPING = 'stopping',
}
export type WatcherEntry =
| { state: WatcherState.INITIALIZING }
| { state: WatcherState.ACTIVE; watcher: chokidar.FSWatcher; position: number; inFlight: boolean }
| { state: WatcherState.STOPPING };
/**
* Service responsible for managing log file watchers and their lifecycle.
* Handles race conditions during watcher initialization and cleanup.
*/
@Injectable()
export class LogWatcherManager {
private readonly logger = new Logger(LogWatcherManager.name);
private readonly watchers = new Map<string, WatcherEntry>();
/**
* Set a watcher as initializing
*/
setInitializing(key: string): void {
this.watchers.set(key, { state: WatcherState.INITIALIZING });
}
/**
* Set a watcher as active with its FSWatcher and position
*/
setActive(key: string, watcher: chokidar.FSWatcher, position: number): void {
this.watchers.set(key, { state: WatcherState.ACTIVE, watcher, position, inFlight: false });
}
/**
* Mark a watcher as stopping (used during initialization race conditions)
*/
setStopping(key: string): void {
this.watchers.set(key, { state: WatcherState.STOPPING });
}
/**
* Get a watcher entry by key
*/
getEntry(key: string): WatcherEntry | undefined {
return this.watchers.get(key);
}
/**
* Remove a watcher entry
*/
removeEntry(key: string): void {
this.watchers.delete(key);
}
/**
* Check if a watcher is active and return typed entry
*/
isActive(entry: WatcherEntry | undefined): entry is {
state: WatcherState.ACTIVE;
watcher: chokidar.FSWatcher;
position: number;
inFlight: boolean;
} {
return entry?.state === WatcherState.ACTIVE;
}
/**
* Check if a watcher exists and is either initializing or active
*/
isWatchingOrInitializing(key: string): boolean {
const entry = this.getEntry(key);
return (
entry !== undefined &&
(entry.state === WatcherState.ACTIVE || entry.state === WatcherState.INITIALIZING)
);
}
/**
* Handle cleanup after initialization completes.
* Returns true if the watcher should continue, false if it should be cleaned up.
*/
handlePostInitialization(key: string, watcher: chokidar.FSWatcher, position: number): boolean {
const currentEntry = this.getEntry(key);
if (!currentEntry || currentEntry.state === WatcherState.STOPPING) {
// We were stopped during initialization, clean up immediately
this.logger.debug(`Watcher for ${key} was stopped during initialization, cleaning up`);
watcher.close();
this.removeEntry(key);
return false;
}
// Store the active watcher and position
this.setActive(key, watcher, position);
return true;
}
/**
* Stop a watcher, handling all possible states
*/
stopWatcher(key: string): void {
const entry = this.getEntry(key);
if (!entry) {
return;
}
if (entry.state === WatcherState.INITIALIZING) {
// Mark as stopping so the initialization will clean up
this.setStopping(key);
this.logger.debug(`Marked watcher as stopping during initialization: ${key}`);
} else if (entry.state === WatcherState.ACTIVE) {
// Close the active watcher
entry.watcher.close();
this.removeEntry(key);
this.logger.debug(`Stopped active watcher: ${key}`);
}
}
/**
* Update the position for an active watcher
*/
updatePosition(key: string, newPosition: number): void {
const entry = this.getEntry(key);
if (this.isActive(entry)) {
entry.position = newPosition;
}
}
/**
* Start processing a change event (set inFlight to true)
* Returns true if processing can proceed, false if already in flight
*/
startProcessing(key: string): boolean {
const entry = this.getEntry(key);
if (this.isActive(entry)) {
if (entry.inFlight) {
return false; // Already processing
}
entry.inFlight = true;
return true;
}
return false;
}
/**
* Finish processing a change event (set inFlight to false)
*/
finishProcessing(key: string): void {
const entry = this.getEntry(key);
if (this.isActive(entry)) {
entry.inFlight = false;
}
}
/**
* Get the position for an active watcher
*/
getPosition(key: string): number | undefined {
const entry = this.getEntry(key);
if (this.isActive(entry)) {
return entry.position;
}
return undefined;
}
/**
* Clean up all watchers (useful for module cleanup)
*/
stopAllWatchers(): void {
for (const entry of this.watchers.values()) {
if (this.isActive(entry)) {
entry.watcher.close();
}
}
this.watchers.clear();
}
}

View File

@@ -1,10 +1,13 @@
import { Module } from '@nestjs/common'; import { Module } from '@nestjs/common';
import { LogWatcherManager } from '@app/unraid-api/graph/resolvers/logs/log-watcher-manager.service.js';
import { LogsResolver } from '@app/unraid-api/graph/resolvers/logs/logs.resolver.js'; import { LogsResolver } from '@app/unraid-api/graph/resolvers/logs/logs.resolver.js';
import { LogsService } from '@app/unraid-api/graph/resolvers/logs/logs.service.js'; import { LogsService } from '@app/unraid-api/graph/resolvers/logs/logs.service.js';
import { ServicesModule } from '@app/unraid-api/graph/services/services.module.js';
@Module({ @Module({
providers: [LogsResolver, LogsService], imports: [ServicesModule],
exports: [LogsService], providers: [LogsResolver, LogsService, LogWatcherManager],
exports: [LogsService, LogWatcherManager],
}) })
export class LogsModule {} export class LogsModule {}

View File

@@ -1,9 +1,10 @@
import { Test, TestingModule } from '@nestjs/testing'; import { Test, TestingModule } from '@nestjs/testing';
import { beforeEach, describe, expect, it } from 'vitest'; import { beforeEach, describe, expect, it, vi } from 'vitest';
import { LogsResolver } from '@app/unraid-api/graph/resolvers/logs/logs.resolver.js'; import { LogsResolver } from '@app/unraid-api/graph/resolvers/logs/logs.resolver.js';
import { LogsService } from '@app/unraid-api/graph/resolvers/logs/logs.service.js'; import { LogsService } from '@app/unraid-api/graph/resolvers/logs/logs.service.js';
import { SubscriptionHelperService } from '@app/unraid-api/graph/services/subscription-helper.service.js';
describe('LogsResolver', () => { describe('LogsResolver', () => {
let resolver: LogsResolver; let resolver: LogsResolver;
@@ -18,6 +19,13 @@ describe('LogsResolver', () => {
// Add mock implementations for service methods used by resolver // Add mock implementations for service methods used by resolver
}, },
}, },
{
provide: SubscriptionHelperService,
useValue: {
// Add mock implementations for subscription helper methods
createTrackedSubscription: vi.fn(),
},
},
], ],
}).compile(); }).compile();
resolver = module.get<LogsResolver>(LogsResolver); resolver = module.get<LogsResolver>(LogsResolver);

View File

@@ -3,13 +3,16 @@ import { Args, Int, Query, Resolver, Subscription } from '@nestjs/graphql';
import { AuthAction, Resource } from '@unraid/shared/graphql.model.js'; import { AuthAction, Resource } from '@unraid/shared/graphql.model.js';
import { UsePermissions } from '@unraid/shared/use-permissions.directive.js'; import { UsePermissions } from '@unraid/shared/use-permissions.directive.js';
import { createSubscription, PUBSUB_CHANNEL } from '@app/core/pubsub.js';
import { LogFile, LogFileContent } from '@app/unraid-api/graph/resolvers/logs/logs.model.js'; import { LogFile, LogFileContent } from '@app/unraid-api/graph/resolvers/logs/logs.model.js';
import { LogsService } from '@app/unraid-api/graph/resolvers/logs/logs.service.js'; import { LogsService } from '@app/unraid-api/graph/resolvers/logs/logs.service.js';
import { SubscriptionHelperService } from '@app/unraid-api/graph/services/subscription-helper.service.js';
@Resolver(() => LogFile) @Resolver(() => LogFile)
export class LogsResolver { export class LogsResolver {
constructor(private readonly logsService: LogsService) {} constructor(
private readonly logsService: LogsService,
private readonly subscriptionHelper: SubscriptionHelperService
) {}
@Query(() => [LogFile]) @Query(() => [LogFile])
@UsePermissions({ @UsePermissions({
@@ -38,27 +41,12 @@ export class LogsResolver {
action: AuthAction.READ_ANY, action: AuthAction.READ_ANY,
resource: Resource.LOGS, resource: Resource.LOGS,
}) })
async logFileSubscription(@Args('path') path: string) { logFileSubscription(@Args('path') path: string) {
// Start watching the file // Register the topic and get the key
this.logsService.getLogFileSubscriptionChannel(path); const topicKey = this.logsService.registerLogFileSubscription(path);
// Create the async iterator // Use the helper service to create a tracked subscription
const asyncIterator = createSubscription(PUBSUB_CHANNEL.LOG_FILE); // This automatically handles subscribe/unsubscribe with reference counting
return this.subscriptionHelper.createTrackedSubscription(topicKey);
// Store the original return method to wrap it
const originalReturn = asyncIterator.return;
// Override the return method to clean up resources
asyncIterator.return = async () => {
// Stop watching the file when subscription ends
this.logsService.stopWatchingLogFile(path);
// Call the original return method
return originalReturn
? originalReturn.call(asyncIterator)
: Promise.resolve({ value: undefined, done: true });
};
return asyncIterator;
} }
} }

View File

@@ -0,0 +1,201 @@
import { Test, TestingModule } from '@nestjs/testing';
import * as fs from 'node:fs/promises';
import * as chokidar from 'chokidar';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import { LogWatcherManager } from '@app/unraid-api/graph/resolvers/logs/log-watcher-manager.service.js';
import { LogsService } from '@app/unraid-api/graph/resolvers/logs/logs.service.js';
import { SubscriptionTrackerService } from '@app/unraid-api/graph/services/subscription-tracker.service.js';
vi.mock('node:fs/promises');
vi.mock('chokidar');
vi.mock('@app/store/index.js', () => ({
getters: {
paths: () => ({
'unraid-log-base': '/var/log',
}),
},
}));
vi.mock('@app/core/pubsub.js', () => ({
pubsub: {
publish: vi.fn(),
},
PUBSUB_CHANNEL: {},
}));
describe('LogsService', () => {
let service: LogsService;
let mockWatcher: any;
let subscriptionTracker: any;
beforeEach(async () => {
// Create a mock watcher
mockWatcher = {
on: vi.fn(),
close: vi.fn(),
};
// Mock chokidar.watch to return our mock watcher
vi.mocked(chokidar.watch).mockReturnValue(mockWatcher as any);
// Mock fs.stat to return a file size
vi.mocked(fs.stat).mockResolvedValue({ size: 1000 } as any);
subscriptionTracker = {
getSubscriberCount: vi.fn().mockReturnValue(0),
registerTopic: vi.fn(),
};
const module: TestingModule = await Test.createTestingModule({
providers: [
LogsService,
LogWatcherManager,
{
provide: SubscriptionTrackerService,
useValue: subscriptionTracker,
},
],
}).compile();
service = module.get<LogsService>(LogsService);
});
afterEach(() => {
vi.clearAllMocks();
});
it('should be defined', () => {
expect(service).toBeDefined();
});
it('should handle race condition when stopping watcher during initialization', async () => {
// Setup: Register the subscription which will trigger registerTopic
service.registerLogFileSubscription('test.log');
// Get the onStart callback that was registered
const registerTopicCall = subscriptionTracker.registerTopic.mock.calls[0];
const onStartCallback = registerTopicCall[1];
const onStopCallback = registerTopicCall[2];
// Create a promise to control when stat resolves
let statResolve: any;
const statPromise = new Promise((resolve) => {
statResolve = resolve;
});
vi.mocked(fs.stat).mockReturnValue(statPromise as any);
// Start the watcher (this will call startWatchingLogFile internally)
onStartCallback();
// At this point, the watcher should be marked as 'initializing'
// Now call stop before the stat promise resolves
onStopCallback();
// Now resolve the stat promise to complete initialization
statResolve({ size: 1000 });
// Wait for any async operations to complete
await new Promise((resolve) => setImmediate(resolve));
// The watcher should have been closed due to the race condition check
expect(mockWatcher.close).toHaveBeenCalled();
});
it('should not leak watcher if stopped multiple times during initialization', async () => {
// Setup: Register the subscription
service.registerLogFileSubscription('test.log');
const registerTopicCall = subscriptionTracker.registerTopic.mock.calls[0];
const onStartCallback = registerTopicCall[1];
const onStopCallback = registerTopicCall[2];
// Create controlled stat promise
let statResolve: any;
const statPromise = new Promise((resolve) => {
statResolve = resolve;
});
vi.mocked(fs.stat).mockReturnValue(statPromise as any);
// Start the watcher
onStartCallback();
// Call stop multiple times during initialization
onStopCallback();
onStopCallback();
onStopCallback();
// Complete initialization
statResolve({ size: 1000 });
await new Promise((resolve) => setImmediate(resolve));
// Should only close once
expect(mockWatcher.close).toHaveBeenCalledTimes(1);
});
it('should properly handle normal start and stop without race condition', async () => {
// Setup: Register the subscription
service.registerLogFileSubscription('test.log');
const registerTopicCall = subscriptionTracker.registerTopic.mock.calls[0];
const onStartCallback = registerTopicCall[1];
const onStopCallback = registerTopicCall[2];
// Make stat resolve immediately
vi.mocked(fs.stat).mockResolvedValue({ size: 1000 } as any);
// Start the watcher and let it complete initialization
onStartCallback();
await new Promise((resolve) => setImmediate(resolve));
// Watcher should be created but not closed
expect(chokidar.watch).toHaveBeenCalled();
expect(mockWatcher.close).not.toHaveBeenCalled();
// Now stop it normally
onStopCallback();
// Watcher should be closed
expect(mockWatcher.close).toHaveBeenCalledTimes(1);
});
it('should handle error during initialization without leaking watchers', async () => {
// Setup: Register the subscription
service.registerLogFileSubscription('test.log');
const registerTopicCall = subscriptionTracker.registerTopic.mock.calls[0];
const onStartCallback = registerTopicCall[1];
// Make stat reject with an error
vi.mocked(fs.stat).mockRejectedValue(new Error('File not found'));
// Start the watcher (should fail during initialization)
onStartCallback();
await new Promise((resolve) => setImmediate(resolve));
// Watcher should never be created due to stat error
expect(chokidar.watch).not.toHaveBeenCalled();
expect(mockWatcher.close).not.toHaveBeenCalled();
});
it('should not create duplicate watchers when started multiple times', async () => {
// Setup: Register the subscription
service.registerLogFileSubscription('test.log');
const registerTopicCall = subscriptionTracker.registerTopic.mock.calls[0];
const onStartCallback = registerTopicCall[1];
// Make stat resolve immediately
vi.mocked(fs.stat).mockResolvedValue({ size: 1000 } as any);
// Start the watcher multiple times
onStartCallback();
onStartCallback();
onStartCallback();
await new Promise((resolve) => setImmediate(resolve));
// Should only create one watcher
expect(chokidar.watch).toHaveBeenCalledTimes(1);
});
});

View File

@@ -1,13 +1,15 @@
import { Injectable, Logger } from '@nestjs/common'; import { Injectable, Logger } from '@nestjs/common';
import { createReadStream } from 'node:fs'; import { createReadStream } from 'node:fs';
import { readdir, readFile, stat } from 'node:fs/promises'; import { readdir, stat } from 'node:fs/promises';
import { basename, join } from 'node:path'; import { basename, join } from 'node:path';
import { createInterface } from 'node:readline'; import { createInterface } from 'node:readline';
import * as chokidar from 'chokidar'; import * as chokidar from 'chokidar';
import { pubsub, PUBSUB_CHANNEL } from '@app/core/pubsub.js'; import { pubsub } from '@app/core/pubsub.js';
import { getters } from '@app/store/index.js'; import { getters } from '@app/store/index.js';
import { LogWatcherManager } from '@app/unraid-api/graph/resolvers/logs/log-watcher-manager.service.js';
import { SubscriptionTrackerService } from '@app/unraid-api/graph/services/subscription-tracker.service.js';
interface LogFile { interface LogFile {
name: string; name: string;
@@ -26,12 +28,13 @@ interface LogFileContent {
@Injectable() @Injectable()
export class LogsService { export class LogsService {
private readonly logger = new Logger(LogsService.name); private readonly logger = new Logger(LogsService.name);
private readonly logWatchers = new Map<
string,
{ watcher: chokidar.FSWatcher; position: number; subscriptionCount: number }
>();
private readonly DEFAULT_LINES = 100; private readonly DEFAULT_LINES = 100;
constructor(
private readonly subscriptionTracker: SubscriptionTrackerService,
private readonly watcherManager: LogWatcherManager
) {}
/** /**
* Get the base path for log files * Get the base path for log files
*/ */
@@ -111,135 +114,208 @@ export class LogsService {
} }
/** /**
* Get the subscription channel for a log file * Register and get the topic key for a log file subscription
* @param path Path to the log file * @param path Path to the log file
* @returns The subscription topic key
*/ */
getLogFileSubscriptionChannel(path: string): PUBSUB_CHANNEL { registerLogFileSubscription(path: string): string {
const normalizedPath = join(this.logBasePath, basename(path)); const normalizedPath = join(this.logBasePath, basename(path));
const topicKey = this.getTopicKey(normalizedPath);
// Start watching the file if not already watching // Register the topic if not already registered
if (!this.logWatchers.has(normalizedPath)) { if (!this.subscriptionTracker.getSubscriberCount(topicKey)) {
this.startWatchingLogFile(normalizedPath); this.logger.debug(`Registering log file subscription topic: ${topicKey}`);
} else {
// Increment subscription count for existing watcher this.subscriptionTracker.registerTopic(
const watcher = this.logWatchers.get(normalizedPath); topicKey,
if (watcher) { // onStart handler
watcher.subscriptionCount++; () => {
this.logger.debug( this.logger.debug(`Starting log file watcher for topic: ${topicKey}`);
`Incremented subscription count for ${normalizedPath} to ${watcher.subscriptionCount}` this.startWatchingLogFile(normalizedPath);
); },
} // onStop handler
() => {
this.logger.debug(`Stopping log file watcher for topic: ${topicKey}`);
this.stopWatchingLogFile(normalizedPath);
}
);
} }
return PUBSUB_CHANNEL.LOG_FILE; return topicKey;
} }
/** /**
* Start watching a log file for changes using chokidar * Start watching a log file for changes using chokidar
* @param path Path to the log file * @param path Path to the log file
*/ */
private async startWatchingLogFile(path: string): Promise<void> { private startWatchingLogFile(path: string): void {
try { const watcherKey = path;
// Get initial file size
const stats = await stat(path);
let position = stats.size;
// Create a watcher for the file using chokidar // Check if already watching or initializing
const watcher = chokidar.watch(path, { if (this.watcherManager.isWatchingOrInitializing(watcherKey)) {
persistent: true, this.logger.debug(`Already watching or initializing log file: ${watcherKey}`);
awaitWriteFinish: { return;
stabilityThreshold: 300, }
pollInterval: 100,
},
});
watcher.on('change', async () => { // Mark as initializing immediately to prevent race conditions
try { this.watcherManager.setInitializing(watcherKey);
const newStats = await stat(path);
// If the file has grown // Get initial file size and set up watcher
if (newStats.size > position) { stat(path)
// Read only the new content .then((stats) => {
const stream = createReadStream(path, { const position = stats.size;
start: position,
end: newStats.size - 1,
});
let newContent = ''; // Create a watcher for the file using chokidar
stream.on('data', (chunk) => { const watcher = chokidar.watch(path, {
newContent += chunk.toString(); persistent: true,
}); awaitWriteFinish: {
stabilityThreshold: 300,
pollInterval: 100,
},
});
stream.on('end', () => { watcher.on('change', async () => {
if (newContent) { // Check if we're already processing a change event for this file
pubsub.publish(PUBSUB_CHANNEL.LOG_FILE, { if (!this.watcherManager.startProcessing(watcherKey)) {
// Already processing, ignore this event
return;
}
try {
const newStats = await stat(path);
// Get the current position
const currentPosition = this.watcherManager.getPosition(watcherKey);
if (currentPosition === undefined) {
// Watcher was stopped or not active, ignore the event
return;
}
// If the file has grown
if (newStats.size > currentPosition) {
// Read only the new content
const stream = createReadStream(path, {
start: currentPosition,
end: newStats.size - 1,
});
let newContent = '';
stream.on('data', (chunk) => {
newContent += chunk.toString();
});
stream.on('end', () => {
try {
if (newContent) {
// Use topic-specific channel
const topicKey = this.getTopicKey(path);
pubsub.publish(topicKey, {
logFile: {
path,
content: newContent,
totalLines: 0, // We don't need to count lines for updates
},
});
}
// Update position for next read (while still holding the guard)
this.watcherManager.updatePosition(watcherKey, newStats.size);
} finally {
// Clear the in-flight flag
this.watcherManager.finishProcessing(watcherKey);
}
});
stream.on('error', (error) => {
this.logger.error(`Error reading stream for ${path}: ${error}`);
// Clear the in-flight flag on error
this.watcherManager.finishProcessing(watcherKey);
});
} else if (newStats.size < currentPosition) {
// File was truncated, reset position and read from beginning
this.logger.debug(`File ${path} was truncated, resetting position`);
try {
// Read the entire file content
const content = await this.getLogFileContent(
path,
this.DEFAULT_LINES,
undefined
);
// Use topic-specific channel
const topicKey = this.getTopicKey(path);
pubsub.publish(topicKey, {
logFile: { logFile: {
path, ...content,
content: newContent,
totalLines: 0, // We don't need to count lines for updates
}, },
}); });
// Update position (while still holding the guard)
this.watcherManager.updatePosition(watcherKey, newStats.size);
} finally {
// Clear the in-flight flag
this.watcherManager.finishProcessing(watcherKey);
} }
} else {
// Update position for next read // File size unchanged, clear the in-flight flag
position = newStats.size; this.watcherManager.finishProcessing(watcherKey);
}); }
} else if (newStats.size < position) { } catch (error: unknown) {
// File was truncated, reset position and read from beginning this.logger.error(`Error processing file change for ${path}: ${error}`);
position = 0; // Clear the in-flight flag on error
this.logger.debug(`File ${path} was truncated, resetting position`); this.watcherManager.finishProcessing(watcherKey);
// Read the entire file content
const content = await this.getLogFileContent(path);
pubsub.publish(PUBSUB_CHANNEL.LOG_FILE, {
logFile: content,
});
position = newStats.size;
} }
} catch (error: unknown) { });
this.logger.error(`Error processing file change for ${path}: ${error}`);
watcher.on('error', (error) => {
this.logger.error(`Chokidar watcher error for ${path}: ${error}`);
});
// Check if we were stopped during initialization and handle cleanup
if (!this.watcherManager.handlePostInitialization(watcherKey, watcher, position)) {
return;
} }
// Publish initial snapshot
this.getLogFileContent(path, this.DEFAULT_LINES, undefined)
.then((content) => {
const topicKey = this.getTopicKey(path);
pubsub.publish(topicKey, {
logFile: {
...content,
},
});
})
.catch((error) => {
this.logger.error(`Error publishing initial log content for ${path}: ${error}`);
});
this.logger.debug(`Started watching log file with chokidar: ${path}`);
})
.catch((error) => {
this.logger.error(`Error setting up file watcher for ${path}: ${error}`);
// Clean up the initializing entry on error
this.watcherManager.removeEntry(watcherKey);
}); });
}
watcher.on('error', (error) => { /**
this.logger.error(`Chokidar watcher error for ${path}: ${error}`); * Get the topic key for a log file subscription
}); * @param path Path to the log file (should already be normalized)
* @returns The topic key
// Store the watcher and current position with initial subscription count of 1 */
this.logWatchers.set(path, { watcher, position, subscriptionCount: 1 }); private getTopicKey(path: string): string {
// Assume path is already normalized (full path)
this.logger.debug( return `LOG_FILE:${path}`;
`Started watching log file with chokidar: ${path} (subscription count: 1)`
);
} catch (error: unknown) {
this.logger.error(`Error setting up chokidar file watcher for ${path}: ${error}`);
}
} }
/** /**
* Stop watching a log file * Stop watching a log file
* @param path Path to the log file * @param path Path to the log file
*/ */
public stopWatchingLogFile(path: string): void { private stopWatchingLogFile(path: string): void {
const normalizedPath = join(this.logBasePath, basename(path)); this.watcherManager.stopWatcher(path);
const watcher = this.logWatchers.get(normalizedPath);
if (watcher) {
// Decrement subscription count
watcher.subscriptionCount--;
this.logger.debug(
`Decremented subscription count for ${normalizedPath} to ${watcher.subscriptionCount}`
);
// Only close the watcher when subscription count reaches 0
if (watcher.subscriptionCount <= 0) {
watcher.watcher.close();
this.logWatchers.delete(normalizedPath);
this.logger.debug(`Stopped watching log file: ${normalizedPath} (no more subscribers)`);
}
}
} }
/** /**

View File

@@ -9,7 +9,7 @@ import { CpuService } from '@app/unraid-api/graph/resolvers/info/cpu/cpu.service
import { MemoryService } from '@app/unraid-api/graph/resolvers/info/memory/memory.service.js'; import { MemoryService } from '@app/unraid-api/graph/resolvers/info/memory/memory.service.js';
import { MetricsResolver } from '@app/unraid-api/graph/resolvers/metrics/metrics.resolver.js'; import { MetricsResolver } from '@app/unraid-api/graph/resolvers/metrics/metrics.resolver.js';
import { SubscriptionHelperService } from '@app/unraid-api/graph/services/subscription-helper.service.js'; import { SubscriptionHelperService } from '@app/unraid-api/graph/services/subscription-helper.service.js';
import { SubscriptionPollingService } from '@app/unraid-api/graph/services/subscription-polling.service.js'; import { SubscriptionManagerService } from '@app/unraid-api/graph/services/subscription-manager.service.js';
import { SubscriptionTrackerService } from '@app/unraid-api/graph/services/subscription-tracker.service.js'; import { SubscriptionTrackerService } from '@app/unraid-api/graph/services/subscription-tracker.service.js';
describe('MetricsResolver Integration Tests', () => { describe('MetricsResolver Integration Tests', () => {
@@ -25,7 +25,7 @@ describe('MetricsResolver Integration Tests', () => {
MemoryService, MemoryService,
SubscriptionTrackerService, SubscriptionTrackerService,
SubscriptionHelperService, SubscriptionHelperService,
SubscriptionPollingService, SubscriptionManagerService,
], ],
}).compile(); }).compile();
@@ -36,8 +36,8 @@ describe('MetricsResolver Integration Tests', () => {
afterEach(async () => { afterEach(async () => {
// Clean up polling service // Clean up polling service
const pollingService = module.get<SubscriptionPollingService>(SubscriptionPollingService); const subscriptionManager = module.get<SubscriptionManagerService>(SubscriptionManagerService);
pollingService.stopAll(); subscriptionManager.stopAll();
await module.close(); await module.close();
}); });
@@ -202,10 +202,13 @@ describe('MetricsResolver Integration Tests', () => {
it('should handle errors in CPU polling gracefully', async () => { it('should handle errors in CPU polling gracefully', async () => {
const service = module.get<CpuService>(CpuService); const service = module.get<CpuService>(CpuService);
const trackerService = module.get<SubscriptionTrackerService>(SubscriptionTrackerService); const trackerService = module.get<SubscriptionTrackerService>(SubscriptionTrackerService);
const pollingService = module.get<SubscriptionPollingService>(SubscriptionPollingService); const subscriptionManager =
module.get<SubscriptionManagerService>(SubscriptionManagerService);
// Mock logger to capture error logs // Mock logger to capture error logs
const loggerSpy = vi.spyOn(pollingService['logger'], 'error').mockImplementation(() => {}); const loggerSpy = vi
.spyOn(subscriptionManager['logger'], 'error')
.mockImplementation(() => {});
vi.spyOn(service, 'generateCpuLoad').mockRejectedValueOnce(new Error('CPU error')); vi.spyOn(service, 'generateCpuLoad').mockRejectedValueOnce(new Error('CPU error'));
// Trigger polling // Trigger polling
@@ -215,7 +218,7 @@ describe('MetricsResolver Integration Tests', () => {
await new Promise((resolve) => setTimeout(resolve, 1100)); await new Promise((resolve) => setTimeout(resolve, 1100));
expect(loggerSpy).toHaveBeenCalledWith( expect(loggerSpy).toHaveBeenCalledWith(
expect.stringContaining('Error in polling task'), expect.stringContaining('Error in subscription callback'),
expect.any(Error) expect.any(Error)
); );
@@ -226,10 +229,13 @@ describe('MetricsResolver Integration Tests', () => {
it('should handle errors in memory polling gracefully', async () => { it('should handle errors in memory polling gracefully', async () => {
const service = module.get<MemoryService>(MemoryService); const service = module.get<MemoryService>(MemoryService);
const trackerService = module.get<SubscriptionTrackerService>(SubscriptionTrackerService); const trackerService = module.get<SubscriptionTrackerService>(SubscriptionTrackerService);
const pollingService = module.get<SubscriptionPollingService>(SubscriptionPollingService); const subscriptionManager =
module.get<SubscriptionManagerService>(SubscriptionManagerService);
// Mock logger to capture error logs // Mock logger to capture error logs
const loggerSpy = vi.spyOn(pollingService['logger'], 'error').mockImplementation(() => {}); const loggerSpy = vi
.spyOn(subscriptionManager['logger'], 'error')
.mockImplementation(() => {});
vi.spyOn(service, 'generateMemoryLoad').mockRejectedValueOnce(new Error('Memory error')); vi.spyOn(service, 'generateMemoryLoad').mockRejectedValueOnce(new Error('Memory error'));
// Trigger polling // Trigger polling
@@ -239,7 +245,7 @@ describe('MetricsResolver Integration Tests', () => {
await new Promise((resolve) => setTimeout(resolve, 2100)); await new Promise((resolve) => setTimeout(resolve, 2100));
expect(loggerSpy).toHaveBeenCalledWith( expect(loggerSpy).toHaveBeenCalledWith(
expect.stringContaining('Error in polling task'), expect.stringContaining('Error in subscription callback'),
expect.any(Error) expect.any(Error)
); );
@@ -251,22 +257,30 @@ describe('MetricsResolver Integration Tests', () => {
describe('Polling cleanup on module destroy', () => { describe('Polling cleanup on module destroy', () => {
it('should clean up timers when module is destroyed', async () => { it('should clean up timers when module is destroyed', async () => {
const trackerService = module.get<SubscriptionTrackerService>(SubscriptionTrackerService); const trackerService = module.get<SubscriptionTrackerService>(SubscriptionTrackerService);
const pollingService = module.get<SubscriptionPollingService>(SubscriptionPollingService); const subscriptionManager =
module.get<SubscriptionManagerService>(SubscriptionManagerService);
// Start polling // Start polling
trackerService.subscribe(PUBSUB_CHANNEL.CPU_UTILIZATION); trackerService.subscribe(PUBSUB_CHANNEL.CPU_UTILIZATION);
trackerService.subscribe(PUBSUB_CHANNEL.MEMORY_UTILIZATION); trackerService.subscribe(PUBSUB_CHANNEL.MEMORY_UTILIZATION);
// Verify polling is active // Wait a bit for subscriptions to be fully set up
expect(pollingService.isPolling(PUBSUB_CHANNEL.CPU_UTILIZATION)).toBe(true); await new Promise((resolve) => setTimeout(resolve, 100));
expect(pollingService.isPolling(PUBSUB_CHANNEL.MEMORY_UTILIZATION)).toBe(true);
// Verify subscriptions are active
expect(subscriptionManager.isSubscriptionActive(PUBSUB_CHANNEL.CPU_UTILIZATION)).toBe(true);
expect(subscriptionManager.isSubscriptionActive(PUBSUB_CHANNEL.MEMORY_UTILIZATION)).toBe(
true
);
// Clean up the module // Clean up the module
await module.close(); await module.close();
// Timers should be cleaned up // Subscriptions should be cleaned up
expect(pollingService.isPolling(PUBSUB_CHANNEL.CPU_UTILIZATION)).toBe(false); expect(subscriptionManager.isSubscriptionActive(PUBSUB_CHANNEL.CPU_UTILIZATION)).toBe(false);
expect(pollingService.isPolling(PUBSUB_CHANNEL.MEMORY_UTILIZATION)).toBe(false); expect(subscriptionManager.isSubscriptionActive(PUBSUB_CHANNEL.MEMORY_UTILIZATION)).toBe(
false
);
}); });
}); });
}); });

View File

@@ -1,6 +1,7 @@
import { Module } from '@nestjs/common'; import { Module } from '@nestjs/common';
import { AuthModule } from '@app/unraid-api/auth/auth.module.js'; import { AuthModule } from '@app/unraid-api/auth/auth.module.js';
import { ApiConfigModule } from '@app/unraid-api/config/api-config.module.js';
import { ApiKeyModule } from '@app/unraid-api/graph/resolvers/api-key/api-key.module.js'; import { ApiKeyModule } from '@app/unraid-api/graph/resolvers/api-key/api-key.module.js';
import { ApiKeyResolver } from '@app/unraid-api/graph/resolvers/api-key/api-key.resolver.js'; import { ApiKeyResolver } from '@app/unraid-api/graph/resolvers/api-key/api-key.resolver.js';
import { ArrayModule } from '@app/unraid-api/graph/resolvers/array/array.module.js'; import { ArrayModule } from '@app/unraid-api/graph/resolvers/array/array.module.js';
@@ -11,8 +12,7 @@ import { DockerModule } from '@app/unraid-api/graph/resolvers/docker/docker.modu
import { FlashBackupModule } from '@app/unraid-api/graph/resolvers/flash-backup/flash-backup.module.js'; import { FlashBackupModule } from '@app/unraid-api/graph/resolvers/flash-backup/flash-backup.module.js';
import { FlashResolver } from '@app/unraid-api/graph/resolvers/flash/flash.resolver.js'; import { FlashResolver } from '@app/unraid-api/graph/resolvers/flash/flash.resolver.js';
import { InfoModule } from '@app/unraid-api/graph/resolvers/info/info.module.js'; import { InfoModule } from '@app/unraid-api/graph/resolvers/info/info.module.js';
import { LogsResolver } from '@app/unraid-api/graph/resolvers/logs/logs.resolver.js'; import { LogsModule } from '@app/unraid-api/graph/resolvers/logs/logs.module.js';
import { LogsService } from '@app/unraid-api/graph/resolvers/logs/logs.service.js';
import { MetricsModule } from '@app/unraid-api/graph/resolvers/metrics/metrics.module.js'; import { MetricsModule } from '@app/unraid-api/graph/resolvers/metrics/metrics.module.js';
import { RootMutationsResolver } from '@app/unraid-api/graph/resolvers/mutation/mutation.resolver.js'; import { RootMutationsResolver } from '@app/unraid-api/graph/resolvers/mutation/mutation.resolver.js';
import { NotificationsResolver } from '@app/unraid-api/graph/resolvers/notifications/notifications.resolver.js'; import { NotificationsResolver } from '@app/unraid-api/graph/resolvers/notifications/notifications.resolver.js';
@@ -39,12 +39,14 @@ import { MeResolver } from '@app/unraid-api/graph/user/user.resolver.js';
ServicesModule, ServicesModule,
ArrayModule, ArrayModule,
ApiKeyModule, ApiKeyModule,
ApiConfigModule,
AuthModule, AuthModule,
CustomizationModule, CustomizationModule,
DockerModule, DockerModule,
DisksModule, DisksModule,
FlashBackupModule, FlashBackupModule,
InfoModule, InfoModule,
LogsModule,
RCloneModule, RCloneModule,
SettingsModule, SettingsModule,
SsoModule, SsoModule,
@@ -54,8 +56,6 @@ import { MeResolver } from '@app/unraid-api/graph/user/user.resolver.js';
providers: [ providers: [
ConfigResolver, ConfigResolver,
FlashResolver, FlashResolver,
LogsResolver,
LogsService,
MeResolver, MeResolver,
NotificationsResolver, NotificationsResolver,
NotificationsService, NotificationsService,

View File

@@ -16,8 +16,8 @@ import {
} from '@app/unraid-api/graph/resolvers/settings/settings.model.js'; } from '@app/unraid-api/graph/resolvers/settings/settings.model.js';
import { ApiSettings } from '@app/unraid-api/graph/resolvers/settings/settings.service.js'; import { ApiSettings } from '@app/unraid-api/graph/resolvers/settings/settings.service.js';
import { SsoSettings } from '@app/unraid-api/graph/resolvers/settings/sso-settings.model.js'; import { SsoSettings } from '@app/unraid-api/graph/resolvers/settings/sso-settings.model.js';
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/oidc-config.service.js'; import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/oidc-provider.model.js'; import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
@Resolver(() => Settings) @Resolver(() => Settings)
export class SettingsResolver { export class SettingsResolver {

View File

@@ -7,7 +7,7 @@ import { type ApiConfig } from '@unraid/shared/services/api-config.js';
import { UserSettingsService } from '@unraid/shared/services/user-settings.js'; import { UserSettingsService } from '@unraid/shared/services/user-settings.js';
import { execa } from 'execa'; import { execa } from 'execa';
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/oidc-config.service.js'; import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
import { createLabeledControl } from '@app/unraid-api/graph/utils/form-utils.js'; import { createLabeledControl } from '@app/unraid-api/graph/utils/form-utils.js';
import { SettingSlice } from '@app/unraid-api/types/json-forms.js'; import { SettingSlice } from '@app/unraid-api/types/json-forms.js';

View File

@@ -0,0 +1,11 @@
import { Module } from '@nestjs/common';
import { OidcAuthorizationService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-authorization.service.js';
import { OidcClaimsService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-claims.service.js';
import { OidcTokenExchangeService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-token-exchange.service.js';
@Module({
providers: [OidcAuthorizationService, OidcTokenExchangeService, OidcClaimsService],
exports: [OidcAuthorizationService, OidcTokenExchangeService, OidcClaimsService],
})
export class OidcAuthModule {}

View File

@@ -1,70 +1,26 @@
import { UnauthorizedException } from '@nestjs/common'; import { UnauthorizedException } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import { Test, TestingModule } from '@nestjs/testing'; import { Test, TestingModule } from '@nestjs/testing';
import * as client from 'openid-client';
import { beforeEach, describe, expect, it, vi } from 'vitest'; import { beforeEach, describe, expect, it, vi } from 'vitest';
import { OidcAuthService } from '@app/unraid-api/graph/resolvers/sso/oidc-auth.service.js'; import { OidcAuthorizationService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-authorization.service.js';
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/oidc-config.service.js';
import { import {
AuthorizationOperator, AuthorizationOperator,
AuthorizationRuleMode, AuthorizationRuleMode,
OidcAuthorizationRule, OidcAuthorizationRule,
OidcProvider, OidcProvider,
} from '@app/unraid-api/graph/resolvers/sso/oidc-provider.model.js'; } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/oidc-session.service.js';
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/oidc-state.service.js';
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/oidc-validation.service.js';
describe('OidcAuthService', () => { describe('OidcAuthorizationService', () => {
let service: OidcAuthService; let service: OidcAuthorizationService;
let oidcConfig: any;
let sessionService: any;
let configService: any;
let stateService: any;
let validationService: any;
let module: TestingModule; let module: TestingModule;
beforeEach(async () => { beforeEach(async () => {
module = await Test.createTestingModule({ module = await Test.createTestingModule({
providers: [ providers: [OidcAuthorizationService],
OidcAuthService,
{
provide: ConfigService,
useValue: {
get: vi.fn(),
},
},
{
provide: OidcConfigPersistence,
useValue: {
getProvider: vi.fn(),
},
},
{
provide: OidcSessionService,
useValue: {
createSession: vi.fn(),
},
},
OidcStateService,
{
provide: OidcValidationService,
useValue: {
validateProvider: vi.fn(),
performDiscovery: vi.fn(),
},
},
],
}).compile(); }).compile();
service = module.get<OidcAuthService>(OidcAuthService); service = module.get<OidcAuthorizationService>(OidcAuthorizationService);
oidcConfig = module.get(OidcConfigPersistence);
sessionService = module.get(OidcSessionService);
configService = module.get(ConfigService);
stateService = module.get(OidcStateService);
validationService = module.get<OidcValidationService>(OidcValidationService);
}); });
describe('Authorization Rule Evaluation', () => { describe('Authorization Rule Evaluation', () => {
@@ -1189,467 +1145,4 @@ describe('OidcAuthService', () => {
).resolves.toBeUndefined(); ).resolves.toBeUndefined();
}); });
}); });
describe('Manual Configuration (No Discovery)', () => {
it('should create manual configuration when discovery fails but manual endpoints are provided', async () => {
const provider: OidcProvider = {
id: 'manual-provider',
name: 'Manual Provider',
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
issuer: 'https://manual.example.com',
authorizationEndpoint: 'https://manual.example.com/auth',
tokenEndpoint: 'https://manual.example.com/token',
jwksUri: 'https://manual.example.com/jwks',
scopes: ['openid', 'profile'],
authorizationRules: [],
};
oidcConfig.getProvider.mockResolvedValue(provider);
// Mock discovery to fail
validationService.performDiscovery = vi
.fn()
.mockRejectedValue(new Error('Discovery failed'));
// Access the private method
const getOrCreateConfig = async (provider: OidcProvider) => {
return (service as any).getOrCreateConfig(provider);
};
const config = await getOrCreateConfig(provider);
// Verify the configuration was created with the correct endpoints
expect(config).toBeDefined();
expect(config.serverMetadata().authorization_endpoint).toBe(
'https://manual.example.com/auth'
);
expect(config.serverMetadata().token_endpoint).toBe('https://manual.example.com/token');
expect(config.serverMetadata().jwks_uri).toBe('https://manual.example.com/jwks');
expect(config.serverMetadata().issuer).toBe('https://manual.example.com');
});
it('should create manual configuration with fallback issuer when not provided', async () => {
const provider: OidcProvider = {
id: 'manual-provider-no-issuer',
name: 'Manual Provider No Issuer',
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
issuer: '', // Empty issuer should skip discovery and use manual endpoints
authorizationEndpoint: 'https://manual.example.com/auth',
tokenEndpoint: 'https://manual.example.com/token',
scopes: ['openid', 'profile'],
authorizationRules: [],
};
oidcConfig.getProvider.mockResolvedValue(provider);
// No need to mock discovery since it won't be called with empty issuer
// Access the private method
const getOrCreateConfig = async (provider: OidcProvider) => {
return (service as any).getOrCreateConfig(provider);
};
const config = await getOrCreateConfig(provider);
// Verify the configuration was created with fallback issuer
expect(config).toBeDefined();
expect(config.serverMetadata().issuer).toBe('manual-manual-provider-no-issuer');
expect(config.serverMetadata().authorization_endpoint).toBe(
'https://manual.example.com/auth'
);
expect(config.serverMetadata().token_endpoint).toBe('https://manual.example.com/token');
});
it('should handle manual configuration with client secret properly', async () => {
const provider: OidcProvider = {
id: 'manual-with-secret',
name: 'Manual With Secret',
clientId: 'test-client-id',
clientSecret: 'secret-123',
issuer: 'https://manual.example.com',
authorizationEndpoint: 'https://manual.example.com/auth',
tokenEndpoint: 'https://manual.example.com/token',
scopes: ['openid', 'profile'],
authorizationRules: [],
};
oidcConfig.getProvider.mockResolvedValue(provider);
// Mock discovery to fail
validationService.performDiscovery = vi
.fn()
.mockRejectedValue(new Error('Discovery failed'));
// Access the private method
const getOrCreateConfig = async (provider: OidcProvider) => {
return (service as any).getOrCreateConfig(provider);
};
const config = await getOrCreateConfig(provider);
// Verify configuration was created successfully
expect(config).toBeDefined();
expect(config.clientMetadata().client_secret).toBe('secret-123');
});
it('should handle manual configuration without client secret (public client)', async () => {
const provider: OidcProvider = {
id: 'manual-public-client',
name: 'Manual Public Client',
clientId: 'public-client-id',
// No client secret
issuer: 'https://manual.example.com',
authorizationEndpoint: 'https://manual.example.com/auth',
tokenEndpoint: 'https://manual.example.com/token',
scopes: ['openid', 'profile'],
authorizationRules: [],
};
oidcConfig.getProvider.mockResolvedValue(provider);
// Mock discovery to fail
validationService.performDiscovery = vi
.fn()
.mockRejectedValue(new Error('Discovery failed'));
// Access the private method
const getOrCreateConfig = async (provider: OidcProvider) => {
return (service as any).getOrCreateConfig(provider);
};
const config = await getOrCreateConfig(provider);
// Verify configuration was created successfully for public client
expect(config).toBeDefined();
expect(config.clientMetadata().client_secret).toBeUndefined();
});
it('should throw error when discovery fails and no manual endpoints provided', async () => {
const provider: OidcProvider = {
id: 'no-manual-endpoints',
name: 'No Manual Endpoints',
clientId: 'test-client-id',
issuer: 'https://broken.example.com',
// Missing authorizationEndpoint and tokenEndpoint
scopes: ['openid', 'profile'],
authorizationRules: [],
};
oidcConfig.getProvider.mockResolvedValue(provider);
// Mock discovery to fail
validationService.performDiscovery = vi
.fn()
.mockRejectedValue(new Error('Discovery failed'));
// Access the private method
const getOrCreateConfig = async (provider: OidcProvider) => {
return (service as any).getOrCreateConfig(provider);
};
await expect(getOrCreateConfig(provider)).rejects.toThrow(UnauthorizedException);
});
it('should throw error when only authorization endpoint is provided', async () => {
const provider: OidcProvider = {
id: 'partial-manual-endpoints',
name: 'Partial Manual Endpoints',
clientId: 'test-client-id',
issuer: 'https://broken.example.com',
authorizationEndpoint: 'https://manual.example.com/auth',
// Missing tokenEndpoint
scopes: ['openid', 'profile'],
authorizationRules: [],
};
oidcConfig.getProvider.mockResolvedValue(provider);
// Mock discovery to fail
validationService.performDiscovery = vi
.fn()
.mockRejectedValue(new Error('Discovery failed'));
// Access the private method
const getOrCreateConfig = async (provider: OidcProvider) => {
return (service as any).getOrCreateConfig(provider);
};
await expect(getOrCreateConfig(provider)).rejects.toThrow(UnauthorizedException);
});
it('should cache manual configuration properly', async () => {
const provider: OidcProvider = {
id: 'cache-test',
name: 'Cache Test',
clientId: 'test-client-id',
clientSecret: 'test-secret',
issuer: 'https://manual.example.com',
authorizationEndpoint: 'https://manual.example.com/auth',
tokenEndpoint: 'https://manual.example.com/token',
scopes: ['openid', 'profile'],
authorizationRules: [],
};
oidcConfig.getProvider.mockResolvedValue(provider);
// Mock discovery to fail
validationService.performDiscovery = vi
.fn()
.mockRejectedValue(new Error('Discovery failed'));
// Access the private method
const getOrCreateConfig = async (provider: OidcProvider) => {
return (service as any).getOrCreateConfig(provider);
};
// First call should create configuration
const config1 = await getOrCreateConfig(provider);
// Second call should return cached configuration
const config2 = await getOrCreateConfig(provider);
expect(config1).toBe(config2); // Should be the exact same instance
expect(validationService.performDiscovery).toHaveBeenCalledTimes(1); // Only called once due to caching
});
it('should handle HTTP endpoints with allowInsecureRequests', async () => {
const provider: OidcProvider = {
id: 'http-endpoints',
name: 'HTTP Endpoints',
clientId: 'test-client-id',
clientSecret: 'test-secret',
issuer: 'http://manual.example.com', // HTTP instead of HTTPS
authorizationEndpoint: 'http://manual.example.com/auth',
tokenEndpoint: 'http://manual.example.com/token',
scopes: ['openid', 'profile'],
authorizationRules: [],
};
oidcConfig.getProvider.mockResolvedValue(provider);
// Mock discovery to fail
validationService.performDiscovery = vi
.fn()
.mockRejectedValue(new Error('Discovery failed'));
// Access the private method
const getOrCreateConfig = async (provider: OidcProvider) => {
return (service as any).getOrCreateConfig(provider);
};
const config = await getOrCreateConfig(provider);
// Verify configuration was created successfully even with HTTP
expect(config).toBeDefined();
expect(config.serverMetadata().token_endpoint).toBe('http://manual.example.com/token');
expect(config.serverMetadata().authorization_endpoint).toBe(
'http://manual.example.com/auth'
);
});
});
describe('getAuthorizationUrl', () => {
it('should generate authorization URL with custom authorization endpoint', async () => {
const provider: OidcProvider = {
id: 'test-provider',
name: 'Test Provider',
clientId: 'test-client-id',
issuer: 'https://example.com',
authorizationEndpoint: 'https://custom.example.com/auth',
scopes: ['openid', 'profile'],
authorizationRules: [],
};
oidcConfig.getProvider.mockResolvedValue(provider);
const authUrl = await service.getAuthorizationUrl(
'test-provider',
'test-state',
'localhost:3001'
);
expect(authUrl).toContain('https://custom.example.com/auth');
expect(authUrl).toContain('client_id=test-client-id');
expect(authUrl).toContain('response_type=code');
expect(authUrl).toContain('scope=openid+profile');
// State should start with provider ID followed by secure state token
expect(authUrl).toMatch(/state=test-provider%3A[a-f0-9]+\.[0-9]+\.[a-f0-9]+/);
expect(authUrl).toContain('redirect_uri=');
});
it('should encode provider ID in state parameter', async () => {
const provider: OidcProvider = {
id: 'encode-test-provider',
name: 'Encode Test Provider',
clientId: 'test-client-id',
issuer: 'https://example.com',
authorizationEndpoint: 'https://example.com/auth',
scopes: ['openid', 'email'],
authorizationRules: [],
};
oidcConfig.getProvider.mockResolvedValue(provider);
const authUrl = await service.getAuthorizationUrl('encode-test-provider', 'original-state');
// Verify that the state parameter includes provider ID at the start
expect(authUrl).toMatch(/state=encode-test-provider%3A[a-f0-9]+\.[0-9]+\.[a-f0-9]+/);
});
it('should throw error when provider not found', async () => {
oidcConfig.getProvider.mockResolvedValue(null);
await expect(
service.getAuthorizationUrl('nonexistent-provider', 'test-state')
).rejects.toThrow('Provider nonexistent-provider not found');
});
it('should handle custom scopes properly', async () => {
const provider: OidcProvider = {
id: 'custom-scopes-provider',
name: 'Custom Scopes Provider',
clientId: 'test-client-id',
issuer: 'https://example.com',
authorizationEndpoint: 'https://example.com/auth',
scopes: ['openid', 'profile', 'groups', 'custom:scope'],
authorizationRules: [],
};
oidcConfig.getProvider.mockResolvedValue(provider);
const authUrl = await service.getAuthorizationUrl('custom-scopes-provider', 'test-state');
expect(authUrl).toContain('scope=openid+profile+groups+custom%3Ascope');
});
});
describe('handleCallback', () => {
it('should throw error when provider not found in callback', async () => {
oidcConfig.getProvider.mockResolvedValue(null);
await expect(
service.handleCallback('nonexistent-provider', 'code', 'redirect-uri')
).rejects.toThrow('Provider nonexistent-provider not found');
});
it('should handle malformed state parameter', async () => {
await expect(
service.handleCallback('invalid-state', 'code', 'redirect-uri')
).rejects.toThrow(UnauthorizedException);
});
it('should call getProvider with the provided provider ID', async () => {
const provider: OidcProvider = {
id: 'test-provider',
name: 'Test Provider',
clientId: 'test-client-id',
issuer: 'https://example.com',
scopes: ['openid'],
authorizationRules: [],
};
oidcConfig.getProvider.mockResolvedValue(provider);
// This will fail during token exchange, but we're testing the provider lookup logic
await expect(
service.handleCallback('test-provider', 'code', 'redirect-uri')
).rejects.toThrow(UnauthorizedException);
// Verify the provider was looked up with the correct ID
expect(oidcConfig.getProvider).toHaveBeenCalledWith('test-provider');
});
});
describe('validateProvider', () => {
it('should delegate to validation service and return result', async () => {
const provider: OidcProvider = {
id: 'validate-provider',
name: 'Validate Provider',
clientId: 'test-client-id',
issuer: 'https://example.com',
scopes: ['openid'],
authorizationRules: [],
};
const expectedResult = {
isValid: true,
authorizationEndpoint: 'https://example.com/auth',
tokenEndpoint: 'https://example.com/token',
};
validationService.validateProvider.mockResolvedValue(expectedResult);
const result = await service.validateProvider(provider);
expect(result).toEqual(expectedResult);
expect(validationService.validateProvider).toHaveBeenCalledWith(provider);
});
it('should clear config cache before validation', async () => {
const provider: OidcProvider = {
id: 'cache-clear-provider',
name: 'Cache Clear Provider',
clientId: 'test-client-id',
issuer: 'https://example.com',
scopes: ['openid'],
authorizationRules: [],
};
const expectedResult = {
isValid: false,
error: 'Validation failed',
};
validationService.validateProvider.mockResolvedValue(expectedResult);
const result = await service.validateProvider(provider);
expect(result).toEqual(expectedResult);
// Verify the cache was cleared by checking the method was called
expect(validationService.validateProvider).toHaveBeenCalledWith(provider);
});
});
describe('getRedirectUri (private method)', () => {
it('should generate correct redirect URI with localhost (development)', () => {
const getRedirectUri = (service as any).getRedirectUri.bind(service);
const redirectUri = getRedirectUri('http://localhost:3000');
expect(redirectUri).toBe('http://localhost:3000/graphql/api/auth/oidc/callback');
});
it('should generate correct redirect URI with non-localhost host', () => {
const getRedirectUri = (service as any).getRedirectUri.bind(service);
const redirectUri = getRedirectUri('https://example.com');
expect(redirectUri).toBe('https://example.com/graphql/api/auth/oidc/callback');
});
it('should handle HTTP protocol for non-localhost hosts', () => {
const getRedirectUri = (service as any).getRedirectUri.bind(service);
const redirectUri = getRedirectUri('http://tower.local');
expect(redirectUri).toBe('http://tower.local/graphql/api/auth/oidc/callback');
});
it('should handle non-standard ports correctly', () => {
const getRedirectUri = (service as any).getRedirectUri.bind(service);
const redirectUri = getRedirectUri('http://example.com:8080');
expect(redirectUri).toBe('http://example.com:8080/graphql/api/auth/oidc/callback');
});
it('should use default redirect URI when no request host provided', () => {
const getRedirectUri = (service as any).getRedirectUri.bind(service);
// Mock the ConfigService to return a default value
configService.get.mockReturnValue('http://tower.local');
const redirectUri = getRedirectUri();
expect(redirectUri).toBe('http://tower.local/graphql/api/auth/oidc/callback');
});
});
}); });

View File

@@ -0,0 +1,170 @@
import { Injectable, Logger, UnauthorizedException } from '@nestjs/common';
import {
AuthorizationOperator,
AuthorizationRuleMode,
OidcAuthorizationRule,
OidcProvider,
} from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
interface JwtClaims {
sub?: string;
email?: string;
name?: string;
hd?: string; // Google hosted domain
[claim: string]: unknown;
}
@Injectable()
export class OidcAuthorizationService {
private readonly logger = new Logger(OidcAuthorizationService.name);
/**
* Check authorization based on rules
* This will throw a helpful error if misconfigured or unauthorized
*/
async checkAuthorization(provider: OidcProvider, claims: JwtClaims): Promise<void> {
this.logger.debug(
`Checking authorization for provider ${provider.id} with ${provider.authorizationRules?.length || 0} rules`
);
this.logger.debug(`Available claims: ${Object.keys(claims).join(', ')}`);
this.logger.debug(
`Authorization rule mode: ${provider.authorizationRuleMode || AuthorizationRuleMode.OR}`
);
// If no authorization rules are specified, throw a helpful error
if (!provider.authorizationRules || provider.authorizationRules.length === 0) {
throw new UnauthorizedException(
`Login failed: The ${provider.name} provider has no authorization rules configured. ` +
`Please configure authorization rules.`
);
}
this.logger.debug('Authorization rules to evaluate: %o', provider.authorizationRules);
// Evaluate the rules
const ruleMode = provider.authorizationRuleMode || AuthorizationRuleMode.OR;
const isAuthorized = this.evaluateAuthorizationRules(
provider.authorizationRules,
claims,
ruleMode
);
this.logger.debug(`Authorization result: ${isAuthorized}`);
if (!isAuthorized) {
// Log authorization failure with safe claim representation (no PII)
const availableClaimKeys = Object.keys(claims).join(', ');
this.logger.warn(
`Authorization failed for provider ${provider.name}, user ${claims.sub}, available claim keys: [${availableClaimKeys}]`
);
throw new UnauthorizedException(
`Access denied: Your account does not meet the authorization requirements for ${provider.name}.`
);
}
this.logger.debug(`Authorization successful for user ${claims.sub}`);
}
private evaluateAuthorizationRules(
rules: OidcAuthorizationRule[],
claims: JwtClaims,
mode: AuthorizationRuleMode = AuthorizationRuleMode.OR
): boolean {
// No rules means no authorization
if (rules.length === 0) {
return false;
}
if (mode === AuthorizationRuleMode.AND) {
// All rules must pass (AND logic)
return rules.every((rule) => this.evaluateRule(rule, claims));
} else {
// Any rule can pass (OR logic) - default behavior
// Multiple rules act as alternative authorization paths
return rules.some((rule) => this.evaluateRule(rule, claims));
}
}
private evaluateRule(rule: OidcAuthorizationRule, claims: JwtClaims): boolean {
const claimValue = claims[rule.claim];
this.logger.verbose(
`Evaluating rule for claim ${rule.claim}: { claimType: ${typeof claimValue}, isArray: ${Array.isArray(claimValue)}, ruleOperator: ${rule.operator}, ruleValuesCount: ${rule.value.length} }`
);
if (claimValue === undefined || claimValue === null) {
this.logger.verbose(`Claim ${rule.claim} not found in token`);
return false;
}
// Handle non-array, non-string objects
if (typeof claimValue === 'object' && claimValue !== null && !Array.isArray(claimValue)) {
this.logger.warn(
`unexpected JWT claim value encountered - claim ${rule.claim} has unsupported object type (keys: [${Object.keys(claimValue as Record<string, unknown>).join(', ')}])`
);
return false;
}
// Handle array claims - evaluate rule against each array element
if (Array.isArray(claimValue)) {
this.logger.verbose(
`Processing array claim ${rule.claim} with ${claimValue.length} elements`
);
// For array claims, check if ANY element in the array matches the rule
const arrayResult = claimValue.some((element) => {
// Skip non-string elements
if (
typeof element !== 'string' &&
typeof element !== 'number' &&
typeof element !== 'boolean'
) {
this.logger.verbose(`Skipping non-primitive element in array: ${typeof element}`);
return false;
}
const elementValue = String(element);
return this.evaluateSingleValue(elementValue, rule);
});
this.logger.verbose(`Array evaluation result for claim ${rule.claim}: ${arrayResult}`);
return arrayResult;
}
// Handle single value claims (string, number, boolean)
const value = String(claimValue);
this.logger.verbose(`Processing single value claim ${rule.claim}`);
return this.evaluateSingleValue(value, rule);
}
private evaluateSingleValue(value: string, rule: OidcAuthorizationRule): boolean {
let result: boolean;
switch (rule.operator) {
case AuthorizationOperator.EQUALS:
result = rule.value.some((v) => value === v);
this.logger.verbose(`EQUALS check: evaluated for claim ${rule.claim}: ${result}`);
return result;
case AuthorizationOperator.CONTAINS:
result = rule.value.some((v) => value.includes(v));
this.logger.verbose(`CONTAINS check: evaluated for claim ${rule.claim}: ${result}`);
return result;
case AuthorizationOperator.STARTS_WITH:
result = rule.value.some((v) => value.startsWith(v));
this.logger.verbose(`STARTS_WITH check: evaluated for claim ${rule.claim}: ${result}`);
return result;
case AuthorizationOperator.ENDS_WITH:
result = rule.value.some((v) => value.endsWith(v));
this.logger.verbose(`ENDS_WITH check: evaluated for claim ${rule.claim}: ${result}`);
return result;
default:
this.logger.error(`Unknown authorization operator: ${rule.operator}`);
return false;
}
}
}

View File

@@ -0,0 +1,218 @@
import { UnauthorizedException } from '@nestjs/common';
import { Test, TestingModule } from '@nestjs/testing';
import { decodeJwt } from 'jose';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import {
JwtClaims,
OidcClaimsService,
} from '@app/unraid-api/graph/resolvers/sso/auth/oidc-claims.service.js';
// Mock jose
vi.mock('jose', () => ({
decodeJwt: vi.fn(),
}));
describe('OidcClaimsService', () => {
let service: OidcClaimsService;
beforeEach(async () => {
vi.clearAllMocks();
const module: TestingModule = await Test.createTestingModule({
providers: [OidcClaimsService],
}).compile();
service = module.get<OidcClaimsService>(OidcClaimsService);
});
describe('parseIdToken', () => {
it('should parse valid ID token', () => {
const mockClaims: JwtClaims = {
sub: 'user123',
email: 'user@example.com',
name: 'Test User',
iat: 1234567890,
exp: 1234567890,
};
(decodeJwt as any).mockReturnValue(mockClaims);
const result = service.parseIdToken('valid.jwt.token');
expect(result).toEqual(mockClaims);
expect(decodeJwt).toHaveBeenCalledWith('valid.jwt.token');
});
it('should return null when no token provided', () => {
const result = service.parseIdToken(undefined);
expect(result).toBeNull();
});
it('should return null when token parsing fails', () => {
(decodeJwt as any).mockImplementation(() => {
throw new Error('Invalid token');
});
const result = service.parseIdToken('invalid.token');
expect(result).toBeNull();
});
it('should handle claims with array values', () => {
const mockClaims: JwtClaims = {
sub: 'user123',
groups: ['admin', 'user'],
roles: ['role1', 'role2', 'role3'],
};
(decodeJwt as any).mockReturnValue(mockClaims);
const result = service.parseIdToken('token.with.arrays');
expect(result).toEqual(mockClaims);
});
it('should log warning for complex object claims', () => {
const loggerSpy = vi.spyOn(service['logger'], 'warn');
const mockClaims: JwtClaims = {
sub: 'user123',
complexClaim: {
nested: 'value',
another: 'field',
},
};
(decodeJwt as any).mockReturnValue(mockClaims);
service.parseIdToken('token.with.complex');
expect(loggerSpy).toHaveBeenCalledWith(expect.stringContaining('complex object structure'));
});
it('should handle Google-specific claims', () => {
const mockClaims: JwtClaims = {
sub: 'google-user-id',
email: 'user@company.com',
name: 'Google User',
hd: 'company.com', // Google hosted domain
};
(decodeJwt as any).mockReturnValue(mockClaims);
const result = service.parseIdToken('google.jwt.token');
expect(result).toEqual(mockClaims);
expect(result?.hd).toBe('company.com');
});
});
describe('validateClaims', () => {
it('should return user sub when claims are valid', () => {
const claims: JwtClaims = {
sub: 'user123',
email: 'user@example.com',
};
const result = service.validateClaims(claims);
expect(result).toBe('user123');
});
it('should throw UnauthorizedException when claims are null', () => {
expect(() => service.validateClaims(null)).toThrow(UnauthorizedException);
});
it('should throw UnauthorizedException when sub is missing', () => {
const claims: JwtClaims = {
email: 'user@example.com',
name: 'User',
};
expect(() => service.validateClaims(claims)).toThrow(UnauthorizedException);
});
it('should throw UnauthorizedException when sub is empty', () => {
const claims: JwtClaims = {
sub: '',
email: 'user@example.com',
};
expect(() => service.validateClaims(claims)).toThrow(UnauthorizedException);
});
});
describe('extractUserInfo', () => {
it('should extract basic user information', () => {
const claims: JwtClaims = {
sub: 'user123',
email: 'user@example.com',
name: 'Test User',
};
const result = service.extractUserInfo(claims);
expect(result).toEqual({
sub: 'user123',
email: 'user@example.com',
name: 'Test User',
domain: undefined,
});
});
it('should extract Google hosted domain', () => {
const claims: JwtClaims = {
sub: 'google-user',
email: 'user@company.com',
name: 'Google User',
hd: 'company.com',
};
const result = service.extractUserInfo(claims);
expect(result).toEqual({
sub: 'google-user',
email: 'user@company.com',
name: 'Google User',
domain: 'company.com',
});
});
it('should handle missing optional fields', () => {
const claims: JwtClaims = {
sub: 'user123',
};
const result = service.extractUserInfo(claims);
expect(result).toEqual({
sub: 'user123',
email: undefined,
name: undefined,
domain: undefined,
});
});
it('should ignore extra claims', () => {
const claims: JwtClaims = {
sub: 'user123',
email: 'user@example.com',
name: 'Test User',
extra: 'claim',
another: 'field',
groups: ['admin'],
};
const result = service.extractUserInfo(claims);
expect(result).toEqual({
sub: 'user123',
email: 'user@example.com',
name: 'Test User',
domain: undefined,
});
expect(result).not.toHaveProperty('extra');
expect(result).not.toHaveProperty('groups');
});
});
});

View File

@@ -0,0 +1,80 @@
import { Injectable, Logger, UnauthorizedException } from '@nestjs/common';
import { decodeJwt } from 'jose';
export interface JwtClaims {
sub?: string;
email?: string;
name?: string;
hd?: string; // Google hosted domain
[claim: string]: unknown;
}
@Injectable()
export class OidcClaimsService {
private readonly logger = new Logger(OidcClaimsService.name);
parseIdToken(idToken: string | undefined): JwtClaims | null {
if (!idToken) {
this.logger.error('No ID token received from provider');
return null;
}
try {
// Use jose to properly decode the JWT
const claims = decodeJwt(idToken) as JwtClaims;
// Log claims safely without PII - only structure, not values
if (claims) {
const claimKeys = Object.keys(claims).join(', ');
this.logger.debug(`ID token decoded successfully. Available claims: [${claimKeys}]`);
// Log claim types without exposing sensitive values
for (const [key, value] of Object.entries(claims)) {
const valueType = Array.isArray(value) ? `array[${value.length}]` : typeof value;
// Only log structure, not actual values (avoid PII)
this.logger.debug(`Claim '${key}': type=${valueType}`);
// Check for unexpected claim types
if (valueType === 'object' && value !== null && !Array.isArray(value)) {
this.logger.warn(`Claim '${key}' contains complex object structure`);
}
}
}
return claims;
} catch (e) {
this.logger.warn(`Failed to parse ID token: ${e}`);
return null;
}
}
validateClaims(claims: JwtClaims | null): string {
if (!claims?.sub) {
this.logger.error(
'No subject in token - claims available: ' +
(claims ? Object.keys(claims).join(', ') : 'none')
);
throw new UnauthorizedException('No subject in token');
}
const userSub = claims.sub;
this.logger.debug(`Processing authentication for user: ${userSub}`);
return userSub;
}
extractUserInfo(claims: JwtClaims): {
sub: string;
email?: string;
name?: string;
domain?: string;
} {
return {
sub: claims.sub!,
email: claims.email,
name: claims.name,
domain: claims.hd,
};
}
}

View File

@@ -0,0 +1,224 @@
import { Logger } from '@nestjs/common';
import * as client from 'openid-client';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { OidcTokenExchangeService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-token-exchange.service.js';
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
vi.mock('openid-client', () => ({
authorizationCodeGrant: vi.fn(),
allowInsecureRequests: vi.fn(),
}));
describe('OidcTokenExchangeService', () => {
let service: OidcTokenExchangeService;
let mockConfig: client.Configuration;
let mockProvider: OidcProvider;
beforeEach(() => {
service = new OidcTokenExchangeService();
mockConfig = {
serverMetadata: vi.fn().mockReturnValue({
issuer: 'https://example.com',
token_endpoint: 'https://example.com/token',
response_types_supported: ['code'],
grant_types_supported: ['authorization_code'],
token_endpoint_auth_methods_supported: ['client_secret_post'],
}),
} as unknown as client.Configuration;
mockProvider = {
id: 'test-provider',
issuer: 'https://example.com',
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
} as OidcProvider;
vi.clearAllMocks();
});
describe('exchangeCodeForTokens', () => {
it('should handle malformed fullCallbackUrl gracefully', async () => {
const code = 'test-code';
const state = 'test-state';
const redirectUri = 'https://example.com/callback';
const malformedUrl = 'not://a valid url';
const mockTokens = {
access_token: 'test-access-token',
id_token: 'test-id-token',
};
vi.mocked(client.authorizationCodeGrant).mockResolvedValue(mockTokens as any);
const loggerWarnSpy = vi.spyOn(Logger.prototype, 'warn').mockImplementation(() => {});
const loggerDebugSpy = vi.spyOn(Logger.prototype, 'debug').mockImplementation(() => {});
const result = await service.exchangeCodeForTokens(
mockConfig,
mockProvider,
code,
state,
redirectUri,
malformedUrl
);
expect(result).toEqual(mockTokens);
expect(loggerWarnSpy).toHaveBeenCalledWith(
expect.stringContaining('Failed to parse fullCallbackUrl'),
expect.any(Error)
);
expect(client.authorizationCodeGrant).toHaveBeenCalled();
});
it('should handle empty fullCallbackUrl without throwing', async () => {
const code = 'test-code';
const state = 'test-state';
const redirectUri = 'https://example.com/callback';
const mockTokens = {
access_token: 'test-access-token',
id_token: 'test-id-token',
};
vi.mocked(client.authorizationCodeGrant).mockResolvedValue(mockTokens as any);
const loggerWarnSpy = vi.spyOn(Logger.prototype, 'warn').mockImplementation(() => {});
const result = await service.exchangeCodeForTokens(
mockConfig,
mockProvider,
code,
state,
redirectUri,
''
);
expect(result).toEqual(mockTokens);
expect(loggerWarnSpy).not.toHaveBeenCalled();
expect(client.authorizationCodeGrant).toHaveBeenCalled();
});
it('should handle whitespace-only fullCallbackUrl without throwing', async () => {
const code = 'test-code';
const state = 'test-state';
const redirectUri = 'https://example.com/callback';
const mockTokens = {
access_token: 'test-access-token',
id_token: 'test-id-token',
};
vi.mocked(client.authorizationCodeGrant).mockResolvedValue(mockTokens as any);
const loggerWarnSpy = vi.spyOn(Logger.prototype, 'warn').mockImplementation(() => {});
const result = await service.exchangeCodeForTokens(
mockConfig,
mockProvider,
code,
state,
redirectUri,
' '
);
expect(result).toEqual(mockTokens);
expect(loggerWarnSpy).not.toHaveBeenCalled();
expect(client.authorizationCodeGrant).toHaveBeenCalled();
});
it('should copy parameters from valid fullCallbackUrl', async () => {
const code = 'test-code';
const state = 'test-state';
const redirectUri = 'https://example.com/callback';
const fullCallbackUrl =
'https://example.com/callback?code=test-code&state=test-state&scope=openid&authuser=0';
const mockTokens = {
access_token: 'test-access-token',
id_token: 'test-id-token',
};
vi.mocked(client.authorizationCodeGrant).mockResolvedValue(mockTokens as any);
const loggerWarnSpy = vi.spyOn(Logger.prototype, 'warn').mockImplementation(() => {});
const loggerDebugSpy = vi.spyOn(Logger.prototype, 'debug').mockImplementation(() => {});
const result = await service.exchangeCodeForTokens(
mockConfig,
mockProvider,
code,
state,
redirectUri,
fullCallbackUrl
);
expect(result).toEqual(mockTokens);
expect(loggerWarnSpy).not.toHaveBeenCalled();
const authCodeGrantCall = vi.mocked(client.authorizationCodeGrant).mock.calls[0];
const cleanUrl = authCodeGrantCall[1] as URL;
expect(cleanUrl.searchParams.get('scope')).toBe('openid');
expect(cleanUrl.searchParams.get('authuser')).toBe('0');
});
it('should handle undefined fullCallbackUrl', async () => {
const code = 'test-code';
const state = 'test-state';
const redirectUri = 'https://example.com/callback';
const mockTokens = {
access_token: 'test-access-token',
id_token: 'test-id-token',
};
vi.mocked(client.authorizationCodeGrant).mockResolvedValue(mockTokens as any);
const loggerWarnSpy = vi.spyOn(Logger.prototype, 'warn').mockImplementation(() => {});
const result = await service.exchangeCodeForTokens(
mockConfig,
mockProvider,
code,
state,
redirectUri,
undefined
);
expect(result).toEqual(mockTokens);
expect(loggerWarnSpy).not.toHaveBeenCalled();
expect(client.authorizationCodeGrant).toHaveBeenCalled();
});
it('should handle non-string fullCallbackUrl types gracefully', async () => {
const code = 'test-code';
const state = 'test-state';
const redirectUri = 'https://example.com/callback';
const mockTokens = {
access_token: 'test-access-token',
id_token: 'test-id-token',
};
vi.mocked(client.authorizationCodeGrant).mockResolvedValue(mockTokens as any);
const loggerWarnSpy = vi.spyOn(Logger.prototype, 'warn').mockImplementation(() => {});
const result = await service.exchangeCodeForTokens(
mockConfig,
mockProvider,
code,
state,
redirectUri,
123 as any
);
expect(result).toEqual(mockTokens);
expect(loggerWarnSpy).not.toHaveBeenCalled();
expect(client.authorizationCodeGrant).toHaveBeenCalled();
});
});
});

View File

@@ -0,0 +1,174 @@
import { Injectable, Logger } from '@nestjs/common';
import * as client from 'openid-client';
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
import { ErrorExtractor } from '@app/unraid-api/utils/error-extractor.util.js';
// Extended type for our internal use - openid-client v6 doesn't directly expose
// skip options for aud/iss checks, so we'll handle validation errors differently
type ExtendedGrantChecks = client.AuthorizationCodeGrantChecks;
@Injectable()
export class OidcTokenExchangeService {
private readonly logger = new Logger(OidcTokenExchangeService.name);
async exchangeCodeForTokens(
config: client.Configuration,
provider: OidcProvider,
code: string,
state: string,
redirectUri: string,
fullCallbackUrl?: string
): Promise<client.TokenEndpointResponse> {
this.logger.debug(`Provider ${provider.id} config loaded`);
this.logger.debug(`Redirect URI: ${redirectUri}`);
// Build current URL for token exchange
// CRITICAL: The URL used here MUST match the redirect_uri that was sent to the authorization endpoint
// Google expects the exact same redirect_uri during token exchange
const currentUrl = new URL(redirectUri);
currentUrl.searchParams.set('code', code);
currentUrl.searchParams.set('state', state);
// Copy additional parameters from the actual callback if provided
if (fullCallbackUrl && typeof fullCallbackUrl === 'string' && fullCallbackUrl.trim()) {
try {
const actualUrl = new URL(fullCallbackUrl);
// Copy over additional params that Google might have added (scope, authuser, prompt, etc)
// but DO NOT change the base URL or path
['scope', 'authuser', 'prompt', 'hd', 'session_state', 'iss'].forEach((param) => {
const value = actualUrl.searchParams.get(param);
if (value && !currentUrl.searchParams.has(param)) {
currentUrl.searchParams.set(param, value);
}
});
} catch (urlError) {
this.logger.warn(`Failed to parse fullCallbackUrl: ${fullCallbackUrl}`, urlError);
// Continue with the existing currentUrl flow without additional params
}
}
// Google returns iss in the response, openid-client v6 expects it
// If not present, add it based on the provider's issuer
if (!currentUrl.searchParams.has('iss') && provider.issuer) {
currentUrl.searchParams.set('iss', provider.issuer);
}
this.logger.debug(`Token exchange URL (matches redirect_uri): ${currentUrl.href}`);
// For openid-client v6, we need to prepare the authorization response
const authorizationResponse = new URLSearchParams(currentUrl.search);
// Set the original client state for openid-client
authorizationResponse.set('state', state);
// Create a new URL with the cleaned parameters
const cleanUrl = new URL(redirectUri);
cleanUrl.search = authorizationResponse.toString();
this.logger.debug(`Clean URL for token exchange: ${cleanUrl.href}`);
try {
this.logger.debug(`Starting token exchange with openid-client`);
this.logger.debug(`Config issuer: ${config.serverMetadata().issuer}`);
this.logger.debug(`Config token endpoint: ${config.serverMetadata().token_endpoint}`);
// Log the complete token exchange request details
const tokenEndpoint = config.serverMetadata().token_endpoint;
this.logger.debug(`Full token endpoint URL: ${tokenEndpoint}`);
this.logger.debug(`Authorization code: ${code.substring(0, 10)}...`);
this.logger.debug(`Redirect URI in token request: ${redirectUri}`);
this.logger.debug(`Client ID: ${provider.clientId}`);
this.logger.debug(`Client secret configured: ${provider.clientSecret ? 'Yes' : 'No'}`);
this.logger.debug(`Expected state value: ${state}`);
// Log the server metadata to check for any configuration issues
const metadata = config.serverMetadata();
this.logger.debug(
`Server supports response types: ${metadata.response_types_supported?.join(', ') || 'not specified'}`
);
this.logger.debug(
`Server grant types: ${metadata.grant_types_supported?.join(', ') || 'not specified'}`
);
this.logger.debug(
`Token endpoint auth methods: ${metadata.token_endpoint_auth_methods_supported?.join(', ') || 'not specified'}`
);
// For HTTP endpoints, we need to call allowInsecureRequests on the config
if (provider.issuer) {
try {
const serverUrl = new URL(provider.issuer);
if (serverUrl.protocol === 'http:') {
this.logger.debug(
`Allowing insecure requests for HTTP endpoint: ${provider.id}`
);
// allowInsecureRequests is deprecated but still needed for HTTP endpoints
client.allowInsecureRequests(config);
}
} catch (error) {
this.logger.warn(
`Invalid issuer URL for provider ${provider.id}: ${provider.issuer}`
);
// Continue without special HTTP options
}
}
const requestChecks: ExtendedGrantChecks = {
expectedState: state,
};
// Log what we're about to send
this.logger.debug(`Executing authorizationCodeGrant with:`);
this.logger.debug(`- Clean URL: ${cleanUrl.href}`);
this.logger.debug(`- Expected state: ${state}`);
this.logger.debug(`- Grant type: authorization_code`);
const tokens = await client.authorizationCodeGrant(config, cleanUrl, requestChecks);
this.logger.debug(
`Token exchange successful, received tokens: ${Object.keys(tokens).join(', ')}`
);
return tokens;
} catch (tokenError) {
// Extract and log error details using the utility
const extracted = ErrorExtractor.extract(tokenError);
this.logger.error('Token exchange failed');
ErrorExtractor.formatForLogging(extracted, this.logger);
// Special handling for content-type and parsing errors
if (ErrorExtractor.isOAuthResponseError(extracted)) {
this.logger.error('Token endpoint returned invalid or non-JSON response.');
this.logger.error('This typically means:');
this.logger.error(
'1. The token endpoint URL is incorrect (check for typos or wrong paths)'
);
this.logger.error('2. The server returned an HTML error page instead of JSON');
this.logger.error('3. Authentication failed (invalid client_id or client_secret)');
this.logger.error('4. A proxy/firewall is intercepting the request');
this.logger.error('5. The OAuth server returned malformed JSON');
this.logger.error(
`Configured token endpoint: ${config.serverMetadata().token_endpoint}`
);
this.logger.error('Please verify your OIDC provider configuration.');
}
// Check if error message contains the "unexpected JWT claim" text
if (ErrorExtractor.isJwtClaimError(extracted)) {
this.logger.error(
`unexpected JWT claim value encountered during token validation by openid-client`
);
this.logger.error(
`This error typically means the 'iss' claim in the JWT doesn't match the expected issuer`
);
this.logger.error(`Check that your provider's issuer URL is configured correctly`);
this.logger.error(`Expected issuer: ${config.serverMetadata().issuer}`);
this.logger.error(`Provider configured issuer: ${provider.issuer}`);
}
// Re-throw the original error with all its properties intact
throw tokenError;
}
}
}

View File

@@ -0,0 +1,267 @@
import { Test, TestingModule } from '@nestjs/testing';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { OidcClientConfigService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-client-config.service.js';
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/core/oidc-validation.service.js';
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
describe('OidcClientConfigService', () => {
let service: OidcClientConfigService;
let validationService: any;
beforeEach(async () => {
const module: TestingModule = await Test.createTestingModule({
providers: [
OidcClientConfigService,
{
provide: OidcValidationService,
useValue: {
performDiscovery: vi.fn(),
},
},
],
}).compile();
service = module.get<OidcClientConfigService>(OidcClientConfigService);
validationService = module.get(OidcValidationService);
});
describe('Manual Configuration', () => {
it('should create manual configuration when discovery fails but manual endpoints are provided', async () => {
const provider: OidcProvider = {
id: 'manual-provider',
name: 'Manual Provider',
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
issuer: 'https://manual.example.com',
authorizationEndpoint: 'https://manual.example.com/auth',
tokenEndpoint: 'https://manual.example.com/token',
jwksUri: 'https://manual.example.com/jwks',
scopes: ['openid', 'profile'],
authorizationRules: [],
};
// Mock discovery to fail
validationService.performDiscovery.mockRejectedValue(new Error('Discovery failed'));
const config = await service.getOrCreateConfig(provider);
// Verify the configuration was created with the correct endpoints
expect(config).toBeDefined();
expect(config.serverMetadata().authorization_endpoint).toBe(
'https://manual.example.com/auth'
);
expect(config.serverMetadata().token_endpoint).toBe('https://manual.example.com/token');
expect(config.serverMetadata().jwks_uri).toBe('https://manual.example.com/jwks');
expect(config.serverMetadata().issuer).toBe('https://manual.example.com');
});
it('should create manual configuration with fallback issuer when not provided', async () => {
const provider: OidcProvider = {
id: 'manual-provider-no-issuer',
name: 'Manual Provider No Issuer',
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
issuer: '', // Empty issuer should skip discovery and use manual endpoints
authorizationEndpoint: 'https://manual.example.com/auth',
tokenEndpoint: 'https://manual.example.com/token',
scopes: ['openid', 'profile'],
authorizationRules: [],
};
const config = await service.getOrCreateConfig(provider);
// Verify the configuration was created with inferred issuer from endpoints
expect(config).toBeDefined();
expect(config.serverMetadata().issuer).toBe('https://manual.example.com');
expect(config.serverMetadata().authorization_endpoint).toBe(
'https://manual.example.com/auth'
);
expect(config.serverMetadata().token_endpoint).toBe('https://manual.example.com/token');
});
it('should handle manual configuration with client secret properly', async () => {
const provider: OidcProvider = {
id: 'manual-with-secret',
name: 'Manual With Secret',
clientId: 'test-client-id',
clientSecret: 'secret-123',
issuer: 'https://manual.example.com',
authorizationEndpoint: 'https://manual.example.com/auth',
tokenEndpoint: 'https://manual.example.com/token',
scopes: ['openid', 'profile'],
authorizationRules: [],
};
// Mock discovery to fail
validationService.performDiscovery.mockRejectedValue(new Error('Discovery failed'));
const config = await service.getOrCreateConfig(provider);
// Verify configuration was created successfully
expect(config).toBeDefined();
expect(config.clientMetadata().client_secret).toBe('secret-123');
});
it('should handle manual configuration without client secret (public client)', async () => {
const provider: OidcProvider = {
id: 'manual-public-client',
name: 'Manual Public Client',
clientId: 'public-client-id',
// No client secret
issuer: 'https://manual.example.com',
authorizationEndpoint: 'https://manual.example.com/auth',
tokenEndpoint: 'https://manual.example.com/token',
scopes: ['openid', 'profile'],
authorizationRules: [],
};
// Mock discovery to fail
validationService.performDiscovery.mockRejectedValue(new Error('Discovery failed'));
const config = await service.getOrCreateConfig(provider);
// Verify configuration was created successfully
expect(config).toBeDefined();
expect(config.clientMetadata().client_secret).toBeUndefined();
});
it('should cache configurations', async () => {
const provider: OidcProvider = {
id: 'cached-provider',
name: 'Cached Provider',
clientId: 'test-client-id',
issuer: '',
authorizationEndpoint: 'https://cached.example.com/auth',
tokenEndpoint: 'https://cached.example.com/token',
scopes: ['openid'],
authorizationRules: [],
};
// First call
const config1 = await service.getOrCreateConfig(provider);
// Second call - should return cached value
const config2 = await service.getOrCreateConfig(provider);
// Should be the exact same object
expect(config1).toBe(config2);
expect(service.getCacheSize()).toBe(1);
});
it('should clear cache for specific provider', async () => {
const provider: OidcProvider = {
id: 'provider-to-clear',
name: 'Provider to Clear',
clientId: 'test-client-id',
issuer: '',
authorizationEndpoint: 'https://clear.example.com/auth',
tokenEndpoint: 'https://clear.example.com/token',
scopes: ['openid'],
authorizationRules: [],
};
await service.getOrCreateConfig(provider);
expect(service.getCacheSize()).toBe(1);
service.clearCache('provider-to-clear');
expect(service.getCacheSize()).toBe(0);
});
it('should clear entire cache', async () => {
const provider1: OidcProvider = {
id: 'provider1',
name: 'Provider 1',
clientId: 'client1',
issuer: '',
authorizationEndpoint: 'https://p1.example.com/auth',
tokenEndpoint: 'https://p1.example.com/token',
scopes: ['openid'],
authorizationRules: [],
};
const provider2: OidcProvider = {
id: 'provider2',
name: 'Provider 2',
clientId: 'client2',
issuer: '',
authorizationEndpoint: 'https://p2.example.com/auth',
tokenEndpoint: 'https://p2.example.com/token',
scopes: ['openid'],
authorizationRules: [],
};
await service.getOrCreateConfig(provider1);
await service.getOrCreateConfig(provider2);
expect(service.getCacheSize()).toBe(2);
service.clearCache();
expect(service.getCacheSize()).toBe(0);
});
});
describe('Discovery Configuration', () => {
it('should use discovery when issuer is provided', async () => {
const provider: OidcProvider = {
id: 'discovery-provider',
name: 'Discovery Provider',
clientId: 'test-client-id',
clientSecret: 'test-secret',
issuer: 'https://discovery.example.com',
scopes: ['openid', 'profile'],
authorizationRules: [],
};
const mockConfig = {
serverMetadata: vi.fn().mockReturnValue({
issuer: 'https://discovery.example.com',
authorization_endpoint: 'https://discovery.example.com/authorize',
token_endpoint: 'https://discovery.example.com/token',
jwks_uri: 'https://discovery.example.com/.well-known/jwks.json',
userinfo_endpoint: 'https://discovery.example.com/userinfo',
}),
clientMetadata: vi.fn().mockReturnValue({}),
};
validationService.performDiscovery.mockResolvedValue(mockConfig);
const config = await service.getOrCreateConfig(provider);
expect(validationService.performDiscovery).toHaveBeenCalledWith(provider, undefined);
expect(config).toBe(mockConfig);
});
it('should allow HTTP for discovery when issuer uses HTTP', async () => {
const provider: OidcProvider = {
id: 'http-discovery-provider',
name: 'HTTP Discovery Provider',
clientId: 'test-client-id',
issuer: 'http://discovery.example.com',
scopes: ['openid'],
authorizationRules: [],
};
const mockConfig = {
serverMetadata: vi.fn().mockReturnValue({
issuer: 'http://discovery.example.com',
authorization_endpoint: 'http://discovery.example.com/authorize',
token_endpoint: 'http://discovery.example.com/token',
}),
clientMetadata: vi.fn().mockReturnValue({}),
};
validationService.performDiscovery.mockResolvedValue(mockConfig);
const config = await service.getOrCreateConfig(provider);
expect(validationService.performDiscovery).toHaveBeenCalledWith(
provider,
expect.objectContaining({
execute: expect.any(Array),
})
);
expect(config).toBe(mockConfig);
});
});
});

View File

@@ -0,0 +1,168 @@
import { Injectable, Logger, UnauthorizedException } from '@nestjs/common';
import * as client from 'openid-client';
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/core/oidc-validation.service.js';
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
import { ErrorExtractor } from '@app/unraid-api/utils/error-extractor.util.js';
@Injectable()
export class OidcClientConfigService {
private readonly logger = new Logger(OidcClientConfigService.name);
private readonly configCache = new Map<string, client.Configuration>();
constructor(private readonly validationService: OidcValidationService) {}
async getOrCreateConfig(provider: OidcProvider): Promise<client.Configuration> {
const cacheKey = provider.id;
if (this.configCache.has(cacheKey)) {
return this.configCache.get(cacheKey)!;
}
try {
// Use the validation service to perform discovery with HTTP support
if (provider.issuer) {
this.logger.debug(`Attempting discovery for ${provider.id} at ${provider.issuer}`);
// Create client options with HTTP support if needed
const serverUrl = new URL(provider.issuer);
let clientOptions: client.DiscoveryRequestOptions | undefined;
if (serverUrl.protocol === 'http:') {
this.logger.debug(`Allowing HTTP for ${provider.id} as specified by user`);
clientOptions = {
execute: [client.allowInsecureRequests],
};
}
try {
const config = await this.validationService.performDiscovery(
provider,
clientOptions
);
this.logger.debug(`Discovery successful for ${provider.id}`);
this.logger.debug(
`Authorization endpoint: ${config.serverMetadata().authorization_endpoint}`
);
this.logger.debug(`Token endpoint: ${config.serverMetadata().token_endpoint}`);
this.logger.debug(`JWKS URI: ${config.serverMetadata().jwks_uri || 'Not provided'}`);
this.logger.debug(
`Userinfo endpoint: ${config.serverMetadata().userinfo_endpoint || 'Not provided'}`
);
this.configCache.set(cacheKey, config);
return config;
} catch (discoveryError) {
const extracted = ErrorExtractor.extract(discoveryError);
this.logger.warn(`Discovery failed for ${provider.id}: ${extracted.message}`);
// Log more details about the discovery error
const discoveryUrl = `${provider.issuer}/.well-known/openid-configuration`;
this.logger.debug(`Discovery URL attempted: ${discoveryUrl}`);
// Use error extractor for consistent logging
ErrorExtractor.formatForLogging(extracted, this.logger);
// If discovery fails but we have manual endpoints, use them
if (provider.authorizationEndpoint && provider.tokenEndpoint) {
this.logger.log(`Using manual endpoints for ${provider.id}`);
return this.createManualConfiguration(provider, cacheKey);
} else {
throw new Error(
`OIDC discovery failed and no manual endpoints provided for ${provider.id}`
);
}
}
}
// Manual configuration when no issuer is provided
if (provider.authorizationEndpoint && provider.tokenEndpoint) {
this.logger.log(`Using manual endpoints for ${provider.id} (no issuer provided)`);
return this.createManualConfiguration(provider, cacheKey);
}
// If we reach here, neither discovery nor manual endpoints are available
throw new Error(
`No configuration method available for ${provider.id}: requires either valid issuer for discovery or manual endpoints`
);
} catch (error) {
const extracted = ErrorExtractor.extract(error);
this.logger.error(
`Failed to create OIDC configuration for ${provider.id}: ${extracted.message}`
);
// Log more details in debug mode
if (extracted.stack) {
this.logger.debug(`Stack trace: ${extracted.stack}`);
}
throw new UnauthorizedException('Provider configuration error');
}
}
private createManualConfiguration(provider: OidcProvider, cacheKey: string): client.Configuration {
// Create manual configuration with a valid issuer URL
const inferredIssuer =
provider.issuer && provider.issuer.trim() !== ''
? provider.issuer
: new URL(provider.authorizationEndpoint ?? provider.tokenEndpoint!).origin;
const serverMetadata: client.ServerMetadata = {
issuer: inferredIssuer,
authorization_endpoint: provider.authorizationEndpoint!,
token_endpoint: provider.tokenEndpoint!,
jwks_uri: provider.jwksUri,
};
const clientMetadata: Partial<client.ClientMetadata> = {
client_secret: provider.clientSecret,
};
// Configure client auth method
const clientAuth = provider.clientSecret
? client.ClientSecretPost(provider.clientSecret)
: client.None();
try {
const config = new client.Configuration(
serverMetadata,
provider.clientId,
clientMetadata,
clientAuth
);
// Allow HTTP if any configured endpoint uses http
const endpoints = [
serverMetadata.authorization_endpoint,
serverMetadata.token_endpoint,
].filter(Boolean) as string[];
const hasHttp = endpoints.some((e) => new URL(e).protocol === 'http:');
if (hasHttp) {
this.logger.debug(`Allowing HTTP for manual endpoints on ${provider.id}`);
// allowInsecureRequests is deprecated but still needed for HTTP endpoints
client.allowInsecureRequests(config);
}
this.logger.debug(`Manual configuration created for ${provider.id}`);
this.logger.debug(`Authorization endpoint: ${serverMetadata.authorization_endpoint}`);
this.logger.debug(`Token endpoint: ${serverMetadata.token_endpoint}`);
this.configCache.set(cacheKey, config);
return config;
} catch (manualConfigError) {
const extracted = ErrorExtractor.extract(manualConfigError);
this.logger.error(`Failed to create manual configuration: ${extracted.message}`);
throw new Error(`Manual configuration failed for ${provider.id}`);
}
}
clearCache(providerId?: string): void {
if (providerId) {
this.configCache.delete(providerId);
} else {
this.configCache.clear();
}
}
getCacheSize(): number {
return this.configCache.size;
}
}

View File

@@ -0,0 +1,12 @@
import { Module } from '@nestjs/common';
import { OidcClientConfigService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-client-config.service.js';
import { OidcRedirectUriService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-redirect-uri.service.js';
import { OidcBaseModule } from '@app/unraid-api/graph/resolvers/sso/core/oidc-base.module.js';
@Module({
imports: [OidcBaseModule],
providers: [OidcClientConfigService, OidcRedirectUriService],
exports: [OidcClientConfigService, OidcRedirectUriService],
})
export class OidcClientModule {}

View File

@@ -0,0 +1,222 @@
import { UnauthorizedException } from '@nestjs/common';
import { Test, TestingModule } from '@nestjs/testing';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { OidcRedirectUriService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-redirect-uri.service.js';
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
import { validateRedirectUri } from '@app/unraid-api/utils/redirect-uri-validator.js';
// Mock the redirect URI validator
vi.mock('@app/unraid-api/utils/redirect-uri-validator.js', () => ({
validateRedirectUri: vi.fn(),
}));
describe('OidcRedirectUriService', () => {
let service: OidcRedirectUriService;
let oidcConfig: any;
beforeEach(async () => {
vi.clearAllMocks();
const module: TestingModule = await Test.createTestingModule({
providers: [
OidcRedirectUriService,
{
provide: OidcConfigPersistence,
useValue: {
getConfig: vi.fn().mockResolvedValue({
providers: [],
defaultAllowedOrigins: ['https://allowed.example.com'],
}),
},
},
],
}).compile();
service = module.get<OidcRedirectUriService>(OidcRedirectUriService);
oidcConfig = module.get(OidcConfigPersistence);
});
describe('getRedirectUri', () => {
it('should return valid redirect URI when validation passes', async () => {
const requestOrigin = 'https://example.com';
const requestHeaders = {
'x-forwarded-proto': 'https',
'x-forwarded-host': 'example.com',
};
(validateRedirectUri as any).mockReturnValue({
isValid: true,
validatedUri: 'https://example.com',
});
const result = await service.getRedirectUri(requestOrigin, requestHeaders);
expect(result).toBe('https://example.com/graphql/api/auth/oidc/callback');
expect(validateRedirectUri).toHaveBeenCalledWith(
'https://example.com',
'https',
'example.com',
expect.anything(),
['https://allowed.example.com']
);
});
it('should throw UnauthorizedException when validation fails', async () => {
const requestOrigin = 'https://evil.com';
const requestHeaders = {
'x-forwarded-proto': 'https',
'x-forwarded-host': 'example.com',
};
(validateRedirectUri as any).mockReturnValue({
isValid: false,
reason: 'Origin not allowed',
});
await expect(service.getRedirectUri(requestOrigin, requestHeaders)).rejects.toThrow(
UnauthorizedException
);
});
it('should handle missing allowed origins', async () => {
oidcConfig.getConfig.mockResolvedValue({
providers: [],
defaultAllowedOrigins: undefined,
});
const requestOrigin = 'https://example.com';
const requestHeaders = {
'x-forwarded-proto': 'https',
'x-forwarded-host': 'example.com',
};
(validateRedirectUri as any).mockReturnValue({
isValid: true,
validatedUri: 'https://example.com',
});
const result = await service.getRedirectUri(requestOrigin, requestHeaders);
expect(result).toBe('https://example.com/graphql/api/auth/oidc/callback');
expect(validateRedirectUri).toHaveBeenCalledWith(
'https://example.com',
'https',
'example.com',
expect.anything(),
undefined
);
});
it('should extract protocol from headers correctly', async () => {
const requestOrigin = 'https://example.com';
const requestHeaders = {
'x-forwarded-proto': ['https', 'http'],
host: 'example.com',
};
(validateRedirectUri as any).mockReturnValue({
isValid: true,
validatedUri: 'https://example.com',
});
const result = await service.getRedirectUri(requestOrigin, requestHeaders);
expect(result).toBe('https://example.com/graphql/api/auth/oidc/callback');
expect(validateRedirectUri).toHaveBeenCalledWith(
'https://example.com',
'https', // Should use first value from array
'example.com',
expect.anything(),
expect.anything()
);
});
it('should use host header as fallback', async () => {
const requestOrigin = 'https://example.com';
const requestHeaders = {
host: 'example.com',
};
(validateRedirectUri as any).mockReturnValue({
isValid: true,
validatedUri: 'https://example.com',
});
const result = await service.getRedirectUri(requestOrigin, requestHeaders);
expect(result).toBe('https://example.com/graphql/api/auth/oidc/callback');
expect(validateRedirectUri).toHaveBeenCalledWith(
'https://example.com',
'https', // Inferred from requestOrigin when x-forwarded-proto not present
'example.com',
expect.anything(),
expect.anything()
);
});
it('should prefer x-forwarded-host over host header', async () => {
const requestOrigin = 'https://example.com';
const requestHeaders = {
'x-forwarded-host': 'forwarded.example.com',
host: 'original.example.com',
};
(validateRedirectUri as any).mockReturnValue({
isValid: true,
validatedUri: 'https://example.com',
});
const result = await service.getRedirectUri(requestOrigin, requestHeaders);
expect(result).toBe('https://example.com/graphql/api/auth/oidc/callback');
expect(validateRedirectUri).toHaveBeenCalledWith(
'https://example.com',
'https', // Inferred from requestOrigin when x-forwarded-proto not present
'forwarded.example.com', // Should use x-forwarded-host
expect.anything(),
expect.anything()
);
});
it('should throw when URL construction fails', async () => {
const requestOrigin = 'https://example.com';
const requestHeaders = {};
(validateRedirectUri as any).mockReturnValue({
isValid: true,
validatedUri: 'invalid-url', // Invalid URL
});
await expect(service.getRedirectUri(requestOrigin, requestHeaders)).rejects.toThrow(
UnauthorizedException
);
});
it('should handle array values in headers correctly', async () => {
const requestOrigin = 'https://example.com';
const requestHeaders = {
'x-forwarded-proto': ['https'],
'x-forwarded-host': ['forwarded.example.com', 'another.example.com'],
host: ['original.example.com'],
};
(validateRedirectUri as any).mockReturnValue({
isValid: true,
validatedUri: 'https://example.com',
});
const result = await service.getRedirectUri(requestOrigin, requestHeaders);
expect(result).toBe('https://example.com/graphql/api/auth/oidc/callback');
expect(validateRedirectUri).toHaveBeenCalledWith(
'https://example.com',
'https',
'forwarded.example.com', // Should use first value from array
expect.anything(),
expect.anything()
);
});
});
});

View File

@@ -0,0 +1,97 @@
import { Injectable, Logger, UnauthorizedException } from '@nestjs/common';
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
import { validateRedirectUri } from '@app/unraid-api/utils/redirect-uri-validator.js';
@Injectable()
export class OidcRedirectUriService {
private readonly logger = new Logger(OidcRedirectUriService.name);
private readonly CALLBACK_PATH = '/graphql/api/auth/oidc/callback';
constructor(private readonly oidcConfig: OidcConfigPersistence) {}
async getRedirectUri(
requestOrigin: string,
requestHeaders: Record<string, string | string[] | undefined>
): Promise<string> {
// Extract protocol and host from headers for validation
const { protocol, host } = this.getRequestOriginInfo(requestHeaders, requestOrigin);
// Get the global allowed origins from OIDC config
const config = await this.oidcConfig.getConfig();
const allowedOrigins = config?.defaultAllowedOrigins;
// Debug logging to trace the issue
this.logger.debug(
`OIDC Config loaded: ${JSON.stringify(config ? { hasConfig: true, allowedOrigins } : { hasConfig: false })}`
);
this.logger.debug(
`Validating redirect URI: ${requestOrigin} against host: ${protocol}://${host}`
);
this.logger.debug(`Allowed origins from config: ${JSON.stringify(allowedOrigins || [])}`);
// Validate the provided requestOrigin using centralized validator
// Pass the global allowed origins if available
const validation = validateRedirectUri(
requestOrigin,
protocol,
host,
this.logger,
allowedOrigins
);
if (!validation.isValid) {
this.logger.warn(`Invalid redirect_uri in GraphQL OIDC flow: ${validation.reason}`);
throw new UnauthorizedException(
`Invalid redirect_uri: ${requestOrigin}. Please add this callback URI to Settings → Management Access → Allowed Redirect URIs`
);
}
// Ensure the validated URI has the correct callback path
try {
const url = new URL(validation.validatedUri);
// Only use origin to prevent path manipulation
const redirectUri = `${url.origin}${this.CALLBACK_PATH}`;
this.logger.debug(`Using validated redirect URI: ${redirectUri}`);
return redirectUri;
} catch (e) {
this.logger.error(
`Failed to construct redirect URI from validated URI: ${validation.validatedUri}`
);
throw new UnauthorizedException('Invalid redirect_uri');
}
}
private getRequestOriginInfo(
requestHeaders: Record<string, string | string[] | undefined>,
requestOrigin?: string
): {
protocol: string;
host: string | undefined;
} {
// Extract protocol from x-forwarded-proto or infer from requestOrigin, default to http
const forwardedProto = requestHeaders['x-forwarded-proto'];
const protocol = forwardedProto
? Array.isArray(forwardedProto)
? forwardedProto[0]
: forwardedProto
: requestOrigin?.startsWith('https')
? 'https'
: 'http';
// Extract host from x-forwarded-host or host header
const forwardedHost = requestHeaders['x-forwarded-host'];
const hostHeader = requestHeaders['host'];
const host = forwardedHost
? Array.isArray(forwardedHost)
? forwardedHost[0]
: forwardedHost
: hostHeader
? Array.isArray(hostHeader)
? hostHeader[0]
: hostHeader
: undefined;
return { protocol, host };
}
}

View File

@@ -0,0 +1,13 @@
import { Module } from '@nestjs/common';
import { UserSettingsModule } from '@unraid/shared/services/user-settings.js';
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/core/oidc-validation.service.js';
@Module({
imports: [UserSettingsModule],
providers: [OidcConfigPersistence, OidcValidationService],
exports: [OidcConfigPersistence, OidcValidationService],
})
export class OidcBaseModule {}

View File

@@ -1,4 +1,4 @@
import { Injectable, Logger } from '@nestjs/common'; import { Injectable } from '@nestjs/common';
import { ConfigService } from '@nestjs/config'; import { ConfigService } from '@nestjs/config';
import { RuleEffect } from '@jsonforms/core'; import { RuleEffect } from '@jsonforms/core';
@@ -6,12 +6,12 @@ import { mergeSettingSlices } from '@unraid/shared/jsonforms/settings.js';
import { ConfigFilePersister } from '@unraid/shared/services/config-file.js'; import { ConfigFilePersister } from '@unraid/shared/services/config-file.js';
import { UserSettingsService } from '@unraid/shared/services/user-settings.js'; import { UserSettingsService } from '@unraid/shared/services/user-settings.js';
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/core/oidc-validation.service.js';
import { import {
AuthorizationOperator, AuthorizationOperator,
OidcAuthorizationRule, OidcAuthorizationRule,
OidcProvider, OidcProvider,
} from '@app/unraid-api/graph/resolvers/sso/oidc-provider.model.js'; } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/oidc-validation.service.js';
import { import {
createAccordionLayout, createAccordionLayout,
createLabeledControl, createLabeledControl,
@@ -21,6 +21,7 @@ import { SettingSlice } from '@app/unraid-api/types/json-forms.js';
export interface OidcConfig { export interface OidcConfig {
providers: OidcProvider[]; providers: OidcProvider[];
defaultAllowedOrigins?: string[];
} }
@Injectable() @Injectable()
@@ -52,6 +53,7 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
defaultConfig(): OidcConfig { defaultConfig(): OidcConfig {
return { return {
providers: [this.getUnraidNetSsoProvider()], providers: [this.getUnraidNetSsoProvider()],
defaultAllowedOrigins: [],
}; };
} }
@@ -93,6 +95,7 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
return { return {
providers: [unraidNetSsoProvider], providers: [unraidNetSsoProvider],
defaultAllowedOrigins: [],
}; };
} }
@@ -119,6 +122,42 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
provider.authorizationRules || currentDefaults.authorizationRules, provider.authorizationRules || currentDefaults.authorizationRules,
}; };
} }
// Fix dangerous authorization rules for non-unraid.net providers
if (provider.authorizationRules && provider.authorizationRules.length > 0) {
// Filter out dangerous rules that would allow all emails
const safeRules = provider.authorizationRules.filter((rule) => {
// Remove rules that have "email endsWith @" which allows all emails
if (
rule.claim === 'email' &&
rule.operator === AuthorizationOperator.ENDS_WITH &&
rule.value &&
rule.value.length === 1 &&
rule.value[0] === '@'
) {
this.logger.warn(
`Removing dangerous authorization rule from provider "${provider.name}": email endsWith "@" allows all emails`
);
return false;
}
// Remove rules with empty or invalid values
if (
!rule.value ||
rule.value.length === 0 ||
rule.value.every((v) => !v || !v.trim())
) {
this.logger.warn(
`Removing invalid authorization rule from provider "${provider.name}": empty values`
);
return false;
}
return true;
});
// Update provider with safe rules
provider.authorizationRules = safeRules;
}
return provider; return provider;
}); });
@@ -155,6 +194,28 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
provider.authorizationRules = rules; provider.authorizationRules = rules;
} }
// Validate that authorization rules are present and valid for ALL providers
if (!provider.authorizationRules || provider.authorizationRules.length === 0) {
throw new Error(
`Provider "${provider.name}" requires authorization rules. Please configure who can access your server.`
);
}
// Validate each rule has valid values
for (const rule of provider.authorizationRules) {
if (!rule.claim || !rule.claim.trim()) {
throw new Error(`Provider "${provider.name}": Authorization rule claim cannot be empty`);
}
if (!rule.operator) {
throw new Error(`Provider "${provider.name}": Authorization rule operator is required`);
}
if (!rule.value || rule.value.length === 0 || rule.value.every((v) => !v || !v.trim())) {
throw new Error(
`Provider "${provider.name}": Authorization rule for claim "${rule.claim}" must have at least one non-empty value`
);
}
}
// Clean up the provider object - remove UI-only fields // Clean up the provider object - remove UI-only fields
const cleanedProvider: OidcProvider = { const cleanedProvider: OidcProvider = {
id: provider.id, id: provider.id,
@@ -191,46 +252,52 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
allowedDomains?: string[]; allowedDomains?: string[];
allowedEmails?: string[]; allowedEmails?: string[];
allowedUserIds?: string[]; allowedUserIds?: string[];
googleWorkspaceDomain?: string;
}): OidcAuthorizationRule[] { }): OidcAuthorizationRule[] {
const rules: OidcAuthorizationRule[] = []; const rules: OidcAuthorizationRule[] = [];
// Convert email domains to endsWith rules // Convert email domains to endsWith rules
// Only add if domains are provided AND not empty AND have non-empty values
if (simpleAuth?.allowedDomains && simpleAuth.allowedDomains.length > 0) { if (simpleAuth?.allowedDomains && simpleAuth.allowedDomains.length > 0) {
rules.push({ const validDomains = simpleAuth.allowedDomains.filter(
claim: 'email', (domain: string) => domain && domain.trim()
operator: AuthorizationOperator.ENDS_WITH, );
value: simpleAuth.allowedDomains.map((domain: string) => if (validDomains.length > 0) {
domain.startsWith('@') ? domain : `@${domain}` rules.push({
), claim: 'email',
}); operator: AuthorizationOperator.ENDS_WITH,
value: validDomains.map((domain: string) =>
domain.startsWith('@') ? domain : `@${domain}`
),
});
}
} }
// Convert specific emails to equals rules // Convert specific emails to equals rules
// Only add if emails are provided AND not empty AND have non-empty values
if (simpleAuth?.allowedEmails && simpleAuth.allowedEmails.length > 0) { if (simpleAuth?.allowedEmails && simpleAuth.allowedEmails.length > 0) {
rules.push({ const validEmails = simpleAuth.allowedEmails.filter(
claim: 'email', (email: string) => email && email.trim()
operator: AuthorizationOperator.EQUALS, );
value: simpleAuth.allowedEmails, if (validEmails.length > 0) {
}); rules.push({
claim: 'email',
operator: AuthorizationOperator.EQUALS,
value: validEmails,
});
}
} }
// Convert user IDs to sub equals rules // Convert user IDs to sub equals rules
// Only add if user IDs are provided AND not empty AND have non-empty values
if (simpleAuth?.allowedUserIds && simpleAuth.allowedUserIds.length > 0) { if (simpleAuth?.allowedUserIds && simpleAuth.allowedUserIds.length > 0) {
rules.push({ const validUserIds = simpleAuth.allowedUserIds.filter((id: string) => id && id.trim());
claim: 'sub', if (validUserIds.length > 0) {
operator: AuthorizationOperator.EQUALS, rules.push({
value: simpleAuth.allowedUserIds, claim: 'sub',
}); operator: AuthorizationOperator.EQUALS,
} value: validUserIds,
});
// Google Workspace domain (hd claim) }
if (simpleAuth?.googleWorkspaceDomain) {
rules.push({
claim: 'hd',
operator: AuthorizationOperator.EQUALS,
value: [simpleAuth.googleWorkspaceDomain],
});
} }
return rules; return rules;
@@ -286,7 +353,6 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
allowedDomains?: string[]; allowedDomains?: string[];
allowedEmails?: string[]; allowedEmails?: string[];
allowedUserIds?: string[]; allowedUserIds?: string[];
googleWorkspaceDomain?: string;
} }
); );
// Return provider with generated rules, removing UI-only fields // Return provider with generated rules, removing UI-only fields
@@ -304,6 +370,38 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
}), }),
}; };
// Validate authorization rules for ALL providers including unraid.net
for (const provider of processedConfig.providers) {
if (!provider.authorizationRules || provider.authorizationRules.length === 0) {
throw new Error(
`Provider "${provider.name}" requires authorization rules. Please configure who can access your server.`
);
}
// Validate each rule has valid values
for (const rule of provider.authorizationRules) {
if (!rule.claim || !rule.claim.trim()) {
throw new Error(
`Provider "${provider.name}": Authorization rule claim cannot be empty`
);
}
if (!rule.operator) {
throw new Error(
`Provider "${provider.name}": Authorization rule operator is required`
);
}
if (
!rule.value ||
rule.value.length === 0 ||
rule.value.every((v) => !v || !v.trim())
) {
throw new Error(
`Provider "${provider.name}": Authorization rule for claim "${rule.claim}" must have at least one non-empty value`
);
}
}
}
// Validate OIDC discovery for all providers with issuer URLs // Validate OIDC discovery for all providers with issuer URLs
const validationErrors: string[] = []; const validationErrors: string[] = [];
for (const provider of processedConfig.providers) { for (const provider of processedConfig.providers) {
@@ -419,10 +517,6 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
if (rule.claim === 'sub' && rule.operator === AuthorizationOperator.EQUALS) { if (rule.claim === 'sub' && rule.operator === AuthorizationOperator.EQUALS) {
return true; return true;
} }
// Google Workspace domain
if (rule.claim === 'hd' && rule.operator === AuthorizationOperator.EQUALS) {
return true;
}
return false; return false;
}); });
} }
@@ -431,13 +525,11 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
allowedDomains: string[]; allowedDomains: string[];
allowedEmails: string[]; allowedEmails: string[];
allowedUserIds: string[]; allowedUserIds: string[];
googleWorkspaceDomain?: string;
} { } {
const simpleAuth = { const simpleAuth = {
allowedDomains: [] as string[], allowedDomains: [] as string[],
allowedEmails: [] as string[], allowedEmails: [] as string[],
allowedUserIds: [] as string[], allowedUserIds: [] as string[],
googleWorkspaceDomain: undefined as string | undefined,
}; };
rules.forEach((rule) => { rules.forEach((rule) => {
@@ -449,12 +541,6 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
simpleAuth.allowedEmails = rule.value; simpleAuth.allowedEmails = rule.value;
} else if (rule.claim === 'sub' && rule.operator === AuthorizationOperator.EQUALS) { } else if (rule.claim === 'sub' && rule.operator === AuthorizationOperator.EQUALS) {
simpleAuth.allowedUserIds = rule.value; simpleAuth.allowedUserIds = rule.value;
} else if (
rule.claim === 'hd' &&
rule.operator === AuthorizationOperator.EQUALS &&
rule.value.length > 0
) {
simpleAuth.googleWorkspaceDomain = rule.value[0];
} }
}); });
@@ -462,7 +548,36 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
} }
private buildSlice(): SettingSlice { private buildSlice(): SettingSlice {
return mergeSettingSlices([this.oidcProvidersSlice()], { as: 'sso' }); const providersSlice = this.oidcProvidersSlice();
// Add defaultAllowedOrigins to the properties
providersSlice.properties.defaultAllowedOrigins = {
type: 'array',
items: { type: 'string' },
title: 'Default Allowed Redirect Origins',
default: [],
description:
'Additional trusted redirect origins to allow redirects from custom ports, reverse proxies, Tailscale, etc.',
};
// Add the control for defaultAllowedOrigins before the providers control using UnraidSettingsLayout
if (providersSlice.elements?.[0]?.elements) {
providersSlice.elements[0].elements.unshift(
createLabeledControl({
scope: '#/properties/sso/properties/defaultAllowedOrigins',
label: 'Allowed Redirect Origins',
description:
'Add trusted origins here when accessing Unraid through custom ports, reverse proxies, or Tailscale. Each origin should include the protocol and optionally a port (e.g., https://unraid.local:8443)',
controlOptions: {
format: 'array',
inputType: 'text',
placeholder: 'https://unraid.local:8443',
},
})
);
}
return mergeSettingSlices([providersSlice], { as: 'sso' });
} }
private oidcProvidersSlice(): SettingSlice { private oidcProvidersSlice(): SettingSlice {
@@ -999,7 +1114,7 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
scope: '#/properties/claim', scope: '#/properties/claim',
label: 'JWT Claim:', label: 'JWT Claim:',
description: description:
'JWT claim to check (e.g., email, sub, groups, hd for Google hosted domain)', 'JWT claim to check (e.g., email, sub, groups)',
controlOptions: { controlOptions: {
inputType: 'text', inputType: 'text',
placeholder: 'email', placeholder: 'email',

View File

@@ -0,0 +1,14 @@
import { Module } from '@nestjs/common';
import { OidcAuthModule } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-auth.module.js';
import { OidcClientModule } from '@app/unraid-api/graph/resolvers/sso/client/oidc-client.module.js';
import { OidcBaseModule } from '@app/unraid-api/graph/resolvers/sso/core/oidc-base.module.js';
import { OidcService } from '@app/unraid-api/graph/resolvers/sso/core/oidc.service.js';
import { OidcSessionModule } from '@app/unraid-api/graph/resolvers/sso/session/oidc-session.module.js';
@Module({
imports: [OidcBaseModule, OidcSessionModule, OidcAuthModule, OidcClientModule],
providers: [OidcService],
exports: [OidcService, OidcBaseModule, OidcSessionModule, OidcAuthModule, OidcClientModule],
})
export class OidcCoreModule {}

View File

@@ -0,0 +1,160 @@
import { Injectable, Logger } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import * as client from 'openid-client';
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
import { OidcErrorHelper } from '@app/unraid-api/graph/resolvers/sso/utils/oidc-error.helper.js';
@Injectable()
export class OidcValidationService {
private readonly logger = new Logger(OidcValidationService.name);
constructor(private readonly configService: ConfigService) {}
/**
* Validate OIDC provider configuration by attempting discovery
* Returns validation result with helpful error messages for debugging
*/
async validateProvider(
provider: OidcProvider
): Promise<{ isValid: boolean; error?: string; details?: unknown }> {
try {
// Validate issuer URL is present
if (!provider.issuer) {
return {
isValid: false,
error: 'No issuer URL provided. Please specify the OIDC provider issuer URL.',
details: { type: 'MISSING_ISSUER' },
};
}
// Validate issuer URL is valid
let serverUrl: URL;
try {
serverUrl = new URL(provider.issuer);
} catch (urlError) {
return {
isValid: false,
error: `Invalid issuer URL format: '${provider.issuer}'. Please provide a valid URL.`,
details: {
type: 'INVALID_URL',
originalError: urlError instanceof Error ? urlError.message : String(urlError),
},
};
}
// Configure client options for HTTP if needed
let clientOptions: any = undefined;
if (serverUrl.protocol === 'http:') {
this.logger.warn(
`HTTP issuer URL detected for provider ${provider.id}: ${provider.issuer} - This is insecure`
);
clientOptions = {
execute: [client.allowInsecureRequests],
};
}
// Attempt OIDC discovery
await this.performDiscovery(provider, clientOptions);
return { isValid: true };
} catch (error) {
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
// Log the raw error for debugging
this.logger.log(`Raw discovery error for ${provider.id}: ${errorMessage}`);
// Use the helper to parse the error
const { userFriendlyError, details } = OidcErrorHelper.parseDiscoveryError(
error,
provider.issuer
);
this.logger.error(`Validation failed for provider ${provider.id}: ${errorMessage}`);
// Add debug logging for HTTP status errors
if (errorMessage.includes('unexpected HTTP response status code')) {
const baseUrl = provider.issuer?.endsWith('/.well-known/openid-configuration')
? provider.issuer.replace('/.well-known/openid-configuration', '')
: provider.issuer;
this.logger.log(`Attempted to fetch: ${baseUrl}/.well-known/openid-configuration`);
this.logger.error(`Full error details: ${errorMessage}`);
}
return {
isValid: false,
error: userFriendlyError,
details,
};
}
}
async performDiscovery(provider: OidcProvider, clientOptions?: any): Promise<client.Configuration> {
if (!provider.issuer) {
throw new Error('No issuer URL provided');
}
// Configure client auth method
const clientAuth = provider.clientSecret
? client.ClientSecretPost(provider.clientSecret)
: undefined;
const serverUrl = new URL(provider.issuer);
const discoveryUrl = `${provider.issuer}/.well-known/openid-configuration`;
this.logger.log(`Starting discovery for provider ${provider.id}`);
this.logger.log(`Discovery URL: ${discoveryUrl}`);
this.logger.log(`Client ID: ${provider.clientId}`);
this.logger.log(`Client secret configured: ${provider.clientSecret ? 'Yes' : 'No'}`);
// Use provided client options or create default options with HTTP support if needed
if (!clientOptions && serverUrl.protocol === 'http:') {
this.logger.warn(
`Allowing HTTP for ${provider.id} - This is insecure and should only be used for testing`
);
// For openid-client v6, use allowInsecureRequests in the execute array
// This is deprecated but needed for local development with HTTP endpoints
clientOptions = {
execute: [client.allowInsecureRequests],
};
}
try {
const config = await client.discovery(
serverUrl,
provider.clientId,
undefined, // client metadata
clientAuth,
clientOptions
);
this.logger.log(`Discovery successful for ${provider.id}`);
this.logger.log(`Discovery response metadata:`);
this.logger.log(` - issuer: ${config.serverMetadata().issuer}`);
this.logger.log(
` - authorization_endpoint: ${config.serverMetadata().authorization_endpoint}`
);
this.logger.log(` - token_endpoint: ${config.serverMetadata().token_endpoint}`);
this.logger.log(
` - userinfo_endpoint: ${config.serverMetadata().userinfo_endpoint || 'not provided'}`
);
this.logger.log(` - jwks_uri: ${config.serverMetadata().jwks_uri || 'not provided'}`);
this.logger.log(
` - response_types_supported: ${config.serverMetadata().response_types_supported?.join(', ') || 'not provided'}`
);
this.logger.log(
` - scopes_supported: ${config.serverMetadata().scopes_supported?.join(', ') || 'not provided'}`
);
return config;
} catch (discoveryError) {
this.logger.error(`Discovery failed for ${provider.id} at ${discoveryUrl}`);
if (discoveryError instanceof Error) {
this.logger.error('Discovery error: %o', discoveryError);
}
throw discoveryError;
}
}
}

View File

@@ -0,0 +1,485 @@
import { Logger } from '@nestjs/common';
import { ConfigModule, ConfigService } from '@nestjs/config';
import { Test, TestingModule } from '@nestjs/testing';
import * as client from 'openid-client';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import { OidcAuthorizationService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-authorization.service.js';
import { OidcClaimsService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-claims.service.js';
import { OidcTokenExchangeService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-token-exchange.service.js';
import { OidcClientConfigService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-client-config.service.js';
import { OidcRedirectUriService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-redirect-uri.service.js';
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/core/oidc-validation.service.js';
import { OidcService } from '@app/unraid-api/graph/resolvers/sso/core/oidc.service.js';
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-session.service.js';
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
describe('OidcService Integration Tests - Enhanced Logging', () => {
let service: OidcService;
let configPersistence: OidcConfigPersistence;
let loggerSpy: any;
let debugLogs: string[] = [];
let errorLogs: string[] = [];
let warnLogs: string[] = [];
let logLogs: string[] = [];
beforeEach(async () => {
// Clear log arrays
debugLogs = [];
errorLogs = [];
warnLogs = [];
logLogs = [];
const module: TestingModule = await Test.createTestingModule({
imports: [
ConfigModule.forRoot({
isGlobal: true,
load: [() => ({ BASE_URL: 'http://test.local' })],
}),
],
providers: [
OidcService,
OidcValidationService,
OidcClientConfigService,
OidcTokenExchangeService,
{
provide: OidcAuthorizationService,
useValue: {
checkAuthorization: vi.fn(),
},
},
{
provide: OidcConfigPersistence,
useValue: {
getProvider: vi.fn(),
saveProvider: vi.fn(),
getConfig: vi.fn().mockReturnValue({
providers: [],
defaultAllowedOrigins: [],
}),
},
},
{
provide: OidcSessionService,
useValue: {
createSession: vi.fn().mockResolvedValue('mock-token'),
validateSession: vi.fn(),
},
},
{
provide: OidcStateService,
useValue: {
generateSecureState: vi.fn().mockResolvedValue('secure-state'),
validateSecureState: vi.fn().mockResolvedValue({
isValid: true,
clientState: 'test-state',
redirectUri: 'https://myapp.example.com/graphql/api/auth/oidc/callback',
}),
extractProviderFromState: vi.fn().mockReturnValue('test-provider'),
},
},
{
provide: OidcRedirectUriService,
useValue: {
getRedirectUri: vi
.fn()
.mockResolvedValue(
'https://myapp.example.com/graphql/api/auth/oidc/callback'
),
},
},
{
provide: OidcClaimsService,
useValue: {
parseIdToken: vi.fn().mockReturnValue({
sub: 'user123',
email: 'user@example.com',
}),
validateClaims: vi.fn().mockReturnValue('user123'),
},
},
],
}).compile();
service = module.get<OidcService>(OidcService);
configPersistence = module.get<OidcConfigPersistence>(OidcConfigPersistence);
// Spy on logger methods to capture logs
loggerSpy = {
debug: vi
.spyOn(Logger.prototype, 'debug')
.mockImplementation((message: string, ...args: any[]) => {
debugLogs.push(message);
}),
error: vi
.spyOn(Logger.prototype, 'error')
.mockImplementation((message: string, ...args: any[]) => {
errorLogs.push(message);
}),
warn: vi
.spyOn(Logger.prototype, 'warn')
.mockImplementation((message: string, ...args: any[]) => {
warnLogs.push(message);
}),
log: vi
.spyOn(Logger.prototype, 'log')
.mockImplementation((message: string, ...args: any[]) => {
logLogs.push(message);
}),
verbose: vi.spyOn(Logger.prototype, 'verbose').mockImplementation(() => {}),
};
});
afterEach(() => {
vi.restoreAllMocks();
});
describe('Token Exchange Error Logging', () => {
it('should log detailed error information when token exchange fails with Google (trailing slash issue)', async () => {
// This simulates the issue from #1616 where a trailing slash causes failure
const provider: OidcProvider = {
id: 'google-test',
name: 'Google Test',
issuer: 'https://accounts.google.com/', // Trailing slash will cause issue
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
scopes: ['openid', 'email', 'profile'],
authorizationRules: [
{
claim: 'email',
operator: 'ENDS_WITH' as any,
value: ['@example.com'],
},
],
};
vi.mocked(configPersistence.getProvider).mockResolvedValue(provider);
try {
await service.handleCallback({
providerId: 'google-test',
code: 'test-code',
state: 'test-state',
requestOrigin: 'http://test.local',
fullCallbackUrl:
'http://test.local/graphql/api/auth/oidc/callback?code=test-code&state=test-state',
requestHeaders: { host: 'test.local' },
});
} catch (error) {
// We expect this to fail
}
// Verify that the service attempted to handle the callback
// Note: Detailed token exchange logging now happens in OidcTokenExchangeService
expect(errorLogs.length).toBeGreaterThan(0);
// Changed logging format to use error extractor
expect(errorLogs.some((log) => log.includes('Token exchange failed'))).toBe(true);
});
it('should log discovery failure details with invalid issuer URL', async () => {
const provider: OidcProvider = {
id: 'invalid-issuer',
name: 'Invalid Issuer Test',
issuer: 'https://invalid-oidc-provider.example.com', // Non-existent domain
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
scopes: ['openid', 'email'],
authorizationRules: [],
};
const validationService = new OidcValidationService(new ConfigService());
const result = await validationService.validateProvider(provider);
expect(result.isValid).toBe(false);
// Should now have more specific error message
expect(result.error).toBeDefined();
// The error should mention the domain cannot be resolved or connection failed
expect(result.error).toMatch(
/Cannot resolve domain name|Failed to connect to OIDC provider/
);
expect(result.details).toBeDefined();
expect(result.details).toHaveProperty('type');
// Should be either DNS_ERROR or FETCH_ERROR depending on the cause
expect(['DNS_ERROR', 'FETCH_ERROR']).toContain((result.details as any).type);
});
it('should log detailed HTTP error responses from discovery', async () => {
const provider: OidcProvider = {
id: 'http-error-test',
name: 'HTTP Error Test',
issuer: 'https://httpstat.us/500', // Returns 500 error
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
scopes: ['openid'],
authorizationRules: [],
};
vi.mocked(configPersistence.getProvider).mockResolvedValue(provider);
try {
await service.validateProvider(provider);
} catch (error) {
// Expected to fail
}
// Check that HTTP status details are logged (now in log level)
expect(logLogs.some((log) => log.includes('Discovery URL:'))).toBe(true);
expect(logLogs.some((log) => log.includes('Client ID:'))).toBe(true);
});
it('should log authorization URL building details', async () => {
const provider: OidcProvider = {
id: 'auth-url-test',
name: 'Auth URL Test',
issuer: 'https://accounts.google.com',
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
scopes: ['openid', 'email', 'profile'],
authorizationRules: [],
};
vi.mocked(configPersistence.getProvider).mockResolvedValue(provider);
try {
await service.getAuthorizationUrl({
providerId: 'auth-url-test',
state: 'test-state',
requestOrigin: 'http://test.local',
requestHeaders: { host: 'test.local' },
});
// Verify URL building logs
expect(logLogs.some((log) => log.includes('Built authorization URL'))).toBe(true);
expect(logLogs.some((log) => log.includes('Authorization parameters:'))).toBe(true);
} catch (error) {
// May fail due to real discovery, but we're interested in the logs
}
});
it('should log detailed information for manual endpoint configuration', async () => {
const provider: OidcProvider = {
id: 'manual-endpoints',
name: 'Manual Endpoints Test',
issuer: undefined,
authorizationEndpoint: 'https://auth.example.com/authorize',
tokenEndpoint: 'https://auth.example.com/token',
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
scopes: ['openid'],
authorizationRules: [],
};
vi.mocked(configPersistence.getProvider).mockResolvedValue(provider);
const authUrl = await service.getAuthorizationUrl({
providerId: 'manual-endpoints',
state: 'test-state',
requestOrigin: 'http://test.local',
requestHeaders: {
'x-forwarded-host': 'test.local',
'x-forwarded-proto': 'http',
},
});
// Verify manual endpoint logs
expect(debugLogs.some((log) => log.includes('Built authorization URL'))).toBe(true);
expect(debugLogs.some((log) => log.includes('client_id=test-client-id'))).toBe(true);
expect(authUrl).toContain('https://auth.example.com/authorize');
});
it('should log JWT claim validation failures with detailed context', async () => {
const provider: OidcProvider = {
id: 'jwt-validation-test',
name: 'JWT Validation Test',
issuer: 'https://accounts.google.com',
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
scopes: ['openid', 'email'],
authorizationRules: [
{
claim: 'email',
operator: 'ENDS_WITH' as any,
value: ['@restricted.com'],
},
],
};
vi.mocked(configPersistence.getProvider).mockResolvedValue(provider);
// Mock a scenario where JWT validation fails
try {
await service.handleCallback({
providerId: 'jwt-validation-test',
code: 'test-code',
state: 'test-state',
requestOrigin: 'http://test.local',
fullCallbackUrl:
'http://test.local/graphql/api/auth/oidc/callback?code=test-code&state=test-state',
requestHeaders: { host: 'test.local' },
});
} catch (error) {
// Expected to fail
}
// The JWT error handling is now in OidcTokenExchangeService
// We should see some error logged
expect(errorLogs.length).toBeGreaterThan(0);
});
});
describe('Discovery Endpoint Logging', () => {
it('should log all discovery metadata when successful', async () => {
// Use a real OIDC provider that works
const provider: OidcProvider = {
id: 'microsoft',
name: 'Microsoft',
issuer: 'https://login.microsoftonline.com/common/v2.0',
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
scopes: ['openid', 'email', 'profile'],
authorizationRules: [],
};
const validationService = new OidcValidationService(new ConfigService());
try {
await validationService.performDiscovery(provider);
} catch (error) {
// May fail due to network, but we're checking logs
}
// Verify discovery logging (now in log level)
expect(logLogs.some((log) => log.includes('Starting discovery'))).toBe(true);
expect(logLogs.some((log) => log.includes('Discovery URL:'))).toBe(true);
});
it('should log discovery failures with malformed JSON response', async () => {
const provider: OidcProvider = {
id: 'malformed-json',
name: 'Malformed JSON Test',
issuer: 'https://example.com/malformed',
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
scopes: ['openid'],
authorizationRules: [],
};
// Mock global fetch to return HTML instead of JSON
const originalFetch = global.fetch;
global.fetch = vi.fn().mockImplementation(() =>
Promise.resolve(
new Response('<html><body>Not JSON</body></html>', {
status: 200,
headers: { 'content-type': 'text/html' },
})
)
);
const validationService = new OidcValidationService(new ConfigService());
const result = await validationService.validateProvider(provider);
// Restore original fetch
global.fetch = originalFetch;
expect(result.isValid).toBe(false);
expect(result.error).toBeDefined();
// The openid-client library will fail when it gets HTML instead of JSON
// It returns "unexpected response content-type" error
expect(result.error).toMatch(
/Invalid OIDC discovery|malformed|doesn't conform|unexpected|content-type/i
);
});
it('should handle and log HTTP vs HTTPS protocol differences', async () => {
const httpProvider: OidcProvider = {
id: 'http-local',
name: 'HTTP Local Test',
issuer: 'http://localhost:8080', // HTTP endpoint
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
scopes: ['openid'],
authorizationRules: [],
};
// Create a validation service and spy on its logger
const validationService = new OidcValidationService(new ConfigService());
try {
await validationService.validateProvider(httpProvider);
} catch (error) {
// Expected to fail if localhost:8080 isn't running
}
// The HTTP logging happens in the validation service
// We should check that HTTP issuers are detected
expect(httpProvider.issuer).toMatch(/^http:/);
// Verify that we're testing an HTTP endpoint
expect(httpProvider.issuer).toBe('http://localhost:8080');
});
});
describe('Request/Response Detail Logging', () => {
it('should log complete request parameters for token exchange', async () => {
const provider: OidcProvider = {
id: 'token-params-test',
name: 'Token Params Test',
issuer: 'https://accounts.google.com',
clientId: 'detailed-client-id',
clientSecret: 'detailed-client-secret',
scopes: ['openid', 'email', 'profile', 'offline_access'],
authorizationRules: [],
};
vi.mocked(configPersistence.getProvider).mockResolvedValue(provider);
try {
await service.handleCallback({
providerId: 'token-params-test',
code: 'authorization-code-12345',
state: 'state-with-signature',
requestOrigin: 'https://myapp.example.com',
fullCallbackUrl:
'https://myapp.example.com/graphql/api/auth/oidc/callback?code=authorization-code-12345&state=state-with-signature&scope=openid+email+profile',
requestHeaders: { host: 'myapp.example.com' },
});
} catch (error) {
// Expected to fail
}
// Verify that we attempted the operation
// Detailed parameter logging is now in OidcTokenExchangeService
expect(debugLogs.length).toBeGreaterThan(0);
expect(debugLogs.some((log) => log.includes('Client ID: detailed-client-id'))).toBe(true);
expect(debugLogs.some((log) => log.includes('Client secret configured: Yes'))).toBe(true);
});
it('should capture and log all error properties from openid-client', async () => {
const provider: OidcProvider = {
id: 'error-properties-test',
name: 'Error Properties Test',
issuer: 'https://expired-cert.badssl.com/', // SSL cert error
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
scopes: ['openid'],
authorizationRules: [],
};
const validationService = new OidcValidationService(new ConfigService());
const result = await validationService.validateProvider(provider);
expect(result.isValid).toBe(false);
expect(result.error).toBeDefined();
// Should detect SSL/certificate issues or connection failure
expect(result.error).toMatch(
/SSL\/TLS certificate error|Failed to connect to OIDC provider|certificate/
);
expect(result.details).toBeDefined();
expect(result.details).toHaveProperty('type');
// Should be either SSL_ERROR or FETCH_ERROR
expect(['SSL_ERROR', 'FETCH_ERROR']).toContain((result.details as any).type);
});
});
});

View File

@@ -0,0 +1,381 @@
import { CacheModule } from '@nestjs/cache-manager';
import { UnauthorizedException } from '@nestjs/common';
import { Test, TestingModule } from '@nestjs/testing';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { OidcAuthorizationService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-authorization.service.js';
import { OidcClaimsService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-claims.service.js';
import { OidcTokenExchangeService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-token-exchange.service.js';
import { OidcClientConfigService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-client-config.service.js';
import { OidcRedirectUriService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-redirect-uri.service.js';
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/core/oidc-validation.service.js';
import { OidcService } from '@app/unraid-api/graph/resolvers/sso/core/oidc.service.js';
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-session.service.js';
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
// Mock openid-client
vi.mock('openid-client', () => ({
buildAuthorizationUrl: vi.fn((config, params) => {
const url = new URL(config.serverMetadata().authorization_endpoint);
Object.entries(params).forEach(([key, value]) => {
if (value !== undefined) {
url.searchParams.set(key, String(value));
}
});
return url;
}),
allowInsecureRequests: vi.fn(),
}));
describe('OidcService Integration', () => {
let service: OidcService;
let oidcConfig: any;
let sessionService: any;
let stateService: OidcStateService;
let redirectUriService: any;
let clientConfigService: any;
let tokenExchangeService: any;
let claimsService: any;
let authorizationService: any;
beforeEach(async () => {
const module: TestingModule = await Test.createTestingModule({
imports: [CacheModule.register()],
providers: [
OidcService,
{
provide: OidcConfigPersistence,
useValue: {
getProvider: vi.fn(),
getConfig: vi.fn().mockResolvedValue({
providers: [],
defaultAllowedOrigins: ['https://example.com'],
}),
},
},
{
provide: OidcSessionService,
useValue: {
createSession: vi.fn().mockResolvedValue('padded-token-123'),
},
},
OidcStateService,
{
provide: OidcValidationService,
useValue: {
validateProvider: vi.fn().mockResolvedValue({ isValid: true }),
performDiscovery: vi.fn(),
},
},
{
provide: OidcAuthorizationService,
useValue: {
checkAuthorization: vi.fn(),
},
},
{
provide: OidcRedirectUriService,
useValue: {
getRedirectUri: vi.fn().mockResolvedValue('https://example.com/callback'),
},
},
{
provide: OidcClientConfigService,
useValue: {
getOrCreateConfig: vi.fn(),
clearCache: vi.fn(),
},
},
{
provide: OidcTokenExchangeService,
useValue: {
exchangeCodeForTokens: vi.fn(),
},
},
{
provide: OidcClaimsService,
useValue: {
parseIdToken: vi.fn(),
validateClaims: vi.fn(),
},
},
],
}).compile();
service = module.get<OidcService>(OidcService);
oidcConfig = module.get(OidcConfigPersistence);
sessionService = module.get(OidcSessionService);
stateService = module.get<OidcStateService>(OidcStateService);
redirectUriService = module.get(OidcRedirectUriService);
clientConfigService = module.get(OidcClientConfigService);
tokenExchangeService = module.get(OidcTokenExchangeService);
claimsService = module.get(OidcClaimsService);
authorizationService = module.get(OidcAuthorizationService);
});
describe('getAuthorizationUrl', () => {
it('should generate authorization URL with custom endpoints', async () => {
const provider: OidcProvider = {
id: 'custom-provider',
name: 'Custom Provider',
clientId: 'test-client-id',
clientSecret: 'test-secret',
authorizationEndpoint: 'https://custom.example.com/auth',
scopes: ['openid', 'profile'],
authorizationRules: [],
};
oidcConfig.getProvider.mockResolvedValue(provider);
const params = {
providerId: 'custom-provider',
state: 'client-state-123',
requestOrigin: 'https://example.com',
requestHeaders: { host: 'example.com' },
};
const url = await service.getAuthorizationUrl(params);
expect(redirectUriService.getRedirectUri).toHaveBeenCalledWith('https://example.com', {
host: 'example.com',
});
const urlObj = new URL(url);
expect(urlObj.origin).toBe('https://custom.example.com');
expect(urlObj.pathname).toBe('/auth');
expect(urlObj.searchParams.get('client_id')).toBe('test-client-id');
expect(urlObj.searchParams.get('redirect_uri')).toBe('https://example.com/callback');
expect(urlObj.searchParams.get('scope')).toBe('openid profile');
expect(urlObj.searchParams.get('response_type')).toBe('code');
expect(urlObj.searchParams.has('state')).toBe(true);
});
it('should use OIDC discovery when no custom authorization endpoint', async () => {
const provider: OidcProvider = {
id: 'discovery-provider',
name: 'Discovery Provider',
clientId: 'test-client-id',
issuer: 'https://discovery.example.com',
scopes: ['openid'],
authorizationRules: [],
};
// Create a mock configuration object
const mockConfig = {
serverMetadata: vi.fn().mockReturnValue({
authorization_endpoint: 'https://discovery.example.com/authorize',
}),
};
oidcConfig.getProvider.mockResolvedValue(provider);
clientConfigService.getOrCreateConfig.mockResolvedValue(mockConfig);
const params = {
providerId: 'discovery-provider',
state: 'client-state-123',
requestOrigin: 'https://example.com',
requestHeaders: {},
};
const url = await service.getAuthorizationUrl(params);
expect(clientConfigService.getOrCreateConfig).toHaveBeenCalledWith(provider);
expect(url).toContain('https://discovery.example.com/authorize');
});
it('should throw when provider not found', async () => {
oidcConfig.getProvider.mockResolvedValue(null);
const params = {
providerId: 'non-existent',
state: 'state',
requestOrigin: 'https://example.com',
requestHeaders: {},
};
await expect(service.getAuthorizationUrl(params)).rejects.toThrow(UnauthorizedException);
});
});
describe('handleCallback', () => {
it('should handle successful callback flow', async () => {
const provider: OidcProvider = {
id: 'test-provider',
name: 'Test Provider',
clientId: 'test-client-id',
issuer: 'https://test.example.com',
scopes: ['openid'],
authorizationRules: [],
};
const mockConfig = {
serverMetadata: vi.fn().mockReturnValue({
issuer: 'https://test.example.com',
token_endpoint: 'https://test.example.com/token',
}),
};
const mockTokens = {
id_token: 'id.token.here',
access_token: 'access.token.here',
};
const mockClaims = {
sub: 'user123',
email: 'user@example.com',
};
oidcConfig.getProvider.mockResolvedValue(provider);
clientConfigService.getOrCreateConfig.mockResolvedValue(mockConfig);
tokenExchangeService.exchangeCodeForTokens.mockResolvedValue(mockTokens);
claimsService.parseIdToken.mockReturnValue(mockClaims);
claimsService.validateClaims.mockReturnValue('user123');
// Mock the OidcStateExtractor's static method
const OidcStateExtractor = await import(
'@app/unraid-api/graph/resolvers/sso/session/oidc-state-extractor.util.js'
);
vi.spyOn(OidcStateExtractor.OidcStateExtractor, 'extractAndValidateState').mockResolvedValue(
{
providerId: 'test-provider',
originalState: 'original-state',
clientState: 'original-state',
redirectUri: 'https://example.com/callback',
}
);
const params = {
providerId: 'test-provider',
code: 'auth-code-123',
state: 'secure-state',
requestOrigin: 'https://example.com',
fullCallbackUrl: 'https://example.com/callback?code=auth-code-123&state=secure-state',
requestHeaders: {},
};
const token = await service.handleCallback(params);
expect(token).toBe('padded-token-123');
expect(tokenExchangeService.exchangeCodeForTokens).toHaveBeenCalled();
expect(claimsService.parseIdToken).toHaveBeenCalledWith('id.token.here');
expect(claimsService.validateClaims).toHaveBeenCalledWith(mockClaims);
expect(authorizationService.checkAuthorization).toHaveBeenCalledWith(provider, mockClaims);
expect(sessionService.createSession).toHaveBeenCalledWith('test-provider', 'user123');
});
it('should throw when provider not found', async () => {
oidcConfig.getProvider.mockResolvedValue(null);
const params = {
providerId: 'non-existent',
code: 'code',
state: 'state',
requestOrigin: 'https://example.com',
fullCallbackUrl: 'https://example.com/callback',
requestHeaders: {},
};
await expect(service.handleCallback(params)).rejects.toThrow(UnauthorizedException);
});
it('should handle authorization rejection', async () => {
const provider: OidcProvider = {
id: 'test-provider',
name: 'Test Provider',
clientId: 'test-client-id',
issuer: 'https://test.example.com',
scopes: ['openid'],
authorizationRules: [],
};
const mockConfig = {
serverMetadata: vi.fn().mockReturnValue({
issuer: 'https://test.example.com',
token_endpoint: 'https://test.example.com/token',
}),
};
const mockTokens = {
id_token: 'id.token.here',
};
const mockClaims = {
sub: 'user123',
email: 'user@example.com',
};
oidcConfig.getProvider.mockResolvedValue(provider);
clientConfigService.getOrCreateConfig.mockResolvedValue(mockConfig);
tokenExchangeService.exchangeCodeForTokens.mockResolvedValue(mockTokens);
claimsService.parseIdToken.mockReturnValue(mockClaims);
claimsService.validateClaims.mockReturnValue('user123');
authorizationService.checkAuthorization.mockRejectedValue(
new UnauthorizedException('Not authorized')
);
// Mock the OidcStateExtractor's static method
const OidcStateExtractor = await import(
'@app/unraid-api/graph/resolvers/sso/session/oidc-state-extractor.util.js'
);
vi.spyOn(OidcStateExtractor.OidcStateExtractor, 'extractAndValidateState').mockResolvedValue(
{
providerId: 'test-provider',
originalState: 'original-state',
clientState: 'original-state',
redirectUri: 'https://example.com/callback',
}
);
const params = {
providerId: 'test-provider',
code: 'auth-code-123',
state: 'secure-state',
requestOrigin: 'https://example.com',
fullCallbackUrl: 'https://example.com/callback',
requestHeaders: {},
};
await expect(service.handleCallback(params)).rejects.toThrow(UnauthorizedException);
});
});
describe('validateProvider', () => {
it('should clear cache and validate provider', async () => {
const provider: OidcProvider = {
id: 'test-provider',
name: 'Test Provider',
clientId: 'test-client-id',
issuer: 'https://test.example.com',
scopes: ['openid'],
authorizationRules: [],
};
const result = await service.validateProvider(provider);
expect(clientConfigService.clearCache).toHaveBeenCalledWith('test-provider');
// The validation service mock already returns { isValid: true }
expect(result).toEqual({ isValid: true });
});
});
describe('extractProviderFromState', () => {
it('should extract provider from state', () => {
const state = 'provider-id:original-state';
const result = service.extractProviderFromState(state);
expect(result.providerId).toBeDefined();
expect(result.originalState).toBeDefined();
});
});
describe('getStateService', () => {
it('should return state service', () => {
const result = service.getStateService();
expect(result).toBe(stateService);
});
});
});

View File

@@ -0,0 +1,243 @@
import { Injectable, Logger, UnauthorizedException } from '@nestjs/common';
import * as client from 'openid-client';
import { OidcAuthorizationService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-authorization.service.js';
import { OidcClaimsService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-claims.service.js';
import { OidcTokenExchangeService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-token-exchange.service.js';
import { OidcClientConfigService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-client-config.service.js';
import { OidcRedirectUriService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-redirect-uri.service.js';
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/core/oidc-validation.service.js';
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-session.service.js';
import { OidcStateExtractor } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state-extractor.util.js';
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
import { ErrorExtractor } from '@app/unraid-api/utils/error-extractor.util.js';
export interface GetAuthorizationUrlParams {
providerId: string;
state: string;
requestOrigin: string;
requestHeaders: Record<string, string | string[] | undefined>;
}
export interface HandleCallbackParams {
providerId: string;
code: string;
state: string;
requestOrigin: string;
fullCallbackUrl: string;
requestHeaders: Record<string, string | string[] | undefined>;
}
@Injectable()
export class OidcService {
private readonly logger = new Logger(OidcService.name);
constructor(
private readonly oidcConfig: OidcConfigPersistence,
private readonly sessionService: OidcSessionService,
private readonly stateService: OidcStateService,
private readonly validationService: OidcValidationService,
private readonly authorizationService: OidcAuthorizationService,
private readonly redirectUriService: OidcRedirectUriService,
private readonly clientConfigService: OidcClientConfigService,
private readonly tokenExchangeService: OidcTokenExchangeService,
private readonly claimsService: OidcClaimsService
) {}
async getAuthorizationUrl(params: GetAuthorizationUrlParams): Promise<string> {
const { providerId, state, requestOrigin, requestHeaders } = params;
const provider = await this.oidcConfig.getProvider(providerId);
if (!provider) {
throw new UnauthorizedException(`Provider ${providerId} not found`);
}
// Use requestOrigin with validation
const redirectUri = await this.redirectUriService.getRedirectUri(requestOrigin, requestHeaders);
this.logger.debug(`Using redirect URI for authorization: ${redirectUri}`);
this.logger.debug(`Request origin was: ${requestOrigin}`);
// Generate secure state with cryptographic signature, including redirect URI
const secureState = await this.stateService.generateSecureState(providerId, state, redirectUri);
// Build authorization URL
if (provider.authorizationEndpoint) {
// Use custom authorization endpoint
const authUrl = new URL(provider.authorizationEndpoint);
// Standard OAuth2 parameters
authUrl.searchParams.set('client_id', provider.clientId);
authUrl.searchParams.set('redirect_uri', redirectUri);
authUrl.searchParams.set('scope', provider.scopes.join(' '));
authUrl.searchParams.set('state', secureState);
authUrl.searchParams.set('response_type', 'code');
this.logger.debug(`Built authorization URL for provider ${provider.id}`);
this.logger.debug(
`Authorization parameters: client_id=${provider.clientId}, redirect_uri=${redirectUri}, scope=${provider.scopes.join(' ')}, response_type=code`
);
return authUrl.href;
}
// Use OIDC discovery for providers without custom endpoints
const config = await this.clientConfigService.getOrCreateConfig(provider);
const parameters: Record<string, string> = {
redirect_uri: redirectUri,
scope: provider.scopes.join(' '),
state: secureState,
response_type: 'code',
};
// For HTTP endpoints, we need to call allowInsecureRequests on the config
if (provider.issuer) {
try {
const serverUrl = new URL(provider.issuer);
if (serverUrl.protocol === 'http:') {
this.logger.debug(`Allowing insecure requests for HTTP endpoint: ${provider.id}`);
// allowInsecureRequests is deprecated but still needed for HTTP endpoints
client.allowInsecureRequests(config);
}
} catch (error) {
this.logger.warn(`Invalid issuer URL for provider ${provider.id}: ${provider.issuer}`);
// Continue without special HTTP options
}
}
const authUrl = client.buildAuthorizationUrl(config, parameters);
this.logger.log(`Built authorization URL via discovery for provider ${provider.id}`);
this.logger.log(`Authorization parameters: ${JSON.stringify(parameters)}`);
return authUrl.href;
}
extractProviderFromState(state: string): { providerId: string; originalState: string } {
return OidcStateExtractor.extractProviderFromState(state, this.stateService);
}
/**
* Get the state service for external utilities
*/
getStateService(): OidcStateService {
return this.stateService;
}
async handleCallback(params: HandleCallbackParams): Promise<string> {
const { providerId, code, state, fullCallbackUrl } = params;
const provider = await this.oidcConfig.getProvider(providerId);
if (!provider) {
throw new UnauthorizedException(`Provider ${providerId} not found`);
}
// Extract and validate state, including the stored redirect URI
const stateInfo = await OidcStateExtractor.extractAndValidateState(state, this.stateService);
if (!stateInfo.redirectUri) {
throw new UnauthorizedException('Missing redirect URI in state');
}
// Use the redirect URI that was stored during authorization
const redirectUri = stateInfo.redirectUri;
this.logger.debug(`Using stored redirect URI from state: ${redirectUri}`);
try {
// Always use openid-client for consistency
const config = await this.clientConfigService.getOrCreateConfig(provider);
// Log configuration details
this.logger.debug(`Provider ${providerId} config loaded`);
this.logger.debug(`Redirect URI: ${redirectUri}`);
// Build current URL for token exchange
// CRITICAL: The URL used here MUST match the redirect_uri that was sent to the authorization endpoint
// Google expects the exact same redirect_uri during token exchange
const currentUrl = new URL(redirectUri);
currentUrl.searchParams.set('code', code);
currentUrl.searchParams.set('state', state);
// Copy additional parameters from the actual callback if provided
if (fullCallbackUrl) {
const actualUrl = new URL(fullCallbackUrl);
// Copy over additional params that Google might have added (scope, authuser, prompt, etc)
// but DO NOT change the base URL or path
['scope', 'authuser', 'prompt', 'hd', 'session_state', 'iss'].forEach((param) => {
const value = actualUrl.searchParams.get(param);
if (value && !currentUrl.searchParams.has(param)) {
currentUrl.searchParams.set(param, value);
}
});
}
// Google returns iss in the response, openid-client v6 expects it
// If not present, add it based on the provider's issuer
if (!currentUrl.searchParams.has('iss') && provider.issuer) {
currentUrl.searchParams.set('iss', provider.issuer);
}
this.logger.debug(`Token exchange URL (matches redirect_uri): ${currentUrl.href}`);
// State was already validated in extractAndValidateState above, use that result
// The clientState should be present after successful validation, but handle the edge case
if (!stateInfo.clientState) {
this.logger.warn('Client state missing after successful validation');
throw new UnauthorizedException('Invalid state: missing client state');
}
const originalState = stateInfo.clientState;
this.logger.debug(`Exchanging code for tokens with provider ${providerId}`);
this.logger.debug(`Client state extracted: ${originalState}`);
// Use the token exchange service
const tokens = await this.tokenExchangeService.exchangeCodeForTokens(
config,
provider,
code,
originalState,
redirectUri,
fullCallbackUrl
);
// Parse ID token to get user info
const claims = this.claimsService.parseIdToken(tokens.id_token);
const userSub = this.claimsService.validateClaims(claims);
// Check authorization based on rules
// This will throw a helpful error if misconfigured or unauthorized
await this.authorizationService.checkAuthorization(provider, claims!);
// Create session and return padded token
const paddedToken = await this.sessionService.createSession(providerId, userSub);
this.logger.log(`Successfully authenticated user ${userSub} via provider ${providerId}`);
return paddedToken;
} catch (error) {
const extracted = ErrorExtractor.extract(error);
this.logger.error(`OAuth callback error: ${extracted.message}`);
// Re-throw the original error if it's already an UnauthorizedException
if (error instanceof UnauthorizedException) {
throw error;
}
// Otherwise throw a generic error
throw new UnauthorizedException('Authentication failed');
}
}
/**
* Validate OIDC provider configuration by attempting discovery
* Returns validation result with helpful error messages for debugging
*/
async validateProvider(
provider: OidcProvider
): Promise<{ isValid: boolean; error?: string; details?: unknown }> {
// Clear any cached config for this provider to force fresh validation
this.clientConfigService.clearCache(provider.id);
// Delegate to the validation service
return this.validationService.validateProvider(provider);
}
}

View File

@@ -0,0 +1,16 @@
import { Field, ObjectType } from '@nestjs/graphql';
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
@ObjectType()
export class OidcConfiguration {
@Field(() => [OidcProvider], { description: 'List of configured OIDC providers' })
providers!: OidcProvider[];
@Field(() => [String], {
nullable: true,
description:
'Default allowed redirect origins that apply to all OIDC providers (e.g., Tailscale domains)',
})
defaultAllowedOrigins?: string[];
}

View File

@@ -80,9 +80,11 @@ export class OidcProvider {
@Field(() => String, { @Field(() => String, {
description: description:
'OIDC issuer URL (e.g., https://accounts.google.com). Required for auto-discovery via /.well-known/openid-configuration', 'OIDC issuer URL (e.g., https://accounts.google.com). Required for auto-discovery via /.well-known/openid-configuration',
nullable: true,
}) })
@IsUrl() @IsUrl()
issuer!: string; @IsOptional()
issuer?: string;
@Field(() => String, { @Field(() => String, {
nullable: true, nullable: true,

View File

@@ -1,4 +1,4 @@
import type { OidcConfig } from '@app/unraid-api/graph/resolvers/sso/oidc-config.service.js'; import type { OidcConfig } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
declare module '@unraid/shared/services/user-settings.js' { declare module '@unraid/shared/services/user-settings.js' {
interface UserSettings { interface UserSettings {

View File

@@ -1,701 +0,0 @@
import { Injectable, Logger, UnauthorizedException } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import { decodeJwt } from 'jose';
import * as client from 'openid-client';
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/oidc-config.service.js';
import {
AuthorizationOperator,
AuthorizationRuleMode,
OidcAuthorizationRule,
OidcProvider,
} from '@app/unraid-api/graph/resolvers/sso/oidc-provider.model.js';
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/oidc-session.service.js';
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/oidc-state.service.js';
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/oidc-validation.service.js';
interface JwtClaims {
sub?: string;
email?: string;
name?: string;
hd?: string; // Google hosted domain
[claim: string]: unknown;
}
@Injectable()
export class OidcAuthService {
private readonly logger = new Logger(OidcAuthService.name);
private readonly configCache = new Map<string, client.Configuration>();
constructor(
private readonly configService: ConfigService,
private readonly oidcConfig: OidcConfigPersistence,
private readonly sessionService: OidcSessionService,
private readonly stateService: OidcStateService,
private readonly validationService: OidcValidationService
) {}
async getAuthorizationUrl(
providerId: string,
state: string,
requestOrigin?: string
): Promise<string> {
const provider = await this.oidcConfig.getProvider(providerId);
if (!provider) {
throw new UnauthorizedException(`Provider ${providerId} not found`);
}
const redirectUri = this.getRedirectUri(requestOrigin);
// Generate secure state with cryptographic signature
const secureState = this.stateService.generateSecureState(providerId, state);
// Build authorization URL
if (provider.authorizationEndpoint) {
// Use custom authorization endpoint
const authUrl = new URL(provider.authorizationEndpoint);
// Standard OAuth2 parameters
authUrl.searchParams.set('client_id', provider.clientId);
authUrl.searchParams.set('redirect_uri', redirectUri);
authUrl.searchParams.set('scope', provider.scopes.join(' '));
authUrl.searchParams.set('state', secureState);
authUrl.searchParams.set('response_type', 'code');
return authUrl.href;
}
// Use OIDC discovery for providers without custom endpoints
const config = await this.getOrCreateConfig(provider);
const parameters: Record<string, string> = {
redirect_uri: redirectUri,
scope: provider.scopes.join(' '),
state: secureState,
response_type: 'code',
};
// For HTTP endpoints, we need to pass the allowInsecureRequests option
const serverUrl = new URL(provider.issuer || '');
let clientOptions: any = undefined;
if (serverUrl.protocol === 'http:') {
this.logger.debug(
`Building authorization URL with allowInsecureRequests for ${provider.id}`
);
clientOptions = {
execute: [client.allowInsecureRequests],
};
}
const authUrl = client.buildAuthorizationUrl(config, parameters);
return authUrl.href;
}
extractProviderFromState(state: string): { providerId: string; originalState: string } {
// Extract provider from state prefix (no decryption needed)
const providerId = this.stateService.extractProviderFromState(state);
if (providerId) {
return {
providerId,
originalState: state,
};
}
// Fallback for unknown formats
return {
providerId: '',
originalState: state,
};
}
async handleCallback(
providerId: string,
code: string,
state: string,
requestOrigin?: string,
fullCallbackUrl?: string
): Promise<string> {
const provider = await this.oidcConfig.getProvider(providerId);
if (!provider) {
throw new UnauthorizedException(`Provider ${providerId} not found`);
}
try {
const redirectUri = this.getRedirectUri(requestOrigin);
// Always use openid-client for consistency
const config = await this.getOrCreateConfig(provider);
// Log configuration details
this.logger.debug(`Provider ${providerId} config loaded`);
this.logger.debug(`Redirect URI: ${redirectUri}`);
// Build current URL for token exchange
// CRITICAL: The URL used here MUST match the redirect_uri that was sent to the authorization endpoint
// Google expects the exact same redirect_uri during token exchange
const currentUrl = new URL(redirectUri);
currentUrl.searchParams.set('code', code);
currentUrl.searchParams.set('state', state);
// Copy additional parameters from the actual callback if provided
if (fullCallbackUrl) {
const actualUrl = new URL(fullCallbackUrl);
// Copy over additional params that Google might have added (scope, authuser, prompt, etc)
// but DO NOT change the base URL or path
['scope', 'authuser', 'prompt', 'hd', 'session_state', 'iss'].forEach((param) => {
const value = actualUrl.searchParams.get(param);
if (value && !currentUrl.searchParams.has(param)) {
currentUrl.searchParams.set(param, value);
}
});
}
// Google returns iss in the response, openid-client v6 expects it
// If not present, add it based on the provider's issuer
if (!currentUrl.searchParams.has('iss') && provider.issuer) {
currentUrl.searchParams.set('iss', provider.issuer);
}
this.logger.debug(`Token exchange URL (matches redirect_uri): ${currentUrl.href}`);
// Validate secure state
const stateValidation = this.stateService.validateSecureState(state, providerId);
if (!stateValidation.isValid) {
this.logger.error(`State validation failed: ${stateValidation.error}`);
throw new UnauthorizedException(stateValidation.error || 'Invalid state parameter');
}
const originalState = stateValidation.clientState!;
this.logger.debug(`Exchanging code for tokens with provider ${providerId}`);
this.logger.debug(`Client state extracted: ${originalState}`);
// For openid-client v6, we need to prepare the authorization response
const authorizationResponse = new URLSearchParams(currentUrl.search);
// Set the original client state for openid-client
authorizationResponse.set('state', originalState);
// Create a new URL with the cleaned parameters
const cleanUrl = new URL(redirectUri);
cleanUrl.search = authorizationResponse.toString();
this.logger.debug(`Clean URL for token exchange: ${cleanUrl.href}`);
let tokens;
try {
this.logger.debug(`Starting token exchange with openid-client`);
this.logger.debug(`Config issuer: ${config.serverMetadata().issuer}`);
this.logger.debug(`Config token endpoint: ${config.serverMetadata().token_endpoint}`);
// For HTTP endpoints, we need to pass the allowInsecureRequests option
const serverUrl = new URL(provider.issuer || '');
let clientOptions: any = undefined;
if (serverUrl.protocol === 'http:') {
this.logger.debug(`Token exchange with allowInsecureRequests for ${provider.id}`);
clientOptions = {
execute: [client.allowInsecureRequests],
};
}
tokens = await client.authorizationCodeGrant(
config,
cleanUrl,
{
expectedState: originalState,
},
clientOptions
);
this.logger.debug(
`Token exchange successful, received tokens: ${Object.keys(tokens).join(', ')}`
);
} catch (tokenError) {
const errorMessage =
tokenError instanceof Error ? tokenError.message : String(tokenError);
this.logger.error(`Token exchange failed: ${errorMessage}`);
// Check if error message contains the "unexpected JWT claim" text
if (errorMessage.includes('unexpected JWT claim value encountered')) {
this.logger.error(
`unexpected JWT claim value encountered during token validation by openid-client`
);
this.logger.debug(
`Token exchange error details: ${JSON.stringify(tokenError, null, 2)}`
);
// Log the actual vs expected issuer
this.logger.error(
`This error typically means the 'iss' claim in the JWT doesn't match the expected issuer`
);
this.logger.error(`Check that your provider's issuer URL is configured correctly`);
}
throw tokenError;
}
// Parse ID token to get user info
let claims: JwtClaims | null = null;
if (tokens.id_token) {
try {
// Use jose to properly decode the JWT
claims = decodeJwt(tokens.id_token) as JwtClaims;
// Log claims safely without PII - only structure, not values
if (claims) {
const claimKeys = Object.keys(claims).join(', ');
this.logger.debug(
`ID token decoded successfully. Available claims: [${claimKeys}]`
);
// Log claim types without exposing sensitive values
for (const [key, value] of Object.entries(claims)) {
const valueType = Array.isArray(value)
? `array[${value.length}]`
: typeof value;
// Only log structure, not actual values (avoid PII)
this.logger.debug(`Claim '${key}': type=${valueType}`);
// Check for unexpected claim types
if (valueType === 'object' && value !== null && !Array.isArray(value)) {
this.logger.warn(`Claim '${key}' contains complex object structure`);
}
}
}
} catch (e) {
this.logger.warn(`Failed to parse ID token: ${e}`);
}
} else {
this.logger.error('No ID token received from provider');
}
if (!claims?.sub) {
this.logger.error(
'No subject in token - claims available: ' +
(claims ? Object.keys(claims).join(', ') : 'none')
);
throw new UnauthorizedException('No subject in token');
}
const userSub = claims.sub;
this.logger.debug(`Processing authentication for user: ${userSub}`);
// Check authorization based on rules
// This will throw a helpful error if misconfigured or unauthorized
await this.checkAuthorization(provider, claims);
// Create session and return padded token
const paddedToken = await this.sessionService.createSession(providerId, userSub);
this.logger.log(`Successfully authenticated user ${userSub} via provider ${providerId}`);
return paddedToken;
} catch (error) {
this.logger.error(
`OAuth callback error: ${error instanceof Error ? error.message : 'Unknown error'}`
);
// Re-throw the original error if it's already an UnauthorizedException
if (error instanceof UnauthorizedException) {
throw error;
}
// Otherwise throw a generic error
throw new UnauthorizedException('Authentication failed');
}
}
private async getOrCreateConfig(provider: OidcProvider): Promise<client.Configuration> {
const cacheKey = provider.id;
if (this.configCache.has(cacheKey)) {
return this.configCache.get(cacheKey)!;
}
try {
// Use the validation service to perform discovery with HTTP support
if (provider.issuer) {
this.logger.debug(`Attempting discovery for ${provider.id} at ${provider.issuer}`);
// Create client options with HTTP support if needed
const serverUrl = new URL(provider.issuer);
let clientOptions: any = undefined;
if (serverUrl.protocol === 'http:') {
this.logger.debug(`Allowing HTTP for ${provider.id} as specified by user`);
clientOptions = {
execute: [client.allowInsecureRequests],
};
}
try {
const config = await this.validationService.performDiscovery(
provider,
clientOptions
);
this.logger.debug(`Discovery successful for ${provider.id}`);
this.logger.debug(
`Authorization endpoint: ${config.serverMetadata().authorization_endpoint}`
);
this.logger.debug(`Token endpoint: ${config.serverMetadata().token_endpoint}`);
this.configCache.set(cacheKey, config);
return config;
} catch (discoveryError) {
const errorMessage =
discoveryError instanceof Error ? discoveryError.message : 'Unknown error';
this.logger.warn(`Discovery failed for ${provider.id}: ${errorMessage}`);
// Log more details about the discovery error
this.logger.debug(
`Discovery URL attempted: ${provider.issuer}/.well-known/openid-configuration`
);
this.logger.debug(
`Full discovery error: ${JSON.stringify(discoveryError, null, 2)}`
);
// Log stack trace for better debugging
if (discoveryError instanceof Error && discoveryError.stack) {
this.logger.debug(`Stack trace: ${discoveryError.stack}`);
}
// If discovery fails but we have manual endpoints, use them
if (provider.authorizationEndpoint && provider.tokenEndpoint) {
this.logger.log(`Using manual endpoints for ${provider.id}`);
// Create manual configuration
const serverMetadata: client.ServerMetadata = {
issuer: provider.issuer || `manual-${provider.id}`,
authorization_endpoint: provider.authorizationEndpoint,
token_endpoint: provider.tokenEndpoint,
jwks_uri: provider.jwksUri,
};
const clientMetadata: Partial<client.ClientMetadata> = {
client_secret: provider.clientSecret,
};
// Configure client auth method
const clientAuth = provider.clientSecret
? client.ClientSecretPost(provider.clientSecret)
: client.None();
try {
const config = new client.Configuration(
serverMetadata,
provider.clientId,
clientMetadata,
clientAuth
);
// Use manual configuration with HTTP support if needed
const serverUrl = new URL(provider.tokenEndpoint);
if (serverUrl.protocol === 'http:') {
this.logger.debug(
`Allowing HTTP for manual endpoints on ${provider.id}`
);
client.allowInsecureRequests(config);
}
this.logger.debug(`Manual configuration created for ${provider.id}`);
this.logger.debug(
`Authorization endpoint: ${serverMetadata.authorization_endpoint}`
);
this.logger.debug(`Token endpoint: ${serverMetadata.token_endpoint}`);
this.configCache.set(cacheKey, config);
return config;
} catch (manualConfigError) {
this.logger.error(
`Failed to create manual configuration: ${manualConfigError instanceof Error ? manualConfigError.message : 'Unknown error'}`
);
throw new Error(`Manual configuration failed for ${provider.id}`);
}
} else {
throw new Error(
`OIDC discovery failed and no manual endpoints provided for ${provider.id}`
);
}
}
}
// Manual configuration when no issuer is provided
if (provider.authorizationEndpoint && provider.tokenEndpoint) {
this.logger.log(`Using manual endpoints for ${provider.id} (no issuer provided)`);
// Create manual configuration
const serverMetadata: client.ServerMetadata = {
issuer: provider.issuer || `manual-${provider.id}`,
authorization_endpoint: provider.authorizationEndpoint,
token_endpoint: provider.tokenEndpoint,
jwks_uri: provider.jwksUri,
};
const clientMetadata: Partial<client.ClientMetadata> = {
client_secret: provider.clientSecret,
};
// Configure client auth method
const clientAuth = provider.clientSecret
? client.ClientSecretPost(provider.clientSecret)
: client.None();
try {
const config = new client.Configuration(
serverMetadata,
provider.clientId,
clientMetadata,
clientAuth
);
// Use manual configuration with HTTP support if needed
const serverUrl = new URL(provider.tokenEndpoint);
if (serverUrl.protocol === 'http:') {
this.logger.debug(`Allowing HTTP for manual endpoints on ${provider.id}`);
client.allowInsecureRequests(config);
}
this.logger.debug(`Manual configuration created for ${provider.id}`);
this.logger.debug(
`Authorization endpoint: ${serverMetadata.authorization_endpoint}`
);
this.logger.debug(`Token endpoint: ${serverMetadata.token_endpoint}`);
this.configCache.set(cacheKey, config);
return config;
} catch (manualConfigError) {
this.logger.error(
`Failed to create manual configuration: ${manualConfigError instanceof Error ? manualConfigError.message : 'Unknown error'}`
);
throw new Error(`Manual configuration failed for ${provider.id}`);
}
}
// If we reach here, neither discovery nor manual endpoints are available
throw new Error(
`No configuration method available for ${provider.id}: requires either valid issuer for discovery or manual endpoints`
);
} catch (error) {
this.logger.error(
`Failed to create OIDC configuration for ${provider.id}: ${
error instanceof Error ? error.message : 'Unknown error'
}`
);
// Log more details in debug mode
if (error instanceof Error && error.stack) {
this.logger.debug(`Stack trace: ${error.stack}`);
}
throw new UnauthorizedException('Provider configuration error');
}
}
private async checkAuthorization(provider: OidcProvider, claims: JwtClaims): Promise<void> {
this.logger.debug(
`Checking authorization for provider ${provider.id} with ${provider.authorizationRules?.length || 0} rules`
);
this.logger.debug(`Available claims: ${Object.keys(claims).join(', ')}`);
this.logger.debug(
`Authorization rule mode: ${provider.authorizationRuleMode || AuthorizationRuleMode.OR}`
);
// If no authorization rules are specified, throw a helpful error
if (!provider.authorizationRules || provider.authorizationRules.length === 0) {
throw new UnauthorizedException(
`Login failed: The ${provider.name} provider has no authorization rules configured. ` +
`Please configure authorization rules.`
);
}
this.logger.debug(
`Authorization rules to evaluate: ${JSON.stringify(provider.authorizationRules, null, 2)}`
);
// Evaluate the rules
const ruleMode = provider.authorizationRuleMode || AuthorizationRuleMode.OR;
const isAuthorized = this.evaluateAuthorizationRules(
provider.authorizationRules,
claims,
ruleMode
);
this.logger.debug(`Authorization result: ${isAuthorized}`);
if (!isAuthorized) {
// Log authorization failure with safe claim representation (no PII)
const availableClaimKeys = Object.keys(claims).join(', ');
this.logger.warn(
`Authorization failed for provider ${provider.name}, user ${claims.sub}, available claim keys: [${availableClaimKeys}]`
);
throw new UnauthorizedException(
`Access denied: Your account does not meet the authorization requirements for ${provider.name}.`
);
}
this.logger.debug(`Authorization successful for user ${claims.sub}`);
}
private evaluateAuthorizationRules(
rules: OidcAuthorizationRule[],
claims: JwtClaims,
mode: AuthorizationRuleMode = AuthorizationRuleMode.OR
): boolean {
// No rules means no authorization
if (rules.length === 0) {
return false;
}
if (mode === AuthorizationRuleMode.AND) {
// All rules must pass (AND logic)
return rules.every((rule) => this.evaluateRule(rule, claims));
} else {
// Any rule can pass (OR logic) - default behavior
// Multiple rules act as alternative authorization paths
return rules.some((rule) => this.evaluateRule(rule, claims));
}
}
private evaluateRule(rule: OidcAuthorizationRule, claims: JwtClaims): boolean {
const claimValue = claims[rule.claim];
this.logger.verbose(
`Evaluating rule for claim ${rule.claim}: ${JSON.stringify({
claimValue,
claimType: typeof claimValue,
isArray: Array.isArray(claimValue),
ruleOperator: rule.operator,
ruleValues: rule.value,
})}`
);
if (claimValue === undefined || claimValue === null) {
this.logger.verbose(`Claim ${rule.claim} not found in token`);
return false;
}
// Handle non-array, non-string objects
if (typeof claimValue === 'object' && claimValue !== null && !Array.isArray(claimValue)) {
this.logger.warn(
`unexpected JWT claim value encountered - claim ${rule.claim} has unsupported object type (keys: [${Object.keys(claimValue as Record<string, unknown>).join(', ')}])`
);
return false;
}
// Handle array claims - evaluate rule against each array element
if (Array.isArray(claimValue)) {
this.logger.verbose(
`Processing array claim ${rule.claim} with ${claimValue.length} elements`
);
// For array claims, check if ANY element in the array matches the rule
const arrayResult = claimValue.some((element) => {
// Skip non-string elements
if (
typeof element !== 'string' &&
typeof element !== 'number' &&
typeof element !== 'boolean'
) {
this.logger.verbose(`Skipping non-primitive element in array: ${typeof element}`);
return false;
}
const elementValue = String(element);
return this.evaluateSingleValue(elementValue, rule);
});
this.logger.verbose(`Array evaluation result for claim ${rule.claim}: ${arrayResult}`);
return arrayResult;
}
// Handle single value claims (string, number, boolean)
const value = String(claimValue);
this.logger.verbose(`Processing single value claim ${rule.claim} with value: "${value}"`);
return this.evaluateSingleValue(value, rule);
}
private evaluateSingleValue(value: string, rule: OidcAuthorizationRule): boolean {
let result: boolean;
switch (rule.operator) {
case AuthorizationOperator.EQUALS:
result = rule.value.some((v) => value === v);
this.logger.verbose(
`EQUALS check: "${value}" matches any of [${rule.value.join(', ')}]: ${result}`
);
return result;
case AuthorizationOperator.CONTAINS:
result = rule.value.some((v) => value.includes(v));
this.logger.verbose(
`CONTAINS check: "${value}" contains any of [${rule.value.join(', ')}]: ${result}`
);
return result;
case AuthorizationOperator.STARTS_WITH:
result = rule.value.some((v) => value.startsWith(v));
this.logger.verbose(
`STARTS_WITH check: "${value}" starts with any of [${rule.value.join(', ')}]: ${result}`
);
return result;
case AuthorizationOperator.ENDS_WITH:
result = rule.value.some((v) => value.endsWith(v));
this.logger.verbose(
`ENDS_WITH check: "${value}" ends with any of [${rule.value.join(', ')}]: ${result}`
);
return result;
default:
this.logger.error(`Unknown authorization operator: ${rule.operator}`);
return false;
}
}
/**
* Validate OIDC provider configuration by attempting discovery
* Returns validation result with helpful error messages for debugging
*/
async validateProvider(
provider: OidcProvider
): Promise<{ isValid: boolean; error?: string; details?: unknown }> {
// Clear any cached config for this provider to force fresh validation
this.configCache.delete(provider.id);
// Delegate to the validation service
return this.validationService.validateProvider(provider);
}
private getRedirectUri(requestOrigin?: string): string {
// If we have the full origin (protocol://host), use it directly
if (requestOrigin) {
// Parse the origin to extract protocol and host
try {
const url = new URL(requestOrigin);
const { protocol, hostname, port } = url;
// Reconstruct the URL, removing default ports
let cleanOrigin = `${protocol}//${hostname}`;
// Add port if it's not the default for the protocol
if (
port &&
!(protocol === 'https:' && port === '443') &&
!(protocol === 'http:' && port === '80')
) {
cleanOrigin += `:${port}`;
}
// Special handling for localhost development with Nuxt proxy
if (hostname === 'localhost' && port === '3000') {
return `${cleanOrigin}/graphql/api/auth/oidc/callback`;
}
return `${cleanOrigin}/graphql/api/auth/oidc/callback`;
} catch (e) {
this.logger.warn(`Failed to parse request origin: ${requestOrigin}, error: ${e}`);
}
}
// Fall back to configured BASE_URL or default
const baseUrl = this.configService.get('BASE_URL', 'http://tower.local');
return `${baseUrl}/graphql/api/auth/oidc/callback`;
}
}

View File

@@ -1,204 +0,0 @@
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/oidc-state.service.js';
describe('OidcStateService', () => {
let service: OidcStateService;
beforeEach(() => {
vi.clearAllMocks();
vi.useFakeTimers();
// Create a single instance for all tests in a describe block
service = new OidcStateService();
});
afterEach(() => {
vi.useRealTimers();
});
describe('generateSecureState', () => {
it('should generate a state with provider prefix and signed token', () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state = service.generateSecureState(providerId, clientState);
expect(state).toBeTruthy();
expect(typeof state).toBe('string');
expect(state.startsWith(`${providerId}:`)).toBe(true);
// Extract signed portion and verify format (nonce.timestamp.signature)
const signed = state.substring(providerId.length + 1);
expect(signed.split('.').length).toBe(3);
});
it('should generate unique states for each call', () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state1 = service.generateSecureState(providerId, clientState);
const state2 = service.generateSecureState(providerId, clientState);
expect(state1).not.toBe(state2);
});
});
describe('validateSecureState', () => {
it('should validate a valid state token', () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state = service.generateSecureState(providerId, clientState);
const result = service.validateSecureState(state, providerId);
expect(result.isValid).toBe(true);
expect(result.clientState).toBe(clientState);
expect(result.error).toBeUndefined();
});
it('should reject state with wrong provider ID', () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state = service.generateSecureState(providerId, clientState);
const result = service.validateSecureState(state, 'wrong-provider');
expect(result.isValid).toBe(false);
expect(result.error).toBe('Provider ID mismatch in state');
});
it('should reject expired state tokens', () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state = service.generateSecureState(providerId, clientState);
// Fast forward time beyond expiration (11 minutes)
vi.advanceTimersByTime(11 * 60 * 1000);
const result = service.validateSecureState(state, providerId);
expect(result.isValid).toBe(false);
expect(result.error).toBe('State token has expired');
});
it('should reject reused state tokens', () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state = service.generateSecureState(providerId, clientState);
// First validation should succeed
const result1 = service.validateSecureState(state, providerId);
expect(result1.isValid).toBe(true);
// Second validation should fail (replay attack prevention)
const result2 = service.validateSecureState(state, providerId);
expect(result2.isValid).toBe(false);
expect(result2.error).toBe('State token not found or already used');
});
it('should reject invalid state tokens', () => {
const result = service.validateSecureState('invalid.state.token', 'test-provider');
expect(result.isValid).toBe(false);
expect(result.error).toBe('Invalid state format');
});
it('should reject tampered state tokens', () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state = service.generateSecureState(providerId, clientState);
// Tamper with the signature
const parts = state.split('.');
parts[2] = parts[2].slice(0, -4) + 'XXXX';
const tamperedState = parts.join('.');
const result = service.validateSecureState(tamperedState, providerId);
expect(result.isValid).toBe(false);
expect(result.error).toBe('Invalid state signature');
});
});
describe('extractProviderFromState', () => {
it('should extract provider from state prefix', () => {
const state = 'provider-id:eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.payload.signature';
const result = service.extractProviderFromState(state);
expect(result).toBe('provider-id');
});
it('should handle states with multiple colons', () => {
const state = 'provider-id:jwt:with:colons';
const result = service.extractProviderFromState(state);
expect(result).toBe('provider-id');
});
it('should return null for invalid format', () => {
const result = service.extractProviderFromState('invalid-state');
expect(result).toBeNull();
});
});
describe('extractProviderFromLegacyState', () => {
it('should extract provider from legacy colon-separated format', () => {
const result = service.extractProviderFromLegacyState('provider-id:client-state');
expect(result.providerId).toBe('provider-id');
expect(result.originalState).toBe('client-state');
});
it('should handle multiple colons in legacy format', () => {
const result = service.extractProviderFromLegacyState(
'provider-id:client:state:with:colons'
);
expect(result.providerId).toBe('provider-id');
expect(result.originalState).toBe('client:state:with:colons');
});
it('should return empty provider for JWT format', () => {
const jwtState = 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.payload.signature';
const result = service.extractProviderFromLegacyState(jwtState);
expect(result.providerId).toBe('');
expect(result.originalState).toBe(jwtState);
});
it('should return empty provider for unknown format', () => {
const result = service.extractProviderFromLegacyState('some-random-state');
expect(result.providerId).toBe('');
expect(result.originalState).toBe('some-random-state');
});
});
describe('cleanupExpiredStates', () => {
it('should clean up expired states periodically', () => {
const providerId = 'test-provider';
// Generate multiple states
service.generateSecureState(providerId, 'state1');
service.generateSecureState(providerId, 'state2');
service.generateSecureState(providerId, 'state3');
// Fast forward past expiration
vi.advanceTimersByTime(11 * 60 * 1000);
// Generate a new state that shouldn't be cleaned
const validState = service.generateSecureState(providerId, 'state4');
// Trigger cleanup (happens every minute)
vi.advanceTimersByTime(60 * 1000);
// The new state should still be valid
const result = service.validateSecureState(validState, providerId);
expect(result.isValid).toBe(true);
});
});
});

View File

@@ -1,164 +0,0 @@
import { Injectable, Logger } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import * as client from 'openid-client';
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/oidc-provider.model.js';
@Injectable()
export class OidcValidationService {
private readonly logger = new Logger(OidcValidationService.name);
constructor(private readonly configService: ConfigService) {}
/**
* Validate OIDC provider configuration by attempting discovery
* Returns validation result with helpful error messages for debugging
*/
async validateProvider(
provider: OidcProvider
): Promise<{ isValid: boolean; error?: string; details?: unknown }> {
try {
// Validate issuer URL is present
if (!provider.issuer) {
return {
isValid: false,
error: 'No issuer URL provided. Please specify the OIDC provider issuer URL.',
details: { type: 'MISSING_ISSUER' },
};
}
// Validate issuer URL is valid
let serverUrl: URL;
try {
serverUrl = new URL(provider.issuer);
} catch (urlError) {
return {
isValid: false,
error: `Invalid issuer URL format: '${provider.issuer}'. Please provide a valid URL.`,
details: {
type: 'INVALID_URL',
originalError: urlError instanceof Error ? urlError.message : String(urlError),
},
};
}
// Configure client options for HTTP if needed
let clientOptions: any = undefined;
if (serverUrl.protocol === 'http:') {
this.logger.debug(
`HTTP issuer URL detected for provider ${provider.id}: ${provider.issuer}`
);
clientOptions = {
execute: [client.allowInsecureRequests],
};
}
// Attempt OIDC discovery
await this.performDiscovery(provider, clientOptions);
return { isValid: true };
} catch (error) {
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
// Log the raw error for debugging
this.logger.debug(`Raw discovery error for ${provider.id}: ${errorMessage}`);
// Provide specific error messages for common issues
let userFriendlyError = errorMessage;
let details: Record<string, unknown> = {};
if (errorMessage.includes('getaddrinfo ENOTFOUND')) {
userFriendlyError = `Cannot resolve domain name. Please check that '${provider.issuer}' is accessible and spelled correctly.`;
details = { type: 'DNS_ERROR', originalError: errorMessage };
} else if (errorMessage.includes('ECONNREFUSED')) {
userFriendlyError = `Connection refused. The server at '${provider.issuer}' is not accepting connections.`;
details = { type: 'CONNECTION_ERROR', originalError: errorMessage };
} else if (errorMessage.includes('ECONNRESET') || errorMessage.includes('ETIMEDOUT')) {
userFriendlyError = `Connection timeout. The server at '${provider.issuer}' is not responding.`;
details = { type: 'TIMEOUT_ERROR', originalError: errorMessage };
} else if (errorMessage.includes('404') || errorMessage.includes('Not Found')) {
const baseUrl = provider.issuer?.endsWith('/.well-known/openid-configuration')
? provider.issuer.replace('/.well-known/openid-configuration', '')
: provider.issuer;
userFriendlyError = `OIDC discovery endpoint not found. Please verify that '${baseUrl}/.well-known/openid-configuration' exists.`;
details = { type: 'DISCOVERY_NOT_FOUND', originalError: errorMessage };
} else if (errorMessage.includes('401') || errorMessage.includes('403')) {
userFriendlyError = `Access denied to discovery endpoint. Please check the issuer URL and any authentication requirements.`;
details = { type: 'AUTHENTICATION_ERROR', originalError: errorMessage };
} else if (errorMessage.includes('unexpected HTTP response status code')) {
// Extract status code if possible
const statusMatch = errorMessage.match(/status code (\d+)/);
const statusCode = statusMatch ? statusMatch[1] : 'unknown';
const baseUrl = provider.issuer?.endsWith('/.well-known/openid-configuration')
? provider.issuer.replace('/.well-known/openid-configuration', '')
: provider.issuer;
userFriendlyError = `HTTP ${statusCode} error from discovery endpoint. Please check that '${baseUrl}/.well-known/openid-configuration' returns a valid OIDC discovery document.`;
details = { type: 'HTTP_STATUS_ERROR', statusCode, originalError: errorMessage };
} else if (
errorMessage.includes('certificate') ||
errorMessage.includes('SSL') ||
errorMessage.includes('TLS')
) {
userFriendlyError = `SSL/TLS certificate error. The server certificate may be invalid or expired.`;
details = { type: 'SSL_ERROR', originalError: errorMessage };
} else if (errorMessage.includes('JSON') || errorMessage.includes('parse')) {
userFriendlyError = `Invalid OIDC discovery response. The server returned malformed JSON.`;
details = { type: 'INVALID_JSON', originalError: errorMessage };
} else if (error && (error as any).code === 'OAUTH_RESPONSE_IS_NOT_CONFORM') {
const baseUrl = provider.issuer?.endsWith('/.well-known/openid-configuration')
? provider.issuer.replace('/.well-known/openid-configuration', '')
: provider.issuer;
userFriendlyError = `Invalid OIDC discovery document. The server at '${baseUrl}/.well-known/openid-configuration' returned a response that doesn't conform to the OpenID Connect Discovery specification. Please verify the endpoint returns valid OIDC metadata.`;
details = { type: 'INVALID_OIDC_DOCUMENT', originalError: errorMessage };
}
this.logger.warn(`OIDC validation failed for provider ${provider.id}: ${errorMessage}`);
// Add debug logging for HTTP status errors
if (errorMessage.includes('unexpected HTTP response status code')) {
const baseUrl = provider.issuer?.endsWith('/.well-known/openid-configuration')
? provider.issuer.replace('/.well-known/openid-configuration', '')
: provider.issuer;
this.logger.debug(`Attempted to fetch: ${baseUrl}/.well-known/openid-configuration`);
this.logger.debug(`Full error details: ${errorMessage}`);
}
return {
isValid: false,
error: userFriendlyError,
details,
};
}
}
async performDiscovery(provider: OidcProvider, clientOptions?: any): Promise<client.Configuration> {
if (!provider.issuer) {
throw new Error('No issuer URL provided');
}
// Configure client auth method
const clientAuth = provider.clientSecret
? client.ClientSecretPost(provider.clientSecret)
: undefined;
const serverUrl = new URL(provider.issuer);
// Use provided client options or create default options with HTTP support if needed
if (!clientOptions && serverUrl.protocol === 'http:') {
this.logger.debug(`Allowing HTTP for ${provider.id} as specified by user`);
// For openid-client v6, use allowInsecureRequests in the execute array
// This is deprecated but needed for local development with HTTP endpoints
clientOptions = {
execute: [client.allowInsecureRequests],
};
}
return client.discovery(
serverUrl,
provider.clientId,
undefined, // client metadata
clientAuth,
clientOptions
);
}
}

View File

@@ -0,0 +1,10 @@
import { Module } from '@nestjs/common';
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-session.service.js';
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
@Module({
providers: [OidcSessionService, OidcStateService],
exports: [OidcSessionService, OidcStateService],
})
export class OidcSessionModule {}

View File

@@ -4,7 +4,7 @@ import { Test } from '@nestjs/testing';
import type { Cache } from 'cache-manager'; import type { Cache } from 'cache-manager';
import { beforeEach, describe, expect, it, vi } from 'vitest'; import { beforeEach, describe, expect, it, vi } from 'vitest';
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/oidc-session.service.js'; import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-session.service.js';
describe('OidcSessionService', () => { describe('OidcSessionService', () => {
let service: OidcSessionService; let service: OidcSessionService;

View File

@@ -15,7 +15,7 @@ export interface OidcSession {
@Injectable() @Injectable()
export class OidcSessionService { export class OidcSessionService {
private readonly logger = new Logger(OidcSessionService.name); private readonly logger = new Logger(OidcSessionService.name);
private readonly SESSION_TTL_SECONDS = 2 * 60; // 2 minutes for one-time token security private readonly SESSION_TTL_MS = 2 * 60 * 1000; // 2 minutes in milliseconds (cache-manager v7 expects milliseconds)
constructor(@Inject(CACHE_MANAGER) private readonly cacheManager: Cache) {} constructor(@Inject(CACHE_MANAGER) private readonly cacheManager: Cache) {}
@@ -28,12 +28,21 @@ export class OidcSessionService {
providerId, providerId,
providerUserId, providerUserId,
createdAt: now, createdAt: now,
expiresAt: new Date(now.getTime() + this.SESSION_TTL_SECONDS * 1000), expiresAt: new Date(now.getTime() + this.SESSION_TTL_MS),
}; };
// Store in cache with TTL // Store in cache with TTL (in milliseconds for cache-manager v7)
await this.cacheManager.set(sessionId, session, this.SESSION_TTL_SECONDS * 1000); await this.cacheManager.set(sessionId, session, this.SESSION_TTL_MS);
this.logger.log(`Created OIDC session ${sessionId} for provider ${providerId}`);
// Verify it was stored
const verifyStored = await this.cacheManager.get(sessionId);
if (verifyStored) {
this.logger.debug(`Session successfully stored and verified with ID: ${sessionId}`);
} else {
this.logger.error(`CRITICAL: Session was NOT stored in cache for ID: ${sessionId}`);
}
this.logger.log(`Created OIDC session for provider ${providerId}`);
return this.createPaddedToken(sessionId); return this.createPaddedToken(sessionId);
} }
@@ -44,15 +53,16 @@ export class OidcSessionService {
return { valid: false }; return { valid: false };
} }
this.logger.debug(`Looking for session with ID: ${sessionId}`);
const session = await this.cacheManager.get<OidcSession>(sessionId); const session = await this.cacheManager.get<OidcSession>(sessionId);
if (!session) { if (!session) {
this.logger.debug(`Session ${sessionId} not found`); this.logger.debug(`Session not found for ID: ${sessionId}`);
return { valid: false }; return { valid: false };
} }
const now = new Date(); const now = new Date();
if (now > new Date(session.expiresAt)) { if (now > new Date(session.expiresAt)) {
this.logger.debug(`Session ${sessionId} expired`); this.logger.debug(`Session expired`);
await this.cacheManager.del(sessionId); await this.cacheManager.del(sessionId);
return { valid: false }; return { valid: false };
} }
@@ -62,7 +72,7 @@ export class OidcSessionService {
await this.cacheManager.del(sessionId); await this.cacheManager.del(sessionId);
this.logger.log( this.logger.log(
`Validated and invalidated session ${sessionId} for provider ${session.providerId} (one-time use)` `Validated and invalidated session for provider ${session.providerId} (one-time use)`
); );
return { valid: true, username: 'root' }; return { valid: true, username: 'root' };
} }

View File

@@ -0,0 +1,155 @@
import { CacheModule } from '@nestjs/cache-manager';
import { UnauthorizedException } from '@nestjs/common';
import { Test } from '@nestjs/testing';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { OidcStateExtractor } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state-extractor.util.js';
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
describe('OidcStateExtractor', () => {
let stateService: OidcStateService;
beforeEach(async () => {
vi.clearAllMocks();
const module = await Test.createTestingModule({
imports: [CacheModule.register()],
providers: [OidcStateService],
}).compile();
stateService = module.get<OidcStateService>(OidcStateService);
});
describe('extractProviderFromState', () => {
it('should extract provider ID from valid state', () => {
const state = 'provider123:nonce.timestamp.signature';
const result = OidcStateExtractor.extractProviderFromState(state, stateService);
expect(result.providerId).toBe('provider123');
expect(result.originalState).toBe(state);
});
it('should handle state without provider prefix', () => {
const state = 'invalid-state-format';
const result = OidcStateExtractor.extractProviderFromState(state, stateService);
expect(result.providerId).toBe('');
expect(result.originalState).toBe(state);
});
});
describe('extractAndValidateState', () => {
it('should extract and validate a valid state with redirectUri', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const redirectUri = 'https://example.com/callback';
// Generate a valid state
const state = await stateService.generateSecureState(providerId, clientState, redirectUri);
// Extract and validate
const result = await OidcStateExtractor.extractAndValidateState(state, stateService);
expect(result.providerId).toBe(providerId);
expect(result.originalState).toBe(state);
expect(result.clientState).toBe(clientState);
expect(result.redirectUri).toBe(redirectUri);
});
it('should extract and validate a valid state without redirectUri', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
// Generate a valid state without redirectUri
const state = await stateService.generateSecureState(providerId, clientState);
// Extract and validate
const result = await OidcStateExtractor.extractAndValidateState(state, stateService);
expect(result.providerId).toBe(providerId);
expect(result.originalState).toBe(state);
expect(result.clientState).toBe(clientState);
expect(result.redirectUri).toBeUndefined();
});
it('should throw UnauthorizedException for invalid state format', async () => {
const invalidState = 'invalid-format';
await expect(async () => {
await OidcStateExtractor.extractAndValidateState(invalidState, stateService);
}).rejects.toThrow(UnauthorizedException);
});
it('should throw UnauthorizedException for expired state', async () => {
vi.useFakeTimers();
const providerId = 'test-provider';
const clientState = 'client-state-123';
const redirectUri = 'https://example.com/callback';
// Generate a valid state
const state = await stateService.generateSecureState(providerId, clientState, redirectUri);
// Fast forward time beyond expiration (11 minutes)
vi.advanceTimersByTime(11 * 60 * 1000);
await expect(async () => {
await OidcStateExtractor.extractAndValidateState(state, stateService);
}).rejects.toThrow(UnauthorizedException);
vi.useRealTimers();
});
it('should throw UnauthorizedException for wrong provider ID', async () => {
const providerId = 'test-provider';
const wrongProviderId = 'wrong-provider';
const clientState = 'client-state-123';
const redirectUri = 'https://example.com/callback';
// Generate a valid state
const state = await stateService.generateSecureState(providerId, clientState, redirectUri);
// Create a fake state with wrong provider prefix
const tamperedState = state.replace(providerId, wrongProviderId);
await expect(async () => {
await OidcStateExtractor.extractAndValidateState(tamperedState, stateService);
}).rejects.toThrow(UnauthorizedException);
});
it('should throw UnauthorizedException for tampered state', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const redirectUri = 'https://example.com/callback';
// Generate a valid state
const state = await stateService.generateSecureState(providerId, clientState, redirectUri);
// Tamper with the signature
const tamperedState = state.slice(0, -5) + 'xxxxx';
await expect(async () => {
await OidcStateExtractor.extractAndValidateState(tamperedState, stateService);
}).rejects.toThrow(UnauthorizedException);
});
it('should throw UnauthorizedException for reused state (replay attack)', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const redirectUri = 'https://example.com/callback';
// Generate a valid state
const state = await stateService.generateSecureState(providerId, clientState, redirectUri);
// First validation should succeed
const result1 = await OidcStateExtractor.extractAndValidateState(state, stateService);
expect(result1.providerId).toBe(providerId);
// Second validation should fail (replay attack)
await expect(async () => {
await OidcStateExtractor.extractAndValidateState(state, stateService);
}).rejects.toThrow(UnauthorizedException);
});
});
});

View File

@@ -0,0 +1,60 @@
import { UnauthorizedException } from '@nestjs/common';
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
export interface StateExtractionResult {
providerId: string;
originalState: string;
clientState?: string;
redirectUri?: string;
}
/**
* Utility to extract and validate OIDC state information consistently
* across authorize and callback endpoints
*/
export class OidcStateExtractor {
/**
* Extract provider ID from state without validation (for routing purposes)
*/
static extractProviderFromState(
state: string,
stateService: OidcStateService
): { providerId: string; originalState: string } {
// Use the state service's extraction method
const providerId = stateService.extractProviderFromState(state);
return {
providerId: providerId || '',
originalState: state,
};
}
/**
* Extract provider ID and validate the full encrypted state
*/
static async extractAndValidateState(
state: string,
stateService: OidcStateService
): Promise<StateExtractionResult> {
// First extract provider ID for routing
const { providerId } = this.extractProviderFromState(state, stateService);
if (!providerId) {
throw new UnauthorizedException('Invalid state format: missing provider ID');
}
// Then validate the full encrypted state
const stateValidation = await stateService.validateSecureState(state, providerId);
if (!stateValidation.isValid) {
throw new UnauthorizedException(`Invalid state: ${stateValidation.error}`);
}
return {
providerId,
originalState: state,
clientState: stateValidation.clientState,
redirectUri: stateValidation.redirectUri,
};
}
}

View File

@@ -0,0 +1,238 @@
import { CacheModule } from '@nestjs/cache-manager';
import { Test, TestingModule } from '@nestjs/testing';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
describe('OidcStateService', () => {
let service: OidcStateService;
let module: TestingModule;
beforeEach(async () => {
vi.clearAllMocks();
vi.useFakeTimers();
// Set a deterministic system time for consistent testing
vi.setSystemTime(new Date('2024-01-01T00:00:00Z'));
module = await Test.createTestingModule({
imports: [CacheModule.register()],
providers: [OidcStateService],
}).compile();
service = module.get<OidcStateService>(OidcStateService);
});
afterEach(async () => {
vi.useRealTimers();
// Close the testing module to prevent handle leaks
if (module) {
await module.close();
}
});
describe('generateSecureState', () => {
it('should generate a state with provider prefix and signed token', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const redirectUri = 'https://example.com/callback';
const state = await service.generateSecureState(providerId, clientState, redirectUri);
expect(state).toBeTruthy();
expect(typeof state).toBe('string');
expect(state.startsWith(`${providerId}:`)).toBe(true);
// Extract signed portion and verify format (nonce.timestamp.signature)
const signed = state.substring(providerId.length + 1);
expect(signed.split('.').length).toBe(3);
});
it('should generate unique states for each call', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const redirectUri = 'https://example.com/callback';
const state1 = await service.generateSecureState(providerId, clientState, redirectUri);
const state2 = await service.generateSecureState(providerId, clientState, redirectUri);
expect(state1).not.toBe(state2);
});
it('should work without redirectUri parameter (backwards compatibility)', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state = await service.generateSecureState(providerId, clientState);
expect(state).toBeTruthy();
expect(state.startsWith(`${providerId}:`)).toBe(true);
});
it('should store state data in cache and retrieve it', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const redirectUri = 'https://example.com/callback';
const state = await service.generateSecureState(providerId, clientState, redirectUri);
const validation = await service.validateSecureState(state, providerId);
expect(validation.isValid).toBe(true);
expect(validation.clientState).toBe(clientState);
expect(validation.redirectUri).toBe(redirectUri);
});
});
describe('validateSecureState', () => {
it('should validate a valid state token', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state = await service.generateSecureState(providerId, clientState);
const result = await service.validateSecureState(state, providerId);
expect(result.isValid).toBe(true);
expect(result.clientState).toBe(clientState);
expect(result.error).toBeUndefined();
});
it('should validate a state token with redirectUri', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const redirectUri = 'https://example.com/callback';
const state = await service.generateSecureState(providerId, clientState, redirectUri);
const result = await service.validateSecureState(state, providerId);
expect(result.isValid).toBe(true);
expect(result.clientState).toBe(clientState);
expect(result.redirectUri).toBe(redirectUri);
expect(result.error).toBeUndefined();
});
it('should reject state with wrong provider ID', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state = await service.generateSecureState(providerId, clientState);
const result = await service.validateSecureState(state, 'different-provider');
expect(result.isValid).toBe(false);
expect(result.error).toContain('Provider ID mismatch');
});
it('should reject expired state tokens', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state = await service.generateSecureState(providerId, clientState);
// Advance time by 11 minutes (past the 10-minute TTL)
vi.advanceTimersByTime(11 * 60 * 1000);
const result = await service.validateSecureState(state, providerId);
expect(result.isValid).toBe(false);
expect(result.error).toContain('expired');
});
it('should reject reused state tokens', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state = await service.generateSecureState(providerId, clientState);
// First validation should succeed
const result1 = await service.validateSecureState(state, providerId);
expect(result1.isValid).toBe(true);
// Second validation should fail (replay attack prevention)
const result2 = await service.validateSecureState(state, providerId);
expect(result2.isValid).toBe(false);
expect(result2.error).toContain('not found or already used');
});
it('should reject invalid state tokens', async () => {
const providerId = 'test-provider';
const invalidState = `${providerId}:invalid-format`;
const result = await service.validateSecureState(invalidState, providerId);
expect(result.isValid).toBe(false);
expect(result.error).toBeTruthy();
});
it('should reject tampered state tokens', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state = await service.generateSecureState(providerId, clientState);
// Tamper with the signature
const tamperedState = state.substring(0, state.length - 5) + 'xxxxx';
const result = await service.validateSecureState(tamperedState, providerId);
expect(result.isValid).toBe(false);
expect(result.error).toContain('signature');
});
});
describe('extractProviderFromState', () => {
it('should extract provider ID from state', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state = await service.generateSecureState(providerId, clientState);
const extracted = service.extractProviderFromState(state);
expect(extracted).toBe(providerId);
});
it('should return null for invalid state format', () => {
const invalidState = 'invalid-state-without-colon';
const extracted = service.extractProviderFromState(invalidState);
expect(extracted).toBeNull();
});
});
describe('extractProviderFromLegacyState', () => {
it('should handle legacy state format', () => {
const legacyState = 'provider-id:client-state-value';
const result = service.extractProviderFromLegacyState(legacyState);
expect(result.providerId).toBe('provider-id');
expect(result.originalState).toBe('client-state-value');
});
it('should handle new signed state format', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state = await service.generateSecureState(providerId, clientState);
const result = service.extractProviderFromLegacyState(state);
// New format should not be recognized as legacy
expect(result.providerId).toBe('');
expect(result.originalState).toBe(state);
});
});
describe('cache TTL', () => {
it('should remove state from cache after successful validation', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state = await service.generateSecureState(providerId, clientState);
// First validation should succeed
const result1 = await service.validateSecureState(state, providerId);
expect(result1.isValid).toBe(true);
// Second validation should fail (state was removed after first use)
const result2 = await service.validateSecureState(state, providerId);
expect(result2.isValid).toBe(false);
expect(result2.error).toContain('not found or already used');
});
});
});

View File

@@ -0,0 +1,241 @@
import { CacheModule } from '@nestjs/cache-manager';
import { Test } from '@nestjs/testing';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
describe('OidcStateService', () => {
let service: OidcStateService;
beforeEach(async () => {
const module = await Test.createTestingModule({
imports: [CacheModule.register()],
providers: [OidcStateService],
}).compile();
service = module.get<OidcStateService>(OidcStateService);
});
describe('state generation and validation flow', () => {
it('should generate state with redirect URI and validate it successfully', async () => {
const providerId = 'unraid.net';
const clientState = 'client-state-123';
const redirectUri = 'http://devgen-dev1.local/graphql/api/auth/oidc/callback';
// Generate state
const state = await service.generateSecureState(providerId, clientState, redirectUri);
// Verify state format: providerId:nonce.timestamp.signature
expect(state).toMatch(/^unraid\.net:[a-f0-9]+\.\d+\.[a-f0-9]+$/);
// Extract and verify parts
const [extractedProviderId, signedState] = state.split(':');
expect(extractedProviderId).toBe(providerId);
// Parse the signed state components
const [nonce, timestamp, signature] = signedState.split('.');
// Verify nonce is a 32-character hex string (16 bytes)
expect(nonce).toMatch(/^[a-f0-9]{32}$/);
// Verify timestamp is a valid number and recent
const timestampNum = parseInt(timestamp, 10);
expect(timestampNum).toBeGreaterThan(Date.now() - 1000); // Generated within last second
expect(timestampNum).toBeLessThanOrEqual(Date.now());
// Verify signature is a 64-character hex string (SHA256 output)
expect(signature).toMatch(/^[a-f0-9]{64}$/);
// Validate the state
const validation = await service.validateSecureState(state, providerId);
expect(validation.isValid).toBe(true);
expect(validation.clientState).toBe(clientState);
expect(validation.redirectUri).toBe(redirectUri);
});
it('should verify signed state integrity with HMAC', async () => {
const providerId = 'test-provider';
const clientState = 'test-state';
const redirectUri = 'http://localhost:3000/callback';
const state = await service.generateSecureState(providerId, clientState, redirectUri);
// Tamper with the signature
const [provider, signedState] = state.split(':');
const [nonce, timestamp] = signedState.split('.');
const tamperedSignature = 'a'.repeat(64); // Invalid signature
const tamperedState = `${provider}:${nonce}.${timestamp}.${tamperedSignature}`;
const validation = await service.validateSecureState(tamperedState, providerId);
expect(validation.isValid).toBe(false);
expect(validation.error).toContain('Invalid state signature');
});
it('should fail validation when nonce is not in cache', async () => {
const providerId = 'unraid.net';
// Create a fake state that looks valid but has unknown nonce
const fakeState = `unraid.net:fakenonce123.${Date.now()}.fakesignature456`;
const validation = await service.validateSecureState(fakeState, providerId);
expect(validation.isValid).toBe(false);
expect(validation.error).toContain('Invalid state signature');
});
it('should prevent replay attacks by removing nonce after validation', async () => {
const providerId = 'test-provider';
const clientState = 'test-state';
const redirectUri = 'http://localhost:3000/callback';
// Generate and validate state once
const state = await service.generateSecureState(providerId, clientState, redirectUri);
const firstValidation = await service.validateSecureState(state, providerId);
expect(firstValidation.isValid).toBe(true);
// Try to validate the same state again (replay attack)
const secondValidation = await service.validateSecureState(state, providerId);
expect(secondValidation.isValid).toBe(false);
expect(secondValidation.error).toContain('State token not found or already used');
});
it('should handle state with missing redirect URI', async () => {
const providerId = 'test-provider';
const clientState = 'test-state';
// No redirect URI provided
const state = await service.generateSecureState(providerId, clientState);
const validation = await service.validateSecureState(state, providerId);
expect(validation.isValid).toBe(true);
expect(validation.clientState).toBe(clientState);
expect(validation.redirectUri).toBeUndefined();
});
it('should reject state with wrong provider ID', async () => {
const providerId = 'provider-a';
const wrongProviderId = 'provider-b';
const clientState = 'test-state';
const state = await service.generateSecureState(providerId, clientState);
const validation = await service.validateSecureState(state, wrongProviderId);
expect(validation.isValid).toBe(false);
expect(validation.error).toContain('Provider ID mismatch');
});
it('should extract provider from state correctly', async () => {
const providerId = 'unraid.net';
const state = await service.generateSecureState(providerId, 'test', 'http://example.com');
const extracted = service.extractProviderFromState(state);
expect(extracted).toBe(providerId);
});
it('should handle state expiration', async () => {
const providerId = 'test-provider';
const clientState = 'test-state';
// Generate state
const state = await service.generateSecureState(providerId, clientState);
// Mock timestamp to simulate expired state
const parts = state.split(':')[1].split('.');
const nonce = parts[0];
const expiredTimestamp = Date.now() - 700000; // 11+ minutes ago
const fakeState = `${providerId}:${nonce}.${expiredTimestamp}.fakesignature`;
const validation = await service.validateSecureState(fakeState, providerId);
expect(validation.isValid).toBe(false);
expect(validation.error).toContain('Invalid state signature'); // Will fail on signature first
});
});
describe('redirect URI extraction from state', () => {
it('should store and retrieve redirect URI from state token', async () => {
const providerId = 'unraid.net';
const clientState = 'original-client-state';
const redirectUri = 'http://devgen-dev1.local/graphql/api/auth/oidc/callback';
// This simulates the authorize flow
const stateToken = await service.generateSecureState(providerId, clientState, redirectUri);
// Log the generated state for debugging
console.log('Generated state token:', stateToken);
// This simulates the callback flow
const validation = await service.validateSecureState(stateToken, providerId);
expect(validation.isValid).toBe(true);
expect(validation.redirectUri).toBe(redirectUri);
expect(validation.clientState).toBe(clientState);
});
it('should handle dynamic redirect URIs for different origins', async () => {
const providerId = 'google';
const clientState = 'state123';
// Test with different origins
const origins = [
'http://localhost:3000/graphql/api/auth/oidc/callback',
'https://myserver.local/graphql/api/auth/oidc/callback',
'http://192.168.1.100/graphql/api/auth/oidc/callback',
];
for (const redirectUri of origins) {
const state = await service.generateSecureState(providerId, clientState, redirectUri);
const validation = await service.validateSecureState(state, providerId);
expect(validation.isValid).toBe(true);
expect(validation.redirectUri).toBe(redirectUri);
}
});
});
describe('cache management', () => {
it('should handle TTL expiration correctly', async () => {
const providerId = 'test-provider';
const clientState = 'test-state';
const state = await service.generateSecureState(providerId, clientState);
// First validation should succeed
const validation1 = await service.validateSecureState(state, providerId);
expect(validation1.isValid).toBe(true);
// State should be removed after first use (replay protection)
const validation2 = await service.validateSecureState(state, providerId);
expect(validation2.isValid).toBe(false);
});
it('should store complete state data in cache with redirect URI', async () => {
const providerId = 'test-provider';
const clientState = 'client-123';
const redirectUri = 'http://example.com/callback';
const state = await service.generateSecureState(providerId, clientState, redirectUri);
// Extract nonce from the generated state
const [, signedState] = state.split(':');
const [nonce] = signedState.split('.');
// Access the cache directly to verify stored data
const cacheKey = `oidc_state:${nonce}`;
const cachedData = await service['cacheManager'].get(cacheKey);
expect(cachedData).toBeDefined();
expect(cachedData).toMatchObject({
nonce,
clientState,
providerId,
redirectUri,
});
// @ts-expect-error - cachedData is of type StateData
expect(cachedData.timestamp).toBeGreaterThan(Date.now() - 1000);
// @ts-expect-error - cachedData is of type StateData
expect(cachedData.timestamp).toBeLessThanOrEqual(Date.now());
});
});
});

View File

@@ -1,4 +1,5 @@
import { Injectable, Logger } from '@nestjs/common'; import { Cache, CACHE_MANAGER } from '@nestjs/cache-manager';
import { Inject, Injectable, Logger } from '@nestjs/common';
import crypto from 'crypto'; import crypto from 'crypto';
interface StateData { interface StateData {
@@ -6,26 +7,34 @@ interface StateData {
clientState: string; clientState: string;
timestamp: number; timestamp: number;
providerId: string; providerId: string;
redirectUri?: string;
} }
@Injectable() @Injectable()
export class OidcStateService { export class OidcStateService {
private static instanceCount = 0;
private readonly instanceId: number;
private readonly logger = new Logger(OidcStateService.name); private readonly logger = new Logger(OidcStateService.name);
private readonly stateCache = new Map<string, StateData>();
private readonly hmacSecret: string; private readonly hmacSecret: string;
private readonly STATE_TTL_SECONDS = 600; // 10 minutes private readonly STATE_TTL_MS = 600000; // 10 minutes in milliseconds (cache-manager v7+ expects milliseconds, not seconds)
private readonly STATE_CACHE_PREFIX = 'oidc_state:';
constructor(@Inject(CACHE_MANAGER) private cacheManager: Cache) {
// Track instance creation
this.instanceId = ++OidcStateService.instanceCount;
constructor() {
// Always generate a new secret on API restart for security // Always generate a new secret on API restart for security
// This ensures state tokens cannot be reused across restarts // This ensures state tokens cannot be reused across restarts
this.hmacSecret = crypto.randomBytes(32).toString('hex'); this.hmacSecret = crypto.randomBytes(32).toString('hex');
this.logger.debug('Generated new OIDC state secret for this session'); this.logger.warn(`OidcStateService instance #${this.instanceId} created with new HMAC secret`);
this.logger.debug(`HMAC secret first 8 chars: ${this.hmacSecret.substring(0, 8)}`);
// Clean up expired states periodically
setInterval(() => this.cleanupExpiredStates(), 60000); // Every minute
} }
generateSecureState(providerId: string, clientState: string): string { async generateSecureState(
providerId: string,
clientState: string,
redirectUri?: string
): Promise<string> {
const nonce = crypto.randomBytes(16).toString('hex'); const nonce = crypto.randomBytes(16).toString('hex');
const timestamp = Date.now(); const timestamp = Date.now();
@@ -35,8 +44,21 @@ export class OidcStateService {
clientState, clientState,
timestamp, timestamp,
providerId, providerId,
redirectUri,
}; };
this.stateCache.set(nonce, stateData);
// Store in cache with TTL (in milliseconds for cache-manager v7)
const cacheKey = `${this.STATE_CACHE_PREFIX}${nonce}`;
this.logger.debug(`Storing state with key: ${cacheKey}, TTL: ${this.STATE_TTL_MS}ms`);
await this.cacheManager.set(cacheKey, stateData, this.STATE_TTL_MS);
// Verify it was stored
const verifyStored = await this.cacheManager.get(cacheKey);
if (verifyStored) {
this.logger.debug(`State successfully stored and verified for key: ${cacheKey}`);
} else {
this.logger.error(`CRITICAL: State was NOT stored in cache for key: ${cacheKey}`);
}
// Create signed state: nonce.timestamp.signature // Create signed state: nonce.timestamp.signature
const dataToSign = `${nonce}.${timestamp}`; const dataToSign = `${nonce}.${timestamp}`;
@@ -45,14 +67,18 @@ export class OidcStateService {
const signedState = `${dataToSign}.${signature}`; const signedState = `${dataToSign}.${signature}`;
this.logger.debug(`Generated secure state for provider ${providerId} with nonce ${nonce}`); this.logger.debug(`Generated secure state for provider ${providerId} with nonce ${nonce}`);
this.logger.debug(
`Instance #${this.instanceId}, HMAC secret first 8 chars: ${this.hmacSecret.substring(0, 8)}`
);
this.logger.debug(`Stored redirectUri: ${redirectUri}`);
// Return state with provider ID prefix (unencrypted) for routing // Return state with provider ID prefix (unencrypted) for routing
return `${providerId}:${signedState}`; return `${providerId}:${signedState}`;
} }
validateSecureState( async validateSecureState(
state: string, state: string,
expectedProviderId: string expectedProviderId: string
): { isValid: boolean; clientState?: string; error?: string } { ): Promise<{ isValid: boolean; clientState?: string; redirectUri?: string; error?: string }> {
try { try {
// Extract provider ID and signed state // Extract provider ID and signed state
const parts = state.split(':'); const parts = state.split(':');
@@ -107,7 +133,7 @@ export class OidcStateService {
// Check timestamp expiration // Check timestamp expiration
const now = Date.now(); const now = Date.now();
const age = now - timestamp; const age = now - timestamp;
if (age > this.STATE_TTL_SECONDS * 1000) { if (age > this.STATE_TTL_MS) {
this.logger.warn(`State validation failed: token expired (age: ${age}ms)`); this.logger.warn(`State validation failed: token expired (age: ${age}ms)`);
return { return {
isValid: false, isValid: false,
@@ -116,11 +142,21 @@ export class OidcStateService {
} }
// Check if state exists in cache (prevents replay attacks) // Check if state exists in cache (prevents replay attacks)
const cachedState = this.stateCache.get(nonce); const cacheKey = `${this.STATE_CACHE_PREFIX}${nonce}`;
this.logger.debug(`Looking for nonce ${nonce} in cache with key: ${cacheKey}`);
this.logger.debug(
`Instance #${this.instanceId}, HMAC secret first 8 chars: ${this.hmacSecret.substring(0, 8)}`
);
this.logger.debug(`Cache manager type: ${this.cacheManager.constructor.name}`);
const cachedState = await this.cacheManager.get<StateData>(cacheKey);
if (!cachedState) { if (!cachedState) {
this.logger.warn( this.logger.warn(
`State validation failed: nonce ${nonce} not found in cache (possible replay attack)` `State validation failed: nonce ${nonce} not found in cache (possible replay attack)`
); );
this.logger.warn(`Cache key checked: ${cacheKey}`);
return { return {
isValid: false, isValid: false,
error: 'State token not found or already used', error: 'State token not found or already used',
@@ -137,12 +173,13 @@ export class OidcStateService {
} }
// Remove from cache to prevent reuse // Remove from cache to prevent reuse
this.stateCache.delete(nonce); await this.cacheManager.del(cacheKey);
this.logger.debug(`State validation successful for provider ${expectedProviderId}`); this.logger.debug(`State validation successful for provider ${expectedProviderId}`);
return { return {
isValid: true, isValid: true,
clientState: cachedState.clientState, clientState: cachedState.clientState,
redirectUri: cachedState.redirectUri,
}; };
} catch (error) { } catch (error) {
this.logger.error( this.logger.error(
@@ -182,20 +219,5 @@ export class OidcStateService {
return null; return null;
} }
private cleanupExpiredStates(): void { // Cleanup is now handled by cache TTL
const now = Date.now();
let cleaned = 0;
for (const [nonce, stateData] of this.stateCache.entries()) {
const age = now - stateData.timestamp;
if (age > this.STATE_TTL_SECONDS * 1000) {
this.stateCache.delete(nonce);
cleaned++;
}
}
if (cleaned > 0) {
this.logger.debug(`Cleaned up ${cleaned} expired state entries`);
}
}
} }

View File

@@ -1,33 +1,13 @@
import { CacheModule } from '@nestjs/cache-manager';
import { Module } from '@nestjs/common'; import { Module } from '@nestjs/common';
import { UserSettingsModule } from '@unraid/shared/services/user-settings.js'; import { OidcCoreModule } from '@app/unraid-api/graph/resolvers/sso/core/oidc-core.module.js';
import { OidcAuthService } from '@app/unraid-api/graph/resolvers/sso/oidc-auth.service.js';
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/oidc-config.service.js';
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/oidc-session.service.js';
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/oidc-state.service.js';
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/oidc-validation.service.js';
import { SsoResolver } from '@app/unraid-api/graph/resolvers/sso/sso.resolver.js'; import { SsoResolver } from '@app/unraid-api/graph/resolvers/sso/sso.resolver.js';
import '@app/unraid-api/graph/resolvers/sso/sso-settings.types.js'; import '@app/unraid-api/graph/resolvers/sso/models/sso-settings.types.js';
@Module({ @Module({
imports: [UserSettingsModule, CacheModule.register()], imports: [OidcCoreModule],
providers: [ providers: [SsoResolver],
SsoResolver, exports: [OidcCoreModule],
OidcConfigPersistence,
OidcSessionService,
OidcStateService,
OidcAuthService,
OidcValidationService,
],
exports: [
OidcConfigPersistence,
OidcSessionService,
OidcStateService,
OidcAuthService,
OidcValidationService,
],
}) })
export class SsoModule {} export class SsoModule {}

View File

@@ -6,11 +6,12 @@ import { PrefixedID } from '@unraid/shared/prefixed-id-scalar.js';
import { UsePermissions } from '@unraid/shared/use-permissions.directive.js'; import { UsePermissions } from '@unraid/shared/use-permissions.directive.js';
import { Public } from '@app/unraid-api/auth/public.decorator.js'; import { Public } from '@app/unraid-api/auth/public.decorator.js';
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/oidc-config.service.js'; import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/oidc-provider.model.js'; import { OidcConfiguration } from '@app/unraid-api/graph/resolvers/sso/models/oidc-configuration.model.js';
import { OidcSessionValidation } from '@app/unraid-api/graph/resolvers/sso/oidc-session-validation.model.js'; import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/oidc-session.service.js'; import { OidcSessionValidation } from '@app/unraid-api/graph/resolvers/sso/models/oidc-session-validation.model.js';
import { PublicOidcProvider } from '@app/unraid-api/graph/resolvers/sso/public-oidc-provider.model.js'; import { PublicOidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/public-oidc-provider.model.js';
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-session.service.js';
@Resolver() @Resolver()
export class SsoResolver { export class SsoResolver {
@@ -88,6 +89,19 @@ export class SsoResolver {
return this.oidcConfig.getProvider(id); return this.oidcConfig.getProvider(id);
} }
@Query(() => OidcConfiguration, { description: 'Get the full OIDC configuration (admin only)' })
@UsePermissions({
action: AuthAction.READ_ANY,
resource: Resource.CONFIG,
})
public async oidcConfiguration(): Promise<OidcConfiguration> {
const config = await this.oidcConfig.getConfig();
return {
providers: config?.providers || [],
defaultAllowedOrigins: config?.defaultAllowedOrigins || [],
};
}
@Query(() => OidcSessionValidation, { @Query(() => OidcSessionValidation, {
description: 'Validate an OIDC session token (internal use for CLI validation)', description: 'Validate an OIDC session token (internal use for CLI validation)',
}) })

View File

@@ -0,0 +1,234 @@
import { Logger } from '@nestjs/common';
export interface OidcErrorDetails {
userFriendlyError: string;
details: Record<string, unknown>;
}
export class OidcErrorHelper {
private static readonly logger = new Logger(OidcErrorHelper.name);
/**
* Parse fetch errors and return user-friendly error messages
*/
static parseFetchError(error: unknown, issuerUrl?: string): OidcErrorDetails {
const errorMessage = error instanceof Error ? error.message : String(error);
let userFriendlyError = errorMessage;
let details: Record<string, unknown> = { originalError: errorMessage };
// Extract cause information if available
if (error instanceof Error && 'cause' in error) {
const cause = (error as any).cause;
if (cause) {
this.logger.log('Fetch error cause: %o', cause);
const errorCode = cause.code || '';
const causeMessage = cause.message || '';
// Map error codes to user-friendly messages
switch (errorCode) {
case 'ENOTFOUND':
userFriendlyError = `Cannot resolve domain name. Please check that '${issuerUrl}' is accessible and spelled correctly.`;
details = {
type: 'DNS_ERROR',
originalError: errorMessage,
cause: causeMessage || errorCode,
};
break;
case 'ECONNREFUSED':
userFriendlyError = `Connection refused. The server at '${issuerUrl}' is not accepting connections.`;
details = {
type: 'CONNECTION_ERROR',
originalError: errorMessage,
cause: causeMessage || errorCode,
};
break;
case 'CERT_HAS_EXPIRED':
userFriendlyError = `SSL/TLS certificate error. The server certificate may be invalid or expired.`;
details = {
type: 'SSL_ERROR',
originalError: errorMessage,
cause: causeMessage || errorCode,
};
break;
case 'ETIMEDOUT':
userFriendlyError = `Connection timeout. The server at '${issuerUrl}' is not responding.`;
details = {
type: 'TIMEOUT_ERROR',
originalError: errorMessage,
cause: causeMessage || errorCode,
};
break;
default:
// Check message patterns if code doesn't match
if (causeMessage.includes('ENOTFOUND')) {
userFriendlyError = `Cannot resolve domain name. Please check that '${issuerUrl}' is accessible and spelled correctly.`;
details = {
type: 'DNS_ERROR',
originalError: errorMessage,
cause: causeMessage,
};
} else if (causeMessage.includes('ECONNREFUSED')) {
userFriendlyError = `Connection refused. The server at '${issuerUrl}' is not accepting connections.`;
details = {
type: 'CONNECTION_ERROR',
originalError: errorMessage,
cause: causeMessage,
};
} else if (
causeMessage.includes('certificate') ||
causeMessage.includes('SSL') ||
causeMessage.includes('TLS')
) {
userFriendlyError = `SSL/TLS certificate error. The server certificate may be invalid or expired.`;
details = {
type: 'SSL_ERROR',
originalError: errorMessage,
cause: causeMessage,
};
} else if (causeMessage.includes('ETIMEDOUT')) {
userFriendlyError = `Connection timeout. The server at '${issuerUrl}' is not responding.`;
details = {
type: 'TIMEOUT_ERROR',
originalError: errorMessage,
cause: causeMessage,
};
} else {
userFriendlyError = `Failed to connect to OIDC provider at '${issuerUrl}'. ${causeMessage || errorCode || 'Unknown network error'}`;
details = {
type: 'FETCH_ERROR',
originalError: errorMessage,
cause: causeMessage || errorCode,
};
}
break;
}
} else {
// Generic fetch failed without cause
userFriendlyError = `Failed to connect to OIDC provider at '${issuerUrl}'. Please verify the URL is correct and accessible.`;
details = { type: 'FETCH_ERROR', originalError: errorMessage };
}
} else if (errorMessage.includes('fetch failed')) {
// Fetch failed but no cause information
userFriendlyError = `Failed to connect to OIDC provider at '${issuerUrl}'. Please verify the URL is correct and accessible.`;
details = { type: 'FETCH_ERROR', originalError: errorMessage };
}
return { userFriendlyError, details };
}
/**
* Parse HTTP status errors and return user-friendly error messages
*/
static parseHttpError(errorMessage: string, issuerUrl?: string): OidcErrorDetails {
let userFriendlyError = errorMessage;
let details: Record<string, unknown> = { originalError: errorMessage };
if (errorMessage.includes('404') || errorMessage.includes('Not Found')) {
const baseUrl = issuerUrl?.endsWith('/.well-known/openid-configuration')
? issuerUrl.replace('/.well-known/openid-configuration', '')
: issuerUrl;
userFriendlyError = `OIDC discovery endpoint not found. Please verify that '${baseUrl}/.well-known/openid-configuration' exists.`;
details = { type: 'DISCOVERY_NOT_FOUND', originalError: errorMessage };
} else if (errorMessage.includes('401') || errorMessage.includes('403')) {
userFriendlyError = `Access denied to discovery endpoint. Please check the issuer URL and any authentication requirements.`;
details = { type: 'AUTHENTICATION_ERROR', originalError: errorMessage };
} else if (errorMessage.includes('unexpected HTTP response status code')) {
// Extract status code if possible
const statusMatch = errorMessage.match(/status code (\d+)/);
const statusCode = statusMatch ? statusMatch[1] : 'unknown';
const baseUrl = issuerUrl?.endsWith('/.well-known/openid-configuration')
? issuerUrl.replace('/.well-known/openid-configuration', '')
: issuerUrl;
userFriendlyError = `HTTP ${statusCode} error from discovery endpoint. Please check that '${baseUrl}/.well-known/openid-configuration' returns a valid OIDC discovery document.`;
details = { type: 'HTTP_STATUS_ERROR', statusCode, originalError: errorMessage };
}
return { userFriendlyError, details };
}
/**
* Parse generic OIDC errors and return user-friendly error messages
*/
static parseGenericError(error: unknown, issuerUrl?: string): OidcErrorDetails {
const errorMessage = error instanceof Error ? error.message : String(error);
let userFriendlyError = errorMessage;
let details: Record<string, unknown> = { originalError: errorMessage };
// Check for specific error patterns
if (errorMessage.includes('getaddrinfo ENOTFOUND')) {
userFriendlyError = `Cannot resolve domain name. Please check that '${issuerUrl}' is accessible and spelled correctly.`;
details = { type: 'DNS_ERROR', originalError: errorMessage };
} else if (errorMessage.includes('ECONNREFUSED')) {
userFriendlyError = `Connection refused. The server at '${issuerUrl}' is not accepting connections.`;
details = { type: 'CONNECTION_ERROR', originalError: errorMessage };
} else if (errorMessage.includes('ECONNRESET') || errorMessage.includes('ETIMEDOUT')) {
userFriendlyError = `Connection timeout. The server at '${issuerUrl}' is not responding.`;
details = { type: 'TIMEOUT_ERROR', originalError: errorMessage };
} else if (
errorMessage.includes('certificate') ||
errorMessage.includes('SSL') ||
errorMessage.includes('TLS')
) {
userFriendlyError = `SSL/TLS certificate error. The server certificate may be invalid or expired.`;
details = { type: 'SSL_ERROR', originalError: errorMessage };
} else if (errorMessage.includes('JSON') || errorMessage.includes('parse')) {
userFriendlyError = `Invalid OIDC discovery response. The server returned malformed JSON.`;
details = { type: 'INVALID_JSON', originalError: errorMessage };
} else if (error && (error as any).code === 'OAUTH_RESPONSE_IS_NOT_CONFORM') {
const baseUrl = issuerUrl?.endsWith('/.well-known/openid-configuration')
? issuerUrl.replace('/.well-known/openid-configuration', '')
: issuerUrl;
userFriendlyError = `Invalid OIDC discovery document. The server at '${baseUrl}/.well-known/openid-configuration' returned a response that doesn't conform to the OpenID Connect Discovery specification. Please verify the endpoint returns valid OIDC metadata.`;
details = { type: 'INVALID_OIDC_DOCUMENT', originalError: errorMessage };
}
return { userFriendlyError, details };
}
/**
* Parse OIDC discovery errors and return user-friendly error messages
*/
static parseDiscoveryError(error: unknown, issuerUrl?: string): OidcErrorDetails {
const errorMessage = error instanceof Error ? error.message : String(error);
// Log additional error details for debugging
if (error instanceof Error) {
this.logger.log(`Error type: ${error.constructor.name}`);
if ('stack' in error && error.stack) {
this.logger.debug(`Stack trace: ${error.stack}`);
}
if ('response' in error) {
const response = (error as any).response;
if (response) {
this.logger.log(`Response status: ${response.status}`);
this.logger.log(`Response body: ${response.body}`);
}
}
}
// Check for fetch-specific errors first
if (errorMessage.includes('fetch failed')) {
return this.parseFetchError(error, issuerUrl);
}
// Check for HTTP status errors
const httpError = this.parseHttpError(errorMessage, issuerUrl);
// Proper type-narrowing guard for accessing details.type
if (
httpError.details &&
typeof httpError.details === 'object' &&
'type' in httpError.details &&
httpError.details.type !== undefined
) {
return httpError;
}
// Fall back to generic error parsing
return this.parseGenericError(error, issuerUrl);
}
}

View File

@@ -0,0 +1,228 @@
import { Logger } from '@nestjs/common';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import type { FastifyRequest } from '@app/unraid-api/types/fastify.js';
import { OidcRequestHandler } from '@app/unraid-api/graph/resolvers/sso/utils/oidc-request-handler.util.js';
describe('OidcRequestHandler', () => {
let mockLogger: Logger;
beforeEach(() => {
vi.clearAllMocks();
mockLogger = {
debug: vi.fn(),
log: vi.fn(),
warn: vi.fn(),
error: vi.fn(),
} as any;
});
describe('extractRequestInfo', () => {
it('should extract request info from headers', () => {
const mockReq = {
headers: {
'x-forwarded-proto': 'https',
'x-forwarded-host': 'example.com:8443',
},
protocol: 'http',
url: '/callback?code=123&state=456',
} as unknown as FastifyRequest;
const result = OidcRequestHandler.extractRequestInfo(mockReq);
expect(result.protocol).toBe('https');
expect(result.host).toBe('example.com:8443');
expect(result.fullUrl).toBe('https://example.com:8443/callback?code=123&state=456');
expect(result.baseUrl).toBe('https://example.com:8443');
});
it('should fall back to request properties when headers are missing', () => {
const mockReq = {
headers: {
host: 'localhost:3000',
},
protocol: 'http',
url: '/callback?code=123&state=456',
} as FastifyRequest;
const result = OidcRequestHandler.extractRequestInfo(mockReq);
expect(result.protocol).toBe('http');
expect(result.host).toBe('localhost:3000');
expect(result.fullUrl).toBe('http://localhost:3000/callback?code=123&state=456');
expect(result.baseUrl).toBe('http://localhost:3000');
});
it('should use defaults when all headers are missing', () => {
const mockReq = {
headers: {},
url: '/callback?code=123&state=456',
} as FastifyRequest;
const result = OidcRequestHandler.extractRequestInfo(mockReq);
expect(result.protocol).toBe('http');
expect(result.host).toBe('localhost:3000');
expect(result.fullUrl).toBe('http://localhost:3000/callback?code=123&state=456');
expect(result.baseUrl).toBe('http://localhost:3000');
});
});
describe('validateAuthorizeParams', () => {
it('should validate valid parameters', () => {
const result = OidcRequestHandler.validateAuthorizeParams(
'provider123',
'state456',
'https://example.com/callback'
);
expect(result.providerId).toBe('provider123');
expect(result.state).toBe('state456');
expect(result.redirectUri).toBe('https://example.com/callback');
});
it('should throw error for missing provider ID', () => {
expect(() => {
OidcRequestHandler.validateAuthorizeParams(
undefined,
'state456',
'https://example.com/callback'
);
}).toThrow('Provider ID is required');
});
it('should throw error for missing state', () => {
expect(() => {
OidcRequestHandler.validateAuthorizeParams(
'provider123',
undefined,
'https://example.com/callback'
);
}).toThrow('State parameter is required');
});
it('should throw error for missing redirect URI', () => {
expect(() => {
OidcRequestHandler.validateAuthorizeParams('provider123', 'state456', undefined);
}).toThrow('Redirect URI is required');
});
});
describe('validateCallbackParams', () => {
it('should validate valid parameters', () => {
const result = OidcRequestHandler.validateCallbackParams('code123', 'state456');
expect(result.code).toBe('code123');
expect(result.state).toBe('state456');
});
it('should throw error for missing code', () => {
expect(() => {
OidcRequestHandler.validateCallbackParams(undefined, 'state456');
}).toThrow('Missing required parameters');
});
it('should throw error for missing state', () => {
expect(() => {
OidcRequestHandler.validateCallbackParams('code123', undefined);
}).toThrow('Missing required parameters');
});
it('should throw error for empty code', () => {
expect(() => {
OidcRequestHandler.validateCallbackParams('', 'state456');
}).toThrow('Missing required parameters');
});
it('should throw error for empty state', () => {
expect(() => {
OidcRequestHandler.validateCallbackParams('code123', '');
}).toThrow('Missing required parameters');
});
});
describe('handleAuthorize', () => {
it('should handle authorization flow', async () => {
const mockAuthService = {
getAuthorizationUrl: vi
.fn()
.mockResolvedValue('https://provider.com/auth?client_id=123'),
};
const mockReq = {
headers: { 'x-forwarded-proto': 'https', 'x-forwarded-host': 'example.com' },
url: '/authorize',
} as unknown as FastifyRequest;
const authUrl = await OidcRequestHandler.handleAuthorize(
'provider123',
'state456',
'https://example.com/callback',
mockReq,
mockAuthService as any,
mockLogger
);
expect(authUrl).toBe('https://provider.com/auth?client_id=123');
expect(mockAuthService.getAuthorizationUrl).toHaveBeenCalledWith({
providerId: 'provider123',
state: 'state456',
requestOrigin: 'https://example.com/callback',
requestHeaders: {
'x-forwarded-proto': 'https',
'x-forwarded-host': 'example.com',
},
});
expect(mockLogger.debug).toHaveBeenCalledWith(
'Authorization request - Provider: provider123'
);
expect(mockLogger.log).toHaveBeenCalledWith(
'Redirecting to OIDC provider: https://provider.com/auth?client_id=123'
);
});
});
describe('handleCallback', () => {
it('should handle callback flow', async () => {
const mockStateService = {
extractProviderFromState: vi.fn().mockReturnValue('provider123'),
};
const mockAuthService = {
getStateService: vi.fn().mockReturnValue(mockStateService),
handleCallback: vi.fn().mockResolvedValue('paddedToken123'),
};
const mockReq: Pick<FastifyRequest, 'id' | 'headers' | 'url'> = {
id: '123',
headers: { 'x-forwarded-proto': 'https', 'x-forwarded-host': 'example.com' },
url: '/callback?code=123&state=456',
};
const result = await OidcRequestHandler.handleCallback(
'code123',
'state456',
mockReq as unknown as FastifyRequest,
mockAuthService as any,
mockLogger
);
expect(result.providerId).toBe('provider123');
expect(result.paddedToken).toBe('paddedToken123');
expect(result.requestInfo.fullUrl).toBe('https://example.com/callback?code=123&state=456');
expect(mockAuthService.handleCallback).toHaveBeenCalledWith({
providerId: 'provider123',
code: 'code123',
state: 'state456',
requestOrigin: 'https://example.com',
fullCallbackUrl: 'https://example.com/callback?code=123&state=456',
requestHeaders: {
'x-forwarded-proto': 'https',
'x-forwarded-host': 'example.com',
},
});
expect(mockLogger.debug).toHaveBeenCalledWith('Callback request - Provider: provider123');
});
});
});

View File

@@ -0,0 +1,155 @@
import { Logger } from '@nestjs/common';
import type { FastifyRequest } from '@app/unraid-api/types/fastify.js';
import { OidcService } from '@app/unraid-api/graph/resolvers/sso/core/oidc.service.js';
import { OidcStateExtractor } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state-extractor.util.js';
export interface RequestInfo {
protocol: string;
host: string;
fullUrl: string;
baseUrl: string;
}
export interface OidcFlowResult {
providerId: string;
requestInfo: RequestInfo;
}
export interface OidcCallbackResult extends OidcFlowResult {
paddedToken: string;
}
/**
* Utility class to handle common OIDC request processing logic
* between authorize and callback endpoints
*/
export class OidcRequestHandler {
/**
* Extract request information from Fastify request headers
*/
static extractRequestInfo(req: FastifyRequest): RequestInfo {
// Handle potentially comma-separated forwarded headers (take first value)
const forwardedProto = String(req.headers['x-forwarded-proto'] || '')
.split(',')[0]
?.trim();
const forwardedHost = String(req.headers['x-forwarded-host'] || '')
.split(',')[0]
?.trim();
const protocol = forwardedProto || req.protocol || 'http';
const host = forwardedHost || req.headers.host || 'localhost:3000';
const fullUrl = `${protocol}://${host}${req.url}`;
const baseUrl = `${protocol}://${host}`;
return {
protocol,
host,
fullUrl,
baseUrl,
};
}
/**
* Handle OIDC authorization flow
*/
static async handleAuthorize(
providerId: string,
state: string,
redirectUri: string,
req: FastifyRequest,
oidcService: OidcService,
logger: Logger
): Promise<string> {
const requestInfo = this.extractRequestInfo(req);
logger.debug(`Authorization request - Provider: ${providerId}`);
logger.debug(`Authorization request - Full URL: ${requestInfo.fullUrl}`);
logger.debug(`Authorization request - Redirect URI: ${redirectUri}`);
// Get authorization URL using the validated redirect URI and request headers
const authUrl = await oidcService.getAuthorizationUrl({
providerId,
state,
requestOrigin: redirectUri,
requestHeaders: req.headers as Record<string, string | string[] | undefined>,
});
logger.log(`Redirecting to OIDC provider: ${authUrl}`);
return authUrl;
}
/**
* Handle OIDC callback flow
*/
static async handleCallback(
code: string,
state: string,
req: FastifyRequest,
oidcService: OidcService,
logger: Logger
): Promise<OidcCallbackResult> {
// Extract provider ID from state for routing
const { providerId } = OidcStateExtractor.extractProviderFromState(
state,
oidcService.getStateService()
);
const requestInfo = this.extractRequestInfo(req);
logger.debug(`Callback request - Provider: ${providerId}`);
logger.debug(`Callback request - Full URL: ${requestInfo.fullUrl}`);
logger.debug(`Redirect URI will be retrieved from encrypted state`);
// Handle the callback using stored redirect URI from state and request headers
const paddedToken = await oidcService.handleCallback({
providerId,
code,
state,
requestOrigin: requestInfo.baseUrl,
fullCallbackUrl: requestInfo.fullUrl,
requestHeaders: req.headers as Record<string, string | string[] | undefined>,
});
return {
providerId,
requestInfo,
paddedToken,
};
}
/**
* Validate required parameters for authorization flow
*/
static validateAuthorizeParams(
providerId: string | undefined,
state: string | undefined,
redirectUri: string | undefined
): { providerId: string; state: string; redirectUri: string } {
if (!providerId) {
throw new Error('Provider ID is required');
}
if (!state) {
throw new Error('State parameter is required');
}
if (!redirectUri) {
throw new Error('Redirect URI is required');
}
return { providerId, state, redirectUri };
}
/**
* Validate required parameters for callback flow
*/
static validateCallbackParams(
code: string | undefined,
state: string | undefined
): { code: string; state: string } {
if (!code || !state) {
throw new Error('Missing required parameters');
}
return { code, state };
}
}

View File

@@ -2,12 +2,12 @@ import { Module } from '@nestjs/common';
import { ScheduleModule } from '@nestjs/schedule'; import { ScheduleModule } from '@nestjs/schedule';
import { SubscriptionHelperService } from '@app/unraid-api/graph/services/subscription-helper.service.js'; import { SubscriptionHelperService } from '@app/unraid-api/graph/services/subscription-helper.service.js';
import { SubscriptionPollingService } from '@app/unraid-api/graph/services/subscription-polling.service.js'; import { SubscriptionManagerService } from '@app/unraid-api/graph/services/subscription-manager.service.js';
import { SubscriptionTrackerService } from '@app/unraid-api/graph/services/subscription-tracker.service.js'; import { SubscriptionTrackerService } from '@app/unraid-api/graph/services/subscription-tracker.service.js';
@Module({ @Module({
imports: [], imports: [],
providers: [SubscriptionTrackerService, SubscriptionHelperService, SubscriptionPollingService], providers: [SubscriptionTrackerService, SubscriptionHelperService, SubscriptionManagerService],
exports: [SubscriptionTrackerService, SubscriptionHelperService, SubscriptionPollingService], exports: [SubscriptionTrackerService, SubscriptionHelperService], // SubscriptionManagerService is internal
}) })
export class ServicesModule {} export class ServicesModule {}

View File

@@ -4,7 +4,25 @@ import { createSubscription, PUBSUB_CHANNEL } from '@app/core/pubsub.js';
import { SubscriptionTrackerService } from '@app/unraid-api/graph/services/subscription-tracker.service.js'; import { SubscriptionTrackerService } from '@app/unraid-api/graph/services/subscription-tracker.service.js';
/** /**
* Helper service for creating tracked GraphQL subscriptions with automatic cleanup * High-level helper service for creating GraphQL subscriptions with automatic cleanup.
*
* This service provides a convenient way to create GraphQL subscriptions that:
* - Automatically track subscriber count via SubscriptionTrackerService
* - Properly clean up resources when subscriptions end
* - Handle errors gracefully
*
* **When to use this service:**
* - In GraphQL resolvers when implementing subscriptions
* - When you need automatic reference counting for shared resources
* - When you want to ensure proper cleanup on subscription termination
*
* @example
* // In a GraphQL resolver
* \@Subscription(() => MetricsUpdate)
* async metricsSubscription() {
* // Topic must be registered first via SubscriptionTrackerService
* return this.subscriptionHelper.createTrackedSubscription(PUBSUB_CHANNEL.METRICS);
* }
*/ */
@Injectable() @Injectable()
export class SubscriptionHelperService { export class SubscriptionHelperService {
@@ -15,7 +33,7 @@ export class SubscriptionHelperService {
* @param topic The subscription topic/channel to subscribe to * @param topic The subscription topic/channel to subscribe to
* @returns A proxy async iterator with automatic cleanup * @returns A proxy async iterator with automatic cleanup
*/ */
public createTrackedSubscription<T = any>(topic: PUBSUB_CHANNEL): AsyncIterableIterator<T> { public createTrackedSubscription<T = any>(topic: PUBSUB_CHANNEL | string): AsyncIterableIterator<T> {
const innerIterator = createSubscription<T>(topic); const innerIterator = createSubscription<T>(topic);
// Subscribe when the subscription starts // Subscribe when the subscription starts

View File

@@ -0,0 +1,196 @@
import { Injectable, Logger, OnModuleDestroy } from '@nestjs/common';
import { SchedulerRegistry } from '@nestjs/schedule';
/**
* Configuration for managed subscriptions
*/
export interface SubscriptionConfig {
/** Unique identifier for the subscription */
name: string;
/**
* Polling interval in milliseconds.
* - If set to a number, the callback will be called at that interval
* - If null/undefined, the subscription is event-based (no polling)
*/
intervalMs?: number | null;
/** Function to call periodically (for polling) or once (for setup) */
callback: () => Promise<void>;
/** Optional function called when the subscription starts */
onStart?: () => Promise<void>;
/** Optional function called when the subscription stops */
onStop?: () => Promise<void>;
}
/**
* Low-level service for managing both polling and event-based subscriptions.
*
* ⚠️ **IMPORTANT**: This is an internal service. Do not use directly in resolvers or business logic.
* Instead, use one of the higher-level services:
* - **SubscriptionTrackerService**: For subscriptions that need reference counting
* - **SubscriptionHelperService**: For GraphQL subscriptions with automatic cleanup
*
* This service provides the underlying implementation for:
* - **Polling subscriptions**: Execute a callback at regular intervals
* - **Event-based subscriptions**: Set up event listeners or watchers that persist until stopped
*
* @internal
*/
@Injectable()
export class SubscriptionManagerService implements OnModuleDestroy {
private readonly logger = new Logger(SubscriptionManagerService.name);
private readonly activeSubscriptions = new Map<
string,
{ isPolling: boolean; config?: SubscriptionConfig }
>();
constructor(private readonly schedulerRegistry: SchedulerRegistry) {}
async onModuleDestroy() {
await this.stopAll();
}
/**
* Start a managed subscription (polling or event-based).
*
* @param config - The subscription configuration
* @throws Will throw an error if the onStart callback fails
*/
async startSubscription(config: SubscriptionConfig): Promise<void> {
const { name, intervalMs, callback, onStart } = config;
// Clean up any existing subscription with the same name
await this.stopSubscription(name);
// Initialize subscription state with config
this.activeSubscriptions.set(name, { isPolling: false, config });
// Call onStart callback if provided
if (onStart) {
try {
await onStart();
this.logger.debug(`Called onStart for '${name}'`);
} catch (error) {
this.logger.error(`Error in onStart for '${name}'`, error);
throw error;
}
}
// If intervalMs is null, this is a continuous/event-based subscription
if (intervalMs === null || intervalMs === undefined) {
this.logger.debug(`Started continuous subscription for '${name}' (no polling)`);
return;
}
// Create the polling function with guard against overlapping executions
const pollFunction = async () => {
const subscription = this.activeSubscriptions.get(name);
if (!subscription || subscription.isPolling) {
return;
}
subscription.isPolling = true;
try {
await callback();
} catch (error) {
this.logger.error(`Error in subscription callback '${name}'`, error);
} finally {
if (subscription) {
subscription.isPolling = false;
}
}
};
// Create and register the interval
const interval = setInterval(pollFunction, intervalMs);
this.schedulerRegistry.addInterval(name, interval);
this.logger.debug(`Started polling for '${name}' every ${intervalMs}ms`);
}
/**
* Stop a managed subscription.
*
* This will:
* 1. Stop any active polling interval
* 2. Call the onStop callback if provided
* 3. Clean up internal state
*
* @param name - The unique identifier of the subscription to stop
*/
async stopSubscription(name: string): Promise<void> {
// Get the config before deleting
const subscription = this.activeSubscriptions.get(name);
const onStop = subscription?.config?.onStop;
try {
if (this.schedulerRegistry.doesExist('interval', name)) {
this.schedulerRegistry.deleteInterval(name);
this.logger.debug(`Stopped polling interval for '${name}'`);
}
} catch (error) {
// Interval doesn't exist, which is fine
}
// Call onStop callback if provided
if (onStop) {
try {
await onStop();
this.logger.debug(`Called onStop for '${name}'`);
} catch (error) {
this.logger.error(`Error in onStop for '${name}'`, error);
}
}
// Clean up subscription state
this.activeSubscriptions.delete(name);
}
/**
* Stop all active subscriptions.
*
* This is automatically called when the module is destroyed.
*/
async stopAll(): Promise<void> {
// Get all active subscription keys (both polling and event-based)
const activeKeys = Array.from(this.activeSubscriptions.keys());
// Stop each subscription and await cleanup
await Promise.all(activeKeys.map((key) => this.stopSubscription(key)));
// Clear the map after all subscriptions are stopped
this.activeSubscriptions.clear();
}
/**
* Check if a subscription is active.
*
* @param name - The unique identifier of the subscription
* @returns true if the subscription exists (either polling or event-based)
*/
isSubscriptionActive(name: string): boolean {
// Check both for polling intervals and event-based subscriptions
return this.activeSubscriptions.has(name) || this.schedulerRegistry.doesExist('interval', name);
}
/**
* Get the total number of active subscriptions.
*
* @returns The count of all active subscriptions (polling and event-based)
*/
getActiveSubscriptionCount(): number {
return this.activeSubscriptions.size;
}
/**
* Get a list of all active subscription names.
*
* @returns Array of subscription identifiers
*/
getActiveSubscriptionNames(): string[] {
return Array.from(this.activeSubscriptions.keys());
}
}

View File

@@ -1,91 +0,0 @@
import { Injectable, Logger, OnModuleDestroy } from '@nestjs/common';
import { SchedulerRegistry } from '@nestjs/schedule';
export interface PollingConfig {
name: string;
intervalMs: number;
callback: () => Promise<void>;
}
@Injectable()
export class SubscriptionPollingService implements OnModuleDestroy {
private readonly logger = new Logger(SubscriptionPollingService.name);
private readonly activePollers = new Map<string, { isPolling: boolean }>();
constructor(private readonly schedulerRegistry: SchedulerRegistry) {}
onModuleDestroy() {
this.stopAll();
}
/**
* Start polling for a specific subscription topic
*/
startPolling(config: PollingConfig): void {
const { name, intervalMs, callback } = config;
// Clean up any existing interval
this.stopPolling(name);
// Initialize polling state
this.activePollers.set(name, { isPolling: false });
// Create the polling function with guard against overlapping executions
const pollFunction = async () => {
const poller = this.activePollers.get(name);
if (!poller || poller.isPolling) {
return;
}
poller.isPolling = true;
try {
await callback();
} catch (error) {
this.logger.error(`Error in polling task '${name}'`, error);
} finally {
if (poller) {
poller.isPolling = false;
}
}
};
// Create and register the interval
const interval = setInterval(pollFunction, intervalMs);
this.schedulerRegistry.addInterval(name, interval);
this.logger.debug(`Started polling for '${name}' every ${intervalMs}ms`);
}
/**
* Stop polling for a specific subscription topic
*/
stopPolling(name: string): void {
try {
if (this.schedulerRegistry.doesExist('interval', name)) {
this.schedulerRegistry.deleteInterval(name);
this.logger.debug(`Stopped polling for '${name}'`);
}
} catch (error) {
// Interval doesn't exist, which is fine
}
// Clean up polling state
this.activePollers.delete(name);
}
/**
* Stop all active polling tasks
*/
stopAll(): void {
const intervals = this.schedulerRegistry.getIntervals();
intervals.forEach((key) => this.stopPolling(key));
this.activePollers.clear();
}
/**
* Check if polling is active for a given name
*/
isPolling(name: string): boolean {
return this.schedulerRegistry.doesExist('interval', name);
}
}

View File

@@ -1,14 +1,44 @@
import { Injectable, Logger } from '@nestjs/common'; import { Injectable, Logger } from '@nestjs/common';
import { SubscriptionPollingService } from '@app/unraid-api/graph/services/subscription-polling.service.js'; import { SubscriptionManagerService } from '@app/unraid-api/graph/services/subscription-manager.service.js';
/**
* Service for managing subscriptions with automatic reference counting.
*
* This service tracks the number of active subscribers for each topic and automatically
* starts/stops the underlying subscription based on subscriber count.
*
* **When to use this service:**
* - When you have multiple GraphQL subscriptions that share the same data source
* - When you need to start a resource (polling, file watcher, etc.) only when there are active subscribers
* - When you need automatic cleanup when the last subscriber disconnects
*
* @example
* // Register a polling subscription
* subscriptionTracker.registerTopic(
* 'metrics-update',
* async () => {
* const metrics = await fetchMetrics();
* pubsub.publish('metrics-update', { metrics });
* },
* 5000 // Poll every 5 seconds
* );
*
* @example
* // Register an event-based subscription (e.g., file watching)
* subscriptionTracker.registerTopic(
* 'log-file-updates',
* () => startFileWatcher('/var/log/app.log'), // onStart
* () => stopFileWatcher('/var/log/app.log') // onStop
* );
*/
@Injectable() @Injectable()
export class SubscriptionTrackerService { export class SubscriptionTrackerService {
private readonly logger = new Logger(SubscriptionTrackerService.name); private readonly logger = new Logger(SubscriptionTrackerService.name);
private subscriberCounts = new Map<string, number>(); private subscriberCounts = new Map<string, number>();
private topicHandlers = new Map<string, { onStart: () => void; onStop: () => void }>(); private topicHandlers = new Map<string, { onStart: () => void; onStop: () => void }>();
constructor(private readonly pollingService: SubscriptionPollingService) {} constructor(private readonly subscriptionManager: SubscriptionManagerService) {}
/** /**
* Register a topic with optional polling support * Register a topic with optional polling support
@@ -29,8 +59,8 @@ export class SubscriptionTrackerService {
callback: async () => callbackOrOnStart(), callback: async () => callbackOrOnStart(),
}; };
this.topicHandlers.set(topic, { this.topicHandlers.set(topic, {
onStart: () => this.pollingService.startPolling(pollingConfig), onStart: () => this.subscriptionManager.startSubscription(pollingConfig),
onStop: () => this.pollingService.stopPolling(topic), onStop: () => this.subscriptionManager.stopSubscription(topic),
}); });
} else { } else {
// Legacy API: onStart and onStop handlers // Legacy API: onStart and onStop handlers

View File

@@ -1,3 +1,4 @@
import { CacheModule } from '@nestjs/cache-manager';
import { Test } from '@nestjs/testing'; import { Test } from '@nestjs/testing';
import { CANONICAL_INTERNAL_CLIENT_TOKEN } from '@unraid/shared'; import { CANONICAL_INTERNAL_CLIENT_TOKEN } from '@unraid/shared';
@@ -60,7 +61,7 @@ vi.mock('execa', () => ({
describe('RestModule Integration', () => { describe('RestModule Integration', () => {
it('should compile with RestService having access to ApiReportService', async () => { it('should compile with RestService having access to ApiReportService', async () => {
const module = await Test.createTestingModule({ const module = await Test.createTestingModule({
imports: [RestModule], imports: [CacheModule.register({ isGlobal: true }), RestModule],
}) })
// Override services that have complex dependencies for testing // Override services that have complex dependencies for testing
.overrideProvider(CANONICAL_INTERNAL_CLIENT_TOKEN) .overrideProvider(CANONICAL_INTERNAL_CLIENT_TOKEN)

View File

@@ -0,0 +1,487 @@
import { Logger } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import { Test, TestingModule } from '@nestjs/testing';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import type { FastifyReply, FastifyRequest } from '@app/unraid-api/types/fastify.js';
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
import { OidcService } from '@app/unraid-api/graph/resolvers/sso/core/oidc.service.js';
import { OidcRequestHandler } from '@app/unraid-api/graph/resolvers/sso/utils/oidc-request-handler.util.js';
import { RestController } from '@app/unraid-api/rest/rest.controller.js';
import { RestService } from '@app/unraid-api/rest/rest.service.js';
describe('RestController', () => {
let controller: RestController;
let oidcService: OidcService;
let oidcConfig: OidcConfigPersistence;
let mockReply: Partial<FastifyReply>;
// Helper function to create a mock request with the desired hostname
const createMockRequest = (hostname?: string, headers: Record<string, any> = {}): FastifyRequest => {
return {
headers,
hostname,
url: '/test',
protocol: 'https',
} as FastifyRequest;
};
beforeEach(async () => {
const module: TestingModule = await Test.createTestingModule({
controllers: [RestController],
providers: [
{
provide: RestService,
useValue: {
getLogs: vi.fn(),
getCustomizationStream: vi.fn(),
},
},
{
provide: OidcService,
useValue: {
getAuthorizationUrl: vi.fn(),
handleCallback: vi.fn(),
},
},
{
provide: OidcConfigPersistence,
useValue: {
getConfig: vi.fn().mockResolvedValue({
defaultAllowedOrigins: [],
}),
},
},
{
provide: ConfigService,
useValue: {
get: vi.fn(),
},
},
],
}).compile();
controller = module.get<RestController>(RestController);
oidcService = module.get<OidcService>(OidcService);
oidcConfig = module.get<OidcConfigPersistence>(OidcConfigPersistence);
mockReply = {
status: vi.fn().mockReturnThis(),
header: vi.fn().mockReturnThis(),
send: vi.fn().mockReturnThis(),
type: vi.fn().mockReturnThis(),
};
});
describe('oidcAuthorize', () => {
describe('redirect URI validation', () => {
beforeEach(() => {
// Mock OidcRequestHandler.handleAuthorize to return a valid auth URL
vi.spyOn(OidcRequestHandler, 'handleAuthorize').mockResolvedValue(
'https://provider.com/authorize?client_id=test&redirect_uri=...'
);
});
it('should accept redirect_uri with same hostname but different port', async () => {
const mockRequest = createMockRequest('unraid.mytailnet.ts.net');
await controller.oidcAuthorize(
'test-provider',
'test-state',
'https://unraid.mytailnet.ts.net:1443/graphql/api/auth/oidc/callback',
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(302);
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalledWith(
'test-provider',
'test-state',
'https://unraid.mytailnet.ts.net:1443/graphql/api/auth/oidc/callback',
mockRequest,
oidcService,
expect.any(Logger)
);
});
it('should accept redirect_uri with same hostname and standard HTTPS port', async () => {
const mockRequest = createMockRequest('unraid.mytailnet.ts.net');
await controller.oidcAuthorize(
'test-provider',
'test-state',
'https://unraid.mytailnet.ts.net/graphql/api/auth/oidc/callback',
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(302);
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalled();
});
it('should accept redirect_uri with same hostname and explicit port 443', async () => {
const mockRequest = createMockRequest('unraid.mytailnet.ts.net');
await controller.oidcAuthorize(
'test-provider',
'test-state',
'https://unraid.mytailnet.ts.net:443/graphql/api/auth/oidc/callback',
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(302);
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalled();
});
it('should reject redirect_uri with different hostname', async () => {
const mockRequest = createMockRequest('unraid.mytailnet.ts.net');
await controller.oidcAuthorize(
'test-provider',
'test-state',
'https://evil.com/graphql/api/auth/oidc/callback',
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(400);
expect(mockReply.send).toHaveBeenCalledWith(
expect.stringContaining(
'Invalid redirect_uri: https://evil.com/graphql/api/auth/oidc/callback'
)
);
expect(OidcRequestHandler.handleAuthorize).not.toHaveBeenCalled();
});
it('should reject redirect_uri with subdomain difference', async () => {
const mockRequest = createMockRequest('unraid.mytailnet.ts.net');
await controller.oidcAuthorize(
'test-provider',
'test-state',
'https://evil.unraid.mytailnet.ts.net/graphql/api/auth/oidc/callback',
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(400);
expect(mockReply.send).toHaveBeenCalledWith(
expect.stringContaining(
'Invalid redirect_uri: https://evil.unraid.mytailnet.ts.net/graphql/api/auth/oidc/callback'
)
);
expect(OidcRequestHandler.handleAuthorize).not.toHaveBeenCalled();
});
it('should handle hostname from host header when hostname is not available', async () => {
const mockRequest = createMockRequest(undefined, {
host: 'unraid.mytailnet.ts.net:8080',
});
await controller.oidcAuthorize(
'test-provider',
'test-state',
'https://unraid.mytailnet.ts.net:1443/graphql/api/auth/oidc/callback',
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(302);
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalled();
});
it('should reject malformed redirect_uri', async () => {
const mockRequest = createMockRequest('unraid.mytailnet.ts.net');
await controller.oidcAuthorize(
'test-provider',
'test-state',
'not-a-valid-url',
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(400);
expect(mockReply.send).toHaveBeenCalledWith(
expect.stringContaining('Invalid redirect_uri: not-a-valid-url')
);
expect(OidcRequestHandler.handleAuthorize).not.toHaveBeenCalled();
});
it('should handle case-insensitive hostname comparison', async () => {
const mockRequest = createMockRequest('UnRaid.MyTailnet.TS.net');
await controller.oidcAuthorize(
'test-provider',
'test-state',
'https://unraid.mytailnet.ts.net:1443/graphql/api/auth/oidc/callback',
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(302);
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalled();
});
it('should preserve exact redirect_uri including custom port in call to handleAuthorize', async () => {
const mockRequest = createMockRequest('unraid.mytailnet.ts.net');
const customRedirectUri =
'https://unraid.mytailnet.ts.net:1443/graphql/api/auth/oidc/callback';
await controller.oidcAuthorize(
'test-provider',
'test-state',
customRedirectUri,
mockRequest,
mockReply as FastifyReply
);
// Verify the exact redirect URI with port is passed through
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalledWith(
'test-provider',
'test-state',
customRedirectUri, // Should be exactly as provided, with :1443
mockRequest,
oidcService,
expect.any(Logger)
);
});
it('should allow localhost with different ports', async () => {
const mockRequest = createMockRequest('localhost');
await controller.oidcAuthorize(
'test-provider',
'test-state',
'http://localhost:3000/graphql/api/auth/oidc/callback',
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(302);
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalledWith(
'test-provider',
'test-state',
'http://localhost:3000/graphql/api/auth/oidc/callback',
mockRequest,
oidcService,
expect.any(Logger)
);
});
it('should allow IP addresses with different ports', async () => {
const mockRequest = createMockRequest('192.168.1.100');
await controller.oidcAuthorize(
'test-provider',
'test-state',
'http://192.168.1.100:8080/graphql/api/auth/oidc/callback',
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(302);
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalled();
});
it('should accept redirect_uri with different hostname if in allowed origins', async () => {
const mockRequest = createMockRequest('devgen-dev1.local');
// Mock the config to include the allowed origin
vi.mocked(oidcConfig.getConfig).mockResolvedValueOnce({
defaultAllowedOrigins: ['https://devgen-bad-dev1.local'],
} as any);
await controller.oidcAuthorize(
'test-provider',
'test-state',
'https://devgen-bad-dev1.local/graphql/api/auth/oidc/callback',
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(302);
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalledWith(
'test-provider',
'test-state',
'https://devgen-bad-dev1.local/graphql/api/auth/oidc/callback',
mockRequest,
oidcService,
expect.any(Logger)
);
});
describe('integration with centralized validator', () => {
it('should use the same validation logic as validateRedirectUri function', async () => {
const testCases = [
{
name: 'accepts HTTPS upgrade from allowed origins',
requestHost: 'devgen-dev1.local',
redirectUri: 'https://allowed-host.local/graphql/api/auth/oidc/callback',
allowedOrigins: ['http://allowed-host.local'],
expectedStatus: 302,
shouldSucceed: true,
},
{
name: 'rejects hostname not in allowed origins',
requestHost: 'devgen-dev1.local',
redirectUri: 'https://evil.com/graphql/api/auth/oidc/callback',
allowedOrigins: ['https://good-host.local'],
expectedStatus: 400,
shouldSucceed: false,
},
{
name: 'accepts multiple allowed origins',
requestHost: 'devgen-dev1.local',
redirectUri: 'https://second.local/graphql/api/auth/oidc/callback',
allowedOrigins: [
'https://first.local',
'https://second.local',
'https://third.local',
],
expectedStatus: 302,
shouldSucceed: true,
},
{
name: 'respects protocol and hostname from headers',
requestHost: undefined,
headers: {
'x-forwarded-proto': 'https',
'x-forwarded-host': 'proxy.local',
},
redirectUri: 'https://proxy.local/graphql/api/auth/oidc/callback',
allowedOrigins: [],
expectedStatus: 302,
shouldSucceed: true,
},
];
for (const testCase of testCases) {
// Reset mocks for each test case
vi.clearAllMocks();
const mockRequest = createMockRequest(
testCase.requestHost,
testCase.headers || {}
);
vi.mocked(oidcConfig.getConfig).mockResolvedValueOnce({
defaultAllowedOrigins: testCase.allowedOrigins,
} as any);
await controller.oidcAuthorize(
'test-provider',
'test-state',
testCase.redirectUri,
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(testCase.expectedStatus);
if (testCase.shouldSucceed) {
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalled();
} else {
expect(mockReply.send).toHaveBeenCalledWith(
expect.stringContaining(testCase.redirectUri)
);
expect(OidcRequestHandler.handleAuthorize).not.toHaveBeenCalled();
}
}
});
it('should handle edge cases consistently with centralized validator', async () => {
// Test with empty allowed origins
vi.mocked(oidcConfig.getConfig).mockResolvedValueOnce({
defaultAllowedOrigins: [],
} as any);
const mockRequest = createMockRequest('host.local');
await controller.oidcAuthorize(
'test-provider',
'test-state',
'https://different.local/graphql/api/auth/oidc/callback',
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(400);
expect(mockReply.send).toHaveBeenCalledWith(
expect.stringContaining('https://different.local/graphql/api/auth/oidc/callback')
);
});
it('should validate that error messages guide users to settings', async () => {
vi.mocked(oidcConfig.getConfig).mockResolvedValueOnce({
defaultAllowedOrigins: [],
} as any);
const mockRequest = createMockRequest('host.local');
await controller.oidcAuthorize(
'test-provider',
'test-state',
'https://different.local/graphql/api/auth/oidc/callback',
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.send).toHaveBeenCalledWith(
expect.stringContaining('Settings → Management Access → Allowed Redirect URIs')
);
});
});
});
describe('parameter validation', () => {
it('should return 400 if redirect_uri is missing', async () => {
const mockRequest = createMockRequest('unraid.local');
await controller.oidcAuthorize(
'test-provider',
'test-state',
undefined as any,
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(400);
// The controller catches validation errors and returns a generic message
expect(mockReply.send).toHaveBeenCalledWith('Invalid provider or configuration');
});
it('should return 400 if providerId is missing', async () => {
const mockRequest = createMockRequest('unraid.local');
await controller.oidcAuthorize(
undefined as any,
'test-state',
'https://unraid.local/callback',
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(400);
expect(mockReply.send).toHaveBeenCalledWith('Invalid provider or configuration');
});
it('should return 400 if state is missing', async () => {
const mockRequest = createMockRequest('unraid.local');
await controller.oidcAuthorize(
'test-provider',
undefined as any,
'https://unraid.local/callback',
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(400);
expect(mockReply.send).toHaveBeenCalledWith('Invalid provider or configuration');
});
});
});
});

View File

@@ -2,18 +2,25 @@ import { Controller, Get, Logger, Param, Query, Req, Res, UnauthorizedException
import { AuthAction, Resource } from '@unraid/shared/graphql.model.js'; import { AuthAction, Resource } from '@unraid/shared/graphql.model.js';
import { UsePermissions } from '@unraid/shared/use-permissions.directive.js'; import { UsePermissions } from '@unraid/shared/use-permissions.directive.js';
import escapeHtml from 'escape-html';
import type { FastifyReply, FastifyRequest } from '@app/unraid-api/types/fastify.js'; import type { FastifyReply, FastifyRequest } from '@app/unraid-api/types/fastify.js';
import { Public } from '@app/unraid-api/auth/public.decorator.js'; import { Public } from '@app/unraid-api/auth/public.decorator.js';
import { OidcAuthService } from '@app/unraid-api/graph/resolvers/sso/oidc-auth.service.js'; import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
import { OidcService } from '@app/unraid-api/graph/resolvers/sso/core/oidc.service.js';
import { OidcRequestHandler } from '@app/unraid-api/graph/resolvers/sso/utils/oidc-request-handler.util.js';
import { RestService } from '@app/unraid-api/rest/rest.service.js'; import { RestService } from '@app/unraid-api/rest/rest.service.js';
import { validateRedirectUri } from '@app/unraid-api/utils/redirect-uri-validator.js';
@Controller() @Controller()
export class RestController { export class RestController {
protected logger = new Logger(RestController.name); protected logger = new Logger(RestController.name);
protected oidcLogger = new Logger('OidcRestController');
constructor( constructor(
private readonly restService: RestService, private readonly restService: RestService,
private readonly oidcAuthService: OidcAuthService private readonly oidcService: OidcService,
private readonly oidcConfig: OidcConfigPersistence
) {} ) {}
@Get('/') @Get('/')
@@ -65,38 +72,69 @@ export class RestController {
async oidcAuthorize( async oidcAuthorize(
@Param('providerId') providerId: string, @Param('providerId') providerId: string,
@Query('state') state: string, @Query('state') state: string,
@Query('redirect_uri') redirectUri: string,
@Req() req: FastifyRequest, @Req() req: FastifyRequest,
@Res() res: FastifyReply @Res() res: FastifyReply
) { ) {
try { try {
if (!state) { // Validate required parameters
return res.status(400).send('State parameter is required'); const params = OidcRequestHandler.validateAuthorizeParams(providerId, state, redirectUri);
// IMPORTANT: Use the redirect_uri from query params directly
// Do NOT parse headers or try to build/validate against headers
// The frontend provides the complete redirect_uri
if (!params.redirectUri) {
return res.status(400).send('redirect_uri parameter is required');
} }
// Get the host and protocol from the request headers // Security validation: validate redirect_uri with support for allowed origins
const protocol = (req.headers['x-forwarded-proto'] as string) || req.protocol || 'http'; const protocol = (req.headers['x-forwarded-proto'] as string) || 'http';
const host = (req.headers['x-forwarded-host'] as string) || req.headers.host || undefined; const host = (req.headers['x-forwarded-host'] as string) || req.headers.host || req.hostname;
const requestInfo = host ? `${protocol}://${host}` : undefined;
const authUrl = await this.oidcAuthService.getAuthorizationUrl( // Get allowed origins from OIDC config
providerId, const config = await this.oidcConfig.getConfig();
state, const allowedOrigins = config?.defaultAllowedOrigins;
requestInfo
// Validate the redirect URI using the centralized validator
const validation = validateRedirectUri(
params.redirectUri,
protocol,
host,
this.oidcLogger,
allowedOrigins
);
if (!validation.isValid) {
this.oidcLogger.warn(`Invalid redirect_uri: ${validation.reason}`);
return res
.status(400)
.send(
`Invalid redirect_uri: ${escapeHtml(params.redirectUri)}. ${escapeHtml(validation.reason || 'Unknown validation error')}. Please add this callback URI to Settings → Management Access → Allowed Redirect URIs`
);
}
// Handle authorization flow using the exact redirect_uri from query params
const authUrl = await OidcRequestHandler.handleAuthorize(
params.providerId,
params.state,
params.redirectUri,
req,
this.oidcService,
this.oidcLogger
); );
this.logger.log(`Redirecting to OIDC provider: ${authUrl}`);
// Manually set redirect headers for better proxy compatibility // Manually set redirect headers for better proxy compatibility
res.status(302); res.status(302);
res.header('Location', authUrl); res.header('Location', authUrl);
return res.send(); return res.send();
} catch (error: unknown) { } catch (error: unknown) {
this.logger.error(`OIDC authorize error for provider ${providerId}:`, error); this.oidcLogger.error(`OIDC authorize error for provider ${providerId}:`, error);
// Log more details about the error // Log more details about the error
if (error instanceof Error) { if (error instanceof Error) {
this.logger.error(`Error message: ${error.message}`); this.oidcLogger.error(`Error message: ${error.message}`);
if (error.stack) { if (error.stack) {
this.logger.debug(`Stack trace: ${error.stack}`); this.oidcLogger.debug(`Stack trace: ${error.stack}`);
} }
} }
@@ -117,32 +155,20 @@ export class RestController {
@Res() res: FastifyReply @Res() res: FastifyReply
) { ) {
try { try {
if (!code || !state) { // Validate required parameters
return res.status(400).send('Missing required parameters'); const params = OidcRequestHandler.validateCallbackParams(code, state);
}
// Extract provider ID from state // Handle callback flow
const { providerId } = this.oidcAuthService.extractProviderFromState(state); const result = await OidcRequestHandler.handleCallback(
params.code,
// Get the full callback URL as received, respecting reverse proxy headers params.state,
const protocol = (req.headers['x-forwarded-proto'] as string) || req.protocol || 'http'; req,
const host = this.oidcService,
(req.headers['x-forwarded-host'] as string) || req.headers.host || 'localhost:3000'; this.oidcLogger
const fullUrl = `${protocol}://${host}${req.url}`;
const requestInfo = `${protocol}://${host}`;
this.logger.debug(`Full callback URL from request: ${fullUrl}`);
const paddedToken = await this.oidcAuthService.handleCallback(
providerId,
code,
state,
requestInfo,
fullUrl
); );
// Redirect to login page with the token in hash to keep it out of server logs // Redirect to login page with the token in hash to keep it out of server logs
const loginUrl = `/login#token=${encodeURIComponent(paddedToken)}`; const loginUrl = `/login#token=${encodeURIComponent(result.paddedToken)}`;
// Manually set redirect headers for better proxy compatibility // Manually set redirect headers for better proxy compatibility
res.header('Cache-Control', 'no-store'); res.header('Cache-Control', 'no-store');
@@ -152,16 +178,16 @@ export class RestController {
res.header('Location', loginUrl); res.header('Location', loginUrl);
return res.send(); return res.send();
} catch (error: unknown) { } catch (error: unknown) {
this.logger.error(`OIDC callback error: ${error}`); this.oidcLogger.error(`OIDC callback error: ${error}`);
// Use a generic error message to avoid leaking sensitive information // Use a generic error message to avoid leaking sensitive information
const errorMessage = 'Authentication failed'; const errorMessage = 'Authentication failed';
// Log detailed error for debugging but don't expose to user // Log detailed error for debugging but don't expose to user
if (error instanceof UnauthorizedException) { if (error instanceof UnauthorizedException) {
this.logger.debug(`UnauthorizedException occurred during OIDC callback`); this.oidcLogger.debug(`UnauthorizedException occurred during OIDC callback`);
} else if (error instanceof Error) { } else if (error instanceof Error) {
this.logger.debug(`Error during OIDC callback: ${error.message}`); this.oidcLogger.debug(`Error during OIDC callback: ${error.message}`);
} }
const loginUrl = `/login#error=${encodeURIComponent(errorMessage)}`; const loginUrl = `/login#error=${encodeURIComponent(errorMessage)}`;

View File

@@ -0,0 +1,406 @@
import { describe, expect, it, vi } from 'vitest';
import { ErrorExtractor } from '@app/unraid-api/utils/error-extractor.util.js';
describe('ErrorExtractor', () => {
describe('extract', () => {
it('should handle null and undefined errors', () => {
const nullResult = ErrorExtractor.extract(null);
expect(nullResult.message).toBe('Unknown error');
expect(nullResult.type).toBe('Unknown');
const undefinedResult = ErrorExtractor.extract(undefined);
expect(undefinedResult.message).toBe('Unknown error');
expect(undefinedResult.type).toBe('Unknown');
});
it('should extract basic Error properties', () => {
const error = new Error('Test error message');
const result = ErrorExtractor.extract(error);
expect(result.message).toBe('Test error message');
expect(result.type).toBe('Error');
expect(result.stack).toBeDefined();
});
it('should extract custom error types', () => {
class CustomError extends Error {}
const error = new CustomError('Custom error');
const result = ErrorExtractor.extract(error);
expect(result.type).toBe('CustomError');
});
it('should extract error code', () => {
const error: any = new Error('Error with code');
error.code = 'ERR_CODE';
const result = ErrorExtractor.extract(error);
expect(result.code).toBe('ERR_CODE');
});
it('should extract HTTP response details', () => {
const error: any = new Error('HTTP error');
error.response = {
status: 404,
statusText: 'Not Found',
body: { error: 'Resource not found' },
headers: { 'content-type': 'application/json' },
};
const result = ErrorExtractor.extract(error);
expect(result.status).toBe(404);
expect(result.statusText).toBe('Not Found');
expect(result.responseBody).toEqual({ error: 'Resource not found' });
expect(result.responseHeaders).toEqual({ 'content-type': 'application/json' });
});
it('should truncate long response body strings', () => {
const error: any = new Error('Error with long body');
const longString = 'x'.repeat(2000);
error.body = longString;
const result = ErrorExtractor.extract(error);
expect(result.responseBody).toBe('x'.repeat(1000) + '... (truncated)');
});
it('should extract OAuth error details', () => {
const error: any = new Error('OAuth error');
error.error = 'invalid_grant';
error.error_description = 'The provided authorization code is invalid';
const result = ErrorExtractor.extract(error);
expect(result.oauthError).toBe('invalid_grant');
expect(result.oauthErrorDescription).toBe('The provided authorization code is invalid');
});
it('should extract cause chain', () => {
const rootCause = new Error('Root cause');
const middleCause: any = new Error('Middle cause');
middleCause.cause = rootCause;
const topError: any = new Error('Top error');
topError.cause = middleCause;
const result = ErrorExtractor.extract(topError);
expect(result.causeChain).toHaveLength(2);
expect(result.causeChain![0]).toEqual({
depth: 1,
type: 'Error',
message: 'Middle cause',
});
expect(result.causeChain![1]).toEqual({
depth: 2,
type: 'Error',
message: 'Root cause',
});
});
it('should limit cause chain depth', () => {
// Create a deep nested error chain
let deepestError: any = new Error('Level 10');
for (let i = 9; i >= 0; i--) {
const error: any = new Error(`Level ${i}`);
error.cause = deepestError;
deepestError = error;
}
const topError: any = new Error('Top');
topError.cause = deepestError;
const result = ErrorExtractor.extract(topError);
expect(result.causeChain).toHaveLength(5); // MAX_CAUSE_DEPTH
});
it('should extract cause with code', () => {
const cause: any = new Error('Cause with code');
cause.code = 'ECONNREFUSED';
const error: any = new Error('Main error');
error.cause = cause;
const result = ErrorExtractor.extract(error);
expect(result.causeChain![0].code).toBe('ECONNREFUSED');
});
it('should extract additional properties', () => {
const error: any = new Error('Error with extras');
error.customProp1 = 'value1';
error.customProp2 = 123;
const result = ErrorExtractor.extract(error);
expect(result.additionalProperties).toEqual({
customProp1: 'value1',
customProp2: 123,
});
});
it('should handle string errors', () => {
const result = ErrorExtractor.extract('String error message');
expect(result.message).toBe('String error message');
expect(result.type).toBe('String');
});
it('should handle object errors', () => {
const error = { code: 'ERROR', message: 'Object error' };
const result = ErrorExtractor.extract(error);
expect(result.message).toBe(JSON.stringify(error));
expect(result.type).toBe('Object');
});
it('should handle primitive errors', () => {
const result = ErrorExtractor.extract(42);
expect(result.message).toBe('42');
expect(result.type).toBe('number');
});
it('should handle openid-client error structure', () => {
const error: any = new Error('unexpected response content-type');
error.code = 'OAUTH_RESPONSE_IS_NOT_JSON';
error.response = {
status: 200,
headers: { 'content-type': 'text/html' },
body: '<html>Error page</html>',
};
const result = ErrorExtractor.extract(error);
expect(result.code).toBe('OAUTH_RESPONSE_IS_NOT_JSON');
expect(result.responseHeaders!['content-type']).toBe('text/html');
expect(result.responseBody).toContain('<html>');
});
});
describe('isOAuthResponseError', () => {
it('should identify OAuth response errors by code', () => {
const extracted = {
message: 'Some error',
type: 'Error',
code: 'OAUTH_RESPONSE_IS_NOT_JSON',
};
expect(ErrorExtractor.isOAuthResponseError(extracted)).toBe(true);
});
it('should identify OAuth response errors by message', () => {
const extracted = {
message: 'unexpected response content-type from server',
type: 'Error',
};
expect(ErrorExtractor.isOAuthResponseError(extracted)).toBe(true);
});
it('should identify parsing errors', () => {
const extracted = {
message: 'JSON parsing error occurred',
type: 'Error',
};
expect(ErrorExtractor.isOAuthResponseError(extracted)).toBe(true);
});
it('should not identify non-OAuth errors', () => {
const extracted = {
message: 'Some other error',
type: 'Error',
code: 'OTHER_ERROR',
};
expect(ErrorExtractor.isOAuthResponseError(extracted)).toBe(false);
});
});
describe('isJwtClaimError', () => {
it('should identify JWT claim errors', () => {
const extracted = {
message: 'unexpected JWT claim value encountered',
type: 'Error',
};
expect(ErrorExtractor.isJwtClaimError(extracted)).toBe(true);
});
it('should not identify non-JWT errors', () => {
const extracted = {
message: 'Some other error',
type: 'Error',
};
expect(ErrorExtractor.isJwtClaimError(extracted)).toBe(false);
});
});
describe('isNetworkError', () => {
it('should identify network errors by code', () => {
const codes = ['ECONNREFUSED', 'ENOTFOUND', 'ETIMEDOUT', 'ECONNRESET'];
for (const code of codes) {
const extracted = {
message: 'Error',
type: 'Error',
code,
};
expect(ErrorExtractor.isNetworkError(extracted)).toBe(true);
}
});
it('should identify network errors by message', () => {
const messages = ['network timeout occurred', 'failed to connect to server'];
for (const message of messages) {
const extracted = {
message,
type: 'Error',
};
expect(ErrorExtractor.isNetworkError(extracted)).toBe(true);
}
});
it('should not identify non-network errors', () => {
const extracted = {
message: 'Invalid credentials',
type: 'Error',
code: 'AUTH_ERROR',
};
expect(ErrorExtractor.isNetworkError(extracted)).toBe(false);
});
});
describe('formatForLogging', () => {
it('should log basic error information', () => {
const logger = {
error: vi.fn(),
debug: vi.fn(),
};
const extracted = {
message: 'Test error',
type: 'CustomError',
code: 'ERR_TEST',
};
ErrorExtractor.formatForLogging(extracted, logger);
expect(logger.error).toHaveBeenCalledWith('Error type: CustomError');
expect(logger.error).toHaveBeenCalledWith('Error message: Test error');
expect(logger.error).toHaveBeenCalledWith('Error code: ERR_TEST');
});
it('should log HTTP response details', () => {
const logger = {
error: vi.fn(),
debug: vi.fn(),
};
const extracted = {
message: 'HTTP error',
type: 'Error',
status: 500,
statusText: 'Internal Server Error',
responseBody: { error: 'Server error' },
responseHeaders: { 'content-type': 'application/json' },
};
ErrorExtractor.formatForLogging(extracted, logger);
expect(logger.error).toHaveBeenCalledWith('HTTP Status: 500 Internal Server Error');
expect(logger.error).toHaveBeenCalledWith('Response body: %o', { error: 'Server error' });
expect(logger.error).toHaveBeenCalledWith('Response Content-Type: application/json');
});
it('should log OAuth error details', () => {
const logger = {
error: vi.fn(),
debug: vi.fn(),
};
const extracted = {
message: 'OAuth error',
type: 'Error',
oauthError: 'invalid_client',
oauthErrorDescription: 'Client authentication failed',
};
ErrorExtractor.formatForLogging(extracted, logger);
expect(logger.error).toHaveBeenCalledWith('OAuth error: invalid_client');
expect(logger.error).toHaveBeenCalledWith(
'OAuth error description: Client authentication failed'
);
});
it('should log cause chain', () => {
const logger = {
error: vi.fn(),
debug: vi.fn(),
};
const extracted = {
message: 'Top error',
type: 'Error',
causeChain: [
{ depth: 1, type: 'Error', message: 'Cause 1', code: 'CODE1' },
{ depth: 2, type: 'Error', message: 'Cause 2' },
],
};
ErrorExtractor.formatForLogging(extracted, logger);
expect(logger.error).toHaveBeenCalledWith('Error cause chain:');
expect(logger.error).toHaveBeenCalledWith(' [Cause 1] Error: Cause 1');
expect(logger.error).toHaveBeenCalledWith(' [Cause 1] Code: CODE1');
expect(logger.error).toHaveBeenCalledWith(' [Cause 2] Error: Cause 2');
});
it('should log additional properties and stack in debug', () => {
const logger = {
error: vi.fn(),
debug: vi.fn(),
};
const extracted = {
message: 'Error',
type: 'Error',
additionalProperties: { custom: 'value' },
stack: 'Stack trace here',
};
ErrorExtractor.formatForLogging(extracted, logger);
expect(logger.debug).toHaveBeenCalledWith('Additional error properties: %o', {
custom: 'value',
});
expect(logger.debug).toHaveBeenCalledWith('Stack trace: Stack trace here');
});
it('should handle case-insensitive Content-Type header', () => {
const logger = {
error: vi.fn(),
debug: vi.fn(),
};
const extracted = {
message: 'Error',
type: 'Error',
responseHeaders: { 'Content-Type': 'text/html' },
};
ErrorExtractor.formatForLogging(extracted, logger);
expect(logger.error).toHaveBeenCalledWith('Response Content-Type: text/html');
});
});
});

View File

@@ -0,0 +1,277 @@
export interface ExtractedError {
message: string;
type: string;
code?: string;
status?: number;
statusText?: string;
responseBody?: unknown;
responseHeaders?: Record<string, string>;
oauthError?: string;
oauthErrorDescription?: string;
causeChain?: Array<{
depth: number;
type: string;
message: string;
code?: string;
}>;
additionalProperties?: Record<string, unknown>;
stack?: string;
}
export class ErrorExtractor {
private static readonly MAX_CAUSE_DEPTH = 5;
private static readonly MAX_BODY_PREVIEW_LENGTH = 1000;
static extract(error: unknown): ExtractedError {
const result: ExtractedError = {
message: 'Unknown error',
type: 'Unknown',
};
if (!error) {
return result;
}
if (error instanceof Error) {
result.message = error.message;
result.type = error.constructor.name;
result.stack = error.stack;
// Extract error code if present
if ('code' in error && error.code) {
result.code = String(error.code);
}
// Extract HTTP response details
if ('response' in error && error.response) {
this.extractResponseDetails(error.response as any, result);
}
// Extract OAuth-specific errors
if ('error' in error && error.error) {
result.oauthError = String(error.error);
}
if ('error_description' in error && error.error_description) {
result.oauthErrorDescription = String(error.error_description);
}
// Extract response body if directly available
if ('body' in error && error.body) {
result.responseBody = this.formatResponseBody(error.body as any);
}
// Extract cause chain
if ('cause' in error && error.cause) {
result.causeChain = this.extractCauseChain(error.cause);
}
// Extract additional properties
const standardKeys = [
'message',
'stack',
'cause',
'code',
'response',
'body',
'error',
'error_description',
];
const additionalKeys = Object.keys(error).filter((key) => !standardKeys.includes(key));
if (additionalKeys.length > 0) {
result.additionalProperties = {};
for (const key of additionalKeys) {
const value = (error as any)[key];
if (value !== undefined && value !== null) {
result.additionalProperties[key] = value;
}
}
}
} else if (typeof error === 'string') {
result.message = error;
result.type = 'String';
} else if (typeof error === 'object' && error !== null) {
result.message = JSON.stringify(error);
result.type = 'Object';
} else {
result.message = String(error);
result.type = typeof error;
}
return result;
}
private static extractResponseDetails(response: any, result: ExtractedError): void {
if (!response) return;
if (response.status) {
result.status = response.status;
}
if (response.statusText) {
result.statusText = response.statusText;
}
if (response.body) {
result.responseBody = this.formatResponseBody(response.body);
}
if (response.headers) {
result.responseHeaders = this.extractHeaders(response.headers);
}
}
private static formatResponseBody(body: unknown): unknown {
if (!body) return undefined;
if (typeof body === 'string') {
if (body.length > this.MAX_BODY_PREVIEW_LENGTH) {
return body.substring(0, this.MAX_BODY_PREVIEW_LENGTH) + '... (truncated)';
}
return body;
}
return body;
}
private static extractHeaders(headers: any): Record<string, string> | undefined {
if (!headers) return undefined;
const result: Record<string, string> = {};
// Handle different header formats
if (typeof headers === 'object') {
for (const key of Object.keys(headers)) {
const value = headers[key];
if (value !== undefined && value !== null) {
result[key] = String(value);
}
}
}
return Object.keys(result).length > 0 ? result : undefined;
}
private static extractCauseChain(
cause: unknown
): Array<{ depth: number; type: string; message: string; code?: string }> | undefined {
const chain: Array<{ depth: number; type: string; message: string; code?: string }> = [];
let currentCause = cause;
let depth = 1;
while (currentCause && depth <= this.MAX_CAUSE_DEPTH) {
const causeInfo: { depth: number; type: string; message: string; code?: string } = {
depth,
type: 'Unknown',
message: 'Unknown cause',
};
if (currentCause instanceof Error) {
causeInfo.type = currentCause.constructor.name;
causeInfo.message = currentCause.message;
if ('code' in currentCause && currentCause.code) {
causeInfo.code = String(currentCause.code);
}
} else if (typeof currentCause === 'string') {
causeInfo.type = 'String';
causeInfo.message = currentCause;
} else {
causeInfo.type = typeof currentCause;
causeInfo.message = String(currentCause);
}
chain.push(causeInfo);
// Get next cause in chain
currentCause =
currentCause && typeof currentCause === 'object' && 'cause' in currentCause
? (currentCause as any).cause
: undefined;
depth++;
}
return chain.length > 0 ? chain : undefined;
}
static isOAuthResponseError(extracted: ExtractedError): boolean {
return Boolean(
extracted.code === 'OAUTH_RESPONSE_IS_NOT_JSON' ||
extracted.code === 'OAUTH_PARSE_ERROR' ||
extracted.message.includes('unexpected response content-type') ||
extracted.message.includes('parsing error')
);
}
static isJwtClaimError(extracted: ExtractedError): boolean {
return extracted.message.includes('unexpected JWT claim value encountered');
}
static isNetworkError(extracted: ExtractedError): boolean {
return Boolean(
extracted.code === 'ECONNREFUSED' ||
extracted.code === 'ENOTFOUND' ||
extracted.code === 'ETIMEDOUT' ||
extracted.code === 'ECONNRESET' ||
extracted.message.includes('network') ||
extracted.message.includes('connect')
);
}
static formatForLogging(
extracted: ExtractedError,
logger: {
error: (msg: string, ...args: any[]) => void;
debug: (msg: string, ...args: any[]) => void;
}
): void {
logger.error(`Error type: ${extracted.type}`);
logger.error(`Error message: ${extracted.message}`);
if (extracted.code) {
logger.error(`Error code: ${extracted.code}`);
}
if (extracted.status) {
logger.error(`HTTP Status: ${extracted.status} ${extracted.statusText || ''}`);
}
if (extracted.responseBody) {
logger.error('Response body: %o', extracted.responseBody);
}
if (extracted.responseHeaders) {
const contentType =
extracted.responseHeaders['content-type'] || extracted.responseHeaders['Content-Type'];
if (contentType) {
logger.error(`Response Content-Type: ${contentType}`);
}
}
if (extracted.oauthError) {
logger.error(`OAuth error: ${extracted.oauthError}`);
if (extracted.oauthErrorDescription) {
logger.error(`OAuth error description: ${extracted.oauthErrorDescription}`);
}
}
if (extracted.causeChain) {
logger.error('Error cause chain:');
for (const cause of extracted.causeChain) {
logger.error(` [Cause ${cause.depth}] ${cause.type}: ${cause.message}`);
if (cause.code) {
logger.error(` [Cause ${cause.depth}] Code: ${cause.code}`);
}
}
}
if (extracted.additionalProperties && Object.keys(extracted.additionalProperties).length > 0) {
logger.debug('Additional error properties: %o', extracted.additionalProperties);
}
if (extracted.stack) {
logger.debug(`Stack trace: ${extracted.stack}`);
}
}
}

View File

@@ -0,0 +1,506 @@
import { Logger } from '@nestjs/common';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { validateRedirectUri } from '@app/unraid-api/utils/redirect-uri-validator.js';
describe('validateRedirectUri', () => {
let mockLogger: Logger;
beforeEach(() => {
mockLogger = {
debug: vi.fn(),
warn: vi.fn(),
} as any;
});
describe('basic validation', () => {
it('should return base URL when no redirect URI is provided', () => {
const result = validateRedirectUri(undefined, 'https', 'example.com', mockLogger);
expect(result).toEqual({
isValid: true,
validatedUri: 'https://example.com',
reason: 'No redirect URI provided',
});
});
it('should handle missing base URL', () => {
const result = validateRedirectUri('https://example.com', 'https', undefined, mockLogger);
expect(result).toEqual({
isValid: false,
validatedUri: '',
reason: 'No base URL available',
});
});
});
describe('hostname validation', () => {
it('should accept matching hostname with same port', () => {
const result = validateRedirectUri(
'https://example.com:3000',
'https',
'example.com:3000',
mockLogger
);
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe('https://example.com:3000');
expect(mockLogger.debug).toHaveBeenCalledWith(
'Validated redirect_uri: https://example.com:3000'
);
});
it('should accept matching hostname with different ports', () => {
const result = validateRedirectUri(
'https://example.com:3001',
'https',
'example.com:3000',
mockLogger
);
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe('https://example.com:3001');
expect(mockLogger.debug).toHaveBeenCalledWith(
'Validated redirect_uri: https://example.com:3001'
);
});
it('should accept matching hostname when expected has no port but provided does', () => {
const result = validateRedirectUri(
'https://example.com:3000',
'https',
'example.com',
mockLogger
);
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe('https://example.com:3000');
});
it('should reject different hostnames', () => {
const result = validateRedirectUri('https://evil.com', 'https', 'example.com', mockLogger);
expect(result.isValid).toBe(false);
expect(result.validatedUri).toBe('https://example.com');
expect(result.reason).toContain('Hostname or protocol mismatch');
expect(mockLogger.warn).toHaveBeenCalled();
});
it('should reject subdomain differences', () => {
const result = validateRedirectUri(
'https://sub.example.com',
'https',
'example.com',
mockLogger
);
expect(result.isValid).toBe(false);
expect(result.validatedUri).toBe('https://example.com');
});
it('should handle case-insensitive hostname comparison', () => {
const result = validateRedirectUri(
'https://EXAMPLE.COM:3000',
'https',
'example.com',
mockLogger
);
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe('https://EXAMPLE.COM:3000');
});
});
describe('protocol validation', () => {
it('should reject HTTP when expecting HTTPS (prevent downgrade attacks)', () => {
const result = validateRedirectUri('http://example.com', 'https', 'example.com', mockLogger);
expect(result.isValid).toBe(false);
expect(result.validatedUri).toBe('https://example.com');
expect(result.reason).toContain('Hostname or protocol mismatch');
});
it('should allow HTTPS when expecting HTTP (common with reverse proxies)', () => {
const result = validateRedirectUri('https://example.com', 'http', 'example.com', mockLogger);
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe('https://example.com');
});
it('should accept matching protocols', () => {
const result = validateRedirectUri('http://example.com', 'http', 'example.com', mockLogger);
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe('http://example.com');
});
});
describe('malformed URL handling', () => {
it('should reject invalid URL format', () => {
const result = validateRedirectUri('not-a-valid-url', 'https', 'example.com', mockLogger);
expect(result.isValid).toBe(false);
expect(result.validatedUri).toBe('https://example.com');
expect(result.reason).toContain('Invalid redirect_uri format');
expect(mockLogger.warn).toHaveBeenCalledWith('Invalid redirect_uri format: not-a-valid-url');
});
it('should reject javascript protocol', () => {
const result = validateRedirectUri(
'javascript:alert(1)',
'https',
'example.com',
mockLogger
);
expect(result.isValid).toBe(false);
expect(result.validatedUri).toBe('https://example.com');
});
it('should handle URLs with paths and query params', () => {
const result = validateRedirectUri(
'https://example.com:3000/callback?foo=bar',
'https',
'example.com',
mockLogger
);
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe('https://example.com:3000/callback?foo=bar');
});
});
describe('security scenarios', () => {
it('should prevent open redirect to attacker domain', () => {
const result = validateRedirectUri(
'https://attacker.com/steal-token',
'https',
'legitimate.com',
mockLogger
);
expect(result.isValid).toBe(false);
expect(result.validatedUri).toBe('https://legitimate.com');
});
it('should prevent homograph attacks with similar looking domains', () => {
const result = validateRedirectUri(
'https://examp1e.com', // with number 1 instead of letter l
'https',
'example.com',
mockLogger
);
expect(result.isValid).toBe(false);
expect(result.validatedUri).toBe('https://example.com');
});
it('should handle localhost variations correctly', () => {
const result = validateRedirectUri(
'http://localhost:3001',
'http',
'localhost:3000',
mockLogger
);
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe('http://localhost:3001');
});
it('should handle IP addresses correctly', () => {
const result = validateRedirectUri(
'http://192.168.1.100:3001',
'http',
'192.168.1.100:3000',
mockLogger
);
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe('http://192.168.1.100:3001');
});
it('should reject IP when expecting domain', () => {
const result = validateRedirectUri(
'https://192.168.1.100',
'https',
'example.com',
mockLogger
);
expect(result.isValid).toBe(false);
expect(result.validatedUri).toBe('https://example.com');
});
});
describe('allowed origins validation', () => {
it('should accept redirect URI with different hostname if in allowed origins', () => {
const result = validateRedirectUri(
'https://devgen-bad-dev1.local/graphql/api/auth/oidc/callback',
'http',
'devgen-dev1.local',
mockLogger,
['https://devgen-bad-dev1.local']
);
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe(
'https://devgen-bad-dev1.local/graphql/api/auth/oidc/callback'
);
});
it('should reject redirect URI with hostname not in allowed origins', () => {
const result = validateRedirectUri(
'https://evil.com/callback',
'http',
'devgen-dev1.local',
mockLogger,
['https://devgen-bad-dev1.local', 'https://another-allowed.local']
);
expect(result.isValid).toBe(false);
expect(result.reason).toContain('Hostname or protocol mismatch');
});
it('should handle multiple allowed origins', () => {
const result = validateRedirectUri(
'https://second-allowed.local/callback',
'http',
'devgen-dev1.local',
mockLogger,
[
'https://first-allowed.local',
'https://second-allowed.local',
'https://third-allowed.local',
]
);
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe('https://second-allowed.local/callback');
});
it('should allow HTTPS upgrade for allowed origins', () => {
const result = validateRedirectUri(
'https://allowed-host.local/callback',
'http',
'devgen-dev1.local',
mockLogger,
['http://allowed-host.local']
);
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe('https://allowed-host.local/callback');
});
it('should validate extra hostname from config with proper logging', () => {
// Simulate a config with an extra hostname like "unraid.local"
const allowedOriginsFromConfig = ['https://unraid.local:8443'];
const result = validateRedirectUri(
'https://unraid.local:8443/graphql/api/auth/oidc/callback',
'http', // Expected from headers
'devgen-dev1.local', // Primary hostname
mockLogger,
allowedOriginsFromConfig
);
// This should pass if the extra hostname is configured correctly
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe('https://unraid.local:8443/graphql/api/auth/oidc/callback');
// Verify that debug logging shows the check
expect(mockLogger.debug).toHaveBeenCalledWith('Checking against 1 allowed origins');
expect(mockLogger.debug).toHaveBeenCalledWith(
'Checking allowed origin: https://unraid.local:8443'
);
expect(mockLogger.debug).toHaveBeenCalledWith(
' Origin match: https://unraid.local:8443 matches https://unraid.local:8443'
);
expect(mockLogger.debug).toHaveBeenCalledWith(
'Validated redirect_uri against allowed origin: https://unraid.local:8443/graphql/api/auth/oidc/callback'
);
});
it('should fail validation when extra hostname is not in config', () => {
// Config has no extra hostnames
const allowedOriginsFromConfig: string[] = [];
const result = validateRedirectUri(
'https://unraid.local:8443/graphql/api/auth/oidc/callback',
'http',
'devgen-dev1.local',
mockLogger,
allowedOriginsFromConfig
);
// This should fail since unraid.local is not in allowed origins
expect(result.isValid).toBe(false);
expect(result.validatedUri).toBe('http://devgen-dev1.local');
expect(result.reason).toContain('Hostname or protocol mismatch');
// Verify error logging
expect(mockLogger.warn).toHaveBeenCalledWith(
expect.stringContaining('Hostname or protocol mismatch')
);
});
describe('enhanced matching modes', () => {
it('should validate URL prefix match', () => {
const allowedOriginsFromConfig = ['https://unraid.local:8443/graphql/api/auth/oidc/'];
const result = validateRedirectUri(
'https://unraid.local:8443/graphql/api/auth/oidc/callback?code=123',
'http',
'devgen-dev1.local',
mockLogger,
allowedOriginsFromConfig
);
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe(
'https://unraid.local:8443/graphql/api/auth/oidc/callback?code=123'
);
expect(mockLogger.debug).toHaveBeenCalledWith(
' URL prefix match: https://unraid.local:8443/graphql/api/auth/oidc/callback?code=123 matches prefix https://unraid.local:8443/graphql/api/auth/oidc/'
);
});
it('should validate origin match (protocol + hostname + port)', () => {
const allowedOriginsFromConfig = ['https://unraid.local:8443'];
const result = validateRedirectUri(
'https://unraid.local:8443/graphql/api/auth/oidc/callback',
'http',
'devgen-dev1.local',
mockLogger,
allowedOriginsFromConfig
);
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe(
'https://unraid.local:8443/graphql/api/auth/oidc/callback'
);
expect(mockLogger.debug).toHaveBeenCalledWith(
' Origin match: https://unraid.local:8443 matches https://unraid.local:8443'
);
});
it('should validate hostname-only match (original behavior)', () => {
const allowedOriginsFromConfig = ['https://unraid.local'];
const result = validateRedirectUri(
'https://unraid.local:8443/graphql/api/auth/oidc/callback',
'http',
'devgen-dev1.local',
mockLogger,
allowedOriginsFromConfig
);
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe(
'https://unraid.local:8443/graphql/api/auth/oidc/callback'
);
expect(mockLogger.debug).toHaveBeenCalledWith(
' Hostname comparison: provided=unraid.local, allowed=unraid.local'
);
});
it('should prefer exact URL match over origin match', () => {
const allowedOriginsFromConfig = [
'https://unraid.local:8443/graphql/api/auth/oidc/callback', // Exact match first
'https://unraid.local:8443', // Origin match second
];
const result = validateRedirectUri(
'https://unraid.local:8443/graphql/api/auth/oidc/callback',
'http',
'devgen-dev1.local',
mockLogger,
allowedOriginsFromConfig
);
expect(result.isValid).toBe(true);
// Should match the first item in the array due to order of specificity
expect(mockLogger.debug).toHaveBeenCalledWith(
' Exact URL match: https://unraid.local:8443/graphql/api/auth/oidc/callback matches https://unraid.local:8443/graphql/api/auth/oidc/callback'
);
});
it('should prefer origin match over hostname match', () => {
const allowedOriginsFromConfig = [
'https://unraid.local:8443', // Origin match first
'https://unraid.local', // Hostname match second
];
const result = validateRedirectUri(
'https://unraid.local:8443/graphql/api/auth/oidc/callback',
'http',
'devgen-dev1.local',
mockLogger,
allowedOriginsFromConfig
);
expect(result.isValid).toBe(true);
expect(mockLogger.debug).toHaveBeenCalledWith(
' Origin match: https://unraid.local:8443 matches https://unraid.local:8443'
);
});
it('should handle protocol upgrades for all matching modes', () => {
const allowedOriginsFromConfig = [
'http://unraid.local:8443/graphql/api/auth/oidc/callback',
'http://unraid.local:8443',
'http://unraid.local',
];
// Test exact URL with protocol upgrade
const result1 = validateRedirectUri(
'https://unraid.local:8443/graphql/api/auth/oidc/callback',
'http',
'devgen-dev1.local',
mockLogger,
allowedOriginsFromConfig
);
expect(result1.isValid).toBe(true);
// Test origin with protocol upgrade
const result2 = validateRedirectUri(
'https://unraid.local:8443/different/path',
'http',
'devgen-dev1.local',
mockLogger,
allowedOriginsFromConfig
);
expect(result2.isValid).toBe(true);
// Test hostname with protocol upgrade
const result3 = validateRedirectUri(
'https://unraid.local:9443/different/path',
'http',
'devgen-dev1.local',
mockLogger,
allowedOriginsFromConfig
);
expect(result3.isValid).toBe(true);
});
it('should handle invalid allowed origin formats gracefully', () => {
const allowedOriginsFromConfig = ['not-a-valid-url', 'https://valid.url'];
const result = validateRedirectUri(
'https://valid.url/graphql/api/auth/oidc/callback',
'http',
'devgen-dev1.local',
mockLogger,
allowedOriginsFromConfig
);
expect(result.isValid).toBe(true);
expect(mockLogger.warn).toHaveBeenCalledWith(
'Invalid allowed origin format: not-a-valid-url'
);
});
});
});
});

View File

@@ -0,0 +1,187 @@
import { Logger } from '@nestjs/common';
export interface RedirectUriValidationResult {
isValid: boolean;
validatedUri: string;
reason?: string;
}
/**
* Validates a redirect URI against the expected origin from request headers.
* This is critical for OAuth security to prevent authorization code interception.
*
* Security considerations:
* - Prevents redirecting OAuth codes to external domains
* - Allows port variations (needed for nginx/socket proxy scenarios)
* - Validates protocol to prevent downgrade attacks
* - Optionally validates against additional allowed origins
*
* @param redirectUri - The redirect URI provided by the client
* @param expectedProtocol - The protocol from request headers (http/https)
* @param expectedHost - The host from request headers (may or may not include port)
* @param logger - Optional logger for debugging
* @param allowedOrigins - Optional list of additional allowed origins
* @returns Validation result with the URI to use
*/
export function validateRedirectUri(
redirectUri: string | undefined,
expectedProtocol: string,
expectedHost: string | undefined,
logger?: Logger,
allowedOrigins?: string[]
): RedirectUriValidationResult {
const baseUrl = expectedHost ? `${expectedProtocol}://${expectedHost}` : undefined;
// If no redirect URI provided, use the base URL
if (!redirectUri || !baseUrl) {
return {
isValid: !redirectUri,
validatedUri: baseUrl || '',
reason: !redirectUri ? 'No redirect URI provided' : 'No base URL available',
};
}
try {
// Parse both URLs to validate hostname
const providedUrl = new URL(redirectUri);
const expectedUrl = new URL(baseUrl);
// Security: Validate hostname matches, but allow port differences
// This handles cases where nginx/socket proxy doesn't preserve port info
const providedHostname = providedUrl.hostname.toLowerCase();
const expectedHostname = expectedUrl.hostname.toLowerCase();
// Check protocol matches, but allow HTTPS when expecting HTTP (common with reverse proxies)
// Never allow HTTP when expecting HTTPS (would be a downgrade attack)
const protocolMatches =
providedUrl.protocol === expectedUrl.protocol ||
(expectedUrl.protocol === 'http:' && providedUrl.protocol === 'https:');
const hostnameMatches = providedHostname === expectedHostname;
// Check against primary expected origin
if (protocolMatches && hostnameMatches) {
// Trust the redirect_uri with its port information
logger?.debug(`Validated redirect_uri: ${redirectUri}`);
return {
isValid: true,
validatedUri: redirectUri,
};
}
// Check against additional allowed origins if provided
if (allowedOrigins && allowedOrigins.length > 0) {
logger?.debug(`Checking against ${allowedOrigins.length} allowed origins`);
for (const allowedOrigin of allowedOrigins) {
try {
const allowedUrl = new URL(allowedOrigin);
const allowedOriginStr = allowedUrl.origin.toLowerCase();
const allowedHostname = allowedUrl.hostname.toLowerCase();
logger?.debug(`Checking allowed origin: ${allowedOrigin}`);
// Try multiple matching strategies in order of specificity
// 1. Exact URL match (if allowed origin includes path/query)
if (allowedOrigin.includes('/') && allowedOrigin.length > allowedOriginStr.length) {
const allowedUrlNormalized = allowedOrigin.toLowerCase();
const providedUrlNormalized = redirectUri.toLowerCase();
// Exact match
if (providedUrlNormalized === allowedUrlNormalized) {
logger?.debug(` Exact URL match: ${redirectUri} matches ${allowedOrigin}`);
logger?.debug(
`Validated redirect_uri against allowed origin: ${redirectUri}`
);
return {
isValid: true,
validatedUri: redirectUri,
};
}
// Prefix match (if allowed origin ends with /)
if (
allowedUrlNormalized.endsWith('/') &&
providedUrlNormalized.startsWith(allowedUrlNormalized)
) {
logger?.debug(
` URL prefix match: ${redirectUri} matches prefix ${allowedOrigin}`
);
logger?.debug(
`Validated redirect_uri against allowed origin: ${redirectUri}`
);
return {
isValid: true,
validatedUri: redirectUri,
};
}
}
// 2. Origin match (protocol + hostname + port)
const providedOrigin = providedUrl.origin.toLowerCase();
if (providedOrigin === allowedOriginStr) {
// Allow HTTPS when expecting HTTP (common with reverse proxies)
const originProtocolMatches =
providedUrl.protocol === allowedUrl.protocol ||
(allowedUrl.protocol === 'http:' && providedUrl.protocol === 'https:');
if (originProtocolMatches) {
logger?.debug(
` Origin match: ${providedOrigin} matches ${allowedOriginStr}`
);
logger?.debug(
`Validated redirect_uri against allowed origin: ${redirectUri}`
);
return {
isValid: true,
validatedUri: redirectUri,
};
}
}
// 3. Hostname match (original behavior, but with better logging)
const allowedProtocolMatches =
providedUrl.protocol === allowedUrl.protocol ||
(allowedUrl.protocol === 'http:' && providedUrl.protocol === 'https:');
const allowedHostnameMatches = providedHostname === allowedHostname;
logger?.debug(
` Hostname comparison: provided=${providedHostname}, allowed=${allowedHostname}`
);
logger?.debug(
` Protocol comparison: provided=${providedUrl.protocol}, allowed=${allowedUrl.protocol}`
);
logger?.debug(
` Protocol matches: ${allowedProtocolMatches}, Hostname matches: ${allowedHostnameMatches}`
);
if (allowedProtocolMatches && allowedHostnameMatches) {
logger?.debug(`Validated redirect_uri against allowed origin: ${redirectUri}`);
return {
isValid: true,
validatedUri: redirectUri,
};
}
} catch (e) {
logger?.warn(`Invalid allowed origin format: ${allowedOrigin}`);
}
}
}
// If we get here, validation failed
const reason = `Hostname or protocol mismatch. Expected: ${expectedUrl.protocol}//${expectedHostname}, Got: ${providedUrl.protocol}//${providedHostname}`;
logger?.warn(`Rejected redirect_uri: ${reason}`);
return {
isValid: false,
validatedUri: baseUrl,
reason,
};
} catch (error) {
const reason = `Invalid redirect_uri format: ${redirectUri}`;
logger?.warn(reason);
return {
isValid: false,
validatedUri: baseUrl,
reason,
};
}
}

View File

@@ -5,6 +5,7 @@
"scripts": { "scripts": {
"build": "pnpm -r build", "build": "pnpm -r build",
"build:watch": " pnpm -r --parallel build:watch", "build:watch": " pnpm -r --parallel build:watch",
"codegen": "pnpm -r codegen",
"dev": "pnpm -r dev", "dev": "pnpm -r dev",
"unraid:deploy": "pnpm -r unraid:deploy", "unraid:deploy": "pnpm -r unraid:deploy",
"test": "pnpm -r test", "test": "pnpm -r test",

View File

@@ -1,94 +1,87 @@
/* eslint-disable */ /* eslint-disable */
import type { import type { ResultOf, DocumentTypeDecoration, TypedDocumentNode } from '@graphql-typed-document-node/core';
DocumentTypeDecoration,
ResultOf,
TypedDocumentNode,
} from '@graphql-typed-document-node/core';
import type { FragmentDefinitionNode } from 'graphql'; import type { FragmentDefinitionNode } from 'graphql';
import type { Incremental } from './graphql.js'; import type { Incremental } from './graphql.js';
export type FragmentType<TDocumentType extends DocumentTypeDecoration<any, any>> =
TDocumentType extends DocumentTypeDecoration<infer TType, any> export type FragmentType<TDocumentType extends DocumentTypeDecoration<any, any>> = TDocumentType extends DocumentTypeDecoration<
? [TType] extends [{ ' $fragmentName'?: infer TKey }] infer TType,
? TKey extends string any
? { ' $fragmentRefs'?: { [key in TKey]: TType } } >
: never ? [TType] extends [{ ' $fragmentName'?: infer TKey }]
: never ? TKey extends string
: never; ? { ' $fragmentRefs'?: { [key in TKey]: TType } }
: never
: never
: never;
// return non-nullable if `fragmentType` is non-nullable // return non-nullable if `fragmentType` is non-nullable
export function useFragment<TType>( export function useFragment<TType>(
_documentNode: DocumentTypeDecoration<TType, any>, _documentNode: DocumentTypeDecoration<TType, any>,
fragmentType: FragmentType<DocumentTypeDecoration<TType, any>> fragmentType: FragmentType<DocumentTypeDecoration<TType, any>>
): TType; ): TType;
// return nullable if `fragmentType` is undefined // return nullable if `fragmentType` is undefined
export function useFragment<TType>( export function useFragment<TType>(
_documentNode: DocumentTypeDecoration<TType, any>, _documentNode: DocumentTypeDecoration<TType, any>,
fragmentType: FragmentType<DocumentTypeDecoration<TType, any>> | undefined fragmentType: FragmentType<DocumentTypeDecoration<TType, any>> | undefined
): TType | undefined; ): TType | undefined;
// return nullable if `fragmentType` is nullable // return nullable if `fragmentType` is nullable
export function useFragment<TType>( export function useFragment<TType>(
_documentNode: DocumentTypeDecoration<TType, any>, _documentNode: DocumentTypeDecoration<TType, any>,
fragmentType: FragmentType<DocumentTypeDecoration<TType, any>> | null fragmentType: FragmentType<DocumentTypeDecoration<TType, any>> | null
): TType | null; ): TType | null;
// return nullable if `fragmentType` is nullable or undefined // return nullable if `fragmentType` is nullable or undefined
export function useFragment<TType>( export function useFragment<TType>(
_documentNode: DocumentTypeDecoration<TType, any>, _documentNode: DocumentTypeDecoration<TType, any>,
fragmentType: FragmentType<DocumentTypeDecoration<TType, any>> | null | undefined fragmentType: FragmentType<DocumentTypeDecoration<TType, any>> | null | undefined
): TType | null | undefined; ): TType | null | undefined;
// return array of non-nullable if `fragmentType` is array of non-nullable // return array of non-nullable if `fragmentType` is array of non-nullable
export function useFragment<TType>( export function useFragment<TType>(
_documentNode: DocumentTypeDecoration<TType, any>, _documentNode: DocumentTypeDecoration<TType, any>,
fragmentType: Array<FragmentType<DocumentTypeDecoration<TType, any>>> fragmentType: Array<FragmentType<DocumentTypeDecoration<TType, any>>>
): Array<TType>; ): Array<TType>;
// return array of nullable if `fragmentType` is array of nullable // return array of nullable if `fragmentType` is array of nullable
export function useFragment<TType>( export function useFragment<TType>(
_documentNode: DocumentTypeDecoration<TType, any>, _documentNode: DocumentTypeDecoration<TType, any>,
fragmentType: Array<FragmentType<DocumentTypeDecoration<TType, any>>> | null | undefined fragmentType: Array<FragmentType<DocumentTypeDecoration<TType, any>>> | null | undefined
): Array<TType> | null | undefined; ): Array<TType> | null | undefined;
// return readonly array of non-nullable if `fragmentType` is array of non-nullable // return readonly array of non-nullable if `fragmentType` is array of non-nullable
export function useFragment<TType>( export function useFragment<TType>(
_documentNode: DocumentTypeDecoration<TType, any>, _documentNode: DocumentTypeDecoration<TType, any>,
fragmentType: ReadonlyArray<FragmentType<DocumentTypeDecoration<TType, any>>> fragmentType: ReadonlyArray<FragmentType<DocumentTypeDecoration<TType, any>>>
): ReadonlyArray<TType>; ): ReadonlyArray<TType>;
// return readonly array of nullable if `fragmentType` is array of nullable // return readonly array of nullable if `fragmentType` is array of nullable
export function useFragment<TType>( export function useFragment<TType>(
_documentNode: DocumentTypeDecoration<TType, any>, _documentNode: DocumentTypeDecoration<TType, any>,
fragmentType: ReadonlyArray<FragmentType<DocumentTypeDecoration<TType, any>>> | null | undefined fragmentType: ReadonlyArray<FragmentType<DocumentTypeDecoration<TType, any>>> | null | undefined
): ReadonlyArray<TType> | null | undefined; ): ReadonlyArray<TType> | null | undefined;
export function useFragment<TType>( export function useFragment<TType>(
_documentNode: DocumentTypeDecoration<TType, any>, _documentNode: DocumentTypeDecoration<TType, any>,
fragmentType: fragmentType: FragmentType<DocumentTypeDecoration<TType, any>> | Array<FragmentType<DocumentTypeDecoration<TType, any>>> | ReadonlyArray<FragmentType<DocumentTypeDecoration<TType, any>>> | null | undefined
| FragmentType<DocumentTypeDecoration<TType, any>>
| Array<FragmentType<DocumentTypeDecoration<TType, any>>>
| ReadonlyArray<FragmentType<DocumentTypeDecoration<TType, any>>>
| null
| undefined
): TType | Array<TType> | ReadonlyArray<TType> | null | undefined { ): TType | Array<TType> | ReadonlyArray<TType> | null | undefined {
return fragmentType as any; return fragmentType as any;
} }
export function makeFragmentData<F extends DocumentTypeDecoration<any, any>, FT extends ResultOf<F>>(
data: FT, export function makeFragmentData<
_fragment: F F extends DocumentTypeDecoration<any, any>,
): FragmentType<F> { FT extends ResultOf<F>
return data as FragmentType<F>; >(data: FT, _fragment: F): FragmentType<F> {
return data as FragmentType<F>;
} }
export function isFragmentReady<TQuery, TFrag>( export function isFragmentReady<TQuery, TFrag>(
queryNode: DocumentTypeDecoration<TQuery, any>, queryNode: DocumentTypeDecoration<TQuery, any>,
fragmentNode: TypedDocumentNode<TFrag>, fragmentNode: TypedDocumentNode<TFrag>,
data: FragmentType<TypedDocumentNode<Incremental<TFrag>, any>> | null | undefined data: FragmentType<TypedDocumentNode<Incremental<TFrag>, any>> | null | undefined
): data is FragmentType<typeof fragmentNode> { ): data is FragmentType<typeof fragmentNode> {
const deferredFields = ( const deferredFields = (queryNode as { __meta__?: { deferredFields: Record<string, (keyof TFrag)[]> } }).__meta__
queryNode as { __meta__?: { deferredFields: Record<string, (keyof TFrag)[]> } } ?.deferredFields;
).__meta__?.deferredFields;
if (!deferredFields) return true; if (!deferredFields) return true;
const fragDef = fragmentNode.definitions[0] as FragmentDefinitionNode | undefined; const fragDef = fragmentNode.definitions[0] as FragmentDefinitionNode | undefined;
const fragName = fragDef?.name?.value; const fragName = fragDef?.name?.value;
const fields = (fragName && deferredFields[fragName]) || []; const fields = (fragName && deferredFields[fragName]) || [];
return fields.length > 0 && fields.every((field) => data && field in data); return fields.length > 0 && fields.every(field => data && field in data);
} }

View File

@@ -1,7 +1,6 @@
/* eslint-disable */ /* eslint-disable */
import type { TypedDocumentNode as DocumentNode } from '@graphql-typed-document-node/core';
import * as types from './graphql.js'; import * as types from './graphql.js';
import type { TypedDocumentNode as DocumentNode } from '@graphql-typed-document-node/core';
/** /**
* Map of all GraphQL operations in the project. * Map of all GraphQL operations in the project.
@@ -15,17 +14,14 @@ import * as types from './graphql.js';
* Learn more about it here: https://the-guild.dev/graphql/codegen/plugins/presets/preset-client#reducing-bundle-size * Learn more about it here: https://the-guild.dev/graphql/codegen/plugins/presets/preset-client#reducing-bundle-size
*/ */
type Documents = { type Documents = {
'\n fragment RemoteGraphQLEventFragment on RemoteGraphQLEvent {\n remoteGraphQLEventData: data {\n type\n body\n sha256\n }\n }\n': typeof types.RemoteGraphQlEventFragmentFragmentDoc; "\n fragment RemoteGraphQLEventFragment on RemoteGraphQLEvent {\n remoteGraphQLEventData: data {\n type\n body\n sha256\n }\n }\n": typeof types.RemoteGraphQlEventFragmentFragmentDoc,
'\n subscription events {\n events {\n __typename\n ... on ClientConnectedEvent {\n connectedData: data {\n type\n version\n apiKey\n }\n connectedEvent: type\n }\n ... on ClientDisconnectedEvent {\n disconnectedData: data {\n type\n version\n apiKey\n }\n disconnectedEvent: type\n }\n ...RemoteGraphQLEventFragment\n }\n }\n': typeof types.EventsDocument; "\n subscription events {\n events {\n __typename\n ... on ClientConnectedEvent {\n connectedData: data {\n type\n version\n apiKey\n }\n connectedEvent: type\n }\n ... on ClientDisconnectedEvent {\n disconnectedData: data {\n type\n version\n apiKey\n }\n disconnectedEvent: type\n }\n ...RemoteGraphQLEventFragment\n }\n }\n": typeof types.EventsDocument,
'\n mutation sendRemoteGraphQLResponse($input: RemoteGraphQLServerInput!) {\n remoteGraphQLResponse(input: $input)\n }\n': typeof types.SendRemoteGraphQlResponseDocument; "\n mutation sendRemoteGraphQLResponse($input: RemoteGraphQLServerInput!) {\n remoteGraphQLResponse(input: $input)\n }\n": typeof types.SendRemoteGraphQlResponseDocument,
}; };
const documents: Documents = { const documents: Documents = {
'\n fragment RemoteGraphQLEventFragment on RemoteGraphQLEvent {\n remoteGraphQLEventData: data {\n type\n body\n sha256\n }\n }\n': "\n fragment RemoteGraphQLEventFragment on RemoteGraphQLEvent {\n remoteGraphQLEventData: data {\n type\n body\n sha256\n }\n }\n": types.RemoteGraphQlEventFragmentFragmentDoc,
types.RemoteGraphQlEventFragmentFragmentDoc, "\n subscription events {\n events {\n __typename\n ... on ClientConnectedEvent {\n connectedData: data {\n type\n version\n apiKey\n }\n connectedEvent: type\n }\n ... on ClientDisconnectedEvent {\n disconnectedData: data {\n type\n version\n apiKey\n }\n disconnectedEvent: type\n }\n ...RemoteGraphQLEventFragment\n }\n }\n": types.EventsDocument,
'\n subscription events {\n events {\n __typename\n ... on ClientConnectedEvent {\n connectedData: data {\n type\n version\n apiKey\n }\n connectedEvent: type\n }\n ... on ClientDisconnectedEvent {\n disconnectedData: data {\n type\n version\n apiKey\n }\n disconnectedEvent: type\n }\n ...RemoteGraphQLEventFragment\n }\n }\n': "\n mutation sendRemoteGraphQLResponse($input: RemoteGraphQLServerInput!) {\n remoteGraphQLResponse(input: $input)\n }\n": types.SendRemoteGraphQlResponseDocument,
types.EventsDocument,
'\n mutation sendRemoteGraphQLResponse($input: RemoteGraphQLServerInput!) {\n remoteGraphQLResponse(input: $input)\n }\n':
types.SendRemoteGraphQlResponseDocument,
}; };
/** /**
@@ -45,25 +41,18 @@ export function graphql(source: string): unknown;
/** /**
* The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients. * The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients.
*/ */
export function graphql( export function graphql(source: "\n fragment RemoteGraphQLEventFragment on RemoteGraphQLEvent {\n remoteGraphQLEventData: data {\n type\n body\n sha256\n }\n }\n"): (typeof documents)["\n fragment RemoteGraphQLEventFragment on RemoteGraphQLEvent {\n remoteGraphQLEventData: data {\n type\n body\n sha256\n }\n }\n"];
source: '\n fragment RemoteGraphQLEventFragment on RemoteGraphQLEvent {\n remoteGraphQLEventData: data {\n type\n body\n sha256\n }\n }\n'
): (typeof documents)['\n fragment RemoteGraphQLEventFragment on RemoteGraphQLEvent {\n remoteGraphQLEventData: data {\n type\n body\n sha256\n }\n }\n'];
/** /**
* The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients. * The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients.
*/ */
export function graphql( export function graphql(source: "\n subscription events {\n events {\n __typename\n ... on ClientConnectedEvent {\n connectedData: data {\n type\n version\n apiKey\n }\n connectedEvent: type\n }\n ... on ClientDisconnectedEvent {\n disconnectedData: data {\n type\n version\n apiKey\n }\n disconnectedEvent: type\n }\n ...RemoteGraphQLEventFragment\n }\n }\n"): (typeof documents)["\n subscription events {\n events {\n __typename\n ... on ClientConnectedEvent {\n connectedData: data {\n type\n version\n apiKey\n }\n connectedEvent: type\n }\n ... on ClientDisconnectedEvent {\n disconnectedData: data {\n type\n version\n apiKey\n }\n disconnectedEvent: type\n }\n ...RemoteGraphQLEventFragment\n }\n }\n"];
source: '\n subscription events {\n events {\n __typename\n ... on ClientConnectedEvent {\n connectedData: data {\n type\n version\n apiKey\n }\n connectedEvent: type\n }\n ... on ClientDisconnectedEvent {\n disconnectedData: data {\n type\n version\n apiKey\n }\n disconnectedEvent: type\n }\n ...RemoteGraphQLEventFragment\n }\n }\n'
): (typeof documents)['\n subscription events {\n events {\n __typename\n ... on ClientConnectedEvent {\n connectedData: data {\n type\n version\n apiKey\n }\n connectedEvent: type\n }\n ... on ClientDisconnectedEvent {\n disconnectedData: data {\n type\n version\n apiKey\n }\n disconnectedEvent: type\n }\n ...RemoteGraphQLEventFragment\n }\n }\n'];
/** /**
* The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients. * The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients.
*/ */
export function graphql( export function graphql(source: "\n mutation sendRemoteGraphQLResponse($input: RemoteGraphQLServerInput!) {\n remoteGraphQLResponse(input: $input)\n }\n"): (typeof documents)["\n mutation sendRemoteGraphQLResponse($input: RemoteGraphQLServerInput!) {\n remoteGraphQLResponse(input: $input)\n }\n"];
source: '\n mutation sendRemoteGraphQLResponse($input: RemoteGraphQLServerInput!) {\n remoteGraphQLResponse(input: $input)\n }\n'
): (typeof documents)['\n mutation sendRemoteGraphQLResponse($input: RemoteGraphQLServerInput!) {\n remoteGraphQLResponse(input: $input)\n }\n'];
export function graphql(source: string) { export function graphql(source: string) {
return (documents as any)[source] ?? {}; return (documents as any)[source] ?? {};
} }
export type DocumentType<TDocumentNode extends DocumentNode<any, any>> = export type DocumentType<TDocumentNode extends DocumentNode<any, any>> = TDocumentNode extends DocumentNode< infer TType, any> ? TType : never;
TDocumentNode extends DocumentNode<infer TType, any> ? TType : never;

View File

@@ -1,2 +1,2 @@
export * from './fragment-masking.js'; export * from "./fragment-masking.js";
export * from './gql.js'; export * from "./gql.js";

View File

@@ -0,0 +1,6 @@
Menu="ManagementAccess:160"
Title="API Config Download"
Icon="icon-download"
Tag="download"
---
<unraid-config-download />

85
pnpm-lock.yaml generated
View File

@@ -178,6 +178,9 @@ importers:
dotenv: dotenv:
specifier: 17.2.1 specifier: 17.2.1
version: 17.2.1 version: 17.2.1
escape-html:
specifier: 1.0.3
version: 1.0.3
execa: execa:
specifier: 9.6.0 specifier: 9.6.0
version: 9.6.0 version: 9.6.0
@@ -1091,6 +1094,9 @@ importers:
ajv: ajv:
specifier: 8.17.1 specifier: 8.17.1
version: 8.17.1 version: 8.17.1
ansi_up:
specifier: ^6.0.6
version: 6.0.6
class-variance-authority: class-variance-authority:
specifier: 0.7.1 specifier: 0.7.1
version: 0.7.1 version: 0.7.1
@@ -6139,6 +6145,9 @@ packages:
resolution: {integrity: sha512-bN798gFfQX+viw3R7yrGWRqnrN2oRkEkUjjl4JNn4E8GxxbjtG3FbrEIIY3l8/hrwUwIeCZvi4QuOTP4MErVug==} resolution: {integrity: sha512-bN798gFfQX+viw3R7yrGWRqnrN2oRkEkUjjl4JNn4E8GxxbjtG3FbrEIIY3l8/hrwUwIeCZvi4QuOTP4MErVug==}
engines: {node: '>=12'} engines: {node: '>=12'}
ansi_up@6.0.6:
resolution: {integrity: sha512-yIa1x3Ecf8jWP4UWEunNjqNX6gzE4vg2gGz+xqRGY+TBSucnYp6RRdPV4brmtg6bQ1ljD48mZ5iGSEj7QEpRKA==}
ansis@4.0.0-node10: ansis@4.0.0-node10:
resolution: {integrity: sha512-BRrU0Bo1X9dFGw6KgGz6hWrqQuOlVEDOzkb0QSLZY9sXHqA7pNj7yHPVJRz7y/rj4EOJ3d/D5uxH+ee9leYgsg==} resolution: {integrity: sha512-BRrU0Bo1X9dFGw6KgGz6hWrqQuOlVEDOzkb0QSLZY9sXHqA7pNj7yHPVJRz7y/rj4EOJ3d/D5uxH+ee9leYgsg==}
engines: {node: '>=10'} engines: {node: '>=10'}
@@ -7694,10 +7703,6 @@ packages:
errx@0.1.0: errx@0.1.0:
resolution: {integrity: sha512-fZmsRiDNv07K6s2KkKFTiD2aIvECa7++PKyD5NC32tpRw46qZA3sOz+aM+/V9V0GDHxVTKLziveV4JhzBHDp9Q==} resolution: {integrity: sha512-fZmsRiDNv07K6s2KkKFTiD2aIvECa7++PKyD5NC32tpRw46qZA3sOz+aM+/V9V0GDHxVTKLziveV4JhzBHDp9Q==}
es-abstract@1.23.9:
resolution: {integrity: sha512-py07lI0wjxAC/DcfK1S6G7iANonniZwTISvdPzk9hzeH0IZIshbuuFxLIU96OyF89Yb9hiqWn8M/bY83KY5vzA==}
engines: {node: '>= 0.4'}
es-abstract@1.24.0: es-abstract@1.24.0:
resolution: {integrity: sha512-WSzPgsdLtTcQwm4CROfS5ju2Wa1QQcVeT37jFjYzdFz1r9ahadC8B8/a4qxJxM+09F18iumCdRmlr96ZYkQvEg==} resolution: {integrity: sha512-WSzPgsdLtTcQwm4CROfS5ju2Wa1QQcVeT37jFjYzdFz1r9ahadC8B8/a4qxJxM+09F18iumCdRmlr96ZYkQvEg==}
engines: {node: '>= 0.4'} engines: {node: '>= 0.4'}
@@ -19163,6 +19168,8 @@ snapshots:
ansi-styles@6.2.1: {} ansi-styles@6.2.1: {}
ansi_up@6.0.6: {}
ansis@4.0.0-node10: {} ansis@4.0.0-node10: {}
ansis@4.1.0: {} ansis@4.1.0: {}
@@ -19249,7 +19256,7 @@ snapshots:
call-bind: 1.0.8 call-bind: 1.0.8
call-bound: 1.0.4 call-bound: 1.0.4
define-properties: 1.2.1 define-properties: 1.2.1
es-abstract: 1.23.9 es-abstract: 1.24.0
es-errors: 1.3.0 es-errors: 1.3.0
es-object-atoms: 1.1.1 es-object-atoms: 1.1.1
es-shim-unscopables: 1.1.0 es-shim-unscopables: 1.1.0
@@ -19258,14 +19265,14 @@ snapshots:
dependencies: dependencies:
call-bind: 1.0.8 call-bind: 1.0.8
define-properties: 1.2.1 define-properties: 1.2.1
es-abstract: 1.23.9 es-abstract: 1.24.0
es-shim-unscopables: 1.1.0 es-shim-unscopables: 1.1.0
array.prototype.flatmap@1.3.3: array.prototype.flatmap@1.3.3:
dependencies: dependencies:
call-bind: 1.0.8 call-bind: 1.0.8
define-properties: 1.2.1 define-properties: 1.2.1
es-abstract: 1.23.9 es-abstract: 1.24.0
es-shim-unscopables: 1.1.0 es-shim-unscopables: 1.1.0
arraybuffer.prototype.slice@1.0.4: arraybuffer.prototype.slice@1.0.4:
@@ -19273,7 +19280,7 @@ snapshots:
array-buffer-byte-length: 1.0.2 array-buffer-byte-length: 1.0.2
call-bind: 1.0.8 call-bind: 1.0.8
define-properties: 1.2.1 define-properties: 1.2.1
es-abstract: 1.23.9 es-abstract: 1.24.0
es-errors: 1.3.0 es-errors: 1.3.0
get-intrinsic: 1.3.0 get-intrinsic: 1.3.0
is-array-buffer: 3.0.5 is-array-buffer: 3.0.5
@@ -20819,60 +20826,6 @@ snapshots:
errx@0.1.0: {} errx@0.1.0: {}
es-abstract@1.23.9:
dependencies:
array-buffer-byte-length: 1.0.2
arraybuffer.prototype.slice: 1.0.4
available-typed-arrays: 1.0.7
call-bind: 1.0.8
call-bound: 1.0.4
data-view-buffer: 1.0.2
data-view-byte-length: 1.0.2
data-view-byte-offset: 1.0.1
es-define-property: 1.0.1
es-errors: 1.3.0
es-object-atoms: 1.1.1
es-set-tostringtag: 2.1.0
es-to-primitive: 1.3.0
function.prototype.name: 1.1.8
get-intrinsic: 1.3.0
get-proto: 1.0.1
get-symbol-description: 1.1.0
globalthis: 1.0.4
gopd: 1.2.0
has-property-descriptors: 1.0.2
has-proto: 1.2.0
has-symbols: 1.1.0
hasown: 2.0.2
internal-slot: 1.1.0
is-array-buffer: 3.0.5
is-callable: 1.2.7
is-data-view: 1.0.2
is-regex: 1.2.1
is-shared-array-buffer: 1.0.4
is-string: 1.1.1
is-typed-array: 1.1.15
is-weakref: 1.1.1
math-intrinsics: 1.1.0
object-inspect: 1.13.4
object-keys: 1.1.1
object.assign: 4.1.7
own-keys: 1.0.1
regexp.prototype.flags: 1.5.4
safe-array-concat: 1.1.3
safe-push-apply: 1.0.0
safe-regex-test: 1.1.0
set-proto: 1.0.0
string.prototype.trim: 1.2.10
string.prototype.trimend: 1.0.9
string.prototype.trimstart: 1.0.8
typed-array-buffer: 1.0.3
typed-array-byte-length: 1.0.3
typed-array-byte-offset: 1.0.4
typed-array-length: 1.0.7
unbox-primitive: 1.1.0
which-typed-array: 1.1.19
es-abstract@1.24.0: es-abstract@1.24.0:
dependencies: dependencies:
array-buffer-byte-length: 1.0.2 array-buffer-byte-length: 1.0.2
@@ -24175,14 +24128,14 @@ snapshots:
dependencies: dependencies:
call-bind: 1.0.8 call-bind: 1.0.8
define-properties: 1.2.1 define-properties: 1.2.1
es-abstract: 1.23.9 es-abstract: 1.24.0
es-object-atoms: 1.1.1 es-object-atoms: 1.1.1
object.groupby@1.0.3: object.groupby@1.0.3:
dependencies: dependencies:
call-bind: 1.0.8 call-bind: 1.0.8
define-properties: 1.2.1 define-properties: 1.2.1
es-abstract: 1.23.9 es-abstract: 1.24.0
object.values@1.2.1: object.values@1.2.1:
dependencies: dependencies:
@@ -25364,7 +25317,7 @@ snapshots:
dependencies: dependencies:
call-bind: 1.0.8 call-bind: 1.0.8
define-properties: 1.2.1 define-properties: 1.2.1
es-abstract: 1.23.9 es-abstract: 1.24.0
es-errors: 1.3.0 es-errors: 1.3.0
es-object-atoms: 1.1.1 es-object-atoms: 1.1.1
get-intrinsic: 1.3.0 get-intrinsic: 1.3.0
@@ -26146,7 +26099,7 @@ snapshots:
call-bound: 1.0.4 call-bound: 1.0.4
define-data-property: 1.1.4 define-data-property: 1.1.4
define-properties: 1.2.1 define-properties: 1.2.1
es-abstract: 1.23.9 es-abstract: 1.24.0
es-object-atoms: 1.1.1 es-object-atoms: 1.1.1
has-property-descriptors: 1.0.2 has-property-descriptors: 1.0.2

View File

@@ -42,7 +42,7 @@ import {
AccordionTrigger, AccordionTrigger,
} from '@/components/ui/accordion'; } from '@/components/ui/accordion';
import { jsonFormsAjv } from '@/forms/config'; import { jsonFormsAjv } from '@/forms/config';
import type { Layout, UISchemaElement } from '@jsonforms/core'; import type { BaseUISchemaElement, Labelable, Layout, UISchemaElement } from '@jsonforms/core';
import { isVisible } from '@jsonforms/core'; import { isVisible } from '@jsonforms/core';
import { DispatchRenderer, useJsonFormsLayout } from '@jsonforms/vue'; import { DispatchRenderer, useJsonFormsLayout } from '@jsonforms/vue';
import type { RendererProps } from '@jsonforms/vue'; import type { RendererProps } from '@jsonforms/vue';
@@ -61,8 +61,9 @@ const elements = computed(() => {
const allElements = props.uischema?.elements || []; const allElements = props.uischema?.elements || [];
// Filter elements based on visibility rules // Filter elements based on visibility rules
return allElements.filter((element: UISchemaElement & Record<string, unknown>) => { return allElements.filter((element) => {
if (!element.rule) { const elementWithRule = element as BaseUISchemaElement;
if (!elementWithRule.rule) {
// No rule means always visible // No rule means always visible
return true; return true;
} }
@@ -71,13 +72,13 @@ const elements = computed(() => {
try { try {
// Get the root data from JSONForms context for rule evaluation // Get the root data from JSONForms context for rule evaluation
const rootData = jsonFormsContext?.core?.data || {}; const rootData = jsonFormsContext?.core?.data || {};
const formData = props.data || layout.data || rootData; const formData = props.data || rootData;
const formPath = props.path || layout.path || ''; const formPath = props.path || layout.value.path || '';
const visible = isVisible(element as UISchemaElement, formData, formPath, jsonFormsAjv); const visible = isVisible(element, formData, formPath, jsonFormsAjv);
return visible; return visible;
} catch (error) { } catch (error) {
console.warn('[AccordionLayout] Error evaluating visibility:', error, element.rule); console.warn('[AccordionLayout] Error evaluating visibility:', error, elementWithRule.rule);
return true; // Default to visible on error return true; // Default to visible on error
} }
}); });
@@ -127,31 +128,21 @@ const defaultOpenItems = computed(() => {
}); });
// Get title for accordion item from element options // Get title for accordion item from element options
const getAccordionTitle = ( const getAccordionTitle = (element: UISchemaElement, index: number): string => {
element: UISchemaElement & Record<string, unknown>, const el = element as BaseUISchemaElement & Labelable;
index: number const options = el.options;
): string => { const accordionTitle = options?.accordion?.title;
return ( const title = options?.title;
(element as { options?: { accordion?: { title?: string }; title?: string }; text?: string }).options const text = el.label;
?.accordion?.title || return accordionTitle || title || text || `Section ${index + 1}`;
(element as { options?: { accordion?: { title?: string }; title?: string }; text?: string }).options
?.title ||
(element as { options?: { accordion?: { title?: string }; title?: string }; text?: string }).text ||
`Section ${index + 1}`
);
}; };
// Get description for accordion item from element options // Get description for accordion item from element options
const getAccordionDescription = ( const getAccordionDescription = (element: UISchemaElement, _index: number): string => {
element: UISchemaElement & Record<string, unknown>, const el = element as BaseUISchemaElement;
index: number const options = el.options;
): string => { const accordionDescription = options?.accordion?.description;
return ( const description = options?.description;
(element as { options?: { accordion?: { description?: string }; description?: string } }).options return accordionDescription || description || '';
?.accordion?.description ||
(element as { options?: { accordion?: { description?: string }; description?: string } }).options
?.description ||
''
);
}; };
</script> </script>

View File

@@ -0,0 +1,424 @@
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { mount, flushPromises } from '@vue/test-utils';
import { nextTick } from 'vue';
import DOMPurify from 'isomorphic-dompurify';
import { AnsiUp } from 'ansi_up';
import { useQuery } from '@vue/apollo-composable';
import SingleLogViewer from '~/components/Logs/SingleLogViewer.vue';
import {
createMockUseQuery,
createMockLogFileQuery
} from '../../helpers/apollo-mocks';
// Mock the UI components
vi.mock('@unraid/ui', () => ({
Button: { template: '<button><slot /></button>' },
Tooltip: { template: '<div><slot /></div>' },
TooltipContent: { template: '<div><slot /></div>' },
TooltipProvider: { template: '<div><slot /></div>' },
TooltipTrigger: { template: '<div><slot /></div>' }
}));
// Mock the GraphQL query
vi.mock('@vue/apollo-composable', () => ({
useApolloClient: vi.fn(() => ({
client: {
query: vi.fn()
}
})),
useQuery: vi.fn()
}));
// Mock the theme store
vi.mock('~/store/theme', () => ({
useThemeStore: vi.fn(() => ({
darkMode: false
}))
}));
describe('SingleLogViewer - ANSI Color Support', () => {
let ansiConverter: AnsiUp;
beforeEach(() => {
// Create a fresh converter instance for each test
ansiConverter = new AnsiUp();
ansiConverter.use_classes = true;
ansiConverter.escape_html = true;
});
describe('ANSI to HTML Conversion', () => {
it('should convert ANSI color codes to CSS classes', () => {
const testCases = [
{
input: '\x1b[31mRed text\x1b[0m',
expected: '<span class="ansi-red-fg">Red text</span>',
description: 'red foreground'
},
{
input: '\x1b[32mGreen text\x1b[0m',
expected: '<span class="ansi-green-fg">Green text</span>',
description: 'green foreground'
},
{
input: '\x1b[33mYellow text\x1b[0m',
expected: '<span class="ansi-yellow-fg">Yellow text</span>',
description: 'yellow foreground'
},
{
input: '\x1b[34mBlue text\x1b[0m',
expected: '<span class="ansi-blue-fg">Blue text</span>',
description: 'blue foreground'
},
{
input: '\x1b[91mBright red\x1b[0m',
expected: '<span class="ansi-bright-red-fg">Bright red</span>',
description: 'bright red foreground'
},
{
input: '\x1b[41mRed background\x1b[0m',
expected: '<span class="ansi-red-bg">Red background</span>',
description: 'red background'
},
{
input: '\x1b[1mBold text\x1b[0m',
expected: '<span style="font-weight:bold">Bold text</span>',
description: 'bold text (ansi_up uses inline style for bold)'
},
{
input: '\x1b[3mItalic text\x1b[0m',
expected: '<span style="font-style:italic">Italic text</span>',
description: 'italic text (ansi_up uses inline style for italic)'
},
{
input: '\x1b[4mUnderlined text\x1b[0m',
expected: '<span style="text-decoration:underline">Underlined text</span>',
description: 'underlined text (ansi_up uses inline style for underline)'
}
];
testCases.forEach(({ input, expected, description }) => {
const result = ansiConverter.ansi_to_html(input);
expect(result, `Failed for ${description}`).toBe(expected);
});
});
it('should handle multiple ANSI codes in one string', () => {
const input = '\x1b[31mRed\x1b[0m \x1b[32mGreen\x1b[0m \x1b[34mBlue\x1b[0m';
const expected = '<span class="ansi-red-fg">Red</span> <span class="ansi-green-fg">Green</span> <span class="ansi-blue-fg">Blue</span>';
const result = ansiConverter.ansi_to_html(input);
expect(result).toBe(expected);
});
it('should handle nested ANSI codes', () => {
const input = '\x1b[1m\x1b[31mBold Red Text\x1b[0m';
const result = ansiConverter.ansi_to_html(input);
// ansi_up uses inline style for bold
expect(result).toContain('font-weight:bold');
expect(result).toContain('ansi-red-fg');
});
it('should escape HTML entities for security', () => {
const input = '\x1b[31m<script>alert("XSS")</script>\x1b[0m';
const result = ansiConverter.ansi_to_html(input);
expect(result).not.toContain('<script>');
expect(result).toContain('&lt;script&gt;');
});
});
describe('DOMPurify Sanitization', () => {
it('should preserve CSS classes after sanitization', () => {
const htmlWithClasses = '<span class="ansi-red-fg">Red text</span>';
const sanitized = DOMPurify.sanitize(htmlWithClasses, {
ALLOWED_TAGS: ['span', 'br'],
ALLOWED_ATTR: ['class']
});
expect(sanitized).toBe(htmlWithClasses);
});
it('should remove inline styles when configured', () => {
const htmlWithStyles = '<span style="color: red;">Red text</span>';
const sanitized = DOMPurify.sanitize(htmlWithStyles, {
ALLOWED_TAGS: ['span', 'br'],
ALLOWED_ATTR: ['class'] // Note: 'style' is not allowed
});
expect(sanitized).toBe('<span>Red text</span>');
});
it('should remove dangerous tags while preserving safe content', () => {
const dangerous = '<span class="ansi-red-fg">Safe</span><script>alert("XSS")</script>';
const sanitized = DOMPurify.sanitize(dangerous, {
ALLOWED_TAGS: ['span', 'br'],
ALLOWED_ATTR: ['class']
});
expect(sanitized).toBe('<span class="ansi-red-fg">Safe</span>');
});
it('should handle complex nested structures', () => {
const complex = '<span class="ansi-bold"><span class="ansi-red-fg">Bold Red</span></span>';
const sanitized = DOMPurify.sanitize(complex, {
ALLOWED_TAGS: ['span', 'br'],
ALLOWED_ATTR: ['class']
});
expect(sanitized).toBe(complex);
});
});
describe('CSS Class Definitions', () => {
it('should have CSS rules for all standard ANSI colors', async () => {
// Mock useQuery to return empty data for this test
// @ts-expect-error Mock implementation for testing
vi.mocked(useQuery).mockReturnValue(createMockUseQuery());
const wrapper = mount(SingleLogViewer, {
props: {
logFilePath: '/test/log.txt',
lineCount: 100,
autoScroll: false
},
global: {
stubs: {
Button: true,
Tooltip: true,
TooltipContent: true,
TooltipProvider: true,
TooltipTrigger: true
}
}
});
// Wait for component to mount
await nextTick();
// Check that the component mounts without errors
expect(wrapper.exists()).toBe(true);
wrapper.unmount();
});
});
describe('Integration Tests', () => {
beforeEach(() => {
// Reset mocks before each test
vi.clearAllMocks();
});
it('should properly render ANSI colored log content', async () => {
// Create mock data
const content = '\x1b[31m[ERROR]\x1b[0m Failed to connect\n\x1b[32m[SUCCESS]\x1b[0m Connected';
const mockQuery = createMockLogFileQuery(content, 2, 1);
// Mock useQuery to return our data
// @ts-expect-error Mock implementation for testing
vi.mocked(useQuery).mockReturnValue(mockQuery);
const wrapper = mount(SingleLogViewer, {
props: {
logFilePath: '/test/log.txt',
lineCount: 100,
autoScroll: false
}
});
// Wait for the component to mount and process initial data
await wrapper.vm.$nextTick();
// Trigger the watcher by modifying the result
// @ts-expect-error Accessing mock properties
if (mockQuery.result.value) {
// @ts-expect-error Modifying mock properties
mockQuery.result.value = {
logFile: {
content,
totalLines: 2,
startLine: 1
}
};
}
// Wait for watchers to process
await wrapper.vm.$nextTick();
await flushPromises();
await wrapper.vm.$nextTick();
// Get the pre element that contains the log content
const preElement = wrapper.find('pre.hljs');
expect(preElement.exists()).toBe(true);
// Check that the rendered HTML contains the CSS classes
const html = preElement.html();
if (!html.includes('ansi-red-fg')) {
console.log('Pre element HTML:', html);
console.log('Full wrapper HTML:', wrapper.html());
}
expect(html).toContain('ansi-red-fg');
expect(html).toContain('[ERROR]');
expect(html).toContain('ansi-green-fg');
expect(html).toContain('[SUCCESS]');
wrapper.unmount();
});
it('should handle log content with mixed ANSI and plain text', async () => {
const content = 'Plain text \x1b[33mWarning\x1b[0m more plain text';
const mockQuery = createMockLogFileQuery(content, 1, 1);
// @ts-expect-error Mock implementation for testing
vi.mocked(useQuery).mockReturnValue(mockQuery);
const wrapper = mount(SingleLogViewer, {
props: {
logFilePath: '/test/log.txt',
lineCount: 100,
autoScroll: false
}
});
// Wait for mount and trigger the watcher
await wrapper.vm.$nextTick();
// @ts-expect-error Accessing mock properties
if (mockQuery.result.value) {
// @ts-expect-error Modifying mock properties
mockQuery.result.value = {
logFile: {
content,
totalLines: 1,
startLine: 1
}
};
}
// Wait for processing
await wrapper.vm.$nextTick();
await flushPromises();
await wrapper.vm.$nextTick();
const preElement = wrapper.find('pre.hljs');
expect(preElement.exists()).toBe(true);
const html = preElement.html();
expect(html).toContain('Plain text');
expect(html).toContain('ansi-yellow-fg');
expect(html).toContain('Warning');
expect(html).toContain('more plain text');
wrapper.unmount();
});
it('should apply client-side filtering while preserving ANSI colors', async () => {
const content = '\x1b[31m[ERROR]\x1b[0m Connection failed\n\x1b[32m[INFO]\x1b[0m Connected\n\x1b[31m[ERROR]\x1b[0m Timeout';
const mockQuery = createMockLogFileQuery(content, 3, 1);
// @ts-expect-error Mock implementation for testing
vi.mocked(useQuery).mockReturnValue(mockQuery);
const wrapper = mount(SingleLogViewer, {
props: {
logFilePath: '/test/log.txt',
lineCount: 100,
autoScroll: false,
clientFilter: 'ERROR'
}
});
// Wait for mount and trigger the watcher
await wrapper.vm.$nextTick();
// @ts-expect-error Accessing mock properties
if (mockQuery.result.value) {
// @ts-expect-error Modifying mock properties
mockQuery.result.value = {
logFile: {
content,
totalLines: 3,
startLine: 1
}
};
}
// Wait for processing
await wrapper.vm.$nextTick();
await flushPromises();
await wrapper.vm.$nextTick();
const preElement = wrapper.find('pre.hljs');
expect(preElement.exists()).toBe(true);
const html = preElement.html();
// Should contain ERROR lines with red color
expect(html).toContain('ansi-red-fg');
expect(html).toContain('ERROR');
// Should not contain INFO line (due to filter)
expect(html).not.toContain('INFO');
expect(html).not.toContain('ansi-green-fg');
wrapper.unmount();
});
});
describe('Performance Tests', () => {
it('should handle large amounts of ANSI colored text efficiently', () => {
const lines = 1000;
const largeInput = Array(lines)
.fill(null)
.map((_, i) => {
const colors = ['31', '32', '33', '34', '35', '36'];
const color = colors[i % colors.length];
return `\x1b[${color}mLine ${i}: Some log message with color\x1b[0m`;
})
.join('\n');
const result = ansiConverter.ansi_to_html(largeInput);
// Should contain the expected number of color spans
const colorMatches = result.match(/class="ansi-/g);
expect(colorMatches).toHaveLength(lines);
});
it('should efficiently sanitize large HTML with many CSS classes', () => {
const lines = 1000;
const largeHtml = Array(lines)
.fill(null)
.map((_, i) => `<span class="ansi-red-fg">Line ${i}</span>`)
.join('\n');
const sanitized = DOMPurify.sanitize(largeHtml, {
ALLOWED_TAGS: ['span', 'br'],
ALLOWED_ATTR: ['class']
});
// Should preserve all spans
const spanMatches = sanitized.match(/<span/g);
expect(spanMatches).toHaveLength(lines);
});
});
describe('Edge Cases', () => {
it('should handle empty input gracefully', () => {
const result = ansiConverter.ansi_to_html('');
expect(result).toBe('');
});
it('should handle input with no ANSI codes', () => {
const plainText = 'This is plain text without any colors';
const result = ansiConverter.ansi_to_html(plainText);
expect(result).toBe(plainText);
});
it('should handle malformed ANSI codes', () => {
const malformed = '\x1b[999mInvalid color code\x1b[0m';
// Should not throw an error
expect(() => ansiConverter.ansi_to_html(malformed)).not.toThrow();
});
it('should handle incomplete ANSI sequences', () => {
const incomplete = '\x1b[31mRed text without reset';
const result = ansiConverter.ansi_to_html(incomplete);
expect(result).toContain('ansi-red-fg');
});
it('should handle ANSI codes at the beginning and end of lines', () => {
const input = '\x1b[31mStart\nMiddle\nEnd\x1b[0m';
const result = ansiConverter.ansi_to_html(input);
expect(result).toContain('ansi-red-fg');
});
});
});

View File

@@ -59,6 +59,8 @@ const mockLocation = {
hash: '', hash: '',
origin: 'http://mock-origin.com', origin: 'http://mock-origin.com',
pathname: '/login', pathname: '/login',
protocol: 'http:',
host: 'mock-origin.com',
get href() { get href() {
return mockLocationHref; return mockLocationHref;
}, },
@@ -253,7 +255,8 @@ describe('SsoButtons', () => {
expect(sessionStorage.setItem).toHaveBeenCalledWith('sso_provider', 'unraid-net'); expect(sessionStorage.setItem).toHaveBeenCalledWith('sso_provider', 'unraid-net');
const generatedState = (sessionStorage.setItem as Mock).mock.calls[0][1]; const generatedState = (sessionStorage.setItem as Mock).mock.calls[0][1];
const expectedUrl = `/graphql/api/auth/oidc/authorize/unraid-net?state=${encodeURIComponent(generatedState)}`; const redirectUri = `${mockLocation.origin}/graphql/api/auth/oidc/callback`;
const expectedUrl = `/graphql/api/auth/oidc/authorize/unraid-net?state=${encodeURIComponent(generatedState)}&redirect_uri=${encodeURIComponent(redirectUri)}`;
expect(mockLocation.href).toBe(expectedUrl); expect(mockLocation.href).toBe(expectedUrl);
}); });
@@ -377,6 +380,57 @@ describe('SsoButtons', () => {
expect(mockLocation.href).toBe(expectedUrl); expect(mockLocation.href).toBe(expectedUrl);
}); });
it('handles HTTPS with non-standard port correctly', async () => {
const mockProviders = [
{
id: 'tsidp',
name: 'Tailscale IDP',
buttonText: 'Sign in with Tailscale',
buttonIcon: null,
buttonVariant: 'secondary',
buttonStyle: null
}
];
// Set up location with HTTPS and non-standard port
mockLocation.protocol = 'https:';
mockLocation.host = 'unraid.mytailnet.ts.net:1443';
mockLocation.origin = 'https://unraid.mytailnet.ts.net:1443';
mockUseQuery.mockReturnValue({
result: { value: { publicOidcProviders: mockProviders } },
refetch: vi.fn().mockResolvedValue({ data: { publicOidcProviders: mockProviders } }),
});
const wrapper = mount(SsoButtons, {
global: {
stubs: {
SsoProviderButton: SsoProviderButtonStub,
Button: { template: '<button><slot /></button>' }
},
},
});
await flushPromises();
vi.runAllTimers();
await flushPromises();
const button = wrapper.find('button');
await button.trigger('click');
// Should include the correct redirect URI with HTTPS and port 1443
const generatedState = (sessionStorage.setItem as Mock).mock.calls[0][1];
const redirectUri = 'https://unraid.mytailnet.ts.net:1443/graphql/api/auth/oidc/callback';
const expectedUrl = `/graphql/api/auth/oidc/authorize/tsidp?state=${encodeURIComponent(generatedState)}&redirect_uri=${encodeURIComponent(redirectUri)}`;
expect(mockLocation.href).toBe(expectedUrl);
// Reset location mock for other tests
mockLocation.protocol = 'http:';
mockLocation.host = 'mock-origin.com';
mockLocation.origin = 'http://mock-origin.com';
});
it('handles multiple OIDC providers', async () => { it('handles multiple OIDC providers', async () => {
const mockProviders = [ const mockProviders = [
{ {

View File

@@ -1,175 +0,0 @@
import { describe, it, expect } from 'vitest';
import { useApiKeyAuthorization } from '~/composables/useApiKeyAuthorization';
import { AuthAction, Resource, Role } from '~/composables/gql/graphql';
describe('useApiKeyAuthorization', () => {
describe('parameter parsing', () => {
it('should parse query parameters correctly', () => {
const params = new URLSearchParams('?name=TestApp&scopes=docker:read,vms:*&redirect_uri=https://example.com&state=abc123');
const { authParams } = useApiKeyAuthorization(params);
expect(authParams.value.name).toBe('TestApp');
expect(authParams.value.scopes).toEqual(['docker:read', 'vms:*']);
expect(authParams.value.redirectUri).toBe('https://example.com');
expect(authParams.value.state).toBe('abc123');
});
it('should handle missing parameters with defaults', () => {
const params = new URLSearchParams('');
const { authParams } = useApiKeyAuthorization(params);
expect(authParams.value.name).toBe('Unknown Application');
expect(authParams.value.scopes).toEqual([]);
expect(authParams.value.redirectUri).toBe('');
expect(authParams.value.state).toBe('');
});
});
describe('formatPermissions', () => {
it('should format role scopes correctly', () => {
const params = new URLSearchParams('?scopes=role:admin,role:viewer');
const { formattedPermissions } = useApiKeyAuthorization(params);
expect(formattedPermissions.value).toEqual([
{
scope: 'role:admin',
name: 'ADMIN',
description: 'Grant admin role access',
isRole: true,
},
{
scope: 'role:viewer',
name: 'VIEWER',
description: 'Grant viewer role access',
isRole: true,
},
]);
});
it('should format resource:action scopes correctly', () => {
const params = new URLSearchParams('?scopes=docker:read,vms:*');
const { formattedPermissions } = useApiKeyAuthorization(params);
expect(formattedPermissions.value).toEqual([
{
scope: 'docker:read',
name: 'Docker - Read',
description: 'Read access to Docker',
isRole: false,
},
{
scope: 'vms:*',
name: 'Vms - Full',
description: 'Full access to Vms',
isRole: false,
},
]);
});
});
describe('convertScopesToPermissions', () => {
it('should convert role scopes to roles', () => {
const params = new URLSearchParams('?scopes=role:admin');
const { convertScopesToPermissions } = useApiKeyAuthorization(params);
const result = convertScopesToPermissions(['role:admin']);
expect(result.roles).toContain(Role.ADMIN);
expect(result.permissions).toEqual([]);
});
it('should convert resource scopes to permissions', () => {
const params = new URLSearchParams('?scopes=docker:read');
const { convertScopesToPermissions } = useApiKeyAuthorization(params);
const result = convertScopesToPermissions(['docker:read']);
expect(result.permissions).toEqual([
{
resource: Resource.DOCKER,
actions: [AuthAction.READ_ANY],
},
]);
expect(result.roles).toEqual([]);
});
it('should handle wildcard actions', () => {
const params = new URLSearchParams('?scopes=vms:*');
const { convertScopesToPermissions } = useApiKeyAuthorization(params);
const result = convertScopesToPermissions(['vms:*']);
expect(result.permissions).toEqual([
{
resource: Resource.VMS,
actions: [AuthAction.CREATE_ANY, AuthAction.READ_ANY, AuthAction.UPDATE_ANY, AuthAction.DELETE_ANY],
},
]);
});
it('should merge multiple actions for same resource', () => {
const params = new URLSearchParams('');
const { convertScopesToPermissions } = useApiKeyAuthorization(params);
const result = convertScopesToPermissions(['docker:read', 'docker:update']);
expect(result.permissions).toEqual([
{
resource: Resource.DOCKER,
actions: [AuthAction.READ_ANY, AuthAction.UPDATE_ANY],
},
]);
});
});
describe('redirect URI validation', () => {
it('should accept HTTPS URLs', () => {
const params = new URLSearchParams('?redirect_uri=https://example.com/callback');
const { hasValidRedirectUri } = useApiKeyAuthorization(params);
expect(hasValidRedirectUri.value).toBe(true);
});
it('should accept localhost URLs', () => {
const params = new URLSearchParams('?redirect_uri=http://localhost:3000/callback');
const { hasValidRedirectUri } = useApiKeyAuthorization(params);
expect(hasValidRedirectUri.value).toBe(true);
});
it('should accept HTTP URLs (non-localhost)', () => {
const params = new URLSearchParams('?redirect_uri=http://example.com/callback');
const { hasValidRedirectUri } = useApiKeyAuthorization(params);
expect(hasValidRedirectUri.value).toBe(true);
});
it('should reject invalid URLs', () => {
const params = new URLSearchParams('?redirect_uri=not-a-url');
const { hasValidRedirectUri } = useApiKeyAuthorization(params);
expect(hasValidRedirectUri.value).toBe(false);
});
});
describe('buildCallbackUrl', () => {
it('should build callback URL with API key', () => {
const params = new URLSearchParams('');
const { buildCallbackUrl } = useApiKeyAuthorization(params);
const url = buildCallbackUrl('https://example.com/callback', 'test-key', undefined, 'state123');
expect(url).toBe('https://example.com/callback?api_key=test-key&state=state123');
});
it('should build callback URL with error', () => {
const params = new URLSearchParams('');
const { buildCallbackUrl } = useApiKeyAuthorization(params);
const url = buildCallbackUrl('https://example.com/callback', undefined, 'access_denied', 'state123');
expect(url).toBe('https://example.com/callback?error=access_denied&state=state123');
});
it('should throw for invalid redirect URI', () => {
const params = new URLSearchParams('');
const { buildCallbackUrl } = useApiKeyAuthorization(params);
expect(() => buildCallbackUrl('not-a-url', 'key')).toThrow('Invalid redirect URI');
});
});
});

View File

@@ -0,0 +1,86 @@
import { ref } from 'vue';
import { vi } from 'vitest';
// Types for mock data
export interface MockLogFile {
content: string;
totalLines: number;
startLine: number;
}
export interface MockQueryResult {
logFile: MockLogFile;
}
/**
* Creates a mock useQuery return value with optional result data
* Using unknown return type to avoid complex Apollo type issues in tests
*/
export function createMockUseQuery<TData = unknown>(
resultData: TData | null = null,
options: {
loading?: boolean;
error?: unknown;
} = {}
): unknown {
return {
result: ref(resultData),
loading: ref(options.loading ?? false),
error: ref(options.error ?? null),
refetch: vi.fn(() => Promise.resolve({
data: resultData || {},
loading: false,
networkStatus: 7,
stale: false,
error: undefined
})),
subscribeToMore: vi.fn(),
networkStatus: ref(7),
start: vi.fn(),
stop: vi.fn(),
restart: vi.fn(),
forceDisabled: ref(false),
document: ref(null),
variables: ref({}),
options: {},
query: ref(null),
fetchMore: vi.fn(),
updateQuery: vi.fn(),
onResult: vi.fn(),
onError: vi.fn()
};
}
/**
* Creates a mock useQuery specifically for log file data
*/
export function createMockLogFileQuery(
content: string,
totalLines: number,
startLine: number = 1
): unknown {
const result: MockQueryResult = {
logFile: {
content,
totalLines,
startLine
}
};
return createMockUseQuery(result);
}
/**
* Factory function to create the mock module object for @vue/apollo-composable
* Call this at the top level of test files: vi.mock('@vue/apollo-composable', () => apolloComposableMockFactory())
*/
export function apolloComposableMockFactory() {
return {
useApolloClient: vi.fn(() => ({
client: {
query: vi.fn()
}
})),
useQuery: vi.fn(() => createMockUseQuery())
};
}

View File

@@ -16,6 +16,7 @@ import { useServerStore } from '~/store/server';
// import type { ConnectSettingsValues } from '~/composables/gql/graphql'; // import type { ConnectSettingsValues } from '~/composables/gql/graphql';
import { getConnectSettingsForm, updateConnectSettings } from './graphql/settings.query'; import { getConnectSettingsForm, updateConnectSettings } from './graphql/settings.query';
import OidcDebugLogs from './OidcDebugLogs.vue';
const { connectPluginInstalled } = storeToRefs(useServerStore()); const { connectPluginInstalled } = storeToRefs(useServerStore());
@@ -120,6 +121,9 @@ const onChange = ({ data }: { data: Record<string, unknown> }) => {
:readonly="isUpdating" :readonly="isUpdating"
@change="onChange" @change="onChange"
/> />
<!-- OIDC Debug Logs -->
<OidcDebugLogs />
<!-- form submission & fallback reaction message --> <!-- form submission & fallback reaction message -->
<div class="mt-6 grid grid-cols-settings gap-y-6 items-baseline"> <div class="mt-6 grid grid-cols-settings gap-y-6 items-baseline">
<div class="text-sm text-end"> <div class="text-sm text-end">

View File

@@ -0,0 +1,59 @@
<script setup lang="ts">
import { ref } from 'vue';
import SingleLogViewer from '../Logs/SingleLogViewer.vue';
import LogViewerToolbar from '../Logs/LogViewerToolbar.vue';
const showLogs = ref(false);
const autoScroll = ref(true);
const filterText = ref('OIDC');
const logViewerRef = ref<InstanceType<typeof SingleLogViewer> | null>(null);
const logFilePath = '/var/log/graphql-api.log';
const refreshLogs = () => {
logViewerRef.value?.refreshLogContent();
};
</script>
<template>
<div class="mt-6 border-2 border-solid rounded-md shadow-md bg-background border-muted">
<LogViewerToolbar
v-model:filter-text="filterText"
v-model:is-expanded="showLogs"
title="OIDC Debug Logs"
description="View real-time OIDC authentication and configuration logs"
:show-toggle="true"
:show-refresh="true"
filter-placeholder="Filter logs..."
@refresh="refreshLogs"
/>
<div v-if="showLogs" class="p-4 pt-0">
<div class="border rounded-lg bg-muted/30 h-[400px] overflow-hidden">
<SingleLogViewer
ref="logViewerRef"
:log-file-path="logFilePath"
:line-count="100"
:auto-scroll="autoScroll"
:client-filter="filterText"
highlight-language="plaintext"
class="h-full"
/>
</div>
<div class="mt-2 flex justify-between items-center text-xs text-muted-foreground">
<span>
{{ filterText ? `Filtering logs for: "${filterText}"` : 'Showing all log entries' }}
</span>
<label class="flex items-center gap-2 cursor-pointer">
<input
v-model="autoScroll"
type="checkbox"
class="rounded border-gray-300"
>
<span>Auto-scroll</span>
</label>
</div>
</div>
</div>
</template>

View File

@@ -0,0 +1,82 @@
<script setup lang="ts">
import { computed } from 'vue';
import { useContentHighlighting } from '~/composables/useContentHighlighting';
const props = defineProps<{
content: string;
language?: string;
showLineNumbers?: boolean;
maxHeight?: string;
class?: string;
}>();
const { highlightContent } = useContentHighlighting();
const highlightedContent = computed(() => {
return highlightContent(props.content, props.language);
});
const lines = computed(() => {
return props.content.split('\n');
});
</script>
<template>
<div
:class="[
'file-viewer-container',
'relative rounded border bg-background text-foreground overflow-hidden',
props.class
]"
:style="{ height: maxHeight || '300px' }"
>
<div class="absolute inset-0 overflow-auto">
<div class="flex min-w-full">
<!-- Line numbers -->
<div
v-if="showLineNumbers"
class="flex-shrink-0 select-none border-r bg-muted/50 px-2 py-2 text-xs font-mono text-muted-foreground"
>
<div v-for="(_, index) in lines" :key="index" class="leading-5 text-right pr-2">
{{ index + 1 }}
</div>
</div>
<!-- Content -->
<div class="flex-1 min-w-0">
<pre
class="p-3 text-xs font-mono leading-5 whitespace-pre m-0"
v-html="highlightedContent"
/>
</div>
</div>
</div>
</div>
</template>
<style scoped>
/* Add some basic styling for the highlighted content */
:deep(.hljs) {
background: transparent;
}
/* ANSI color classes */
:deep(.ansi-bright-black) { color: #666; }
:deep(.ansi-bright-red) { color: #ff6b6b; }
:deep(.ansi-bright-green) { color: #51cf66; }
:deep(.ansi-bright-yellow) { color: #ffd43b; }
:deep(.ansi-bright-blue) { color: #339af0; }
:deep(.ansi-bright-magenta) { color: #f06292; }
:deep(.ansi-bright-cyan) { color: #22d3ee; }
:deep(.ansi-bright-white) { color: #f8f9fa; }
/* Standard ANSI colors for dark theme */
:deep(.ansi-black) { color: #000; }
:deep(.ansi-red) { color: #e03131; }
:deep(.ansi-green) { color: #2f9e44; }
:deep(.ansi-yellow) { color: #f59f00; }
:deep(.ansi-blue) { color: #1971c2; }
:deep(.ansi-magenta) { color: #c2255c; }
:deep(.ansi-cyan) { color: #0891b2; }
:deep(.ansi-white) { color: #495057; }
</style>

View File

@@ -0,0 +1,67 @@
<script setup lang="ts">
import { ref, computed } from 'vue';
import { Dialog } from '@unraid/ui';
import SingleLogViewer from './SingleLogViewer.vue';
interface Props {
modelValue: boolean;
logFilePath: string;
filter?: string;
title?: string;
description?: string;
lineCount?: number;
autoScroll?: boolean;
highlightLanguage?: string;
size?: 'sm' | 'md' | 'lg' | 'xl' | 'full';
}
const props = withDefaults(defineProps<Props>(), {
lineCount: 100,
autoScroll: true,
highlightLanguage: 'plaintext',
size: 'xl',
title: 'Log Viewer',
filter: undefined,
description: undefined,
});
const emit = defineEmits<{
'update:modelValue': [value: boolean];
}>();
const logViewerRef = ref<InstanceType<typeof SingleLogViewer> | null>(null);
const fullLogPath = computed(() => {
if (props.logFilePath.startsWith('/')) {
return props.logFilePath;
}
return `/var/log/${props.logFilePath}`;
});
const handleOpenChange = (open: boolean) => {
emit('update:modelValue', open);
};
</script>
<template>
<Dialog
:model-value="modelValue"
:title="title"
:description="description"
:size="size"
:show-footer="false"
@update:model-value="handleOpenChange"
>
<div class="h-[600px] flex flex-col">
<SingleLogViewer
ref="logViewerRef"
:log-file-path="fullLogPath"
:line-count="lineCount"
:auto-scroll="autoScroll"
:highlight-language="highlightLanguage"
:filter="filter"
class="flex-1"
/>
</div>
</Dialog>
</template>

View File

@@ -0,0 +1,86 @@
<script setup lang="ts">
import { computed } from 'vue';
import { Input, Label, Select } from '@unraid/ui';
import type { SelectItemType } from '@unraid/ui';
import { MagnifyingGlassIcon } from '@heroicons/vue/24/outline';
const props = withDefaults(defineProps<{
modelValue: string;
preset?: string;
showPresets?: boolean;
presetFilters?: SelectItemType[];
inputClass?: string;
placeholder?: string;
label?: string;
showIcon?: boolean;
}>(), {
preset: 'none',
showPresets: false,
presetFilters: () => [
{ value: 'none', label: 'No Filter' },
{ value: 'OIDC', label: 'OIDC Logs' },
{ value: 'ERROR', label: 'Errors' },
{ value: 'WARNING', label: 'Warnings' },
{ value: 'AUTH', label: 'Authentication' },
],
placeholder: 'Filter logs...',
label: 'Filter',
showIcon: true,
inputClass: ''
});
const emit = defineEmits<{
'update:modelValue': [value: string];
'update:preset': [value: string];
}>();
const filterText = computed({
get: () => props.modelValue,
set: (value) => emit('update:modelValue', value)
});
const presetValue = computed({
get: () => props.preset || 'none',
set: (value) => {
emit('update:preset', value);
if (value && value !== 'none') {
emit('update:modelValue', value);
} else if (value === 'none') {
emit('update:modelValue', '');
}
}
});
</script>
<template>
<div class="flex gap-2 items-end">
<div v-if="showPresets" class="min-w-[150px]">
<Label v-if="label" :for="`preset-filter-${$.uid}`">Quick {{ label }}</Label>
<Select
:id="`preset-filter-${$.uid}`"
v-model="presetValue"
:items="presetFilters"
placeholder="Select filter"
class="w-full"
/>
</div>
<div class="relative flex-1">
<Label v-if="label && !showPresets" :for="`filter-input-${$.uid}`">{{ label }}</Label>
<Label v-else-if="label && showPresets" :for="`filter-input-${$.uid}`">Custom {{ label }}</Label>
<div class="relative">
<MagnifyingGlassIcon
v-if="showIcon"
class="absolute left-2 top-1/2 -translate-y-1/2 h-4 w-4 text-muted-foreground"
/>
<Input
:id="`filter-input-${$.uid}`"
v-model="filterText"
type="text"
:placeholder="placeholder"
:class="[showIcon ? 'pl-8' : '', inputClass]"
/>
</div>
</div>
</div>
</template>

View File

@@ -11,6 +11,7 @@ import {
import { GET_LOG_FILES } from './log.query'; import { GET_LOG_FILES } from './log.query';
import SingleLogViewer from './SingleLogViewer.vue'; import SingleLogViewer from './SingleLogViewer.vue';
import LogViewerToolbar from './LogViewerToolbar.vue';
// Types // Types
interface LogFile { interface LogFile {
@@ -20,10 +21,12 @@ interface LogFile {
} }
// Component state // Component state
const selectedLogFile = ref<string>(''); const selectedLogFile = ref<string | null>(null);
const lineCount = ref<number>(100); const lineCount = ref<number>(100);
const autoScroll = ref<boolean>(true); const autoScroll = ref<boolean>(true);
const highlightLanguage = ref<string>('plaintext'); const highlightLanguage = ref<string>('plaintext');
const filterText = ref<string>('');
const presetFilter = ref<string>('none');
// Available highlight languages // Available highlight languages
const highlightLanguages = [ const highlightLanguages = [
@@ -39,6 +42,15 @@ const highlightLanguages = [
{ value: 'php', label: 'PHP' }, { value: 'php', label: 'PHP' },
]; ];
// Preset filter options
const presetFilters = [
{ value: 'none', label: 'No Filter' },
{ value: 'OIDC', label: 'OIDC Logs' },
{ value: 'ERROR', label: 'Errors' },
{ value: 'WARNING', label: 'Warnings' },
{ value: 'AUTH', label: 'Authentication' },
];
// Fetch log files // Fetch log files
const { const {
result: logFilesResult, result: logFilesResult,
@@ -102,15 +114,32 @@ watch(selectedLogFile, (newValue) => {
highlightLanguage.value = autoDetectLanguage(newValue); highlightLanguage.value = autoDetectLanguage(newValue);
} }
}); });
// Watch for preset filter changes to update the filter text
watch(presetFilter, (newValue) => {
if (newValue && newValue !== 'none') {
filterText.value = newValue;
} else if (newValue === 'none') {
filterText.value = '';
}
});
</script> </script>
<template> <template>
<div <div
class="flex flex-col h-[500px] resize-y bg-background text-foreground rounded-lg border border-border overflow-hidden" class="flex flex-col h-[500px] resize-y bg-background text-foreground rounded-lg border border-border overflow-hidden"
> >
<LogViewerToolbar
v-model:filter-text="filterText"
v-model:preset-filter="presetFilter"
title="Log Viewer"
:show-presets="true"
:preset-filters="presetFilters"
:show-toggle="false"
:show-refresh="false"
/>
<div class="p-4 border-b border-border"> <div class="p-4 border-b border-border">
<h2 class="text-lg font-semibold mb-4">Log Viewer</h2>
<div class="flex flex-wrap gap-4 items-end"> <div class="flex flex-wrap gap-4 items-end">
<div class="flex-1 min-w-[200px]"> <div class="flex-1 min-w-[200px]">
<Label for="log-file-select">Log File</Label> <Label for="log-file-select">Log File</Label>
@@ -186,6 +215,7 @@ watch(selectedLogFile, (newValue) => {
:line-count="lineCount" :line-count="lineCount"
:auto-scroll="autoScroll" :auto-scroll="autoScroll"
:highlight-language="highlightLanguage" :highlight-language="highlightLanguage"
:client-filter="filterText"
class="h-full" class="h-full"
/> />
</div> </div>

View File

@@ -0,0 +1,122 @@
<script setup lang="ts">
import { computed } from 'vue';
import { Button } from '@unraid/ui';
import type { SelectItemType } from '@unraid/ui';
import { ArrowPathIcon, EyeIcon, EyeSlashIcon, ChevronDownIcon, ChevronUpIcon } from '@heroicons/vue/24/outline';
import LogFilterInput from './LogFilterInput.vue';
const props = withDefaults(defineProps<{
title?: string;
description?: string;
filterText: string;
showToggle?: boolean;
isExpanded?: boolean;
showRefresh?: boolean;
showPresets?: boolean;
presetFilter?: string;
presetFilters?: SelectItemType[];
filterPlaceholder?: string;
filterLabel?: string;
compact?: boolean;
}>(), {
title: '',
description: '',
showToggle: false,
isExpanded: true,
showRefresh: true,
showPresets: false,
presetFilter: 'none',
presetFilters: () => [],
filterPlaceholder: 'Filter logs...',
filterLabel: 'Filter',
compact: false
});
const emit = defineEmits<{
'update:filterText': [value: string];
'update:presetFilter': [value: string];
'update:isExpanded': [value: boolean];
'refresh': [];
}>();
const filterValue = computed({
get: () => props.filterText,
set: (value) => emit('update:filterText', value)
});
const presetValue = computed({
get: () => props.presetFilter || 'none',
set: (value) => emit('update:presetFilter', value)
});
const toggleExpanded = () => {
emit('update:isExpanded', !props.isExpanded);
};
const handleRefresh = () => {
emit('refresh');
};
</script>
<template>
<div :class="['border-b', compact ? 'p-3' : 'p-4 pb-3', 'border-muted']">
<div class="flex justify-between items-center">
<div v-if="title || description">
<h3 v-if="title" :class="[compact ? 'text-sm' : 'text-base', 'font-semibold']">{{ title }}</h3>
<p v-if="description" :class="['mt-1 text-muted-foreground', compact ? 'text-xs' : 'text-sm']">
{{ description }}
</p>
</div>
<div class="flex gap-2 items-center" :class="!title && !description ? 'w-full' : ''">
<div :class="!title && !description ? 'flex-1' : ''">
<LogFilterInput
v-model="filterValue"
v-model:preset="presetValue"
:show-presets="showPresets"
:preset-filters="presetFilters"
:placeholder="filterPlaceholder"
:label="compact || (!title && !description) ? filterLabel : ''"
:input-class="compact ? 'h-7 text-sm' : 'h-8'"
/>
</div>
<Button
v-if="showRefresh"
variant="outline"
:size="compact ? 'sm' : 'sm'"
title="Refresh logs"
@click="handleRefresh"
>
<ArrowPathIcon :class="compact ? 'h-3 w-3' : 'h-4 w-4'" />
</Button>
<Button
v-if="showToggle"
variant="outline"
:size="compact ? 'sm' : 'sm'"
@click="toggleExpanded"
>
<component
:is="isExpanded ? EyeSlashIcon : EyeIcon"
:class="compact ? 'h-3 w-3' : 'h-4 w-4'"
/>
<span v-if="!compact" class="ml-2">{{ isExpanded ? 'Hide' : 'Show' }} Logs</span>
</Button>
<Button
v-else-if="showToggle === false && typeof isExpanded === 'boolean'"
variant="ghost"
:size="compact ? 'sm' : 'sm'"
class="p-1"
@click="toggleExpanded"
>
<component
:is="isExpanded ? ChevronUpIcon : ChevronDownIcon"
:class="compact ? 'h-3 w-3' : 'h-4 w-4'"
/>
</Button>
</div>
</div>
</div>
</template>

View File

@@ -0,0 +1,28 @@
<script setup lang="ts">
import { ref } from 'vue';
import { Button } from '@unraid/ui';
import { BugAntIcon } from '@heroicons/vue/24/outline';
import FilteredLogModal from './FilteredLogModal.vue';
const showOidcLogs = ref(false);
</script>
<template>
<div>
<Button variant="outline" size="sm" @click="showOidcLogs = true">
<BugAntIcon class="w-4 h-4 mr-2" />
View OIDC Debug Logs
</Button>
<FilteredLogModal
v-model="showOidcLogs"
log-file-path="graphql-api.log"
filter="OIDC"
title="OIDC Debug Logs"
description="Viewing OIDC authentication and configuration logs"
:line-count="200"
:auto-scroll="true"
highlight-language="plaintext"
/>
</div>
</template>

View File

@@ -5,22 +5,7 @@ import { vInfiniteScroll } from '@vueuse/components';
import { ArrowDownTrayIcon, ArrowPathIcon } from '@heroicons/vue/24/outline'; import { ArrowDownTrayIcon, ArrowPathIcon } from '@heroicons/vue/24/outline';
import { Button, Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '@unraid/ui'; import { Button, Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '@unraid/ui';
import hljs from 'highlight.js/lib/core'; import { useContentHighlighting } from '~/composables/useContentHighlighting';
import DOMPurify from 'isomorphic-dompurify';
import 'highlight.js/styles/github-dark.css'; // You can choose a different style
import apache from 'highlight.js/lib/languages/apache';
import bash from 'highlight.js/lib/languages/bash';
import ini from 'highlight.js/lib/languages/ini';
import javascript from 'highlight.js/lib/languages/javascript';
import json from 'highlight.js/lib/languages/json';
import nginx from 'highlight.js/lib/languages/nginx';
import php from 'highlight.js/lib/languages/php';
// Register the languages you want to support
import plaintext from 'highlight.js/lib/languages/plaintext';
import xml from 'highlight.js/lib/languages/xml';
import yaml from 'highlight.js/lib/languages/yaml';
import type { LogFileContentQuery, LogFileContentQueryVariables } from '~/composables/gql/graphql'; import type { LogFileContentQuery, LogFileContentQueryVariables } from '~/composables/gql/graphql';
@@ -32,28 +17,17 @@ import { LOG_FILE_SUBSCRIPTION } from './log.subscription';
const themeStore = useThemeStore(); const themeStore = useThemeStore();
const isDarkMode = computed(() => themeStore.darkMode); const isDarkMode = computed(() => themeStore.darkMode);
// Register the languages // Use shared highlighting logic
hljs.registerLanguage('plaintext', plaintext); const { highlightContent } = useContentHighlighting();
hljs.registerLanguage('bash', bash);
hljs.registerLanguage('ini', ini);
hljs.registerLanguage('xml', xml);
hljs.registerLanguage('json', json);
hljs.registerLanguage('yaml', yaml);
hljs.registerLanguage('nginx', nginx);
hljs.registerLanguage('apache', apache);
hljs.registerLanguage('javascript', javascript);
hljs.registerLanguage('php', php);
const props = defineProps<{ const props = defineProps<{
logFilePath: string; logFilePath: string;
lineCount: number; lineCount: number;
autoScroll: boolean; autoScroll: boolean;
highlightLanguage?: string; // Optional prop to specify the language for highlighting highlightLanguage?: string; // Optional prop to specify the language for highlighting
clientFilter?: string; // Optional client-side filter to apply to log content
}>(); }>();
// Default language for highlighting
const defaultLanguage = 'plaintext';
const DEFAULT_CHUNK_SIZE = 100; const DEFAULT_CHUNK_SIZE = 100;
const scrollViewportRef = ref<HTMLElement | null>(null); const scrollViewportRef = ref<HTMLElement | null>(null);
const state = reactive({ const state = reactive({
@@ -111,44 +85,8 @@ onMounted(() => {
observer.observe(scrollViewportRef.value as unknown as Node, { childList: true, subtree: true }); observer.observe(scrollViewportRef.value as unknown as Node, { childList: true, subtree: true });
} }
if (props.logFilePath) { // Start the log subscription
subscribeToMore({ startLogSubscription();
document: LOG_FILE_SUBSCRIPTION,
variables: { path: props.logFilePath },
updateQuery: (prev, { subscriptionData }) => {
if (!subscriptionData.data || !prev) return prev;
// Set subscription as active when we receive data
state.isSubscriptionActive = true;
const existingContent = prev.logFile?.content || '';
const newContent = subscriptionData.data.logFile.content;
// Update the local state with the new content
if (newContent && state.loadedContentChunks.length > 0) {
const lastChunk = state.loadedContentChunks[state.loadedContentChunks.length - 1];
lastChunk.content += newContent;
// Force scroll to bottom if auto-scroll is enabled
if (props.autoScroll) {
nextTick(() => forceScrollToBottom());
}
}
return {
...prev,
logFile: {
...prev.logFile,
content: existingContent + newContent,
totalLines: (prev.logFile?.totalLines || 0) + (newContent.split('\n').length - 1),
},
};
},
});
// Set subscription as active
state.isSubscriptionActive = true;
}
}); });
// Cleanup observer on unmount // Cleanup observer on unmount
@@ -188,75 +126,33 @@ watch(
{ deep: true } { deep: true }
); );
// Function to highlight log content // Function to highlight log content using shared composable
const highlightLog = (content: string): string => { const highlightLog = (content: string): string => {
try { return highlightContent(content, props.highlightLanguage);
// Determine which language to use for highlighting
const language = props.highlightLanguage || defaultLanguage;
// Apply syntax highlighting
let highlighted = hljs.highlight(content, { language }).value;
// Apply additional custom highlighting for common log patterns
// Highlight timestamps (various formats)
highlighted = highlighted.replace(
/\b(\d{4}-\d{2}-\d{2}[T ]\d{2}:\d{2}:\d{2}(?:\.\d+)?(?:Z|[+-]\d{2}:?\d{2})?)\b/g,
'<span class="hljs-timestamp">$1</span>'
);
// Highlight IP addresses
highlighted = highlighted.replace(
/\b(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})\b/g,
'<span class="hljs-ip">$1</span>'
);
// Split the content into lines
let lines = highlighted.split('\n');
// Process each line to add error, warning, and success highlighting
lines = lines.map((line) => {
if (/(error|exception|fail|failed|failure)/i.test(line)) {
// Highlight error keywords
line = line.replace(
/\b(error|exception|fail|failed|failure)\b/gi,
'<span class="hljs-error-keyword">$1</span>'
);
// Wrap the entire line
return `<span class="hljs-error">${line}</span>`;
} else if (/(warning|warn)/i.test(line)) {
// Highlight warning keywords
line = line.replace(/\b(warning|warn)\b/gi, '<span class="hljs-warning-keyword">$1</span>');
// Wrap the entire line
return `<span class="hljs-warning">${line}</span>`;
} else if (/(success|successful|completed|done)/i.test(line)) {
// Highlight success keywords
line = line.replace(
/\b(success|successful|completed|done)\b/gi,
'<span class="hljs-success-keyword">$1</span>'
);
// Wrap the entire line
return `<span class="hljs-success">${line}</span>`;
}
return line;
});
// Join the lines back together
highlighted = lines.join('\n');
// Sanitize the highlighted HTML
return DOMPurify.sanitize(highlighted);
} catch (error) {
console.error('Error highlighting log content:', error);
// Fallback to sanitized but not highlighted content
return DOMPurify.sanitize(content);
}
}; };
// Apply client-side filtering
const filteredContent = computed(() => {
// Join chunks ensuring proper newline handling
const rawContent = state.loadedContentChunks
.map((chunk) => chunk.content)
.filter(content => content) // Remove empty chunks
.join(''); // Content should already have proper newlines
// Apply client-side filter if provided
if (props.clientFilter && props.clientFilter.trim()) {
const filterLower = props.clientFilter.toLowerCase();
const lines = rawContent.split('\n');
const filtered = lines.filter(line => line.toLowerCase().includes(filterLower));
return filtered.join('\n');
}
return rawContent;
});
// Computed properties // Computed properties
const logContent = computed(() => { const logContent = computed(() => {
const rawContent = state.loadedContentChunks.map((chunk) => chunk.content).join(''); return highlightLog(filteredContent.value);
return highlightLog(rawContent);
}); });
const totalLines = computed(() => logContentResult.value?.logFile?.totalLines || 0); const totalLines = computed(() => logContentResult.value?.logFile?.totalLines || 0);
@@ -339,15 +235,84 @@ const downloadLogFile = async () => {
} }
}; };
// Refresh logs // Clear all state to initial values
const refreshLogContent = () => { const clearState = () => {
state.loadedContentChunks = []; state.loadedContentChunks = [];
state.currentStartLine = undefined; state.currentStartLine = undefined;
state.isAtTop = false; state.isAtTop = false;
state.canLoadMore = false; state.canLoadMore = false;
state.initialLoadComplete = false; state.initialLoadComplete = false;
state.isLoadingMore = false; state.isLoadingMore = false;
refetchLogContent(); };
// Helper function to start log subscription
const startLogSubscription = () => {
if (!props.logFilePath) return;
try {
subscribeToMore({
document: LOG_FILE_SUBSCRIPTION,
variables: { path: props.logFilePath },
updateQuery: (prev, { subscriptionData }) => {
if (!subscriptionData.data || !prev) return prev;
// Set subscription as active when we receive data
state.isSubscriptionActive = true;
const existingContent = prev.logFile?.content || '';
const newContent = subscriptionData.data.logFile.content;
// Update the local state with the new content
if (newContent && state.loadedContentChunks.length > 0) {
const lastChunk = state.loadedContentChunks[state.loadedContentChunks.length - 1];
// Ensure there's a newline between the existing content and new content if needed
if (lastChunk.content && !lastChunk.content.endsWith('\n') && newContent) {
lastChunk.content += '\n' + newContent;
} else {
lastChunk.content += newContent;
}
// Force scroll to bottom if auto-scroll is enabled
if (props.autoScroll) {
nextTick(() => forceScrollToBottom());
}
}
return {
...prev,
logFile: {
...prev.logFile,
content: existingContent + newContent,
totalLines: (prev.logFile?.totalLines || 0) + (newContent.split('\n').length - 1),
},
};
},
});
// Set subscription as active
state.isSubscriptionActive = true;
} catch (error) {
console.error('Error starting log subscription:', error);
state.isSubscriptionActive = false;
}
};
// Refresh logs
const refreshLogContent = async () => {
// Clear the state
clearState();
// Refetch with explicit variables to ensure we get the latest logs
await refetchLogContent({
path: props.logFilePath,
lines: props.lineCount || DEFAULT_CHUNK_SIZE,
startLine: undefined, // Explicitly pass undefined to get the latest lines
});
// Restart the subscription with the same variables used for refetch
// Note: subscribeToMore in Vue Apollo doesn't return an unsubscribe function
// The previous subscription is automatically replaced when calling subscribeToMore again
startLogSubscription();
nextTick(() => { nextTick(() => {
forceScrollToBottom(); forceScrollToBottom();
@@ -436,7 +401,7 @@ defineExpose({ refreshLogContent });
</div> </div>
<pre <pre
class="font-mono whitespace-pre-wrap p-4 m-0 text-xs leading-6 hljs" class="font-mono whitespace-pre p-4 m-0 text-xs leading-6 hljs"
:class="{ 'theme-dark': isDarkMode, 'theme-light': !isDarkMode }" :class="{ 'theme-dark': isDarkMode, 'theme-light': !isDarkMode }"
v-html="logContent" v-html="logContent"
/> />
@@ -612,4 +577,51 @@ defineExpose({ refreshLogContent });
color: var(--log-success-color); color: var(--log-success-color);
font-weight: bold; font-weight: bold;
} }
/* ANSI color styles for ansi_up output - using format: ansi-{color}-fg/bg */
/* Foreground colors */
:deep(.ansi-black-fg) { color: #000; }
:deep(.ansi-red-fg) { color: #c91b00; }
:deep(.ansi-green-fg) { color: #00c200; }
:deep(.ansi-yellow-fg) { color: #c7c400; }
:deep(.ansi-blue-fg) { color: #0225c7; }
:deep(.ansi-magenta-fg) { color: #c930c7; }
:deep(.ansi-cyan-fg) { color: #00c5c7; }
:deep(.ansi-white-fg) { color: #c7c7c7; }
/* Bright foreground colors */
:deep(.ansi-bright-black-fg) { color: #676767; }
:deep(.ansi-bright-red-fg) { color: #ff6d67; }
:deep(.ansi-bright-green-fg) { color: #5ff967; }
:deep(.ansi-bright-yellow-fg) { color: #fefb67; }
:deep(.ansi-bright-blue-fg) { color: #6871ff; }
:deep(.ansi-bright-magenta-fg) { color: #ff76ff; }
:deep(.ansi-bright-cyan-fg) { color: #5ffdff; }
:deep(.ansi-bright-white-fg) { color: #fff; }
/* Background colors */
:deep(.ansi-black-bg) { background-color: #000; }
:deep(.ansi-red-bg) { background-color: #c91b00; }
:deep(.ansi-green-bg) { background-color: #00c200; }
:deep(.ansi-yellow-bg) { background-color: #c7c400; }
:deep(.ansi-blue-bg) { background-color: #0225c7; }
:deep(.ansi-magenta-bg) { background-color: #c930c7; }
:deep(.ansi-cyan-bg) { background-color: #00c5c7; }
:deep(.ansi-white-bg) { background-color: #c7c7c7; }
/* Bright background colors */
:deep(.ansi-bright-black-bg) { background-color: #676767; }
:deep(.ansi-bright-red-bg) { background-color: #ff6d67; }
:deep(.ansi-bright-green-bg) { background-color: #5ff967; }
:deep(.ansi-bright-yellow-bg) { background-color: #fefb67; }
:deep(.ansi-bright-blue-bg) { background-color: #6871ff; }
:deep(.ansi-bright-magenta-bg) { background-color: #ff76ff; }
:deep(.ansi-bright-cyan-bg) { background-color: #5ffdff; }
:deep(.ansi-bright-white-bg) { background-color: #fff; }
/* Additional ansi_up classes */
:deep(.ansi-bold) { font-weight: bold; }
:deep(.ansi-italic) { font-style: italic; }
:deep(.ansi-underline) { text-decoration: underline; }
:deep(.ansi-strike) { text-decoration: line-through; }
</style> </style>

View File

@@ -70,8 +70,11 @@ export function useSsoAuth() {
sessionStorage.setItem('sso_state', state); sessionStorage.setItem('sso_state', state);
sessionStorage.setItem('sso_provider', providerId); sessionStorage.setItem('sso_provider', providerId);
// Redirect to OIDC authorization endpoint with just the state token // Build the redirect URI based on current window location
const authUrl = `/graphql/api/auth/oidc/authorize/${encodeURIComponent(providerId)}?state=${encodeURIComponent(state)}`; const redirectUri = `${window.location.protocol}//${window.location.host}/graphql/api/auth/oidc/callback`;
// Redirect to OIDC authorization endpoint with state token and redirect URI
const authUrl = `/graphql/api/auth/oidc/authorize/${encodeURIComponent(providerId)}?state=${encodeURIComponent(state)}&redirect_uri=${encodeURIComponent(redirectUri)}`;
window.location.href = authUrl; window.location.href = authUrl;
}; };

View File

@@ -49,6 +49,8 @@ type Documents = {
"\n query InfoVersions {\n info {\n id\n os {\n id\n hostname\n }\n versions {\n id\n core {\n unraid\n api\n }\n }\n }\n }\n": typeof types.InfoVersionsDocument, "\n query InfoVersions {\n info {\n id\n os {\n id\n hostname\n }\n versions {\n id\n core {\n unraid\n api\n }\n }\n }\n }\n": typeof types.InfoVersionsDocument,
"\n query OidcProviders {\n settings {\n sso {\n oidcProviders {\n id\n name\n clientId\n issuer\n authorizationEndpoint\n tokenEndpoint\n jwksUri\n scopes\n authorizationRules {\n claim\n operator\n value\n }\n authorizationRuleMode\n buttonText\n buttonIcon\n }\n }\n }\n }\n": typeof types.OidcProvidersDocument, "\n query OidcProviders {\n settings {\n sso {\n oidcProviders {\n id\n name\n clientId\n issuer\n authorizationEndpoint\n tokenEndpoint\n jwksUri\n scopes\n authorizationRules {\n claim\n operator\n value\n }\n authorizationRuleMode\n buttonText\n buttonIcon\n }\n }\n }\n }\n": typeof types.OidcProvidersDocument,
"\n query PublicOidcProviders {\n publicOidcProviders {\n id\n name\n buttonText\n buttonIcon\n buttonVariant\n buttonStyle\n }\n }\n": typeof types.PublicOidcProvidersDocument, "\n query PublicOidcProviders {\n publicOidcProviders {\n id\n name\n buttonText\n buttonIcon\n buttonVariant\n buttonStyle\n }\n }\n": typeof types.PublicOidcProvidersDocument,
"\n query AllConfigFiles {\n allConfigFiles {\n files {\n name\n content\n path\n sizeReadable\n }\n }\n }\n": typeof types.AllConfigFilesDocument,
"\n query ConfigFile($name: String!) {\n configFile(name: $name) {\n name\n content\n path\n sizeReadable\n }\n }\n": typeof types.ConfigFileDocument,
"\n query serverInfo {\n info {\n os {\n hostname\n }\n }\n vars {\n comment\n }\n }\n": typeof types.ServerInfoDocument, "\n query serverInfo {\n info {\n os {\n hostname\n }\n }\n vars {\n comment\n }\n }\n": typeof types.ServerInfoDocument,
"\n mutation ConnectSignIn($input: ConnectSignInInput!) {\n connectSignIn(input: $input)\n }\n": typeof types.ConnectSignInDocument, "\n mutation ConnectSignIn($input: ConnectSignInInput!) {\n connectSignIn(input: $input)\n }\n": typeof types.ConnectSignInDocument,
"\n mutation SignOut {\n connectSignOut\n }\n": typeof types.SignOutDocument, "\n mutation SignOut {\n connectSignOut\n }\n": typeof types.SignOutDocument,
@@ -94,6 +96,8 @@ const documents: Documents = {
"\n query InfoVersions {\n info {\n id\n os {\n id\n hostname\n }\n versions {\n id\n core {\n unraid\n api\n }\n }\n }\n }\n": types.InfoVersionsDocument, "\n query InfoVersions {\n info {\n id\n os {\n id\n hostname\n }\n versions {\n id\n core {\n unraid\n api\n }\n }\n }\n }\n": types.InfoVersionsDocument,
"\n query OidcProviders {\n settings {\n sso {\n oidcProviders {\n id\n name\n clientId\n issuer\n authorizationEndpoint\n tokenEndpoint\n jwksUri\n scopes\n authorizationRules {\n claim\n operator\n value\n }\n authorizationRuleMode\n buttonText\n buttonIcon\n }\n }\n }\n }\n": types.OidcProvidersDocument, "\n query OidcProviders {\n settings {\n sso {\n oidcProviders {\n id\n name\n clientId\n issuer\n authorizationEndpoint\n tokenEndpoint\n jwksUri\n scopes\n authorizationRules {\n claim\n operator\n value\n }\n authorizationRuleMode\n buttonText\n buttonIcon\n }\n }\n }\n }\n": types.OidcProvidersDocument,
"\n query PublicOidcProviders {\n publicOidcProviders {\n id\n name\n buttonText\n buttonIcon\n buttonVariant\n buttonStyle\n }\n }\n": types.PublicOidcProvidersDocument, "\n query PublicOidcProviders {\n publicOidcProviders {\n id\n name\n buttonText\n buttonIcon\n buttonVariant\n buttonStyle\n }\n }\n": types.PublicOidcProvidersDocument,
"\n query AllConfigFiles {\n allConfigFiles {\n files {\n name\n content\n path\n sizeReadable\n }\n }\n }\n": types.AllConfigFilesDocument,
"\n query ConfigFile($name: String!) {\n configFile(name: $name) {\n name\n content\n path\n sizeReadable\n }\n }\n": types.ConfigFileDocument,
"\n query serverInfo {\n info {\n os {\n hostname\n }\n }\n vars {\n comment\n }\n }\n": types.ServerInfoDocument, "\n query serverInfo {\n info {\n os {\n hostname\n }\n }\n vars {\n comment\n }\n }\n": types.ServerInfoDocument,
"\n mutation ConnectSignIn($input: ConnectSignInInput!) {\n connectSignIn(input: $input)\n }\n": types.ConnectSignInDocument, "\n mutation ConnectSignIn($input: ConnectSignInInput!) {\n connectSignIn(input: $input)\n }\n": types.ConnectSignInDocument,
"\n mutation SignOut {\n connectSignOut\n }\n": types.SignOutDocument, "\n mutation SignOut {\n connectSignOut\n }\n": types.SignOutDocument,
@@ -258,6 +262,14 @@ export function graphql(source: "\n query OidcProviders {\n settings {\n
* The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients. * The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients.
*/ */
export function graphql(source: "\n query PublicOidcProviders {\n publicOidcProviders {\n id\n name\n buttonText\n buttonIcon\n buttonVariant\n buttonStyle\n }\n }\n"): (typeof documents)["\n query PublicOidcProviders {\n publicOidcProviders {\n id\n name\n buttonText\n buttonIcon\n buttonVariant\n buttonStyle\n }\n }\n"]; export function graphql(source: "\n query PublicOidcProviders {\n publicOidcProviders {\n id\n name\n buttonText\n buttonIcon\n buttonVariant\n buttonStyle\n }\n }\n"): (typeof documents)["\n query PublicOidcProviders {\n publicOidcProviders {\n id\n name\n buttonText\n buttonIcon\n buttonVariant\n buttonStyle\n }\n }\n"];
/**
* The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients.
*/
export function graphql(source: "\n query AllConfigFiles {\n allConfigFiles {\n files {\n name\n content\n path\n sizeReadable\n }\n }\n }\n"): (typeof documents)["\n query AllConfigFiles {\n allConfigFiles {\n files {\n name\n content\n path\n sizeReadable\n }\n }\n }\n"];
/**
* The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients.
*/
export function graphql(source: "\n query ConfigFile($name: String!) {\n configFile(name: $name) {\n name\n content\n path\n sizeReadable\n }\n }\n"): (typeof documents)["\n query ConfigFile($name: String!) {\n configFile(name: $name) {\n name\n content\n path\n sizeReadable\n }\n }\n"];
/** /**
* The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients. * The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients.
*/ */

View File

@@ -448,6 +448,20 @@ export enum ConfigErrorState {
WITHDRAWN = 'WITHDRAWN' WITHDRAWN = 'WITHDRAWN'
} }
export type ConfigFile = {
__typename?: 'ConfigFile';
content: Scalars['String']['output'];
name: Scalars['String']['output'];
path: Scalars['String']['output'];
/** Human-readable file size (e.g., "1.5 KB", "2.3 MB") */
sizeReadable: Scalars['String']['output'];
};
export type ConfigFilesResponse = {
__typename?: 'ConfigFilesResponse';
files: Array<ConfigFile>;
};
export type Connect = Node & { export type Connect = Node & {
__typename?: 'Connect'; __typename?: 'Connect';
/** The status of dynamic remote access */ /** The status of dynamic remote access */
@@ -1432,6 +1446,14 @@ export type OidcAuthorizationRule = {
value: Array<Scalars['String']['output']>; value: Array<Scalars['String']['output']>;
}; };
export type OidcConfiguration = {
__typename?: 'OidcConfiguration';
/** Default allowed redirect origins that apply to all OIDC providers (e.g., Tailscale domains) */
defaultAllowedOrigins?: Maybe<Array<Scalars['String']['output']>>;
/** List of configured OIDC providers */
providers: Array<OidcProvider>;
};
export type OidcProvider = { export type OidcProvider = {
__typename?: 'OidcProvider'; __typename?: 'OidcProvider';
/** OAuth2 authorization endpoint URL. If omitted, will be auto-discovered from issuer/.well-known/openid-configuration */ /** OAuth2 authorization endpoint URL. If omitted, will be auto-discovered from issuer/.well-known/openid-configuration */
@@ -1455,7 +1477,7 @@ export type OidcProvider = {
/** The unique identifier for the OIDC provider */ /** The unique identifier for the OIDC provider */
id: Scalars['PrefixedID']['output']; id: Scalars['PrefixedID']['output'];
/** OIDC issuer URL (e.g., https://accounts.google.com). Required for auto-discovery via /.well-known/openid-configuration */ /** OIDC issuer URL (e.g., https://accounts.google.com). Required for auto-discovery via /.well-known/openid-configuration */
issuer: Scalars['String']['output']; issuer?: Maybe<Scalars['String']['output']>;
/** JSON Web Key Set URI for token validation. If omitted, will be auto-discovered from issuer/.well-known/openid-configuration */ /** JSON Web Key Set URI for token validation. If omitted, will be auto-discovered from issuer/.well-known/openid-configuration */
jwksUri?: Maybe<Scalars['String']['output']>; jwksUri?: Maybe<Scalars['String']['output']>;
/** Display name of the OIDC provider */ /** Display name of the OIDC provider */
@@ -1623,6 +1645,7 @@ export type PublicPartnerInfo = {
export type Query = { export type Query = {
__typename?: 'Query'; __typename?: 'Query';
allConfigFiles: ConfigFilesResponse;
apiKey?: Maybe<ApiKey>; apiKey?: Maybe<ApiKey>;
/** All possible permissions for API keys */ /** All possible permissions for API keys */
apiKeyPossiblePermissions: Array<Permission>; apiKeyPossiblePermissions: Array<Permission>;
@@ -1632,6 +1655,7 @@ export type Query = {
array: UnraidArray; array: UnraidArray;
cloud: Cloud; cloud: Cloud;
config: Config; config: Config;
configFile?: Maybe<ConfigFile>;
connect: Connect; connect: Connect;
customization?: Maybe<Customization>; customization?: Maybe<Customization>;
disk: Disk; disk: Disk;
@@ -1654,6 +1678,8 @@ export type Query = {
network: Network; network: Network;
/** Get all notifications */ /** Get all notifications */
notifications: Notifications; notifications: Notifications;
/** Get the full OIDC configuration (admin only) */
oidcConfiguration: OidcConfiguration;
/** Get a specific OIDC provider by ID */ /** Get a specific OIDC provider by ID */
oidcProvider?: Maybe<OidcProvider>; oidcProvider?: Maybe<OidcProvider>;
/** Get all configured OIDC providers (admin only) */ /** Get all configured OIDC providers (admin only) */
@@ -1693,6 +1719,11 @@ export type QueryApiKeyArgs = {
}; };
export type QueryConfigFileArgs = {
name: Scalars['String']['input'];
};
export type QueryDiskArgs = { export type QueryDiskArgs = {
id: Scalars['PrefixedID']['input']; id: Scalars['PrefixedID']['input'];
}; };
@@ -1933,6 +1964,7 @@ export type Server = Node & {
name: Scalars['String']['output']; name: Scalars['String']['output'];
owner: ProfileModel; owner: ProfileModel;
remoteurl: Scalars['String']['output']; remoteurl: Scalars['String']['output'];
/** Whether this server is online or offline */
status: ServerStatus; status: ServerStatus;
wanip: Scalars['String']['output']; wanip: Scalars['String']['output'];
}; };
@@ -2757,13 +2789,25 @@ export type InfoVersionsQuery = { __typename?: 'Query', info: { __typename?: 'In
export type OidcProvidersQueryVariables = Exact<{ [key: string]: never; }>; export type OidcProvidersQueryVariables = Exact<{ [key: string]: never; }>;
export type OidcProvidersQuery = { __typename?: 'Query', settings: { __typename?: 'Settings', sso: { __typename?: 'SsoSettings', oidcProviders: Array<{ __typename?: 'OidcProvider', id: string, name: string, clientId: string, issuer: string, authorizationEndpoint?: string | null, tokenEndpoint?: string | null, jwksUri?: string | null, scopes: Array<string>, authorizationRuleMode?: AuthorizationRuleMode | null, buttonText?: string | null, buttonIcon?: string | null, authorizationRules?: Array<{ __typename?: 'OidcAuthorizationRule', claim: string, operator: AuthorizationOperator, value: Array<string> }> | null }> } } }; export type OidcProvidersQuery = { __typename?: 'Query', settings: { __typename?: 'Settings', sso: { __typename?: 'SsoSettings', oidcProviders: Array<{ __typename?: 'OidcProvider', id: string, name: string, clientId: string, issuer?: string | null, authorizationEndpoint?: string | null, tokenEndpoint?: string | null, jwksUri?: string | null, scopes: Array<string>, authorizationRuleMode?: AuthorizationRuleMode | null, buttonText?: string | null, buttonIcon?: string | null, authorizationRules?: Array<{ __typename?: 'OidcAuthorizationRule', claim: string, operator: AuthorizationOperator, value: Array<string> }> | null }> } } };
export type PublicOidcProvidersQueryVariables = Exact<{ [key: string]: never; }>; export type PublicOidcProvidersQueryVariables = Exact<{ [key: string]: never; }>;
export type PublicOidcProvidersQuery = { __typename?: 'Query', publicOidcProviders: Array<{ __typename?: 'PublicOidcProvider', id: string, name: string, buttonText?: string | null, buttonIcon?: string | null, buttonVariant?: string | null, buttonStyle?: string | null }> }; export type PublicOidcProvidersQuery = { __typename?: 'Query', publicOidcProviders: Array<{ __typename?: 'PublicOidcProvider', id: string, name: string, buttonText?: string | null, buttonIcon?: string | null, buttonVariant?: string | null, buttonStyle?: string | null }> };
export type AllConfigFilesQueryVariables = Exact<{ [key: string]: never; }>;
export type AllConfigFilesQuery = { __typename?: 'Query', allConfigFiles: { __typename?: 'ConfigFilesResponse', files: Array<{ __typename?: 'ConfigFile', name: string, content: string, path: string, sizeReadable: string }> } };
export type ConfigFileQueryVariables = Exact<{
name: Scalars['String']['input'];
}>;
export type ConfigFileQuery = { __typename?: 'Query', configFile?: { __typename?: 'ConfigFile', name: string, content: string, path: string, sizeReadable: string } | null };
export type ServerInfoQueryVariables = Exact<{ [key: string]: never; }>; export type ServerInfoQueryVariables = Exact<{ [key: string]: never; }>;
@@ -2842,6 +2886,8 @@ export const ListRCloneRemotesDocument = {"kind":"Document","definitions":[{"kin
export const InfoVersionsDocument = {"kind":"Document","definitions":[{"kind":"OperationDefinition","operation":"query","name":{"kind":"Name","value":"InfoVersions"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"info"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"id"}},{"kind":"Field","name":{"kind":"Name","value":"os"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"id"}},{"kind":"Field","name":{"kind":"Name","value":"hostname"}}]}},{"kind":"Field","name":{"kind":"Name","value":"versions"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"id"}},{"kind":"Field","name":{"kind":"Name","value":"core"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"unraid"}},{"kind":"Field","name":{"kind":"Name","value":"api"}}]}}]}}]}}]}}]} as unknown as DocumentNode<InfoVersionsQuery, InfoVersionsQueryVariables>; export const InfoVersionsDocument = {"kind":"Document","definitions":[{"kind":"OperationDefinition","operation":"query","name":{"kind":"Name","value":"InfoVersions"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"info"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"id"}},{"kind":"Field","name":{"kind":"Name","value":"os"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"id"}},{"kind":"Field","name":{"kind":"Name","value":"hostname"}}]}},{"kind":"Field","name":{"kind":"Name","value":"versions"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"id"}},{"kind":"Field","name":{"kind":"Name","value":"core"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"unraid"}},{"kind":"Field","name":{"kind":"Name","value":"api"}}]}}]}}]}}]}}]} as unknown as DocumentNode<InfoVersionsQuery, InfoVersionsQueryVariables>;
export const OidcProvidersDocument = {"kind":"Document","definitions":[{"kind":"OperationDefinition","operation":"query","name":{"kind":"Name","value":"OidcProviders"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"settings"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"sso"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"oidcProviders"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"id"}},{"kind":"Field","name":{"kind":"Name","value":"name"}},{"kind":"Field","name":{"kind":"Name","value":"clientId"}},{"kind":"Field","name":{"kind":"Name","value":"issuer"}},{"kind":"Field","name":{"kind":"Name","value":"authorizationEndpoint"}},{"kind":"Field","name":{"kind":"Name","value":"tokenEndpoint"}},{"kind":"Field","name":{"kind":"Name","value":"jwksUri"}},{"kind":"Field","name":{"kind":"Name","value":"scopes"}},{"kind":"Field","name":{"kind":"Name","value":"authorizationRules"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"claim"}},{"kind":"Field","name":{"kind":"Name","value":"operator"}},{"kind":"Field","name":{"kind":"Name","value":"value"}}]}},{"kind":"Field","name":{"kind":"Name","value":"authorizationRuleMode"}},{"kind":"Field","name":{"kind":"Name","value":"buttonText"}},{"kind":"Field","name":{"kind":"Name","value":"buttonIcon"}}]}}]}}]}}]}}]} as unknown as DocumentNode<OidcProvidersQuery, OidcProvidersQueryVariables>; export const OidcProvidersDocument = {"kind":"Document","definitions":[{"kind":"OperationDefinition","operation":"query","name":{"kind":"Name","value":"OidcProviders"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"settings"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"sso"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"oidcProviders"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"id"}},{"kind":"Field","name":{"kind":"Name","value":"name"}},{"kind":"Field","name":{"kind":"Name","value":"clientId"}},{"kind":"Field","name":{"kind":"Name","value":"issuer"}},{"kind":"Field","name":{"kind":"Name","value":"authorizationEndpoint"}},{"kind":"Field","name":{"kind":"Name","value":"tokenEndpoint"}},{"kind":"Field","name":{"kind":"Name","value":"jwksUri"}},{"kind":"Field","name":{"kind":"Name","value":"scopes"}},{"kind":"Field","name":{"kind":"Name","value":"authorizationRules"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"claim"}},{"kind":"Field","name":{"kind":"Name","value":"operator"}},{"kind":"Field","name":{"kind":"Name","value":"value"}}]}},{"kind":"Field","name":{"kind":"Name","value":"authorizationRuleMode"}},{"kind":"Field","name":{"kind":"Name","value":"buttonText"}},{"kind":"Field","name":{"kind":"Name","value":"buttonIcon"}}]}}]}}]}}]}}]} as unknown as DocumentNode<OidcProvidersQuery, OidcProvidersQueryVariables>;
export const PublicOidcProvidersDocument = {"kind":"Document","definitions":[{"kind":"OperationDefinition","operation":"query","name":{"kind":"Name","value":"PublicOidcProviders"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"publicOidcProviders"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"id"}},{"kind":"Field","name":{"kind":"Name","value":"name"}},{"kind":"Field","name":{"kind":"Name","value":"buttonText"}},{"kind":"Field","name":{"kind":"Name","value":"buttonIcon"}},{"kind":"Field","name":{"kind":"Name","value":"buttonVariant"}},{"kind":"Field","name":{"kind":"Name","value":"buttonStyle"}}]}}]}}]} as unknown as DocumentNode<PublicOidcProvidersQuery, PublicOidcProvidersQueryVariables>; export const PublicOidcProvidersDocument = {"kind":"Document","definitions":[{"kind":"OperationDefinition","operation":"query","name":{"kind":"Name","value":"PublicOidcProviders"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"publicOidcProviders"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"id"}},{"kind":"Field","name":{"kind":"Name","value":"name"}},{"kind":"Field","name":{"kind":"Name","value":"buttonText"}},{"kind":"Field","name":{"kind":"Name","value":"buttonIcon"}},{"kind":"Field","name":{"kind":"Name","value":"buttonVariant"}},{"kind":"Field","name":{"kind":"Name","value":"buttonStyle"}}]}}]}}]} as unknown as DocumentNode<PublicOidcProvidersQuery, PublicOidcProvidersQueryVariables>;
export const AllConfigFilesDocument = {"kind":"Document","definitions":[{"kind":"OperationDefinition","operation":"query","name":{"kind":"Name","value":"AllConfigFiles"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"allConfigFiles"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"files"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"name"}},{"kind":"Field","name":{"kind":"Name","value":"content"}},{"kind":"Field","name":{"kind":"Name","value":"path"}},{"kind":"Field","name":{"kind":"Name","value":"sizeReadable"}}]}}]}}]}}]} as unknown as DocumentNode<AllConfigFilesQuery, AllConfigFilesQueryVariables>;
export const ConfigFileDocument = {"kind":"Document","definitions":[{"kind":"OperationDefinition","operation":"query","name":{"kind":"Name","value":"ConfigFile"},"variableDefinitions":[{"kind":"VariableDefinition","variable":{"kind":"Variable","name":{"kind":"Name","value":"name"}},"type":{"kind":"NonNullType","type":{"kind":"NamedType","name":{"kind":"Name","value":"String"}}}}],"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"configFile"},"arguments":[{"kind":"Argument","name":{"kind":"Name","value":"name"},"value":{"kind":"Variable","name":{"kind":"Name","value":"name"}}}],"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"name"}},{"kind":"Field","name":{"kind":"Name","value":"content"}},{"kind":"Field","name":{"kind":"Name","value":"path"}},{"kind":"Field","name":{"kind":"Name","value":"sizeReadable"}}]}}]}}]} as unknown as DocumentNode<ConfigFileQuery, ConfigFileQueryVariables>;
export const ServerInfoDocument = {"kind":"Document","definitions":[{"kind":"OperationDefinition","operation":"query","name":{"kind":"Name","value":"serverInfo"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"info"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"os"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"hostname"}}]}}]}},{"kind":"Field","name":{"kind":"Name","value":"vars"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"comment"}}]}}]}}]} as unknown as DocumentNode<ServerInfoQuery, ServerInfoQueryVariables>; export const ServerInfoDocument = {"kind":"Document","definitions":[{"kind":"OperationDefinition","operation":"query","name":{"kind":"Name","value":"serverInfo"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"info"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"os"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"hostname"}}]}}]}},{"kind":"Field","name":{"kind":"Name","value":"vars"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"comment"}}]}}]}}]} as unknown as DocumentNode<ServerInfoQuery, ServerInfoQueryVariables>;
export const ConnectSignInDocument = {"kind":"Document","definitions":[{"kind":"OperationDefinition","operation":"mutation","name":{"kind":"Name","value":"ConnectSignIn"},"variableDefinitions":[{"kind":"VariableDefinition","variable":{"kind":"Variable","name":{"kind":"Name","value":"input"}},"type":{"kind":"NonNullType","type":{"kind":"NamedType","name":{"kind":"Name","value":"ConnectSignInInput"}}}}],"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"connectSignIn"},"arguments":[{"kind":"Argument","name":{"kind":"Name","value":"input"},"value":{"kind":"Variable","name":{"kind":"Name","value":"input"}}}]}]}}]} as unknown as DocumentNode<ConnectSignInMutation, ConnectSignInMutationVariables>; export const ConnectSignInDocument = {"kind":"Document","definitions":[{"kind":"OperationDefinition","operation":"mutation","name":{"kind":"Name","value":"ConnectSignIn"},"variableDefinitions":[{"kind":"VariableDefinition","variable":{"kind":"Variable","name":{"kind":"Name","value":"input"}},"type":{"kind":"NonNullType","type":{"kind":"NamedType","name":{"kind":"Name","value":"ConnectSignInInput"}}}}],"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"connectSignIn"},"arguments":[{"kind":"Argument","name":{"kind":"Name","value":"input"},"value":{"kind":"Variable","name":{"kind":"Name","value":"input"}}}]}]}}]} as unknown as DocumentNode<ConnectSignInMutation, ConnectSignInMutationVariables>;
export const SignOutDocument = {"kind":"Document","definitions":[{"kind":"OperationDefinition","operation":"mutation","name":{"kind":"Name","value":"SignOut"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"connectSignOut"}}]}}]} as unknown as DocumentNode<SignOutMutation, SignOutMutationVariables>; export const SignOutDocument = {"kind":"Document","definitions":[{"kind":"OperationDefinition","operation":"mutation","name":{"kind":"Name","value":"SignOut"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"connectSignOut"}}]}}]} as unknown as DocumentNode<SignOutMutation, SignOutMutationVariables>;

View File

@@ -1,211 +0,0 @@
import { computed, ref } from 'vue';
import { AuthAction, Resource, Role } from '~/composables/gql/graphql';
export interface ScopeConversion {
permissions: Array<{ resource: Resource; actions: AuthAction[] }>;
roles: Role[];
}
/**
* Convert scope strings to permissions and roles
* Scopes can be in format:
* - "role:admin" for roles
* - "docker:read" for resource permissions
* - "docker:*" for all actions on a resource
*/
function convertScopesToPermissions(scopes: string[]): ScopeConversion {
const permissions: Array<{ resource: Resource; actions: AuthAction[] }> = [];
const roles: Role[] = [];
for (const scope of scopes) {
if (scope.startsWith('role:')) {
// Handle role scope
const roleStr = scope.substring(5).toUpperCase();
if (Object.values(Role).includes(roleStr as Role)) {
roles.push(roleStr as Role);
} else {
console.warn(`Unknown role in scope: ${scope}`);
}
} else {
// Handle permission scope
const [resourceStr, actionStr] = scope.split(':');
if (resourceStr && actionStr) {
const resourceUpper = resourceStr.toUpperCase();
const resource = Object.values(Resource).find(r => r === resourceUpper) as Resource;
if (!resource) {
console.warn(`Unknown resource in scope: ${scope}`);
continue;
}
// Handle wildcard or specific action
let actions: AuthAction[];
if (actionStr === '*') {
// Wildcard means all CRUD actions
actions = [
AuthAction.CREATE_ANY,
AuthAction.READ_ANY,
AuthAction.UPDATE_ANY,
AuthAction.DELETE_ANY
];
} else {
// Convert action string to AuthAction enum
// Scopes come in as 'read', 'create', etc. - convert to 'READ_ANY', 'CREATE_ANY'
const enumValue = `${actionStr.toUpperCase()}_ANY` as AuthAction;
if (Object.values(AuthAction).includes(enumValue)) {
actions = [enumValue];
} else {
console.warn(`Unknown action in scope: ${scope}`);
continue;
}
}
// Merge with existing permissions for the same resource
const existing = permissions.find(p => p.resource === resource);
if (existing) {
actions.forEach(a => {
if (!existing.actions.includes(a)) {
existing.actions.push(a);
}
});
} else {
permissions.push({ resource, actions });
}
}
}
}
return { permissions, roles };
}
export interface ApiKeyAuthorizationParams {
name: string;
description: string;
scopes: string[];
redirectUri: string;
state: string;
}
export interface FormattedPermission {
scope: string;
name: string;
description: string;
isRole: boolean;
}
/**
* Composable for handling API key authorization flow
*/
export function useApiKeyAuthorization(urlSearchParams?: URLSearchParams) {
// Parse query parameters with SSR safety
const params = urlSearchParams || (
typeof window !== 'undefined'
? new URLSearchParams(window.location.search)
: new URLSearchParams()
);
const authParams = ref<ApiKeyAuthorizationParams>({
name: params.get('name') || 'Unknown Application',
description: params.get('description') || '',
scopes: (params.get('scopes') || '').split(',').filter(Boolean),
redirectUri: params.get('redirect_uri') || '',
state: params.get('state') || '',
});
// Validate redirect URI - allow any valid URL including app URLs and custom schemes
const isValidRedirectUri = (uri: string): boolean => {
if (!uri) return false;
try {
// Just check if it's a valid URL format, don't restrict protocols or hosts
new URL(uri);
return true;
} catch {
return false;
}
};
// Format scopes for display
const formatPermissions = (scopes: string[]): FormattedPermission[] => {
return scopes.map(scope => {
if (scope.startsWith('role:')) {
const role = scope.substring(5);
return {
scope,
name: role.toUpperCase(),
description: `Grant ${role} role access`,
isRole: true,
};
} else {
const [resource, action] = scope.split(':');
if (resource && action) {
const resourceName = resource.charAt(0).toUpperCase() + resource.slice(1);
const actionDesc = action === '*'
? 'Full'
: action.charAt(0).toUpperCase() + action.slice(1);
return {
scope,
name: `${resourceName} - ${actionDesc}`,
description: `${actionDesc} access to ${resourceName}`,
isRole: false,
};
}
}
return {
scope,
name: scope,
description: scope,
isRole: false
};
});
};
// Use the shared convertScopesToFrontendFormPermissions function from @unraid/shared
// This ensures consistent scope parsing across frontend and backend
// Build redirect URL with API key or error
const buildCallbackUrl = (
redirectUri: string,
apiKey?: string,
error?: string,
state?: string
): string => {
try {
const url = new URL(redirectUri);
if (apiKey) {
url.searchParams.set('api_key', apiKey);
}
if (error) {
url.searchParams.set('error', error);
}
if (state) {
url.searchParams.set('state', state);
}
return url.toString();
} catch {
throw new Error('Invalid redirect URI');
}
};
// Computed properties
const formattedPermissions = computed(() =>
formatPermissions(authParams.value.scopes)
);
const hasValidRedirectUri = computed(() =>
isValidRedirectUri(authParams.value.redirectUri)
);
const defaultKeyName = computed(() => authParams.value.name);
return {
authParams,
formattedPermissions,
hasValidRedirectUri,
defaultKeyName,
isValidRedirectUri,
formatPermissions,
convertScopesToPermissions,
buildCallbackUrl,
};
}

View File

@@ -4,7 +4,7 @@ import {
scopesToFormData, scopesToFormData,
buildCallbackUrl as buildUrl, buildCallbackUrl as buildUrl,
generateAuthorizationUrl as generateUrl generateAuthorizationUrl as generateUrl
} from '~/utils/authorizationScopes.js'; } from '~/utils/authorizationScopes';
export interface ApiKeyAuthorizationParams { export interface ApiKeyAuthorizationParams {
name: string; name: string;
@@ -14,22 +14,6 @@ export interface ApiKeyAuthorizationParams {
state: string; state: string;
} }
// Re-export types from utils for convenience
export type {
AuthorizationFormData,
AuthorizationLinkParams,
RawPermission
} from '~/utils/authorizationScopes.js';
// Re-export functions for direct use
export {
encodePermissionsToScopes,
decodeScopesToPermissions,
extractActionVerb,
generateAuthorizationUrl,
buildCallbackUrl
} from '~/utils/authorizationScopes.js';
/** /**
* Composable for authorization link handling with reactive state * Composable for authorization link handling with reactive state
*/ */

View File

@@ -0,0 +1,83 @@
import hljs from 'highlight.js/lib/core';
import DOMPurify from 'isomorphic-dompurify';
import { AnsiUp } from 'ansi_up';
import 'highlight.js/styles/github-dark.css';
import apache from 'highlight.js/lib/languages/apache';
import bash from 'highlight.js/lib/languages/bash';
import ini from 'highlight.js/lib/languages/ini';
import javascript from 'highlight.js/lib/languages/javascript';
import json from 'highlight.js/lib/languages/json';
import nginx from 'highlight.js/lib/languages/nginx';
import php from 'highlight.js/lib/languages/php';
import plaintext from 'highlight.js/lib/languages/plaintext';
import xml from 'highlight.js/lib/languages/xml';
import yaml from 'highlight.js/lib/languages/yaml';
// Register the languages (only once)
let languagesRegistered = false;
const registerLanguages = () => {
if (!languagesRegistered) {
hljs.registerLanguage('plaintext', plaintext);
hljs.registerLanguage('bash', bash);
hljs.registerLanguage('ini', ini);
hljs.registerLanguage('xml', xml);
hljs.registerLanguage('json', json);
hljs.registerLanguage('yaml', yaml);
hljs.registerLanguage('nginx', nginx);
hljs.registerLanguage('apache', apache);
hljs.registerLanguage('javascript', javascript);
hljs.registerLanguage('php', php);
languagesRegistered = true;
}
};
export const useContentHighlighting = () => {
// Initialize ANSI to HTML converter with CSS classes
const ansiConverter = new AnsiUp();
ansiConverter.use_classes = true;
ansiConverter.escape_html = true;
// Register languages on first use
registerLanguages();
// Function to highlight content
const highlightContent = (content: string, language?: string): string => {
try {
let highlighted: string;
// Check if content contains ANSI escape sequences
// eslint-disable-next-line no-control-regex
const hasAnsiSequences = /\x1b\[/.test(content);
if (hasAnsiSequences) {
// Use ANSI converter for content with ANSI codes
highlighted = ansiConverter.ansi_to_html(content);
} else if (language) {
// Use highlight.js for specific language if provided
const result = hljs.highlight(content, { language, ignoreIllegals: true });
highlighted = result.value;
} else {
// Use highlight.js auto-detection for non-ANSI content
const result = hljs.highlightAuto(content);
highlighted = result.value;
}
// Sanitize the highlighted HTML while preserving class attributes for syntax highlighting
return DOMPurify.sanitize(highlighted, {
ALLOWED_TAGS: ['span', 'br', 'code', 'pre'],
ALLOWED_ATTR: ['class'] // Allow class attribute for hljs and ANSI color classes
});
} catch (error) {
console.error('Error highlighting content:', error);
// Fallback to sanitized but not highlighted content
return DOMPurify.sanitize(content);
}
};
return {
highlightContent
};
};

View File

@@ -107,6 +107,7 @@
"@vueuse/components": "13.8.0", "@vueuse/components": "13.8.0",
"@vueuse/integrations": "13.8.0", "@vueuse/integrations": "13.8.0",
"ajv": "8.17.1", "ajv": "8.17.1",
"ansi_up": "^6.0.6",
"class-variance-authority": "0.7.1", "class-variance-authority": "0.7.1",
"clsx": "2.1.1", "clsx": "2.1.1",
"crypto-js": "4.2.0", "crypto-js": "4.2.0",