feat(api): enhance OIDC redirect URI handling in service and tests (#1618)

- Updated `getRedirectUri` method in `OidcAuthService` to handle various
edge cases for redirect URIs, including full URIs, malformed URLs, and
default ports.
- Added comprehensive tests for `OidcAuthService` to validate redirect
URI construction and error handling.
- Modified `RestController` to utilize `redirect_uri` query parameter
for authorization requests.
- Updated frontend components to include `redirect_uri` in authorization
URLs, ensuring correct handling of different protocols and ports.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Stronger OIDC redirect_uri validation and an admin GraphQL endpoint to
view full OIDC configuration.
* OIDC Debug Logs UI (panel, button, modal), enhanced log viewer with
presets/filters, ANSI-colored rendering, and a File Viewer component.
* New GraphQL queries to list and fetch config files; API Config
Download page.

* **Refactor**
* Centralized, modular OIDC flows and safer redirect handling;
topic-based log subscriptions with a watcher manager for scalable live
logs.

* **Documentation**
  * Cache TTL guidance clarified to use milliseconds.

* **Chores**
* Added ansi_up and escape-html deps; improved log formatting; added
root codegen script.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
This commit is contained in:
Eli Bosley
2025-09-02 10:40:20 -04:00
committed by GitHub
parent 6356f9c41d
commit 4e945f5f56
100 changed files with 9543 additions and 3396 deletions

View File

@@ -17,5 +17,6 @@
],
"buttonText": "Login With Unraid.net"
}
]
],
"defaultAllowedOrigins": []
}

View File

@@ -1798,6 +1798,8 @@ type Server implements Node {
guid: String!
apikey: String!
name: String!
"""Whether this server is online or offline"""
status: ServerStatus!
wanip: String!
lanip: String!
@@ -1854,7 +1856,7 @@ type OidcProvider {
"""
OIDC issuer URL (e.g., https://accounts.google.com). Required for auto-discovery via /.well-known/openid-configuration
"""
issuer: String!
issuer: String
"""
OAuth2 authorization endpoint URL. If omitted, will be auto-discovered from issuer/.well-known/openid-configuration
@@ -1907,6 +1909,16 @@ enum AuthorizationRuleMode {
AND
}
type OidcConfiguration {
"""List of configured OIDC providers"""
providers: [OidcProvider!]!
"""
Default allowed redirect origins that apply to all OIDC providers (e.g., Tailscale domains)
"""
defaultAllowedOrigins: [String!]
}
type OidcSessionValidation {
valid: Boolean!
username: String
@@ -2307,8 +2319,6 @@ type Query {
getApiKeyCreationFormSchema: ApiKeyFormSettings!
config: Config!
flash: Flash!
logFiles: [LogFile!]!
logFile(path: String!, lines: Int, startLine: Int): LogFileContent!
me: UserAccount!
"""Get all notifications"""
@@ -2335,6 +2345,8 @@ type Query {
disk(id: PrefixedID!): Disk!
rclone: RCloneBackupSettings!
info: Info!
logFiles: [LogFile!]!
logFile(path: String!, lines: Int, startLine: Int): LogFileContent!
settings: Settings!
isSSOEnabled: Boolean!
@@ -2347,6 +2359,9 @@ type Query {
"""Get a specific OIDC provider by ID"""
oidcProvider(id: PrefixedID!): OidcProvider
"""Get the full OIDC configuration (admin only)"""
oidcConfiguration: OidcConfiguration!
"""Validate an OIDC session token (internal use for CLI validation)"""
validateOidcSession(token: String!): OidcSessionValidation!
metrics: Metrics!
@@ -2590,13 +2605,13 @@ input AccessUrlInput {
}
type Subscription {
logFile(path: String!): LogFileContent!
notificationAdded: Notification!
notificationsOverview: NotificationOverview!
ownerSubscription: Owner!
serversSubscription: Server!
parityHistorySubscription: ParityCheck!
arraySubscription: UnraidArray!
logFile(path: String!): LogFileContent!
systemMetricsCpu: CpuUtilization!
systemMetricsMemory: MemoryUtilization!
upsUpdates: UPSDevice!

View File

@@ -99,6 +99,7 @@
"diff": "8.0.2",
"dockerode": "4.0.7",
"dotenv": "17.2.1",
"escape-html": "1.0.3",
"execa": "9.6.0",
"exit-hook": "4.0.0",
"fastify": "5.5.0",

View File

@@ -29,8 +29,24 @@ const stream = SUPPRESS_LOGS
singleLine: true,
hideObject: false,
colorize: true,
colorizeObjects: true,
levelFirst: false,
ignore: 'hostname,pid',
destination: logDestination,
translateTime: 'HH:mm:ss',
customPrettifiers: {
time: (timestamp: string | object) => `[${timestamp}`,
level: (logLevel: string | object, key: string, log: any, extras: any) => {
// Use labelColorized which preserves the colors
const { labelColorized } = extras;
const context = log.context || log.logger || 'app';
return `${labelColorized} ${context}]`;
},
},
messageFormat: (log: any, messageKey: string) => {
const msg = log[messageKey] || log.msg || '';
return msg;
},
})
: logDestination;

View File

@@ -13,10 +13,11 @@ export const pubsub = new PubSub({ eventEmitter });
/**
* Create a pubsub subscription.
* @param channel The pubsub channel to subscribe to.
* @param channel The pubsub channel to subscribe to. Can be either a predefined GRAPHQL_PUBSUB_CHANNEL
* or a dynamic string for runtime-generated topics (e.g., log file paths like "LOG_FILE:/var/log/test.log")
*/
export const createSubscription = <T = any>(
channel: GRAPHQL_PUBSUB_CHANNEL
channel: GRAPHQL_PUBSUB_CHANNEL | string
): AsyncIterableIterator<T> => {
return pubsub.asyncIterableIterator<T>(channel);
};

View File

@@ -1,3 +1,4 @@
import { CacheModule } from '@nestjs/cache-manager';
import { Test } from '@nestjs/testing';
import { describe, expect, it } from 'vitest';
@@ -9,7 +10,7 @@ describe('Module Dependencies Integration', () => {
let module;
try {
module = await Test.createTestingModule({
imports: [RestModule],
imports: [CacheModule.register({ isGlobal: true }), RestModule],
}).compile();
expect(module).toBeDefined();

View File

@@ -34,6 +34,15 @@ import { UnraidFileModifierModule } from '@app/unraid-api/unraid-file-modifier/u
req: () => undefined,
res: () => undefined,
},
formatters: {
log: (obj) => {
// Map NestJS context to Pino context field for pino-pretty
if (obj.context && !obj.logger) {
return { ...obj, logger: obj.context };
}
return obj;
},
},
},
}),
AuthModule,

View File

@@ -448,6 +448,20 @@ export enum ConfigErrorState {
WITHDRAWN = 'WITHDRAWN'
}
export type ConfigFile = {
__typename?: 'ConfigFile';
content: Scalars['String']['output'];
name: Scalars['String']['output'];
path: Scalars['String']['output'];
/** Human-readable file size (e.g., "1.5 KB", "2.3 MB") */
sizeReadable: Scalars['String']['output'];
};
export type ConfigFilesResponse = {
__typename?: 'ConfigFilesResponse';
files: Array<ConfigFile>;
};
export type Connect = Node & {
__typename?: 'Connect';
/** The status of dynamic remote access */
@@ -1432,6 +1446,14 @@ export type OidcAuthorizationRule = {
value: Array<Scalars['String']['output']>;
};
export type OidcConfiguration = {
__typename?: 'OidcConfiguration';
/** Default allowed redirect origins that apply to all OIDC providers (e.g., Tailscale domains) */
defaultAllowedOrigins?: Maybe<Array<Scalars['String']['output']>>;
/** List of configured OIDC providers */
providers: Array<OidcProvider>;
};
export type OidcProvider = {
__typename?: 'OidcProvider';
/** OAuth2 authorization endpoint URL. If omitted, will be auto-discovered from issuer/.well-known/openid-configuration */
@@ -1455,7 +1477,7 @@ export type OidcProvider = {
/** The unique identifier for the OIDC provider */
id: Scalars['PrefixedID']['output'];
/** OIDC issuer URL (e.g., https://accounts.google.com). Required for auto-discovery via /.well-known/openid-configuration */
issuer: Scalars['String']['output'];
issuer?: Maybe<Scalars['String']['output']>;
/** JSON Web Key Set URI for token validation. If omitted, will be auto-discovered from issuer/.well-known/openid-configuration */
jwksUri?: Maybe<Scalars['String']['output']>;
/** Display name of the OIDC provider */
@@ -1623,6 +1645,7 @@ export type PublicPartnerInfo = {
export type Query = {
__typename?: 'Query';
allConfigFiles: ConfigFilesResponse;
apiKey?: Maybe<ApiKey>;
/** All possible permissions for API keys */
apiKeyPossiblePermissions: Array<Permission>;
@@ -1632,6 +1655,7 @@ export type Query = {
array: UnraidArray;
cloud: Cloud;
config: Config;
configFile?: Maybe<ConfigFile>;
connect: Connect;
customization?: Maybe<Customization>;
disk: Disk;
@@ -1654,6 +1678,8 @@ export type Query = {
network: Network;
/** Get all notifications */
notifications: Notifications;
/** Get the full OIDC configuration (admin only) */
oidcConfiguration: OidcConfiguration;
/** Get a specific OIDC provider by ID */
oidcProvider?: Maybe<OidcProvider>;
/** Get all configured OIDC providers (admin only) */
@@ -1693,6 +1719,11 @@ export type QueryApiKeyArgs = {
};
export type QueryConfigFileArgs = {
name: Scalars['String']['input'];
};
export type QueryDiskArgs = {
id: Scalars['PrefixedID']['input'];
};
@@ -1933,6 +1964,7 @@ export type Server = Node & {
name: Scalars['String']['output'];
owner: ProfileModel;
remoteurl: Scalars['String']['output'];
/** Whether this server is online or offline */
status: ServerStatus;
wanip: Scalars['String']['output'];
};

View File

@@ -0,0 +1,213 @@
import { Test, TestingModule } from '@nestjs/testing';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import {
LogWatcherManager,
WatcherState,
} from '@app/unraid-api/graph/resolvers/logs/log-watcher-manager.service.js';
describe('LogWatcherManager', () => {
let manager: LogWatcherManager;
let mockWatcher: any;
beforeEach(async () => {
const module: TestingModule = await Test.createTestingModule({
providers: [LogWatcherManager],
}).compile();
manager = module.get<LogWatcherManager>(LogWatcherManager);
mockWatcher = {
close: vi.fn(),
on: vi.fn(),
};
});
describe('state management', () => {
it('should set watcher as initializing', () => {
manager.setInitializing('test-key');
const entry = manager.getEntry('test-key');
expect(entry).toBeDefined();
expect(entry?.state).toBe(WatcherState.INITIALIZING);
});
it('should set watcher as active with position', () => {
manager.setActive('test-key', mockWatcher as any, 1000);
const entry = manager.getEntry('test-key');
expect(entry).toBeDefined();
expect(entry?.state).toBe(WatcherState.ACTIVE);
if (manager.isActive(entry)) {
expect(entry.watcher).toBe(mockWatcher);
expect(entry.position).toBe(1000);
}
});
it('should set watcher as stopping', () => {
manager.setStopping('test-key');
const entry = manager.getEntry('test-key');
expect(entry).toBeDefined();
expect(entry?.state).toBe(WatcherState.STOPPING);
});
});
describe('isWatchingOrInitializing', () => {
it('should return true for initializing watcher', () => {
manager.setInitializing('test-key');
expect(manager.isWatchingOrInitializing('test-key')).toBe(true);
});
it('should return true for active watcher', () => {
manager.setActive('test-key', mockWatcher as any, 0);
expect(manager.isWatchingOrInitializing('test-key')).toBe(true);
});
it('should return false for stopping watcher', () => {
manager.setStopping('test-key');
expect(manager.isWatchingOrInitializing('test-key')).toBe(false);
});
it('should return false for non-existent watcher', () => {
expect(manager.isWatchingOrInitializing('test-key')).toBe(false);
});
});
describe('handlePostInitialization', () => {
it('should activate watcher when not stopped', () => {
manager.setInitializing('test-key');
const result = manager.handlePostInitialization('test-key', mockWatcher as any, 500);
expect(result).toBe(true);
expect(mockWatcher.close).not.toHaveBeenCalled();
const entry = manager.getEntry('test-key');
expect(entry?.state).toBe(WatcherState.ACTIVE);
if (manager.isActive(entry)) {
expect(entry.position).toBe(500);
}
});
it('should cleanup watcher when marked as stopping', () => {
manager.setStopping('test-key');
const result = manager.handlePostInitialization('test-key', mockWatcher as any, 500);
expect(result).toBe(false);
expect(mockWatcher.close).toHaveBeenCalled();
expect(manager.getEntry('test-key')).toBeUndefined();
});
it('should cleanup watcher when entry is missing', () => {
const result = manager.handlePostInitialization('test-key', mockWatcher as any, 500);
expect(result).toBe(false);
expect(mockWatcher.close).toHaveBeenCalled();
expect(manager.getEntry('test-key')).toBeUndefined();
});
});
describe('stopWatcher', () => {
it('should mark initializing watcher as stopping', () => {
manager.setInitializing('test-key');
manager.stopWatcher('test-key');
const entry = manager.getEntry('test-key');
expect(entry?.state).toBe(WatcherState.STOPPING);
});
it('should close and remove active watcher', () => {
manager.setActive('test-key', mockWatcher as any, 0);
manager.stopWatcher('test-key');
expect(mockWatcher.close).toHaveBeenCalled();
expect(manager.getEntry('test-key')).toBeUndefined();
});
it('should do nothing for non-existent watcher', () => {
manager.stopWatcher('test-key');
expect(mockWatcher.close).not.toHaveBeenCalled();
});
});
describe('position management', () => {
it('should update position for active watcher', () => {
manager.setActive('test-key', mockWatcher as any, 100);
manager.updatePosition('test-key', 200);
const position = manager.getPosition('test-key');
expect(position).toBe(200);
});
it('should not update position for non-active watcher', () => {
manager.setInitializing('test-key');
manager.updatePosition('test-key', 200);
const position = manager.getPosition('test-key');
expect(position).toBeUndefined();
});
it('should get position for active watcher', () => {
manager.setActive('test-key', mockWatcher as any, 300);
expect(manager.getPosition('test-key')).toBe(300);
});
it('should return undefined for non-active watcher', () => {
manager.setStopping('test-key');
expect(manager.getPosition('test-key')).toBeUndefined();
});
});
describe('stopAllWatchers', () => {
it('should close all active watchers and clear map', () => {
const mockWatcher1 = { close: vi.fn() };
const mockWatcher2 = { close: vi.fn() };
const mockWatcher3 = { close: vi.fn() };
manager.setActive('key1', mockWatcher1 as any, 0);
manager.setInitializing('key2');
manager.setActive('key3', mockWatcher2 as any, 0);
manager.setStopping('key4');
manager.setActive('key5', mockWatcher3 as any, 0);
manager.stopAllWatchers();
expect(mockWatcher1.close).toHaveBeenCalled();
expect(mockWatcher2.close).toHaveBeenCalled();
expect(mockWatcher3.close).toHaveBeenCalled();
expect(manager.getEntry('key1')).toBeUndefined();
expect(manager.getEntry('key2')).toBeUndefined();
expect(manager.getEntry('key3')).toBeUndefined();
expect(manager.getEntry('key4')).toBeUndefined();
expect(manager.getEntry('key5')).toBeUndefined();
});
});
describe('in-flight processing', () => {
it('should prevent concurrent processing', () => {
manager.setActive('test-key', mockWatcher as any, 0);
// First call should succeed
expect(manager.startProcessing('test-key')).toBe(true);
// Second call should fail (already in flight)
expect(manager.startProcessing('test-key')).toBe(false);
// After finishing, should be able to start again
manager.finishProcessing('test-key');
expect(manager.startProcessing('test-key')).toBe(true);
});
it('should not start processing for non-active watcher', () => {
manager.setInitializing('test-key');
expect(manager.startProcessing('test-key')).toBe(false);
manager.setStopping('test-key');
expect(manager.startProcessing('test-key')).toBe(false);
});
it('should handle finish processing for non-existent watcher', () => {
// Should not throw
expect(() => manager.finishProcessing('non-existent')).not.toThrow();
});
});
});

View File

@@ -0,0 +1,183 @@
import { Injectable, Logger } from '@nestjs/common';
import * as chokidar from 'chokidar';
export enum WatcherState {
INITIALIZING = 'initializing',
ACTIVE = 'active',
STOPPING = 'stopping',
}
export type WatcherEntry =
| { state: WatcherState.INITIALIZING }
| { state: WatcherState.ACTIVE; watcher: chokidar.FSWatcher; position: number; inFlight: boolean }
| { state: WatcherState.STOPPING };
/**
* Service responsible for managing log file watchers and their lifecycle.
* Handles race conditions during watcher initialization and cleanup.
*/
@Injectable()
export class LogWatcherManager {
private readonly logger = new Logger(LogWatcherManager.name);
private readonly watchers = new Map<string, WatcherEntry>();
/**
* Set a watcher as initializing
*/
setInitializing(key: string): void {
this.watchers.set(key, { state: WatcherState.INITIALIZING });
}
/**
* Set a watcher as active with its FSWatcher and position
*/
setActive(key: string, watcher: chokidar.FSWatcher, position: number): void {
this.watchers.set(key, { state: WatcherState.ACTIVE, watcher, position, inFlight: false });
}
/**
* Mark a watcher as stopping (used during initialization race conditions)
*/
setStopping(key: string): void {
this.watchers.set(key, { state: WatcherState.STOPPING });
}
/**
* Get a watcher entry by key
*/
getEntry(key: string): WatcherEntry | undefined {
return this.watchers.get(key);
}
/**
* Remove a watcher entry
*/
removeEntry(key: string): void {
this.watchers.delete(key);
}
/**
* Check if a watcher is active and return typed entry
*/
isActive(entry: WatcherEntry | undefined): entry is {
state: WatcherState.ACTIVE;
watcher: chokidar.FSWatcher;
position: number;
inFlight: boolean;
} {
return entry?.state === WatcherState.ACTIVE;
}
/**
* Check if a watcher exists and is either initializing or active
*/
isWatchingOrInitializing(key: string): boolean {
const entry = this.getEntry(key);
return (
entry !== undefined &&
(entry.state === WatcherState.ACTIVE || entry.state === WatcherState.INITIALIZING)
);
}
/**
* Handle cleanup after initialization completes.
* Returns true if the watcher should continue, false if it should be cleaned up.
*/
handlePostInitialization(key: string, watcher: chokidar.FSWatcher, position: number): boolean {
const currentEntry = this.getEntry(key);
if (!currentEntry || currentEntry.state === WatcherState.STOPPING) {
// We were stopped during initialization, clean up immediately
this.logger.debug(`Watcher for ${key} was stopped during initialization, cleaning up`);
watcher.close();
this.removeEntry(key);
return false;
}
// Store the active watcher and position
this.setActive(key, watcher, position);
return true;
}
/**
* Stop a watcher, handling all possible states
*/
stopWatcher(key: string): void {
const entry = this.getEntry(key);
if (!entry) {
return;
}
if (entry.state === WatcherState.INITIALIZING) {
// Mark as stopping so the initialization will clean up
this.setStopping(key);
this.logger.debug(`Marked watcher as stopping during initialization: ${key}`);
} else if (entry.state === WatcherState.ACTIVE) {
// Close the active watcher
entry.watcher.close();
this.removeEntry(key);
this.logger.debug(`Stopped active watcher: ${key}`);
}
}
/**
* Update the position for an active watcher
*/
updatePosition(key: string, newPosition: number): void {
const entry = this.getEntry(key);
if (this.isActive(entry)) {
entry.position = newPosition;
}
}
/**
* Start processing a change event (set inFlight to true)
* Returns true if processing can proceed, false if already in flight
*/
startProcessing(key: string): boolean {
const entry = this.getEntry(key);
if (this.isActive(entry)) {
if (entry.inFlight) {
return false; // Already processing
}
entry.inFlight = true;
return true;
}
return false;
}
/**
* Finish processing a change event (set inFlight to false)
*/
finishProcessing(key: string): void {
const entry = this.getEntry(key);
if (this.isActive(entry)) {
entry.inFlight = false;
}
}
/**
* Get the position for an active watcher
*/
getPosition(key: string): number | undefined {
const entry = this.getEntry(key);
if (this.isActive(entry)) {
return entry.position;
}
return undefined;
}
/**
* Clean up all watchers (useful for module cleanup)
*/
stopAllWatchers(): void {
for (const entry of this.watchers.values()) {
if (this.isActive(entry)) {
entry.watcher.close();
}
}
this.watchers.clear();
}
}

View File

@@ -1,10 +1,13 @@
import { Module } from '@nestjs/common';
import { LogWatcherManager } from '@app/unraid-api/graph/resolvers/logs/log-watcher-manager.service.js';
import { LogsResolver } from '@app/unraid-api/graph/resolvers/logs/logs.resolver.js';
import { LogsService } from '@app/unraid-api/graph/resolvers/logs/logs.service.js';
import { ServicesModule } from '@app/unraid-api/graph/services/services.module.js';
@Module({
providers: [LogsResolver, LogsService],
exports: [LogsService],
imports: [ServicesModule],
providers: [LogsResolver, LogsService, LogWatcherManager],
exports: [LogsService, LogWatcherManager],
})
export class LogsModule {}

View File

@@ -1,9 +1,10 @@
import { Test, TestingModule } from '@nestjs/testing';
import { beforeEach, describe, expect, it } from 'vitest';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { LogsResolver } from '@app/unraid-api/graph/resolvers/logs/logs.resolver.js';
import { LogsService } from '@app/unraid-api/graph/resolvers/logs/logs.service.js';
import { SubscriptionHelperService } from '@app/unraid-api/graph/services/subscription-helper.service.js';
describe('LogsResolver', () => {
let resolver: LogsResolver;
@@ -18,6 +19,13 @@ describe('LogsResolver', () => {
// Add mock implementations for service methods used by resolver
},
},
{
provide: SubscriptionHelperService,
useValue: {
// Add mock implementations for subscription helper methods
createTrackedSubscription: vi.fn(),
},
},
],
}).compile();
resolver = module.get<LogsResolver>(LogsResolver);

View File

@@ -3,13 +3,16 @@ import { Args, Int, Query, Resolver, Subscription } from '@nestjs/graphql';
import { AuthAction, Resource } from '@unraid/shared/graphql.model.js';
import { UsePermissions } from '@unraid/shared/use-permissions.directive.js';
import { createSubscription, PUBSUB_CHANNEL } from '@app/core/pubsub.js';
import { LogFile, LogFileContent } from '@app/unraid-api/graph/resolvers/logs/logs.model.js';
import { LogsService } from '@app/unraid-api/graph/resolvers/logs/logs.service.js';
import { SubscriptionHelperService } from '@app/unraid-api/graph/services/subscription-helper.service.js';
@Resolver(() => LogFile)
export class LogsResolver {
constructor(private readonly logsService: LogsService) {}
constructor(
private readonly logsService: LogsService,
private readonly subscriptionHelper: SubscriptionHelperService
) {}
@Query(() => [LogFile])
@UsePermissions({
@@ -38,27 +41,12 @@ export class LogsResolver {
action: AuthAction.READ_ANY,
resource: Resource.LOGS,
})
async logFileSubscription(@Args('path') path: string) {
// Start watching the file
this.logsService.getLogFileSubscriptionChannel(path);
logFileSubscription(@Args('path') path: string) {
// Register the topic and get the key
const topicKey = this.logsService.registerLogFileSubscription(path);
// Create the async iterator
const asyncIterator = createSubscription(PUBSUB_CHANNEL.LOG_FILE);
// Store the original return method to wrap it
const originalReturn = asyncIterator.return;
// Override the return method to clean up resources
asyncIterator.return = async () => {
// Stop watching the file when subscription ends
this.logsService.stopWatchingLogFile(path);
// Call the original return method
return originalReturn
? originalReturn.call(asyncIterator)
: Promise.resolve({ value: undefined, done: true });
};
return asyncIterator;
// Use the helper service to create a tracked subscription
// This automatically handles subscribe/unsubscribe with reference counting
return this.subscriptionHelper.createTrackedSubscription(topicKey);
}
}

View File

@@ -0,0 +1,201 @@
import { Test, TestingModule } from '@nestjs/testing';
import * as fs from 'node:fs/promises';
import * as chokidar from 'chokidar';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import { LogWatcherManager } from '@app/unraid-api/graph/resolvers/logs/log-watcher-manager.service.js';
import { LogsService } from '@app/unraid-api/graph/resolvers/logs/logs.service.js';
import { SubscriptionTrackerService } from '@app/unraid-api/graph/services/subscription-tracker.service.js';
vi.mock('node:fs/promises');
vi.mock('chokidar');
vi.mock('@app/store/index.js', () => ({
getters: {
paths: () => ({
'unraid-log-base': '/var/log',
}),
},
}));
vi.mock('@app/core/pubsub.js', () => ({
pubsub: {
publish: vi.fn(),
},
PUBSUB_CHANNEL: {},
}));
describe('LogsService', () => {
let service: LogsService;
let mockWatcher: any;
let subscriptionTracker: any;
beforeEach(async () => {
// Create a mock watcher
mockWatcher = {
on: vi.fn(),
close: vi.fn(),
};
// Mock chokidar.watch to return our mock watcher
vi.mocked(chokidar.watch).mockReturnValue(mockWatcher as any);
// Mock fs.stat to return a file size
vi.mocked(fs.stat).mockResolvedValue({ size: 1000 } as any);
subscriptionTracker = {
getSubscriberCount: vi.fn().mockReturnValue(0),
registerTopic: vi.fn(),
};
const module: TestingModule = await Test.createTestingModule({
providers: [
LogsService,
LogWatcherManager,
{
provide: SubscriptionTrackerService,
useValue: subscriptionTracker,
},
],
}).compile();
service = module.get<LogsService>(LogsService);
});
afterEach(() => {
vi.clearAllMocks();
});
it('should be defined', () => {
expect(service).toBeDefined();
});
it('should handle race condition when stopping watcher during initialization', async () => {
// Setup: Register the subscription which will trigger registerTopic
service.registerLogFileSubscription('test.log');
// Get the onStart callback that was registered
const registerTopicCall = subscriptionTracker.registerTopic.mock.calls[0];
const onStartCallback = registerTopicCall[1];
const onStopCallback = registerTopicCall[2];
// Create a promise to control when stat resolves
let statResolve: any;
const statPromise = new Promise((resolve) => {
statResolve = resolve;
});
vi.mocked(fs.stat).mockReturnValue(statPromise as any);
// Start the watcher (this will call startWatchingLogFile internally)
onStartCallback();
// At this point, the watcher should be marked as 'initializing'
// Now call stop before the stat promise resolves
onStopCallback();
// Now resolve the stat promise to complete initialization
statResolve({ size: 1000 });
// Wait for any async operations to complete
await new Promise((resolve) => setImmediate(resolve));
// The watcher should have been closed due to the race condition check
expect(mockWatcher.close).toHaveBeenCalled();
});
it('should not leak watcher if stopped multiple times during initialization', async () => {
// Setup: Register the subscription
service.registerLogFileSubscription('test.log');
const registerTopicCall = subscriptionTracker.registerTopic.mock.calls[0];
const onStartCallback = registerTopicCall[1];
const onStopCallback = registerTopicCall[2];
// Create controlled stat promise
let statResolve: any;
const statPromise = new Promise((resolve) => {
statResolve = resolve;
});
vi.mocked(fs.stat).mockReturnValue(statPromise as any);
// Start the watcher
onStartCallback();
// Call stop multiple times during initialization
onStopCallback();
onStopCallback();
onStopCallback();
// Complete initialization
statResolve({ size: 1000 });
await new Promise((resolve) => setImmediate(resolve));
// Should only close once
expect(mockWatcher.close).toHaveBeenCalledTimes(1);
});
it('should properly handle normal start and stop without race condition', async () => {
// Setup: Register the subscription
service.registerLogFileSubscription('test.log');
const registerTopicCall = subscriptionTracker.registerTopic.mock.calls[0];
const onStartCallback = registerTopicCall[1];
const onStopCallback = registerTopicCall[2];
// Make stat resolve immediately
vi.mocked(fs.stat).mockResolvedValue({ size: 1000 } as any);
// Start the watcher and let it complete initialization
onStartCallback();
await new Promise((resolve) => setImmediate(resolve));
// Watcher should be created but not closed
expect(chokidar.watch).toHaveBeenCalled();
expect(mockWatcher.close).not.toHaveBeenCalled();
// Now stop it normally
onStopCallback();
// Watcher should be closed
expect(mockWatcher.close).toHaveBeenCalledTimes(1);
});
it('should handle error during initialization without leaking watchers', async () => {
// Setup: Register the subscription
service.registerLogFileSubscription('test.log');
const registerTopicCall = subscriptionTracker.registerTopic.mock.calls[0];
const onStartCallback = registerTopicCall[1];
// Make stat reject with an error
vi.mocked(fs.stat).mockRejectedValue(new Error('File not found'));
// Start the watcher (should fail during initialization)
onStartCallback();
await new Promise((resolve) => setImmediate(resolve));
// Watcher should never be created due to stat error
expect(chokidar.watch).not.toHaveBeenCalled();
expect(mockWatcher.close).not.toHaveBeenCalled();
});
it('should not create duplicate watchers when started multiple times', async () => {
// Setup: Register the subscription
service.registerLogFileSubscription('test.log');
const registerTopicCall = subscriptionTracker.registerTopic.mock.calls[0];
const onStartCallback = registerTopicCall[1];
// Make stat resolve immediately
vi.mocked(fs.stat).mockResolvedValue({ size: 1000 } as any);
// Start the watcher multiple times
onStartCallback();
onStartCallback();
onStartCallback();
await new Promise((resolve) => setImmediate(resolve));
// Should only create one watcher
expect(chokidar.watch).toHaveBeenCalledTimes(1);
});
});

View File

@@ -1,13 +1,15 @@
import { Injectable, Logger } from '@nestjs/common';
import { createReadStream } from 'node:fs';
import { readdir, readFile, stat } from 'node:fs/promises';
import { readdir, stat } from 'node:fs/promises';
import { basename, join } from 'node:path';
import { createInterface } from 'node:readline';
import * as chokidar from 'chokidar';
import { pubsub, PUBSUB_CHANNEL } from '@app/core/pubsub.js';
import { pubsub } from '@app/core/pubsub.js';
import { getters } from '@app/store/index.js';
import { LogWatcherManager } from '@app/unraid-api/graph/resolvers/logs/log-watcher-manager.service.js';
import { SubscriptionTrackerService } from '@app/unraid-api/graph/services/subscription-tracker.service.js';
interface LogFile {
name: string;
@@ -26,12 +28,13 @@ interface LogFileContent {
@Injectable()
export class LogsService {
private readonly logger = new Logger(LogsService.name);
private readonly logWatchers = new Map<
string,
{ watcher: chokidar.FSWatcher; position: number; subscriptionCount: number }
>();
private readonly DEFAULT_LINES = 100;
constructor(
private readonly subscriptionTracker: SubscriptionTrackerService,
private readonly watcherManager: LogWatcherManager
) {}
/**
* Get the base path for log files
*/
@@ -111,135 +114,208 @@ export class LogsService {
}
/**
* Get the subscription channel for a log file
* Register and get the topic key for a log file subscription
* @param path Path to the log file
* @returns The subscription topic key
*/
getLogFileSubscriptionChannel(path: string): PUBSUB_CHANNEL {
registerLogFileSubscription(path: string): string {
const normalizedPath = join(this.logBasePath, basename(path));
const topicKey = this.getTopicKey(normalizedPath);
// Start watching the file if not already watching
if (!this.logWatchers.has(normalizedPath)) {
this.startWatchingLogFile(normalizedPath);
} else {
// Increment subscription count for existing watcher
const watcher = this.logWatchers.get(normalizedPath);
if (watcher) {
watcher.subscriptionCount++;
this.logger.debug(
`Incremented subscription count for ${normalizedPath} to ${watcher.subscriptionCount}`
);
}
// Register the topic if not already registered
if (!this.subscriptionTracker.getSubscriberCount(topicKey)) {
this.logger.debug(`Registering log file subscription topic: ${topicKey}`);
this.subscriptionTracker.registerTopic(
topicKey,
// onStart handler
() => {
this.logger.debug(`Starting log file watcher for topic: ${topicKey}`);
this.startWatchingLogFile(normalizedPath);
},
// onStop handler
() => {
this.logger.debug(`Stopping log file watcher for topic: ${topicKey}`);
this.stopWatchingLogFile(normalizedPath);
}
);
}
return PUBSUB_CHANNEL.LOG_FILE;
return topicKey;
}
/**
* Start watching a log file for changes using chokidar
* @param path Path to the log file
*/
private async startWatchingLogFile(path: string): Promise<void> {
try {
// Get initial file size
const stats = await stat(path);
let position = stats.size;
private startWatchingLogFile(path: string): void {
const watcherKey = path;
// Create a watcher for the file using chokidar
const watcher = chokidar.watch(path, {
persistent: true,
awaitWriteFinish: {
stabilityThreshold: 300,
pollInterval: 100,
},
});
// Check if already watching or initializing
if (this.watcherManager.isWatchingOrInitializing(watcherKey)) {
this.logger.debug(`Already watching or initializing log file: ${watcherKey}`);
return;
}
watcher.on('change', async () => {
try {
const newStats = await stat(path);
// Mark as initializing immediately to prevent race conditions
this.watcherManager.setInitializing(watcherKey);
// If the file has grown
if (newStats.size > position) {
// Read only the new content
const stream = createReadStream(path, {
start: position,
end: newStats.size - 1,
});
// Get initial file size and set up watcher
stat(path)
.then((stats) => {
const position = stats.size;
let newContent = '';
stream.on('data', (chunk) => {
newContent += chunk.toString();
});
// Create a watcher for the file using chokidar
const watcher = chokidar.watch(path, {
persistent: true,
awaitWriteFinish: {
stabilityThreshold: 300,
pollInterval: 100,
},
});
stream.on('end', () => {
if (newContent) {
pubsub.publish(PUBSUB_CHANNEL.LOG_FILE, {
watcher.on('change', async () => {
// Check if we're already processing a change event for this file
if (!this.watcherManager.startProcessing(watcherKey)) {
// Already processing, ignore this event
return;
}
try {
const newStats = await stat(path);
// Get the current position
const currentPosition = this.watcherManager.getPosition(watcherKey);
if (currentPosition === undefined) {
// Watcher was stopped or not active, ignore the event
return;
}
// If the file has grown
if (newStats.size > currentPosition) {
// Read only the new content
const stream = createReadStream(path, {
start: currentPosition,
end: newStats.size - 1,
});
let newContent = '';
stream.on('data', (chunk) => {
newContent += chunk.toString();
});
stream.on('end', () => {
try {
if (newContent) {
// Use topic-specific channel
const topicKey = this.getTopicKey(path);
pubsub.publish(topicKey, {
logFile: {
path,
content: newContent,
totalLines: 0, // We don't need to count lines for updates
},
});
}
// Update position for next read (while still holding the guard)
this.watcherManager.updatePosition(watcherKey, newStats.size);
} finally {
// Clear the in-flight flag
this.watcherManager.finishProcessing(watcherKey);
}
});
stream.on('error', (error) => {
this.logger.error(`Error reading stream for ${path}: ${error}`);
// Clear the in-flight flag on error
this.watcherManager.finishProcessing(watcherKey);
});
} else if (newStats.size < currentPosition) {
// File was truncated, reset position and read from beginning
this.logger.debug(`File ${path} was truncated, resetting position`);
try {
// Read the entire file content
const content = await this.getLogFileContent(
path,
this.DEFAULT_LINES,
undefined
);
// Use topic-specific channel
const topicKey = this.getTopicKey(path);
pubsub.publish(topicKey, {
logFile: {
path,
content: newContent,
totalLines: 0, // We don't need to count lines for updates
...content,
},
});
// Update position (while still holding the guard)
this.watcherManager.updatePosition(watcherKey, newStats.size);
} finally {
// Clear the in-flight flag
this.watcherManager.finishProcessing(watcherKey);
}
// Update position for next read
position = newStats.size;
});
} else if (newStats.size < position) {
// File was truncated, reset position and read from beginning
position = 0;
this.logger.debug(`File ${path} was truncated, resetting position`);
// Read the entire file content
const content = await this.getLogFileContent(path);
pubsub.publish(PUBSUB_CHANNEL.LOG_FILE, {
logFile: content,
});
position = newStats.size;
} else {
// File size unchanged, clear the in-flight flag
this.watcherManager.finishProcessing(watcherKey);
}
} catch (error: unknown) {
this.logger.error(`Error processing file change for ${path}: ${error}`);
// Clear the in-flight flag on error
this.watcherManager.finishProcessing(watcherKey);
}
} catch (error: unknown) {
this.logger.error(`Error processing file change for ${path}: ${error}`);
});
watcher.on('error', (error) => {
this.logger.error(`Chokidar watcher error for ${path}: ${error}`);
});
// Check if we were stopped during initialization and handle cleanup
if (!this.watcherManager.handlePostInitialization(watcherKey, watcher, position)) {
return;
}
// Publish initial snapshot
this.getLogFileContent(path, this.DEFAULT_LINES, undefined)
.then((content) => {
const topicKey = this.getTopicKey(path);
pubsub.publish(topicKey, {
logFile: {
...content,
},
});
})
.catch((error) => {
this.logger.error(`Error publishing initial log content for ${path}: ${error}`);
});
this.logger.debug(`Started watching log file with chokidar: ${path}`);
})
.catch((error) => {
this.logger.error(`Error setting up file watcher for ${path}: ${error}`);
// Clean up the initializing entry on error
this.watcherManager.removeEntry(watcherKey);
});
}
watcher.on('error', (error) => {
this.logger.error(`Chokidar watcher error for ${path}: ${error}`);
});
// Store the watcher and current position with initial subscription count of 1
this.logWatchers.set(path, { watcher, position, subscriptionCount: 1 });
this.logger.debug(
`Started watching log file with chokidar: ${path} (subscription count: 1)`
);
} catch (error: unknown) {
this.logger.error(`Error setting up chokidar file watcher for ${path}: ${error}`);
}
/**
* Get the topic key for a log file subscription
* @param path Path to the log file (should already be normalized)
* @returns The topic key
*/
private getTopicKey(path: string): string {
// Assume path is already normalized (full path)
return `LOG_FILE:${path}`;
}
/**
* Stop watching a log file
* @param path Path to the log file
*/
public stopWatchingLogFile(path: string): void {
const normalizedPath = join(this.logBasePath, basename(path));
const watcher = this.logWatchers.get(normalizedPath);
if (watcher) {
// Decrement subscription count
watcher.subscriptionCount--;
this.logger.debug(
`Decremented subscription count for ${normalizedPath} to ${watcher.subscriptionCount}`
);
// Only close the watcher when subscription count reaches 0
if (watcher.subscriptionCount <= 0) {
watcher.watcher.close();
this.logWatchers.delete(normalizedPath);
this.logger.debug(`Stopped watching log file: ${normalizedPath} (no more subscribers)`);
}
}
private stopWatchingLogFile(path: string): void {
this.watcherManager.stopWatcher(path);
}
/**

View File

@@ -9,7 +9,7 @@ import { CpuService } from '@app/unraid-api/graph/resolvers/info/cpu/cpu.service
import { MemoryService } from '@app/unraid-api/graph/resolvers/info/memory/memory.service.js';
import { MetricsResolver } from '@app/unraid-api/graph/resolvers/metrics/metrics.resolver.js';
import { SubscriptionHelperService } from '@app/unraid-api/graph/services/subscription-helper.service.js';
import { SubscriptionPollingService } from '@app/unraid-api/graph/services/subscription-polling.service.js';
import { SubscriptionManagerService } from '@app/unraid-api/graph/services/subscription-manager.service.js';
import { SubscriptionTrackerService } from '@app/unraid-api/graph/services/subscription-tracker.service.js';
describe('MetricsResolver Integration Tests', () => {
@@ -25,7 +25,7 @@ describe('MetricsResolver Integration Tests', () => {
MemoryService,
SubscriptionTrackerService,
SubscriptionHelperService,
SubscriptionPollingService,
SubscriptionManagerService,
],
}).compile();
@@ -36,8 +36,8 @@ describe('MetricsResolver Integration Tests', () => {
afterEach(async () => {
// Clean up polling service
const pollingService = module.get<SubscriptionPollingService>(SubscriptionPollingService);
pollingService.stopAll();
const subscriptionManager = module.get<SubscriptionManagerService>(SubscriptionManagerService);
subscriptionManager.stopAll();
await module.close();
});
@@ -202,10 +202,13 @@ describe('MetricsResolver Integration Tests', () => {
it('should handle errors in CPU polling gracefully', async () => {
const service = module.get<CpuService>(CpuService);
const trackerService = module.get<SubscriptionTrackerService>(SubscriptionTrackerService);
const pollingService = module.get<SubscriptionPollingService>(SubscriptionPollingService);
const subscriptionManager =
module.get<SubscriptionManagerService>(SubscriptionManagerService);
// Mock logger to capture error logs
const loggerSpy = vi.spyOn(pollingService['logger'], 'error').mockImplementation(() => {});
const loggerSpy = vi
.spyOn(subscriptionManager['logger'], 'error')
.mockImplementation(() => {});
vi.spyOn(service, 'generateCpuLoad').mockRejectedValueOnce(new Error('CPU error'));
// Trigger polling
@@ -215,7 +218,7 @@ describe('MetricsResolver Integration Tests', () => {
await new Promise((resolve) => setTimeout(resolve, 1100));
expect(loggerSpy).toHaveBeenCalledWith(
expect.stringContaining('Error in polling task'),
expect.stringContaining('Error in subscription callback'),
expect.any(Error)
);
@@ -226,10 +229,13 @@ describe('MetricsResolver Integration Tests', () => {
it('should handle errors in memory polling gracefully', async () => {
const service = module.get<MemoryService>(MemoryService);
const trackerService = module.get<SubscriptionTrackerService>(SubscriptionTrackerService);
const pollingService = module.get<SubscriptionPollingService>(SubscriptionPollingService);
const subscriptionManager =
module.get<SubscriptionManagerService>(SubscriptionManagerService);
// Mock logger to capture error logs
const loggerSpy = vi.spyOn(pollingService['logger'], 'error').mockImplementation(() => {});
const loggerSpy = vi
.spyOn(subscriptionManager['logger'], 'error')
.mockImplementation(() => {});
vi.spyOn(service, 'generateMemoryLoad').mockRejectedValueOnce(new Error('Memory error'));
// Trigger polling
@@ -239,7 +245,7 @@ describe('MetricsResolver Integration Tests', () => {
await new Promise((resolve) => setTimeout(resolve, 2100));
expect(loggerSpy).toHaveBeenCalledWith(
expect.stringContaining('Error in polling task'),
expect.stringContaining('Error in subscription callback'),
expect.any(Error)
);
@@ -251,22 +257,30 @@ describe('MetricsResolver Integration Tests', () => {
describe('Polling cleanup on module destroy', () => {
it('should clean up timers when module is destroyed', async () => {
const trackerService = module.get<SubscriptionTrackerService>(SubscriptionTrackerService);
const pollingService = module.get<SubscriptionPollingService>(SubscriptionPollingService);
const subscriptionManager =
module.get<SubscriptionManagerService>(SubscriptionManagerService);
// Start polling
trackerService.subscribe(PUBSUB_CHANNEL.CPU_UTILIZATION);
trackerService.subscribe(PUBSUB_CHANNEL.MEMORY_UTILIZATION);
// Verify polling is active
expect(pollingService.isPolling(PUBSUB_CHANNEL.CPU_UTILIZATION)).toBe(true);
expect(pollingService.isPolling(PUBSUB_CHANNEL.MEMORY_UTILIZATION)).toBe(true);
// Wait a bit for subscriptions to be fully set up
await new Promise((resolve) => setTimeout(resolve, 100));
// Verify subscriptions are active
expect(subscriptionManager.isSubscriptionActive(PUBSUB_CHANNEL.CPU_UTILIZATION)).toBe(true);
expect(subscriptionManager.isSubscriptionActive(PUBSUB_CHANNEL.MEMORY_UTILIZATION)).toBe(
true
);
// Clean up the module
await module.close();
// Timers should be cleaned up
expect(pollingService.isPolling(PUBSUB_CHANNEL.CPU_UTILIZATION)).toBe(false);
expect(pollingService.isPolling(PUBSUB_CHANNEL.MEMORY_UTILIZATION)).toBe(false);
// Subscriptions should be cleaned up
expect(subscriptionManager.isSubscriptionActive(PUBSUB_CHANNEL.CPU_UTILIZATION)).toBe(false);
expect(subscriptionManager.isSubscriptionActive(PUBSUB_CHANNEL.MEMORY_UTILIZATION)).toBe(
false
);
});
});
});

View File

@@ -1,6 +1,7 @@
import { Module } from '@nestjs/common';
import { AuthModule } from '@app/unraid-api/auth/auth.module.js';
import { ApiConfigModule } from '@app/unraid-api/config/api-config.module.js';
import { ApiKeyModule } from '@app/unraid-api/graph/resolvers/api-key/api-key.module.js';
import { ApiKeyResolver } from '@app/unraid-api/graph/resolvers/api-key/api-key.resolver.js';
import { ArrayModule } from '@app/unraid-api/graph/resolvers/array/array.module.js';
@@ -11,8 +12,7 @@ import { DockerModule } from '@app/unraid-api/graph/resolvers/docker/docker.modu
import { FlashBackupModule } from '@app/unraid-api/graph/resolvers/flash-backup/flash-backup.module.js';
import { FlashResolver } from '@app/unraid-api/graph/resolvers/flash/flash.resolver.js';
import { InfoModule } from '@app/unraid-api/graph/resolvers/info/info.module.js';
import { LogsResolver } from '@app/unraid-api/graph/resolvers/logs/logs.resolver.js';
import { LogsService } from '@app/unraid-api/graph/resolvers/logs/logs.service.js';
import { LogsModule } from '@app/unraid-api/graph/resolvers/logs/logs.module.js';
import { MetricsModule } from '@app/unraid-api/graph/resolvers/metrics/metrics.module.js';
import { RootMutationsResolver } from '@app/unraid-api/graph/resolvers/mutation/mutation.resolver.js';
import { NotificationsResolver } from '@app/unraid-api/graph/resolvers/notifications/notifications.resolver.js';
@@ -39,12 +39,14 @@ import { MeResolver } from '@app/unraid-api/graph/user/user.resolver.js';
ServicesModule,
ArrayModule,
ApiKeyModule,
ApiConfigModule,
AuthModule,
CustomizationModule,
DockerModule,
DisksModule,
FlashBackupModule,
InfoModule,
LogsModule,
RCloneModule,
SettingsModule,
SsoModule,
@@ -54,8 +56,6 @@ import { MeResolver } from '@app/unraid-api/graph/user/user.resolver.js';
providers: [
ConfigResolver,
FlashResolver,
LogsResolver,
LogsService,
MeResolver,
NotificationsResolver,
NotificationsService,

View File

@@ -16,8 +16,8 @@ import {
} from '@app/unraid-api/graph/resolvers/settings/settings.model.js';
import { ApiSettings } from '@app/unraid-api/graph/resolvers/settings/settings.service.js';
import { SsoSettings } from '@app/unraid-api/graph/resolvers/settings/sso-settings.model.js';
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/oidc-config.service.js';
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/oidc-provider.model.js';
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
@Resolver(() => Settings)
export class SettingsResolver {

View File

@@ -7,7 +7,7 @@ import { type ApiConfig } from '@unraid/shared/services/api-config.js';
import { UserSettingsService } from '@unraid/shared/services/user-settings.js';
import { execa } from 'execa';
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/oidc-config.service.js';
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
import { createLabeledControl } from '@app/unraid-api/graph/utils/form-utils.js';
import { SettingSlice } from '@app/unraid-api/types/json-forms.js';

View File

@@ -0,0 +1,11 @@
import { Module } from '@nestjs/common';
import { OidcAuthorizationService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-authorization.service.js';
import { OidcClaimsService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-claims.service.js';
import { OidcTokenExchangeService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-token-exchange.service.js';
@Module({
providers: [OidcAuthorizationService, OidcTokenExchangeService, OidcClaimsService],
exports: [OidcAuthorizationService, OidcTokenExchangeService, OidcClaimsService],
})
export class OidcAuthModule {}

View File

@@ -1,70 +1,26 @@
import { UnauthorizedException } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import { Test, TestingModule } from '@nestjs/testing';
import * as client from 'openid-client';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { OidcAuthService } from '@app/unraid-api/graph/resolvers/sso/oidc-auth.service.js';
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/oidc-config.service.js';
import { OidcAuthorizationService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-authorization.service.js';
import {
AuthorizationOperator,
AuthorizationRuleMode,
OidcAuthorizationRule,
OidcProvider,
} from '@app/unraid-api/graph/resolvers/sso/oidc-provider.model.js';
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/oidc-session.service.js';
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/oidc-state.service.js';
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/oidc-validation.service.js';
} from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
describe('OidcAuthService', () => {
let service: OidcAuthService;
let oidcConfig: any;
let sessionService: any;
let configService: any;
let stateService: any;
let validationService: any;
describe('OidcAuthorizationService', () => {
let service: OidcAuthorizationService;
let module: TestingModule;
beforeEach(async () => {
module = await Test.createTestingModule({
providers: [
OidcAuthService,
{
provide: ConfigService,
useValue: {
get: vi.fn(),
},
},
{
provide: OidcConfigPersistence,
useValue: {
getProvider: vi.fn(),
},
},
{
provide: OidcSessionService,
useValue: {
createSession: vi.fn(),
},
},
OidcStateService,
{
provide: OidcValidationService,
useValue: {
validateProvider: vi.fn(),
performDiscovery: vi.fn(),
},
},
],
providers: [OidcAuthorizationService],
}).compile();
service = module.get<OidcAuthService>(OidcAuthService);
oidcConfig = module.get(OidcConfigPersistence);
sessionService = module.get(OidcSessionService);
configService = module.get(ConfigService);
stateService = module.get(OidcStateService);
validationService = module.get<OidcValidationService>(OidcValidationService);
service = module.get<OidcAuthorizationService>(OidcAuthorizationService);
});
describe('Authorization Rule Evaluation', () => {
@@ -1189,467 +1145,4 @@ describe('OidcAuthService', () => {
).resolves.toBeUndefined();
});
});
describe('Manual Configuration (No Discovery)', () => {
it('should create manual configuration when discovery fails but manual endpoints are provided', async () => {
const provider: OidcProvider = {
id: 'manual-provider',
name: 'Manual Provider',
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
issuer: 'https://manual.example.com',
authorizationEndpoint: 'https://manual.example.com/auth',
tokenEndpoint: 'https://manual.example.com/token',
jwksUri: 'https://manual.example.com/jwks',
scopes: ['openid', 'profile'],
authorizationRules: [],
};
oidcConfig.getProvider.mockResolvedValue(provider);
// Mock discovery to fail
validationService.performDiscovery = vi
.fn()
.mockRejectedValue(new Error('Discovery failed'));
// Access the private method
const getOrCreateConfig = async (provider: OidcProvider) => {
return (service as any).getOrCreateConfig(provider);
};
const config = await getOrCreateConfig(provider);
// Verify the configuration was created with the correct endpoints
expect(config).toBeDefined();
expect(config.serverMetadata().authorization_endpoint).toBe(
'https://manual.example.com/auth'
);
expect(config.serverMetadata().token_endpoint).toBe('https://manual.example.com/token');
expect(config.serverMetadata().jwks_uri).toBe('https://manual.example.com/jwks');
expect(config.serverMetadata().issuer).toBe('https://manual.example.com');
});
it('should create manual configuration with fallback issuer when not provided', async () => {
const provider: OidcProvider = {
id: 'manual-provider-no-issuer',
name: 'Manual Provider No Issuer',
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
issuer: '', // Empty issuer should skip discovery and use manual endpoints
authorizationEndpoint: 'https://manual.example.com/auth',
tokenEndpoint: 'https://manual.example.com/token',
scopes: ['openid', 'profile'],
authorizationRules: [],
};
oidcConfig.getProvider.mockResolvedValue(provider);
// No need to mock discovery since it won't be called with empty issuer
// Access the private method
const getOrCreateConfig = async (provider: OidcProvider) => {
return (service as any).getOrCreateConfig(provider);
};
const config = await getOrCreateConfig(provider);
// Verify the configuration was created with fallback issuer
expect(config).toBeDefined();
expect(config.serverMetadata().issuer).toBe('manual-manual-provider-no-issuer');
expect(config.serverMetadata().authorization_endpoint).toBe(
'https://manual.example.com/auth'
);
expect(config.serverMetadata().token_endpoint).toBe('https://manual.example.com/token');
});
it('should handle manual configuration with client secret properly', async () => {
const provider: OidcProvider = {
id: 'manual-with-secret',
name: 'Manual With Secret',
clientId: 'test-client-id',
clientSecret: 'secret-123',
issuer: 'https://manual.example.com',
authorizationEndpoint: 'https://manual.example.com/auth',
tokenEndpoint: 'https://manual.example.com/token',
scopes: ['openid', 'profile'],
authorizationRules: [],
};
oidcConfig.getProvider.mockResolvedValue(provider);
// Mock discovery to fail
validationService.performDiscovery = vi
.fn()
.mockRejectedValue(new Error('Discovery failed'));
// Access the private method
const getOrCreateConfig = async (provider: OidcProvider) => {
return (service as any).getOrCreateConfig(provider);
};
const config = await getOrCreateConfig(provider);
// Verify configuration was created successfully
expect(config).toBeDefined();
expect(config.clientMetadata().client_secret).toBe('secret-123');
});
it('should handle manual configuration without client secret (public client)', async () => {
const provider: OidcProvider = {
id: 'manual-public-client',
name: 'Manual Public Client',
clientId: 'public-client-id',
// No client secret
issuer: 'https://manual.example.com',
authorizationEndpoint: 'https://manual.example.com/auth',
tokenEndpoint: 'https://manual.example.com/token',
scopes: ['openid', 'profile'],
authorizationRules: [],
};
oidcConfig.getProvider.mockResolvedValue(provider);
// Mock discovery to fail
validationService.performDiscovery = vi
.fn()
.mockRejectedValue(new Error('Discovery failed'));
// Access the private method
const getOrCreateConfig = async (provider: OidcProvider) => {
return (service as any).getOrCreateConfig(provider);
};
const config = await getOrCreateConfig(provider);
// Verify configuration was created successfully for public client
expect(config).toBeDefined();
expect(config.clientMetadata().client_secret).toBeUndefined();
});
it('should throw error when discovery fails and no manual endpoints provided', async () => {
const provider: OidcProvider = {
id: 'no-manual-endpoints',
name: 'No Manual Endpoints',
clientId: 'test-client-id',
issuer: 'https://broken.example.com',
// Missing authorizationEndpoint and tokenEndpoint
scopes: ['openid', 'profile'],
authorizationRules: [],
};
oidcConfig.getProvider.mockResolvedValue(provider);
// Mock discovery to fail
validationService.performDiscovery = vi
.fn()
.mockRejectedValue(new Error('Discovery failed'));
// Access the private method
const getOrCreateConfig = async (provider: OidcProvider) => {
return (service as any).getOrCreateConfig(provider);
};
await expect(getOrCreateConfig(provider)).rejects.toThrow(UnauthorizedException);
});
it('should throw error when only authorization endpoint is provided', async () => {
const provider: OidcProvider = {
id: 'partial-manual-endpoints',
name: 'Partial Manual Endpoints',
clientId: 'test-client-id',
issuer: 'https://broken.example.com',
authorizationEndpoint: 'https://manual.example.com/auth',
// Missing tokenEndpoint
scopes: ['openid', 'profile'],
authorizationRules: [],
};
oidcConfig.getProvider.mockResolvedValue(provider);
// Mock discovery to fail
validationService.performDiscovery = vi
.fn()
.mockRejectedValue(new Error('Discovery failed'));
// Access the private method
const getOrCreateConfig = async (provider: OidcProvider) => {
return (service as any).getOrCreateConfig(provider);
};
await expect(getOrCreateConfig(provider)).rejects.toThrow(UnauthorizedException);
});
it('should cache manual configuration properly', async () => {
const provider: OidcProvider = {
id: 'cache-test',
name: 'Cache Test',
clientId: 'test-client-id',
clientSecret: 'test-secret',
issuer: 'https://manual.example.com',
authorizationEndpoint: 'https://manual.example.com/auth',
tokenEndpoint: 'https://manual.example.com/token',
scopes: ['openid', 'profile'],
authorizationRules: [],
};
oidcConfig.getProvider.mockResolvedValue(provider);
// Mock discovery to fail
validationService.performDiscovery = vi
.fn()
.mockRejectedValue(new Error('Discovery failed'));
// Access the private method
const getOrCreateConfig = async (provider: OidcProvider) => {
return (service as any).getOrCreateConfig(provider);
};
// First call should create configuration
const config1 = await getOrCreateConfig(provider);
// Second call should return cached configuration
const config2 = await getOrCreateConfig(provider);
expect(config1).toBe(config2); // Should be the exact same instance
expect(validationService.performDiscovery).toHaveBeenCalledTimes(1); // Only called once due to caching
});
it('should handle HTTP endpoints with allowInsecureRequests', async () => {
const provider: OidcProvider = {
id: 'http-endpoints',
name: 'HTTP Endpoints',
clientId: 'test-client-id',
clientSecret: 'test-secret',
issuer: 'http://manual.example.com', // HTTP instead of HTTPS
authorizationEndpoint: 'http://manual.example.com/auth',
tokenEndpoint: 'http://manual.example.com/token',
scopes: ['openid', 'profile'],
authorizationRules: [],
};
oidcConfig.getProvider.mockResolvedValue(provider);
// Mock discovery to fail
validationService.performDiscovery = vi
.fn()
.mockRejectedValue(new Error('Discovery failed'));
// Access the private method
const getOrCreateConfig = async (provider: OidcProvider) => {
return (service as any).getOrCreateConfig(provider);
};
const config = await getOrCreateConfig(provider);
// Verify configuration was created successfully even with HTTP
expect(config).toBeDefined();
expect(config.serverMetadata().token_endpoint).toBe('http://manual.example.com/token');
expect(config.serverMetadata().authorization_endpoint).toBe(
'http://manual.example.com/auth'
);
});
});
describe('getAuthorizationUrl', () => {
it('should generate authorization URL with custom authorization endpoint', async () => {
const provider: OidcProvider = {
id: 'test-provider',
name: 'Test Provider',
clientId: 'test-client-id',
issuer: 'https://example.com',
authorizationEndpoint: 'https://custom.example.com/auth',
scopes: ['openid', 'profile'],
authorizationRules: [],
};
oidcConfig.getProvider.mockResolvedValue(provider);
const authUrl = await service.getAuthorizationUrl(
'test-provider',
'test-state',
'localhost:3001'
);
expect(authUrl).toContain('https://custom.example.com/auth');
expect(authUrl).toContain('client_id=test-client-id');
expect(authUrl).toContain('response_type=code');
expect(authUrl).toContain('scope=openid+profile');
// State should start with provider ID followed by secure state token
expect(authUrl).toMatch(/state=test-provider%3A[a-f0-9]+\.[0-9]+\.[a-f0-9]+/);
expect(authUrl).toContain('redirect_uri=');
});
it('should encode provider ID in state parameter', async () => {
const provider: OidcProvider = {
id: 'encode-test-provider',
name: 'Encode Test Provider',
clientId: 'test-client-id',
issuer: 'https://example.com',
authorizationEndpoint: 'https://example.com/auth',
scopes: ['openid', 'email'],
authorizationRules: [],
};
oidcConfig.getProvider.mockResolvedValue(provider);
const authUrl = await service.getAuthorizationUrl('encode-test-provider', 'original-state');
// Verify that the state parameter includes provider ID at the start
expect(authUrl).toMatch(/state=encode-test-provider%3A[a-f0-9]+\.[0-9]+\.[a-f0-9]+/);
});
it('should throw error when provider not found', async () => {
oidcConfig.getProvider.mockResolvedValue(null);
await expect(
service.getAuthorizationUrl('nonexistent-provider', 'test-state')
).rejects.toThrow('Provider nonexistent-provider not found');
});
it('should handle custom scopes properly', async () => {
const provider: OidcProvider = {
id: 'custom-scopes-provider',
name: 'Custom Scopes Provider',
clientId: 'test-client-id',
issuer: 'https://example.com',
authorizationEndpoint: 'https://example.com/auth',
scopes: ['openid', 'profile', 'groups', 'custom:scope'],
authorizationRules: [],
};
oidcConfig.getProvider.mockResolvedValue(provider);
const authUrl = await service.getAuthorizationUrl('custom-scopes-provider', 'test-state');
expect(authUrl).toContain('scope=openid+profile+groups+custom%3Ascope');
});
});
describe('handleCallback', () => {
it('should throw error when provider not found in callback', async () => {
oidcConfig.getProvider.mockResolvedValue(null);
await expect(
service.handleCallback('nonexistent-provider', 'code', 'redirect-uri')
).rejects.toThrow('Provider nonexistent-provider not found');
});
it('should handle malformed state parameter', async () => {
await expect(
service.handleCallback('invalid-state', 'code', 'redirect-uri')
).rejects.toThrow(UnauthorizedException);
});
it('should call getProvider with the provided provider ID', async () => {
const provider: OidcProvider = {
id: 'test-provider',
name: 'Test Provider',
clientId: 'test-client-id',
issuer: 'https://example.com',
scopes: ['openid'],
authorizationRules: [],
};
oidcConfig.getProvider.mockResolvedValue(provider);
// This will fail during token exchange, but we're testing the provider lookup logic
await expect(
service.handleCallback('test-provider', 'code', 'redirect-uri')
).rejects.toThrow(UnauthorizedException);
// Verify the provider was looked up with the correct ID
expect(oidcConfig.getProvider).toHaveBeenCalledWith('test-provider');
});
});
describe('validateProvider', () => {
it('should delegate to validation service and return result', async () => {
const provider: OidcProvider = {
id: 'validate-provider',
name: 'Validate Provider',
clientId: 'test-client-id',
issuer: 'https://example.com',
scopes: ['openid'],
authorizationRules: [],
};
const expectedResult = {
isValid: true,
authorizationEndpoint: 'https://example.com/auth',
tokenEndpoint: 'https://example.com/token',
};
validationService.validateProvider.mockResolvedValue(expectedResult);
const result = await service.validateProvider(provider);
expect(result).toEqual(expectedResult);
expect(validationService.validateProvider).toHaveBeenCalledWith(provider);
});
it('should clear config cache before validation', async () => {
const provider: OidcProvider = {
id: 'cache-clear-provider',
name: 'Cache Clear Provider',
clientId: 'test-client-id',
issuer: 'https://example.com',
scopes: ['openid'],
authorizationRules: [],
};
const expectedResult = {
isValid: false,
error: 'Validation failed',
};
validationService.validateProvider.mockResolvedValue(expectedResult);
const result = await service.validateProvider(provider);
expect(result).toEqual(expectedResult);
// Verify the cache was cleared by checking the method was called
expect(validationService.validateProvider).toHaveBeenCalledWith(provider);
});
});
describe('getRedirectUri (private method)', () => {
it('should generate correct redirect URI with localhost (development)', () => {
const getRedirectUri = (service as any).getRedirectUri.bind(service);
const redirectUri = getRedirectUri('http://localhost:3000');
expect(redirectUri).toBe('http://localhost:3000/graphql/api/auth/oidc/callback');
});
it('should generate correct redirect URI with non-localhost host', () => {
const getRedirectUri = (service as any).getRedirectUri.bind(service);
const redirectUri = getRedirectUri('https://example.com');
expect(redirectUri).toBe('https://example.com/graphql/api/auth/oidc/callback');
});
it('should handle HTTP protocol for non-localhost hosts', () => {
const getRedirectUri = (service as any).getRedirectUri.bind(service);
const redirectUri = getRedirectUri('http://tower.local');
expect(redirectUri).toBe('http://tower.local/graphql/api/auth/oidc/callback');
});
it('should handle non-standard ports correctly', () => {
const getRedirectUri = (service as any).getRedirectUri.bind(service);
const redirectUri = getRedirectUri('http://example.com:8080');
expect(redirectUri).toBe('http://example.com:8080/graphql/api/auth/oidc/callback');
});
it('should use default redirect URI when no request host provided', () => {
const getRedirectUri = (service as any).getRedirectUri.bind(service);
// Mock the ConfigService to return a default value
configService.get.mockReturnValue('http://tower.local');
const redirectUri = getRedirectUri();
expect(redirectUri).toBe('http://tower.local/graphql/api/auth/oidc/callback');
});
});
});

View File

@@ -0,0 +1,170 @@
import { Injectable, Logger, UnauthorizedException } from '@nestjs/common';
import {
AuthorizationOperator,
AuthorizationRuleMode,
OidcAuthorizationRule,
OidcProvider,
} from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
interface JwtClaims {
sub?: string;
email?: string;
name?: string;
hd?: string; // Google hosted domain
[claim: string]: unknown;
}
@Injectable()
export class OidcAuthorizationService {
private readonly logger = new Logger(OidcAuthorizationService.name);
/**
* Check authorization based on rules
* This will throw a helpful error if misconfigured or unauthorized
*/
async checkAuthorization(provider: OidcProvider, claims: JwtClaims): Promise<void> {
this.logger.debug(
`Checking authorization for provider ${provider.id} with ${provider.authorizationRules?.length || 0} rules`
);
this.logger.debug(`Available claims: ${Object.keys(claims).join(', ')}`);
this.logger.debug(
`Authorization rule mode: ${provider.authorizationRuleMode || AuthorizationRuleMode.OR}`
);
// If no authorization rules are specified, throw a helpful error
if (!provider.authorizationRules || provider.authorizationRules.length === 0) {
throw new UnauthorizedException(
`Login failed: The ${provider.name} provider has no authorization rules configured. ` +
`Please configure authorization rules.`
);
}
this.logger.debug('Authorization rules to evaluate: %o', provider.authorizationRules);
// Evaluate the rules
const ruleMode = provider.authorizationRuleMode || AuthorizationRuleMode.OR;
const isAuthorized = this.evaluateAuthorizationRules(
provider.authorizationRules,
claims,
ruleMode
);
this.logger.debug(`Authorization result: ${isAuthorized}`);
if (!isAuthorized) {
// Log authorization failure with safe claim representation (no PII)
const availableClaimKeys = Object.keys(claims).join(', ');
this.logger.warn(
`Authorization failed for provider ${provider.name}, user ${claims.sub}, available claim keys: [${availableClaimKeys}]`
);
throw new UnauthorizedException(
`Access denied: Your account does not meet the authorization requirements for ${provider.name}.`
);
}
this.logger.debug(`Authorization successful for user ${claims.sub}`);
}
private evaluateAuthorizationRules(
rules: OidcAuthorizationRule[],
claims: JwtClaims,
mode: AuthorizationRuleMode = AuthorizationRuleMode.OR
): boolean {
// No rules means no authorization
if (rules.length === 0) {
return false;
}
if (mode === AuthorizationRuleMode.AND) {
// All rules must pass (AND logic)
return rules.every((rule) => this.evaluateRule(rule, claims));
} else {
// Any rule can pass (OR logic) - default behavior
// Multiple rules act as alternative authorization paths
return rules.some((rule) => this.evaluateRule(rule, claims));
}
}
private evaluateRule(rule: OidcAuthorizationRule, claims: JwtClaims): boolean {
const claimValue = claims[rule.claim];
this.logger.verbose(
`Evaluating rule for claim ${rule.claim}: { claimType: ${typeof claimValue}, isArray: ${Array.isArray(claimValue)}, ruleOperator: ${rule.operator}, ruleValuesCount: ${rule.value.length} }`
);
if (claimValue === undefined || claimValue === null) {
this.logger.verbose(`Claim ${rule.claim} not found in token`);
return false;
}
// Handle non-array, non-string objects
if (typeof claimValue === 'object' && claimValue !== null && !Array.isArray(claimValue)) {
this.logger.warn(
`unexpected JWT claim value encountered - claim ${rule.claim} has unsupported object type (keys: [${Object.keys(claimValue as Record<string, unknown>).join(', ')}])`
);
return false;
}
// Handle array claims - evaluate rule against each array element
if (Array.isArray(claimValue)) {
this.logger.verbose(
`Processing array claim ${rule.claim} with ${claimValue.length} elements`
);
// For array claims, check if ANY element in the array matches the rule
const arrayResult = claimValue.some((element) => {
// Skip non-string elements
if (
typeof element !== 'string' &&
typeof element !== 'number' &&
typeof element !== 'boolean'
) {
this.logger.verbose(`Skipping non-primitive element in array: ${typeof element}`);
return false;
}
const elementValue = String(element);
return this.evaluateSingleValue(elementValue, rule);
});
this.logger.verbose(`Array evaluation result for claim ${rule.claim}: ${arrayResult}`);
return arrayResult;
}
// Handle single value claims (string, number, boolean)
const value = String(claimValue);
this.logger.verbose(`Processing single value claim ${rule.claim}`);
return this.evaluateSingleValue(value, rule);
}
private evaluateSingleValue(value: string, rule: OidcAuthorizationRule): boolean {
let result: boolean;
switch (rule.operator) {
case AuthorizationOperator.EQUALS:
result = rule.value.some((v) => value === v);
this.logger.verbose(`EQUALS check: evaluated for claim ${rule.claim}: ${result}`);
return result;
case AuthorizationOperator.CONTAINS:
result = rule.value.some((v) => value.includes(v));
this.logger.verbose(`CONTAINS check: evaluated for claim ${rule.claim}: ${result}`);
return result;
case AuthorizationOperator.STARTS_WITH:
result = rule.value.some((v) => value.startsWith(v));
this.logger.verbose(`STARTS_WITH check: evaluated for claim ${rule.claim}: ${result}`);
return result;
case AuthorizationOperator.ENDS_WITH:
result = rule.value.some((v) => value.endsWith(v));
this.logger.verbose(`ENDS_WITH check: evaluated for claim ${rule.claim}: ${result}`);
return result;
default:
this.logger.error(`Unknown authorization operator: ${rule.operator}`);
return false;
}
}
}

View File

@@ -0,0 +1,218 @@
import { UnauthorizedException } from '@nestjs/common';
import { Test, TestingModule } from '@nestjs/testing';
import { decodeJwt } from 'jose';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import {
JwtClaims,
OidcClaimsService,
} from '@app/unraid-api/graph/resolvers/sso/auth/oidc-claims.service.js';
// Mock jose
vi.mock('jose', () => ({
decodeJwt: vi.fn(),
}));
describe('OidcClaimsService', () => {
let service: OidcClaimsService;
beforeEach(async () => {
vi.clearAllMocks();
const module: TestingModule = await Test.createTestingModule({
providers: [OidcClaimsService],
}).compile();
service = module.get<OidcClaimsService>(OidcClaimsService);
});
describe('parseIdToken', () => {
it('should parse valid ID token', () => {
const mockClaims: JwtClaims = {
sub: 'user123',
email: 'user@example.com',
name: 'Test User',
iat: 1234567890,
exp: 1234567890,
};
(decodeJwt as any).mockReturnValue(mockClaims);
const result = service.parseIdToken('valid.jwt.token');
expect(result).toEqual(mockClaims);
expect(decodeJwt).toHaveBeenCalledWith('valid.jwt.token');
});
it('should return null when no token provided', () => {
const result = service.parseIdToken(undefined);
expect(result).toBeNull();
});
it('should return null when token parsing fails', () => {
(decodeJwt as any).mockImplementation(() => {
throw new Error('Invalid token');
});
const result = service.parseIdToken('invalid.token');
expect(result).toBeNull();
});
it('should handle claims with array values', () => {
const mockClaims: JwtClaims = {
sub: 'user123',
groups: ['admin', 'user'],
roles: ['role1', 'role2', 'role3'],
};
(decodeJwt as any).mockReturnValue(mockClaims);
const result = service.parseIdToken('token.with.arrays');
expect(result).toEqual(mockClaims);
});
it('should log warning for complex object claims', () => {
const loggerSpy = vi.spyOn(service['logger'], 'warn');
const mockClaims: JwtClaims = {
sub: 'user123',
complexClaim: {
nested: 'value',
another: 'field',
},
};
(decodeJwt as any).mockReturnValue(mockClaims);
service.parseIdToken('token.with.complex');
expect(loggerSpy).toHaveBeenCalledWith(expect.stringContaining('complex object structure'));
});
it('should handle Google-specific claims', () => {
const mockClaims: JwtClaims = {
sub: 'google-user-id',
email: 'user@company.com',
name: 'Google User',
hd: 'company.com', // Google hosted domain
};
(decodeJwt as any).mockReturnValue(mockClaims);
const result = service.parseIdToken('google.jwt.token');
expect(result).toEqual(mockClaims);
expect(result?.hd).toBe('company.com');
});
});
describe('validateClaims', () => {
it('should return user sub when claims are valid', () => {
const claims: JwtClaims = {
sub: 'user123',
email: 'user@example.com',
};
const result = service.validateClaims(claims);
expect(result).toBe('user123');
});
it('should throw UnauthorizedException when claims are null', () => {
expect(() => service.validateClaims(null)).toThrow(UnauthorizedException);
});
it('should throw UnauthorizedException when sub is missing', () => {
const claims: JwtClaims = {
email: 'user@example.com',
name: 'User',
};
expect(() => service.validateClaims(claims)).toThrow(UnauthorizedException);
});
it('should throw UnauthorizedException when sub is empty', () => {
const claims: JwtClaims = {
sub: '',
email: 'user@example.com',
};
expect(() => service.validateClaims(claims)).toThrow(UnauthorizedException);
});
});
describe('extractUserInfo', () => {
it('should extract basic user information', () => {
const claims: JwtClaims = {
sub: 'user123',
email: 'user@example.com',
name: 'Test User',
};
const result = service.extractUserInfo(claims);
expect(result).toEqual({
sub: 'user123',
email: 'user@example.com',
name: 'Test User',
domain: undefined,
});
});
it('should extract Google hosted domain', () => {
const claims: JwtClaims = {
sub: 'google-user',
email: 'user@company.com',
name: 'Google User',
hd: 'company.com',
};
const result = service.extractUserInfo(claims);
expect(result).toEqual({
sub: 'google-user',
email: 'user@company.com',
name: 'Google User',
domain: 'company.com',
});
});
it('should handle missing optional fields', () => {
const claims: JwtClaims = {
sub: 'user123',
};
const result = service.extractUserInfo(claims);
expect(result).toEqual({
sub: 'user123',
email: undefined,
name: undefined,
domain: undefined,
});
});
it('should ignore extra claims', () => {
const claims: JwtClaims = {
sub: 'user123',
email: 'user@example.com',
name: 'Test User',
extra: 'claim',
another: 'field',
groups: ['admin'],
};
const result = service.extractUserInfo(claims);
expect(result).toEqual({
sub: 'user123',
email: 'user@example.com',
name: 'Test User',
domain: undefined,
});
expect(result).not.toHaveProperty('extra');
expect(result).not.toHaveProperty('groups');
});
});
});

View File

@@ -0,0 +1,80 @@
import { Injectable, Logger, UnauthorizedException } from '@nestjs/common';
import { decodeJwt } from 'jose';
export interface JwtClaims {
sub?: string;
email?: string;
name?: string;
hd?: string; // Google hosted domain
[claim: string]: unknown;
}
@Injectable()
export class OidcClaimsService {
private readonly logger = new Logger(OidcClaimsService.name);
parseIdToken(idToken: string | undefined): JwtClaims | null {
if (!idToken) {
this.logger.error('No ID token received from provider');
return null;
}
try {
// Use jose to properly decode the JWT
const claims = decodeJwt(idToken) as JwtClaims;
// Log claims safely without PII - only structure, not values
if (claims) {
const claimKeys = Object.keys(claims).join(', ');
this.logger.debug(`ID token decoded successfully. Available claims: [${claimKeys}]`);
// Log claim types without exposing sensitive values
for (const [key, value] of Object.entries(claims)) {
const valueType = Array.isArray(value) ? `array[${value.length}]` : typeof value;
// Only log structure, not actual values (avoid PII)
this.logger.debug(`Claim '${key}': type=${valueType}`);
// Check for unexpected claim types
if (valueType === 'object' && value !== null && !Array.isArray(value)) {
this.logger.warn(`Claim '${key}' contains complex object structure`);
}
}
}
return claims;
} catch (e) {
this.logger.warn(`Failed to parse ID token: ${e}`);
return null;
}
}
validateClaims(claims: JwtClaims | null): string {
if (!claims?.sub) {
this.logger.error(
'No subject in token - claims available: ' +
(claims ? Object.keys(claims).join(', ') : 'none')
);
throw new UnauthorizedException('No subject in token');
}
const userSub = claims.sub;
this.logger.debug(`Processing authentication for user: ${userSub}`);
return userSub;
}
extractUserInfo(claims: JwtClaims): {
sub: string;
email?: string;
name?: string;
domain?: string;
} {
return {
sub: claims.sub!,
email: claims.email,
name: claims.name,
domain: claims.hd,
};
}
}

View File

@@ -0,0 +1,224 @@
import { Logger } from '@nestjs/common';
import * as client from 'openid-client';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { OidcTokenExchangeService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-token-exchange.service.js';
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
vi.mock('openid-client', () => ({
authorizationCodeGrant: vi.fn(),
allowInsecureRequests: vi.fn(),
}));
describe('OidcTokenExchangeService', () => {
let service: OidcTokenExchangeService;
let mockConfig: client.Configuration;
let mockProvider: OidcProvider;
beforeEach(() => {
service = new OidcTokenExchangeService();
mockConfig = {
serverMetadata: vi.fn().mockReturnValue({
issuer: 'https://example.com',
token_endpoint: 'https://example.com/token',
response_types_supported: ['code'],
grant_types_supported: ['authorization_code'],
token_endpoint_auth_methods_supported: ['client_secret_post'],
}),
} as unknown as client.Configuration;
mockProvider = {
id: 'test-provider',
issuer: 'https://example.com',
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
} as OidcProvider;
vi.clearAllMocks();
});
describe('exchangeCodeForTokens', () => {
it('should handle malformed fullCallbackUrl gracefully', async () => {
const code = 'test-code';
const state = 'test-state';
const redirectUri = 'https://example.com/callback';
const malformedUrl = 'not://a valid url';
const mockTokens = {
access_token: 'test-access-token',
id_token: 'test-id-token',
};
vi.mocked(client.authorizationCodeGrant).mockResolvedValue(mockTokens as any);
const loggerWarnSpy = vi.spyOn(Logger.prototype, 'warn').mockImplementation(() => {});
const loggerDebugSpy = vi.spyOn(Logger.prototype, 'debug').mockImplementation(() => {});
const result = await service.exchangeCodeForTokens(
mockConfig,
mockProvider,
code,
state,
redirectUri,
malformedUrl
);
expect(result).toEqual(mockTokens);
expect(loggerWarnSpy).toHaveBeenCalledWith(
expect.stringContaining('Failed to parse fullCallbackUrl'),
expect.any(Error)
);
expect(client.authorizationCodeGrant).toHaveBeenCalled();
});
it('should handle empty fullCallbackUrl without throwing', async () => {
const code = 'test-code';
const state = 'test-state';
const redirectUri = 'https://example.com/callback';
const mockTokens = {
access_token: 'test-access-token',
id_token: 'test-id-token',
};
vi.mocked(client.authorizationCodeGrant).mockResolvedValue(mockTokens as any);
const loggerWarnSpy = vi.spyOn(Logger.prototype, 'warn').mockImplementation(() => {});
const result = await service.exchangeCodeForTokens(
mockConfig,
mockProvider,
code,
state,
redirectUri,
''
);
expect(result).toEqual(mockTokens);
expect(loggerWarnSpy).not.toHaveBeenCalled();
expect(client.authorizationCodeGrant).toHaveBeenCalled();
});
it('should handle whitespace-only fullCallbackUrl without throwing', async () => {
const code = 'test-code';
const state = 'test-state';
const redirectUri = 'https://example.com/callback';
const mockTokens = {
access_token: 'test-access-token',
id_token: 'test-id-token',
};
vi.mocked(client.authorizationCodeGrant).mockResolvedValue(mockTokens as any);
const loggerWarnSpy = vi.spyOn(Logger.prototype, 'warn').mockImplementation(() => {});
const result = await service.exchangeCodeForTokens(
mockConfig,
mockProvider,
code,
state,
redirectUri,
' '
);
expect(result).toEqual(mockTokens);
expect(loggerWarnSpy).not.toHaveBeenCalled();
expect(client.authorizationCodeGrant).toHaveBeenCalled();
});
it('should copy parameters from valid fullCallbackUrl', async () => {
const code = 'test-code';
const state = 'test-state';
const redirectUri = 'https://example.com/callback';
const fullCallbackUrl =
'https://example.com/callback?code=test-code&state=test-state&scope=openid&authuser=0';
const mockTokens = {
access_token: 'test-access-token',
id_token: 'test-id-token',
};
vi.mocked(client.authorizationCodeGrant).mockResolvedValue(mockTokens as any);
const loggerWarnSpy = vi.spyOn(Logger.prototype, 'warn').mockImplementation(() => {});
const loggerDebugSpy = vi.spyOn(Logger.prototype, 'debug').mockImplementation(() => {});
const result = await service.exchangeCodeForTokens(
mockConfig,
mockProvider,
code,
state,
redirectUri,
fullCallbackUrl
);
expect(result).toEqual(mockTokens);
expect(loggerWarnSpy).not.toHaveBeenCalled();
const authCodeGrantCall = vi.mocked(client.authorizationCodeGrant).mock.calls[0];
const cleanUrl = authCodeGrantCall[1] as URL;
expect(cleanUrl.searchParams.get('scope')).toBe('openid');
expect(cleanUrl.searchParams.get('authuser')).toBe('0');
});
it('should handle undefined fullCallbackUrl', async () => {
const code = 'test-code';
const state = 'test-state';
const redirectUri = 'https://example.com/callback';
const mockTokens = {
access_token: 'test-access-token',
id_token: 'test-id-token',
};
vi.mocked(client.authorizationCodeGrant).mockResolvedValue(mockTokens as any);
const loggerWarnSpy = vi.spyOn(Logger.prototype, 'warn').mockImplementation(() => {});
const result = await service.exchangeCodeForTokens(
mockConfig,
mockProvider,
code,
state,
redirectUri,
undefined
);
expect(result).toEqual(mockTokens);
expect(loggerWarnSpy).not.toHaveBeenCalled();
expect(client.authorizationCodeGrant).toHaveBeenCalled();
});
it('should handle non-string fullCallbackUrl types gracefully', async () => {
const code = 'test-code';
const state = 'test-state';
const redirectUri = 'https://example.com/callback';
const mockTokens = {
access_token: 'test-access-token',
id_token: 'test-id-token',
};
vi.mocked(client.authorizationCodeGrant).mockResolvedValue(mockTokens as any);
const loggerWarnSpy = vi.spyOn(Logger.prototype, 'warn').mockImplementation(() => {});
const result = await service.exchangeCodeForTokens(
mockConfig,
mockProvider,
code,
state,
redirectUri,
123 as any
);
expect(result).toEqual(mockTokens);
expect(loggerWarnSpy).not.toHaveBeenCalled();
expect(client.authorizationCodeGrant).toHaveBeenCalled();
});
});
});

View File

@@ -0,0 +1,174 @@
import { Injectable, Logger } from '@nestjs/common';
import * as client from 'openid-client';
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
import { ErrorExtractor } from '@app/unraid-api/utils/error-extractor.util.js';
// Extended type for our internal use - openid-client v6 doesn't directly expose
// skip options for aud/iss checks, so we'll handle validation errors differently
type ExtendedGrantChecks = client.AuthorizationCodeGrantChecks;
@Injectable()
export class OidcTokenExchangeService {
private readonly logger = new Logger(OidcTokenExchangeService.name);
async exchangeCodeForTokens(
config: client.Configuration,
provider: OidcProvider,
code: string,
state: string,
redirectUri: string,
fullCallbackUrl?: string
): Promise<client.TokenEndpointResponse> {
this.logger.debug(`Provider ${provider.id} config loaded`);
this.logger.debug(`Redirect URI: ${redirectUri}`);
// Build current URL for token exchange
// CRITICAL: The URL used here MUST match the redirect_uri that was sent to the authorization endpoint
// Google expects the exact same redirect_uri during token exchange
const currentUrl = new URL(redirectUri);
currentUrl.searchParams.set('code', code);
currentUrl.searchParams.set('state', state);
// Copy additional parameters from the actual callback if provided
if (fullCallbackUrl && typeof fullCallbackUrl === 'string' && fullCallbackUrl.trim()) {
try {
const actualUrl = new URL(fullCallbackUrl);
// Copy over additional params that Google might have added (scope, authuser, prompt, etc)
// but DO NOT change the base URL or path
['scope', 'authuser', 'prompt', 'hd', 'session_state', 'iss'].forEach((param) => {
const value = actualUrl.searchParams.get(param);
if (value && !currentUrl.searchParams.has(param)) {
currentUrl.searchParams.set(param, value);
}
});
} catch (urlError) {
this.logger.warn(`Failed to parse fullCallbackUrl: ${fullCallbackUrl}`, urlError);
// Continue with the existing currentUrl flow without additional params
}
}
// Google returns iss in the response, openid-client v6 expects it
// If not present, add it based on the provider's issuer
if (!currentUrl.searchParams.has('iss') && provider.issuer) {
currentUrl.searchParams.set('iss', provider.issuer);
}
this.logger.debug(`Token exchange URL (matches redirect_uri): ${currentUrl.href}`);
// For openid-client v6, we need to prepare the authorization response
const authorizationResponse = new URLSearchParams(currentUrl.search);
// Set the original client state for openid-client
authorizationResponse.set('state', state);
// Create a new URL with the cleaned parameters
const cleanUrl = new URL(redirectUri);
cleanUrl.search = authorizationResponse.toString();
this.logger.debug(`Clean URL for token exchange: ${cleanUrl.href}`);
try {
this.logger.debug(`Starting token exchange with openid-client`);
this.logger.debug(`Config issuer: ${config.serverMetadata().issuer}`);
this.logger.debug(`Config token endpoint: ${config.serverMetadata().token_endpoint}`);
// Log the complete token exchange request details
const tokenEndpoint = config.serverMetadata().token_endpoint;
this.logger.debug(`Full token endpoint URL: ${tokenEndpoint}`);
this.logger.debug(`Authorization code: ${code.substring(0, 10)}...`);
this.logger.debug(`Redirect URI in token request: ${redirectUri}`);
this.logger.debug(`Client ID: ${provider.clientId}`);
this.logger.debug(`Client secret configured: ${provider.clientSecret ? 'Yes' : 'No'}`);
this.logger.debug(`Expected state value: ${state}`);
// Log the server metadata to check for any configuration issues
const metadata = config.serverMetadata();
this.logger.debug(
`Server supports response types: ${metadata.response_types_supported?.join(', ') || 'not specified'}`
);
this.logger.debug(
`Server grant types: ${metadata.grant_types_supported?.join(', ') || 'not specified'}`
);
this.logger.debug(
`Token endpoint auth methods: ${metadata.token_endpoint_auth_methods_supported?.join(', ') || 'not specified'}`
);
// For HTTP endpoints, we need to call allowInsecureRequests on the config
if (provider.issuer) {
try {
const serverUrl = new URL(provider.issuer);
if (serverUrl.protocol === 'http:') {
this.logger.debug(
`Allowing insecure requests for HTTP endpoint: ${provider.id}`
);
// allowInsecureRequests is deprecated but still needed for HTTP endpoints
client.allowInsecureRequests(config);
}
} catch (error) {
this.logger.warn(
`Invalid issuer URL for provider ${provider.id}: ${provider.issuer}`
);
// Continue without special HTTP options
}
}
const requestChecks: ExtendedGrantChecks = {
expectedState: state,
};
// Log what we're about to send
this.logger.debug(`Executing authorizationCodeGrant with:`);
this.logger.debug(`- Clean URL: ${cleanUrl.href}`);
this.logger.debug(`- Expected state: ${state}`);
this.logger.debug(`- Grant type: authorization_code`);
const tokens = await client.authorizationCodeGrant(config, cleanUrl, requestChecks);
this.logger.debug(
`Token exchange successful, received tokens: ${Object.keys(tokens).join(', ')}`
);
return tokens;
} catch (tokenError) {
// Extract and log error details using the utility
const extracted = ErrorExtractor.extract(tokenError);
this.logger.error('Token exchange failed');
ErrorExtractor.formatForLogging(extracted, this.logger);
// Special handling for content-type and parsing errors
if (ErrorExtractor.isOAuthResponseError(extracted)) {
this.logger.error('Token endpoint returned invalid or non-JSON response.');
this.logger.error('This typically means:');
this.logger.error(
'1. The token endpoint URL is incorrect (check for typos or wrong paths)'
);
this.logger.error('2. The server returned an HTML error page instead of JSON');
this.logger.error('3. Authentication failed (invalid client_id or client_secret)');
this.logger.error('4. A proxy/firewall is intercepting the request');
this.logger.error('5. The OAuth server returned malformed JSON');
this.logger.error(
`Configured token endpoint: ${config.serverMetadata().token_endpoint}`
);
this.logger.error('Please verify your OIDC provider configuration.');
}
// Check if error message contains the "unexpected JWT claim" text
if (ErrorExtractor.isJwtClaimError(extracted)) {
this.logger.error(
`unexpected JWT claim value encountered during token validation by openid-client`
);
this.logger.error(
`This error typically means the 'iss' claim in the JWT doesn't match the expected issuer`
);
this.logger.error(`Check that your provider's issuer URL is configured correctly`);
this.logger.error(`Expected issuer: ${config.serverMetadata().issuer}`);
this.logger.error(`Provider configured issuer: ${provider.issuer}`);
}
// Re-throw the original error with all its properties intact
throw tokenError;
}
}
}

View File

@@ -0,0 +1,267 @@
import { Test, TestingModule } from '@nestjs/testing';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { OidcClientConfigService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-client-config.service.js';
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/core/oidc-validation.service.js';
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
describe('OidcClientConfigService', () => {
let service: OidcClientConfigService;
let validationService: any;
beforeEach(async () => {
const module: TestingModule = await Test.createTestingModule({
providers: [
OidcClientConfigService,
{
provide: OidcValidationService,
useValue: {
performDiscovery: vi.fn(),
},
},
],
}).compile();
service = module.get<OidcClientConfigService>(OidcClientConfigService);
validationService = module.get(OidcValidationService);
});
describe('Manual Configuration', () => {
it('should create manual configuration when discovery fails but manual endpoints are provided', async () => {
const provider: OidcProvider = {
id: 'manual-provider',
name: 'Manual Provider',
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
issuer: 'https://manual.example.com',
authorizationEndpoint: 'https://manual.example.com/auth',
tokenEndpoint: 'https://manual.example.com/token',
jwksUri: 'https://manual.example.com/jwks',
scopes: ['openid', 'profile'],
authorizationRules: [],
};
// Mock discovery to fail
validationService.performDiscovery.mockRejectedValue(new Error('Discovery failed'));
const config = await service.getOrCreateConfig(provider);
// Verify the configuration was created with the correct endpoints
expect(config).toBeDefined();
expect(config.serverMetadata().authorization_endpoint).toBe(
'https://manual.example.com/auth'
);
expect(config.serverMetadata().token_endpoint).toBe('https://manual.example.com/token');
expect(config.serverMetadata().jwks_uri).toBe('https://manual.example.com/jwks');
expect(config.serverMetadata().issuer).toBe('https://manual.example.com');
});
it('should create manual configuration with fallback issuer when not provided', async () => {
const provider: OidcProvider = {
id: 'manual-provider-no-issuer',
name: 'Manual Provider No Issuer',
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
issuer: '', // Empty issuer should skip discovery and use manual endpoints
authorizationEndpoint: 'https://manual.example.com/auth',
tokenEndpoint: 'https://manual.example.com/token',
scopes: ['openid', 'profile'],
authorizationRules: [],
};
const config = await service.getOrCreateConfig(provider);
// Verify the configuration was created with inferred issuer from endpoints
expect(config).toBeDefined();
expect(config.serverMetadata().issuer).toBe('https://manual.example.com');
expect(config.serverMetadata().authorization_endpoint).toBe(
'https://manual.example.com/auth'
);
expect(config.serverMetadata().token_endpoint).toBe('https://manual.example.com/token');
});
it('should handle manual configuration with client secret properly', async () => {
const provider: OidcProvider = {
id: 'manual-with-secret',
name: 'Manual With Secret',
clientId: 'test-client-id',
clientSecret: 'secret-123',
issuer: 'https://manual.example.com',
authorizationEndpoint: 'https://manual.example.com/auth',
tokenEndpoint: 'https://manual.example.com/token',
scopes: ['openid', 'profile'],
authorizationRules: [],
};
// Mock discovery to fail
validationService.performDiscovery.mockRejectedValue(new Error('Discovery failed'));
const config = await service.getOrCreateConfig(provider);
// Verify configuration was created successfully
expect(config).toBeDefined();
expect(config.clientMetadata().client_secret).toBe('secret-123');
});
it('should handle manual configuration without client secret (public client)', async () => {
const provider: OidcProvider = {
id: 'manual-public-client',
name: 'Manual Public Client',
clientId: 'public-client-id',
// No client secret
issuer: 'https://manual.example.com',
authorizationEndpoint: 'https://manual.example.com/auth',
tokenEndpoint: 'https://manual.example.com/token',
scopes: ['openid', 'profile'],
authorizationRules: [],
};
// Mock discovery to fail
validationService.performDiscovery.mockRejectedValue(new Error('Discovery failed'));
const config = await service.getOrCreateConfig(provider);
// Verify configuration was created successfully
expect(config).toBeDefined();
expect(config.clientMetadata().client_secret).toBeUndefined();
});
it('should cache configurations', async () => {
const provider: OidcProvider = {
id: 'cached-provider',
name: 'Cached Provider',
clientId: 'test-client-id',
issuer: '',
authorizationEndpoint: 'https://cached.example.com/auth',
tokenEndpoint: 'https://cached.example.com/token',
scopes: ['openid'],
authorizationRules: [],
};
// First call
const config1 = await service.getOrCreateConfig(provider);
// Second call - should return cached value
const config2 = await service.getOrCreateConfig(provider);
// Should be the exact same object
expect(config1).toBe(config2);
expect(service.getCacheSize()).toBe(1);
});
it('should clear cache for specific provider', async () => {
const provider: OidcProvider = {
id: 'provider-to-clear',
name: 'Provider to Clear',
clientId: 'test-client-id',
issuer: '',
authorizationEndpoint: 'https://clear.example.com/auth',
tokenEndpoint: 'https://clear.example.com/token',
scopes: ['openid'],
authorizationRules: [],
};
await service.getOrCreateConfig(provider);
expect(service.getCacheSize()).toBe(1);
service.clearCache('provider-to-clear');
expect(service.getCacheSize()).toBe(0);
});
it('should clear entire cache', async () => {
const provider1: OidcProvider = {
id: 'provider1',
name: 'Provider 1',
clientId: 'client1',
issuer: '',
authorizationEndpoint: 'https://p1.example.com/auth',
tokenEndpoint: 'https://p1.example.com/token',
scopes: ['openid'],
authorizationRules: [],
};
const provider2: OidcProvider = {
id: 'provider2',
name: 'Provider 2',
clientId: 'client2',
issuer: '',
authorizationEndpoint: 'https://p2.example.com/auth',
tokenEndpoint: 'https://p2.example.com/token',
scopes: ['openid'],
authorizationRules: [],
};
await service.getOrCreateConfig(provider1);
await service.getOrCreateConfig(provider2);
expect(service.getCacheSize()).toBe(2);
service.clearCache();
expect(service.getCacheSize()).toBe(0);
});
});
describe('Discovery Configuration', () => {
it('should use discovery when issuer is provided', async () => {
const provider: OidcProvider = {
id: 'discovery-provider',
name: 'Discovery Provider',
clientId: 'test-client-id',
clientSecret: 'test-secret',
issuer: 'https://discovery.example.com',
scopes: ['openid', 'profile'],
authorizationRules: [],
};
const mockConfig = {
serverMetadata: vi.fn().mockReturnValue({
issuer: 'https://discovery.example.com',
authorization_endpoint: 'https://discovery.example.com/authorize',
token_endpoint: 'https://discovery.example.com/token',
jwks_uri: 'https://discovery.example.com/.well-known/jwks.json',
userinfo_endpoint: 'https://discovery.example.com/userinfo',
}),
clientMetadata: vi.fn().mockReturnValue({}),
};
validationService.performDiscovery.mockResolvedValue(mockConfig);
const config = await service.getOrCreateConfig(provider);
expect(validationService.performDiscovery).toHaveBeenCalledWith(provider, undefined);
expect(config).toBe(mockConfig);
});
it('should allow HTTP for discovery when issuer uses HTTP', async () => {
const provider: OidcProvider = {
id: 'http-discovery-provider',
name: 'HTTP Discovery Provider',
clientId: 'test-client-id',
issuer: 'http://discovery.example.com',
scopes: ['openid'],
authorizationRules: [],
};
const mockConfig = {
serverMetadata: vi.fn().mockReturnValue({
issuer: 'http://discovery.example.com',
authorization_endpoint: 'http://discovery.example.com/authorize',
token_endpoint: 'http://discovery.example.com/token',
}),
clientMetadata: vi.fn().mockReturnValue({}),
};
validationService.performDiscovery.mockResolvedValue(mockConfig);
const config = await service.getOrCreateConfig(provider);
expect(validationService.performDiscovery).toHaveBeenCalledWith(
provider,
expect.objectContaining({
execute: expect.any(Array),
})
);
expect(config).toBe(mockConfig);
});
});
});

View File

@@ -0,0 +1,168 @@
import { Injectable, Logger, UnauthorizedException } from '@nestjs/common';
import * as client from 'openid-client';
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/core/oidc-validation.service.js';
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
import { ErrorExtractor } from '@app/unraid-api/utils/error-extractor.util.js';
@Injectable()
export class OidcClientConfigService {
private readonly logger = new Logger(OidcClientConfigService.name);
private readonly configCache = new Map<string, client.Configuration>();
constructor(private readonly validationService: OidcValidationService) {}
async getOrCreateConfig(provider: OidcProvider): Promise<client.Configuration> {
const cacheKey = provider.id;
if (this.configCache.has(cacheKey)) {
return this.configCache.get(cacheKey)!;
}
try {
// Use the validation service to perform discovery with HTTP support
if (provider.issuer) {
this.logger.debug(`Attempting discovery for ${provider.id} at ${provider.issuer}`);
// Create client options with HTTP support if needed
const serverUrl = new URL(provider.issuer);
let clientOptions: client.DiscoveryRequestOptions | undefined;
if (serverUrl.protocol === 'http:') {
this.logger.debug(`Allowing HTTP for ${provider.id} as specified by user`);
clientOptions = {
execute: [client.allowInsecureRequests],
};
}
try {
const config = await this.validationService.performDiscovery(
provider,
clientOptions
);
this.logger.debug(`Discovery successful for ${provider.id}`);
this.logger.debug(
`Authorization endpoint: ${config.serverMetadata().authorization_endpoint}`
);
this.logger.debug(`Token endpoint: ${config.serverMetadata().token_endpoint}`);
this.logger.debug(`JWKS URI: ${config.serverMetadata().jwks_uri || 'Not provided'}`);
this.logger.debug(
`Userinfo endpoint: ${config.serverMetadata().userinfo_endpoint || 'Not provided'}`
);
this.configCache.set(cacheKey, config);
return config;
} catch (discoveryError) {
const extracted = ErrorExtractor.extract(discoveryError);
this.logger.warn(`Discovery failed for ${provider.id}: ${extracted.message}`);
// Log more details about the discovery error
const discoveryUrl = `${provider.issuer}/.well-known/openid-configuration`;
this.logger.debug(`Discovery URL attempted: ${discoveryUrl}`);
// Use error extractor for consistent logging
ErrorExtractor.formatForLogging(extracted, this.logger);
// If discovery fails but we have manual endpoints, use them
if (provider.authorizationEndpoint && provider.tokenEndpoint) {
this.logger.log(`Using manual endpoints for ${provider.id}`);
return this.createManualConfiguration(provider, cacheKey);
} else {
throw new Error(
`OIDC discovery failed and no manual endpoints provided for ${provider.id}`
);
}
}
}
// Manual configuration when no issuer is provided
if (provider.authorizationEndpoint && provider.tokenEndpoint) {
this.logger.log(`Using manual endpoints for ${provider.id} (no issuer provided)`);
return this.createManualConfiguration(provider, cacheKey);
}
// If we reach here, neither discovery nor manual endpoints are available
throw new Error(
`No configuration method available for ${provider.id}: requires either valid issuer for discovery or manual endpoints`
);
} catch (error) {
const extracted = ErrorExtractor.extract(error);
this.logger.error(
`Failed to create OIDC configuration for ${provider.id}: ${extracted.message}`
);
// Log more details in debug mode
if (extracted.stack) {
this.logger.debug(`Stack trace: ${extracted.stack}`);
}
throw new UnauthorizedException('Provider configuration error');
}
}
private createManualConfiguration(provider: OidcProvider, cacheKey: string): client.Configuration {
// Create manual configuration with a valid issuer URL
const inferredIssuer =
provider.issuer && provider.issuer.trim() !== ''
? provider.issuer
: new URL(provider.authorizationEndpoint ?? provider.tokenEndpoint!).origin;
const serverMetadata: client.ServerMetadata = {
issuer: inferredIssuer,
authorization_endpoint: provider.authorizationEndpoint!,
token_endpoint: provider.tokenEndpoint!,
jwks_uri: provider.jwksUri,
};
const clientMetadata: Partial<client.ClientMetadata> = {
client_secret: provider.clientSecret,
};
// Configure client auth method
const clientAuth = provider.clientSecret
? client.ClientSecretPost(provider.clientSecret)
: client.None();
try {
const config = new client.Configuration(
serverMetadata,
provider.clientId,
clientMetadata,
clientAuth
);
// Allow HTTP if any configured endpoint uses http
const endpoints = [
serverMetadata.authorization_endpoint,
serverMetadata.token_endpoint,
].filter(Boolean) as string[];
const hasHttp = endpoints.some((e) => new URL(e).protocol === 'http:');
if (hasHttp) {
this.logger.debug(`Allowing HTTP for manual endpoints on ${provider.id}`);
// allowInsecureRequests is deprecated but still needed for HTTP endpoints
client.allowInsecureRequests(config);
}
this.logger.debug(`Manual configuration created for ${provider.id}`);
this.logger.debug(`Authorization endpoint: ${serverMetadata.authorization_endpoint}`);
this.logger.debug(`Token endpoint: ${serverMetadata.token_endpoint}`);
this.configCache.set(cacheKey, config);
return config;
} catch (manualConfigError) {
const extracted = ErrorExtractor.extract(manualConfigError);
this.logger.error(`Failed to create manual configuration: ${extracted.message}`);
throw new Error(`Manual configuration failed for ${provider.id}`);
}
}
clearCache(providerId?: string): void {
if (providerId) {
this.configCache.delete(providerId);
} else {
this.configCache.clear();
}
}
getCacheSize(): number {
return this.configCache.size;
}
}

View File

@@ -0,0 +1,12 @@
import { Module } from '@nestjs/common';
import { OidcClientConfigService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-client-config.service.js';
import { OidcRedirectUriService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-redirect-uri.service.js';
import { OidcBaseModule } from '@app/unraid-api/graph/resolvers/sso/core/oidc-base.module.js';
@Module({
imports: [OidcBaseModule],
providers: [OidcClientConfigService, OidcRedirectUriService],
exports: [OidcClientConfigService, OidcRedirectUriService],
})
export class OidcClientModule {}

View File

@@ -0,0 +1,222 @@
import { UnauthorizedException } from '@nestjs/common';
import { Test, TestingModule } from '@nestjs/testing';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { OidcRedirectUriService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-redirect-uri.service.js';
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
import { validateRedirectUri } from '@app/unraid-api/utils/redirect-uri-validator.js';
// Mock the redirect URI validator
vi.mock('@app/unraid-api/utils/redirect-uri-validator.js', () => ({
validateRedirectUri: vi.fn(),
}));
describe('OidcRedirectUriService', () => {
let service: OidcRedirectUriService;
let oidcConfig: any;
beforeEach(async () => {
vi.clearAllMocks();
const module: TestingModule = await Test.createTestingModule({
providers: [
OidcRedirectUriService,
{
provide: OidcConfigPersistence,
useValue: {
getConfig: vi.fn().mockResolvedValue({
providers: [],
defaultAllowedOrigins: ['https://allowed.example.com'],
}),
},
},
],
}).compile();
service = module.get<OidcRedirectUriService>(OidcRedirectUriService);
oidcConfig = module.get(OidcConfigPersistence);
});
describe('getRedirectUri', () => {
it('should return valid redirect URI when validation passes', async () => {
const requestOrigin = 'https://example.com';
const requestHeaders = {
'x-forwarded-proto': 'https',
'x-forwarded-host': 'example.com',
};
(validateRedirectUri as any).mockReturnValue({
isValid: true,
validatedUri: 'https://example.com',
});
const result = await service.getRedirectUri(requestOrigin, requestHeaders);
expect(result).toBe('https://example.com/graphql/api/auth/oidc/callback');
expect(validateRedirectUri).toHaveBeenCalledWith(
'https://example.com',
'https',
'example.com',
expect.anything(),
['https://allowed.example.com']
);
});
it('should throw UnauthorizedException when validation fails', async () => {
const requestOrigin = 'https://evil.com';
const requestHeaders = {
'x-forwarded-proto': 'https',
'x-forwarded-host': 'example.com',
};
(validateRedirectUri as any).mockReturnValue({
isValid: false,
reason: 'Origin not allowed',
});
await expect(service.getRedirectUri(requestOrigin, requestHeaders)).rejects.toThrow(
UnauthorizedException
);
});
it('should handle missing allowed origins', async () => {
oidcConfig.getConfig.mockResolvedValue({
providers: [],
defaultAllowedOrigins: undefined,
});
const requestOrigin = 'https://example.com';
const requestHeaders = {
'x-forwarded-proto': 'https',
'x-forwarded-host': 'example.com',
};
(validateRedirectUri as any).mockReturnValue({
isValid: true,
validatedUri: 'https://example.com',
});
const result = await service.getRedirectUri(requestOrigin, requestHeaders);
expect(result).toBe('https://example.com/graphql/api/auth/oidc/callback');
expect(validateRedirectUri).toHaveBeenCalledWith(
'https://example.com',
'https',
'example.com',
expect.anything(),
undefined
);
});
it('should extract protocol from headers correctly', async () => {
const requestOrigin = 'https://example.com';
const requestHeaders = {
'x-forwarded-proto': ['https', 'http'],
host: 'example.com',
};
(validateRedirectUri as any).mockReturnValue({
isValid: true,
validatedUri: 'https://example.com',
});
const result = await service.getRedirectUri(requestOrigin, requestHeaders);
expect(result).toBe('https://example.com/graphql/api/auth/oidc/callback');
expect(validateRedirectUri).toHaveBeenCalledWith(
'https://example.com',
'https', // Should use first value from array
'example.com',
expect.anything(),
expect.anything()
);
});
it('should use host header as fallback', async () => {
const requestOrigin = 'https://example.com';
const requestHeaders = {
host: 'example.com',
};
(validateRedirectUri as any).mockReturnValue({
isValid: true,
validatedUri: 'https://example.com',
});
const result = await service.getRedirectUri(requestOrigin, requestHeaders);
expect(result).toBe('https://example.com/graphql/api/auth/oidc/callback');
expect(validateRedirectUri).toHaveBeenCalledWith(
'https://example.com',
'https', // Inferred from requestOrigin when x-forwarded-proto not present
'example.com',
expect.anything(),
expect.anything()
);
});
it('should prefer x-forwarded-host over host header', async () => {
const requestOrigin = 'https://example.com';
const requestHeaders = {
'x-forwarded-host': 'forwarded.example.com',
host: 'original.example.com',
};
(validateRedirectUri as any).mockReturnValue({
isValid: true,
validatedUri: 'https://example.com',
});
const result = await service.getRedirectUri(requestOrigin, requestHeaders);
expect(result).toBe('https://example.com/graphql/api/auth/oidc/callback');
expect(validateRedirectUri).toHaveBeenCalledWith(
'https://example.com',
'https', // Inferred from requestOrigin when x-forwarded-proto not present
'forwarded.example.com', // Should use x-forwarded-host
expect.anything(),
expect.anything()
);
});
it('should throw when URL construction fails', async () => {
const requestOrigin = 'https://example.com';
const requestHeaders = {};
(validateRedirectUri as any).mockReturnValue({
isValid: true,
validatedUri: 'invalid-url', // Invalid URL
});
await expect(service.getRedirectUri(requestOrigin, requestHeaders)).rejects.toThrow(
UnauthorizedException
);
});
it('should handle array values in headers correctly', async () => {
const requestOrigin = 'https://example.com';
const requestHeaders = {
'x-forwarded-proto': ['https'],
'x-forwarded-host': ['forwarded.example.com', 'another.example.com'],
host: ['original.example.com'],
};
(validateRedirectUri as any).mockReturnValue({
isValid: true,
validatedUri: 'https://example.com',
});
const result = await service.getRedirectUri(requestOrigin, requestHeaders);
expect(result).toBe('https://example.com/graphql/api/auth/oidc/callback');
expect(validateRedirectUri).toHaveBeenCalledWith(
'https://example.com',
'https',
'forwarded.example.com', // Should use first value from array
expect.anything(),
expect.anything()
);
});
});
});

View File

@@ -0,0 +1,97 @@
import { Injectable, Logger, UnauthorizedException } from '@nestjs/common';
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
import { validateRedirectUri } from '@app/unraid-api/utils/redirect-uri-validator.js';
@Injectable()
export class OidcRedirectUriService {
private readonly logger = new Logger(OidcRedirectUriService.name);
private readonly CALLBACK_PATH = '/graphql/api/auth/oidc/callback';
constructor(private readonly oidcConfig: OidcConfigPersistence) {}
async getRedirectUri(
requestOrigin: string,
requestHeaders: Record<string, string | string[] | undefined>
): Promise<string> {
// Extract protocol and host from headers for validation
const { protocol, host } = this.getRequestOriginInfo(requestHeaders, requestOrigin);
// Get the global allowed origins from OIDC config
const config = await this.oidcConfig.getConfig();
const allowedOrigins = config?.defaultAllowedOrigins;
// Debug logging to trace the issue
this.logger.debug(
`OIDC Config loaded: ${JSON.stringify(config ? { hasConfig: true, allowedOrigins } : { hasConfig: false })}`
);
this.logger.debug(
`Validating redirect URI: ${requestOrigin} against host: ${protocol}://${host}`
);
this.logger.debug(`Allowed origins from config: ${JSON.stringify(allowedOrigins || [])}`);
// Validate the provided requestOrigin using centralized validator
// Pass the global allowed origins if available
const validation = validateRedirectUri(
requestOrigin,
protocol,
host,
this.logger,
allowedOrigins
);
if (!validation.isValid) {
this.logger.warn(`Invalid redirect_uri in GraphQL OIDC flow: ${validation.reason}`);
throw new UnauthorizedException(
`Invalid redirect_uri: ${requestOrigin}. Please add this callback URI to Settings → Management Access → Allowed Redirect URIs`
);
}
// Ensure the validated URI has the correct callback path
try {
const url = new URL(validation.validatedUri);
// Only use origin to prevent path manipulation
const redirectUri = `${url.origin}${this.CALLBACK_PATH}`;
this.logger.debug(`Using validated redirect URI: ${redirectUri}`);
return redirectUri;
} catch (e) {
this.logger.error(
`Failed to construct redirect URI from validated URI: ${validation.validatedUri}`
);
throw new UnauthorizedException('Invalid redirect_uri');
}
}
private getRequestOriginInfo(
requestHeaders: Record<string, string | string[] | undefined>,
requestOrigin?: string
): {
protocol: string;
host: string | undefined;
} {
// Extract protocol from x-forwarded-proto or infer from requestOrigin, default to http
const forwardedProto = requestHeaders['x-forwarded-proto'];
const protocol = forwardedProto
? Array.isArray(forwardedProto)
? forwardedProto[0]
: forwardedProto
: requestOrigin?.startsWith('https')
? 'https'
: 'http';
// Extract host from x-forwarded-host or host header
const forwardedHost = requestHeaders['x-forwarded-host'];
const hostHeader = requestHeaders['host'];
const host = forwardedHost
? Array.isArray(forwardedHost)
? forwardedHost[0]
: forwardedHost
: hostHeader
? Array.isArray(hostHeader)
? hostHeader[0]
: hostHeader
: undefined;
return { protocol, host };
}
}

View File

@@ -0,0 +1,13 @@
import { Module } from '@nestjs/common';
import { UserSettingsModule } from '@unraid/shared/services/user-settings.js';
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/core/oidc-validation.service.js';
@Module({
imports: [UserSettingsModule],
providers: [OidcConfigPersistence, OidcValidationService],
exports: [OidcConfigPersistence, OidcValidationService],
})
export class OidcBaseModule {}

View File

@@ -1,4 +1,4 @@
import { Injectable, Logger } from '@nestjs/common';
import { Injectable } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import { RuleEffect } from '@jsonforms/core';
@@ -6,12 +6,12 @@ import { mergeSettingSlices } from '@unraid/shared/jsonforms/settings.js';
import { ConfigFilePersister } from '@unraid/shared/services/config-file.js';
import { UserSettingsService } from '@unraid/shared/services/user-settings.js';
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/core/oidc-validation.service.js';
import {
AuthorizationOperator,
OidcAuthorizationRule,
OidcProvider,
} from '@app/unraid-api/graph/resolvers/sso/oidc-provider.model.js';
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/oidc-validation.service.js';
} from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
import {
createAccordionLayout,
createLabeledControl,
@@ -21,6 +21,7 @@ import { SettingSlice } from '@app/unraid-api/types/json-forms.js';
export interface OidcConfig {
providers: OidcProvider[];
defaultAllowedOrigins?: string[];
}
@Injectable()
@@ -52,6 +53,7 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
defaultConfig(): OidcConfig {
return {
providers: [this.getUnraidNetSsoProvider()],
defaultAllowedOrigins: [],
};
}
@@ -93,6 +95,7 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
return {
providers: [unraidNetSsoProvider],
defaultAllowedOrigins: [],
};
}
@@ -119,6 +122,42 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
provider.authorizationRules || currentDefaults.authorizationRules,
};
}
// Fix dangerous authorization rules for non-unraid.net providers
if (provider.authorizationRules && provider.authorizationRules.length > 0) {
// Filter out dangerous rules that would allow all emails
const safeRules = provider.authorizationRules.filter((rule) => {
// Remove rules that have "email endsWith @" which allows all emails
if (
rule.claim === 'email' &&
rule.operator === AuthorizationOperator.ENDS_WITH &&
rule.value &&
rule.value.length === 1 &&
rule.value[0] === '@'
) {
this.logger.warn(
`Removing dangerous authorization rule from provider "${provider.name}": email endsWith "@" allows all emails`
);
return false;
}
// Remove rules with empty or invalid values
if (
!rule.value ||
rule.value.length === 0 ||
rule.value.every((v) => !v || !v.trim())
) {
this.logger.warn(
`Removing invalid authorization rule from provider "${provider.name}": empty values`
);
return false;
}
return true;
});
// Update provider with safe rules
provider.authorizationRules = safeRules;
}
return provider;
});
@@ -155,6 +194,28 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
provider.authorizationRules = rules;
}
// Validate that authorization rules are present and valid for ALL providers
if (!provider.authorizationRules || provider.authorizationRules.length === 0) {
throw new Error(
`Provider "${provider.name}" requires authorization rules. Please configure who can access your server.`
);
}
// Validate each rule has valid values
for (const rule of provider.authorizationRules) {
if (!rule.claim || !rule.claim.trim()) {
throw new Error(`Provider "${provider.name}": Authorization rule claim cannot be empty`);
}
if (!rule.operator) {
throw new Error(`Provider "${provider.name}": Authorization rule operator is required`);
}
if (!rule.value || rule.value.length === 0 || rule.value.every((v) => !v || !v.trim())) {
throw new Error(
`Provider "${provider.name}": Authorization rule for claim "${rule.claim}" must have at least one non-empty value`
);
}
}
// Clean up the provider object - remove UI-only fields
const cleanedProvider: OidcProvider = {
id: provider.id,
@@ -191,46 +252,52 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
allowedDomains?: string[];
allowedEmails?: string[];
allowedUserIds?: string[];
googleWorkspaceDomain?: string;
}): OidcAuthorizationRule[] {
const rules: OidcAuthorizationRule[] = [];
// Convert email domains to endsWith rules
// Only add if domains are provided AND not empty AND have non-empty values
if (simpleAuth?.allowedDomains && simpleAuth.allowedDomains.length > 0) {
rules.push({
claim: 'email',
operator: AuthorizationOperator.ENDS_WITH,
value: simpleAuth.allowedDomains.map((domain: string) =>
domain.startsWith('@') ? domain : `@${domain}`
),
});
const validDomains = simpleAuth.allowedDomains.filter(
(domain: string) => domain && domain.trim()
);
if (validDomains.length > 0) {
rules.push({
claim: 'email',
operator: AuthorizationOperator.ENDS_WITH,
value: validDomains.map((domain: string) =>
domain.startsWith('@') ? domain : `@${domain}`
),
});
}
}
// Convert specific emails to equals rules
// Only add if emails are provided AND not empty AND have non-empty values
if (simpleAuth?.allowedEmails && simpleAuth.allowedEmails.length > 0) {
rules.push({
claim: 'email',
operator: AuthorizationOperator.EQUALS,
value: simpleAuth.allowedEmails,
});
const validEmails = simpleAuth.allowedEmails.filter(
(email: string) => email && email.trim()
);
if (validEmails.length > 0) {
rules.push({
claim: 'email',
operator: AuthorizationOperator.EQUALS,
value: validEmails,
});
}
}
// Convert user IDs to sub equals rules
// Only add if user IDs are provided AND not empty AND have non-empty values
if (simpleAuth?.allowedUserIds && simpleAuth.allowedUserIds.length > 0) {
rules.push({
claim: 'sub',
operator: AuthorizationOperator.EQUALS,
value: simpleAuth.allowedUserIds,
});
}
// Google Workspace domain (hd claim)
if (simpleAuth?.googleWorkspaceDomain) {
rules.push({
claim: 'hd',
operator: AuthorizationOperator.EQUALS,
value: [simpleAuth.googleWorkspaceDomain],
});
const validUserIds = simpleAuth.allowedUserIds.filter((id: string) => id && id.trim());
if (validUserIds.length > 0) {
rules.push({
claim: 'sub',
operator: AuthorizationOperator.EQUALS,
value: validUserIds,
});
}
}
return rules;
@@ -286,7 +353,6 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
allowedDomains?: string[];
allowedEmails?: string[];
allowedUserIds?: string[];
googleWorkspaceDomain?: string;
}
);
// Return provider with generated rules, removing UI-only fields
@@ -304,6 +370,38 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
}),
};
// Validate authorization rules for ALL providers including unraid.net
for (const provider of processedConfig.providers) {
if (!provider.authorizationRules || provider.authorizationRules.length === 0) {
throw new Error(
`Provider "${provider.name}" requires authorization rules. Please configure who can access your server.`
);
}
// Validate each rule has valid values
for (const rule of provider.authorizationRules) {
if (!rule.claim || !rule.claim.trim()) {
throw new Error(
`Provider "${provider.name}": Authorization rule claim cannot be empty`
);
}
if (!rule.operator) {
throw new Error(
`Provider "${provider.name}": Authorization rule operator is required`
);
}
if (
!rule.value ||
rule.value.length === 0 ||
rule.value.every((v) => !v || !v.trim())
) {
throw new Error(
`Provider "${provider.name}": Authorization rule for claim "${rule.claim}" must have at least one non-empty value`
);
}
}
}
// Validate OIDC discovery for all providers with issuer URLs
const validationErrors: string[] = [];
for (const provider of processedConfig.providers) {
@@ -419,10 +517,6 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
if (rule.claim === 'sub' && rule.operator === AuthorizationOperator.EQUALS) {
return true;
}
// Google Workspace domain
if (rule.claim === 'hd' && rule.operator === AuthorizationOperator.EQUALS) {
return true;
}
return false;
});
}
@@ -431,13 +525,11 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
allowedDomains: string[];
allowedEmails: string[];
allowedUserIds: string[];
googleWorkspaceDomain?: string;
} {
const simpleAuth = {
allowedDomains: [] as string[],
allowedEmails: [] as string[],
allowedUserIds: [] as string[],
googleWorkspaceDomain: undefined as string | undefined,
};
rules.forEach((rule) => {
@@ -449,12 +541,6 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
simpleAuth.allowedEmails = rule.value;
} else if (rule.claim === 'sub' && rule.operator === AuthorizationOperator.EQUALS) {
simpleAuth.allowedUserIds = rule.value;
} else if (
rule.claim === 'hd' &&
rule.operator === AuthorizationOperator.EQUALS &&
rule.value.length > 0
) {
simpleAuth.googleWorkspaceDomain = rule.value[0];
}
});
@@ -462,7 +548,36 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
}
private buildSlice(): SettingSlice {
return mergeSettingSlices([this.oidcProvidersSlice()], { as: 'sso' });
const providersSlice = this.oidcProvidersSlice();
// Add defaultAllowedOrigins to the properties
providersSlice.properties.defaultAllowedOrigins = {
type: 'array',
items: { type: 'string' },
title: 'Default Allowed Redirect Origins',
default: [],
description:
'Additional trusted redirect origins to allow redirects from custom ports, reverse proxies, Tailscale, etc.',
};
// Add the control for defaultAllowedOrigins before the providers control using UnraidSettingsLayout
if (providersSlice.elements?.[0]?.elements) {
providersSlice.elements[0].elements.unshift(
createLabeledControl({
scope: '#/properties/sso/properties/defaultAllowedOrigins',
label: 'Allowed Redirect Origins',
description:
'Add trusted origins here when accessing Unraid through custom ports, reverse proxies, or Tailscale. Each origin should include the protocol and optionally a port (e.g., https://unraid.local:8443)',
controlOptions: {
format: 'array',
inputType: 'text',
placeholder: 'https://unraid.local:8443',
},
})
);
}
return mergeSettingSlices([providersSlice], { as: 'sso' });
}
private oidcProvidersSlice(): SettingSlice {
@@ -999,7 +1114,7 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
scope: '#/properties/claim',
label: 'JWT Claim:',
description:
'JWT claim to check (e.g., email, sub, groups, hd for Google hosted domain)',
'JWT claim to check (e.g., email, sub, groups)',
controlOptions: {
inputType: 'text',
placeholder: 'email',

View File

@@ -0,0 +1,14 @@
import { Module } from '@nestjs/common';
import { OidcAuthModule } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-auth.module.js';
import { OidcClientModule } from '@app/unraid-api/graph/resolvers/sso/client/oidc-client.module.js';
import { OidcBaseModule } from '@app/unraid-api/graph/resolvers/sso/core/oidc-base.module.js';
import { OidcService } from '@app/unraid-api/graph/resolvers/sso/core/oidc.service.js';
import { OidcSessionModule } from '@app/unraid-api/graph/resolvers/sso/session/oidc-session.module.js';
@Module({
imports: [OidcBaseModule, OidcSessionModule, OidcAuthModule, OidcClientModule],
providers: [OidcService],
exports: [OidcService, OidcBaseModule, OidcSessionModule, OidcAuthModule, OidcClientModule],
})
export class OidcCoreModule {}

View File

@@ -0,0 +1,160 @@
import { Injectable, Logger } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import * as client from 'openid-client';
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
import { OidcErrorHelper } from '@app/unraid-api/graph/resolvers/sso/utils/oidc-error.helper.js';
@Injectable()
export class OidcValidationService {
private readonly logger = new Logger(OidcValidationService.name);
constructor(private readonly configService: ConfigService) {}
/**
* Validate OIDC provider configuration by attempting discovery
* Returns validation result with helpful error messages for debugging
*/
async validateProvider(
provider: OidcProvider
): Promise<{ isValid: boolean; error?: string; details?: unknown }> {
try {
// Validate issuer URL is present
if (!provider.issuer) {
return {
isValid: false,
error: 'No issuer URL provided. Please specify the OIDC provider issuer URL.',
details: { type: 'MISSING_ISSUER' },
};
}
// Validate issuer URL is valid
let serverUrl: URL;
try {
serverUrl = new URL(provider.issuer);
} catch (urlError) {
return {
isValid: false,
error: `Invalid issuer URL format: '${provider.issuer}'. Please provide a valid URL.`,
details: {
type: 'INVALID_URL',
originalError: urlError instanceof Error ? urlError.message : String(urlError),
},
};
}
// Configure client options for HTTP if needed
let clientOptions: any = undefined;
if (serverUrl.protocol === 'http:') {
this.logger.warn(
`HTTP issuer URL detected for provider ${provider.id}: ${provider.issuer} - This is insecure`
);
clientOptions = {
execute: [client.allowInsecureRequests],
};
}
// Attempt OIDC discovery
await this.performDiscovery(provider, clientOptions);
return { isValid: true };
} catch (error) {
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
// Log the raw error for debugging
this.logger.log(`Raw discovery error for ${provider.id}: ${errorMessage}`);
// Use the helper to parse the error
const { userFriendlyError, details } = OidcErrorHelper.parseDiscoveryError(
error,
provider.issuer
);
this.logger.error(`Validation failed for provider ${provider.id}: ${errorMessage}`);
// Add debug logging for HTTP status errors
if (errorMessage.includes('unexpected HTTP response status code')) {
const baseUrl = provider.issuer?.endsWith('/.well-known/openid-configuration')
? provider.issuer.replace('/.well-known/openid-configuration', '')
: provider.issuer;
this.logger.log(`Attempted to fetch: ${baseUrl}/.well-known/openid-configuration`);
this.logger.error(`Full error details: ${errorMessage}`);
}
return {
isValid: false,
error: userFriendlyError,
details,
};
}
}
async performDiscovery(provider: OidcProvider, clientOptions?: any): Promise<client.Configuration> {
if (!provider.issuer) {
throw new Error('No issuer URL provided');
}
// Configure client auth method
const clientAuth = provider.clientSecret
? client.ClientSecretPost(provider.clientSecret)
: undefined;
const serverUrl = new URL(provider.issuer);
const discoveryUrl = `${provider.issuer}/.well-known/openid-configuration`;
this.logger.log(`Starting discovery for provider ${provider.id}`);
this.logger.log(`Discovery URL: ${discoveryUrl}`);
this.logger.log(`Client ID: ${provider.clientId}`);
this.logger.log(`Client secret configured: ${provider.clientSecret ? 'Yes' : 'No'}`);
// Use provided client options or create default options with HTTP support if needed
if (!clientOptions && serverUrl.protocol === 'http:') {
this.logger.warn(
`Allowing HTTP for ${provider.id} - This is insecure and should only be used for testing`
);
// For openid-client v6, use allowInsecureRequests in the execute array
// This is deprecated but needed for local development with HTTP endpoints
clientOptions = {
execute: [client.allowInsecureRequests],
};
}
try {
const config = await client.discovery(
serverUrl,
provider.clientId,
undefined, // client metadata
clientAuth,
clientOptions
);
this.logger.log(`Discovery successful for ${provider.id}`);
this.logger.log(`Discovery response metadata:`);
this.logger.log(` - issuer: ${config.serverMetadata().issuer}`);
this.logger.log(
` - authorization_endpoint: ${config.serverMetadata().authorization_endpoint}`
);
this.logger.log(` - token_endpoint: ${config.serverMetadata().token_endpoint}`);
this.logger.log(
` - userinfo_endpoint: ${config.serverMetadata().userinfo_endpoint || 'not provided'}`
);
this.logger.log(` - jwks_uri: ${config.serverMetadata().jwks_uri || 'not provided'}`);
this.logger.log(
` - response_types_supported: ${config.serverMetadata().response_types_supported?.join(', ') || 'not provided'}`
);
this.logger.log(
` - scopes_supported: ${config.serverMetadata().scopes_supported?.join(', ') || 'not provided'}`
);
return config;
} catch (discoveryError) {
this.logger.error(`Discovery failed for ${provider.id} at ${discoveryUrl}`);
if (discoveryError instanceof Error) {
this.logger.error('Discovery error: %o', discoveryError);
}
throw discoveryError;
}
}
}

View File

@@ -0,0 +1,485 @@
import { Logger } from '@nestjs/common';
import { ConfigModule, ConfigService } from '@nestjs/config';
import { Test, TestingModule } from '@nestjs/testing';
import * as client from 'openid-client';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import { OidcAuthorizationService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-authorization.service.js';
import { OidcClaimsService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-claims.service.js';
import { OidcTokenExchangeService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-token-exchange.service.js';
import { OidcClientConfigService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-client-config.service.js';
import { OidcRedirectUriService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-redirect-uri.service.js';
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/core/oidc-validation.service.js';
import { OidcService } from '@app/unraid-api/graph/resolvers/sso/core/oidc.service.js';
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-session.service.js';
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
describe('OidcService Integration Tests - Enhanced Logging', () => {
let service: OidcService;
let configPersistence: OidcConfigPersistence;
let loggerSpy: any;
let debugLogs: string[] = [];
let errorLogs: string[] = [];
let warnLogs: string[] = [];
let logLogs: string[] = [];
beforeEach(async () => {
// Clear log arrays
debugLogs = [];
errorLogs = [];
warnLogs = [];
logLogs = [];
const module: TestingModule = await Test.createTestingModule({
imports: [
ConfigModule.forRoot({
isGlobal: true,
load: [() => ({ BASE_URL: 'http://test.local' })],
}),
],
providers: [
OidcService,
OidcValidationService,
OidcClientConfigService,
OidcTokenExchangeService,
{
provide: OidcAuthorizationService,
useValue: {
checkAuthorization: vi.fn(),
},
},
{
provide: OidcConfigPersistence,
useValue: {
getProvider: vi.fn(),
saveProvider: vi.fn(),
getConfig: vi.fn().mockReturnValue({
providers: [],
defaultAllowedOrigins: [],
}),
},
},
{
provide: OidcSessionService,
useValue: {
createSession: vi.fn().mockResolvedValue('mock-token'),
validateSession: vi.fn(),
},
},
{
provide: OidcStateService,
useValue: {
generateSecureState: vi.fn().mockResolvedValue('secure-state'),
validateSecureState: vi.fn().mockResolvedValue({
isValid: true,
clientState: 'test-state',
redirectUri: 'https://myapp.example.com/graphql/api/auth/oidc/callback',
}),
extractProviderFromState: vi.fn().mockReturnValue('test-provider'),
},
},
{
provide: OidcRedirectUriService,
useValue: {
getRedirectUri: vi
.fn()
.mockResolvedValue(
'https://myapp.example.com/graphql/api/auth/oidc/callback'
),
},
},
{
provide: OidcClaimsService,
useValue: {
parseIdToken: vi.fn().mockReturnValue({
sub: 'user123',
email: 'user@example.com',
}),
validateClaims: vi.fn().mockReturnValue('user123'),
},
},
],
}).compile();
service = module.get<OidcService>(OidcService);
configPersistence = module.get<OidcConfigPersistence>(OidcConfigPersistence);
// Spy on logger methods to capture logs
loggerSpy = {
debug: vi
.spyOn(Logger.prototype, 'debug')
.mockImplementation((message: string, ...args: any[]) => {
debugLogs.push(message);
}),
error: vi
.spyOn(Logger.prototype, 'error')
.mockImplementation((message: string, ...args: any[]) => {
errorLogs.push(message);
}),
warn: vi
.spyOn(Logger.prototype, 'warn')
.mockImplementation((message: string, ...args: any[]) => {
warnLogs.push(message);
}),
log: vi
.spyOn(Logger.prototype, 'log')
.mockImplementation((message: string, ...args: any[]) => {
logLogs.push(message);
}),
verbose: vi.spyOn(Logger.prototype, 'verbose').mockImplementation(() => {}),
};
});
afterEach(() => {
vi.restoreAllMocks();
});
describe('Token Exchange Error Logging', () => {
it('should log detailed error information when token exchange fails with Google (trailing slash issue)', async () => {
// This simulates the issue from #1616 where a trailing slash causes failure
const provider: OidcProvider = {
id: 'google-test',
name: 'Google Test',
issuer: 'https://accounts.google.com/', // Trailing slash will cause issue
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
scopes: ['openid', 'email', 'profile'],
authorizationRules: [
{
claim: 'email',
operator: 'ENDS_WITH' as any,
value: ['@example.com'],
},
],
};
vi.mocked(configPersistence.getProvider).mockResolvedValue(provider);
try {
await service.handleCallback({
providerId: 'google-test',
code: 'test-code',
state: 'test-state',
requestOrigin: 'http://test.local',
fullCallbackUrl:
'http://test.local/graphql/api/auth/oidc/callback?code=test-code&state=test-state',
requestHeaders: { host: 'test.local' },
});
} catch (error) {
// We expect this to fail
}
// Verify that the service attempted to handle the callback
// Note: Detailed token exchange logging now happens in OidcTokenExchangeService
expect(errorLogs.length).toBeGreaterThan(0);
// Changed logging format to use error extractor
expect(errorLogs.some((log) => log.includes('Token exchange failed'))).toBe(true);
});
it('should log discovery failure details with invalid issuer URL', async () => {
const provider: OidcProvider = {
id: 'invalid-issuer',
name: 'Invalid Issuer Test',
issuer: 'https://invalid-oidc-provider.example.com', // Non-existent domain
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
scopes: ['openid', 'email'],
authorizationRules: [],
};
const validationService = new OidcValidationService(new ConfigService());
const result = await validationService.validateProvider(provider);
expect(result.isValid).toBe(false);
// Should now have more specific error message
expect(result.error).toBeDefined();
// The error should mention the domain cannot be resolved or connection failed
expect(result.error).toMatch(
/Cannot resolve domain name|Failed to connect to OIDC provider/
);
expect(result.details).toBeDefined();
expect(result.details).toHaveProperty('type');
// Should be either DNS_ERROR or FETCH_ERROR depending on the cause
expect(['DNS_ERROR', 'FETCH_ERROR']).toContain((result.details as any).type);
});
it('should log detailed HTTP error responses from discovery', async () => {
const provider: OidcProvider = {
id: 'http-error-test',
name: 'HTTP Error Test',
issuer: 'https://httpstat.us/500', // Returns 500 error
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
scopes: ['openid'],
authorizationRules: [],
};
vi.mocked(configPersistence.getProvider).mockResolvedValue(provider);
try {
await service.validateProvider(provider);
} catch (error) {
// Expected to fail
}
// Check that HTTP status details are logged (now in log level)
expect(logLogs.some((log) => log.includes('Discovery URL:'))).toBe(true);
expect(logLogs.some((log) => log.includes('Client ID:'))).toBe(true);
});
it('should log authorization URL building details', async () => {
const provider: OidcProvider = {
id: 'auth-url-test',
name: 'Auth URL Test',
issuer: 'https://accounts.google.com',
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
scopes: ['openid', 'email', 'profile'],
authorizationRules: [],
};
vi.mocked(configPersistence.getProvider).mockResolvedValue(provider);
try {
await service.getAuthorizationUrl({
providerId: 'auth-url-test',
state: 'test-state',
requestOrigin: 'http://test.local',
requestHeaders: { host: 'test.local' },
});
// Verify URL building logs
expect(logLogs.some((log) => log.includes('Built authorization URL'))).toBe(true);
expect(logLogs.some((log) => log.includes('Authorization parameters:'))).toBe(true);
} catch (error) {
// May fail due to real discovery, but we're interested in the logs
}
});
it('should log detailed information for manual endpoint configuration', async () => {
const provider: OidcProvider = {
id: 'manual-endpoints',
name: 'Manual Endpoints Test',
issuer: undefined,
authorizationEndpoint: 'https://auth.example.com/authorize',
tokenEndpoint: 'https://auth.example.com/token',
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
scopes: ['openid'],
authorizationRules: [],
};
vi.mocked(configPersistence.getProvider).mockResolvedValue(provider);
const authUrl = await service.getAuthorizationUrl({
providerId: 'manual-endpoints',
state: 'test-state',
requestOrigin: 'http://test.local',
requestHeaders: {
'x-forwarded-host': 'test.local',
'x-forwarded-proto': 'http',
},
});
// Verify manual endpoint logs
expect(debugLogs.some((log) => log.includes('Built authorization URL'))).toBe(true);
expect(debugLogs.some((log) => log.includes('client_id=test-client-id'))).toBe(true);
expect(authUrl).toContain('https://auth.example.com/authorize');
});
it('should log JWT claim validation failures with detailed context', async () => {
const provider: OidcProvider = {
id: 'jwt-validation-test',
name: 'JWT Validation Test',
issuer: 'https://accounts.google.com',
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
scopes: ['openid', 'email'],
authorizationRules: [
{
claim: 'email',
operator: 'ENDS_WITH' as any,
value: ['@restricted.com'],
},
],
};
vi.mocked(configPersistence.getProvider).mockResolvedValue(provider);
// Mock a scenario where JWT validation fails
try {
await service.handleCallback({
providerId: 'jwt-validation-test',
code: 'test-code',
state: 'test-state',
requestOrigin: 'http://test.local',
fullCallbackUrl:
'http://test.local/graphql/api/auth/oidc/callback?code=test-code&state=test-state',
requestHeaders: { host: 'test.local' },
});
} catch (error) {
// Expected to fail
}
// The JWT error handling is now in OidcTokenExchangeService
// We should see some error logged
expect(errorLogs.length).toBeGreaterThan(0);
});
});
describe('Discovery Endpoint Logging', () => {
it('should log all discovery metadata when successful', async () => {
// Use a real OIDC provider that works
const provider: OidcProvider = {
id: 'microsoft',
name: 'Microsoft',
issuer: 'https://login.microsoftonline.com/common/v2.0',
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
scopes: ['openid', 'email', 'profile'],
authorizationRules: [],
};
const validationService = new OidcValidationService(new ConfigService());
try {
await validationService.performDiscovery(provider);
} catch (error) {
// May fail due to network, but we're checking logs
}
// Verify discovery logging (now in log level)
expect(logLogs.some((log) => log.includes('Starting discovery'))).toBe(true);
expect(logLogs.some((log) => log.includes('Discovery URL:'))).toBe(true);
});
it('should log discovery failures with malformed JSON response', async () => {
const provider: OidcProvider = {
id: 'malformed-json',
name: 'Malformed JSON Test',
issuer: 'https://example.com/malformed',
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
scopes: ['openid'],
authorizationRules: [],
};
// Mock global fetch to return HTML instead of JSON
const originalFetch = global.fetch;
global.fetch = vi.fn().mockImplementation(() =>
Promise.resolve(
new Response('<html><body>Not JSON</body></html>', {
status: 200,
headers: { 'content-type': 'text/html' },
})
)
);
const validationService = new OidcValidationService(new ConfigService());
const result = await validationService.validateProvider(provider);
// Restore original fetch
global.fetch = originalFetch;
expect(result.isValid).toBe(false);
expect(result.error).toBeDefined();
// The openid-client library will fail when it gets HTML instead of JSON
// It returns "unexpected response content-type" error
expect(result.error).toMatch(
/Invalid OIDC discovery|malformed|doesn't conform|unexpected|content-type/i
);
});
it('should handle and log HTTP vs HTTPS protocol differences', async () => {
const httpProvider: OidcProvider = {
id: 'http-local',
name: 'HTTP Local Test',
issuer: 'http://localhost:8080', // HTTP endpoint
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
scopes: ['openid'],
authorizationRules: [],
};
// Create a validation service and spy on its logger
const validationService = new OidcValidationService(new ConfigService());
try {
await validationService.validateProvider(httpProvider);
} catch (error) {
// Expected to fail if localhost:8080 isn't running
}
// The HTTP logging happens in the validation service
// We should check that HTTP issuers are detected
expect(httpProvider.issuer).toMatch(/^http:/);
// Verify that we're testing an HTTP endpoint
expect(httpProvider.issuer).toBe('http://localhost:8080');
});
});
describe('Request/Response Detail Logging', () => {
it('should log complete request parameters for token exchange', async () => {
const provider: OidcProvider = {
id: 'token-params-test',
name: 'Token Params Test',
issuer: 'https://accounts.google.com',
clientId: 'detailed-client-id',
clientSecret: 'detailed-client-secret',
scopes: ['openid', 'email', 'profile', 'offline_access'],
authorizationRules: [],
};
vi.mocked(configPersistence.getProvider).mockResolvedValue(provider);
try {
await service.handleCallback({
providerId: 'token-params-test',
code: 'authorization-code-12345',
state: 'state-with-signature',
requestOrigin: 'https://myapp.example.com',
fullCallbackUrl:
'https://myapp.example.com/graphql/api/auth/oidc/callback?code=authorization-code-12345&state=state-with-signature&scope=openid+email+profile',
requestHeaders: { host: 'myapp.example.com' },
});
} catch (error) {
// Expected to fail
}
// Verify that we attempted the operation
// Detailed parameter logging is now in OidcTokenExchangeService
expect(debugLogs.length).toBeGreaterThan(0);
expect(debugLogs.some((log) => log.includes('Client ID: detailed-client-id'))).toBe(true);
expect(debugLogs.some((log) => log.includes('Client secret configured: Yes'))).toBe(true);
});
it('should capture and log all error properties from openid-client', async () => {
const provider: OidcProvider = {
id: 'error-properties-test',
name: 'Error Properties Test',
issuer: 'https://expired-cert.badssl.com/', // SSL cert error
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
scopes: ['openid'],
authorizationRules: [],
};
const validationService = new OidcValidationService(new ConfigService());
const result = await validationService.validateProvider(provider);
expect(result.isValid).toBe(false);
expect(result.error).toBeDefined();
// Should detect SSL/certificate issues or connection failure
expect(result.error).toMatch(
/SSL\/TLS certificate error|Failed to connect to OIDC provider|certificate/
);
expect(result.details).toBeDefined();
expect(result.details).toHaveProperty('type');
// Should be either SSL_ERROR or FETCH_ERROR
expect(['SSL_ERROR', 'FETCH_ERROR']).toContain((result.details as any).type);
});
});
});

View File

@@ -0,0 +1,381 @@
import { CacheModule } from '@nestjs/cache-manager';
import { UnauthorizedException } from '@nestjs/common';
import { Test, TestingModule } from '@nestjs/testing';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { OidcAuthorizationService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-authorization.service.js';
import { OidcClaimsService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-claims.service.js';
import { OidcTokenExchangeService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-token-exchange.service.js';
import { OidcClientConfigService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-client-config.service.js';
import { OidcRedirectUriService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-redirect-uri.service.js';
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/core/oidc-validation.service.js';
import { OidcService } from '@app/unraid-api/graph/resolvers/sso/core/oidc.service.js';
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-session.service.js';
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
// Mock openid-client
vi.mock('openid-client', () => ({
buildAuthorizationUrl: vi.fn((config, params) => {
const url = new URL(config.serverMetadata().authorization_endpoint);
Object.entries(params).forEach(([key, value]) => {
if (value !== undefined) {
url.searchParams.set(key, String(value));
}
});
return url;
}),
allowInsecureRequests: vi.fn(),
}));
describe('OidcService Integration', () => {
let service: OidcService;
let oidcConfig: any;
let sessionService: any;
let stateService: OidcStateService;
let redirectUriService: any;
let clientConfigService: any;
let tokenExchangeService: any;
let claimsService: any;
let authorizationService: any;
beforeEach(async () => {
const module: TestingModule = await Test.createTestingModule({
imports: [CacheModule.register()],
providers: [
OidcService,
{
provide: OidcConfigPersistence,
useValue: {
getProvider: vi.fn(),
getConfig: vi.fn().mockResolvedValue({
providers: [],
defaultAllowedOrigins: ['https://example.com'],
}),
},
},
{
provide: OidcSessionService,
useValue: {
createSession: vi.fn().mockResolvedValue('padded-token-123'),
},
},
OidcStateService,
{
provide: OidcValidationService,
useValue: {
validateProvider: vi.fn().mockResolvedValue({ isValid: true }),
performDiscovery: vi.fn(),
},
},
{
provide: OidcAuthorizationService,
useValue: {
checkAuthorization: vi.fn(),
},
},
{
provide: OidcRedirectUriService,
useValue: {
getRedirectUri: vi.fn().mockResolvedValue('https://example.com/callback'),
},
},
{
provide: OidcClientConfigService,
useValue: {
getOrCreateConfig: vi.fn(),
clearCache: vi.fn(),
},
},
{
provide: OidcTokenExchangeService,
useValue: {
exchangeCodeForTokens: vi.fn(),
},
},
{
provide: OidcClaimsService,
useValue: {
parseIdToken: vi.fn(),
validateClaims: vi.fn(),
},
},
],
}).compile();
service = module.get<OidcService>(OidcService);
oidcConfig = module.get(OidcConfigPersistence);
sessionService = module.get(OidcSessionService);
stateService = module.get<OidcStateService>(OidcStateService);
redirectUriService = module.get(OidcRedirectUriService);
clientConfigService = module.get(OidcClientConfigService);
tokenExchangeService = module.get(OidcTokenExchangeService);
claimsService = module.get(OidcClaimsService);
authorizationService = module.get(OidcAuthorizationService);
});
describe('getAuthorizationUrl', () => {
it('should generate authorization URL with custom endpoints', async () => {
const provider: OidcProvider = {
id: 'custom-provider',
name: 'Custom Provider',
clientId: 'test-client-id',
clientSecret: 'test-secret',
authorizationEndpoint: 'https://custom.example.com/auth',
scopes: ['openid', 'profile'],
authorizationRules: [],
};
oidcConfig.getProvider.mockResolvedValue(provider);
const params = {
providerId: 'custom-provider',
state: 'client-state-123',
requestOrigin: 'https://example.com',
requestHeaders: { host: 'example.com' },
};
const url = await service.getAuthorizationUrl(params);
expect(redirectUriService.getRedirectUri).toHaveBeenCalledWith('https://example.com', {
host: 'example.com',
});
const urlObj = new URL(url);
expect(urlObj.origin).toBe('https://custom.example.com');
expect(urlObj.pathname).toBe('/auth');
expect(urlObj.searchParams.get('client_id')).toBe('test-client-id');
expect(urlObj.searchParams.get('redirect_uri')).toBe('https://example.com/callback');
expect(urlObj.searchParams.get('scope')).toBe('openid profile');
expect(urlObj.searchParams.get('response_type')).toBe('code');
expect(urlObj.searchParams.has('state')).toBe(true);
});
it('should use OIDC discovery when no custom authorization endpoint', async () => {
const provider: OidcProvider = {
id: 'discovery-provider',
name: 'Discovery Provider',
clientId: 'test-client-id',
issuer: 'https://discovery.example.com',
scopes: ['openid'],
authorizationRules: [],
};
// Create a mock configuration object
const mockConfig = {
serverMetadata: vi.fn().mockReturnValue({
authorization_endpoint: 'https://discovery.example.com/authorize',
}),
};
oidcConfig.getProvider.mockResolvedValue(provider);
clientConfigService.getOrCreateConfig.mockResolvedValue(mockConfig);
const params = {
providerId: 'discovery-provider',
state: 'client-state-123',
requestOrigin: 'https://example.com',
requestHeaders: {},
};
const url = await service.getAuthorizationUrl(params);
expect(clientConfigService.getOrCreateConfig).toHaveBeenCalledWith(provider);
expect(url).toContain('https://discovery.example.com/authorize');
});
it('should throw when provider not found', async () => {
oidcConfig.getProvider.mockResolvedValue(null);
const params = {
providerId: 'non-existent',
state: 'state',
requestOrigin: 'https://example.com',
requestHeaders: {},
};
await expect(service.getAuthorizationUrl(params)).rejects.toThrow(UnauthorizedException);
});
});
describe('handleCallback', () => {
it('should handle successful callback flow', async () => {
const provider: OidcProvider = {
id: 'test-provider',
name: 'Test Provider',
clientId: 'test-client-id',
issuer: 'https://test.example.com',
scopes: ['openid'],
authorizationRules: [],
};
const mockConfig = {
serverMetadata: vi.fn().mockReturnValue({
issuer: 'https://test.example.com',
token_endpoint: 'https://test.example.com/token',
}),
};
const mockTokens = {
id_token: 'id.token.here',
access_token: 'access.token.here',
};
const mockClaims = {
sub: 'user123',
email: 'user@example.com',
};
oidcConfig.getProvider.mockResolvedValue(provider);
clientConfigService.getOrCreateConfig.mockResolvedValue(mockConfig);
tokenExchangeService.exchangeCodeForTokens.mockResolvedValue(mockTokens);
claimsService.parseIdToken.mockReturnValue(mockClaims);
claimsService.validateClaims.mockReturnValue('user123');
// Mock the OidcStateExtractor's static method
const OidcStateExtractor = await import(
'@app/unraid-api/graph/resolvers/sso/session/oidc-state-extractor.util.js'
);
vi.spyOn(OidcStateExtractor.OidcStateExtractor, 'extractAndValidateState').mockResolvedValue(
{
providerId: 'test-provider',
originalState: 'original-state',
clientState: 'original-state',
redirectUri: 'https://example.com/callback',
}
);
const params = {
providerId: 'test-provider',
code: 'auth-code-123',
state: 'secure-state',
requestOrigin: 'https://example.com',
fullCallbackUrl: 'https://example.com/callback?code=auth-code-123&state=secure-state',
requestHeaders: {},
};
const token = await service.handleCallback(params);
expect(token).toBe('padded-token-123');
expect(tokenExchangeService.exchangeCodeForTokens).toHaveBeenCalled();
expect(claimsService.parseIdToken).toHaveBeenCalledWith('id.token.here');
expect(claimsService.validateClaims).toHaveBeenCalledWith(mockClaims);
expect(authorizationService.checkAuthorization).toHaveBeenCalledWith(provider, mockClaims);
expect(sessionService.createSession).toHaveBeenCalledWith('test-provider', 'user123');
});
it('should throw when provider not found', async () => {
oidcConfig.getProvider.mockResolvedValue(null);
const params = {
providerId: 'non-existent',
code: 'code',
state: 'state',
requestOrigin: 'https://example.com',
fullCallbackUrl: 'https://example.com/callback',
requestHeaders: {},
};
await expect(service.handleCallback(params)).rejects.toThrow(UnauthorizedException);
});
it('should handle authorization rejection', async () => {
const provider: OidcProvider = {
id: 'test-provider',
name: 'Test Provider',
clientId: 'test-client-id',
issuer: 'https://test.example.com',
scopes: ['openid'],
authorizationRules: [],
};
const mockConfig = {
serverMetadata: vi.fn().mockReturnValue({
issuer: 'https://test.example.com',
token_endpoint: 'https://test.example.com/token',
}),
};
const mockTokens = {
id_token: 'id.token.here',
};
const mockClaims = {
sub: 'user123',
email: 'user@example.com',
};
oidcConfig.getProvider.mockResolvedValue(provider);
clientConfigService.getOrCreateConfig.mockResolvedValue(mockConfig);
tokenExchangeService.exchangeCodeForTokens.mockResolvedValue(mockTokens);
claimsService.parseIdToken.mockReturnValue(mockClaims);
claimsService.validateClaims.mockReturnValue('user123');
authorizationService.checkAuthorization.mockRejectedValue(
new UnauthorizedException('Not authorized')
);
// Mock the OidcStateExtractor's static method
const OidcStateExtractor = await import(
'@app/unraid-api/graph/resolvers/sso/session/oidc-state-extractor.util.js'
);
vi.spyOn(OidcStateExtractor.OidcStateExtractor, 'extractAndValidateState').mockResolvedValue(
{
providerId: 'test-provider',
originalState: 'original-state',
clientState: 'original-state',
redirectUri: 'https://example.com/callback',
}
);
const params = {
providerId: 'test-provider',
code: 'auth-code-123',
state: 'secure-state',
requestOrigin: 'https://example.com',
fullCallbackUrl: 'https://example.com/callback',
requestHeaders: {},
};
await expect(service.handleCallback(params)).rejects.toThrow(UnauthorizedException);
});
});
describe('validateProvider', () => {
it('should clear cache and validate provider', async () => {
const provider: OidcProvider = {
id: 'test-provider',
name: 'Test Provider',
clientId: 'test-client-id',
issuer: 'https://test.example.com',
scopes: ['openid'],
authorizationRules: [],
};
const result = await service.validateProvider(provider);
expect(clientConfigService.clearCache).toHaveBeenCalledWith('test-provider');
// The validation service mock already returns { isValid: true }
expect(result).toEqual({ isValid: true });
});
});
describe('extractProviderFromState', () => {
it('should extract provider from state', () => {
const state = 'provider-id:original-state';
const result = service.extractProviderFromState(state);
expect(result.providerId).toBeDefined();
expect(result.originalState).toBeDefined();
});
});
describe('getStateService', () => {
it('should return state service', () => {
const result = service.getStateService();
expect(result).toBe(stateService);
});
});
});

View File

@@ -0,0 +1,243 @@
import { Injectable, Logger, UnauthorizedException } from '@nestjs/common';
import * as client from 'openid-client';
import { OidcAuthorizationService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-authorization.service.js';
import { OidcClaimsService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-claims.service.js';
import { OidcTokenExchangeService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-token-exchange.service.js';
import { OidcClientConfigService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-client-config.service.js';
import { OidcRedirectUriService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-redirect-uri.service.js';
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/core/oidc-validation.service.js';
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-session.service.js';
import { OidcStateExtractor } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state-extractor.util.js';
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
import { ErrorExtractor } from '@app/unraid-api/utils/error-extractor.util.js';
export interface GetAuthorizationUrlParams {
providerId: string;
state: string;
requestOrigin: string;
requestHeaders: Record<string, string | string[] | undefined>;
}
export interface HandleCallbackParams {
providerId: string;
code: string;
state: string;
requestOrigin: string;
fullCallbackUrl: string;
requestHeaders: Record<string, string | string[] | undefined>;
}
@Injectable()
export class OidcService {
private readonly logger = new Logger(OidcService.name);
constructor(
private readonly oidcConfig: OidcConfigPersistence,
private readonly sessionService: OidcSessionService,
private readonly stateService: OidcStateService,
private readonly validationService: OidcValidationService,
private readonly authorizationService: OidcAuthorizationService,
private readonly redirectUriService: OidcRedirectUriService,
private readonly clientConfigService: OidcClientConfigService,
private readonly tokenExchangeService: OidcTokenExchangeService,
private readonly claimsService: OidcClaimsService
) {}
async getAuthorizationUrl(params: GetAuthorizationUrlParams): Promise<string> {
const { providerId, state, requestOrigin, requestHeaders } = params;
const provider = await this.oidcConfig.getProvider(providerId);
if (!provider) {
throw new UnauthorizedException(`Provider ${providerId} not found`);
}
// Use requestOrigin with validation
const redirectUri = await this.redirectUriService.getRedirectUri(requestOrigin, requestHeaders);
this.logger.debug(`Using redirect URI for authorization: ${redirectUri}`);
this.logger.debug(`Request origin was: ${requestOrigin}`);
// Generate secure state with cryptographic signature, including redirect URI
const secureState = await this.stateService.generateSecureState(providerId, state, redirectUri);
// Build authorization URL
if (provider.authorizationEndpoint) {
// Use custom authorization endpoint
const authUrl = new URL(provider.authorizationEndpoint);
// Standard OAuth2 parameters
authUrl.searchParams.set('client_id', provider.clientId);
authUrl.searchParams.set('redirect_uri', redirectUri);
authUrl.searchParams.set('scope', provider.scopes.join(' '));
authUrl.searchParams.set('state', secureState);
authUrl.searchParams.set('response_type', 'code');
this.logger.debug(`Built authorization URL for provider ${provider.id}`);
this.logger.debug(
`Authorization parameters: client_id=${provider.clientId}, redirect_uri=${redirectUri}, scope=${provider.scopes.join(' ')}, response_type=code`
);
return authUrl.href;
}
// Use OIDC discovery for providers without custom endpoints
const config = await this.clientConfigService.getOrCreateConfig(provider);
const parameters: Record<string, string> = {
redirect_uri: redirectUri,
scope: provider.scopes.join(' '),
state: secureState,
response_type: 'code',
};
// For HTTP endpoints, we need to call allowInsecureRequests on the config
if (provider.issuer) {
try {
const serverUrl = new URL(provider.issuer);
if (serverUrl.protocol === 'http:') {
this.logger.debug(`Allowing insecure requests for HTTP endpoint: ${provider.id}`);
// allowInsecureRequests is deprecated but still needed for HTTP endpoints
client.allowInsecureRequests(config);
}
} catch (error) {
this.logger.warn(`Invalid issuer URL for provider ${provider.id}: ${provider.issuer}`);
// Continue without special HTTP options
}
}
const authUrl = client.buildAuthorizationUrl(config, parameters);
this.logger.log(`Built authorization URL via discovery for provider ${provider.id}`);
this.logger.log(`Authorization parameters: ${JSON.stringify(parameters)}`);
return authUrl.href;
}
extractProviderFromState(state: string): { providerId: string; originalState: string } {
return OidcStateExtractor.extractProviderFromState(state, this.stateService);
}
/**
* Get the state service for external utilities
*/
getStateService(): OidcStateService {
return this.stateService;
}
async handleCallback(params: HandleCallbackParams): Promise<string> {
const { providerId, code, state, fullCallbackUrl } = params;
const provider = await this.oidcConfig.getProvider(providerId);
if (!provider) {
throw new UnauthorizedException(`Provider ${providerId} not found`);
}
// Extract and validate state, including the stored redirect URI
const stateInfo = await OidcStateExtractor.extractAndValidateState(state, this.stateService);
if (!stateInfo.redirectUri) {
throw new UnauthorizedException('Missing redirect URI in state');
}
// Use the redirect URI that was stored during authorization
const redirectUri = stateInfo.redirectUri;
this.logger.debug(`Using stored redirect URI from state: ${redirectUri}`);
try {
// Always use openid-client for consistency
const config = await this.clientConfigService.getOrCreateConfig(provider);
// Log configuration details
this.logger.debug(`Provider ${providerId} config loaded`);
this.logger.debug(`Redirect URI: ${redirectUri}`);
// Build current URL for token exchange
// CRITICAL: The URL used here MUST match the redirect_uri that was sent to the authorization endpoint
// Google expects the exact same redirect_uri during token exchange
const currentUrl = new URL(redirectUri);
currentUrl.searchParams.set('code', code);
currentUrl.searchParams.set('state', state);
// Copy additional parameters from the actual callback if provided
if (fullCallbackUrl) {
const actualUrl = new URL(fullCallbackUrl);
// Copy over additional params that Google might have added (scope, authuser, prompt, etc)
// but DO NOT change the base URL or path
['scope', 'authuser', 'prompt', 'hd', 'session_state', 'iss'].forEach((param) => {
const value = actualUrl.searchParams.get(param);
if (value && !currentUrl.searchParams.has(param)) {
currentUrl.searchParams.set(param, value);
}
});
}
// Google returns iss in the response, openid-client v6 expects it
// If not present, add it based on the provider's issuer
if (!currentUrl.searchParams.has('iss') && provider.issuer) {
currentUrl.searchParams.set('iss', provider.issuer);
}
this.logger.debug(`Token exchange URL (matches redirect_uri): ${currentUrl.href}`);
// State was already validated in extractAndValidateState above, use that result
// The clientState should be present after successful validation, but handle the edge case
if (!stateInfo.clientState) {
this.logger.warn('Client state missing after successful validation');
throw new UnauthorizedException('Invalid state: missing client state');
}
const originalState = stateInfo.clientState;
this.logger.debug(`Exchanging code for tokens with provider ${providerId}`);
this.logger.debug(`Client state extracted: ${originalState}`);
// Use the token exchange service
const tokens = await this.tokenExchangeService.exchangeCodeForTokens(
config,
provider,
code,
originalState,
redirectUri,
fullCallbackUrl
);
// Parse ID token to get user info
const claims = this.claimsService.parseIdToken(tokens.id_token);
const userSub = this.claimsService.validateClaims(claims);
// Check authorization based on rules
// This will throw a helpful error if misconfigured or unauthorized
await this.authorizationService.checkAuthorization(provider, claims!);
// Create session and return padded token
const paddedToken = await this.sessionService.createSession(providerId, userSub);
this.logger.log(`Successfully authenticated user ${userSub} via provider ${providerId}`);
return paddedToken;
} catch (error) {
const extracted = ErrorExtractor.extract(error);
this.logger.error(`OAuth callback error: ${extracted.message}`);
// Re-throw the original error if it's already an UnauthorizedException
if (error instanceof UnauthorizedException) {
throw error;
}
// Otherwise throw a generic error
throw new UnauthorizedException('Authentication failed');
}
}
/**
* Validate OIDC provider configuration by attempting discovery
* Returns validation result with helpful error messages for debugging
*/
async validateProvider(
provider: OidcProvider
): Promise<{ isValid: boolean; error?: string; details?: unknown }> {
// Clear any cached config for this provider to force fresh validation
this.clientConfigService.clearCache(provider.id);
// Delegate to the validation service
return this.validationService.validateProvider(provider);
}
}

View File

@@ -0,0 +1,16 @@
import { Field, ObjectType } from '@nestjs/graphql';
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
@ObjectType()
export class OidcConfiguration {
@Field(() => [OidcProvider], { description: 'List of configured OIDC providers' })
providers!: OidcProvider[];
@Field(() => [String], {
nullable: true,
description:
'Default allowed redirect origins that apply to all OIDC providers (e.g., Tailscale domains)',
})
defaultAllowedOrigins?: string[];
}

View File

@@ -80,9 +80,11 @@ export class OidcProvider {
@Field(() => String, {
description:
'OIDC issuer URL (e.g., https://accounts.google.com). Required for auto-discovery via /.well-known/openid-configuration',
nullable: true,
})
@IsUrl()
issuer!: string;
@IsOptional()
issuer?: string;
@Field(() => String, {
nullable: true,

View File

@@ -1,4 +1,4 @@
import type { OidcConfig } from '@app/unraid-api/graph/resolvers/sso/oidc-config.service.js';
import type { OidcConfig } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
declare module '@unraid/shared/services/user-settings.js' {
interface UserSettings {

View File

@@ -1,701 +0,0 @@
import { Injectable, Logger, UnauthorizedException } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import { decodeJwt } from 'jose';
import * as client from 'openid-client';
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/oidc-config.service.js';
import {
AuthorizationOperator,
AuthorizationRuleMode,
OidcAuthorizationRule,
OidcProvider,
} from '@app/unraid-api/graph/resolvers/sso/oidc-provider.model.js';
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/oidc-session.service.js';
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/oidc-state.service.js';
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/oidc-validation.service.js';
interface JwtClaims {
sub?: string;
email?: string;
name?: string;
hd?: string; // Google hosted domain
[claim: string]: unknown;
}
@Injectable()
export class OidcAuthService {
private readonly logger = new Logger(OidcAuthService.name);
private readonly configCache = new Map<string, client.Configuration>();
constructor(
private readonly configService: ConfigService,
private readonly oidcConfig: OidcConfigPersistence,
private readonly sessionService: OidcSessionService,
private readonly stateService: OidcStateService,
private readonly validationService: OidcValidationService
) {}
async getAuthorizationUrl(
providerId: string,
state: string,
requestOrigin?: string
): Promise<string> {
const provider = await this.oidcConfig.getProvider(providerId);
if (!provider) {
throw new UnauthorizedException(`Provider ${providerId} not found`);
}
const redirectUri = this.getRedirectUri(requestOrigin);
// Generate secure state with cryptographic signature
const secureState = this.stateService.generateSecureState(providerId, state);
// Build authorization URL
if (provider.authorizationEndpoint) {
// Use custom authorization endpoint
const authUrl = new URL(provider.authorizationEndpoint);
// Standard OAuth2 parameters
authUrl.searchParams.set('client_id', provider.clientId);
authUrl.searchParams.set('redirect_uri', redirectUri);
authUrl.searchParams.set('scope', provider.scopes.join(' '));
authUrl.searchParams.set('state', secureState);
authUrl.searchParams.set('response_type', 'code');
return authUrl.href;
}
// Use OIDC discovery for providers without custom endpoints
const config = await this.getOrCreateConfig(provider);
const parameters: Record<string, string> = {
redirect_uri: redirectUri,
scope: provider.scopes.join(' '),
state: secureState,
response_type: 'code',
};
// For HTTP endpoints, we need to pass the allowInsecureRequests option
const serverUrl = new URL(provider.issuer || '');
let clientOptions: any = undefined;
if (serverUrl.protocol === 'http:') {
this.logger.debug(
`Building authorization URL with allowInsecureRequests for ${provider.id}`
);
clientOptions = {
execute: [client.allowInsecureRequests],
};
}
const authUrl = client.buildAuthorizationUrl(config, parameters);
return authUrl.href;
}
extractProviderFromState(state: string): { providerId: string; originalState: string } {
// Extract provider from state prefix (no decryption needed)
const providerId = this.stateService.extractProviderFromState(state);
if (providerId) {
return {
providerId,
originalState: state,
};
}
// Fallback for unknown formats
return {
providerId: '',
originalState: state,
};
}
async handleCallback(
providerId: string,
code: string,
state: string,
requestOrigin?: string,
fullCallbackUrl?: string
): Promise<string> {
const provider = await this.oidcConfig.getProvider(providerId);
if (!provider) {
throw new UnauthorizedException(`Provider ${providerId} not found`);
}
try {
const redirectUri = this.getRedirectUri(requestOrigin);
// Always use openid-client for consistency
const config = await this.getOrCreateConfig(provider);
// Log configuration details
this.logger.debug(`Provider ${providerId} config loaded`);
this.logger.debug(`Redirect URI: ${redirectUri}`);
// Build current URL for token exchange
// CRITICAL: The URL used here MUST match the redirect_uri that was sent to the authorization endpoint
// Google expects the exact same redirect_uri during token exchange
const currentUrl = new URL(redirectUri);
currentUrl.searchParams.set('code', code);
currentUrl.searchParams.set('state', state);
// Copy additional parameters from the actual callback if provided
if (fullCallbackUrl) {
const actualUrl = new URL(fullCallbackUrl);
// Copy over additional params that Google might have added (scope, authuser, prompt, etc)
// but DO NOT change the base URL or path
['scope', 'authuser', 'prompt', 'hd', 'session_state', 'iss'].forEach((param) => {
const value = actualUrl.searchParams.get(param);
if (value && !currentUrl.searchParams.has(param)) {
currentUrl.searchParams.set(param, value);
}
});
}
// Google returns iss in the response, openid-client v6 expects it
// If not present, add it based on the provider's issuer
if (!currentUrl.searchParams.has('iss') && provider.issuer) {
currentUrl.searchParams.set('iss', provider.issuer);
}
this.logger.debug(`Token exchange URL (matches redirect_uri): ${currentUrl.href}`);
// Validate secure state
const stateValidation = this.stateService.validateSecureState(state, providerId);
if (!stateValidation.isValid) {
this.logger.error(`State validation failed: ${stateValidation.error}`);
throw new UnauthorizedException(stateValidation.error || 'Invalid state parameter');
}
const originalState = stateValidation.clientState!;
this.logger.debug(`Exchanging code for tokens with provider ${providerId}`);
this.logger.debug(`Client state extracted: ${originalState}`);
// For openid-client v6, we need to prepare the authorization response
const authorizationResponse = new URLSearchParams(currentUrl.search);
// Set the original client state for openid-client
authorizationResponse.set('state', originalState);
// Create a new URL with the cleaned parameters
const cleanUrl = new URL(redirectUri);
cleanUrl.search = authorizationResponse.toString();
this.logger.debug(`Clean URL for token exchange: ${cleanUrl.href}`);
let tokens;
try {
this.logger.debug(`Starting token exchange with openid-client`);
this.logger.debug(`Config issuer: ${config.serverMetadata().issuer}`);
this.logger.debug(`Config token endpoint: ${config.serverMetadata().token_endpoint}`);
// For HTTP endpoints, we need to pass the allowInsecureRequests option
const serverUrl = new URL(provider.issuer || '');
let clientOptions: any = undefined;
if (serverUrl.protocol === 'http:') {
this.logger.debug(`Token exchange with allowInsecureRequests for ${provider.id}`);
clientOptions = {
execute: [client.allowInsecureRequests],
};
}
tokens = await client.authorizationCodeGrant(
config,
cleanUrl,
{
expectedState: originalState,
},
clientOptions
);
this.logger.debug(
`Token exchange successful, received tokens: ${Object.keys(tokens).join(', ')}`
);
} catch (tokenError) {
const errorMessage =
tokenError instanceof Error ? tokenError.message : String(tokenError);
this.logger.error(`Token exchange failed: ${errorMessage}`);
// Check if error message contains the "unexpected JWT claim" text
if (errorMessage.includes('unexpected JWT claim value encountered')) {
this.logger.error(
`unexpected JWT claim value encountered during token validation by openid-client`
);
this.logger.debug(
`Token exchange error details: ${JSON.stringify(tokenError, null, 2)}`
);
// Log the actual vs expected issuer
this.logger.error(
`This error typically means the 'iss' claim in the JWT doesn't match the expected issuer`
);
this.logger.error(`Check that your provider's issuer URL is configured correctly`);
}
throw tokenError;
}
// Parse ID token to get user info
let claims: JwtClaims | null = null;
if (tokens.id_token) {
try {
// Use jose to properly decode the JWT
claims = decodeJwt(tokens.id_token) as JwtClaims;
// Log claims safely without PII - only structure, not values
if (claims) {
const claimKeys = Object.keys(claims).join(', ');
this.logger.debug(
`ID token decoded successfully. Available claims: [${claimKeys}]`
);
// Log claim types without exposing sensitive values
for (const [key, value] of Object.entries(claims)) {
const valueType = Array.isArray(value)
? `array[${value.length}]`
: typeof value;
// Only log structure, not actual values (avoid PII)
this.logger.debug(`Claim '${key}': type=${valueType}`);
// Check for unexpected claim types
if (valueType === 'object' && value !== null && !Array.isArray(value)) {
this.logger.warn(`Claim '${key}' contains complex object structure`);
}
}
}
} catch (e) {
this.logger.warn(`Failed to parse ID token: ${e}`);
}
} else {
this.logger.error('No ID token received from provider');
}
if (!claims?.sub) {
this.logger.error(
'No subject in token - claims available: ' +
(claims ? Object.keys(claims).join(', ') : 'none')
);
throw new UnauthorizedException('No subject in token');
}
const userSub = claims.sub;
this.logger.debug(`Processing authentication for user: ${userSub}`);
// Check authorization based on rules
// This will throw a helpful error if misconfigured or unauthorized
await this.checkAuthorization(provider, claims);
// Create session and return padded token
const paddedToken = await this.sessionService.createSession(providerId, userSub);
this.logger.log(`Successfully authenticated user ${userSub} via provider ${providerId}`);
return paddedToken;
} catch (error) {
this.logger.error(
`OAuth callback error: ${error instanceof Error ? error.message : 'Unknown error'}`
);
// Re-throw the original error if it's already an UnauthorizedException
if (error instanceof UnauthorizedException) {
throw error;
}
// Otherwise throw a generic error
throw new UnauthorizedException('Authentication failed');
}
}
private async getOrCreateConfig(provider: OidcProvider): Promise<client.Configuration> {
const cacheKey = provider.id;
if (this.configCache.has(cacheKey)) {
return this.configCache.get(cacheKey)!;
}
try {
// Use the validation service to perform discovery with HTTP support
if (provider.issuer) {
this.logger.debug(`Attempting discovery for ${provider.id} at ${provider.issuer}`);
// Create client options with HTTP support if needed
const serverUrl = new URL(provider.issuer);
let clientOptions: any = undefined;
if (serverUrl.protocol === 'http:') {
this.logger.debug(`Allowing HTTP for ${provider.id} as specified by user`);
clientOptions = {
execute: [client.allowInsecureRequests],
};
}
try {
const config = await this.validationService.performDiscovery(
provider,
clientOptions
);
this.logger.debug(`Discovery successful for ${provider.id}`);
this.logger.debug(
`Authorization endpoint: ${config.serverMetadata().authorization_endpoint}`
);
this.logger.debug(`Token endpoint: ${config.serverMetadata().token_endpoint}`);
this.configCache.set(cacheKey, config);
return config;
} catch (discoveryError) {
const errorMessage =
discoveryError instanceof Error ? discoveryError.message : 'Unknown error';
this.logger.warn(`Discovery failed for ${provider.id}: ${errorMessage}`);
// Log more details about the discovery error
this.logger.debug(
`Discovery URL attempted: ${provider.issuer}/.well-known/openid-configuration`
);
this.logger.debug(
`Full discovery error: ${JSON.stringify(discoveryError, null, 2)}`
);
// Log stack trace for better debugging
if (discoveryError instanceof Error && discoveryError.stack) {
this.logger.debug(`Stack trace: ${discoveryError.stack}`);
}
// If discovery fails but we have manual endpoints, use them
if (provider.authorizationEndpoint && provider.tokenEndpoint) {
this.logger.log(`Using manual endpoints for ${provider.id}`);
// Create manual configuration
const serverMetadata: client.ServerMetadata = {
issuer: provider.issuer || `manual-${provider.id}`,
authorization_endpoint: provider.authorizationEndpoint,
token_endpoint: provider.tokenEndpoint,
jwks_uri: provider.jwksUri,
};
const clientMetadata: Partial<client.ClientMetadata> = {
client_secret: provider.clientSecret,
};
// Configure client auth method
const clientAuth = provider.clientSecret
? client.ClientSecretPost(provider.clientSecret)
: client.None();
try {
const config = new client.Configuration(
serverMetadata,
provider.clientId,
clientMetadata,
clientAuth
);
// Use manual configuration with HTTP support if needed
const serverUrl = new URL(provider.tokenEndpoint);
if (serverUrl.protocol === 'http:') {
this.logger.debug(
`Allowing HTTP for manual endpoints on ${provider.id}`
);
client.allowInsecureRequests(config);
}
this.logger.debug(`Manual configuration created for ${provider.id}`);
this.logger.debug(
`Authorization endpoint: ${serverMetadata.authorization_endpoint}`
);
this.logger.debug(`Token endpoint: ${serverMetadata.token_endpoint}`);
this.configCache.set(cacheKey, config);
return config;
} catch (manualConfigError) {
this.logger.error(
`Failed to create manual configuration: ${manualConfigError instanceof Error ? manualConfigError.message : 'Unknown error'}`
);
throw new Error(`Manual configuration failed for ${provider.id}`);
}
} else {
throw new Error(
`OIDC discovery failed and no manual endpoints provided for ${provider.id}`
);
}
}
}
// Manual configuration when no issuer is provided
if (provider.authorizationEndpoint && provider.tokenEndpoint) {
this.logger.log(`Using manual endpoints for ${provider.id} (no issuer provided)`);
// Create manual configuration
const serverMetadata: client.ServerMetadata = {
issuer: provider.issuer || `manual-${provider.id}`,
authorization_endpoint: provider.authorizationEndpoint,
token_endpoint: provider.tokenEndpoint,
jwks_uri: provider.jwksUri,
};
const clientMetadata: Partial<client.ClientMetadata> = {
client_secret: provider.clientSecret,
};
// Configure client auth method
const clientAuth = provider.clientSecret
? client.ClientSecretPost(provider.clientSecret)
: client.None();
try {
const config = new client.Configuration(
serverMetadata,
provider.clientId,
clientMetadata,
clientAuth
);
// Use manual configuration with HTTP support if needed
const serverUrl = new URL(provider.tokenEndpoint);
if (serverUrl.protocol === 'http:') {
this.logger.debug(`Allowing HTTP for manual endpoints on ${provider.id}`);
client.allowInsecureRequests(config);
}
this.logger.debug(`Manual configuration created for ${provider.id}`);
this.logger.debug(
`Authorization endpoint: ${serverMetadata.authorization_endpoint}`
);
this.logger.debug(`Token endpoint: ${serverMetadata.token_endpoint}`);
this.configCache.set(cacheKey, config);
return config;
} catch (manualConfigError) {
this.logger.error(
`Failed to create manual configuration: ${manualConfigError instanceof Error ? manualConfigError.message : 'Unknown error'}`
);
throw new Error(`Manual configuration failed for ${provider.id}`);
}
}
// If we reach here, neither discovery nor manual endpoints are available
throw new Error(
`No configuration method available for ${provider.id}: requires either valid issuer for discovery or manual endpoints`
);
} catch (error) {
this.logger.error(
`Failed to create OIDC configuration for ${provider.id}: ${
error instanceof Error ? error.message : 'Unknown error'
}`
);
// Log more details in debug mode
if (error instanceof Error && error.stack) {
this.logger.debug(`Stack trace: ${error.stack}`);
}
throw new UnauthorizedException('Provider configuration error');
}
}
private async checkAuthorization(provider: OidcProvider, claims: JwtClaims): Promise<void> {
this.logger.debug(
`Checking authorization for provider ${provider.id} with ${provider.authorizationRules?.length || 0} rules`
);
this.logger.debug(`Available claims: ${Object.keys(claims).join(', ')}`);
this.logger.debug(
`Authorization rule mode: ${provider.authorizationRuleMode || AuthorizationRuleMode.OR}`
);
// If no authorization rules are specified, throw a helpful error
if (!provider.authorizationRules || provider.authorizationRules.length === 0) {
throw new UnauthorizedException(
`Login failed: The ${provider.name} provider has no authorization rules configured. ` +
`Please configure authorization rules.`
);
}
this.logger.debug(
`Authorization rules to evaluate: ${JSON.stringify(provider.authorizationRules, null, 2)}`
);
// Evaluate the rules
const ruleMode = provider.authorizationRuleMode || AuthorizationRuleMode.OR;
const isAuthorized = this.evaluateAuthorizationRules(
provider.authorizationRules,
claims,
ruleMode
);
this.logger.debug(`Authorization result: ${isAuthorized}`);
if (!isAuthorized) {
// Log authorization failure with safe claim representation (no PII)
const availableClaimKeys = Object.keys(claims).join(', ');
this.logger.warn(
`Authorization failed for provider ${provider.name}, user ${claims.sub}, available claim keys: [${availableClaimKeys}]`
);
throw new UnauthorizedException(
`Access denied: Your account does not meet the authorization requirements for ${provider.name}.`
);
}
this.logger.debug(`Authorization successful for user ${claims.sub}`);
}
private evaluateAuthorizationRules(
rules: OidcAuthorizationRule[],
claims: JwtClaims,
mode: AuthorizationRuleMode = AuthorizationRuleMode.OR
): boolean {
// No rules means no authorization
if (rules.length === 0) {
return false;
}
if (mode === AuthorizationRuleMode.AND) {
// All rules must pass (AND logic)
return rules.every((rule) => this.evaluateRule(rule, claims));
} else {
// Any rule can pass (OR logic) - default behavior
// Multiple rules act as alternative authorization paths
return rules.some((rule) => this.evaluateRule(rule, claims));
}
}
private evaluateRule(rule: OidcAuthorizationRule, claims: JwtClaims): boolean {
const claimValue = claims[rule.claim];
this.logger.verbose(
`Evaluating rule for claim ${rule.claim}: ${JSON.stringify({
claimValue,
claimType: typeof claimValue,
isArray: Array.isArray(claimValue),
ruleOperator: rule.operator,
ruleValues: rule.value,
})}`
);
if (claimValue === undefined || claimValue === null) {
this.logger.verbose(`Claim ${rule.claim} not found in token`);
return false;
}
// Handle non-array, non-string objects
if (typeof claimValue === 'object' && claimValue !== null && !Array.isArray(claimValue)) {
this.logger.warn(
`unexpected JWT claim value encountered - claim ${rule.claim} has unsupported object type (keys: [${Object.keys(claimValue as Record<string, unknown>).join(', ')}])`
);
return false;
}
// Handle array claims - evaluate rule against each array element
if (Array.isArray(claimValue)) {
this.logger.verbose(
`Processing array claim ${rule.claim} with ${claimValue.length} elements`
);
// For array claims, check if ANY element in the array matches the rule
const arrayResult = claimValue.some((element) => {
// Skip non-string elements
if (
typeof element !== 'string' &&
typeof element !== 'number' &&
typeof element !== 'boolean'
) {
this.logger.verbose(`Skipping non-primitive element in array: ${typeof element}`);
return false;
}
const elementValue = String(element);
return this.evaluateSingleValue(elementValue, rule);
});
this.logger.verbose(`Array evaluation result for claim ${rule.claim}: ${arrayResult}`);
return arrayResult;
}
// Handle single value claims (string, number, boolean)
const value = String(claimValue);
this.logger.verbose(`Processing single value claim ${rule.claim} with value: "${value}"`);
return this.evaluateSingleValue(value, rule);
}
private evaluateSingleValue(value: string, rule: OidcAuthorizationRule): boolean {
let result: boolean;
switch (rule.operator) {
case AuthorizationOperator.EQUALS:
result = rule.value.some((v) => value === v);
this.logger.verbose(
`EQUALS check: "${value}" matches any of [${rule.value.join(', ')}]: ${result}`
);
return result;
case AuthorizationOperator.CONTAINS:
result = rule.value.some((v) => value.includes(v));
this.logger.verbose(
`CONTAINS check: "${value}" contains any of [${rule.value.join(', ')}]: ${result}`
);
return result;
case AuthorizationOperator.STARTS_WITH:
result = rule.value.some((v) => value.startsWith(v));
this.logger.verbose(
`STARTS_WITH check: "${value}" starts with any of [${rule.value.join(', ')}]: ${result}`
);
return result;
case AuthorizationOperator.ENDS_WITH:
result = rule.value.some((v) => value.endsWith(v));
this.logger.verbose(
`ENDS_WITH check: "${value}" ends with any of [${rule.value.join(', ')}]: ${result}`
);
return result;
default:
this.logger.error(`Unknown authorization operator: ${rule.operator}`);
return false;
}
}
/**
* Validate OIDC provider configuration by attempting discovery
* Returns validation result with helpful error messages for debugging
*/
async validateProvider(
provider: OidcProvider
): Promise<{ isValid: boolean; error?: string; details?: unknown }> {
// Clear any cached config for this provider to force fresh validation
this.configCache.delete(provider.id);
// Delegate to the validation service
return this.validationService.validateProvider(provider);
}
private getRedirectUri(requestOrigin?: string): string {
// If we have the full origin (protocol://host), use it directly
if (requestOrigin) {
// Parse the origin to extract protocol and host
try {
const url = new URL(requestOrigin);
const { protocol, hostname, port } = url;
// Reconstruct the URL, removing default ports
let cleanOrigin = `${protocol}//${hostname}`;
// Add port if it's not the default for the protocol
if (
port &&
!(protocol === 'https:' && port === '443') &&
!(protocol === 'http:' && port === '80')
) {
cleanOrigin += `:${port}`;
}
// Special handling for localhost development with Nuxt proxy
if (hostname === 'localhost' && port === '3000') {
return `${cleanOrigin}/graphql/api/auth/oidc/callback`;
}
return `${cleanOrigin}/graphql/api/auth/oidc/callback`;
} catch (e) {
this.logger.warn(`Failed to parse request origin: ${requestOrigin}, error: ${e}`);
}
}
// Fall back to configured BASE_URL or default
const baseUrl = this.configService.get('BASE_URL', 'http://tower.local');
return `${baseUrl}/graphql/api/auth/oidc/callback`;
}
}

View File

@@ -1,204 +0,0 @@
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/oidc-state.service.js';
describe('OidcStateService', () => {
let service: OidcStateService;
beforeEach(() => {
vi.clearAllMocks();
vi.useFakeTimers();
// Create a single instance for all tests in a describe block
service = new OidcStateService();
});
afterEach(() => {
vi.useRealTimers();
});
describe('generateSecureState', () => {
it('should generate a state with provider prefix and signed token', () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state = service.generateSecureState(providerId, clientState);
expect(state).toBeTruthy();
expect(typeof state).toBe('string');
expect(state.startsWith(`${providerId}:`)).toBe(true);
// Extract signed portion and verify format (nonce.timestamp.signature)
const signed = state.substring(providerId.length + 1);
expect(signed.split('.').length).toBe(3);
});
it('should generate unique states for each call', () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state1 = service.generateSecureState(providerId, clientState);
const state2 = service.generateSecureState(providerId, clientState);
expect(state1).not.toBe(state2);
});
});
describe('validateSecureState', () => {
it('should validate a valid state token', () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state = service.generateSecureState(providerId, clientState);
const result = service.validateSecureState(state, providerId);
expect(result.isValid).toBe(true);
expect(result.clientState).toBe(clientState);
expect(result.error).toBeUndefined();
});
it('should reject state with wrong provider ID', () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state = service.generateSecureState(providerId, clientState);
const result = service.validateSecureState(state, 'wrong-provider');
expect(result.isValid).toBe(false);
expect(result.error).toBe('Provider ID mismatch in state');
});
it('should reject expired state tokens', () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state = service.generateSecureState(providerId, clientState);
// Fast forward time beyond expiration (11 minutes)
vi.advanceTimersByTime(11 * 60 * 1000);
const result = service.validateSecureState(state, providerId);
expect(result.isValid).toBe(false);
expect(result.error).toBe('State token has expired');
});
it('should reject reused state tokens', () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state = service.generateSecureState(providerId, clientState);
// First validation should succeed
const result1 = service.validateSecureState(state, providerId);
expect(result1.isValid).toBe(true);
// Second validation should fail (replay attack prevention)
const result2 = service.validateSecureState(state, providerId);
expect(result2.isValid).toBe(false);
expect(result2.error).toBe('State token not found or already used');
});
it('should reject invalid state tokens', () => {
const result = service.validateSecureState('invalid.state.token', 'test-provider');
expect(result.isValid).toBe(false);
expect(result.error).toBe('Invalid state format');
});
it('should reject tampered state tokens', () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state = service.generateSecureState(providerId, clientState);
// Tamper with the signature
const parts = state.split('.');
parts[2] = parts[2].slice(0, -4) + 'XXXX';
const tamperedState = parts.join('.');
const result = service.validateSecureState(tamperedState, providerId);
expect(result.isValid).toBe(false);
expect(result.error).toBe('Invalid state signature');
});
});
describe('extractProviderFromState', () => {
it('should extract provider from state prefix', () => {
const state = 'provider-id:eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.payload.signature';
const result = service.extractProviderFromState(state);
expect(result).toBe('provider-id');
});
it('should handle states with multiple colons', () => {
const state = 'provider-id:jwt:with:colons';
const result = service.extractProviderFromState(state);
expect(result).toBe('provider-id');
});
it('should return null for invalid format', () => {
const result = service.extractProviderFromState('invalid-state');
expect(result).toBeNull();
});
});
describe('extractProviderFromLegacyState', () => {
it('should extract provider from legacy colon-separated format', () => {
const result = service.extractProviderFromLegacyState('provider-id:client-state');
expect(result.providerId).toBe('provider-id');
expect(result.originalState).toBe('client-state');
});
it('should handle multiple colons in legacy format', () => {
const result = service.extractProviderFromLegacyState(
'provider-id:client:state:with:colons'
);
expect(result.providerId).toBe('provider-id');
expect(result.originalState).toBe('client:state:with:colons');
});
it('should return empty provider for JWT format', () => {
const jwtState = 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.payload.signature';
const result = service.extractProviderFromLegacyState(jwtState);
expect(result.providerId).toBe('');
expect(result.originalState).toBe(jwtState);
});
it('should return empty provider for unknown format', () => {
const result = service.extractProviderFromLegacyState('some-random-state');
expect(result.providerId).toBe('');
expect(result.originalState).toBe('some-random-state');
});
});
describe('cleanupExpiredStates', () => {
it('should clean up expired states periodically', () => {
const providerId = 'test-provider';
// Generate multiple states
service.generateSecureState(providerId, 'state1');
service.generateSecureState(providerId, 'state2');
service.generateSecureState(providerId, 'state3');
// Fast forward past expiration
vi.advanceTimersByTime(11 * 60 * 1000);
// Generate a new state that shouldn't be cleaned
const validState = service.generateSecureState(providerId, 'state4');
// Trigger cleanup (happens every minute)
vi.advanceTimersByTime(60 * 1000);
// The new state should still be valid
const result = service.validateSecureState(validState, providerId);
expect(result.isValid).toBe(true);
});
});
});

View File

@@ -1,164 +0,0 @@
import { Injectable, Logger } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import * as client from 'openid-client';
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/oidc-provider.model.js';
@Injectable()
export class OidcValidationService {
private readonly logger = new Logger(OidcValidationService.name);
constructor(private readonly configService: ConfigService) {}
/**
* Validate OIDC provider configuration by attempting discovery
* Returns validation result with helpful error messages for debugging
*/
async validateProvider(
provider: OidcProvider
): Promise<{ isValid: boolean; error?: string; details?: unknown }> {
try {
// Validate issuer URL is present
if (!provider.issuer) {
return {
isValid: false,
error: 'No issuer URL provided. Please specify the OIDC provider issuer URL.',
details: { type: 'MISSING_ISSUER' },
};
}
// Validate issuer URL is valid
let serverUrl: URL;
try {
serverUrl = new URL(provider.issuer);
} catch (urlError) {
return {
isValid: false,
error: `Invalid issuer URL format: '${provider.issuer}'. Please provide a valid URL.`,
details: {
type: 'INVALID_URL',
originalError: urlError instanceof Error ? urlError.message : String(urlError),
},
};
}
// Configure client options for HTTP if needed
let clientOptions: any = undefined;
if (serverUrl.protocol === 'http:') {
this.logger.debug(
`HTTP issuer URL detected for provider ${provider.id}: ${provider.issuer}`
);
clientOptions = {
execute: [client.allowInsecureRequests],
};
}
// Attempt OIDC discovery
await this.performDiscovery(provider, clientOptions);
return { isValid: true };
} catch (error) {
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
// Log the raw error for debugging
this.logger.debug(`Raw discovery error for ${provider.id}: ${errorMessage}`);
// Provide specific error messages for common issues
let userFriendlyError = errorMessage;
let details: Record<string, unknown> = {};
if (errorMessage.includes('getaddrinfo ENOTFOUND')) {
userFriendlyError = `Cannot resolve domain name. Please check that '${provider.issuer}' is accessible and spelled correctly.`;
details = { type: 'DNS_ERROR', originalError: errorMessage };
} else if (errorMessage.includes('ECONNREFUSED')) {
userFriendlyError = `Connection refused. The server at '${provider.issuer}' is not accepting connections.`;
details = { type: 'CONNECTION_ERROR', originalError: errorMessage };
} else if (errorMessage.includes('ECONNRESET') || errorMessage.includes('ETIMEDOUT')) {
userFriendlyError = `Connection timeout. The server at '${provider.issuer}' is not responding.`;
details = { type: 'TIMEOUT_ERROR', originalError: errorMessage };
} else if (errorMessage.includes('404') || errorMessage.includes('Not Found')) {
const baseUrl = provider.issuer?.endsWith('/.well-known/openid-configuration')
? provider.issuer.replace('/.well-known/openid-configuration', '')
: provider.issuer;
userFriendlyError = `OIDC discovery endpoint not found. Please verify that '${baseUrl}/.well-known/openid-configuration' exists.`;
details = { type: 'DISCOVERY_NOT_FOUND', originalError: errorMessage };
} else if (errorMessage.includes('401') || errorMessage.includes('403')) {
userFriendlyError = `Access denied to discovery endpoint. Please check the issuer URL and any authentication requirements.`;
details = { type: 'AUTHENTICATION_ERROR', originalError: errorMessage };
} else if (errorMessage.includes('unexpected HTTP response status code')) {
// Extract status code if possible
const statusMatch = errorMessage.match(/status code (\d+)/);
const statusCode = statusMatch ? statusMatch[1] : 'unknown';
const baseUrl = provider.issuer?.endsWith('/.well-known/openid-configuration')
? provider.issuer.replace('/.well-known/openid-configuration', '')
: provider.issuer;
userFriendlyError = `HTTP ${statusCode} error from discovery endpoint. Please check that '${baseUrl}/.well-known/openid-configuration' returns a valid OIDC discovery document.`;
details = { type: 'HTTP_STATUS_ERROR', statusCode, originalError: errorMessage };
} else if (
errorMessage.includes('certificate') ||
errorMessage.includes('SSL') ||
errorMessage.includes('TLS')
) {
userFriendlyError = `SSL/TLS certificate error. The server certificate may be invalid or expired.`;
details = { type: 'SSL_ERROR', originalError: errorMessage };
} else if (errorMessage.includes('JSON') || errorMessage.includes('parse')) {
userFriendlyError = `Invalid OIDC discovery response. The server returned malformed JSON.`;
details = { type: 'INVALID_JSON', originalError: errorMessage };
} else if (error && (error as any).code === 'OAUTH_RESPONSE_IS_NOT_CONFORM') {
const baseUrl = provider.issuer?.endsWith('/.well-known/openid-configuration')
? provider.issuer.replace('/.well-known/openid-configuration', '')
: provider.issuer;
userFriendlyError = `Invalid OIDC discovery document. The server at '${baseUrl}/.well-known/openid-configuration' returned a response that doesn't conform to the OpenID Connect Discovery specification. Please verify the endpoint returns valid OIDC metadata.`;
details = { type: 'INVALID_OIDC_DOCUMENT', originalError: errorMessage };
}
this.logger.warn(`OIDC validation failed for provider ${provider.id}: ${errorMessage}`);
// Add debug logging for HTTP status errors
if (errorMessage.includes('unexpected HTTP response status code')) {
const baseUrl = provider.issuer?.endsWith('/.well-known/openid-configuration')
? provider.issuer.replace('/.well-known/openid-configuration', '')
: provider.issuer;
this.logger.debug(`Attempted to fetch: ${baseUrl}/.well-known/openid-configuration`);
this.logger.debug(`Full error details: ${errorMessage}`);
}
return {
isValid: false,
error: userFriendlyError,
details,
};
}
}
async performDiscovery(provider: OidcProvider, clientOptions?: any): Promise<client.Configuration> {
if (!provider.issuer) {
throw new Error('No issuer URL provided');
}
// Configure client auth method
const clientAuth = provider.clientSecret
? client.ClientSecretPost(provider.clientSecret)
: undefined;
const serverUrl = new URL(provider.issuer);
// Use provided client options or create default options with HTTP support if needed
if (!clientOptions && serverUrl.protocol === 'http:') {
this.logger.debug(`Allowing HTTP for ${provider.id} as specified by user`);
// For openid-client v6, use allowInsecureRequests in the execute array
// This is deprecated but needed for local development with HTTP endpoints
clientOptions = {
execute: [client.allowInsecureRequests],
};
}
return client.discovery(
serverUrl,
provider.clientId,
undefined, // client metadata
clientAuth,
clientOptions
);
}
}

View File

@@ -0,0 +1,10 @@
import { Module } from '@nestjs/common';
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-session.service.js';
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
@Module({
providers: [OidcSessionService, OidcStateService],
exports: [OidcSessionService, OidcStateService],
})
export class OidcSessionModule {}

View File

@@ -4,7 +4,7 @@ import { Test } from '@nestjs/testing';
import type { Cache } from 'cache-manager';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/oidc-session.service.js';
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-session.service.js';
describe('OidcSessionService', () => {
let service: OidcSessionService;

View File

@@ -15,7 +15,7 @@ export interface OidcSession {
@Injectable()
export class OidcSessionService {
private readonly logger = new Logger(OidcSessionService.name);
private readonly SESSION_TTL_SECONDS = 2 * 60; // 2 minutes for one-time token security
private readonly SESSION_TTL_MS = 2 * 60 * 1000; // 2 minutes in milliseconds (cache-manager v7 expects milliseconds)
constructor(@Inject(CACHE_MANAGER) private readonly cacheManager: Cache) {}
@@ -28,12 +28,21 @@ export class OidcSessionService {
providerId,
providerUserId,
createdAt: now,
expiresAt: new Date(now.getTime() + this.SESSION_TTL_SECONDS * 1000),
expiresAt: new Date(now.getTime() + this.SESSION_TTL_MS),
};
// Store in cache with TTL
await this.cacheManager.set(sessionId, session, this.SESSION_TTL_SECONDS * 1000);
this.logger.log(`Created OIDC session ${sessionId} for provider ${providerId}`);
// Store in cache with TTL (in milliseconds for cache-manager v7)
await this.cacheManager.set(sessionId, session, this.SESSION_TTL_MS);
// Verify it was stored
const verifyStored = await this.cacheManager.get(sessionId);
if (verifyStored) {
this.logger.debug(`Session successfully stored and verified with ID: ${sessionId}`);
} else {
this.logger.error(`CRITICAL: Session was NOT stored in cache for ID: ${sessionId}`);
}
this.logger.log(`Created OIDC session for provider ${providerId}`);
return this.createPaddedToken(sessionId);
}
@@ -44,15 +53,16 @@ export class OidcSessionService {
return { valid: false };
}
this.logger.debug(`Looking for session with ID: ${sessionId}`);
const session = await this.cacheManager.get<OidcSession>(sessionId);
if (!session) {
this.logger.debug(`Session ${sessionId} not found`);
this.logger.debug(`Session not found for ID: ${sessionId}`);
return { valid: false };
}
const now = new Date();
if (now > new Date(session.expiresAt)) {
this.logger.debug(`Session ${sessionId} expired`);
this.logger.debug(`Session expired`);
await this.cacheManager.del(sessionId);
return { valid: false };
}
@@ -62,7 +72,7 @@ export class OidcSessionService {
await this.cacheManager.del(sessionId);
this.logger.log(
`Validated and invalidated session ${sessionId} for provider ${session.providerId} (one-time use)`
`Validated and invalidated session for provider ${session.providerId} (one-time use)`
);
return { valid: true, username: 'root' };
}

View File

@@ -0,0 +1,155 @@
import { CacheModule } from '@nestjs/cache-manager';
import { UnauthorizedException } from '@nestjs/common';
import { Test } from '@nestjs/testing';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { OidcStateExtractor } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state-extractor.util.js';
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
describe('OidcStateExtractor', () => {
let stateService: OidcStateService;
beforeEach(async () => {
vi.clearAllMocks();
const module = await Test.createTestingModule({
imports: [CacheModule.register()],
providers: [OidcStateService],
}).compile();
stateService = module.get<OidcStateService>(OidcStateService);
});
describe('extractProviderFromState', () => {
it('should extract provider ID from valid state', () => {
const state = 'provider123:nonce.timestamp.signature';
const result = OidcStateExtractor.extractProviderFromState(state, stateService);
expect(result.providerId).toBe('provider123');
expect(result.originalState).toBe(state);
});
it('should handle state without provider prefix', () => {
const state = 'invalid-state-format';
const result = OidcStateExtractor.extractProviderFromState(state, stateService);
expect(result.providerId).toBe('');
expect(result.originalState).toBe(state);
});
});
describe('extractAndValidateState', () => {
it('should extract and validate a valid state with redirectUri', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const redirectUri = 'https://example.com/callback';
// Generate a valid state
const state = await stateService.generateSecureState(providerId, clientState, redirectUri);
// Extract and validate
const result = await OidcStateExtractor.extractAndValidateState(state, stateService);
expect(result.providerId).toBe(providerId);
expect(result.originalState).toBe(state);
expect(result.clientState).toBe(clientState);
expect(result.redirectUri).toBe(redirectUri);
});
it('should extract and validate a valid state without redirectUri', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
// Generate a valid state without redirectUri
const state = await stateService.generateSecureState(providerId, clientState);
// Extract and validate
const result = await OidcStateExtractor.extractAndValidateState(state, stateService);
expect(result.providerId).toBe(providerId);
expect(result.originalState).toBe(state);
expect(result.clientState).toBe(clientState);
expect(result.redirectUri).toBeUndefined();
});
it('should throw UnauthorizedException for invalid state format', async () => {
const invalidState = 'invalid-format';
await expect(async () => {
await OidcStateExtractor.extractAndValidateState(invalidState, stateService);
}).rejects.toThrow(UnauthorizedException);
});
it('should throw UnauthorizedException for expired state', async () => {
vi.useFakeTimers();
const providerId = 'test-provider';
const clientState = 'client-state-123';
const redirectUri = 'https://example.com/callback';
// Generate a valid state
const state = await stateService.generateSecureState(providerId, clientState, redirectUri);
// Fast forward time beyond expiration (11 minutes)
vi.advanceTimersByTime(11 * 60 * 1000);
await expect(async () => {
await OidcStateExtractor.extractAndValidateState(state, stateService);
}).rejects.toThrow(UnauthorizedException);
vi.useRealTimers();
});
it('should throw UnauthorizedException for wrong provider ID', async () => {
const providerId = 'test-provider';
const wrongProviderId = 'wrong-provider';
const clientState = 'client-state-123';
const redirectUri = 'https://example.com/callback';
// Generate a valid state
const state = await stateService.generateSecureState(providerId, clientState, redirectUri);
// Create a fake state with wrong provider prefix
const tamperedState = state.replace(providerId, wrongProviderId);
await expect(async () => {
await OidcStateExtractor.extractAndValidateState(tamperedState, stateService);
}).rejects.toThrow(UnauthorizedException);
});
it('should throw UnauthorizedException for tampered state', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const redirectUri = 'https://example.com/callback';
// Generate a valid state
const state = await stateService.generateSecureState(providerId, clientState, redirectUri);
// Tamper with the signature
const tamperedState = state.slice(0, -5) + 'xxxxx';
await expect(async () => {
await OidcStateExtractor.extractAndValidateState(tamperedState, stateService);
}).rejects.toThrow(UnauthorizedException);
});
it('should throw UnauthorizedException for reused state (replay attack)', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const redirectUri = 'https://example.com/callback';
// Generate a valid state
const state = await stateService.generateSecureState(providerId, clientState, redirectUri);
// First validation should succeed
const result1 = await OidcStateExtractor.extractAndValidateState(state, stateService);
expect(result1.providerId).toBe(providerId);
// Second validation should fail (replay attack)
await expect(async () => {
await OidcStateExtractor.extractAndValidateState(state, stateService);
}).rejects.toThrow(UnauthorizedException);
});
});
});

View File

@@ -0,0 +1,60 @@
import { UnauthorizedException } from '@nestjs/common';
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
export interface StateExtractionResult {
providerId: string;
originalState: string;
clientState?: string;
redirectUri?: string;
}
/**
* Utility to extract and validate OIDC state information consistently
* across authorize and callback endpoints
*/
export class OidcStateExtractor {
/**
* Extract provider ID from state without validation (for routing purposes)
*/
static extractProviderFromState(
state: string,
stateService: OidcStateService
): { providerId: string; originalState: string } {
// Use the state service's extraction method
const providerId = stateService.extractProviderFromState(state);
return {
providerId: providerId || '',
originalState: state,
};
}
/**
* Extract provider ID and validate the full encrypted state
*/
static async extractAndValidateState(
state: string,
stateService: OidcStateService
): Promise<StateExtractionResult> {
// First extract provider ID for routing
const { providerId } = this.extractProviderFromState(state, stateService);
if (!providerId) {
throw new UnauthorizedException('Invalid state format: missing provider ID');
}
// Then validate the full encrypted state
const stateValidation = await stateService.validateSecureState(state, providerId);
if (!stateValidation.isValid) {
throw new UnauthorizedException(`Invalid state: ${stateValidation.error}`);
}
return {
providerId,
originalState: state,
clientState: stateValidation.clientState,
redirectUri: stateValidation.redirectUri,
};
}
}

View File

@@ -0,0 +1,238 @@
import { CacheModule } from '@nestjs/cache-manager';
import { Test, TestingModule } from '@nestjs/testing';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
describe('OidcStateService', () => {
let service: OidcStateService;
let module: TestingModule;
beforeEach(async () => {
vi.clearAllMocks();
vi.useFakeTimers();
// Set a deterministic system time for consistent testing
vi.setSystemTime(new Date('2024-01-01T00:00:00Z'));
module = await Test.createTestingModule({
imports: [CacheModule.register()],
providers: [OidcStateService],
}).compile();
service = module.get<OidcStateService>(OidcStateService);
});
afterEach(async () => {
vi.useRealTimers();
// Close the testing module to prevent handle leaks
if (module) {
await module.close();
}
});
describe('generateSecureState', () => {
it('should generate a state with provider prefix and signed token', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const redirectUri = 'https://example.com/callback';
const state = await service.generateSecureState(providerId, clientState, redirectUri);
expect(state).toBeTruthy();
expect(typeof state).toBe('string');
expect(state.startsWith(`${providerId}:`)).toBe(true);
// Extract signed portion and verify format (nonce.timestamp.signature)
const signed = state.substring(providerId.length + 1);
expect(signed.split('.').length).toBe(3);
});
it('should generate unique states for each call', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const redirectUri = 'https://example.com/callback';
const state1 = await service.generateSecureState(providerId, clientState, redirectUri);
const state2 = await service.generateSecureState(providerId, clientState, redirectUri);
expect(state1).not.toBe(state2);
});
it('should work without redirectUri parameter (backwards compatibility)', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state = await service.generateSecureState(providerId, clientState);
expect(state).toBeTruthy();
expect(state.startsWith(`${providerId}:`)).toBe(true);
});
it('should store state data in cache and retrieve it', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const redirectUri = 'https://example.com/callback';
const state = await service.generateSecureState(providerId, clientState, redirectUri);
const validation = await service.validateSecureState(state, providerId);
expect(validation.isValid).toBe(true);
expect(validation.clientState).toBe(clientState);
expect(validation.redirectUri).toBe(redirectUri);
});
});
describe('validateSecureState', () => {
it('should validate a valid state token', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state = await service.generateSecureState(providerId, clientState);
const result = await service.validateSecureState(state, providerId);
expect(result.isValid).toBe(true);
expect(result.clientState).toBe(clientState);
expect(result.error).toBeUndefined();
});
it('should validate a state token with redirectUri', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const redirectUri = 'https://example.com/callback';
const state = await service.generateSecureState(providerId, clientState, redirectUri);
const result = await service.validateSecureState(state, providerId);
expect(result.isValid).toBe(true);
expect(result.clientState).toBe(clientState);
expect(result.redirectUri).toBe(redirectUri);
expect(result.error).toBeUndefined();
});
it('should reject state with wrong provider ID', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state = await service.generateSecureState(providerId, clientState);
const result = await service.validateSecureState(state, 'different-provider');
expect(result.isValid).toBe(false);
expect(result.error).toContain('Provider ID mismatch');
});
it('should reject expired state tokens', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state = await service.generateSecureState(providerId, clientState);
// Advance time by 11 minutes (past the 10-minute TTL)
vi.advanceTimersByTime(11 * 60 * 1000);
const result = await service.validateSecureState(state, providerId);
expect(result.isValid).toBe(false);
expect(result.error).toContain('expired');
});
it('should reject reused state tokens', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state = await service.generateSecureState(providerId, clientState);
// First validation should succeed
const result1 = await service.validateSecureState(state, providerId);
expect(result1.isValid).toBe(true);
// Second validation should fail (replay attack prevention)
const result2 = await service.validateSecureState(state, providerId);
expect(result2.isValid).toBe(false);
expect(result2.error).toContain('not found or already used');
});
it('should reject invalid state tokens', async () => {
const providerId = 'test-provider';
const invalidState = `${providerId}:invalid-format`;
const result = await service.validateSecureState(invalidState, providerId);
expect(result.isValid).toBe(false);
expect(result.error).toBeTruthy();
});
it('should reject tampered state tokens', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state = await service.generateSecureState(providerId, clientState);
// Tamper with the signature
const tamperedState = state.substring(0, state.length - 5) + 'xxxxx';
const result = await service.validateSecureState(tamperedState, providerId);
expect(result.isValid).toBe(false);
expect(result.error).toContain('signature');
});
});
describe('extractProviderFromState', () => {
it('should extract provider ID from state', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state = await service.generateSecureState(providerId, clientState);
const extracted = service.extractProviderFromState(state);
expect(extracted).toBe(providerId);
});
it('should return null for invalid state format', () => {
const invalidState = 'invalid-state-without-colon';
const extracted = service.extractProviderFromState(invalidState);
expect(extracted).toBeNull();
});
});
describe('extractProviderFromLegacyState', () => {
it('should handle legacy state format', () => {
const legacyState = 'provider-id:client-state-value';
const result = service.extractProviderFromLegacyState(legacyState);
expect(result.providerId).toBe('provider-id');
expect(result.originalState).toBe('client-state-value');
});
it('should handle new signed state format', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state = await service.generateSecureState(providerId, clientState);
const result = service.extractProviderFromLegacyState(state);
// New format should not be recognized as legacy
expect(result.providerId).toBe('');
expect(result.originalState).toBe(state);
});
});
describe('cache TTL', () => {
it('should remove state from cache after successful validation', async () => {
const providerId = 'test-provider';
const clientState = 'client-state-123';
const state = await service.generateSecureState(providerId, clientState);
// First validation should succeed
const result1 = await service.validateSecureState(state, providerId);
expect(result1.isValid).toBe(true);
// Second validation should fail (state was removed after first use)
const result2 = await service.validateSecureState(state, providerId);
expect(result2.isValid).toBe(false);
expect(result2.error).toContain('not found or already used');
});
});
});

View File

@@ -0,0 +1,241 @@
import { CacheModule } from '@nestjs/cache-manager';
import { Test } from '@nestjs/testing';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
describe('OidcStateService', () => {
let service: OidcStateService;
beforeEach(async () => {
const module = await Test.createTestingModule({
imports: [CacheModule.register()],
providers: [OidcStateService],
}).compile();
service = module.get<OidcStateService>(OidcStateService);
});
describe('state generation and validation flow', () => {
it('should generate state with redirect URI and validate it successfully', async () => {
const providerId = 'unraid.net';
const clientState = 'client-state-123';
const redirectUri = 'http://devgen-dev1.local/graphql/api/auth/oidc/callback';
// Generate state
const state = await service.generateSecureState(providerId, clientState, redirectUri);
// Verify state format: providerId:nonce.timestamp.signature
expect(state).toMatch(/^unraid\.net:[a-f0-9]+\.\d+\.[a-f0-9]+$/);
// Extract and verify parts
const [extractedProviderId, signedState] = state.split(':');
expect(extractedProviderId).toBe(providerId);
// Parse the signed state components
const [nonce, timestamp, signature] = signedState.split('.');
// Verify nonce is a 32-character hex string (16 bytes)
expect(nonce).toMatch(/^[a-f0-9]{32}$/);
// Verify timestamp is a valid number and recent
const timestampNum = parseInt(timestamp, 10);
expect(timestampNum).toBeGreaterThan(Date.now() - 1000); // Generated within last second
expect(timestampNum).toBeLessThanOrEqual(Date.now());
// Verify signature is a 64-character hex string (SHA256 output)
expect(signature).toMatch(/^[a-f0-9]{64}$/);
// Validate the state
const validation = await service.validateSecureState(state, providerId);
expect(validation.isValid).toBe(true);
expect(validation.clientState).toBe(clientState);
expect(validation.redirectUri).toBe(redirectUri);
});
it('should verify signed state integrity with HMAC', async () => {
const providerId = 'test-provider';
const clientState = 'test-state';
const redirectUri = 'http://localhost:3000/callback';
const state = await service.generateSecureState(providerId, clientState, redirectUri);
// Tamper with the signature
const [provider, signedState] = state.split(':');
const [nonce, timestamp] = signedState.split('.');
const tamperedSignature = 'a'.repeat(64); // Invalid signature
const tamperedState = `${provider}:${nonce}.${timestamp}.${tamperedSignature}`;
const validation = await service.validateSecureState(tamperedState, providerId);
expect(validation.isValid).toBe(false);
expect(validation.error).toContain('Invalid state signature');
});
it('should fail validation when nonce is not in cache', async () => {
const providerId = 'unraid.net';
// Create a fake state that looks valid but has unknown nonce
const fakeState = `unraid.net:fakenonce123.${Date.now()}.fakesignature456`;
const validation = await service.validateSecureState(fakeState, providerId);
expect(validation.isValid).toBe(false);
expect(validation.error).toContain('Invalid state signature');
});
it('should prevent replay attacks by removing nonce after validation', async () => {
const providerId = 'test-provider';
const clientState = 'test-state';
const redirectUri = 'http://localhost:3000/callback';
// Generate and validate state once
const state = await service.generateSecureState(providerId, clientState, redirectUri);
const firstValidation = await service.validateSecureState(state, providerId);
expect(firstValidation.isValid).toBe(true);
// Try to validate the same state again (replay attack)
const secondValidation = await service.validateSecureState(state, providerId);
expect(secondValidation.isValid).toBe(false);
expect(secondValidation.error).toContain('State token not found or already used');
});
it('should handle state with missing redirect URI', async () => {
const providerId = 'test-provider';
const clientState = 'test-state';
// No redirect URI provided
const state = await service.generateSecureState(providerId, clientState);
const validation = await service.validateSecureState(state, providerId);
expect(validation.isValid).toBe(true);
expect(validation.clientState).toBe(clientState);
expect(validation.redirectUri).toBeUndefined();
});
it('should reject state with wrong provider ID', async () => {
const providerId = 'provider-a';
const wrongProviderId = 'provider-b';
const clientState = 'test-state';
const state = await service.generateSecureState(providerId, clientState);
const validation = await service.validateSecureState(state, wrongProviderId);
expect(validation.isValid).toBe(false);
expect(validation.error).toContain('Provider ID mismatch');
});
it('should extract provider from state correctly', async () => {
const providerId = 'unraid.net';
const state = await service.generateSecureState(providerId, 'test', 'http://example.com');
const extracted = service.extractProviderFromState(state);
expect(extracted).toBe(providerId);
});
it('should handle state expiration', async () => {
const providerId = 'test-provider';
const clientState = 'test-state';
// Generate state
const state = await service.generateSecureState(providerId, clientState);
// Mock timestamp to simulate expired state
const parts = state.split(':')[1].split('.');
const nonce = parts[0];
const expiredTimestamp = Date.now() - 700000; // 11+ minutes ago
const fakeState = `${providerId}:${nonce}.${expiredTimestamp}.fakesignature`;
const validation = await service.validateSecureState(fakeState, providerId);
expect(validation.isValid).toBe(false);
expect(validation.error).toContain('Invalid state signature'); // Will fail on signature first
});
});
describe('redirect URI extraction from state', () => {
it('should store and retrieve redirect URI from state token', async () => {
const providerId = 'unraid.net';
const clientState = 'original-client-state';
const redirectUri = 'http://devgen-dev1.local/graphql/api/auth/oidc/callback';
// This simulates the authorize flow
const stateToken = await service.generateSecureState(providerId, clientState, redirectUri);
// Log the generated state for debugging
console.log('Generated state token:', stateToken);
// This simulates the callback flow
const validation = await service.validateSecureState(stateToken, providerId);
expect(validation.isValid).toBe(true);
expect(validation.redirectUri).toBe(redirectUri);
expect(validation.clientState).toBe(clientState);
});
it('should handle dynamic redirect URIs for different origins', async () => {
const providerId = 'google';
const clientState = 'state123';
// Test with different origins
const origins = [
'http://localhost:3000/graphql/api/auth/oidc/callback',
'https://myserver.local/graphql/api/auth/oidc/callback',
'http://192.168.1.100/graphql/api/auth/oidc/callback',
];
for (const redirectUri of origins) {
const state = await service.generateSecureState(providerId, clientState, redirectUri);
const validation = await service.validateSecureState(state, providerId);
expect(validation.isValid).toBe(true);
expect(validation.redirectUri).toBe(redirectUri);
}
});
});
describe('cache management', () => {
it('should handle TTL expiration correctly', async () => {
const providerId = 'test-provider';
const clientState = 'test-state';
const state = await service.generateSecureState(providerId, clientState);
// First validation should succeed
const validation1 = await service.validateSecureState(state, providerId);
expect(validation1.isValid).toBe(true);
// State should be removed after first use (replay protection)
const validation2 = await service.validateSecureState(state, providerId);
expect(validation2.isValid).toBe(false);
});
it('should store complete state data in cache with redirect URI', async () => {
const providerId = 'test-provider';
const clientState = 'client-123';
const redirectUri = 'http://example.com/callback';
const state = await service.generateSecureState(providerId, clientState, redirectUri);
// Extract nonce from the generated state
const [, signedState] = state.split(':');
const [nonce] = signedState.split('.');
// Access the cache directly to verify stored data
const cacheKey = `oidc_state:${nonce}`;
const cachedData = await service['cacheManager'].get(cacheKey);
expect(cachedData).toBeDefined();
expect(cachedData).toMatchObject({
nonce,
clientState,
providerId,
redirectUri,
});
// @ts-expect-error - cachedData is of type StateData
expect(cachedData.timestamp).toBeGreaterThan(Date.now() - 1000);
// @ts-expect-error - cachedData is of type StateData
expect(cachedData.timestamp).toBeLessThanOrEqual(Date.now());
});
});
});

View File

@@ -1,4 +1,5 @@
import { Injectable, Logger } from '@nestjs/common';
import { Cache, CACHE_MANAGER } from '@nestjs/cache-manager';
import { Inject, Injectable, Logger } from '@nestjs/common';
import crypto from 'crypto';
interface StateData {
@@ -6,26 +7,34 @@ interface StateData {
clientState: string;
timestamp: number;
providerId: string;
redirectUri?: string;
}
@Injectable()
export class OidcStateService {
private static instanceCount = 0;
private readonly instanceId: number;
private readonly logger = new Logger(OidcStateService.name);
private readonly stateCache = new Map<string, StateData>();
private readonly hmacSecret: string;
private readonly STATE_TTL_SECONDS = 600; // 10 minutes
private readonly STATE_TTL_MS = 600000; // 10 minutes in milliseconds (cache-manager v7+ expects milliseconds, not seconds)
private readonly STATE_CACHE_PREFIX = 'oidc_state:';
constructor(@Inject(CACHE_MANAGER) private cacheManager: Cache) {
// Track instance creation
this.instanceId = ++OidcStateService.instanceCount;
constructor() {
// Always generate a new secret on API restart for security
// This ensures state tokens cannot be reused across restarts
this.hmacSecret = crypto.randomBytes(32).toString('hex');
this.logger.debug('Generated new OIDC state secret for this session');
// Clean up expired states periodically
setInterval(() => this.cleanupExpiredStates(), 60000); // Every minute
this.logger.warn(`OidcStateService instance #${this.instanceId} created with new HMAC secret`);
this.logger.debug(`HMAC secret first 8 chars: ${this.hmacSecret.substring(0, 8)}`);
}
generateSecureState(providerId: string, clientState: string): string {
async generateSecureState(
providerId: string,
clientState: string,
redirectUri?: string
): Promise<string> {
const nonce = crypto.randomBytes(16).toString('hex');
const timestamp = Date.now();
@@ -35,8 +44,21 @@ export class OidcStateService {
clientState,
timestamp,
providerId,
redirectUri,
};
this.stateCache.set(nonce, stateData);
// Store in cache with TTL (in milliseconds for cache-manager v7)
const cacheKey = `${this.STATE_CACHE_PREFIX}${nonce}`;
this.logger.debug(`Storing state with key: ${cacheKey}, TTL: ${this.STATE_TTL_MS}ms`);
await this.cacheManager.set(cacheKey, stateData, this.STATE_TTL_MS);
// Verify it was stored
const verifyStored = await this.cacheManager.get(cacheKey);
if (verifyStored) {
this.logger.debug(`State successfully stored and verified for key: ${cacheKey}`);
} else {
this.logger.error(`CRITICAL: State was NOT stored in cache for key: ${cacheKey}`);
}
// Create signed state: nonce.timestamp.signature
const dataToSign = `${nonce}.${timestamp}`;
@@ -45,14 +67,18 @@ export class OidcStateService {
const signedState = `${dataToSign}.${signature}`;
this.logger.debug(`Generated secure state for provider ${providerId} with nonce ${nonce}`);
this.logger.debug(
`Instance #${this.instanceId}, HMAC secret first 8 chars: ${this.hmacSecret.substring(0, 8)}`
);
this.logger.debug(`Stored redirectUri: ${redirectUri}`);
// Return state with provider ID prefix (unencrypted) for routing
return `${providerId}:${signedState}`;
}
validateSecureState(
async validateSecureState(
state: string,
expectedProviderId: string
): { isValid: boolean; clientState?: string; error?: string } {
): Promise<{ isValid: boolean; clientState?: string; redirectUri?: string; error?: string }> {
try {
// Extract provider ID and signed state
const parts = state.split(':');
@@ -107,7 +133,7 @@ export class OidcStateService {
// Check timestamp expiration
const now = Date.now();
const age = now - timestamp;
if (age > this.STATE_TTL_SECONDS * 1000) {
if (age > this.STATE_TTL_MS) {
this.logger.warn(`State validation failed: token expired (age: ${age}ms)`);
return {
isValid: false,
@@ -116,11 +142,21 @@ export class OidcStateService {
}
// Check if state exists in cache (prevents replay attacks)
const cachedState = this.stateCache.get(nonce);
const cacheKey = `${this.STATE_CACHE_PREFIX}${nonce}`;
this.logger.debug(`Looking for nonce ${nonce} in cache with key: ${cacheKey}`);
this.logger.debug(
`Instance #${this.instanceId}, HMAC secret first 8 chars: ${this.hmacSecret.substring(0, 8)}`
);
this.logger.debug(`Cache manager type: ${this.cacheManager.constructor.name}`);
const cachedState = await this.cacheManager.get<StateData>(cacheKey);
if (!cachedState) {
this.logger.warn(
`State validation failed: nonce ${nonce} not found in cache (possible replay attack)`
);
this.logger.warn(`Cache key checked: ${cacheKey}`);
return {
isValid: false,
error: 'State token not found or already used',
@@ -137,12 +173,13 @@ export class OidcStateService {
}
// Remove from cache to prevent reuse
this.stateCache.delete(nonce);
await this.cacheManager.del(cacheKey);
this.logger.debug(`State validation successful for provider ${expectedProviderId}`);
return {
isValid: true,
clientState: cachedState.clientState,
redirectUri: cachedState.redirectUri,
};
} catch (error) {
this.logger.error(
@@ -182,20 +219,5 @@ export class OidcStateService {
return null;
}
private cleanupExpiredStates(): void {
const now = Date.now();
let cleaned = 0;
for (const [nonce, stateData] of this.stateCache.entries()) {
const age = now - stateData.timestamp;
if (age > this.STATE_TTL_SECONDS * 1000) {
this.stateCache.delete(nonce);
cleaned++;
}
}
if (cleaned > 0) {
this.logger.debug(`Cleaned up ${cleaned} expired state entries`);
}
}
// Cleanup is now handled by cache TTL
}

View File

@@ -1,33 +1,13 @@
import { CacheModule } from '@nestjs/cache-manager';
import { Module } from '@nestjs/common';
import { UserSettingsModule } from '@unraid/shared/services/user-settings.js';
import { OidcAuthService } from '@app/unraid-api/graph/resolvers/sso/oidc-auth.service.js';
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/oidc-config.service.js';
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/oidc-session.service.js';
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/oidc-state.service.js';
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/oidc-validation.service.js';
import { OidcCoreModule } from '@app/unraid-api/graph/resolvers/sso/core/oidc-core.module.js';
import { SsoResolver } from '@app/unraid-api/graph/resolvers/sso/sso.resolver.js';
import '@app/unraid-api/graph/resolvers/sso/sso-settings.types.js';
import '@app/unraid-api/graph/resolvers/sso/models/sso-settings.types.js';
@Module({
imports: [UserSettingsModule, CacheModule.register()],
providers: [
SsoResolver,
OidcConfigPersistence,
OidcSessionService,
OidcStateService,
OidcAuthService,
OidcValidationService,
],
exports: [
OidcConfigPersistence,
OidcSessionService,
OidcStateService,
OidcAuthService,
OidcValidationService,
],
imports: [OidcCoreModule],
providers: [SsoResolver],
exports: [OidcCoreModule],
})
export class SsoModule {}

View File

@@ -6,11 +6,12 @@ import { PrefixedID } from '@unraid/shared/prefixed-id-scalar.js';
import { UsePermissions } from '@unraid/shared/use-permissions.directive.js';
import { Public } from '@app/unraid-api/auth/public.decorator.js';
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/oidc-config.service.js';
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/oidc-provider.model.js';
import { OidcSessionValidation } from '@app/unraid-api/graph/resolvers/sso/oidc-session-validation.model.js';
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/oidc-session.service.js';
import { PublicOidcProvider } from '@app/unraid-api/graph/resolvers/sso/public-oidc-provider.model.js';
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
import { OidcConfiguration } from '@app/unraid-api/graph/resolvers/sso/models/oidc-configuration.model.js';
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
import { OidcSessionValidation } from '@app/unraid-api/graph/resolvers/sso/models/oidc-session-validation.model.js';
import { PublicOidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/public-oidc-provider.model.js';
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-session.service.js';
@Resolver()
export class SsoResolver {
@@ -88,6 +89,19 @@ export class SsoResolver {
return this.oidcConfig.getProvider(id);
}
@Query(() => OidcConfiguration, { description: 'Get the full OIDC configuration (admin only)' })
@UsePermissions({
action: AuthAction.READ_ANY,
resource: Resource.CONFIG,
})
public async oidcConfiguration(): Promise<OidcConfiguration> {
const config = await this.oidcConfig.getConfig();
return {
providers: config?.providers || [],
defaultAllowedOrigins: config?.defaultAllowedOrigins || [],
};
}
@Query(() => OidcSessionValidation, {
description: 'Validate an OIDC session token (internal use for CLI validation)',
})

View File

@@ -0,0 +1,234 @@
import { Logger } from '@nestjs/common';
export interface OidcErrorDetails {
userFriendlyError: string;
details: Record<string, unknown>;
}
export class OidcErrorHelper {
private static readonly logger = new Logger(OidcErrorHelper.name);
/**
* Parse fetch errors and return user-friendly error messages
*/
static parseFetchError(error: unknown, issuerUrl?: string): OidcErrorDetails {
const errorMessage = error instanceof Error ? error.message : String(error);
let userFriendlyError = errorMessage;
let details: Record<string, unknown> = { originalError: errorMessage };
// Extract cause information if available
if (error instanceof Error && 'cause' in error) {
const cause = (error as any).cause;
if (cause) {
this.logger.log('Fetch error cause: %o', cause);
const errorCode = cause.code || '';
const causeMessage = cause.message || '';
// Map error codes to user-friendly messages
switch (errorCode) {
case 'ENOTFOUND':
userFriendlyError = `Cannot resolve domain name. Please check that '${issuerUrl}' is accessible and spelled correctly.`;
details = {
type: 'DNS_ERROR',
originalError: errorMessage,
cause: causeMessage || errorCode,
};
break;
case 'ECONNREFUSED':
userFriendlyError = `Connection refused. The server at '${issuerUrl}' is not accepting connections.`;
details = {
type: 'CONNECTION_ERROR',
originalError: errorMessage,
cause: causeMessage || errorCode,
};
break;
case 'CERT_HAS_EXPIRED':
userFriendlyError = `SSL/TLS certificate error. The server certificate may be invalid or expired.`;
details = {
type: 'SSL_ERROR',
originalError: errorMessage,
cause: causeMessage || errorCode,
};
break;
case 'ETIMEDOUT':
userFriendlyError = `Connection timeout. The server at '${issuerUrl}' is not responding.`;
details = {
type: 'TIMEOUT_ERROR',
originalError: errorMessage,
cause: causeMessage || errorCode,
};
break;
default:
// Check message patterns if code doesn't match
if (causeMessage.includes('ENOTFOUND')) {
userFriendlyError = `Cannot resolve domain name. Please check that '${issuerUrl}' is accessible and spelled correctly.`;
details = {
type: 'DNS_ERROR',
originalError: errorMessage,
cause: causeMessage,
};
} else if (causeMessage.includes('ECONNREFUSED')) {
userFriendlyError = `Connection refused. The server at '${issuerUrl}' is not accepting connections.`;
details = {
type: 'CONNECTION_ERROR',
originalError: errorMessage,
cause: causeMessage,
};
} else if (
causeMessage.includes('certificate') ||
causeMessage.includes('SSL') ||
causeMessage.includes('TLS')
) {
userFriendlyError = `SSL/TLS certificate error. The server certificate may be invalid or expired.`;
details = {
type: 'SSL_ERROR',
originalError: errorMessage,
cause: causeMessage,
};
} else if (causeMessage.includes('ETIMEDOUT')) {
userFriendlyError = `Connection timeout. The server at '${issuerUrl}' is not responding.`;
details = {
type: 'TIMEOUT_ERROR',
originalError: errorMessage,
cause: causeMessage,
};
} else {
userFriendlyError = `Failed to connect to OIDC provider at '${issuerUrl}'. ${causeMessage || errorCode || 'Unknown network error'}`;
details = {
type: 'FETCH_ERROR',
originalError: errorMessage,
cause: causeMessage || errorCode,
};
}
break;
}
} else {
// Generic fetch failed without cause
userFriendlyError = `Failed to connect to OIDC provider at '${issuerUrl}'. Please verify the URL is correct and accessible.`;
details = { type: 'FETCH_ERROR', originalError: errorMessage };
}
} else if (errorMessage.includes('fetch failed')) {
// Fetch failed but no cause information
userFriendlyError = `Failed to connect to OIDC provider at '${issuerUrl}'. Please verify the URL is correct and accessible.`;
details = { type: 'FETCH_ERROR', originalError: errorMessage };
}
return { userFriendlyError, details };
}
/**
* Parse HTTP status errors and return user-friendly error messages
*/
static parseHttpError(errorMessage: string, issuerUrl?: string): OidcErrorDetails {
let userFriendlyError = errorMessage;
let details: Record<string, unknown> = { originalError: errorMessage };
if (errorMessage.includes('404') || errorMessage.includes('Not Found')) {
const baseUrl = issuerUrl?.endsWith('/.well-known/openid-configuration')
? issuerUrl.replace('/.well-known/openid-configuration', '')
: issuerUrl;
userFriendlyError = `OIDC discovery endpoint not found. Please verify that '${baseUrl}/.well-known/openid-configuration' exists.`;
details = { type: 'DISCOVERY_NOT_FOUND', originalError: errorMessage };
} else if (errorMessage.includes('401') || errorMessage.includes('403')) {
userFriendlyError = `Access denied to discovery endpoint. Please check the issuer URL and any authentication requirements.`;
details = { type: 'AUTHENTICATION_ERROR', originalError: errorMessage };
} else if (errorMessage.includes('unexpected HTTP response status code')) {
// Extract status code if possible
const statusMatch = errorMessage.match(/status code (\d+)/);
const statusCode = statusMatch ? statusMatch[1] : 'unknown';
const baseUrl = issuerUrl?.endsWith('/.well-known/openid-configuration')
? issuerUrl.replace('/.well-known/openid-configuration', '')
: issuerUrl;
userFriendlyError = `HTTP ${statusCode} error from discovery endpoint. Please check that '${baseUrl}/.well-known/openid-configuration' returns a valid OIDC discovery document.`;
details = { type: 'HTTP_STATUS_ERROR', statusCode, originalError: errorMessage };
}
return { userFriendlyError, details };
}
/**
* Parse generic OIDC errors and return user-friendly error messages
*/
static parseGenericError(error: unknown, issuerUrl?: string): OidcErrorDetails {
const errorMessage = error instanceof Error ? error.message : String(error);
let userFriendlyError = errorMessage;
let details: Record<string, unknown> = { originalError: errorMessage };
// Check for specific error patterns
if (errorMessage.includes('getaddrinfo ENOTFOUND')) {
userFriendlyError = `Cannot resolve domain name. Please check that '${issuerUrl}' is accessible and spelled correctly.`;
details = { type: 'DNS_ERROR', originalError: errorMessage };
} else if (errorMessage.includes('ECONNREFUSED')) {
userFriendlyError = `Connection refused. The server at '${issuerUrl}' is not accepting connections.`;
details = { type: 'CONNECTION_ERROR', originalError: errorMessage };
} else if (errorMessage.includes('ECONNRESET') || errorMessage.includes('ETIMEDOUT')) {
userFriendlyError = `Connection timeout. The server at '${issuerUrl}' is not responding.`;
details = { type: 'TIMEOUT_ERROR', originalError: errorMessage };
} else if (
errorMessage.includes('certificate') ||
errorMessage.includes('SSL') ||
errorMessage.includes('TLS')
) {
userFriendlyError = `SSL/TLS certificate error. The server certificate may be invalid or expired.`;
details = { type: 'SSL_ERROR', originalError: errorMessage };
} else if (errorMessage.includes('JSON') || errorMessage.includes('parse')) {
userFriendlyError = `Invalid OIDC discovery response. The server returned malformed JSON.`;
details = { type: 'INVALID_JSON', originalError: errorMessage };
} else if (error && (error as any).code === 'OAUTH_RESPONSE_IS_NOT_CONFORM') {
const baseUrl = issuerUrl?.endsWith('/.well-known/openid-configuration')
? issuerUrl.replace('/.well-known/openid-configuration', '')
: issuerUrl;
userFriendlyError = `Invalid OIDC discovery document. The server at '${baseUrl}/.well-known/openid-configuration' returned a response that doesn't conform to the OpenID Connect Discovery specification. Please verify the endpoint returns valid OIDC metadata.`;
details = { type: 'INVALID_OIDC_DOCUMENT', originalError: errorMessage };
}
return { userFriendlyError, details };
}
/**
* Parse OIDC discovery errors and return user-friendly error messages
*/
static parseDiscoveryError(error: unknown, issuerUrl?: string): OidcErrorDetails {
const errorMessage = error instanceof Error ? error.message : String(error);
// Log additional error details for debugging
if (error instanceof Error) {
this.logger.log(`Error type: ${error.constructor.name}`);
if ('stack' in error && error.stack) {
this.logger.debug(`Stack trace: ${error.stack}`);
}
if ('response' in error) {
const response = (error as any).response;
if (response) {
this.logger.log(`Response status: ${response.status}`);
this.logger.log(`Response body: ${response.body}`);
}
}
}
// Check for fetch-specific errors first
if (errorMessage.includes('fetch failed')) {
return this.parseFetchError(error, issuerUrl);
}
// Check for HTTP status errors
const httpError = this.parseHttpError(errorMessage, issuerUrl);
// Proper type-narrowing guard for accessing details.type
if (
httpError.details &&
typeof httpError.details === 'object' &&
'type' in httpError.details &&
httpError.details.type !== undefined
) {
return httpError;
}
// Fall back to generic error parsing
return this.parseGenericError(error, issuerUrl);
}
}

View File

@@ -0,0 +1,228 @@
import { Logger } from '@nestjs/common';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import type { FastifyRequest } from '@app/unraid-api/types/fastify.js';
import { OidcRequestHandler } from '@app/unraid-api/graph/resolvers/sso/utils/oidc-request-handler.util.js';
describe('OidcRequestHandler', () => {
let mockLogger: Logger;
beforeEach(() => {
vi.clearAllMocks();
mockLogger = {
debug: vi.fn(),
log: vi.fn(),
warn: vi.fn(),
error: vi.fn(),
} as any;
});
describe('extractRequestInfo', () => {
it('should extract request info from headers', () => {
const mockReq = {
headers: {
'x-forwarded-proto': 'https',
'x-forwarded-host': 'example.com:8443',
},
protocol: 'http',
url: '/callback?code=123&state=456',
} as unknown as FastifyRequest;
const result = OidcRequestHandler.extractRequestInfo(mockReq);
expect(result.protocol).toBe('https');
expect(result.host).toBe('example.com:8443');
expect(result.fullUrl).toBe('https://example.com:8443/callback?code=123&state=456');
expect(result.baseUrl).toBe('https://example.com:8443');
});
it('should fall back to request properties when headers are missing', () => {
const mockReq = {
headers: {
host: 'localhost:3000',
},
protocol: 'http',
url: '/callback?code=123&state=456',
} as FastifyRequest;
const result = OidcRequestHandler.extractRequestInfo(mockReq);
expect(result.protocol).toBe('http');
expect(result.host).toBe('localhost:3000');
expect(result.fullUrl).toBe('http://localhost:3000/callback?code=123&state=456');
expect(result.baseUrl).toBe('http://localhost:3000');
});
it('should use defaults when all headers are missing', () => {
const mockReq = {
headers: {},
url: '/callback?code=123&state=456',
} as FastifyRequest;
const result = OidcRequestHandler.extractRequestInfo(mockReq);
expect(result.protocol).toBe('http');
expect(result.host).toBe('localhost:3000');
expect(result.fullUrl).toBe('http://localhost:3000/callback?code=123&state=456');
expect(result.baseUrl).toBe('http://localhost:3000');
});
});
describe('validateAuthorizeParams', () => {
it('should validate valid parameters', () => {
const result = OidcRequestHandler.validateAuthorizeParams(
'provider123',
'state456',
'https://example.com/callback'
);
expect(result.providerId).toBe('provider123');
expect(result.state).toBe('state456');
expect(result.redirectUri).toBe('https://example.com/callback');
});
it('should throw error for missing provider ID', () => {
expect(() => {
OidcRequestHandler.validateAuthorizeParams(
undefined,
'state456',
'https://example.com/callback'
);
}).toThrow('Provider ID is required');
});
it('should throw error for missing state', () => {
expect(() => {
OidcRequestHandler.validateAuthorizeParams(
'provider123',
undefined,
'https://example.com/callback'
);
}).toThrow('State parameter is required');
});
it('should throw error for missing redirect URI', () => {
expect(() => {
OidcRequestHandler.validateAuthorizeParams('provider123', 'state456', undefined);
}).toThrow('Redirect URI is required');
});
});
describe('validateCallbackParams', () => {
it('should validate valid parameters', () => {
const result = OidcRequestHandler.validateCallbackParams('code123', 'state456');
expect(result.code).toBe('code123');
expect(result.state).toBe('state456');
});
it('should throw error for missing code', () => {
expect(() => {
OidcRequestHandler.validateCallbackParams(undefined, 'state456');
}).toThrow('Missing required parameters');
});
it('should throw error for missing state', () => {
expect(() => {
OidcRequestHandler.validateCallbackParams('code123', undefined);
}).toThrow('Missing required parameters');
});
it('should throw error for empty code', () => {
expect(() => {
OidcRequestHandler.validateCallbackParams('', 'state456');
}).toThrow('Missing required parameters');
});
it('should throw error for empty state', () => {
expect(() => {
OidcRequestHandler.validateCallbackParams('code123', '');
}).toThrow('Missing required parameters');
});
});
describe('handleAuthorize', () => {
it('should handle authorization flow', async () => {
const mockAuthService = {
getAuthorizationUrl: vi
.fn()
.mockResolvedValue('https://provider.com/auth?client_id=123'),
};
const mockReq = {
headers: { 'x-forwarded-proto': 'https', 'x-forwarded-host': 'example.com' },
url: '/authorize',
} as unknown as FastifyRequest;
const authUrl = await OidcRequestHandler.handleAuthorize(
'provider123',
'state456',
'https://example.com/callback',
mockReq,
mockAuthService as any,
mockLogger
);
expect(authUrl).toBe('https://provider.com/auth?client_id=123');
expect(mockAuthService.getAuthorizationUrl).toHaveBeenCalledWith({
providerId: 'provider123',
state: 'state456',
requestOrigin: 'https://example.com/callback',
requestHeaders: {
'x-forwarded-proto': 'https',
'x-forwarded-host': 'example.com',
},
});
expect(mockLogger.debug).toHaveBeenCalledWith(
'Authorization request - Provider: provider123'
);
expect(mockLogger.log).toHaveBeenCalledWith(
'Redirecting to OIDC provider: https://provider.com/auth?client_id=123'
);
});
});
describe('handleCallback', () => {
it('should handle callback flow', async () => {
const mockStateService = {
extractProviderFromState: vi.fn().mockReturnValue('provider123'),
};
const mockAuthService = {
getStateService: vi.fn().mockReturnValue(mockStateService),
handleCallback: vi.fn().mockResolvedValue('paddedToken123'),
};
const mockReq: Pick<FastifyRequest, 'id' | 'headers' | 'url'> = {
id: '123',
headers: { 'x-forwarded-proto': 'https', 'x-forwarded-host': 'example.com' },
url: '/callback?code=123&state=456',
};
const result = await OidcRequestHandler.handleCallback(
'code123',
'state456',
mockReq as unknown as FastifyRequest,
mockAuthService as any,
mockLogger
);
expect(result.providerId).toBe('provider123');
expect(result.paddedToken).toBe('paddedToken123');
expect(result.requestInfo.fullUrl).toBe('https://example.com/callback?code=123&state=456');
expect(mockAuthService.handleCallback).toHaveBeenCalledWith({
providerId: 'provider123',
code: 'code123',
state: 'state456',
requestOrigin: 'https://example.com',
fullCallbackUrl: 'https://example.com/callback?code=123&state=456',
requestHeaders: {
'x-forwarded-proto': 'https',
'x-forwarded-host': 'example.com',
},
});
expect(mockLogger.debug).toHaveBeenCalledWith('Callback request - Provider: provider123');
});
});
});

View File

@@ -0,0 +1,155 @@
import { Logger } from '@nestjs/common';
import type { FastifyRequest } from '@app/unraid-api/types/fastify.js';
import { OidcService } from '@app/unraid-api/graph/resolvers/sso/core/oidc.service.js';
import { OidcStateExtractor } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state-extractor.util.js';
export interface RequestInfo {
protocol: string;
host: string;
fullUrl: string;
baseUrl: string;
}
export interface OidcFlowResult {
providerId: string;
requestInfo: RequestInfo;
}
export interface OidcCallbackResult extends OidcFlowResult {
paddedToken: string;
}
/**
* Utility class to handle common OIDC request processing logic
* between authorize and callback endpoints
*/
export class OidcRequestHandler {
/**
* Extract request information from Fastify request headers
*/
static extractRequestInfo(req: FastifyRequest): RequestInfo {
// Handle potentially comma-separated forwarded headers (take first value)
const forwardedProto = String(req.headers['x-forwarded-proto'] || '')
.split(',')[0]
?.trim();
const forwardedHost = String(req.headers['x-forwarded-host'] || '')
.split(',')[0]
?.trim();
const protocol = forwardedProto || req.protocol || 'http';
const host = forwardedHost || req.headers.host || 'localhost:3000';
const fullUrl = `${protocol}://${host}${req.url}`;
const baseUrl = `${protocol}://${host}`;
return {
protocol,
host,
fullUrl,
baseUrl,
};
}
/**
* Handle OIDC authorization flow
*/
static async handleAuthorize(
providerId: string,
state: string,
redirectUri: string,
req: FastifyRequest,
oidcService: OidcService,
logger: Logger
): Promise<string> {
const requestInfo = this.extractRequestInfo(req);
logger.debug(`Authorization request - Provider: ${providerId}`);
logger.debug(`Authorization request - Full URL: ${requestInfo.fullUrl}`);
logger.debug(`Authorization request - Redirect URI: ${redirectUri}`);
// Get authorization URL using the validated redirect URI and request headers
const authUrl = await oidcService.getAuthorizationUrl({
providerId,
state,
requestOrigin: redirectUri,
requestHeaders: req.headers as Record<string, string | string[] | undefined>,
});
logger.log(`Redirecting to OIDC provider: ${authUrl}`);
return authUrl;
}
/**
* Handle OIDC callback flow
*/
static async handleCallback(
code: string,
state: string,
req: FastifyRequest,
oidcService: OidcService,
logger: Logger
): Promise<OidcCallbackResult> {
// Extract provider ID from state for routing
const { providerId } = OidcStateExtractor.extractProviderFromState(
state,
oidcService.getStateService()
);
const requestInfo = this.extractRequestInfo(req);
logger.debug(`Callback request - Provider: ${providerId}`);
logger.debug(`Callback request - Full URL: ${requestInfo.fullUrl}`);
logger.debug(`Redirect URI will be retrieved from encrypted state`);
// Handle the callback using stored redirect URI from state and request headers
const paddedToken = await oidcService.handleCallback({
providerId,
code,
state,
requestOrigin: requestInfo.baseUrl,
fullCallbackUrl: requestInfo.fullUrl,
requestHeaders: req.headers as Record<string, string | string[] | undefined>,
});
return {
providerId,
requestInfo,
paddedToken,
};
}
/**
* Validate required parameters for authorization flow
*/
static validateAuthorizeParams(
providerId: string | undefined,
state: string | undefined,
redirectUri: string | undefined
): { providerId: string; state: string; redirectUri: string } {
if (!providerId) {
throw new Error('Provider ID is required');
}
if (!state) {
throw new Error('State parameter is required');
}
if (!redirectUri) {
throw new Error('Redirect URI is required');
}
return { providerId, state, redirectUri };
}
/**
* Validate required parameters for callback flow
*/
static validateCallbackParams(
code: string | undefined,
state: string | undefined
): { code: string; state: string } {
if (!code || !state) {
throw new Error('Missing required parameters');
}
return { code, state };
}
}

View File

@@ -2,12 +2,12 @@ import { Module } from '@nestjs/common';
import { ScheduleModule } from '@nestjs/schedule';
import { SubscriptionHelperService } from '@app/unraid-api/graph/services/subscription-helper.service.js';
import { SubscriptionPollingService } from '@app/unraid-api/graph/services/subscription-polling.service.js';
import { SubscriptionManagerService } from '@app/unraid-api/graph/services/subscription-manager.service.js';
import { SubscriptionTrackerService } from '@app/unraid-api/graph/services/subscription-tracker.service.js';
@Module({
imports: [],
providers: [SubscriptionTrackerService, SubscriptionHelperService, SubscriptionPollingService],
exports: [SubscriptionTrackerService, SubscriptionHelperService, SubscriptionPollingService],
providers: [SubscriptionTrackerService, SubscriptionHelperService, SubscriptionManagerService],
exports: [SubscriptionTrackerService, SubscriptionHelperService], // SubscriptionManagerService is internal
})
export class ServicesModule {}

View File

@@ -4,7 +4,25 @@ import { createSubscription, PUBSUB_CHANNEL } from '@app/core/pubsub.js';
import { SubscriptionTrackerService } from '@app/unraid-api/graph/services/subscription-tracker.service.js';
/**
* Helper service for creating tracked GraphQL subscriptions with automatic cleanup
* High-level helper service for creating GraphQL subscriptions with automatic cleanup.
*
* This service provides a convenient way to create GraphQL subscriptions that:
* - Automatically track subscriber count via SubscriptionTrackerService
* - Properly clean up resources when subscriptions end
* - Handle errors gracefully
*
* **When to use this service:**
* - In GraphQL resolvers when implementing subscriptions
* - When you need automatic reference counting for shared resources
* - When you want to ensure proper cleanup on subscription termination
*
* @example
* // In a GraphQL resolver
* \@Subscription(() => MetricsUpdate)
* async metricsSubscription() {
* // Topic must be registered first via SubscriptionTrackerService
* return this.subscriptionHelper.createTrackedSubscription(PUBSUB_CHANNEL.METRICS);
* }
*/
@Injectable()
export class SubscriptionHelperService {
@@ -15,7 +33,7 @@ export class SubscriptionHelperService {
* @param topic The subscription topic/channel to subscribe to
* @returns A proxy async iterator with automatic cleanup
*/
public createTrackedSubscription<T = any>(topic: PUBSUB_CHANNEL): AsyncIterableIterator<T> {
public createTrackedSubscription<T = any>(topic: PUBSUB_CHANNEL | string): AsyncIterableIterator<T> {
const innerIterator = createSubscription<T>(topic);
// Subscribe when the subscription starts

View File

@@ -0,0 +1,196 @@
import { Injectable, Logger, OnModuleDestroy } from '@nestjs/common';
import { SchedulerRegistry } from '@nestjs/schedule';
/**
* Configuration for managed subscriptions
*/
export interface SubscriptionConfig {
/** Unique identifier for the subscription */
name: string;
/**
* Polling interval in milliseconds.
* - If set to a number, the callback will be called at that interval
* - If null/undefined, the subscription is event-based (no polling)
*/
intervalMs?: number | null;
/** Function to call periodically (for polling) or once (for setup) */
callback: () => Promise<void>;
/** Optional function called when the subscription starts */
onStart?: () => Promise<void>;
/** Optional function called when the subscription stops */
onStop?: () => Promise<void>;
}
/**
* Low-level service for managing both polling and event-based subscriptions.
*
* ⚠️ **IMPORTANT**: This is an internal service. Do not use directly in resolvers or business logic.
* Instead, use one of the higher-level services:
* - **SubscriptionTrackerService**: For subscriptions that need reference counting
* - **SubscriptionHelperService**: For GraphQL subscriptions with automatic cleanup
*
* This service provides the underlying implementation for:
* - **Polling subscriptions**: Execute a callback at regular intervals
* - **Event-based subscriptions**: Set up event listeners or watchers that persist until stopped
*
* @internal
*/
@Injectable()
export class SubscriptionManagerService implements OnModuleDestroy {
private readonly logger = new Logger(SubscriptionManagerService.name);
private readonly activeSubscriptions = new Map<
string,
{ isPolling: boolean; config?: SubscriptionConfig }
>();
constructor(private readonly schedulerRegistry: SchedulerRegistry) {}
async onModuleDestroy() {
await this.stopAll();
}
/**
* Start a managed subscription (polling or event-based).
*
* @param config - The subscription configuration
* @throws Will throw an error if the onStart callback fails
*/
async startSubscription(config: SubscriptionConfig): Promise<void> {
const { name, intervalMs, callback, onStart } = config;
// Clean up any existing subscription with the same name
await this.stopSubscription(name);
// Initialize subscription state with config
this.activeSubscriptions.set(name, { isPolling: false, config });
// Call onStart callback if provided
if (onStart) {
try {
await onStart();
this.logger.debug(`Called onStart for '${name}'`);
} catch (error) {
this.logger.error(`Error in onStart for '${name}'`, error);
throw error;
}
}
// If intervalMs is null, this is a continuous/event-based subscription
if (intervalMs === null || intervalMs === undefined) {
this.logger.debug(`Started continuous subscription for '${name}' (no polling)`);
return;
}
// Create the polling function with guard against overlapping executions
const pollFunction = async () => {
const subscription = this.activeSubscriptions.get(name);
if (!subscription || subscription.isPolling) {
return;
}
subscription.isPolling = true;
try {
await callback();
} catch (error) {
this.logger.error(`Error in subscription callback '${name}'`, error);
} finally {
if (subscription) {
subscription.isPolling = false;
}
}
};
// Create and register the interval
const interval = setInterval(pollFunction, intervalMs);
this.schedulerRegistry.addInterval(name, interval);
this.logger.debug(`Started polling for '${name}' every ${intervalMs}ms`);
}
/**
* Stop a managed subscription.
*
* This will:
* 1. Stop any active polling interval
* 2. Call the onStop callback if provided
* 3. Clean up internal state
*
* @param name - The unique identifier of the subscription to stop
*/
async stopSubscription(name: string): Promise<void> {
// Get the config before deleting
const subscription = this.activeSubscriptions.get(name);
const onStop = subscription?.config?.onStop;
try {
if (this.schedulerRegistry.doesExist('interval', name)) {
this.schedulerRegistry.deleteInterval(name);
this.logger.debug(`Stopped polling interval for '${name}'`);
}
} catch (error) {
// Interval doesn't exist, which is fine
}
// Call onStop callback if provided
if (onStop) {
try {
await onStop();
this.logger.debug(`Called onStop for '${name}'`);
} catch (error) {
this.logger.error(`Error in onStop for '${name}'`, error);
}
}
// Clean up subscription state
this.activeSubscriptions.delete(name);
}
/**
* Stop all active subscriptions.
*
* This is automatically called when the module is destroyed.
*/
async stopAll(): Promise<void> {
// Get all active subscription keys (both polling and event-based)
const activeKeys = Array.from(this.activeSubscriptions.keys());
// Stop each subscription and await cleanup
await Promise.all(activeKeys.map((key) => this.stopSubscription(key)));
// Clear the map after all subscriptions are stopped
this.activeSubscriptions.clear();
}
/**
* Check if a subscription is active.
*
* @param name - The unique identifier of the subscription
* @returns true if the subscription exists (either polling or event-based)
*/
isSubscriptionActive(name: string): boolean {
// Check both for polling intervals and event-based subscriptions
return this.activeSubscriptions.has(name) || this.schedulerRegistry.doesExist('interval', name);
}
/**
* Get the total number of active subscriptions.
*
* @returns The count of all active subscriptions (polling and event-based)
*/
getActiveSubscriptionCount(): number {
return this.activeSubscriptions.size;
}
/**
* Get a list of all active subscription names.
*
* @returns Array of subscription identifiers
*/
getActiveSubscriptionNames(): string[] {
return Array.from(this.activeSubscriptions.keys());
}
}

View File

@@ -1,91 +0,0 @@
import { Injectable, Logger, OnModuleDestroy } from '@nestjs/common';
import { SchedulerRegistry } from '@nestjs/schedule';
export interface PollingConfig {
name: string;
intervalMs: number;
callback: () => Promise<void>;
}
@Injectable()
export class SubscriptionPollingService implements OnModuleDestroy {
private readonly logger = new Logger(SubscriptionPollingService.name);
private readonly activePollers = new Map<string, { isPolling: boolean }>();
constructor(private readonly schedulerRegistry: SchedulerRegistry) {}
onModuleDestroy() {
this.stopAll();
}
/**
* Start polling for a specific subscription topic
*/
startPolling(config: PollingConfig): void {
const { name, intervalMs, callback } = config;
// Clean up any existing interval
this.stopPolling(name);
// Initialize polling state
this.activePollers.set(name, { isPolling: false });
// Create the polling function with guard against overlapping executions
const pollFunction = async () => {
const poller = this.activePollers.get(name);
if (!poller || poller.isPolling) {
return;
}
poller.isPolling = true;
try {
await callback();
} catch (error) {
this.logger.error(`Error in polling task '${name}'`, error);
} finally {
if (poller) {
poller.isPolling = false;
}
}
};
// Create and register the interval
const interval = setInterval(pollFunction, intervalMs);
this.schedulerRegistry.addInterval(name, interval);
this.logger.debug(`Started polling for '${name}' every ${intervalMs}ms`);
}
/**
* Stop polling for a specific subscription topic
*/
stopPolling(name: string): void {
try {
if (this.schedulerRegistry.doesExist('interval', name)) {
this.schedulerRegistry.deleteInterval(name);
this.logger.debug(`Stopped polling for '${name}'`);
}
} catch (error) {
// Interval doesn't exist, which is fine
}
// Clean up polling state
this.activePollers.delete(name);
}
/**
* Stop all active polling tasks
*/
stopAll(): void {
const intervals = this.schedulerRegistry.getIntervals();
intervals.forEach((key) => this.stopPolling(key));
this.activePollers.clear();
}
/**
* Check if polling is active for a given name
*/
isPolling(name: string): boolean {
return this.schedulerRegistry.doesExist('interval', name);
}
}

View File

@@ -1,14 +1,44 @@
import { Injectable, Logger } from '@nestjs/common';
import { SubscriptionPollingService } from '@app/unraid-api/graph/services/subscription-polling.service.js';
import { SubscriptionManagerService } from '@app/unraid-api/graph/services/subscription-manager.service.js';
/**
* Service for managing subscriptions with automatic reference counting.
*
* This service tracks the number of active subscribers for each topic and automatically
* starts/stops the underlying subscription based on subscriber count.
*
* **When to use this service:**
* - When you have multiple GraphQL subscriptions that share the same data source
* - When you need to start a resource (polling, file watcher, etc.) only when there are active subscribers
* - When you need automatic cleanup when the last subscriber disconnects
*
* @example
* // Register a polling subscription
* subscriptionTracker.registerTopic(
* 'metrics-update',
* async () => {
* const metrics = await fetchMetrics();
* pubsub.publish('metrics-update', { metrics });
* },
* 5000 // Poll every 5 seconds
* );
*
* @example
* // Register an event-based subscription (e.g., file watching)
* subscriptionTracker.registerTopic(
* 'log-file-updates',
* () => startFileWatcher('/var/log/app.log'), // onStart
* () => stopFileWatcher('/var/log/app.log') // onStop
* );
*/
@Injectable()
export class SubscriptionTrackerService {
private readonly logger = new Logger(SubscriptionTrackerService.name);
private subscriberCounts = new Map<string, number>();
private topicHandlers = new Map<string, { onStart: () => void; onStop: () => void }>();
constructor(private readonly pollingService: SubscriptionPollingService) {}
constructor(private readonly subscriptionManager: SubscriptionManagerService) {}
/**
* Register a topic with optional polling support
@@ -29,8 +59,8 @@ export class SubscriptionTrackerService {
callback: async () => callbackOrOnStart(),
};
this.topicHandlers.set(topic, {
onStart: () => this.pollingService.startPolling(pollingConfig),
onStop: () => this.pollingService.stopPolling(topic),
onStart: () => this.subscriptionManager.startSubscription(pollingConfig),
onStop: () => this.subscriptionManager.stopSubscription(topic),
});
} else {
// Legacy API: onStart and onStop handlers

View File

@@ -1,3 +1,4 @@
import { CacheModule } from '@nestjs/cache-manager';
import { Test } from '@nestjs/testing';
import { CANONICAL_INTERNAL_CLIENT_TOKEN } from '@unraid/shared';
@@ -60,7 +61,7 @@ vi.mock('execa', () => ({
describe('RestModule Integration', () => {
it('should compile with RestService having access to ApiReportService', async () => {
const module = await Test.createTestingModule({
imports: [RestModule],
imports: [CacheModule.register({ isGlobal: true }), RestModule],
})
// Override services that have complex dependencies for testing
.overrideProvider(CANONICAL_INTERNAL_CLIENT_TOKEN)

View File

@@ -0,0 +1,487 @@
import { Logger } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import { Test, TestingModule } from '@nestjs/testing';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import type { FastifyReply, FastifyRequest } from '@app/unraid-api/types/fastify.js';
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
import { OidcService } from '@app/unraid-api/graph/resolvers/sso/core/oidc.service.js';
import { OidcRequestHandler } from '@app/unraid-api/graph/resolvers/sso/utils/oidc-request-handler.util.js';
import { RestController } from '@app/unraid-api/rest/rest.controller.js';
import { RestService } from '@app/unraid-api/rest/rest.service.js';
describe('RestController', () => {
let controller: RestController;
let oidcService: OidcService;
let oidcConfig: OidcConfigPersistence;
let mockReply: Partial<FastifyReply>;
// Helper function to create a mock request with the desired hostname
const createMockRequest = (hostname?: string, headers: Record<string, any> = {}): FastifyRequest => {
return {
headers,
hostname,
url: '/test',
protocol: 'https',
} as FastifyRequest;
};
beforeEach(async () => {
const module: TestingModule = await Test.createTestingModule({
controllers: [RestController],
providers: [
{
provide: RestService,
useValue: {
getLogs: vi.fn(),
getCustomizationStream: vi.fn(),
},
},
{
provide: OidcService,
useValue: {
getAuthorizationUrl: vi.fn(),
handleCallback: vi.fn(),
},
},
{
provide: OidcConfigPersistence,
useValue: {
getConfig: vi.fn().mockResolvedValue({
defaultAllowedOrigins: [],
}),
},
},
{
provide: ConfigService,
useValue: {
get: vi.fn(),
},
},
],
}).compile();
controller = module.get<RestController>(RestController);
oidcService = module.get<OidcService>(OidcService);
oidcConfig = module.get<OidcConfigPersistence>(OidcConfigPersistence);
mockReply = {
status: vi.fn().mockReturnThis(),
header: vi.fn().mockReturnThis(),
send: vi.fn().mockReturnThis(),
type: vi.fn().mockReturnThis(),
};
});
describe('oidcAuthorize', () => {
describe('redirect URI validation', () => {
beforeEach(() => {
// Mock OidcRequestHandler.handleAuthorize to return a valid auth URL
vi.spyOn(OidcRequestHandler, 'handleAuthorize').mockResolvedValue(
'https://provider.com/authorize?client_id=test&redirect_uri=...'
);
});
it('should accept redirect_uri with same hostname but different port', async () => {
const mockRequest = createMockRequest('unraid.mytailnet.ts.net');
await controller.oidcAuthorize(
'test-provider',
'test-state',
'https://unraid.mytailnet.ts.net:1443/graphql/api/auth/oidc/callback',
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(302);
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalledWith(
'test-provider',
'test-state',
'https://unraid.mytailnet.ts.net:1443/graphql/api/auth/oidc/callback',
mockRequest,
oidcService,
expect.any(Logger)
);
});
it('should accept redirect_uri with same hostname and standard HTTPS port', async () => {
const mockRequest = createMockRequest('unraid.mytailnet.ts.net');
await controller.oidcAuthorize(
'test-provider',
'test-state',
'https://unraid.mytailnet.ts.net/graphql/api/auth/oidc/callback',
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(302);
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalled();
});
it('should accept redirect_uri with same hostname and explicit port 443', async () => {
const mockRequest = createMockRequest('unraid.mytailnet.ts.net');
await controller.oidcAuthorize(
'test-provider',
'test-state',
'https://unraid.mytailnet.ts.net:443/graphql/api/auth/oidc/callback',
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(302);
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalled();
});
it('should reject redirect_uri with different hostname', async () => {
const mockRequest = createMockRequest('unraid.mytailnet.ts.net');
await controller.oidcAuthorize(
'test-provider',
'test-state',
'https://evil.com/graphql/api/auth/oidc/callback',
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(400);
expect(mockReply.send).toHaveBeenCalledWith(
expect.stringContaining(
'Invalid redirect_uri: https://evil.com/graphql/api/auth/oidc/callback'
)
);
expect(OidcRequestHandler.handleAuthorize).not.toHaveBeenCalled();
});
it('should reject redirect_uri with subdomain difference', async () => {
const mockRequest = createMockRequest('unraid.mytailnet.ts.net');
await controller.oidcAuthorize(
'test-provider',
'test-state',
'https://evil.unraid.mytailnet.ts.net/graphql/api/auth/oidc/callback',
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(400);
expect(mockReply.send).toHaveBeenCalledWith(
expect.stringContaining(
'Invalid redirect_uri: https://evil.unraid.mytailnet.ts.net/graphql/api/auth/oidc/callback'
)
);
expect(OidcRequestHandler.handleAuthorize).not.toHaveBeenCalled();
});
it('should handle hostname from host header when hostname is not available', async () => {
const mockRequest = createMockRequest(undefined, {
host: 'unraid.mytailnet.ts.net:8080',
});
await controller.oidcAuthorize(
'test-provider',
'test-state',
'https://unraid.mytailnet.ts.net:1443/graphql/api/auth/oidc/callback',
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(302);
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalled();
});
it('should reject malformed redirect_uri', async () => {
const mockRequest = createMockRequest('unraid.mytailnet.ts.net');
await controller.oidcAuthorize(
'test-provider',
'test-state',
'not-a-valid-url',
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(400);
expect(mockReply.send).toHaveBeenCalledWith(
expect.stringContaining('Invalid redirect_uri: not-a-valid-url')
);
expect(OidcRequestHandler.handleAuthorize).not.toHaveBeenCalled();
});
it('should handle case-insensitive hostname comparison', async () => {
const mockRequest = createMockRequest('UnRaid.MyTailnet.TS.net');
await controller.oidcAuthorize(
'test-provider',
'test-state',
'https://unraid.mytailnet.ts.net:1443/graphql/api/auth/oidc/callback',
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(302);
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalled();
});
it('should preserve exact redirect_uri including custom port in call to handleAuthorize', async () => {
const mockRequest = createMockRequest('unraid.mytailnet.ts.net');
const customRedirectUri =
'https://unraid.mytailnet.ts.net:1443/graphql/api/auth/oidc/callback';
await controller.oidcAuthorize(
'test-provider',
'test-state',
customRedirectUri,
mockRequest,
mockReply as FastifyReply
);
// Verify the exact redirect URI with port is passed through
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalledWith(
'test-provider',
'test-state',
customRedirectUri, // Should be exactly as provided, with :1443
mockRequest,
oidcService,
expect.any(Logger)
);
});
it('should allow localhost with different ports', async () => {
const mockRequest = createMockRequest('localhost');
await controller.oidcAuthorize(
'test-provider',
'test-state',
'http://localhost:3000/graphql/api/auth/oidc/callback',
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(302);
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalledWith(
'test-provider',
'test-state',
'http://localhost:3000/graphql/api/auth/oidc/callback',
mockRequest,
oidcService,
expect.any(Logger)
);
});
it('should allow IP addresses with different ports', async () => {
const mockRequest = createMockRequest('192.168.1.100');
await controller.oidcAuthorize(
'test-provider',
'test-state',
'http://192.168.1.100:8080/graphql/api/auth/oidc/callback',
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(302);
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalled();
});
it('should accept redirect_uri with different hostname if in allowed origins', async () => {
const mockRequest = createMockRequest('devgen-dev1.local');
// Mock the config to include the allowed origin
vi.mocked(oidcConfig.getConfig).mockResolvedValueOnce({
defaultAllowedOrigins: ['https://devgen-bad-dev1.local'],
} as any);
await controller.oidcAuthorize(
'test-provider',
'test-state',
'https://devgen-bad-dev1.local/graphql/api/auth/oidc/callback',
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(302);
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalledWith(
'test-provider',
'test-state',
'https://devgen-bad-dev1.local/graphql/api/auth/oidc/callback',
mockRequest,
oidcService,
expect.any(Logger)
);
});
describe('integration with centralized validator', () => {
it('should use the same validation logic as validateRedirectUri function', async () => {
const testCases = [
{
name: 'accepts HTTPS upgrade from allowed origins',
requestHost: 'devgen-dev1.local',
redirectUri: 'https://allowed-host.local/graphql/api/auth/oidc/callback',
allowedOrigins: ['http://allowed-host.local'],
expectedStatus: 302,
shouldSucceed: true,
},
{
name: 'rejects hostname not in allowed origins',
requestHost: 'devgen-dev1.local',
redirectUri: 'https://evil.com/graphql/api/auth/oidc/callback',
allowedOrigins: ['https://good-host.local'],
expectedStatus: 400,
shouldSucceed: false,
},
{
name: 'accepts multiple allowed origins',
requestHost: 'devgen-dev1.local',
redirectUri: 'https://second.local/graphql/api/auth/oidc/callback',
allowedOrigins: [
'https://first.local',
'https://second.local',
'https://third.local',
],
expectedStatus: 302,
shouldSucceed: true,
},
{
name: 'respects protocol and hostname from headers',
requestHost: undefined,
headers: {
'x-forwarded-proto': 'https',
'x-forwarded-host': 'proxy.local',
},
redirectUri: 'https://proxy.local/graphql/api/auth/oidc/callback',
allowedOrigins: [],
expectedStatus: 302,
shouldSucceed: true,
},
];
for (const testCase of testCases) {
// Reset mocks for each test case
vi.clearAllMocks();
const mockRequest = createMockRequest(
testCase.requestHost,
testCase.headers || {}
);
vi.mocked(oidcConfig.getConfig).mockResolvedValueOnce({
defaultAllowedOrigins: testCase.allowedOrigins,
} as any);
await controller.oidcAuthorize(
'test-provider',
'test-state',
testCase.redirectUri,
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(testCase.expectedStatus);
if (testCase.shouldSucceed) {
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalled();
} else {
expect(mockReply.send).toHaveBeenCalledWith(
expect.stringContaining(testCase.redirectUri)
);
expect(OidcRequestHandler.handleAuthorize).not.toHaveBeenCalled();
}
}
});
it('should handle edge cases consistently with centralized validator', async () => {
// Test with empty allowed origins
vi.mocked(oidcConfig.getConfig).mockResolvedValueOnce({
defaultAllowedOrigins: [],
} as any);
const mockRequest = createMockRequest('host.local');
await controller.oidcAuthorize(
'test-provider',
'test-state',
'https://different.local/graphql/api/auth/oidc/callback',
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(400);
expect(mockReply.send).toHaveBeenCalledWith(
expect.stringContaining('https://different.local/graphql/api/auth/oidc/callback')
);
});
it('should validate that error messages guide users to settings', async () => {
vi.mocked(oidcConfig.getConfig).mockResolvedValueOnce({
defaultAllowedOrigins: [],
} as any);
const mockRequest = createMockRequest('host.local');
await controller.oidcAuthorize(
'test-provider',
'test-state',
'https://different.local/graphql/api/auth/oidc/callback',
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.send).toHaveBeenCalledWith(
expect.stringContaining('Settings → Management Access → Allowed Redirect URIs')
);
});
});
});
describe('parameter validation', () => {
it('should return 400 if redirect_uri is missing', async () => {
const mockRequest = createMockRequest('unraid.local');
await controller.oidcAuthorize(
'test-provider',
'test-state',
undefined as any,
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(400);
// The controller catches validation errors and returns a generic message
expect(mockReply.send).toHaveBeenCalledWith('Invalid provider or configuration');
});
it('should return 400 if providerId is missing', async () => {
const mockRequest = createMockRequest('unraid.local');
await controller.oidcAuthorize(
undefined as any,
'test-state',
'https://unraid.local/callback',
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(400);
expect(mockReply.send).toHaveBeenCalledWith('Invalid provider or configuration');
});
it('should return 400 if state is missing', async () => {
const mockRequest = createMockRequest('unraid.local');
await controller.oidcAuthorize(
'test-provider',
undefined as any,
'https://unraid.local/callback',
mockRequest,
mockReply as FastifyReply
);
expect(mockReply.status).toHaveBeenCalledWith(400);
expect(mockReply.send).toHaveBeenCalledWith('Invalid provider or configuration');
});
});
});
});

View File

@@ -2,18 +2,25 @@ import { Controller, Get, Logger, Param, Query, Req, Res, UnauthorizedException
import { AuthAction, Resource } from '@unraid/shared/graphql.model.js';
import { UsePermissions } from '@unraid/shared/use-permissions.directive.js';
import escapeHtml from 'escape-html';
import type { FastifyReply, FastifyRequest } from '@app/unraid-api/types/fastify.js';
import { Public } from '@app/unraid-api/auth/public.decorator.js';
import { OidcAuthService } from '@app/unraid-api/graph/resolvers/sso/oidc-auth.service.js';
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
import { OidcService } from '@app/unraid-api/graph/resolvers/sso/core/oidc.service.js';
import { OidcRequestHandler } from '@app/unraid-api/graph/resolvers/sso/utils/oidc-request-handler.util.js';
import { RestService } from '@app/unraid-api/rest/rest.service.js';
import { validateRedirectUri } from '@app/unraid-api/utils/redirect-uri-validator.js';
@Controller()
export class RestController {
protected logger = new Logger(RestController.name);
protected oidcLogger = new Logger('OidcRestController');
constructor(
private readonly restService: RestService,
private readonly oidcAuthService: OidcAuthService
private readonly oidcService: OidcService,
private readonly oidcConfig: OidcConfigPersistence
) {}
@Get('/')
@@ -65,38 +72,69 @@ export class RestController {
async oidcAuthorize(
@Param('providerId') providerId: string,
@Query('state') state: string,
@Query('redirect_uri') redirectUri: string,
@Req() req: FastifyRequest,
@Res() res: FastifyReply
) {
try {
if (!state) {
return res.status(400).send('State parameter is required');
// Validate required parameters
const params = OidcRequestHandler.validateAuthorizeParams(providerId, state, redirectUri);
// IMPORTANT: Use the redirect_uri from query params directly
// Do NOT parse headers or try to build/validate against headers
// The frontend provides the complete redirect_uri
if (!params.redirectUri) {
return res.status(400).send('redirect_uri parameter is required');
}
// Get the host and protocol from the request headers
const protocol = (req.headers['x-forwarded-proto'] as string) || req.protocol || 'http';
const host = (req.headers['x-forwarded-host'] as string) || req.headers.host || undefined;
const requestInfo = host ? `${protocol}://${host}` : undefined;
// Security validation: validate redirect_uri with support for allowed origins
const protocol = (req.headers['x-forwarded-proto'] as string) || 'http';
const host = (req.headers['x-forwarded-host'] as string) || req.headers.host || req.hostname;
const authUrl = await this.oidcAuthService.getAuthorizationUrl(
providerId,
state,
requestInfo
// Get allowed origins from OIDC config
const config = await this.oidcConfig.getConfig();
const allowedOrigins = config?.defaultAllowedOrigins;
// Validate the redirect URI using the centralized validator
const validation = validateRedirectUri(
params.redirectUri,
protocol,
host,
this.oidcLogger,
allowedOrigins
);
if (!validation.isValid) {
this.oidcLogger.warn(`Invalid redirect_uri: ${validation.reason}`);
return res
.status(400)
.send(
`Invalid redirect_uri: ${escapeHtml(params.redirectUri)}. ${escapeHtml(validation.reason || 'Unknown validation error')}. Please add this callback URI to Settings → Management Access → Allowed Redirect URIs`
);
}
// Handle authorization flow using the exact redirect_uri from query params
const authUrl = await OidcRequestHandler.handleAuthorize(
params.providerId,
params.state,
params.redirectUri,
req,
this.oidcService,
this.oidcLogger
);
this.logger.log(`Redirecting to OIDC provider: ${authUrl}`);
// Manually set redirect headers for better proxy compatibility
res.status(302);
res.header('Location', authUrl);
return res.send();
} catch (error: unknown) {
this.logger.error(`OIDC authorize error for provider ${providerId}:`, error);
this.oidcLogger.error(`OIDC authorize error for provider ${providerId}:`, error);
// Log more details about the error
if (error instanceof Error) {
this.logger.error(`Error message: ${error.message}`);
this.oidcLogger.error(`Error message: ${error.message}`);
if (error.stack) {
this.logger.debug(`Stack trace: ${error.stack}`);
this.oidcLogger.debug(`Stack trace: ${error.stack}`);
}
}
@@ -117,32 +155,20 @@ export class RestController {
@Res() res: FastifyReply
) {
try {
if (!code || !state) {
return res.status(400).send('Missing required parameters');
}
// Validate required parameters
const params = OidcRequestHandler.validateCallbackParams(code, state);
// Extract provider ID from state
const { providerId } = this.oidcAuthService.extractProviderFromState(state);
// Get the full callback URL as received, respecting reverse proxy headers
const protocol = (req.headers['x-forwarded-proto'] as string) || req.protocol || 'http';
const host =
(req.headers['x-forwarded-host'] as string) || req.headers.host || 'localhost:3000';
const fullUrl = `${protocol}://${host}${req.url}`;
const requestInfo = `${protocol}://${host}`;
this.logger.debug(`Full callback URL from request: ${fullUrl}`);
const paddedToken = await this.oidcAuthService.handleCallback(
providerId,
code,
state,
requestInfo,
fullUrl
// Handle callback flow
const result = await OidcRequestHandler.handleCallback(
params.code,
params.state,
req,
this.oidcService,
this.oidcLogger
);
// Redirect to login page with the token in hash to keep it out of server logs
const loginUrl = `/login#token=${encodeURIComponent(paddedToken)}`;
const loginUrl = `/login#token=${encodeURIComponent(result.paddedToken)}`;
// Manually set redirect headers for better proxy compatibility
res.header('Cache-Control', 'no-store');
@@ -152,16 +178,16 @@ export class RestController {
res.header('Location', loginUrl);
return res.send();
} catch (error: unknown) {
this.logger.error(`OIDC callback error: ${error}`);
this.oidcLogger.error(`OIDC callback error: ${error}`);
// Use a generic error message to avoid leaking sensitive information
const errorMessage = 'Authentication failed';
// Log detailed error for debugging but don't expose to user
if (error instanceof UnauthorizedException) {
this.logger.debug(`UnauthorizedException occurred during OIDC callback`);
this.oidcLogger.debug(`UnauthorizedException occurred during OIDC callback`);
} else if (error instanceof Error) {
this.logger.debug(`Error during OIDC callback: ${error.message}`);
this.oidcLogger.debug(`Error during OIDC callback: ${error.message}`);
}
const loginUrl = `/login#error=${encodeURIComponent(errorMessage)}`;

View File

@@ -0,0 +1,406 @@
import { describe, expect, it, vi } from 'vitest';
import { ErrorExtractor } from '@app/unraid-api/utils/error-extractor.util.js';
describe('ErrorExtractor', () => {
describe('extract', () => {
it('should handle null and undefined errors', () => {
const nullResult = ErrorExtractor.extract(null);
expect(nullResult.message).toBe('Unknown error');
expect(nullResult.type).toBe('Unknown');
const undefinedResult = ErrorExtractor.extract(undefined);
expect(undefinedResult.message).toBe('Unknown error');
expect(undefinedResult.type).toBe('Unknown');
});
it('should extract basic Error properties', () => {
const error = new Error('Test error message');
const result = ErrorExtractor.extract(error);
expect(result.message).toBe('Test error message');
expect(result.type).toBe('Error');
expect(result.stack).toBeDefined();
});
it('should extract custom error types', () => {
class CustomError extends Error {}
const error = new CustomError('Custom error');
const result = ErrorExtractor.extract(error);
expect(result.type).toBe('CustomError');
});
it('should extract error code', () => {
const error: any = new Error('Error with code');
error.code = 'ERR_CODE';
const result = ErrorExtractor.extract(error);
expect(result.code).toBe('ERR_CODE');
});
it('should extract HTTP response details', () => {
const error: any = new Error('HTTP error');
error.response = {
status: 404,
statusText: 'Not Found',
body: { error: 'Resource not found' },
headers: { 'content-type': 'application/json' },
};
const result = ErrorExtractor.extract(error);
expect(result.status).toBe(404);
expect(result.statusText).toBe('Not Found');
expect(result.responseBody).toEqual({ error: 'Resource not found' });
expect(result.responseHeaders).toEqual({ 'content-type': 'application/json' });
});
it('should truncate long response body strings', () => {
const error: any = new Error('Error with long body');
const longString = 'x'.repeat(2000);
error.body = longString;
const result = ErrorExtractor.extract(error);
expect(result.responseBody).toBe('x'.repeat(1000) + '... (truncated)');
});
it('should extract OAuth error details', () => {
const error: any = new Error('OAuth error');
error.error = 'invalid_grant';
error.error_description = 'The provided authorization code is invalid';
const result = ErrorExtractor.extract(error);
expect(result.oauthError).toBe('invalid_grant');
expect(result.oauthErrorDescription).toBe('The provided authorization code is invalid');
});
it('should extract cause chain', () => {
const rootCause = new Error('Root cause');
const middleCause: any = new Error('Middle cause');
middleCause.cause = rootCause;
const topError: any = new Error('Top error');
topError.cause = middleCause;
const result = ErrorExtractor.extract(topError);
expect(result.causeChain).toHaveLength(2);
expect(result.causeChain![0]).toEqual({
depth: 1,
type: 'Error',
message: 'Middle cause',
});
expect(result.causeChain![1]).toEqual({
depth: 2,
type: 'Error',
message: 'Root cause',
});
});
it('should limit cause chain depth', () => {
// Create a deep nested error chain
let deepestError: any = new Error('Level 10');
for (let i = 9; i >= 0; i--) {
const error: any = new Error(`Level ${i}`);
error.cause = deepestError;
deepestError = error;
}
const topError: any = new Error('Top');
topError.cause = deepestError;
const result = ErrorExtractor.extract(topError);
expect(result.causeChain).toHaveLength(5); // MAX_CAUSE_DEPTH
});
it('should extract cause with code', () => {
const cause: any = new Error('Cause with code');
cause.code = 'ECONNREFUSED';
const error: any = new Error('Main error');
error.cause = cause;
const result = ErrorExtractor.extract(error);
expect(result.causeChain![0].code).toBe('ECONNREFUSED');
});
it('should extract additional properties', () => {
const error: any = new Error('Error with extras');
error.customProp1 = 'value1';
error.customProp2 = 123;
const result = ErrorExtractor.extract(error);
expect(result.additionalProperties).toEqual({
customProp1: 'value1',
customProp2: 123,
});
});
it('should handle string errors', () => {
const result = ErrorExtractor.extract('String error message');
expect(result.message).toBe('String error message');
expect(result.type).toBe('String');
});
it('should handle object errors', () => {
const error = { code: 'ERROR', message: 'Object error' };
const result = ErrorExtractor.extract(error);
expect(result.message).toBe(JSON.stringify(error));
expect(result.type).toBe('Object');
});
it('should handle primitive errors', () => {
const result = ErrorExtractor.extract(42);
expect(result.message).toBe('42');
expect(result.type).toBe('number');
});
it('should handle openid-client error structure', () => {
const error: any = new Error('unexpected response content-type');
error.code = 'OAUTH_RESPONSE_IS_NOT_JSON';
error.response = {
status: 200,
headers: { 'content-type': 'text/html' },
body: '<html>Error page</html>',
};
const result = ErrorExtractor.extract(error);
expect(result.code).toBe('OAUTH_RESPONSE_IS_NOT_JSON');
expect(result.responseHeaders!['content-type']).toBe('text/html');
expect(result.responseBody).toContain('<html>');
});
});
describe('isOAuthResponseError', () => {
it('should identify OAuth response errors by code', () => {
const extracted = {
message: 'Some error',
type: 'Error',
code: 'OAUTH_RESPONSE_IS_NOT_JSON',
};
expect(ErrorExtractor.isOAuthResponseError(extracted)).toBe(true);
});
it('should identify OAuth response errors by message', () => {
const extracted = {
message: 'unexpected response content-type from server',
type: 'Error',
};
expect(ErrorExtractor.isOAuthResponseError(extracted)).toBe(true);
});
it('should identify parsing errors', () => {
const extracted = {
message: 'JSON parsing error occurred',
type: 'Error',
};
expect(ErrorExtractor.isOAuthResponseError(extracted)).toBe(true);
});
it('should not identify non-OAuth errors', () => {
const extracted = {
message: 'Some other error',
type: 'Error',
code: 'OTHER_ERROR',
};
expect(ErrorExtractor.isOAuthResponseError(extracted)).toBe(false);
});
});
describe('isJwtClaimError', () => {
it('should identify JWT claim errors', () => {
const extracted = {
message: 'unexpected JWT claim value encountered',
type: 'Error',
};
expect(ErrorExtractor.isJwtClaimError(extracted)).toBe(true);
});
it('should not identify non-JWT errors', () => {
const extracted = {
message: 'Some other error',
type: 'Error',
};
expect(ErrorExtractor.isJwtClaimError(extracted)).toBe(false);
});
});
describe('isNetworkError', () => {
it('should identify network errors by code', () => {
const codes = ['ECONNREFUSED', 'ENOTFOUND', 'ETIMEDOUT', 'ECONNRESET'];
for (const code of codes) {
const extracted = {
message: 'Error',
type: 'Error',
code,
};
expect(ErrorExtractor.isNetworkError(extracted)).toBe(true);
}
});
it('should identify network errors by message', () => {
const messages = ['network timeout occurred', 'failed to connect to server'];
for (const message of messages) {
const extracted = {
message,
type: 'Error',
};
expect(ErrorExtractor.isNetworkError(extracted)).toBe(true);
}
});
it('should not identify non-network errors', () => {
const extracted = {
message: 'Invalid credentials',
type: 'Error',
code: 'AUTH_ERROR',
};
expect(ErrorExtractor.isNetworkError(extracted)).toBe(false);
});
});
describe('formatForLogging', () => {
it('should log basic error information', () => {
const logger = {
error: vi.fn(),
debug: vi.fn(),
};
const extracted = {
message: 'Test error',
type: 'CustomError',
code: 'ERR_TEST',
};
ErrorExtractor.formatForLogging(extracted, logger);
expect(logger.error).toHaveBeenCalledWith('Error type: CustomError');
expect(logger.error).toHaveBeenCalledWith('Error message: Test error');
expect(logger.error).toHaveBeenCalledWith('Error code: ERR_TEST');
});
it('should log HTTP response details', () => {
const logger = {
error: vi.fn(),
debug: vi.fn(),
};
const extracted = {
message: 'HTTP error',
type: 'Error',
status: 500,
statusText: 'Internal Server Error',
responseBody: { error: 'Server error' },
responseHeaders: { 'content-type': 'application/json' },
};
ErrorExtractor.formatForLogging(extracted, logger);
expect(logger.error).toHaveBeenCalledWith('HTTP Status: 500 Internal Server Error');
expect(logger.error).toHaveBeenCalledWith('Response body: %o', { error: 'Server error' });
expect(logger.error).toHaveBeenCalledWith('Response Content-Type: application/json');
});
it('should log OAuth error details', () => {
const logger = {
error: vi.fn(),
debug: vi.fn(),
};
const extracted = {
message: 'OAuth error',
type: 'Error',
oauthError: 'invalid_client',
oauthErrorDescription: 'Client authentication failed',
};
ErrorExtractor.formatForLogging(extracted, logger);
expect(logger.error).toHaveBeenCalledWith('OAuth error: invalid_client');
expect(logger.error).toHaveBeenCalledWith(
'OAuth error description: Client authentication failed'
);
});
it('should log cause chain', () => {
const logger = {
error: vi.fn(),
debug: vi.fn(),
};
const extracted = {
message: 'Top error',
type: 'Error',
causeChain: [
{ depth: 1, type: 'Error', message: 'Cause 1', code: 'CODE1' },
{ depth: 2, type: 'Error', message: 'Cause 2' },
],
};
ErrorExtractor.formatForLogging(extracted, logger);
expect(logger.error).toHaveBeenCalledWith('Error cause chain:');
expect(logger.error).toHaveBeenCalledWith(' [Cause 1] Error: Cause 1');
expect(logger.error).toHaveBeenCalledWith(' [Cause 1] Code: CODE1');
expect(logger.error).toHaveBeenCalledWith(' [Cause 2] Error: Cause 2');
});
it('should log additional properties and stack in debug', () => {
const logger = {
error: vi.fn(),
debug: vi.fn(),
};
const extracted = {
message: 'Error',
type: 'Error',
additionalProperties: { custom: 'value' },
stack: 'Stack trace here',
};
ErrorExtractor.formatForLogging(extracted, logger);
expect(logger.debug).toHaveBeenCalledWith('Additional error properties: %o', {
custom: 'value',
});
expect(logger.debug).toHaveBeenCalledWith('Stack trace: Stack trace here');
});
it('should handle case-insensitive Content-Type header', () => {
const logger = {
error: vi.fn(),
debug: vi.fn(),
};
const extracted = {
message: 'Error',
type: 'Error',
responseHeaders: { 'Content-Type': 'text/html' },
};
ErrorExtractor.formatForLogging(extracted, logger);
expect(logger.error).toHaveBeenCalledWith('Response Content-Type: text/html');
});
});
});

View File

@@ -0,0 +1,277 @@
export interface ExtractedError {
message: string;
type: string;
code?: string;
status?: number;
statusText?: string;
responseBody?: unknown;
responseHeaders?: Record<string, string>;
oauthError?: string;
oauthErrorDescription?: string;
causeChain?: Array<{
depth: number;
type: string;
message: string;
code?: string;
}>;
additionalProperties?: Record<string, unknown>;
stack?: string;
}
export class ErrorExtractor {
private static readonly MAX_CAUSE_DEPTH = 5;
private static readonly MAX_BODY_PREVIEW_LENGTH = 1000;
static extract(error: unknown): ExtractedError {
const result: ExtractedError = {
message: 'Unknown error',
type: 'Unknown',
};
if (!error) {
return result;
}
if (error instanceof Error) {
result.message = error.message;
result.type = error.constructor.name;
result.stack = error.stack;
// Extract error code if present
if ('code' in error && error.code) {
result.code = String(error.code);
}
// Extract HTTP response details
if ('response' in error && error.response) {
this.extractResponseDetails(error.response as any, result);
}
// Extract OAuth-specific errors
if ('error' in error && error.error) {
result.oauthError = String(error.error);
}
if ('error_description' in error && error.error_description) {
result.oauthErrorDescription = String(error.error_description);
}
// Extract response body if directly available
if ('body' in error && error.body) {
result.responseBody = this.formatResponseBody(error.body as any);
}
// Extract cause chain
if ('cause' in error && error.cause) {
result.causeChain = this.extractCauseChain(error.cause);
}
// Extract additional properties
const standardKeys = [
'message',
'stack',
'cause',
'code',
'response',
'body',
'error',
'error_description',
];
const additionalKeys = Object.keys(error).filter((key) => !standardKeys.includes(key));
if (additionalKeys.length > 0) {
result.additionalProperties = {};
for (const key of additionalKeys) {
const value = (error as any)[key];
if (value !== undefined && value !== null) {
result.additionalProperties[key] = value;
}
}
}
} else if (typeof error === 'string') {
result.message = error;
result.type = 'String';
} else if (typeof error === 'object' && error !== null) {
result.message = JSON.stringify(error);
result.type = 'Object';
} else {
result.message = String(error);
result.type = typeof error;
}
return result;
}
private static extractResponseDetails(response: any, result: ExtractedError): void {
if (!response) return;
if (response.status) {
result.status = response.status;
}
if (response.statusText) {
result.statusText = response.statusText;
}
if (response.body) {
result.responseBody = this.formatResponseBody(response.body);
}
if (response.headers) {
result.responseHeaders = this.extractHeaders(response.headers);
}
}
private static formatResponseBody(body: unknown): unknown {
if (!body) return undefined;
if (typeof body === 'string') {
if (body.length > this.MAX_BODY_PREVIEW_LENGTH) {
return body.substring(0, this.MAX_BODY_PREVIEW_LENGTH) + '... (truncated)';
}
return body;
}
return body;
}
private static extractHeaders(headers: any): Record<string, string> | undefined {
if (!headers) return undefined;
const result: Record<string, string> = {};
// Handle different header formats
if (typeof headers === 'object') {
for (const key of Object.keys(headers)) {
const value = headers[key];
if (value !== undefined && value !== null) {
result[key] = String(value);
}
}
}
return Object.keys(result).length > 0 ? result : undefined;
}
private static extractCauseChain(
cause: unknown
): Array<{ depth: number; type: string; message: string; code?: string }> | undefined {
const chain: Array<{ depth: number; type: string; message: string; code?: string }> = [];
let currentCause = cause;
let depth = 1;
while (currentCause && depth <= this.MAX_CAUSE_DEPTH) {
const causeInfo: { depth: number; type: string; message: string; code?: string } = {
depth,
type: 'Unknown',
message: 'Unknown cause',
};
if (currentCause instanceof Error) {
causeInfo.type = currentCause.constructor.name;
causeInfo.message = currentCause.message;
if ('code' in currentCause && currentCause.code) {
causeInfo.code = String(currentCause.code);
}
} else if (typeof currentCause === 'string') {
causeInfo.type = 'String';
causeInfo.message = currentCause;
} else {
causeInfo.type = typeof currentCause;
causeInfo.message = String(currentCause);
}
chain.push(causeInfo);
// Get next cause in chain
currentCause =
currentCause && typeof currentCause === 'object' && 'cause' in currentCause
? (currentCause as any).cause
: undefined;
depth++;
}
return chain.length > 0 ? chain : undefined;
}
static isOAuthResponseError(extracted: ExtractedError): boolean {
return Boolean(
extracted.code === 'OAUTH_RESPONSE_IS_NOT_JSON' ||
extracted.code === 'OAUTH_PARSE_ERROR' ||
extracted.message.includes('unexpected response content-type') ||
extracted.message.includes('parsing error')
);
}
static isJwtClaimError(extracted: ExtractedError): boolean {
return extracted.message.includes('unexpected JWT claim value encountered');
}
static isNetworkError(extracted: ExtractedError): boolean {
return Boolean(
extracted.code === 'ECONNREFUSED' ||
extracted.code === 'ENOTFOUND' ||
extracted.code === 'ETIMEDOUT' ||
extracted.code === 'ECONNRESET' ||
extracted.message.includes('network') ||
extracted.message.includes('connect')
);
}
static formatForLogging(
extracted: ExtractedError,
logger: {
error: (msg: string, ...args: any[]) => void;
debug: (msg: string, ...args: any[]) => void;
}
): void {
logger.error(`Error type: ${extracted.type}`);
logger.error(`Error message: ${extracted.message}`);
if (extracted.code) {
logger.error(`Error code: ${extracted.code}`);
}
if (extracted.status) {
logger.error(`HTTP Status: ${extracted.status} ${extracted.statusText || ''}`);
}
if (extracted.responseBody) {
logger.error('Response body: %o', extracted.responseBody);
}
if (extracted.responseHeaders) {
const contentType =
extracted.responseHeaders['content-type'] || extracted.responseHeaders['Content-Type'];
if (contentType) {
logger.error(`Response Content-Type: ${contentType}`);
}
}
if (extracted.oauthError) {
logger.error(`OAuth error: ${extracted.oauthError}`);
if (extracted.oauthErrorDescription) {
logger.error(`OAuth error description: ${extracted.oauthErrorDescription}`);
}
}
if (extracted.causeChain) {
logger.error('Error cause chain:');
for (const cause of extracted.causeChain) {
logger.error(` [Cause ${cause.depth}] ${cause.type}: ${cause.message}`);
if (cause.code) {
logger.error(` [Cause ${cause.depth}] Code: ${cause.code}`);
}
}
}
if (extracted.additionalProperties && Object.keys(extracted.additionalProperties).length > 0) {
logger.debug('Additional error properties: %o', extracted.additionalProperties);
}
if (extracted.stack) {
logger.debug(`Stack trace: ${extracted.stack}`);
}
}
}

View File

@@ -0,0 +1,506 @@
import { Logger } from '@nestjs/common';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { validateRedirectUri } from '@app/unraid-api/utils/redirect-uri-validator.js';
describe('validateRedirectUri', () => {
let mockLogger: Logger;
beforeEach(() => {
mockLogger = {
debug: vi.fn(),
warn: vi.fn(),
} as any;
});
describe('basic validation', () => {
it('should return base URL when no redirect URI is provided', () => {
const result = validateRedirectUri(undefined, 'https', 'example.com', mockLogger);
expect(result).toEqual({
isValid: true,
validatedUri: 'https://example.com',
reason: 'No redirect URI provided',
});
});
it('should handle missing base URL', () => {
const result = validateRedirectUri('https://example.com', 'https', undefined, mockLogger);
expect(result).toEqual({
isValid: false,
validatedUri: '',
reason: 'No base URL available',
});
});
});
describe('hostname validation', () => {
it('should accept matching hostname with same port', () => {
const result = validateRedirectUri(
'https://example.com:3000',
'https',
'example.com:3000',
mockLogger
);
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe('https://example.com:3000');
expect(mockLogger.debug).toHaveBeenCalledWith(
'Validated redirect_uri: https://example.com:3000'
);
});
it('should accept matching hostname with different ports', () => {
const result = validateRedirectUri(
'https://example.com:3001',
'https',
'example.com:3000',
mockLogger
);
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe('https://example.com:3001');
expect(mockLogger.debug).toHaveBeenCalledWith(
'Validated redirect_uri: https://example.com:3001'
);
});
it('should accept matching hostname when expected has no port but provided does', () => {
const result = validateRedirectUri(
'https://example.com:3000',
'https',
'example.com',
mockLogger
);
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe('https://example.com:3000');
});
it('should reject different hostnames', () => {
const result = validateRedirectUri('https://evil.com', 'https', 'example.com', mockLogger);
expect(result.isValid).toBe(false);
expect(result.validatedUri).toBe('https://example.com');
expect(result.reason).toContain('Hostname or protocol mismatch');
expect(mockLogger.warn).toHaveBeenCalled();
});
it('should reject subdomain differences', () => {
const result = validateRedirectUri(
'https://sub.example.com',
'https',
'example.com',
mockLogger
);
expect(result.isValid).toBe(false);
expect(result.validatedUri).toBe('https://example.com');
});
it('should handle case-insensitive hostname comparison', () => {
const result = validateRedirectUri(
'https://EXAMPLE.COM:3000',
'https',
'example.com',
mockLogger
);
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe('https://EXAMPLE.COM:3000');
});
});
describe('protocol validation', () => {
it('should reject HTTP when expecting HTTPS (prevent downgrade attacks)', () => {
const result = validateRedirectUri('http://example.com', 'https', 'example.com', mockLogger);
expect(result.isValid).toBe(false);
expect(result.validatedUri).toBe('https://example.com');
expect(result.reason).toContain('Hostname or protocol mismatch');
});
it('should allow HTTPS when expecting HTTP (common with reverse proxies)', () => {
const result = validateRedirectUri('https://example.com', 'http', 'example.com', mockLogger);
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe('https://example.com');
});
it('should accept matching protocols', () => {
const result = validateRedirectUri('http://example.com', 'http', 'example.com', mockLogger);
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe('http://example.com');
});
});
describe('malformed URL handling', () => {
it('should reject invalid URL format', () => {
const result = validateRedirectUri('not-a-valid-url', 'https', 'example.com', mockLogger);
expect(result.isValid).toBe(false);
expect(result.validatedUri).toBe('https://example.com');
expect(result.reason).toContain('Invalid redirect_uri format');
expect(mockLogger.warn).toHaveBeenCalledWith('Invalid redirect_uri format: not-a-valid-url');
});
it('should reject javascript protocol', () => {
const result = validateRedirectUri(
'javascript:alert(1)',
'https',
'example.com',
mockLogger
);
expect(result.isValid).toBe(false);
expect(result.validatedUri).toBe('https://example.com');
});
it('should handle URLs with paths and query params', () => {
const result = validateRedirectUri(
'https://example.com:3000/callback?foo=bar',
'https',
'example.com',
mockLogger
);
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe('https://example.com:3000/callback?foo=bar');
});
});
describe('security scenarios', () => {
it('should prevent open redirect to attacker domain', () => {
const result = validateRedirectUri(
'https://attacker.com/steal-token',
'https',
'legitimate.com',
mockLogger
);
expect(result.isValid).toBe(false);
expect(result.validatedUri).toBe('https://legitimate.com');
});
it('should prevent homograph attacks with similar looking domains', () => {
const result = validateRedirectUri(
'https://examp1e.com', // with number 1 instead of letter l
'https',
'example.com',
mockLogger
);
expect(result.isValid).toBe(false);
expect(result.validatedUri).toBe('https://example.com');
});
it('should handle localhost variations correctly', () => {
const result = validateRedirectUri(
'http://localhost:3001',
'http',
'localhost:3000',
mockLogger
);
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe('http://localhost:3001');
});
it('should handle IP addresses correctly', () => {
const result = validateRedirectUri(
'http://192.168.1.100:3001',
'http',
'192.168.1.100:3000',
mockLogger
);
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe('http://192.168.1.100:3001');
});
it('should reject IP when expecting domain', () => {
const result = validateRedirectUri(
'https://192.168.1.100',
'https',
'example.com',
mockLogger
);
expect(result.isValid).toBe(false);
expect(result.validatedUri).toBe('https://example.com');
});
});
describe('allowed origins validation', () => {
it('should accept redirect URI with different hostname if in allowed origins', () => {
const result = validateRedirectUri(
'https://devgen-bad-dev1.local/graphql/api/auth/oidc/callback',
'http',
'devgen-dev1.local',
mockLogger,
['https://devgen-bad-dev1.local']
);
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe(
'https://devgen-bad-dev1.local/graphql/api/auth/oidc/callback'
);
});
it('should reject redirect URI with hostname not in allowed origins', () => {
const result = validateRedirectUri(
'https://evil.com/callback',
'http',
'devgen-dev1.local',
mockLogger,
['https://devgen-bad-dev1.local', 'https://another-allowed.local']
);
expect(result.isValid).toBe(false);
expect(result.reason).toContain('Hostname or protocol mismatch');
});
it('should handle multiple allowed origins', () => {
const result = validateRedirectUri(
'https://second-allowed.local/callback',
'http',
'devgen-dev1.local',
mockLogger,
[
'https://first-allowed.local',
'https://second-allowed.local',
'https://third-allowed.local',
]
);
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe('https://second-allowed.local/callback');
});
it('should allow HTTPS upgrade for allowed origins', () => {
const result = validateRedirectUri(
'https://allowed-host.local/callback',
'http',
'devgen-dev1.local',
mockLogger,
['http://allowed-host.local']
);
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe('https://allowed-host.local/callback');
});
it('should validate extra hostname from config with proper logging', () => {
// Simulate a config with an extra hostname like "unraid.local"
const allowedOriginsFromConfig = ['https://unraid.local:8443'];
const result = validateRedirectUri(
'https://unraid.local:8443/graphql/api/auth/oidc/callback',
'http', // Expected from headers
'devgen-dev1.local', // Primary hostname
mockLogger,
allowedOriginsFromConfig
);
// This should pass if the extra hostname is configured correctly
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe('https://unraid.local:8443/graphql/api/auth/oidc/callback');
// Verify that debug logging shows the check
expect(mockLogger.debug).toHaveBeenCalledWith('Checking against 1 allowed origins');
expect(mockLogger.debug).toHaveBeenCalledWith(
'Checking allowed origin: https://unraid.local:8443'
);
expect(mockLogger.debug).toHaveBeenCalledWith(
' Origin match: https://unraid.local:8443 matches https://unraid.local:8443'
);
expect(mockLogger.debug).toHaveBeenCalledWith(
'Validated redirect_uri against allowed origin: https://unraid.local:8443/graphql/api/auth/oidc/callback'
);
});
it('should fail validation when extra hostname is not in config', () => {
// Config has no extra hostnames
const allowedOriginsFromConfig: string[] = [];
const result = validateRedirectUri(
'https://unraid.local:8443/graphql/api/auth/oidc/callback',
'http',
'devgen-dev1.local',
mockLogger,
allowedOriginsFromConfig
);
// This should fail since unraid.local is not in allowed origins
expect(result.isValid).toBe(false);
expect(result.validatedUri).toBe('http://devgen-dev1.local');
expect(result.reason).toContain('Hostname or protocol mismatch');
// Verify error logging
expect(mockLogger.warn).toHaveBeenCalledWith(
expect.stringContaining('Hostname or protocol mismatch')
);
});
describe('enhanced matching modes', () => {
it('should validate URL prefix match', () => {
const allowedOriginsFromConfig = ['https://unraid.local:8443/graphql/api/auth/oidc/'];
const result = validateRedirectUri(
'https://unraid.local:8443/graphql/api/auth/oidc/callback?code=123',
'http',
'devgen-dev1.local',
mockLogger,
allowedOriginsFromConfig
);
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe(
'https://unraid.local:8443/graphql/api/auth/oidc/callback?code=123'
);
expect(mockLogger.debug).toHaveBeenCalledWith(
' URL prefix match: https://unraid.local:8443/graphql/api/auth/oidc/callback?code=123 matches prefix https://unraid.local:8443/graphql/api/auth/oidc/'
);
});
it('should validate origin match (protocol + hostname + port)', () => {
const allowedOriginsFromConfig = ['https://unraid.local:8443'];
const result = validateRedirectUri(
'https://unraid.local:8443/graphql/api/auth/oidc/callback',
'http',
'devgen-dev1.local',
mockLogger,
allowedOriginsFromConfig
);
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe(
'https://unraid.local:8443/graphql/api/auth/oidc/callback'
);
expect(mockLogger.debug).toHaveBeenCalledWith(
' Origin match: https://unraid.local:8443 matches https://unraid.local:8443'
);
});
it('should validate hostname-only match (original behavior)', () => {
const allowedOriginsFromConfig = ['https://unraid.local'];
const result = validateRedirectUri(
'https://unraid.local:8443/graphql/api/auth/oidc/callback',
'http',
'devgen-dev1.local',
mockLogger,
allowedOriginsFromConfig
);
expect(result.isValid).toBe(true);
expect(result.validatedUri).toBe(
'https://unraid.local:8443/graphql/api/auth/oidc/callback'
);
expect(mockLogger.debug).toHaveBeenCalledWith(
' Hostname comparison: provided=unraid.local, allowed=unraid.local'
);
});
it('should prefer exact URL match over origin match', () => {
const allowedOriginsFromConfig = [
'https://unraid.local:8443/graphql/api/auth/oidc/callback', // Exact match first
'https://unraid.local:8443', // Origin match second
];
const result = validateRedirectUri(
'https://unraid.local:8443/graphql/api/auth/oidc/callback',
'http',
'devgen-dev1.local',
mockLogger,
allowedOriginsFromConfig
);
expect(result.isValid).toBe(true);
// Should match the first item in the array due to order of specificity
expect(mockLogger.debug).toHaveBeenCalledWith(
' Exact URL match: https://unraid.local:8443/graphql/api/auth/oidc/callback matches https://unraid.local:8443/graphql/api/auth/oidc/callback'
);
});
it('should prefer origin match over hostname match', () => {
const allowedOriginsFromConfig = [
'https://unraid.local:8443', // Origin match first
'https://unraid.local', // Hostname match second
];
const result = validateRedirectUri(
'https://unraid.local:8443/graphql/api/auth/oidc/callback',
'http',
'devgen-dev1.local',
mockLogger,
allowedOriginsFromConfig
);
expect(result.isValid).toBe(true);
expect(mockLogger.debug).toHaveBeenCalledWith(
' Origin match: https://unraid.local:8443 matches https://unraid.local:8443'
);
});
it('should handle protocol upgrades for all matching modes', () => {
const allowedOriginsFromConfig = [
'http://unraid.local:8443/graphql/api/auth/oidc/callback',
'http://unraid.local:8443',
'http://unraid.local',
];
// Test exact URL with protocol upgrade
const result1 = validateRedirectUri(
'https://unraid.local:8443/graphql/api/auth/oidc/callback',
'http',
'devgen-dev1.local',
mockLogger,
allowedOriginsFromConfig
);
expect(result1.isValid).toBe(true);
// Test origin with protocol upgrade
const result2 = validateRedirectUri(
'https://unraid.local:8443/different/path',
'http',
'devgen-dev1.local',
mockLogger,
allowedOriginsFromConfig
);
expect(result2.isValid).toBe(true);
// Test hostname with protocol upgrade
const result3 = validateRedirectUri(
'https://unraid.local:9443/different/path',
'http',
'devgen-dev1.local',
mockLogger,
allowedOriginsFromConfig
);
expect(result3.isValid).toBe(true);
});
it('should handle invalid allowed origin formats gracefully', () => {
const allowedOriginsFromConfig = ['not-a-valid-url', 'https://valid.url'];
const result = validateRedirectUri(
'https://valid.url/graphql/api/auth/oidc/callback',
'http',
'devgen-dev1.local',
mockLogger,
allowedOriginsFromConfig
);
expect(result.isValid).toBe(true);
expect(mockLogger.warn).toHaveBeenCalledWith(
'Invalid allowed origin format: not-a-valid-url'
);
});
});
});
});

View File

@@ -0,0 +1,187 @@
import { Logger } from '@nestjs/common';
export interface RedirectUriValidationResult {
isValid: boolean;
validatedUri: string;
reason?: string;
}
/**
* Validates a redirect URI against the expected origin from request headers.
* This is critical for OAuth security to prevent authorization code interception.
*
* Security considerations:
* - Prevents redirecting OAuth codes to external domains
* - Allows port variations (needed for nginx/socket proxy scenarios)
* - Validates protocol to prevent downgrade attacks
* - Optionally validates against additional allowed origins
*
* @param redirectUri - The redirect URI provided by the client
* @param expectedProtocol - The protocol from request headers (http/https)
* @param expectedHost - The host from request headers (may or may not include port)
* @param logger - Optional logger for debugging
* @param allowedOrigins - Optional list of additional allowed origins
* @returns Validation result with the URI to use
*/
export function validateRedirectUri(
redirectUri: string | undefined,
expectedProtocol: string,
expectedHost: string | undefined,
logger?: Logger,
allowedOrigins?: string[]
): RedirectUriValidationResult {
const baseUrl = expectedHost ? `${expectedProtocol}://${expectedHost}` : undefined;
// If no redirect URI provided, use the base URL
if (!redirectUri || !baseUrl) {
return {
isValid: !redirectUri,
validatedUri: baseUrl || '',
reason: !redirectUri ? 'No redirect URI provided' : 'No base URL available',
};
}
try {
// Parse both URLs to validate hostname
const providedUrl = new URL(redirectUri);
const expectedUrl = new URL(baseUrl);
// Security: Validate hostname matches, but allow port differences
// This handles cases where nginx/socket proxy doesn't preserve port info
const providedHostname = providedUrl.hostname.toLowerCase();
const expectedHostname = expectedUrl.hostname.toLowerCase();
// Check protocol matches, but allow HTTPS when expecting HTTP (common with reverse proxies)
// Never allow HTTP when expecting HTTPS (would be a downgrade attack)
const protocolMatches =
providedUrl.protocol === expectedUrl.protocol ||
(expectedUrl.protocol === 'http:' && providedUrl.protocol === 'https:');
const hostnameMatches = providedHostname === expectedHostname;
// Check against primary expected origin
if (protocolMatches && hostnameMatches) {
// Trust the redirect_uri with its port information
logger?.debug(`Validated redirect_uri: ${redirectUri}`);
return {
isValid: true,
validatedUri: redirectUri,
};
}
// Check against additional allowed origins if provided
if (allowedOrigins && allowedOrigins.length > 0) {
logger?.debug(`Checking against ${allowedOrigins.length} allowed origins`);
for (const allowedOrigin of allowedOrigins) {
try {
const allowedUrl = new URL(allowedOrigin);
const allowedOriginStr = allowedUrl.origin.toLowerCase();
const allowedHostname = allowedUrl.hostname.toLowerCase();
logger?.debug(`Checking allowed origin: ${allowedOrigin}`);
// Try multiple matching strategies in order of specificity
// 1. Exact URL match (if allowed origin includes path/query)
if (allowedOrigin.includes('/') && allowedOrigin.length > allowedOriginStr.length) {
const allowedUrlNormalized = allowedOrigin.toLowerCase();
const providedUrlNormalized = redirectUri.toLowerCase();
// Exact match
if (providedUrlNormalized === allowedUrlNormalized) {
logger?.debug(` Exact URL match: ${redirectUri} matches ${allowedOrigin}`);
logger?.debug(
`Validated redirect_uri against allowed origin: ${redirectUri}`
);
return {
isValid: true,
validatedUri: redirectUri,
};
}
// Prefix match (if allowed origin ends with /)
if (
allowedUrlNormalized.endsWith('/') &&
providedUrlNormalized.startsWith(allowedUrlNormalized)
) {
logger?.debug(
` URL prefix match: ${redirectUri} matches prefix ${allowedOrigin}`
);
logger?.debug(
`Validated redirect_uri against allowed origin: ${redirectUri}`
);
return {
isValid: true,
validatedUri: redirectUri,
};
}
}
// 2. Origin match (protocol + hostname + port)
const providedOrigin = providedUrl.origin.toLowerCase();
if (providedOrigin === allowedOriginStr) {
// Allow HTTPS when expecting HTTP (common with reverse proxies)
const originProtocolMatches =
providedUrl.protocol === allowedUrl.protocol ||
(allowedUrl.protocol === 'http:' && providedUrl.protocol === 'https:');
if (originProtocolMatches) {
logger?.debug(
` Origin match: ${providedOrigin} matches ${allowedOriginStr}`
);
logger?.debug(
`Validated redirect_uri against allowed origin: ${redirectUri}`
);
return {
isValid: true,
validatedUri: redirectUri,
};
}
}
// 3. Hostname match (original behavior, but with better logging)
const allowedProtocolMatches =
providedUrl.protocol === allowedUrl.protocol ||
(allowedUrl.protocol === 'http:' && providedUrl.protocol === 'https:');
const allowedHostnameMatches = providedHostname === allowedHostname;
logger?.debug(
` Hostname comparison: provided=${providedHostname}, allowed=${allowedHostname}`
);
logger?.debug(
` Protocol comparison: provided=${providedUrl.protocol}, allowed=${allowedUrl.protocol}`
);
logger?.debug(
` Protocol matches: ${allowedProtocolMatches}, Hostname matches: ${allowedHostnameMatches}`
);
if (allowedProtocolMatches && allowedHostnameMatches) {
logger?.debug(`Validated redirect_uri against allowed origin: ${redirectUri}`);
return {
isValid: true,
validatedUri: redirectUri,
};
}
} catch (e) {
logger?.warn(`Invalid allowed origin format: ${allowedOrigin}`);
}
}
}
// If we get here, validation failed
const reason = `Hostname or protocol mismatch. Expected: ${expectedUrl.protocol}//${expectedHostname}, Got: ${providedUrl.protocol}//${providedHostname}`;
logger?.warn(`Rejected redirect_uri: ${reason}`);
return {
isValid: false,
validatedUri: baseUrl,
reason,
};
} catch (error) {
const reason = `Invalid redirect_uri format: ${redirectUri}`;
logger?.warn(reason);
return {
isValid: false,
validatedUri: baseUrl,
reason,
};
}
}