mirror of
https://github.com/unraid/api.git
synced 2025-12-30 21:19:49 -06:00
feat(api): enhance OIDC redirect URI handling in service and tests (#1618)
- Updated `getRedirectUri` method in `OidcAuthService` to handle various edge cases for redirect URIs, including full URIs, malformed URLs, and default ports. - Added comprehensive tests for `OidcAuthService` to validate redirect URI construction and error handling. - Modified `RestController` to utilize `redirect_uri` query parameter for authorization requests. - Updated frontend components to include `redirect_uri` in authorization URLs, ensuring correct handling of different protocols and ports. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Stronger OIDC redirect_uri validation and an admin GraphQL endpoint to view full OIDC configuration. * OIDC Debug Logs UI (panel, button, modal), enhanced log viewer with presets/filters, ANSI-colored rendering, and a File Viewer component. * New GraphQL queries to list and fetch config files; API Config Download page. * **Refactor** * Centralized, modular OIDC flows and safer redirect handling; topic-based log subscriptions with a watcher manager for scalable live logs. * **Documentation** * Cache TTL guidance clarified to use milliseconds. * **Chores** * Added ansi_up and escape-html deps; improved log formatting; added root codegen script. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
This commit is contained in:
@@ -157,4 +157,7 @@ Enables GraphQL playground at `http://tower.local/graphql`
|
||||
|
||||
- We are using tailwind v4 we do not need a tailwind config anymore
|
||||
- always search the internet for tailwind v4 documentation when making tailwind related style changes
|
||||
- never run or restart the API server or web server. I will handle the lifecylce, simply wait and ask me to do this for you
|
||||
- never run or restart the API server or web server. I will handle the lifecycle, simply wait and ask me to do this for you
|
||||
- Never use the `any` type. Always prefer proper typing
|
||||
- Avoid using casting whenever possible, prefer proper typing from the start
|
||||
- **IMPORTANT:** cache-manager v7 expects TTL values in **milliseconds**, not seconds. Always use milliseconds when setting cache TTL (e.g., 600000 for 10 minutes, not 600)
|
||||
|
||||
@@ -17,5 +17,6 @@
|
||||
],
|
||||
"buttonText": "Login With Unraid.net"
|
||||
}
|
||||
]
|
||||
],
|
||||
"defaultAllowedOrigins": []
|
||||
}
|
||||
@@ -1798,6 +1798,8 @@ type Server implements Node {
|
||||
guid: String!
|
||||
apikey: String!
|
||||
name: String!
|
||||
|
||||
"""Whether this server is online or offline"""
|
||||
status: ServerStatus!
|
||||
wanip: String!
|
||||
lanip: String!
|
||||
@@ -1854,7 +1856,7 @@ type OidcProvider {
|
||||
"""
|
||||
OIDC issuer URL (e.g., https://accounts.google.com). Required for auto-discovery via /.well-known/openid-configuration
|
||||
"""
|
||||
issuer: String!
|
||||
issuer: String
|
||||
|
||||
"""
|
||||
OAuth2 authorization endpoint URL. If omitted, will be auto-discovered from issuer/.well-known/openid-configuration
|
||||
@@ -1907,6 +1909,16 @@ enum AuthorizationRuleMode {
|
||||
AND
|
||||
}
|
||||
|
||||
type OidcConfiguration {
|
||||
"""List of configured OIDC providers"""
|
||||
providers: [OidcProvider!]!
|
||||
|
||||
"""
|
||||
Default allowed redirect origins that apply to all OIDC providers (e.g., Tailscale domains)
|
||||
"""
|
||||
defaultAllowedOrigins: [String!]
|
||||
}
|
||||
|
||||
type OidcSessionValidation {
|
||||
valid: Boolean!
|
||||
username: String
|
||||
@@ -2307,8 +2319,6 @@ type Query {
|
||||
getApiKeyCreationFormSchema: ApiKeyFormSettings!
|
||||
config: Config!
|
||||
flash: Flash!
|
||||
logFiles: [LogFile!]!
|
||||
logFile(path: String!, lines: Int, startLine: Int): LogFileContent!
|
||||
me: UserAccount!
|
||||
|
||||
"""Get all notifications"""
|
||||
@@ -2335,6 +2345,8 @@ type Query {
|
||||
disk(id: PrefixedID!): Disk!
|
||||
rclone: RCloneBackupSettings!
|
||||
info: Info!
|
||||
logFiles: [LogFile!]!
|
||||
logFile(path: String!, lines: Int, startLine: Int): LogFileContent!
|
||||
settings: Settings!
|
||||
isSSOEnabled: Boolean!
|
||||
|
||||
@@ -2347,6 +2359,9 @@ type Query {
|
||||
"""Get a specific OIDC provider by ID"""
|
||||
oidcProvider(id: PrefixedID!): OidcProvider
|
||||
|
||||
"""Get the full OIDC configuration (admin only)"""
|
||||
oidcConfiguration: OidcConfiguration!
|
||||
|
||||
"""Validate an OIDC session token (internal use for CLI validation)"""
|
||||
validateOidcSession(token: String!): OidcSessionValidation!
|
||||
metrics: Metrics!
|
||||
@@ -2590,13 +2605,13 @@ input AccessUrlInput {
|
||||
}
|
||||
|
||||
type Subscription {
|
||||
logFile(path: String!): LogFileContent!
|
||||
notificationAdded: Notification!
|
||||
notificationsOverview: NotificationOverview!
|
||||
ownerSubscription: Owner!
|
||||
serversSubscription: Server!
|
||||
parityHistorySubscription: ParityCheck!
|
||||
arraySubscription: UnraidArray!
|
||||
logFile(path: String!): LogFileContent!
|
||||
systemMetricsCpu: CpuUtilization!
|
||||
systemMetricsMemory: MemoryUtilization!
|
||||
upsUpdates: UPSDevice!
|
||||
|
||||
@@ -99,6 +99,7 @@
|
||||
"diff": "8.0.2",
|
||||
"dockerode": "4.0.7",
|
||||
"dotenv": "17.2.1",
|
||||
"escape-html": "1.0.3",
|
||||
"execa": "9.6.0",
|
||||
"exit-hook": "4.0.0",
|
||||
"fastify": "5.5.0",
|
||||
|
||||
@@ -29,8 +29,24 @@ const stream = SUPPRESS_LOGS
|
||||
singleLine: true,
|
||||
hideObject: false,
|
||||
colorize: true,
|
||||
colorizeObjects: true,
|
||||
levelFirst: false,
|
||||
ignore: 'hostname,pid',
|
||||
destination: logDestination,
|
||||
translateTime: 'HH:mm:ss',
|
||||
customPrettifiers: {
|
||||
time: (timestamp: string | object) => `[${timestamp}`,
|
||||
level: (logLevel: string | object, key: string, log: any, extras: any) => {
|
||||
// Use labelColorized which preserves the colors
|
||||
const { labelColorized } = extras;
|
||||
const context = log.context || log.logger || 'app';
|
||||
return `${labelColorized} ${context}]`;
|
||||
},
|
||||
},
|
||||
messageFormat: (log: any, messageKey: string) => {
|
||||
const msg = log[messageKey] || log.msg || '';
|
||||
return msg;
|
||||
},
|
||||
})
|
||||
: logDestination;
|
||||
|
||||
|
||||
@@ -13,10 +13,11 @@ export const pubsub = new PubSub({ eventEmitter });
|
||||
|
||||
/**
|
||||
* Create a pubsub subscription.
|
||||
* @param channel The pubsub channel to subscribe to.
|
||||
* @param channel The pubsub channel to subscribe to. Can be either a predefined GRAPHQL_PUBSUB_CHANNEL
|
||||
* or a dynamic string for runtime-generated topics (e.g., log file paths like "LOG_FILE:/var/log/test.log")
|
||||
*/
|
||||
export const createSubscription = <T = any>(
|
||||
channel: GRAPHQL_PUBSUB_CHANNEL
|
||||
channel: GRAPHQL_PUBSUB_CHANNEL | string
|
||||
): AsyncIterableIterator<T> => {
|
||||
return pubsub.asyncIterableIterator<T>(channel);
|
||||
};
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { CacheModule } from '@nestjs/cache-manager';
|
||||
import { Test } from '@nestjs/testing';
|
||||
|
||||
import { describe, expect, it } from 'vitest';
|
||||
@@ -9,7 +10,7 @@ describe('Module Dependencies Integration', () => {
|
||||
let module;
|
||||
try {
|
||||
module = await Test.createTestingModule({
|
||||
imports: [RestModule],
|
||||
imports: [CacheModule.register({ isGlobal: true }), RestModule],
|
||||
}).compile();
|
||||
|
||||
expect(module).toBeDefined();
|
||||
|
||||
@@ -34,6 +34,15 @@ import { UnraidFileModifierModule } from '@app/unraid-api/unraid-file-modifier/u
|
||||
req: () => undefined,
|
||||
res: () => undefined,
|
||||
},
|
||||
formatters: {
|
||||
log: (obj) => {
|
||||
// Map NestJS context to Pino context field for pino-pretty
|
||||
if (obj.context && !obj.logger) {
|
||||
return { ...obj, logger: obj.context };
|
||||
}
|
||||
return obj;
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
AuthModule,
|
||||
|
||||
@@ -448,6 +448,20 @@ export enum ConfigErrorState {
|
||||
WITHDRAWN = 'WITHDRAWN'
|
||||
}
|
||||
|
||||
export type ConfigFile = {
|
||||
__typename?: 'ConfigFile';
|
||||
content: Scalars['String']['output'];
|
||||
name: Scalars['String']['output'];
|
||||
path: Scalars['String']['output'];
|
||||
/** Human-readable file size (e.g., "1.5 KB", "2.3 MB") */
|
||||
sizeReadable: Scalars['String']['output'];
|
||||
};
|
||||
|
||||
export type ConfigFilesResponse = {
|
||||
__typename?: 'ConfigFilesResponse';
|
||||
files: Array<ConfigFile>;
|
||||
};
|
||||
|
||||
export type Connect = Node & {
|
||||
__typename?: 'Connect';
|
||||
/** The status of dynamic remote access */
|
||||
@@ -1432,6 +1446,14 @@ export type OidcAuthorizationRule = {
|
||||
value: Array<Scalars['String']['output']>;
|
||||
};
|
||||
|
||||
export type OidcConfiguration = {
|
||||
__typename?: 'OidcConfiguration';
|
||||
/** Default allowed redirect origins that apply to all OIDC providers (e.g., Tailscale domains) */
|
||||
defaultAllowedOrigins?: Maybe<Array<Scalars['String']['output']>>;
|
||||
/** List of configured OIDC providers */
|
||||
providers: Array<OidcProvider>;
|
||||
};
|
||||
|
||||
export type OidcProvider = {
|
||||
__typename?: 'OidcProvider';
|
||||
/** OAuth2 authorization endpoint URL. If omitted, will be auto-discovered from issuer/.well-known/openid-configuration */
|
||||
@@ -1455,7 +1477,7 @@ export type OidcProvider = {
|
||||
/** The unique identifier for the OIDC provider */
|
||||
id: Scalars['PrefixedID']['output'];
|
||||
/** OIDC issuer URL (e.g., https://accounts.google.com). Required for auto-discovery via /.well-known/openid-configuration */
|
||||
issuer: Scalars['String']['output'];
|
||||
issuer?: Maybe<Scalars['String']['output']>;
|
||||
/** JSON Web Key Set URI for token validation. If omitted, will be auto-discovered from issuer/.well-known/openid-configuration */
|
||||
jwksUri?: Maybe<Scalars['String']['output']>;
|
||||
/** Display name of the OIDC provider */
|
||||
@@ -1623,6 +1645,7 @@ export type PublicPartnerInfo = {
|
||||
|
||||
export type Query = {
|
||||
__typename?: 'Query';
|
||||
allConfigFiles: ConfigFilesResponse;
|
||||
apiKey?: Maybe<ApiKey>;
|
||||
/** All possible permissions for API keys */
|
||||
apiKeyPossiblePermissions: Array<Permission>;
|
||||
@@ -1632,6 +1655,7 @@ export type Query = {
|
||||
array: UnraidArray;
|
||||
cloud: Cloud;
|
||||
config: Config;
|
||||
configFile?: Maybe<ConfigFile>;
|
||||
connect: Connect;
|
||||
customization?: Maybe<Customization>;
|
||||
disk: Disk;
|
||||
@@ -1654,6 +1678,8 @@ export type Query = {
|
||||
network: Network;
|
||||
/** Get all notifications */
|
||||
notifications: Notifications;
|
||||
/** Get the full OIDC configuration (admin only) */
|
||||
oidcConfiguration: OidcConfiguration;
|
||||
/** Get a specific OIDC provider by ID */
|
||||
oidcProvider?: Maybe<OidcProvider>;
|
||||
/** Get all configured OIDC providers (admin only) */
|
||||
@@ -1693,6 +1719,11 @@ export type QueryApiKeyArgs = {
|
||||
};
|
||||
|
||||
|
||||
export type QueryConfigFileArgs = {
|
||||
name: Scalars['String']['input'];
|
||||
};
|
||||
|
||||
|
||||
export type QueryDiskArgs = {
|
||||
id: Scalars['PrefixedID']['input'];
|
||||
};
|
||||
@@ -1933,6 +1964,7 @@ export type Server = Node & {
|
||||
name: Scalars['String']['output'];
|
||||
owner: ProfileModel;
|
||||
remoteurl: Scalars['String']['output'];
|
||||
/** Whether this server is online or offline */
|
||||
status: ServerStatus;
|
||||
wanip: Scalars['String']['output'];
|
||||
};
|
||||
|
||||
@@ -0,0 +1,213 @@
|
||||
import { Test, TestingModule } from '@nestjs/testing';
|
||||
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import {
|
||||
LogWatcherManager,
|
||||
WatcherState,
|
||||
} from '@app/unraid-api/graph/resolvers/logs/log-watcher-manager.service.js';
|
||||
|
||||
describe('LogWatcherManager', () => {
|
||||
let manager: LogWatcherManager;
|
||||
let mockWatcher: any;
|
||||
|
||||
beforeEach(async () => {
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
providers: [LogWatcherManager],
|
||||
}).compile();
|
||||
|
||||
manager = module.get<LogWatcherManager>(LogWatcherManager);
|
||||
|
||||
mockWatcher = {
|
||||
close: vi.fn(),
|
||||
on: vi.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
describe('state management', () => {
|
||||
it('should set watcher as initializing', () => {
|
||||
manager.setInitializing('test-key');
|
||||
const entry = manager.getEntry('test-key');
|
||||
expect(entry).toBeDefined();
|
||||
expect(entry?.state).toBe(WatcherState.INITIALIZING);
|
||||
});
|
||||
|
||||
it('should set watcher as active with position', () => {
|
||||
manager.setActive('test-key', mockWatcher as any, 1000);
|
||||
const entry = manager.getEntry('test-key');
|
||||
expect(entry).toBeDefined();
|
||||
expect(entry?.state).toBe(WatcherState.ACTIVE);
|
||||
if (manager.isActive(entry)) {
|
||||
expect(entry.watcher).toBe(mockWatcher);
|
||||
expect(entry.position).toBe(1000);
|
||||
}
|
||||
});
|
||||
|
||||
it('should set watcher as stopping', () => {
|
||||
manager.setStopping('test-key');
|
||||
const entry = manager.getEntry('test-key');
|
||||
expect(entry).toBeDefined();
|
||||
expect(entry?.state).toBe(WatcherState.STOPPING);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isWatchingOrInitializing', () => {
|
||||
it('should return true for initializing watcher', () => {
|
||||
manager.setInitializing('test-key');
|
||||
expect(manager.isWatchingOrInitializing('test-key')).toBe(true);
|
||||
});
|
||||
|
||||
it('should return true for active watcher', () => {
|
||||
manager.setActive('test-key', mockWatcher as any, 0);
|
||||
expect(manager.isWatchingOrInitializing('test-key')).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false for stopping watcher', () => {
|
||||
manager.setStopping('test-key');
|
||||
expect(manager.isWatchingOrInitializing('test-key')).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false for non-existent watcher', () => {
|
||||
expect(manager.isWatchingOrInitializing('test-key')).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('handlePostInitialization', () => {
|
||||
it('should activate watcher when not stopped', () => {
|
||||
manager.setInitializing('test-key');
|
||||
const result = manager.handlePostInitialization('test-key', mockWatcher as any, 500);
|
||||
|
||||
expect(result).toBe(true);
|
||||
expect(mockWatcher.close).not.toHaveBeenCalled();
|
||||
|
||||
const entry = manager.getEntry('test-key');
|
||||
expect(entry?.state).toBe(WatcherState.ACTIVE);
|
||||
if (manager.isActive(entry)) {
|
||||
expect(entry.position).toBe(500);
|
||||
}
|
||||
});
|
||||
|
||||
it('should cleanup watcher when marked as stopping', () => {
|
||||
manager.setStopping('test-key');
|
||||
const result = manager.handlePostInitialization('test-key', mockWatcher as any, 500);
|
||||
|
||||
expect(result).toBe(false);
|
||||
expect(mockWatcher.close).toHaveBeenCalled();
|
||||
expect(manager.getEntry('test-key')).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should cleanup watcher when entry is missing', () => {
|
||||
const result = manager.handlePostInitialization('test-key', mockWatcher as any, 500);
|
||||
|
||||
expect(result).toBe(false);
|
||||
expect(mockWatcher.close).toHaveBeenCalled();
|
||||
expect(manager.getEntry('test-key')).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('stopWatcher', () => {
|
||||
it('should mark initializing watcher as stopping', () => {
|
||||
manager.setInitializing('test-key');
|
||||
manager.stopWatcher('test-key');
|
||||
|
||||
const entry = manager.getEntry('test-key');
|
||||
expect(entry?.state).toBe(WatcherState.STOPPING);
|
||||
});
|
||||
|
||||
it('should close and remove active watcher', () => {
|
||||
manager.setActive('test-key', mockWatcher as any, 0);
|
||||
manager.stopWatcher('test-key');
|
||||
|
||||
expect(mockWatcher.close).toHaveBeenCalled();
|
||||
expect(manager.getEntry('test-key')).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should do nothing for non-existent watcher', () => {
|
||||
manager.stopWatcher('test-key');
|
||||
expect(mockWatcher.close).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('position management', () => {
|
||||
it('should update position for active watcher', () => {
|
||||
manager.setActive('test-key', mockWatcher as any, 100);
|
||||
manager.updatePosition('test-key', 200);
|
||||
|
||||
const position = manager.getPosition('test-key');
|
||||
expect(position).toBe(200);
|
||||
});
|
||||
|
||||
it('should not update position for non-active watcher', () => {
|
||||
manager.setInitializing('test-key');
|
||||
manager.updatePosition('test-key', 200);
|
||||
|
||||
const position = manager.getPosition('test-key');
|
||||
expect(position).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should get position for active watcher', () => {
|
||||
manager.setActive('test-key', mockWatcher as any, 300);
|
||||
expect(manager.getPosition('test-key')).toBe(300);
|
||||
});
|
||||
|
||||
it('should return undefined for non-active watcher', () => {
|
||||
manager.setStopping('test-key');
|
||||
expect(manager.getPosition('test-key')).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('stopAllWatchers', () => {
|
||||
it('should close all active watchers and clear map', () => {
|
||||
const mockWatcher1 = { close: vi.fn() };
|
||||
const mockWatcher2 = { close: vi.fn() };
|
||||
const mockWatcher3 = { close: vi.fn() };
|
||||
|
||||
manager.setActive('key1', mockWatcher1 as any, 0);
|
||||
manager.setInitializing('key2');
|
||||
manager.setActive('key3', mockWatcher2 as any, 0);
|
||||
manager.setStopping('key4');
|
||||
manager.setActive('key5', mockWatcher3 as any, 0);
|
||||
|
||||
manager.stopAllWatchers();
|
||||
|
||||
expect(mockWatcher1.close).toHaveBeenCalled();
|
||||
expect(mockWatcher2.close).toHaveBeenCalled();
|
||||
expect(mockWatcher3.close).toHaveBeenCalled();
|
||||
|
||||
expect(manager.getEntry('key1')).toBeUndefined();
|
||||
expect(manager.getEntry('key2')).toBeUndefined();
|
||||
expect(manager.getEntry('key3')).toBeUndefined();
|
||||
expect(manager.getEntry('key4')).toBeUndefined();
|
||||
expect(manager.getEntry('key5')).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('in-flight processing', () => {
|
||||
it('should prevent concurrent processing', () => {
|
||||
manager.setActive('test-key', mockWatcher as any, 0);
|
||||
|
||||
// First call should succeed
|
||||
expect(manager.startProcessing('test-key')).toBe(true);
|
||||
|
||||
// Second call should fail (already in flight)
|
||||
expect(manager.startProcessing('test-key')).toBe(false);
|
||||
|
||||
// After finishing, should be able to start again
|
||||
manager.finishProcessing('test-key');
|
||||
expect(manager.startProcessing('test-key')).toBe(true);
|
||||
});
|
||||
|
||||
it('should not start processing for non-active watcher', () => {
|
||||
manager.setInitializing('test-key');
|
||||
expect(manager.startProcessing('test-key')).toBe(false);
|
||||
|
||||
manager.setStopping('test-key');
|
||||
expect(manager.startProcessing('test-key')).toBe(false);
|
||||
});
|
||||
|
||||
it('should handle finish processing for non-existent watcher', () => {
|
||||
// Should not throw
|
||||
expect(() => manager.finishProcessing('non-existent')).not.toThrow();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,183 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
|
||||
import * as chokidar from 'chokidar';
|
||||
|
||||
export enum WatcherState {
|
||||
INITIALIZING = 'initializing',
|
||||
ACTIVE = 'active',
|
||||
STOPPING = 'stopping',
|
||||
}
|
||||
|
||||
export type WatcherEntry =
|
||||
| { state: WatcherState.INITIALIZING }
|
||||
| { state: WatcherState.ACTIVE; watcher: chokidar.FSWatcher; position: number; inFlight: boolean }
|
||||
| { state: WatcherState.STOPPING };
|
||||
|
||||
/**
|
||||
* Service responsible for managing log file watchers and their lifecycle.
|
||||
* Handles race conditions during watcher initialization and cleanup.
|
||||
*/
|
||||
@Injectable()
|
||||
export class LogWatcherManager {
|
||||
private readonly logger = new Logger(LogWatcherManager.name);
|
||||
private readonly watchers = new Map<string, WatcherEntry>();
|
||||
|
||||
/**
|
||||
* Set a watcher as initializing
|
||||
*/
|
||||
setInitializing(key: string): void {
|
||||
this.watchers.set(key, { state: WatcherState.INITIALIZING });
|
||||
}
|
||||
|
||||
/**
|
||||
* Set a watcher as active with its FSWatcher and position
|
||||
*/
|
||||
setActive(key: string, watcher: chokidar.FSWatcher, position: number): void {
|
||||
this.watchers.set(key, { state: WatcherState.ACTIVE, watcher, position, inFlight: false });
|
||||
}
|
||||
|
||||
/**
|
||||
* Mark a watcher as stopping (used during initialization race conditions)
|
||||
*/
|
||||
setStopping(key: string): void {
|
||||
this.watchers.set(key, { state: WatcherState.STOPPING });
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a watcher entry by key
|
||||
*/
|
||||
getEntry(key: string): WatcherEntry | undefined {
|
||||
return this.watchers.get(key);
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a watcher entry
|
||||
*/
|
||||
removeEntry(key: string): void {
|
||||
this.watchers.delete(key);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a watcher is active and return typed entry
|
||||
*/
|
||||
isActive(entry: WatcherEntry | undefined): entry is {
|
||||
state: WatcherState.ACTIVE;
|
||||
watcher: chokidar.FSWatcher;
|
||||
position: number;
|
||||
inFlight: boolean;
|
||||
} {
|
||||
return entry?.state === WatcherState.ACTIVE;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a watcher exists and is either initializing or active
|
||||
*/
|
||||
isWatchingOrInitializing(key: string): boolean {
|
||||
const entry = this.getEntry(key);
|
||||
return (
|
||||
entry !== undefined &&
|
||||
(entry.state === WatcherState.ACTIVE || entry.state === WatcherState.INITIALIZING)
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle cleanup after initialization completes.
|
||||
* Returns true if the watcher should continue, false if it should be cleaned up.
|
||||
*/
|
||||
handlePostInitialization(key: string, watcher: chokidar.FSWatcher, position: number): boolean {
|
||||
const currentEntry = this.getEntry(key);
|
||||
|
||||
if (!currentEntry || currentEntry.state === WatcherState.STOPPING) {
|
||||
// We were stopped during initialization, clean up immediately
|
||||
this.logger.debug(`Watcher for ${key} was stopped during initialization, cleaning up`);
|
||||
watcher.close();
|
||||
this.removeEntry(key);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Store the active watcher and position
|
||||
this.setActive(key, watcher, position);
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop a watcher, handling all possible states
|
||||
*/
|
||||
stopWatcher(key: string): void {
|
||||
const entry = this.getEntry(key);
|
||||
|
||||
if (!entry) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (entry.state === WatcherState.INITIALIZING) {
|
||||
// Mark as stopping so the initialization will clean up
|
||||
this.setStopping(key);
|
||||
this.logger.debug(`Marked watcher as stopping during initialization: ${key}`);
|
||||
} else if (entry.state === WatcherState.ACTIVE) {
|
||||
// Close the active watcher
|
||||
entry.watcher.close();
|
||||
this.removeEntry(key);
|
||||
this.logger.debug(`Stopped active watcher: ${key}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update the position for an active watcher
|
||||
*/
|
||||
updatePosition(key: string, newPosition: number): void {
|
||||
const entry = this.getEntry(key);
|
||||
if (this.isActive(entry)) {
|
||||
entry.position = newPosition;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Start processing a change event (set inFlight to true)
|
||||
* Returns true if processing can proceed, false if already in flight
|
||||
*/
|
||||
startProcessing(key: string): boolean {
|
||||
const entry = this.getEntry(key);
|
||||
if (this.isActive(entry)) {
|
||||
if (entry.inFlight) {
|
||||
return false; // Already processing
|
||||
}
|
||||
entry.inFlight = true;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Finish processing a change event (set inFlight to false)
|
||||
*/
|
||||
finishProcessing(key: string): void {
|
||||
const entry = this.getEntry(key);
|
||||
if (this.isActive(entry)) {
|
||||
entry.inFlight = false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the position for an active watcher
|
||||
*/
|
||||
getPosition(key: string): number | undefined {
|
||||
const entry = this.getEntry(key);
|
||||
if (this.isActive(entry)) {
|
||||
return entry.position;
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up all watchers (useful for module cleanup)
|
||||
*/
|
||||
stopAllWatchers(): void {
|
||||
for (const entry of this.watchers.values()) {
|
||||
if (this.isActive(entry)) {
|
||||
entry.watcher.close();
|
||||
}
|
||||
}
|
||||
this.watchers.clear();
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,13 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
|
||||
import { LogWatcherManager } from '@app/unraid-api/graph/resolvers/logs/log-watcher-manager.service.js';
|
||||
import { LogsResolver } from '@app/unraid-api/graph/resolvers/logs/logs.resolver.js';
|
||||
import { LogsService } from '@app/unraid-api/graph/resolvers/logs/logs.service.js';
|
||||
import { ServicesModule } from '@app/unraid-api/graph/services/services.module.js';
|
||||
|
||||
@Module({
|
||||
providers: [LogsResolver, LogsService],
|
||||
exports: [LogsService],
|
||||
imports: [ServicesModule],
|
||||
providers: [LogsResolver, LogsService, LogWatcherManager],
|
||||
exports: [LogsService, LogWatcherManager],
|
||||
})
|
||||
export class LogsModule {}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import { Test, TestingModule } from '@nestjs/testing';
|
||||
|
||||
import { beforeEach, describe, expect, it } from 'vitest';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { LogsResolver } from '@app/unraid-api/graph/resolvers/logs/logs.resolver.js';
|
||||
import { LogsService } from '@app/unraid-api/graph/resolvers/logs/logs.service.js';
|
||||
import { SubscriptionHelperService } from '@app/unraid-api/graph/services/subscription-helper.service.js';
|
||||
|
||||
describe('LogsResolver', () => {
|
||||
let resolver: LogsResolver;
|
||||
@@ -18,6 +19,13 @@ describe('LogsResolver', () => {
|
||||
// Add mock implementations for service methods used by resolver
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: SubscriptionHelperService,
|
||||
useValue: {
|
||||
// Add mock implementations for subscription helper methods
|
||||
createTrackedSubscription: vi.fn(),
|
||||
},
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
resolver = module.get<LogsResolver>(LogsResolver);
|
||||
|
||||
@@ -3,13 +3,16 @@ import { Args, Int, Query, Resolver, Subscription } from '@nestjs/graphql';
|
||||
import { AuthAction, Resource } from '@unraid/shared/graphql.model.js';
|
||||
import { UsePermissions } from '@unraid/shared/use-permissions.directive.js';
|
||||
|
||||
import { createSubscription, PUBSUB_CHANNEL } from '@app/core/pubsub.js';
|
||||
import { LogFile, LogFileContent } from '@app/unraid-api/graph/resolvers/logs/logs.model.js';
|
||||
import { LogsService } from '@app/unraid-api/graph/resolvers/logs/logs.service.js';
|
||||
import { SubscriptionHelperService } from '@app/unraid-api/graph/services/subscription-helper.service.js';
|
||||
|
||||
@Resolver(() => LogFile)
|
||||
export class LogsResolver {
|
||||
constructor(private readonly logsService: LogsService) {}
|
||||
constructor(
|
||||
private readonly logsService: LogsService,
|
||||
private readonly subscriptionHelper: SubscriptionHelperService
|
||||
) {}
|
||||
|
||||
@Query(() => [LogFile])
|
||||
@UsePermissions({
|
||||
@@ -38,27 +41,12 @@ export class LogsResolver {
|
||||
action: AuthAction.READ_ANY,
|
||||
resource: Resource.LOGS,
|
||||
})
|
||||
async logFileSubscription(@Args('path') path: string) {
|
||||
// Start watching the file
|
||||
this.logsService.getLogFileSubscriptionChannel(path);
|
||||
logFileSubscription(@Args('path') path: string) {
|
||||
// Register the topic and get the key
|
||||
const topicKey = this.logsService.registerLogFileSubscription(path);
|
||||
|
||||
// Create the async iterator
|
||||
const asyncIterator = createSubscription(PUBSUB_CHANNEL.LOG_FILE);
|
||||
|
||||
// Store the original return method to wrap it
|
||||
const originalReturn = asyncIterator.return;
|
||||
|
||||
// Override the return method to clean up resources
|
||||
asyncIterator.return = async () => {
|
||||
// Stop watching the file when subscription ends
|
||||
this.logsService.stopWatchingLogFile(path);
|
||||
|
||||
// Call the original return method
|
||||
return originalReturn
|
||||
? originalReturn.call(asyncIterator)
|
||||
: Promise.resolve({ value: undefined, done: true });
|
||||
};
|
||||
|
||||
return asyncIterator;
|
||||
// Use the helper service to create a tracked subscription
|
||||
// This automatically handles subscribe/unsubscribe with reference counting
|
||||
return this.subscriptionHelper.createTrackedSubscription(topicKey);
|
||||
}
|
||||
}
|
||||
|
||||
201
api/src/unraid-api/graph/resolvers/logs/logs.service.spec.ts
Normal file
201
api/src/unraid-api/graph/resolvers/logs/logs.service.spec.ts
Normal file
@@ -0,0 +1,201 @@
|
||||
import { Test, TestingModule } from '@nestjs/testing';
|
||||
import * as fs from 'node:fs/promises';
|
||||
|
||||
import * as chokidar from 'chokidar';
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { LogWatcherManager } from '@app/unraid-api/graph/resolvers/logs/log-watcher-manager.service.js';
|
||||
import { LogsService } from '@app/unraid-api/graph/resolvers/logs/logs.service.js';
|
||||
import { SubscriptionTrackerService } from '@app/unraid-api/graph/services/subscription-tracker.service.js';
|
||||
|
||||
vi.mock('node:fs/promises');
|
||||
vi.mock('chokidar');
|
||||
vi.mock('@app/store/index.js', () => ({
|
||||
getters: {
|
||||
paths: () => ({
|
||||
'unraid-log-base': '/var/log',
|
||||
}),
|
||||
},
|
||||
}));
|
||||
vi.mock('@app/core/pubsub.js', () => ({
|
||||
pubsub: {
|
||||
publish: vi.fn(),
|
||||
},
|
||||
PUBSUB_CHANNEL: {},
|
||||
}));
|
||||
|
||||
describe('LogsService', () => {
|
||||
let service: LogsService;
|
||||
let mockWatcher: any;
|
||||
let subscriptionTracker: any;
|
||||
|
||||
beforeEach(async () => {
|
||||
// Create a mock watcher
|
||||
mockWatcher = {
|
||||
on: vi.fn(),
|
||||
close: vi.fn(),
|
||||
};
|
||||
|
||||
// Mock chokidar.watch to return our mock watcher
|
||||
vi.mocked(chokidar.watch).mockReturnValue(mockWatcher as any);
|
||||
|
||||
// Mock fs.stat to return a file size
|
||||
vi.mocked(fs.stat).mockResolvedValue({ size: 1000 } as any);
|
||||
|
||||
subscriptionTracker = {
|
||||
getSubscriberCount: vi.fn().mockReturnValue(0),
|
||||
registerTopic: vi.fn(),
|
||||
};
|
||||
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
providers: [
|
||||
LogsService,
|
||||
LogWatcherManager,
|
||||
{
|
||||
provide: SubscriptionTrackerService,
|
||||
useValue: subscriptionTracker,
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
service = module.get<LogsService>(LogsService);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should be defined', () => {
|
||||
expect(service).toBeDefined();
|
||||
});
|
||||
|
||||
it('should handle race condition when stopping watcher during initialization', async () => {
|
||||
// Setup: Register the subscription which will trigger registerTopic
|
||||
service.registerLogFileSubscription('test.log');
|
||||
|
||||
// Get the onStart callback that was registered
|
||||
const registerTopicCall = subscriptionTracker.registerTopic.mock.calls[0];
|
||||
const onStartCallback = registerTopicCall[1];
|
||||
const onStopCallback = registerTopicCall[2];
|
||||
|
||||
// Create a promise to control when stat resolves
|
||||
let statResolve: any;
|
||||
const statPromise = new Promise((resolve) => {
|
||||
statResolve = resolve;
|
||||
});
|
||||
vi.mocked(fs.stat).mockReturnValue(statPromise as any);
|
||||
|
||||
// Start the watcher (this will call startWatchingLogFile internally)
|
||||
onStartCallback();
|
||||
|
||||
// At this point, the watcher should be marked as 'initializing'
|
||||
// Now call stop before the stat promise resolves
|
||||
onStopCallback();
|
||||
|
||||
// Now resolve the stat promise to complete initialization
|
||||
statResolve({ size: 1000 });
|
||||
|
||||
// Wait for any async operations to complete
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
|
||||
// The watcher should have been closed due to the race condition check
|
||||
expect(mockWatcher.close).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should not leak watcher if stopped multiple times during initialization', async () => {
|
||||
// Setup: Register the subscription
|
||||
service.registerLogFileSubscription('test.log');
|
||||
|
||||
const registerTopicCall = subscriptionTracker.registerTopic.mock.calls[0];
|
||||
const onStartCallback = registerTopicCall[1];
|
||||
const onStopCallback = registerTopicCall[2];
|
||||
|
||||
// Create controlled stat promise
|
||||
let statResolve: any;
|
||||
const statPromise = new Promise((resolve) => {
|
||||
statResolve = resolve;
|
||||
});
|
||||
vi.mocked(fs.stat).mockReturnValue(statPromise as any);
|
||||
|
||||
// Start the watcher
|
||||
onStartCallback();
|
||||
|
||||
// Call stop multiple times during initialization
|
||||
onStopCallback();
|
||||
onStopCallback();
|
||||
onStopCallback();
|
||||
|
||||
// Complete initialization
|
||||
statResolve({ size: 1000 });
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
|
||||
// Should only close once
|
||||
expect(mockWatcher.close).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should properly handle normal start and stop without race condition', async () => {
|
||||
// Setup: Register the subscription
|
||||
service.registerLogFileSubscription('test.log');
|
||||
|
||||
const registerTopicCall = subscriptionTracker.registerTopic.mock.calls[0];
|
||||
const onStartCallback = registerTopicCall[1];
|
||||
const onStopCallback = registerTopicCall[2];
|
||||
|
||||
// Make stat resolve immediately
|
||||
vi.mocked(fs.stat).mockResolvedValue({ size: 1000 } as any);
|
||||
|
||||
// Start the watcher and let it complete initialization
|
||||
onStartCallback();
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
|
||||
// Watcher should be created but not closed
|
||||
expect(chokidar.watch).toHaveBeenCalled();
|
||||
expect(mockWatcher.close).not.toHaveBeenCalled();
|
||||
|
||||
// Now stop it normally
|
||||
onStopCallback();
|
||||
|
||||
// Watcher should be closed
|
||||
expect(mockWatcher.close).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should handle error during initialization without leaking watchers', async () => {
|
||||
// Setup: Register the subscription
|
||||
service.registerLogFileSubscription('test.log');
|
||||
|
||||
const registerTopicCall = subscriptionTracker.registerTopic.mock.calls[0];
|
||||
const onStartCallback = registerTopicCall[1];
|
||||
|
||||
// Make stat reject with an error
|
||||
vi.mocked(fs.stat).mockRejectedValue(new Error('File not found'));
|
||||
|
||||
// Start the watcher (should fail during initialization)
|
||||
onStartCallback();
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
|
||||
// Watcher should never be created due to stat error
|
||||
expect(chokidar.watch).not.toHaveBeenCalled();
|
||||
expect(mockWatcher.close).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should not create duplicate watchers when started multiple times', async () => {
|
||||
// Setup: Register the subscription
|
||||
service.registerLogFileSubscription('test.log');
|
||||
|
||||
const registerTopicCall = subscriptionTracker.registerTopic.mock.calls[0];
|
||||
const onStartCallback = registerTopicCall[1];
|
||||
|
||||
// Make stat resolve immediately
|
||||
vi.mocked(fs.stat).mockResolvedValue({ size: 1000 } as any);
|
||||
|
||||
// Start the watcher multiple times
|
||||
onStartCallback();
|
||||
onStartCallback();
|
||||
onStartCallback();
|
||||
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
|
||||
// Should only create one watcher
|
||||
expect(chokidar.watch).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
@@ -1,13 +1,15 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { createReadStream } from 'node:fs';
|
||||
import { readdir, readFile, stat } from 'node:fs/promises';
|
||||
import { readdir, stat } from 'node:fs/promises';
|
||||
import { basename, join } from 'node:path';
|
||||
import { createInterface } from 'node:readline';
|
||||
|
||||
import * as chokidar from 'chokidar';
|
||||
|
||||
import { pubsub, PUBSUB_CHANNEL } from '@app/core/pubsub.js';
|
||||
import { pubsub } from '@app/core/pubsub.js';
|
||||
import { getters } from '@app/store/index.js';
|
||||
import { LogWatcherManager } from '@app/unraid-api/graph/resolvers/logs/log-watcher-manager.service.js';
|
||||
import { SubscriptionTrackerService } from '@app/unraid-api/graph/services/subscription-tracker.service.js';
|
||||
|
||||
interface LogFile {
|
||||
name: string;
|
||||
@@ -26,12 +28,13 @@ interface LogFileContent {
|
||||
@Injectable()
|
||||
export class LogsService {
|
||||
private readonly logger = new Logger(LogsService.name);
|
||||
private readonly logWatchers = new Map<
|
||||
string,
|
||||
{ watcher: chokidar.FSWatcher; position: number; subscriptionCount: number }
|
||||
>();
|
||||
private readonly DEFAULT_LINES = 100;
|
||||
|
||||
constructor(
|
||||
private readonly subscriptionTracker: SubscriptionTrackerService,
|
||||
private readonly watcherManager: LogWatcherManager
|
||||
) {}
|
||||
|
||||
/**
|
||||
* Get the base path for log files
|
||||
*/
|
||||
@@ -111,135 +114,208 @@ export class LogsService {
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the subscription channel for a log file
|
||||
* Register and get the topic key for a log file subscription
|
||||
* @param path Path to the log file
|
||||
* @returns The subscription topic key
|
||||
*/
|
||||
getLogFileSubscriptionChannel(path: string): PUBSUB_CHANNEL {
|
||||
registerLogFileSubscription(path: string): string {
|
||||
const normalizedPath = join(this.logBasePath, basename(path));
|
||||
const topicKey = this.getTopicKey(normalizedPath);
|
||||
|
||||
// Start watching the file if not already watching
|
||||
if (!this.logWatchers.has(normalizedPath)) {
|
||||
this.startWatchingLogFile(normalizedPath);
|
||||
} else {
|
||||
// Increment subscription count for existing watcher
|
||||
const watcher = this.logWatchers.get(normalizedPath);
|
||||
if (watcher) {
|
||||
watcher.subscriptionCount++;
|
||||
this.logger.debug(
|
||||
`Incremented subscription count for ${normalizedPath} to ${watcher.subscriptionCount}`
|
||||
);
|
||||
}
|
||||
// Register the topic if not already registered
|
||||
if (!this.subscriptionTracker.getSubscriberCount(topicKey)) {
|
||||
this.logger.debug(`Registering log file subscription topic: ${topicKey}`);
|
||||
|
||||
this.subscriptionTracker.registerTopic(
|
||||
topicKey,
|
||||
// onStart handler
|
||||
() => {
|
||||
this.logger.debug(`Starting log file watcher for topic: ${topicKey}`);
|
||||
this.startWatchingLogFile(normalizedPath);
|
||||
},
|
||||
// onStop handler
|
||||
() => {
|
||||
this.logger.debug(`Stopping log file watcher for topic: ${topicKey}`);
|
||||
this.stopWatchingLogFile(normalizedPath);
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
return PUBSUB_CHANNEL.LOG_FILE;
|
||||
return topicKey;
|
||||
}
|
||||
|
||||
/**
|
||||
* Start watching a log file for changes using chokidar
|
||||
* @param path Path to the log file
|
||||
*/
|
||||
private async startWatchingLogFile(path: string): Promise<void> {
|
||||
try {
|
||||
// Get initial file size
|
||||
const stats = await stat(path);
|
||||
let position = stats.size;
|
||||
private startWatchingLogFile(path: string): void {
|
||||
const watcherKey = path;
|
||||
|
||||
// Create a watcher for the file using chokidar
|
||||
const watcher = chokidar.watch(path, {
|
||||
persistent: true,
|
||||
awaitWriteFinish: {
|
||||
stabilityThreshold: 300,
|
||||
pollInterval: 100,
|
||||
},
|
||||
});
|
||||
// Check if already watching or initializing
|
||||
if (this.watcherManager.isWatchingOrInitializing(watcherKey)) {
|
||||
this.logger.debug(`Already watching or initializing log file: ${watcherKey}`);
|
||||
return;
|
||||
}
|
||||
|
||||
watcher.on('change', async () => {
|
||||
try {
|
||||
const newStats = await stat(path);
|
||||
// Mark as initializing immediately to prevent race conditions
|
||||
this.watcherManager.setInitializing(watcherKey);
|
||||
|
||||
// If the file has grown
|
||||
if (newStats.size > position) {
|
||||
// Read only the new content
|
||||
const stream = createReadStream(path, {
|
||||
start: position,
|
||||
end: newStats.size - 1,
|
||||
});
|
||||
// Get initial file size and set up watcher
|
||||
stat(path)
|
||||
.then((stats) => {
|
||||
const position = stats.size;
|
||||
|
||||
let newContent = '';
|
||||
stream.on('data', (chunk) => {
|
||||
newContent += chunk.toString();
|
||||
});
|
||||
// Create a watcher for the file using chokidar
|
||||
const watcher = chokidar.watch(path, {
|
||||
persistent: true,
|
||||
awaitWriteFinish: {
|
||||
stabilityThreshold: 300,
|
||||
pollInterval: 100,
|
||||
},
|
||||
});
|
||||
|
||||
stream.on('end', () => {
|
||||
if (newContent) {
|
||||
pubsub.publish(PUBSUB_CHANNEL.LOG_FILE, {
|
||||
watcher.on('change', async () => {
|
||||
// Check if we're already processing a change event for this file
|
||||
if (!this.watcherManager.startProcessing(watcherKey)) {
|
||||
// Already processing, ignore this event
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const newStats = await stat(path);
|
||||
|
||||
// Get the current position
|
||||
const currentPosition = this.watcherManager.getPosition(watcherKey);
|
||||
if (currentPosition === undefined) {
|
||||
// Watcher was stopped or not active, ignore the event
|
||||
return;
|
||||
}
|
||||
|
||||
// If the file has grown
|
||||
if (newStats.size > currentPosition) {
|
||||
// Read only the new content
|
||||
const stream = createReadStream(path, {
|
||||
start: currentPosition,
|
||||
end: newStats.size - 1,
|
||||
});
|
||||
|
||||
let newContent = '';
|
||||
stream.on('data', (chunk) => {
|
||||
newContent += chunk.toString();
|
||||
});
|
||||
|
||||
stream.on('end', () => {
|
||||
try {
|
||||
if (newContent) {
|
||||
// Use topic-specific channel
|
||||
const topicKey = this.getTopicKey(path);
|
||||
pubsub.publish(topicKey, {
|
||||
logFile: {
|
||||
path,
|
||||
content: newContent,
|
||||
totalLines: 0, // We don't need to count lines for updates
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Update position for next read (while still holding the guard)
|
||||
this.watcherManager.updatePosition(watcherKey, newStats.size);
|
||||
} finally {
|
||||
// Clear the in-flight flag
|
||||
this.watcherManager.finishProcessing(watcherKey);
|
||||
}
|
||||
});
|
||||
|
||||
stream.on('error', (error) => {
|
||||
this.logger.error(`Error reading stream for ${path}: ${error}`);
|
||||
// Clear the in-flight flag on error
|
||||
this.watcherManager.finishProcessing(watcherKey);
|
||||
});
|
||||
} else if (newStats.size < currentPosition) {
|
||||
// File was truncated, reset position and read from beginning
|
||||
this.logger.debug(`File ${path} was truncated, resetting position`);
|
||||
|
||||
try {
|
||||
// Read the entire file content
|
||||
const content = await this.getLogFileContent(
|
||||
path,
|
||||
this.DEFAULT_LINES,
|
||||
undefined
|
||||
);
|
||||
|
||||
// Use topic-specific channel
|
||||
const topicKey = this.getTopicKey(path);
|
||||
pubsub.publish(topicKey, {
|
||||
logFile: {
|
||||
path,
|
||||
content: newContent,
|
||||
totalLines: 0, // We don't need to count lines for updates
|
||||
...content,
|
||||
},
|
||||
});
|
||||
|
||||
// Update position (while still holding the guard)
|
||||
this.watcherManager.updatePosition(watcherKey, newStats.size);
|
||||
} finally {
|
||||
// Clear the in-flight flag
|
||||
this.watcherManager.finishProcessing(watcherKey);
|
||||
}
|
||||
|
||||
// Update position for next read
|
||||
position = newStats.size;
|
||||
});
|
||||
} else if (newStats.size < position) {
|
||||
// File was truncated, reset position and read from beginning
|
||||
position = 0;
|
||||
this.logger.debug(`File ${path} was truncated, resetting position`);
|
||||
|
||||
// Read the entire file content
|
||||
const content = await this.getLogFileContent(path);
|
||||
|
||||
pubsub.publish(PUBSUB_CHANNEL.LOG_FILE, {
|
||||
logFile: content,
|
||||
});
|
||||
|
||||
position = newStats.size;
|
||||
} else {
|
||||
// File size unchanged, clear the in-flight flag
|
||||
this.watcherManager.finishProcessing(watcherKey);
|
||||
}
|
||||
} catch (error: unknown) {
|
||||
this.logger.error(`Error processing file change for ${path}: ${error}`);
|
||||
// Clear the in-flight flag on error
|
||||
this.watcherManager.finishProcessing(watcherKey);
|
||||
}
|
||||
} catch (error: unknown) {
|
||||
this.logger.error(`Error processing file change for ${path}: ${error}`);
|
||||
});
|
||||
|
||||
watcher.on('error', (error) => {
|
||||
this.logger.error(`Chokidar watcher error for ${path}: ${error}`);
|
||||
});
|
||||
|
||||
// Check if we were stopped during initialization and handle cleanup
|
||||
if (!this.watcherManager.handlePostInitialization(watcherKey, watcher, position)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Publish initial snapshot
|
||||
this.getLogFileContent(path, this.DEFAULT_LINES, undefined)
|
||||
.then((content) => {
|
||||
const topicKey = this.getTopicKey(path);
|
||||
pubsub.publish(topicKey, {
|
||||
logFile: {
|
||||
...content,
|
||||
},
|
||||
});
|
||||
})
|
||||
.catch((error) => {
|
||||
this.logger.error(`Error publishing initial log content for ${path}: ${error}`);
|
||||
});
|
||||
|
||||
this.logger.debug(`Started watching log file with chokidar: ${path}`);
|
||||
})
|
||||
.catch((error) => {
|
||||
this.logger.error(`Error setting up file watcher for ${path}: ${error}`);
|
||||
// Clean up the initializing entry on error
|
||||
this.watcherManager.removeEntry(watcherKey);
|
||||
});
|
||||
}
|
||||
|
||||
watcher.on('error', (error) => {
|
||||
this.logger.error(`Chokidar watcher error for ${path}: ${error}`);
|
||||
});
|
||||
|
||||
// Store the watcher and current position with initial subscription count of 1
|
||||
this.logWatchers.set(path, { watcher, position, subscriptionCount: 1 });
|
||||
|
||||
this.logger.debug(
|
||||
`Started watching log file with chokidar: ${path} (subscription count: 1)`
|
||||
);
|
||||
} catch (error: unknown) {
|
||||
this.logger.error(`Error setting up chokidar file watcher for ${path}: ${error}`);
|
||||
}
|
||||
/**
|
||||
* Get the topic key for a log file subscription
|
||||
* @param path Path to the log file (should already be normalized)
|
||||
* @returns The topic key
|
||||
*/
|
||||
private getTopicKey(path: string): string {
|
||||
// Assume path is already normalized (full path)
|
||||
return `LOG_FILE:${path}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop watching a log file
|
||||
* @param path Path to the log file
|
||||
*/
|
||||
public stopWatchingLogFile(path: string): void {
|
||||
const normalizedPath = join(this.logBasePath, basename(path));
|
||||
const watcher = this.logWatchers.get(normalizedPath);
|
||||
|
||||
if (watcher) {
|
||||
// Decrement subscription count
|
||||
watcher.subscriptionCount--;
|
||||
this.logger.debug(
|
||||
`Decremented subscription count for ${normalizedPath} to ${watcher.subscriptionCount}`
|
||||
);
|
||||
|
||||
// Only close the watcher when subscription count reaches 0
|
||||
if (watcher.subscriptionCount <= 0) {
|
||||
watcher.watcher.close();
|
||||
this.logWatchers.delete(normalizedPath);
|
||||
this.logger.debug(`Stopped watching log file: ${normalizedPath} (no more subscribers)`);
|
||||
}
|
||||
}
|
||||
private stopWatchingLogFile(path: string): void {
|
||||
this.watcherManager.stopWatcher(path);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -9,7 +9,7 @@ import { CpuService } from '@app/unraid-api/graph/resolvers/info/cpu/cpu.service
|
||||
import { MemoryService } from '@app/unraid-api/graph/resolvers/info/memory/memory.service.js';
|
||||
import { MetricsResolver } from '@app/unraid-api/graph/resolvers/metrics/metrics.resolver.js';
|
||||
import { SubscriptionHelperService } from '@app/unraid-api/graph/services/subscription-helper.service.js';
|
||||
import { SubscriptionPollingService } from '@app/unraid-api/graph/services/subscription-polling.service.js';
|
||||
import { SubscriptionManagerService } from '@app/unraid-api/graph/services/subscription-manager.service.js';
|
||||
import { SubscriptionTrackerService } from '@app/unraid-api/graph/services/subscription-tracker.service.js';
|
||||
|
||||
describe('MetricsResolver Integration Tests', () => {
|
||||
@@ -25,7 +25,7 @@ describe('MetricsResolver Integration Tests', () => {
|
||||
MemoryService,
|
||||
SubscriptionTrackerService,
|
||||
SubscriptionHelperService,
|
||||
SubscriptionPollingService,
|
||||
SubscriptionManagerService,
|
||||
],
|
||||
}).compile();
|
||||
|
||||
@@ -36,8 +36,8 @@ describe('MetricsResolver Integration Tests', () => {
|
||||
|
||||
afterEach(async () => {
|
||||
// Clean up polling service
|
||||
const pollingService = module.get<SubscriptionPollingService>(SubscriptionPollingService);
|
||||
pollingService.stopAll();
|
||||
const subscriptionManager = module.get<SubscriptionManagerService>(SubscriptionManagerService);
|
||||
subscriptionManager.stopAll();
|
||||
await module.close();
|
||||
});
|
||||
|
||||
@@ -202,10 +202,13 @@ describe('MetricsResolver Integration Tests', () => {
|
||||
it('should handle errors in CPU polling gracefully', async () => {
|
||||
const service = module.get<CpuService>(CpuService);
|
||||
const trackerService = module.get<SubscriptionTrackerService>(SubscriptionTrackerService);
|
||||
const pollingService = module.get<SubscriptionPollingService>(SubscriptionPollingService);
|
||||
const subscriptionManager =
|
||||
module.get<SubscriptionManagerService>(SubscriptionManagerService);
|
||||
|
||||
// Mock logger to capture error logs
|
||||
const loggerSpy = vi.spyOn(pollingService['logger'], 'error').mockImplementation(() => {});
|
||||
const loggerSpy = vi
|
||||
.spyOn(subscriptionManager['logger'], 'error')
|
||||
.mockImplementation(() => {});
|
||||
vi.spyOn(service, 'generateCpuLoad').mockRejectedValueOnce(new Error('CPU error'));
|
||||
|
||||
// Trigger polling
|
||||
@@ -215,7 +218,7 @@ describe('MetricsResolver Integration Tests', () => {
|
||||
await new Promise((resolve) => setTimeout(resolve, 1100));
|
||||
|
||||
expect(loggerSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Error in polling task'),
|
||||
expect.stringContaining('Error in subscription callback'),
|
||||
expect.any(Error)
|
||||
);
|
||||
|
||||
@@ -226,10 +229,13 @@ describe('MetricsResolver Integration Tests', () => {
|
||||
it('should handle errors in memory polling gracefully', async () => {
|
||||
const service = module.get<MemoryService>(MemoryService);
|
||||
const trackerService = module.get<SubscriptionTrackerService>(SubscriptionTrackerService);
|
||||
const pollingService = module.get<SubscriptionPollingService>(SubscriptionPollingService);
|
||||
const subscriptionManager =
|
||||
module.get<SubscriptionManagerService>(SubscriptionManagerService);
|
||||
|
||||
// Mock logger to capture error logs
|
||||
const loggerSpy = vi.spyOn(pollingService['logger'], 'error').mockImplementation(() => {});
|
||||
const loggerSpy = vi
|
||||
.spyOn(subscriptionManager['logger'], 'error')
|
||||
.mockImplementation(() => {});
|
||||
vi.spyOn(service, 'generateMemoryLoad').mockRejectedValueOnce(new Error('Memory error'));
|
||||
|
||||
// Trigger polling
|
||||
@@ -239,7 +245,7 @@ describe('MetricsResolver Integration Tests', () => {
|
||||
await new Promise((resolve) => setTimeout(resolve, 2100));
|
||||
|
||||
expect(loggerSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Error in polling task'),
|
||||
expect.stringContaining('Error in subscription callback'),
|
||||
expect.any(Error)
|
||||
);
|
||||
|
||||
@@ -251,22 +257,30 @@ describe('MetricsResolver Integration Tests', () => {
|
||||
describe('Polling cleanup on module destroy', () => {
|
||||
it('should clean up timers when module is destroyed', async () => {
|
||||
const trackerService = module.get<SubscriptionTrackerService>(SubscriptionTrackerService);
|
||||
const pollingService = module.get<SubscriptionPollingService>(SubscriptionPollingService);
|
||||
const subscriptionManager =
|
||||
module.get<SubscriptionManagerService>(SubscriptionManagerService);
|
||||
|
||||
// Start polling
|
||||
trackerService.subscribe(PUBSUB_CHANNEL.CPU_UTILIZATION);
|
||||
trackerService.subscribe(PUBSUB_CHANNEL.MEMORY_UTILIZATION);
|
||||
|
||||
// Verify polling is active
|
||||
expect(pollingService.isPolling(PUBSUB_CHANNEL.CPU_UTILIZATION)).toBe(true);
|
||||
expect(pollingService.isPolling(PUBSUB_CHANNEL.MEMORY_UTILIZATION)).toBe(true);
|
||||
// Wait a bit for subscriptions to be fully set up
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
|
||||
// Verify subscriptions are active
|
||||
expect(subscriptionManager.isSubscriptionActive(PUBSUB_CHANNEL.CPU_UTILIZATION)).toBe(true);
|
||||
expect(subscriptionManager.isSubscriptionActive(PUBSUB_CHANNEL.MEMORY_UTILIZATION)).toBe(
|
||||
true
|
||||
);
|
||||
|
||||
// Clean up the module
|
||||
await module.close();
|
||||
|
||||
// Timers should be cleaned up
|
||||
expect(pollingService.isPolling(PUBSUB_CHANNEL.CPU_UTILIZATION)).toBe(false);
|
||||
expect(pollingService.isPolling(PUBSUB_CHANNEL.MEMORY_UTILIZATION)).toBe(false);
|
||||
// Subscriptions should be cleaned up
|
||||
expect(subscriptionManager.isSubscriptionActive(PUBSUB_CHANNEL.CPU_UTILIZATION)).toBe(false);
|
||||
expect(subscriptionManager.isSubscriptionActive(PUBSUB_CHANNEL.MEMORY_UTILIZATION)).toBe(
|
||||
false
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
|
||||
import { AuthModule } from '@app/unraid-api/auth/auth.module.js';
|
||||
import { ApiConfigModule } from '@app/unraid-api/config/api-config.module.js';
|
||||
import { ApiKeyModule } from '@app/unraid-api/graph/resolvers/api-key/api-key.module.js';
|
||||
import { ApiKeyResolver } from '@app/unraid-api/graph/resolvers/api-key/api-key.resolver.js';
|
||||
import { ArrayModule } from '@app/unraid-api/graph/resolvers/array/array.module.js';
|
||||
@@ -11,8 +12,7 @@ import { DockerModule } from '@app/unraid-api/graph/resolvers/docker/docker.modu
|
||||
import { FlashBackupModule } from '@app/unraid-api/graph/resolvers/flash-backup/flash-backup.module.js';
|
||||
import { FlashResolver } from '@app/unraid-api/graph/resolvers/flash/flash.resolver.js';
|
||||
import { InfoModule } from '@app/unraid-api/graph/resolvers/info/info.module.js';
|
||||
import { LogsResolver } from '@app/unraid-api/graph/resolvers/logs/logs.resolver.js';
|
||||
import { LogsService } from '@app/unraid-api/graph/resolvers/logs/logs.service.js';
|
||||
import { LogsModule } from '@app/unraid-api/graph/resolvers/logs/logs.module.js';
|
||||
import { MetricsModule } from '@app/unraid-api/graph/resolvers/metrics/metrics.module.js';
|
||||
import { RootMutationsResolver } from '@app/unraid-api/graph/resolvers/mutation/mutation.resolver.js';
|
||||
import { NotificationsResolver } from '@app/unraid-api/graph/resolvers/notifications/notifications.resolver.js';
|
||||
@@ -39,12 +39,14 @@ import { MeResolver } from '@app/unraid-api/graph/user/user.resolver.js';
|
||||
ServicesModule,
|
||||
ArrayModule,
|
||||
ApiKeyModule,
|
||||
ApiConfigModule,
|
||||
AuthModule,
|
||||
CustomizationModule,
|
||||
DockerModule,
|
||||
DisksModule,
|
||||
FlashBackupModule,
|
||||
InfoModule,
|
||||
LogsModule,
|
||||
RCloneModule,
|
||||
SettingsModule,
|
||||
SsoModule,
|
||||
@@ -54,8 +56,6 @@ import { MeResolver } from '@app/unraid-api/graph/user/user.resolver.js';
|
||||
providers: [
|
||||
ConfigResolver,
|
||||
FlashResolver,
|
||||
LogsResolver,
|
||||
LogsService,
|
||||
MeResolver,
|
||||
NotificationsResolver,
|
||||
NotificationsService,
|
||||
|
||||
@@ -16,8 +16,8 @@ import {
|
||||
} from '@app/unraid-api/graph/resolvers/settings/settings.model.js';
|
||||
import { ApiSettings } from '@app/unraid-api/graph/resolvers/settings/settings.service.js';
|
||||
import { SsoSettings } from '@app/unraid-api/graph/resolvers/settings/sso-settings.model.js';
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/oidc-config.service.js';
|
||||
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/oidc-provider.model.js';
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
|
||||
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
|
||||
|
||||
@Resolver(() => Settings)
|
||||
export class SettingsResolver {
|
||||
|
||||
@@ -7,7 +7,7 @@ import { type ApiConfig } from '@unraid/shared/services/api-config.js';
|
||||
import { UserSettingsService } from '@unraid/shared/services/user-settings.js';
|
||||
import { execa } from 'execa';
|
||||
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/oidc-config.service.js';
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
|
||||
import { createLabeledControl } from '@app/unraid-api/graph/utils/form-utils.js';
|
||||
import { SettingSlice } from '@app/unraid-api/types/json-forms.js';
|
||||
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
|
||||
import { OidcAuthorizationService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-authorization.service.js';
|
||||
import { OidcClaimsService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-claims.service.js';
|
||||
import { OidcTokenExchangeService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-token-exchange.service.js';
|
||||
|
||||
@Module({
|
||||
providers: [OidcAuthorizationService, OidcTokenExchangeService, OidcClaimsService],
|
||||
exports: [OidcAuthorizationService, OidcTokenExchangeService, OidcClaimsService],
|
||||
})
|
||||
export class OidcAuthModule {}
|
||||
@@ -1,70 +1,26 @@
|
||||
import { UnauthorizedException } from '@nestjs/common';
|
||||
import { ConfigService } from '@nestjs/config';
|
||||
import { Test, TestingModule } from '@nestjs/testing';
|
||||
|
||||
import * as client from 'openid-client';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { OidcAuthService } from '@app/unraid-api/graph/resolvers/sso/oidc-auth.service.js';
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/oidc-config.service.js';
|
||||
import { OidcAuthorizationService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-authorization.service.js';
|
||||
import {
|
||||
AuthorizationOperator,
|
||||
AuthorizationRuleMode,
|
||||
OidcAuthorizationRule,
|
||||
OidcProvider,
|
||||
} from '@app/unraid-api/graph/resolvers/sso/oidc-provider.model.js';
|
||||
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/oidc-session.service.js';
|
||||
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/oidc-state.service.js';
|
||||
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/oidc-validation.service.js';
|
||||
} from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
|
||||
|
||||
describe('OidcAuthService', () => {
|
||||
let service: OidcAuthService;
|
||||
let oidcConfig: any;
|
||||
let sessionService: any;
|
||||
let configService: any;
|
||||
let stateService: any;
|
||||
let validationService: any;
|
||||
describe('OidcAuthorizationService', () => {
|
||||
let service: OidcAuthorizationService;
|
||||
let module: TestingModule;
|
||||
|
||||
beforeEach(async () => {
|
||||
module = await Test.createTestingModule({
|
||||
providers: [
|
||||
OidcAuthService,
|
||||
{
|
||||
provide: ConfigService,
|
||||
useValue: {
|
||||
get: vi.fn(),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcConfigPersistence,
|
||||
useValue: {
|
||||
getProvider: vi.fn(),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcSessionService,
|
||||
useValue: {
|
||||
createSession: vi.fn(),
|
||||
},
|
||||
},
|
||||
OidcStateService,
|
||||
{
|
||||
provide: OidcValidationService,
|
||||
useValue: {
|
||||
validateProvider: vi.fn(),
|
||||
performDiscovery: vi.fn(),
|
||||
},
|
||||
},
|
||||
],
|
||||
providers: [OidcAuthorizationService],
|
||||
}).compile();
|
||||
|
||||
service = module.get<OidcAuthService>(OidcAuthService);
|
||||
oidcConfig = module.get(OidcConfigPersistence);
|
||||
sessionService = module.get(OidcSessionService);
|
||||
configService = module.get(ConfigService);
|
||||
stateService = module.get(OidcStateService);
|
||||
validationService = module.get<OidcValidationService>(OidcValidationService);
|
||||
service = module.get<OidcAuthorizationService>(OidcAuthorizationService);
|
||||
});
|
||||
|
||||
describe('Authorization Rule Evaluation', () => {
|
||||
@@ -1189,467 +1145,4 @@ describe('OidcAuthService', () => {
|
||||
).resolves.toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Manual Configuration (No Discovery)', () => {
|
||||
it('should create manual configuration when discovery fails but manual endpoints are provided', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'manual-provider',
|
||||
name: 'Manual Provider',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
issuer: 'https://manual.example.com',
|
||||
authorizationEndpoint: 'https://manual.example.com/auth',
|
||||
tokenEndpoint: 'https://manual.example.com/token',
|
||||
jwksUri: 'https://manual.example.com/jwks',
|
||||
scopes: ['openid', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
|
||||
// Mock discovery to fail
|
||||
validationService.performDiscovery = vi
|
||||
.fn()
|
||||
.mockRejectedValue(new Error('Discovery failed'));
|
||||
|
||||
// Access the private method
|
||||
const getOrCreateConfig = async (provider: OidcProvider) => {
|
||||
return (service as any).getOrCreateConfig(provider);
|
||||
};
|
||||
|
||||
const config = await getOrCreateConfig(provider);
|
||||
|
||||
// Verify the configuration was created with the correct endpoints
|
||||
expect(config).toBeDefined();
|
||||
expect(config.serverMetadata().authorization_endpoint).toBe(
|
||||
'https://manual.example.com/auth'
|
||||
);
|
||||
expect(config.serverMetadata().token_endpoint).toBe('https://manual.example.com/token');
|
||||
expect(config.serverMetadata().jwks_uri).toBe('https://manual.example.com/jwks');
|
||||
expect(config.serverMetadata().issuer).toBe('https://manual.example.com');
|
||||
});
|
||||
|
||||
it('should create manual configuration with fallback issuer when not provided', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'manual-provider-no-issuer',
|
||||
name: 'Manual Provider No Issuer',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
issuer: '', // Empty issuer should skip discovery and use manual endpoints
|
||||
authorizationEndpoint: 'https://manual.example.com/auth',
|
||||
tokenEndpoint: 'https://manual.example.com/token',
|
||||
scopes: ['openid', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
|
||||
// No need to mock discovery since it won't be called with empty issuer
|
||||
|
||||
// Access the private method
|
||||
const getOrCreateConfig = async (provider: OidcProvider) => {
|
||||
return (service as any).getOrCreateConfig(provider);
|
||||
};
|
||||
|
||||
const config = await getOrCreateConfig(provider);
|
||||
|
||||
// Verify the configuration was created with fallback issuer
|
||||
expect(config).toBeDefined();
|
||||
expect(config.serverMetadata().issuer).toBe('manual-manual-provider-no-issuer');
|
||||
expect(config.serverMetadata().authorization_endpoint).toBe(
|
||||
'https://manual.example.com/auth'
|
||||
);
|
||||
expect(config.serverMetadata().token_endpoint).toBe('https://manual.example.com/token');
|
||||
});
|
||||
|
||||
it('should handle manual configuration with client secret properly', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'manual-with-secret',
|
||||
name: 'Manual With Secret',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'secret-123',
|
||||
issuer: 'https://manual.example.com',
|
||||
authorizationEndpoint: 'https://manual.example.com/auth',
|
||||
tokenEndpoint: 'https://manual.example.com/token',
|
||||
scopes: ['openid', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
|
||||
// Mock discovery to fail
|
||||
validationService.performDiscovery = vi
|
||||
.fn()
|
||||
.mockRejectedValue(new Error('Discovery failed'));
|
||||
|
||||
// Access the private method
|
||||
const getOrCreateConfig = async (provider: OidcProvider) => {
|
||||
return (service as any).getOrCreateConfig(provider);
|
||||
};
|
||||
|
||||
const config = await getOrCreateConfig(provider);
|
||||
|
||||
// Verify configuration was created successfully
|
||||
expect(config).toBeDefined();
|
||||
expect(config.clientMetadata().client_secret).toBe('secret-123');
|
||||
});
|
||||
|
||||
it('should handle manual configuration without client secret (public client)', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'manual-public-client',
|
||||
name: 'Manual Public Client',
|
||||
clientId: 'public-client-id',
|
||||
// No client secret
|
||||
issuer: 'https://manual.example.com',
|
||||
authorizationEndpoint: 'https://manual.example.com/auth',
|
||||
tokenEndpoint: 'https://manual.example.com/token',
|
||||
scopes: ['openid', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
|
||||
// Mock discovery to fail
|
||||
validationService.performDiscovery = vi
|
||||
.fn()
|
||||
.mockRejectedValue(new Error('Discovery failed'));
|
||||
|
||||
// Access the private method
|
||||
const getOrCreateConfig = async (provider: OidcProvider) => {
|
||||
return (service as any).getOrCreateConfig(provider);
|
||||
};
|
||||
|
||||
const config = await getOrCreateConfig(provider);
|
||||
|
||||
// Verify configuration was created successfully for public client
|
||||
expect(config).toBeDefined();
|
||||
expect(config.clientMetadata().client_secret).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should throw error when discovery fails and no manual endpoints provided', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'no-manual-endpoints',
|
||||
name: 'No Manual Endpoints',
|
||||
clientId: 'test-client-id',
|
||||
issuer: 'https://broken.example.com',
|
||||
// Missing authorizationEndpoint and tokenEndpoint
|
||||
scopes: ['openid', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
|
||||
// Mock discovery to fail
|
||||
validationService.performDiscovery = vi
|
||||
.fn()
|
||||
.mockRejectedValue(new Error('Discovery failed'));
|
||||
|
||||
// Access the private method
|
||||
const getOrCreateConfig = async (provider: OidcProvider) => {
|
||||
return (service as any).getOrCreateConfig(provider);
|
||||
};
|
||||
|
||||
await expect(getOrCreateConfig(provider)).rejects.toThrow(UnauthorizedException);
|
||||
});
|
||||
|
||||
it('should throw error when only authorization endpoint is provided', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'partial-manual-endpoints',
|
||||
name: 'Partial Manual Endpoints',
|
||||
clientId: 'test-client-id',
|
||||
issuer: 'https://broken.example.com',
|
||||
authorizationEndpoint: 'https://manual.example.com/auth',
|
||||
// Missing tokenEndpoint
|
||||
scopes: ['openid', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
|
||||
// Mock discovery to fail
|
||||
validationService.performDiscovery = vi
|
||||
.fn()
|
||||
.mockRejectedValue(new Error('Discovery failed'));
|
||||
|
||||
// Access the private method
|
||||
const getOrCreateConfig = async (provider: OidcProvider) => {
|
||||
return (service as any).getOrCreateConfig(provider);
|
||||
};
|
||||
|
||||
await expect(getOrCreateConfig(provider)).rejects.toThrow(UnauthorizedException);
|
||||
});
|
||||
|
||||
it('should cache manual configuration properly', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'cache-test',
|
||||
name: 'Cache Test',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-secret',
|
||||
issuer: 'https://manual.example.com',
|
||||
authorizationEndpoint: 'https://manual.example.com/auth',
|
||||
tokenEndpoint: 'https://manual.example.com/token',
|
||||
scopes: ['openid', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
|
||||
// Mock discovery to fail
|
||||
validationService.performDiscovery = vi
|
||||
.fn()
|
||||
.mockRejectedValue(new Error('Discovery failed'));
|
||||
|
||||
// Access the private method
|
||||
const getOrCreateConfig = async (provider: OidcProvider) => {
|
||||
return (service as any).getOrCreateConfig(provider);
|
||||
};
|
||||
|
||||
// First call should create configuration
|
||||
const config1 = await getOrCreateConfig(provider);
|
||||
|
||||
// Second call should return cached configuration
|
||||
const config2 = await getOrCreateConfig(provider);
|
||||
|
||||
expect(config1).toBe(config2); // Should be the exact same instance
|
||||
expect(validationService.performDiscovery).toHaveBeenCalledTimes(1); // Only called once due to caching
|
||||
});
|
||||
|
||||
it('should handle HTTP endpoints with allowInsecureRequests', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'http-endpoints',
|
||||
name: 'HTTP Endpoints',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-secret',
|
||||
issuer: 'http://manual.example.com', // HTTP instead of HTTPS
|
||||
authorizationEndpoint: 'http://manual.example.com/auth',
|
||||
tokenEndpoint: 'http://manual.example.com/token',
|
||||
scopes: ['openid', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
|
||||
// Mock discovery to fail
|
||||
validationService.performDiscovery = vi
|
||||
.fn()
|
||||
.mockRejectedValue(new Error('Discovery failed'));
|
||||
|
||||
// Access the private method
|
||||
const getOrCreateConfig = async (provider: OidcProvider) => {
|
||||
return (service as any).getOrCreateConfig(provider);
|
||||
};
|
||||
|
||||
const config = await getOrCreateConfig(provider);
|
||||
|
||||
// Verify configuration was created successfully even with HTTP
|
||||
expect(config).toBeDefined();
|
||||
expect(config.serverMetadata().token_endpoint).toBe('http://manual.example.com/token');
|
||||
expect(config.serverMetadata().authorization_endpoint).toBe(
|
||||
'http://manual.example.com/auth'
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getAuthorizationUrl', () => {
|
||||
it('should generate authorization URL with custom authorization endpoint', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'test-provider',
|
||||
name: 'Test Provider',
|
||||
clientId: 'test-client-id',
|
||||
issuer: 'https://example.com',
|
||||
authorizationEndpoint: 'https://custom.example.com/auth',
|
||||
scopes: ['openid', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
|
||||
const authUrl = await service.getAuthorizationUrl(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'localhost:3001'
|
||||
);
|
||||
|
||||
expect(authUrl).toContain('https://custom.example.com/auth');
|
||||
expect(authUrl).toContain('client_id=test-client-id');
|
||||
expect(authUrl).toContain('response_type=code');
|
||||
expect(authUrl).toContain('scope=openid+profile');
|
||||
// State should start with provider ID followed by secure state token
|
||||
expect(authUrl).toMatch(/state=test-provider%3A[a-f0-9]+\.[0-9]+\.[a-f0-9]+/);
|
||||
expect(authUrl).toContain('redirect_uri=');
|
||||
});
|
||||
|
||||
it('should encode provider ID in state parameter', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'encode-test-provider',
|
||||
name: 'Encode Test Provider',
|
||||
clientId: 'test-client-id',
|
||||
issuer: 'https://example.com',
|
||||
authorizationEndpoint: 'https://example.com/auth',
|
||||
scopes: ['openid', 'email'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
|
||||
const authUrl = await service.getAuthorizationUrl('encode-test-provider', 'original-state');
|
||||
|
||||
// Verify that the state parameter includes provider ID at the start
|
||||
expect(authUrl).toMatch(/state=encode-test-provider%3A[a-f0-9]+\.[0-9]+\.[a-f0-9]+/);
|
||||
});
|
||||
|
||||
it('should throw error when provider not found', async () => {
|
||||
oidcConfig.getProvider.mockResolvedValue(null);
|
||||
|
||||
await expect(
|
||||
service.getAuthorizationUrl('nonexistent-provider', 'test-state')
|
||||
).rejects.toThrow('Provider nonexistent-provider not found');
|
||||
});
|
||||
|
||||
it('should handle custom scopes properly', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'custom-scopes-provider',
|
||||
name: 'Custom Scopes Provider',
|
||||
clientId: 'test-client-id',
|
||||
issuer: 'https://example.com',
|
||||
authorizationEndpoint: 'https://example.com/auth',
|
||||
scopes: ['openid', 'profile', 'groups', 'custom:scope'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
|
||||
const authUrl = await service.getAuthorizationUrl('custom-scopes-provider', 'test-state');
|
||||
|
||||
expect(authUrl).toContain('scope=openid+profile+groups+custom%3Ascope');
|
||||
});
|
||||
});
|
||||
|
||||
describe('handleCallback', () => {
|
||||
it('should throw error when provider not found in callback', async () => {
|
||||
oidcConfig.getProvider.mockResolvedValue(null);
|
||||
|
||||
await expect(
|
||||
service.handleCallback('nonexistent-provider', 'code', 'redirect-uri')
|
||||
).rejects.toThrow('Provider nonexistent-provider not found');
|
||||
});
|
||||
|
||||
it('should handle malformed state parameter', async () => {
|
||||
await expect(
|
||||
service.handleCallback('invalid-state', 'code', 'redirect-uri')
|
||||
).rejects.toThrow(UnauthorizedException);
|
||||
});
|
||||
|
||||
it('should call getProvider with the provided provider ID', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'test-provider',
|
||||
name: 'Test Provider',
|
||||
clientId: 'test-client-id',
|
||||
issuer: 'https://example.com',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
|
||||
// This will fail during token exchange, but we're testing the provider lookup logic
|
||||
await expect(
|
||||
service.handleCallback('test-provider', 'code', 'redirect-uri')
|
||||
).rejects.toThrow(UnauthorizedException);
|
||||
|
||||
// Verify the provider was looked up with the correct ID
|
||||
expect(oidcConfig.getProvider).toHaveBeenCalledWith('test-provider');
|
||||
});
|
||||
});
|
||||
|
||||
describe('validateProvider', () => {
|
||||
it('should delegate to validation service and return result', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'validate-provider',
|
||||
name: 'Validate Provider',
|
||||
clientId: 'test-client-id',
|
||||
issuer: 'https://example.com',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
const expectedResult = {
|
||||
isValid: true,
|
||||
authorizationEndpoint: 'https://example.com/auth',
|
||||
tokenEndpoint: 'https://example.com/token',
|
||||
};
|
||||
|
||||
validationService.validateProvider.mockResolvedValue(expectedResult);
|
||||
|
||||
const result = await service.validateProvider(provider);
|
||||
|
||||
expect(result).toEqual(expectedResult);
|
||||
expect(validationService.validateProvider).toHaveBeenCalledWith(provider);
|
||||
});
|
||||
|
||||
it('should clear config cache before validation', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'cache-clear-provider',
|
||||
name: 'Cache Clear Provider',
|
||||
clientId: 'test-client-id',
|
||||
issuer: 'https://example.com',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
const expectedResult = {
|
||||
isValid: false,
|
||||
error: 'Validation failed',
|
||||
};
|
||||
|
||||
validationService.validateProvider.mockResolvedValue(expectedResult);
|
||||
|
||||
const result = await service.validateProvider(provider);
|
||||
|
||||
expect(result).toEqual(expectedResult);
|
||||
// Verify the cache was cleared by checking the method was called
|
||||
expect(validationService.validateProvider).toHaveBeenCalledWith(provider);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getRedirectUri (private method)', () => {
|
||||
it('should generate correct redirect URI with localhost (development)', () => {
|
||||
const getRedirectUri = (service as any).getRedirectUri.bind(service);
|
||||
const redirectUri = getRedirectUri('http://localhost:3000');
|
||||
|
||||
expect(redirectUri).toBe('http://localhost:3000/graphql/api/auth/oidc/callback');
|
||||
});
|
||||
|
||||
it('should generate correct redirect URI with non-localhost host', () => {
|
||||
const getRedirectUri = (service as any).getRedirectUri.bind(service);
|
||||
const redirectUri = getRedirectUri('https://example.com');
|
||||
|
||||
expect(redirectUri).toBe('https://example.com/graphql/api/auth/oidc/callback');
|
||||
});
|
||||
|
||||
it('should handle HTTP protocol for non-localhost hosts', () => {
|
||||
const getRedirectUri = (service as any).getRedirectUri.bind(service);
|
||||
const redirectUri = getRedirectUri('http://tower.local');
|
||||
|
||||
expect(redirectUri).toBe('http://tower.local/graphql/api/auth/oidc/callback');
|
||||
});
|
||||
|
||||
it('should handle non-standard ports correctly', () => {
|
||||
const getRedirectUri = (service as any).getRedirectUri.bind(service);
|
||||
const redirectUri = getRedirectUri('http://example.com:8080');
|
||||
|
||||
expect(redirectUri).toBe('http://example.com:8080/graphql/api/auth/oidc/callback');
|
||||
});
|
||||
|
||||
it('should use default redirect URI when no request host provided', () => {
|
||||
const getRedirectUri = (service as any).getRedirectUri.bind(service);
|
||||
|
||||
// Mock the ConfigService to return a default value
|
||||
configService.get.mockReturnValue('http://tower.local');
|
||||
|
||||
const redirectUri = getRedirectUri();
|
||||
|
||||
expect(redirectUri).toBe('http://tower.local/graphql/api/auth/oidc/callback');
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,170 @@
|
||||
import { Injectable, Logger, UnauthorizedException } from '@nestjs/common';
|
||||
|
||||
import {
|
||||
AuthorizationOperator,
|
||||
AuthorizationRuleMode,
|
||||
OidcAuthorizationRule,
|
||||
OidcProvider,
|
||||
} from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
|
||||
|
||||
interface JwtClaims {
|
||||
sub?: string;
|
||||
email?: string;
|
||||
name?: string;
|
||||
hd?: string; // Google hosted domain
|
||||
[claim: string]: unknown;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class OidcAuthorizationService {
|
||||
private readonly logger = new Logger(OidcAuthorizationService.name);
|
||||
|
||||
/**
|
||||
* Check authorization based on rules
|
||||
* This will throw a helpful error if misconfigured or unauthorized
|
||||
*/
|
||||
async checkAuthorization(provider: OidcProvider, claims: JwtClaims): Promise<void> {
|
||||
this.logger.debug(
|
||||
`Checking authorization for provider ${provider.id} with ${provider.authorizationRules?.length || 0} rules`
|
||||
);
|
||||
this.logger.debug(`Available claims: ${Object.keys(claims).join(', ')}`);
|
||||
this.logger.debug(
|
||||
`Authorization rule mode: ${provider.authorizationRuleMode || AuthorizationRuleMode.OR}`
|
||||
);
|
||||
|
||||
// If no authorization rules are specified, throw a helpful error
|
||||
if (!provider.authorizationRules || provider.authorizationRules.length === 0) {
|
||||
throw new UnauthorizedException(
|
||||
`Login failed: The ${provider.name} provider has no authorization rules configured. ` +
|
||||
`Please configure authorization rules.`
|
||||
);
|
||||
}
|
||||
|
||||
this.logger.debug('Authorization rules to evaluate: %o', provider.authorizationRules);
|
||||
|
||||
// Evaluate the rules
|
||||
const ruleMode = provider.authorizationRuleMode || AuthorizationRuleMode.OR;
|
||||
const isAuthorized = this.evaluateAuthorizationRules(
|
||||
provider.authorizationRules,
|
||||
claims,
|
||||
ruleMode
|
||||
);
|
||||
|
||||
this.logger.debug(`Authorization result: ${isAuthorized}`);
|
||||
|
||||
if (!isAuthorized) {
|
||||
// Log authorization failure with safe claim representation (no PII)
|
||||
const availableClaimKeys = Object.keys(claims).join(', ');
|
||||
this.logger.warn(
|
||||
`Authorization failed for provider ${provider.name}, user ${claims.sub}, available claim keys: [${availableClaimKeys}]`
|
||||
);
|
||||
throw new UnauthorizedException(
|
||||
`Access denied: Your account does not meet the authorization requirements for ${provider.name}.`
|
||||
);
|
||||
}
|
||||
|
||||
this.logger.debug(`Authorization successful for user ${claims.sub}`);
|
||||
}
|
||||
|
||||
private evaluateAuthorizationRules(
|
||||
rules: OidcAuthorizationRule[],
|
||||
claims: JwtClaims,
|
||||
mode: AuthorizationRuleMode = AuthorizationRuleMode.OR
|
||||
): boolean {
|
||||
// No rules means no authorization
|
||||
if (rules.length === 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (mode === AuthorizationRuleMode.AND) {
|
||||
// All rules must pass (AND logic)
|
||||
return rules.every((rule) => this.evaluateRule(rule, claims));
|
||||
} else {
|
||||
// Any rule can pass (OR logic) - default behavior
|
||||
// Multiple rules act as alternative authorization paths
|
||||
return rules.some((rule) => this.evaluateRule(rule, claims));
|
||||
}
|
||||
}
|
||||
|
||||
private evaluateRule(rule: OidcAuthorizationRule, claims: JwtClaims): boolean {
|
||||
const claimValue = claims[rule.claim];
|
||||
|
||||
this.logger.verbose(
|
||||
`Evaluating rule for claim ${rule.claim}: { claimType: ${typeof claimValue}, isArray: ${Array.isArray(claimValue)}, ruleOperator: ${rule.operator}, ruleValuesCount: ${rule.value.length} }`
|
||||
);
|
||||
|
||||
if (claimValue === undefined || claimValue === null) {
|
||||
this.logger.verbose(`Claim ${rule.claim} not found in token`);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Handle non-array, non-string objects
|
||||
if (typeof claimValue === 'object' && claimValue !== null && !Array.isArray(claimValue)) {
|
||||
this.logger.warn(
|
||||
`unexpected JWT claim value encountered - claim ${rule.claim} has unsupported object type (keys: [${Object.keys(claimValue as Record<string, unknown>).join(', ')}])`
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Handle array claims - evaluate rule against each array element
|
||||
if (Array.isArray(claimValue)) {
|
||||
this.logger.verbose(
|
||||
`Processing array claim ${rule.claim} with ${claimValue.length} elements`
|
||||
);
|
||||
|
||||
// For array claims, check if ANY element in the array matches the rule
|
||||
const arrayResult = claimValue.some((element) => {
|
||||
// Skip non-string elements
|
||||
if (
|
||||
typeof element !== 'string' &&
|
||||
typeof element !== 'number' &&
|
||||
typeof element !== 'boolean'
|
||||
) {
|
||||
this.logger.verbose(`Skipping non-primitive element in array: ${typeof element}`);
|
||||
return false;
|
||||
}
|
||||
|
||||
const elementValue = String(element);
|
||||
return this.evaluateSingleValue(elementValue, rule);
|
||||
});
|
||||
|
||||
this.logger.verbose(`Array evaluation result for claim ${rule.claim}: ${arrayResult}`);
|
||||
return arrayResult;
|
||||
}
|
||||
|
||||
// Handle single value claims (string, number, boolean)
|
||||
const value = String(claimValue);
|
||||
this.logger.verbose(`Processing single value claim ${rule.claim}`);
|
||||
|
||||
return this.evaluateSingleValue(value, rule);
|
||||
}
|
||||
|
||||
private evaluateSingleValue(value: string, rule: OidcAuthorizationRule): boolean {
|
||||
let result: boolean;
|
||||
switch (rule.operator) {
|
||||
case AuthorizationOperator.EQUALS:
|
||||
result = rule.value.some((v) => value === v);
|
||||
this.logger.verbose(`EQUALS check: evaluated for claim ${rule.claim}: ${result}`);
|
||||
return result;
|
||||
|
||||
case AuthorizationOperator.CONTAINS:
|
||||
result = rule.value.some((v) => value.includes(v));
|
||||
this.logger.verbose(`CONTAINS check: evaluated for claim ${rule.claim}: ${result}`);
|
||||
return result;
|
||||
|
||||
case AuthorizationOperator.STARTS_WITH:
|
||||
result = rule.value.some((v) => value.startsWith(v));
|
||||
this.logger.verbose(`STARTS_WITH check: evaluated for claim ${rule.claim}: ${result}`);
|
||||
return result;
|
||||
|
||||
case AuthorizationOperator.ENDS_WITH:
|
||||
result = rule.value.some((v) => value.endsWith(v));
|
||||
this.logger.verbose(`ENDS_WITH check: evaluated for claim ${rule.claim}: ${result}`);
|
||||
return result;
|
||||
|
||||
default:
|
||||
this.logger.error(`Unknown authorization operator: ${rule.operator}`);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
import { UnauthorizedException } from '@nestjs/common';
|
||||
import { Test, TestingModule } from '@nestjs/testing';
|
||||
|
||||
import { decodeJwt } from 'jose';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import {
|
||||
JwtClaims,
|
||||
OidcClaimsService,
|
||||
} from '@app/unraid-api/graph/resolvers/sso/auth/oidc-claims.service.js';
|
||||
|
||||
// Mock jose
|
||||
vi.mock('jose', () => ({
|
||||
decodeJwt: vi.fn(),
|
||||
}));
|
||||
|
||||
describe('OidcClaimsService', () => {
|
||||
let service: OidcClaimsService;
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
providers: [OidcClaimsService],
|
||||
}).compile();
|
||||
|
||||
service = module.get<OidcClaimsService>(OidcClaimsService);
|
||||
});
|
||||
|
||||
describe('parseIdToken', () => {
|
||||
it('should parse valid ID token', () => {
|
||||
const mockClaims: JwtClaims = {
|
||||
sub: 'user123',
|
||||
email: 'user@example.com',
|
||||
name: 'Test User',
|
||||
iat: 1234567890,
|
||||
exp: 1234567890,
|
||||
};
|
||||
|
||||
(decodeJwt as any).mockReturnValue(mockClaims);
|
||||
|
||||
const result = service.parseIdToken('valid.jwt.token');
|
||||
|
||||
expect(result).toEqual(mockClaims);
|
||||
expect(decodeJwt).toHaveBeenCalledWith('valid.jwt.token');
|
||||
});
|
||||
|
||||
it('should return null when no token provided', () => {
|
||||
const result = service.parseIdToken(undefined);
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should return null when token parsing fails', () => {
|
||||
(decodeJwt as any).mockImplementation(() => {
|
||||
throw new Error('Invalid token');
|
||||
});
|
||||
|
||||
const result = service.parseIdToken('invalid.token');
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should handle claims with array values', () => {
|
||||
const mockClaims: JwtClaims = {
|
||||
sub: 'user123',
|
||||
groups: ['admin', 'user'],
|
||||
roles: ['role1', 'role2', 'role3'],
|
||||
};
|
||||
|
||||
(decodeJwt as any).mockReturnValue(mockClaims);
|
||||
|
||||
const result = service.parseIdToken('token.with.arrays');
|
||||
|
||||
expect(result).toEqual(mockClaims);
|
||||
});
|
||||
|
||||
it('should log warning for complex object claims', () => {
|
||||
const loggerSpy = vi.spyOn(service['logger'], 'warn');
|
||||
|
||||
const mockClaims: JwtClaims = {
|
||||
sub: 'user123',
|
||||
complexClaim: {
|
||||
nested: 'value',
|
||||
another: 'field',
|
||||
},
|
||||
};
|
||||
|
||||
(decodeJwt as any).mockReturnValue(mockClaims);
|
||||
|
||||
service.parseIdToken('token.with.complex');
|
||||
|
||||
expect(loggerSpy).toHaveBeenCalledWith(expect.stringContaining('complex object structure'));
|
||||
});
|
||||
|
||||
it('should handle Google-specific claims', () => {
|
||||
const mockClaims: JwtClaims = {
|
||||
sub: 'google-user-id',
|
||||
email: 'user@company.com',
|
||||
name: 'Google User',
|
||||
hd: 'company.com', // Google hosted domain
|
||||
};
|
||||
|
||||
(decodeJwt as any).mockReturnValue(mockClaims);
|
||||
|
||||
const result = service.parseIdToken('google.jwt.token');
|
||||
|
||||
expect(result).toEqual(mockClaims);
|
||||
expect(result?.hd).toBe('company.com');
|
||||
});
|
||||
});
|
||||
|
||||
describe('validateClaims', () => {
|
||||
it('should return user sub when claims are valid', () => {
|
||||
const claims: JwtClaims = {
|
||||
sub: 'user123',
|
||||
email: 'user@example.com',
|
||||
};
|
||||
|
||||
const result = service.validateClaims(claims);
|
||||
expect(result).toBe('user123');
|
||||
});
|
||||
|
||||
it('should throw UnauthorizedException when claims are null', () => {
|
||||
expect(() => service.validateClaims(null)).toThrow(UnauthorizedException);
|
||||
});
|
||||
|
||||
it('should throw UnauthorizedException when sub is missing', () => {
|
||||
const claims: JwtClaims = {
|
||||
email: 'user@example.com',
|
||||
name: 'User',
|
||||
};
|
||||
|
||||
expect(() => service.validateClaims(claims)).toThrow(UnauthorizedException);
|
||||
});
|
||||
|
||||
it('should throw UnauthorizedException when sub is empty', () => {
|
||||
const claims: JwtClaims = {
|
||||
sub: '',
|
||||
email: 'user@example.com',
|
||||
};
|
||||
|
||||
expect(() => service.validateClaims(claims)).toThrow(UnauthorizedException);
|
||||
});
|
||||
});
|
||||
|
||||
describe('extractUserInfo', () => {
|
||||
it('should extract basic user information', () => {
|
||||
const claims: JwtClaims = {
|
||||
sub: 'user123',
|
||||
email: 'user@example.com',
|
||||
name: 'Test User',
|
||||
};
|
||||
|
||||
const result = service.extractUserInfo(claims);
|
||||
|
||||
expect(result).toEqual({
|
||||
sub: 'user123',
|
||||
email: 'user@example.com',
|
||||
name: 'Test User',
|
||||
domain: undefined,
|
||||
});
|
||||
});
|
||||
|
||||
it('should extract Google hosted domain', () => {
|
||||
const claims: JwtClaims = {
|
||||
sub: 'google-user',
|
||||
email: 'user@company.com',
|
||||
name: 'Google User',
|
||||
hd: 'company.com',
|
||||
};
|
||||
|
||||
const result = service.extractUserInfo(claims);
|
||||
|
||||
expect(result).toEqual({
|
||||
sub: 'google-user',
|
||||
email: 'user@company.com',
|
||||
name: 'Google User',
|
||||
domain: 'company.com',
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle missing optional fields', () => {
|
||||
const claims: JwtClaims = {
|
||||
sub: 'user123',
|
||||
};
|
||||
|
||||
const result = service.extractUserInfo(claims);
|
||||
|
||||
expect(result).toEqual({
|
||||
sub: 'user123',
|
||||
email: undefined,
|
||||
name: undefined,
|
||||
domain: undefined,
|
||||
});
|
||||
});
|
||||
|
||||
it('should ignore extra claims', () => {
|
||||
const claims: JwtClaims = {
|
||||
sub: 'user123',
|
||||
email: 'user@example.com',
|
||||
name: 'Test User',
|
||||
extra: 'claim',
|
||||
another: 'field',
|
||||
groups: ['admin'],
|
||||
};
|
||||
|
||||
const result = service.extractUserInfo(claims);
|
||||
|
||||
expect(result).toEqual({
|
||||
sub: 'user123',
|
||||
email: 'user@example.com',
|
||||
name: 'Test User',
|
||||
domain: undefined,
|
||||
});
|
||||
expect(result).not.toHaveProperty('extra');
|
||||
expect(result).not.toHaveProperty('groups');
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,80 @@
|
||||
import { Injectable, Logger, UnauthorizedException } from '@nestjs/common';
|
||||
|
||||
import { decodeJwt } from 'jose';
|
||||
|
||||
export interface JwtClaims {
|
||||
sub?: string;
|
||||
email?: string;
|
||||
name?: string;
|
||||
hd?: string; // Google hosted domain
|
||||
[claim: string]: unknown;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class OidcClaimsService {
|
||||
private readonly logger = new Logger(OidcClaimsService.name);
|
||||
|
||||
parseIdToken(idToken: string | undefined): JwtClaims | null {
|
||||
if (!idToken) {
|
||||
this.logger.error('No ID token received from provider');
|
||||
return null;
|
||||
}
|
||||
|
||||
try {
|
||||
// Use jose to properly decode the JWT
|
||||
const claims = decodeJwt(idToken) as JwtClaims;
|
||||
|
||||
// Log claims safely without PII - only structure, not values
|
||||
if (claims) {
|
||||
const claimKeys = Object.keys(claims).join(', ');
|
||||
this.logger.debug(`ID token decoded successfully. Available claims: [${claimKeys}]`);
|
||||
|
||||
// Log claim types without exposing sensitive values
|
||||
for (const [key, value] of Object.entries(claims)) {
|
||||
const valueType = Array.isArray(value) ? `array[${value.length}]` : typeof value;
|
||||
|
||||
// Only log structure, not actual values (avoid PII)
|
||||
this.logger.debug(`Claim '${key}': type=${valueType}`);
|
||||
|
||||
// Check for unexpected claim types
|
||||
if (valueType === 'object' && value !== null && !Array.isArray(value)) {
|
||||
this.logger.warn(`Claim '${key}' contains complex object structure`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return claims;
|
||||
} catch (e) {
|
||||
this.logger.warn(`Failed to parse ID token: ${e}`);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
validateClaims(claims: JwtClaims | null): string {
|
||||
if (!claims?.sub) {
|
||||
this.logger.error(
|
||||
'No subject in token - claims available: ' +
|
||||
(claims ? Object.keys(claims).join(', ') : 'none')
|
||||
);
|
||||
throw new UnauthorizedException('No subject in token');
|
||||
}
|
||||
|
||||
const userSub = claims.sub;
|
||||
this.logger.debug(`Processing authentication for user: ${userSub}`);
|
||||
return userSub;
|
||||
}
|
||||
|
||||
extractUserInfo(claims: JwtClaims): {
|
||||
sub: string;
|
||||
email?: string;
|
||||
name?: string;
|
||||
domain?: string;
|
||||
} {
|
||||
return {
|
||||
sub: claims.sub!,
|
||||
email: claims.email,
|
||||
name: claims.name,
|
||||
domain: claims.hd,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,224 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
|
||||
import * as client from 'openid-client';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { OidcTokenExchangeService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-token-exchange.service.js';
|
||||
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
|
||||
|
||||
vi.mock('openid-client', () => ({
|
||||
authorizationCodeGrant: vi.fn(),
|
||||
allowInsecureRequests: vi.fn(),
|
||||
}));
|
||||
|
||||
describe('OidcTokenExchangeService', () => {
|
||||
let service: OidcTokenExchangeService;
|
||||
let mockConfig: client.Configuration;
|
||||
let mockProvider: OidcProvider;
|
||||
|
||||
beforeEach(() => {
|
||||
service = new OidcTokenExchangeService();
|
||||
|
||||
mockConfig = {
|
||||
serverMetadata: vi.fn().mockReturnValue({
|
||||
issuer: 'https://example.com',
|
||||
token_endpoint: 'https://example.com/token',
|
||||
response_types_supported: ['code'],
|
||||
grant_types_supported: ['authorization_code'],
|
||||
token_endpoint_auth_methods_supported: ['client_secret_post'],
|
||||
}),
|
||||
} as unknown as client.Configuration;
|
||||
|
||||
mockProvider = {
|
||||
id: 'test-provider',
|
||||
issuer: 'https://example.com',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
} as OidcProvider;
|
||||
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('exchangeCodeForTokens', () => {
|
||||
it('should handle malformed fullCallbackUrl gracefully', async () => {
|
||||
const code = 'test-code';
|
||||
const state = 'test-state';
|
||||
const redirectUri = 'https://example.com/callback';
|
||||
const malformedUrl = 'not://a valid url';
|
||||
|
||||
const mockTokens = {
|
||||
access_token: 'test-access-token',
|
||||
id_token: 'test-id-token',
|
||||
};
|
||||
|
||||
vi.mocked(client.authorizationCodeGrant).mockResolvedValue(mockTokens as any);
|
||||
|
||||
const loggerWarnSpy = vi.spyOn(Logger.prototype, 'warn').mockImplementation(() => {});
|
||||
const loggerDebugSpy = vi.spyOn(Logger.prototype, 'debug').mockImplementation(() => {});
|
||||
|
||||
const result = await service.exchangeCodeForTokens(
|
||||
mockConfig,
|
||||
mockProvider,
|
||||
code,
|
||||
state,
|
||||
redirectUri,
|
||||
malformedUrl
|
||||
);
|
||||
|
||||
expect(result).toEqual(mockTokens);
|
||||
expect(loggerWarnSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Failed to parse fullCallbackUrl'),
|
||||
expect.any(Error)
|
||||
);
|
||||
expect(client.authorizationCodeGrant).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle empty fullCallbackUrl without throwing', async () => {
|
||||
const code = 'test-code';
|
||||
const state = 'test-state';
|
||||
const redirectUri = 'https://example.com/callback';
|
||||
|
||||
const mockTokens = {
|
||||
access_token: 'test-access-token',
|
||||
id_token: 'test-id-token',
|
||||
};
|
||||
|
||||
vi.mocked(client.authorizationCodeGrant).mockResolvedValue(mockTokens as any);
|
||||
|
||||
const loggerWarnSpy = vi.spyOn(Logger.prototype, 'warn').mockImplementation(() => {});
|
||||
|
||||
const result = await service.exchangeCodeForTokens(
|
||||
mockConfig,
|
||||
mockProvider,
|
||||
code,
|
||||
state,
|
||||
redirectUri,
|
||||
''
|
||||
);
|
||||
|
||||
expect(result).toEqual(mockTokens);
|
||||
expect(loggerWarnSpy).not.toHaveBeenCalled();
|
||||
expect(client.authorizationCodeGrant).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle whitespace-only fullCallbackUrl without throwing', async () => {
|
||||
const code = 'test-code';
|
||||
const state = 'test-state';
|
||||
const redirectUri = 'https://example.com/callback';
|
||||
|
||||
const mockTokens = {
|
||||
access_token: 'test-access-token',
|
||||
id_token: 'test-id-token',
|
||||
};
|
||||
|
||||
vi.mocked(client.authorizationCodeGrant).mockResolvedValue(mockTokens as any);
|
||||
|
||||
const loggerWarnSpy = vi.spyOn(Logger.prototype, 'warn').mockImplementation(() => {});
|
||||
|
||||
const result = await service.exchangeCodeForTokens(
|
||||
mockConfig,
|
||||
mockProvider,
|
||||
code,
|
||||
state,
|
||||
redirectUri,
|
||||
' '
|
||||
);
|
||||
|
||||
expect(result).toEqual(mockTokens);
|
||||
expect(loggerWarnSpy).not.toHaveBeenCalled();
|
||||
expect(client.authorizationCodeGrant).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should copy parameters from valid fullCallbackUrl', async () => {
|
||||
const code = 'test-code';
|
||||
const state = 'test-state';
|
||||
const redirectUri = 'https://example.com/callback';
|
||||
const fullCallbackUrl =
|
||||
'https://example.com/callback?code=test-code&state=test-state&scope=openid&authuser=0';
|
||||
|
||||
const mockTokens = {
|
||||
access_token: 'test-access-token',
|
||||
id_token: 'test-id-token',
|
||||
};
|
||||
|
||||
vi.mocked(client.authorizationCodeGrant).mockResolvedValue(mockTokens as any);
|
||||
|
||||
const loggerWarnSpy = vi.spyOn(Logger.prototype, 'warn').mockImplementation(() => {});
|
||||
const loggerDebugSpy = vi.spyOn(Logger.prototype, 'debug').mockImplementation(() => {});
|
||||
|
||||
const result = await service.exchangeCodeForTokens(
|
||||
mockConfig,
|
||||
mockProvider,
|
||||
code,
|
||||
state,
|
||||
redirectUri,
|
||||
fullCallbackUrl
|
||||
);
|
||||
|
||||
expect(result).toEqual(mockTokens);
|
||||
expect(loggerWarnSpy).not.toHaveBeenCalled();
|
||||
|
||||
const authCodeGrantCall = vi.mocked(client.authorizationCodeGrant).mock.calls[0];
|
||||
const cleanUrl = authCodeGrantCall[1] as URL;
|
||||
|
||||
expect(cleanUrl.searchParams.get('scope')).toBe('openid');
|
||||
expect(cleanUrl.searchParams.get('authuser')).toBe('0');
|
||||
});
|
||||
|
||||
it('should handle undefined fullCallbackUrl', async () => {
|
||||
const code = 'test-code';
|
||||
const state = 'test-state';
|
||||
const redirectUri = 'https://example.com/callback';
|
||||
|
||||
const mockTokens = {
|
||||
access_token: 'test-access-token',
|
||||
id_token: 'test-id-token',
|
||||
};
|
||||
|
||||
vi.mocked(client.authorizationCodeGrant).mockResolvedValue(mockTokens as any);
|
||||
|
||||
const loggerWarnSpy = vi.spyOn(Logger.prototype, 'warn').mockImplementation(() => {});
|
||||
|
||||
const result = await service.exchangeCodeForTokens(
|
||||
mockConfig,
|
||||
mockProvider,
|
||||
code,
|
||||
state,
|
||||
redirectUri,
|
||||
undefined
|
||||
);
|
||||
|
||||
expect(result).toEqual(mockTokens);
|
||||
expect(loggerWarnSpy).not.toHaveBeenCalled();
|
||||
expect(client.authorizationCodeGrant).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle non-string fullCallbackUrl types gracefully', async () => {
|
||||
const code = 'test-code';
|
||||
const state = 'test-state';
|
||||
const redirectUri = 'https://example.com/callback';
|
||||
|
||||
const mockTokens = {
|
||||
access_token: 'test-access-token',
|
||||
id_token: 'test-id-token',
|
||||
};
|
||||
|
||||
vi.mocked(client.authorizationCodeGrant).mockResolvedValue(mockTokens as any);
|
||||
|
||||
const loggerWarnSpy = vi.spyOn(Logger.prototype, 'warn').mockImplementation(() => {});
|
||||
|
||||
const result = await service.exchangeCodeForTokens(
|
||||
mockConfig,
|
||||
mockProvider,
|
||||
code,
|
||||
state,
|
||||
redirectUri,
|
||||
123 as any
|
||||
);
|
||||
|
||||
expect(result).toEqual(mockTokens);
|
||||
expect(loggerWarnSpy).not.toHaveBeenCalled();
|
||||
expect(client.authorizationCodeGrant).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,174 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
|
||||
import * as client from 'openid-client';
|
||||
|
||||
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
|
||||
import { ErrorExtractor } from '@app/unraid-api/utils/error-extractor.util.js';
|
||||
|
||||
// Extended type for our internal use - openid-client v6 doesn't directly expose
|
||||
// skip options for aud/iss checks, so we'll handle validation errors differently
|
||||
type ExtendedGrantChecks = client.AuthorizationCodeGrantChecks;
|
||||
|
||||
@Injectable()
|
||||
export class OidcTokenExchangeService {
|
||||
private readonly logger = new Logger(OidcTokenExchangeService.name);
|
||||
|
||||
async exchangeCodeForTokens(
|
||||
config: client.Configuration,
|
||||
provider: OidcProvider,
|
||||
code: string,
|
||||
state: string,
|
||||
redirectUri: string,
|
||||
fullCallbackUrl?: string
|
||||
): Promise<client.TokenEndpointResponse> {
|
||||
this.logger.debug(`Provider ${provider.id} config loaded`);
|
||||
this.logger.debug(`Redirect URI: ${redirectUri}`);
|
||||
|
||||
// Build current URL for token exchange
|
||||
// CRITICAL: The URL used here MUST match the redirect_uri that was sent to the authorization endpoint
|
||||
// Google expects the exact same redirect_uri during token exchange
|
||||
const currentUrl = new URL(redirectUri);
|
||||
currentUrl.searchParams.set('code', code);
|
||||
currentUrl.searchParams.set('state', state);
|
||||
|
||||
// Copy additional parameters from the actual callback if provided
|
||||
if (fullCallbackUrl && typeof fullCallbackUrl === 'string' && fullCallbackUrl.trim()) {
|
||||
try {
|
||||
const actualUrl = new URL(fullCallbackUrl);
|
||||
// Copy over additional params that Google might have added (scope, authuser, prompt, etc)
|
||||
// but DO NOT change the base URL or path
|
||||
['scope', 'authuser', 'prompt', 'hd', 'session_state', 'iss'].forEach((param) => {
|
||||
const value = actualUrl.searchParams.get(param);
|
||||
if (value && !currentUrl.searchParams.has(param)) {
|
||||
currentUrl.searchParams.set(param, value);
|
||||
}
|
||||
});
|
||||
} catch (urlError) {
|
||||
this.logger.warn(`Failed to parse fullCallbackUrl: ${fullCallbackUrl}`, urlError);
|
||||
// Continue with the existing currentUrl flow without additional params
|
||||
}
|
||||
}
|
||||
|
||||
// Google returns iss in the response, openid-client v6 expects it
|
||||
// If not present, add it based on the provider's issuer
|
||||
if (!currentUrl.searchParams.has('iss') && provider.issuer) {
|
||||
currentUrl.searchParams.set('iss', provider.issuer);
|
||||
}
|
||||
|
||||
this.logger.debug(`Token exchange URL (matches redirect_uri): ${currentUrl.href}`);
|
||||
|
||||
// For openid-client v6, we need to prepare the authorization response
|
||||
const authorizationResponse = new URLSearchParams(currentUrl.search);
|
||||
|
||||
// Set the original client state for openid-client
|
||||
authorizationResponse.set('state', state);
|
||||
|
||||
// Create a new URL with the cleaned parameters
|
||||
const cleanUrl = new URL(redirectUri);
|
||||
cleanUrl.search = authorizationResponse.toString();
|
||||
|
||||
this.logger.debug(`Clean URL for token exchange: ${cleanUrl.href}`);
|
||||
|
||||
try {
|
||||
this.logger.debug(`Starting token exchange with openid-client`);
|
||||
this.logger.debug(`Config issuer: ${config.serverMetadata().issuer}`);
|
||||
this.logger.debug(`Config token endpoint: ${config.serverMetadata().token_endpoint}`);
|
||||
|
||||
// Log the complete token exchange request details
|
||||
const tokenEndpoint = config.serverMetadata().token_endpoint;
|
||||
this.logger.debug(`Full token endpoint URL: ${tokenEndpoint}`);
|
||||
this.logger.debug(`Authorization code: ${code.substring(0, 10)}...`);
|
||||
this.logger.debug(`Redirect URI in token request: ${redirectUri}`);
|
||||
this.logger.debug(`Client ID: ${provider.clientId}`);
|
||||
this.logger.debug(`Client secret configured: ${provider.clientSecret ? 'Yes' : 'No'}`);
|
||||
this.logger.debug(`Expected state value: ${state}`);
|
||||
|
||||
// Log the server metadata to check for any configuration issues
|
||||
const metadata = config.serverMetadata();
|
||||
this.logger.debug(
|
||||
`Server supports response types: ${metadata.response_types_supported?.join(', ') || 'not specified'}`
|
||||
);
|
||||
this.logger.debug(
|
||||
`Server grant types: ${metadata.grant_types_supported?.join(', ') || 'not specified'}`
|
||||
);
|
||||
this.logger.debug(
|
||||
`Token endpoint auth methods: ${metadata.token_endpoint_auth_methods_supported?.join(', ') || 'not specified'}`
|
||||
);
|
||||
|
||||
// For HTTP endpoints, we need to call allowInsecureRequests on the config
|
||||
if (provider.issuer) {
|
||||
try {
|
||||
const serverUrl = new URL(provider.issuer);
|
||||
if (serverUrl.protocol === 'http:') {
|
||||
this.logger.debug(
|
||||
`Allowing insecure requests for HTTP endpoint: ${provider.id}`
|
||||
);
|
||||
// allowInsecureRequests is deprecated but still needed for HTTP endpoints
|
||||
client.allowInsecureRequests(config);
|
||||
}
|
||||
} catch (error) {
|
||||
this.logger.warn(
|
||||
`Invalid issuer URL for provider ${provider.id}: ${provider.issuer}`
|
||||
);
|
||||
// Continue without special HTTP options
|
||||
}
|
||||
}
|
||||
|
||||
const requestChecks: ExtendedGrantChecks = {
|
||||
expectedState: state,
|
||||
};
|
||||
|
||||
// Log what we're about to send
|
||||
this.logger.debug(`Executing authorizationCodeGrant with:`);
|
||||
this.logger.debug(`- Clean URL: ${cleanUrl.href}`);
|
||||
this.logger.debug(`- Expected state: ${state}`);
|
||||
this.logger.debug(`- Grant type: authorization_code`);
|
||||
|
||||
const tokens = await client.authorizationCodeGrant(config, cleanUrl, requestChecks);
|
||||
|
||||
this.logger.debug(
|
||||
`Token exchange successful, received tokens: ${Object.keys(tokens).join(', ')}`
|
||||
);
|
||||
|
||||
return tokens;
|
||||
} catch (tokenError) {
|
||||
// Extract and log error details using the utility
|
||||
const extracted = ErrorExtractor.extract(tokenError);
|
||||
this.logger.error('Token exchange failed');
|
||||
ErrorExtractor.formatForLogging(extracted, this.logger);
|
||||
|
||||
// Special handling for content-type and parsing errors
|
||||
if (ErrorExtractor.isOAuthResponseError(extracted)) {
|
||||
this.logger.error('Token endpoint returned invalid or non-JSON response.');
|
||||
this.logger.error('This typically means:');
|
||||
this.logger.error(
|
||||
'1. The token endpoint URL is incorrect (check for typos or wrong paths)'
|
||||
);
|
||||
this.logger.error('2. The server returned an HTML error page instead of JSON');
|
||||
this.logger.error('3. Authentication failed (invalid client_id or client_secret)');
|
||||
this.logger.error('4. A proxy/firewall is intercepting the request');
|
||||
this.logger.error('5. The OAuth server returned malformed JSON');
|
||||
this.logger.error(
|
||||
`Configured token endpoint: ${config.serverMetadata().token_endpoint}`
|
||||
);
|
||||
this.logger.error('Please verify your OIDC provider configuration.');
|
||||
}
|
||||
|
||||
// Check if error message contains the "unexpected JWT claim" text
|
||||
if (ErrorExtractor.isJwtClaimError(extracted)) {
|
||||
this.logger.error(
|
||||
`unexpected JWT claim value encountered during token validation by openid-client`
|
||||
);
|
||||
this.logger.error(
|
||||
`This error typically means the 'iss' claim in the JWT doesn't match the expected issuer`
|
||||
);
|
||||
this.logger.error(`Check that your provider's issuer URL is configured correctly`);
|
||||
this.logger.error(`Expected issuer: ${config.serverMetadata().issuer}`);
|
||||
this.logger.error(`Provider configured issuer: ${provider.issuer}`);
|
||||
}
|
||||
|
||||
// Re-throw the original error with all its properties intact
|
||||
throw tokenError;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,267 @@
|
||||
import { Test, TestingModule } from '@nestjs/testing';
|
||||
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { OidcClientConfigService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-client-config.service.js';
|
||||
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/core/oidc-validation.service.js';
|
||||
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
|
||||
|
||||
describe('OidcClientConfigService', () => {
|
||||
let service: OidcClientConfigService;
|
||||
let validationService: any;
|
||||
|
||||
beforeEach(async () => {
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
providers: [
|
||||
OidcClientConfigService,
|
||||
{
|
||||
provide: OidcValidationService,
|
||||
useValue: {
|
||||
performDiscovery: vi.fn(),
|
||||
},
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
service = module.get<OidcClientConfigService>(OidcClientConfigService);
|
||||
validationService = module.get(OidcValidationService);
|
||||
});
|
||||
|
||||
describe('Manual Configuration', () => {
|
||||
it('should create manual configuration when discovery fails but manual endpoints are provided', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'manual-provider',
|
||||
name: 'Manual Provider',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
issuer: 'https://manual.example.com',
|
||||
authorizationEndpoint: 'https://manual.example.com/auth',
|
||||
tokenEndpoint: 'https://manual.example.com/token',
|
||||
jwksUri: 'https://manual.example.com/jwks',
|
||||
scopes: ['openid', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
// Mock discovery to fail
|
||||
validationService.performDiscovery.mockRejectedValue(new Error('Discovery failed'));
|
||||
|
||||
const config = await service.getOrCreateConfig(provider);
|
||||
|
||||
// Verify the configuration was created with the correct endpoints
|
||||
expect(config).toBeDefined();
|
||||
expect(config.serverMetadata().authorization_endpoint).toBe(
|
||||
'https://manual.example.com/auth'
|
||||
);
|
||||
expect(config.serverMetadata().token_endpoint).toBe('https://manual.example.com/token');
|
||||
expect(config.serverMetadata().jwks_uri).toBe('https://manual.example.com/jwks');
|
||||
expect(config.serverMetadata().issuer).toBe('https://manual.example.com');
|
||||
});
|
||||
|
||||
it('should create manual configuration with fallback issuer when not provided', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'manual-provider-no-issuer',
|
||||
name: 'Manual Provider No Issuer',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
issuer: '', // Empty issuer should skip discovery and use manual endpoints
|
||||
authorizationEndpoint: 'https://manual.example.com/auth',
|
||||
tokenEndpoint: 'https://manual.example.com/token',
|
||||
scopes: ['openid', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
const config = await service.getOrCreateConfig(provider);
|
||||
|
||||
// Verify the configuration was created with inferred issuer from endpoints
|
||||
expect(config).toBeDefined();
|
||||
expect(config.serverMetadata().issuer).toBe('https://manual.example.com');
|
||||
expect(config.serverMetadata().authorization_endpoint).toBe(
|
||||
'https://manual.example.com/auth'
|
||||
);
|
||||
expect(config.serverMetadata().token_endpoint).toBe('https://manual.example.com/token');
|
||||
});
|
||||
|
||||
it('should handle manual configuration with client secret properly', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'manual-with-secret',
|
||||
name: 'Manual With Secret',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'secret-123',
|
||||
issuer: 'https://manual.example.com',
|
||||
authorizationEndpoint: 'https://manual.example.com/auth',
|
||||
tokenEndpoint: 'https://manual.example.com/token',
|
||||
scopes: ['openid', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
// Mock discovery to fail
|
||||
validationService.performDiscovery.mockRejectedValue(new Error('Discovery failed'));
|
||||
|
||||
const config = await service.getOrCreateConfig(provider);
|
||||
|
||||
// Verify configuration was created successfully
|
||||
expect(config).toBeDefined();
|
||||
expect(config.clientMetadata().client_secret).toBe('secret-123');
|
||||
});
|
||||
|
||||
it('should handle manual configuration without client secret (public client)', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'manual-public-client',
|
||||
name: 'Manual Public Client',
|
||||
clientId: 'public-client-id',
|
||||
// No client secret
|
||||
issuer: 'https://manual.example.com',
|
||||
authorizationEndpoint: 'https://manual.example.com/auth',
|
||||
tokenEndpoint: 'https://manual.example.com/token',
|
||||
scopes: ['openid', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
// Mock discovery to fail
|
||||
validationService.performDiscovery.mockRejectedValue(new Error('Discovery failed'));
|
||||
|
||||
const config = await service.getOrCreateConfig(provider);
|
||||
|
||||
// Verify configuration was created successfully
|
||||
expect(config).toBeDefined();
|
||||
expect(config.clientMetadata().client_secret).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should cache configurations', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'cached-provider',
|
||||
name: 'Cached Provider',
|
||||
clientId: 'test-client-id',
|
||||
issuer: '',
|
||||
authorizationEndpoint: 'https://cached.example.com/auth',
|
||||
tokenEndpoint: 'https://cached.example.com/token',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
// First call
|
||||
const config1 = await service.getOrCreateConfig(provider);
|
||||
|
||||
// Second call - should return cached value
|
||||
const config2 = await service.getOrCreateConfig(provider);
|
||||
|
||||
// Should be the exact same object
|
||||
expect(config1).toBe(config2);
|
||||
expect(service.getCacheSize()).toBe(1);
|
||||
});
|
||||
|
||||
it('should clear cache for specific provider', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'provider-to-clear',
|
||||
name: 'Provider to Clear',
|
||||
clientId: 'test-client-id',
|
||||
issuer: '',
|
||||
authorizationEndpoint: 'https://clear.example.com/auth',
|
||||
tokenEndpoint: 'https://clear.example.com/token',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
await service.getOrCreateConfig(provider);
|
||||
expect(service.getCacheSize()).toBe(1);
|
||||
|
||||
service.clearCache('provider-to-clear');
|
||||
expect(service.getCacheSize()).toBe(0);
|
||||
});
|
||||
|
||||
it('should clear entire cache', async () => {
|
||||
const provider1: OidcProvider = {
|
||||
id: 'provider1',
|
||||
name: 'Provider 1',
|
||||
clientId: 'client1',
|
||||
issuer: '',
|
||||
authorizationEndpoint: 'https://p1.example.com/auth',
|
||||
tokenEndpoint: 'https://p1.example.com/token',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
const provider2: OidcProvider = {
|
||||
id: 'provider2',
|
||||
name: 'Provider 2',
|
||||
clientId: 'client2',
|
||||
issuer: '',
|
||||
authorizationEndpoint: 'https://p2.example.com/auth',
|
||||
tokenEndpoint: 'https://p2.example.com/token',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
await service.getOrCreateConfig(provider1);
|
||||
await service.getOrCreateConfig(provider2);
|
||||
expect(service.getCacheSize()).toBe(2);
|
||||
|
||||
service.clearCache();
|
||||
expect(service.getCacheSize()).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Discovery Configuration', () => {
|
||||
it('should use discovery when issuer is provided', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'discovery-provider',
|
||||
name: 'Discovery Provider',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-secret',
|
||||
issuer: 'https://discovery.example.com',
|
||||
scopes: ['openid', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
const mockConfig = {
|
||||
serverMetadata: vi.fn().mockReturnValue({
|
||||
issuer: 'https://discovery.example.com',
|
||||
authorization_endpoint: 'https://discovery.example.com/authorize',
|
||||
token_endpoint: 'https://discovery.example.com/token',
|
||||
jwks_uri: 'https://discovery.example.com/.well-known/jwks.json',
|
||||
userinfo_endpoint: 'https://discovery.example.com/userinfo',
|
||||
}),
|
||||
clientMetadata: vi.fn().mockReturnValue({}),
|
||||
};
|
||||
|
||||
validationService.performDiscovery.mockResolvedValue(mockConfig);
|
||||
|
||||
const config = await service.getOrCreateConfig(provider);
|
||||
|
||||
expect(validationService.performDiscovery).toHaveBeenCalledWith(provider, undefined);
|
||||
expect(config).toBe(mockConfig);
|
||||
});
|
||||
|
||||
it('should allow HTTP for discovery when issuer uses HTTP', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'http-discovery-provider',
|
||||
name: 'HTTP Discovery Provider',
|
||||
clientId: 'test-client-id',
|
||||
issuer: 'http://discovery.example.com',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
const mockConfig = {
|
||||
serverMetadata: vi.fn().mockReturnValue({
|
||||
issuer: 'http://discovery.example.com',
|
||||
authorization_endpoint: 'http://discovery.example.com/authorize',
|
||||
token_endpoint: 'http://discovery.example.com/token',
|
||||
}),
|
||||
clientMetadata: vi.fn().mockReturnValue({}),
|
||||
};
|
||||
|
||||
validationService.performDiscovery.mockResolvedValue(mockConfig);
|
||||
|
||||
const config = await service.getOrCreateConfig(provider);
|
||||
|
||||
expect(validationService.performDiscovery).toHaveBeenCalledWith(
|
||||
provider,
|
||||
expect.objectContaining({
|
||||
execute: expect.any(Array),
|
||||
})
|
||||
);
|
||||
expect(config).toBe(mockConfig);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,168 @@
|
||||
import { Injectable, Logger, UnauthorizedException } from '@nestjs/common';
|
||||
|
||||
import * as client from 'openid-client';
|
||||
|
||||
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/core/oidc-validation.service.js';
|
||||
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
|
||||
import { ErrorExtractor } from '@app/unraid-api/utils/error-extractor.util.js';
|
||||
|
||||
@Injectable()
|
||||
export class OidcClientConfigService {
|
||||
private readonly logger = new Logger(OidcClientConfigService.name);
|
||||
private readonly configCache = new Map<string, client.Configuration>();
|
||||
|
||||
constructor(private readonly validationService: OidcValidationService) {}
|
||||
|
||||
async getOrCreateConfig(provider: OidcProvider): Promise<client.Configuration> {
|
||||
const cacheKey = provider.id;
|
||||
|
||||
if (this.configCache.has(cacheKey)) {
|
||||
return this.configCache.get(cacheKey)!;
|
||||
}
|
||||
|
||||
try {
|
||||
// Use the validation service to perform discovery with HTTP support
|
||||
if (provider.issuer) {
|
||||
this.logger.debug(`Attempting discovery for ${provider.id} at ${provider.issuer}`);
|
||||
|
||||
// Create client options with HTTP support if needed
|
||||
const serverUrl = new URL(provider.issuer);
|
||||
let clientOptions: client.DiscoveryRequestOptions | undefined;
|
||||
if (serverUrl.protocol === 'http:') {
|
||||
this.logger.debug(`Allowing HTTP for ${provider.id} as specified by user`);
|
||||
clientOptions = {
|
||||
execute: [client.allowInsecureRequests],
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const config = await this.validationService.performDiscovery(
|
||||
provider,
|
||||
clientOptions
|
||||
);
|
||||
this.logger.debug(`Discovery successful for ${provider.id}`);
|
||||
this.logger.debug(
|
||||
`Authorization endpoint: ${config.serverMetadata().authorization_endpoint}`
|
||||
);
|
||||
this.logger.debug(`Token endpoint: ${config.serverMetadata().token_endpoint}`);
|
||||
this.logger.debug(`JWKS URI: ${config.serverMetadata().jwks_uri || 'Not provided'}`);
|
||||
this.logger.debug(
|
||||
`Userinfo endpoint: ${config.serverMetadata().userinfo_endpoint || 'Not provided'}`
|
||||
);
|
||||
this.configCache.set(cacheKey, config);
|
||||
return config;
|
||||
} catch (discoveryError) {
|
||||
const extracted = ErrorExtractor.extract(discoveryError);
|
||||
this.logger.warn(`Discovery failed for ${provider.id}: ${extracted.message}`);
|
||||
|
||||
// Log more details about the discovery error
|
||||
const discoveryUrl = `${provider.issuer}/.well-known/openid-configuration`;
|
||||
this.logger.debug(`Discovery URL attempted: ${discoveryUrl}`);
|
||||
|
||||
// Use error extractor for consistent logging
|
||||
ErrorExtractor.formatForLogging(extracted, this.logger);
|
||||
|
||||
// If discovery fails but we have manual endpoints, use them
|
||||
if (provider.authorizationEndpoint && provider.tokenEndpoint) {
|
||||
this.logger.log(`Using manual endpoints for ${provider.id}`);
|
||||
return this.createManualConfiguration(provider, cacheKey);
|
||||
} else {
|
||||
throw new Error(
|
||||
`OIDC discovery failed and no manual endpoints provided for ${provider.id}`
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Manual configuration when no issuer is provided
|
||||
if (provider.authorizationEndpoint && provider.tokenEndpoint) {
|
||||
this.logger.log(`Using manual endpoints for ${provider.id} (no issuer provided)`);
|
||||
return this.createManualConfiguration(provider, cacheKey);
|
||||
}
|
||||
|
||||
// If we reach here, neither discovery nor manual endpoints are available
|
||||
throw new Error(
|
||||
`No configuration method available for ${provider.id}: requires either valid issuer for discovery or manual endpoints`
|
||||
);
|
||||
} catch (error) {
|
||||
const extracted = ErrorExtractor.extract(error);
|
||||
this.logger.error(
|
||||
`Failed to create OIDC configuration for ${provider.id}: ${extracted.message}`
|
||||
);
|
||||
|
||||
// Log more details in debug mode
|
||||
if (extracted.stack) {
|
||||
this.logger.debug(`Stack trace: ${extracted.stack}`);
|
||||
}
|
||||
|
||||
throw new UnauthorizedException('Provider configuration error');
|
||||
}
|
||||
}
|
||||
|
||||
private createManualConfiguration(provider: OidcProvider, cacheKey: string): client.Configuration {
|
||||
// Create manual configuration with a valid issuer URL
|
||||
const inferredIssuer =
|
||||
provider.issuer && provider.issuer.trim() !== ''
|
||||
? provider.issuer
|
||||
: new URL(provider.authorizationEndpoint ?? provider.tokenEndpoint!).origin;
|
||||
const serverMetadata: client.ServerMetadata = {
|
||||
issuer: inferredIssuer,
|
||||
authorization_endpoint: provider.authorizationEndpoint!,
|
||||
token_endpoint: provider.tokenEndpoint!,
|
||||
jwks_uri: provider.jwksUri,
|
||||
};
|
||||
|
||||
const clientMetadata: Partial<client.ClientMetadata> = {
|
||||
client_secret: provider.clientSecret,
|
||||
};
|
||||
|
||||
// Configure client auth method
|
||||
const clientAuth = provider.clientSecret
|
||||
? client.ClientSecretPost(provider.clientSecret)
|
||||
: client.None();
|
||||
|
||||
try {
|
||||
const config = new client.Configuration(
|
||||
serverMetadata,
|
||||
provider.clientId,
|
||||
clientMetadata,
|
||||
clientAuth
|
||||
);
|
||||
|
||||
// Allow HTTP if any configured endpoint uses http
|
||||
const endpoints = [
|
||||
serverMetadata.authorization_endpoint,
|
||||
serverMetadata.token_endpoint,
|
||||
].filter(Boolean) as string[];
|
||||
const hasHttp = endpoints.some((e) => new URL(e).protocol === 'http:');
|
||||
if (hasHttp) {
|
||||
this.logger.debug(`Allowing HTTP for manual endpoints on ${provider.id}`);
|
||||
// allowInsecureRequests is deprecated but still needed for HTTP endpoints
|
||||
client.allowInsecureRequests(config);
|
||||
}
|
||||
|
||||
this.logger.debug(`Manual configuration created for ${provider.id}`);
|
||||
this.logger.debug(`Authorization endpoint: ${serverMetadata.authorization_endpoint}`);
|
||||
this.logger.debug(`Token endpoint: ${serverMetadata.token_endpoint}`);
|
||||
|
||||
this.configCache.set(cacheKey, config);
|
||||
return config;
|
||||
} catch (manualConfigError) {
|
||||
const extracted = ErrorExtractor.extract(manualConfigError);
|
||||
this.logger.error(`Failed to create manual configuration: ${extracted.message}`);
|
||||
throw new Error(`Manual configuration failed for ${provider.id}`);
|
||||
}
|
||||
}
|
||||
|
||||
clearCache(providerId?: string): void {
|
||||
if (providerId) {
|
||||
this.configCache.delete(providerId);
|
||||
} else {
|
||||
this.configCache.clear();
|
||||
}
|
||||
}
|
||||
|
||||
getCacheSize(): number {
|
||||
return this.configCache.size;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
|
||||
import { OidcClientConfigService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-client-config.service.js';
|
||||
import { OidcRedirectUriService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-redirect-uri.service.js';
|
||||
import { OidcBaseModule } from '@app/unraid-api/graph/resolvers/sso/core/oidc-base.module.js';
|
||||
|
||||
@Module({
|
||||
imports: [OidcBaseModule],
|
||||
providers: [OidcClientConfigService, OidcRedirectUriService],
|
||||
exports: [OidcClientConfigService, OidcRedirectUriService],
|
||||
})
|
||||
export class OidcClientModule {}
|
||||
@@ -0,0 +1,222 @@
|
||||
import { UnauthorizedException } from '@nestjs/common';
|
||||
import { Test, TestingModule } from '@nestjs/testing';
|
||||
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { OidcRedirectUriService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-redirect-uri.service.js';
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
|
||||
import { validateRedirectUri } from '@app/unraid-api/utils/redirect-uri-validator.js';
|
||||
|
||||
// Mock the redirect URI validator
|
||||
vi.mock('@app/unraid-api/utils/redirect-uri-validator.js', () => ({
|
||||
validateRedirectUri: vi.fn(),
|
||||
}));
|
||||
|
||||
describe('OidcRedirectUriService', () => {
|
||||
let service: OidcRedirectUriService;
|
||||
let oidcConfig: any;
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
providers: [
|
||||
OidcRedirectUriService,
|
||||
{
|
||||
provide: OidcConfigPersistence,
|
||||
useValue: {
|
||||
getConfig: vi.fn().mockResolvedValue({
|
||||
providers: [],
|
||||
defaultAllowedOrigins: ['https://allowed.example.com'],
|
||||
}),
|
||||
},
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
service = module.get<OidcRedirectUriService>(OidcRedirectUriService);
|
||||
oidcConfig = module.get(OidcConfigPersistence);
|
||||
});
|
||||
|
||||
describe('getRedirectUri', () => {
|
||||
it('should return valid redirect URI when validation passes', async () => {
|
||||
const requestOrigin = 'https://example.com';
|
||||
const requestHeaders = {
|
||||
'x-forwarded-proto': 'https',
|
||||
'x-forwarded-host': 'example.com',
|
||||
};
|
||||
|
||||
(validateRedirectUri as any).mockReturnValue({
|
||||
isValid: true,
|
||||
validatedUri: 'https://example.com',
|
||||
});
|
||||
|
||||
const result = await service.getRedirectUri(requestOrigin, requestHeaders);
|
||||
|
||||
expect(result).toBe('https://example.com/graphql/api/auth/oidc/callback');
|
||||
expect(validateRedirectUri).toHaveBeenCalledWith(
|
||||
'https://example.com',
|
||||
'https',
|
||||
'example.com',
|
||||
expect.anything(),
|
||||
['https://allowed.example.com']
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw UnauthorizedException when validation fails', async () => {
|
||||
const requestOrigin = 'https://evil.com';
|
||||
const requestHeaders = {
|
||||
'x-forwarded-proto': 'https',
|
||||
'x-forwarded-host': 'example.com',
|
||||
};
|
||||
|
||||
(validateRedirectUri as any).mockReturnValue({
|
||||
isValid: false,
|
||||
reason: 'Origin not allowed',
|
||||
});
|
||||
|
||||
await expect(service.getRedirectUri(requestOrigin, requestHeaders)).rejects.toThrow(
|
||||
UnauthorizedException
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle missing allowed origins', async () => {
|
||||
oidcConfig.getConfig.mockResolvedValue({
|
||||
providers: [],
|
||||
defaultAllowedOrigins: undefined,
|
||||
});
|
||||
|
||||
const requestOrigin = 'https://example.com';
|
||||
const requestHeaders = {
|
||||
'x-forwarded-proto': 'https',
|
||||
'x-forwarded-host': 'example.com',
|
||||
};
|
||||
|
||||
(validateRedirectUri as any).mockReturnValue({
|
||||
isValid: true,
|
||||
validatedUri: 'https://example.com',
|
||||
});
|
||||
|
||||
const result = await service.getRedirectUri(requestOrigin, requestHeaders);
|
||||
|
||||
expect(result).toBe('https://example.com/graphql/api/auth/oidc/callback');
|
||||
expect(validateRedirectUri).toHaveBeenCalledWith(
|
||||
'https://example.com',
|
||||
'https',
|
||||
'example.com',
|
||||
expect.anything(),
|
||||
undefined
|
||||
);
|
||||
});
|
||||
|
||||
it('should extract protocol from headers correctly', async () => {
|
||||
const requestOrigin = 'https://example.com';
|
||||
const requestHeaders = {
|
||||
'x-forwarded-proto': ['https', 'http'],
|
||||
host: 'example.com',
|
||||
};
|
||||
|
||||
(validateRedirectUri as any).mockReturnValue({
|
||||
isValid: true,
|
||||
validatedUri: 'https://example.com',
|
||||
});
|
||||
|
||||
const result = await service.getRedirectUri(requestOrigin, requestHeaders);
|
||||
|
||||
expect(result).toBe('https://example.com/graphql/api/auth/oidc/callback');
|
||||
expect(validateRedirectUri).toHaveBeenCalledWith(
|
||||
'https://example.com',
|
||||
'https', // Should use first value from array
|
||||
'example.com',
|
||||
expect.anything(),
|
||||
expect.anything()
|
||||
);
|
||||
});
|
||||
|
||||
it('should use host header as fallback', async () => {
|
||||
const requestOrigin = 'https://example.com';
|
||||
const requestHeaders = {
|
||||
host: 'example.com',
|
||||
};
|
||||
|
||||
(validateRedirectUri as any).mockReturnValue({
|
||||
isValid: true,
|
||||
validatedUri: 'https://example.com',
|
||||
});
|
||||
|
||||
const result = await service.getRedirectUri(requestOrigin, requestHeaders);
|
||||
|
||||
expect(result).toBe('https://example.com/graphql/api/auth/oidc/callback');
|
||||
expect(validateRedirectUri).toHaveBeenCalledWith(
|
||||
'https://example.com',
|
||||
'https', // Inferred from requestOrigin when x-forwarded-proto not present
|
||||
'example.com',
|
||||
expect.anything(),
|
||||
expect.anything()
|
||||
);
|
||||
});
|
||||
|
||||
it('should prefer x-forwarded-host over host header', async () => {
|
||||
const requestOrigin = 'https://example.com';
|
||||
const requestHeaders = {
|
||||
'x-forwarded-host': 'forwarded.example.com',
|
||||
host: 'original.example.com',
|
||||
};
|
||||
|
||||
(validateRedirectUri as any).mockReturnValue({
|
||||
isValid: true,
|
||||
validatedUri: 'https://example.com',
|
||||
});
|
||||
|
||||
const result = await service.getRedirectUri(requestOrigin, requestHeaders);
|
||||
|
||||
expect(result).toBe('https://example.com/graphql/api/auth/oidc/callback');
|
||||
expect(validateRedirectUri).toHaveBeenCalledWith(
|
||||
'https://example.com',
|
||||
'https', // Inferred from requestOrigin when x-forwarded-proto not present
|
||||
'forwarded.example.com', // Should use x-forwarded-host
|
||||
expect.anything(),
|
||||
expect.anything()
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw when URL construction fails', async () => {
|
||||
const requestOrigin = 'https://example.com';
|
||||
const requestHeaders = {};
|
||||
|
||||
(validateRedirectUri as any).mockReturnValue({
|
||||
isValid: true,
|
||||
validatedUri: 'invalid-url', // Invalid URL
|
||||
});
|
||||
|
||||
await expect(service.getRedirectUri(requestOrigin, requestHeaders)).rejects.toThrow(
|
||||
UnauthorizedException
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle array values in headers correctly', async () => {
|
||||
const requestOrigin = 'https://example.com';
|
||||
const requestHeaders = {
|
||||
'x-forwarded-proto': ['https'],
|
||||
'x-forwarded-host': ['forwarded.example.com', 'another.example.com'],
|
||||
host: ['original.example.com'],
|
||||
};
|
||||
|
||||
(validateRedirectUri as any).mockReturnValue({
|
||||
isValid: true,
|
||||
validatedUri: 'https://example.com',
|
||||
});
|
||||
|
||||
const result = await service.getRedirectUri(requestOrigin, requestHeaders);
|
||||
|
||||
expect(result).toBe('https://example.com/graphql/api/auth/oidc/callback');
|
||||
expect(validateRedirectUri).toHaveBeenCalledWith(
|
||||
'https://example.com',
|
||||
'https',
|
||||
'forwarded.example.com', // Should use first value from array
|
||||
expect.anything(),
|
||||
expect.anything()
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,97 @@
|
||||
import { Injectable, Logger, UnauthorizedException } from '@nestjs/common';
|
||||
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
|
||||
import { validateRedirectUri } from '@app/unraid-api/utils/redirect-uri-validator.js';
|
||||
|
||||
@Injectable()
|
||||
export class OidcRedirectUriService {
|
||||
private readonly logger = new Logger(OidcRedirectUriService.name);
|
||||
private readonly CALLBACK_PATH = '/graphql/api/auth/oidc/callback';
|
||||
|
||||
constructor(private readonly oidcConfig: OidcConfigPersistence) {}
|
||||
|
||||
async getRedirectUri(
|
||||
requestOrigin: string,
|
||||
requestHeaders: Record<string, string | string[] | undefined>
|
||||
): Promise<string> {
|
||||
// Extract protocol and host from headers for validation
|
||||
const { protocol, host } = this.getRequestOriginInfo(requestHeaders, requestOrigin);
|
||||
|
||||
// Get the global allowed origins from OIDC config
|
||||
const config = await this.oidcConfig.getConfig();
|
||||
const allowedOrigins = config?.defaultAllowedOrigins;
|
||||
|
||||
// Debug logging to trace the issue
|
||||
this.logger.debug(
|
||||
`OIDC Config loaded: ${JSON.stringify(config ? { hasConfig: true, allowedOrigins } : { hasConfig: false })}`
|
||||
);
|
||||
this.logger.debug(
|
||||
`Validating redirect URI: ${requestOrigin} against host: ${protocol}://${host}`
|
||||
);
|
||||
this.logger.debug(`Allowed origins from config: ${JSON.stringify(allowedOrigins || [])}`);
|
||||
|
||||
// Validate the provided requestOrigin using centralized validator
|
||||
// Pass the global allowed origins if available
|
||||
const validation = validateRedirectUri(
|
||||
requestOrigin,
|
||||
protocol,
|
||||
host,
|
||||
this.logger,
|
||||
allowedOrigins
|
||||
);
|
||||
|
||||
if (!validation.isValid) {
|
||||
this.logger.warn(`Invalid redirect_uri in GraphQL OIDC flow: ${validation.reason}`);
|
||||
throw new UnauthorizedException(
|
||||
`Invalid redirect_uri: ${requestOrigin}. Please add this callback URI to Settings → Management Access → Allowed Redirect URIs`
|
||||
);
|
||||
}
|
||||
|
||||
// Ensure the validated URI has the correct callback path
|
||||
try {
|
||||
const url = new URL(validation.validatedUri);
|
||||
// Only use origin to prevent path manipulation
|
||||
const redirectUri = `${url.origin}${this.CALLBACK_PATH}`;
|
||||
this.logger.debug(`Using validated redirect URI: ${redirectUri}`);
|
||||
return redirectUri;
|
||||
} catch (e) {
|
||||
this.logger.error(
|
||||
`Failed to construct redirect URI from validated URI: ${validation.validatedUri}`
|
||||
);
|
||||
throw new UnauthorizedException('Invalid redirect_uri');
|
||||
}
|
||||
}
|
||||
|
||||
private getRequestOriginInfo(
|
||||
requestHeaders: Record<string, string | string[] | undefined>,
|
||||
requestOrigin?: string
|
||||
): {
|
||||
protocol: string;
|
||||
host: string | undefined;
|
||||
} {
|
||||
// Extract protocol from x-forwarded-proto or infer from requestOrigin, default to http
|
||||
const forwardedProto = requestHeaders['x-forwarded-proto'];
|
||||
const protocol = forwardedProto
|
||||
? Array.isArray(forwardedProto)
|
||||
? forwardedProto[0]
|
||||
: forwardedProto
|
||||
: requestOrigin?.startsWith('https')
|
||||
? 'https'
|
||||
: 'http';
|
||||
|
||||
// Extract host from x-forwarded-host or host header
|
||||
const forwardedHost = requestHeaders['x-forwarded-host'];
|
||||
const hostHeader = requestHeaders['host'];
|
||||
const host = forwardedHost
|
||||
? Array.isArray(forwardedHost)
|
||||
? forwardedHost[0]
|
||||
: forwardedHost
|
||||
: hostHeader
|
||||
? Array.isArray(hostHeader)
|
||||
? hostHeader[0]
|
||||
: hostHeader
|
||||
: undefined;
|
||||
|
||||
return { protocol, host };
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
|
||||
import { UserSettingsModule } from '@unraid/shared/services/user-settings.js';
|
||||
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
|
||||
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/core/oidc-validation.service.js';
|
||||
|
||||
@Module({
|
||||
imports: [UserSettingsModule],
|
||||
providers: [OidcConfigPersistence, OidcValidationService],
|
||||
exports: [OidcConfigPersistence, OidcValidationService],
|
||||
})
|
||||
export class OidcBaseModule {}
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import { ConfigService } from '@nestjs/config';
|
||||
|
||||
import { RuleEffect } from '@jsonforms/core';
|
||||
@@ -6,12 +6,12 @@ import { mergeSettingSlices } from '@unraid/shared/jsonforms/settings.js';
|
||||
import { ConfigFilePersister } from '@unraid/shared/services/config-file.js';
|
||||
import { UserSettingsService } from '@unraid/shared/services/user-settings.js';
|
||||
|
||||
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/core/oidc-validation.service.js';
|
||||
import {
|
||||
AuthorizationOperator,
|
||||
OidcAuthorizationRule,
|
||||
OidcProvider,
|
||||
} from '@app/unraid-api/graph/resolvers/sso/oidc-provider.model.js';
|
||||
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/oidc-validation.service.js';
|
||||
} from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
|
||||
import {
|
||||
createAccordionLayout,
|
||||
createLabeledControl,
|
||||
@@ -21,6 +21,7 @@ import { SettingSlice } from '@app/unraid-api/types/json-forms.js';
|
||||
|
||||
export interface OidcConfig {
|
||||
providers: OidcProvider[];
|
||||
defaultAllowedOrigins?: string[];
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
@@ -52,6 +53,7 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
|
||||
defaultConfig(): OidcConfig {
|
||||
return {
|
||||
providers: [this.getUnraidNetSsoProvider()],
|
||||
defaultAllowedOrigins: [],
|
||||
};
|
||||
}
|
||||
|
||||
@@ -93,6 +95,7 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
|
||||
|
||||
return {
|
||||
providers: [unraidNetSsoProvider],
|
||||
defaultAllowedOrigins: [],
|
||||
};
|
||||
}
|
||||
|
||||
@@ -119,6 +122,42 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
|
||||
provider.authorizationRules || currentDefaults.authorizationRules,
|
||||
};
|
||||
}
|
||||
|
||||
// Fix dangerous authorization rules for non-unraid.net providers
|
||||
if (provider.authorizationRules && provider.authorizationRules.length > 0) {
|
||||
// Filter out dangerous rules that would allow all emails
|
||||
const safeRules = provider.authorizationRules.filter((rule) => {
|
||||
// Remove rules that have "email endsWith @" which allows all emails
|
||||
if (
|
||||
rule.claim === 'email' &&
|
||||
rule.operator === AuthorizationOperator.ENDS_WITH &&
|
||||
rule.value &&
|
||||
rule.value.length === 1 &&
|
||||
rule.value[0] === '@'
|
||||
) {
|
||||
this.logger.warn(
|
||||
`Removing dangerous authorization rule from provider "${provider.name}": email endsWith "@" allows all emails`
|
||||
);
|
||||
return false;
|
||||
}
|
||||
// Remove rules with empty or invalid values
|
||||
if (
|
||||
!rule.value ||
|
||||
rule.value.length === 0 ||
|
||||
rule.value.every((v) => !v || !v.trim())
|
||||
) {
|
||||
this.logger.warn(
|
||||
`Removing invalid authorization rule from provider "${provider.name}": empty values`
|
||||
);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
// Update provider with safe rules
|
||||
provider.authorizationRules = safeRules;
|
||||
}
|
||||
|
||||
return provider;
|
||||
});
|
||||
|
||||
@@ -155,6 +194,28 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
|
||||
provider.authorizationRules = rules;
|
||||
}
|
||||
|
||||
// Validate that authorization rules are present and valid for ALL providers
|
||||
if (!provider.authorizationRules || provider.authorizationRules.length === 0) {
|
||||
throw new Error(
|
||||
`Provider "${provider.name}" requires authorization rules. Please configure who can access your server.`
|
||||
);
|
||||
}
|
||||
|
||||
// Validate each rule has valid values
|
||||
for (const rule of provider.authorizationRules) {
|
||||
if (!rule.claim || !rule.claim.trim()) {
|
||||
throw new Error(`Provider "${provider.name}": Authorization rule claim cannot be empty`);
|
||||
}
|
||||
if (!rule.operator) {
|
||||
throw new Error(`Provider "${provider.name}": Authorization rule operator is required`);
|
||||
}
|
||||
if (!rule.value || rule.value.length === 0 || rule.value.every((v) => !v || !v.trim())) {
|
||||
throw new Error(
|
||||
`Provider "${provider.name}": Authorization rule for claim "${rule.claim}" must have at least one non-empty value`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up the provider object - remove UI-only fields
|
||||
const cleanedProvider: OidcProvider = {
|
||||
id: provider.id,
|
||||
@@ -191,46 +252,52 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
|
||||
allowedDomains?: string[];
|
||||
allowedEmails?: string[];
|
||||
allowedUserIds?: string[];
|
||||
googleWorkspaceDomain?: string;
|
||||
}): OidcAuthorizationRule[] {
|
||||
const rules: OidcAuthorizationRule[] = [];
|
||||
|
||||
// Convert email domains to endsWith rules
|
||||
// Only add if domains are provided AND not empty AND have non-empty values
|
||||
if (simpleAuth?.allowedDomains && simpleAuth.allowedDomains.length > 0) {
|
||||
rules.push({
|
||||
claim: 'email',
|
||||
operator: AuthorizationOperator.ENDS_WITH,
|
||||
value: simpleAuth.allowedDomains.map((domain: string) =>
|
||||
domain.startsWith('@') ? domain : `@${domain}`
|
||||
),
|
||||
});
|
||||
const validDomains = simpleAuth.allowedDomains.filter(
|
||||
(domain: string) => domain && domain.trim()
|
||||
);
|
||||
if (validDomains.length > 0) {
|
||||
rules.push({
|
||||
claim: 'email',
|
||||
operator: AuthorizationOperator.ENDS_WITH,
|
||||
value: validDomains.map((domain: string) =>
|
||||
domain.startsWith('@') ? domain : `@${domain}`
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Convert specific emails to equals rules
|
||||
// Only add if emails are provided AND not empty AND have non-empty values
|
||||
if (simpleAuth?.allowedEmails && simpleAuth.allowedEmails.length > 0) {
|
||||
rules.push({
|
||||
claim: 'email',
|
||||
operator: AuthorizationOperator.EQUALS,
|
||||
value: simpleAuth.allowedEmails,
|
||||
});
|
||||
const validEmails = simpleAuth.allowedEmails.filter(
|
||||
(email: string) => email && email.trim()
|
||||
);
|
||||
if (validEmails.length > 0) {
|
||||
rules.push({
|
||||
claim: 'email',
|
||||
operator: AuthorizationOperator.EQUALS,
|
||||
value: validEmails,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Convert user IDs to sub equals rules
|
||||
// Only add if user IDs are provided AND not empty AND have non-empty values
|
||||
if (simpleAuth?.allowedUserIds && simpleAuth.allowedUserIds.length > 0) {
|
||||
rules.push({
|
||||
claim: 'sub',
|
||||
operator: AuthorizationOperator.EQUALS,
|
||||
value: simpleAuth.allowedUserIds,
|
||||
});
|
||||
}
|
||||
|
||||
// Google Workspace domain (hd claim)
|
||||
if (simpleAuth?.googleWorkspaceDomain) {
|
||||
rules.push({
|
||||
claim: 'hd',
|
||||
operator: AuthorizationOperator.EQUALS,
|
||||
value: [simpleAuth.googleWorkspaceDomain],
|
||||
});
|
||||
const validUserIds = simpleAuth.allowedUserIds.filter((id: string) => id && id.trim());
|
||||
if (validUserIds.length > 0) {
|
||||
rules.push({
|
||||
claim: 'sub',
|
||||
operator: AuthorizationOperator.EQUALS,
|
||||
value: validUserIds,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return rules;
|
||||
@@ -286,7 +353,6 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
|
||||
allowedDomains?: string[];
|
||||
allowedEmails?: string[];
|
||||
allowedUserIds?: string[];
|
||||
googleWorkspaceDomain?: string;
|
||||
}
|
||||
);
|
||||
// Return provider with generated rules, removing UI-only fields
|
||||
@@ -304,6 +370,38 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
|
||||
}),
|
||||
};
|
||||
|
||||
// Validate authorization rules for ALL providers including unraid.net
|
||||
for (const provider of processedConfig.providers) {
|
||||
if (!provider.authorizationRules || provider.authorizationRules.length === 0) {
|
||||
throw new Error(
|
||||
`Provider "${provider.name}" requires authorization rules. Please configure who can access your server.`
|
||||
);
|
||||
}
|
||||
|
||||
// Validate each rule has valid values
|
||||
for (const rule of provider.authorizationRules) {
|
||||
if (!rule.claim || !rule.claim.trim()) {
|
||||
throw new Error(
|
||||
`Provider "${provider.name}": Authorization rule claim cannot be empty`
|
||||
);
|
||||
}
|
||||
if (!rule.operator) {
|
||||
throw new Error(
|
||||
`Provider "${provider.name}": Authorization rule operator is required`
|
||||
);
|
||||
}
|
||||
if (
|
||||
!rule.value ||
|
||||
rule.value.length === 0 ||
|
||||
rule.value.every((v) => !v || !v.trim())
|
||||
) {
|
||||
throw new Error(
|
||||
`Provider "${provider.name}": Authorization rule for claim "${rule.claim}" must have at least one non-empty value`
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate OIDC discovery for all providers with issuer URLs
|
||||
const validationErrors: string[] = [];
|
||||
for (const provider of processedConfig.providers) {
|
||||
@@ -419,10 +517,6 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
|
||||
if (rule.claim === 'sub' && rule.operator === AuthorizationOperator.EQUALS) {
|
||||
return true;
|
||||
}
|
||||
// Google Workspace domain
|
||||
if (rule.claim === 'hd' && rule.operator === AuthorizationOperator.EQUALS) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
}
|
||||
@@ -431,13 +525,11 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
|
||||
allowedDomains: string[];
|
||||
allowedEmails: string[];
|
||||
allowedUserIds: string[];
|
||||
googleWorkspaceDomain?: string;
|
||||
} {
|
||||
const simpleAuth = {
|
||||
allowedDomains: [] as string[],
|
||||
allowedEmails: [] as string[],
|
||||
allowedUserIds: [] as string[],
|
||||
googleWorkspaceDomain: undefined as string | undefined,
|
||||
};
|
||||
|
||||
rules.forEach((rule) => {
|
||||
@@ -449,12 +541,6 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
|
||||
simpleAuth.allowedEmails = rule.value;
|
||||
} else if (rule.claim === 'sub' && rule.operator === AuthorizationOperator.EQUALS) {
|
||||
simpleAuth.allowedUserIds = rule.value;
|
||||
} else if (
|
||||
rule.claim === 'hd' &&
|
||||
rule.operator === AuthorizationOperator.EQUALS &&
|
||||
rule.value.length > 0
|
||||
) {
|
||||
simpleAuth.googleWorkspaceDomain = rule.value[0];
|
||||
}
|
||||
});
|
||||
|
||||
@@ -462,7 +548,36 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
|
||||
}
|
||||
|
||||
private buildSlice(): SettingSlice {
|
||||
return mergeSettingSlices([this.oidcProvidersSlice()], { as: 'sso' });
|
||||
const providersSlice = this.oidcProvidersSlice();
|
||||
|
||||
// Add defaultAllowedOrigins to the properties
|
||||
providersSlice.properties.defaultAllowedOrigins = {
|
||||
type: 'array',
|
||||
items: { type: 'string' },
|
||||
title: 'Default Allowed Redirect Origins',
|
||||
default: [],
|
||||
description:
|
||||
'Additional trusted redirect origins to allow redirects from custom ports, reverse proxies, Tailscale, etc.',
|
||||
};
|
||||
|
||||
// Add the control for defaultAllowedOrigins before the providers control using UnraidSettingsLayout
|
||||
if (providersSlice.elements?.[0]?.elements) {
|
||||
providersSlice.elements[0].elements.unshift(
|
||||
createLabeledControl({
|
||||
scope: '#/properties/sso/properties/defaultAllowedOrigins',
|
||||
label: 'Allowed Redirect Origins',
|
||||
description:
|
||||
'Add trusted origins here when accessing Unraid through custom ports, reverse proxies, or Tailscale. Each origin should include the protocol and optionally a port (e.g., https://unraid.local:8443)',
|
||||
controlOptions: {
|
||||
format: 'array',
|
||||
inputType: 'text',
|
||||
placeholder: 'https://unraid.local:8443',
|
||||
},
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
return mergeSettingSlices([providersSlice], { as: 'sso' });
|
||||
}
|
||||
|
||||
private oidcProvidersSlice(): SettingSlice {
|
||||
@@ -999,7 +1114,7 @@ export class OidcConfigPersistence extends ConfigFilePersister<OidcConfig> {
|
||||
scope: '#/properties/claim',
|
||||
label: 'JWT Claim:',
|
||||
description:
|
||||
'JWT claim to check (e.g., email, sub, groups, hd for Google hosted domain)',
|
||||
'JWT claim to check (e.g., email, sub, groups)',
|
||||
controlOptions: {
|
||||
inputType: 'text',
|
||||
placeholder: 'email',
|
||||
@@ -0,0 +1,14 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
|
||||
import { OidcAuthModule } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-auth.module.js';
|
||||
import { OidcClientModule } from '@app/unraid-api/graph/resolvers/sso/client/oidc-client.module.js';
|
||||
import { OidcBaseModule } from '@app/unraid-api/graph/resolvers/sso/core/oidc-base.module.js';
|
||||
import { OidcService } from '@app/unraid-api/graph/resolvers/sso/core/oidc.service.js';
|
||||
import { OidcSessionModule } from '@app/unraid-api/graph/resolvers/sso/session/oidc-session.module.js';
|
||||
|
||||
@Module({
|
||||
imports: [OidcBaseModule, OidcSessionModule, OidcAuthModule, OidcClientModule],
|
||||
providers: [OidcService],
|
||||
exports: [OidcService, OidcBaseModule, OidcSessionModule, OidcAuthModule, OidcClientModule],
|
||||
})
|
||||
export class OidcCoreModule {}
|
||||
@@ -0,0 +1,160 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { ConfigService } from '@nestjs/config';
|
||||
|
||||
import * as client from 'openid-client';
|
||||
|
||||
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
|
||||
import { OidcErrorHelper } from '@app/unraid-api/graph/resolvers/sso/utils/oidc-error.helper.js';
|
||||
|
||||
@Injectable()
|
||||
export class OidcValidationService {
|
||||
private readonly logger = new Logger(OidcValidationService.name);
|
||||
|
||||
constructor(private readonly configService: ConfigService) {}
|
||||
|
||||
/**
|
||||
* Validate OIDC provider configuration by attempting discovery
|
||||
* Returns validation result with helpful error messages for debugging
|
||||
*/
|
||||
async validateProvider(
|
||||
provider: OidcProvider
|
||||
): Promise<{ isValid: boolean; error?: string; details?: unknown }> {
|
||||
try {
|
||||
// Validate issuer URL is present
|
||||
if (!provider.issuer) {
|
||||
return {
|
||||
isValid: false,
|
||||
error: 'No issuer URL provided. Please specify the OIDC provider issuer URL.',
|
||||
details: { type: 'MISSING_ISSUER' },
|
||||
};
|
||||
}
|
||||
|
||||
// Validate issuer URL is valid
|
||||
let serverUrl: URL;
|
||||
try {
|
||||
serverUrl = new URL(provider.issuer);
|
||||
} catch (urlError) {
|
||||
return {
|
||||
isValid: false,
|
||||
error: `Invalid issuer URL format: '${provider.issuer}'. Please provide a valid URL.`,
|
||||
details: {
|
||||
type: 'INVALID_URL',
|
||||
originalError: urlError instanceof Error ? urlError.message : String(urlError),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Configure client options for HTTP if needed
|
||||
let clientOptions: any = undefined;
|
||||
if (serverUrl.protocol === 'http:') {
|
||||
this.logger.warn(
|
||||
`HTTP issuer URL detected for provider ${provider.id}: ${provider.issuer} - This is insecure`
|
||||
);
|
||||
clientOptions = {
|
||||
execute: [client.allowInsecureRequests],
|
||||
};
|
||||
}
|
||||
|
||||
// Attempt OIDC discovery
|
||||
await this.performDiscovery(provider, clientOptions);
|
||||
return { isValid: true };
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
|
||||
|
||||
// Log the raw error for debugging
|
||||
this.logger.log(`Raw discovery error for ${provider.id}: ${errorMessage}`);
|
||||
|
||||
// Use the helper to parse the error
|
||||
const { userFriendlyError, details } = OidcErrorHelper.parseDiscoveryError(
|
||||
error,
|
||||
provider.issuer
|
||||
);
|
||||
|
||||
this.logger.error(`Validation failed for provider ${provider.id}: ${errorMessage}`);
|
||||
|
||||
// Add debug logging for HTTP status errors
|
||||
if (errorMessage.includes('unexpected HTTP response status code')) {
|
||||
const baseUrl = provider.issuer?.endsWith('/.well-known/openid-configuration')
|
||||
? provider.issuer.replace('/.well-known/openid-configuration', '')
|
||||
: provider.issuer;
|
||||
this.logger.log(`Attempted to fetch: ${baseUrl}/.well-known/openid-configuration`);
|
||||
this.logger.error(`Full error details: ${errorMessage}`);
|
||||
}
|
||||
|
||||
return {
|
||||
isValid: false,
|
||||
error: userFriendlyError,
|
||||
details,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
async performDiscovery(provider: OidcProvider, clientOptions?: any): Promise<client.Configuration> {
|
||||
if (!provider.issuer) {
|
||||
throw new Error('No issuer URL provided');
|
||||
}
|
||||
|
||||
// Configure client auth method
|
||||
const clientAuth = provider.clientSecret
|
||||
? client.ClientSecretPost(provider.clientSecret)
|
||||
: undefined;
|
||||
|
||||
const serverUrl = new URL(provider.issuer);
|
||||
const discoveryUrl = `${provider.issuer}/.well-known/openid-configuration`;
|
||||
|
||||
this.logger.log(`Starting discovery for provider ${provider.id}`);
|
||||
this.logger.log(`Discovery URL: ${discoveryUrl}`);
|
||||
this.logger.log(`Client ID: ${provider.clientId}`);
|
||||
this.logger.log(`Client secret configured: ${provider.clientSecret ? 'Yes' : 'No'}`);
|
||||
|
||||
// Use provided client options or create default options with HTTP support if needed
|
||||
if (!clientOptions && serverUrl.protocol === 'http:') {
|
||||
this.logger.warn(
|
||||
`Allowing HTTP for ${provider.id} - This is insecure and should only be used for testing`
|
||||
);
|
||||
// For openid-client v6, use allowInsecureRequests in the execute array
|
||||
// This is deprecated but needed for local development with HTTP endpoints
|
||||
clientOptions = {
|
||||
execute: [client.allowInsecureRequests],
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const config = await client.discovery(
|
||||
serverUrl,
|
||||
provider.clientId,
|
||||
undefined, // client metadata
|
||||
clientAuth,
|
||||
clientOptions
|
||||
);
|
||||
|
||||
this.logger.log(`Discovery successful for ${provider.id}`);
|
||||
this.logger.log(`Discovery response metadata:`);
|
||||
this.logger.log(` - issuer: ${config.serverMetadata().issuer}`);
|
||||
this.logger.log(
|
||||
` - authorization_endpoint: ${config.serverMetadata().authorization_endpoint}`
|
||||
);
|
||||
this.logger.log(` - token_endpoint: ${config.serverMetadata().token_endpoint}`);
|
||||
this.logger.log(
|
||||
` - userinfo_endpoint: ${config.serverMetadata().userinfo_endpoint || 'not provided'}`
|
||||
);
|
||||
this.logger.log(` - jwks_uri: ${config.serverMetadata().jwks_uri || 'not provided'}`);
|
||||
this.logger.log(
|
||||
` - response_types_supported: ${config.serverMetadata().response_types_supported?.join(', ') || 'not provided'}`
|
||||
);
|
||||
this.logger.log(
|
||||
` - scopes_supported: ${config.serverMetadata().scopes_supported?.join(', ') || 'not provided'}`
|
||||
);
|
||||
|
||||
return config;
|
||||
} catch (discoveryError) {
|
||||
this.logger.error(`Discovery failed for ${provider.id} at ${discoveryUrl}`);
|
||||
|
||||
if (discoveryError instanceof Error) {
|
||||
this.logger.error('Discovery error: %o', discoveryError);
|
||||
}
|
||||
|
||||
throw discoveryError;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,485 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
import { ConfigModule, ConfigService } from '@nestjs/config';
|
||||
import { Test, TestingModule } from '@nestjs/testing';
|
||||
|
||||
import * as client from 'openid-client';
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { OidcAuthorizationService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-authorization.service.js';
|
||||
import { OidcClaimsService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-claims.service.js';
|
||||
import { OidcTokenExchangeService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-token-exchange.service.js';
|
||||
import { OidcClientConfigService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-client-config.service.js';
|
||||
import { OidcRedirectUriService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-redirect-uri.service.js';
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
|
||||
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/core/oidc-validation.service.js';
|
||||
import { OidcService } from '@app/unraid-api/graph/resolvers/sso/core/oidc.service.js';
|
||||
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
|
||||
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-session.service.js';
|
||||
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
|
||||
|
||||
describe('OidcService Integration Tests - Enhanced Logging', () => {
|
||||
let service: OidcService;
|
||||
let configPersistence: OidcConfigPersistence;
|
||||
let loggerSpy: any;
|
||||
let debugLogs: string[] = [];
|
||||
let errorLogs: string[] = [];
|
||||
let warnLogs: string[] = [];
|
||||
let logLogs: string[] = [];
|
||||
|
||||
beforeEach(async () => {
|
||||
// Clear log arrays
|
||||
debugLogs = [];
|
||||
errorLogs = [];
|
||||
warnLogs = [];
|
||||
logLogs = [];
|
||||
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
imports: [
|
||||
ConfigModule.forRoot({
|
||||
isGlobal: true,
|
||||
load: [() => ({ BASE_URL: 'http://test.local' })],
|
||||
}),
|
||||
],
|
||||
providers: [
|
||||
OidcService,
|
||||
OidcValidationService,
|
||||
OidcClientConfigService,
|
||||
OidcTokenExchangeService,
|
||||
{
|
||||
provide: OidcAuthorizationService,
|
||||
useValue: {
|
||||
checkAuthorization: vi.fn(),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcConfigPersistence,
|
||||
useValue: {
|
||||
getProvider: vi.fn(),
|
||||
saveProvider: vi.fn(),
|
||||
getConfig: vi.fn().mockReturnValue({
|
||||
providers: [],
|
||||
defaultAllowedOrigins: [],
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcSessionService,
|
||||
useValue: {
|
||||
createSession: vi.fn().mockResolvedValue('mock-token'),
|
||||
validateSession: vi.fn(),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcStateService,
|
||||
useValue: {
|
||||
generateSecureState: vi.fn().mockResolvedValue('secure-state'),
|
||||
validateSecureState: vi.fn().mockResolvedValue({
|
||||
isValid: true,
|
||||
clientState: 'test-state',
|
||||
redirectUri: 'https://myapp.example.com/graphql/api/auth/oidc/callback',
|
||||
}),
|
||||
extractProviderFromState: vi.fn().mockReturnValue('test-provider'),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcRedirectUriService,
|
||||
useValue: {
|
||||
getRedirectUri: vi
|
||||
.fn()
|
||||
.mockResolvedValue(
|
||||
'https://myapp.example.com/graphql/api/auth/oidc/callback'
|
||||
),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcClaimsService,
|
||||
useValue: {
|
||||
parseIdToken: vi.fn().mockReturnValue({
|
||||
sub: 'user123',
|
||||
email: 'user@example.com',
|
||||
}),
|
||||
validateClaims: vi.fn().mockReturnValue('user123'),
|
||||
},
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
service = module.get<OidcService>(OidcService);
|
||||
configPersistence = module.get<OidcConfigPersistence>(OidcConfigPersistence);
|
||||
|
||||
// Spy on logger methods to capture logs
|
||||
loggerSpy = {
|
||||
debug: vi
|
||||
.spyOn(Logger.prototype, 'debug')
|
||||
.mockImplementation((message: string, ...args: any[]) => {
|
||||
debugLogs.push(message);
|
||||
}),
|
||||
error: vi
|
||||
.spyOn(Logger.prototype, 'error')
|
||||
.mockImplementation((message: string, ...args: any[]) => {
|
||||
errorLogs.push(message);
|
||||
}),
|
||||
warn: vi
|
||||
.spyOn(Logger.prototype, 'warn')
|
||||
.mockImplementation((message: string, ...args: any[]) => {
|
||||
warnLogs.push(message);
|
||||
}),
|
||||
log: vi
|
||||
.spyOn(Logger.prototype, 'log')
|
||||
.mockImplementation((message: string, ...args: any[]) => {
|
||||
logLogs.push(message);
|
||||
}),
|
||||
verbose: vi.spyOn(Logger.prototype, 'verbose').mockImplementation(() => {}),
|
||||
};
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('Token Exchange Error Logging', () => {
|
||||
it('should log detailed error information when token exchange fails with Google (trailing slash issue)', async () => {
|
||||
// This simulates the issue from #1616 where a trailing slash causes failure
|
||||
const provider: OidcProvider = {
|
||||
id: 'google-test',
|
||||
name: 'Google Test',
|
||||
issuer: 'https://accounts.google.com/', // Trailing slash will cause issue
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
scopes: ['openid', 'email', 'profile'],
|
||||
authorizationRules: [
|
||||
{
|
||||
claim: 'email',
|
||||
operator: 'ENDS_WITH' as any,
|
||||
value: ['@example.com'],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
vi.mocked(configPersistence.getProvider).mockResolvedValue(provider);
|
||||
|
||||
try {
|
||||
await service.handleCallback({
|
||||
providerId: 'google-test',
|
||||
code: 'test-code',
|
||||
state: 'test-state',
|
||||
requestOrigin: 'http://test.local',
|
||||
fullCallbackUrl:
|
||||
'http://test.local/graphql/api/auth/oidc/callback?code=test-code&state=test-state',
|
||||
requestHeaders: { host: 'test.local' },
|
||||
});
|
||||
} catch (error) {
|
||||
// We expect this to fail
|
||||
}
|
||||
|
||||
// Verify that the service attempted to handle the callback
|
||||
// Note: Detailed token exchange logging now happens in OidcTokenExchangeService
|
||||
expect(errorLogs.length).toBeGreaterThan(0);
|
||||
// Changed logging format to use error extractor
|
||||
expect(errorLogs.some((log) => log.includes('Token exchange failed'))).toBe(true);
|
||||
});
|
||||
|
||||
it('should log discovery failure details with invalid issuer URL', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'invalid-issuer',
|
||||
name: 'Invalid Issuer Test',
|
||||
issuer: 'https://invalid-oidc-provider.example.com', // Non-existent domain
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
scopes: ['openid', 'email'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
const validationService = new OidcValidationService(new ConfigService());
|
||||
const result = await validationService.validateProvider(provider);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
// Should now have more specific error message
|
||||
expect(result.error).toBeDefined();
|
||||
// The error should mention the domain cannot be resolved or connection failed
|
||||
expect(result.error).toMatch(
|
||||
/Cannot resolve domain name|Failed to connect to OIDC provider/
|
||||
);
|
||||
expect(result.details).toBeDefined();
|
||||
expect(result.details).toHaveProperty('type');
|
||||
// Should be either DNS_ERROR or FETCH_ERROR depending on the cause
|
||||
expect(['DNS_ERROR', 'FETCH_ERROR']).toContain((result.details as any).type);
|
||||
});
|
||||
|
||||
it('should log detailed HTTP error responses from discovery', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'http-error-test',
|
||||
name: 'HTTP Error Test',
|
||||
issuer: 'https://httpstat.us/500', // Returns 500 error
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
vi.mocked(configPersistence.getProvider).mockResolvedValue(provider);
|
||||
|
||||
try {
|
||||
await service.validateProvider(provider);
|
||||
} catch (error) {
|
||||
// Expected to fail
|
||||
}
|
||||
|
||||
// Check that HTTP status details are logged (now in log level)
|
||||
expect(logLogs.some((log) => log.includes('Discovery URL:'))).toBe(true);
|
||||
expect(logLogs.some((log) => log.includes('Client ID:'))).toBe(true);
|
||||
});
|
||||
|
||||
it('should log authorization URL building details', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'auth-url-test',
|
||||
name: 'Auth URL Test',
|
||||
issuer: 'https://accounts.google.com',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
scopes: ['openid', 'email', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
vi.mocked(configPersistence.getProvider).mockResolvedValue(provider);
|
||||
|
||||
try {
|
||||
await service.getAuthorizationUrl({
|
||||
providerId: 'auth-url-test',
|
||||
state: 'test-state',
|
||||
requestOrigin: 'http://test.local',
|
||||
requestHeaders: { host: 'test.local' },
|
||||
});
|
||||
|
||||
// Verify URL building logs
|
||||
expect(logLogs.some((log) => log.includes('Built authorization URL'))).toBe(true);
|
||||
expect(logLogs.some((log) => log.includes('Authorization parameters:'))).toBe(true);
|
||||
} catch (error) {
|
||||
// May fail due to real discovery, but we're interested in the logs
|
||||
}
|
||||
});
|
||||
|
||||
it('should log detailed information for manual endpoint configuration', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'manual-endpoints',
|
||||
name: 'Manual Endpoints Test',
|
||||
issuer: undefined,
|
||||
authorizationEndpoint: 'https://auth.example.com/authorize',
|
||||
tokenEndpoint: 'https://auth.example.com/token',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
vi.mocked(configPersistence.getProvider).mockResolvedValue(provider);
|
||||
|
||||
const authUrl = await service.getAuthorizationUrl({
|
||||
providerId: 'manual-endpoints',
|
||||
state: 'test-state',
|
||||
requestOrigin: 'http://test.local',
|
||||
requestHeaders: {
|
||||
'x-forwarded-host': 'test.local',
|
||||
'x-forwarded-proto': 'http',
|
||||
},
|
||||
});
|
||||
|
||||
// Verify manual endpoint logs
|
||||
expect(debugLogs.some((log) => log.includes('Built authorization URL'))).toBe(true);
|
||||
expect(debugLogs.some((log) => log.includes('client_id=test-client-id'))).toBe(true);
|
||||
expect(authUrl).toContain('https://auth.example.com/authorize');
|
||||
});
|
||||
|
||||
it('should log JWT claim validation failures with detailed context', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'jwt-validation-test',
|
||||
name: 'JWT Validation Test',
|
||||
issuer: 'https://accounts.google.com',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
scopes: ['openid', 'email'],
|
||||
authorizationRules: [
|
||||
{
|
||||
claim: 'email',
|
||||
operator: 'ENDS_WITH' as any,
|
||||
value: ['@restricted.com'],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
vi.mocked(configPersistence.getProvider).mockResolvedValue(provider);
|
||||
|
||||
// Mock a scenario where JWT validation fails
|
||||
try {
|
||||
await service.handleCallback({
|
||||
providerId: 'jwt-validation-test',
|
||||
code: 'test-code',
|
||||
state: 'test-state',
|
||||
requestOrigin: 'http://test.local',
|
||||
fullCallbackUrl:
|
||||
'http://test.local/graphql/api/auth/oidc/callback?code=test-code&state=test-state',
|
||||
requestHeaders: { host: 'test.local' },
|
||||
});
|
||||
} catch (error) {
|
||||
// Expected to fail
|
||||
}
|
||||
|
||||
// The JWT error handling is now in OidcTokenExchangeService
|
||||
// We should see some error logged
|
||||
expect(errorLogs.length).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Discovery Endpoint Logging', () => {
|
||||
it('should log all discovery metadata when successful', async () => {
|
||||
// Use a real OIDC provider that works
|
||||
const provider: OidcProvider = {
|
||||
id: 'microsoft',
|
||||
name: 'Microsoft',
|
||||
issuer: 'https://login.microsoftonline.com/common/v2.0',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
scopes: ['openid', 'email', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
const validationService = new OidcValidationService(new ConfigService());
|
||||
|
||||
try {
|
||||
await validationService.performDiscovery(provider);
|
||||
} catch (error) {
|
||||
// May fail due to network, but we're checking logs
|
||||
}
|
||||
|
||||
// Verify discovery logging (now in log level)
|
||||
expect(logLogs.some((log) => log.includes('Starting discovery'))).toBe(true);
|
||||
expect(logLogs.some((log) => log.includes('Discovery URL:'))).toBe(true);
|
||||
});
|
||||
|
||||
it('should log discovery failures with malformed JSON response', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'malformed-json',
|
||||
name: 'Malformed JSON Test',
|
||||
issuer: 'https://example.com/malformed',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
// Mock global fetch to return HTML instead of JSON
|
||||
const originalFetch = global.fetch;
|
||||
global.fetch = vi.fn().mockImplementation(() =>
|
||||
Promise.resolve(
|
||||
new Response('<html><body>Not JSON</body></html>', {
|
||||
status: 200,
|
||||
headers: { 'content-type': 'text/html' },
|
||||
})
|
||||
)
|
||||
);
|
||||
|
||||
const validationService = new OidcValidationService(new ConfigService());
|
||||
const result = await validationService.validateProvider(provider);
|
||||
|
||||
// Restore original fetch
|
||||
global.fetch = originalFetch;
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toBeDefined();
|
||||
// The openid-client library will fail when it gets HTML instead of JSON
|
||||
// It returns "unexpected response content-type" error
|
||||
expect(result.error).toMatch(
|
||||
/Invalid OIDC discovery|malformed|doesn't conform|unexpected|content-type/i
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle and log HTTP vs HTTPS protocol differences', async () => {
|
||||
const httpProvider: OidcProvider = {
|
||||
id: 'http-local',
|
||||
name: 'HTTP Local Test',
|
||||
issuer: 'http://localhost:8080', // HTTP endpoint
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
// Create a validation service and spy on its logger
|
||||
const validationService = new OidcValidationService(new ConfigService());
|
||||
|
||||
try {
|
||||
await validationService.validateProvider(httpProvider);
|
||||
} catch (error) {
|
||||
// Expected to fail if localhost:8080 isn't running
|
||||
}
|
||||
|
||||
// The HTTP logging happens in the validation service
|
||||
// We should check that HTTP issuers are detected
|
||||
expect(httpProvider.issuer).toMatch(/^http:/);
|
||||
// Verify that we're testing an HTTP endpoint
|
||||
expect(httpProvider.issuer).toBe('http://localhost:8080');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Request/Response Detail Logging', () => {
|
||||
it('should log complete request parameters for token exchange', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'token-params-test',
|
||||
name: 'Token Params Test',
|
||||
issuer: 'https://accounts.google.com',
|
||||
clientId: 'detailed-client-id',
|
||||
clientSecret: 'detailed-client-secret',
|
||||
scopes: ['openid', 'email', 'profile', 'offline_access'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
vi.mocked(configPersistence.getProvider).mockResolvedValue(provider);
|
||||
|
||||
try {
|
||||
await service.handleCallback({
|
||||
providerId: 'token-params-test',
|
||||
code: 'authorization-code-12345',
|
||||
state: 'state-with-signature',
|
||||
requestOrigin: 'https://myapp.example.com',
|
||||
fullCallbackUrl:
|
||||
'https://myapp.example.com/graphql/api/auth/oidc/callback?code=authorization-code-12345&state=state-with-signature&scope=openid+email+profile',
|
||||
requestHeaders: { host: 'myapp.example.com' },
|
||||
});
|
||||
} catch (error) {
|
||||
// Expected to fail
|
||||
}
|
||||
|
||||
// Verify that we attempted the operation
|
||||
// Detailed parameter logging is now in OidcTokenExchangeService
|
||||
expect(debugLogs.length).toBeGreaterThan(0);
|
||||
expect(debugLogs.some((log) => log.includes('Client ID: detailed-client-id'))).toBe(true);
|
||||
expect(debugLogs.some((log) => log.includes('Client secret configured: Yes'))).toBe(true);
|
||||
});
|
||||
|
||||
it('should capture and log all error properties from openid-client', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'error-properties-test',
|
||||
name: 'Error Properties Test',
|
||||
issuer: 'https://expired-cert.badssl.com/', // SSL cert error
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
const validationService = new OidcValidationService(new ConfigService());
|
||||
const result = await validationService.validateProvider(provider);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toBeDefined();
|
||||
// Should detect SSL/certificate issues or connection failure
|
||||
expect(result.error).toMatch(
|
||||
/SSL\/TLS certificate error|Failed to connect to OIDC provider|certificate/
|
||||
);
|
||||
expect(result.details).toBeDefined();
|
||||
expect(result.details).toHaveProperty('type');
|
||||
// Should be either SSL_ERROR or FETCH_ERROR
|
||||
expect(['SSL_ERROR', 'FETCH_ERROR']).toContain((result.details as any).type);
|
||||
});
|
||||
});
|
||||
});
|
||||
381
api/src/unraid-api/graph/resolvers/sso/core/oidc.service.test.ts
Normal file
381
api/src/unraid-api/graph/resolvers/sso/core/oidc.service.test.ts
Normal file
@@ -0,0 +1,381 @@
|
||||
import { CacheModule } from '@nestjs/cache-manager';
|
||||
import { UnauthorizedException } from '@nestjs/common';
|
||||
import { Test, TestingModule } from '@nestjs/testing';
|
||||
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { OidcAuthorizationService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-authorization.service.js';
|
||||
import { OidcClaimsService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-claims.service.js';
|
||||
import { OidcTokenExchangeService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-token-exchange.service.js';
|
||||
import { OidcClientConfigService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-client-config.service.js';
|
||||
import { OidcRedirectUriService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-redirect-uri.service.js';
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
|
||||
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/core/oidc-validation.service.js';
|
||||
import { OidcService } from '@app/unraid-api/graph/resolvers/sso/core/oidc.service.js';
|
||||
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
|
||||
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-session.service.js';
|
||||
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
|
||||
|
||||
// Mock openid-client
|
||||
vi.mock('openid-client', () => ({
|
||||
buildAuthorizationUrl: vi.fn((config, params) => {
|
||||
const url = new URL(config.serverMetadata().authorization_endpoint);
|
||||
Object.entries(params).forEach(([key, value]) => {
|
||||
if (value !== undefined) {
|
||||
url.searchParams.set(key, String(value));
|
||||
}
|
||||
});
|
||||
return url;
|
||||
}),
|
||||
allowInsecureRequests: vi.fn(),
|
||||
}));
|
||||
|
||||
describe('OidcService Integration', () => {
|
||||
let service: OidcService;
|
||||
let oidcConfig: any;
|
||||
let sessionService: any;
|
||||
let stateService: OidcStateService;
|
||||
let redirectUriService: any;
|
||||
let clientConfigService: any;
|
||||
let tokenExchangeService: any;
|
||||
let claimsService: any;
|
||||
let authorizationService: any;
|
||||
|
||||
beforeEach(async () => {
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
imports: [CacheModule.register()],
|
||||
providers: [
|
||||
OidcService,
|
||||
{
|
||||
provide: OidcConfigPersistence,
|
||||
useValue: {
|
||||
getProvider: vi.fn(),
|
||||
getConfig: vi.fn().mockResolvedValue({
|
||||
providers: [],
|
||||
defaultAllowedOrigins: ['https://example.com'],
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcSessionService,
|
||||
useValue: {
|
||||
createSession: vi.fn().mockResolvedValue('padded-token-123'),
|
||||
},
|
||||
},
|
||||
OidcStateService,
|
||||
{
|
||||
provide: OidcValidationService,
|
||||
useValue: {
|
||||
validateProvider: vi.fn().mockResolvedValue({ isValid: true }),
|
||||
performDiscovery: vi.fn(),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcAuthorizationService,
|
||||
useValue: {
|
||||
checkAuthorization: vi.fn(),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcRedirectUriService,
|
||||
useValue: {
|
||||
getRedirectUri: vi.fn().mockResolvedValue('https://example.com/callback'),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcClientConfigService,
|
||||
useValue: {
|
||||
getOrCreateConfig: vi.fn(),
|
||||
clearCache: vi.fn(),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcTokenExchangeService,
|
||||
useValue: {
|
||||
exchangeCodeForTokens: vi.fn(),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcClaimsService,
|
||||
useValue: {
|
||||
parseIdToken: vi.fn(),
|
||||
validateClaims: vi.fn(),
|
||||
},
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
service = module.get<OidcService>(OidcService);
|
||||
oidcConfig = module.get(OidcConfigPersistence);
|
||||
sessionService = module.get(OidcSessionService);
|
||||
stateService = module.get<OidcStateService>(OidcStateService);
|
||||
redirectUriService = module.get(OidcRedirectUriService);
|
||||
clientConfigService = module.get(OidcClientConfigService);
|
||||
tokenExchangeService = module.get(OidcTokenExchangeService);
|
||||
claimsService = module.get(OidcClaimsService);
|
||||
authorizationService = module.get(OidcAuthorizationService);
|
||||
});
|
||||
|
||||
describe('getAuthorizationUrl', () => {
|
||||
it('should generate authorization URL with custom endpoints', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'custom-provider',
|
||||
name: 'Custom Provider',
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-secret',
|
||||
authorizationEndpoint: 'https://custom.example.com/auth',
|
||||
scopes: ['openid', 'profile'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
|
||||
const params = {
|
||||
providerId: 'custom-provider',
|
||||
state: 'client-state-123',
|
||||
requestOrigin: 'https://example.com',
|
||||
requestHeaders: { host: 'example.com' },
|
||||
};
|
||||
|
||||
const url = await service.getAuthorizationUrl(params);
|
||||
|
||||
expect(redirectUriService.getRedirectUri).toHaveBeenCalledWith('https://example.com', {
|
||||
host: 'example.com',
|
||||
});
|
||||
|
||||
const urlObj = new URL(url);
|
||||
expect(urlObj.origin).toBe('https://custom.example.com');
|
||||
expect(urlObj.pathname).toBe('/auth');
|
||||
expect(urlObj.searchParams.get('client_id')).toBe('test-client-id');
|
||||
expect(urlObj.searchParams.get('redirect_uri')).toBe('https://example.com/callback');
|
||||
expect(urlObj.searchParams.get('scope')).toBe('openid profile');
|
||||
expect(urlObj.searchParams.get('response_type')).toBe('code');
|
||||
expect(urlObj.searchParams.has('state')).toBe(true);
|
||||
});
|
||||
|
||||
it('should use OIDC discovery when no custom authorization endpoint', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'discovery-provider',
|
||||
name: 'Discovery Provider',
|
||||
clientId: 'test-client-id',
|
||||
issuer: 'https://discovery.example.com',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
// Create a mock configuration object
|
||||
const mockConfig = {
|
||||
serverMetadata: vi.fn().mockReturnValue({
|
||||
authorization_endpoint: 'https://discovery.example.com/authorize',
|
||||
}),
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
clientConfigService.getOrCreateConfig.mockResolvedValue(mockConfig);
|
||||
|
||||
const params = {
|
||||
providerId: 'discovery-provider',
|
||||
state: 'client-state-123',
|
||||
requestOrigin: 'https://example.com',
|
||||
requestHeaders: {},
|
||||
};
|
||||
|
||||
const url = await service.getAuthorizationUrl(params);
|
||||
|
||||
expect(clientConfigService.getOrCreateConfig).toHaveBeenCalledWith(provider);
|
||||
expect(url).toContain('https://discovery.example.com/authorize');
|
||||
});
|
||||
|
||||
it('should throw when provider not found', async () => {
|
||||
oidcConfig.getProvider.mockResolvedValue(null);
|
||||
|
||||
const params = {
|
||||
providerId: 'non-existent',
|
||||
state: 'state',
|
||||
requestOrigin: 'https://example.com',
|
||||
requestHeaders: {},
|
||||
};
|
||||
|
||||
await expect(service.getAuthorizationUrl(params)).rejects.toThrow(UnauthorizedException);
|
||||
});
|
||||
});
|
||||
|
||||
describe('handleCallback', () => {
|
||||
it('should handle successful callback flow', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'test-provider',
|
||||
name: 'Test Provider',
|
||||
clientId: 'test-client-id',
|
||||
issuer: 'https://test.example.com',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
const mockConfig = {
|
||||
serverMetadata: vi.fn().mockReturnValue({
|
||||
issuer: 'https://test.example.com',
|
||||
token_endpoint: 'https://test.example.com/token',
|
||||
}),
|
||||
};
|
||||
|
||||
const mockTokens = {
|
||||
id_token: 'id.token.here',
|
||||
access_token: 'access.token.here',
|
||||
};
|
||||
|
||||
const mockClaims = {
|
||||
sub: 'user123',
|
||||
email: 'user@example.com',
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
clientConfigService.getOrCreateConfig.mockResolvedValue(mockConfig);
|
||||
tokenExchangeService.exchangeCodeForTokens.mockResolvedValue(mockTokens);
|
||||
claimsService.parseIdToken.mockReturnValue(mockClaims);
|
||||
claimsService.validateClaims.mockReturnValue('user123');
|
||||
|
||||
// Mock the OidcStateExtractor's static method
|
||||
const OidcStateExtractor = await import(
|
||||
'@app/unraid-api/graph/resolvers/sso/session/oidc-state-extractor.util.js'
|
||||
);
|
||||
vi.spyOn(OidcStateExtractor.OidcStateExtractor, 'extractAndValidateState').mockResolvedValue(
|
||||
{
|
||||
providerId: 'test-provider',
|
||||
originalState: 'original-state',
|
||||
clientState: 'original-state',
|
||||
redirectUri: 'https://example.com/callback',
|
||||
}
|
||||
);
|
||||
|
||||
const params = {
|
||||
providerId: 'test-provider',
|
||||
code: 'auth-code-123',
|
||||
state: 'secure-state',
|
||||
requestOrigin: 'https://example.com',
|
||||
fullCallbackUrl: 'https://example.com/callback?code=auth-code-123&state=secure-state',
|
||||
requestHeaders: {},
|
||||
};
|
||||
|
||||
const token = await service.handleCallback(params);
|
||||
|
||||
expect(token).toBe('padded-token-123');
|
||||
expect(tokenExchangeService.exchangeCodeForTokens).toHaveBeenCalled();
|
||||
expect(claimsService.parseIdToken).toHaveBeenCalledWith('id.token.here');
|
||||
expect(claimsService.validateClaims).toHaveBeenCalledWith(mockClaims);
|
||||
expect(authorizationService.checkAuthorization).toHaveBeenCalledWith(provider, mockClaims);
|
||||
expect(sessionService.createSession).toHaveBeenCalledWith('test-provider', 'user123');
|
||||
});
|
||||
|
||||
it('should throw when provider not found', async () => {
|
||||
oidcConfig.getProvider.mockResolvedValue(null);
|
||||
|
||||
const params = {
|
||||
providerId: 'non-existent',
|
||||
code: 'code',
|
||||
state: 'state',
|
||||
requestOrigin: 'https://example.com',
|
||||
fullCallbackUrl: 'https://example.com/callback',
|
||||
requestHeaders: {},
|
||||
};
|
||||
|
||||
await expect(service.handleCallback(params)).rejects.toThrow(UnauthorizedException);
|
||||
});
|
||||
|
||||
it('should handle authorization rejection', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'test-provider',
|
||||
name: 'Test Provider',
|
||||
clientId: 'test-client-id',
|
||||
issuer: 'https://test.example.com',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
const mockConfig = {
|
||||
serverMetadata: vi.fn().mockReturnValue({
|
||||
issuer: 'https://test.example.com',
|
||||
token_endpoint: 'https://test.example.com/token',
|
||||
}),
|
||||
};
|
||||
|
||||
const mockTokens = {
|
||||
id_token: 'id.token.here',
|
||||
};
|
||||
|
||||
const mockClaims = {
|
||||
sub: 'user123',
|
||||
email: 'user@example.com',
|
||||
};
|
||||
|
||||
oidcConfig.getProvider.mockResolvedValue(provider);
|
||||
clientConfigService.getOrCreateConfig.mockResolvedValue(mockConfig);
|
||||
tokenExchangeService.exchangeCodeForTokens.mockResolvedValue(mockTokens);
|
||||
claimsService.parseIdToken.mockReturnValue(mockClaims);
|
||||
claimsService.validateClaims.mockReturnValue('user123');
|
||||
authorizationService.checkAuthorization.mockRejectedValue(
|
||||
new UnauthorizedException('Not authorized')
|
||||
);
|
||||
|
||||
// Mock the OidcStateExtractor's static method
|
||||
const OidcStateExtractor = await import(
|
||||
'@app/unraid-api/graph/resolvers/sso/session/oidc-state-extractor.util.js'
|
||||
);
|
||||
vi.spyOn(OidcStateExtractor.OidcStateExtractor, 'extractAndValidateState').mockResolvedValue(
|
||||
{
|
||||
providerId: 'test-provider',
|
||||
originalState: 'original-state',
|
||||
clientState: 'original-state',
|
||||
redirectUri: 'https://example.com/callback',
|
||||
}
|
||||
);
|
||||
|
||||
const params = {
|
||||
providerId: 'test-provider',
|
||||
code: 'auth-code-123',
|
||||
state: 'secure-state',
|
||||
requestOrigin: 'https://example.com',
|
||||
fullCallbackUrl: 'https://example.com/callback',
|
||||
requestHeaders: {},
|
||||
};
|
||||
|
||||
await expect(service.handleCallback(params)).rejects.toThrow(UnauthorizedException);
|
||||
});
|
||||
});
|
||||
|
||||
describe('validateProvider', () => {
|
||||
it('should clear cache and validate provider', async () => {
|
||||
const provider: OidcProvider = {
|
||||
id: 'test-provider',
|
||||
name: 'Test Provider',
|
||||
clientId: 'test-client-id',
|
||||
issuer: 'https://test.example.com',
|
||||
scopes: ['openid'],
|
||||
authorizationRules: [],
|
||||
};
|
||||
|
||||
const result = await service.validateProvider(provider);
|
||||
|
||||
expect(clientConfigService.clearCache).toHaveBeenCalledWith('test-provider');
|
||||
// The validation service mock already returns { isValid: true }
|
||||
expect(result).toEqual({ isValid: true });
|
||||
});
|
||||
});
|
||||
|
||||
describe('extractProviderFromState', () => {
|
||||
it('should extract provider from state', () => {
|
||||
const state = 'provider-id:original-state';
|
||||
|
||||
const result = service.extractProviderFromState(state);
|
||||
|
||||
expect(result.providerId).toBeDefined();
|
||||
expect(result.originalState).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getStateService', () => {
|
||||
it('should return state service', () => {
|
||||
const result = service.getStateService();
|
||||
expect(result).toBe(stateService);
|
||||
});
|
||||
});
|
||||
});
|
||||
243
api/src/unraid-api/graph/resolvers/sso/core/oidc.service.ts
Normal file
243
api/src/unraid-api/graph/resolvers/sso/core/oidc.service.ts
Normal file
@@ -0,0 +1,243 @@
|
||||
import { Injectable, Logger, UnauthorizedException } from '@nestjs/common';
|
||||
|
||||
import * as client from 'openid-client';
|
||||
|
||||
import { OidcAuthorizationService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-authorization.service.js';
|
||||
import { OidcClaimsService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-claims.service.js';
|
||||
import { OidcTokenExchangeService } from '@app/unraid-api/graph/resolvers/sso/auth/oidc-token-exchange.service.js';
|
||||
import { OidcClientConfigService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-client-config.service.js';
|
||||
import { OidcRedirectUriService } from '@app/unraid-api/graph/resolvers/sso/client/oidc-redirect-uri.service.js';
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
|
||||
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/core/oidc-validation.service.js';
|
||||
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
|
||||
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-session.service.js';
|
||||
import { OidcStateExtractor } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state-extractor.util.js';
|
||||
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
|
||||
import { ErrorExtractor } from '@app/unraid-api/utils/error-extractor.util.js';
|
||||
|
||||
export interface GetAuthorizationUrlParams {
|
||||
providerId: string;
|
||||
state: string;
|
||||
requestOrigin: string;
|
||||
requestHeaders: Record<string, string | string[] | undefined>;
|
||||
}
|
||||
|
||||
export interface HandleCallbackParams {
|
||||
providerId: string;
|
||||
code: string;
|
||||
state: string;
|
||||
requestOrigin: string;
|
||||
fullCallbackUrl: string;
|
||||
requestHeaders: Record<string, string | string[] | undefined>;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class OidcService {
|
||||
private readonly logger = new Logger(OidcService.name);
|
||||
|
||||
constructor(
|
||||
private readonly oidcConfig: OidcConfigPersistence,
|
||||
private readonly sessionService: OidcSessionService,
|
||||
private readonly stateService: OidcStateService,
|
||||
private readonly validationService: OidcValidationService,
|
||||
private readonly authorizationService: OidcAuthorizationService,
|
||||
private readonly redirectUriService: OidcRedirectUriService,
|
||||
private readonly clientConfigService: OidcClientConfigService,
|
||||
private readonly tokenExchangeService: OidcTokenExchangeService,
|
||||
private readonly claimsService: OidcClaimsService
|
||||
) {}
|
||||
|
||||
async getAuthorizationUrl(params: GetAuthorizationUrlParams): Promise<string> {
|
||||
const { providerId, state, requestOrigin, requestHeaders } = params;
|
||||
|
||||
const provider = await this.oidcConfig.getProvider(providerId);
|
||||
if (!provider) {
|
||||
throw new UnauthorizedException(`Provider ${providerId} not found`);
|
||||
}
|
||||
|
||||
// Use requestOrigin with validation
|
||||
const redirectUri = await this.redirectUriService.getRedirectUri(requestOrigin, requestHeaders);
|
||||
|
||||
this.logger.debug(`Using redirect URI for authorization: ${redirectUri}`);
|
||||
this.logger.debug(`Request origin was: ${requestOrigin}`);
|
||||
|
||||
// Generate secure state with cryptographic signature, including redirect URI
|
||||
const secureState = await this.stateService.generateSecureState(providerId, state, redirectUri);
|
||||
|
||||
// Build authorization URL
|
||||
if (provider.authorizationEndpoint) {
|
||||
// Use custom authorization endpoint
|
||||
const authUrl = new URL(provider.authorizationEndpoint);
|
||||
|
||||
// Standard OAuth2 parameters
|
||||
authUrl.searchParams.set('client_id', provider.clientId);
|
||||
authUrl.searchParams.set('redirect_uri', redirectUri);
|
||||
authUrl.searchParams.set('scope', provider.scopes.join(' '));
|
||||
authUrl.searchParams.set('state', secureState);
|
||||
authUrl.searchParams.set('response_type', 'code');
|
||||
|
||||
this.logger.debug(`Built authorization URL for provider ${provider.id}`);
|
||||
this.logger.debug(
|
||||
`Authorization parameters: client_id=${provider.clientId}, redirect_uri=${redirectUri}, scope=${provider.scopes.join(' ')}, response_type=code`
|
||||
);
|
||||
|
||||
return authUrl.href;
|
||||
}
|
||||
|
||||
// Use OIDC discovery for providers without custom endpoints
|
||||
const config = await this.clientConfigService.getOrCreateConfig(provider);
|
||||
const parameters: Record<string, string> = {
|
||||
redirect_uri: redirectUri,
|
||||
scope: provider.scopes.join(' '),
|
||||
state: secureState,
|
||||
response_type: 'code',
|
||||
};
|
||||
|
||||
// For HTTP endpoints, we need to call allowInsecureRequests on the config
|
||||
if (provider.issuer) {
|
||||
try {
|
||||
const serverUrl = new URL(provider.issuer);
|
||||
if (serverUrl.protocol === 'http:') {
|
||||
this.logger.debug(`Allowing insecure requests for HTTP endpoint: ${provider.id}`);
|
||||
// allowInsecureRequests is deprecated but still needed for HTTP endpoints
|
||||
client.allowInsecureRequests(config);
|
||||
}
|
||||
} catch (error) {
|
||||
this.logger.warn(`Invalid issuer URL for provider ${provider.id}: ${provider.issuer}`);
|
||||
// Continue without special HTTP options
|
||||
}
|
||||
}
|
||||
|
||||
const authUrl = client.buildAuthorizationUrl(config, parameters);
|
||||
|
||||
this.logger.log(`Built authorization URL via discovery for provider ${provider.id}`);
|
||||
this.logger.log(`Authorization parameters: ${JSON.stringify(parameters)}`);
|
||||
|
||||
return authUrl.href;
|
||||
}
|
||||
|
||||
extractProviderFromState(state: string): { providerId: string; originalState: string } {
|
||||
return OidcStateExtractor.extractProviderFromState(state, this.stateService);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the state service for external utilities
|
||||
*/
|
||||
getStateService(): OidcStateService {
|
||||
return this.stateService;
|
||||
}
|
||||
|
||||
async handleCallback(params: HandleCallbackParams): Promise<string> {
|
||||
const { providerId, code, state, fullCallbackUrl } = params;
|
||||
|
||||
const provider = await this.oidcConfig.getProvider(providerId);
|
||||
if (!provider) {
|
||||
throw new UnauthorizedException(`Provider ${providerId} not found`);
|
||||
}
|
||||
|
||||
// Extract and validate state, including the stored redirect URI
|
||||
const stateInfo = await OidcStateExtractor.extractAndValidateState(state, this.stateService);
|
||||
if (!stateInfo.redirectUri) {
|
||||
throw new UnauthorizedException('Missing redirect URI in state');
|
||||
}
|
||||
|
||||
// Use the redirect URI that was stored during authorization
|
||||
const redirectUri = stateInfo.redirectUri;
|
||||
this.logger.debug(`Using stored redirect URI from state: ${redirectUri}`);
|
||||
|
||||
try {
|
||||
// Always use openid-client for consistency
|
||||
const config = await this.clientConfigService.getOrCreateConfig(provider);
|
||||
|
||||
// Log configuration details
|
||||
this.logger.debug(`Provider ${providerId} config loaded`);
|
||||
this.logger.debug(`Redirect URI: ${redirectUri}`);
|
||||
|
||||
// Build current URL for token exchange
|
||||
// CRITICAL: The URL used here MUST match the redirect_uri that was sent to the authorization endpoint
|
||||
// Google expects the exact same redirect_uri during token exchange
|
||||
const currentUrl = new URL(redirectUri);
|
||||
currentUrl.searchParams.set('code', code);
|
||||
currentUrl.searchParams.set('state', state);
|
||||
|
||||
// Copy additional parameters from the actual callback if provided
|
||||
if (fullCallbackUrl) {
|
||||
const actualUrl = new URL(fullCallbackUrl);
|
||||
// Copy over additional params that Google might have added (scope, authuser, prompt, etc)
|
||||
// but DO NOT change the base URL or path
|
||||
['scope', 'authuser', 'prompt', 'hd', 'session_state', 'iss'].forEach((param) => {
|
||||
const value = actualUrl.searchParams.get(param);
|
||||
if (value && !currentUrl.searchParams.has(param)) {
|
||||
currentUrl.searchParams.set(param, value);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Google returns iss in the response, openid-client v6 expects it
|
||||
// If not present, add it based on the provider's issuer
|
||||
if (!currentUrl.searchParams.has('iss') && provider.issuer) {
|
||||
currentUrl.searchParams.set('iss', provider.issuer);
|
||||
}
|
||||
|
||||
this.logger.debug(`Token exchange URL (matches redirect_uri): ${currentUrl.href}`);
|
||||
|
||||
// State was already validated in extractAndValidateState above, use that result
|
||||
// The clientState should be present after successful validation, but handle the edge case
|
||||
if (!stateInfo.clientState) {
|
||||
this.logger.warn('Client state missing after successful validation');
|
||||
throw new UnauthorizedException('Invalid state: missing client state');
|
||||
}
|
||||
const originalState = stateInfo.clientState;
|
||||
this.logger.debug(`Exchanging code for tokens with provider ${providerId}`);
|
||||
this.logger.debug(`Client state extracted: ${originalState}`);
|
||||
|
||||
// Use the token exchange service
|
||||
const tokens = await this.tokenExchangeService.exchangeCodeForTokens(
|
||||
config,
|
||||
provider,
|
||||
code,
|
||||
originalState,
|
||||
redirectUri,
|
||||
fullCallbackUrl
|
||||
);
|
||||
|
||||
// Parse ID token to get user info
|
||||
const claims = this.claimsService.parseIdToken(tokens.id_token);
|
||||
const userSub = this.claimsService.validateClaims(claims);
|
||||
|
||||
// Check authorization based on rules
|
||||
// This will throw a helpful error if misconfigured or unauthorized
|
||||
await this.authorizationService.checkAuthorization(provider, claims!);
|
||||
|
||||
// Create session and return padded token
|
||||
const paddedToken = await this.sessionService.createSession(providerId, userSub);
|
||||
|
||||
this.logger.log(`Successfully authenticated user ${userSub} via provider ${providerId}`);
|
||||
|
||||
return paddedToken;
|
||||
} catch (error) {
|
||||
const extracted = ErrorExtractor.extract(error);
|
||||
this.logger.error(`OAuth callback error: ${extracted.message}`);
|
||||
// Re-throw the original error if it's already an UnauthorizedException
|
||||
if (error instanceof UnauthorizedException) {
|
||||
throw error;
|
||||
}
|
||||
// Otherwise throw a generic error
|
||||
throw new UnauthorizedException('Authentication failed');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate OIDC provider configuration by attempting discovery
|
||||
* Returns validation result with helpful error messages for debugging
|
||||
*/
|
||||
async validateProvider(
|
||||
provider: OidcProvider
|
||||
): Promise<{ isValid: boolean; error?: string; details?: unknown }> {
|
||||
// Clear any cached config for this provider to force fresh validation
|
||||
this.clientConfigService.clearCache(provider.id);
|
||||
|
||||
// Delegate to the validation service
|
||||
return this.validationService.validateProvider(provider);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
import { Field, ObjectType } from '@nestjs/graphql';
|
||||
|
||||
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
|
||||
|
||||
@ObjectType()
|
||||
export class OidcConfiguration {
|
||||
@Field(() => [OidcProvider], { description: 'List of configured OIDC providers' })
|
||||
providers!: OidcProvider[];
|
||||
|
||||
@Field(() => [String], {
|
||||
nullable: true,
|
||||
description:
|
||||
'Default allowed redirect origins that apply to all OIDC providers (e.g., Tailscale domains)',
|
||||
})
|
||||
defaultAllowedOrigins?: string[];
|
||||
}
|
||||
@@ -80,9 +80,11 @@ export class OidcProvider {
|
||||
@Field(() => String, {
|
||||
description:
|
||||
'OIDC issuer URL (e.g., https://accounts.google.com). Required for auto-discovery via /.well-known/openid-configuration',
|
||||
nullable: true,
|
||||
})
|
||||
@IsUrl()
|
||||
issuer!: string;
|
||||
@IsOptional()
|
||||
issuer?: string;
|
||||
|
||||
@Field(() => String, {
|
||||
nullable: true,
|
||||
@@ -1,4 +1,4 @@
|
||||
import type { OidcConfig } from '@app/unraid-api/graph/resolvers/sso/oidc-config.service.js';
|
||||
import type { OidcConfig } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
|
||||
|
||||
declare module '@unraid/shared/services/user-settings.js' {
|
||||
interface UserSettings {
|
||||
@@ -1,701 +0,0 @@
|
||||
import { Injectable, Logger, UnauthorizedException } from '@nestjs/common';
|
||||
import { ConfigService } from '@nestjs/config';
|
||||
|
||||
import { decodeJwt } from 'jose';
|
||||
import * as client from 'openid-client';
|
||||
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/oidc-config.service.js';
|
||||
import {
|
||||
AuthorizationOperator,
|
||||
AuthorizationRuleMode,
|
||||
OidcAuthorizationRule,
|
||||
OidcProvider,
|
||||
} from '@app/unraid-api/graph/resolvers/sso/oidc-provider.model.js';
|
||||
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/oidc-session.service.js';
|
||||
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/oidc-state.service.js';
|
||||
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/oidc-validation.service.js';
|
||||
|
||||
interface JwtClaims {
|
||||
sub?: string;
|
||||
email?: string;
|
||||
name?: string;
|
||||
hd?: string; // Google hosted domain
|
||||
[claim: string]: unknown;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class OidcAuthService {
|
||||
private readonly logger = new Logger(OidcAuthService.name);
|
||||
private readonly configCache = new Map<string, client.Configuration>();
|
||||
|
||||
constructor(
|
||||
private readonly configService: ConfigService,
|
||||
private readonly oidcConfig: OidcConfigPersistence,
|
||||
private readonly sessionService: OidcSessionService,
|
||||
private readonly stateService: OidcStateService,
|
||||
private readonly validationService: OidcValidationService
|
||||
) {}
|
||||
|
||||
async getAuthorizationUrl(
|
||||
providerId: string,
|
||||
state: string,
|
||||
requestOrigin?: string
|
||||
): Promise<string> {
|
||||
const provider = await this.oidcConfig.getProvider(providerId);
|
||||
if (!provider) {
|
||||
throw new UnauthorizedException(`Provider ${providerId} not found`);
|
||||
}
|
||||
|
||||
const redirectUri = this.getRedirectUri(requestOrigin);
|
||||
|
||||
// Generate secure state with cryptographic signature
|
||||
const secureState = this.stateService.generateSecureState(providerId, state);
|
||||
|
||||
// Build authorization URL
|
||||
if (provider.authorizationEndpoint) {
|
||||
// Use custom authorization endpoint
|
||||
const authUrl = new URL(provider.authorizationEndpoint);
|
||||
|
||||
// Standard OAuth2 parameters
|
||||
authUrl.searchParams.set('client_id', provider.clientId);
|
||||
authUrl.searchParams.set('redirect_uri', redirectUri);
|
||||
authUrl.searchParams.set('scope', provider.scopes.join(' '));
|
||||
authUrl.searchParams.set('state', secureState);
|
||||
authUrl.searchParams.set('response_type', 'code');
|
||||
|
||||
return authUrl.href;
|
||||
}
|
||||
|
||||
// Use OIDC discovery for providers without custom endpoints
|
||||
const config = await this.getOrCreateConfig(provider);
|
||||
const parameters: Record<string, string> = {
|
||||
redirect_uri: redirectUri,
|
||||
scope: provider.scopes.join(' '),
|
||||
state: secureState,
|
||||
response_type: 'code',
|
||||
};
|
||||
|
||||
// For HTTP endpoints, we need to pass the allowInsecureRequests option
|
||||
const serverUrl = new URL(provider.issuer || '');
|
||||
let clientOptions: any = undefined;
|
||||
if (serverUrl.protocol === 'http:') {
|
||||
this.logger.debug(
|
||||
`Building authorization URL with allowInsecureRequests for ${provider.id}`
|
||||
);
|
||||
clientOptions = {
|
||||
execute: [client.allowInsecureRequests],
|
||||
};
|
||||
}
|
||||
|
||||
const authUrl = client.buildAuthorizationUrl(config, parameters);
|
||||
|
||||
return authUrl.href;
|
||||
}
|
||||
|
||||
extractProviderFromState(state: string): { providerId: string; originalState: string } {
|
||||
// Extract provider from state prefix (no decryption needed)
|
||||
const providerId = this.stateService.extractProviderFromState(state);
|
||||
|
||||
if (providerId) {
|
||||
return {
|
||||
providerId,
|
||||
originalState: state,
|
||||
};
|
||||
}
|
||||
|
||||
// Fallback for unknown formats
|
||||
return {
|
||||
providerId: '',
|
||||
originalState: state,
|
||||
};
|
||||
}
|
||||
|
||||
async handleCallback(
|
||||
providerId: string,
|
||||
code: string,
|
||||
state: string,
|
||||
requestOrigin?: string,
|
||||
fullCallbackUrl?: string
|
||||
): Promise<string> {
|
||||
const provider = await this.oidcConfig.getProvider(providerId);
|
||||
if (!provider) {
|
||||
throw new UnauthorizedException(`Provider ${providerId} not found`);
|
||||
}
|
||||
|
||||
try {
|
||||
const redirectUri = this.getRedirectUri(requestOrigin);
|
||||
|
||||
// Always use openid-client for consistency
|
||||
const config = await this.getOrCreateConfig(provider);
|
||||
|
||||
// Log configuration details
|
||||
this.logger.debug(`Provider ${providerId} config loaded`);
|
||||
this.logger.debug(`Redirect URI: ${redirectUri}`);
|
||||
|
||||
// Build current URL for token exchange
|
||||
// CRITICAL: The URL used here MUST match the redirect_uri that was sent to the authorization endpoint
|
||||
// Google expects the exact same redirect_uri during token exchange
|
||||
const currentUrl = new URL(redirectUri);
|
||||
currentUrl.searchParams.set('code', code);
|
||||
currentUrl.searchParams.set('state', state);
|
||||
|
||||
// Copy additional parameters from the actual callback if provided
|
||||
if (fullCallbackUrl) {
|
||||
const actualUrl = new URL(fullCallbackUrl);
|
||||
// Copy over additional params that Google might have added (scope, authuser, prompt, etc)
|
||||
// but DO NOT change the base URL or path
|
||||
['scope', 'authuser', 'prompt', 'hd', 'session_state', 'iss'].forEach((param) => {
|
||||
const value = actualUrl.searchParams.get(param);
|
||||
if (value && !currentUrl.searchParams.has(param)) {
|
||||
currentUrl.searchParams.set(param, value);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Google returns iss in the response, openid-client v6 expects it
|
||||
// If not present, add it based on the provider's issuer
|
||||
if (!currentUrl.searchParams.has('iss') && provider.issuer) {
|
||||
currentUrl.searchParams.set('iss', provider.issuer);
|
||||
}
|
||||
|
||||
this.logger.debug(`Token exchange URL (matches redirect_uri): ${currentUrl.href}`);
|
||||
|
||||
// Validate secure state
|
||||
const stateValidation = this.stateService.validateSecureState(state, providerId);
|
||||
if (!stateValidation.isValid) {
|
||||
this.logger.error(`State validation failed: ${stateValidation.error}`);
|
||||
throw new UnauthorizedException(stateValidation.error || 'Invalid state parameter');
|
||||
}
|
||||
|
||||
const originalState = stateValidation.clientState!;
|
||||
this.logger.debug(`Exchanging code for tokens with provider ${providerId}`);
|
||||
this.logger.debug(`Client state extracted: ${originalState}`);
|
||||
|
||||
// For openid-client v6, we need to prepare the authorization response
|
||||
const authorizationResponse = new URLSearchParams(currentUrl.search);
|
||||
|
||||
// Set the original client state for openid-client
|
||||
authorizationResponse.set('state', originalState);
|
||||
|
||||
// Create a new URL with the cleaned parameters
|
||||
const cleanUrl = new URL(redirectUri);
|
||||
cleanUrl.search = authorizationResponse.toString();
|
||||
|
||||
this.logger.debug(`Clean URL for token exchange: ${cleanUrl.href}`);
|
||||
|
||||
let tokens;
|
||||
try {
|
||||
this.logger.debug(`Starting token exchange with openid-client`);
|
||||
this.logger.debug(`Config issuer: ${config.serverMetadata().issuer}`);
|
||||
this.logger.debug(`Config token endpoint: ${config.serverMetadata().token_endpoint}`);
|
||||
|
||||
// For HTTP endpoints, we need to pass the allowInsecureRequests option
|
||||
const serverUrl = new URL(provider.issuer || '');
|
||||
let clientOptions: any = undefined;
|
||||
if (serverUrl.protocol === 'http:') {
|
||||
this.logger.debug(`Token exchange with allowInsecureRequests for ${provider.id}`);
|
||||
clientOptions = {
|
||||
execute: [client.allowInsecureRequests],
|
||||
};
|
||||
}
|
||||
|
||||
tokens = await client.authorizationCodeGrant(
|
||||
config,
|
||||
cleanUrl,
|
||||
{
|
||||
expectedState: originalState,
|
||||
},
|
||||
clientOptions
|
||||
);
|
||||
this.logger.debug(
|
||||
`Token exchange successful, received tokens: ${Object.keys(tokens).join(', ')}`
|
||||
);
|
||||
} catch (tokenError) {
|
||||
const errorMessage =
|
||||
tokenError instanceof Error ? tokenError.message : String(tokenError);
|
||||
this.logger.error(`Token exchange failed: ${errorMessage}`);
|
||||
|
||||
// Check if error message contains the "unexpected JWT claim" text
|
||||
if (errorMessage.includes('unexpected JWT claim value encountered')) {
|
||||
this.logger.error(
|
||||
`unexpected JWT claim value encountered during token validation by openid-client`
|
||||
);
|
||||
this.logger.debug(
|
||||
`Token exchange error details: ${JSON.stringify(tokenError, null, 2)}`
|
||||
);
|
||||
|
||||
// Log the actual vs expected issuer
|
||||
this.logger.error(
|
||||
`This error typically means the 'iss' claim in the JWT doesn't match the expected issuer`
|
||||
);
|
||||
this.logger.error(`Check that your provider's issuer URL is configured correctly`);
|
||||
}
|
||||
|
||||
throw tokenError;
|
||||
}
|
||||
|
||||
// Parse ID token to get user info
|
||||
let claims: JwtClaims | null = null;
|
||||
if (tokens.id_token) {
|
||||
try {
|
||||
// Use jose to properly decode the JWT
|
||||
claims = decodeJwt(tokens.id_token) as JwtClaims;
|
||||
|
||||
// Log claims safely without PII - only structure, not values
|
||||
if (claims) {
|
||||
const claimKeys = Object.keys(claims).join(', ');
|
||||
this.logger.debug(
|
||||
`ID token decoded successfully. Available claims: [${claimKeys}]`
|
||||
);
|
||||
|
||||
// Log claim types without exposing sensitive values
|
||||
for (const [key, value] of Object.entries(claims)) {
|
||||
const valueType = Array.isArray(value)
|
||||
? `array[${value.length}]`
|
||||
: typeof value;
|
||||
|
||||
// Only log structure, not actual values (avoid PII)
|
||||
this.logger.debug(`Claim '${key}': type=${valueType}`);
|
||||
|
||||
// Check for unexpected claim types
|
||||
if (valueType === 'object' && value !== null && !Array.isArray(value)) {
|
||||
this.logger.warn(`Claim '${key}' contains complex object structure`);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
this.logger.warn(`Failed to parse ID token: ${e}`);
|
||||
}
|
||||
} else {
|
||||
this.logger.error('No ID token received from provider');
|
||||
}
|
||||
|
||||
if (!claims?.sub) {
|
||||
this.logger.error(
|
||||
'No subject in token - claims available: ' +
|
||||
(claims ? Object.keys(claims).join(', ') : 'none')
|
||||
);
|
||||
throw new UnauthorizedException('No subject in token');
|
||||
}
|
||||
|
||||
const userSub = claims.sub;
|
||||
this.logger.debug(`Processing authentication for user: ${userSub}`);
|
||||
|
||||
// Check authorization based on rules
|
||||
// This will throw a helpful error if misconfigured or unauthorized
|
||||
await this.checkAuthorization(provider, claims);
|
||||
|
||||
// Create session and return padded token
|
||||
const paddedToken = await this.sessionService.createSession(providerId, userSub);
|
||||
|
||||
this.logger.log(`Successfully authenticated user ${userSub} via provider ${providerId}`);
|
||||
|
||||
return paddedToken;
|
||||
} catch (error) {
|
||||
this.logger.error(
|
||||
`OAuth callback error: ${error instanceof Error ? error.message : 'Unknown error'}`
|
||||
);
|
||||
// Re-throw the original error if it's already an UnauthorizedException
|
||||
if (error instanceof UnauthorizedException) {
|
||||
throw error;
|
||||
}
|
||||
// Otherwise throw a generic error
|
||||
throw new UnauthorizedException('Authentication failed');
|
||||
}
|
||||
}
|
||||
|
||||
private async getOrCreateConfig(provider: OidcProvider): Promise<client.Configuration> {
|
||||
const cacheKey = provider.id;
|
||||
|
||||
if (this.configCache.has(cacheKey)) {
|
||||
return this.configCache.get(cacheKey)!;
|
||||
}
|
||||
|
||||
try {
|
||||
// Use the validation service to perform discovery with HTTP support
|
||||
if (provider.issuer) {
|
||||
this.logger.debug(`Attempting discovery for ${provider.id} at ${provider.issuer}`);
|
||||
|
||||
// Create client options with HTTP support if needed
|
||||
const serverUrl = new URL(provider.issuer);
|
||||
let clientOptions: any = undefined;
|
||||
if (serverUrl.protocol === 'http:') {
|
||||
this.logger.debug(`Allowing HTTP for ${provider.id} as specified by user`);
|
||||
clientOptions = {
|
||||
execute: [client.allowInsecureRequests],
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const config = await this.validationService.performDiscovery(
|
||||
provider,
|
||||
clientOptions
|
||||
);
|
||||
this.logger.debug(`Discovery successful for ${provider.id}`);
|
||||
this.logger.debug(
|
||||
`Authorization endpoint: ${config.serverMetadata().authorization_endpoint}`
|
||||
);
|
||||
this.logger.debug(`Token endpoint: ${config.serverMetadata().token_endpoint}`);
|
||||
this.configCache.set(cacheKey, config);
|
||||
return config;
|
||||
} catch (discoveryError) {
|
||||
const errorMessage =
|
||||
discoveryError instanceof Error ? discoveryError.message : 'Unknown error';
|
||||
this.logger.warn(`Discovery failed for ${provider.id}: ${errorMessage}`);
|
||||
|
||||
// Log more details about the discovery error
|
||||
this.logger.debug(
|
||||
`Discovery URL attempted: ${provider.issuer}/.well-known/openid-configuration`
|
||||
);
|
||||
this.logger.debug(
|
||||
`Full discovery error: ${JSON.stringify(discoveryError, null, 2)}`
|
||||
);
|
||||
|
||||
// Log stack trace for better debugging
|
||||
if (discoveryError instanceof Error && discoveryError.stack) {
|
||||
this.logger.debug(`Stack trace: ${discoveryError.stack}`);
|
||||
}
|
||||
|
||||
// If discovery fails but we have manual endpoints, use them
|
||||
if (provider.authorizationEndpoint && provider.tokenEndpoint) {
|
||||
this.logger.log(`Using manual endpoints for ${provider.id}`);
|
||||
|
||||
// Create manual configuration
|
||||
const serverMetadata: client.ServerMetadata = {
|
||||
issuer: provider.issuer || `manual-${provider.id}`,
|
||||
authorization_endpoint: provider.authorizationEndpoint,
|
||||
token_endpoint: provider.tokenEndpoint,
|
||||
jwks_uri: provider.jwksUri,
|
||||
};
|
||||
|
||||
const clientMetadata: Partial<client.ClientMetadata> = {
|
||||
client_secret: provider.clientSecret,
|
||||
};
|
||||
|
||||
// Configure client auth method
|
||||
const clientAuth = provider.clientSecret
|
||||
? client.ClientSecretPost(provider.clientSecret)
|
||||
: client.None();
|
||||
|
||||
try {
|
||||
const config = new client.Configuration(
|
||||
serverMetadata,
|
||||
provider.clientId,
|
||||
clientMetadata,
|
||||
clientAuth
|
||||
);
|
||||
|
||||
// Use manual configuration with HTTP support if needed
|
||||
const serverUrl = new URL(provider.tokenEndpoint);
|
||||
if (serverUrl.protocol === 'http:') {
|
||||
this.logger.debug(
|
||||
`Allowing HTTP for manual endpoints on ${provider.id}`
|
||||
);
|
||||
client.allowInsecureRequests(config);
|
||||
}
|
||||
|
||||
this.logger.debug(`Manual configuration created for ${provider.id}`);
|
||||
this.logger.debug(
|
||||
`Authorization endpoint: ${serverMetadata.authorization_endpoint}`
|
||||
);
|
||||
this.logger.debug(`Token endpoint: ${serverMetadata.token_endpoint}`);
|
||||
|
||||
this.configCache.set(cacheKey, config);
|
||||
return config;
|
||||
} catch (manualConfigError) {
|
||||
this.logger.error(
|
||||
`Failed to create manual configuration: ${manualConfigError instanceof Error ? manualConfigError.message : 'Unknown error'}`
|
||||
);
|
||||
throw new Error(`Manual configuration failed for ${provider.id}`);
|
||||
}
|
||||
} else {
|
||||
throw new Error(
|
||||
`OIDC discovery failed and no manual endpoints provided for ${provider.id}`
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Manual configuration when no issuer is provided
|
||||
if (provider.authorizationEndpoint && provider.tokenEndpoint) {
|
||||
this.logger.log(`Using manual endpoints for ${provider.id} (no issuer provided)`);
|
||||
|
||||
// Create manual configuration
|
||||
const serverMetadata: client.ServerMetadata = {
|
||||
issuer: provider.issuer || `manual-${provider.id}`,
|
||||
authorization_endpoint: provider.authorizationEndpoint,
|
||||
token_endpoint: provider.tokenEndpoint,
|
||||
jwks_uri: provider.jwksUri,
|
||||
};
|
||||
|
||||
const clientMetadata: Partial<client.ClientMetadata> = {
|
||||
client_secret: provider.clientSecret,
|
||||
};
|
||||
|
||||
// Configure client auth method
|
||||
const clientAuth = provider.clientSecret
|
||||
? client.ClientSecretPost(provider.clientSecret)
|
||||
: client.None();
|
||||
|
||||
try {
|
||||
const config = new client.Configuration(
|
||||
serverMetadata,
|
||||
provider.clientId,
|
||||
clientMetadata,
|
||||
clientAuth
|
||||
);
|
||||
|
||||
// Use manual configuration with HTTP support if needed
|
||||
const serverUrl = new URL(provider.tokenEndpoint);
|
||||
if (serverUrl.protocol === 'http:') {
|
||||
this.logger.debug(`Allowing HTTP for manual endpoints on ${provider.id}`);
|
||||
client.allowInsecureRequests(config);
|
||||
}
|
||||
|
||||
this.logger.debug(`Manual configuration created for ${provider.id}`);
|
||||
this.logger.debug(
|
||||
`Authorization endpoint: ${serverMetadata.authorization_endpoint}`
|
||||
);
|
||||
this.logger.debug(`Token endpoint: ${serverMetadata.token_endpoint}`);
|
||||
|
||||
this.configCache.set(cacheKey, config);
|
||||
return config;
|
||||
} catch (manualConfigError) {
|
||||
this.logger.error(
|
||||
`Failed to create manual configuration: ${manualConfigError instanceof Error ? manualConfigError.message : 'Unknown error'}`
|
||||
);
|
||||
throw new Error(`Manual configuration failed for ${provider.id}`);
|
||||
}
|
||||
}
|
||||
|
||||
// If we reach here, neither discovery nor manual endpoints are available
|
||||
throw new Error(
|
||||
`No configuration method available for ${provider.id}: requires either valid issuer for discovery or manual endpoints`
|
||||
);
|
||||
} catch (error) {
|
||||
this.logger.error(
|
||||
`Failed to create OIDC configuration for ${provider.id}: ${
|
||||
error instanceof Error ? error.message : 'Unknown error'
|
||||
}`
|
||||
);
|
||||
|
||||
// Log more details in debug mode
|
||||
if (error instanceof Error && error.stack) {
|
||||
this.logger.debug(`Stack trace: ${error.stack}`);
|
||||
}
|
||||
|
||||
throw new UnauthorizedException('Provider configuration error');
|
||||
}
|
||||
}
|
||||
|
||||
private async checkAuthorization(provider: OidcProvider, claims: JwtClaims): Promise<void> {
|
||||
this.logger.debug(
|
||||
`Checking authorization for provider ${provider.id} with ${provider.authorizationRules?.length || 0} rules`
|
||||
);
|
||||
this.logger.debug(`Available claims: ${Object.keys(claims).join(', ')}`);
|
||||
this.logger.debug(
|
||||
`Authorization rule mode: ${provider.authorizationRuleMode || AuthorizationRuleMode.OR}`
|
||||
);
|
||||
|
||||
// If no authorization rules are specified, throw a helpful error
|
||||
if (!provider.authorizationRules || provider.authorizationRules.length === 0) {
|
||||
throw new UnauthorizedException(
|
||||
`Login failed: The ${provider.name} provider has no authorization rules configured. ` +
|
||||
`Please configure authorization rules.`
|
||||
);
|
||||
}
|
||||
|
||||
this.logger.debug(
|
||||
`Authorization rules to evaluate: ${JSON.stringify(provider.authorizationRules, null, 2)}`
|
||||
);
|
||||
|
||||
// Evaluate the rules
|
||||
const ruleMode = provider.authorizationRuleMode || AuthorizationRuleMode.OR;
|
||||
const isAuthorized = this.evaluateAuthorizationRules(
|
||||
provider.authorizationRules,
|
||||
claims,
|
||||
ruleMode
|
||||
);
|
||||
|
||||
this.logger.debug(`Authorization result: ${isAuthorized}`);
|
||||
|
||||
if (!isAuthorized) {
|
||||
// Log authorization failure with safe claim representation (no PII)
|
||||
const availableClaimKeys = Object.keys(claims).join(', ');
|
||||
this.logger.warn(
|
||||
`Authorization failed for provider ${provider.name}, user ${claims.sub}, available claim keys: [${availableClaimKeys}]`
|
||||
);
|
||||
throw new UnauthorizedException(
|
||||
`Access denied: Your account does not meet the authorization requirements for ${provider.name}.`
|
||||
);
|
||||
}
|
||||
|
||||
this.logger.debug(`Authorization successful for user ${claims.sub}`);
|
||||
}
|
||||
|
||||
private evaluateAuthorizationRules(
|
||||
rules: OidcAuthorizationRule[],
|
||||
claims: JwtClaims,
|
||||
mode: AuthorizationRuleMode = AuthorizationRuleMode.OR
|
||||
): boolean {
|
||||
// No rules means no authorization
|
||||
if (rules.length === 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (mode === AuthorizationRuleMode.AND) {
|
||||
// All rules must pass (AND logic)
|
||||
return rules.every((rule) => this.evaluateRule(rule, claims));
|
||||
} else {
|
||||
// Any rule can pass (OR logic) - default behavior
|
||||
// Multiple rules act as alternative authorization paths
|
||||
return rules.some((rule) => this.evaluateRule(rule, claims));
|
||||
}
|
||||
}
|
||||
|
||||
private evaluateRule(rule: OidcAuthorizationRule, claims: JwtClaims): boolean {
|
||||
const claimValue = claims[rule.claim];
|
||||
|
||||
this.logger.verbose(
|
||||
`Evaluating rule for claim ${rule.claim}: ${JSON.stringify({
|
||||
claimValue,
|
||||
claimType: typeof claimValue,
|
||||
isArray: Array.isArray(claimValue),
|
||||
ruleOperator: rule.operator,
|
||||
ruleValues: rule.value,
|
||||
})}`
|
||||
);
|
||||
|
||||
if (claimValue === undefined || claimValue === null) {
|
||||
this.logger.verbose(`Claim ${rule.claim} not found in token`);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Handle non-array, non-string objects
|
||||
if (typeof claimValue === 'object' && claimValue !== null && !Array.isArray(claimValue)) {
|
||||
this.logger.warn(
|
||||
`unexpected JWT claim value encountered - claim ${rule.claim} has unsupported object type (keys: [${Object.keys(claimValue as Record<string, unknown>).join(', ')}])`
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Handle array claims - evaluate rule against each array element
|
||||
if (Array.isArray(claimValue)) {
|
||||
this.logger.verbose(
|
||||
`Processing array claim ${rule.claim} with ${claimValue.length} elements`
|
||||
);
|
||||
|
||||
// For array claims, check if ANY element in the array matches the rule
|
||||
const arrayResult = claimValue.some((element) => {
|
||||
// Skip non-string elements
|
||||
if (
|
||||
typeof element !== 'string' &&
|
||||
typeof element !== 'number' &&
|
||||
typeof element !== 'boolean'
|
||||
) {
|
||||
this.logger.verbose(`Skipping non-primitive element in array: ${typeof element}`);
|
||||
return false;
|
||||
}
|
||||
|
||||
const elementValue = String(element);
|
||||
return this.evaluateSingleValue(elementValue, rule);
|
||||
});
|
||||
|
||||
this.logger.verbose(`Array evaluation result for claim ${rule.claim}: ${arrayResult}`);
|
||||
return arrayResult;
|
||||
}
|
||||
|
||||
// Handle single value claims (string, number, boolean)
|
||||
const value = String(claimValue);
|
||||
this.logger.verbose(`Processing single value claim ${rule.claim} with value: "${value}"`);
|
||||
|
||||
return this.evaluateSingleValue(value, rule);
|
||||
}
|
||||
|
||||
private evaluateSingleValue(value: string, rule: OidcAuthorizationRule): boolean {
|
||||
let result: boolean;
|
||||
switch (rule.operator) {
|
||||
case AuthorizationOperator.EQUALS:
|
||||
result = rule.value.some((v) => value === v);
|
||||
this.logger.verbose(
|
||||
`EQUALS check: "${value}" matches any of [${rule.value.join(', ')}]: ${result}`
|
||||
);
|
||||
return result;
|
||||
|
||||
case AuthorizationOperator.CONTAINS:
|
||||
result = rule.value.some((v) => value.includes(v));
|
||||
this.logger.verbose(
|
||||
`CONTAINS check: "${value}" contains any of [${rule.value.join(', ')}]: ${result}`
|
||||
);
|
||||
return result;
|
||||
|
||||
case AuthorizationOperator.STARTS_WITH:
|
||||
result = rule.value.some((v) => value.startsWith(v));
|
||||
this.logger.verbose(
|
||||
`STARTS_WITH check: "${value}" starts with any of [${rule.value.join(', ')}]: ${result}`
|
||||
);
|
||||
return result;
|
||||
|
||||
case AuthorizationOperator.ENDS_WITH:
|
||||
result = rule.value.some((v) => value.endsWith(v));
|
||||
this.logger.verbose(
|
||||
`ENDS_WITH check: "${value}" ends with any of [${rule.value.join(', ')}]: ${result}`
|
||||
);
|
||||
return result;
|
||||
|
||||
default:
|
||||
this.logger.error(`Unknown authorization operator: ${rule.operator}`);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate OIDC provider configuration by attempting discovery
|
||||
* Returns validation result with helpful error messages for debugging
|
||||
*/
|
||||
async validateProvider(
|
||||
provider: OidcProvider
|
||||
): Promise<{ isValid: boolean; error?: string; details?: unknown }> {
|
||||
// Clear any cached config for this provider to force fresh validation
|
||||
this.configCache.delete(provider.id);
|
||||
|
||||
// Delegate to the validation service
|
||||
return this.validationService.validateProvider(provider);
|
||||
}
|
||||
|
||||
private getRedirectUri(requestOrigin?: string): string {
|
||||
// If we have the full origin (protocol://host), use it directly
|
||||
if (requestOrigin) {
|
||||
// Parse the origin to extract protocol and host
|
||||
try {
|
||||
const url = new URL(requestOrigin);
|
||||
const { protocol, hostname, port } = url;
|
||||
|
||||
// Reconstruct the URL, removing default ports
|
||||
let cleanOrigin = `${protocol}//${hostname}`;
|
||||
|
||||
// Add port if it's not the default for the protocol
|
||||
if (
|
||||
port &&
|
||||
!(protocol === 'https:' && port === '443') &&
|
||||
!(protocol === 'http:' && port === '80')
|
||||
) {
|
||||
cleanOrigin += `:${port}`;
|
||||
}
|
||||
|
||||
// Special handling for localhost development with Nuxt proxy
|
||||
if (hostname === 'localhost' && port === '3000') {
|
||||
return `${cleanOrigin}/graphql/api/auth/oidc/callback`;
|
||||
}
|
||||
|
||||
return `${cleanOrigin}/graphql/api/auth/oidc/callback`;
|
||||
} catch (e) {
|
||||
this.logger.warn(`Failed to parse request origin: ${requestOrigin}, error: ${e}`);
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to configured BASE_URL or default
|
||||
const baseUrl = this.configService.get('BASE_URL', 'http://tower.local');
|
||||
return `${baseUrl}/graphql/api/auth/oidc/callback`;
|
||||
}
|
||||
}
|
||||
@@ -1,204 +0,0 @@
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/oidc-state.service.js';
|
||||
|
||||
describe('OidcStateService', () => {
|
||||
let service: OidcStateService;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
vi.useFakeTimers();
|
||||
// Create a single instance for all tests in a describe block
|
||||
service = new OidcStateService();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
describe('generateSecureState', () => {
|
||||
it('should generate a state with provider prefix and signed token', () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state = service.generateSecureState(providerId, clientState);
|
||||
|
||||
expect(state).toBeTruthy();
|
||||
expect(typeof state).toBe('string');
|
||||
expect(state.startsWith(`${providerId}:`)).toBe(true);
|
||||
|
||||
// Extract signed portion and verify format (nonce.timestamp.signature)
|
||||
const signed = state.substring(providerId.length + 1);
|
||||
expect(signed.split('.').length).toBe(3);
|
||||
});
|
||||
|
||||
it('should generate unique states for each call', () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state1 = service.generateSecureState(providerId, clientState);
|
||||
const state2 = service.generateSecureState(providerId, clientState);
|
||||
|
||||
expect(state1).not.toBe(state2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('validateSecureState', () => {
|
||||
it('should validate a valid state token', () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state = service.generateSecureState(providerId, clientState);
|
||||
const result = service.validateSecureState(state, providerId);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.clientState).toBe(clientState);
|
||||
expect(result.error).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should reject state with wrong provider ID', () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state = service.generateSecureState(providerId, clientState);
|
||||
const result = service.validateSecureState(state, 'wrong-provider');
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toBe('Provider ID mismatch in state');
|
||||
});
|
||||
|
||||
it('should reject expired state tokens', () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state = service.generateSecureState(providerId, clientState);
|
||||
|
||||
// Fast forward time beyond expiration (11 minutes)
|
||||
vi.advanceTimersByTime(11 * 60 * 1000);
|
||||
|
||||
const result = service.validateSecureState(state, providerId);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toBe('State token has expired');
|
||||
});
|
||||
|
||||
it('should reject reused state tokens', () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state = service.generateSecureState(providerId, clientState);
|
||||
|
||||
// First validation should succeed
|
||||
const result1 = service.validateSecureState(state, providerId);
|
||||
expect(result1.isValid).toBe(true);
|
||||
|
||||
// Second validation should fail (replay attack prevention)
|
||||
const result2 = service.validateSecureState(state, providerId);
|
||||
expect(result2.isValid).toBe(false);
|
||||
expect(result2.error).toBe('State token not found or already used');
|
||||
});
|
||||
|
||||
it('should reject invalid state tokens', () => {
|
||||
const result = service.validateSecureState('invalid.state.token', 'test-provider');
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toBe('Invalid state format');
|
||||
});
|
||||
|
||||
it('should reject tampered state tokens', () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state = service.generateSecureState(providerId, clientState);
|
||||
|
||||
// Tamper with the signature
|
||||
const parts = state.split('.');
|
||||
parts[2] = parts[2].slice(0, -4) + 'XXXX';
|
||||
const tamperedState = parts.join('.');
|
||||
|
||||
const result = service.validateSecureState(tamperedState, providerId);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toBe('Invalid state signature');
|
||||
});
|
||||
});
|
||||
|
||||
describe('extractProviderFromState', () => {
|
||||
it('should extract provider from state prefix', () => {
|
||||
const state = 'provider-id:eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.payload.signature';
|
||||
const result = service.extractProviderFromState(state);
|
||||
|
||||
expect(result).toBe('provider-id');
|
||||
});
|
||||
|
||||
it('should handle states with multiple colons', () => {
|
||||
const state = 'provider-id:jwt:with:colons';
|
||||
const result = service.extractProviderFromState(state);
|
||||
|
||||
expect(result).toBe('provider-id');
|
||||
});
|
||||
|
||||
it('should return null for invalid format', () => {
|
||||
const result = service.extractProviderFromState('invalid-state');
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('extractProviderFromLegacyState', () => {
|
||||
it('should extract provider from legacy colon-separated format', () => {
|
||||
const result = service.extractProviderFromLegacyState('provider-id:client-state');
|
||||
|
||||
expect(result.providerId).toBe('provider-id');
|
||||
expect(result.originalState).toBe('client-state');
|
||||
});
|
||||
|
||||
it('should handle multiple colons in legacy format', () => {
|
||||
const result = service.extractProviderFromLegacyState(
|
||||
'provider-id:client:state:with:colons'
|
||||
);
|
||||
|
||||
expect(result.providerId).toBe('provider-id');
|
||||
expect(result.originalState).toBe('client:state:with:colons');
|
||||
});
|
||||
|
||||
it('should return empty provider for JWT format', () => {
|
||||
const jwtState = 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.payload.signature';
|
||||
const result = service.extractProviderFromLegacyState(jwtState);
|
||||
|
||||
expect(result.providerId).toBe('');
|
||||
expect(result.originalState).toBe(jwtState);
|
||||
});
|
||||
|
||||
it('should return empty provider for unknown format', () => {
|
||||
const result = service.extractProviderFromLegacyState('some-random-state');
|
||||
|
||||
expect(result.providerId).toBe('');
|
||||
expect(result.originalState).toBe('some-random-state');
|
||||
});
|
||||
});
|
||||
|
||||
describe('cleanupExpiredStates', () => {
|
||||
it('should clean up expired states periodically', () => {
|
||||
const providerId = 'test-provider';
|
||||
|
||||
// Generate multiple states
|
||||
service.generateSecureState(providerId, 'state1');
|
||||
service.generateSecureState(providerId, 'state2');
|
||||
service.generateSecureState(providerId, 'state3');
|
||||
|
||||
// Fast forward past expiration
|
||||
vi.advanceTimersByTime(11 * 60 * 1000);
|
||||
|
||||
// Generate a new state that shouldn't be cleaned
|
||||
const validState = service.generateSecureState(providerId, 'state4');
|
||||
|
||||
// Trigger cleanup (happens every minute)
|
||||
vi.advanceTimersByTime(60 * 1000);
|
||||
|
||||
// The new state should still be valid
|
||||
const result = service.validateSecureState(validState, providerId);
|
||||
expect(result.isValid).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,164 +0,0 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { ConfigService } from '@nestjs/config';
|
||||
|
||||
import * as client from 'openid-client';
|
||||
|
||||
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/oidc-provider.model.js';
|
||||
|
||||
@Injectable()
|
||||
export class OidcValidationService {
|
||||
private readonly logger = new Logger(OidcValidationService.name);
|
||||
|
||||
constructor(private readonly configService: ConfigService) {}
|
||||
|
||||
/**
|
||||
* Validate OIDC provider configuration by attempting discovery
|
||||
* Returns validation result with helpful error messages for debugging
|
||||
*/
|
||||
async validateProvider(
|
||||
provider: OidcProvider
|
||||
): Promise<{ isValid: boolean; error?: string; details?: unknown }> {
|
||||
try {
|
||||
// Validate issuer URL is present
|
||||
if (!provider.issuer) {
|
||||
return {
|
||||
isValid: false,
|
||||
error: 'No issuer URL provided. Please specify the OIDC provider issuer URL.',
|
||||
details: { type: 'MISSING_ISSUER' },
|
||||
};
|
||||
}
|
||||
|
||||
// Validate issuer URL is valid
|
||||
let serverUrl: URL;
|
||||
try {
|
||||
serverUrl = new URL(provider.issuer);
|
||||
} catch (urlError) {
|
||||
return {
|
||||
isValid: false,
|
||||
error: `Invalid issuer URL format: '${provider.issuer}'. Please provide a valid URL.`,
|
||||
details: {
|
||||
type: 'INVALID_URL',
|
||||
originalError: urlError instanceof Error ? urlError.message : String(urlError),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Configure client options for HTTP if needed
|
||||
let clientOptions: any = undefined;
|
||||
if (serverUrl.protocol === 'http:') {
|
||||
this.logger.debug(
|
||||
`HTTP issuer URL detected for provider ${provider.id}: ${provider.issuer}`
|
||||
);
|
||||
clientOptions = {
|
||||
execute: [client.allowInsecureRequests],
|
||||
};
|
||||
}
|
||||
|
||||
// Attempt OIDC discovery
|
||||
await this.performDiscovery(provider, clientOptions);
|
||||
return { isValid: true };
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
|
||||
|
||||
// Log the raw error for debugging
|
||||
this.logger.debug(`Raw discovery error for ${provider.id}: ${errorMessage}`);
|
||||
|
||||
// Provide specific error messages for common issues
|
||||
let userFriendlyError = errorMessage;
|
||||
let details: Record<string, unknown> = {};
|
||||
|
||||
if (errorMessage.includes('getaddrinfo ENOTFOUND')) {
|
||||
userFriendlyError = `Cannot resolve domain name. Please check that '${provider.issuer}' is accessible and spelled correctly.`;
|
||||
details = { type: 'DNS_ERROR', originalError: errorMessage };
|
||||
} else if (errorMessage.includes('ECONNREFUSED')) {
|
||||
userFriendlyError = `Connection refused. The server at '${provider.issuer}' is not accepting connections.`;
|
||||
details = { type: 'CONNECTION_ERROR', originalError: errorMessage };
|
||||
} else if (errorMessage.includes('ECONNRESET') || errorMessage.includes('ETIMEDOUT')) {
|
||||
userFriendlyError = `Connection timeout. The server at '${provider.issuer}' is not responding.`;
|
||||
details = { type: 'TIMEOUT_ERROR', originalError: errorMessage };
|
||||
} else if (errorMessage.includes('404') || errorMessage.includes('Not Found')) {
|
||||
const baseUrl = provider.issuer?.endsWith('/.well-known/openid-configuration')
|
||||
? provider.issuer.replace('/.well-known/openid-configuration', '')
|
||||
: provider.issuer;
|
||||
userFriendlyError = `OIDC discovery endpoint not found. Please verify that '${baseUrl}/.well-known/openid-configuration' exists.`;
|
||||
details = { type: 'DISCOVERY_NOT_FOUND', originalError: errorMessage };
|
||||
} else if (errorMessage.includes('401') || errorMessage.includes('403')) {
|
||||
userFriendlyError = `Access denied to discovery endpoint. Please check the issuer URL and any authentication requirements.`;
|
||||
details = { type: 'AUTHENTICATION_ERROR', originalError: errorMessage };
|
||||
} else if (errorMessage.includes('unexpected HTTP response status code')) {
|
||||
// Extract status code if possible
|
||||
const statusMatch = errorMessage.match(/status code (\d+)/);
|
||||
const statusCode = statusMatch ? statusMatch[1] : 'unknown';
|
||||
const baseUrl = provider.issuer?.endsWith('/.well-known/openid-configuration')
|
||||
? provider.issuer.replace('/.well-known/openid-configuration', '')
|
||||
: provider.issuer;
|
||||
userFriendlyError = `HTTP ${statusCode} error from discovery endpoint. Please check that '${baseUrl}/.well-known/openid-configuration' returns a valid OIDC discovery document.`;
|
||||
details = { type: 'HTTP_STATUS_ERROR', statusCode, originalError: errorMessage };
|
||||
} else if (
|
||||
errorMessage.includes('certificate') ||
|
||||
errorMessage.includes('SSL') ||
|
||||
errorMessage.includes('TLS')
|
||||
) {
|
||||
userFriendlyError = `SSL/TLS certificate error. The server certificate may be invalid or expired.`;
|
||||
details = { type: 'SSL_ERROR', originalError: errorMessage };
|
||||
} else if (errorMessage.includes('JSON') || errorMessage.includes('parse')) {
|
||||
userFriendlyError = `Invalid OIDC discovery response. The server returned malformed JSON.`;
|
||||
details = { type: 'INVALID_JSON', originalError: errorMessage };
|
||||
} else if (error && (error as any).code === 'OAUTH_RESPONSE_IS_NOT_CONFORM') {
|
||||
const baseUrl = provider.issuer?.endsWith('/.well-known/openid-configuration')
|
||||
? provider.issuer.replace('/.well-known/openid-configuration', '')
|
||||
: provider.issuer;
|
||||
userFriendlyError = `Invalid OIDC discovery document. The server at '${baseUrl}/.well-known/openid-configuration' returned a response that doesn't conform to the OpenID Connect Discovery specification. Please verify the endpoint returns valid OIDC metadata.`;
|
||||
details = { type: 'INVALID_OIDC_DOCUMENT', originalError: errorMessage };
|
||||
}
|
||||
|
||||
this.logger.warn(`OIDC validation failed for provider ${provider.id}: ${errorMessage}`);
|
||||
|
||||
// Add debug logging for HTTP status errors
|
||||
if (errorMessage.includes('unexpected HTTP response status code')) {
|
||||
const baseUrl = provider.issuer?.endsWith('/.well-known/openid-configuration')
|
||||
? provider.issuer.replace('/.well-known/openid-configuration', '')
|
||||
: provider.issuer;
|
||||
this.logger.debug(`Attempted to fetch: ${baseUrl}/.well-known/openid-configuration`);
|
||||
this.logger.debug(`Full error details: ${errorMessage}`);
|
||||
}
|
||||
|
||||
return {
|
||||
isValid: false,
|
||||
error: userFriendlyError,
|
||||
details,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
async performDiscovery(provider: OidcProvider, clientOptions?: any): Promise<client.Configuration> {
|
||||
if (!provider.issuer) {
|
||||
throw new Error('No issuer URL provided');
|
||||
}
|
||||
|
||||
// Configure client auth method
|
||||
const clientAuth = provider.clientSecret
|
||||
? client.ClientSecretPost(provider.clientSecret)
|
||||
: undefined;
|
||||
|
||||
const serverUrl = new URL(provider.issuer);
|
||||
|
||||
// Use provided client options or create default options with HTTP support if needed
|
||||
if (!clientOptions && serverUrl.protocol === 'http:') {
|
||||
this.logger.debug(`Allowing HTTP for ${provider.id} as specified by user`);
|
||||
// For openid-client v6, use allowInsecureRequests in the execute array
|
||||
// This is deprecated but needed for local development with HTTP endpoints
|
||||
clientOptions = {
|
||||
execute: [client.allowInsecureRequests],
|
||||
};
|
||||
}
|
||||
|
||||
return client.discovery(
|
||||
serverUrl,
|
||||
provider.clientId,
|
||||
undefined, // client metadata
|
||||
clientAuth,
|
||||
clientOptions
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
|
||||
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-session.service.js';
|
||||
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
|
||||
|
||||
@Module({
|
||||
providers: [OidcSessionService, OidcStateService],
|
||||
exports: [OidcSessionService, OidcStateService],
|
||||
})
|
||||
export class OidcSessionModule {}
|
||||
@@ -4,7 +4,7 @@ import { Test } from '@nestjs/testing';
|
||||
import type { Cache } from 'cache-manager';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/oidc-session.service.js';
|
||||
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-session.service.js';
|
||||
|
||||
describe('OidcSessionService', () => {
|
||||
let service: OidcSessionService;
|
||||
@@ -15,7 +15,7 @@ export interface OidcSession {
|
||||
@Injectable()
|
||||
export class OidcSessionService {
|
||||
private readonly logger = new Logger(OidcSessionService.name);
|
||||
private readonly SESSION_TTL_SECONDS = 2 * 60; // 2 minutes for one-time token security
|
||||
private readonly SESSION_TTL_MS = 2 * 60 * 1000; // 2 minutes in milliseconds (cache-manager v7 expects milliseconds)
|
||||
|
||||
constructor(@Inject(CACHE_MANAGER) private readonly cacheManager: Cache) {}
|
||||
|
||||
@@ -28,12 +28,21 @@ export class OidcSessionService {
|
||||
providerId,
|
||||
providerUserId,
|
||||
createdAt: now,
|
||||
expiresAt: new Date(now.getTime() + this.SESSION_TTL_SECONDS * 1000),
|
||||
expiresAt: new Date(now.getTime() + this.SESSION_TTL_MS),
|
||||
};
|
||||
|
||||
// Store in cache with TTL
|
||||
await this.cacheManager.set(sessionId, session, this.SESSION_TTL_SECONDS * 1000);
|
||||
this.logger.log(`Created OIDC session ${sessionId} for provider ${providerId}`);
|
||||
// Store in cache with TTL (in milliseconds for cache-manager v7)
|
||||
await this.cacheManager.set(sessionId, session, this.SESSION_TTL_MS);
|
||||
|
||||
// Verify it was stored
|
||||
const verifyStored = await this.cacheManager.get(sessionId);
|
||||
if (verifyStored) {
|
||||
this.logger.debug(`Session successfully stored and verified with ID: ${sessionId}`);
|
||||
} else {
|
||||
this.logger.error(`CRITICAL: Session was NOT stored in cache for ID: ${sessionId}`);
|
||||
}
|
||||
|
||||
this.logger.log(`Created OIDC session for provider ${providerId}`);
|
||||
|
||||
return this.createPaddedToken(sessionId);
|
||||
}
|
||||
@@ -44,15 +53,16 @@ export class OidcSessionService {
|
||||
return { valid: false };
|
||||
}
|
||||
|
||||
this.logger.debug(`Looking for session with ID: ${sessionId}`);
|
||||
const session = await this.cacheManager.get<OidcSession>(sessionId);
|
||||
if (!session) {
|
||||
this.logger.debug(`Session ${sessionId} not found`);
|
||||
this.logger.debug(`Session not found for ID: ${sessionId}`);
|
||||
return { valid: false };
|
||||
}
|
||||
|
||||
const now = new Date();
|
||||
if (now > new Date(session.expiresAt)) {
|
||||
this.logger.debug(`Session ${sessionId} expired`);
|
||||
this.logger.debug(`Session expired`);
|
||||
await this.cacheManager.del(sessionId);
|
||||
return { valid: false };
|
||||
}
|
||||
@@ -62,7 +72,7 @@ export class OidcSessionService {
|
||||
await this.cacheManager.del(sessionId);
|
||||
|
||||
this.logger.log(
|
||||
`Validated and invalidated session ${sessionId} for provider ${session.providerId} (one-time use)`
|
||||
`Validated and invalidated session for provider ${session.providerId} (one-time use)`
|
||||
);
|
||||
return { valid: true, username: 'root' };
|
||||
}
|
||||
@@ -0,0 +1,155 @@
|
||||
import { CacheModule } from '@nestjs/cache-manager';
|
||||
import { UnauthorizedException } from '@nestjs/common';
|
||||
import { Test } from '@nestjs/testing';
|
||||
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { OidcStateExtractor } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state-extractor.util.js';
|
||||
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
|
||||
|
||||
describe('OidcStateExtractor', () => {
|
||||
let stateService: OidcStateService;
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
const module = await Test.createTestingModule({
|
||||
imports: [CacheModule.register()],
|
||||
providers: [OidcStateService],
|
||||
}).compile();
|
||||
|
||||
stateService = module.get<OidcStateService>(OidcStateService);
|
||||
});
|
||||
|
||||
describe('extractProviderFromState', () => {
|
||||
it('should extract provider ID from valid state', () => {
|
||||
const state = 'provider123:nonce.timestamp.signature';
|
||||
const result = OidcStateExtractor.extractProviderFromState(state, stateService);
|
||||
|
||||
expect(result.providerId).toBe('provider123');
|
||||
expect(result.originalState).toBe(state);
|
||||
});
|
||||
|
||||
it('should handle state without provider prefix', () => {
|
||||
const state = 'invalid-state-format';
|
||||
const result = OidcStateExtractor.extractProviderFromState(state, stateService);
|
||||
|
||||
expect(result.providerId).toBe('');
|
||||
expect(result.originalState).toBe(state);
|
||||
});
|
||||
});
|
||||
|
||||
describe('extractAndValidateState', () => {
|
||||
it('should extract and validate a valid state with redirectUri', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
const redirectUri = 'https://example.com/callback';
|
||||
|
||||
// Generate a valid state
|
||||
const state = await stateService.generateSecureState(providerId, clientState, redirectUri);
|
||||
|
||||
// Extract and validate
|
||||
const result = await OidcStateExtractor.extractAndValidateState(state, stateService);
|
||||
|
||||
expect(result.providerId).toBe(providerId);
|
||||
expect(result.originalState).toBe(state);
|
||||
expect(result.clientState).toBe(clientState);
|
||||
expect(result.redirectUri).toBe(redirectUri);
|
||||
});
|
||||
|
||||
it('should extract and validate a valid state without redirectUri', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
// Generate a valid state without redirectUri
|
||||
const state = await stateService.generateSecureState(providerId, clientState);
|
||||
|
||||
// Extract and validate
|
||||
const result = await OidcStateExtractor.extractAndValidateState(state, stateService);
|
||||
|
||||
expect(result.providerId).toBe(providerId);
|
||||
expect(result.originalState).toBe(state);
|
||||
expect(result.clientState).toBe(clientState);
|
||||
expect(result.redirectUri).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should throw UnauthorizedException for invalid state format', async () => {
|
||||
const invalidState = 'invalid-format';
|
||||
|
||||
await expect(async () => {
|
||||
await OidcStateExtractor.extractAndValidateState(invalidState, stateService);
|
||||
}).rejects.toThrow(UnauthorizedException);
|
||||
});
|
||||
|
||||
it('should throw UnauthorizedException for expired state', async () => {
|
||||
vi.useFakeTimers();
|
||||
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
const redirectUri = 'https://example.com/callback';
|
||||
|
||||
// Generate a valid state
|
||||
const state = await stateService.generateSecureState(providerId, clientState, redirectUri);
|
||||
|
||||
// Fast forward time beyond expiration (11 minutes)
|
||||
vi.advanceTimersByTime(11 * 60 * 1000);
|
||||
|
||||
await expect(async () => {
|
||||
await OidcStateExtractor.extractAndValidateState(state, stateService);
|
||||
}).rejects.toThrow(UnauthorizedException);
|
||||
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
it('should throw UnauthorizedException for wrong provider ID', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const wrongProviderId = 'wrong-provider';
|
||||
const clientState = 'client-state-123';
|
||||
const redirectUri = 'https://example.com/callback';
|
||||
|
||||
// Generate a valid state
|
||||
const state = await stateService.generateSecureState(providerId, clientState, redirectUri);
|
||||
|
||||
// Create a fake state with wrong provider prefix
|
||||
const tamperedState = state.replace(providerId, wrongProviderId);
|
||||
|
||||
await expect(async () => {
|
||||
await OidcStateExtractor.extractAndValidateState(tamperedState, stateService);
|
||||
}).rejects.toThrow(UnauthorizedException);
|
||||
});
|
||||
|
||||
it('should throw UnauthorizedException for tampered state', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
const redirectUri = 'https://example.com/callback';
|
||||
|
||||
// Generate a valid state
|
||||
const state = await stateService.generateSecureState(providerId, clientState, redirectUri);
|
||||
|
||||
// Tamper with the signature
|
||||
const tamperedState = state.slice(0, -5) + 'xxxxx';
|
||||
|
||||
await expect(async () => {
|
||||
await OidcStateExtractor.extractAndValidateState(tamperedState, stateService);
|
||||
}).rejects.toThrow(UnauthorizedException);
|
||||
});
|
||||
|
||||
it('should throw UnauthorizedException for reused state (replay attack)', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
const redirectUri = 'https://example.com/callback';
|
||||
|
||||
// Generate a valid state
|
||||
const state = await stateService.generateSecureState(providerId, clientState, redirectUri);
|
||||
|
||||
// First validation should succeed
|
||||
const result1 = await OidcStateExtractor.extractAndValidateState(state, stateService);
|
||||
expect(result1.providerId).toBe(providerId);
|
||||
|
||||
// Second validation should fail (replay attack)
|
||||
await expect(async () => {
|
||||
await OidcStateExtractor.extractAndValidateState(state, stateService);
|
||||
}).rejects.toThrow(UnauthorizedException);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,60 @@
|
||||
import { UnauthorizedException } from '@nestjs/common';
|
||||
|
||||
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
|
||||
|
||||
export interface StateExtractionResult {
|
||||
providerId: string;
|
||||
originalState: string;
|
||||
clientState?: string;
|
||||
redirectUri?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Utility to extract and validate OIDC state information consistently
|
||||
* across authorize and callback endpoints
|
||||
*/
|
||||
export class OidcStateExtractor {
|
||||
/**
|
||||
* Extract provider ID from state without validation (for routing purposes)
|
||||
*/
|
||||
static extractProviderFromState(
|
||||
state: string,
|
||||
stateService: OidcStateService
|
||||
): { providerId: string; originalState: string } {
|
||||
// Use the state service's extraction method
|
||||
const providerId = stateService.extractProviderFromState(state);
|
||||
|
||||
return {
|
||||
providerId: providerId || '',
|
||||
originalState: state,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract provider ID and validate the full encrypted state
|
||||
*/
|
||||
static async extractAndValidateState(
|
||||
state: string,
|
||||
stateService: OidcStateService
|
||||
): Promise<StateExtractionResult> {
|
||||
// First extract provider ID for routing
|
||||
const { providerId } = this.extractProviderFromState(state, stateService);
|
||||
|
||||
if (!providerId) {
|
||||
throw new UnauthorizedException('Invalid state format: missing provider ID');
|
||||
}
|
||||
|
||||
// Then validate the full encrypted state
|
||||
const stateValidation = await stateService.validateSecureState(state, providerId);
|
||||
if (!stateValidation.isValid) {
|
||||
throw new UnauthorizedException(`Invalid state: ${stateValidation.error}`);
|
||||
}
|
||||
|
||||
return {
|
||||
providerId,
|
||||
originalState: state,
|
||||
clientState: stateValidation.clientState,
|
||||
redirectUri: stateValidation.redirectUri,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,238 @@
|
||||
import { CacheModule } from '@nestjs/cache-manager';
|
||||
import { Test, TestingModule } from '@nestjs/testing';
|
||||
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
|
||||
|
||||
describe('OidcStateService', () => {
|
||||
let service: OidcStateService;
|
||||
let module: TestingModule;
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.clearAllMocks();
|
||||
vi.useFakeTimers();
|
||||
// Set a deterministic system time for consistent testing
|
||||
vi.setSystemTime(new Date('2024-01-01T00:00:00Z'));
|
||||
|
||||
module = await Test.createTestingModule({
|
||||
imports: [CacheModule.register()],
|
||||
providers: [OidcStateService],
|
||||
}).compile();
|
||||
|
||||
service = module.get<OidcStateService>(OidcStateService);
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
vi.useRealTimers();
|
||||
// Close the testing module to prevent handle leaks
|
||||
if (module) {
|
||||
await module.close();
|
||||
}
|
||||
});
|
||||
|
||||
describe('generateSecureState', () => {
|
||||
it('should generate a state with provider prefix and signed token', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
const redirectUri = 'https://example.com/callback';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState, redirectUri);
|
||||
|
||||
expect(state).toBeTruthy();
|
||||
expect(typeof state).toBe('string');
|
||||
expect(state.startsWith(`${providerId}:`)).toBe(true);
|
||||
|
||||
// Extract signed portion and verify format (nonce.timestamp.signature)
|
||||
const signed = state.substring(providerId.length + 1);
|
||||
expect(signed.split('.').length).toBe(3);
|
||||
});
|
||||
|
||||
it('should generate unique states for each call', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
const redirectUri = 'https://example.com/callback';
|
||||
|
||||
const state1 = await service.generateSecureState(providerId, clientState, redirectUri);
|
||||
const state2 = await service.generateSecureState(providerId, clientState, redirectUri);
|
||||
|
||||
expect(state1).not.toBe(state2);
|
||||
});
|
||||
|
||||
it('should work without redirectUri parameter (backwards compatibility)', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState);
|
||||
|
||||
expect(state).toBeTruthy();
|
||||
expect(state.startsWith(`${providerId}:`)).toBe(true);
|
||||
});
|
||||
|
||||
it('should store state data in cache and retrieve it', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
const redirectUri = 'https://example.com/callback';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState, redirectUri);
|
||||
const validation = await service.validateSecureState(state, providerId);
|
||||
|
||||
expect(validation.isValid).toBe(true);
|
||||
expect(validation.clientState).toBe(clientState);
|
||||
expect(validation.redirectUri).toBe(redirectUri);
|
||||
});
|
||||
});
|
||||
|
||||
describe('validateSecureState', () => {
|
||||
it('should validate a valid state token', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState);
|
||||
const result = await service.validateSecureState(state, providerId);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.clientState).toBe(clientState);
|
||||
expect(result.error).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should validate a state token with redirectUri', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
const redirectUri = 'https://example.com/callback';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState, redirectUri);
|
||||
const result = await service.validateSecureState(state, providerId);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.clientState).toBe(clientState);
|
||||
expect(result.redirectUri).toBe(redirectUri);
|
||||
expect(result.error).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should reject state with wrong provider ID', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState);
|
||||
const result = await service.validateSecureState(state, 'different-provider');
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toContain('Provider ID mismatch');
|
||||
});
|
||||
|
||||
it('should reject expired state tokens', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState);
|
||||
|
||||
// Advance time by 11 minutes (past the 10-minute TTL)
|
||||
vi.advanceTimersByTime(11 * 60 * 1000);
|
||||
|
||||
const result = await service.validateSecureState(state, providerId);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toContain('expired');
|
||||
});
|
||||
|
||||
it('should reject reused state tokens', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState);
|
||||
|
||||
// First validation should succeed
|
||||
const result1 = await service.validateSecureState(state, providerId);
|
||||
expect(result1.isValid).toBe(true);
|
||||
|
||||
// Second validation should fail (replay attack prevention)
|
||||
const result2 = await service.validateSecureState(state, providerId);
|
||||
expect(result2.isValid).toBe(false);
|
||||
expect(result2.error).toContain('not found or already used');
|
||||
});
|
||||
|
||||
it('should reject invalid state tokens', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const invalidState = `${providerId}:invalid-format`;
|
||||
|
||||
const result = await service.validateSecureState(invalidState, providerId);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toBeTruthy();
|
||||
});
|
||||
|
||||
it('should reject tampered state tokens', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState);
|
||||
// Tamper with the signature
|
||||
const tamperedState = state.substring(0, state.length - 5) + 'xxxxx';
|
||||
|
||||
const result = await service.validateSecureState(tamperedState, providerId);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toContain('signature');
|
||||
});
|
||||
});
|
||||
|
||||
describe('extractProviderFromState', () => {
|
||||
it('should extract provider ID from state', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState);
|
||||
const extracted = service.extractProviderFromState(state);
|
||||
|
||||
expect(extracted).toBe(providerId);
|
||||
});
|
||||
|
||||
it('should return null for invalid state format', () => {
|
||||
const invalidState = 'invalid-state-without-colon';
|
||||
const extracted = service.extractProviderFromState(invalidState);
|
||||
|
||||
expect(extracted).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('extractProviderFromLegacyState', () => {
|
||||
it('should handle legacy state format', () => {
|
||||
const legacyState = 'provider-id:client-state-value';
|
||||
const result = service.extractProviderFromLegacyState(legacyState);
|
||||
|
||||
expect(result.providerId).toBe('provider-id');
|
||||
expect(result.originalState).toBe('client-state-value');
|
||||
});
|
||||
|
||||
it('should handle new signed state format', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState);
|
||||
const result = service.extractProviderFromLegacyState(state);
|
||||
|
||||
// New format should not be recognized as legacy
|
||||
expect(result.providerId).toBe('');
|
||||
expect(result.originalState).toBe(state);
|
||||
});
|
||||
});
|
||||
|
||||
describe('cache TTL', () => {
|
||||
it('should remove state from cache after successful validation', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-state-123';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState);
|
||||
|
||||
// First validation should succeed
|
||||
const result1 = await service.validateSecureState(state, providerId);
|
||||
expect(result1.isValid).toBe(true);
|
||||
|
||||
// Second validation should fail (state was removed after first use)
|
||||
const result2 = await service.validateSecureState(state, providerId);
|
||||
expect(result2.isValid).toBe(false);
|
||||
expect(result2.error).toContain('not found or already used');
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,241 @@
|
||||
import { CacheModule } from '@nestjs/cache-manager';
|
||||
import { Test } from '@nestjs/testing';
|
||||
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state.service.js';
|
||||
|
||||
describe('OidcStateService', () => {
|
||||
let service: OidcStateService;
|
||||
|
||||
beforeEach(async () => {
|
||||
const module = await Test.createTestingModule({
|
||||
imports: [CacheModule.register()],
|
||||
providers: [OidcStateService],
|
||||
}).compile();
|
||||
|
||||
service = module.get<OidcStateService>(OidcStateService);
|
||||
});
|
||||
|
||||
describe('state generation and validation flow', () => {
|
||||
it('should generate state with redirect URI and validate it successfully', async () => {
|
||||
const providerId = 'unraid.net';
|
||||
const clientState = 'client-state-123';
|
||||
const redirectUri = 'http://devgen-dev1.local/graphql/api/auth/oidc/callback';
|
||||
|
||||
// Generate state
|
||||
const state = await service.generateSecureState(providerId, clientState, redirectUri);
|
||||
|
||||
// Verify state format: providerId:nonce.timestamp.signature
|
||||
expect(state).toMatch(/^unraid\.net:[a-f0-9]+\.\d+\.[a-f0-9]+$/);
|
||||
|
||||
// Extract and verify parts
|
||||
const [extractedProviderId, signedState] = state.split(':');
|
||||
expect(extractedProviderId).toBe(providerId);
|
||||
|
||||
// Parse the signed state components
|
||||
const [nonce, timestamp, signature] = signedState.split('.');
|
||||
|
||||
// Verify nonce is a 32-character hex string (16 bytes)
|
||||
expect(nonce).toMatch(/^[a-f0-9]{32}$/);
|
||||
|
||||
// Verify timestamp is a valid number and recent
|
||||
const timestampNum = parseInt(timestamp, 10);
|
||||
expect(timestampNum).toBeGreaterThan(Date.now() - 1000); // Generated within last second
|
||||
expect(timestampNum).toBeLessThanOrEqual(Date.now());
|
||||
|
||||
// Verify signature is a 64-character hex string (SHA256 output)
|
||||
expect(signature).toMatch(/^[a-f0-9]{64}$/);
|
||||
|
||||
// Validate the state
|
||||
const validation = await service.validateSecureState(state, providerId);
|
||||
|
||||
expect(validation.isValid).toBe(true);
|
||||
expect(validation.clientState).toBe(clientState);
|
||||
expect(validation.redirectUri).toBe(redirectUri);
|
||||
});
|
||||
|
||||
it('should verify signed state integrity with HMAC', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'test-state';
|
||||
const redirectUri = 'http://localhost:3000/callback';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState, redirectUri);
|
||||
|
||||
// Tamper with the signature
|
||||
const [provider, signedState] = state.split(':');
|
||||
const [nonce, timestamp] = signedState.split('.');
|
||||
const tamperedSignature = 'a'.repeat(64); // Invalid signature
|
||||
const tamperedState = `${provider}:${nonce}.${timestamp}.${tamperedSignature}`;
|
||||
|
||||
const validation = await service.validateSecureState(tamperedState, providerId);
|
||||
|
||||
expect(validation.isValid).toBe(false);
|
||||
expect(validation.error).toContain('Invalid state signature');
|
||||
});
|
||||
|
||||
it('should fail validation when nonce is not in cache', async () => {
|
||||
const providerId = 'unraid.net';
|
||||
// Create a fake state that looks valid but has unknown nonce
|
||||
const fakeState = `unraid.net:fakenonce123.${Date.now()}.fakesignature456`;
|
||||
|
||||
const validation = await service.validateSecureState(fakeState, providerId);
|
||||
|
||||
expect(validation.isValid).toBe(false);
|
||||
expect(validation.error).toContain('Invalid state signature');
|
||||
});
|
||||
|
||||
it('should prevent replay attacks by removing nonce after validation', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'test-state';
|
||||
const redirectUri = 'http://localhost:3000/callback';
|
||||
|
||||
// Generate and validate state once
|
||||
const state = await service.generateSecureState(providerId, clientState, redirectUri);
|
||||
const firstValidation = await service.validateSecureState(state, providerId);
|
||||
expect(firstValidation.isValid).toBe(true);
|
||||
|
||||
// Try to validate the same state again (replay attack)
|
||||
const secondValidation = await service.validateSecureState(state, providerId);
|
||||
expect(secondValidation.isValid).toBe(false);
|
||||
expect(secondValidation.error).toContain('State token not found or already used');
|
||||
});
|
||||
|
||||
it('should handle state with missing redirect URI', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'test-state';
|
||||
// No redirect URI provided
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState);
|
||||
const validation = await service.validateSecureState(state, providerId);
|
||||
|
||||
expect(validation.isValid).toBe(true);
|
||||
expect(validation.clientState).toBe(clientState);
|
||||
expect(validation.redirectUri).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should reject state with wrong provider ID', async () => {
|
||||
const providerId = 'provider-a';
|
||||
const wrongProviderId = 'provider-b';
|
||||
const clientState = 'test-state';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState);
|
||||
const validation = await service.validateSecureState(state, wrongProviderId);
|
||||
|
||||
expect(validation.isValid).toBe(false);
|
||||
expect(validation.error).toContain('Provider ID mismatch');
|
||||
});
|
||||
|
||||
it('should extract provider from state correctly', async () => {
|
||||
const providerId = 'unraid.net';
|
||||
const state = await service.generateSecureState(providerId, 'test', 'http://example.com');
|
||||
|
||||
const extracted = service.extractProviderFromState(state);
|
||||
expect(extracted).toBe(providerId);
|
||||
});
|
||||
|
||||
it('should handle state expiration', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'test-state';
|
||||
|
||||
// Generate state
|
||||
const state = await service.generateSecureState(providerId, clientState);
|
||||
|
||||
// Mock timestamp to simulate expired state
|
||||
const parts = state.split(':')[1].split('.');
|
||||
const nonce = parts[0];
|
||||
const expiredTimestamp = Date.now() - 700000; // 11+ minutes ago
|
||||
const fakeState = `${providerId}:${nonce}.${expiredTimestamp}.fakesignature`;
|
||||
|
||||
const validation = await service.validateSecureState(fakeState, providerId);
|
||||
expect(validation.isValid).toBe(false);
|
||||
expect(validation.error).toContain('Invalid state signature'); // Will fail on signature first
|
||||
});
|
||||
});
|
||||
|
||||
describe('redirect URI extraction from state', () => {
|
||||
it('should store and retrieve redirect URI from state token', async () => {
|
||||
const providerId = 'unraid.net';
|
||||
const clientState = 'original-client-state';
|
||||
const redirectUri = 'http://devgen-dev1.local/graphql/api/auth/oidc/callback';
|
||||
|
||||
// This simulates the authorize flow
|
||||
const stateToken = await service.generateSecureState(providerId, clientState, redirectUri);
|
||||
|
||||
// Log the generated state for debugging
|
||||
console.log('Generated state token:', stateToken);
|
||||
|
||||
// This simulates the callback flow
|
||||
const validation = await service.validateSecureState(stateToken, providerId);
|
||||
|
||||
expect(validation.isValid).toBe(true);
|
||||
expect(validation.redirectUri).toBe(redirectUri);
|
||||
expect(validation.clientState).toBe(clientState);
|
||||
});
|
||||
|
||||
it('should handle dynamic redirect URIs for different origins', async () => {
|
||||
const providerId = 'google';
|
||||
const clientState = 'state123';
|
||||
|
||||
// Test with different origins
|
||||
const origins = [
|
||||
'http://localhost:3000/graphql/api/auth/oidc/callback',
|
||||
'https://myserver.local/graphql/api/auth/oidc/callback',
|
||||
'http://192.168.1.100/graphql/api/auth/oidc/callback',
|
||||
];
|
||||
|
||||
for (const redirectUri of origins) {
|
||||
const state = await service.generateSecureState(providerId, clientState, redirectUri);
|
||||
const validation = await service.validateSecureState(state, providerId);
|
||||
|
||||
expect(validation.isValid).toBe(true);
|
||||
expect(validation.redirectUri).toBe(redirectUri);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('cache management', () => {
|
||||
it('should handle TTL expiration correctly', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'test-state';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState);
|
||||
|
||||
// First validation should succeed
|
||||
const validation1 = await service.validateSecureState(state, providerId);
|
||||
expect(validation1.isValid).toBe(true);
|
||||
|
||||
// State should be removed after first use (replay protection)
|
||||
const validation2 = await service.validateSecureState(state, providerId);
|
||||
expect(validation2.isValid).toBe(false);
|
||||
});
|
||||
|
||||
it('should store complete state data in cache with redirect URI', async () => {
|
||||
const providerId = 'test-provider';
|
||||
const clientState = 'client-123';
|
||||
const redirectUri = 'http://example.com/callback';
|
||||
|
||||
const state = await service.generateSecureState(providerId, clientState, redirectUri);
|
||||
|
||||
// Extract nonce from the generated state
|
||||
const [, signedState] = state.split(':');
|
||||
const [nonce] = signedState.split('.');
|
||||
|
||||
// Access the cache directly to verify stored data
|
||||
const cacheKey = `oidc_state:${nonce}`;
|
||||
const cachedData = await service['cacheManager'].get(cacheKey);
|
||||
|
||||
expect(cachedData).toBeDefined();
|
||||
expect(cachedData).toMatchObject({
|
||||
nonce,
|
||||
clientState,
|
||||
providerId,
|
||||
redirectUri,
|
||||
});
|
||||
// @ts-expect-error - cachedData is of type StateData
|
||||
expect(cachedData.timestamp).toBeGreaterThan(Date.now() - 1000);
|
||||
// @ts-expect-error - cachedData is of type StateData
|
||||
expect(cachedData.timestamp).toBeLessThanOrEqual(Date.now());
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,4 +1,5 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { Cache, CACHE_MANAGER } from '@nestjs/cache-manager';
|
||||
import { Inject, Injectable, Logger } from '@nestjs/common';
|
||||
import crypto from 'crypto';
|
||||
|
||||
interface StateData {
|
||||
@@ -6,26 +7,34 @@ interface StateData {
|
||||
clientState: string;
|
||||
timestamp: number;
|
||||
providerId: string;
|
||||
redirectUri?: string;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class OidcStateService {
|
||||
private static instanceCount = 0;
|
||||
private readonly instanceId: number;
|
||||
private readonly logger = new Logger(OidcStateService.name);
|
||||
private readonly stateCache = new Map<string, StateData>();
|
||||
private readonly hmacSecret: string;
|
||||
private readonly STATE_TTL_SECONDS = 600; // 10 minutes
|
||||
private readonly STATE_TTL_MS = 600000; // 10 minutes in milliseconds (cache-manager v7+ expects milliseconds, not seconds)
|
||||
private readonly STATE_CACHE_PREFIX = 'oidc_state:';
|
||||
|
||||
constructor(@Inject(CACHE_MANAGER) private cacheManager: Cache) {
|
||||
// Track instance creation
|
||||
this.instanceId = ++OidcStateService.instanceCount;
|
||||
|
||||
constructor() {
|
||||
// Always generate a new secret on API restart for security
|
||||
// This ensures state tokens cannot be reused across restarts
|
||||
this.hmacSecret = crypto.randomBytes(32).toString('hex');
|
||||
this.logger.debug('Generated new OIDC state secret for this session');
|
||||
|
||||
// Clean up expired states periodically
|
||||
setInterval(() => this.cleanupExpiredStates(), 60000); // Every minute
|
||||
this.logger.warn(`OidcStateService instance #${this.instanceId} created with new HMAC secret`);
|
||||
this.logger.debug(`HMAC secret first 8 chars: ${this.hmacSecret.substring(0, 8)}`);
|
||||
}
|
||||
|
||||
generateSecureState(providerId: string, clientState: string): string {
|
||||
async generateSecureState(
|
||||
providerId: string,
|
||||
clientState: string,
|
||||
redirectUri?: string
|
||||
): Promise<string> {
|
||||
const nonce = crypto.randomBytes(16).toString('hex');
|
||||
const timestamp = Date.now();
|
||||
|
||||
@@ -35,8 +44,21 @@ export class OidcStateService {
|
||||
clientState,
|
||||
timestamp,
|
||||
providerId,
|
||||
redirectUri,
|
||||
};
|
||||
this.stateCache.set(nonce, stateData);
|
||||
|
||||
// Store in cache with TTL (in milliseconds for cache-manager v7)
|
||||
const cacheKey = `${this.STATE_CACHE_PREFIX}${nonce}`;
|
||||
this.logger.debug(`Storing state with key: ${cacheKey}, TTL: ${this.STATE_TTL_MS}ms`);
|
||||
await this.cacheManager.set(cacheKey, stateData, this.STATE_TTL_MS);
|
||||
|
||||
// Verify it was stored
|
||||
const verifyStored = await this.cacheManager.get(cacheKey);
|
||||
if (verifyStored) {
|
||||
this.logger.debug(`State successfully stored and verified for key: ${cacheKey}`);
|
||||
} else {
|
||||
this.logger.error(`CRITICAL: State was NOT stored in cache for key: ${cacheKey}`);
|
||||
}
|
||||
|
||||
// Create signed state: nonce.timestamp.signature
|
||||
const dataToSign = `${nonce}.${timestamp}`;
|
||||
@@ -45,14 +67,18 @@ export class OidcStateService {
|
||||
const signedState = `${dataToSign}.${signature}`;
|
||||
|
||||
this.logger.debug(`Generated secure state for provider ${providerId} with nonce ${nonce}`);
|
||||
this.logger.debug(
|
||||
`Instance #${this.instanceId}, HMAC secret first 8 chars: ${this.hmacSecret.substring(0, 8)}`
|
||||
);
|
||||
this.logger.debug(`Stored redirectUri: ${redirectUri}`);
|
||||
// Return state with provider ID prefix (unencrypted) for routing
|
||||
return `${providerId}:${signedState}`;
|
||||
}
|
||||
|
||||
validateSecureState(
|
||||
async validateSecureState(
|
||||
state: string,
|
||||
expectedProviderId: string
|
||||
): { isValid: boolean; clientState?: string; error?: string } {
|
||||
): Promise<{ isValid: boolean; clientState?: string; redirectUri?: string; error?: string }> {
|
||||
try {
|
||||
// Extract provider ID and signed state
|
||||
const parts = state.split(':');
|
||||
@@ -107,7 +133,7 @@ export class OidcStateService {
|
||||
// Check timestamp expiration
|
||||
const now = Date.now();
|
||||
const age = now - timestamp;
|
||||
if (age > this.STATE_TTL_SECONDS * 1000) {
|
||||
if (age > this.STATE_TTL_MS) {
|
||||
this.logger.warn(`State validation failed: token expired (age: ${age}ms)`);
|
||||
return {
|
||||
isValid: false,
|
||||
@@ -116,11 +142,21 @@ export class OidcStateService {
|
||||
}
|
||||
|
||||
// Check if state exists in cache (prevents replay attacks)
|
||||
const cachedState = this.stateCache.get(nonce);
|
||||
const cacheKey = `${this.STATE_CACHE_PREFIX}${nonce}`;
|
||||
this.logger.debug(`Looking for nonce ${nonce} in cache with key: ${cacheKey}`);
|
||||
this.logger.debug(
|
||||
`Instance #${this.instanceId}, HMAC secret first 8 chars: ${this.hmacSecret.substring(0, 8)}`
|
||||
);
|
||||
this.logger.debug(`Cache manager type: ${this.cacheManager.constructor.name}`);
|
||||
|
||||
const cachedState = await this.cacheManager.get<StateData>(cacheKey);
|
||||
|
||||
if (!cachedState) {
|
||||
this.logger.warn(
|
||||
`State validation failed: nonce ${nonce} not found in cache (possible replay attack)`
|
||||
);
|
||||
this.logger.warn(`Cache key checked: ${cacheKey}`);
|
||||
|
||||
return {
|
||||
isValid: false,
|
||||
error: 'State token not found or already used',
|
||||
@@ -137,12 +173,13 @@ export class OidcStateService {
|
||||
}
|
||||
|
||||
// Remove from cache to prevent reuse
|
||||
this.stateCache.delete(nonce);
|
||||
await this.cacheManager.del(cacheKey);
|
||||
|
||||
this.logger.debug(`State validation successful for provider ${expectedProviderId}`);
|
||||
return {
|
||||
isValid: true,
|
||||
clientState: cachedState.clientState,
|
||||
redirectUri: cachedState.redirectUri,
|
||||
};
|
||||
} catch (error) {
|
||||
this.logger.error(
|
||||
@@ -182,20 +219,5 @@ export class OidcStateService {
|
||||
return null;
|
||||
}
|
||||
|
||||
private cleanupExpiredStates(): void {
|
||||
const now = Date.now();
|
||||
let cleaned = 0;
|
||||
|
||||
for (const [nonce, stateData] of this.stateCache.entries()) {
|
||||
const age = now - stateData.timestamp;
|
||||
if (age > this.STATE_TTL_SECONDS * 1000) {
|
||||
this.stateCache.delete(nonce);
|
||||
cleaned++;
|
||||
}
|
||||
}
|
||||
|
||||
if (cleaned > 0) {
|
||||
this.logger.debug(`Cleaned up ${cleaned} expired state entries`);
|
||||
}
|
||||
}
|
||||
// Cleanup is now handled by cache TTL
|
||||
}
|
||||
@@ -1,33 +1,13 @@
|
||||
import { CacheModule } from '@nestjs/cache-manager';
|
||||
import { Module } from '@nestjs/common';
|
||||
|
||||
import { UserSettingsModule } from '@unraid/shared/services/user-settings.js';
|
||||
|
||||
import { OidcAuthService } from '@app/unraid-api/graph/resolvers/sso/oidc-auth.service.js';
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/oidc-config.service.js';
|
||||
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/oidc-session.service.js';
|
||||
import { OidcStateService } from '@app/unraid-api/graph/resolvers/sso/oidc-state.service.js';
|
||||
import { OidcValidationService } from '@app/unraid-api/graph/resolvers/sso/oidc-validation.service.js';
|
||||
import { OidcCoreModule } from '@app/unraid-api/graph/resolvers/sso/core/oidc-core.module.js';
|
||||
import { SsoResolver } from '@app/unraid-api/graph/resolvers/sso/sso.resolver.js';
|
||||
|
||||
import '@app/unraid-api/graph/resolvers/sso/sso-settings.types.js';
|
||||
import '@app/unraid-api/graph/resolvers/sso/models/sso-settings.types.js';
|
||||
|
||||
@Module({
|
||||
imports: [UserSettingsModule, CacheModule.register()],
|
||||
providers: [
|
||||
SsoResolver,
|
||||
OidcConfigPersistence,
|
||||
OidcSessionService,
|
||||
OidcStateService,
|
||||
OidcAuthService,
|
||||
OidcValidationService,
|
||||
],
|
||||
exports: [
|
||||
OidcConfigPersistence,
|
||||
OidcSessionService,
|
||||
OidcStateService,
|
||||
OidcAuthService,
|
||||
OidcValidationService,
|
||||
],
|
||||
imports: [OidcCoreModule],
|
||||
providers: [SsoResolver],
|
||||
exports: [OidcCoreModule],
|
||||
})
|
||||
export class SsoModule {}
|
||||
|
||||
@@ -6,11 +6,12 @@ import { PrefixedID } from '@unraid/shared/prefixed-id-scalar.js';
|
||||
import { UsePermissions } from '@unraid/shared/use-permissions.directive.js';
|
||||
|
||||
import { Public } from '@app/unraid-api/auth/public.decorator.js';
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/oidc-config.service.js';
|
||||
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/oidc-provider.model.js';
|
||||
import { OidcSessionValidation } from '@app/unraid-api/graph/resolvers/sso/oidc-session-validation.model.js';
|
||||
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/oidc-session.service.js';
|
||||
import { PublicOidcProvider } from '@app/unraid-api/graph/resolvers/sso/public-oidc-provider.model.js';
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
|
||||
import { OidcConfiguration } from '@app/unraid-api/graph/resolvers/sso/models/oidc-configuration.model.js';
|
||||
import { OidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/oidc-provider.model.js';
|
||||
import { OidcSessionValidation } from '@app/unraid-api/graph/resolvers/sso/models/oidc-session-validation.model.js';
|
||||
import { PublicOidcProvider } from '@app/unraid-api/graph/resolvers/sso/models/public-oidc-provider.model.js';
|
||||
import { OidcSessionService } from '@app/unraid-api/graph/resolvers/sso/session/oidc-session.service.js';
|
||||
|
||||
@Resolver()
|
||||
export class SsoResolver {
|
||||
@@ -88,6 +89,19 @@ export class SsoResolver {
|
||||
return this.oidcConfig.getProvider(id);
|
||||
}
|
||||
|
||||
@Query(() => OidcConfiguration, { description: 'Get the full OIDC configuration (admin only)' })
|
||||
@UsePermissions({
|
||||
action: AuthAction.READ_ANY,
|
||||
resource: Resource.CONFIG,
|
||||
})
|
||||
public async oidcConfiguration(): Promise<OidcConfiguration> {
|
||||
const config = await this.oidcConfig.getConfig();
|
||||
return {
|
||||
providers: config?.providers || [],
|
||||
defaultAllowedOrigins: config?.defaultAllowedOrigins || [],
|
||||
};
|
||||
}
|
||||
|
||||
@Query(() => OidcSessionValidation, {
|
||||
description: 'Validate an OIDC session token (internal use for CLI validation)',
|
||||
})
|
||||
|
||||
@@ -0,0 +1,234 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
|
||||
export interface OidcErrorDetails {
|
||||
userFriendlyError: string;
|
||||
details: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export class OidcErrorHelper {
|
||||
private static readonly logger = new Logger(OidcErrorHelper.name);
|
||||
|
||||
/**
|
||||
* Parse fetch errors and return user-friendly error messages
|
||||
*/
|
||||
static parseFetchError(error: unknown, issuerUrl?: string): OidcErrorDetails {
|
||||
const errorMessage = error instanceof Error ? error.message : String(error);
|
||||
let userFriendlyError = errorMessage;
|
||||
let details: Record<string, unknown> = { originalError: errorMessage };
|
||||
|
||||
// Extract cause information if available
|
||||
if (error instanceof Error && 'cause' in error) {
|
||||
const cause = (error as any).cause;
|
||||
if (cause) {
|
||||
this.logger.log('Fetch error cause: %o', cause);
|
||||
|
||||
const errorCode = cause.code || '';
|
||||
const causeMessage = cause.message || '';
|
||||
|
||||
// Map error codes to user-friendly messages
|
||||
switch (errorCode) {
|
||||
case 'ENOTFOUND':
|
||||
userFriendlyError = `Cannot resolve domain name. Please check that '${issuerUrl}' is accessible and spelled correctly.`;
|
||||
details = {
|
||||
type: 'DNS_ERROR',
|
||||
originalError: errorMessage,
|
||||
cause: causeMessage || errorCode,
|
||||
};
|
||||
break;
|
||||
|
||||
case 'ECONNREFUSED':
|
||||
userFriendlyError = `Connection refused. The server at '${issuerUrl}' is not accepting connections.`;
|
||||
details = {
|
||||
type: 'CONNECTION_ERROR',
|
||||
originalError: errorMessage,
|
||||
cause: causeMessage || errorCode,
|
||||
};
|
||||
break;
|
||||
|
||||
case 'CERT_HAS_EXPIRED':
|
||||
userFriendlyError = `SSL/TLS certificate error. The server certificate may be invalid or expired.`;
|
||||
details = {
|
||||
type: 'SSL_ERROR',
|
||||
originalError: errorMessage,
|
||||
cause: causeMessage || errorCode,
|
||||
};
|
||||
break;
|
||||
|
||||
case 'ETIMEDOUT':
|
||||
userFriendlyError = `Connection timeout. The server at '${issuerUrl}' is not responding.`;
|
||||
details = {
|
||||
type: 'TIMEOUT_ERROR',
|
||||
originalError: errorMessage,
|
||||
cause: causeMessage || errorCode,
|
||||
};
|
||||
break;
|
||||
|
||||
default:
|
||||
// Check message patterns if code doesn't match
|
||||
if (causeMessage.includes('ENOTFOUND')) {
|
||||
userFriendlyError = `Cannot resolve domain name. Please check that '${issuerUrl}' is accessible and spelled correctly.`;
|
||||
details = {
|
||||
type: 'DNS_ERROR',
|
||||
originalError: errorMessage,
|
||||
cause: causeMessage,
|
||||
};
|
||||
} else if (causeMessage.includes('ECONNREFUSED')) {
|
||||
userFriendlyError = `Connection refused. The server at '${issuerUrl}' is not accepting connections.`;
|
||||
details = {
|
||||
type: 'CONNECTION_ERROR',
|
||||
originalError: errorMessage,
|
||||
cause: causeMessage,
|
||||
};
|
||||
} else if (
|
||||
causeMessage.includes('certificate') ||
|
||||
causeMessage.includes('SSL') ||
|
||||
causeMessage.includes('TLS')
|
||||
) {
|
||||
userFriendlyError = `SSL/TLS certificate error. The server certificate may be invalid or expired.`;
|
||||
details = {
|
||||
type: 'SSL_ERROR',
|
||||
originalError: errorMessage,
|
||||
cause: causeMessage,
|
||||
};
|
||||
} else if (causeMessage.includes('ETIMEDOUT')) {
|
||||
userFriendlyError = `Connection timeout. The server at '${issuerUrl}' is not responding.`;
|
||||
details = {
|
||||
type: 'TIMEOUT_ERROR',
|
||||
originalError: errorMessage,
|
||||
cause: causeMessage,
|
||||
};
|
||||
} else {
|
||||
userFriendlyError = `Failed to connect to OIDC provider at '${issuerUrl}'. ${causeMessage || errorCode || 'Unknown network error'}`;
|
||||
details = {
|
||||
type: 'FETCH_ERROR',
|
||||
originalError: errorMessage,
|
||||
cause: causeMessage || errorCode,
|
||||
};
|
||||
}
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
// Generic fetch failed without cause
|
||||
userFriendlyError = `Failed to connect to OIDC provider at '${issuerUrl}'. Please verify the URL is correct and accessible.`;
|
||||
details = { type: 'FETCH_ERROR', originalError: errorMessage };
|
||||
}
|
||||
} else if (errorMessage.includes('fetch failed')) {
|
||||
// Fetch failed but no cause information
|
||||
userFriendlyError = `Failed to connect to OIDC provider at '${issuerUrl}'. Please verify the URL is correct and accessible.`;
|
||||
details = { type: 'FETCH_ERROR', originalError: errorMessage };
|
||||
}
|
||||
|
||||
return { userFriendlyError, details };
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse HTTP status errors and return user-friendly error messages
|
||||
*/
|
||||
static parseHttpError(errorMessage: string, issuerUrl?: string): OidcErrorDetails {
|
||||
let userFriendlyError = errorMessage;
|
||||
let details: Record<string, unknown> = { originalError: errorMessage };
|
||||
|
||||
if (errorMessage.includes('404') || errorMessage.includes('Not Found')) {
|
||||
const baseUrl = issuerUrl?.endsWith('/.well-known/openid-configuration')
|
||||
? issuerUrl.replace('/.well-known/openid-configuration', '')
|
||||
: issuerUrl;
|
||||
userFriendlyError = `OIDC discovery endpoint not found. Please verify that '${baseUrl}/.well-known/openid-configuration' exists.`;
|
||||
details = { type: 'DISCOVERY_NOT_FOUND', originalError: errorMessage };
|
||||
} else if (errorMessage.includes('401') || errorMessage.includes('403')) {
|
||||
userFriendlyError = `Access denied to discovery endpoint. Please check the issuer URL and any authentication requirements.`;
|
||||
details = { type: 'AUTHENTICATION_ERROR', originalError: errorMessage };
|
||||
} else if (errorMessage.includes('unexpected HTTP response status code')) {
|
||||
// Extract status code if possible
|
||||
const statusMatch = errorMessage.match(/status code (\d+)/);
|
||||
const statusCode = statusMatch ? statusMatch[1] : 'unknown';
|
||||
const baseUrl = issuerUrl?.endsWith('/.well-known/openid-configuration')
|
||||
? issuerUrl.replace('/.well-known/openid-configuration', '')
|
||||
: issuerUrl;
|
||||
userFriendlyError = `HTTP ${statusCode} error from discovery endpoint. Please check that '${baseUrl}/.well-known/openid-configuration' returns a valid OIDC discovery document.`;
|
||||
details = { type: 'HTTP_STATUS_ERROR', statusCode, originalError: errorMessage };
|
||||
}
|
||||
|
||||
return { userFriendlyError, details };
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse generic OIDC errors and return user-friendly error messages
|
||||
*/
|
||||
static parseGenericError(error: unknown, issuerUrl?: string): OidcErrorDetails {
|
||||
const errorMessage = error instanceof Error ? error.message : String(error);
|
||||
let userFriendlyError = errorMessage;
|
||||
let details: Record<string, unknown> = { originalError: errorMessage };
|
||||
|
||||
// Check for specific error patterns
|
||||
if (errorMessage.includes('getaddrinfo ENOTFOUND')) {
|
||||
userFriendlyError = `Cannot resolve domain name. Please check that '${issuerUrl}' is accessible and spelled correctly.`;
|
||||
details = { type: 'DNS_ERROR', originalError: errorMessage };
|
||||
} else if (errorMessage.includes('ECONNREFUSED')) {
|
||||
userFriendlyError = `Connection refused. The server at '${issuerUrl}' is not accepting connections.`;
|
||||
details = { type: 'CONNECTION_ERROR', originalError: errorMessage };
|
||||
} else if (errorMessage.includes('ECONNRESET') || errorMessage.includes('ETIMEDOUT')) {
|
||||
userFriendlyError = `Connection timeout. The server at '${issuerUrl}' is not responding.`;
|
||||
details = { type: 'TIMEOUT_ERROR', originalError: errorMessage };
|
||||
} else if (
|
||||
errorMessage.includes('certificate') ||
|
||||
errorMessage.includes('SSL') ||
|
||||
errorMessage.includes('TLS')
|
||||
) {
|
||||
userFriendlyError = `SSL/TLS certificate error. The server certificate may be invalid or expired.`;
|
||||
details = { type: 'SSL_ERROR', originalError: errorMessage };
|
||||
} else if (errorMessage.includes('JSON') || errorMessage.includes('parse')) {
|
||||
userFriendlyError = `Invalid OIDC discovery response. The server returned malformed JSON.`;
|
||||
details = { type: 'INVALID_JSON', originalError: errorMessage };
|
||||
} else if (error && (error as any).code === 'OAUTH_RESPONSE_IS_NOT_CONFORM') {
|
||||
const baseUrl = issuerUrl?.endsWith('/.well-known/openid-configuration')
|
||||
? issuerUrl.replace('/.well-known/openid-configuration', '')
|
||||
: issuerUrl;
|
||||
userFriendlyError = `Invalid OIDC discovery document. The server at '${baseUrl}/.well-known/openid-configuration' returned a response that doesn't conform to the OpenID Connect Discovery specification. Please verify the endpoint returns valid OIDC metadata.`;
|
||||
details = { type: 'INVALID_OIDC_DOCUMENT', originalError: errorMessage };
|
||||
}
|
||||
|
||||
return { userFriendlyError, details };
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse OIDC discovery errors and return user-friendly error messages
|
||||
*/
|
||||
static parseDiscoveryError(error: unknown, issuerUrl?: string): OidcErrorDetails {
|
||||
const errorMessage = error instanceof Error ? error.message : String(error);
|
||||
|
||||
// Log additional error details for debugging
|
||||
if (error instanceof Error) {
|
||||
this.logger.log(`Error type: ${error.constructor.name}`);
|
||||
if ('stack' in error && error.stack) {
|
||||
this.logger.debug(`Stack trace: ${error.stack}`);
|
||||
}
|
||||
if ('response' in error) {
|
||||
const response = (error as any).response;
|
||||
if (response) {
|
||||
this.logger.log(`Response status: ${response.status}`);
|
||||
this.logger.log(`Response body: ${response.body}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for fetch-specific errors first
|
||||
if (errorMessage.includes('fetch failed')) {
|
||||
return this.parseFetchError(error, issuerUrl);
|
||||
}
|
||||
|
||||
// Check for HTTP status errors
|
||||
const httpError = this.parseHttpError(errorMessage, issuerUrl);
|
||||
// Proper type-narrowing guard for accessing details.type
|
||||
if (
|
||||
httpError.details &&
|
||||
typeof httpError.details === 'object' &&
|
||||
'type' in httpError.details &&
|
||||
httpError.details.type !== undefined
|
||||
) {
|
||||
return httpError;
|
||||
}
|
||||
|
||||
// Fall back to generic error parsing
|
||||
return this.parseGenericError(error, issuerUrl);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,228 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import type { FastifyRequest } from '@app/unraid-api/types/fastify.js';
|
||||
import { OidcRequestHandler } from '@app/unraid-api/graph/resolvers/sso/utils/oidc-request-handler.util.js';
|
||||
|
||||
describe('OidcRequestHandler', () => {
|
||||
let mockLogger: Logger;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockLogger = {
|
||||
debug: vi.fn(),
|
||||
log: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
} as any;
|
||||
});
|
||||
|
||||
describe('extractRequestInfo', () => {
|
||||
it('should extract request info from headers', () => {
|
||||
const mockReq = {
|
||||
headers: {
|
||||
'x-forwarded-proto': 'https',
|
||||
'x-forwarded-host': 'example.com:8443',
|
||||
},
|
||||
protocol: 'http',
|
||||
url: '/callback?code=123&state=456',
|
||||
} as unknown as FastifyRequest;
|
||||
|
||||
const result = OidcRequestHandler.extractRequestInfo(mockReq);
|
||||
|
||||
expect(result.protocol).toBe('https');
|
||||
expect(result.host).toBe('example.com:8443');
|
||||
expect(result.fullUrl).toBe('https://example.com:8443/callback?code=123&state=456');
|
||||
expect(result.baseUrl).toBe('https://example.com:8443');
|
||||
});
|
||||
|
||||
it('should fall back to request properties when headers are missing', () => {
|
||||
const mockReq = {
|
||||
headers: {
|
||||
host: 'localhost:3000',
|
||||
},
|
||||
protocol: 'http',
|
||||
url: '/callback?code=123&state=456',
|
||||
} as FastifyRequest;
|
||||
|
||||
const result = OidcRequestHandler.extractRequestInfo(mockReq);
|
||||
|
||||
expect(result.protocol).toBe('http');
|
||||
expect(result.host).toBe('localhost:3000');
|
||||
expect(result.fullUrl).toBe('http://localhost:3000/callback?code=123&state=456');
|
||||
expect(result.baseUrl).toBe('http://localhost:3000');
|
||||
});
|
||||
|
||||
it('should use defaults when all headers are missing', () => {
|
||||
const mockReq = {
|
||||
headers: {},
|
||||
url: '/callback?code=123&state=456',
|
||||
} as FastifyRequest;
|
||||
|
||||
const result = OidcRequestHandler.extractRequestInfo(mockReq);
|
||||
|
||||
expect(result.protocol).toBe('http');
|
||||
expect(result.host).toBe('localhost:3000');
|
||||
expect(result.fullUrl).toBe('http://localhost:3000/callback?code=123&state=456');
|
||||
expect(result.baseUrl).toBe('http://localhost:3000');
|
||||
});
|
||||
});
|
||||
|
||||
describe('validateAuthorizeParams', () => {
|
||||
it('should validate valid parameters', () => {
|
||||
const result = OidcRequestHandler.validateAuthorizeParams(
|
||||
'provider123',
|
||||
'state456',
|
||||
'https://example.com/callback'
|
||||
);
|
||||
|
||||
expect(result.providerId).toBe('provider123');
|
||||
expect(result.state).toBe('state456');
|
||||
expect(result.redirectUri).toBe('https://example.com/callback');
|
||||
});
|
||||
|
||||
it('should throw error for missing provider ID', () => {
|
||||
expect(() => {
|
||||
OidcRequestHandler.validateAuthorizeParams(
|
||||
undefined,
|
||||
'state456',
|
||||
'https://example.com/callback'
|
||||
);
|
||||
}).toThrow('Provider ID is required');
|
||||
});
|
||||
|
||||
it('should throw error for missing state', () => {
|
||||
expect(() => {
|
||||
OidcRequestHandler.validateAuthorizeParams(
|
||||
'provider123',
|
||||
undefined,
|
||||
'https://example.com/callback'
|
||||
);
|
||||
}).toThrow('State parameter is required');
|
||||
});
|
||||
|
||||
it('should throw error for missing redirect URI', () => {
|
||||
expect(() => {
|
||||
OidcRequestHandler.validateAuthorizeParams('provider123', 'state456', undefined);
|
||||
}).toThrow('Redirect URI is required');
|
||||
});
|
||||
});
|
||||
|
||||
describe('validateCallbackParams', () => {
|
||||
it('should validate valid parameters', () => {
|
||||
const result = OidcRequestHandler.validateCallbackParams('code123', 'state456');
|
||||
|
||||
expect(result.code).toBe('code123');
|
||||
expect(result.state).toBe('state456');
|
||||
});
|
||||
|
||||
it('should throw error for missing code', () => {
|
||||
expect(() => {
|
||||
OidcRequestHandler.validateCallbackParams(undefined, 'state456');
|
||||
}).toThrow('Missing required parameters');
|
||||
});
|
||||
|
||||
it('should throw error for missing state', () => {
|
||||
expect(() => {
|
||||
OidcRequestHandler.validateCallbackParams('code123', undefined);
|
||||
}).toThrow('Missing required parameters');
|
||||
});
|
||||
|
||||
it('should throw error for empty code', () => {
|
||||
expect(() => {
|
||||
OidcRequestHandler.validateCallbackParams('', 'state456');
|
||||
}).toThrow('Missing required parameters');
|
||||
});
|
||||
|
||||
it('should throw error for empty state', () => {
|
||||
expect(() => {
|
||||
OidcRequestHandler.validateCallbackParams('code123', '');
|
||||
}).toThrow('Missing required parameters');
|
||||
});
|
||||
});
|
||||
|
||||
describe('handleAuthorize', () => {
|
||||
it('should handle authorization flow', async () => {
|
||||
const mockAuthService = {
|
||||
getAuthorizationUrl: vi
|
||||
.fn()
|
||||
.mockResolvedValue('https://provider.com/auth?client_id=123'),
|
||||
};
|
||||
|
||||
const mockReq = {
|
||||
headers: { 'x-forwarded-proto': 'https', 'x-forwarded-host': 'example.com' },
|
||||
url: '/authorize',
|
||||
} as unknown as FastifyRequest;
|
||||
|
||||
const authUrl = await OidcRequestHandler.handleAuthorize(
|
||||
'provider123',
|
||||
'state456',
|
||||
'https://example.com/callback',
|
||||
mockReq,
|
||||
mockAuthService as any,
|
||||
mockLogger
|
||||
);
|
||||
|
||||
expect(authUrl).toBe('https://provider.com/auth?client_id=123');
|
||||
expect(mockAuthService.getAuthorizationUrl).toHaveBeenCalledWith({
|
||||
providerId: 'provider123',
|
||||
state: 'state456',
|
||||
requestOrigin: 'https://example.com/callback',
|
||||
requestHeaders: {
|
||||
'x-forwarded-proto': 'https',
|
||||
'x-forwarded-host': 'example.com',
|
||||
},
|
||||
});
|
||||
expect(mockLogger.debug).toHaveBeenCalledWith(
|
||||
'Authorization request - Provider: provider123'
|
||||
);
|
||||
expect(mockLogger.log).toHaveBeenCalledWith(
|
||||
'Redirecting to OIDC provider: https://provider.com/auth?client_id=123'
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('handleCallback', () => {
|
||||
it('should handle callback flow', async () => {
|
||||
const mockStateService = {
|
||||
extractProviderFromState: vi.fn().mockReturnValue('provider123'),
|
||||
};
|
||||
|
||||
const mockAuthService = {
|
||||
getStateService: vi.fn().mockReturnValue(mockStateService),
|
||||
handleCallback: vi.fn().mockResolvedValue('paddedToken123'),
|
||||
};
|
||||
|
||||
const mockReq: Pick<FastifyRequest, 'id' | 'headers' | 'url'> = {
|
||||
id: '123',
|
||||
headers: { 'x-forwarded-proto': 'https', 'x-forwarded-host': 'example.com' },
|
||||
url: '/callback?code=123&state=456',
|
||||
};
|
||||
|
||||
const result = await OidcRequestHandler.handleCallback(
|
||||
'code123',
|
||||
'state456',
|
||||
mockReq as unknown as FastifyRequest,
|
||||
mockAuthService as any,
|
||||
mockLogger
|
||||
);
|
||||
|
||||
expect(result.providerId).toBe('provider123');
|
||||
expect(result.paddedToken).toBe('paddedToken123');
|
||||
expect(result.requestInfo.fullUrl).toBe('https://example.com/callback?code=123&state=456');
|
||||
expect(mockAuthService.handleCallback).toHaveBeenCalledWith({
|
||||
providerId: 'provider123',
|
||||
code: 'code123',
|
||||
state: 'state456',
|
||||
requestOrigin: 'https://example.com',
|
||||
fullCallbackUrl: 'https://example.com/callback?code=123&state=456',
|
||||
requestHeaders: {
|
||||
'x-forwarded-proto': 'https',
|
||||
'x-forwarded-host': 'example.com',
|
||||
},
|
||||
});
|
||||
expect(mockLogger.debug).toHaveBeenCalledWith('Callback request - Provider: provider123');
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,155 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
|
||||
import type { FastifyRequest } from '@app/unraid-api/types/fastify.js';
|
||||
import { OidcService } from '@app/unraid-api/graph/resolvers/sso/core/oidc.service.js';
|
||||
import { OidcStateExtractor } from '@app/unraid-api/graph/resolvers/sso/session/oidc-state-extractor.util.js';
|
||||
|
||||
export interface RequestInfo {
|
||||
protocol: string;
|
||||
host: string;
|
||||
fullUrl: string;
|
||||
baseUrl: string;
|
||||
}
|
||||
|
||||
export interface OidcFlowResult {
|
||||
providerId: string;
|
||||
requestInfo: RequestInfo;
|
||||
}
|
||||
|
||||
export interface OidcCallbackResult extends OidcFlowResult {
|
||||
paddedToken: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Utility class to handle common OIDC request processing logic
|
||||
* between authorize and callback endpoints
|
||||
*/
|
||||
export class OidcRequestHandler {
|
||||
/**
|
||||
* Extract request information from Fastify request headers
|
||||
*/
|
||||
static extractRequestInfo(req: FastifyRequest): RequestInfo {
|
||||
// Handle potentially comma-separated forwarded headers (take first value)
|
||||
const forwardedProto = String(req.headers['x-forwarded-proto'] || '')
|
||||
.split(',')[0]
|
||||
?.trim();
|
||||
const forwardedHost = String(req.headers['x-forwarded-host'] || '')
|
||||
.split(',')[0]
|
||||
?.trim();
|
||||
|
||||
const protocol = forwardedProto || req.protocol || 'http';
|
||||
const host = forwardedHost || req.headers.host || 'localhost:3000';
|
||||
const fullUrl = `${protocol}://${host}${req.url}`;
|
||||
const baseUrl = `${protocol}://${host}`;
|
||||
|
||||
return {
|
||||
protocol,
|
||||
host,
|
||||
fullUrl,
|
||||
baseUrl,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle OIDC authorization flow
|
||||
*/
|
||||
static async handleAuthorize(
|
||||
providerId: string,
|
||||
state: string,
|
||||
redirectUri: string,
|
||||
req: FastifyRequest,
|
||||
oidcService: OidcService,
|
||||
logger: Logger
|
||||
): Promise<string> {
|
||||
const requestInfo = this.extractRequestInfo(req);
|
||||
|
||||
logger.debug(`Authorization request - Provider: ${providerId}`);
|
||||
logger.debug(`Authorization request - Full URL: ${requestInfo.fullUrl}`);
|
||||
logger.debug(`Authorization request - Redirect URI: ${redirectUri}`);
|
||||
|
||||
// Get authorization URL using the validated redirect URI and request headers
|
||||
const authUrl = await oidcService.getAuthorizationUrl({
|
||||
providerId,
|
||||
state,
|
||||
requestOrigin: redirectUri,
|
||||
requestHeaders: req.headers as Record<string, string | string[] | undefined>,
|
||||
});
|
||||
|
||||
logger.log(`Redirecting to OIDC provider: ${authUrl}`);
|
||||
return authUrl;
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle OIDC callback flow
|
||||
*/
|
||||
static async handleCallback(
|
||||
code: string,
|
||||
state: string,
|
||||
req: FastifyRequest,
|
||||
oidcService: OidcService,
|
||||
logger: Logger
|
||||
): Promise<OidcCallbackResult> {
|
||||
// Extract provider ID from state for routing
|
||||
const { providerId } = OidcStateExtractor.extractProviderFromState(
|
||||
state,
|
||||
oidcService.getStateService()
|
||||
);
|
||||
|
||||
const requestInfo = this.extractRequestInfo(req);
|
||||
|
||||
logger.debug(`Callback request - Provider: ${providerId}`);
|
||||
logger.debug(`Callback request - Full URL: ${requestInfo.fullUrl}`);
|
||||
logger.debug(`Redirect URI will be retrieved from encrypted state`);
|
||||
|
||||
// Handle the callback using stored redirect URI from state and request headers
|
||||
const paddedToken = await oidcService.handleCallback({
|
||||
providerId,
|
||||
code,
|
||||
state,
|
||||
requestOrigin: requestInfo.baseUrl,
|
||||
fullCallbackUrl: requestInfo.fullUrl,
|
||||
requestHeaders: req.headers as Record<string, string | string[] | undefined>,
|
||||
});
|
||||
|
||||
return {
|
||||
providerId,
|
||||
requestInfo,
|
||||
paddedToken,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate required parameters for authorization flow
|
||||
*/
|
||||
static validateAuthorizeParams(
|
||||
providerId: string | undefined,
|
||||
state: string | undefined,
|
||||
redirectUri: string | undefined
|
||||
): { providerId: string; state: string; redirectUri: string } {
|
||||
if (!providerId) {
|
||||
throw new Error('Provider ID is required');
|
||||
}
|
||||
if (!state) {
|
||||
throw new Error('State parameter is required');
|
||||
}
|
||||
if (!redirectUri) {
|
||||
throw new Error('Redirect URI is required');
|
||||
}
|
||||
|
||||
return { providerId, state, redirectUri };
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate required parameters for callback flow
|
||||
*/
|
||||
static validateCallbackParams(
|
||||
code: string | undefined,
|
||||
state: string | undefined
|
||||
): { code: string; state: string } {
|
||||
if (!code || !state) {
|
||||
throw new Error('Missing required parameters');
|
||||
}
|
||||
|
||||
return { code, state };
|
||||
}
|
||||
}
|
||||
@@ -2,12 +2,12 @@ import { Module } from '@nestjs/common';
|
||||
import { ScheduleModule } from '@nestjs/schedule';
|
||||
|
||||
import { SubscriptionHelperService } from '@app/unraid-api/graph/services/subscription-helper.service.js';
|
||||
import { SubscriptionPollingService } from '@app/unraid-api/graph/services/subscription-polling.service.js';
|
||||
import { SubscriptionManagerService } from '@app/unraid-api/graph/services/subscription-manager.service.js';
|
||||
import { SubscriptionTrackerService } from '@app/unraid-api/graph/services/subscription-tracker.service.js';
|
||||
|
||||
@Module({
|
||||
imports: [],
|
||||
providers: [SubscriptionTrackerService, SubscriptionHelperService, SubscriptionPollingService],
|
||||
exports: [SubscriptionTrackerService, SubscriptionHelperService, SubscriptionPollingService],
|
||||
providers: [SubscriptionTrackerService, SubscriptionHelperService, SubscriptionManagerService],
|
||||
exports: [SubscriptionTrackerService, SubscriptionHelperService], // SubscriptionManagerService is internal
|
||||
})
|
||||
export class ServicesModule {}
|
||||
|
||||
@@ -4,7 +4,25 @@ import { createSubscription, PUBSUB_CHANNEL } from '@app/core/pubsub.js';
|
||||
import { SubscriptionTrackerService } from '@app/unraid-api/graph/services/subscription-tracker.service.js';
|
||||
|
||||
/**
|
||||
* Helper service for creating tracked GraphQL subscriptions with automatic cleanup
|
||||
* High-level helper service for creating GraphQL subscriptions with automatic cleanup.
|
||||
*
|
||||
* This service provides a convenient way to create GraphQL subscriptions that:
|
||||
* - Automatically track subscriber count via SubscriptionTrackerService
|
||||
* - Properly clean up resources when subscriptions end
|
||||
* - Handle errors gracefully
|
||||
*
|
||||
* **When to use this service:**
|
||||
* - In GraphQL resolvers when implementing subscriptions
|
||||
* - When you need automatic reference counting for shared resources
|
||||
* - When you want to ensure proper cleanup on subscription termination
|
||||
*
|
||||
* @example
|
||||
* // In a GraphQL resolver
|
||||
* \@Subscription(() => MetricsUpdate)
|
||||
* async metricsSubscription() {
|
||||
* // Topic must be registered first via SubscriptionTrackerService
|
||||
* return this.subscriptionHelper.createTrackedSubscription(PUBSUB_CHANNEL.METRICS);
|
||||
* }
|
||||
*/
|
||||
@Injectable()
|
||||
export class SubscriptionHelperService {
|
||||
@@ -15,7 +33,7 @@ export class SubscriptionHelperService {
|
||||
* @param topic The subscription topic/channel to subscribe to
|
||||
* @returns A proxy async iterator with automatic cleanup
|
||||
*/
|
||||
public createTrackedSubscription<T = any>(topic: PUBSUB_CHANNEL): AsyncIterableIterator<T> {
|
||||
public createTrackedSubscription<T = any>(topic: PUBSUB_CHANNEL | string): AsyncIterableIterator<T> {
|
||||
const innerIterator = createSubscription<T>(topic);
|
||||
|
||||
// Subscribe when the subscription starts
|
||||
|
||||
@@ -0,0 +1,196 @@
|
||||
import { Injectable, Logger, OnModuleDestroy } from '@nestjs/common';
|
||||
import { SchedulerRegistry } from '@nestjs/schedule';
|
||||
|
||||
/**
|
||||
* Configuration for managed subscriptions
|
||||
*/
|
||||
export interface SubscriptionConfig {
|
||||
/** Unique identifier for the subscription */
|
||||
name: string;
|
||||
|
||||
/**
|
||||
* Polling interval in milliseconds.
|
||||
* - If set to a number, the callback will be called at that interval
|
||||
* - If null/undefined, the subscription is event-based (no polling)
|
||||
*/
|
||||
intervalMs?: number | null;
|
||||
|
||||
/** Function to call periodically (for polling) or once (for setup) */
|
||||
callback: () => Promise<void>;
|
||||
|
||||
/** Optional function called when the subscription starts */
|
||||
onStart?: () => Promise<void>;
|
||||
|
||||
/** Optional function called when the subscription stops */
|
||||
onStop?: () => Promise<void>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Low-level service for managing both polling and event-based subscriptions.
|
||||
*
|
||||
* ⚠️ **IMPORTANT**: This is an internal service. Do not use directly in resolvers or business logic.
|
||||
* Instead, use one of the higher-level services:
|
||||
* - **SubscriptionTrackerService**: For subscriptions that need reference counting
|
||||
* - **SubscriptionHelperService**: For GraphQL subscriptions with automatic cleanup
|
||||
*
|
||||
* This service provides the underlying implementation for:
|
||||
* - **Polling subscriptions**: Execute a callback at regular intervals
|
||||
* - **Event-based subscriptions**: Set up event listeners or watchers that persist until stopped
|
||||
*
|
||||
* @internal
|
||||
*/
|
||||
@Injectable()
|
||||
export class SubscriptionManagerService implements OnModuleDestroy {
|
||||
private readonly logger = new Logger(SubscriptionManagerService.name);
|
||||
private readonly activeSubscriptions = new Map<
|
||||
string,
|
||||
{ isPolling: boolean; config?: SubscriptionConfig }
|
||||
>();
|
||||
|
||||
constructor(private readonly schedulerRegistry: SchedulerRegistry) {}
|
||||
|
||||
async onModuleDestroy() {
|
||||
await this.stopAll();
|
||||
}
|
||||
|
||||
/**
|
||||
* Start a managed subscription (polling or event-based).
|
||||
*
|
||||
* @param config - The subscription configuration
|
||||
* @throws Will throw an error if the onStart callback fails
|
||||
*/
|
||||
async startSubscription(config: SubscriptionConfig): Promise<void> {
|
||||
const { name, intervalMs, callback, onStart } = config;
|
||||
|
||||
// Clean up any existing subscription with the same name
|
||||
await this.stopSubscription(name);
|
||||
|
||||
// Initialize subscription state with config
|
||||
this.activeSubscriptions.set(name, { isPolling: false, config });
|
||||
|
||||
// Call onStart callback if provided
|
||||
if (onStart) {
|
||||
try {
|
||||
await onStart();
|
||||
this.logger.debug(`Called onStart for '${name}'`);
|
||||
} catch (error) {
|
||||
this.logger.error(`Error in onStart for '${name}'`, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
// If intervalMs is null, this is a continuous/event-based subscription
|
||||
if (intervalMs === null || intervalMs === undefined) {
|
||||
this.logger.debug(`Started continuous subscription for '${name}' (no polling)`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Create the polling function with guard against overlapping executions
|
||||
const pollFunction = async () => {
|
||||
const subscription = this.activeSubscriptions.get(name);
|
||||
if (!subscription || subscription.isPolling) {
|
||||
return;
|
||||
}
|
||||
|
||||
subscription.isPolling = true;
|
||||
try {
|
||||
await callback();
|
||||
} catch (error) {
|
||||
this.logger.error(`Error in subscription callback '${name}'`, error);
|
||||
} finally {
|
||||
if (subscription) {
|
||||
subscription.isPolling = false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Create and register the interval
|
||||
const interval = setInterval(pollFunction, intervalMs);
|
||||
this.schedulerRegistry.addInterval(name, interval);
|
||||
|
||||
this.logger.debug(`Started polling for '${name}' every ${intervalMs}ms`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop a managed subscription.
|
||||
*
|
||||
* This will:
|
||||
* 1. Stop any active polling interval
|
||||
* 2. Call the onStop callback if provided
|
||||
* 3. Clean up internal state
|
||||
*
|
||||
* @param name - The unique identifier of the subscription to stop
|
||||
*/
|
||||
async stopSubscription(name: string): Promise<void> {
|
||||
// Get the config before deleting
|
||||
const subscription = this.activeSubscriptions.get(name);
|
||||
const onStop = subscription?.config?.onStop;
|
||||
|
||||
try {
|
||||
if (this.schedulerRegistry.doesExist('interval', name)) {
|
||||
this.schedulerRegistry.deleteInterval(name);
|
||||
this.logger.debug(`Stopped polling interval for '${name}'`);
|
||||
}
|
||||
} catch (error) {
|
||||
// Interval doesn't exist, which is fine
|
||||
}
|
||||
|
||||
// Call onStop callback if provided
|
||||
if (onStop) {
|
||||
try {
|
||||
await onStop();
|
||||
this.logger.debug(`Called onStop for '${name}'`);
|
||||
} catch (error) {
|
||||
this.logger.error(`Error in onStop for '${name}'`, error);
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up subscription state
|
||||
this.activeSubscriptions.delete(name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop all active subscriptions.
|
||||
*
|
||||
* This is automatically called when the module is destroyed.
|
||||
*/
|
||||
async stopAll(): Promise<void> {
|
||||
// Get all active subscription keys (both polling and event-based)
|
||||
const activeKeys = Array.from(this.activeSubscriptions.keys());
|
||||
|
||||
// Stop each subscription and await cleanup
|
||||
await Promise.all(activeKeys.map((key) => this.stopSubscription(key)));
|
||||
|
||||
// Clear the map after all subscriptions are stopped
|
||||
this.activeSubscriptions.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a subscription is active.
|
||||
*
|
||||
* @param name - The unique identifier of the subscription
|
||||
* @returns true if the subscription exists (either polling or event-based)
|
||||
*/
|
||||
isSubscriptionActive(name: string): boolean {
|
||||
// Check both for polling intervals and event-based subscriptions
|
||||
return this.activeSubscriptions.has(name) || this.schedulerRegistry.doesExist('interval', name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the total number of active subscriptions.
|
||||
*
|
||||
* @returns The count of all active subscriptions (polling and event-based)
|
||||
*/
|
||||
getActiveSubscriptionCount(): number {
|
||||
return this.activeSubscriptions.size;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a list of all active subscription names.
|
||||
*
|
||||
* @returns Array of subscription identifiers
|
||||
*/
|
||||
getActiveSubscriptionNames(): string[] {
|
||||
return Array.from(this.activeSubscriptions.keys());
|
||||
}
|
||||
}
|
||||
@@ -1,91 +0,0 @@
|
||||
import { Injectable, Logger, OnModuleDestroy } from '@nestjs/common';
|
||||
import { SchedulerRegistry } from '@nestjs/schedule';
|
||||
|
||||
export interface PollingConfig {
|
||||
name: string;
|
||||
intervalMs: number;
|
||||
callback: () => Promise<void>;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class SubscriptionPollingService implements OnModuleDestroy {
|
||||
private readonly logger = new Logger(SubscriptionPollingService.name);
|
||||
private readonly activePollers = new Map<string, { isPolling: boolean }>();
|
||||
|
||||
constructor(private readonly schedulerRegistry: SchedulerRegistry) {}
|
||||
|
||||
onModuleDestroy() {
|
||||
this.stopAll();
|
||||
}
|
||||
|
||||
/**
|
||||
* Start polling for a specific subscription topic
|
||||
*/
|
||||
startPolling(config: PollingConfig): void {
|
||||
const { name, intervalMs, callback } = config;
|
||||
|
||||
// Clean up any existing interval
|
||||
this.stopPolling(name);
|
||||
|
||||
// Initialize polling state
|
||||
this.activePollers.set(name, { isPolling: false });
|
||||
|
||||
// Create the polling function with guard against overlapping executions
|
||||
const pollFunction = async () => {
|
||||
const poller = this.activePollers.get(name);
|
||||
if (!poller || poller.isPolling) {
|
||||
return;
|
||||
}
|
||||
|
||||
poller.isPolling = true;
|
||||
try {
|
||||
await callback();
|
||||
} catch (error) {
|
||||
this.logger.error(`Error in polling task '${name}'`, error);
|
||||
} finally {
|
||||
if (poller) {
|
||||
poller.isPolling = false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Create and register the interval
|
||||
const interval = setInterval(pollFunction, intervalMs);
|
||||
this.schedulerRegistry.addInterval(name, interval);
|
||||
|
||||
this.logger.debug(`Started polling for '${name}' every ${intervalMs}ms`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop polling for a specific subscription topic
|
||||
*/
|
||||
stopPolling(name: string): void {
|
||||
try {
|
||||
if (this.schedulerRegistry.doesExist('interval', name)) {
|
||||
this.schedulerRegistry.deleteInterval(name);
|
||||
this.logger.debug(`Stopped polling for '${name}'`);
|
||||
}
|
||||
} catch (error) {
|
||||
// Interval doesn't exist, which is fine
|
||||
}
|
||||
|
||||
// Clean up polling state
|
||||
this.activePollers.delete(name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop all active polling tasks
|
||||
*/
|
||||
stopAll(): void {
|
||||
const intervals = this.schedulerRegistry.getIntervals();
|
||||
intervals.forEach((key) => this.stopPolling(key));
|
||||
this.activePollers.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if polling is active for a given name
|
||||
*/
|
||||
isPolling(name: string): boolean {
|
||||
return this.schedulerRegistry.doesExist('interval', name);
|
||||
}
|
||||
}
|
||||
@@ -1,14 +1,44 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
|
||||
import { SubscriptionPollingService } from '@app/unraid-api/graph/services/subscription-polling.service.js';
|
||||
import { SubscriptionManagerService } from '@app/unraid-api/graph/services/subscription-manager.service.js';
|
||||
|
||||
/**
|
||||
* Service for managing subscriptions with automatic reference counting.
|
||||
*
|
||||
* This service tracks the number of active subscribers for each topic and automatically
|
||||
* starts/stops the underlying subscription based on subscriber count.
|
||||
*
|
||||
* **When to use this service:**
|
||||
* - When you have multiple GraphQL subscriptions that share the same data source
|
||||
* - When you need to start a resource (polling, file watcher, etc.) only when there are active subscribers
|
||||
* - When you need automatic cleanup when the last subscriber disconnects
|
||||
*
|
||||
* @example
|
||||
* // Register a polling subscription
|
||||
* subscriptionTracker.registerTopic(
|
||||
* 'metrics-update',
|
||||
* async () => {
|
||||
* const metrics = await fetchMetrics();
|
||||
* pubsub.publish('metrics-update', { metrics });
|
||||
* },
|
||||
* 5000 // Poll every 5 seconds
|
||||
* );
|
||||
*
|
||||
* @example
|
||||
* // Register an event-based subscription (e.g., file watching)
|
||||
* subscriptionTracker.registerTopic(
|
||||
* 'log-file-updates',
|
||||
* () => startFileWatcher('/var/log/app.log'), // onStart
|
||||
* () => stopFileWatcher('/var/log/app.log') // onStop
|
||||
* );
|
||||
*/
|
||||
@Injectable()
|
||||
export class SubscriptionTrackerService {
|
||||
private readonly logger = new Logger(SubscriptionTrackerService.name);
|
||||
private subscriberCounts = new Map<string, number>();
|
||||
private topicHandlers = new Map<string, { onStart: () => void; onStop: () => void }>();
|
||||
|
||||
constructor(private readonly pollingService: SubscriptionPollingService) {}
|
||||
constructor(private readonly subscriptionManager: SubscriptionManagerService) {}
|
||||
|
||||
/**
|
||||
* Register a topic with optional polling support
|
||||
@@ -29,8 +59,8 @@ export class SubscriptionTrackerService {
|
||||
callback: async () => callbackOrOnStart(),
|
||||
};
|
||||
this.topicHandlers.set(topic, {
|
||||
onStart: () => this.pollingService.startPolling(pollingConfig),
|
||||
onStop: () => this.pollingService.stopPolling(topic),
|
||||
onStart: () => this.subscriptionManager.startSubscription(pollingConfig),
|
||||
onStop: () => this.subscriptionManager.stopSubscription(topic),
|
||||
});
|
||||
} else {
|
||||
// Legacy API: onStart and onStop handlers
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { CacheModule } from '@nestjs/cache-manager';
|
||||
import { Test } from '@nestjs/testing';
|
||||
|
||||
import { CANONICAL_INTERNAL_CLIENT_TOKEN } from '@unraid/shared';
|
||||
@@ -60,7 +61,7 @@ vi.mock('execa', () => ({
|
||||
describe('RestModule Integration', () => {
|
||||
it('should compile with RestService having access to ApiReportService', async () => {
|
||||
const module = await Test.createTestingModule({
|
||||
imports: [RestModule],
|
||||
imports: [CacheModule.register({ isGlobal: true }), RestModule],
|
||||
})
|
||||
// Override services that have complex dependencies for testing
|
||||
.overrideProvider(CANONICAL_INTERNAL_CLIENT_TOKEN)
|
||||
|
||||
487
api/src/unraid-api/rest/rest.controller.test.ts
Normal file
487
api/src/unraid-api/rest/rest.controller.test.ts
Normal file
@@ -0,0 +1,487 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
import { ConfigService } from '@nestjs/config';
|
||||
import { Test, TestingModule } from '@nestjs/testing';
|
||||
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import type { FastifyReply, FastifyRequest } from '@app/unraid-api/types/fastify.js';
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
|
||||
import { OidcService } from '@app/unraid-api/graph/resolvers/sso/core/oidc.service.js';
|
||||
import { OidcRequestHandler } from '@app/unraid-api/graph/resolvers/sso/utils/oidc-request-handler.util.js';
|
||||
import { RestController } from '@app/unraid-api/rest/rest.controller.js';
|
||||
import { RestService } from '@app/unraid-api/rest/rest.service.js';
|
||||
|
||||
describe('RestController', () => {
|
||||
let controller: RestController;
|
||||
let oidcService: OidcService;
|
||||
let oidcConfig: OidcConfigPersistence;
|
||||
let mockReply: Partial<FastifyReply>;
|
||||
|
||||
// Helper function to create a mock request with the desired hostname
|
||||
const createMockRequest = (hostname?: string, headers: Record<string, any> = {}): FastifyRequest => {
|
||||
return {
|
||||
headers,
|
||||
hostname,
|
||||
url: '/test',
|
||||
protocol: 'https',
|
||||
} as FastifyRequest;
|
||||
};
|
||||
|
||||
beforeEach(async () => {
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
controllers: [RestController],
|
||||
providers: [
|
||||
{
|
||||
provide: RestService,
|
||||
useValue: {
|
||||
getLogs: vi.fn(),
|
||||
getCustomizationStream: vi.fn(),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcService,
|
||||
useValue: {
|
||||
getAuthorizationUrl: vi.fn(),
|
||||
handleCallback: vi.fn(),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: OidcConfigPersistence,
|
||||
useValue: {
|
||||
getConfig: vi.fn().mockResolvedValue({
|
||||
defaultAllowedOrigins: [],
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: ConfigService,
|
||||
useValue: {
|
||||
get: vi.fn(),
|
||||
},
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
controller = module.get<RestController>(RestController);
|
||||
oidcService = module.get<OidcService>(OidcService);
|
||||
oidcConfig = module.get<OidcConfigPersistence>(OidcConfigPersistence);
|
||||
|
||||
mockReply = {
|
||||
status: vi.fn().mockReturnThis(),
|
||||
header: vi.fn().mockReturnThis(),
|
||||
send: vi.fn().mockReturnThis(),
|
||||
type: vi.fn().mockReturnThis(),
|
||||
};
|
||||
});
|
||||
|
||||
describe('oidcAuthorize', () => {
|
||||
describe('redirect URI validation', () => {
|
||||
beforeEach(() => {
|
||||
// Mock OidcRequestHandler.handleAuthorize to return a valid auth URL
|
||||
vi.spyOn(OidcRequestHandler, 'handleAuthorize').mockResolvedValue(
|
||||
'https://provider.com/authorize?client_id=test&redirect_uri=...'
|
||||
);
|
||||
});
|
||||
|
||||
it('should accept redirect_uri with same hostname but different port', async () => {
|
||||
const mockRequest = createMockRequest('unraid.mytailnet.ts.net');
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'https://unraid.mytailnet.ts.net:1443/graphql/api/auth/oidc/callback',
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(302);
|
||||
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalledWith(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'https://unraid.mytailnet.ts.net:1443/graphql/api/auth/oidc/callback',
|
||||
mockRequest,
|
||||
oidcService,
|
||||
expect.any(Logger)
|
||||
);
|
||||
});
|
||||
|
||||
it('should accept redirect_uri with same hostname and standard HTTPS port', async () => {
|
||||
const mockRequest = createMockRequest('unraid.mytailnet.ts.net');
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'https://unraid.mytailnet.ts.net/graphql/api/auth/oidc/callback',
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(302);
|
||||
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should accept redirect_uri with same hostname and explicit port 443', async () => {
|
||||
const mockRequest = createMockRequest('unraid.mytailnet.ts.net');
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'https://unraid.mytailnet.ts.net:443/graphql/api/auth/oidc/callback',
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(302);
|
||||
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should reject redirect_uri with different hostname', async () => {
|
||||
const mockRequest = createMockRequest('unraid.mytailnet.ts.net');
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'https://evil.com/graphql/api/auth/oidc/callback',
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(400);
|
||||
expect(mockReply.send).toHaveBeenCalledWith(
|
||||
expect.stringContaining(
|
||||
'Invalid redirect_uri: https://evil.com/graphql/api/auth/oidc/callback'
|
||||
)
|
||||
);
|
||||
expect(OidcRequestHandler.handleAuthorize).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should reject redirect_uri with subdomain difference', async () => {
|
||||
const mockRequest = createMockRequest('unraid.mytailnet.ts.net');
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'https://evil.unraid.mytailnet.ts.net/graphql/api/auth/oidc/callback',
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(400);
|
||||
expect(mockReply.send).toHaveBeenCalledWith(
|
||||
expect.stringContaining(
|
||||
'Invalid redirect_uri: https://evil.unraid.mytailnet.ts.net/graphql/api/auth/oidc/callback'
|
||||
)
|
||||
);
|
||||
expect(OidcRequestHandler.handleAuthorize).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle hostname from host header when hostname is not available', async () => {
|
||||
const mockRequest = createMockRequest(undefined, {
|
||||
host: 'unraid.mytailnet.ts.net:8080',
|
||||
});
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'https://unraid.mytailnet.ts.net:1443/graphql/api/auth/oidc/callback',
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(302);
|
||||
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should reject malformed redirect_uri', async () => {
|
||||
const mockRequest = createMockRequest('unraid.mytailnet.ts.net');
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'not-a-valid-url',
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(400);
|
||||
expect(mockReply.send).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Invalid redirect_uri: not-a-valid-url')
|
||||
);
|
||||
expect(OidcRequestHandler.handleAuthorize).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle case-insensitive hostname comparison', async () => {
|
||||
const mockRequest = createMockRequest('UnRaid.MyTailnet.TS.net');
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'https://unraid.mytailnet.ts.net:1443/graphql/api/auth/oidc/callback',
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(302);
|
||||
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should preserve exact redirect_uri including custom port in call to handleAuthorize', async () => {
|
||||
const mockRequest = createMockRequest('unraid.mytailnet.ts.net');
|
||||
const customRedirectUri =
|
||||
'https://unraid.mytailnet.ts.net:1443/graphql/api/auth/oidc/callback';
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
customRedirectUri,
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
// Verify the exact redirect URI with port is passed through
|
||||
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalledWith(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
customRedirectUri, // Should be exactly as provided, with :1443
|
||||
mockRequest,
|
||||
oidcService,
|
||||
expect.any(Logger)
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow localhost with different ports', async () => {
|
||||
const mockRequest = createMockRequest('localhost');
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'http://localhost:3000/graphql/api/auth/oidc/callback',
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(302);
|
||||
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalledWith(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'http://localhost:3000/graphql/api/auth/oidc/callback',
|
||||
mockRequest,
|
||||
oidcService,
|
||||
expect.any(Logger)
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow IP addresses with different ports', async () => {
|
||||
const mockRequest = createMockRequest('192.168.1.100');
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'http://192.168.1.100:8080/graphql/api/auth/oidc/callback',
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(302);
|
||||
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should accept redirect_uri with different hostname if in allowed origins', async () => {
|
||||
const mockRequest = createMockRequest('devgen-dev1.local');
|
||||
|
||||
// Mock the config to include the allowed origin
|
||||
vi.mocked(oidcConfig.getConfig).mockResolvedValueOnce({
|
||||
defaultAllowedOrigins: ['https://devgen-bad-dev1.local'],
|
||||
} as any);
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'https://devgen-bad-dev1.local/graphql/api/auth/oidc/callback',
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(302);
|
||||
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalledWith(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'https://devgen-bad-dev1.local/graphql/api/auth/oidc/callback',
|
||||
mockRequest,
|
||||
oidcService,
|
||||
expect.any(Logger)
|
||||
);
|
||||
});
|
||||
|
||||
describe('integration with centralized validator', () => {
|
||||
it('should use the same validation logic as validateRedirectUri function', async () => {
|
||||
const testCases = [
|
||||
{
|
||||
name: 'accepts HTTPS upgrade from allowed origins',
|
||||
requestHost: 'devgen-dev1.local',
|
||||
redirectUri: 'https://allowed-host.local/graphql/api/auth/oidc/callback',
|
||||
allowedOrigins: ['http://allowed-host.local'],
|
||||
expectedStatus: 302,
|
||||
shouldSucceed: true,
|
||||
},
|
||||
{
|
||||
name: 'rejects hostname not in allowed origins',
|
||||
requestHost: 'devgen-dev1.local',
|
||||
redirectUri: 'https://evil.com/graphql/api/auth/oidc/callback',
|
||||
allowedOrigins: ['https://good-host.local'],
|
||||
expectedStatus: 400,
|
||||
shouldSucceed: false,
|
||||
},
|
||||
{
|
||||
name: 'accepts multiple allowed origins',
|
||||
requestHost: 'devgen-dev1.local',
|
||||
redirectUri: 'https://second.local/graphql/api/auth/oidc/callback',
|
||||
allowedOrigins: [
|
||||
'https://first.local',
|
||||
'https://second.local',
|
||||
'https://third.local',
|
||||
],
|
||||
expectedStatus: 302,
|
||||
shouldSucceed: true,
|
||||
},
|
||||
{
|
||||
name: 'respects protocol and hostname from headers',
|
||||
requestHost: undefined,
|
||||
headers: {
|
||||
'x-forwarded-proto': 'https',
|
||||
'x-forwarded-host': 'proxy.local',
|
||||
},
|
||||
redirectUri: 'https://proxy.local/graphql/api/auth/oidc/callback',
|
||||
allowedOrigins: [],
|
||||
expectedStatus: 302,
|
||||
shouldSucceed: true,
|
||||
},
|
||||
];
|
||||
|
||||
for (const testCase of testCases) {
|
||||
// Reset mocks for each test case
|
||||
vi.clearAllMocks();
|
||||
|
||||
const mockRequest = createMockRequest(
|
||||
testCase.requestHost,
|
||||
testCase.headers || {}
|
||||
);
|
||||
|
||||
vi.mocked(oidcConfig.getConfig).mockResolvedValueOnce({
|
||||
defaultAllowedOrigins: testCase.allowedOrigins,
|
||||
} as any);
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
testCase.redirectUri,
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(testCase.expectedStatus);
|
||||
|
||||
if (testCase.shouldSucceed) {
|
||||
expect(OidcRequestHandler.handleAuthorize).toHaveBeenCalled();
|
||||
} else {
|
||||
expect(mockReply.send).toHaveBeenCalledWith(
|
||||
expect.stringContaining(testCase.redirectUri)
|
||||
);
|
||||
expect(OidcRequestHandler.handleAuthorize).not.toHaveBeenCalled();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
it('should handle edge cases consistently with centralized validator', async () => {
|
||||
// Test with empty allowed origins
|
||||
vi.mocked(oidcConfig.getConfig).mockResolvedValueOnce({
|
||||
defaultAllowedOrigins: [],
|
||||
} as any);
|
||||
|
||||
const mockRequest = createMockRequest('host.local');
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'https://different.local/graphql/api/auth/oidc/callback',
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(400);
|
||||
expect(mockReply.send).toHaveBeenCalledWith(
|
||||
expect.stringContaining('https://different.local/graphql/api/auth/oidc/callback')
|
||||
);
|
||||
});
|
||||
|
||||
it('should validate that error messages guide users to settings', async () => {
|
||||
vi.mocked(oidcConfig.getConfig).mockResolvedValueOnce({
|
||||
defaultAllowedOrigins: [],
|
||||
} as any);
|
||||
|
||||
const mockRequest = createMockRequest('host.local');
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
'https://different.local/graphql/api/auth/oidc/callback',
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.send).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Settings → Management Access → Allowed Redirect URIs')
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('parameter validation', () => {
|
||||
it('should return 400 if redirect_uri is missing', async () => {
|
||||
const mockRequest = createMockRequest('unraid.local');
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
'test-state',
|
||||
undefined as any,
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(400);
|
||||
// The controller catches validation errors and returns a generic message
|
||||
expect(mockReply.send).toHaveBeenCalledWith('Invalid provider or configuration');
|
||||
});
|
||||
|
||||
it('should return 400 if providerId is missing', async () => {
|
||||
const mockRequest = createMockRequest('unraid.local');
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
undefined as any,
|
||||
'test-state',
|
||||
'https://unraid.local/callback',
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(400);
|
||||
expect(mockReply.send).toHaveBeenCalledWith('Invalid provider or configuration');
|
||||
});
|
||||
|
||||
it('should return 400 if state is missing', async () => {
|
||||
const mockRequest = createMockRequest('unraid.local');
|
||||
|
||||
await controller.oidcAuthorize(
|
||||
'test-provider',
|
||||
undefined as any,
|
||||
'https://unraid.local/callback',
|
||||
mockRequest,
|
||||
mockReply as FastifyReply
|
||||
);
|
||||
|
||||
expect(mockReply.status).toHaveBeenCalledWith(400);
|
||||
expect(mockReply.send).toHaveBeenCalledWith('Invalid provider or configuration');
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -2,18 +2,25 @@ import { Controller, Get, Logger, Param, Query, Req, Res, UnauthorizedException
|
||||
|
||||
import { AuthAction, Resource } from '@unraid/shared/graphql.model.js';
|
||||
import { UsePermissions } from '@unraid/shared/use-permissions.directive.js';
|
||||
import escapeHtml from 'escape-html';
|
||||
|
||||
import type { FastifyReply, FastifyRequest } from '@app/unraid-api/types/fastify.js';
|
||||
import { Public } from '@app/unraid-api/auth/public.decorator.js';
|
||||
import { OidcAuthService } from '@app/unraid-api/graph/resolvers/sso/oidc-auth.service.js';
|
||||
import { OidcConfigPersistence } from '@app/unraid-api/graph/resolvers/sso/core/oidc-config.service.js';
|
||||
import { OidcService } from '@app/unraid-api/graph/resolvers/sso/core/oidc.service.js';
|
||||
import { OidcRequestHandler } from '@app/unraid-api/graph/resolvers/sso/utils/oidc-request-handler.util.js';
|
||||
import { RestService } from '@app/unraid-api/rest/rest.service.js';
|
||||
import { validateRedirectUri } from '@app/unraid-api/utils/redirect-uri-validator.js';
|
||||
|
||||
@Controller()
|
||||
export class RestController {
|
||||
protected logger = new Logger(RestController.name);
|
||||
protected oidcLogger = new Logger('OidcRestController');
|
||||
|
||||
constructor(
|
||||
private readonly restService: RestService,
|
||||
private readonly oidcAuthService: OidcAuthService
|
||||
private readonly oidcService: OidcService,
|
||||
private readonly oidcConfig: OidcConfigPersistence
|
||||
) {}
|
||||
|
||||
@Get('/')
|
||||
@@ -65,38 +72,69 @@ export class RestController {
|
||||
async oidcAuthorize(
|
||||
@Param('providerId') providerId: string,
|
||||
@Query('state') state: string,
|
||||
@Query('redirect_uri') redirectUri: string,
|
||||
@Req() req: FastifyRequest,
|
||||
@Res() res: FastifyReply
|
||||
) {
|
||||
try {
|
||||
if (!state) {
|
||||
return res.status(400).send('State parameter is required');
|
||||
// Validate required parameters
|
||||
const params = OidcRequestHandler.validateAuthorizeParams(providerId, state, redirectUri);
|
||||
|
||||
// IMPORTANT: Use the redirect_uri from query params directly
|
||||
// Do NOT parse headers or try to build/validate against headers
|
||||
// The frontend provides the complete redirect_uri
|
||||
if (!params.redirectUri) {
|
||||
return res.status(400).send('redirect_uri parameter is required');
|
||||
}
|
||||
|
||||
// Get the host and protocol from the request headers
|
||||
const protocol = (req.headers['x-forwarded-proto'] as string) || req.protocol || 'http';
|
||||
const host = (req.headers['x-forwarded-host'] as string) || req.headers.host || undefined;
|
||||
const requestInfo = host ? `${protocol}://${host}` : undefined;
|
||||
// Security validation: validate redirect_uri with support for allowed origins
|
||||
const protocol = (req.headers['x-forwarded-proto'] as string) || 'http';
|
||||
const host = (req.headers['x-forwarded-host'] as string) || req.headers.host || req.hostname;
|
||||
|
||||
const authUrl = await this.oidcAuthService.getAuthorizationUrl(
|
||||
providerId,
|
||||
state,
|
||||
requestInfo
|
||||
// Get allowed origins from OIDC config
|
||||
const config = await this.oidcConfig.getConfig();
|
||||
const allowedOrigins = config?.defaultAllowedOrigins;
|
||||
|
||||
// Validate the redirect URI using the centralized validator
|
||||
const validation = validateRedirectUri(
|
||||
params.redirectUri,
|
||||
protocol,
|
||||
host,
|
||||
this.oidcLogger,
|
||||
allowedOrigins
|
||||
);
|
||||
|
||||
if (!validation.isValid) {
|
||||
this.oidcLogger.warn(`Invalid redirect_uri: ${validation.reason}`);
|
||||
return res
|
||||
.status(400)
|
||||
.send(
|
||||
`Invalid redirect_uri: ${escapeHtml(params.redirectUri)}. ${escapeHtml(validation.reason || 'Unknown validation error')}. Please add this callback URI to Settings → Management Access → Allowed Redirect URIs`
|
||||
);
|
||||
}
|
||||
|
||||
// Handle authorization flow using the exact redirect_uri from query params
|
||||
const authUrl = await OidcRequestHandler.handleAuthorize(
|
||||
params.providerId,
|
||||
params.state,
|
||||
params.redirectUri,
|
||||
req,
|
||||
this.oidcService,
|
||||
this.oidcLogger
|
||||
);
|
||||
this.logger.log(`Redirecting to OIDC provider: ${authUrl}`);
|
||||
|
||||
// Manually set redirect headers for better proxy compatibility
|
||||
res.status(302);
|
||||
res.header('Location', authUrl);
|
||||
return res.send();
|
||||
} catch (error: unknown) {
|
||||
this.logger.error(`OIDC authorize error for provider ${providerId}:`, error);
|
||||
this.oidcLogger.error(`OIDC authorize error for provider ${providerId}:`, error);
|
||||
|
||||
// Log more details about the error
|
||||
if (error instanceof Error) {
|
||||
this.logger.error(`Error message: ${error.message}`);
|
||||
this.oidcLogger.error(`Error message: ${error.message}`);
|
||||
if (error.stack) {
|
||||
this.logger.debug(`Stack trace: ${error.stack}`);
|
||||
this.oidcLogger.debug(`Stack trace: ${error.stack}`);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,32 +155,20 @@ export class RestController {
|
||||
@Res() res: FastifyReply
|
||||
) {
|
||||
try {
|
||||
if (!code || !state) {
|
||||
return res.status(400).send('Missing required parameters');
|
||||
}
|
||||
// Validate required parameters
|
||||
const params = OidcRequestHandler.validateCallbackParams(code, state);
|
||||
|
||||
// Extract provider ID from state
|
||||
const { providerId } = this.oidcAuthService.extractProviderFromState(state);
|
||||
|
||||
// Get the full callback URL as received, respecting reverse proxy headers
|
||||
const protocol = (req.headers['x-forwarded-proto'] as string) || req.protocol || 'http';
|
||||
const host =
|
||||
(req.headers['x-forwarded-host'] as string) || req.headers.host || 'localhost:3000';
|
||||
const fullUrl = `${protocol}://${host}${req.url}`;
|
||||
const requestInfo = `${protocol}://${host}`;
|
||||
|
||||
this.logger.debug(`Full callback URL from request: ${fullUrl}`);
|
||||
|
||||
const paddedToken = await this.oidcAuthService.handleCallback(
|
||||
providerId,
|
||||
code,
|
||||
state,
|
||||
requestInfo,
|
||||
fullUrl
|
||||
// Handle callback flow
|
||||
const result = await OidcRequestHandler.handleCallback(
|
||||
params.code,
|
||||
params.state,
|
||||
req,
|
||||
this.oidcService,
|
||||
this.oidcLogger
|
||||
);
|
||||
|
||||
// Redirect to login page with the token in hash to keep it out of server logs
|
||||
const loginUrl = `/login#token=${encodeURIComponent(paddedToken)}`;
|
||||
const loginUrl = `/login#token=${encodeURIComponent(result.paddedToken)}`;
|
||||
|
||||
// Manually set redirect headers for better proxy compatibility
|
||||
res.header('Cache-Control', 'no-store');
|
||||
@@ -152,16 +178,16 @@ export class RestController {
|
||||
res.header('Location', loginUrl);
|
||||
return res.send();
|
||||
} catch (error: unknown) {
|
||||
this.logger.error(`OIDC callback error: ${error}`);
|
||||
this.oidcLogger.error(`OIDC callback error: ${error}`);
|
||||
|
||||
// Use a generic error message to avoid leaking sensitive information
|
||||
const errorMessage = 'Authentication failed';
|
||||
|
||||
// Log detailed error for debugging but don't expose to user
|
||||
if (error instanceof UnauthorizedException) {
|
||||
this.logger.debug(`UnauthorizedException occurred during OIDC callback`);
|
||||
this.oidcLogger.debug(`UnauthorizedException occurred during OIDC callback`);
|
||||
} else if (error instanceof Error) {
|
||||
this.logger.debug(`Error during OIDC callback: ${error.message}`);
|
||||
this.oidcLogger.debug(`Error during OIDC callback: ${error.message}`);
|
||||
}
|
||||
|
||||
const loginUrl = `/login#error=${encodeURIComponent(errorMessage)}`;
|
||||
|
||||
406
api/src/unraid-api/utils/error-extractor.util.test.ts
Normal file
406
api/src/unraid-api/utils/error-extractor.util.test.ts
Normal file
@@ -0,0 +1,406 @@
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { ErrorExtractor } from '@app/unraid-api/utils/error-extractor.util.js';
|
||||
|
||||
describe('ErrorExtractor', () => {
|
||||
describe('extract', () => {
|
||||
it('should handle null and undefined errors', () => {
|
||||
const nullResult = ErrorExtractor.extract(null);
|
||||
expect(nullResult.message).toBe('Unknown error');
|
||||
expect(nullResult.type).toBe('Unknown');
|
||||
|
||||
const undefinedResult = ErrorExtractor.extract(undefined);
|
||||
expect(undefinedResult.message).toBe('Unknown error');
|
||||
expect(undefinedResult.type).toBe('Unknown');
|
||||
});
|
||||
|
||||
it('should extract basic Error properties', () => {
|
||||
const error = new Error('Test error message');
|
||||
const result = ErrorExtractor.extract(error);
|
||||
|
||||
expect(result.message).toBe('Test error message');
|
||||
expect(result.type).toBe('Error');
|
||||
expect(result.stack).toBeDefined();
|
||||
});
|
||||
|
||||
it('should extract custom error types', () => {
|
||||
class CustomError extends Error {}
|
||||
const error = new CustomError('Custom error');
|
||||
const result = ErrorExtractor.extract(error);
|
||||
|
||||
expect(result.type).toBe('CustomError');
|
||||
});
|
||||
|
||||
it('should extract error code', () => {
|
||||
const error: any = new Error('Error with code');
|
||||
error.code = 'ERR_CODE';
|
||||
const result = ErrorExtractor.extract(error);
|
||||
|
||||
expect(result.code).toBe('ERR_CODE');
|
||||
});
|
||||
|
||||
it('should extract HTTP response details', () => {
|
||||
const error: any = new Error('HTTP error');
|
||||
error.response = {
|
||||
status: 404,
|
||||
statusText: 'Not Found',
|
||||
body: { error: 'Resource not found' },
|
||||
headers: { 'content-type': 'application/json' },
|
||||
};
|
||||
|
||||
const result = ErrorExtractor.extract(error);
|
||||
|
||||
expect(result.status).toBe(404);
|
||||
expect(result.statusText).toBe('Not Found');
|
||||
expect(result.responseBody).toEqual({ error: 'Resource not found' });
|
||||
expect(result.responseHeaders).toEqual({ 'content-type': 'application/json' });
|
||||
});
|
||||
|
||||
it('should truncate long response body strings', () => {
|
||||
const error: any = new Error('Error with long body');
|
||||
const longString = 'x'.repeat(2000);
|
||||
error.body = longString;
|
||||
|
||||
const result = ErrorExtractor.extract(error);
|
||||
|
||||
expect(result.responseBody).toBe('x'.repeat(1000) + '... (truncated)');
|
||||
});
|
||||
|
||||
it('should extract OAuth error details', () => {
|
||||
const error: any = new Error('OAuth error');
|
||||
error.error = 'invalid_grant';
|
||||
error.error_description = 'The provided authorization code is invalid';
|
||||
|
||||
const result = ErrorExtractor.extract(error);
|
||||
|
||||
expect(result.oauthError).toBe('invalid_grant');
|
||||
expect(result.oauthErrorDescription).toBe('The provided authorization code is invalid');
|
||||
});
|
||||
|
||||
it('should extract cause chain', () => {
|
||||
const rootCause = new Error('Root cause');
|
||||
const middleCause: any = new Error('Middle cause');
|
||||
middleCause.cause = rootCause;
|
||||
const topError: any = new Error('Top error');
|
||||
topError.cause = middleCause;
|
||||
|
||||
const result = ErrorExtractor.extract(topError);
|
||||
|
||||
expect(result.causeChain).toHaveLength(2);
|
||||
expect(result.causeChain![0]).toEqual({
|
||||
depth: 1,
|
||||
type: 'Error',
|
||||
message: 'Middle cause',
|
||||
});
|
||||
expect(result.causeChain![1]).toEqual({
|
||||
depth: 2,
|
||||
type: 'Error',
|
||||
message: 'Root cause',
|
||||
});
|
||||
});
|
||||
|
||||
it('should limit cause chain depth', () => {
|
||||
// Create a deep nested error chain
|
||||
let deepestError: any = new Error('Level 10');
|
||||
|
||||
for (let i = 9; i >= 0; i--) {
|
||||
const error: any = new Error(`Level ${i}`);
|
||||
error.cause = deepestError;
|
||||
deepestError = error;
|
||||
}
|
||||
|
||||
const topError: any = new Error('Top');
|
||||
topError.cause = deepestError;
|
||||
|
||||
const result = ErrorExtractor.extract(topError);
|
||||
|
||||
expect(result.causeChain).toHaveLength(5); // MAX_CAUSE_DEPTH
|
||||
});
|
||||
|
||||
it('should extract cause with code', () => {
|
||||
const cause: any = new Error('Cause with code');
|
||||
cause.code = 'ECONNREFUSED';
|
||||
const error: any = new Error('Main error');
|
||||
error.cause = cause;
|
||||
|
||||
const result = ErrorExtractor.extract(error);
|
||||
|
||||
expect(result.causeChain![0].code).toBe('ECONNREFUSED');
|
||||
});
|
||||
|
||||
it('should extract additional properties', () => {
|
||||
const error: any = new Error('Error with extras');
|
||||
error.customProp1 = 'value1';
|
||||
error.customProp2 = 123;
|
||||
|
||||
const result = ErrorExtractor.extract(error);
|
||||
|
||||
expect(result.additionalProperties).toEqual({
|
||||
customProp1: 'value1',
|
||||
customProp2: 123,
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle string errors', () => {
|
||||
const result = ErrorExtractor.extract('String error message');
|
||||
|
||||
expect(result.message).toBe('String error message');
|
||||
expect(result.type).toBe('String');
|
||||
});
|
||||
|
||||
it('should handle object errors', () => {
|
||||
const error = { code: 'ERROR', message: 'Object error' };
|
||||
const result = ErrorExtractor.extract(error);
|
||||
|
||||
expect(result.message).toBe(JSON.stringify(error));
|
||||
expect(result.type).toBe('Object');
|
||||
});
|
||||
|
||||
it('should handle primitive errors', () => {
|
||||
const result = ErrorExtractor.extract(42);
|
||||
|
||||
expect(result.message).toBe('42');
|
||||
expect(result.type).toBe('number');
|
||||
});
|
||||
|
||||
it('should handle openid-client error structure', () => {
|
||||
const error: any = new Error('unexpected response content-type');
|
||||
error.code = 'OAUTH_RESPONSE_IS_NOT_JSON';
|
||||
error.response = {
|
||||
status: 200,
|
||||
headers: { 'content-type': 'text/html' },
|
||||
body: '<html>Error page</html>',
|
||||
};
|
||||
|
||||
const result = ErrorExtractor.extract(error);
|
||||
|
||||
expect(result.code).toBe('OAUTH_RESPONSE_IS_NOT_JSON');
|
||||
expect(result.responseHeaders!['content-type']).toBe('text/html');
|
||||
expect(result.responseBody).toContain('<html>');
|
||||
});
|
||||
});
|
||||
|
||||
describe('isOAuthResponseError', () => {
|
||||
it('should identify OAuth response errors by code', () => {
|
||||
const extracted = {
|
||||
message: 'Some error',
|
||||
type: 'Error',
|
||||
code: 'OAUTH_RESPONSE_IS_NOT_JSON',
|
||||
};
|
||||
|
||||
expect(ErrorExtractor.isOAuthResponseError(extracted)).toBe(true);
|
||||
});
|
||||
|
||||
it('should identify OAuth response errors by message', () => {
|
||||
const extracted = {
|
||||
message: 'unexpected response content-type from server',
|
||||
type: 'Error',
|
||||
};
|
||||
|
||||
expect(ErrorExtractor.isOAuthResponseError(extracted)).toBe(true);
|
||||
});
|
||||
|
||||
it('should identify parsing errors', () => {
|
||||
const extracted = {
|
||||
message: 'JSON parsing error occurred',
|
||||
type: 'Error',
|
||||
};
|
||||
|
||||
expect(ErrorExtractor.isOAuthResponseError(extracted)).toBe(true);
|
||||
});
|
||||
|
||||
it('should not identify non-OAuth errors', () => {
|
||||
const extracted = {
|
||||
message: 'Some other error',
|
||||
type: 'Error',
|
||||
code: 'OTHER_ERROR',
|
||||
};
|
||||
|
||||
expect(ErrorExtractor.isOAuthResponseError(extracted)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isJwtClaimError', () => {
|
||||
it('should identify JWT claim errors', () => {
|
||||
const extracted = {
|
||||
message: 'unexpected JWT claim value encountered',
|
||||
type: 'Error',
|
||||
};
|
||||
|
||||
expect(ErrorExtractor.isJwtClaimError(extracted)).toBe(true);
|
||||
});
|
||||
|
||||
it('should not identify non-JWT errors', () => {
|
||||
const extracted = {
|
||||
message: 'Some other error',
|
||||
type: 'Error',
|
||||
};
|
||||
|
||||
expect(ErrorExtractor.isJwtClaimError(extracted)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isNetworkError', () => {
|
||||
it('should identify network errors by code', () => {
|
||||
const codes = ['ECONNREFUSED', 'ENOTFOUND', 'ETIMEDOUT', 'ECONNRESET'];
|
||||
|
||||
for (const code of codes) {
|
||||
const extracted = {
|
||||
message: 'Error',
|
||||
type: 'Error',
|
||||
code,
|
||||
};
|
||||
|
||||
expect(ErrorExtractor.isNetworkError(extracted)).toBe(true);
|
||||
}
|
||||
});
|
||||
|
||||
it('should identify network errors by message', () => {
|
||||
const messages = ['network timeout occurred', 'failed to connect to server'];
|
||||
|
||||
for (const message of messages) {
|
||||
const extracted = {
|
||||
message,
|
||||
type: 'Error',
|
||||
};
|
||||
|
||||
expect(ErrorExtractor.isNetworkError(extracted)).toBe(true);
|
||||
}
|
||||
});
|
||||
|
||||
it('should not identify non-network errors', () => {
|
||||
const extracted = {
|
||||
message: 'Invalid credentials',
|
||||
type: 'Error',
|
||||
code: 'AUTH_ERROR',
|
||||
};
|
||||
|
||||
expect(ErrorExtractor.isNetworkError(extracted)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('formatForLogging', () => {
|
||||
it('should log basic error information', () => {
|
||||
const logger = {
|
||||
error: vi.fn(),
|
||||
debug: vi.fn(),
|
||||
};
|
||||
|
||||
const extracted = {
|
||||
message: 'Test error',
|
||||
type: 'CustomError',
|
||||
code: 'ERR_TEST',
|
||||
};
|
||||
|
||||
ErrorExtractor.formatForLogging(extracted, logger);
|
||||
|
||||
expect(logger.error).toHaveBeenCalledWith('Error type: CustomError');
|
||||
expect(logger.error).toHaveBeenCalledWith('Error message: Test error');
|
||||
expect(logger.error).toHaveBeenCalledWith('Error code: ERR_TEST');
|
||||
});
|
||||
|
||||
it('should log HTTP response details', () => {
|
||||
const logger = {
|
||||
error: vi.fn(),
|
||||
debug: vi.fn(),
|
||||
};
|
||||
|
||||
const extracted = {
|
||||
message: 'HTTP error',
|
||||
type: 'Error',
|
||||
status: 500,
|
||||
statusText: 'Internal Server Error',
|
||||
responseBody: { error: 'Server error' },
|
||||
responseHeaders: { 'content-type': 'application/json' },
|
||||
};
|
||||
|
||||
ErrorExtractor.formatForLogging(extracted, logger);
|
||||
|
||||
expect(logger.error).toHaveBeenCalledWith('HTTP Status: 500 Internal Server Error');
|
||||
expect(logger.error).toHaveBeenCalledWith('Response body: %o', { error: 'Server error' });
|
||||
expect(logger.error).toHaveBeenCalledWith('Response Content-Type: application/json');
|
||||
});
|
||||
|
||||
it('should log OAuth error details', () => {
|
||||
const logger = {
|
||||
error: vi.fn(),
|
||||
debug: vi.fn(),
|
||||
};
|
||||
|
||||
const extracted = {
|
||||
message: 'OAuth error',
|
||||
type: 'Error',
|
||||
oauthError: 'invalid_client',
|
||||
oauthErrorDescription: 'Client authentication failed',
|
||||
};
|
||||
|
||||
ErrorExtractor.formatForLogging(extracted, logger);
|
||||
|
||||
expect(logger.error).toHaveBeenCalledWith('OAuth error: invalid_client');
|
||||
expect(logger.error).toHaveBeenCalledWith(
|
||||
'OAuth error description: Client authentication failed'
|
||||
);
|
||||
});
|
||||
|
||||
it('should log cause chain', () => {
|
||||
const logger = {
|
||||
error: vi.fn(),
|
||||
debug: vi.fn(),
|
||||
};
|
||||
|
||||
const extracted = {
|
||||
message: 'Top error',
|
||||
type: 'Error',
|
||||
causeChain: [
|
||||
{ depth: 1, type: 'Error', message: 'Cause 1', code: 'CODE1' },
|
||||
{ depth: 2, type: 'Error', message: 'Cause 2' },
|
||||
],
|
||||
};
|
||||
|
||||
ErrorExtractor.formatForLogging(extracted, logger);
|
||||
|
||||
expect(logger.error).toHaveBeenCalledWith('Error cause chain:');
|
||||
expect(logger.error).toHaveBeenCalledWith(' [Cause 1] Error: Cause 1');
|
||||
expect(logger.error).toHaveBeenCalledWith(' [Cause 1] Code: CODE1');
|
||||
expect(logger.error).toHaveBeenCalledWith(' [Cause 2] Error: Cause 2');
|
||||
});
|
||||
|
||||
it('should log additional properties and stack in debug', () => {
|
||||
const logger = {
|
||||
error: vi.fn(),
|
||||
debug: vi.fn(),
|
||||
};
|
||||
|
||||
const extracted = {
|
||||
message: 'Error',
|
||||
type: 'Error',
|
||||
additionalProperties: { custom: 'value' },
|
||||
stack: 'Stack trace here',
|
||||
};
|
||||
|
||||
ErrorExtractor.formatForLogging(extracted, logger);
|
||||
|
||||
expect(logger.debug).toHaveBeenCalledWith('Additional error properties: %o', {
|
||||
custom: 'value',
|
||||
});
|
||||
expect(logger.debug).toHaveBeenCalledWith('Stack trace: Stack trace here');
|
||||
});
|
||||
|
||||
it('should handle case-insensitive Content-Type header', () => {
|
||||
const logger = {
|
||||
error: vi.fn(),
|
||||
debug: vi.fn(),
|
||||
};
|
||||
|
||||
const extracted = {
|
||||
message: 'Error',
|
||||
type: 'Error',
|
||||
responseHeaders: { 'Content-Type': 'text/html' },
|
||||
};
|
||||
|
||||
ErrorExtractor.formatForLogging(extracted, logger);
|
||||
|
||||
expect(logger.error).toHaveBeenCalledWith('Response Content-Type: text/html');
|
||||
});
|
||||
});
|
||||
});
|
||||
277
api/src/unraid-api/utils/error-extractor.util.ts
Normal file
277
api/src/unraid-api/utils/error-extractor.util.ts
Normal file
@@ -0,0 +1,277 @@
|
||||
export interface ExtractedError {
|
||||
message: string;
|
||||
type: string;
|
||||
code?: string;
|
||||
status?: number;
|
||||
statusText?: string;
|
||||
responseBody?: unknown;
|
||||
responseHeaders?: Record<string, string>;
|
||||
oauthError?: string;
|
||||
oauthErrorDescription?: string;
|
||||
causeChain?: Array<{
|
||||
depth: number;
|
||||
type: string;
|
||||
message: string;
|
||||
code?: string;
|
||||
}>;
|
||||
additionalProperties?: Record<string, unknown>;
|
||||
stack?: string;
|
||||
}
|
||||
|
||||
export class ErrorExtractor {
|
||||
private static readonly MAX_CAUSE_DEPTH = 5;
|
||||
private static readonly MAX_BODY_PREVIEW_LENGTH = 1000;
|
||||
|
||||
static extract(error: unknown): ExtractedError {
|
||||
const result: ExtractedError = {
|
||||
message: 'Unknown error',
|
||||
type: 'Unknown',
|
||||
};
|
||||
|
||||
if (!error) {
|
||||
return result;
|
||||
}
|
||||
|
||||
if (error instanceof Error) {
|
||||
result.message = error.message;
|
||||
result.type = error.constructor.name;
|
||||
result.stack = error.stack;
|
||||
|
||||
// Extract error code if present
|
||||
if ('code' in error && error.code) {
|
||||
result.code = String(error.code);
|
||||
}
|
||||
|
||||
// Extract HTTP response details
|
||||
if ('response' in error && error.response) {
|
||||
this.extractResponseDetails(error.response as any, result);
|
||||
}
|
||||
|
||||
// Extract OAuth-specific errors
|
||||
if ('error' in error && error.error) {
|
||||
result.oauthError = String(error.error);
|
||||
}
|
||||
|
||||
if ('error_description' in error && error.error_description) {
|
||||
result.oauthErrorDescription = String(error.error_description);
|
||||
}
|
||||
|
||||
// Extract response body if directly available
|
||||
if ('body' in error && error.body) {
|
||||
result.responseBody = this.formatResponseBody(error.body as any);
|
||||
}
|
||||
|
||||
// Extract cause chain
|
||||
if ('cause' in error && error.cause) {
|
||||
result.causeChain = this.extractCauseChain(error.cause);
|
||||
}
|
||||
|
||||
// Extract additional properties
|
||||
const standardKeys = [
|
||||
'message',
|
||||
'stack',
|
||||
'cause',
|
||||
'code',
|
||||
'response',
|
||||
'body',
|
||||
'error',
|
||||
'error_description',
|
||||
];
|
||||
const additionalKeys = Object.keys(error).filter((key) => !standardKeys.includes(key));
|
||||
|
||||
if (additionalKeys.length > 0) {
|
||||
result.additionalProperties = {};
|
||||
for (const key of additionalKeys) {
|
||||
const value = (error as any)[key];
|
||||
if (value !== undefined && value !== null) {
|
||||
result.additionalProperties[key] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (typeof error === 'string') {
|
||||
result.message = error;
|
||||
result.type = 'String';
|
||||
} else if (typeof error === 'object' && error !== null) {
|
||||
result.message = JSON.stringify(error);
|
||||
result.type = 'Object';
|
||||
} else {
|
||||
result.message = String(error);
|
||||
result.type = typeof error;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private static extractResponseDetails(response: any, result: ExtractedError): void {
|
||||
if (!response) return;
|
||||
|
||||
if (response.status) {
|
||||
result.status = response.status;
|
||||
}
|
||||
|
||||
if (response.statusText) {
|
||||
result.statusText = response.statusText;
|
||||
}
|
||||
|
||||
if (response.body) {
|
||||
result.responseBody = this.formatResponseBody(response.body);
|
||||
}
|
||||
|
||||
if (response.headers) {
|
||||
result.responseHeaders = this.extractHeaders(response.headers);
|
||||
}
|
||||
}
|
||||
|
||||
private static formatResponseBody(body: unknown): unknown {
|
||||
if (!body) return undefined;
|
||||
|
||||
if (typeof body === 'string') {
|
||||
if (body.length > this.MAX_BODY_PREVIEW_LENGTH) {
|
||||
return body.substring(0, this.MAX_BODY_PREVIEW_LENGTH) + '... (truncated)';
|
||||
}
|
||||
return body;
|
||||
}
|
||||
|
||||
return body;
|
||||
}
|
||||
|
||||
private static extractHeaders(headers: any): Record<string, string> | undefined {
|
||||
if (!headers) return undefined;
|
||||
|
||||
const result: Record<string, string> = {};
|
||||
|
||||
// Handle different header formats
|
||||
if (typeof headers === 'object') {
|
||||
for (const key of Object.keys(headers)) {
|
||||
const value = headers[key];
|
||||
if (value !== undefined && value !== null) {
|
||||
result[key] = String(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Object.keys(result).length > 0 ? result : undefined;
|
||||
}
|
||||
|
||||
private static extractCauseChain(
|
||||
cause: unknown
|
||||
): Array<{ depth: number; type: string; message: string; code?: string }> | undefined {
|
||||
const chain: Array<{ depth: number; type: string; message: string; code?: string }> = [];
|
||||
let currentCause = cause;
|
||||
let depth = 1;
|
||||
|
||||
while (currentCause && depth <= this.MAX_CAUSE_DEPTH) {
|
||||
const causeInfo: { depth: number; type: string; message: string; code?: string } = {
|
||||
depth,
|
||||
type: 'Unknown',
|
||||
message: 'Unknown cause',
|
||||
};
|
||||
|
||||
if (currentCause instanceof Error) {
|
||||
causeInfo.type = currentCause.constructor.name;
|
||||
causeInfo.message = currentCause.message;
|
||||
|
||||
if ('code' in currentCause && currentCause.code) {
|
||||
causeInfo.code = String(currentCause.code);
|
||||
}
|
||||
} else if (typeof currentCause === 'string') {
|
||||
causeInfo.type = 'String';
|
||||
causeInfo.message = currentCause;
|
||||
} else {
|
||||
causeInfo.type = typeof currentCause;
|
||||
causeInfo.message = String(currentCause);
|
||||
}
|
||||
|
||||
chain.push(causeInfo);
|
||||
|
||||
// Get next cause in chain
|
||||
currentCause =
|
||||
currentCause && typeof currentCause === 'object' && 'cause' in currentCause
|
||||
? (currentCause as any).cause
|
||||
: undefined;
|
||||
depth++;
|
||||
}
|
||||
|
||||
return chain.length > 0 ? chain : undefined;
|
||||
}
|
||||
|
||||
static isOAuthResponseError(extracted: ExtractedError): boolean {
|
||||
return Boolean(
|
||||
extracted.code === 'OAUTH_RESPONSE_IS_NOT_JSON' ||
|
||||
extracted.code === 'OAUTH_PARSE_ERROR' ||
|
||||
extracted.message.includes('unexpected response content-type') ||
|
||||
extracted.message.includes('parsing error')
|
||||
);
|
||||
}
|
||||
|
||||
static isJwtClaimError(extracted: ExtractedError): boolean {
|
||||
return extracted.message.includes('unexpected JWT claim value encountered');
|
||||
}
|
||||
|
||||
static isNetworkError(extracted: ExtractedError): boolean {
|
||||
return Boolean(
|
||||
extracted.code === 'ECONNREFUSED' ||
|
||||
extracted.code === 'ENOTFOUND' ||
|
||||
extracted.code === 'ETIMEDOUT' ||
|
||||
extracted.code === 'ECONNRESET' ||
|
||||
extracted.message.includes('network') ||
|
||||
extracted.message.includes('connect')
|
||||
);
|
||||
}
|
||||
|
||||
static formatForLogging(
|
||||
extracted: ExtractedError,
|
||||
logger: {
|
||||
error: (msg: string, ...args: any[]) => void;
|
||||
debug: (msg: string, ...args: any[]) => void;
|
||||
}
|
||||
): void {
|
||||
logger.error(`Error type: ${extracted.type}`);
|
||||
logger.error(`Error message: ${extracted.message}`);
|
||||
|
||||
if (extracted.code) {
|
||||
logger.error(`Error code: ${extracted.code}`);
|
||||
}
|
||||
|
||||
if (extracted.status) {
|
||||
logger.error(`HTTP Status: ${extracted.status} ${extracted.statusText || ''}`);
|
||||
}
|
||||
|
||||
if (extracted.responseBody) {
|
||||
logger.error('Response body: %o', extracted.responseBody);
|
||||
}
|
||||
|
||||
if (extracted.responseHeaders) {
|
||||
const contentType =
|
||||
extracted.responseHeaders['content-type'] || extracted.responseHeaders['Content-Type'];
|
||||
if (contentType) {
|
||||
logger.error(`Response Content-Type: ${contentType}`);
|
||||
}
|
||||
}
|
||||
|
||||
if (extracted.oauthError) {
|
||||
logger.error(`OAuth error: ${extracted.oauthError}`);
|
||||
if (extracted.oauthErrorDescription) {
|
||||
logger.error(`OAuth error description: ${extracted.oauthErrorDescription}`);
|
||||
}
|
||||
}
|
||||
|
||||
if (extracted.causeChain) {
|
||||
logger.error('Error cause chain:');
|
||||
for (const cause of extracted.causeChain) {
|
||||
logger.error(` [Cause ${cause.depth}] ${cause.type}: ${cause.message}`);
|
||||
if (cause.code) {
|
||||
logger.error(` [Cause ${cause.depth}] Code: ${cause.code}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (extracted.additionalProperties && Object.keys(extracted.additionalProperties).length > 0) {
|
||||
logger.debug('Additional error properties: %o', extracted.additionalProperties);
|
||||
}
|
||||
|
||||
if (extracted.stack) {
|
||||
logger.debug(`Stack trace: ${extracted.stack}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
506
api/src/unraid-api/utils/redirect-uri-validator.test.ts
Normal file
506
api/src/unraid-api/utils/redirect-uri-validator.test.ts
Normal file
@@ -0,0 +1,506 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { validateRedirectUri } from '@app/unraid-api/utils/redirect-uri-validator.js';
|
||||
|
||||
describe('validateRedirectUri', () => {
|
||||
let mockLogger: Logger;
|
||||
|
||||
beforeEach(() => {
|
||||
mockLogger = {
|
||||
debug: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
} as any;
|
||||
});
|
||||
|
||||
describe('basic validation', () => {
|
||||
it('should return base URL when no redirect URI is provided', () => {
|
||||
const result = validateRedirectUri(undefined, 'https', 'example.com', mockLogger);
|
||||
|
||||
expect(result).toEqual({
|
||||
isValid: true,
|
||||
validatedUri: 'https://example.com',
|
||||
reason: 'No redirect URI provided',
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle missing base URL', () => {
|
||||
const result = validateRedirectUri('https://example.com', 'https', undefined, mockLogger);
|
||||
|
||||
expect(result).toEqual({
|
||||
isValid: false,
|
||||
validatedUri: '',
|
||||
reason: 'No base URL available',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('hostname validation', () => {
|
||||
it('should accept matching hostname with same port', () => {
|
||||
const result = validateRedirectUri(
|
||||
'https://example.com:3000',
|
||||
'https',
|
||||
'example.com:3000',
|
||||
mockLogger
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.validatedUri).toBe('https://example.com:3000');
|
||||
expect(mockLogger.debug).toHaveBeenCalledWith(
|
||||
'Validated redirect_uri: https://example.com:3000'
|
||||
);
|
||||
});
|
||||
|
||||
it('should accept matching hostname with different ports', () => {
|
||||
const result = validateRedirectUri(
|
||||
'https://example.com:3001',
|
||||
'https',
|
||||
'example.com:3000',
|
||||
mockLogger
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.validatedUri).toBe('https://example.com:3001');
|
||||
expect(mockLogger.debug).toHaveBeenCalledWith(
|
||||
'Validated redirect_uri: https://example.com:3001'
|
||||
);
|
||||
});
|
||||
|
||||
it('should accept matching hostname when expected has no port but provided does', () => {
|
||||
const result = validateRedirectUri(
|
||||
'https://example.com:3000',
|
||||
'https',
|
||||
'example.com',
|
||||
mockLogger
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.validatedUri).toBe('https://example.com:3000');
|
||||
});
|
||||
|
||||
it('should reject different hostnames', () => {
|
||||
const result = validateRedirectUri('https://evil.com', 'https', 'example.com', mockLogger);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.validatedUri).toBe('https://example.com');
|
||||
expect(result.reason).toContain('Hostname or protocol mismatch');
|
||||
expect(mockLogger.warn).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should reject subdomain differences', () => {
|
||||
const result = validateRedirectUri(
|
||||
'https://sub.example.com',
|
||||
'https',
|
||||
'example.com',
|
||||
mockLogger
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.validatedUri).toBe('https://example.com');
|
||||
});
|
||||
|
||||
it('should handle case-insensitive hostname comparison', () => {
|
||||
const result = validateRedirectUri(
|
||||
'https://EXAMPLE.COM:3000',
|
||||
'https',
|
||||
'example.com',
|
||||
mockLogger
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.validatedUri).toBe('https://EXAMPLE.COM:3000');
|
||||
});
|
||||
});
|
||||
|
||||
describe('protocol validation', () => {
|
||||
it('should reject HTTP when expecting HTTPS (prevent downgrade attacks)', () => {
|
||||
const result = validateRedirectUri('http://example.com', 'https', 'example.com', mockLogger);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.validatedUri).toBe('https://example.com');
|
||||
expect(result.reason).toContain('Hostname or protocol mismatch');
|
||||
});
|
||||
|
||||
it('should allow HTTPS when expecting HTTP (common with reverse proxies)', () => {
|
||||
const result = validateRedirectUri('https://example.com', 'http', 'example.com', mockLogger);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.validatedUri).toBe('https://example.com');
|
||||
});
|
||||
|
||||
it('should accept matching protocols', () => {
|
||||
const result = validateRedirectUri('http://example.com', 'http', 'example.com', mockLogger);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.validatedUri).toBe('http://example.com');
|
||||
});
|
||||
});
|
||||
|
||||
describe('malformed URL handling', () => {
|
||||
it('should reject invalid URL format', () => {
|
||||
const result = validateRedirectUri('not-a-valid-url', 'https', 'example.com', mockLogger);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.validatedUri).toBe('https://example.com');
|
||||
expect(result.reason).toContain('Invalid redirect_uri format');
|
||||
expect(mockLogger.warn).toHaveBeenCalledWith('Invalid redirect_uri format: not-a-valid-url');
|
||||
});
|
||||
|
||||
it('should reject javascript protocol', () => {
|
||||
const result = validateRedirectUri(
|
||||
'javascript:alert(1)',
|
||||
'https',
|
||||
'example.com',
|
||||
mockLogger
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.validatedUri).toBe('https://example.com');
|
||||
});
|
||||
|
||||
it('should handle URLs with paths and query params', () => {
|
||||
const result = validateRedirectUri(
|
||||
'https://example.com:3000/callback?foo=bar',
|
||||
'https',
|
||||
'example.com',
|
||||
mockLogger
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.validatedUri).toBe('https://example.com:3000/callback?foo=bar');
|
||||
});
|
||||
});
|
||||
|
||||
describe('security scenarios', () => {
|
||||
it('should prevent open redirect to attacker domain', () => {
|
||||
const result = validateRedirectUri(
|
||||
'https://attacker.com/steal-token',
|
||||
'https',
|
||||
'legitimate.com',
|
||||
mockLogger
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.validatedUri).toBe('https://legitimate.com');
|
||||
});
|
||||
|
||||
it('should prevent homograph attacks with similar looking domains', () => {
|
||||
const result = validateRedirectUri(
|
||||
'https://examp1e.com', // with number 1 instead of letter l
|
||||
'https',
|
||||
'example.com',
|
||||
mockLogger
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.validatedUri).toBe('https://example.com');
|
||||
});
|
||||
|
||||
it('should handle localhost variations correctly', () => {
|
||||
const result = validateRedirectUri(
|
||||
'http://localhost:3001',
|
||||
'http',
|
||||
'localhost:3000',
|
||||
mockLogger
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.validatedUri).toBe('http://localhost:3001');
|
||||
});
|
||||
|
||||
it('should handle IP addresses correctly', () => {
|
||||
const result = validateRedirectUri(
|
||||
'http://192.168.1.100:3001',
|
||||
'http',
|
||||
'192.168.1.100:3000',
|
||||
mockLogger
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.validatedUri).toBe('http://192.168.1.100:3001');
|
||||
});
|
||||
|
||||
it('should reject IP when expecting domain', () => {
|
||||
const result = validateRedirectUri(
|
||||
'https://192.168.1.100',
|
||||
'https',
|
||||
'example.com',
|
||||
mockLogger
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.validatedUri).toBe('https://example.com');
|
||||
});
|
||||
});
|
||||
|
||||
describe('allowed origins validation', () => {
|
||||
it('should accept redirect URI with different hostname if in allowed origins', () => {
|
||||
const result = validateRedirectUri(
|
||||
'https://devgen-bad-dev1.local/graphql/api/auth/oidc/callback',
|
||||
'http',
|
||||
'devgen-dev1.local',
|
||||
mockLogger,
|
||||
['https://devgen-bad-dev1.local']
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.validatedUri).toBe(
|
||||
'https://devgen-bad-dev1.local/graphql/api/auth/oidc/callback'
|
||||
);
|
||||
});
|
||||
|
||||
it('should reject redirect URI with hostname not in allowed origins', () => {
|
||||
const result = validateRedirectUri(
|
||||
'https://evil.com/callback',
|
||||
'http',
|
||||
'devgen-dev1.local',
|
||||
mockLogger,
|
||||
['https://devgen-bad-dev1.local', 'https://another-allowed.local']
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.reason).toContain('Hostname or protocol mismatch');
|
||||
});
|
||||
|
||||
it('should handle multiple allowed origins', () => {
|
||||
const result = validateRedirectUri(
|
||||
'https://second-allowed.local/callback',
|
||||
'http',
|
||||
'devgen-dev1.local',
|
||||
mockLogger,
|
||||
[
|
||||
'https://first-allowed.local',
|
||||
'https://second-allowed.local',
|
||||
'https://third-allowed.local',
|
||||
]
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.validatedUri).toBe('https://second-allowed.local/callback');
|
||||
});
|
||||
|
||||
it('should allow HTTPS upgrade for allowed origins', () => {
|
||||
const result = validateRedirectUri(
|
||||
'https://allowed-host.local/callback',
|
||||
'http',
|
||||
'devgen-dev1.local',
|
||||
mockLogger,
|
||||
['http://allowed-host.local']
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.validatedUri).toBe('https://allowed-host.local/callback');
|
||||
});
|
||||
|
||||
it('should validate extra hostname from config with proper logging', () => {
|
||||
// Simulate a config with an extra hostname like "unraid.local"
|
||||
const allowedOriginsFromConfig = ['https://unraid.local:8443'];
|
||||
|
||||
const result = validateRedirectUri(
|
||||
'https://unraid.local:8443/graphql/api/auth/oidc/callback',
|
||||
'http', // Expected from headers
|
||||
'devgen-dev1.local', // Primary hostname
|
||||
mockLogger,
|
||||
allowedOriginsFromConfig
|
||||
);
|
||||
|
||||
// This should pass if the extra hostname is configured correctly
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.validatedUri).toBe('https://unraid.local:8443/graphql/api/auth/oidc/callback');
|
||||
|
||||
// Verify that debug logging shows the check
|
||||
expect(mockLogger.debug).toHaveBeenCalledWith('Checking against 1 allowed origins');
|
||||
expect(mockLogger.debug).toHaveBeenCalledWith(
|
||||
'Checking allowed origin: https://unraid.local:8443'
|
||||
);
|
||||
expect(mockLogger.debug).toHaveBeenCalledWith(
|
||||
' Origin match: https://unraid.local:8443 matches https://unraid.local:8443'
|
||||
);
|
||||
expect(mockLogger.debug).toHaveBeenCalledWith(
|
||||
'Validated redirect_uri against allowed origin: https://unraid.local:8443/graphql/api/auth/oidc/callback'
|
||||
);
|
||||
});
|
||||
|
||||
it('should fail validation when extra hostname is not in config', () => {
|
||||
// Config has no extra hostnames
|
||||
const allowedOriginsFromConfig: string[] = [];
|
||||
|
||||
const result = validateRedirectUri(
|
||||
'https://unraid.local:8443/graphql/api/auth/oidc/callback',
|
||||
'http',
|
||||
'devgen-dev1.local',
|
||||
mockLogger,
|
||||
allowedOriginsFromConfig
|
||||
);
|
||||
|
||||
// This should fail since unraid.local is not in allowed origins
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.validatedUri).toBe('http://devgen-dev1.local');
|
||||
expect(result.reason).toContain('Hostname or protocol mismatch');
|
||||
|
||||
// Verify error logging
|
||||
expect(mockLogger.warn).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Hostname or protocol mismatch')
|
||||
);
|
||||
});
|
||||
|
||||
describe('enhanced matching modes', () => {
|
||||
it('should validate URL prefix match', () => {
|
||||
const allowedOriginsFromConfig = ['https://unraid.local:8443/graphql/api/auth/oidc/'];
|
||||
|
||||
const result = validateRedirectUri(
|
||||
'https://unraid.local:8443/graphql/api/auth/oidc/callback?code=123',
|
||||
'http',
|
||||
'devgen-dev1.local',
|
||||
mockLogger,
|
||||
allowedOriginsFromConfig
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.validatedUri).toBe(
|
||||
'https://unraid.local:8443/graphql/api/auth/oidc/callback?code=123'
|
||||
);
|
||||
expect(mockLogger.debug).toHaveBeenCalledWith(
|
||||
' URL prefix match: https://unraid.local:8443/graphql/api/auth/oidc/callback?code=123 matches prefix https://unraid.local:8443/graphql/api/auth/oidc/'
|
||||
);
|
||||
});
|
||||
|
||||
it('should validate origin match (protocol + hostname + port)', () => {
|
||||
const allowedOriginsFromConfig = ['https://unraid.local:8443'];
|
||||
|
||||
const result = validateRedirectUri(
|
||||
'https://unraid.local:8443/graphql/api/auth/oidc/callback',
|
||||
'http',
|
||||
'devgen-dev1.local',
|
||||
mockLogger,
|
||||
allowedOriginsFromConfig
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.validatedUri).toBe(
|
||||
'https://unraid.local:8443/graphql/api/auth/oidc/callback'
|
||||
);
|
||||
expect(mockLogger.debug).toHaveBeenCalledWith(
|
||||
' Origin match: https://unraid.local:8443 matches https://unraid.local:8443'
|
||||
);
|
||||
});
|
||||
|
||||
it('should validate hostname-only match (original behavior)', () => {
|
||||
const allowedOriginsFromConfig = ['https://unraid.local'];
|
||||
|
||||
const result = validateRedirectUri(
|
||||
'https://unraid.local:8443/graphql/api/auth/oidc/callback',
|
||||
'http',
|
||||
'devgen-dev1.local',
|
||||
mockLogger,
|
||||
allowedOriginsFromConfig
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.validatedUri).toBe(
|
||||
'https://unraid.local:8443/graphql/api/auth/oidc/callback'
|
||||
);
|
||||
expect(mockLogger.debug).toHaveBeenCalledWith(
|
||||
' Hostname comparison: provided=unraid.local, allowed=unraid.local'
|
||||
);
|
||||
});
|
||||
|
||||
it('should prefer exact URL match over origin match', () => {
|
||||
const allowedOriginsFromConfig = [
|
||||
'https://unraid.local:8443/graphql/api/auth/oidc/callback', // Exact match first
|
||||
'https://unraid.local:8443', // Origin match second
|
||||
];
|
||||
|
||||
const result = validateRedirectUri(
|
||||
'https://unraid.local:8443/graphql/api/auth/oidc/callback',
|
||||
'http',
|
||||
'devgen-dev1.local',
|
||||
mockLogger,
|
||||
allowedOriginsFromConfig
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
// Should match the first item in the array due to order of specificity
|
||||
expect(mockLogger.debug).toHaveBeenCalledWith(
|
||||
' Exact URL match: https://unraid.local:8443/graphql/api/auth/oidc/callback matches https://unraid.local:8443/graphql/api/auth/oidc/callback'
|
||||
);
|
||||
});
|
||||
|
||||
it('should prefer origin match over hostname match', () => {
|
||||
const allowedOriginsFromConfig = [
|
||||
'https://unraid.local:8443', // Origin match first
|
||||
'https://unraid.local', // Hostname match second
|
||||
];
|
||||
|
||||
const result = validateRedirectUri(
|
||||
'https://unraid.local:8443/graphql/api/auth/oidc/callback',
|
||||
'http',
|
||||
'devgen-dev1.local',
|
||||
mockLogger,
|
||||
allowedOriginsFromConfig
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(mockLogger.debug).toHaveBeenCalledWith(
|
||||
' Origin match: https://unraid.local:8443 matches https://unraid.local:8443'
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle protocol upgrades for all matching modes', () => {
|
||||
const allowedOriginsFromConfig = [
|
||||
'http://unraid.local:8443/graphql/api/auth/oidc/callback',
|
||||
'http://unraid.local:8443',
|
||||
'http://unraid.local',
|
||||
];
|
||||
|
||||
// Test exact URL with protocol upgrade
|
||||
const result1 = validateRedirectUri(
|
||||
'https://unraid.local:8443/graphql/api/auth/oidc/callback',
|
||||
'http',
|
||||
'devgen-dev1.local',
|
||||
mockLogger,
|
||||
allowedOriginsFromConfig
|
||||
);
|
||||
expect(result1.isValid).toBe(true);
|
||||
|
||||
// Test origin with protocol upgrade
|
||||
const result2 = validateRedirectUri(
|
||||
'https://unraid.local:8443/different/path',
|
||||
'http',
|
||||
'devgen-dev1.local',
|
||||
mockLogger,
|
||||
allowedOriginsFromConfig
|
||||
);
|
||||
expect(result2.isValid).toBe(true);
|
||||
|
||||
// Test hostname with protocol upgrade
|
||||
const result3 = validateRedirectUri(
|
||||
'https://unraid.local:9443/different/path',
|
||||
'http',
|
||||
'devgen-dev1.local',
|
||||
mockLogger,
|
||||
allowedOriginsFromConfig
|
||||
);
|
||||
expect(result3.isValid).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle invalid allowed origin formats gracefully', () => {
|
||||
const allowedOriginsFromConfig = ['not-a-valid-url', 'https://valid.url'];
|
||||
|
||||
const result = validateRedirectUri(
|
||||
'https://valid.url/graphql/api/auth/oidc/callback',
|
||||
'http',
|
||||
'devgen-dev1.local',
|
||||
mockLogger,
|
||||
allowedOriginsFromConfig
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(mockLogger.warn).toHaveBeenCalledWith(
|
||||
'Invalid allowed origin format: not-a-valid-url'
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
187
api/src/unraid-api/utils/redirect-uri-validator.ts
Normal file
187
api/src/unraid-api/utils/redirect-uri-validator.ts
Normal file
@@ -0,0 +1,187 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
|
||||
export interface RedirectUriValidationResult {
|
||||
isValid: boolean;
|
||||
validatedUri: string;
|
||||
reason?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates a redirect URI against the expected origin from request headers.
|
||||
* This is critical for OAuth security to prevent authorization code interception.
|
||||
*
|
||||
* Security considerations:
|
||||
* - Prevents redirecting OAuth codes to external domains
|
||||
* - Allows port variations (needed for nginx/socket proxy scenarios)
|
||||
* - Validates protocol to prevent downgrade attacks
|
||||
* - Optionally validates against additional allowed origins
|
||||
*
|
||||
* @param redirectUri - The redirect URI provided by the client
|
||||
* @param expectedProtocol - The protocol from request headers (http/https)
|
||||
* @param expectedHost - The host from request headers (may or may not include port)
|
||||
* @param logger - Optional logger for debugging
|
||||
* @param allowedOrigins - Optional list of additional allowed origins
|
||||
* @returns Validation result with the URI to use
|
||||
*/
|
||||
export function validateRedirectUri(
|
||||
redirectUri: string | undefined,
|
||||
expectedProtocol: string,
|
||||
expectedHost: string | undefined,
|
||||
logger?: Logger,
|
||||
allowedOrigins?: string[]
|
||||
): RedirectUriValidationResult {
|
||||
const baseUrl = expectedHost ? `${expectedProtocol}://${expectedHost}` : undefined;
|
||||
|
||||
// If no redirect URI provided, use the base URL
|
||||
if (!redirectUri || !baseUrl) {
|
||||
return {
|
||||
isValid: !redirectUri,
|
||||
validatedUri: baseUrl || '',
|
||||
reason: !redirectUri ? 'No redirect URI provided' : 'No base URL available',
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
// Parse both URLs to validate hostname
|
||||
const providedUrl = new URL(redirectUri);
|
||||
const expectedUrl = new URL(baseUrl);
|
||||
|
||||
// Security: Validate hostname matches, but allow port differences
|
||||
// This handles cases where nginx/socket proxy doesn't preserve port info
|
||||
const providedHostname = providedUrl.hostname.toLowerCase();
|
||||
const expectedHostname = expectedUrl.hostname.toLowerCase();
|
||||
|
||||
// Check protocol matches, but allow HTTPS when expecting HTTP (common with reverse proxies)
|
||||
// Never allow HTTP when expecting HTTPS (would be a downgrade attack)
|
||||
const protocolMatches =
|
||||
providedUrl.protocol === expectedUrl.protocol ||
|
||||
(expectedUrl.protocol === 'http:' && providedUrl.protocol === 'https:');
|
||||
const hostnameMatches = providedHostname === expectedHostname;
|
||||
|
||||
// Check against primary expected origin
|
||||
if (protocolMatches && hostnameMatches) {
|
||||
// Trust the redirect_uri with its port information
|
||||
logger?.debug(`Validated redirect_uri: ${redirectUri}`);
|
||||
return {
|
||||
isValid: true,
|
||||
validatedUri: redirectUri,
|
||||
};
|
||||
}
|
||||
|
||||
// Check against additional allowed origins if provided
|
||||
if (allowedOrigins && allowedOrigins.length > 0) {
|
||||
logger?.debug(`Checking against ${allowedOrigins.length} allowed origins`);
|
||||
for (const allowedOrigin of allowedOrigins) {
|
||||
try {
|
||||
const allowedUrl = new URL(allowedOrigin);
|
||||
const allowedOriginStr = allowedUrl.origin.toLowerCase();
|
||||
const allowedHostname = allowedUrl.hostname.toLowerCase();
|
||||
|
||||
logger?.debug(`Checking allowed origin: ${allowedOrigin}`);
|
||||
|
||||
// Try multiple matching strategies in order of specificity
|
||||
|
||||
// 1. Exact URL match (if allowed origin includes path/query)
|
||||
if (allowedOrigin.includes('/') && allowedOrigin.length > allowedOriginStr.length) {
|
||||
const allowedUrlNormalized = allowedOrigin.toLowerCase();
|
||||
const providedUrlNormalized = redirectUri.toLowerCase();
|
||||
|
||||
// Exact match
|
||||
if (providedUrlNormalized === allowedUrlNormalized) {
|
||||
logger?.debug(` Exact URL match: ${redirectUri} matches ${allowedOrigin}`);
|
||||
logger?.debug(
|
||||
`Validated redirect_uri against allowed origin: ${redirectUri}`
|
||||
);
|
||||
return {
|
||||
isValid: true,
|
||||
validatedUri: redirectUri,
|
||||
};
|
||||
}
|
||||
|
||||
// Prefix match (if allowed origin ends with /)
|
||||
if (
|
||||
allowedUrlNormalized.endsWith('/') &&
|
||||
providedUrlNormalized.startsWith(allowedUrlNormalized)
|
||||
) {
|
||||
logger?.debug(
|
||||
` URL prefix match: ${redirectUri} matches prefix ${allowedOrigin}`
|
||||
);
|
||||
logger?.debug(
|
||||
`Validated redirect_uri against allowed origin: ${redirectUri}`
|
||||
);
|
||||
return {
|
||||
isValid: true,
|
||||
validatedUri: redirectUri,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Origin match (protocol + hostname + port)
|
||||
const providedOrigin = providedUrl.origin.toLowerCase();
|
||||
if (providedOrigin === allowedOriginStr) {
|
||||
// Allow HTTPS when expecting HTTP (common with reverse proxies)
|
||||
const originProtocolMatches =
|
||||
providedUrl.protocol === allowedUrl.protocol ||
|
||||
(allowedUrl.protocol === 'http:' && providedUrl.protocol === 'https:');
|
||||
|
||||
if (originProtocolMatches) {
|
||||
logger?.debug(
|
||||
` Origin match: ${providedOrigin} matches ${allowedOriginStr}`
|
||||
);
|
||||
logger?.debug(
|
||||
`Validated redirect_uri against allowed origin: ${redirectUri}`
|
||||
);
|
||||
return {
|
||||
isValid: true,
|
||||
validatedUri: redirectUri,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Hostname match (original behavior, but with better logging)
|
||||
const allowedProtocolMatches =
|
||||
providedUrl.protocol === allowedUrl.protocol ||
|
||||
(allowedUrl.protocol === 'http:' && providedUrl.protocol === 'https:');
|
||||
const allowedHostnameMatches = providedHostname === allowedHostname;
|
||||
|
||||
logger?.debug(
|
||||
` Hostname comparison: provided=${providedHostname}, allowed=${allowedHostname}`
|
||||
);
|
||||
logger?.debug(
|
||||
` Protocol comparison: provided=${providedUrl.protocol}, allowed=${allowedUrl.protocol}`
|
||||
);
|
||||
logger?.debug(
|
||||
` Protocol matches: ${allowedProtocolMatches}, Hostname matches: ${allowedHostnameMatches}`
|
||||
);
|
||||
|
||||
if (allowedProtocolMatches && allowedHostnameMatches) {
|
||||
logger?.debug(`Validated redirect_uri against allowed origin: ${redirectUri}`);
|
||||
return {
|
||||
isValid: true,
|
||||
validatedUri: redirectUri,
|
||||
};
|
||||
}
|
||||
} catch (e) {
|
||||
logger?.warn(`Invalid allowed origin format: ${allowedOrigin}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we get here, validation failed
|
||||
const reason = `Hostname or protocol mismatch. Expected: ${expectedUrl.protocol}//${expectedHostname}, Got: ${providedUrl.protocol}//${providedHostname}`;
|
||||
logger?.warn(`Rejected redirect_uri: ${reason}`);
|
||||
return {
|
||||
isValid: false,
|
||||
validatedUri: baseUrl,
|
||||
reason,
|
||||
};
|
||||
} catch (error) {
|
||||
const reason = `Invalid redirect_uri format: ${redirectUri}`;
|
||||
logger?.warn(reason);
|
||||
return {
|
||||
isValid: false,
|
||||
validatedUri: baseUrl,
|
||||
reason,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@
|
||||
"scripts": {
|
||||
"build": "pnpm -r build",
|
||||
"build:watch": " pnpm -r --parallel build:watch",
|
||||
"codegen": "pnpm -r codegen",
|
||||
"dev": "pnpm -r dev",
|
||||
"unraid:deploy": "pnpm -r unraid:deploy",
|
||||
"test": "pnpm -r test",
|
||||
|
||||
@@ -1,94 +1,87 @@
|
||||
/* eslint-disable */
|
||||
import type {
|
||||
DocumentTypeDecoration,
|
||||
ResultOf,
|
||||
TypedDocumentNode,
|
||||
} from '@graphql-typed-document-node/core';
|
||||
import type { ResultOf, DocumentTypeDecoration, TypedDocumentNode } from '@graphql-typed-document-node/core';
|
||||
import type { FragmentDefinitionNode } from 'graphql';
|
||||
|
||||
import type { Incremental } from './graphql.js';
|
||||
|
||||
export type FragmentType<TDocumentType extends DocumentTypeDecoration<any, any>> =
|
||||
TDocumentType extends DocumentTypeDecoration<infer TType, any>
|
||||
? [TType] extends [{ ' $fragmentName'?: infer TKey }]
|
||||
? TKey extends string
|
||||
? { ' $fragmentRefs'?: { [key in TKey]: TType } }
|
||||
: never
|
||||
: never
|
||||
: never;
|
||||
|
||||
export type FragmentType<TDocumentType extends DocumentTypeDecoration<any, any>> = TDocumentType extends DocumentTypeDecoration<
|
||||
infer TType,
|
||||
any
|
||||
>
|
||||
? [TType] extends [{ ' $fragmentName'?: infer TKey }]
|
||||
? TKey extends string
|
||||
? { ' $fragmentRefs'?: { [key in TKey]: TType } }
|
||||
: never
|
||||
: never
|
||||
: never;
|
||||
|
||||
// return non-nullable if `fragmentType` is non-nullable
|
||||
export function useFragment<TType>(
|
||||
_documentNode: DocumentTypeDecoration<TType, any>,
|
||||
fragmentType: FragmentType<DocumentTypeDecoration<TType, any>>
|
||||
_documentNode: DocumentTypeDecoration<TType, any>,
|
||||
fragmentType: FragmentType<DocumentTypeDecoration<TType, any>>
|
||||
): TType;
|
||||
// return nullable if `fragmentType` is undefined
|
||||
export function useFragment<TType>(
|
||||
_documentNode: DocumentTypeDecoration<TType, any>,
|
||||
fragmentType: FragmentType<DocumentTypeDecoration<TType, any>> | undefined
|
||||
_documentNode: DocumentTypeDecoration<TType, any>,
|
||||
fragmentType: FragmentType<DocumentTypeDecoration<TType, any>> | undefined
|
||||
): TType | undefined;
|
||||
// return nullable if `fragmentType` is nullable
|
||||
export function useFragment<TType>(
|
||||
_documentNode: DocumentTypeDecoration<TType, any>,
|
||||
fragmentType: FragmentType<DocumentTypeDecoration<TType, any>> | null
|
||||
_documentNode: DocumentTypeDecoration<TType, any>,
|
||||
fragmentType: FragmentType<DocumentTypeDecoration<TType, any>> | null
|
||||
): TType | null;
|
||||
// return nullable if `fragmentType` is nullable or undefined
|
||||
export function useFragment<TType>(
|
||||
_documentNode: DocumentTypeDecoration<TType, any>,
|
||||
fragmentType: FragmentType<DocumentTypeDecoration<TType, any>> | null | undefined
|
||||
_documentNode: DocumentTypeDecoration<TType, any>,
|
||||
fragmentType: FragmentType<DocumentTypeDecoration<TType, any>> | null | undefined
|
||||
): TType | null | undefined;
|
||||
// return array of non-nullable if `fragmentType` is array of non-nullable
|
||||
export function useFragment<TType>(
|
||||
_documentNode: DocumentTypeDecoration<TType, any>,
|
||||
fragmentType: Array<FragmentType<DocumentTypeDecoration<TType, any>>>
|
||||
_documentNode: DocumentTypeDecoration<TType, any>,
|
||||
fragmentType: Array<FragmentType<DocumentTypeDecoration<TType, any>>>
|
||||
): Array<TType>;
|
||||
// return array of nullable if `fragmentType` is array of nullable
|
||||
export function useFragment<TType>(
|
||||
_documentNode: DocumentTypeDecoration<TType, any>,
|
||||
fragmentType: Array<FragmentType<DocumentTypeDecoration<TType, any>>> | null | undefined
|
||||
_documentNode: DocumentTypeDecoration<TType, any>,
|
||||
fragmentType: Array<FragmentType<DocumentTypeDecoration<TType, any>>> | null | undefined
|
||||
): Array<TType> | null | undefined;
|
||||
// return readonly array of non-nullable if `fragmentType` is array of non-nullable
|
||||
export function useFragment<TType>(
|
||||
_documentNode: DocumentTypeDecoration<TType, any>,
|
||||
fragmentType: ReadonlyArray<FragmentType<DocumentTypeDecoration<TType, any>>>
|
||||
_documentNode: DocumentTypeDecoration<TType, any>,
|
||||
fragmentType: ReadonlyArray<FragmentType<DocumentTypeDecoration<TType, any>>>
|
||||
): ReadonlyArray<TType>;
|
||||
// return readonly array of nullable if `fragmentType` is array of nullable
|
||||
export function useFragment<TType>(
|
||||
_documentNode: DocumentTypeDecoration<TType, any>,
|
||||
fragmentType: ReadonlyArray<FragmentType<DocumentTypeDecoration<TType, any>>> | null | undefined
|
||||
_documentNode: DocumentTypeDecoration<TType, any>,
|
||||
fragmentType: ReadonlyArray<FragmentType<DocumentTypeDecoration<TType, any>>> | null | undefined
|
||||
): ReadonlyArray<TType> | null | undefined;
|
||||
export function useFragment<TType>(
|
||||
_documentNode: DocumentTypeDecoration<TType, any>,
|
||||
fragmentType:
|
||||
| FragmentType<DocumentTypeDecoration<TType, any>>
|
||||
| Array<FragmentType<DocumentTypeDecoration<TType, any>>>
|
||||
| ReadonlyArray<FragmentType<DocumentTypeDecoration<TType, any>>>
|
||||
| null
|
||||
| undefined
|
||||
_documentNode: DocumentTypeDecoration<TType, any>,
|
||||
fragmentType: FragmentType<DocumentTypeDecoration<TType, any>> | Array<FragmentType<DocumentTypeDecoration<TType, any>>> | ReadonlyArray<FragmentType<DocumentTypeDecoration<TType, any>>> | null | undefined
|
||||
): TType | Array<TType> | ReadonlyArray<TType> | null | undefined {
|
||||
return fragmentType as any;
|
||||
return fragmentType as any;
|
||||
}
|
||||
|
||||
export function makeFragmentData<F extends DocumentTypeDecoration<any, any>, FT extends ResultOf<F>>(
|
||||
data: FT,
|
||||
_fragment: F
|
||||
): FragmentType<F> {
|
||||
return data as FragmentType<F>;
|
||||
|
||||
export function makeFragmentData<
|
||||
F extends DocumentTypeDecoration<any, any>,
|
||||
FT extends ResultOf<F>
|
||||
>(data: FT, _fragment: F): FragmentType<F> {
|
||||
return data as FragmentType<F>;
|
||||
}
|
||||
export function isFragmentReady<TQuery, TFrag>(
|
||||
queryNode: DocumentTypeDecoration<TQuery, any>,
|
||||
fragmentNode: TypedDocumentNode<TFrag>,
|
||||
data: FragmentType<TypedDocumentNode<Incremental<TFrag>, any>> | null | undefined
|
||||
queryNode: DocumentTypeDecoration<TQuery, any>,
|
||||
fragmentNode: TypedDocumentNode<TFrag>,
|
||||
data: FragmentType<TypedDocumentNode<Incremental<TFrag>, any>> | null | undefined
|
||||
): data is FragmentType<typeof fragmentNode> {
|
||||
const deferredFields = (
|
||||
queryNode as { __meta__?: { deferredFields: Record<string, (keyof TFrag)[]> } }
|
||||
).__meta__?.deferredFields;
|
||||
const deferredFields = (queryNode as { __meta__?: { deferredFields: Record<string, (keyof TFrag)[]> } }).__meta__
|
||||
?.deferredFields;
|
||||
|
||||
if (!deferredFields) return true;
|
||||
if (!deferredFields) return true;
|
||||
|
||||
const fragDef = fragmentNode.definitions[0] as FragmentDefinitionNode | undefined;
|
||||
const fragName = fragDef?.name?.value;
|
||||
const fragDef = fragmentNode.definitions[0] as FragmentDefinitionNode | undefined;
|
||||
const fragName = fragDef?.name?.value;
|
||||
|
||||
const fields = (fragName && deferredFields[fragName]) || [];
|
||||
return fields.length > 0 && fields.every((field) => data && field in data);
|
||||
const fields = (fragName && deferredFields[fragName]) || [];
|
||||
return fields.length > 0 && fields.every(field => data && field in data);
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
/* eslint-disable */
|
||||
import type { TypedDocumentNode as DocumentNode } from '@graphql-typed-document-node/core';
|
||||
|
||||
import * as types from './graphql.js';
|
||||
import type { TypedDocumentNode as DocumentNode } from '@graphql-typed-document-node/core';
|
||||
|
||||
/**
|
||||
* Map of all GraphQL operations in the project.
|
||||
@@ -15,17 +14,14 @@ import * as types from './graphql.js';
|
||||
* Learn more about it here: https://the-guild.dev/graphql/codegen/plugins/presets/preset-client#reducing-bundle-size
|
||||
*/
|
||||
type Documents = {
|
||||
'\n fragment RemoteGraphQLEventFragment on RemoteGraphQLEvent {\n remoteGraphQLEventData: data {\n type\n body\n sha256\n }\n }\n': typeof types.RemoteGraphQlEventFragmentFragmentDoc;
|
||||
'\n subscription events {\n events {\n __typename\n ... on ClientConnectedEvent {\n connectedData: data {\n type\n version\n apiKey\n }\n connectedEvent: type\n }\n ... on ClientDisconnectedEvent {\n disconnectedData: data {\n type\n version\n apiKey\n }\n disconnectedEvent: type\n }\n ...RemoteGraphQLEventFragment\n }\n }\n': typeof types.EventsDocument;
|
||||
'\n mutation sendRemoteGraphQLResponse($input: RemoteGraphQLServerInput!) {\n remoteGraphQLResponse(input: $input)\n }\n': typeof types.SendRemoteGraphQlResponseDocument;
|
||||
"\n fragment RemoteGraphQLEventFragment on RemoteGraphQLEvent {\n remoteGraphQLEventData: data {\n type\n body\n sha256\n }\n }\n": typeof types.RemoteGraphQlEventFragmentFragmentDoc,
|
||||
"\n subscription events {\n events {\n __typename\n ... on ClientConnectedEvent {\n connectedData: data {\n type\n version\n apiKey\n }\n connectedEvent: type\n }\n ... on ClientDisconnectedEvent {\n disconnectedData: data {\n type\n version\n apiKey\n }\n disconnectedEvent: type\n }\n ...RemoteGraphQLEventFragment\n }\n }\n": typeof types.EventsDocument,
|
||||
"\n mutation sendRemoteGraphQLResponse($input: RemoteGraphQLServerInput!) {\n remoteGraphQLResponse(input: $input)\n }\n": typeof types.SendRemoteGraphQlResponseDocument,
|
||||
};
|
||||
const documents: Documents = {
|
||||
'\n fragment RemoteGraphQLEventFragment on RemoteGraphQLEvent {\n remoteGraphQLEventData: data {\n type\n body\n sha256\n }\n }\n':
|
||||
types.RemoteGraphQlEventFragmentFragmentDoc,
|
||||
'\n subscription events {\n events {\n __typename\n ... on ClientConnectedEvent {\n connectedData: data {\n type\n version\n apiKey\n }\n connectedEvent: type\n }\n ... on ClientDisconnectedEvent {\n disconnectedData: data {\n type\n version\n apiKey\n }\n disconnectedEvent: type\n }\n ...RemoteGraphQLEventFragment\n }\n }\n':
|
||||
types.EventsDocument,
|
||||
'\n mutation sendRemoteGraphQLResponse($input: RemoteGraphQLServerInput!) {\n remoteGraphQLResponse(input: $input)\n }\n':
|
||||
types.SendRemoteGraphQlResponseDocument,
|
||||
"\n fragment RemoteGraphQLEventFragment on RemoteGraphQLEvent {\n remoteGraphQLEventData: data {\n type\n body\n sha256\n }\n }\n": types.RemoteGraphQlEventFragmentFragmentDoc,
|
||||
"\n subscription events {\n events {\n __typename\n ... on ClientConnectedEvent {\n connectedData: data {\n type\n version\n apiKey\n }\n connectedEvent: type\n }\n ... on ClientDisconnectedEvent {\n disconnectedData: data {\n type\n version\n apiKey\n }\n disconnectedEvent: type\n }\n ...RemoteGraphQLEventFragment\n }\n }\n": types.EventsDocument,
|
||||
"\n mutation sendRemoteGraphQLResponse($input: RemoteGraphQLServerInput!) {\n remoteGraphQLResponse(input: $input)\n }\n": types.SendRemoteGraphQlResponseDocument,
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -45,25 +41,18 @@ export function graphql(source: string): unknown;
|
||||
/**
|
||||
* The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients.
|
||||
*/
|
||||
export function graphql(
|
||||
source: '\n fragment RemoteGraphQLEventFragment on RemoteGraphQLEvent {\n remoteGraphQLEventData: data {\n type\n body\n sha256\n }\n }\n'
|
||||
): (typeof documents)['\n fragment RemoteGraphQLEventFragment on RemoteGraphQLEvent {\n remoteGraphQLEventData: data {\n type\n body\n sha256\n }\n }\n'];
|
||||
export function graphql(source: "\n fragment RemoteGraphQLEventFragment on RemoteGraphQLEvent {\n remoteGraphQLEventData: data {\n type\n body\n sha256\n }\n }\n"): (typeof documents)["\n fragment RemoteGraphQLEventFragment on RemoteGraphQLEvent {\n remoteGraphQLEventData: data {\n type\n body\n sha256\n }\n }\n"];
|
||||
/**
|
||||
* The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients.
|
||||
*/
|
||||
export function graphql(
|
||||
source: '\n subscription events {\n events {\n __typename\n ... on ClientConnectedEvent {\n connectedData: data {\n type\n version\n apiKey\n }\n connectedEvent: type\n }\n ... on ClientDisconnectedEvent {\n disconnectedData: data {\n type\n version\n apiKey\n }\n disconnectedEvent: type\n }\n ...RemoteGraphQLEventFragment\n }\n }\n'
|
||||
): (typeof documents)['\n subscription events {\n events {\n __typename\n ... on ClientConnectedEvent {\n connectedData: data {\n type\n version\n apiKey\n }\n connectedEvent: type\n }\n ... on ClientDisconnectedEvent {\n disconnectedData: data {\n type\n version\n apiKey\n }\n disconnectedEvent: type\n }\n ...RemoteGraphQLEventFragment\n }\n }\n'];
|
||||
export function graphql(source: "\n subscription events {\n events {\n __typename\n ... on ClientConnectedEvent {\n connectedData: data {\n type\n version\n apiKey\n }\n connectedEvent: type\n }\n ... on ClientDisconnectedEvent {\n disconnectedData: data {\n type\n version\n apiKey\n }\n disconnectedEvent: type\n }\n ...RemoteGraphQLEventFragment\n }\n }\n"): (typeof documents)["\n subscription events {\n events {\n __typename\n ... on ClientConnectedEvent {\n connectedData: data {\n type\n version\n apiKey\n }\n connectedEvent: type\n }\n ... on ClientDisconnectedEvent {\n disconnectedData: data {\n type\n version\n apiKey\n }\n disconnectedEvent: type\n }\n ...RemoteGraphQLEventFragment\n }\n }\n"];
|
||||
/**
|
||||
* The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients.
|
||||
*/
|
||||
export function graphql(
|
||||
source: '\n mutation sendRemoteGraphQLResponse($input: RemoteGraphQLServerInput!) {\n remoteGraphQLResponse(input: $input)\n }\n'
|
||||
): (typeof documents)['\n mutation sendRemoteGraphQLResponse($input: RemoteGraphQLServerInput!) {\n remoteGraphQLResponse(input: $input)\n }\n'];
|
||||
export function graphql(source: "\n mutation sendRemoteGraphQLResponse($input: RemoteGraphQLServerInput!) {\n remoteGraphQLResponse(input: $input)\n }\n"): (typeof documents)["\n mutation sendRemoteGraphQLResponse($input: RemoteGraphQLServerInput!) {\n remoteGraphQLResponse(input: $input)\n }\n"];
|
||||
|
||||
export function graphql(source: string) {
|
||||
return (documents as any)[source] ?? {};
|
||||
return (documents as any)[source] ?? {};
|
||||
}
|
||||
|
||||
export type DocumentType<TDocumentNode extends DocumentNode<any, any>> =
|
||||
TDocumentNode extends DocumentNode<infer TType, any> ? TType : never;
|
||||
export type DocumentType<TDocumentNode extends DocumentNode<any, any>> = TDocumentNode extends DocumentNode< infer TType, any> ? TType : never;
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,2 +1,2 @@
|
||||
export * from './fragment-masking.js';
|
||||
export * from './gql.js';
|
||||
export * from "./fragment-masking.js";
|
||||
export * from "./gql.js";
|
||||
@@ -0,0 +1,6 @@
|
||||
Menu="ManagementAccess:160"
|
||||
Title="API Config Download"
|
||||
Icon="icon-download"
|
||||
Tag="download"
|
||||
---
|
||||
<unraid-config-download />
|
||||
85
pnpm-lock.yaml
generated
85
pnpm-lock.yaml
generated
@@ -178,6 +178,9 @@ importers:
|
||||
dotenv:
|
||||
specifier: 17.2.1
|
||||
version: 17.2.1
|
||||
escape-html:
|
||||
specifier: 1.0.3
|
||||
version: 1.0.3
|
||||
execa:
|
||||
specifier: 9.6.0
|
||||
version: 9.6.0
|
||||
@@ -1091,6 +1094,9 @@ importers:
|
||||
ajv:
|
||||
specifier: 8.17.1
|
||||
version: 8.17.1
|
||||
ansi_up:
|
||||
specifier: ^6.0.6
|
||||
version: 6.0.6
|
||||
class-variance-authority:
|
||||
specifier: 0.7.1
|
||||
version: 0.7.1
|
||||
@@ -6139,6 +6145,9 @@ packages:
|
||||
resolution: {integrity: sha512-bN798gFfQX+viw3R7yrGWRqnrN2oRkEkUjjl4JNn4E8GxxbjtG3FbrEIIY3l8/hrwUwIeCZvi4QuOTP4MErVug==}
|
||||
engines: {node: '>=12'}
|
||||
|
||||
ansi_up@6.0.6:
|
||||
resolution: {integrity: sha512-yIa1x3Ecf8jWP4UWEunNjqNX6gzE4vg2gGz+xqRGY+TBSucnYp6RRdPV4brmtg6bQ1ljD48mZ5iGSEj7QEpRKA==}
|
||||
|
||||
ansis@4.0.0-node10:
|
||||
resolution: {integrity: sha512-BRrU0Bo1X9dFGw6KgGz6hWrqQuOlVEDOzkb0QSLZY9sXHqA7pNj7yHPVJRz7y/rj4EOJ3d/D5uxH+ee9leYgsg==}
|
||||
engines: {node: '>=10'}
|
||||
@@ -7694,10 +7703,6 @@ packages:
|
||||
errx@0.1.0:
|
||||
resolution: {integrity: sha512-fZmsRiDNv07K6s2KkKFTiD2aIvECa7++PKyD5NC32tpRw46qZA3sOz+aM+/V9V0GDHxVTKLziveV4JhzBHDp9Q==}
|
||||
|
||||
es-abstract@1.23.9:
|
||||
resolution: {integrity: sha512-py07lI0wjxAC/DcfK1S6G7iANonniZwTISvdPzk9hzeH0IZIshbuuFxLIU96OyF89Yb9hiqWn8M/bY83KY5vzA==}
|
||||
engines: {node: '>= 0.4'}
|
||||
|
||||
es-abstract@1.24.0:
|
||||
resolution: {integrity: sha512-WSzPgsdLtTcQwm4CROfS5ju2Wa1QQcVeT37jFjYzdFz1r9ahadC8B8/a4qxJxM+09F18iumCdRmlr96ZYkQvEg==}
|
||||
engines: {node: '>= 0.4'}
|
||||
@@ -19163,6 +19168,8 @@ snapshots:
|
||||
|
||||
ansi-styles@6.2.1: {}
|
||||
|
||||
ansi_up@6.0.6: {}
|
||||
|
||||
ansis@4.0.0-node10: {}
|
||||
|
||||
ansis@4.1.0: {}
|
||||
@@ -19249,7 +19256,7 @@ snapshots:
|
||||
call-bind: 1.0.8
|
||||
call-bound: 1.0.4
|
||||
define-properties: 1.2.1
|
||||
es-abstract: 1.23.9
|
||||
es-abstract: 1.24.0
|
||||
es-errors: 1.3.0
|
||||
es-object-atoms: 1.1.1
|
||||
es-shim-unscopables: 1.1.0
|
||||
@@ -19258,14 +19265,14 @@ snapshots:
|
||||
dependencies:
|
||||
call-bind: 1.0.8
|
||||
define-properties: 1.2.1
|
||||
es-abstract: 1.23.9
|
||||
es-abstract: 1.24.0
|
||||
es-shim-unscopables: 1.1.0
|
||||
|
||||
array.prototype.flatmap@1.3.3:
|
||||
dependencies:
|
||||
call-bind: 1.0.8
|
||||
define-properties: 1.2.1
|
||||
es-abstract: 1.23.9
|
||||
es-abstract: 1.24.0
|
||||
es-shim-unscopables: 1.1.0
|
||||
|
||||
arraybuffer.prototype.slice@1.0.4:
|
||||
@@ -19273,7 +19280,7 @@ snapshots:
|
||||
array-buffer-byte-length: 1.0.2
|
||||
call-bind: 1.0.8
|
||||
define-properties: 1.2.1
|
||||
es-abstract: 1.23.9
|
||||
es-abstract: 1.24.0
|
||||
es-errors: 1.3.0
|
||||
get-intrinsic: 1.3.0
|
||||
is-array-buffer: 3.0.5
|
||||
@@ -20819,60 +20826,6 @@ snapshots:
|
||||
|
||||
errx@0.1.0: {}
|
||||
|
||||
es-abstract@1.23.9:
|
||||
dependencies:
|
||||
array-buffer-byte-length: 1.0.2
|
||||
arraybuffer.prototype.slice: 1.0.4
|
||||
available-typed-arrays: 1.0.7
|
||||
call-bind: 1.0.8
|
||||
call-bound: 1.0.4
|
||||
data-view-buffer: 1.0.2
|
||||
data-view-byte-length: 1.0.2
|
||||
data-view-byte-offset: 1.0.1
|
||||
es-define-property: 1.0.1
|
||||
es-errors: 1.3.0
|
||||
es-object-atoms: 1.1.1
|
||||
es-set-tostringtag: 2.1.0
|
||||
es-to-primitive: 1.3.0
|
||||
function.prototype.name: 1.1.8
|
||||
get-intrinsic: 1.3.0
|
||||
get-proto: 1.0.1
|
||||
get-symbol-description: 1.1.0
|
||||
globalthis: 1.0.4
|
||||
gopd: 1.2.0
|
||||
has-property-descriptors: 1.0.2
|
||||
has-proto: 1.2.0
|
||||
has-symbols: 1.1.0
|
||||
hasown: 2.0.2
|
||||
internal-slot: 1.1.0
|
||||
is-array-buffer: 3.0.5
|
||||
is-callable: 1.2.7
|
||||
is-data-view: 1.0.2
|
||||
is-regex: 1.2.1
|
||||
is-shared-array-buffer: 1.0.4
|
||||
is-string: 1.1.1
|
||||
is-typed-array: 1.1.15
|
||||
is-weakref: 1.1.1
|
||||
math-intrinsics: 1.1.0
|
||||
object-inspect: 1.13.4
|
||||
object-keys: 1.1.1
|
||||
object.assign: 4.1.7
|
||||
own-keys: 1.0.1
|
||||
regexp.prototype.flags: 1.5.4
|
||||
safe-array-concat: 1.1.3
|
||||
safe-push-apply: 1.0.0
|
||||
safe-regex-test: 1.1.0
|
||||
set-proto: 1.0.0
|
||||
string.prototype.trim: 1.2.10
|
||||
string.prototype.trimend: 1.0.9
|
||||
string.prototype.trimstart: 1.0.8
|
||||
typed-array-buffer: 1.0.3
|
||||
typed-array-byte-length: 1.0.3
|
||||
typed-array-byte-offset: 1.0.4
|
||||
typed-array-length: 1.0.7
|
||||
unbox-primitive: 1.1.0
|
||||
which-typed-array: 1.1.19
|
||||
|
||||
es-abstract@1.24.0:
|
||||
dependencies:
|
||||
array-buffer-byte-length: 1.0.2
|
||||
@@ -24175,14 +24128,14 @@ snapshots:
|
||||
dependencies:
|
||||
call-bind: 1.0.8
|
||||
define-properties: 1.2.1
|
||||
es-abstract: 1.23.9
|
||||
es-abstract: 1.24.0
|
||||
es-object-atoms: 1.1.1
|
||||
|
||||
object.groupby@1.0.3:
|
||||
dependencies:
|
||||
call-bind: 1.0.8
|
||||
define-properties: 1.2.1
|
||||
es-abstract: 1.23.9
|
||||
es-abstract: 1.24.0
|
||||
|
||||
object.values@1.2.1:
|
||||
dependencies:
|
||||
@@ -25364,7 +25317,7 @@ snapshots:
|
||||
dependencies:
|
||||
call-bind: 1.0.8
|
||||
define-properties: 1.2.1
|
||||
es-abstract: 1.23.9
|
||||
es-abstract: 1.24.0
|
||||
es-errors: 1.3.0
|
||||
es-object-atoms: 1.1.1
|
||||
get-intrinsic: 1.3.0
|
||||
@@ -26146,7 +26099,7 @@ snapshots:
|
||||
call-bound: 1.0.4
|
||||
define-data-property: 1.1.4
|
||||
define-properties: 1.2.1
|
||||
es-abstract: 1.23.9
|
||||
es-abstract: 1.24.0
|
||||
es-object-atoms: 1.1.1
|
||||
has-property-descriptors: 1.0.2
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ import {
|
||||
AccordionTrigger,
|
||||
} from '@/components/ui/accordion';
|
||||
import { jsonFormsAjv } from '@/forms/config';
|
||||
import type { Layout, UISchemaElement } from '@jsonforms/core';
|
||||
import type { BaseUISchemaElement, Labelable, Layout, UISchemaElement } from '@jsonforms/core';
|
||||
import { isVisible } from '@jsonforms/core';
|
||||
import { DispatchRenderer, useJsonFormsLayout } from '@jsonforms/vue';
|
||||
import type { RendererProps } from '@jsonforms/vue';
|
||||
@@ -61,8 +61,9 @@ const elements = computed(() => {
|
||||
const allElements = props.uischema?.elements || [];
|
||||
|
||||
// Filter elements based on visibility rules
|
||||
return allElements.filter((element: UISchemaElement & Record<string, unknown>) => {
|
||||
if (!element.rule) {
|
||||
return allElements.filter((element) => {
|
||||
const elementWithRule = element as BaseUISchemaElement;
|
||||
if (!elementWithRule.rule) {
|
||||
// No rule means always visible
|
||||
return true;
|
||||
}
|
||||
@@ -71,13 +72,13 @@ const elements = computed(() => {
|
||||
try {
|
||||
// Get the root data from JSONForms context for rule evaluation
|
||||
const rootData = jsonFormsContext?.core?.data || {};
|
||||
const formData = props.data || layout.data || rootData;
|
||||
const formPath = props.path || layout.path || '';
|
||||
const formData = props.data || rootData;
|
||||
const formPath = props.path || layout.value.path || '';
|
||||
|
||||
const visible = isVisible(element as UISchemaElement, formData, formPath, jsonFormsAjv);
|
||||
const visible = isVisible(element, formData, formPath, jsonFormsAjv);
|
||||
return visible;
|
||||
} catch (error) {
|
||||
console.warn('[AccordionLayout] Error evaluating visibility:', error, element.rule);
|
||||
console.warn('[AccordionLayout] Error evaluating visibility:', error, elementWithRule.rule);
|
||||
return true; // Default to visible on error
|
||||
}
|
||||
});
|
||||
@@ -127,31 +128,21 @@ const defaultOpenItems = computed(() => {
|
||||
});
|
||||
|
||||
// Get title for accordion item from element options
|
||||
const getAccordionTitle = (
|
||||
element: UISchemaElement & Record<string, unknown>,
|
||||
index: number
|
||||
): string => {
|
||||
return (
|
||||
(element as { options?: { accordion?: { title?: string }; title?: string }; text?: string }).options
|
||||
?.accordion?.title ||
|
||||
(element as { options?: { accordion?: { title?: string }; title?: string }; text?: string }).options
|
||||
?.title ||
|
||||
(element as { options?: { accordion?: { title?: string }; title?: string }; text?: string }).text ||
|
||||
`Section ${index + 1}`
|
||||
);
|
||||
const getAccordionTitle = (element: UISchemaElement, index: number): string => {
|
||||
const el = element as BaseUISchemaElement & Labelable;
|
||||
const options = el.options;
|
||||
const accordionTitle = options?.accordion?.title;
|
||||
const title = options?.title;
|
||||
const text = el.label;
|
||||
return accordionTitle || title || text || `Section ${index + 1}`;
|
||||
};
|
||||
|
||||
// Get description for accordion item from element options
|
||||
const getAccordionDescription = (
|
||||
element: UISchemaElement & Record<string, unknown>,
|
||||
index: number
|
||||
): string => {
|
||||
return (
|
||||
(element as { options?: { accordion?: { description?: string }; description?: string } }).options
|
||||
?.accordion?.description ||
|
||||
(element as { options?: { accordion?: { description?: string }; description?: string } }).options
|
||||
?.description ||
|
||||
''
|
||||
);
|
||||
const getAccordionDescription = (element: UISchemaElement, _index: number): string => {
|
||||
const el = element as BaseUISchemaElement;
|
||||
const options = el.options;
|
||||
const accordionDescription = options?.accordion?.description;
|
||||
const description = options?.description;
|
||||
return accordionDescription || description || '';
|
||||
};
|
||||
</script>
|
||||
|
||||
424
web/__test__/components/Logs/SingleLogViewer.test.ts
Normal file
424
web/__test__/components/Logs/SingleLogViewer.test.ts
Normal file
@@ -0,0 +1,424 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { mount, flushPromises } from '@vue/test-utils';
|
||||
import { nextTick } from 'vue';
|
||||
import DOMPurify from 'isomorphic-dompurify';
|
||||
import { AnsiUp } from 'ansi_up';
|
||||
import { useQuery } from '@vue/apollo-composable';
|
||||
import SingleLogViewer from '~/components/Logs/SingleLogViewer.vue';
|
||||
import {
|
||||
createMockUseQuery,
|
||||
createMockLogFileQuery
|
||||
} from '../../helpers/apollo-mocks';
|
||||
|
||||
// Mock the UI components
|
||||
vi.mock('@unraid/ui', () => ({
|
||||
Button: { template: '<button><slot /></button>' },
|
||||
Tooltip: { template: '<div><slot /></div>' },
|
||||
TooltipContent: { template: '<div><slot /></div>' },
|
||||
TooltipProvider: { template: '<div><slot /></div>' },
|
||||
TooltipTrigger: { template: '<div><slot /></div>' }
|
||||
}));
|
||||
|
||||
// Mock the GraphQL query
|
||||
vi.mock('@vue/apollo-composable', () => ({
|
||||
useApolloClient: vi.fn(() => ({
|
||||
client: {
|
||||
query: vi.fn()
|
||||
}
|
||||
})),
|
||||
useQuery: vi.fn()
|
||||
}));
|
||||
|
||||
// Mock the theme store
|
||||
vi.mock('~/store/theme', () => ({
|
||||
useThemeStore: vi.fn(() => ({
|
||||
darkMode: false
|
||||
}))
|
||||
}));
|
||||
|
||||
describe('SingleLogViewer - ANSI Color Support', () => {
|
||||
let ansiConverter: AnsiUp;
|
||||
|
||||
beforeEach(() => {
|
||||
// Create a fresh converter instance for each test
|
||||
ansiConverter = new AnsiUp();
|
||||
ansiConverter.use_classes = true;
|
||||
ansiConverter.escape_html = true;
|
||||
});
|
||||
|
||||
describe('ANSI to HTML Conversion', () => {
|
||||
it('should convert ANSI color codes to CSS classes', () => {
|
||||
const testCases = [
|
||||
{
|
||||
input: '\x1b[31mRed text\x1b[0m',
|
||||
expected: '<span class="ansi-red-fg">Red text</span>',
|
||||
description: 'red foreground'
|
||||
},
|
||||
{
|
||||
input: '\x1b[32mGreen text\x1b[0m',
|
||||
expected: '<span class="ansi-green-fg">Green text</span>',
|
||||
description: 'green foreground'
|
||||
},
|
||||
{
|
||||
input: '\x1b[33mYellow text\x1b[0m',
|
||||
expected: '<span class="ansi-yellow-fg">Yellow text</span>',
|
||||
description: 'yellow foreground'
|
||||
},
|
||||
{
|
||||
input: '\x1b[34mBlue text\x1b[0m',
|
||||
expected: '<span class="ansi-blue-fg">Blue text</span>',
|
||||
description: 'blue foreground'
|
||||
},
|
||||
{
|
||||
input: '\x1b[91mBright red\x1b[0m',
|
||||
expected: '<span class="ansi-bright-red-fg">Bright red</span>',
|
||||
description: 'bright red foreground'
|
||||
},
|
||||
{
|
||||
input: '\x1b[41mRed background\x1b[0m',
|
||||
expected: '<span class="ansi-red-bg">Red background</span>',
|
||||
description: 'red background'
|
||||
},
|
||||
{
|
||||
input: '\x1b[1mBold text\x1b[0m',
|
||||
expected: '<span style="font-weight:bold">Bold text</span>',
|
||||
description: 'bold text (ansi_up uses inline style for bold)'
|
||||
},
|
||||
{
|
||||
input: '\x1b[3mItalic text\x1b[0m',
|
||||
expected: '<span style="font-style:italic">Italic text</span>',
|
||||
description: 'italic text (ansi_up uses inline style for italic)'
|
||||
},
|
||||
{
|
||||
input: '\x1b[4mUnderlined text\x1b[0m',
|
||||
expected: '<span style="text-decoration:underline">Underlined text</span>',
|
||||
description: 'underlined text (ansi_up uses inline style for underline)'
|
||||
}
|
||||
];
|
||||
|
||||
testCases.forEach(({ input, expected, description }) => {
|
||||
const result = ansiConverter.ansi_to_html(input);
|
||||
expect(result, `Failed for ${description}`).toBe(expected);
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle multiple ANSI codes in one string', () => {
|
||||
const input = '\x1b[31mRed\x1b[0m \x1b[32mGreen\x1b[0m \x1b[34mBlue\x1b[0m';
|
||||
const expected = '<span class="ansi-red-fg">Red</span> <span class="ansi-green-fg">Green</span> <span class="ansi-blue-fg">Blue</span>';
|
||||
const result = ansiConverter.ansi_to_html(input);
|
||||
expect(result).toBe(expected);
|
||||
});
|
||||
|
||||
it('should handle nested ANSI codes', () => {
|
||||
const input = '\x1b[1m\x1b[31mBold Red Text\x1b[0m';
|
||||
const result = ansiConverter.ansi_to_html(input);
|
||||
// ansi_up uses inline style for bold
|
||||
expect(result).toContain('font-weight:bold');
|
||||
expect(result).toContain('ansi-red-fg');
|
||||
});
|
||||
|
||||
it('should escape HTML entities for security', () => {
|
||||
const input = '\x1b[31m<script>alert("XSS")</script>\x1b[0m';
|
||||
const result = ansiConverter.ansi_to_html(input);
|
||||
expect(result).not.toContain('<script>');
|
||||
expect(result).toContain('<script>');
|
||||
});
|
||||
});
|
||||
|
||||
describe('DOMPurify Sanitization', () => {
|
||||
it('should preserve CSS classes after sanitization', () => {
|
||||
const htmlWithClasses = '<span class="ansi-red-fg">Red text</span>';
|
||||
const sanitized = DOMPurify.sanitize(htmlWithClasses, {
|
||||
ALLOWED_TAGS: ['span', 'br'],
|
||||
ALLOWED_ATTR: ['class']
|
||||
});
|
||||
expect(sanitized).toBe(htmlWithClasses);
|
||||
});
|
||||
|
||||
it('should remove inline styles when configured', () => {
|
||||
const htmlWithStyles = '<span style="color: red;">Red text</span>';
|
||||
const sanitized = DOMPurify.sanitize(htmlWithStyles, {
|
||||
ALLOWED_TAGS: ['span', 'br'],
|
||||
ALLOWED_ATTR: ['class'] // Note: 'style' is not allowed
|
||||
});
|
||||
expect(sanitized).toBe('<span>Red text</span>');
|
||||
});
|
||||
|
||||
it('should remove dangerous tags while preserving safe content', () => {
|
||||
const dangerous = '<span class="ansi-red-fg">Safe</span><script>alert("XSS")</script>';
|
||||
const sanitized = DOMPurify.sanitize(dangerous, {
|
||||
ALLOWED_TAGS: ['span', 'br'],
|
||||
ALLOWED_ATTR: ['class']
|
||||
});
|
||||
expect(sanitized).toBe('<span class="ansi-red-fg">Safe</span>');
|
||||
});
|
||||
|
||||
it('should handle complex nested structures', () => {
|
||||
const complex = '<span class="ansi-bold"><span class="ansi-red-fg">Bold Red</span></span>';
|
||||
const sanitized = DOMPurify.sanitize(complex, {
|
||||
ALLOWED_TAGS: ['span', 'br'],
|
||||
ALLOWED_ATTR: ['class']
|
||||
});
|
||||
expect(sanitized).toBe(complex);
|
||||
});
|
||||
});
|
||||
|
||||
describe('CSS Class Definitions', () => {
|
||||
it('should have CSS rules for all standard ANSI colors', async () => {
|
||||
// Mock useQuery to return empty data for this test
|
||||
// @ts-expect-error Mock implementation for testing
|
||||
vi.mocked(useQuery).mockReturnValue(createMockUseQuery());
|
||||
|
||||
const wrapper = mount(SingleLogViewer, {
|
||||
props: {
|
||||
logFilePath: '/test/log.txt',
|
||||
lineCount: 100,
|
||||
autoScroll: false
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
Button: true,
|
||||
Tooltip: true,
|
||||
TooltipContent: true,
|
||||
TooltipProvider: true,
|
||||
TooltipTrigger: true
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Wait for component to mount
|
||||
await nextTick();
|
||||
|
||||
// Check that the component mounts without errors
|
||||
expect(wrapper.exists()).toBe(true);
|
||||
|
||||
wrapper.unmount();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Integration Tests', () => {
|
||||
beforeEach(() => {
|
||||
// Reset mocks before each test
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should properly render ANSI colored log content', async () => {
|
||||
// Create mock data
|
||||
const content = '\x1b[31m[ERROR]\x1b[0m Failed to connect\n\x1b[32m[SUCCESS]\x1b[0m Connected';
|
||||
const mockQuery = createMockLogFileQuery(content, 2, 1);
|
||||
|
||||
// Mock useQuery to return our data
|
||||
// @ts-expect-error Mock implementation for testing
|
||||
vi.mocked(useQuery).mockReturnValue(mockQuery);
|
||||
|
||||
const wrapper = mount(SingleLogViewer, {
|
||||
props: {
|
||||
logFilePath: '/test/log.txt',
|
||||
lineCount: 100,
|
||||
autoScroll: false
|
||||
}
|
||||
});
|
||||
|
||||
// Wait for the component to mount and process initial data
|
||||
await wrapper.vm.$nextTick();
|
||||
|
||||
// Trigger the watcher by modifying the result
|
||||
// @ts-expect-error Accessing mock properties
|
||||
if (mockQuery.result.value) {
|
||||
// @ts-expect-error Modifying mock properties
|
||||
mockQuery.result.value = {
|
||||
logFile: {
|
||||
content,
|
||||
totalLines: 2,
|
||||
startLine: 1
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Wait for watchers to process
|
||||
await wrapper.vm.$nextTick();
|
||||
await flushPromises();
|
||||
await wrapper.vm.$nextTick();
|
||||
|
||||
// Get the pre element that contains the log content
|
||||
const preElement = wrapper.find('pre.hljs');
|
||||
expect(preElement.exists()).toBe(true);
|
||||
|
||||
// Check that the rendered HTML contains the CSS classes
|
||||
const html = preElement.html();
|
||||
if (!html.includes('ansi-red-fg')) {
|
||||
console.log('Pre element HTML:', html);
|
||||
console.log('Full wrapper HTML:', wrapper.html());
|
||||
}
|
||||
expect(html).toContain('ansi-red-fg');
|
||||
expect(html).toContain('[ERROR]');
|
||||
expect(html).toContain('ansi-green-fg');
|
||||
expect(html).toContain('[SUCCESS]');
|
||||
|
||||
wrapper.unmount();
|
||||
});
|
||||
|
||||
it('should handle log content with mixed ANSI and plain text', async () => {
|
||||
const content = 'Plain text \x1b[33mWarning\x1b[0m more plain text';
|
||||
const mockQuery = createMockLogFileQuery(content, 1, 1);
|
||||
// @ts-expect-error Mock implementation for testing
|
||||
vi.mocked(useQuery).mockReturnValue(mockQuery);
|
||||
|
||||
const wrapper = mount(SingleLogViewer, {
|
||||
props: {
|
||||
logFilePath: '/test/log.txt',
|
||||
lineCount: 100,
|
||||
autoScroll: false
|
||||
}
|
||||
});
|
||||
|
||||
// Wait for mount and trigger the watcher
|
||||
await wrapper.vm.$nextTick();
|
||||
|
||||
// @ts-expect-error Accessing mock properties
|
||||
if (mockQuery.result.value) {
|
||||
// @ts-expect-error Modifying mock properties
|
||||
mockQuery.result.value = {
|
||||
logFile: {
|
||||
content,
|
||||
totalLines: 1,
|
||||
startLine: 1
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Wait for processing
|
||||
await wrapper.vm.$nextTick();
|
||||
await flushPromises();
|
||||
await wrapper.vm.$nextTick();
|
||||
|
||||
const preElement = wrapper.find('pre.hljs');
|
||||
expect(preElement.exists()).toBe(true);
|
||||
|
||||
const html = preElement.html();
|
||||
expect(html).toContain('Plain text');
|
||||
expect(html).toContain('ansi-yellow-fg');
|
||||
expect(html).toContain('Warning');
|
||||
expect(html).toContain('more plain text');
|
||||
|
||||
wrapper.unmount();
|
||||
});
|
||||
|
||||
it('should apply client-side filtering while preserving ANSI colors', async () => {
|
||||
const content = '\x1b[31m[ERROR]\x1b[0m Connection failed\n\x1b[32m[INFO]\x1b[0m Connected\n\x1b[31m[ERROR]\x1b[0m Timeout';
|
||||
const mockQuery = createMockLogFileQuery(content, 3, 1);
|
||||
// @ts-expect-error Mock implementation for testing
|
||||
vi.mocked(useQuery).mockReturnValue(mockQuery);
|
||||
|
||||
const wrapper = mount(SingleLogViewer, {
|
||||
props: {
|
||||
logFilePath: '/test/log.txt',
|
||||
lineCount: 100,
|
||||
autoScroll: false,
|
||||
clientFilter: 'ERROR'
|
||||
}
|
||||
});
|
||||
|
||||
// Wait for mount and trigger the watcher
|
||||
await wrapper.vm.$nextTick();
|
||||
|
||||
// @ts-expect-error Accessing mock properties
|
||||
if (mockQuery.result.value) {
|
||||
// @ts-expect-error Modifying mock properties
|
||||
mockQuery.result.value = {
|
||||
logFile: {
|
||||
content,
|
||||
totalLines: 3,
|
||||
startLine: 1
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Wait for processing
|
||||
await wrapper.vm.$nextTick();
|
||||
await flushPromises();
|
||||
await wrapper.vm.$nextTick();
|
||||
|
||||
const preElement = wrapper.find('pre.hljs');
|
||||
expect(preElement.exists()).toBe(true);
|
||||
|
||||
const html = preElement.html();
|
||||
// Should contain ERROR lines with red color
|
||||
expect(html).toContain('ansi-red-fg');
|
||||
expect(html).toContain('ERROR');
|
||||
// Should not contain INFO line (due to filter)
|
||||
expect(html).not.toContain('INFO');
|
||||
expect(html).not.toContain('ansi-green-fg');
|
||||
|
||||
wrapper.unmount();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Performance Tests', () => {
|
||||
it('should handle large amounts of ANSI colored text efficiently', () => {
|
||||
const lines = 1000;
|
||||
const largeInput = Array(lines)
|
||||
.fill(null)
|
||||
.map((_, i) => {
|
||||
const colors = ['31', '32', '33', '34', '35', '36'];
|
||||
const color = colors[i % colors.length];
|
||||
return `\x1b[${color}mLine ${i}: Some log message with color\x1b[0m`;
|
||||
})
|
||||
.join('\n');
|
||||
|
||||
const result = ansiConverter.ansi_to_html(largeInput);
|
||||
|
||||
// Should contain the expected number of color spans
|
||||
const colorMatches = result.match(/class="ansi-/g);
|
||||
expect(colorMatches).toHaveLength(lines);
|
||||
});
|
||||
|
||||
it('should efficiently sanitize large HTML with many CSS classes', () => {
|
||||
const lines = 1000;
|
||||
const largeHtml = Array(lines)
|
||||
.fill(null)
|
||||
.map((_, i) => `<span class="ansi-red-fg">Line ${i}</span>`)
|
||||
.join('\n');
|
||||
|
||||
const sanitized = DOMPurify.sanitize(largeHtml, {
|
||||
ALLOWED_TAGS: ['span', 'br'],
|
||||
ALLOWED_ATTR: ['class']
|
||||
});
|
||||
|
||||
// Should preserve all spans
|
||||
const spanMatches = sanitized.match(/<span/g);
|
||||
expect(spanMatches).toHaveLength(lines);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle empty input gracefully', () => {
|
||||
const result = ansiConverter.ansi_to_html('');
|
||||
expect(result).toBe('');
|
||||
});
|
||||
|
||||
it('should handle input with no ANSI codes', () => {
|
||||
const plainText = 'This is plain text without any colors';
|
||||
const result = ansiConverter.ansi_to_html(plainText);
|
||||
expect(result).toBe(plainText);
|
||||
});
|
||||
|
||||
it('should handle malformed ANSI codes', () => {
|
||||
const malformed = '\x1b[999mInvalid color code\x1b[0m';
|
||||
// Should not throw an error
|
||||
expect(() => ansiConverter.ansi_to_html(malformed)).not.toThrow();
|
||||
});
|
||||
|
||||
it('should handle incomplete ANSI sequences', () => {
|
||||
const incomplete = '\x1b[31mRed text without reset';
|
||||
const result = ansiConverter.ansi_to_html(incomplete);
|
||||
expect(result).toContain('ansi-red-fg');
|
||||
});
|
||||
|
||||
it('should handle ANSI codes at the beginning and end of lines', () => {
|
||||
const input = '\x1b[31mStart\nMiddle\nEnd\x1b[0m';
|
||||
const result = ansiConverter.ansi_to_html(input);
|
||||
expect(result).toContain('ansi-red-fg');
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -59,6 +59,8 @@ const mockLocation = {
|
||||
hash: '',
|
||||
origin: 'http://mock-origin.com',
|
||||
pathname: '/login',
|
||||
protocol: 'http:',
|
||||
host: 'mock-origin.com',
|
||||
get href() {
|
||||
return mockLocationHref;
|
||||
},
|
||||
@@ -253,7 +255,8 @@ describe('SsoButtons', () => {
|
||||
expect(sessionStorage.setItem).toHaveBeenCalledWith('sso_provider', 'unraid-net');
|
||||
|
||||
const generatedState = (sessionStorage.setItem as Mock).mock.calls[0][1];
|
||||
const expectedUrl = `/graphql/api/auth/oidc/authorize/unraid-net?state=${encodeURIComponent(generatedState)}`;
|
||||
const redirectUri = `${mockLocation.origin}/graphql/api/auth/oidc/callback`;
|
||||
const expectedUrl = `/graphql/api/auth/oidc/authorize/unraid-net?state=${encodeURIComponent(generatedState)}&redirect_uri=${encodeURIComponent(redirectUri)}`;
|
||||
|
||||
expect(mockLocation.href).toBe(expectedUrl);
|
||||
});
|
||||
@@ -377,6 +380,57 @@ describe('SsoButtons', () => {
|
||||
expect(mockLocation.href).toBe(expectedUrl);
|
||||
});
|
||||
|
||||
it('handles HTTPS with non-standard port correctly', async () => {
|
||||
const mockProviders = [
|
||||
{
|
||||
id: 'tsidp',
|
||||
name: 'Tailscale IDP',
|
||||
buttonText: 'Sign in with Tailscale',
|
||||
buttonIcon: null,
|
||||
buttonVariant: 'secondary',
|
||||
buttonStyle: null
|
||||
}
|
||||
];
|
||||
|
||||
// Set up location with HTTPS and non-standard port
|
||||
mockLocation.protocol = 'https:';
|
||||
mockLocation.host = 'unraid.mytailnet.ts.net:1443';
|
||||
mockLocation.origin = 'https://unraid.mytailnet.ts.net:1443';
|
||||
|
||||
mockUseQuery.mockReturnValue({
|
||||
result: { value: { publicOidcProviders: mockProviders } },
|
||||
refetch: vi.fn().mockResolvedValue({ data: { publicOidcProviders: mockProviders } }),
|
||||
});
|
||||
|
||||
const wrapper = mount(SsoButtons, {
|
||||
global: {
|
||||
stubs: {
|
||||
SsoProviderButton: SsoProviderButtonStub,
|
||||
Button: { template: '<button><slot /></button>' }
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
await flushPromises();
|
||||
vi.runAllTimers();
|
||||
await flushPromises();
|
||||
|
||||
const button = wrapper.find('button');
|
||||
await button.trigger('click');
|
||||
|
||||
// Should include the correct redirect URI with HTTPS and port 1443
|
||||
const generatedState = (sessionStorage.setItem as Mock).mock.calls[0][1];
|
||||
const redirectUri = 'https://unraid.mytailnet.ts.net:1443/graphql/api/auth/oidc/callback';
|
||||
const expectedUrl = `/graphql/api/auth/oidc/authorize/tsidp?state=${encodeURIComponent(generatedState)}&redirect_uri=${encodeURIComponent(redirectUri)}`;
|
||||
|
||||
expect(mockLocation.href).toBe(expectedUrl);
|
||||
|
||||
// Reset location mock for other tests
|
||||
mockLocation.protocol = 'http:';
|
||||
mockLocation.host = 'mock-origin.com';
|
||||
mockLocation.origin = 'http://mock-origin.com';
|
||||
});
|
||||
|
||||
it('handles multiple OIDC providers', async () => {
|
||||
const mockProviders = [
|
||||
{
|
||||
|
||||
@@ -1,175 +0,0 @@
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { useApiKeyAuthorization } from '~/composables/useApiKeyAuthorization';
|
||||
import { AuthAction, Resource, Role } from '~/composables/gql/graphql';
|
||||
|
||||
describe('useApiKeyAuthorization', () => {
|
||||
describe('parameter parsing', () => {
|
||||
it('should parse query parameters correctly', () => {
|
||||
const params = new URLSearchParams('?name=TestApp&scopes=docker:read,vms:*&redirect_uri=https://example.com&state=abc123');
|
||||
const { authParams } = useApiKeyAuthorization(params);
|
||||
|
||||
expect(authParams.value.name).toBe('TestApp');
|
||||
expect(authParams.value.scopes).toEqual(['docker:read', 'vms:*']);
|
||||
expect(authParams.value.redirectUri).toBe('https://example.com');
|
||||
expect(authParams.value.state).toBe('abc123');
|
||||
});
|
||||
|
||||
it('should handle missing parameters with defaults', () => {
|
||||
const params = new URLSearchParams('');
|
||||
const { authParams } = useApiKeyAuthorization(params);
|
||||
|
||||
expect(authParams.value.name).toBe('Unknown Application');
|
||||
expect(authParams.value.scopes).toEqual([]);
|
||||
expect(authParams.value.redirectUri).toBe('');
|
||||
expect(authParams.value.state).toBe('');
|
||||
});
|
||||
});
|
||||
|
||||
describe('formatPermissions', () => {
|
||||
it('should format role scopes correctly', () => {
|
||||
const params = new URLSearchParams('?scopes=role:admin,role:viewer');
|
||||
const { formattedPermissions } = useApiKeyAuthorization(params);
|
||||
|
||||
expect(formattedPermissions.value).toEqual([
|
||||
{
|
||||
scope: 'role:admin',
|
||||
name: 'ADMIN',
|
||||
description: 'Grant admin role access',
|
||||
isRole: true,
|
||||
},
|
||||
{
|
||||
scope: 'role:viewer',
|
||||
name: 'VIEWER',
|
||||
description: 'Grant viewer role access',
|
||||
isRole: true,
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should format resource:action scopes correctly', () => {
|
||||
const params = new URLSearchParams('?scopes=docker:read,vms:*');
|
||||
const { formattedPermissions } = useApiKeyAuthorization(params);
|
||||
|
||||
expect(formattedPermissions.value).toEqual([
|
||||
{
|
||||
scope: 'docker:read',
|
||||
name: 'Docker - Read',
|
||||
description: 'Read access to Docker',
|
||||
isRole: false,
|
||||
},
|
||||
{
|
||||
scope: 'vms:*',
|
||||
name: 'Vms - Full',
|
||||
description: 'Full access to Vms',
|
||||
isRole: false,
|
||||
},
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('convertScopesToPermissions', () => {
|
||||
it('should convert role scopes to roles', () => {
|
||||
const params = new URLSearchParams('?scopes=role:admin');
|
||||
const { convertScopesToPermissions } = useApiKeyAuthorization(params);
|
||||
const result = convertScopesToPermissions(['role:admin']);
|
||||
|
||||
expect(result.roles).toContain(Role.ADMIN);
|
||||
expect(result.permissions).toEqual([]);
|
||||
});
|
||||
|
||||
it('should convert resource scopes to permissions', () => {
|
||||
const params = new URLSearchParams('?scopes=docker:read');
|
||||
const { convertScopesToPermissions } = useApiKeyAuthorization(params);
|
||||
const result = convertScopesToPermissions(['docker:read']);
|
||||
|
||||
expect(result.permissions).toEqual([
|
||||
{
|
||||
resource: Resource.DOCKER,
|
||||
actions: [AuthAction.READ_ANY],
|
||||
},
|
||||
]);
|
||||
expect(result.roles).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle wildcard actions', () => {
|
||||
const params = new URLSearchParams('?scopes=vms:*');
|
||||
const { convertScopesToPermissions } = useApiKeyAuthorization(params);
|
||||
const result = convertScopesToPermissions(['vms:*']);
|
||||
|
||||
expect(result.permissions).toEqual([
|
||||
{
|
||||
resource: Resource.VMS,
|
||||
actions: [AuthAction.CREATE_ANY, AuthAction.READ_ANY, AuthAction.UPDATE_ANY, AuthAction.DELETE_ANY],
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should merge multiple actions for same resource', () => {
|
||||
const params = new URLSearchParams('');
|
||||
const { convertScopesToPermissions } = useApiKeyAuthorization(params);
|
||||
const result = convertScopesToPermissions(['docker:read', 'docker:update']);
|
||||
|
||||
expect(result.permissions).toEqual([
|
||||
{
|
||||
resource: Resource.DOCKER,
|
||||
actions: [AuthAction.READ_ANY, AuthAction.UPDATE_ANY],
|
||||
},
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('redirect URI validation', () => {
|
||||
it('should accept HTTPS URLs', () => {
|
||||
const params = new URLSearchParams('?redirect_uri=https://example.com/callback');
|
||||
const { hasValidRedirectUri } = useApiKeyAuthorization(params);
|
||||
|
||||
expect(hasValidRedirectUri.value).toBe(true);
|
||||
});
|
||||
|
||||
it('should accept localhost URLs', () => {
|
||||
const params = new URLSearchParams('?redirect_uri=http://localhost:3000/callback');
|
||||
const { hasValidRedirectUri } = useApiKeyAuthorization(params);
|
||||
|
||||
expect(hasValidRedirectUri.value).toBe(true);
|
||||
});
|
||||
|
||||
it('should accept HTTP URLs (non-localhost)', () => {
|
||||
const params = new URLSearchParams('?redirect_uri=http://example.com/callback');
|
||||
const { hasValidRedirectUri } = useApiKeyAuthorization(params);
|
||||
|
||||
expect(hasValidRedirectUri.value).toBe(true);
|
||||
});
|
||||
|
||||
it('should reject invalid URLs', () => {
|
||||
const params = new URLSearchParams('?redirect_uri=not-a-url');
|
||||
const { hasValidRedirectUri } = useApiKeyAuthorization(params);
|
||||
|
||||
expect(hasValidRedirectUri.value).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('buildCallbackUrl', () => {
|
||||
it('should build callback URL with API key', () => {
|
||||
const params = new URLSearchParams('');
|
||||
const { buildCallbackUrl } = useApiKeyAuthorization(params);
|
||||
const url = buildCallbackUrl('https://example.com/callback', 'test-key', undefined, 'state123');
|
||||
|
||||
expect(url).toBe('https://example.com/callback?api_key=test-key&state=state123');
|
||||
});
|
||||
|
||||
it('should build callback URL with error', () => {
|
||||
const params = new URLSearchParams('');
|
||||
const { buildCallbackUrl } = useApiKeyAuthorization(params);
|
||||
const url = buildCallbackUrl('https://example.com/callback', undefined, 'access_denied', 'state123');
|
||||
|
||||
expect(url).toBe('https://example.com/callback?error=access_denied&state=state123');
|
||||
});
|
||||
|
||||
it('should throw for invalid redirect URI', () => {
|
||||
const params = new URLSearchParams('');
|
||||
const { buildCallbackUrl } = useApiKeyAuthorization(params);
|
||||
|
||||
expect(() => buildCallbackUrl('not-a-url', 'key')).toThrow('Invalid redirect URI');
|
||||
});
|
||||
});
|
||||
});
|
||||
86
web/__test__/helpers/apollo-mocks.ts
Normal file
86
web/__test__/helpers/apollo-mocks.ts
Normal file
@@ -0,0 +1,86 @@
|
||||
import { ref } from 'vue';
|
||||
import { vi } from 'vitest';
|
||||
|
||||
// Types for mock data
|
||||
export interface MockLogFile {
|
||||
content: string;
|
||||
totalLines: number;
|
||||
startLine: number;
|
||||
}
|
||||
|
||||
export interface MockQueryResult {
|
||||
logFile: MockLogFile;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a mock useQuery return value with optional result data
|
||||
* Using unknown return type to avoid complex Apollo type issues in tests
|
||||
*/
|
||||
export function createMockUseQuery<TData = unknown>(
|
||||
resultData: TData | null = null,
|
||||
options: {
|
||||
loading?: boolean;
|
||||
error?: unknown;
|
||||
} = {}
|
||||
): unknown {
|
||||
return {
|
||||
result: ref(resultData),
|
||||
loading: ref(options.loading ?? false),
|
||||
error: ref(options.error ?? null),
|
||||
refetch: vi.fn(() => Promise.resolve({
|
||||
data: resultData || {},
|
||||
loading: false,
|
||||
networkStatus: 7,
|
||||
stale: false,
|
||||
error: undefined
|
||||
})),
|
||||
subscribeToMore: vi.fn(),
|
||||
networkStatus: ref(7),
|
||||
start: vi.fn(),
|
||||
stop: vi.fn(),
|
||||
restart: vi.fn(),
|
||||
forceDisabled: ref(false),
|
||||
document: ref(null),
|
||||
variables: ref({}),
|
||||
options: {},
|
||||
query: ref(null),
|
||||
fetchMore: vi.fn(),
|
||||
updateQuery: vi.fn(),
|
||||
onResult: vi.fn(),
|
||||
onError: vi.fn()
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a mock useQuery specifically for log file data
|
||||
*/
|
||||
export function createMockLogFileQuery(
|
||||
content: string,
|
||||
totalLines: number,
|
||||
startLine: number = 1
|
||||
): unknown {
|
||||
const result: MockQueryResult = {
|
||||
logFile: {
|
||||
content,
|
||||
totalLines,
|
||||
startLine
|
||||
}
|
||||
};
|
||||
|
||||
return createMockUseQuery(result);
|
||||
}
|
||||
|
||||
/**
|
||||
* Factory function to create the mock module object for @vue/apollo-composable
|
||||
* Call this at the top level of test files: vi.mock('@vue/apollo-composable', () => apolloComposableMockFactory())
|
||||
*/
|
||||
export function apolloComposableMockFactory() {
|
||||
return {
|
||||
useApolloClient: vi.fn(() => ({
|
||||
client: {
|
||||
query: vi.fn()
|
||||
}
|
||||
})),
|
||||
useQuery: vi.fn(() => createMockUseQuery())
|
||||
};
|
||||
}
|
||||
@@ -16,6 +16,7 @@ import { useServerStore } from '~/store/server';
|
||||
// import type { ConnectSettingsValues } from '~/composables/gql/graphql';
|
||||
|
||||
import { getConnectSettingsForm, updateConnectSettings } from './graphql/settings.query';
|
||||
import OidcDebugLogs from './OidcDebugLogs.vue';
|
||||
|
||||
const { connectPluginInstalled } = storeToRefs(useServerStore());
|
||||
|
||||
@@ -120,6 +121,9 @@ const onChange = ({ data }: { data: Record<string, unknown> }) => {
|
||||
:readonly="isUpdating"
|
||||
@change="onChange"
|
||||
/>
|
||||
<!-- OIDC Debug Logs -->
|
||||
<OidcDebugLogs />
|
||||
|
||||
<!-- form submission & fallback reaction message -->
|
||||
<div class="mt-6 grid grid-cols-settings gap-y-6 items-baseline">
|
||||
<div class="text-sm text-end">
|
||||
|
||||
59
web/components/ConnectSettings/OidcDebugLogs.vue
Normal file
59
web/components/ConnectSettings/OidcDebugLogs.vue
Normal file
@@ -0,0 +1,59 @@
|
||||
<script setup lang="ts">
|
||||
import { ref } from 'vue';
|
||||
import SingleLogViewer from '../Logs/SingleLogViewer.vue';
|
||||
import LogViewerToolbar from '../Logs/LogViewerToolbar.vue';
|
||||
|
||||
const showLogs = ref(false);
|
||||
const autoScroll = ref(true);
|
||||
const filterText = ref('OIDC');
|
||||
const logViewerRef = ref<InstanceType<typeof SingleLogViewer> | null>(null);
|
||||
|
||||
const logFilePath = '/var/log/graphql-api.log';
|
||||
|
||||
const refreshLogs = () => {
|
||||
logViewerRef.value?.refreshLogContent();
|
||||
};
|
||||
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div class="mt-6 border-2 border-solid rounded-md shadow-md bg-background border-muted">
|
||||
<LogViewerToolbar
|
||||
v-model:filter-text="filterText"
|
||||
v-model:is-expanded="showLogs"
|
||||
title="OIDC Debug Logs"
|
||||
description="View real-time OIDC authentication and configuration logs"
|
||||
:show-toggle="true"
|
||||
:show-refresh="true"
|
||||
filter-placeholder="Filter logs..."
|
||||
@refresh="refreshLogs"
|
||||
/>
|
||||
|
||||
<div v-if="showLogs" class="p-4 pt-0">
|
||||
<div class="border rounded-lg bg-muted/30 h-[400px] overflow-hidden">
|
||||
<SingleLogViewer
|
||||
ref="logViewerRef"
|
||||
:log-file-path="logFilePath"
|
||||
:line-count="100"
|
||||
:auto-scroll="autoScroll"
|
||||
:client-filter="filterText"
|
||||
highlight-language="plaintext"
|
||||
class="h-full"
|
||||
/>
|
||||
</div>
|
||||
<div class="mt-2 flex justify-between items-center text-xs text-muted-foreground">
|
||||
<span>
|
||||
{{ filterText ? `Filtering logs for: "${filterText}"` : 'Showing all log entries' }}
|
||||
</span>
|
||||
<label class="flex items-center gap-2 cursor-pointer">
|
||||
<input
|
||||
v-model="autoScroll"
|
||||
type="checkbox"
|
||||
class="rounded border-gray-300"
|
||||
>
|
||||
<span>Auto-scroll</span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
82
web/components/FileViewer.vue
Normal file
82
web/components/FileViewer.vue
Normal file
@@ -0,0 +1,82 @@
|
||||
<script setup lang="ts">
|
||||
import { computed } from 'vue';
|
||||
import { useContentHighlighting } from '~/composables/useContentHighlighting';
|
||||
|
||||
const props = defineProps<{
|
||||
content: string;
|
||||
language?: string;
|
||||
showLineNumbers?: boolean;
|
||||
maxHeight?: string;
|
||||
class?: string;
|
||||
}>();
|
||||
|
||||
const { highlightContent } = useContentHighlighting();
|
||||
|
||||
const highlightedContent = computed(() => {
|
||||
return highlightContent(props.content, props.language);
|
||||
});
|
||||
|
||||
const lines = computed(() => {
|
||||
return props.content.split('\n');
|
||||
});
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div
|
||||
:class="[
|
||||
'file-viewer-container',
|
||||
'relative rounded border bg-background text-foreground overflow-hidden',
|
||||
props.class
|
||||
]"
|
||||
:style="{ height: maxHeight || '300px' }"
|
||||
>
|
||||
<div class="absolute inset-0 overflow-auto">
|
||||
<div class="flex min-w-full">
|
||||
<!-- Line numbers -->
|
||||
<div
|
||||
v-if="showLineNumbers"
|
||||
class="flex-shrink-0 select-none border-r bg-muted/50 px-2 py-2 text-xs font-mono text-muted-foreground"
|
||||
>
|
||||
<div v-for="(_, index) in lines" :key="index" class="leading-5 text-right pr-2">
|
||||
{{ index + 1 }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Content -->
|
||||
<div class="flex-1 min-w-0">
|
||||
<pre
|
||||
class="p-3 text-xs font-mono leading-5 whitespace-pre m-0"
|
||||
v-html="highlightedContent"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
/* Add some basic styling for the highlighted content */
|
||||
:deep(.hljs) {
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
/* ANSI color classes */
|
||||
:deep(.ansi-bright-black) { color: #666; }
|
||||
:deep(.ansi-bright-red) { color: #ff6b6b; }
|
||||
:deep(.ansi-bright-green) { color: #51cf66; }
|
||||
:deep(.ansi-bright-yellow) { color: #ffd43b; }
|
||||
:deep(.ansi-bright-blue) { color: #339af0; }
|
||||
:deep(.ansi-bright-magenta) { color: #f06292; }
|
||||
:deep(.ansi-bright-cyan) { color: #22d3ee; }
|
||||
:deep(.ansi-bright-white) { color: #f8f9fa; }
|
||||
|
||||
/* Standard ANSI colors for dark theme */
|
||||
:deep(.ansi-black) { color: #000; }
|
||||
:deep(.ansi-red) { color: #e03131; }
|
||||
:deep(.ansi-green) { color: #2f9e44; }
|
||||
:deep(.ansi-yellow) { color: #f59f00; }
|
||||
:deep(.ansi-blue) { color: #1971c2; }
|
||||
:deep(.ansi-magenta) { color: #c2255c; }
|
||||
:deep(.ansi-cyan) { color: #0891b2; }
|
||||
:deep(.ansi-white) { color: #495057; }
|
||||
</style>
|
||||
67
web/components/Logs/FilteredLogModal.vue
Normal file
67
web/components/Logs/FilteredLogModal.vue
Normal file
@@ -0,0 +1,67 @@
|
||||
<script setup lang="ts">
|
||||
import { ref, computed } from 'vue';
|
||||
import { Dialog } from '@unraid/ui';
|
||||
import SingleLogViewer from './SingleLogViewer.vue';
|
||||
|
||||
interface Props {
|
||||
modelValue: boolean;
|
||||
logFilePath: string;
|
||||
filter?: string;
|
||||
title?: string;
|
||||
description?: string;
|
||||
lineCount?: number;
|
||||
autoScroll?: boolean;
|
||||
highlightLanguage?: string;
|
||||
size?: 'sm' | 'md' | 'lg' | 'xl' | 'full';
|
||||
}
|
||||
|
||||
const props = withDefaults(defineProps<Props>(), {
|
||||
lineCount: 100,
|
||||
autoScroll: true,
|
||||
highlightLanguage: 'plaintext',
|
||||
size: 'xl',
|
||||
title: 'Log Viewer',
|
||||
filter: undefined,
|
||||
description: undefined,
|
||||
});
|
||||
|
||||
const emit = defineEmits<{
|
||||
'update:modelValue': [value: boolean];
|
||||
}>();
|
||||
|
||||
const logViewerRef = ref<InstanceType<typeof SingleLogViewer> | null>(null);
|
||||
|
||||
const fullLogPath = computed(() => {
|
||||
if (props.logFilePath.startsWith('/')) {
|
||||
return props.logFilePath;
|
||||
}
|
||||
return `/var/log/${props.logFilePath}`;
|
||||
});
|
||||
|
||||
const handleOpenChange = (open: boolean) => {
|
||||
emit('update:modelValue', open);
|
||||
};
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<Dialog
|
||||
:model-value="modelValue"
|
||||
:title="title"
|
||||
:description="description"
|
||||
:size="size"
|
||||
:show-footer="false"
|
||||
@update:model-value="handleOpenChange"
|
||||
>
|
||||
<div class="h-[600px] flex flex-col">
|
||||
<SingleLogViewer
|
||||
ref="logViewerRef"
|
||||
:log-file-path="fullLogPath"
|
||||
:line-count="lineCount"
|
||||
:auto-scroll="autoScroll"
|
||||
:highlight-language="highlightLanguage"
|
||||
:filter="filter"
|
||||
class="flex-1"
|
||||
/>
|
||||
</div>
|
||||
</Dialog>
|
||||
</template>
|
||||
86
web/components/Logs/LogFilterInput.vue
Normal file
86
web/components/Logs/LogFilterInput.vue
Normal file
@@ -0,0 +1,86 @@
|
||||
<script setup lang="ts">
|
||||
import { computed } from 'vue';
|
||||
import { Input, Label, Select } from '@unraid/ui';
|
||||
import type { SelectItemType } from '@unraid/ui';
|
||||
import { MagnifyingGlassIcon } from '@heroicons/vue/24/outline';
|
||||
|
||||
const props = withDefaults(defineProps<{
|
||||
modelValue: string;
|
||||
preset?: string;
|
||||
showPresets?: boolean;
|
||||
presetFilters?: SelectItemType[];
|
||||
inputClass?: string;
|
||||
placeholder?: string;
|
||||
label?: string;
|
||||
showIcon?: boolean;
|
||||
}>(), {
|
||||
preset: 'none',
|
||||
showPresets: false,
|
||||
presetFilters: () => [
|
||||
{ value: 'none', label: 'No Filter' },
|
||||
{ value: 'OIDC', label: 'OIDC Logs' },
|
||||
{ value: 'ERROR', label: 'Errors' },
|
||||
{ value: 'WARNING', label: 'Warnings' },
|
||||
{ value: 'AUTH', label: 'Authentication' },
|
||||
],
|
||||
placeholder: 'Filter logs...',
|
||||
label: 'Filter',
|
||||
showIcon: true,
|
||||
inputClass: ''
|
||||
});
|
||||
|
||||
const emit = defineEmits<{
|
||||
'update:modelValue': [value: string];
|
||||
'update:preset': [value: string];
|
||||
}>();
|
||||
|
||||
const filterText = computed({
|
||||
get: () => props.modelValue,
|
||||
set: (value) => emit('update:modelValue', value)
|
||||
});
|
||||
|
||||
const presetValue = computed({
|
||||
get: () => props.preset || 'none',
|
||||
set: (value) => {
|
||||
emit('update:preset', value);
|
||||
if (value && value !== 'none') {
|
||||
emit('update:modelValue', value);
|
||||
} else if (value === 'none') {
|
||||
emit('update:modelValue', '');
|
||||
}
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div class="flex gap-2 items-end">
|
||||
<div v-if="showPresets" class="min-w-[150px]">
|
||||
<Label v-if="label" :for="`preset-filter-${$.uid}`">Quick {{ label }}</Label>
|
||||
<Select
|
||||
:id="`preset-filter-${$.uid}`"
|
||||
v-model="presetValue"
|
||||
:items="presetFilters"
|
||||
placeholder="Select filter"
|
||||
class="w-full"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div class="relative flex-1">
|
||||
<Label v-if="label && !showPresets" :for="`filter-input-${$.uid}`">{{ label }}</Label>
|
||||
<Label v-else-if="label && showPresets" :for="`filter-input-${$.uid}`">Custom {{ label }}</Label>
|
||||
<div class="relative">
|
||||
<MagnifyingGlassIcon
|
||||
v-if="showIcon"
|
||||
class="absolute left-2 top-1/2 -translate-y-1/2 h-4 w-4 text-muted-foreground"
|
||||
/>
|
||||
<Input
|
||||
:id="`filter-input-${$.uid}`"
|
||||
v-model="filterText"
|
||||
type="text"
|
||||
:placeholder="placeholder"
|
||||
:class="[showIcon ? 'pl-8' : '', inputClass]"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
@@ -11,6 +11,7 @@ import {
|
||||
|
||||
import { GET_LOG_FILES } from './log.query';
|
||||
import SingleLogViewer from './SingleLogViewer.vue';
|
||||
import LogViewerToolbar from './LogViewerToolbar.vue';
|
||||
|
||||
// Types
|
||||
interface LogFile {
|
||||
@@ -20,10 +21,12 @@ interface LogFile {
|
||||
}
|
||||
|
||||
// Component state
|
||||
const selectedLogFile = ref<string>('');
|
||||
const selectedLogFile = ref<string | null>(null);
|
||||
const lineCount = ref<number>(100);
|
||||
const autoScroll = ref<boolean>(true);
|
||||
const highlightLanguage = ref<string>('plaintext');
|
||||
const filterText = ref<string>('');
|
||||
const presetFilter = ref<string>('none');
|
||||
|
||||
// Available highlight languages
|
||||
const highlightLanguages = [
|
||||
@@ -39,6 +42,15 @@ const highlightLanguages = [
|
||||
{ value: 'php', label: 'PHP' },
|
||||
];
|
||||
|
||||
// Preset filter options
|
||||
const presetFilters = [
|
||||
{ value: 'none', label: 'No Filter' },
|
||||
{ value: 'OIDC', label: 'OIDC Logs' },
|
||||
{ value: 'ERROR', label: 'Errors' },
|
||||
{ value: 'WARNING', label: 'Warnings' },
|
||||
{ value: 'AUTH', label: 'Authentication' },
|
||||
];
|
||||
|
||||
// Fetch log files
|
||||
const {
|
||||
result: logFilesResult,
|
||||
@@ -102,15 +114,32 @@ watch(selectedLogFile, (newValue) => {
|
||||
highlightLanguage.value = autoDetectLanguage(newValue);
|
||||
}
|
||||
});
|
||||
|
||||
// Watch for preset filter changes to update the filter text
|
||||
watch(presetFilter, (newValue) => {
|
||||
if (newValue && newValue !== 'none') {
|
||||
filterText.value = newValue;
|
||||
} else if (newValue === 'none') {
|
||||
filterText.value = '';
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div
|
||||
class="flex flex-col h-[500px] resize-y bg-background text-foreground rounded-lg border border-border overflow-hidden"
|
||||
>
|
||||
<LogViewerToolbar
|
||||
v-model:filter-text="filterText"
|
||||
v-model:preset-filter="presetFilter"
|
||||
title="Log Viewer"
|
||||
:show-presets="true"
|
||||
:preset-filters="presetFilters"
|
||||
:show-toggle="false"
|
||||
:show-refresh="false"
|
||||
/>
|
||||
|
||||
<div class="p-4 border-b border-border">
|
||||
<h2 class="text-lg font-semibold mb-4">Log Viewer</h2>
|
||||
|
||||
<div class="flex flex-wrap gap-4 items-end">
|
||||
<div class="flex-1 min-w-[200px]">
|
||||
<Label for="log-file-select">Log File</Label>
|
||||
@@ -186,6 +215,7 @@ watch(selectedLogFile, (newValue) => {
|
||||
:line-count="lineCount"
|
||||
:auto-scroll="autoScroll"
|
||||
:highlight-language="highlightLanguage"
|
||||
:client-filter="filterText"
|
||||
class="h-full"
|
||||
/>
|
||||
</div>
|
||||
|
||||
122
web/components/Logs/LogViewerToolbar.vue
Normal file
122
web/components/Logs/LogViewerToolbar.vue
Normal file
@@ -0,0 +1,122 @@
|
||||
<script setup lang="ts">
|
||||
import { computed } from 'vue';
|
||||
import { Button } from '@unraid/ui';
|
||||
import type { SelectItemType } from '@unraid/ui';
|
||||
import { ArrowPathIcon, EyeIcon, EyeSlashIcon, ChevronDownIcon, ChevronUpIcon } from '@heroicons/vue/24/outline';
|
||||
import LogFilterInput from './LogFilterInput.vue';
|
||||
|
||||
const props = withDefaults(defineProps<{
|
||||
title?: string;
|
||||
description?: string;
|
||||
filterText: string;
|
||||
showToggle?: boolean;
|
||||
isExpanded?: boolean;
|
||||
showRefresh?: boolean;
|
||||
showPresets?: boolean;
|
||||
presetFilter?: string;
|
||||
presetFilters?: SelectItemType[];
|
||||
filterPlaceholder?: string;
|
||||
filterLabel?: string;
|
||||
compact?: boolean;
|
||||
}>(), {
|
||||
title: '',
|
||||
description: '',
|
||||
showToggle: false,
|
||||
isExpanded: true,
|
||||
showRefresh: true,
|
||||
showPresets: false,
|
||||
presetFilter: 'none',
|
||||
presetFilters: () => [],
|
||||
filterPlaceholder: 'Filter logs...',
|
||||
filterLabel: 'Filter',
|
||||
compact: false
|
||||
});
|
||||
|
||||
const emit = defineEmits<{
|
||||
'update:filterText': [value: string];
|
||||
'update:presetFilter': [value: string];
|
||||
'update:isExpanded': [value: boolean];
|
||||
'refresh': [];
|
||||
}>();
|
||||
|
||||
const filterValue = computed({
|
||||
get: () => props.filterText,
|
||||
set: (value) => emit('update:filterText', value)
|
||||
});
|
||||
|
||||
const presetValue = computed({
|
||||
get: () => props.presetFilter || 'none',
|
||||
set: (value) => emit('update:presetFilter', value)
|
||||
});
|
||||
|
||||
const toggleExpanded = () => {
|
||||
emit('update:isExpanded', !props.isExpanded);
|
||||
};
|
||||
|
||||
const handleRefresh = () => {
|
||||
emit('refresh');
|
||||
};
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div :class="['border-b', compact ? 'p-3' : 'p-4 pb-3', 'border-muted']">
|
||||
<div class="flex justify-between items-center">
|
||||
<div v-if="title || description">
|
||||
<h3 v-if="title" :class="[compact ? 'text-sm' : 'text-base', 'font-semibold']">{{ title }}</h3>
|
||||
<p v-if="description" :class="['mt-1 text-muted-foreground', compact ? 'text-xs' : 'text-sm']">
|
||||
{{ description }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div class="flex gap-2 items-center" :class="!title && !description ? 'w-full' : ''">
|
||||
<div :class="!title && !description ? 'flex-1' : ''">
|
||||
<LogFilterInput
|
||||
v-model="filterValue"
|
||||
v-model:preset="presetValue"
|
||||
:show-presets="showPresets"
|
||||
:preset-filters="presetFilters"
|
||||
:placeholder="filterPlaceholder"
|
||||
:label="compact || (!title && !description) ? filterLabel : ''"
|
||||
:input-class="compact ? 'h-7 text-sm' : 'h-8'"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<Button
|
||||
v-if="showRefresh"
|
||||
variant="outline"
|
||||
:size="compact ? 'sm' : 'sm'"
|
||||
title="Refresh logs"
|
||||
@click="handleRefresh"
|
||||
>
|
||||
<ArrowPathIcon :class="compact ? 'h-3 w-3' : 'h-4 w-4'" />
|
||||
</Button>
|
||||
|
||||
<Button
|
||||
v-if="showToggle"
|
||||
variant="outline"
|
||||
:size="compact ? 'sm' : 'sm'"
|
||||
@click="toggleExpanded"
|
||||
>
|
||||
<component
|
||||
:is="isExpanded ? EyeSlashIcon : EyeIcon"
|
||||
:class="compact ? 'h-3 w-3' : 'h-4 w-4'"
|
||||
/>
|
||||
<span v-if="!compact" class="ml-2">{{ isExpanded ? 'Hide' : 'Show' }} Logs</span>
|
||||
</Button>
|
||||
|
||||
<Button
|
||||
v-else-if="showToggle === false && typeof isExpanded === 'boolean'"
|
||||
variant="ghost"
|
||||
:size="compact ? 'sm' : 'sm'"
|
||||
class="p-1"
|
||||
@click="toggleExpanded"
|
||||
>
|
||||
<component
|
||||
:is="isExpanded ? ChevronUpIcon : ChevronDownIcon"
|
||||
:class="compact ? 'h-3 w-3' : 'h-4 w-4'"
|
||||
/>
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
28
web/components/Logs/OidcDebugButton.vue
Normal file
28
web/components/Logs/OidcDebugButton.vue
Normal file
@@ -0,0 +1,28 @@
|
||||
<script setup lang="ts">
|
||||
import { ref } from 'vue';
|
||||
import { Button } from '@unraid/ui';
|
||||
import { BugAntIcon } from '@heroicons/vue/24/outline';
|
||||
import FilteredLogModal from './FilteredLogModal.vue';
|
||||
|
||||
const showOidcLogs = ref(false);
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div>
|
||||
<Button variant="outline" size="sm" @click="showOidcLogs = true">
|
||||
<BugAntIcon class="w-4 h-4 mr-2" />
|
||||
View OIDC Debug Logs
|
||||
</Button>
|
||||
|
||||
<FilteredLogModal
|
||||
v-model="showOidcLogs"
|
||||
log-file-path="graphql-api.log"
|
||||
filter="OIDC"
|
||||
title="OIDC Debug Logs"
|
||||
description="Viewing OIDC authentication and configuration logs"
|
||||
:line-count="200"
|
||||
:auto-scroll="true"
|
||||
highlight-language="plaintext"
|
||||
/>
|
||||
</div>
|
||||
</template>
|
||||
@@ -5,22 +5,7 @@ import { vInfiniteScroll } from '@vueuse/components';
|
||||
|
||||
import { ArrowDownTrayIcon, ArrowPathIcon } from '@heroicons/vue/24/outline';
|
||||
import { Button, Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '@unraid/ui';
|
||||
import hljs from 'highlight.js/lib/core';
|
||||
import DOMPurify from 'isomorphic-dompurify';
|
||||
|
||||
import 'highlight.js/styles/github-dark.css'; // You can choose a different style
|
||||
|
||||
import apache from 'highlight.js/lib/languages/apache';
|
||||
import bash from 'highlight.js/lib/languages/bash';
|
||||
import ini from 'highlight.js/lib/languages/ini';
|
||||
import javascript from 'highlight.js/lib/languages/javascript';
|
||||
import json from 'highlight.js/lib/languages/json';
|
||||
import nginx from 'highlight.js/lib/languages/nginx';
|
||||
import php from 'highlight.js/lib/languages/php';
|
||||
// Register the languages you want to support
|
||||
import plaintext from 'highlight.js/lib/languages/plaintext';
|
||||
import xml from 'highlight.js/lib/languages/xml';
|
||||
import yaml from 'highlight.js/lib/languages/yaml';
|
||||
import { useContentHighlighting } from '~/composables/useContentHighlighting';
|
||||
|
||||
import type { LogFileContentQuery, LogFileContentQueryVariables } from '~/composables/gql/graphql';
|
||||
|
||||
@@ -32,28 +17,17 @@ import { LOG_FILE_SUBSCRIPTION } from './log.subscription';
|
||||
const themeStore = useThemeStore();
|
||||
const isDarkMode = computed(() => themeStore.darkMode);
|
||||
|
||||
// Register the languages
|
||||
hljs.registerLanguage('plaintext', plaintext);
|
||||
hljs.registerLanguage('bash', bash);
|
||||
hljs.registerLanguage('ini', ini);
|
||||
hljs.registerLanguage('xml', xml);
|
||||
hljs.registerLanguage('json', json);
|
||||
hljs.registerLanguage('yaml', yaml);
|
||||
hljs.registerLanguage('nginx', nginx);
|
||||
hljs.registerLanguage('apache', apache);
|
||||
hljs.registerLanguage('javascript', javascript);
|
||||
hljs.registerLanguage('php', php);
|
||||
// Use shared highlighting logic
|
||||
const { highlightContent } = useContentHighlighting();
|
||||
|
||||
const props = defineProps<{
|
||||
logFilePath: string;
|
||||
lineCount: number;
|
||||
autoScroll: boolean;
|
||||
highlightLanguage?: string; // Optional prop to specify the language for highlighting
|
||||
clientFilter?: string; // Optional client-side filter to apply to log content
|
||||
}>();
|
||||
|
||||
// Default language for highlighting
|
||||
const defaultLanguage = 'plaintext';
|
||||
|
||||
const DEFAULT_CHUNK_SIZE = 100;
|
||||
const scrollViewportRef = ref<HTMLElement | null>(null);
|
||||
const state = reactive({
|
||||
@@ -111,44 +85,8 @@ onMounted(() => {
|
||||
observer.observe(scrollViewportRef.value as unknown as Node, { childList: true, subtree: true });
|
||||
}
|
||||
|
||||
if (props.logFilePath) {
|
||||
subscribeToMore({
|
||||
document: LOG_FILE_SUBSCRIPTION,
|
||||
variables: { path: props.logFilePath },
|
||||
updateQuery: (prev, { subscriptionData }) => {
|
||||
if (!subscriptionData.data || !prev) return prev;
|
||||
|
||||
// Set subscription as active when we receive data
|
||||
state.isSubscriptionActive = true;
|
||||
|
||||
const existingContent = prev.logFile?.content || '';
|
||||
const newContent = subscriptionData.data.logFile.content;
|
||||
|
||||
// Update the local state with the new content
|
||||
if (newContent && state.loadedContentChunks.length > 0) {
|
||||
const lastChunk = state.loadedContentChunks[state.loadedContentChunks.length - 1];
|
||||
lastChunk.content += newContent;
|
||||
|
||||
// Force scroll to bottom if auto-scroll is enabled
|
||||
if (props.autoScroll) {
|
||||
nextTick(() => forceScrollToBottom());
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
...prev,
|
||||
logFile: {
|
||||
...prev.logFile,
|
||||
content: existingContent + newContent,
|
||||
totalLines: (prev.logFile?.totalLines || 0) + (newContent.split('\n').length - 1),
|
||||
},
|
||||
};
|
||||
},
|
||||
});
|
||||
|
||||
// Set subscription as active
|
||||
state.isSubscriptionActive = true;
|
||||
}
|
||||
// Start the log subscription
|
||||
startLogSubscription();
|
||||
});
|
||||
|
||||
// Cleanup observer on unmount
|
||||
@@ -188,75 +126,33 @@ watch(
|
||||
{ deep: true }
|
||||
);
|
||||
|
||||
// Function to highlight log content
|
||||
// Function to highlight log content using shared composable
|
||||
const highlightLog = (content: string): string => {
|
||||
try {
|
||||
// Determine which language to use for highlighting
|
||||
const language = props.highlightLanguage || defaultLanguage;
|
||||
|
||||
// Apply syntax highlighting
|
||||
let highlighted = hljs.highlight(content, { language }).value;
|
||||
|
||||
// Apply additional custom highlighting for common log patterns
|
||||
|
||||
// Highlight timestamps (various formats)
|
||||
highlighted = highlighted.replace(
|
||||
/\b(\d{4}-\d{2}-\d{2}[T ]\d{2}:\d{2}:\d{2}(?:\.\d+)?(?:Z|[+-]\d{2}:?\d{2})?)\b/g,
|
||||
'<span class="hljs-timestamp">$1</span>'
|
||||
);
|
||||
|
||||
// Highlight IP addresses
|
||||
highlighted = highlighted.replace(
|
||||
/\b(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})\b/g,
|
||||
'<span class="hljs-ip">$1</span>'
|
||||
);
|
||||
|
||||
// Split the content into lines
|
||||
let lines = highlighted.split('\n');
|
||||
|
||||
// Process each line to add error, warning, and success highlighting
|
||||
lines = lines.map((line) => {
|
||||
if (/(error|exception|fail|failed|failure)/i.test(line)) {
|
||||
// Highlight error keywords
|
||||
line = line.replace(
|
||||
/\b(error|exception|fail|failed|failure)\b/gi,
|
||||
'<span class="hljs-error-keyword">$1</span>'
|
||||
);
|
||||
// Wrap the entire line
|
||||
return `<span class="hljs-error">${line}</span>`;
|
||||
} else if (/(warning|warn)/i.test(line)) {
|
||||
// Highlight warning keywords
|
||||
line = line.replace(/\b(warning|warn)\b/gi, '<span class="hljs-warning-keyword">$1</span>');
|
||||
// Wrap the entire line
|
||||
return `<span class="hljs-warning">${line}</span>`;
|
||||
} else if (/(success|successful|completed|done)/i.test(line)) {
|
||||
// Highlight success keywords
|
||||
line = line.replace(
|
||||
/\b(success|successful|completed|done)\b/gi,
|
||||
'<span class="hljs-success-keyword">$1</span>'
|
||||
);
|
||||
// Wrap the entire line
|
||||
return `<span class="hljs-success">${line}</span>`;
|
||||
}
|
||||
return line;
|
||||
});
|
||||
|
||||
// Join the lines back together
|
||||
highlighted = lines.join('\n');
|
||||
|
||||
// Sanitize the highlighted HTML
|
||||
return DOMPurify.sanitize(highlighted);
|
||||
} catch (error) {
|
||||
console.error('Error highlighting log content:', error);
|
||||
// Fallback to sanitized but not highlighted content
|
||||
return DOMPurify.sanitize(content);
|
||||
}
|
||||
return highlightContent(content, props.highlightLanguage);
|
||||
};
|
||||
|
||||
// Apply client-side filtering
|
||||
const filteredContent = computed(() => {
|
||||
// Join chunks ensuring proper newline handling
|
||||
const rawContent = state.loadedContentChunks
|
||||
.map((chunk) => chunk.content)
|
||||
.filter(content => content) // Remove empty chunks
|
||||
.join(''); // Content should already have proper newlines
|
||||
|
||||
// Apply client-side filter if provided
|
||||
if (props.clientFilter && props.clientFilter.trim()) {
|
||||
const filterLower = props.clientFilter.toLowerCase();
|
||||
const lines = rawContent.split('\n');
|
||||
const filtered = lines.filter(line => line.toLowerCase().includes(filterLower));
|
||||
return filtered.join('\n');
|
||||
}
|
||||
|
||||
return rawContent;
|
||||
});
|
||||
|
||||
// Computed properties
|
||||
const logContent = computed(() => {
|
||||
const rawContent = state.loadedContentChunks.map((chunk) => chunk.content).join('');
|
||||
return highlightLog(rawContent);
|
||||
return highlightLog(filteredContent.value);
|
||||
});
|
||||
|
||||
const totalLines = computed(() => logContentResult.value?.logFile?.totalLines || 0);
|
||||
@@ -339,15 +235,84 @@ const downloadLogFile = async () => {
|
||||
}
|
||||
};
|
||||
|
||||
// Refresh logs
|
||||
const refreshLogContent = () => {
|
||||
// Clear all state to initial values
|
||||
const clearState = () => {
|
||||
state.loadedContentChunks = [];
|
||||
state.currentStartLine = undefined;
|
||||
state.isAtTop = false;
|
||||
state.canLoadMore = false;
|
||||
state.initialLoadComplete = false;
|
||||
state.isLoadingMore = false;
|
||||
refetchLogContent();
|
||||
};
|
||||
|
||||
// Helper function to start log subscription
|
||||
const startLogSubscription = () => {
|
||||
if (!props.logFilePath) return;
|
||||
|
||||
try {
|
||||
subscribeToMore({
|
||||
document: LOG_FILE_SUBSCRIPTION,
|
||||
variables: { path: props.logFilePath },
|
||||
updateQuery: (prev, { subscriptionData }) => {
|
||||
if (!subscriptionData.data || !prev) return prev;
|
||||
|
||||
// Set subscription as active when we receive data
|
||||
state.isSubscriptionActive = true;
|
||||
|
||||
const existingContent = prev.logFile?.content || '';
|
||||
const newContent = subscriptionData.data.logFile.content;
|
||||
|
||||
// Update the local state with the new content
|
||||
if (newContent && state.loadedContentChunks.length > 0) {
|
||||
const lastChunk = state.loadedContentChunks[state.loadedContentChunks.length - 1];
|
||||
// Ensure there's a newline between the existing content and new content if needed
|
||||
if (lastChunk.content && !lastChunk.content.endsWith('\n') && newContent) {
|
||||
lastChunk.content += '\n' + newContent;
|
||||
} else {
|
||||
lastChunk.content += newContent;
|
||||
}
|
||||
|
||||
// Force scroll to bottom if auto-scroll is enabled
|
||||
if (props.autoScroll) {
|
||||
nextTick(() => forceScrollToBottom());
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
...prev,
|
||||
logFile: {
|
||||
...prev.logFile,
|
||||
content: existingContent + newContent,
|
||||
totalLines: (prev.logFile?.totalLines || 0) + (newContent.split('\n').length - 1),
|
||||
},
|
||||
};
|
||||
},
|
||||
});
|
||||
|
||||
// Set subscription as active
|
||||
state.isSubscriptionActive = true;
|
||||
} catch (error) {
|
||||
console.error('Error starting log subscription:', error);
|
||||
state.isSubscriptionActive = false;
|
||||
}
|
||||
};
|
||||
|
||||
// Refresh logs
|
||||
const refreshLogContent = async () => {
|
||||
// Clear the state
|
||||
clearState();
|
||||
|
||||
// Refetch with explicit variables to ensure we get the latest logs
|
||||
await refetchLogContent({
|
||||
path: props.logFilePath,
|
||||
lines: props.lineCount || DEFAULT_CHUNK_SIZE,
|
||||
startLine: undefined, // Explicitly pass undefined to get the latest lines
|
||||
});
|
||||
|
||||
// Restart the subscription with the same variables used for refetch
|
||||
// Note: subscribeToMore in Vue Apollo doesn't return an unsubscribe function
|
||||
// The previous subscription is automatically replaced when calling subscribeToMore again
|
||||
startLogSubscription();
|
||||
|
||||
nextTick(() => {
|
||||
forceScrollToBottom();
|
||||
@@ -436,7 +401,7 @@ defineExpose({ refreshLogContent });
|
||||
</div>
|
||||
|
||||
<pre
|
||||
class="font-mono whitespace-pre-wrap p-4 m-0 text-xs leading-6 hljs"
|
||||
class="font-mono whitespace-pre p-4 m-0 text-xs leading-6 hljs"
|
||||
:class="{ 'theme-dark': isDarkMode, 'theme-light': !isDarkMode }"
|
||||
v-html="logContent"
|
||||
/>
|
||||
@@ -612,4 +577,51 @@ defineExpose({ refreshLogContent });
|
||||
color: var(--log-success-color);
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
/* ANSI color styles for ansi_up output - using format: ansi-{color}-fg/bg */
|
||||
/* Foreground colors */
|
||||
:deep(.ansi-black-fg) { color: #000; }
|
||||
:deep(.ansi-red-fg) { color: #c91b00; }
|
||||
:deep(.ansi-green-fg) { color: #00c200; }
|
||||
:deep(.ansi-yellow-fg) { color: #c7c400; }
|
||||
:deep(.ansi-blue-fg) { color: #0225c7; }
|
||||
:deep(.ansi-magenta-fg) { color: #c930c7; }
|
||||
:deep(.ansi-cyan-fg) { color: #00c5c7; }
|
||||
:deep(.ansi-white-fg) { color: #c7c7c7; }
|
||||
|
||||
/* Bright foreground colors */
|
||||
:deep(.ansi-bright-black-fg) { color: #676767; }
|
||||
:deep(.ansi-bright-red-fg) { color: #ff6d67; }
|
||||
:deep(.ansi-bright-green-fg) { color: #5ff967; }
|
||||
:deep(.ansi-bright-yellow-fg) { color: #fefb67; }
|
||||
:deep(.ansi-bright-blue-fg) { color: #6871ff; }
|
||||
:deep(.ansi-bright-magenta-fg) { color: #ff76ff; }
|
||||
:deep(.ansi-bright-cyan-fg) { color: #5ffdff; }
|
||||
:deep(.ansi-bright-white-fg) { color: #fff; }
|
||||
|
||||
/* Background colors */
|
||||
:deep(.ansi-black-bg) { background-color: #000; }
|
||||
:deep(.ansi-red-bg) { background-color: #c91b00; }
|
||||
:deep(.ansi-green-bg) { background-color: #00c200; }
|
||||
:deep(.ansi-yellow-bg) { background-color: #c7c400; }
|
||||
:deep(.ansi-blue-bg) { background-color: #0225c7; }
|
||||
:deep(.ansi-magenta-bg) { background-color: #c930c7; }
|
||||
:deep(.ansi-cyan-bg) { background-color: #00c5c7; }
|
||||
:deep(.ansi-white-bg) { background-color: #c7c7c7; }
|
||||
|
||||
/* Bright background colors */
|
||||
:deep(.ansi-bright-black-bg) { background-color: #676767; }
|
||||
:deep(.ansi-bright-red-bg) { background-color: #ff6d67; }
|
||||
:deep(.ansi-bright-green-bg) { background-color: #5ff967; }
|
||||
:deep(.ansi-bright-yellow-bg) { background-color: #fefb67; }
|
||||
:deep(.ansi-bright-blue-bg) { background-color: #6871ff; }
|
||||
:deep(.ansi-bright-magenta-bg) { background-color: #ff76ff; }
|
||||
:deep(.ansi-bright-cyan-bg) { background-color: #5ffdff; }
|
||||
:deep(.ansi-bright-white-bg) { background-color: #fff; }
|
||||
|
||||
/* Additional ansi_up classes */
|
||||
:deep(.ansi-bold) { font-weight: bold; }
|
||||
:deep(.ansi-italic) { font-style: italic; }
|
||||
:deep(.ansi-underline) { text-decoration: underline; }
|
||||
:deep(.ansi-strike) { text-decoration: line-through; }
|
||||
</style>
|
||||
|
||||
@@ -70,8 +70,11 @@ export function useSsoAuth() {
|
||||
sessionStorage.setItem('sso_state', state);
|
||||
sessionStorage.setItem('sso_provider', providerId);
|
||||
|
||||
// Redirect to OIDC authorization endpoint with just the state token
|
||||
const authUrl = `/graphql/api/auth/oidc/authorize/${encodeURIComponent(providerId)}?state=${encodeURIComponent(state)}`;
|
||||
// Build the redirect URI based on current window location
|
||||
const redirectUri = `${window.location.protocol}//${window.location.host}/graphql/api/auth/oidc/callback`;
|
||||
|
||||
// Redirect to OIDC authorization endpoint with state token and redirect URI
|
||||
const authUrl = `/graphql/api/auth/oidc/authorize/${encodeURIComponent(providerId)}?state=${encodeURIComponent(state)}&redirect_uri=${encodeURIComponent(redirectUri)}`;
|
||||
window.location.href = authUrl;
|
||||
};
|
||||
|
||||
|
||||
@@ -49,6 +49,8 @@ type Documents = {
|
||||
"\n query InfoVersions {\n info {\n id\n os {\n id\n hostname\n }\n versions {\n id\n core {\n unraid\n api\n }\n }\n }\n }\n": typeof types.InfoVersionsDocument,
|
||||
"\n query OidcProviders {\n settings {\n sso {\n oidcProviders {\n id\n name\n clientId\n issuer\n authorizationEndpoint\n tokenEndpoint\n jwksUri\n scopes\n authorizationRules {\n claim\n operator\n value\n }\n authorizationRuleMode\n buttonText\n buttonIcon\n }\n }\n }\n }\n": typeof types.OidcProvidersDocument,
|
||||
"\n query PublicOidcProviders {\n publicOidcProviders {\n id\n name\n buttonText\n buttonIcon\n buttonVariant\n buttonStyle\n }\n }\n": typeof types.PublicOidcProvidersDocument,
|
||||
"\n query AllConfigFiles {\n allConfigFiles {\n files {\n name\n content\n path\n sizeReadable\n }\n }\n }\n": typeof types.AllConfigFilesDocument,
|
||||
"\n query ConfigFile($name: String!) {\n configFile(name: $name) {\n name\n content\n path\n sizeReadable\n }\n }\n": typeof types.ConfigFileDocument,
|
||||
"\n query serverInfo {\n info {\n os {\n hostname\n }\n }\n vars {\n comment\n }\n }\n": typeof types.ServerInfoDocument,
|
||||
"\n mutation ConnectSignIn($input: ConnectSignInInput!) {\n connectSignIn(input: $input)\n }\n": typeof types.ConnectSignInDocument,
|
||||
"\n mutation SignOut {\n connectSignOut\n }\n": typeof types.SignOutDocument,
|
||||
@@ -94,6 +96,8 @@ const documents: Documents = {
|
||||
"\n query InfoVersions {\n info {\n id\n os {\n id\n hostname\n }\n versions {\n id\n core {\n unraid\n api\n }\n }\n }\n }\n": types.InfoVersionsDocument,
|
||||
"\n query OidcProviders {\n settings {\n sso {\n oidcProviders {\n id\n name\n clientId\n issuer\n authorizationEndpoint\n tokenEndpoint\n jwksUri\n scopes\n authorizationRules {\n claim\n operator\n value\n }\n authorizationRuleMode\n buttonText\n buttonIcon\n }\n }\n }\n }\n": types.OidcProvidersDocument,
|
||||
"\n query PublicOidcProviders {\n publicOidcProviders {\n id\n name\n buttonText\n buttonIcon\n buttonVariant\n buttonStyle\n }\n }\n": types.PublicOidcProvidersDocument,
|
||||
"\n query AllConfigFiles {\n allConfigFiles {\n files {\n name\n content\n path\n sizeReadable\n }\n }\n }\n": types.AllConfigFilesDocument,
|
||||
"\n query ConfigFile($name: String!) {\n configFile(name: $name) {\n name\n content\n path\n sizeReadable\n }\n }\n": types.ConfigFileDocument,
|
||||
"\n query serverInfo {\n info {\n os {\n hostname\n }\n }\n vars {\n comment\n }\n }\n": types.ServerInfoDocument,
|
||||
"\n mutation ConnectSignIn($input: ConnectSignInInput!) {\n connectSignIn(input: $input)\n }\n": types.ConnectSignInDocument,
|
||||
"\n mutation SignOut {\n connectSignOut\n }\n": types.SignOutDocument,
|
||||
@@ -258,6 +262,14 @@ export function graphql(source: "\n query OidcProviders {\n settings {\n
|
||||
* The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients.
|
||||
*/
|
||||
export function graphql(source: "\n query PublicOidcProviders {\n publicOidcProviders {\n id\n name\n buttonText\n buttonIcon\n buttonVariant\n buttonStyle\n }\n }\n"): (typeof documents)["\n query PublicOidcProviders {\n publicOidcProviders {\n id\n name\n buttonText\n buttonIcon\n buttonVariant\n buttonStyle\n }\n }\n"];
|
||||
/**
|
||||
* The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients.
|
||||
*/
|
||||
export function graphql(source: "\n query AllConfigFiles {\n allConfigFiles {\n files {\n name\n content\n path\n sizeReadable\n }\n }\n }\n"): (typeof documents)["\n query AllConfigFiles {\n allConfigFiles {\n files {\n name\n content\n path\n sizeReadable\n }\n }\n }\n"];
|
||||
/**
|
||||
* The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients.
|
||||
*/
|
||||
export function graphql(source: "\n query ConfigFile($name: String!) {\n configFile(name: $name) {\n name\n content\n path\n sizeReadable\n }\n }\n"): (typeof documents)["\n query ConfigFile($name: String!) {\n configFile(name: $name) {\n name\n content\n path\n sizeReadable\n }\n }\n"];
|
||||
/**
|
||||
* The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients.
|
||||
*/
|
||||
|
||||
@@ -448,6 +448,20 @@ export enum ConfigErrorState {
|
||||
WITHDRAWN = 'WITHDRAWN'
|
||||
}
|
||||
|
||||
export type ConfigFile = {
|
||||
__typename?: 'ConfigFile';
|
||||
content: Scalars['String']['output'];
|
||||
name: Scalars['String']['output'];
|
||||
path: Scalars['String']['output'];
|
||||
/** Human-readable file size (e.g., "1.5 KB", "2.3 MB") */
|
||||
sizeReadable: Scalars['String']['output'];
|
||||
};
|
||||
|
||||
export type ConfigFilesResponse = {
|
||||
__typename?: 'ConfigFilesResponse';
|
||||
files: Array<ConfigFile>;
|
||||
};
|
||||
|
||||
export type Connect = Node & {
|
||||
__typename?: 'Connect';
|
||||
/** The status of dynamic remote access */
|
||||
@@ -1432,6 +1446,14 @@ export type OidcAuthorizationRule = {
|
||||
value: Array<Scalars['String']['output']>;
|
||||
};
|
||||
|
||||
export type OidcConfiguration = {
|
||||
__typename?: 'OidcConfiguration';
|
||||
/** Default allowed redirect origins that apply to all OIDC providers (e.g., Tailscale domains) */
|
||||
defaultAllowedOrigins?: Maybe<Array<Scalars['String']['output']>>;
|
||||
/** List of configured OIDC providers */
|
||||
providers: Array<OidcProvider>;
|
||||
};
|
||||
|
||||
export type OidcProvider = {
|
||||
__typename?: 'OidcProvider';
|
||||
/** OAuth2 authorization endpoint URL. If omitted, will be auto-discovered from issuer/.well-known/openid-configuration */
|
||||
@@ -1455,7 +1477,7 @@ export type OidcProvider = {
|
||||
/** The unique identifier for the OIDC provider */
|
||||
id: Scalars['PrefixedID']['output'];
|
||||
/** OIDC issuer URL (e.g., https://accounts.google.com). Required for auto-discovery via /.well-known/openid-configuration */
|
||||
issuer: Scalars['String']['output'];
|
||||
issuer?: Maybe<Scalars['String']['output']>;
|
||||
/** JSON Web Key Set URI for token validation. If omitted, will be auto-discovered from issuer/.well-known/openid-configuration */
|
||||
jwksUri?: Maybe<Scalars['String']['output']>;
|
||||
/** Display name of the OIDC provider */
|
||||
@@ -1623,6 +1645,7 @@ export type PublicPartnerInfo = {
|
||||
|
||||
export type Query = {
|
||||
__typename?: 'Query';
|
||||
allConfigFiles: ConfigFilesResponse;
|
||||
apiKey?: Maybe<ApiKey>;
|
||||
/** All possible permissions for API keys */
|
||||
apiKeyPossiblePermissions: Array<Permission>;
|
||||
@@ -1632,6 +1655,7 @@ export type Query = {
|
||||
array: UnraidArray;
|
||||
cloud: Cloud;
|
||||
config: Config;
|
||||
configFile?: Maybe<ConfigFile>;
|
||||
connect: Connect;
|
||||
customization?: Maybe<Customization>;
|
||||
disk: Disk;
|
||||
@@ -1654,6 +1678,8 @@ export type Query = {
|
||||
network: Network;
|
||||
/** Get all notifications */
|
||||
notifications: Notifications;
|
||||
/** Get the full OIDC configuration (admin only) */
|
||||
oidcConfiguration: OidcConfiguration;
|
||||
/** Get a specific OIDC provider by ID */
|
||||
oidcProvider?: Maybe<OidcProvider>;
|
||||
/** Get all configured OIDC providers (admin only) */
|
||||
@@ -1693,6 +1719,11 @@ export type QueryApiKeyArgs = {
|
||||
};
|
||||
|
||||
|
||||
export type QueryConfigFileArgs = {
|
||||
name: Scalars['String']['input'];
|
||||
};
|
||||
|
||||
|
||||
export type QueryDiskArgs = {
|
||||
id: Scalars['PrefixedID']['input'];
|
||||
};
|
||||
@@ -1933,6 +1964,7 @@ export type Server = Node & {
|
||||
name: Scalars['String']['output'];
|
||||
owner: ProfileModel;
|
||||
remoteurl: Scalars['String']['output'];
|
||||
/** Whether this server is online or offline */
|
||||
status: ServerStatus;
|
||||
wanip: Scalars['String']['output'];
|
||||
};
|
||||
@@ -2757,13 +2789,25 @@ export type InfoVersionsQuery = { __typename?: 'Query', info: { __typename?: 'In
|
||||
export type OidcProvidersQueryVariables = Exact<{ [key: string]: never; }>;
|
||||
|
||||
|
||||
export type OidcProvidersQuery = { __typename?: 'Query', settings: { __typename?: 'Settings', sso: { __typename?: 'SsoSettings', oidcProviders: Array<{ __typename?: 'OidcProvider', id: string, name: string, clientId: string, issuer: string, authorizationEndpoint?: string | null, tokenEndpoint?: string | null, jwksUri?: string | null, scopes: Array<string>, authorizationRuleMode?: AuthorizationRuleMode | null, buttonText?: string | null, buttonIcon?: string | null, authorizationRules?: Array<{ __typename?: 'OidcAuthorizationRule', claim: string, operator: AuthorizationOperator, value: Array<string> }> | null }> } } };
|
||||
export type OidcProvidersQuery = { __typename?: 'Query', settings: { __typename?: 'Settings', sso: { __typename?: 'SsoSettings', oidcProviders: Array<{ __typename?: 'OidcProvider', id: string, name: string, clientId: string, issuer?: string | null, authorizationEndpoint?: string | null, tokenEndpoint?: string | null, jwksUri?: string | null, scopes: Array<string>, authorizationRuleMode?: AuthorizationRuleMode | null, buttonText?: string | null, buttonIcon?: string | null, authorizationRules?: Array<{ __typename?: 'OidcAuthorizationRule', claim: string, operator: AuthorizationOperator, value: Array<string> }> | null }> } } };
|
||||
|
||||
export type PublicOidcProvidersQueryVariables = Exact<{ [key: string]: never; }>;
|
||||
|
||||
|
||||
export type PublicOidcProvidersQuery = { __typename?: 'Query', publicOidcProviders: Array<{ __typename?: 'PublicOidcProvider', id: string, name: string, buttonText?: string | null, buttonIcon?: string | null, buttonVariant?: string | null, buttonStyle?: string | null }> };
|
||||
|
||||
export type AllConfigFilesQueryVariables = Exact<{ [key: string]: never; }>;
|
||||
|
||||
|
||||
export type AllConfigFilesQuery = { __typename?: 'Query', allConfigFiles: { __typename?: 'ConfigFilesResponse', files: Array<{ __typename?: 'ConfigFile', name: string, content: string, path: string, sizeReadable: string }> } };
|
||||
|
||||
export type ConfigFileQueryVariables = Exact<{
|
||||
name: Scalars['String']['input'];
|
||||
}>;
|
||||
|
||||
|
||||
export type ConfigFileQuery = { __typename?: 'Query', configFile?: { __typename?: 'ConfigFile', name: string, content: string, path: string, sizeReadable: string } | null };
|
||||
|
||||
export type ServerInfoQueryVariables = Exact<{ [key: string]: never; }>;
|
||||
|
||||
|
||||
@@ -2842,6 +2886,8 @@ export const ListRCloneRemotesDocument = {"kind":"Document","definitions":[{"kin
|
||||
export const InfoVersionsDocument = {"kind":"Document","definitions":[{"kind":"OperationDefinition","operation":"query","name":{"kind":"Name","value":"InfoVersions"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"info"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"id"}},{"kind":"Field","name":{"kind":"Name","value":"os"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"id"}},{"kind":"Field","name":{"kind":"Name","value":"hostname"}}]}},{"kind":"Field","name":{"kind":"Name","value":"versions"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"id"}},{"kind":"Field","name":{"kind":"Name","value":"core"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"unraid"}},{"kind":"Field","name":{"kind":"Name","value":"api"}}]}}]}}]}}]}}]} as unknown as DocumentNode<InfoVersionsQuery, InfoVersionsQueryVariables>;
|
||||
export const OidcProvidersDocument = {"kind":"Document","definitions":[{"kind":"OperationDefinition","operation":"query","name":{"kind":"Name","value":"OidcProviders"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"settings"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"sso"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"oidcProviders"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"id"}},{"kind":"Field","name":{"kind":"Name","value":"name"}},{"kind":"Field","name":{"kind":"Name","value":"clientId"}},{"kind":"Field","name":{"kind":"Name","value":"issuer"}},{"kind":"Field","name":{"kind":"Name","value":"authorizationEndpoint"}},{"kind":"Field","name":{"kind":"Name","value":"tokenEndpoint"}},{"kind":"Field","name":{"kind":"Name","value":"jwksUri"}},{"kind":"Field","name":{"kind":"Name","value":"scopes"}},{"kind":"Field","name":{"kind":"Name","value":"authorizationRules"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"claim"}},{"kind":"Field","name":{"kind":"Name","value":"operator"}},{"kind":"Field","name":{"kind":"Name","value":"value"}}]}},{"kind":"Field","name":{"kind":"Name","value":"authorizationRuleMode"}},{"kind":"Field","name":{"kind":"Name","value":"buttonText"}},{"kind":"Field","name":{"kind":"Name","value":"buttonIcon"}}]}}]}}]}}]}}]} as unknown as DocumentNode<OidcProvidersQuery, OidcProvidersQueryVariables>;
|
||||
export const PublicOidcProvidersDocument = {"kind":"Document","definitions":[{"kind":"OperationDefinition","operation":"query","name":{"kind":"Name","value":"PublicOidcProviders"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"publicOidcProviders"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"id"}},{"kind":"Field","name":{"kind":"Name","value":"name"}},{"kind":"Field","name":{"kind":"Name","value":"buttonText"}},{"kind":"Field","name":{"kind":"Name","value":"buttonIcon"}},{"kind":"Field","name":{"kind":"Name","value":"buttonVariant"}},{"kind":"Field","name":{"kind":"Name","value":"buttonStyle"}}]}}]}}]} as unknown as DocumentNode<PublicOidcProvidersQuery, PublicOidcProvidersQueryVariables>;
|
||||
export const AllConfigFilesDocument = {"kind":"Document","definitions":[{"kind":"OperationDefinition","operation":"query","name":{"kind":"Name","value":"AllConfigFiles"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"allConfigFiles"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"files"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"name"}},{"kind":"Field","name":{"kind":"Name","value":"content"}},{"kind":"Field","name":{"kind":"Name","value":"path"}},{"kind":"Field","name":{"kind":"Name","value":"sizeReadable"}}]}}]}}]}}]} as unknown as DocumentNode<AllConfigFilesQuery, AllConfigFilesQueryVariables>;
|
||||
export const ConfigFileDocument = {"kind":"Document","definitions":[{"kind":"OperationDefinition","operation":"query","name":{"kind":"Name","value":"ConfigFile"},"variableDefinitions":[{"kind":"VariableDefinition","variable":{"kind":"Variable","name":{"kind":"Name","value":"name"}},"type":{"kind":"NonNullType","type":{"kind":"NamedType","name":{"kind":"Name","value":"String"}}}}],"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"configFile"},"arguments":[{"kind":"Argument","name":{"kind":"Name","value":"name"},"value":{"kind":"Variable","name":{"kind":"Name","value":"name"}}}],"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"name"}},{"kind":"Field","name":{"kind":"Name","value":"content"}},{"kind":"Field","name":{"kind":"Name","value":"path"}},{"kind":"Field","name":{"kind":"Name","value":"sizeReadable"}}]}}]}}]} as unknown as DocumentNode<ConfigFileQuery, ConfigFileQueryVariables>;
|
||||
export const ServerInfoDocument = {"kind":"Document","definitions":[{"kind":"OperationDefinition","operation":"query","name":{"kind":"Name","value":"serverInfo"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"info"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"os"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"hostname"}}]}}]}},{"kind":"Field","name":{"kind":"Name","value":"vars"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"comment"}}]}}]}}]} as unknown as DocumentNode<ServerInfoQuery, ServerInfoQueryVariables>;
|
||||
export const ConnectSignInDocument = {"kind":"Document","definitions":[{"kind":"OperationDefinition","operation":"mutation","name":{"kind":"Name","value":"ConnectSignIn"},"variableDefinitions":[{"kind":"VariableDefinition","variable":{"kind":"Variable","name":{"kind":"Name","value":"input"}},"type":{"kind":"NonNullType","type":{"kind":"NamedType","name":{"kind":"Name","value":"ConnectSignInInput"}}}}],"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"connectSignIn"},"arguments":[{"kind":"Argument","name":{"kind":"Name","value":"input"},"value":{"kind":"Variable","name":{"kind":"Name","value":"input"}}}]}]}}]} as unknown as DocumentNode<ConnectSignInMutation, ConnectSignInMutationVariables>;
|
||||
export const SignOutDocument = {"kind":"Document","definitions":[{"kind":"OperationDefinition","operation":"mutation","name":{"kind":"Name","value":"SignOut"},"selectionSet":{"kind":"SelectionSet","selections":[{"kind":"Field","name":{"kind":"Name","value":"connectSignOut"}}]}}]} as unknown as DocumentNode<SignOutMutation, SignOutMutationVariables>;
|
||||
|
||||
@@ -1,211 +0,0 @@
|
||||
import { computed, ref } from 'vue';
|
||||
import { AuthAction, Resource, Role } from '~/composables/gql/graphql';
|
||||
|
||||
export interface ScopeConversion {
|
||||
permissions: Array<{ resource: Resource; actions: AuthAction[] }>;
|
||||
roles: Role[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert scope strings to permissions and roles
|
||||
* Scopes can be in format:
|
||||
* - "role:admin" for roles
|
||||
* - "docker:read" for resource permissions
|
||||
* - "docker:*" for all actions on a resource
|
||||
*/
|
||||
function convertScopesToPermissions(scopes: string[]): ScopeConversion {
|
||||
const permissions: Array<{ resource: Resource; actions: AuthAction[] }> = [];
|
||||
const roles: Role[] = [];
|
||||
|
||||
for (const scope of scopes) {
|
||||
if (scope.startsWith('role:')) {
|
||||
// Handle role scope
|
||||
const roleStr = scope.substring(5).toUpperCase();
|
||||
if (Object.values(Role).includes(roleStr as Role)) {
|
||||
roles.push(roleStr as Role);
|
||||
} else {
|
||||
console.warn(`Unknown role in scope: ${scope}`);
|
||||
}
|
||||
} else {
|
||||
// Handle permission scope
|
||||
const [resourceStr, actionStr] = scope.split(':');
|
||||
if (resourceStr && actionStr) {
|
||||
const resourceUpper = resourceStr.toUpperCase();
|
||||
const resource = Object.values(Resource).find(r => r === resourceUpper) as Resource;
|
||||
|
||||
if (!resource) {
|
||||
console.warn(`Unknown resource in scope: ${scope}`);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle wildcard or specific action
|
||||
let actions: AuthAction[];
|
||||
if (actionStr === '*') {
|
||||
// Wildcard means all CRUD actions
|
||||
actions = [
|
||||
AuthAction.CREATE_ANY,
|
||||
AuthAction.READ_ANY,
|
||||
AuthAction.UPDATE_ANY,
|
||||
AuthAction.DELETE_ANY
|
||||
];
|
||||
} else {
|
||||
// Convert action string to AuthAction enum
|
||||
// Scopes come in as 'read', 'create', etc. - convert to 'READ_ANY', 'CREATE_ANY'
|
||||
const enumValue = `${actionStr.toUpperCase()}_ANY` as AuthAction;
|
||||
if (Object.values(AuthAction).includes(enumValue)) {
|
||||
actions = [enumValue];
|
||||
} else {
|
||||
console.warn(`Unknown action in scope: ${scope}`);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Merge with existing permissions for the same resource
|
||||
const existing = permissions.find(p => p.resource === resource);
|
||||
if (existing) {
|
||||
actions.forEach(a => {
|
||||
if (!existing.actions.includes(a)) {
|
||||
existing.actions.push(a);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
permissions.push({ resource, actions });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { permissions, roles };
|
||||
}
|
||||
|
||||
export interface ApiKeyAuthorizationParams {
|
||||
name: string;
|
||||
description: string;
|
||||
scopes: string[];
|
||||
redirectUri: string;
|
||||
state: string;
|
||||
}
|
||||
|
||||
export interface FormattedPermission {
|
||||
scope: string;
|
||||
name: string;
|
||||
description: string;
|
||||
isRole: boolean;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Composable for handling API key authorization flow
|
||||
*/
|
||||
export function useApiKeyAuthorization(urlSearchParams?: URLSearchParams) {
|
||||
// Parse query parameters with SSR safety
|
||||
const params = urlSearchParams || (
|
||||
typeof window !== 'undefined'
|
||||
? new URLSearchParams(window.location.search)
|
||||
: new URLSearchParams()
|
||||
);
|
||||
|
||||
const authParams = ref<ApiKeyAuthorizationParams>({
|
||||
name: params.get('name') || 'Unknown Application',
|
||||
description: params.get('description') || '',
|
||||
scopes: (params.get('scopes') || '').split(',').filter(Boolean),
|
||||
redirectUri: params.get('redirect_uri') || '',
|
||||
state: params.get('state') || '',
|
||||
});
|
||||
|
||||
// Validate redirect URI - allow any valid URL including app URLs and custom schemes
|
||||
const isValidRedirectUri = (uri: string): boolean => {
|
||||
if (!uri) return false;
|
||||
try {
|
||||
// Just check if it's a valid URL format, don't restrict protocols or hosts
|
||||
new URL(uri);
|
||||
return true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
// Format scopes for display
|
||||
const formatPermissions = (scopes: string[]): FormattedPermission[] => {
|
||||
return scopes.map(scope => {
|
||||
if (scope.startsWith('role:')) {
|
||||
const role = scope.substring(5);
|
||||
return {
|
||||
scope,
|
||||
name: role.toUpperCase(),
|
||||
description: `Grant ${role} role access`,
|
||||
isRole: true,
|
||||
};
|
||||
} else {
|
||||
const [resource, action] = scope.split(':');
|
||||
if (resource && action) {
|
||||
const resourceName = resource.charAt(0).toUpperCase() + resource.slice(1);
|
||||
const actionDesc = action === '*'
|
||||
? 'Full'
|
||||
: action.charAt(0).toUpperCase() + action.slice(1);
|
||||
return {
|
||||
scope,
|
||||
name: `${resourceName} - ${actionDesc}`,
|
||||
description: `${actionDesc} access to ${resourceName}`,
|
||||
isRole: false,
|
||||
};
|
||||
}
|
||||
}
|
||||
return {
|
||||
scope,
|
||||
name: scope,
|
||||
description: scope,
|
||||
isRole: false
|
||||
};
|
||||
});
|
||||
};
|
||||
|
||||
// Use the shared convertScopesToFrontendFormPermissions function from @unraid/shared
|
||||
// This ensures consistent scope parsing across frontend and backend
|
||||
|
||||
// Build redirect URL with API key or error
|
||||
const buildCallbackUrl = (
|
||||
redirectUri: string,
|
||||
apiKey?: string,
|
||||
error?: string,
|
||||
state?: string
|
||||
): string => {
|
||||
try {
|
||||
const url = new URL(redirectUri);
|
||||
if (apiKey) {
|
||||
url.searchParams.set('api_key', apiKey);
|
||||
}
|
||||
if (error) {
|
||||
url.searchParams.set('error', error);
|
||||
}
|
||||
if (state) {
|
||||
url.searchParams.set('state', state);
|
||||
}
|
||||
return url.toString();
|
||||
} catch {
|
||||
throw new Error('Invalid redirect URI');
|
||||
}
|
||||
};
|
||||
|
||||
// Computed properties
|
||||
const formattedPermissions = computed(() =>
|
||||
formatPermissions(authParams.value.scopes)
|
||||
);
|
||||
|
||||
const hasValidRedirectUri = computed(() =>
|
||||
isValidRedirectUri(authParams.value.redirectUri)
|
||||
);
|
||||
|
||||
const defaultKeyName = computed(() => authParams.value.name);
|
||||
|
||||
return {
|
||||
authParams,
|
||||
formattedPermissions,
|
||||
hasValidRedirectUri,
|
||||
defaultKeyName,
|
||||
isValidRedirectUri,
|
||||
formatPermissions,
|
||||
convertScopesToPermissions,
|
||||
buildCallbackUrl,
|
||||
};
|
||||
}
|
||||
@@ -4,7 +4,7 @@ import {
|
||||
scopesToFormData,
|
||||
buildCallbackUrl as buildUrl,
|
||||
generateAuthorizationUrl as generateUrl
|
||||
} from '~/utils/authorizationScopes.js';
|
||||
} from '~/utils/authorizationScopes';
|
||||
|
||||
export interface ApiKeyAuthorizationParams {
|
||||
name: string;
|
||||
@@ -14,22 +14,6 @@ export interface ApiKeyAuthorizationParams {
|
||||
state: string;
|
||||
}
|
||||
|
||||
// Re-export types from utils for convenience
|
||||
export type {
|
||||
AuthorizationFormData,
|
||||
AuthorizationLinkParams,
|
||||
RawPermission
|
||||
} from '~/utils/authorizationScopes.js';
|
||||
|
||||
// Re-export functions for direct use
|
||||
export {
|
||||
encodePermissionsToScopes,
|
||||
decodeScopesToPermissions,
|
||||
extractActionVerb,
|
||||
generateAuthorizationUrl,
|
||||
buildCallbackUrl
|
||||
} from '~/utils/authorizationScopes.js';
|
||||
|
||||
/**
|
||||
* Composable for authorization link handling with reactive state
|
||||
*/
|
||||
|
||||
83
web/composables/useContentHighlighting.ts
Normal file
83
web/composables/useContentHighlighting.ts
Normal file
@@ -0,0 +1,83 @@
|
||||
import hljs from 'highlight.js/lib/core';
|
||||
import DOMPurify from 'isomorphic-dompurify';
|
||||
import { AnsiUp } from 'ansi_up';
|
||||
|
||||
import 'highlight.js/styles/github-dark.css';
|
||||
|
||||
import apache from 'highlight.js/lib/languages/apache';
|
||||
import bash from 'highlight.js/lib/languages/bash';
|
||||
import ini from 'highlight.js/lib/languages/ini';
|
||||
import javascript from 'highlight.js/lib/languages/javascript';
|
||||
import json from 'highlight.js/lib/languages/json';
|
||||
import nginx from 'highlight.js/lib/languages/nginx';
|
||||
import php from 'highlight.js/lib/languages/php';
|
||||
import plaintext from 'highlight.js/lib/languages/plaintext';
|
||||
import xml from 'highlight.js/lib/languages/xml';
|
||||
import yaml from 'highlight.js/lib/languages/yaml';
|
||||
|
||||
// Register the languages (only once)
|
||||
let languagesRegistered = false;
|
||||
|
||||
const registerLanguages = () => {
|
||||
if (!languagesRegistered) {
|
||||
hljs.registerLanguage('plaintext', plaintext);
|
||||
hljs.registerLanguage('bash', bash);
|
||||
hljs.registerLanguage('ini', ini);
|
||||
hljs.registerLanguage('xml', xml);
|
||||
hljs.registerLanguage('json', json);
|
||||
hljs.registerLanguage('yaml', yaml);
|
||||
hljs.registerLanguage('nginx', nginx);
|
||||
hljs.registerLanguage('apache', apache);
|
||||
hljs.registerLanguage('javascript', javascript);
|
||||
hljs.registerLanguage('php', php);
|
||||
languagesRegistered = true;
|
||||
}
|
||||
};
|
||||
|
||||
export const useContentHighlighting = () => {
|
||||
// Initialize ANSI to HTML converter with CSS classes
|
||||
const ansiConverter = new AnsiUp();
|
||||
ansiConverter.use_classes = true;
|
||||
ansiConverter.escape_html = true;
|
||||
|
||||
// Register languages on first use
|
||||
registerLanguages();
|
||||
|
||||
// Function to highlight content
|
||||
const highlightContent = (content: string, language?: string): string => {
|
||||
try {
|
||||
let highlighted: string;
|
||||
|
||||
// Check if content contains ANSI escape sequences
|
||||
// eslint-disable-next-line no-control-regex
|
||||
const hasAnsiSequences = /\x1b\[/.test(content);
|
||||
|
||||
if (hasAnsiSequences) {
|
||||
// Use ANSI converter for content with ANSI codes
|
||||
highlighted = ansiConverter.ansi_to_html(content);
|
||||
} else if (language) {
|
||||
// Use highlight.js for specific language if provided
|
||||
const result = hljs.highlight(content, { language, ignoreIllegals: true });
|
||||
highlighted = result.value;
|
||||
} else {
|
||||
// Use highlight.js auto-detection for non-ANSI content
|
||||
const result = hljs.highlightAuto(content);
|
||||
highlighted = result.value;
|
||||
}
|
||||
|
||||
// Sanitize the highlighted HTML while preserving class attributes for syntax highlighting
|
||||
return DOMPurify.sanitize(highlighted, {
|
||||
ALLOWED_TAGS: ['span', 'br', 'code', 'pre'],
|
||||
ALLOWED_ATTR: ['class'] // Allow class attribute for hljs and ANSI color classes
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Error highlighting content:', error);
|
||||
// Fallback to sanitized but not highlighted content
|
||||
return DOMPurify.sanitize(content);
|
||||
}
|
||||
};
|
||||
|
||||
return {
|
||||
highlightContent
|
||||
};
|
||||
};
|
||||
@@ -107,6 +107,7 @@
|
||||
"@vueuse/components": "13.8.0",
|
||||
"@vueuse/integrations": "13.8.0",
|
||||
"ajv": "8.17.1",
|
||||
"ansi_up": "^6.0.6",
|
||||
"class-variance-authority": "0.7.1",
|
||||
"clsx": "2.1.1",
|
||||
"crypto-js": "4.2.0",
|
||||
|
||||
Reference in New Issue
Block a user