mirror of
https://github.com/btouchard/ackify.git
synced 2025-12-31 01:49:41 -06:00
feat: migrate to Vue.js SPA with API-first architecture
Major refactoring to modernize the application architecture: Backend changes: - Restructure API with v1 versioning and modular handlers - Add comprehensive OpenAPI specification - Implement RESTful endpoints for documents, signatures, admin - Add checksum verification system for document integrity - Add server-side runtime injection of ACKIFY_BASE_URL and meta tags - Generate dynamic Open Graph/Twitter Card meta tags for unfurling - Remove legacy HTML template handlers - Isolate backend source on dedicated folder - Improve tests suite Frontend changes: - Migrate from Go templates to Vue.js 3 SPA with TypeScript - Add Tailwind CSS with shadcn/vue components - Implement i18n support (fr, en, es, de, it) - Add admin dashboard for document and signer management - Add signature tracking with file checksum verification - Add embed page with sign button linking to main app - Implement dark mode and accessibility features - Auto load file to compute checksum Infrastructure: - Update Dockerfile for SPA build process - Simplify deployment with embedded frontend assets - Add migration for checksum_verifications table This enables better UX, proper link previews on social platforms, and provides a foundation for future enhancements.
This commit is contained in:
@@ -15,8 +15,6 @@ LICENSE
|
||||
.env
|
||||
.env.local
|
||||
.env.example
|
||||
community
|
||||
migrate
|
||||
compose.cloud.yml
|
||||
compose.local.yml
|
||||
|
||||
|
||||
62
.github/workflows/ci.yml
vendored
62
.github/workflows/ci.yml
vendored
@@ -43,10 +43,13 @@ jobs:
|
||||
cache: true
|
||||
|
||||
- name: Download dependencies
|
||||
run: go mod download
|
||||
run: |
|
||||
cd backend
|
||||
go mod download
|
||||
|
||||
- name: Run go fmt check
|
||||
run: |
|
||||
cd backend
|
||||
if [ "$(gofmt -s -l . | wc -l)" -gt 0 ]; then
|
||||
echo "The following files need to be formatted:"
|
||||
gofmt -s -l .
|
||||
@@ -54,41 +57,52 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Run go vet
|
||||
run: go vet ./...
|
||||
run: |
|
||||
cd backend
|
||||
go vet ./...
|
||||
|
||||
- name: Run unit tests
|
||||
env:
|
||||
APP_BASE_URL: "http://localhost:8080"
|
||||
APP_ORGANISATION: "Test Org"
|
||||
OAUTH_CLIENT_ID: "test-client-id"
|
||||
OAUTH_CLIENT_SECRET: "test-client-secret"
|
||||
OAUTH_COOKIE_SECRET: "dGVzdC1jb29raWUtc2VjcmV0LXRlc3QtY29va2llLXNlY3JldA=="
|
||||
run: go test -v -race -short ./...
|
||||
ACKIFY_BASE_URL: "http://localhost:8080"
|
||||
ACKIFY_ORGANISATION: "Test Org"
|
||||
ACKIFY_OAUTH_CLIENT_ID: "test-client-id"
|
||||
ACKIFY_OAUTH_CLIENT_SECRET: "test-client-secret"
|
||||
ACKIFY_OAUTH_COOKIE_SECRET: "dGVzdC1jb29raWUtc2VjcmV0LXRlc3QtY29va2llLXNlY3JldA=="
|
||||
run: |
|
||||
cd backend
|
||||
go test -v -race -short ./...
|
||||
|
||||
- name: Run integration tests
|
||||
env:
|
||||
DB_DSN: "postgres://postgres:testpassword@localhost:5432/ackify_test?sslmode=disable"
|
||||
INTEGRATION_TESTS: "true"
|
||||
run: go test -v -race -tags=integration ./internal/infrastructure/database/...
|
||||
ACKIFY_DB_DSN: "postgres://postgres:testpassword@localhost:5432/ackify_test?sslmode=disable"
|
||||
INTEGRATION_TESTS: "1"
|
||||
run: |
|
||||
cd backend
|
||||
go test -v -race -tags=integration ./internal/infrastructure/database/...
|
||||
|
||||
- name: Generate coverage report (unit+integration)
|
||||
env:
|
||||
DB_DSN: "postgres://postgres:testpassword@localhost:5432/ackify_test?sslmode=disable"
|
||||
INTEGRATION_TESTS: "true"
|
||||
APP_BASE_URL: "http://localhost:8080"
|
||||
APP_ORGANISATION: "Test Org"
|
||||
OAUTH_CLIENT_ID: "test-client-id"
|
||||
OAUTH_CLIENT_SECRET: "test-client-secret"
|
||||
OAUTH_COOKIE_SECRET: "dGVzdC1jb29raWUtc2VjcmV0LXRlc3QtY29va2llLXNlY3JldA=="
|
||||
run: go test -v -race -tags=integration -coverprofile=coverage.out ./...
|
||||
ACKIFY_DB_DSN: "postgres://postgres:testpassword@localhost:5432/ackify_test?sslmode=disable"
|
||||
INTEGRATION_TESTS: "1"
|
||||
ACKIFY_BASE_URL: "http://localhost:8080"
|
||||
ACKIFY_ORGANISATION: "Test Org"
|
||||
ACKIFY_OAUTH_CLIENT_ID: "test-client-id"
|
||||
ACKIFY_OAUTH_CLIENT_SECRET: "test-client-secret"
|
||||
ACKIFY_OAUTH_COOKIE_SECRET: "dGVzdC1jb29raWUtc2VjcmV0LXRlc3QtY29va2llLXNlY3JldA=="
|
||||
run: |
|
||||
cd backend
|
||||
go test -v -race -tags=integration -coverprofile=coverage.out ./...
|
||||
go tool cover -func=coverage.out | tail -1
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
if: success()
|
||||
uses: codecov/codecov-action@v3
|
||||
uses: codecov/codecov-action@v4
|
||||
with:
|
||||
file: ./coverage.out
|
||||
flags: unittests,integrations
|
||||
name: codecov-umbrella
|
||||
files: ./backend/coverage.out
|
||||
flags: unittests,integration
|
||||
name: codecov-ackify-ce
|
||||
fail_ci_if_error: false
|
||||
verbose: true
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
build:
|
||||
name: Build and Push Docker Image
|
||||
|
||||
10
.gitignore
vendored
10
.gitignore
vendored
@@ -1,23 +1,21 @@
|
||||
CLAUDE.md
|
||||
AGENTS.md
|
||||
*SETUP.md
|
||||
RELEASE_*.md
|
||||
.ai
|
||||
.claude
|
||||
.idea
|
||||
.env
|
||||
|
||||
.ai/.last_prompt.txt
|
||||
.ai/.cache/
|
||||
.env.local
|
||||
|
||||
.gocache/
|
||||
codecov.yml
|
||||
|
||||
compose.local.yml
|
||||
compose.cloud.yml
|
||||
client_secret*.json
|
||||
|
||||
/static
|
||||
/community
|
||||
/migrate
|
||||
/cmd/community/web/dist
|
||||
|
||||
# Tailwind CSS
|
||||
/bin/tailwindcss
|
||||
|
||||
185
BUILD.md
185
BUILD.md
@@ -2,11 +2,21 @@
|
||||
|
||||
## Overview
|
||||
|
||||
Ackify Community Edition (CE) is the open-source version of Ackify, a document signature validation platform. This guide covers building and deploying the Community Edition.
|
||||
Ackify Community Edition (CE) is the open-source version of Ackify, a document signature validation platform with a modern API-first architecture. This guide covers building and deploying the Community Edition.
|
||||
|
||||
## Architecture
|
||||
|
||||
Ackify CE consists of:
|
||||
- **Go Backend**: Vue 3 SPA frontend served by Go backend with REST API v1, OAuth2 authentication, and PostgreSQL database
|
||||
- **Vue 3 SPA Frontend**: Modern TypeScript-based single-page application with Vite, Pinia state management, and Tailwind CSS
|
||||
- **Docker Multi-Stage Build**: Optimized containerized deployment
|
||||
|
||||
The built Vue 3 SPA is embedded directly into the Go binary via the `//go:embed all:web/dist` directive, allowing single-binary deployment.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Go 1.24.5 or later
|
||||
- Node.js 22+ and npm (for Vue SPA development)
|
||||
- Docker and Docker Compose (for containerized deployment)
|
||||
- PostgreSQL 16+ (for database)
|
||||
|
||||
@@ -19,24 +29,42 @@ git clone https://github.com/btouchard/ackify-ce.git
|
||||
cd ackify-ce
|
||||
```
|
||||
|
||||
### 2. Build the Application
|
||||
### 2. Build the Vue SPA
|
||||
|
||||
```bash
|
||||
cd webapp
|
||||
npm install
|
||||
npm run build
|
||||
cd ..
|
||||
```
|
||||
|
||||
This creates an optimized production build in `webapp/dist/`.
|
||||
|
||||
### 3. Build the Go Application
|
||||
|
||||
Run from project root:
|
||||
|
||||
```bash
|
||||
# Build Community Edition
|
||||
go build ./cmd/community
|
||||
go build ./backend/cmd/community
|
||||
|
||||
# Or build with specific output name
|
||||
go build -o ackify-ce ./cmd/community
|
||||
go build -o ackify-ce ./backend/cmd/community
|
||||
```
|
||||
|
||||
### 3. Run Tests
|
||||
The Go application will serve both the API endpoints and the Vue SPA.
|
||||
|
||||
### 4. Run Tests
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
# Run Go tests
|
||||
go test ./...
|
||||
|
||||
# Run tests with verbose output
|
||||
go test -v ./tests/
|
||||
# Run Go tests with verbose output
|
||||
go test -v ./backend/internal/...
|
||||
|
||||
# Run integration tests (requires PostgreSQL)
|
||||
INTEGRATION_TESTS=1 go test -tags=integration -v ./internal/infrastructure/database/
|
||||
```
|
||||
|
||||
## Configuration
|
||||
@@ -55,14 +83,56 @@ Required environment variables:
|
||||
- `ACKIFY_OAUTH_CLIENT_ID`: OAuth2 client ID
|
||||
- `ACKIFY_OAUTH_CLIENT_SECRET`: OAuth2 client secret
|
||||
- `ACKIFY_DB_DSN`: PostgreSQL connection string
|
||||
- `ACKIFY_OAUTH_COOKIE_SECRET`: Base64-encoded secret for session cookies
|
||||
- `ACKIFY_OAUTH_COOKIE_SECRET`: Base64-encoded secret for session cookies (32+ bytes)
|
||||
- `ACKIFY_ORGANISATION`: Organization name displayed in the application
|
||||
|
||||
Optional configuration:
|
||||
- `ACKIFY_TEMPLATES_DIR`: Custom path to HTML templates directory (defaults to relative path for development, `/app/templates` in Docker)
|
||||
- `ACKIFY_TEMPLATES_DIR`: Custom path to emails templates directory (defaults to relative path for development, `/app/templates` in Docker)
|
||||
- `ACKIFY_LOCALES_DIR`: Custom path to locales directory (default: `locales`)
|
||||
- `ACKIFY_SPA_DIR`: Custom path to Vue SPA build directory (default: `dist`)
|
||||
- `ACKIFY_LISTEN_ADDR`: Server listen address (default: `:8080`)
|
||||
- `ACKIFY_ED25519_PRIVATE_KEY`: Base64-encoded Ed25519 private key for signatures
|
||||
- `ACKIFY_OAUTH_PROVIDER`: OAuth provider (`google`, `github`, `gitlab` or empty for custom)
|
||||
- `ACKIFY_OAUTH_ALLOWED_DOMAIN`: Domain restriction for OAuth users
|
||||
- `ACKIFY_OAUTH_AUTO_LOGIN`: Enable automatic OAuth login when session exists (default: `false`)
|
||||
- `ACKIFY_LOG_LEVEL`: Logging level - `debug`, `info`, `warn`, `error` (default: `info`)
|
||||
- `ACKIFY_ADMIN_EMAILS`: Comma-separated list of admin email addresses
|
||||
- `ACKIFY_MAIL_HOST`: SMTP server host (required to enable email features)
|
||||
- `ACKIFY_MAIL_PORT`: SMTP server port (default: `587`)
|
||||
- `ACKIFY_MAIL_USERNAME`: SMTP username for authentication
|
||||
- `ACKIFY_MAIL_PASSWORD`: SMTP password for authentication
|
||||
- `ACKIFY_MAIL_TLS`: Enable TLS connection (default: `true`)
|
||||
- `ACKIFY_MAIL_STARTTLS`: Enable STARTTLS (default: `true`)
|
||||
- `ACKIFY_MAIL_TIMEOUT`: SMTP connection timeout (default: `10s`)
|
||||
- `ACKIFY_MAIL_FROM`: Email sender address
|
||||
- `ACKIFY_MAIL_FROM_NAME`: Email sender name (defaults to `ACKIFY_ORGANISATION`)
|
||||
- `ACKIFY_MAIL_SUBJECT_PREFIX`: Prefix for email subjects
|
||||
- `ACKIFY_MAIL_TEMPLATE_DIR`: Custom path to email templates (default: `templates/emails`)
|
||||
- `ACKIFY_MAIL_DEFAULT_LOCALE`: Default locale for emails (default: `en`)
|
||||
|
||||
### Logging Configuration
|
||||
|
||||
Ackify uses structured JSON logging with the following levels:
|
||||
|
||||
- **debug**: Detailed diagnostic information (request/response details, authentication attempts)
|
||||
- **info**: General informational messages (successful operations, API requests)
|
||||
- **warn**: Warning messages (failed authentication, rate limiting)
|
||||
- **error**: Error messages (server errors, database failures)
|
||||
|
||||
Example:
|
||||
```bash
|
||||
# Development - verbose logging
|
||||
ACKIFY_LOG_LEVEL=debug
|
||||
|
||||
# Production - standard logging
|
||||
ACKIFY_LOG_LEVEL=info
|
||||
```
|
||||
|
||||
Logs include structured fields for easy parsing:
|
||||
- `request_id`: Unique identifier for each request
|
||||
- `user_email`: Authenticated user email
|
||||
- `method`, `path`, `status`: HTTP request details
|
||||
- `duration_ms`: Request processing time
|
||||
|
||||
### OAuth2 Providers
|
||||
|
||||
@@ -81,7 +151,7 @@ Supported providers:
|
||||
3. Run the binary:
|
||||
|
||||
```bash
|
||||
./community
|
||||
./ackify
|
||||
```
|
||||
|
||||
### Option 2: Docker Compose (Recommended)
|
||||
@@ -143,12 +213,80 @@ curl http://localhost:8080/health
|
||||
|
||||
## API Endpoints
|
||||
|
||||
- `GET /` - Homepage
|
||||
### API v1 (RESTful)
|
||||
|
||||
All API v1 endpoints are prefixed with `/api/v1`.
|
||||
|
||||
#### Public Endpoints
|
||||
- `GET /api/v1/health` - Health check
|
||||
- `GET /api/v1/csrf` - Get CSRF token for authenticated requests
|
||||
- `GET /api/v1/documents` - List all documents
|
||||
- `GET /api/v1/documents/{docId}` - Get document details
|
||||
- `GET /api/v1/documents/{docId}/signatures` - Get document signatures
|
||||
- `GET /api/v1/documents/{docId}/expected-signers` - Get expected signers list
|
||||
|
||||
#### Authentication Endpoints
|
||||
- `POST /api/v1/auth/start` - Start OAuth flow
|
||||
- `GET /api/v1/auth/logout` - Logout
|
||||
- `GET /api/v1/auth/check` - Check authentication status (if `ACKIFY_OAUTH_AUTO_LOGIN=true`)
|
||||
|
||||
#### Authenticated Endpoints (require valid session)
|
||||
- `GET /api/v1/users/me` - Get current user profile
|
||||
- `GET /api/v1/signatures` - Get current user's signatures
|
||||
- `POST /api/v1/signatures` - Create new signature
|
||||
- `GET /api/v1/documents/{docId}/signatures/status` - Get user's signature status for document
|
||||
|
||||
#### Admin Endpoints (require admin privileges)
|
||||
- `GET /api/v1/admin/documents` - List all documents with stats
|
||||
- `GET /api/v1/admin/documents/{docId}` - Get document details (admin view)
|
||||
- `GET /api/v1/admin/documents/{docId}/signers` - Get document with signers and stats
|
||||
- `POST /api/v1/admin/documents/{docId}/signers` - Add expected signer
|
||||
- `DELETE /api/v1/admin/documents/{docId}/signers/{email}` - Remove expected signer
|
||||
- `POST /api/v1/admin/documents/{docId}/reminders` - Send email reminders
|
||||
- `GET /api/v1/admin/documents/{docId}/reminders` - Get reminder history
|
||||
|
||||
### Public Endpoints
|
||||
|
||||
- `GET /` - Vue SPA (serves index.html for all routes)
|
||||
- `GET /health` - Health check
|
||||
- `GET /sign?doc=<id>` - Document signing interface
|
||||
- `POST /sign` - Create signature
|
||||
- `GET /status?doc=<id>` - Get document signature status (JSON)
|
||||
- `GET /status.png?doc=<id>&user=<email>` - Signature status badge
|
||||
- `GET /api/v1/auth/callback` - OAuth2 callback handler
|
||||
|
||||
**Note:** Link unfurling for messaging apps (Slack, Discord, etc.) is handled automatically via dynamic Open Graph meta tags in the Vue SPA. There are no separate `/embed` or `/oembed` endpoints.
|
||||
|
||||
## Development
|
||||
|
||||
### Vue SPA Development
|
||||
|
||||
For Vue SPA development with hot-reload:
|
||||
|
||||
```bash
|
||||
cd webapp
|
||||
npm install
|
||||
npm run dev
|
||||
```
|
||||
|
||||
This starts a Vite development server on `http://localhost:5173` with:
|
||||
- Hot module replacement (HMR)
|
||||
- TypeScript type checking
|
||||
- API proxy to backend (configured in `vite.config.ts`)
|
||||
|
||||
The development server proxies API requests to your Go backend (default: `http://localhost:8080`).
|
||||
|
||||
### Backend Development
|
||||
|
||||
Run the Go backend separately:
|
||||
|
||||
```bash
|
||||
# In project root
|
||||
go build ./backend/cmd/community
|
||||
./ackify
|
||||
```
|
||||
|
||||
Or use Docker Compose for complete stack:
|
||||
|
||||
```bash
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
@@ -157,10 +295,23 @@ curl http://localhost:8080/health
|
||||
1. **Port already in use**: Change `ACKIFY_LISTEN_ADDR` in environment variables
|
||||
2. **Database connection failed**: Check `ACKIFY_DB_DSN` and ensure PostgreSQL is running
|
||||
3. **OAuth2 errors**: Verify `ACKIFY_OAUTH_CLIENT_ID` and `ACKIFY_OAUTH_CLIENT_SECRET`
|
||||
4. **SPA not loading**: Ensure Vue app is built (`npm run build` in webapp/) before running Go binary
|
||||
5. **CORS errors in development**: Check that Vite dev server proxy is correctly configured
|
||||
|
||||
### Logs
|
||||
|
||||
Enable debug logging by setting `LOG_LEVEL=debug` in your environment.
|
||||
Enable debug logging to see detailed request/response information:
|
||||
|
||||
```bash
|
||||
ACKIFY_LOG_LEVEL=debug ./ackify
|
||||
```
|
||||
|
||||
Debug logs include:
|
||||
- HTTP request details (method, path, headers)
|
||||
- Authentication attempts and results
|
||||
- Database queries and performance
|
||||
- OAuth flow progression
|
||||
- Signature creation and validation steps
|
||||
|
||||
## Contributing
|
||||
|
||||
|
||||
148
CHANGELOG.md
148
CHANGELOG.md
@@ -5,6 +5,153 @@ All notable changes to this project will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [2.0.0] - 2025-10-16
|
||||
|
||||
### 🎉 Major Release: API-First Vue Migration
|
||||
|
||||
Complete architectural overhaul to a modern API-first architecture with Vue 3 SPA frontend.
|
||||
|
||||
### Added
|
||||
|
||||
- **RESTful API v1**
|
||||
- Versioned API with `/api/v1` prefix
|
||||
- Structured JSON responses with consistent error handling
|
||||
- Public endpoints: health, documents, signatures, expected signers
|
||||
- Authentication endpoints: OAuth flow, logout, auth check
|
||||
- Authenticated endpoints: user profile, signatures, signature creation
|
||||
- Admin endpoints: document management, signer management, reminders
|
||||
- OpenAPI specification endpoint `/api/v1/openapi.json`
|
||||
|
||||
- **Vue 3 SPA Frontend**
|
||||
- Modern single-page application with TypeScript
|
||||
- Vite build tool with hot module replacement (HMR)
|
||||
- Pinia state management for centralized application state
|
||||
- Vue Router for client-side routing
|
||||
- Tailwind CSS for utility-first styling
|
||||
- Responsive design with mobile support
|
||||
- Pages: Home, Sign, Signatures, Embed, Admin Dashboard, Document Details
|
||||
|
||||
- **Comprehensive Logging System**
|
||||
- Structured JSON logging with `slog` package
|
||||
- Log levels: debug, info, warn, error (configurable via `ACKIFY_LOG_LEVEL`)
|
||||
- Request ID tracking through entire request lifecycle
|
||||
- HTTP request/response logging with timing
|
||||
- Authentication flow logging
|
||||
- Signature operation logging
|
||||
- Reminder service logging
|
||||
- Database query logging
|
||||
- OAuth flow progression logging
|
||||
|
||||
- **Enhanced Security**
|
||||
- CSRF token protection for all state-changing operations
|
||||
- Rate limiting (5 auth attempts/min, 100 general requests/min)
|
||||
- CORS configuration for development and production
|
||||
- Security headers (CSP, X-Content-Type-Options, X-Frame-Options, etc.)
|
||||
- Session-based authentication with secure cookies
|
||||
- Request ID propagation for distributed tracing
|
||||
|
||||
- **Public Embed Route**
|
||||
- `/embed/{docId}` route for public embedding (no authentication required)
|
||||
- oEmbed protocol support for unfurl functionality
|
||||
- CSP headers configured to allow iframe embedding on embed routes
|
||||
- Suitable for integration in documentation tools and wikis
|
||||
|
||||
- **Auto-Login Feature**
|
||||
- Optional `ACKIFY_OAUTH_AUTO_LOGIN` configuration
|
||||
- Silent authentication when OAuth session exists
|
||||
- `/api/v1/auth/check` endpoint for session verification
|
||||
- Seamless user experience when returning to application
|
||||
|
||||
- **Docker Multi-Stage Build**
|
||||
- Optimized Dockerfile with separate Node and Go build stages
|
||||
- Smaller final image size
|
||||
- SPA assets built during Docker build process
|
||||
- Production-ready containerized deployment
|
||||
|
||||
### Changed
|
||||
|
||||
- **Architecture**
|
||||
- Migrated from template-based rendering to API-first architecture
|
||||
- Introduced clear separation between API and frontend
|
||||
- Organized API handlers into logical modules (admin, auth, documents, signatures, users)
|
||||
- Centralized middleware in `shared` package (logging, CORS, CSRF, rate limiting, security headers)
|
||||
|
||||
- **Routing**
|
||||
- Chi router now serves both API v1 and Vue SPA
|
||||
- SPA fallback routing for all unmatched routes
|
||||
- API endpoints prefixed with `/api/v1`
|
||||
- Static assets served from `/assets` for SPA and `/static` for legacy
|
||||
|
||||
- **Authentication**
|
||||
- Standardized session-based auth across API and templates
|
||||
- CSRF protection on all authenticated API endpoints
|
||||
- Rate limiting on authentication endpoints
|
||||
|
||||
- **Documentation**
|
||||
- Updated BUILD.md with Vue SPA build instructions
|
||||
- Updated README.md with API v1 endpoint documentation
|
||||
- Updated README_FR.md with French translations
|
||||
- Added logging configuration documentation
|
||||
- Added development environment setup instructions
|
||||
|
||||
### Fixed
|
||||
|
||||
- Consistent error handling across all API endpoints
|
||||
- Proper HTTP status codes for all responses
|
||||
- CORS issues in development environment
|
||||
|
||||
### Technical Details
|
||||
|
||||
**New Files:**
|
||||
- `internal/presentation/api/` - Complete API v1 implementation
|
||||
- `admin/handler.go` - Admin endpoints
|
||||
- `auth/handler.go` - Authentication endpoints
|
||||
- `documents/handler.go` - Document endpoints
|
||||
- `signatures/handler.go` - Signature endpoints
|
||||
- `users/handler.go` - User endpoints
|
||||
- `health/handler.go` - Health check endpoint
|
||||
- `shared/` - Shared middleware and utilities
|
||||
- `logging.go` - Request logging middleware
|
||||
- `middleware.go` - Auth, admin, CSRF, rate limiting middleware
|
||||
- `response.go` - Standardized JSON response helpers
|
||||
- `errors.go` - Error code constants
|
||||
- `router.go` - API v1 router configuration
|
||||
- `webapp/` - Complete Vue 3 SPA
|
||||
- `src/components/` - Reusable Vue components
|
||||
- `src/pages/` - Page components (Home, Sign, Signatures, Embed, Admin)
|
||||
- `src/services/` - API client services
|
||||
- `src/stores/` - Pinia state stores
|
||||
- `src/router/` - Vue Router configuration
|
||||
- `vite.config.ts` - Vite build configuration
|
||||
- `tsconfig.json` - TypeScript configuration
|
||||
|
||||
**Modified Files:**
|
||||
- `pkg/web/server.go` - Updated to serve both API and SPA
|
||||
- `internal/infrastructure/auth/oauth.go` - Added structured logging
|
||||
- `internal/application/services/signature.go` - Added structured logging
|
||||
- `internal/application/services/reminder.go` - Added structured logging
|
||||
- `Dockerfile` - Multi-stage build for Node and Go
|
||||
- `docker-compose.yml` - Updated for new architecture
|
||||
|
||||
**Deprecated:**
|
||||
- Template-based admin routes (will be maintained for backward compatibility)
|
||||
- Legacy `/status` and `/status.png` endpoints (superseded by API v1)
|
||||
|
||||
### Migration Guide
|
||||
|
||||
For users upgrading from v1.x to v2.0:
|
||||
|
||||
1. **Environment Variables**: Add optional `ACKIFY_LOG_LEVEL` and `ACKIFY_OAUTH_AUTO_LOGIN` if desired
|
||||
2. **Docker**: Rebuild images to include Vue SPA build
|
||||
3. **API Clients**: Consider migrating to new API v1 endpoints for better structure
|
||||
4. **Embed URLs**: Update to use `/embed/{docId}` instead of token-based system
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- None - v2.0 maintains backward compatibility with all v1.x features
|
||||
- Template-based admin interface remains functional
|
||||
- Legacy endpoints continue to work
|
||||
|
||||
## [1.1.3] - 2025-10-08
|
||||
|
||||
### Added
|
||||
@@ -116,6 +263,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- NULL UserName handling in database operations
|
||||
- Proper string conversion for UserName field
|
||||
|
||||
[2.0.0]: https://github.com/btouchard/ackify-ce/compare/v1.1.3...v2.0.0
|
||||
[1.1.3]: https://github.com/btouchard/ackify-ce/compare/v1.1.2...v1.1.3
|
||||
[1.1.2]: https://github.com/btouchard/ackify-ce/compare/v1.1.1...v1.1.2
|
||||
[1.1.1]: https://github.com/btouchard/ackify-ce/compare/v1.1.0...v1.1.1
|
||||
|
||||
38
Dockerfile
38
Dockerfile
@@ -1,4 +1,11 @@
|
||||
# ---- Build ----
|
||||
FROM node:22-alpine AS spa-builder
|
||||
|
||||
WORKDIR /app/webapp
|
||||
COPY webapp/package*.json ./
|
||||
RUN npm ci
|
||||
COPY webapp/ ./
|
||||
RUN npm run build
|
||||
|
||||
FROM golang:alpine AS builder
|
||||
|
||||
RUN apk update && apk add --no-cache ca-certificates git curl && rm -rf /var/cache/apk/*
|
||||
@@ -8,19 +15,10 @@ WORKDIR /app
|
||||
COPY go.mod go.sum ./
|
||||
ENV GOTOOLCHAIN=auto
|
||||
RUN go mod download && go mod verify
|
||||
COPY . .
|
||||
COPY backend/ ./backend/
|
||||
|
||||
# Download Tailwind CSS CLI (use v3 for compatibility)
|
||||
RUN ARCH=$(uname -m) && \
|
||||
if [ "$ARCH" = "x86_64" ]; then TAILWIND_ARCH="x64"; \
|
||||
elif [ "$ARCH" = "aarch64" ]; then TAILWIND_ARCH="arm64"; \
|
||||
else echo "Unsupported architecture: $ARCH" && exit 1; fi && \
|
||||
curl -sL https://github.com/tailwindlabs/tailwindcss/releases/download/v3.4.16/tailwindcss-linux-${TAILWIND_ARCH} -o /tmp/tailwindcss && \
|
||||
chmod +x /tmp/tailwindcss
|
||||
|
||||
# Build CSS
|
||||
RUN mkdir -p ./static && \
|
||||
/tmp/tailwindcss -i ./assets/input.css -o ./static/output.css --minify
|
||||
RUN mkdir -p backend/cmd/community/web/dist
|
||||
COPY --from=spa-builder /app/webapp/dist ./backend/cmd/community/web/dist
|
||||
|
||||
ARG VERSION="dev"
|
||||
ARG COMMIT="unknown"
|
||||
@@ -29,14 +27,13 @@ ARG BUILD_DATE="unknown"
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build \
|
||||
-a -installsuffix cgo \
|
||||
-ldflags="-w -s -X main.Version=${VERSION} -X main.Commit=${COMMIT} -X main.BuildDate=${BUILD_DATE}" \
|
||||
-o ackify ./cmd/community
|
||||
-o ackify ./backend/cmd/community
|
||||
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build \
|
||||
-a -installsuffix cgo \
|
||||
-ldflags="-w -s" \
|
||||
-o migrate ./cmd/migrate
|
||||
-o migrate ./backend/cmd/migrate
|
||||
|
||||
# ---- Run ----
|
||||
FROM gcr.io/distroless/static-debian12:nonroot
|
||||
|
||||
ARG VERSION="dev"
|
||||
@@ -53,16 +50,13 @@ COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/
|
||||
WORKDIR /app
|
||||
COPY --from=builder /app/ackify /app/ackify
|
||||
COPY --from=builder /app/migrate /app/migrate
|
||||
COPY --from=builder /app/migrations /app/migrations
|
||||
COPY --from=builder /app/locales /app/locales
|
||||
COPY --from=builder /app/templates /app/templates
|
||||
COPY --from=builder /app/static /app/static
|
||||
COPY --from=builder /app/backend/migrations /app/migrations
|
||||
COPY --from=builder /app/backend/locales /app/locales
|
||||
COPY --from=builder /app/backend/templates /app/templates
|
||||
|
||||
ENV ACKIFY_TEMPLATES_DIR=/app/templates
|
||||
ENV ACKIFY_LOCALES_DIR=/app/locales
|
||||
ENV ACKIFY_STATIC_DIR=/app/static
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
ENTRYPOINT ["/app/ackify"]
|
||||
## SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
|
||||
84
Makefile
84
Makefile
@@ -1,12 +1,14 @@
|
||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
# Makefile for ackify-ce project
|
||||
|
||||
.PHONY: build test test-unit test-integration test-short coverage lint fmt vet clean help
|
||||
.PHONY: build build-frontend build-backend build-all test test-unit test-integration test-short coverage lint fmt vet clean help dev dev-frontend dev-backend migrate-up migrate-down docker-rebuild
|
||||
|
||||
# Variables
|
||||
BINARY_NAME=ackify-ce
|
||||
BUILD_DIR=./cmd/community
|
||||
BUILD_DIR=./backend/cmd/community
|
||||
MIGRATE_DIR=./backend/cmd/migrate
|
||||
COVERAGE_DIR=coverage
|
||||
WEBAPP_DIR=./webapp
|
||||
|
||||
# Default target
|
||||
help: ## Display this help message
|
||||
@@ -14,30 +16,37 @@ help: ## Display this help message
|
||||
@awk 'BEGIN {FS = ":.*?## "} /^[a-zA-Z_-]+:.*?## / {printf " \033[36m%-15s\033[0m %s\n", $$1, $$2}' $(MAKEFILE_LIST)
|
||||
|
||||
# Build targets
|
||||
build: ## Build the application
|
||||
build: build-all ## Build the complete application (frontend + backend)
|
||||
|
||||
build-frontend: ## Build the Vue.js frontend
|
||||
@echo "Building frontend..."
|
||||
cd $(WEBAPP_DIR) && npm install && npm run build
|
||||
|
||||
build-backend: ## Build the Go backend
|
||||
@echo "Building $(BINARY_NAME)..."
|
||||
go build -o $(BINARY_NAME) $(BUILD_DIR)
|
||||
|
||||
build-all: build-frontend build-backend ## Build frontend and backend
|
||||
|
||||
# Test targets
|
||||
test: test-unit test-integration ## Run all tests
|
||||
|
||||
test-unit: ## Run unit tests
|
||||
@echo "Running unit tests with race detection..."
|
||||
CGO_ENABLED=1 go test -short -race -v ./internal/... ./pkg/... ./cmd/...
|
||||
CGO_ENABLED=1 go test -short -race -v ./backend/internal/... ./backend/pkg/... ./backend/cmd/...
|
||||
|
||||
test-integration: ## Run integration tests (requires PostgreSQL)
|
||||
test-integration: ## Run integration tests (requires PostgreSQL - migrations are applied automatically)
|
||||
@echo "Running integration tests with race detection..."
|
||||
@if [ -z "$(DB_DSN)" ]; then \
|
||||
export DB_DSN="postgres://postgres:testpassword@localhost:5432/ackify_test?sslmode=disable"; \
|
||||
fi; \
|
||||
export INTEGRATION_TESTS=true; \
|
||||
CGO_ENABLED=1 go test -v -race -tags=integration ./internal/infrastructure/database/...
|
||||
@echo "Note: Migrations are applied automatically by test setup"
|
||||
@export INTEGRATION_TESTS=1; \
|
||||
export ACKIFY_DB_DSN="postgres://postgres:testpassword@localhost:5432/ackify_test?sslmode=disable"; \
|
||||
CGO_ENABLED=1 go test -v -race -tags=integration ./backend/internal/infrastructure/database/...
|
||||
|
||||
test-integration-setup: ## Setup test database for integration tests
|
||||
test-integration-setup: ## Setup test database for integration tests (migrations applied by tests)
|
||||
@echo "Setting up test database..."
|
||||
@psql "postgres://postgres:testpassword@localhost:5432/postgres?sslmode=disable" -c "DROP DATABASE IF EXISTS ackify_test;" || true
|
||||
@psql "postgres://postgres:testpassword@localhost:5432/postgres?sslmode=disable" -c "CREATE DATABASE ackify_test;"
|
||||
@echo "Test database ready!"
|
||||
@echo "Test database ready! Migrations will be applied automatically when tests run."
|
||||
|
||||
test-short: ## Run only quick tests
|
||||
@echo "Running short tests..."
|
||||
@@ -55,18 +64,18 @@ coverage: ## Generate test coverage report
|
||||
coverage-integration: ## Generate integration test coverage report
|
||||
@echo "Generating integration coverage report..."
|
||||
@mkdir -p $(COVERAGE_DIR)
|
||||
@export DB_DSN="postgres://postgres:testpassword@localhost:5432/ackify_test?sslmode=disable"; \
|
||||
export INTEGRATION_TESTS=true; \
|
||||
go test -v -race -tags=integration -coverprofile=$(COVERAGE_DIR)/coverage-integration.out ./internal/infrastructure/database/...
|
||||
@export ACKIFY_DB_DSN="postgres://postgres:testpassword@localhost:5432/ackify_test?sslmode=disable"; \
|
||||
export INTEGRATION_TESTS=1; \
|
||||
CGO_ENABLED=1 go test -v -race -tags=integration -coverprofile=$(COVERAGE_DIR)/coverage-integration.out ./backend/internal/infrastructure/database/...
|
||||
go tool cover -html=$(COVERAGE_DIR)/coverage-integration.out -o $(COVERAGE_DIR)/coverage-integration.html
|
||||
@echo "Integration coverage report generated: $(COVERAGE_DIR)/coverage-integration.html"
|
||||
|
||||
coverage-all: ## Generate full coverage report (unit + integration)
|
||||
@echo "Generating full coverage report..."
|
||||
@mkdir -p $(COVERAGE_DIR)
|
||||
@export DB_DSN="postgres://postgres:testpassword@localhost:5432/ackify_test?sslmode=disable"; \
|
||||
export INTEGRATION_TESTS=true; \
|
||||
go test -v -race -tags=integration -coverprofile=$(COVERAGE_DIR)/coverage-all.out ./...
|
||||
@export ACKIFY_DB_DSN="postgres://postgres:testpassword@localhost:5432/ackify_test?sslmode=disable"; \
|
||||
export INTEGRATION_TESTS=1; \
|
||||
CGO_ENABLED=1 go test -v -race -tags=integration -coverprofile=$(COVERAGE_DIR)/coverage-all.out ./...
|
||||
go tool cover -html=$(COVERAGE_DIR)/coverage-all.out -o $(COVERAGE_DIR)/coverage-all.html
|
||||
go tool cover -func=$(COVERAGE_DIR)/coverage-all.out
|
||||
@echo "Full coverage report generated: $(COVERAGE_DIR)/coverage-all.html"
|
||||
@@ -91,16 +100,38 @@ lint-extra: ## Run staticcheck if available (installs if missing)
|
||||
staticcheck ./...
|
||||
|
||||
# Development targets
|
||||
dev: dev-backend ## Start development server (backend only - frontend served by backend)
|
||||
|
||||
dev-frontend: ## Start frontend development server (Vite hot reload)
|
||||
@echo "Starting frontend dev server..."
|
||||
cd $(WEBAPP_DIR) && npm run dev
|
||||
|
||||
dev-backend: ## Run backend in development mode
|
||||
@echo "Starting backend..."
|
||||
go run $(BUILD_DIR)
|
||||
|
||||
clean: ## Clean build artifacts and test coverage
|
||||
@echo "Cleaning..."
|
||||
rm -f $(BINARY_NAME)
|
||||
rm -rf $(COVERAGE_DIR)
|
||||
rm -rf $(WEBAPP_DIR)/dist
|
||||
rm -rf $(WEBAPP_DIR)/node_modules
|
||||
go clean ./...
|
||||
|
||||
deps: ## Download and tidy dependencies
|
||||
@echo "Downloading dependencies..."
|
||||
deps: ## Download and tidy dependencies (Go + npm)
|
||||
@echo "Downloading Go dependencies..."
|
||||
go mod download
|
||||
go mod tidy
|
||||
@echo "Installing frontend dependencies..."
|
||||
cd $(WEBAPP_DIR) && npm install
|
||||
|
||||
migrate-up: ## Apply database migrations
|
||||
@echo "Applying database migrations..."
|
||||
go run $(MIGRATE_DIR) up
|
||||
|
||||
migrate-down: ## Rollback last database migration
|
||||
@echo "Rolling back last migration..."
|
||||
go run $(MIGRATE_DIR) down
|
||||
|
||||
# Mock generation (none at the moment)
|
||||
generate-mocks: ## No exported interfaces to mock (skipped)
|
||||
@@ -110,6 +141,19 @@ generate-mocks: ## No exported interfaces to mock (skipped)
|
||||
docker-build: ## Build Docker image
|
||||
docker build -t ackify-ce:latest .
|
||||
|
||||
docker-rebuild: ## Rebuild and restart Docker containers (as per CLAUDE.md)
|
||||
@echo "Rebuilding and restarting Docker containers..."
|
||||
docker compose -f compose.local.yml up -d --force-recreate ackify-ce --build
|
||||
|
||||
docker-up: ## Start Docker containers
|
||||
docker compose -f compose.local.yml up -d
|
||||
|
||||
docker-down: ## Stop Docker containers
|
||||
docker compose -f compose.local.yml down
|
||||
|
||||
docker-logs: ## View Docker logs
|
||||
docker compose -f compose.local.yml logs -f ackify-ce
|
||||
|
||||
docker-test: ## Run tests in Docker environment
|
||||
docker compose -f compose.local.yml up -d ackify-db
|
||||
@sleep 5
|
||||
|
||||
954
README_FR.md
954
README_FR.md
File diff suppressed because it is too large
Load Diff
1672
api/openapi.yaml
Normal file
1672
api/openapi.yaml
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,3 +0,0 @@
|
||||
@tailwind base;
|
||||
@tailwind components;
|
||||
@tailwind utilities;
|
||||
@@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"errors"
|
||||
"log"
|
||||
"net/http"
|
||||
@@ -10,12 +11,14 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/btouchard/ackify-ce/internal/infrastructure/config"
|
||||
"github.com/btouchard/ackify-ce/internal/presentation/admin"
|
||||
"github.com/btouchard/ackify-ce/pkg/logger"
|
||||
"github.com/btouchard/ackify-ce/pkg/web"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/config"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/logger"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/web"
|
||||
)
|
||||
|
||||
//go:embed all:web/dist
|
||||
var frontend embed.FS
|
||||
|
||||
func main() {
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -26,13 +29,11 @@ func main() {
|
||||
|
||||
logger.SetLevel(logger.ParseLevel(cfg.Logger.Level))
|
||||
|
||||
server, err := web.NewServer(ctx, cfg)
|
||||
server, err := web.NewServer(ctx, cfg, frontend)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create server: %v", err)
|
||||
}
|
||||
|
||||
server.RegisterRoutes(admin.RegisterAdminRoutes(cfg, server.GetTemplates(), server.GetDB(), server.GetAuthService(), server.GetEmailSender()))
|
||||
|
||||
go func() {
|
||||
log.Printf("Community Edition server starting on %s", server.GetAddr())
|
||||
if err := server.Start(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
0
backend/cmd/community/web/dist/index.html
vendored
Normal file
0
backend/cmd/community/web/dist/index.html
vendored
Normal file
218
backend/internal/application/services/checksum_service.go
Normal file
218
backend/internal/application/services/checksum_service.go
Normal file
@@ -0,0 +1,218 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/logger"
|
||||
)
|
||||
|
||||
// ChecksumVerificationRepository defines the interface for checksum verification persistence
|
||||
type ChecksumVerificationRepository interface {
|
||||
RecordVerification(ctx context.Context, verification *models.ChecksumVerification) error
|
||||
GetVerificationHistory(ctx context.Context, docID string, limit int) ([]*models.ChecksumVerification, error)
|
||||
GetLastVerification(ctx context.Context, docID string) (*models.ChecksumVerification, error)
|
||||
}
|
||||
|
||||
// DocumentRepository defines the interface for document metadata operations
|
||||
type DocumentRepository interface {
|
||||
GetByDocID(ctx context.Context, docID string) (*models.Document, error)
|
||||
}
|
||||
|
||||
// ChecksumService orchestrates document integrity verification with audit trail persistence
|
||||
type ChecksumService struct {
|
||||
verificationRepo ChecksumVerificationRepository
|
||||
documentRepo DocumentRepository
|
||||
}
|
||||
|
||||
// NewChecksumService initializes checksum verification service with required repository dependencies
|
||||
func NewChecksumService(
|
||||
verificationRepo ChecksumVerificationRepository,
|
||||
documentRepo DocumentRepository,
|
||||
) *ChecksumService {
|
||||
return &ChecksumService{
|
||||
verificationRepo: verificationRepo,
|
||||
documentRepo: documentRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateChecksumFormat ensures checksum matches expected hexadecimal length for SHA-256/SHA-512/MD5
|
||||
func (s *ChecksumService) ValidateChecksumFormat(checksum, algorithm string) error {
|
||||
// Remove common separators and whitespace
|
||||
checksum = normalizeChecksum(checksum)
|
||||
|
||||
var expectedLength int
|
||||
switch algorithm {
|
||||
case "SHA-256":
|
||||
expectedLength = 64
|
||||
case "SHA-512":
|
||||
expectedLength = 128
|
||||
case "MD5":
|
||||
expectedLength = 32
|
||||
default:
|
||||
return fmt.Errorf("unsupported algorithm: %s", algorithm)
|
||||
}
|
||||
|
||||
// Check length
|
||||
if len(checksum) != expectedLength {
|
||||
return fmt.Errorf("invalid checksum length for %s: expected %d hexadecimal characters, got %d", algorithm, expectedLength, len(checksum))
|
||||
}
|
||||
|
||||
// Check if it's a valid hex string
|
||||
hexPattern := regexp.MustCompile("^[a-fA-F0-9]+$")
|
||||
if !hexPattern.MatchString(checksum) {
|
||||
return fmt.Errorf("invalid checksum format: must contain only hexadecimal characters (0-9, a-f, A-F)")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyChecksum compares calculated hash against stored reference and creates immutable audit record
|
||||
func (s *ChecksumService) VerifyChecksum(ctx context.Context, docID, calculatedChecksum, verifiedBy string) (*models.ChecksumVerificationResult, error) {
|
||||
// Get document metadata
|
||||
doc, err := s.documentRepo.GetByDocID(ctx, docID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get document: %w", err)
|
||||
}
|
||||
|
||||
if doc == nil {
|
||||
return nil, fmt.Errorf("document not found: %s", docID)
|
||||
}
|
||||
|
||||
// Normalize checksums for comparison
|
||||
normalizedCalculated := normalizeChecksum(calculatedChecksum)
|
||||
normalizedStored := normalizeChecksum(doc.Checksum)
|
||||
|
||||
// Determine the algorithm to use (from document or default to SHA-256)
|
||||
algorithm := doc.ChecksumAlgorithm
|
||||
if algorithm == "" {
|
||||
algorithm = "SHA-256"
|
||||
}
|
||||
|
||||
// Validate the calculated checksum format
|
||||
if err := s.ValidateChecksumFormat(normalizedCalculated, algorithm); err != nil {
|
||||
// Record failed verification with error
|
||||
errorMsg := err.Error()
|
||||
verification := &models.ChecksumVerification{
|
||||
DocID: docID,
|
||||
VerifiedBy: verifiedBy,
|
||||
VerifiedAt: time.Now(),
|
||||
StoredChecksum: normalizedStored,
|
||||
CalculatedChecksum: normalizedCalculated,
|
||||
Algorithm: algorithm,
|
||||
IsValid: false,
|
||||
ErrorMessage: &errorMsg,
|
||||
}
|
||||
_ = s.verificationRepo.RecordVerification(ctx, verification)
|
||||
|
||||
return nil, fmt.Errorf("invalid checksum format: %w", err)
|
||||
}
|
||||
|
||||
// Check if document has a reference checksum
|
||||
if !doc.HasChecksum() {
|
||||
result := &models.ChecksumVerificationResult{
|
||||
Valid: false,
|
||||
StoredChecksum: "",
|
||||
CalculatedChecksum: normalizedCalculated,
|
||||
Algorithm: algorithm,
|
||||
Message: "No reference checksum configured for this document",
|
||||
HasReferenceHash: false,
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Compare checksums (case-insensitive)
|
||||
isValid := strings.EqualFold(normalizedCalculated, normalizedStored)
|
||||
|
||||
// Record verification
|
||||
verification := &models.ChecksumVerification{
|
||||
DocID: docID,
|
||||
VerifiedBy: verifiedBy,
|
||||
VerifiedAt: time.Now(),
|
||||
StoredChecksum: normalizedStored,
|
||||
CalculatedChecksum: normalizedCalculated,
|
||||
Algorithm: algorithm,
|
||||
IsValid: isValid,
|
||||
ErrorMessage: nil,
|
||||
}
|
||||
|
||||
if err := s.verificationRepo.RecordVerification(ctx, verification); err != nil {
|
||||
logger.Logger.Error("Failed to record verification", "error", err.Error(), "doc_id", docID)
|
||||
// Continue even if recording fails - return the result
|
||||
}
|
||||
|
||||
var message string
|
||||
if isValid {
|
||||
message = "Checksums match - document integrity verified"
|
||||
} else {
|
||||
message = "Checksums do not match - document may have been modified"
|
||||
}
|
||||
|
||||
result := &models.ChecksumVerificationResult{
|
||||
Valid: isValid,
|
||||
StoredChecksum: normalizedStored,
|
||||
CalculatedChecksum: normalizedCalculated,
|
||||
Algorithm: algorithm,
|
||||
Message: message,
|
||||
HasReferenceHash: true,
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetVerificationHistory retrieves paginated audit trail of all checksum validation attempts
|
||||
func (s *ChecksumService) GetVerificationHistory(ctx context.Context, docID string, limit int) ([]*models.ChecksumVerification, error) {
|
||||
if limit <= 0 {
|
||||
limit = 20
|
||||
}
|
||||
|
||||
return s.verificationRepo.GetVerificationHistory(ctx, docID, limit)
|
||||
}
|
||||
|
||||
// GetSupportedAlgorithms returns available hash algorithms for client-side documentation
|
||||
func (s *ChecksumService) GetSupportedAlgorithms() []string {
|
||||
return []string{"SHA-256", "SHA-512", "MD5"}
|
||||
}
|
||||
|
||||
// GetChecksumInfo exposes document hash metadata for public verification interfaces
|
||||
func (s *ChecksumService) GetChecksumInfo(ctx context.Context, docID string) (map[string]interface{}, error) {
|
||||
doc, err := s.documentRepo.GetByDocID(ctx, docID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get document: %w", err)
|
||||
}
|
||||
|
||||
if doc == nil {
|
||||
return nil, fmt.Errorf("document not found: %s", docID)
|
||||
}
|
||||
|
||||
algorithm := doc.ChecksumAlgorithm
|
||||
if algorithm == "" {
|
||||
algorithm = "SHA-256"
|
||||
}
|
||||
|
||||
info := map[string]interface{}{
|
||||
"doc_id": docID,
|
||||
"has_checksum": doc.HasChecksum(),
|
||||
"algorithm": algorithm,
|
||||
"checksum_length": doc.GetExpectedChecksumLength(),
|
||||
"supported_algorithms": s.GetSupportedAlgorithms(),
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// normalizeChecksum removes common separators and converts to lowercase
|
||||
func normalizeChecksum(checksum string) string {
|
||||
// Remove spaces, hyphens, underscores
|
||||
checksum = strings.ReplaceAll(checksum, " ", "")
|
||||
checksum = strings.ReplaceAll(checksum, "-", "")
|
||||
checksum = strings.ReplaceAll(checksum, "_", "")
|
||||
checksum = strings.TrimSpace(checksum)
|
||||
// Convert to lowercase for case-insensitive comparison
|
||||
return strings.ToLower(checksum)
|
||||
}
|
||||
472
backend/internal/application/services/checksum_service_test.go
Normal file
472
backend/internal/application/services/checksum_service_test.go
Normal file
@@ -0,0 +1,472 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
)
|
||||
|
||||
type fakeVerificationRepository struct {
|
||||
verifications []*models.ChecksumVerification
|
||||
shouldFailRecord bool
|
||||
shouldFailGetHistory bool
|
||||
shouldFailGetLast bool
|
||||
}
|
||||
|
||||
func newFakeVerificationRepository() *fakeVerificationRepository {
|
||||
return &fakeVerificationRepository{
|
||||
verifications: make([]*models.ChecksumVerification, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeVerificationRepository) RecordVerification(_ context.Context, verification *models.ChecksumVerification) error {
|
||||
if f.shouldFailRecord {
|
||||
return errors.New("repository record failed")
|
||||
}
|
||||
|
||||
verification.ID = int64(len(f.verifications) + 1)
|
||||
f.verifications = append(f.verifications, verification)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeVerificationRepository) GetVerificationHistory(_ context.Context, docID string, limit int) ([]*models.ChecksumVerification, error) {
|
||||
if f.shouldFailGetHistory {
|
||||
return nil, errors.New("repository get history failed")
|
||||
}
|
||||
|
||||
var result []*models.ChecksumVerification
|
||||
for _, v := range f.verifications {
|
||||
if v.DocID == docID {
|
||||
result = append(result, v)
|
||||
if len(result) >= limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (f *fakeVerificationRepository) GetLastVerification(_ context.Context, docID string) (*models.ChecksumVerification, error) {
|
||||
if f.shouldFailGetLast {
|
||||
return nil, errors.New("repository get last failed")
|
||||
}
|
||||
|
||||
for i := len(f.verifications) - 1; i >= 0; i-- {
|
||||
if f.verifications[i].DocID == docID {
|
||||
return f.verifications[i], nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type fakeDocumentRepository struct {
|
||||
documents map[string]*models.Document
|
||||
shouldFailGet bool
|
||||
}
|
||||
|
||||
func newFakeDocumentRepository() *fakeDocumentRepository {
|
||||
return &fakeDocumentRepository{
|
||||
documents: make(map[string]*models.Document),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeDocumentRepository) GetByDocID(_ context.Context, docID string) (*models.Document, error) {
|
||||
if f.shouldFailGet {
|
||||
return nil, errors.New("repository get failed")
|
||||
}
|
||||
|
||||
doc, exists := f.documents[docID]
|
||||
if !exists {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return doc, nil
|
||||
}
|
||||
|
||||
func (f *fakeDocumentRepository) Create(_ context.Context, docID string, input models.DocumentInput, createdBy string) (*models.Document, error) {
|
||||
if f.shouldFailGet {
|
||||
return nil, errors.New("repository create failed")
|
||||
}
|
||||
|
||||
doc := &models.Document{
|
||||
DocID: docID,
|
||||
Title: input.Title,
|
||||
URL: input.URL,
|
||||
Checksum: input.Checksum,
|
||||
ChecksumAlgorithm: input.ChecksumAlgorithm,
|
||||
Description: input.Description,
|
||||
CreatedBy: createdBy,
|
||||
}
|
||||
f.documents[docID] = doc
|
||||
return doc, nil
|
||||
}
|
||||
|
||||
func (f *fakeDocumentRepository) FindByReference(_ context.Context, ref string, refType string) (*models.Document, error) {
|
||||
if f.shouldFailGet {
|
||||
return nil, errors.New("repository find failed")
|
||||
}
|
||||
|
||||
for _, doc := range f.documents {
|
||||
if doc.URL == ref {
|
||||
return doc, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestChecksumService_ValidateChecksumFormat(t *testing.T) {
|
||||
service := NewChecksumService(newFakeVerificationRepository(), newFakeDocumentRepository())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
checksum string
|
||||
algorithm string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "valid SHA-256",
|
||||
checksum: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
|
||||
algorithm: "SHA-256",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "valid SHA-512",
|
||||
checksum: "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e",
|
||||
algorithm: "SHA-512",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "valid MD5",
|
||||
checksum: "d41d8cd98f00b204e9800998ecf8427e",
|
||||
algorithm: "MD5",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "valid with uppercase",
|
||||
checksum: "E3B0C44298FC1C149AFBF4C8996FB92427AE41E4649B934CA495991B7852B855",
|
||||
algorithm: "SHA-256",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "valid with spaces",
|
||||
checksum: "e3b0c442 98fc1c14 9afbf4c8 996fb924 27ae41e4 649b934c a495991b 7852b855",
|
||||
algorithm: "SHA-256",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "valid with hyphens",
|
||||
checksum: "e3b0c442-98fc1c14-9afbf4c8-996fb924-27ae41e4-649b934c-a495991b-7852b855",
|
||||
algorithm: "SHA-256",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid - too short for SHA-256",
|
||||
checksum: "abc123",
|
||||
algorithm: "SHA-256",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - too long for MD5",
|
||||
checksum: "d41d8cd98f00b204e9800998ecf8427eextra",
|
||||
algorithm: "MD5",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - non-hex characters",
|
||||
checksum: "gggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg",
|
||||
algorithm: "SHA-256",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - unsupported algorithm",
|
||||
checksum: "abc123",
|
||||
algorithm: "SHA-1",
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := service.ValidateChecksumFormat(tt.checksum, tt.algorithm)
|
||||
|
||||
if tt.wantError && err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
if !tt.wantError && err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChecksumService_VerifyChecksum(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
docID string
|
||||
document *models.Document
|
||||
calculatedChecksum string
|
||||
verifiedBy string
|
||||
wantValid bool
|
||||
wantHasReference bool
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "valid verification - checksums match",
|
||||
docID: "doc-001",
|
||||
document: &models.Document{
|
||||
DocID: "doc-001",
|
||||
Checksum: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
|
||||
ChecksumAlgorithm: "SHA-256",
|
||||
},
|
||||
calculatedChecksum: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
|
||||
verifiedBy: "user@example.com",
|
||||
wantValid: true,
|
||||
wantHasReference: true,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid verification - checksums differ",
|
||||
docID: "doc-002",
|
||||
document: &models.Document{
|
||||
DocID: "doc-002",
|
||||
Checksum: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
|
||||
ChecksumAlgorithm: "SHA-256",
|
||||
},
|
||||
calculatedChecksum: "0000000000000000000000000000000000000000000000000000000000000000",
|
||||
verifiedBy: "user@example.com",
|
||||
wantValid: false,
|
||||
wantHasReference: true,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "no reference checksum",
|
||||
docID: "doc-003",
|
||||
document: &models.Document{
|
||||
DocID: "doc-003",
|
||||
Checksum: "",
|
||||
ChecksumAlgorithm: "SHA-256",
|
||||
},
|
||||
calculatedChecksum: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
|
||||
verifiedBy: "user@example.com",
|
||||
wantValid: false,
|
||||
wantHasReference: false,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "case insensitive comparison",
|
||||
docID: "doc-004",
|
||||
document: &models.Document{
|
||||
DocID: "doc-004",
|
||||
Checksum: "E3B0C44298FC1C149AFBF4C8996FB92427AE41E4649B934CA495991B7852B855",
|
||||
ChecksumAlgorithm: "SHA-256",
|
||||
},
|
||||
calculatedChecksum: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
|
||||
verifiedBy: "user@example.com",
|
||||
wantValid: true,
|
||||
wantHasReference: true,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "document not found",
|
||||
docID: "non-existent",
|
||||
document: nil,
|
||||
calculatedChecksum: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
|
||||
verifiedBy: "user@example.com",
|
||||
wantValid: false,
|
||||
wantHasReference: false,
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
verificationRepo := newFakeVerificationRepository()
|
||||
documentRepo := newFakeDocumentRepository()
|
||||
|
||||
if tt.document != nil {
|
||||
documentRepo.documents[tt.docID] = tt.document
|
||||
}
|
||||
|
||||
service := NewChecksumService(verificationRepo, documentRepo)
|
||||
|
||||
result, err := service.VerifyChecksum(ctx, tt.docID, tt.calculatedChecksum, tt.verifiedBy)
|
||||
|
||||
if tt.wantError {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result.Valid != tt.wantValid {
|
||||
t.Errorf("expected Valid=%v, got %v", tt.wantValid, result.Valid)
|
||||
}
|
||||
|
||||
if result.HasReferenceHash != tt.wantHasReference {
|
||||
t.Errorf("expected HasReferenceHash=%v, got %v", tt.wantHasReference, result.HasReferenceHash)
|
||||
}
|
||||
|
||||
// Check that verification was recorded (if document has checksum)
|
||||
if tt.wantHasReference {
|
||||
if len(verificationRepo.verifications) != 1 {
|
||||
t.Errorf("expected 1 verification recorded, got %d", len(verificationRepo.verifications))
|
||||
} else {
|
||||
v := verificationRepo.verifications[0]
|
||||
if v.IsValid != tt.wantValid {
|
||||
t.Errorf("recorded verification IsValid=%v, expected %v", v.IsValid, tt.wantValid)
|
||||
}
|
||||
if v.VerifiedBy != tt.verifiedBy {
|
||||
t.Errorf("recorded verification VerifiedBy=%s, expected %s", v.VerifiedBy, tt.verifiedBy)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChecksumService_GetVerificationHistory(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
verificationRepo := newFakeVerificationRepository()
|
||||
documentRepo := newFakeDocumentRepository()
|
||||
service := NewChecksumService(verificationRepo, documentRepo)
|
||||
|
||||
// Add test verifications
|
||||
for i := 0; i < 5; i++ {
|
||||
v := &models.ChecksumVerification{
|
||||
DocID: "doc-001",
|
||||
VerifiedBy: "user@example.com",
|
||||
VerifiedAt: time.Now(),
|
||||
StoredChecksum: "abc123",
|
||||
CalculatedChecksum: "abc123",
|
||||
Algorithm: "SHA-256",
|
||||
IsValid: true,
|
||||
}
|
||||
_ = verificationRepo.RecordVerification(ctx, v)
|
||||
}
|
||||
|
||||
// Test get all
|
||||
history, err := service.GetVerificationHistory(ctx, "doc-001", 10)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(history) != 5 {
|
||||
t.Errorf("expected 5 verifications, got %d", len(history))
|
||||
}
|
||||
|
||||
// Test with limit
|
||||
limited, err := service.GetVerificationHistory(ctx, "doc-001", 2)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(limited) != 2 {
|
||||
t.Errorf("expected 2 verifications with limit, got %d", len(limited))
|
||||
}
|
||||
}
|
||||
|
||||
func TestChecksumService_GetSupportedAlgorithms(t *testing.T) {
|
||||
service := NewChecksumService(newFakeVerificationRepository(), newFakeDocumentRepository())
|
||||
|
||||
algorithms := service.GetSupportedAlgorithms()
|
||||
|
||||
expected := []string{"SHA-256", "SHA-512", "MD5"}
|
||||
if len(algorithms) != len(expected) {
|
||||
t.Errorf("expected %d algorithms, got %d", len(expected), len(algorithms))
|
||||
}
|
||||
|
||||
for i, alg := range expected {
|
||||
if algorithms[i] != alg {
|
||||
t.Errorf("expected algorithm %s at position %d, got %s", alg, i, algorithms[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestChecksumService_GetChecksumInfo(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
docID string
|
||||
document *models.Document
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "document with checksum",
|
||||
docID: "doc-001",
|
||||
document: &models.Document{
|
||||
DocID: "doc-001",
|
||||
Checksum: "abc123",
|
||||
ChecksumAlgorithm: "SHA-256",
|
||||
},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "document without checksum",
|
||||
docID: "doc-002",
|
||||
document: &models.Document{
|
||||
DocID: "doc-002",
|
||||
Checksum: "",
|
||||
ChecksumAlgorithm: "SHA-256",
|
||||
},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "document not found",
|
||||
docID: "non-existent",
|
||||
document: nil,
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
documentRepo := newFakeDocumentRepository()
|
||||
if tt.document != nil {
|
||||
documentRepo.documents[tt.docID] = tt.document
|
||||
}
|
||||
|
||||
service := NewChecksumService(newFakeVerificationRepository(), documentRepo)
|
||||
|
||||
info, err := service.GetChecksumInfo(ctx, tt.docID)
|
||||
|
||||
if tt.wantError {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if info["doc_id"] != tt.docID {
|
||||
t.Errorf("expected doc_id %s, got %v", tt.docID, info["doc_id"])
|
||||
}
|
||||
|
||||
if _, ok := info["has_checksum"]; !ok {
|
||||
t.Error("expected has_checksum field")
|
||||
}
|
||||
|
||||
if _, ok := info["supported_algorithms"]; !ok {
|
||||
t.Error("expected supported_algorithms field")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
311
backend/internal/application/services/document_service.go
Normal file
311
backend/internal/application/services/document_service.go
Normal file
@@ -0,0 +1,311 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/config"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/checksum"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/logger"
|
||||
)
|
||||
|
||||
type documentRepository interface {
|
||||
Create(ctx context.Context, docID string, input models.DocumentInput, createdBy string) (*models.Document, error)
|
||||
GetByDocID(ctx context.Context, docID string) (*models.Document, error)
|
||||
FindByReference(ctx context.Context, ref string, refType string) (*models.Document, error)
|
||||
}
|
||||
|
||||
// DocumentService handles document metadata operations and unique ID generation
|
||||
type DocumentService struct {
|
||||
repo documentRepository
|
||||
checksumConfig *config.ChecksumConfig
|
||||
}
|
||||
|
||||
// NewDocumentService initializes the document service with its repository dependency
|
||||
func NewDocumentService(repo documentRepository, checksumConfig *config.ChecksumConfig) *DocumentService {
|
||||
return &DocumentService{
|
||||
repo: repo,
|
||||
checksumConfig: checksumConfig,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateDocumentRequest represents the request to create a document
|
||||
type CreateDocumentRequest struct {
|
||||
Reference string `json:"reference" validate:"required,min=1"`
|
||||
Title string `json:"title"`
|
||||
}
|
||||
|
||||
// CreateDocument generates a collision-resistant base36 identifier and persists document metadata
|
||||
func (s *DocumentService) CreateDocument(ctx context.Context, req CreateDocumentRequest) (*models.Document, error) {
|
||||
logger.Logger.Info("Document creation attempt", "reference", req.Reference)
|
||||
|
||||
var docID string
|
||||
maxRetries := 5
|
||||
for i := 0; i < maxRetries; i++ {
|
||||
docID = generateDocID()
|
||||
|
||||
existing, err := s.repo.GetByDocID(ctx, docID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check doc_id uniqueness: %w", err)
|
||||
}
|
||||
|
||||
if existing == nil {
|
||||
break
|
||||
}
|
||||
|
||||
logger.Logger.Debug("Generated doc_id already exists, retrying",
|
||||
"doc_id", docID, "attempt", i+1)
|
||||
}
|
||||
|
||||
var url, title string
|
||||
if strings.HasPrefix(req.Reference, "http://") || strings.HasPrefix(req.Reference, "https://") {
|
||||
url = req.Reference
|
||||
|
||||
if req.Title == "" {
|
||||
title = extractTitleFromURL(req.Reference)
|
||||
} else {
|
||||
title = req.Title
|
||||
}
|
||||
} else {
|
||||
url = ""
|
||||
if req.Title == "" {
|
||||
title = req.Reference
|
||||
} else {
|
||||
title = req.Title
|
||||
}
|
||||
}
|
||||
|
||||
input := models.DocumentInput{
|
||||
Title: title,
|
||||
URL: url,
|
||||
}
|
||||
|
||||
// Automatically compute checksum for remote URLs if enabled
|
||||
if url != "" && s.checksumConfig != nil {
|
||||
checksumResult := s.computeChecksumForURL(url)
|
||||
if checksumResult != nil {
|
||||
input.Checksum = checksumResult.ChecksumHex
|
||||
input.ChecksumAlgorithm = checksumResult.Algorithm
|
||||
logger.Logger.Info("Automatically computed checksum for document",
|
||||
"doc_id", docID,
|
||||
"checksum", checksumResult.ChecksumHex,
|
||||
"algorithm", checksumResult.Algorithm)
|
||||
}
|
||||
}
|
||||
|
||||
doc, err := s.repo.Create(ctx, docID, input, "")
|
||||
if err != nil {
|
||||
logger.Logger.Error("Failed to create document",
|
||||
"doc_id", docID,
|
||||
"error", err.Error())
|
||||
return nil, fmt.Errorf("failed to create document: %w", err)
|
||||
}
|
||||
|
||||
logger.Logger.Info("Document created successfully",
|
||||
"doc_id", docID,
|
||||
"url", url,
|
||||
"title", title)
|
||||
|
||||
return doc, nil
|
||||
}
|
||||
|
||||
func generateDocID() string {
|
||||
timestamp := time.Now().Unix()
|
||||
timestampB36 := strconv.FormatInt(timestamp, 36)
|
||||
|
||||
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
const suffixLen = 4
|
||||
|
||||
suffix := make([]byte, suffixLen)
|
||||
for i := range suffix {
|
||||
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
|
||||
if err != nil {
|
||||
suffix[i] = charset[(int(timestamp)+i)%len(charset)]
|
||||
} else {
|
||||
suffix[i] = charset[n.Int64()]
|
||||
}
|
||||
}
|
||||
|
||||
return timestampB36 + string(suffix)
|
||||
}
|
||||
|
||||
func extractTitleFromURL(urlStr string) string {
|
||||
urlStr = strings.TrimRight(urlStr, "/")
|
||||
|
||||
urlStr = strings.TrimPrefix(urlStr, "http://")
|
||||
urlStr = strings.TrimPrefix(urlStr, "https://")
|
||||
|
||||
parts := strings.Split(urlStr, "/")
|
||||
|
||||
if len(parts) == 0 {
|
||||
return urlStr
|
||||
}
|
||||
|
||||
var lastSegment string
|
||||
for i := len(parts) - 1; i >= 0; i-- {
|
||||
if parts[i] != "" {
|
||||
lastSegment = parts[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if lastSegment == "" {
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
return parts[0]
|
||||
}
|
||||
return urlStr
|
||||
}
|
||||
|
||||
if idx := strings.Index(lastSegment, "?"); idx >= 0 {
|
||||
lastSegment = lastSegment[:idx]
|
||||
}
|
||||
|
||||
if idx := strings.Index(lastSegment, "#"); idx >= 0 {
|
||||
lastSegment = lastSegment[:idx]
|
||||
}
|
||||
|
||||
if idx := strings.LastIndex(lastSegment, "."); idx > 0 {
|
||||
return lastSegment[:idx]
|
||||
}
|
||||
|
||||
return lastSegment
|
||||
}
|
||||
|
||||
// computeChecksumForURL attempts to compute the checksum for a remote URL
|
||||
// Returns nil if the checksum cannot be computed (error, too large, etc.)
|
||||
func (s *DocumentService) computeChecksumForURL(url string) *checksum.Result {
|
||||
if s.checksumConfig == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
opts := checksum.ComputeOptions{
|
||||
MaxBytes: s.checksumConfig.MaxBytes,
|
||||
TimeoutMs: s.checksumConfig.TimeoutMs,
|
||||
MaxRedirects: s.checksumConfig.MaxRedirects,
|
||||
AllowedContentType: s.checksumConfig.AllowedContentType,
|
||||
SkipSSRFCheck: s.checksumConfig.SkipSSRFCheck,
|
||||
InsecureSkipVerify: s.checksumConfig.InsecureSkipVerify,
|
||||
}
|
||||
|
||||
result, err := checksum.ComputeRemoteChecksum(url, opts)
|
||||
if err != nil {
|
||||
logger.Logger.Warn("Failed to compute checksum for URL",
|
||||
"url", url,
|
||||
"error", err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
type ReferenceType string
|
||||
|
||||
const (
|
||||
ReferenceTypeURL ReferenceType = "url"
|
||||
ReferenceTypePath ReferenceType = "path"
|
||||
ReferenceTypeReference ReferenceType = "reference"
|
||||
)
|
||||
|
||||
func detectReferenceType(ref string) ReferenceType {
|
||||
if strings.HasPrefix(ref, "http://") || strings.HasPrefix(ref, "https://") {
|
||||
return ReferenceTypeURL
|
||||
}
|
||||
|
||||
if strings.Contains(ref, "/") || strings.Contains(ref, "\\") {
|
||||
return ReferenceTypePath
|
||||
}
|
||||
|
||||
return ReferenceTypeReference
|
||||
}
|
||||
|
||||
// FindByReference finds a document by its reference without creating it
|
||||
func (s *DocumentService) FindByReference(ctx context.Context, ref string, refType string) (*models.Document, error) {
|
||||
doc, err := s.repo.FindByReference(ctx, ref, refType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return doc, nil
|
||||
}
|
||||
|
||||
// FindOrCreateDocument performs smart lookup by URL/path/reference or creates new document if not found
|
||||
func (s *DocumentService) FindOrCreateDocument(ctx context.Context, ref string) (*models.Document, bool, error) {
|
||||
logger.Logger.Info("Find or create document", "reference", ref)
|
||||
|
||||
refType := detectReferenceType(ref)
|
||||
logger.Logger.Debug("Reference type detected", "type", refType, "reference", ref)
|
||||
|
||||
doc, err := s.repo.FindByReference(ctx, ref, string(refType))
|
||||
if err != nil {
|
||||
logger.Logger.Error("Error searching for document", "reference", ref, "error", err.Error())
|
||||
return nil, false, fmt.Errorf("failed to search for document: %w", err)
|
||||
}
|
||||
|
||||
if doc != nil {
|
||||
logger.Logger.Info("Document found", "doc_id", doc.DocID, "reference", ref)
|
||||
return doc, false, nil
|
||||
}
|
||||
|
||||
logger.Logger.Info("Document not found, creating new one", "reference", ref)
|
||||
|
||||
var title string
|
||||
switch refType {
|
||||
case ReferenceTypeURL:
|
||||
title = extractTitleFromURL(ref)
|
||||
case ReferenceTypePath:
|
||||
title = extractTitleFromURL(ref)
|
||||
case ReferenceTypeReference:
|
||||
title = ref
|
||||
}
|
||||
|
||||
createReq := CreateDocumentRequest{
|
||||
Reference: ref,
|
||||
Title: title,
|
||||
}
|
||||
|
||||
if refType == ReferenceTypeReference {
|
||||
input := models.DocumentInput{
|
||||
Title: title,
|
||||
URL: "",
|
||||
}
|
||||
|
||||
doc, err := s.repo.Create(ctx, ref, input, "")
|
||||
if err != nil {
|
||||
logger.Logger.Error("Failed to create document with custom doc_id",
|
||||
"doc_id", ref,
|
||||
"error", err.Error())
|
||||
return nil, false, fmt.Errorf("failed to create document: %w", err)
|
||||
}
|
||||
|
||||
logger.Logger.Info("Document created with custom doc_id",
|
||||
"doc_id", ref,
|
||||
"title", title)
|
||||
|
||||
return doc, true, nil
|
||||
}
|
||||
|
||||
// For URL references, compute checksum before creating
|
||||
if refType == ReferenceTypeURL && s.checksumConfig != nil {
|
||||
logger.Logger.Debug("Computing checksum for URL reference", "url", ref)
|
||||
checksumResult := s.computeChecksumForURL(ref)
|
||||
if checksumResult != nil {
|
||||
logger.Logger.Info("Automatically computed checksum for URL reference",
|
||||
"url", ref,
|
||||
"checksum", checksumResult.ChecksumHex,
|
||||
"algorithm", checksumResult.Algorithm)
|
||||
}
|
||||
}
|
||||
|
||||
doc, err = s.CreateDocument(ctx, createReq)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
return doc, true, nil
|
||||
}
|
||||
@@ -0,0 +1,328 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/config"
|
||||
)
|
||||
|
||||
// Test automatic checksum computation with valid PDF
|
||||
func TestDocumentService_CreateDocument_WithAutomaticChecksum(t *testing.T) {
|
||||
content := "Sample PDF content"
|
||||
expectedChecksum := "b3b4e8714358cc79990c5c83391172e01c3e79a1b456d7e0c570cbf59da30e23" // SHA-256
|
||||
|
||||
// Create test HTTP server
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/pdf")
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(content)))
|
||||
if r.Method == "GET" {
|
||||
w.Write([]byte(content))
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
mockRepo := &mockDocumentRepository{}
|
||||
checksumConfig := &config.ChecksumConfig{
|
||||
MaxBytes: 10 * 1024 * 1024, // 10 MB
|
||||
TimeoutMs: 5000,
|
||||
MaxRedirects: 3,
|
||||
AllowedContentType: []string{
|
||||
"application/pdf",
|
||||
"image/*",
|
||||
},
|
||||
SkipSSRFCheck: true, // For testing with httptest
|
||||
InsecureSkipVerify: true, // Accept self-signed certs in tests
|
||||
}
|
||||
service := NewDocumentService(mockRepo, checksumConfig)
|
||||
|
||||
req := CreateDocumentRequest{
|
||||
Reference: server.URL,
|
||||
Title: "Test Document",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
doc, err := service.CreateDocument(ctx, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("CreateDocument failed: %v", err)
|
||||
}
|
||||
|
||||
if doc == nil {
|
||||
t.Fatal("Expected document to be created, got nil")
|
||||
}
|
||||
|
||||
// Verify checksum was computed
|
||||
if doc.Checksum != expectedChecksum {
|
||||
t.Errorf("Expected checksum %q, got %q", expectedChecksum, doc.Checksum)
|
||||
}
|
||||
|
||||
if doc.ChecksumAlgorithm != "SHA-256" {
|
||||
t.Errorf("Expected algorithm SHA-256, got %q", doc.ChecksumAlgorithm)
|
||||
}
|
||||
}
|
||||
|
||||
// Test automatic checksum computation with HTTP (should be rejected)
|
||||
func TestDocumentService_CreateDocument_RejectsHTTP(t *testing.T) {
|
||||
mockRepo := &mockDocumentRepository{}
|
||||
checksumConfig := &config.ChecksumConfig{
|
||||
MaxBytes: 10 * 1024 * 1024,
|
||||
TimeoutMs: 5000,
|
||||
MaxRedirects: 3,
|
||||
AllowedContentType: []string{
|
||||
"application/pdf",
|
||||
},
|
||||
SkipSSRFCheck: true,
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
service := NewDocumentService(mockRepo, checksumConfig)
|
||||
|
||||
// HTTP URL (not HTTPS)
|
||||
req := CreateDocumentRequest{
|
||||
Reference: "http://example.com/document.pdf",
|
||||
Title: "Test Document",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
doc, err := service.CreateDocument(ctx, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("CreateDocument failed: %v", err)
|
||||
}
|
||||
|
||||
// Document should be created, but without checksum
|
||||
if doc.Checksum != "" {
|
||||
t.Error("Expected checksum to be empty for HTTP URL, got", doc.Checksum)
|
||||
}
|
||||
}
|
||||
|
||||
// Test automatic checksum computation with too large file
|
||||
func TestDocumentService_CreateDocument_TooLargeFile(t *testing.T) {
|
||||
// Create test HTTP server that returns large Content-Length
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/pdf")
|
||||
w.Header().Set("Content-Length", "20971520") // 20 MB
|
||||
if r.Method == "GET" {
|
||||
w.Write([]byte("should not reach here"))
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
mockRepo := &mockDocumentRepository{}
|
||||
checksumConfig := &config.ChecksumConfig{
|
||||
MaxBytes: 10 * 1024 * 1024, // 10 MB limit
|
||||
TimeoutMs: 5000,
|
||||
MaxRedirects: 3,
|
||||
AllowedContentType: []string{
|
||||
"application/pdf",
|
||||
},
|
||||
}
|
||||
service := NewDocumentService(mockRepo, checksumConfig)
|
||||
|
||||
req := CreateDocumentRequest{
|
||||
Reference: server.URL,
|
||||
Title: "Large Document",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
doc, err := service.CreateDocument(ctx, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("CreateDocument failed: %v", err)
|
||||
}
|
||||
|
||||
// Document should be created, but without checksum (file too large)
|
||||
if doc.Checksum != "" {
|
||||
t.Error("Expected checksum to be empty for too large file, got", doc.Checksum)
|
||||
}
|
||||
}
|
||||
|
||||
// Test automatic checksum computation with wrong content type
|
||||
func TestDocumentService_CreateDocument_WrongContentType(t *testing.T) {
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html") // Not allowed
|
||||
w.Header().Set("Content-Length", "100")
|
||||
if r.Method == "GET" {
|
||||
w.Write([]byte("<html>test</html>"))
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
mockRepo := &mockDocumentRepository{}
|
||||
checksumConfig := &config.ChecksumConfig{
|
||||
MaxBytes: 10 * 1024 * 1024,
|
||||
TimeoutMs: 5000,
|
||||
MaxRedirects: 3,
|
||||
AllowedContentType: []string{
|
||||
"application/pdf",
|
||||
},
|
||||
SkipSSRFCheck: true,
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
service := NewDocumentService(mockRepo, checksumConfig)
|
||||
|
||||
req := CreateDocumentRequest{
|
||||
Reference: server.URL,
|
||||
Title: "HTML Document",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
doc, err := service.CreateDocument(ctx, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("CreateDocument failed: %v", err)
|
||||
}
|
||||
|
||||
// Document should be created, but without checksum (wrong content type)
|
||||
if doc.Checksum != "" {
|
||||
t.Error("Expected checksum to be empty for wrong content type, got", doc.Checksum)
|
||||
}
|
||||
}
|
||||
|
||||
// Test automatic checksum computation with image wildcard
|
||||
func TestDocumentService_CreateDocument_ImageWildcard(t *testing.T) {
|
||||
content := []byte{0x89, 0x50, 0x4E, 0x47} // PNG header
|
||||
expectedChecksum := "0f4636c78f65d3639ece5a064b5ae753e3408614a14fb18ab4d7540d2c248543"
|
||||
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "image/png")
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(content)))
|
||||
if r.Method == "GET" {
|
||||
w.Write(content)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
mockRepo := &mockDocumentRepository{}
|
||||
checksumConfig := &config.ChecksumConfig{
|
||||
MaxBytes: 10 * 1024 * 1024,
|
||||
TimeoutMs: 5000,
|
||||
MaxRedirects: 3,
|
||||
AllowedContentType: []string{
|
||||
"image/*", // Wildcard for all images
|
||||
},
|
||||
SkipSSRFCheck: true,
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
service := NewDocumentService(mockRepo, checksumConfig)
|
||||
|
||||
req := CreateDocumentRequest{
|
||||
Reference: server.URL,
|
||||
Title: "Test Image",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
doc, err := service.CreateDocument(ctx, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("CreateDocument failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify checksum was computed for image
|
||||
if doc.Checksum != expectedChecksum {
|
||||
t.Errorf("Expected checksum %q, got %q", expectedChecksum, doc.Checksum)
|
||||
}
|
||||
}
|
||||
|
||||
// Test automatic checksum computation disabled (nil config)
|
||||
func TestDocumentService_CreateDocument_NoChecksumConfig(t *testing.T) {
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/pdf")
|
||||
w.Write([]byte("content"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
mockRepo := &mockDocumentRepository{}
|
||||
service := NewDocumentService(mockRepo, nil) // No checksum config
|
||||
|
||||
req := CreateDocumentRequest{
|
||||
Reference: server.URL,
|
||||
Title: "Test Document",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
doc, err := service.CreateDocument(ctx, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("CreateDocument failed: %v", err)
|
||||
}
|
||||
|
||||
// Document should be created without checksum (feature disabled)
|
||||
if doc.Checksum != "" {
|
||||
t.Error("Expected checksum to be empty when config is nil, got", doc.Checksum)
|
||||
}
|
||||
}
|
||||
|
||||
// Test automatic checksum computation with network error
|
||||
func TestDocumentService_CreateDocument_NetworkError(t *testing.T) {
|
||||
mockRepo := &mockDocumentRepository{}
|
||||
checksumConfig := &config.ChecksumConfig{
|
||||
MaxBytes: 10 * 1024 * 1024,
|
||||
TimeoutMs: 100, // Very short timeout
|
||||
MaxRedirects: 3,
|
||||
AllowedContentType: []string{
|
||||
"application/pdf",
|
||||
},
|
||||
}
|
||||
service := NewDocumentService(mockRepo, checksumConfig)
|
||||
|
||||
// Non-existent server
|
||||
req := CreateDocumentRequest{
|
||||
Reference: "https://non-existent-server-12345.example.com/doc.pdf",
|
||||
Title: "Test Document",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
doc, err := service.CreateDocument(ctx, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("CreateDocument failed: %v", err)
|
||||
}
|
||||
|
||||
// Document should be created without checksum (network error)
|
||||
if doc.Checksum != "" {
|
||||
t.Error("Expected checksum to be empty for network error, got", doc.Checksum)
|
||||
}
|
||||
}
|
||||
|
||||
// Test CreateDocument without URL (plain reference)
|
||||
func TestDocumentService_CreateDocument_PlainReferenceNoChecksum(t *testing.T) {
|
||||
mockRepo := &mockDocumentRepository{}
|
||||
checksumConfig := &config.ChecksumConfig{
|
||||
MaxBytes: 10 * 1024 * 1024,
|
||||
TimeoutMs: 5000,
|
||||
MaxRedirects: 3,
|
||||
AllowedContentType: []string{
|
||||
"application/pdf",
|
||||
},
|
||||
SkipSSRFCheck: true,
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
service := NewDocumentService(mockRepo, checksumConfig)
|
||||
|
||||
req := CreateDocumentRequest{
|
||||
Reference: "company-policy-2024",
|
||||
Title: "",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
doc, err := service.CreateDocument(ctx, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("CreateDocument failed: %v", err)
|
||||
}
|
||||
|
||||
// Document should be created without checksum (no URL)
|
||||
if doc.Checksum != "" {
|
||||
t.Error("Expected checksum to be empty for plain reference, got", doc.Checksum)
|
||||
}
|
||||
|
||||
// Verify it's not treated as URL
|
||||
if doc.URL != "" {
|
||||
t.Errorf("Expected URL to be empty, got %q", doc.URL)
|
||||
}
|
||||
}
|
||||
646
backend/internal/application/services/document_service_test.go
Normal file
646
backend/internal/application/services/document_service_test.go
Normal file
@@ -0,0 +1,646 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
)
|
||||
|
||||
// Test generateDocID function
|
||||
func TestGenerateDocID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
}{
|
||||
{"Generate first ID"},
|
||||
{"Generate second ID"},
|
||||
{"Generate third ID"},
|
||||
}
|
||||
|
||||
seenIDs := make(map[string]bool)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
id := generateDocID()
|
||||
|
||||
// Check length (timestamp in base36 + 4 random chars = ~10-11 chars)
|
||||
if len(id) < 10 || len(id) > 12 {
|
||||
t.Errorf("Expected ID length between 10-12 chars, got %d (%s)", len(id), id)
|
||||
}
|
||||
|
||||
// Check all characters are alphanumeric lowercase
|
||||
for _, ch := range id {
|
||||
if !((ch >= 'a' && ch <= 'z') || (ch >= '0' && ch <= '9')) {
|
||||
t.Errorf("ID contains invalid character: %c in %s", ch, id)
|
||||
}
|
||||
}
|
||||
|
||||
// Check uniqueness (probabilistic, but should be unique in small sample)
|
||||
if seenIDs[id] {
|
||||
t.Errorf("Duplicate ID generated: %s", id)
|
||||
}
|
||||
seenIDs[id] = true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test extractTitleFromURL function
|
||||
func TestExtractTitleFromURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "URL with file extension",
|
||||
url: "https://example.com/documents/report.pdf",
|
||||
expected: "report",
|
||||
},
|
||||
{
|
||||
name: "URL without extension",
|
||||
url: "https://example.com/documents/annual-report",
|
||||
expected: "annual-report",
|
||||
},
|
||||
{
|
||||
name: "URL with query parameters",
|
||||
url: "https://example.com/doc.pdf?version=2",
|
||||
expected: "doc",
|
||||
},
|
||||
{
|
||||
name: "URL with fragment",
|
||||
url: "https://example.com/guide.html#section1",
|
||||
expected: "guide",
|
||||
},
|
||||
{
|
||||
name: "URL with trailing slash",
|
||||
url: "https://example.com/page/",
|
||||
expected: "page",
|
||||
},
|
||||
{
|
||||
name: "Domain only",
|
||||
url: "https://example.com",
|
||||
expected: "example",
|
||||
},
|
||||
{
|
||||
name: "Domain with trailing slash",
|
||||
url: "https://example.com/",
|
||||
expected: "example",
|
||||
},
|
||||
{
|
||||
name: "HTTP URL",
|
||||
url: "http://example.com/test.txt",
|
||||
expected: "test",
|
||||
},
|
||||
{
|
||||
name: "URL with path and extension",
|
||||
url: "https://docs.example.com/v2/api/reference.json",
|
||||
expected: "reference",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := extractTitleFromURL(tt.url)
|
||||
if result != tt.expected {
|
||||
t.Errorf("extractTitleFromURL(%q) = %q, want %q", tt.url, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// mockDocumentRepository is a mock implementation for testing
|
||||
type mockDocumentRepository struct {
|
||||
createFunc func(ctx context.Context, docID string, input models.DocumentInput, createdBy string) (*models.Document, error)
|
||||
getByDocIDFunc func(ctx context.Context, docID string) (*models.Document, error)
|
||||
findByReferenceFunc func(ctx context.Context, ref string, refType string) (*models.Document, error)
|
||||
}
|
||||
|
||||
func (m *mockDocumentRepository) Create(ctx context.Context, docID string, input models.DocumentInput, createdBy string) (*models.Document, error) {
|
||||
if m.createFunc != nil {
|
||||
return m.createFunc(ctx, docID, input, createdBy)
|
||||
}
|
||||
return &models.Document{
|
||||
DocID: docID,
|
||||
Title: input.Title,
|
||||
URL: input.URL,
|
||||
Checksum: input.Checksum,
|
||||
ChecksumAlgorithm: input.ChecksumAlgorithm,
|
||||
CreatedBy: createdBy,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockDocumentRepository) GetByDocID(ctx context.Context, docID string) (*models.Document, error) {
|
||||
if m.getByDocIDFunc != nil {
|
||||
return m.getByDocIDFunc(ctx, docID)
|
||||
}
|
||||
return nil, nil // Not found by default
|
||||
}
|
||||
|
||||
func (m *mockDocumentRepository) FindByReference(ctx context.Context, ref string, refType string) (*models.Document, error) {
|
||||
if m.findByReferenceFunc != nil {
|
||||
return m.findByReferenceFunc(ctx, ref, refType)
|
||||
}
|
||||
return nil, nil // Not found by default
|
||||
}
|
||||
|
||||
// Test CreateDocument with URL reference
|
||||
func TestDocumentService_CreateDocument_WithURL(t *testing.T) {
|
||||
mockRepo := &mockDocumentRepository{}
|
||||
service := NewDocumentService(mockRepo, nil) // nil config = no automatic checksum
|
||||
|
||||
req := CreateDocumentRequest{
|
||||
Reference: "https://example.com/important-doc.pdf",
|
||||
Title: "",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
doc, err := service.CreateDocument(ctx, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("CreateDocument failed: %v", err)
|
||||
}
|
||||
|
||||
if doc == nil {
|
||||
t.Fatal("Expected document to be created, got nil")
|
||||
}
|
||||
|
||||
// Check that URL was extracted
|
||||
if doc.URL != "https://example.com/important-doc.pdf" {
|
||||
t.Errorf("Expected URL to be %q, got %q", "https://example.com/important-doc.pdf", doc.URL)
|
||||
}
|
||||
|
||||
// Check that title was extracted from URL
|
||||
if doc.Title != "important-doc" {
|
||||
t.Errorf("Expected title to be %q, got %q", "important-doc", doc.Title)
|
||||
}
|
||||
|
||||
// Check that doc_id was generated
|
||||
if doc.DocID == "" {
|
||||
t.Error("Expected doc_id to be generated")
|
||||
}
|
||||
}
|
||||
|
||||
// Test CreateDocument with URL reference and custom title
|
||||
func TestDocumentService_CreateDocument_WithURLAndTitle(t *testing.T) {
|
||||
mockRepo := &mockDocumentRepository{}
|
||||
service := NewDocumentService(mockRepo, nil)
|
||||
|
||||
req := CreateDocumentRequest{
|
||||
Reference: "https://example.com/doc.pdf",
|
||||
Title: "My Custom Title",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
doc, err := service.CreateDocument(ctx, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("CreateDocument failed: %v", err)
|
||||
}
|
||||
|
||||
// Check that URL was extracted
|
||||
if doc.URL != "https://example.com/doc.pdf" {
|
||||
t.Errorf("Expected URL to be %q, got %q", "https://example.com/doc.pdf", doc.URL)
|
||||
}
|
||||
|
||||
// Check that custom title was used
|
||||
if doc.Title != "My Custom Title" {
|
||||
t.Errorf("Expected title to be %q, got %q", "My Custom Title", doc.Title)
|
||||
}
|
||||
}
|
||||
|
||||
// Test CreateDocument with plain text reference
|
||||
func TestDocumentService_CreateDocument_WithPlainReference(t *testing.T) {
|
||||
mockRepo := &mockDocumentRepository{}
|
||||
service := NewDocumentService(mockRepo, nil)
|
||||
|
||||
req := CreateDocumentRequest{
|
||||
Reference: "company-policy-2024",
|
||||
Title: "",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
doc, err := service.CreateDocument(ctx, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("CreateDocument failed: %v", err)
|
||||
}
|
||||
|
||||
// Check that URL is empty
|
||||
if doc.URL != "" {
|
||||
t.Errorf("Expected URL to be empty, got %q", doc.URL)
|
||||
}
|
||||
|
||||
// Check that reference was used as title
|
||||
if doc.Title != "company-policy-2024" {
|
||||
t.Errorf("Expected title to be %q, got %q", "company-policy-2024", doc.Title)
|
||||
}
|
||||
}
|
||||
|
||||
// Test CreateDocument with plain reference and custom title
|
||||
func TestDocumentService_CreateDocument_WithPlainReferenceAndTitle(t *testing.T) {
|
||||
mockRepo := &mockDocumentRepository{}
|
||||
service := NewDocumentService(mockRepo, nil)
|
||||
|
||||
req := CreateDocumentRequest{
|
||||
Reference: "doc-ref-123",
|
||||
Title: "Employee Handbook",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
doc, err := service.CreateDocument(ctx, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("CreateDocument failed: %v", err)
|
||||
}
|
||||
|
||||
// Check that URL is empty
|
||||
if doc.URL != "" {
|
||||
t.Errorf("Expected URL to be empty, got %q", doc.URL)
|
||||
}
|
||||
|
||||
// Check that custom title was used
|
||||
if doc.Title != "Employee Handbook" {
|
||||
t.Errorf("Expected title to be %q, got %q", "Employee Handbook", doc.Title)
|
||||
}
|
||||
}
|
||||
|
||||
// Test CreateDocument with HTTP URL
|
||||
func TestDocumentService_CreateDocument_WithHTTPURL(t *testing.T) {
|
||||
mockRepo := &mockDocumentRepository{}
|
||||
service := NewDocumentService(mockRepo, nil)
|
||||
|
||||
req := CreateDocumentRequest{
|
||||
Reference: "http://example.com/doc.html",
|
||||
Title: "",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
doc, err := service.CreateDocument(ctx, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("CreateDocument failed: %v", err)
|
||||
}
|
||||
|
||||
// Check that URL was extracted (HTTP should work too)
|
||||
if doc.URL != "http://example.com/doc.html" {
|
||||
t.Errorf("Expected URL to be %q, got %q", "http://example.com/doc.html", doc.URL)
|
||||
}
|
||||
|
||||
// Check that title was extracted
|
||||
if doc.Title != "doc" {
|
||||
t.Errorf("Expected title to be %q, got %q", "doc", doc.Title)
|
||||
}
|
||||
}
|
||||
|
||||
// Test CreateDocument with ID collision retry
|
||||
func TestDocumentService_CreateDocument_IDCollisionRetry(t *testing.T) {
|
||||
collisionCount := 0
|
||||
mockRepo := &mockDocumentRepository{
|
||||
getByDocIDFunc: func(ctx context.Context, docID string) (*models.Document, error) {
|
||||
// First two attempts return existing document (collision)
|
||||
if collisionCount < 2 {
|
||||
collisionCount++
|
||||
return &models.Document{DocID: docID}, nil
|
||||
}
|
||||
// Third attempt returns nil (ID is available)
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
service := NewDocumentService(mockRepo, nil)
|
||||
|
||||
req := CreateDocumentRequest{
|
||||
Reference: "test-doc",
|
||||
Title: "",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
doc, err := service.CreateDocument(ctx, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("CreateDocument failed: %v", err)
|
||||
}
|
||||
|
||||
// Should have retried at least twice
|
||||
if collisionCount < 2 {
|
||||
t.Errorf("Expected at least 2 collision retries, got %d", collisionCount)
|
||||
}
|
||||
|
||||
if doc == nil {
|
||||
t.Fatal("Expected document to be created after retries")
|
||||
}
|
||||
}
|
||||
|
||||
// Test that generated IDs are URL-safe
|
||||
func TestGenerateDocID_URLSafe(t *testing.T) {
|
||||
for i := 0; i < 100; i++ {
|
||||
id := generateDocID()
|
||||
|
||||
// Check no uppercase letters
|
||||
if strings.ToLower(id) != id {
|
||||
t.Errorf("ID contains uppercase letters: %s", id)
|
||||
}
|
||||
|
||||
// Check no special characters that need encoding
|
||||
specialChars := []string{"/", "?", "#", "&", "=", "+", " ", "%"}
|
||||
for _, char := range specialChars {
|
||||
if strings.Contains(id, char) {
|
||||
t.Errorf("ID contains special character %q: %s", char, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test detectReferenceType function
|
||||
func TestDetectReferenceType(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ref string
|
||||
expected ReferenceType
|
||||
}{
|
||||
{
|
||||
name: "HTTPS URL",
|
||||
ref: "https://example.com/document.pdf",
|
||||
expected: ReferenceTypeURL,
|
||||
},
|
||||
{
|
||||
name: "HTTP URL",
|
||||
ref: "http://example.com/doc",
|
||||
expected: ReferenceTypeURL,
|
||||
},
|
||||
{
|
||||
name: "Unix path",
|
||||
ref: "/home/user/documents/file.pdf",
|
||||
expected: ReferenceTypePath,
|
||||
},
|
||||
{
|
||||
name: "Windows path",
|
||||
ref: "C:\\Users\\Documents\\file.pdf",
|
||||
expected: ReferenceTypePath,
|
||||
},
|
||||
{
|
||||
name: "Relative path with forward slash",
|
||||
ref: "docs/file.pdf",
|
||||
expected: ReferenceTypePath,
|
||||
},
|
||||
{
|
||||
name: "Relative path with backslash",
|
||||
ref: "docs\\file.pdf",
|
||||
expected: ReferenceTypePath,
|
||||
},
|
||||
{
|
||||
name: "Plain reference",
|
||||
ref: "policy-2024",
|
||||
expected: ReferenceTypeReference,
|
||||
},
|
||||
{
|
||||
name: "Plain reference with dashes",
|
||||
ref: "company-doc-v2",
|
||||
expected: ReferenceTypeReference,
|
||||
},
|
||||
{
|
||||
name: "Plain reference with underscores",
|
||||
ref: "employee_handbook_2024",
|
||||
expected: ReferenceTypeReference,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := detectReferenceType(tt.ref)
|
||||
if result != tt.expected {
|
||||
t.Errorf("detectReferenceType(%q) = %q, want %q", tt.ref, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test FindByReference success
|
||||
func TestDocumentService_FindByReference_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
expectedDoc := &models.Document{
|
||||
DocID: "test123",
|
||||
Title: "Test Document",
|
||||
URL: "https://example.com/test.pdf",
|
||||
}
|
||||
|
||||
mockRepo := &mockDocumentRepository{
|
||||
findByReferenceFunc: func(ctx context.Context, ref string, refType string) (*models.Document, error) {
|
||||
if ref == "https://example.com/test.pdf" && refType == "url" {
|
||||
return expectedDoc, nil
|
||||
}
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
service := NewDocumentService(mockRepo, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
doc, err := service.FindByReference(ctx, "https://example.com/test.pdf", "url")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("FindByReference failed: %v", err)
|
||||
}
|
||||
|
||||
if doc == nil {
|
||||
t.Fatal("Expected document to be found, got nil")
|
||||
}
|
||||
|
||||
if doc.DocID != expectedDoc.DocID {
|
||||
t.Errorf("Expected DocID %q, got %q", expectedDoc.DocID, doc.DocID)
|
||||
}
|
||||
}
|
||||
|
||||
// Test FindByReference not found
|
||||
func TestDocumentService_FindByReference_NotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockRepo := &mockDocumentRepository{
|
||||
findByReferenceFunc: func(ctx context.Context, ref string, refType string) (*models.Document, error) {
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
service := NewDocumentService(mockRepo, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
doc, err := service.FindByReference(ctx, "nonexistent", "reference")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("FindByReference should not error when document not found: %v", err)
|
||||
}
|
||||
|
||||
if doc != nil {
|
||||
t.Errorf("Expected nil document, got %+v", doc)
|
||||
}
|
||||
}
|
||||
|
||||
// Test FindOrCreateDocument - found existing document
|
||||
func TestDocumentService_FindOrCreateDocument_Found(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
existingDoc := &models.Document{
|
||||
DocID: "existing123",
|
||||
Title: "Existing Document",
|
||||
URL: "https://example.com/existing.pdf",
|
||||
}
|
||||
|
||||
mockRepo := &mockDocumentRepository{
|
||||
findByReferenceFunc: func(ctx context.Context, ref string, refType string) (*models.Document, error) {
|
||||
if ref == "https://example.com/existing.pdf" {
|
||||
return existingDoc, nil
|
||||
}
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
service := NewDocumentService(mockRepo, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
doc, created, err := service.FindOrCreateDocument(ctx, "https://example.com/existing.pdf")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("FindOrCreateDocument failed: %v", err)
|
||||
}
|
||||
|
||||
if doc == nil {
|
||||
t.Fatal("Expected document to be returned, got nil")
|
||||
}
|
||||
|
||||
if created {
|
||||
t.Error("Expected created to be false for existing document")
|
||||
}
|
||||
|
||||
if doc.DocID != existingDoc.DocID {
|
||||
t.Errorf("Expected DocID %q, got %q", existingDoc.DocID, doc.DocID)
|
||||
}
|
||||
}
|
||||
|
||||
// Test FindOrCreateDocument - create new document with URL
|
||||
func TestDocumentService_FindOrCreateDocument_CreateWithURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockRepo := &mockDocumentRepository{
|
||||
findByReferenceFunc: func(ctx context.Context, ref string, refType string) (*models.Document, error) {
|
||||
return nil, nil // Not found
|
||||
},
|
||||
}
|
||||
|
||||
service := NewDocumentService(mockRepo, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
doc, created, err := service.FindOrCreateDocument(ctx, "https://example.com/new-doc.pdf")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("FindOrCreateDocument failed: %v", err)
|
||||
}
|
||||
|
||||
if doc == nil {
|
||||
t.Fatal("Expected document to be created, got nil")
|
||||
}
|
||||
|
||||
if !created {
|
||||
t.Error("Expected created to be true for new document")
|
||||
}
|
||||
|
||||
if doc.URL != "https://example.com/new-doc.pdf" {
|
||||
t.Errorf("Expected URL %q, got %q", "https://example.com/new-doc.pdf", doc.URL)
|
||||
}
|
||||
|
||||
if doc.Title != "new-doc" {
|
||||
t.Errorf("Expected title %q, got %q", "new-doc", doc.Title)
|
||||
}
|
||||
}
|
||||
|
||||
// Test FindOrCreateDocument - create new document with path
|
||||
func TestDocumentService_FindOrCreateDocument_CreateWithPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockRepo := &mockDocumentRepository{
|
||||
findByReferenceFunc: func(ctx context.Context, ref string, refType string) (*models.Document, error) {
|
||||
return nil, nil // Not found
|
||||
},
|
||||
}
|
||||
|
||||
service := NewDocumentService(mockRepo, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
doc, created, err := service.FindOrCreateDocument(ctx, "/home/user/important-file.pdf")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("FindOrCreateDocument failed: %v", err)
|
||||
}
|
||||
|
||||
if doc == nil {
|
||||
t.Fatal("Expected document to be created, got nil")
|
||||
}
|
||||
|
||||
if !created {
|
||||
t.Error("Expected created to be true for new document")
|
||||
}
|
||||
|
||||
// Path is extracted as title (like extractTitleFromURL does for paths)
|
||||
if doc.Title != "important-file" {
|
||||
t.Errorf("Expected title %q, got %q", "important-file", doc.Title)
|
||||
}
|
||||
|
||||
// URL should be empty for paths (they're not http/https)
|
||||
if doc.URL != "" {
|
||||
t.Errorf("Expected URL to be empty for path, got %q", doc.URL)
|
||||
}
|
||||
}
|
||||
|
||||
// Test FindOrCreateDocument - create new document with plain reference
|
||||
func TestDocumentService_FindOrCreateDocument_CreateWithReference(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockRepo := &mockDocumentRepository{
|
||||
findByReferenceFunc: func(ctx context.Context, ref string, refType string) (*models.Document, error) {
|
||||
return nil, nil // Not found
|
||||
},
|
||||
createFunc: func(ctx context.Context, docID string, input models.DocumentInput, createdBy string) (*models.Document, error) {
|
||||
return &models.Document{
|
||||
DocID: docID,
|
||||
Title: input.Title,
|
||||
URL: input.URL,
|
||||
CreatedBy: createdBy,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
service := NewDocumentService(mockRepo, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
doc, created, err := service.FindOrCreateDocument(ctx, "company-policy-2024")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("FindOrCreateDocument failed: %v", err)
|
||||
}
|
||||
|
||||
if doc == nil {
|
||||
t.Fatal("Expected document to be created, got nil")
|
||||
}
|
||||
|
||||
if !created {
|
||||
t.Error("Expected created to be true for new document")
|
||||
}
|
||||
|
||||
// For plain reference, doc_id should be the reference itself
|
||||
if doc.DocID != "company-policy-2024" {
|
||||
t.Errorf("Expected DocID to be the reference %q, got %q", "company-policy-2024", doc.DocID)
|
||||
}
|
||||
|
||||
if doc.Title != "company-policy-2024" {
|
||||
t.Errorf("Expected title %q, got %q", "company-policy-2024", doc.Title)
|
||||
}
|
||||
|
||||
if doc.URL != "" {
|
||||
t.Errorf("Expected URL to be empty for plain reference, got %q", doc.URL)
|
||||
}
|
||||
}
|
||||
@@ -6,22 +6,35 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/btouchard/ackify-ce/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/internal/infrastructure/database"
|
||||
"github.com/btouchard/ackify-ce/internal/infrastructure/email"
|
||||
"github.com/btouchard/ackify-ce/pkg/logger"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/email"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/logger"
|
||||
)
|
||||
|
||||
// expectedSignerRepository defines minimal interface for expected signer operations
|
||||
type expectedSignerRepository interface {
|
||||
ListWithStatusByDocID(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error)
|
||||
}
|
||||
|
||||
// reminderRepository defines minimal interface for reminder logging and history
|
||||
type reminderRepository interface {
|
||||
LogReminder(ctx context.Context, log *models.ReminderLog) error
|
||||
GetReminderHistory(ctx context.Context, docID string) ([]*models.ReminderLog, error)
|
||||
GetReminderStats(ctx context.Context, docID string) (*models.ReminderStats, error)
|
||||
}
|
||||
|
||||
// ReminderService manages email notifications to pending signers with delivery tracking
|
||||
type ReminderService struct {
|
||||
expectedSignerRepo *database.ExpectedSignerRepository
|
||||
reminderRepo *database.ReminderRepository
|
||||
expectedSignerRepo expectedSignerRepository
|
||||
reminderRepo reminderRepository
|
||||
emailSender email.Sender
|
||||
baseURL string
|
||||
}
|
||||
|
||||
// NewReminderService initializes reminder service with email sender and repository dependencies
|
||||
func NewReminderService(
|
||||
expectedSignerRepo *database.ExpectedSignerRepository,
|
||||
reminderRepo *database.ReminderRepository,
|
||||
expectedSignerRepo expectedSignerRepository,
|
||||
reminderRepo reminderRepository,
|
||||
emailSender email.Sender,
|
||||
baseURL string,
|
||||
) *ReminderService {
|
||||
@@ -33,7 +46,7 @@ func NewReminderService(
|
||||
}
|
||||
}
|
||||
|
||||
// SendReminders sends reminder emails to pending signers
|
||||
// SendReminders dispatches email notifications to all or selected pending signers with result aggregation
|
||||
func (s *ReminderService) SendReminders(
|
||||
ctx context.Context,
|
||||
docID string,
|
||||
@@ -43,11 +56,24 @@ func (s *ReminderService) SendReminders(
|
||||
locale string,
|
||||
) (*models.ReminderSendResult, error) {
|
||||
|
||||
logger.Logger.Info("Starting reminder sending process",
|
||||
"doc_id", docID,
|
||||
"sent_by", sentBy,
|
||||
"specific_emails_count", len(specificEmails),
|
||||
"locale", locale)
|
||||
|
||||
allSigners, err := s.expectedSignerRepo.ListWithStatusByDocID(ctx, docID)
|
||||
if err != nil {
|
||||
logger.Logger.Error("Failed to get expected signers for reminders",
|
||||
"doc_id", docID,
|
||||
"error", err.Error())
|
||||
return nil, fmt.Errorf("failed to get expected signers: %w", err)
|
||||
}
|
||||
|
||||
logger.Logger.Debug("Retrieved expected signers",
|
||||
"doc_id", docID,
|
||||
"total_signers", len(allSigners))
|
||||
|
||||
var pendingSigners []*models.ExpectedSignerWithStatus
|
||||
for _, signer := range allSigners {
|
||||
if !signer.HasSigned {
|
||||
@@ -61,7 +87,14 @@ func (s *ReminderService) SendReminders(
|
||||
}
|
||||
}
|
||||
|
||||
logger.Logger.Info("Identified pending signers",
|
||||
"doc_id", docID,
|
||||
"pending_count", len(pendingSigners),
|
||||
"total_signers", len(allSigners))
|
||||
|
||||
if len(pendingSigners) == 0 {
|
||||
logger.Logger.Info("No pending signers found, no reminders to send",
|
||||
"doc_id", docID)
|
||||
return &models.ReminderSendResult{
|
||||
TotalAttempted: 0,
|
||||
SuccessfullySent: 0,
|
||||
@@ -83,6 +116,12 @@ func (s *ReminderService) SendReminders(
|
||||
}
|
||||
}
|
||||
|
||||
logger.Logger.Info("Reminder batch completed",
|
||||
"doc_id", docID,
|
||||
"total_attempted", result.TotalAttempted,
|
||||
"successfully_sent", result.SuccessfullySent,
|
||||
"failed", result.Failed)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -97,6 +136,12 @@ func (s *ReminderService) sendSingleReminder(
|
||||
locale string,
|
||||
) error {
|
||||
|
||||
logger.Logger.Debug("Sending reminder to signer",
|
||||
"doc_id", docID,
|
||||
"recipient_email", recipientEmail,
|
||||
"recipient_name", recipientName,
|
||||
"sent_by", sentBy)
|
||||
|
||||
signURL := fmt.Sprintf("%s/sign?doc=%s", s.baseURL, docID)
|
||||
|
||||
log := &models.ReminderLog{
|
||||
@@ -114,27 +159,43 @@ func (s *ReminderService) sendSingleReminder(
|
||||
errMsg := err.Error()
|
||||
log.ErrorMessage = &errMsg
|
||||
|
||||
logger.Logger.Warn("Failed to send reminder email",
|
||||
"doc_id", docID,
|
||||
"recipient_email", recipientEmail,
|
||||
"error", err.Error())
|
||||
|
||||
if logErr := s.reminderRepo.LogReminder(ctx, log); logErr != nil {
|
||||
logger.Logger.Error("failed to log reminder error", "error", logErr, "original_error", err)
|
||||
logger.Logger.Error("Failed to log reminder error",
|
||||
"doc_id", docID,
|
||||
"recipient_email", recipientEmail,
|
||||
"log_error", logErr.Error(),
|
||||
"original_error", err.Error())
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to send email: %w", err)
|
||||
}
|
||||
|
||||
logger.Logger.Info("Reminder email sent successfully",
|
||||
"doc_id", docID,
|
||||
"recipient_email", recipientEmail)
|
||||
|
||||
if err := s.reminderRepo.LogReminder(ctx, log); err != nil {
|
||||
logger.Logger.Error("failed to log successful reminder", "error", err)
|
||||
logger.Logger.Error("Failed to log successful reminder",
|
||||
"doc_id", docID,
|
||||
"recipient_email", recipientEmail,
|
||||
"error", err.Error())
|
||||
return fmt.Errorf("email sent but failed to log: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetReminderStats returns reminder statistics for a document
|
||||
// GetReminderStats retrieves aggregated reminder metrics for monitoring dashboard
|
||||
func (s *ReminderService) GetReminderStats(ctx context.Context, docID string) (*models.ReminderStats, error) {
|
||||
return s.reminderRepo.GetReminderStats(ctx, docID)
|
||||
}
|
||||
|
||||
// GetReminderHistory returns reminder history for a document
|
||||
// GetReminderHistory retrieves complete email send log with success/failure tracking
|
||||
func (s *ReminderService) GetReminderHistory(ctx context.Context, docID string) ([]*models.ReminderLog, error) {
|
||||
return s.reminderRepo.GetReminderHistory(ctx, docID)
|
||||
}
|
||||
257
backend/internal/application/services/reminder_async.go
Normal file
257
backend/internal/application/services/reminder_async.go
Normal file
@@ -0,0 +1,257 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/logger"
|
||||
)
|
||||
|
||||
// emailQueueRepository defines minimal interface for email queue operations
|
||||
type emailQueueRepository interface {
|
||||
Enqueue(ctx context.Context, input models.EmailQueueInput) (*models.EmailQueueItem, error)
|
||||
GetQueueStats(ctx context.Context) (*models.EmailQueueStats, error)
|
||||
}
|
||||
|
||||
// ReminderAsyncService manages email notifications using asynchronous queue
|
||||
type ReminderAsyncService struct {
|
||||
expectedSignerRepo expectedSignerRepository
|
||||
reminderRepo reminderRepository
|
||||
queueRepo emailQueueRepository
|
||||
baseURL string
|
||||
useAsyncQueue bool // Feature flag to enable/disable async queue
|
||||
}
|
||||
|
||||
// NewReminderAsyncService initializes async reminder service with queue support
|
||||
func NewReminderAsyncService(
|
||||
expectedSignerRepo expectedSignerRepository,
|
||||
reminderRepo reminderRepository,
|
||||
queueRepo emailQueueRepository,
|
||||
baseURL string,
|
||||
) *ReminderAsyncService {
|
||||
return &ReminderAsyncService{
|
||||
expectedSignerRepo: expectedSignerRepo,
|
||||
reminderRepo: reminderRepo,
|
||||
queueRepo: queueRepo,
|
||||
baseURL: baseURL,
|
||||
useAsyncQueue: true, // Enable async by default
|
||||
}
|
||||
}
|
||||
|
||||
// SendRemindersAsync dispatches email notifications to queue for async processing
|
||||
func (s *ReminderAsyncService) SendRemindersAsync(
|
||||
ctx context.Context,
|
||||
docID string,
|
||||
sentBy string,
|
||||
specificEmails []string,
|
||||
docURL string,
|
||||
locale string,
|
||||
) (*models.ReminderSendResult, error) {
|
||||
|
||||
logger.Logger.Info("Starting async reminder queueing process",
|
||||
"doc_id", docID,
|
||||
"sent_by", sentBy,
|
||||
"specific_emails_count", len(specificEmails),
|
||||
"locale", locale)
|
||||
|
||||
allSigners, err := s.expectedSignerRepo.ListWithStatusByDocID(ctx, docID)
|
||||
if err != nil {
|
||||
logger.Logger.Error("Failed to get expected signers for reminders",
|
||||
"doc_id", docID,
|
||||
"error", err.Error())
|
||||
return nil, fmt.Errorf("failed to get expected signers: %w", err)
|
||||
}
|
||||
|
||||
logger.Logger.Debug("Retrieved expected signers",
|
||||
"doc_id", docID,
|
||||
"total_signers", len(allSigners))
|
||||
|
||||
// Filter pending signers
|
||||
var pendingSigners []*models.ExpectedSignerWithStatus
|
||||
for _, signer := range allSigners {
|
||||
if !signer.HasSigned {
|
||||
if len(specificEmails) > 0 {
|
||||
if containsEmail(specificEmails, signer.Email) {
|
||||
pendingSigners = append(pendingSigners, signer)
|
||||
}
|
||||
} else {
|
||||
pendingSigners = append(pendingSigners, signer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.Logger.Info("Identified pending signers",
|
||||
"doc_id", docID,
|
||||
"pending_count", len(pendingSigners),
|
||||
"total_signers", len(allSigners))
|
||||
|
||||
if len(pendingSigners) == 0 {
|
||||
logger.Logger.Info("No pending signers found, no reminders to queue",
|
||||
"doc_id", docID)
|
||||
return &models.ReminderSendResult{
|
||||
TotalAttempted: 0,
|
||||
SuccessfullySent: 0,
|
||||
Failed: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
result := &models.ReminderSendResult{
|
||||
TotalAttempted: len(pendingSigners),
|
||||
}
|
||||
|
||||
// Queue emails asynchronously
|
||||
for _, signer := range pendingSigners {
|
||||
err := s.queueSingleReminder(ctx, docID, signer.Email, signer.Name, sentBy, docURL, locale)
|
||||
if err != nil {
|
||||
result.Failed++
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("%s: %v", signer.Email, err))
|
||||
} else {
|
||||
result.SuccessfullySent++
|
||||
}
|
||||
}
|
||||
|
||||
logger.Logger.Info("Reminder queueing completed",
|
||||
"doc_id", docID,
|
||||
"total_attempted", result.TotalAttempted,
|
||||
"successfully_queued", result.SuccessfullySent,
|
||||
"failed", result.Failed)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// queueSingleReminder queues a reminder for a single signer
|
||||
func (s *ReminderAsyncService) queueSingleReminder(
|
||||
ctx context.Context,
|
||||
docID string,
|
||||
recipientEmail string,
|
||||
recipientName string,
|
||||
sentBy string,
|
||||
docURL string,
|
||||
locale string,
|
||||
) error {
|
||||
|
||||
logger.Logger.Debug("Queueing reminder for signer",
|
||||
"doc_id", docID,
|
||||
"recipient_email", recipientEmail,
|
||||
"recipient_name", recipientName,
|
||||
"sent_by", sentBy)
|
||||
|
||||
signURL := fmt.Sprintf("%s/sign?doc=%s", s.baseURL, docID)
|
||||
|
||||
// Prepare email data (keys must match template variables)
|
||||
data := map[string]interface{}{
|
||||
"DocID": docID,
|
||||
"DocURL": docURL,
|
||||
"SignURL": signURL,
|
||||
"RecipientName": recipientName,
|
||||
"Locale": locale,
|
||||
}
|
||||
|
||||
// Create email queue input
|
||||
refType := "signature_reminder"
|
||||
input := models.EmailQueueInput{
|
||||
ToAddresses: []string{recipientEmail},
|
||||
Subject: "Reminder: Document signature required",
|
||||
Template: "signature_reminder",
|
||||
Locale: locale,
|
||||
Data: data,
|
||||
Priority: models.EmailPriorityHigh,
|
||||
ReferenceType: &refType,
|
||||
ReferenceID: &docID,
|
||||
CreatedBy: &sentBy,
|
||||
MaxRetries: 5, // More retries for important reminders
|
||||
}
|
||||
|
||||
// Queue the email
|
||||
item, err := s.queueRepo.Enqueue(ctx, input)
|
||||
if err != nil {
|
||||
logger.Logger.Warn("Failed to queue reminder email",
|
||||
"doc_id", docID,
|
||||
"recipient_email", recipientEmail,
|
||||
"error", err.Error())
|
||||
|
||||
// Log the failure
|
||||
log := &models.ReminderLog{
|
||||
DocID: docID,
|
||||
RecipientEmail: recipientEmail,
|
||||
SentAt: time.Now(),
|
||||
SentBy: sentBy,
|
||||
TemplateUsed: "signature_reminder",
|
||||
Status: "failed",
|
||||
}
|
||||
errMsg := fmt.Sprintf("Failed to queue: %v", err)
|
||||
log.ErrorMessage = &errMsg
|
||||
|
||||
if logErr := s.reminderRepo.LogReminder(ctx, log); logErr != nil {
|
||||
logger.Logger.Error("Failed to log reminder queue error",
|
||||
"doc_id", docID,
|
||||
"recipient_email", recipientEmail,
|
||||
"log_error", logErr.Error(),
|
||||
"original_error", err.Error())
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to queue email: %w", err)
|
||||
}
|
||||
|
||||
logger.Logger.Info("Reminder email queued successfully",
|
||||
"doc_id", docID,
|
||||
"recipient_email", recipientEmail,
|
||||
"queue_id", item.ID)
|
||||
|
||||
// Log successful queueing
|
||||
log := &models.ReminderLog{
|
||||
DocID: docID,
|
||||
RecipientEmail: recipientEmail,
|
||||
SentAt: time.Now(),
|
||||
SentBy: sentBy,
|
||||
TemplateUsed: "signature_reminder",
|
||||
Status: "queued", // New status for queued emails
|
||||
}
|
||||
|
||||
if err := s.reminderRepo.LogReminder(ctx, log); err != nil {
|
||||
logger.Logger.Error("Failed to log successful reminder queueing",
|
||||
"doc_id", docID,
|
||||
"recipient_email", recipientEmail,
|
||||
"error", err.Error())
|
||||
// Non-critical error, email is already queued
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetQueueStats returns current email queue statistics
|
||||
func (s *ReminderAsyncService) GetQueueStats(ctx context.Context) (*models.EmailQueueStats, error) {
|
||||
return s.queueRepo.GetQueueStats(ctx)
|
||||
}
|
||||
|
||||
// GetReminderStats retrieves aggregated reminder metrics for monitoring dashboard
|
||||
func (s *ReminderAsyncService) GetReminderStats(ctx context.Context, docID string) (*models.ReminderStats, error) {
|
||||
return s.reminderRepo.GetReminderStats(ctx, docID)
|
||||
}
|
||||
|
||||
// GetReminderHistory retrieves complete email send log with success/failure tracking
|
||||
func (s *ReminderAsyncService) GetReminderHistory(ctx context.Context, docID string) ([]*models.ReminderLog, error) {
|
||||
return s.reminderRepo.GetReminderHistory(ctx, docID)
|
||||
}
|
||||
|
||||
// EnableAsync enables or disables async queue processing
|
||||
func (s *ReminderAsyncService) EnableAsync(enabled bool) {
|
||||
s.useAsyncQueue = enabled
|
||||
logger.Logger.Info("Async queue processing toggled", "enabled", enabled)
|
||||
}
|
||||
|
||||
// SendReminders is a compatibility method that calls SendRemindersAsync
|
||||
// This allows the service to work with existing interfaces expecting SendReminders
|
||||
func (s *ReminderAsyncService) SendReminders(
|
||||
ctx context.Context,
|
||||
docID string,
|
||||
sentBy string,
|
||||
specificEmails []string,
|
||||
docURL string,
|
||||
locale string,
|
||||
) (*models.ReminderSendResult, error) {
|
||||
return s.SendRemindersAsync(ctx, docID, sentBy, specificEmails, docURL, locale)
|
||||
}
|
||||
518
backend/internal/application/services/reminder_test.go
Normal file
518
backend/internal/application/services/reminder_test.go
Normal file
@@ -0,0 +1,518 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/email"
|
||||
)
|
||||
|
||||
// Mock implementations for testing
|
||||
type mockExpectedSignerRepository struct {
|
||||
listWithStatusFunc func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error)
|
||||
}
|
||||
|
||||
func (m *mockExpectedSignerRepository) ListWithStatusByDocID(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) {
|
||||
if m.listWithStatusFunc != nil {
|
||||
return m.listWithStatusFunc(ctx, docID)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type mockReminderRepository struct {
|
||||
logReminderFunc func(ctx context.Context, log *models.ReminderLog) error
|
||||
getReminderHistoryFunc func(ctx context.Context, docID string) ([]*models.ReminderLog, error)
|
||||
getReminderStatsFunc func(ctx context.Context, docID string) (*models.ReminderStats, error)
|
||||
}
|
||||
|
||||
func (m *mockReminderRepository) LogReminder(ctx context.Context, log *models.ReminderLog) error {
|
||||
if m.logReminderFunc != nil {
|
||||
return m.logReminderFunc(ctx, log)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockReminderRepository) GetReminderHistory(ctx context.Context, docID string) ([]*models.ReminderLog, error) {
|
||||
if m.getReminderHistoryFunc != nil {
|
||||
return m.getReminderHistoryFunc(ctx, docID)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockReminderRepository) GetReminderStats(ctx context.Context, docID string) (*models.ReminderStats, error) {
|
||||
if m.getReminderStatsFunc != nil {
|
||||
return m.getReminderStatsFunc(ctx, docID)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type mockEmailSender struct {
|
||||
sendFunc func(ctx context.Context, msg email.Message) error
|
||||
}
|
||||
|
||||
func (m *mockEmailSender) Send(ctx context.Context, msg email.Message) error {
|
||||
if m.sendFunc != nil {
|
||||
return m.sendFunc(ctx, msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Test helper function
|
||||
func TestContainsEmail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
slice []string
|
||||
item string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Email found",
|
||||
slice: []string{"alice@example.com", "bob@example.com", "charlie@example.com"},
|
||||
item: "bob@example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Email not found",
|
||||
slice: []string{"alice@example.com", "bob@example.com"},
|
||||
item: "charlie@example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Empty slice",
|
||||
slice: []string{},
|
||||
item: "test@example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Case sensitive",
|
||||
slice: []string{"Test@Example.com"},
|
||||
item: "test@example.com",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := containsEmail(tt.slice, tt.item)
|
||||
if result != tt.expected {
|
||||
t.Errorf("containsEmail(%v, %q) = %v, want %v", tt.slice, tt.item, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test SendReminders with no pending signers
|
||||
func TestReminderService_SendReminders_NoPendingSigners(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
mockExpectedRepo := &mockExpectedSignerRepository{
|
||||
listWithStatusFunc: func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) {
|
||||
return []*models.ExpectedSignerWithStatus{
|
||||
{ExpectedSigner: models.ExpectedSigner{Email: "signed@example.com"}, HasSigned: true},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mockReminderRepo := &mockReminderRepository{}
|
||||
mockEmailSender := &mockEmailSender{}
|
||||
|
||||
service := NewReminderService(mockExpectedRepo, mockReminderRepo, mockEmailSender, "https://example.com")
|
||||
|
||||
result, err := service.SendReminders(ctx, "doc1", "admin@example.com", nil, "https://example.com/doc.pdf", "en")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if result.TotalAttempted != 0 {
|
||||
t.Errorf("Expected 0 total attempted, got %d", result.TotalAttempted)
|
||||
}
|
||||
|
||||
if result.SuccessfullySent != 0 {
|
||||
t.Errorf("Expected 0 successfully sent, got %d", result.SuccessfullySent)
|
||||
}
|
||||
}
|
||||
|
||||
// Test SendReminders with successful email send
|
||||
func TestReminderService_SendReminders_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
mockExpectedRepo := &mockExpectedSignerRepository{
|
||||
listWithStatusFunc: func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) {
|
||||
return []*models.ExpectedSignerWithStatus{
|
||||
{ExpectedSigner: models.ExpectedSigner{Email: "pending@example.com", Name: "Pending User"}, HasSigned: false},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
loggedReminder := false
|
||||
mockReminderRepo := &mockReminderRepository{
|
||||
logReminderFunc: func(ctx context.Context, log *models.ReminderLog) error {
|
||||
loggedReminder = true
|
||||
if log.Status != "sent" {
|
||||
t.Errorf("Expected status 'sent', got '%s'", log.Status)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
emailSent := false
|
||||
mockEmailSender := &mockEmailSender{
|
||||
sendFunc: func(ctx context.Context, msg email.Message) error {
|
||||
emailSent = true
|
||||
if len(msg.To) != 1 || msg.To[0] != "pending@example.com" {
|
||||
t.Errorf("Expected email to 'pending@example.com', got %v", msg.To)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
service := NewReminderService(mockExpectedRepo, mockReminderRepo, mockEmailSender, "https://example.com")
|
||||
|
||||
result, err := service.SendReminders(ctx, "doc1", "admin@example.com", nil, "https://example.com/doc.pdf", "en")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if result.TotalAttempted != 1 {
|
||||
t.Errorf("Expected 1 total attempted, got %d", result.TotalAttempted)
|
||||
}
|
||||
|
||||
if result.SuccessfullySent != 1 {
|
||||
t.Errorf("Expected 1 successfully sent, got %d", result.SuccessfullySent)
|
||||
}
|
||||
|
||||
if result.Failed != 0 {
|
||||
t.Errorf("Expected 0 failed, got %d", result.Failed)
|
||||
}
|
||||
|
||||
if !emailSent {
|
||||
t.Error("Expected email to be sent")
|
||||
}
|
||||
|
||||
if !loggedReminder {
|
||||
t.Error("Expected reminder to be logged")
|
||||
}
|
||||
}
|
||||
|
||||
// Test SendReminders with email failure
|
||||
func TestReminderService_SendReminders_EmailFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
mockExpectedRepo := &mockExpectedSignerRepository{
|
||||
listWithStatusFunc: func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) {
|
||||
return []*models.ExpectedSignerWithStatus{
|
||||
{ExpectedSigner: models.ExpectedSigner{Email: "pending@example.com", Name: "Pending User"}, HasSigned: false},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
loggedReminder := false
|
||||
mockReminderRepo := &mockReminderRepository{
|
||||
logReminderFunc: func(ctx context.Context, log *models.ReminderLog) error {
|
||||
loggedReminder = true
|
||||
if log.Status != "failed" {
|
||||
t.Errorf("Expected status 'failed', got '%s'", log.Status)
|
||||
}
|
||||
if log.ErrorMessage == nil {
|
||||
t.Error("Expected error message to be set")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
mockEmailSender := &mockEmailSender{
|
||||
sendFunc: func(ctx context.Context, msg email.Message) error {
|
||||
return errors.New("SMTP connection failed")
|
||||
},
|
||||
}
|
||||
|
||||
service := NewReminderService(mockExpectedRepo, mockReminderRepo, mockEmailSender, "https://example.com")
|
||||
|
||||
result, err := service.SendReminders(ctx, "doc1", "admin@example.com", nil, "https://example.com/doc.pdf", "en")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error from SendReminders, got: %v", err)
|
||||
}
|
||||
|
||||
if result.TotalAttempted != 1 {
|
||||
t.Errorf("Expected 1 total attempted, got %d", result.TotalAttempted)
|
||||
}
|
||||
|
||||
if result.Failed != 1 {
|
||||
t.Errorf("Expected 1 failed, got %d", result.Failed)
|
||||
}
|
||||
|
||||
if result.SuccessfullySent != 0 {
|
||||
t.Errorf("Expected 0 successfully sent, got %d", result.SuccessfullySent)
|
||||
}
|
||||
|
||||
if len(result.Errors) != 1 {
|
||||
t.Errorf("Expected 1 error message, got %d", len(result.Errors))
|
||||
}
|
||||
|
||||
if !loggedReminder {
|
||||
t.Error("Expected failed reminder to be logged")
|
||||
}
|
||||
}
|
||||
|
||||
// Test SendReminders with specific emails filter
|
||||
func TestReminderService_SendReminders_SpecificEmails(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
mockExpectedRepo := &mockExpectedSignerRepository{
|
||||
listWithStatusFunc: func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) {
|
||||
return []*models.ExpectedSignerWithStatus{
|
||||
{ExpectedSigner: models.ExpectedSigner{Email: "pending1@example.com"}, HasSigned: false},
|
||||
{ExpectedSigner: models.ExpectedSigner{Email: "pending2@example.com"}, HasSigned: false},
|
||||
{ExpectedSigner: models.ExpectedSigner{Email: "pending3@example.com"}, HasSigned: false},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
emailsSent := []string{}
|
||||
mockReminderRepo := &mockReminderRepository{
|
||||
logReminderFunc: func(ctx context.Context, log *models.ReminderLog) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
mockEmailSender := &mockEmailSender{
|
||||
sendFunc: func(ctx context.Context, msg email.Message) error {
|
||||
emailsSent = append(emailsSent, msg.To[0])
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
service := NewReminderService(mockExpectedRepo, mockReminderRepo, mockEmailSender, "https://example.com")
|
||||
|
||||
specificEmails := []string{"pending2@example.com"}
|
||||
result, err := service.SendReminders(ctx, "doc1", "admin@example.com", specificEmails, "https://example.com/doc.pdf", "en")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if result.TotalAttempted != 1 {
|
||||
t.Errorf("Expected 1 total attempted, got %d", result.TotalAttempted)
|
||||
}
|
||||
|
||||
if len(emailsSent) != 1 || emailsSent[0] != "pending2@example.com" {
|
||||
t.Errorf("Expected only 'pending2@example.com' to receive email, got %v", emailsSent)
|
||||
}
|
||||
}
|
||||
|
||||
// Test SendReminders with repository error
|
||||
func TestReminderService_SendReminders_RepositoryError(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
mockExpectedRepo := &mockExpectedSignerRepository{
|
||||
listWithStatusFunc: func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) {
|
||||
return nil, errors.New("database connection failed")
|
||||
},
|
||||
}
|
||||
|
||||
mockReminderRepo := &mockReminderRepository{}
|
||||
mockEmailSender := &mockEmailSender{}
|
||||
|
||||
service := NewReminderService(mockExpectedRepo, mockReminderRepo, mockEmailSender, "https://example.com")
|
||||
|
||||
result, err := service.SendReminders(ctx, "doc1", "admin@example.com", nil, "https://example.com/doc.pdf", "en")
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil result on error, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
// Test GetReminderHistory
|
||||
func TestReminderService_GetReminderHistory(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
expectedLogs := []*models.ReminderLog{
|
||||
{
|
||||
DocID: "doc1",
|
||||
RecipientEmail: "user@example.com",
|
||||
SentAt: time.Now(),
|
||||
SentBy: "admin@example.com",
|
||||
Status: "sent",
|
||||
},
|
||||
}
|
||||
|
||||
mockReminderRepo := &mockReminderRepository{
|
||||
getReminderHistoryFunc: func(ctx context.Context, docID string) ([]*models.ReminderLog, error) {
|
||||
if docID != "doc1" {
|
||||
t.Errorf("Expected docID 'doc1', got '%s'", docID)
|
||||
}
|
||||
return expectedLogs, nil
|
||||
},
|
||||
}
|
||||
|
||||
service := NewReminderService(&mockExpectedSignerRepository{}, mockReminderRepo, &mockEmailSender{}, "https://example.com")
|
||||
|
||||
logs, err := service.GetReminderHistory(ctx, "doc1")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if len(logs) != 1 {
|
||||
t.Errorf("Expected 1 log, got %d", len(logs))
|
||||
}
|
||||
|
||||
if logs[0].RecipientEmail != "user@example.com" {
|
||||
t.Errorf("Expected recipient 'user@example.com', got '%s'", logs[0].RecipientEmail)
|
||||
}
|
||||
}
|
||||
|
||||
// Test GetReminderStats
|
||||
func TestReminderService_GetReminderStats(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now()
|
||||
expectedStats := &models.ReminderStats{
|
||||
TotalSent: 5,
|
||||
LastSentAt: &now,
|
||||
PendingCount: 2,
|
||||
}
|
||||
|
||||
mockReminderRepo := &mockReminderRepository{
|
||||
getReminderStatsFunc: func(ctx context.Context, docID string) (*models.ReminderStats, error) {
|
||||
if docID != "doc1" {
|
||||
t.Errorf("Expected docID 'doc1', got '%s'", docID)
|
||||
}
|
||||
return expectedStats, nil
|
||||
},
|
||||
}
|
||||
|
||||
service := NewReminderService(&mockExpectedSignerRepository{}, mockReminderRepo, &mockEmailSender{}, "https://example.com")
|
||||
|
||||
stats, err := service.GetReminderStats(ctx, "doc1")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if stats.TotalSent != 5 {
|
||||
t.Errorf("Expected 5 total sent, got %d", stats.TotalSent)
|
||||
}
|
||||
|
||||
if stats.PendingCount != 2 {
|
||||
t.Errorf("Expected 2 pending, got %d", stats.PendingCount)
|
||||
}
|
||||
|
||||
if stats.LastSentAt == nil {
|
||||
t.Error("Expected LastSentAt to be set")
|
||||
}
|
||||
}
|
||||
|
||||
// Test SendReminders with multiple pending signers
|
||||
func TestReminderService_SendReminders_MultiplePending(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
mockExpectedRepo := &mockExpectedSignerRepository{
|
||||
listWithStatusFunc: func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) {
|
||||
return []*models.ExpectedSignerWithStatus{
|
||||
{ExpectedSigner: models.ExpectedSigner{Email: "pending1@example.com", Name: "User 1"}, HasSigned: false},
|
||||
{ExpectedSigner: models.ExpectedSigner{Email: "pending2@example.com", Name: "User 2"}, HasSigned: false},
|
||||
{ExpectedSigner: models.ExpectedSigner{Email: "already-signed@example.com", Name: "User 3"}, HasSigned: true},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
emailsSent := 0
|
||||
mockReminderRepo := &mockReminderRepository{
|
||||
logReminderFunc: func(ctx context.Context, log *models.ReminderLog) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
mockEmailSender := &mockEmailSender{
|
||||
sendFunc: func(ctx context.Context, msg email.Message) error {
|
||||
emailsSent++
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
service := NewReminderService(mockExpectedRepo, mockReminderRepo, mockEmailSender, "https://example.com")
|
||||
|
||||
result, err := service.SendReminders(ctx, "doc1", "admin@example.com", nil, "https://example.com/doc.pdf", "en")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if result.TotalAttempted != 2 {
|
||||
t.Errorf("Expected 2 total attempted, got %d", result.TotalAttempted)
|
||||
}
|
||||
|
||||
if result.SuccessfullySent != 2 {
|
||||
t.Errorf("Expected 2 successfully sent, got %d", result.SuccessfullySent)
|
||||
}
|
||||
|
||||
if emailsSent != 2 {
|
||||
t.Errorf("Expected 2 emails sent, got %d", emailsSent)
|
||||
}
|
||||
}
|
||||
|
||||
// Test SendReminders with log failure after successful email
|
||||
func TestReminderService_SendReminders_LogFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
mockExpectedRepo := &mockExpectedSignerRepository{
|
||||
listWithStatusFunc: func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) {
|
||||
return []*models.ExpectedSignerWithStatus{
|
||||
{ExpectedSigner: models.ExpectedSigner{Email: "pending@example.com", Name: "Pending User"}, HasSigned: false},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mockReminderRepo := &mockReminderRepository{
|
||||
logReminderFunc: func(ctx context.Context, log *models.ReminderLog) error {
|
||||
return errors.New("database write failed")
|
||||
},
|
||||
}
|
||||
|
||||
mockEmailSender := &mockEmailSender{
|
||||
sendFunc: func(ctx context.Context, msg email.Message) error {
|
||||
return nil // Email succeeds
|
||||
},
|
||||
}
|
||||
|
||||
service := NewReminderService(mockExpectedRepo, mockReminderRepo, mockEmailSender, "https://example.com")
|
||||
|
||||
result, err := service.SendReminders(ctx, "doc1", "admin@example.com", nil, "https://example.com/doc.pdf", "en")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error from SendReminders, got: %v", err)
|
||||
}
|
||||
|
||||
// The send should fail because logging failed
|
||||
if result.Failed != 1 {
|
||||
t.Errorf("Expected 1 failed, got %d", result.Failed)
|
||||
}
|
||||
|
||||
if result.SuccessfullySent != 0 {
|
||||
t.Errorf("Expected 0 successfully sent, got %d", result.SuccessfullySent)
|
||||
}
|
||||
}
|
||||
444
backend/internal/application/services/signature.go
Normal file
444
backend/internal/application/services/signature.go
Normal file
@@ -0,0 +1,444 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/config"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/checksum"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/crypto"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/logger"
|
||||
)
|
||||
|
||||
type repository interface {
|
||||
Create(ctx context.Context, signature *models.Signature) error
|
||||
GetByDocAndUser(ctx context.Context, docID, userSub string) (*models.Signature, error)
|
||||
GetByDoc(ctx context.Context, docID string) ([]*models.Signature, error)
|
||||
GetByUser(ctx context.Context, userSub string) ([]*models.Signature, error)
|
||||
ExistsByDocAndUser(ctx context.Context, docID, userSub string) (bool, error)
|
||||
CheckUserSignatureStatus(ctx context.Context, docID, userIdentifier string) (bool, error)
|
||||
GetLastSignature(ctx context.Context, docID string) (*models.Signature, error)
|
||||
GetAllSignaturesOrdered(ctx context.Context) ([]*models.Signature, error)
|
||||
UpdatePrevHash(ctx context.Context, id int64, prevHash *string) error
|
||||
}
|
||||
|
||||
type cryptoSigner interface {
|
||||
CreateSignature(docID string, user *models.User, timestamp time.Time, nonce string, docChecksum string) (string, string, error)
|
||||
}
|
||||
|
||||
// SignatureService orchestrates signature creation with Ed25519 cryptography and hash chain linking
|
||||
type SignatureService struct {
|
||||
repo repository
|
||||
docRepo documentRepository
|
||||
signer cryptoSigner
|
||||
checksumConfig *config.ChecksumConfig
|
||||
}
|
||||
|
||||
// NewSignatureService initializes the signature service with repository and cryptographic signer dependencies
|
||||
func NewSignatureService(repo repository, docRepo documentRepository, signer cryptoSigner) *SignatureService {
|
||||
return &SignatureService{
|
||||
repo: repo,
|
||||
docRepo: docRepo,
|
||||
signer: signer,
|
||||
}
|
||||
}
|
||||
|
||||
// SetChecksumConfig sets the checksum configuration for document verification
|
||||
func (s *SignatureService) SetChecksumConfig(cfg *config.ChecksumConfig) {
|
||||
s.checksumConfig = cfg
|
||||
}
|
||||
|
||||
// CreateSignature validates user authorization, generates cryptographic proof, and chains to previous signature
|
||||
func (s *SignatureService) CreateSignature(ctx context.Context, request *models.SignatureRequest) error {
|
||||
logger.Logger.Info("Signature creation attempt",
|
||||
"doc_id", request.DocID,
|
||||
"user_email", func() string {
|
||||
if request.User != nil {
|
||||
return request.User.NormalizedEmail()
|
||||
}
|
||||
return ""
|
||||
}())
|
||||
|
||||
if request.User == nil || !request.User.IsValid() {
|
||||
logger.Logger.Warn("Signature creation failed: invalid user",
|
||||
"doc_id", request.DocID,
|
||||
"user_nil", request.User == nil)
|
||||
return models.ErrInvalidUser
|
||||
}
|
||||
|
||||
if request.DocID == "" {
|
||||
logger.Logger.Warn("Signature creation failed: invalid document",
|
||||
"user_email", request.User.NormalizedEmail())
|
||||
return models.ErrInvalidDocument
|
||||
}
|
||||
|
||||
exists, err := s.repo.ExistsByDocAndUser(ctx, request.DocID, request.User.Sub)
|
||||
if err != nil {
|
||||
logger.Logger.Error("Signature creation failed: database check error",
|
||||
"doc_id", request.DocID,
|
||||
"user_email", request.User.NormalizedEmail(),
|
||||
"error", err.Error())
|
||||
return fmt.Errorf("failed to check existing signature: %w", err)
|
||||
}
|
||||
|
||||
if exists {
|
||||
logger.Logger.Warn("Signature creation failed: already exists",
|
||||
"doc_id", request.DocID,
|
||||
"user_email", request.User.NormalizedEmail())
|
||||
return models.ErrSignatureAlreadyExists
|
||||
}
|
||||
|
||||
nonce, err := crypto.GenerateNonce()
|
||||
if err != nil {
|
||||
logger.Logger.Error("Signature creation failed: nonce generation error",
|
||||
"doc_id", request.DocID,
|
||||
"user_email", request.User.NormalizedEmail(),
|
||||
"error", err.Error())
|
||||
return fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
// Fetch document metadata to get checksum (if available)
|
||||
var docChecksum string
|
||||
doc, err := s.docRepo.GetByDocID(ctx, request.DocID)
|
||||
if err != nil {
|
||||
logger.Logger.Debug("Document metadata not found, signing without checksum",
|
||||
"doc_id", request.DocID,
|
||||
"error", err.Error())
|
||||
// Continue without checksum - document metadata is optional
|
||||
} else if doc != nil && doc.Checksum != "" {
|
||||
// Verify document hasn't been modified before signing
|
||||
if err := s.verifyDocumentIntegrity(doc); err != nil {
|
||||
logger.Logger.Warn("Document integrity check failed",
|
||||
"doc_id", request.DocID,
|
||||
"error", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
docChecksum = doc.Checksum
|
||||
checksumPreview := docChecksum
|
||||
if len(docChecksum) > 16 {
|
||||
checksumPreview = docChecksum[:16] + "..."
|
||||
}
|
||||
logger.Logger.Debug("Including document checksum in signature",
|
||||
"doc_id", request.DocID,
|
||||
"checksum", checksumPreview)
|
||||
}
|
||||
|
||||
timestamp := time.Now().UTC()
|
||||
payloadHash, signatureB64, err := s.signer.CreateSignature(request.DocID, request.User, timestamp, nonce, docChecksum)
|
||||
if err != nil {
|
||||
logger.Logger.Error("Signature creation failed: cryptographic signature error",
|
||||
"doc_id", request.DocID,
|
||||
"user_email", request.User.NormalizedEmail(),
|
||||
"error", err.Error())
|
||||
return fmt.Errorf("failed to create cryptographic signature: %w", err)
|
||||
}
|
||||
|
||||
lastSignature, err := s.repo.GetLastSignature(ctx, request.DocID)
|
||||
if err != nil {
|
||||
logger.Logger.Error("Signature creation failed: chain lookup error",
|
||||
"doc_id", request.DocID,
|
||||
"user_email", request.User.NormalizedEmail(),
|
||||
"error", err.Error())
|
||||
return fmt.Errorf("failed to get last signature for chaining: %w", err)
|
||||
}
|
||||
|
||||
var prevHashB64 *string
|
||||
if lastSignature != nil {
|
||||
hash := lastSignature.ComputeRecordHash()
|
||||
prevHashB64 = &hash
|
||||
logger.Logger.Debug("Chaining to previous signature",
|
||||
"doc_id", request.DocID,
|
||||
"prev_signature_id", lastSignature.ID,
|
||||
"prev_hash", hash[:16]+"...")
|
||||
} else {
|
||||
logger.Logger.Debug("Creating genesis signature (no previous signature)",
|
||||
"doc_id", request.DocID)
|
||||
}
|
||||
|
||||
signature := &models.Signature{
|
||||
DocID: request.DocID,
|
||||
UserSub: request.User.Sub,
|
||||
UserEmail: request.User.NormalizedEmail(),
|
||||
UserName: request.User.Name,
|
||||
SignedAtUTC: timestamp,
|
||||
DocChecksum: docChecksum,
|
||||
PayloadHash: payloadHash,
|
||||
Signature: signatureB64,
|
||||
Nonce: nonce,
|
||||
Referer: request.Referer,
|
||||
PrevHash: prevHashB64,
|
||||
}
|
||||
|
||||
if err := s.repo.Create(ctx, signature); err != nil {
|
||||
logger.Logger.Error("Signature creation failed: database save error",
|
||||
"doc_id", request.DocID,
|
||||
"user_email", request.User.NormalizedEmail(),
|
||||
"error", err.Error())
|
||||
return fmt.Errorf("failed to save signature: %w", err)
|
||||
}
|
||||
|
||||
logger.Logger.Info("Signature created successfully",
|
||||
"signature_id", signature.ID,
|
||||
"doc_id", request.DocID,
|
||||
"user_email", request.User.NormalizedEmail(),
|
||||
"has_prev_hash", prevHashB64 != nil)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSignatureStatus checks if a user has already signed a document and returns signature timestamp if exists
|
||||
func (s *SignatureService) GetSignatureStatus(ctx context.Context, docID string, user *models.User) (*models.SignatureStatus, error) {
|
||||
if user == nil || !user.IsValid() {
|
||||
return nil, models.ErrInvalidUser
|
||||
}
|
||||
|
||||
signature, err := s.repo.GetByDocAndUser(ctx, docID, user.Sub)
|
||||
if err != nil {
|
||||
if errors.Is(err, models.ErrSignatureNotFound) {
|
||||
return &models.SignatureStatus{
|
||||
DocID: docID,
|
||||
UserEmail: user.Email,
|
||||
IsSigned: false,
|
||||
SignedAt: nil,
|
||||
}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get signature: %w", err)
|
||||
}
|
||||
|
||||
return &models.SignatureStatus{
|
||||
DocID: docID,
|
||||
UserEmail: user.Email,
|
||||
IsSigned: true,
|
||||
SignedAt: &signature.SignedAtUTC,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetDocumentSignatures retrieves all cryptographic signatures associated with a document for public verification
|
||||
func (s *SignatureService) GetDocumentSignatures(ctx context.Context, docID string) ([]*models.Signature, error) {
|
||||
logger.Logger.Debug("Retrieving document signatures",
|
||||
"doc_id", docID)
|
||||
|
||||
signatures, err := s.repo.GetByDoc(ctx, docID)
|
||||
if err != nil {
|
||||
logger.Logger.Error("Failed to retrieve document signatures",
|
||||
"doc_id", docID,
|
||||
"error", err.Error())
|
||||
return nil, fmt.Errorf("failed to get document signatures: %w", err)
|
||||
}
|
||||
|
||||
logger.Logger.Debug("Document signatures retrieved",
|
||||
"doc_id", docID,
|
||||
"count", len(signatures))
|
||||
|
||||
return signatures, nil
|
||||
}
|
||||
|
||||
// GetUserSignatures retrieves all documents signed by a specific user for personal dashboard display
|
||||
func (s *SignatureService) GetUserSignatures(ctx context.Context, user *models.User) ([]*models.Signature, error) {
|
||||
if user == nil || !user.IsValid() {
|
||||
return nil, models.ErrInvalidUser
|
||||
}
|
||||
|
||||
signatures, err := s.repo.GetByUser(ctx, user.Sub)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user signatures: %w", err)
|
||||
}
|
||||
|
||||
return signatures, nil
|
||||
}
|
||||
|
||||
// GetSignatureByDocAndUser retrieves a specific signature record for verification or display purposes
|
||||
func (s *SignatureService) GetSignatureByDocAndUser(ctx context.Context, docID string, user *models.User) (*models.Signature, error) {
|
||||
if user == nil || !user.IsValid() {
|
||||
return nil, models.ErrInvalidUser
|
||||
}
|
||||
|
||||
signature, err := s.repo.GetByDocAndUser(ctx, docID, user.Sub)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get signature: %w", err)
|
||||
}
|
||||
|
||||
return signature, nil
|
||||
}
|
||||
|
||||
// CheckUserSignature verifies signature existence using flexible identifier matching (email or OAuth subject)
|
||||
func (s *SignatureService) CheckUserSignature(ctx context.Context, docID, userIdentifier string) (bool, error) {
|
||||
exists, err := s.repo.CheckUserSignatureStatus(ctx, docID, userIdentifier)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check user signature: %w", err)
|
||||
}
|
||||
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
type ChainIntegrityResult struct {
|
||||
IsValid bool
|
||||
TotalRecords int
|
||||
BreakAtID *int64
|
||||
Details string
|
||||
}
|
||||
|
||||
// VerifyChainIntegrity validates the cryptographic hash chain across all signatures for tamper detection
|
||||
func (s *SignatureService) VerifyChainIntegrity(ctx context.Context) (*ChainIntegrityResult, error) {
|
||||
signatures, err := s.repo.GetAllSignaturesOrdered(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get signatures for chain verification: %w", err)
|
||||
}
|
||||
|
||||
result := &ChainIntegrityResult{
|
||||
IsValid: true,
|
||||
TotalRecords: len(signatures),
|
||||
}
|
||||
|
||||
if len(signatures) == 0 {
|
||||
result.Details = "No signatures found"
|
||||
return result, nil
|
||||
}
|
||||
|
||||
if signatures[0].PrevHash != nil {
|
||||
result.IsValid = false
|
||||
result.BreakAtID = &signatures[0].ID
|
||||
result.Details = "Genesis signature has non-null previous hash"
|
||||
return result, nil
|
||||
}
|
||||
|
||||
for i := 1; i < len(signatures); i++ {
|
||||
current := signatures[i]
|
||||
previous := signatures[i-1]
|
||||
|
||||
expectedHash := previous.ComputeRecordHash()
|
||||
|
||||
if current.PrevHash == nil {
|
||||
result.IsValid = false
|
||||
result.BreakAtID = ¤t.ID
|
||||
result.Details = fmt.Sprintf("Signature %d has null previous hash, expected: %s...", current.ID, expectedHash[:16])
|
||||
return result, nil
|
||||
}
|
||||
|
||||
if *current.PrevHash != expectedHash {
|
||||
result.IsValid = false
|
||||
result.BreakAtID = ¤t.ID
|
||||
result.Details = fmt.Sprintf("Hash mismatch at signature %d: expected %s..., got %s...",
|
||||
current.ID, expectedHash[:16], (*current.PrevHash)[:16])
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
result.Details = "Chain integrity verified successfully"
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// RebuildChain recalculates and updates prev_hash pointers for existing signatures during migration
|
||||
func (s *SignatureService) RebuildChain(ctx context.Context) error {
|
||||
signatures, err := s.repo.GetAllSignaturesOrdered(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get signatures for chain rebuild: %w", err)
|
||||
}
|
||||
|
||||
if len(signatures) == 0 {
|
||||
logger.Logger.Info("No signatures found, nothing to rebuild")
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Logger.Info("Starting chain rebuild", "totalSignatures", len(signatures))
|
||||
|
||||
if signatures[0].PrevHash != nil {
|
||||
if err := s.repo.UpdatePrevHash(ctx, signatures[0].ID, nil); err != nil {
|
||||
logger.Logger.Warn("Failed to nullify genesis prev_hash", "id", signatures[0].ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 1; i < len(signatures); i++ {
|
||||
current := signatures[i]
|
||||
previous := signatures[i-1]
|
||||
|
||||
expectedHash := previous.ComputeRecordHash()
|
||||
|
||||
if current.PrevHash == nil || *current.PrevHash != expectedHash {
|
||||
logger.Logger.Info("Chain rebuild: updating prev_hash",
|
||||
"id", current.ID,
|
||||
"expectedHash", expectedHash[:16]+"...",
|
||||
"hadPrevHash", current.PrevHash != nil)
|
||||
if err := s.repo.UpdatePrevHash(ctx, current.ID, &expectedHash); err != nil {
|
||||
logger.Logger.Warn("Failed to update prev_hash", "id", current.ID, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.Logger.Info("Chain rebuild completed", "processedSignatures", len(signatures))
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyDocumentIntegrity checks if the document at the URL hasn't been modified since the checksum was stored
|
||||
func (s *SignatureService) verifyDocumentIntegrity(doc *models.Document) error {
|
||||
// Only verify if document has URL and checksum, and checksum config is available
|
||||
if doc.URL == "" || doc.Checksum == "" || s.checksumConfig == nil {
|
||||
logger.Logger.Debug("Skipping document integrity check",
|
||||
"doc_id", doc.DocID,
|
||||
"has_url", doc.URL != "",
|
||||
"has_checksum", doc.Checksum != "",
|
||||
"has_config", s.checksumConfig != nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
storedChecksumPreview := doc.Checksum
|
||||
if len(doc.Checksum) > 16 {
|
||||
storedChecksumPreview = doc.Checksum[:16] + "..."
|
||||
}
|
||||
logger.Logger.Info("Verifying document integrity before signature",
|
||||
"doc_id", doc.DocID,
|
||||
"url", doc.URL,
|
||||
"stored_checksum", storedChecksumPreview)
|
||||
|
||||
// Configure checksum computation options
|
||||
opts := checksum.ComputeOptions{
|
||||
MaxBytes: s.checksumConfig.MaxBytes,
|
||||
TimeoutMs: s.checksumConfig.TimeoutMs,
|
||||
MaxRedirects: s.checksumConfig.MaxRedirects,
|
||||
AllowedContentType: s.checksumConfig.AllowedContentType,
|
||||
SkipSSRFCheck: s.checksumConfig.SkipSSRFCheck,
|
||||
InsecureSkipVerify: s.checksumConfig.InsecureSkipVerify,
|
||||
}
|
||||
|
||||
// Compute current checksum
|
||||
result, err := checksum.ComputeRemoteChecksum(doc.URL, opts)
|
||||
if err != nil {
|
||||
logger.Logger.Error("Failed to compute checksum for integrity check",
|
||||
"doc_id", doc.DocID,
|
||||
"url", doc.URL,
|
||||
"error", err.Error())
|
||||
// If we can't verify, we can't be sure it's modified, so we continue
|
||||
// but log the issue
|
||||
return nil
|
||||
}
|
||||
|
||||
// If checksum computation returned nil (too large, wrong type, network error, etc.)
|
||||
// we can't verify integrity, so we continue but log a warning
|
||||
if result == nil {
|
||||
logger.Logger.Warn("Could not verify document integrity - unable to compute checksum",
|
||||
"doc_id", doc.DocID,
|
||||
"url", doc.URL)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Compare checksums
|
||||
if result.ChecksumHex != doc.Checksum {
|
||||
logger.Logger.Error("Document integrity check FAILED - checksums do not match",
|
||||
"doc_id", doc.DocID,
|
||||
"url", doc.URL,
|
||||
"stored_checksum", doc.Checksum,
|
||||
"current_checksum", result.ChecksumHex)
|
||||
return models.ErrDocumentModified
|
||||
}
|
||||
|
||||
logger.Logger.Info("Document integrity verified successfully",
|
||||
"doc_id", doc.DocID,
|
||||
"checksum", result.ChecksumHex[:16]+"...")
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,356 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/config"
|
||||
)
|
||||
|
||||
// mockSignatureRepository for testing
|
||||
type mockSignatureRepository struct {
|
||||
createFunc func(ctx context.Context, signature *models.Signature) error
|
||||
existsByDocAndUserFunc func(ctx context.Context, docID, userSub string) (bool, error)
|
||||
getLastSignatureFunc func(ctx context.Context, docID string) (*models.Signature, error)
|
||||
getByDocAndUserFunc func(ctx context.Context, docID, userSub string) (*models.Signature, error)
|
||||
}
|
||||
|
||||
func (m *mockSignatureRepository) Create(ctx context.Context, signature *models.Signature) error {
|
||||
if m.createFunc != nil {
|
||||
return m.createFunc(ctx, signature)
|
||||
}
|
||||
signature.ID = 1
|
||||
signature.CreatedAt = time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSignatureRepository) ExistsByDocAndUser(ctx context.Context, docID, userSub string) (bool, error) {
|
||||
if m.existsByDocAndUserFunc != nil {
|
||||
return m.existsByDocAndUserFunc(ctx, docID, userSub)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *mockSignatureRepository) GetLastSignature(ctx context.Context, docID string) (*models.Signature, error) {
|
||||
if m.getLastSignatureFunc != nil {
|
||||
return m.getLastSignatureFunc(ctx, docID)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockSignatureRepository) GetByDocAndUser(ctx context.Context, docID, userSub string) (*models.Signature, error) {
|
||||
if m.getByDocAndUserFunc != nil {
|
||||
return m.getByDocAndUserFunc(ctx, docID, userSub)
|
||||
}
|
||||
return nil, models.ErrSignatureNotFound
|
||||
}
|
||||
|
||||
func (m *mockSignatureRepository) GetByDoc(ctx context.Context, docID string) ([]*models.Signature, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockSignatureRepository) GetByUser(ctx context.Context, userSub string) ([]*models.Signature, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockSignatureRepository) CheckUserSignatureStatus(ctx context.Context, docID, userIdentifier string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *mockSignatureRepository) GetAllSignaturesOrdered(ctx context.Context) ([]*models.Signature, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockSignatureRepository) UpdatePrevHash(ctx context.Context, id int64, prevHash *string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// mockCryptoSigner for testing
|
||||
type mockCryptoSigner struct{}
|
||||
|
||||
func (m *mockCryptoSigner) CreateSignature(docID string, user *models.User, timestamp time.Time, nonce string, docChecksum string) (string, string, error) {
|
||||
return "payload_hash", "signature_base64", nil
|
||||
}
|
||||
|
||||
// Test document integrity verification with matching checksum
|
||||
func TestSignatureService_DocumentIntegrity_Success(t *testing.T) {
|
||||
content := "Sample PDF content"
|
||||
expectedChecksum := "b3b4e8714358cc79990c5c83391172e01c3e79a1b456d7e0c570cbf59da30e23"
|
||||
|
||||
// Create test server with consistent content
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/pdf")
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(content)))
|
||||
if r.Method == "GET" {
|
||||
w.Write([]byte(content))
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create mock repositories
|
||||
docRepo := &mockDocumentRepository{
|
||||
getByDocIDFunc: func(ctx context.Context, docID string) (*models.Document, error) {
|
||||
return &models.Document{
|
||||
DocID: "test-doc",
|
||||
URL: server.URL,
|
||||
Checksum: expectedChecksum,
|
||||
ChecksumAlgorithm: "SHA-256",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
sigRepo := &mockSignatureRepository{}
|
||||
signer := &mockCryptoSigner{}
|
||||
|
||||
// Create service with checksum config
|
||||
service := NewSignatureService(sigRepo, docRepo, signer)
|
||||
checksumConfig := &config.ChecksumConfig{
|
||||
MaxBytes: 10 * 1024 * 1024,
|
||||
TimeoutMs: 5000,
|
||||
MaxRedirects: 3,
|
||||
AllowedContentType: []string{
|
||||
"application/pdf",
|
||||
},
|
||||
SkipSSRFCheck: true,
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
service.SetChecksumConfig(checksumConfig)
|
||||
|
||||
// Create signature request
|
||||
user := &models.User{
|
||||
Sub: "test-user",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
}
|
||||
|
||||
request := &models.SignatureRequest{
|
||||
DocID: "test-doc",
|
||||
User: user,
|
||||
}
|
||||
|
||||
// Should succeed because checksum matches
|
||||
err := service.CreateSignature(context.Background(), request)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected signature creation to succeed, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test document integrity verification with mismatched checksum
|
||||
func TestSignatureService_DocumentIntegrity_Modified(t *testing.T) {
|
||||
content := "Modified PDF content"
|
||||
storedChecksum := "b3b4e8714358cc79990c5c83391172e01c3e79a1b456d7e0c570cbf59da30e23" // Original checksum
|
||||
|
||||
// Create test server with different content
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/pdf")
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(content)))
|
||||
if r.Method == "GET" {
|
||||
w.Write([]byte(content))
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create mock repositories
|
||||
docRepo := &mockDocumentRepository{
|
||||
getByDocIDFunc: func(ctx context.Context, docID string) (*models.Document, error) {
|
||||
return &models.Document{
|
||||
DocID: "test-doc",
|
||||
URL: server.URL,
|
||||
Checksum: storedChecksum,
|
||||
ChecksumAlgorithm: "SHA-256",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
sigRepo := &mockSignatureRepository{}
|
||||
signer := &mockCryptoSigner{}
|
||||
|
||||
// Create service with checksum config
|
||||
service := NewSignatureService(sigRepo, docRepo, signer)
|
||||
checksumConfig := &config.ChecksumConfig{
|
||||
MaxBytes: 10 * 1024 * 1024,
|
||||
TimeoutMs: 5000,
|
||||
MaxRedirects: 3,
|
||||
AllowedContentType: []string{
|
||||
"application/pdf",
|
||||
},
|
||||
SkipSSRFCheck: true,
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
service.SetChecksumConfig(checksumConfig)
|
||||
|
||||
// Create signature request
|
||||
user := &models.User{
|
||||
Sub: "test-user",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
}
|
||||
|
||||
request := &models.SignatureRequest{
|
||||
DocID: "test-doc",
|
||||
User: user,
|
||||
}
|
||||
|
||||
// Should fail with ErrDocumentModified
|
||||
err := service.CreateSignature(context.Background(), request)
|
||||
if err != models.ErrDocumentModified {
|
||||
t.Fatalf("Expected ErrDocumentModified, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test signature creation without checksum (document has no URL or checksum)
|
||||
func TestSignatureService_NoChecksum_Success(t *testing.T) {
|
||||
// Create mock repositories
|
||||
docRepo := &mockDocumentRepository{
|
||||
getByDocIDFunc: func(ctx context.Context, docID string) (*models.Document, error) {
|
||||
return &models.Document{
|
||||
DocID: "test-doc",
|
||||
URL: "",
|
||||
Checksum: "",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
sigRepo := &mockSignatureRepository{}
|
||||
signer := &mockCryptoSigner{}
|
||||
|
||||
// Create service with checksum config
|
||||
service := NewSignatureService(sigRepo, docRepo, signer)
|
||||
checksumConfig := &config.ChecksumConfig{
|
||||
MaxBytes: 10 * 1024 * 1024,
|
||||
TimeoutMs: 5000,
|
||||
MaxRedirects: 3,
|
||||
AllowedContentType: []string{
|
||||
"application/pdf",
|
||||
},
|
||||
SkipSSRFCheck: true,
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
service.SetChecksumConfig(checksumConfig)
|
||||
|
||||
// Create signature request
|
||||
user := &models.User{
|
||||
Sub: "test-user",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
}
|
||||
|
||||
request := &models.SignatureRequest{
|
||||
DocID: "test-doc",
|
||||
User: user,
|
||||
}
|
||||
|
||||
// Should succeed because no checksum to verify
|
||||
err := service.CreateSignature(context.Background(), request)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected signature creation to succeed without checksum, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test signature creation without checksum config
|
||||
func TestSignatureService_NoChecksumConfig_Success(t *testing.T) {
|
||||
content := "Sample PDF content"
|
||||
|
||||
// Create test server
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/pdf")
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(content)))
|
||||
if r.Method == "GET" {
|
||||
w.Write([]byte(content))
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create mock repositories
|
||||
docRepo := &mockDocumentRepository{
|
||||
getByDocIDFunc: func(ctx context.Context, docID string) (*models.Document, error) {
|
||||
return &models.Document{
|
||||
DocID: "test-doc",
|
||||
URL: server.URL,
|
||||
Checksum: "some_checksum",
|
||||
ChecksumAlgorithm: "SHA-256",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
sigRepo := &mockSignatureRepository{}
|
||||
signer := &mockCryptoSigner{}
|
||||
|
||||
// Create service WITHOUT checksum config
|
||||
service := NewSignatureService(sigRepo, docRepo, signer)
|
||||
// Don't call SetChecksumConfig
|
||||
|
||||
// Create signature request
|
||||
user := &models.User{
|
||||
Sub: "test-user",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
}
|
||||
|
||||
request := &models.SignatureRequest{
|
||||
DocID: "test-doc",
|
||||
User: user,
|
||||
}
|
||||
|
||||
// Should succeed because no config means no verification
|
||||
err := service.CreateSignature(context.Background(), request)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected signature creation to succeed without config, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test document integrity with network error (should not block signature)
|
||||
func TestSignatureService_NetworkError_ContinuesAnyway(t *testing.T) {
|
||||
// Create mock repositories with unreachable URL
|
||||
docRepo := &mockDocumentRepository{
|
||||
getByDocIDFunc: func(ctx context.Context, docID string) (*models.Document, error) {
|
||||
return &models.Document{
|
||||
DocID: "test-doc",
|
||||
URL: "https://non-existent-server-12345.example.com/doc.pdf",
|
||||
Checksum: "some_checksum",
|
||||
ChecksumAlgorithm: "SHA-256",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
sigRepo := &mockSignatureRepository{}
|
||||
signer := &mockCryptoSigner{}
|
||||
|
||||
// Create service with checksum config
|
||||
service := NewSignatureService(sigRepo, docRepo, signer)
|
||||
checksumConfig := &config.ChecksumConfig{
|
||||
MaxBytes: 10 * 1024 * 1024,
|
||||
TimeoutMs: 100, // Very short timeout
|
||||
MaxRedirects: 3,
|
||||
AllowedContentType: []string{
|
||||
"application/pdf",
|
||||
},
|
||||
SkipSSRFCheck: false, // Enable SSRF check
|
||||
InsecureSkipVerify: false,
|
||||
}
|
||||
service.SetChecksumConfig(checksumConfig)
|
||||
|
||||
// Create signature request
|
||||
user := &models.User{
|
||||
Sub: "test-user",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
}
|
||||
|
||||
request := &models.SignatureRequest{
|
||||
DocID: "test-doc",
|
||||
User: user,
|
||||
}
|
||||
|
||||
// Should succeed even though we can't verify (network error doesn't block signature)
|
||||
err := service.CreateSignature(context.Background(), request)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected signature creation to succeed despite network error, got: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/btouchard/ackify-ce/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
)
|
||||
|
||||
type fakeRepository struct {
|
||||
@@ -158,7 +158,7 @@ func newFakeCryptoSigner() *fakeCryptoSigner {
|
||||
return &fakeCryptoSigner{}
|
||||
}
|
||||
|
||||
func (f *fakeCryptoSigner) CreateSignature(docID string, user *models.User, _ time.Time, _ string) (string, string, error) {
|
||||
func (f *fakeCryptoSigner) CreateSignature(docID string, user *models.User, _ time.Time, _ string, _ string) (string, string, error) {
|
||||
if f.shouldFail {
|
||||
return "", "", errors.New("crypto signing failed")
|
||||
}
|
||||
@@ -170,14 +170,17 @@ func (f *fakeCryptoSigner) CreateSignature(docID string, user *models.User, _ ti
|
||||
|
||||
func TestNewSignatureService(t *testing.T) {
|
||||
repo := newFakeRepository()
|
||||
docRepo := newFakeDocumentRepository()
|
||||
signer := newFakeCryptoSigner()
|
||||
|
||||
service := NewSignatureService(repo, signer)
|
||||
service := NewSignatureService(repo, docRepo, signer)
|
||||
|
||||
if service == nil {
|
||||
t.Error("NewSignatureService should not return nil")
|
||||
} else if service.repo != repo {
|
||||
t.Error("Service repository not set correctly")
|
||||
} else if service.docRepo == nil {
|
||||
t.Error("Service document repository not set correctly")
|
||||
} else if service.signer != signer {
|
||||
t.Error("Service signer not set correctly")
|
||||
}
|
||||
@@ -348,7 +351,7 @@ func TestSignatureService_CreateSignature(t *testing.T) {
|
||||
tt.setupSigner(signer)
|
||||
}
|
||||
|
||||
service := NewSignatureService(repo, signer)
|
||||
service := NewSignatureService(repo, newFakeDocumentRepository(), signer)
|
||||
|
||||
err := service.CreateSignature(context.Background(), tt.request)
|
||||
|
||||
@@ -467,7 +470,7 @@ func TestSignatureService_GetSignatureStatus(t *testing.T) {
|
||||
tt.setupRepo(repo)
|
||||
}
|
||||
|
||||
service := NewSignatureService(repo, signer)
|
||||
service := NewSignatureService(repo, newFakeDocumentRepository(), signer)
|
||||
|
||||
status, err := service.GetSignatureStatus(context.Background(), tt.docID, tt.user)
|
||||
|
||||
@@ -508,7 +511,7 @@ func TestSignatureService_GetSignatureStatus(t *testing.T) {
|
||||
func TestSignatureService_GetDocumentSignatures(t *testing.T) {
|
||||
repo := newFakeRepository()
|
||||
signer := newFakeCryptoSigner()
|
||||
service := NewSignatureService(repo, signer)
|
||||
service := NewSignatureService(repo, newFakeDocumentRepository(), signer)
|
||||
|
||||
sig1 := &models.Signature{ID: 1, DocID: "doc1", UserSub: "user1"}
|
||||
sig2 := &models.Signature{ID: 2, DocID: "doc1", UserSub: "user2"}
|
||||
@@ -541,7 +544,7 @@ func TestSignatureService_GetDocumentSignatures(t *testing.T) {
|
||||
func TestSignatureService_GetUserSignatures(t *testing.T) {
|
||||
repo := newFakeRepository()
|
||||
signer := newFakeCryptoSigner()
|
||||
service := NewSignatureService(repo, signer)
|
||||
service := NewSignatureService(repo, newFakeDocumentRepository(), signer)
|
||||
|
||||
sig1 := &models.Signature{ID: 1, DocID: "doc1", UserSub: "user1"}
|
||||
sig2 := &models.Signature{ID: 2, DocID: "doc2", UserSub: "user1"}
|
||||
@@ -583,7 +586,7 @@ func TestSignatureService_GetUserSignatures(t *testing.T) {
|
||||
func TestSignatureService_GetSignatureByDocAndUser(t *testing.T) {
|
||||
repo := newFakeRepository()
|
||||
signer := newFakeCryptoSigner()
|
||||
service := NewSignatureService(repo, signer)
|
||||
service := NewSignatureService(repo, newFakeDocumentRepository(), signer)
|
||||
|
||||
sig := &models.Signature{ID: 1, DocID: "doc1", UserSub: "user1"}
|
||||
repo.signatures["doc1_user1"] = sig
|
||||
@@ -620,7 +623,7 @@ func TestSignatureService_GetSignatureByDocAndUser(t *testing.T) {
|
||||
func TestSignatureService_CheckUserSignature(t *testing.T) {
|
||||
repo := newFakeRepository()
|
||||
signer := newFakeCryptoSigner()
|
||||
service := NewSignatureService(repo, signer)
|
||||
service := NewSignatureService(repo, newFakeDocumentRepository(), signer)
|
||||
|
||||
sig := &models.Signature{ID: 1, DocID: "doc1", UserSub: "user1", UserEmail: "user1@example.com"}
|
||||
repo.signatures["doc1_user1"] = sig
|
||||
@@ -776,7 +779,7 @@ func TestSignatureService_VerifyChainIntegrity(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := newFakeRepository()
|
||||
signer := newFakeCryptoSigner()
|
||||
service := NewSignatureService(repo, signer)
|
||||
service := NewSignatureService(repo, newFakeDocumentRepository(), signer)
|
||||
|
||||
tt.setupSignatures(repo)
|
||||
|
||||
@@ -807,7 +810,7 @@ func TestSignatureService_VerifyChainIntegrity(t *testing.T) {
|
||||
t.Run("repository fails", func(t *testing.T) {
|
||||
repo := newFakeRepository()
|
||||
signer := newFakeCryptoSigner()
|
||||
service := NewSignatureService(repo, signer)
|
||||
service := NewSignatureService(repo, newFakeDocumentRepository(), signer)
|
||||
|
||||
repo.shouldFailGetAll = true
|
||||
|
||||
@@ -822,7 +825,7 @@ func TestSignatureService_RebuildChain(t *testing.T) {
|
||||
t.Run("empty chain", func(t *testing.T) {
|
||||
repo := newFakeRepository()
|
||||
signer := newFakeCryptoSigner()
|
||||
service := NewSignatureService(repo, signer)
|
||||
service := NewSignatureService(repo, newFakeDocumentRepository(), signer)
|
||||
|
||||
err := service.RebuildChain(context.Background())
|
||||
if err != nil {
|
||||
@@ -833,7 +836,7 @@ func TestSignatureService_RebuildChain(t *testing.T) {
|
||||
t.Run("chain with signatures", func(t *testing.T) {
|
||||
repo := newFakeRepository()
|
||||
signer := newFakeCryptoSigner()
|
||||
service := NewSignatureService(repo, signer)
|
||||
service := NewSignatureService(repo, newFakeDocumentRepository(), signer)
|
||||
|
||||
hash := "wrong-hash"
|
||||
sig1 := &models.Signature{
|
||||
@@ -859,7 +862,7 @@ func TestSignatureService_RebuildChain(t *testing.T) {
|
||||
t.Run("repository fails", func(t *testing.T) {
|
||||
repo := newFakeRepository()
|
||||
signer := newFakeCryptoSigner()
|
||||
service := NewSignatureService(repo, signer)
|
||||
service := NewSignatureService(repo, newFakeDocumentRepository(), signer)
|
||||
|
||||
repo.shouldFailGetAll = true
|
||||
|
||||
@@ -905,7 +908,7 @@ func int64Ptr(i int64) *int64 {
|
||||
func TestSignatureService_CreateSignature_MultipleDocumentsChaining(t *testing.T) {
|
||||
repo := newFakeRepository()
|
||||
signer := newFakeCryptoSigner()
|
||||
service := NewSignatureService(repo, signer)
|
||||
service := NewSignatureService(repo, newFakeDocumentRepository(), signer)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
27
backend/internal/domain/models/checksum_verification.go
Normal file
27
backend/internal/domain/models/checksum_verification.go
Normal file
@@ -0,0 +1,27 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package models
|
||||
|
||||
import "time"
|
||||
|
||||
// ChecksumVerification represents a verification attempt of a document's checksum
|
||||
type ChecksumVerification struct {
|
||||
ID int64 `json:"id" db:"id"`
|
||||
DocID string `json:"doc_id" db:"doc_id"`
|
||||
VerifiedBy string `json:"verified_by" db:"verified_by"`
|
||||
VerifiedAt time.Time `json:"verified_at" db:"verified_at"`
|
||||
StoredChecksum string `json:"stored_checksum" db:"stored_checksum"`
|
||||
CalculatedChecksum string `json:"calculated_checksum" db:"calculated_checksum"`
|
||||
Algorithm string `json:"algorithm" db:"algorithm"`
|
||||
IsValid bool `json:"is_valid" db:"is_valid"`
|
||||
ErrorMessage *string `json:"error_message,omitempty" db:"error_message"`
|
||||
}
|
||||
|
||||
// ChecksumVerificationResult represents the result of a checksum verification operation
|
||||
type ChecksumVerificationResult struct {
|
||||
Valid bool `json:"valid"`
|
||||
StoredChecksum string `json:"stored_checksum"`
|
||||
CalculatedChecksum string `json:"calculated_checksum"`
|
||||
Algorithm string `json:"algorithm"`
|
||||
Message string `json:"message"`
|
||||
HasReferenceHash bool `json:"has_reference_hash"`
|
||||
}
|
||||
46
backend/internal/domain/models/document.go
Normal file
46
backend/internal/domain/models/document.go
Normal file
@@ -0,0 +1,46 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package models
|
||||
|
||||
import "time"
|
||||
|
||||
// Document represents document metadata for tracking and integrity verification
|
||||
type Document struct {
|
||||
DocID string `json:"doc_id" db:"doc_id"`
|
||||
Title string `json:"title" db:"title"`
|
||||
URL string `json:"url" db:"url"`
|
||||
Checksum string `json:"checksum" db:"checksum"`
|
||||
ChecksumAlgorithm string `json:"checksum_algorithm" db:"checksum_algorithm"`
|
||||
Description string `json:"description" db:"description"`
|
||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
|
||||
CreatedBy string `json:"created_by" db:"created_by"`
|
||||
DeletedAt *time.Time `json:"deleted_at,omitempty" db:"deleted_at"`
|
||||
}
|
||||
|
||||
// DocumentInput represents the input for creating/updating document metadata
|
||||
type DocumentInput struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Checksum string `json:"checksum"`
|
||||
ChecksumAlgorithm string `json:"checksum_algorithm"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
// HasChecksum returns true if the document has a checksum configured
|
||||
func (d *Document) HasChecksum() bool {
|
||||
return d.Checksum != ""
|
||||
}
|
||||
|
||||
// GetExpectedChecksumLength returns the expected length for the configured algorithm
|
||||
func (d *Document) GetExpectedChecksumLength() int {
|
||||
switch d.ChecksumAlgorithm {
|
||||
case "SHA-256":
|
||||
return 64
|
||||
case "SHA-512":
|
||||
return 128
|
||||
case "MD5":
|
||||
return 32
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
100
backend/internal/domain/models/document_test.go
Normal file
100
backend/internal/domain/models/document_test.go
Normal file
@@ -0,0 +1,100 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package models
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDocument_HasChecksum(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
document *Document
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Document with checksum",
|
||||
document: &Document{
|
||||
Checksum: "abc123def456",
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Document without checksum",
|
||||
document: &Document{
|
||||
Checksum: "",
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Document with whitespace checksum",
|
||||
document: &Document{
|
||||
Checksum: " ",
|
||||
},
|
||||
expected: true, // Non-empty string
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := tt.document.HasChecksum()
|
||||
if result != tt.expected {
|
||||
t.Errorf("HasChecksum() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDocument_GetExpectedChecksumLength(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
checksumAlgorithm string
|
||||
expectedLength int
|
||||
}{
|
||||
{
|
||||
name: "SHA-256 algorithm",
|
||||
checksumAlgorithm: "SHA-256",
|
||||
expectedLength: 64,
|
||||
},
|
||||
{
|
||||
name: "SHA-512 algorithm",
|
||||
checksumAlgorithm: "SHA-512",
|
||||
expectedLength: 128,
|
||||
},
|
||||
{
|
||||
name: "MD5 algorithm",
|
||||
checksumAlgorithm: "MD5",
|
||||
expectedLength: 32,
|
||||
},
|
||||
{
|
||||
name: "Unknown algorithm",
|
||||
checksumAlgorithm: "UNKNOWN",
|
||||
expectedLength: 0,
|
||||
},
|
||||
{
|
||||
name: "Empty algorithm",
|
||||
checksumAlgorithm: "",
|
||||
expectedLength: 0,
|
||||
},
|
||||
{
|
||||
name: "Lowercase sha-256",
|
||||
checksumAlgorithm: "sha-256",
|
||||
expectedLength: 0, // Case sensitive
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
doc := &Document{
|
||||
ChecksumAlgorithm: tt.checksumAlgorithm,
|
||||
}
|
||||
result := doc.GetExpectedChecksumLength()
|
||||
if result != tt.expectedLength {
|
||||
t.Errorf("GetExpectedChecksumLength() = %v, want %v", result, tt.expectedLength)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
162
backend/internal/domain/models/email_queue.go
Normal file
162
backend/internal/domain/models/email_queue.go
Normal file
@@ -0,0 +1,162 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package models
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
// EmailQueueStatus represents the status of an email in the queue
|
||||
type EmailQueueStatus string
|
||||
|
||||
const (
|
||||
EmailStatusPending EmailQueueStatus = "pending"
|
||||
EmailStatusProcessing EmailQueueStatus = "processing"
|
||||
EmailStatusSent EmailQueueStatus = "sent"
|
||||
EmailStatusFailed EmailQueueStatus = "failed"
|
||||
EmailStatusCancelled EmailQueueStatus = "cancelled"
|
||||
)
|
||||
|
||||
// EmailPriority represents email priority levels
|
||||
type EmailPriority int
|
||||
|
||||
const (
|
||||
EmailPriorityNormal EmailPriority = 0
|
||||
EmailPriorityHigh EmailPriority = 10
|
||||
EmailPriorityUrgent EmailPriority = 100
|
||||
)
|
||||
|
||||
// EmailQueueItem represents an email in the processing queue
|
||||
type EmailQueueItem struct {
|
||||
ID int64 `json:"id"`
|
||||
ToAddresses []string `json:"to_addresses"`
|
||||
CcAddresses []string `json:"cc_addresses,omitempty"`
|
||||
BccAddresses []string `json:"bcc_addresses,omitempty"`
|
||||
Subject string `json:"subject"`
|
||||
Template string `json:"template"`
|
||||
Locale string `json:"locale"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
Headers NullRawMessage `json:"headers,omitempty"`
|
||||
Status EmailQueueStatus `json:"status"`
|
||||
Priority EmailPriority `json:"priority"`
|
||||
RetryCount int `json:"retry_count"`
|
||||
MaxRetries int `json:"max_retries"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ScheduledFor time.Time `json:"scheduled_for"`
|
||||
ProcessedAt *time.Time `json:"processed_at,omitempty"`
|
||||
NextRetryAt *time.Time `json:"next_retry_at,omitempty"`
|
||||
LastError *string `json:"last_error,omitempty"`
|
||||
ErrorDetails NullRawMessage `json:"error_details,omitempty"`
|
||||
ReferenceType *string `json:"reference_type,omitempty"`
|
||||
ReferenceID *string `json:"reference_id,omitempty"`
|
||||
CreatedBy *string `json:"created_by,omitempty"`
|
||||
}
|
||||
|
||||
// EmailQueueInput represents the input for creating a new email queue item
|
||||
type EmailQueueInput struct {
|
||||
ToAddresses []string `json:"to_addresses"`
|
||||
CcAddresses []string `json:"cc_addresses,omitempty"`
|
||||
BccAddresses []string `json:"bcc_addresses,omitempty"`
|
||||
Subject string `json:"subject"`
|
||||
Template string `json:"template"`
|
||||
Locale string `json:"locale"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
Headers map[string]string `json:"headers,omitempty"`
|
||||
Priority EmailPriority `json:"priority"`
|
||||
ScheduledFor *time.Time `json:"scheduled_for,omitempty"` // nil = immediate
|
||||
ReferenceType *string `json:"reference_type,omitempty"`
|
||||
ReferenceID *string `json:"reference_id,omitempty"`
|
||||
CreatedBy *string `json:"created_by,omitempty"`
|
||||
MaxRetries int `json:"max_retries"` // 0 = use default (3)
|
||||
}
|
||||
|
||||
// EmailQueueStats represents aggregated statistics for the email queue
|
||||
type EmailQueueStats struct {
|
||||
TotalPending int `json:"total_pending"`
|
||||
TotalProcessing int `json:"total_processing"`
|
||||
TotalSent int `json:"total_sent"`
|
||||
TotalFailed int `json:"total_failed"`
|
||||
OldestPending *time.Time `json:"oldest_pending,omitempty"`
|
||||
AverageRetries float64 `json:"average_retries"`
|
||||
ByStatus map[string]int `json:"by_status"`
|
||||
ByPriority map[string]int `json:"by_priority"`
|
||||
Last24Hours EmailPeriodStats `json:"last_24_hours"`
|
||||
}
|
||||
|
||||
// EmailPeriodStats represents email statistics for a time period
|
||||
type EmailPeriodStats struct {
|
||||
Sent int `json:"sent"`
|
||||
Failed int `json:"failed"`
|
||||
Queued int `json:"queued"`
|
||||
}
|
||||
|
||||
// JSONB is a helper type for handling JSONB columns
|
||||
type JSONB map[string]interface{}
|
||||
|
||||
// Value implements driver.Valuer
|
||||
func (j JSONB) Value() (driver.Value, error) {
|
||||
if j == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return json.Marshal(j)
|
||||
}
|
||||
|
||||
// Scan implements sql.Scanner
|
||||
func (j *JSONB) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
*j = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
var data []byte
|
||||
switch v := value.(type) {
|
||||
case []byte:
|
||||
data = v
|
||||
case string:
|
||||
data = []byte(v)
|
||||
default:
|
||||
data = []byte("{}")
|
||||
}
|
||||
|
||||
return json.Unmarshal(data, j)
|
||||
}
|
||||
|
||||
// NullRawMessage is a nullable json.RawMessage for database scanning
|
||||
type NullRawMessage struct {
|
||||
RawMessage json.RawMessage
|
||||
Valid bool
|
||||
}
|
||||
|
||||
// Scan implements sql.Scanner
|
||||
func (n *NullRawMessage) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
n.RawMessage = nil
|
||||
n.Valid = false
|
||||
return nil
|
||||
}
|
||||
|
||||
var data []byte
|
||||
switch v := value.(type) {
|
||||
case []byte:
|
||||
data = v
|
||||
case string:
|
||||
data = []byte(v)
|
||||
default:
|
||||
n.RawMessage = nil
|
||||
n.Valid = false
|
||||
return nil
|
||||
}
|
||||
|
||||
n.RawMessage = data
|
||||
n.Valid = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements driver.Valuer
|
||||
func (n NullRawMessage) Value() (driver.Value, error) {
|
||||
if !n.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return n.RawMessage, nil
|
||||
}
|
||||
@@ -11,4 +11,6 @@ var (
|
||||
ErrDatabaseConnection = errors.New("database connection error")
|
||||
ErrUnauthorized = errors.New("unauthorized")
|
||||
ErrDomainNotAllowed = errors.New("domain not allowed")
|
||||
ErrDocumentModified = errors.New("document has been modified since creation")
|
||||
ErrDocumentNotFound = errors.New("document not found")
|
||||
)
|
||||
127
backend/internal/domain/models/signature.go
Normal file
127
backend/internal/domain/models/signature.go
Normal file
@@ -0,0 +1,127 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package models
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/services"
|
||||
)
|
||||
|
||||
type Signature struct {
|
||||
ID int64 `json:"id" db:"id"`
|
||||
DocID string `json:"doc_id" db:"doc_id"`
|
||||
UserSub string `json:"user_sub" db:"user_sub"`
|
||||
UserEmail string `json:"user_email" db:"user_email"`
|
||||
UserName string `json:"user_name,omitempty" db:"user_name"`
|
||||
SignedAtUTC time.Time `json:"signed_at" db:"signed_at"`
|
||||
DocChecksum string `json:"doc_checksum,omitempty" db:"doc_checksum"`
|
||||
PayloadHash string `json:"payload_hash" db:"payload_hash"`
|
||||
Signature string `json:"signature" db:"signature"`
|
||||
Nonce string `json:"nonce" db:"nonce"`
|
||||
Referer *string `json:"referer,omitempty" db:"referer"`
|
||||
PrevHash *string `json:"prev_hash,omitempty" db:"prev_hash"`
|
||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||
HashVersion int `json:"hash_version" db:"hash_version"`
|
||||
DocDeletedAt *time.Time `json:"doc_deleted_at,omitempty" db:"doc_deleted_at"`
|
||||
// Document metadata enriched from LEFT JOIN (not stored in signatures table)
|
||||
DocTitle string `json:"doc_title,omitempty"`
|
||||
DocURL string `json:"doc_url,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Signature) GetServiceInfo() *services.ServiceInfo {
|
||||
if s.Referer == nil {
|
||||
return nil
|
||||
}
|
||||
return services.DetectServiceFromReferrer(*s.Referer)
|
||||
}
|
||||
|
||||
type SignatureRequest struct {
|
||||
DocID string
|
||||
User *User
|
||||
Referer *string
|
||||
}
|
||||
|
||||
type SignatureStatus struct {
|
||||
DocID string
|
||||
UserEmail string
|
||||
IsSigned bool
|
||||
SignedAt *time.Time
|
||||
}
|
||||
|
||||
// ComputeRecordHash computes the hash of the signature record for blockchain integrity
|
||||
// Uses versioned hash algorithms for backward compatibility
|
||||
func (s *Signature) ComputeRecordHash() string {
|
||||
switch s.HashVersion {
|
||||
case 2:
|
||||
return s.computeHashV2()
|
||||
default:
|
||||
// Version 1 or unset (backward compatibility)
|
||||
return s.computeHashV1()
|
||||
}
|
||||
}
|
||||
|
||||
// computeHashV1 computes hash using legacy pipe-separated format
|
||||
// Used for existing signatures to maintain backward compatibility
|
||||
func (s *Signature) computeHashV1() string {
|
||||
data := fmt.Sprintf("%d|%s|%s|%s|%s|%s|%s|%s|%s|%s|%s|%s",
|
||||
s.ID,
|
||||
s.DocID,
|
||||
s.UserSub,
|
||||
s.UserEmail,
|
||||
s.UserName,
|
||||
s.SignedAtUTC.Format(time.RFC3339Nano),
|
||||
s.DocChecksum,
|
||||
s.PayloadHash,
|
||||
s.Signature,
|
||||
s.Nonce,
|
||||
s.CreatedAt.Format(time.RFC3339Nano),
|
||||
func() string {
|
||||
if s.Referer != nil {
|
||||
return *s.Referer
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
)
|
||||
|
||||
hash := sha256.Sum256([]byte(data))
|
||||
return base64.StdEncoding.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// computeHashV2 computes hash using JSON canonical format
|
||||
// Recommended for new signatures - eliminates ambiguity and is more extensible
|
||||
func (s *Signature) computeHashV2() string {
|
||||
// Create canonical representation with keys sorted alphabetically
|
||||
canonical := map[string]interface{}{
|
||||
"created_at": s.CreatedAt.Unix(),
|
||||
"doc_checksum": s.DocChecksum,
|
||||
"doc_id": s.DocID,
|
||||
"id": s.ID,
|
||||
"nonce": s.Nonce,
|
||||
"payload_hash": s.PayloadHash,
|
||||
"referer": func() string {
|
||||
if s.Referer != nil {
|
||||
return *s.Referer
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
"signature": s.Signature,
|
||||
"signed_at": s.SignedAtUTC.Unix(),
|
||||
"user_email": s.UserEmail,
|
||||
"user_name": s.UserName,
|
||||
"user_sub": s.UserSub,
|
||||
}
|
||||
|
||||
// Marshal to JSON with sorted keys (Go's json.Marshal sorts keys automatically)
|
||||
data, err := json.Marshal(canonical)
|
||||
if err != nil {
|
||||
// Fallback to V1 if JSON marshaling fails (should never happen)
|
||||
return s.computeHashV1()
|
||||
}
|
||||
|
||||
hash := sha256.Sum256(data)
|
||||
return base64.StdEncoding.EncodeToString(hash[:])
|
||||
}
|
||||
@@ -14,8 +14,8 @@ import (
|
||||
"github.com/gorilla/sessions"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/btouchard/ackify-ce/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/pkg/logger"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/logger"
|
||||
)
|
||||
|
||||
const sessionName = "ackapp_session"
|
||||
@@ -48,7 +48,7 @@ func NewOAuthService(config Config) *OauthService {
|
||||
oauthConfig := &oauth2.Config{
|
||||
ClientID: config.ClientID,
|
||||
ClientSecret: config.ClientSecret,
|
||||
RedirectURL: config.BaseURL + "/oauth2/callback",
|
||||
RedirectURL: config.BaseURL + "/api/v1/auth/callback",
|
||||
Scopes: config.Scopes,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: config.AuthURL,
|
||||
@@ -58,6 +58,19 @@ func NewOAuthService(config Config) *OauthService {
|
||||
|
||||
sessionStore := sessions.NewCookieStore(config.CookieSecret)
|
||||
|
||||
// Configure session options globally on the store
|
||||
sessionStore.Options = &sessions.Options{
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: config.SecureCookies,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: 86400 * 7, // 7 days
|
||||
}
|
||||
|
||||
logger.Logger.Info("OAuth session store configured",
|
||||
"secure_cookies", config.SecureCookies,
|
||||
"max_age_days", 7)
|
||||
|
||||
return &OauthService{
|
||||
oauthConfig: oauthConfig,
|
||||
sessionStore: sessionStore,
|
||||
@@ -95,7 +108,13 @@ func (s *OauthService) GetUser(r *http.Request) (*models.User, error) {
|
||||
}
|
||||
|
||||
func (s *OauthService) SetUser(w http.ResponseWriter, r *http.Request, user *models.User) error {
|
||||
session, _ := s.sessionStore.Get(r, sessionName)
|
||||
// Always create a fresh new session to ensure session ID is generated
|
||||
// This fixes an issue where reusing an existing invalid session results in empty session.ID
|
||||
session, err := s.sessionStore.New(r, sessionName)
|
||||
if err != nil {
|
||||
logger.Logger.Error("SetUser: failed to create new session", "error", err.Error())
|
||||
return fmt.Errorf("failed to create new session: %w", err)
|
||||
}
|
||||
|
||||
userJSON, err := json.Marshal(user)
|
||||
if err != nil {
|
||||
@@ -103,24 +122,27 @@ func (s *OauthService) SetUser(w http.ResponseWriter, r *http.Request, user *mod
|
||||
return fmt.Errorf("failed to marshal user: %w", err)
|
||||
}
|
||||
|
||||
logger.Logger.Debug("SetUser: saving user to session",
|
||||
logger.Logger.Debug("SetUser: saving user to new session",
|
||||
"email", user.Email,
|
||||
"secure_cookies", s.secureCookies)
|
||||
"secure_cookies", s.secureCookies,
|
||||
"session_is_new", session.IsNew)
|
||||
|
||||
session.Values["user"] = string(userJSON)
|
||||
session.Options = &sessions.Options{
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: s.secureCookies,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}
|
||||
|
||||
// Session options are already configured globally on the store
|
||||
// No need to set them again here
|
||||
|
||||
if err := session.Save(r, w); err != nil {
|
||||
logger.Logger.Error("SetUser: failed to save session", "error", err.Error())
|
||||
logger.Logger.Error("SetUser: failed to save session",
|
||||
"error", err.Error(),
|
||||
"session_is_new", session.IsNew,
|
||||
"session_id_length", len(session.ID))
|
||||
return fmt.Errorf("failed to save session: %w", err)
|
||||
}
|
||||
|
||||
logger.Logger.Debug("SetUser: session saved successfully")
|
||||
logger.Logger.Info("SetUser: session saved successfully",
|
||||
"email", user.Email,
|
||||
"session_id_length", len(session.ID))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -155,29 +177,37 @@ func (s *OauthService) GetAuthURL(nextURL string) string {
|
||||
return s.oauthConfig.AuthCodeURL(state, oauth2.SetAuthURLParam("prompt", "select_account"))
|
||||
}
|
||||
|
||||
// CreateAuthURL Persist a CSRF state token server-side to prevent forged OAuth callbacks; encode nextURL to preserve intended redirect.
|
||||
func (s *OauthService) CreateAuthURL(w http.ResponseWriter, r *http.Request, nextURL string) string {
|
||||
randPart := securecookie.GenerateRandomKey(20)
|
||||
token := base64.RawURLEncoding.EncodeToString(randPart)
|
||||
state := token + ":" + base64.RawURLEncoding.EncodeToString([]byte(nextURL))
|
||||
|
||||
logger.Logger.Debug("CreateAuthURL: generating OAuth state",
|
||||
"token_length", len(token),
|
||||
"next_url", nextURL)
|
||||
|
||||
session, _ := s.sessionStore.Get(r, sessionName)
|
||||
session.Values["oauth_state"] = token
|
||||
session.Options = &sessions.Options{Path: "/", HttpOnly: true, Secure: s.secureCookies, SameSite: http.SameSiteLaxMode}
|
||||
err := session.Save(r, w)
|
||||
if err != nil {
|
||||
logger.Logger.Error("CreateAuthURL: failed to save session", "error", err.Error())
|
||||
promptParam := "select_account"
|
||||
isSilent := r.URL.Query().Get("silent") == "true"
|
||||
if isSilent {
|
||||
promptParam = "none"
|
||||
}
|
||||
|
||||
// Check if silent login is requested
|
||||
promptParam := "select_account"
|
||||
if r.URL.Query().Get("silent") == "true" {
|
||||
promptParam = "none"
|
||||
logger.Logger.Debug("CreateAuthURL: using silent login (prompt=none)")
|
||||
logger.Logger.Info("Starting OAuth flow",
|
||||
"next_url", nextURL,
|
||||
"silent", isSilent,
|
||||
"state_token_length", len(token))
|
||||
|
||||
session, err := s.sessionStore.Get(r, sessionName)
|
||||
if err != nil {
|
||||
logger.Logger.Error("CreateAuthURL: failed to get session from store", "error", err.Error())
|
||||
// Create a new empty session if Get fails
|
||||
session, _ = s.sessionStore.New(r, sessionName)
|
||||
}
|
||||
|
||||
session.Values["oauth_state"] = token
|
||||
|
||||
// Session options are already configured globally on the store
|
||||
// No need to set them again here
|
||||
|
||||
err = session.Save(r, w)
|
||||
if err != nil {
|
||||
logger.Logger.Error("CreateAuthURL: failed to save session", "error", err.Error())
|
||||
}
|
||||
|
||||
authURL := s.oauthConfig.AuthCodeURL(state, oauth2.SetAuthURLParam("prompt", promptParam))
|
||||
@@ -188,7 +218,6 @@ func (s *OauthService) CreateAuthURL(w http.ResponseWriter, r *http.Request, nex
|
||||
return authURL
|
||||
}
|
||||
|
||||
// VerifyState Clear single-use state on success to prevent replay; compare in constant time to avoid timing leaks.
|
||||
func (s *OauthService) VerifyState(w http.ResponseWriter, r *http.Request, stateToken string) bool {
|
||||
session, _ := s.sessionStore.Get(r, sessionName)
|
||||
stored, _ := session.Values["oauth_state"].(string)
|
||||
@@ -235,29 +264,56 @@ func (s *OauthService) HandleCallback(ctx context.Context, code, state string) (
|
||||
}
|
||||
}
|
||||
|
||||
logger.Logger.Debug("Processing OAuth callback",
|
||||
"has_code", code != "",
|
||||
"next_url", nextURL)
|
||||
|
||||
token, err := s.oauthConfig.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
logger.Logger.Error("OAuth token exchange failed",
|
||||
"error", err.Error())
|
||||
return nil, nextURL, fmt.Errorf("oauth exchange failed: %w", err)
|
||||
}
|
||||
|
||||
logger.Logger.Debug("OAuth token exchange successful")
|
||||
|
||||
client := s.oauthConfig.Client(ctx, token)
|
||||
resp, err := client.Get(s.userInfoURL)
|
||||
if err != nil || resp.StatusCode != 200 {
|
||||
statusCode := 0
|
||||
if resp != nil {
|
||||
statusCode = resp.StatusCode
|
||||
}
|
||||
logger.Logger.Error("User info request failed",
|
||||
"error", err,
|
||||
"status_code", statusCode)
|
||||
return nil, nextURL, fmt.Errorf("userinfo request failed: %w", err)
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
_ = Body.Close()
|
||||
}(resp.Body)
|
||||
|
||||
logger.Logger.Debug("User info retrieved successfully",
|
||||
"status_code", resp.StatusCode)
|
||||
|
||||
user, err := s.parseUserInfo(resp)
|
||||
if err != nil {
|
||||
logger.Logger.Error("Failed to parse user info",
|
||||
"error", err.Error())
|
||||
return nil, nextURL, fmt.Errorf("failed to parse user info: %w", err)
|
||||
}
|
||||
|
||||
if !s.IsAllowedDomain(user.Email) {
|
||||
logger.Logger.Warn("User domain not allowed",
|
||||
"user_email", user.Email,
|
||||
"allowed_domain", s.allowedDomain)
|
||||
return nil, nextURL, models.ErrDomainNotAllowed
|
||||
}
|
||||
|
||||
logger.Logger.Info("OAuth callback successful",
|
||||
"user_email", user.Email,
|
||||
"user_name", user.Name)
|
||||
|
||||
return user, nextURL, nil
|
||||
}
|
||||
|
||||
@@ -292,7 +348,7 @@ func (s *OauthService) parseUserInfo(resp *http.Response) (*models.User, error)
|
||||
if sub, ok := rawUser["sub"].(string); ok {
|
||||
user.Sub = sub
|
||||
} else if id, ok := rawUser["id"]; ok {
|
||||
user.Sub = fmt.Sprintf("%v", id) // Convert to string regardless of type
|
||||
user.Sub = fmt.Sprintf("%v", id)
|
||||
} else {
|
||||
return nil, fmt.Errorf("missing user ID in response")
|
||||
}
|
||||
@@ -304,7 +360,6 @@ func (s *OauthService) parseUserInfo(resp *http.Response) (*models.User, error)
|
||||
}
|
||||
|
||||
var name string
|
||||
// Priority: full name first, then composite name, then username as fallback
|
||||
if fullName, ok := rawUser["name"].(string); ok && fullName != "" {
|
||||
name = fullName
|
||||
} else if firstName, ok := rawUser["given_name"].(string); ok {
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/btouchard/ackify-ce/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
)
|
||||
|
||||
func TestNewOAuthService(t *testing.T) {
|
||||
@@ -69,7 +69,7 @@ func TestNewOAuthService(t *testing.T) {
|
||||
t.Errorf("ClientSecret = %v, expected %v", service.oauthConfig.ClientSecret, tt.config.ClientSecret)
|
||||
}
|
||||
|
||||
expectedRedirectURL := tt.config.BaseURL + "/oauth2/callback"
|
||||
expectedRedirectURL := tt.config.BaseURL + "/api/v1/auth/callback"
|
||||
if service.oauthConfig.RedirectURL != expectedRedirectURL {
|
||||
t.Errorf("RedirectURL = %v, expected %v", service.oauthConfig.RedirectURL, expectedRedirectURL)
|
||||
}
|
||||
@@ -894,3 +894,237 @@ func createTestServiceWithSecure(secure bool) *OauthService {
|
||||
}
|
||||
return NewOAuthService(config)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - VerifyState
|
||||
// ============================================================================
|
||||
|
||||
func TestOauthService_VerifyState_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
service := createTestService()
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
// First, create a session with an oauth_state
|
||||
session, _ := service.sessionStore.Get(r, sessionName)
|
||||
session.Values["oauth_state"] = "test-state-token-123"
|
||||
_ = session.Save(r, w)
|
||||
|
||||
// Get cookies from response
|
||||
cookies := w.Result().Cookies()
|
||||
r2 := httptest.NewRequest("GET", "/", nil)
|
||||
for _, cookie := range cookies {
|
||||
r2.AddCookie(cookie)
|
||||
}
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
result := service.VerifyState(w2, r2, "test-state-token-123")
|
||||
|
||||
if !result {
|
||||
t.Error("VerifyState should return true for matching state")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOauthService_VerifyState_Mismatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
service := createTestService()
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
// Set state in session
|
||||
session, _ := service.sessionStore.Get(r, sessionName)
|
||||
session.Values["oauth_state"] = "correct-state"
|
||||
_ = session.Save(r, w)
|
||||
|
||||
cookies := w.Result().Cookies()
|
||||
r2 := httptest.NewRequest("GET", "/", nil)
|
||||
for _, cookie := range cookies {
|
||||
r2.AddCookie(cookie)
|
||||
}
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
result := service.VerifyState(w2, r2, "wrong-state")
|
||||
|
||||
if result {
|
||||
t.Error("VerifyState should return false for mismatched state")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOauthService_VerifyState_EmptyStored(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
service := createTestService()
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
// Don't set any state in session (empty)
|
||||
result := service.VerifyState(w, r, "some-token")
|
||||
|
||||
if result {
|
||||
t.Error("VerifyState should return false when stored state is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOauthService_VerifyState_EmptyToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
service := createTestService()
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
// Set state in session
|
||||
session, _ := service.sessionStore.Get(r, sessionName)
|
||||
session.Values["oauth_state"] = "some-state"
|
||||
_ = session.Save(r, w)
|
||||
|
||||
cookies := w.Result().Cookies()
|
||||
r2 := httptest.NewRequest("GET", "/", nil)
|
||||
for _, cookie := range cookies {
|
||||
r2.AddCookie(cookie)
|
||||
}
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
result := service.VerifyState(w2, r2, "")
|
||||
|
||||
if result {
|
||||
t.Error("VerifyState should return false when token is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOauthService_VerifyState_BothEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
service := createTestService()
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
result := service.VerifyState(w, r, "")
|
||||
|
||||
if result {
|
||||
t.Error("VerifyState should return false when both are empty")
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - subtleConstantTimeCompare
|
||||
// ============================================================================
|
||||
|
||||
func TestSubtleConstantTimeCompare_Equal(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
a string
|
||||
b string
|
||||
}{
|
||||
{"identical strings", "hello", "hello"},
|
||||
{"identical long strings", "this-is-a-very-long-state-token-12345", "this-is-a-very-long-state-token-12345"},
|
||||
{"empty strings", "", ""},
|
||||
{"special characters", "abc!@#$%^&*()", "abc!@#$%^&*()"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !subtleConstantTimeCompare(tt.a, tt.b) {
|
||||
t.Errorf("subtleConstantTimeCompare(%q, %q) should return true", tt.a, tt.b)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubtleConstantTimeCompare_NotEqual(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
a string
|
||||
b string
|
||||
}{
|
||||
{"different strings", "hello", "world"},
|
||||
{"different lengths", "short", "longer-string"},
|
||||
{"one empty", "hello", ""},
|
||||
{"other empty", "", "world"},
|
||||
{"similar but different", "state123", "state124"},
|
||||
{"case sensitive", "Hello", "hello"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if subtleConstantTimeCompare(tt.a, tt.b) {
|
||||
t.Errorf("subtleConstantTimeCompare(%q, %q) should return false", tt.a, tt.b)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubtleConstantTimeCompare_TimingSafety(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Test that comparison takes similar time regardless of where difference occurs
|
||||
// This is a basic test - true timing attack resistance requires more sophisticated testing
|
||||
a := "this-is-a-long-state-token-with-many-characters"
|
||||
b1 := "Xhis-is-a-long-state-token-with-many-characters" // Differs at start
|
||||
b2 := "this-is-a-long-state-token-with-many-characterX" // Differs at end
|
||||
|
||||
// Both should return false
|
||||
if subtleConstantTimeCompare(a, b1) {
|
||||
t.Error("Should return false for b1")
|
||||
}
|
||||
if subtleConstantTimeCompare(a, b2) {
|
||||
t.Error("Should return false for b2")
|
||||
}
|
||||
|
||||
// The function should have similar behavior regardless of where the difference is
|
||||
// (This is ensured by the XOR loop that always iterates through entire string)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// BENCHMARKS
|
||||
// ============================================================================
|
||||
|
||||
func BenchmarkVerifyState(b *testing.B) {
|
||||
service := createTestService()
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
// Setup session
|
||||
session, _ := service.sessionStore.Get(r, sessionName)
|
||||
session.Values["oauth_state"] = "test-state-token"
|
||||
_ = session.Save(r, w)
|
||||
|
||||
cookies := w.Result().Cookies()
|
||||
r2 := httptest.NewRequest("GET", "/", nil)
|
||||
for _, cookie := range cookies {
|
||||
r2.AddCookie(cookie)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
w := httptest.NewRecorder()
|
||||
_ = service.VerifyState(w, r2, "test-state-token")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSubtleConstantTimeCompare(b *testing.B) {
|
||||
a := "this-is-a-state-token-123456789"
|
||||
b1 := "this-is-a-state-token-123456789"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = subtleConstantTimeCompare(a, b1)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSubtleConstantTimeCompare_Different(b *testing.B) {
|
||||
a := "this-is-a-state-token-123456789"
|
||||
b1 := "this-is-a-state-token-987654321"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = subtleConstantTimeCompare(a, b1)
|
||||
}
|
||||
}
|
||||
@@ -17,6 +17,7 @@ type Config struct {
|
||||
Server ServerConfig
|
||||
Logger LoggerConfig
|
||||
Mail MailConfig
|
||||
Checksum ChecksumConfig
|
||||
}
|
||||
|
||||
type AppConfig struct {
|
||||
@@ -66,6 +67,15 @@ type MailConfig struct {
|
||||
DefaultLocale string
|
||||
}
|
||||
|
||||
type ChecksumConfig struct {
|
||||
MaxBytes int64
|
||||
TimeoutMs int
|
||||
MaxRedirects int
|
||||
AllowedContentType []string
|
||||
SkipSSRFCheck bool // For testing only - DO NOT use in production
|
||||
InsecureSkipVerify bool // For testing only - DO NOT use in production
|
||||
}
|
||||
|
||||
// Load loads configuration from environment variables
|
||||
func Load() (*Config, error) {
|
||||
config := &Config{}
|
||||
@@ -151,6 +161,23 @@ func Load() (*Config, error) {
|
||||
config.Mail.DefaultLocale = getEnv("ACKIFY_MAIL_DEFAULT_LOCALE", "en")
|
||||
}
|
||||
|
||||
// Parse checksum config (automatic checksum computation for remote URLs)
|
||||
config.Checksum.MaxBytes = getEnvInt64("ACKIFY_CHECKSUM_MAX_BYTES", 10*1024*1024) // 10 MB default
|
||||
config.Checksum.TimeoutMs = getEnvInt("ACKIFY_CHECKSUM_TIMEOUT_MS", 5000) // 5 seconds default
|
||||
config.Checksum.MaxRedirects = getEnvInt("ACKIFY_CHECKSUM_MAX_REDIRECTS", 3)
|
||||
|
||||
// Parse allowed content types
|
||||
allowedTypesStr := getEnv("ACKIFY_CHECKSUM_ALLOWED_TYPES", "application/pdf,image/*,application/msword,application/vnd.openxmlformats-officedocument.wordprocessingml.document,application/vnd.ms-excel,application/vnd.openxmlformats-officedocument.spreadsheetml.sheet,application/vnd.oasis.opendocument.*")
|
||||
if allowedTypesStr != "" {
|
||||
types := strings.Split(allowedTypesStr, ",")
|
||||
for _, typ := range types {
|
||||
trimmed := strings.TrimSpace(typ)
|
||||
if trimmed != "" {
|
||||
config.Checksum.AllowedContentType = append(config.Checksum.AllowedContentType, trimmed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
@@ -204,3 +231,15 @@ func getEnvBool(key string, defaultValue bool) bool {
|
||||
}
|
||||
return strings.ToLower(value) == "true" || value == "1"
|
||||
}
|
||||
|
||||
func getEnvInt64(key string, defaultValue int64) int64 {
|
||||
value := strings.TrimSpace(os.Getenv(key))
|
||||
if value == "" {
|
||||
return defaultValue
|
||||
}
|
||||
var result int64
|
||||
if _, err := fmt.Sscanf(value, "%d", &result); err == nil {
|
||||
return result
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/btouchard/ackify-ce/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
)
|
||||
|
||||
type DocumentAgg struct {
|
||||
@@ -26,7 +26,7 @@ func NewAdminRepository(db *sql.DB) *AdminRepository {
|
||||
return &AdminRepository{db: db}
|
||||
}
|
||||
|
||||
// ListDocumentsWithCounts returns all documents with their signature counts
|
||||
// ListDocumentsWithCounts aggregates signature metrics across all documents for admin dashboard
|
||||
func (r *AdminRepository) ListDocumentsWithCounts(ctx context.Context) ([]DocumentAgg, error) {
|
||||
query := `
|
||||
SELECT
|
||||
@@ -82,7 +82,7 @@ func (r *AdminRepository) ListDocumentsWithCounts(ctx context.Context) ([]Docume
|
||||
return documents, nil
|
||||
}
|
||||
|
||||
// ListSignaturesByDoc returns all signatures for a specific document
|
||||
// ListSignaturesByDoc retrieves all signatures for a document in reverse chronological order
|
||||
func (r *AdminRepository) ListSignaturesByDoc(ctx context.Context, docID string) ([]*models.Signature, error) {
|
||||
query := `
|
||||
SELECT id, doc_id, user_sub, user_email, user_name, signed_at, payload_hash, signature, nonce, created_at, referer, prev_hash
|
||||
@@ -125,7 +125,7 @@ func (r *AdminRepository) ListSignaturesByDoc(ctx context.Context, docID string)
|
||||
return signatures, nil
|
||||
}
|
||||
|
||||
// VerifyDocumentChainIntegrity vérifie l'intégrité de la chaîne pour un document donné
|
||||
// VerifyDocumentChainIntegrity validates cryptographic hash chain continuity for all signatures in a document
|
||||
func (r *AdminRepository) VerifyDocumentChainIntegrity(ctx context.Context, docID string) (*ChainIntegrityResult, error) {
|
||||
signatures, err := r.ListSignaturesByDoc(ctx, docID)
|
||||
if err != nil {
|
||||
@@ -208,7 +208,7 @@ func (r *AdminRepository) verifyChainIntegrity(signatures []*models.Signature) *
|
||||
return result
|
||||
}
|
||||
|
||||
// Close closes the database connection
|
||||
// Close gracefully terminates the database connection pool to prevent resource leaks
|
||||
func (r *AdminRepository) Close() error {
|
||||
if r.db != nil {
|
||||
return r.db.Close()
|
||||
@@ -0,0 +1,63 @@
|
||||
//go:build integration
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAdminRepository_ListDocumentsWithCounts_Integration(t *testing.T) {
|
||||
testDB := SetupTestDB(t)
|
||||
// Tables are created by migrations in SetupTestDB
|
||||
|
||||
ctx := context.Background()
|
||||
repo := NewAdminRepository(testDB.DB)
|
||||
|
||||
// Test listing documents - should succeed even if empty
|
||||
docs, err := repo.ListDocumentsWithCounts(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("ListDocumentsWithCounts failed: %v", err)
|
||||
}
|
||||
|
||||
// docs can be nil or empty slice if no documents exist - both are valid
|
||||
_ = docs
|
||||
}
|
||||
|
||||
func TestAdminRepository_ListSignaturesByDoc_Integration(t *testing.T) {
|
||||
testDB := SetupTestDB(t)
|
||||
|
||||
_, err := testDB.DB.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS signatures (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
doc_id TEXT NOT NULL,
|
||||
user_sub TEXT NOT NULL,
|
||||
user_email TEXT NOT NULL,
|
||||
user_name TEXT,
|
||||
signed_at TIMESTAMPTZ NOT NULL,
|
||||
payload_hash TEXT NOT NULL,
|
||||
signature TEXT NOT NULL,
|
||||
nonce TEXT NOT NULL,
|
||||
referer TEXT,
|
||||
prev_hash TEXT,
|
||||
doc_checksum TEXT,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
UNIQUE (doc_id, user_sub)
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create signatures table: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
repo := NewAdminRepository(testDB.DB)
|
||||
|
||||
// Test listing signatures for a doc
|
||||
sigs, err := repo.ListSignaturesByDoc(ctx, "test-doc")
|
||||
if err != nil {
|
||||
t.Fatalf("ListSignaturesByDoc failed: %v", err)
|
||||
}
|
||||
|
||||
// sigs can be nil or empty if no signatures exist
|
||||
_ = sigs
|
||||
}
|
||||
@@ -6,8 +6,8 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/btouchard/ackify-ce/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/pkg/logger"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/logger"
|
||||
)
|
||||
|
||||
// DocumentRepository handles document metadata persistence
|
||||
@@ -20,14 +20,24 @@ func NewDocumentRepository(db *sql.DB) *DocumentRepository {
|
||||
return &DocumentRepository{db: db}
|
||||
}
|
||||
|
||||
// Create creates a new document metadata entry
|
||||
// Create persists a new document with metadata including optional checksum validation data
|
||||
func (r *DocumentRepository) Create(ctx context.Context, docID string, input models.DocumentInput, createdBy string) (*models.Document, error) {
|
||||
query := `
|
||||
INSERT INTO documents (doc_id, title, url, checksum, checksum_algorithm, description, created_by)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
RETURNING doc_id, title, url, checksum, checksum_algorithm, description, created_at, updated_at, created_by
|
||||
RETURNING doc_id, title, url, checksum, checksum_algorithm, description, created_at, updated_at, created_by, deleted_at
|
||||
`
|
||||
|
||||
// Use NULL for empty checksum fields to avoid constraint violation
|
||||
var checksum, checksumAlgorithm interface{}
|
||||
if input.Checksum != "" {
|
||||
checksum = input.Checksum
|
||||
checksumAlgorithm = input.ChecksumAlgorithm
|
||||
} else {
|
||||
checksum = ""
|
||||
checksumAlgorithm = "SHA-256"
|
||||
}
|
||||
|
||||
doc := &models.Document{}
|
||||
err := r.db.QueryRowContext(
|
||||
ctx,
|
||||
@@ -35,8 +45,8 @@ func (r *DocumentRepository) Create(ctx context.Context, docID string, input mod
|
||||
docID,
|
||||
input.Title,
|
||||
input.URL,
|
||||
input.Checksum,
|
||||
input.ChecksumAlgorithm,
|
||||
checksum,
|
||||
checksumAlgorithm,
|
||||
input.Description,
|
||||
createdBy,
|
||||
).Scan(
|
||||
@@ -49,6 +59,7 @@ func (r *DocumentRepository) Create(ctx context.Context, docID string, input mod
|
||||
&doc.CreatedAt,
|
||||
&doc.UpdatedAt,
|
||||
&doc.CreatedBy,
|
||||
&doc.DeletedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
@@ -59,12 +70,12 @@ func (r *DocumentRepository) Create(ctx context.Context, docID string, input mod
|
||||
return doc, nil
|
||||
}
|
||||
|
||||
// GetByDocID retrieves document metadata by document ID
|
||||
// GetByDocID retrieves document metadata by document ID (excluding soft-deleted documents)
|
||||
func (r *DocumentRepository) GetByDocID(ctx context.Context, docID string) (*models.Document, error) {
|
||||
query := `
|
||||
SELECT doc_id, title, url, checksum, checksum_algorithm, description, created_at, updated_at, created_by
|
||||
SELECT doc_id, title, url, checksum, checksum_algorithm, description, created_at, updated_at, created_by, deleted_at
|
||||
FROM documents
|
||||
WHERE doc_id = $1
|
||||
WHERE doc_id = $1 AND deleted_at IS NULL
|
||||
`
|
||||
|
||||
doc := &models.Document{}
|
||||
@@ -78,6 +89,7 @@ func (r *DocumentRepository) GetByDocID(ctx context.Context, docID string) (*mod
|
||||
&doc.CreatedAt,
|
||||
&doc.UpdatedAt,
|
||||
&doc.CreatedBy,
|
||||
&doc.DeletedAt,
|
||||
)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
@@ -92,15 +104,99 @@ func (r *DocumentRepository) GetByDocID(ctx context.Context, docID string) (*mod
|
||||
return doc, nil
|
||||
}
|
||||
|
||||
// Update updates document metadata
|
||||
// FindByReference searches for a document by reference (URL, path, or doc_id)
|
||||
func (r *DocumentRepository) FindByReference(ctx context.Context, ref string, refType string) (*models.Document, error) {
|
||||
var query string
|
||||
var args []interface{}
|
||||
|
||||
switch refType {
|
||||
case "url":
|
||||
// Search by URL field (excluding soft-deleted)
|
||||
query = `
|
||||
SELECT doc_id, title, url, checksum, checksum_algorithm, description, created_at, updated_at, created_by, deleted_at
|
||||
FROM documents
|
||||
WHERE url = $1 AND deleted_at IS NULL
|
||||
LIMIT 1
|
||||
`
|
||||
args = []interface{}{ref}
|
||||
|
||||
case "path":
|
||||
// Search by URL field (paths are also stored in url field, excluding soft-deleted)
|
||||
query = `
|
||||
SELECT doc_id, title, url, checksum, checksum_algorithm, description, created_at, updated_at, created_by, deleted_at
|
||||
FROM documents
|
||||
WHERE url = $1 AND deleted_at IS NULL
|
||||
LIMIT 1
|
||||
`
|
||||
args = []interface{}{ref}
|
||||
|
||||
case "reference":
|
||||
// Search by doc_id (excluding soft-deleted)
|
||||
query = `
|
||||
SELECT doc_id, title, url, checksum, checksum_algorithm, description, created_at, updated_at, created_by, deleted_at
|
||||
FROM documents
|
||||
WHERE doc_id = $1 AND deleted_at IS NULL
|
||||
LIMIT 1
|
||||
`
|
||||
args = []interface{}{ref}
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown reference type: %s", refType)
|
||||
}
|
||||
|
||||
doc := &models.Document{}
|
||||
err := r.db.QueryRowContext(ctx, query, args...).Scan(
|
||||
&doc.DocID,
|
||||
&doc.Title,
|
||||
&doc.URL,
|
||||
&doc.Checksum,
|
||||
&doc.ChecksumAlgorithm,
|
||||
&doc.Description,
|
||||
&doc.CreatedAt,
|
||||
&doc.UpdatedAt,
|
||||
&doc.CreatedBy,
|
||||
&doc.DeletedAt,
|
||||
)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
logger.Logger.Debug("Document not found by reference",
|
||||
"reference", ref,
|
||||
"type", refType)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.Logger.Error("Failed to find document by reference",
|
||||
"error", err.Error(),
|
||||
"reference", ref,
|
||||
"type", refType)
|
||||
return nil, fmt.Errorf("failed to find document: %w", err)
|
||||
}
|
||||
|
||||
logger.Logger.Debug("Document found by reference",
|
||||
"doc_id", doc.DocID,
|
||||
"reference", ref,
|
||||
"type", refType)
|
||||
|
||||
return doc, nil
|
||||
}
|
||||
|
||||
// Update modifies existing document metadata while preserving creation timestamp and creator
|
||||
func (r *DocumentRepository) Update(ctx context.Context, docID string, input models.DocumentInput) (*models.Document, error) {
|
||||
query := `
|
||||
UPDATE documents
|
||||
SET title = $2, url = $3, checksum = $4, checksum_algorithm = $5, description = $6
|
||||
WHERE doc_id = $1
|
||||
RETURNING doc_id, title, url, checksum, checksum_algorithm, description, created_at, updated_at, created_by
|
||||
WHERE doc_id = $1 AND deleted_at IS NULL
|
||||
RETURNING doc_id, title, url, checksum, checksum_algorithm, description, created_at, updated_at, created_by, deleted_at
|
||||
`
|
||||
|
||||
// Use empty string for empty checksum fields (table has NOT NULL DEFAULT '')
|
||||
checksum := input.Checksum
|
||||
checksumAlgorithm := input.ChecksumAlgorithm
|
||||
if checksumAlgorithm == "" {
|
||||
checksumAlgorithm = "SHA-256" // Default algorithm
|
||||
}
|
||||
|
||||
doc := &models.Document{}
|
||||
err := r.db.QueryRowContext(
|
||||
ctx,
|
||||
@@ -108,8 +204,8 @@ func (r *DocumentRepository) Update(ctx context.Context, docID string, input mod
|
||||
docID,
|
||||
input.Title,
|
||||
input.URL,
|
||||
input.Checksum,
|
||||
input.ChecksumAlgorithm,
|
||||
checksum,
|
||||
checksumAlgorithm,
|
||||
input.Description,
|
||||
).Scan(
|
||||
&doc.DocID,
|
||||
@@ -121,6 +217,7 @@ func (r *DocumentRepository) Update(ctx context.Context, docID string, input mod
|
||||
&doc.CreatedAt,
|
||||
&doc.UpdatedAt,
|
||||
&doc.CreatedBy,
|
||||
&doc.DeletedAt,
|
||||
)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
@@ -135,7 +232,7 @@ func (r *DocumentRepository) Update(ctx context.Context, docID string, input mod
|
||||
return doc, nil
|
||||
}
|
||||
|
||||
// CreateOrUpdate creates or updates document metadata
|
||||
// CreateOrUpdate performs upsert operation, creating new document or updating existing one atomically
|
||||
func (r *DocumentRepository) CreateOrUpdate(ctx context.Context, docID string, input models.DocumentInput, createdBy string) (*models.Document, error) {
|
||||
query := `
|
||||
INSERT INTO documents (doc_id, title, url, checksum, checksum_algorithm, description, created_by)
|
||||
@@ -145,10 +242,18 @@ func (r *DocumentRepository) CreateOrUpdate(ctx context.Context, docID string, i
|
||||
url = EXCLUDED.url,
|
||||
checksum = EXCLUDED.checksum,
|
||||
checksum_algorithm = EXCLUDED.checksum_algorithm,
|
||||
description = EXCLUDED.description
|
||||
RETURNING doc_id, title, url, checksum, checksum_algorithm, description, created_at, updated_at, created_by
|
||||
description = EXCLUDED.description,
|
||||
deleted_at = NULL
|
||||
RETURNING doc_id, title, url, checksum, checksum_algorithm, description, created_at, updated_at, created_by, deleted_at
|
||||
`
|
||||
|
||||
// Use empty string for empty checksum fields (table has NOT NULL DEFAULT '')
|
||||
checksum := input.Checksum
|
||||
checksumAlgorithm := input.ChecksumAlgorithm
|
||||
if checksumAlgorithm == "" {
|
||||
checksumAlgorithm = "SHA-256" // Default algorithm
|
||||
}
|
||||
|
||||
doc := &models.Document{}
|
||||
err := r.db.QueryRowContext(
|
||||
ctx,
|
||||
@@ -156,8 +261,8 @@ func (r *DocumentRepository) CreateOrUpdate(ctx context.Context, docID string, i
|
||||
docID,
|
||||
input.Title,
|
||||
input.URL,
|
||||
input.Checksum,
|
||||
input.ChecksumAlgorithm,
|
||||
checksum,
|
||||
checksumAlgorithm,
|
||||
input.Description,
|
||||
createdBy,
|
||||
).Scan(
|
||||
@@ -170,6 +275,7 @@ func (r *DocumentRepository) CreateOrUpdate(ctx context.Context, docID string, i
|
||||
&doc.CreatedAt,
|
||||
&doc.UpdatedAt,
|
||||
&doc.CreatedBy,
|
||||
&doc.DeletedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
@@ -180,9 +286,9 @@ func (r *DocumentRepository) CreateOrUpdate(ctx context.Context, docID string, i
|
||||
return doc, nil
|
||||
}
|
||||
|
||||
// Delete deletes document metadata
|
||||
// Delete soft-deletes document by setting deleted_at timestamp, preserving metadata and signature history
|
||||
func (r *DocumentRepository) Delete(ctx context.Context, docID string) error {
|
||||
query := `DELETE FROM documents WHERE doc_id = $1`
|
||||
query := `UPDATE documents SET deleted_at = now() WHERE doc_id = $1 AND deleted_at IS NULL`
|
||||
|
||||
result, err := r.db.ExecContext(ctx, query, docID)
|
||||
if err != nil {
|
||||
@@ -196,17 +302,18 @@ func (r *DocumentRepository) Delete(ctx context.Context, docID string) error {
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("document not found")
|
||||
return fmt.Errorf("document not found or already deleted")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// List retrieves all documents with pagination
|
||||
// List retrieves paginated documents ordered by creation date, newest first (excluding soft-deleted)
|
||||
func (r *DocumentRepository) List(ctx context.Context, limit, offset int) ([]*models.Document, error) {
|
||||
query := `
|
||||
SELECT doc_id, title, url, checksum, checksum_algorithm, description, created_at, updated_at, created_by
|
||||
SELECT doc_id, title, url, checksum, checksum_algorithm, description, created_at, updated_at, created_by, deleted_at
|
||||
FROM documents
|
||||
WHERE deleted_at IS NULL
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $1 OFFSET $2
|
||||
`
|
||||
@@ -231,6 +338,7 @@ func (r *DocumentRepository) List(ctx context.Context, limit, offset int) ([]*mo
|
||||
&doc.CreatedAt,
|
||||
&doc.UpdatedAt,
|
||||
&doc.CreatedBy,
|
||||
&doc.DeletedAt,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Logger.Error("Failed to scan document row", "error", err.Error())
|
||||
@@ -7,48 +7,11 @@ import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/btouchard/ackify-ce/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
)
|
||||
|
||||
func setupDocumentsTable(t *testing.T, testDB *TestDB) {
|
||||
t.Helper()
|
||||
|
||||
schema := `
|
||||
DROP TABLE IF EXISTS documents;
|
||||
|
||||
CREATE TABLE documents (
|
||||
doc_id TEXT PRIMARY KEY,
|
||||
title TEXT NOT NULL DEFAULT '',
|
||||
url TEXT NOT NULL DEFAULT '',
|
||||
checksum TEXT NOT NULL DEFAULT '',
|
||||
checksum_algorithm TEXT NOT NULL DEFAULT 'SHA-256' CHECK (checksum_algorithm IN ('SHA-256', 'SHA-512', 'MD5')),
|
||||
description TEXT NOT NULL DEFAULT '',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
created_by TEXT NOT NULL DEFAULT ''
|
||||
);
|
||||
|
||||
CREATE INDEX idx_documents_created_at ON documents(created_at DESC);
|
||||
|
||||
CREATE OR REPLACE FUNCTION update_documents_updated_at()
|
||||
RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
NEW.updated_at = now();
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
CREATE TRIGGER trigger_update_documents_updated_at
|
||||
BEFORE UPDATE ON documents
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION update_documents_updated_at();
|
||||
`
|
||||
|
||||
_, err := testDB.DB.Exec(schema)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to setup documents table: %v", err)
|
||||
}
|
||||
}
|
||||
// setupDocumentsTable is no longer needed - migrations handle schema creation
|
||||
// Removed to use real migrations from testutils.go
|
||||
|
||||
func clearDocumentsTable(t *testing.T, testDB *TestDB) {
|
||||
t.Helper()
|
||||
@@ -60,7 +23,6 @@ func clearDocumentsTable(t *testing.T, testDB *TestDB) {
|
||||
|
||||
func TestDocumentRepository_Create(t *testing.T) {
|
||||
testDB := SetupTestDB(t)
|
||||
setupDocumentsTable(t, testDB)
|
||||
|
||||
ctx := context.Background()
|
||||
repo := NewDocumentRepository(testDB.DB)
|
||||
@@ -117,7 +79,6 @@ func TestDocumentRepository_Create(t *testing.T) {
|
||||
|
||||
func TestDocumentRepository_GetByDocID(t *testing.T) {
|
||||
testDB := SetupTestDB(t)
|
||||
setupDocumentsTable(t, testDB)
|
||||
|
||||
ctx := context.Background()
|
||||
repo := NewDocumentRepository(testDB.DB)
|
||||
@@ -165,7 +126,6 @@ func TestDocumentRepository_GetByDocID(t *testing.T) {
|
||||
|
||||
func TestDocumentRepository_Update(t *testing.T) {
|
||||
testDB := SetupTestDB(t)
|
||||
setupDocumentsTable(t, testDB)
|
||||
|
||||
ctx := context.Background()
|
||||
repo := NewDocumentRepository(testDB.DB)
|
||||
@@ -229,7 +189,6 @@ func TestDocumentRepository_Update(t *testing.T) {
|
||||
|
||||
func TestDocumentRepository_CreateOrUpdate(t *testing.T) {
|
||||
testDB := SetupTestDB(t)
|
||||
setupDocumentsTable(t, testDB)
|
||||
|
||||
ctx := context.Background()
|
||||
repo := NewDocumentRepository(testDB.DB)
|
||||
@@ -287,7 +246,6 @@ func TestDocumentRepository_CreateOrUpdate(t *testing.T) {
|
||||
|
||||
func TestDocumentRepository_Delete(t *testing.T) {
|
||||
testDB := SetupTestDB(t)
|
||||
setupDocumentsTable(t, testDB)
|
||||
|
||||
ctx := context.Background()
|
||||
repo := NewDocumentRepository(testDB.DB)
|
||||
@@ -329,7 +287,6 @@ func TestDocumentRepository_Delete(t *testing.T) {
|
||||
|
||||
func TestDocumentRepository_List(t *testing.T) {
|
||||
testDB := SetupTestDB(t)
|
||||
setupDocumentsTable(t, testDB)
|
||||
|
||||
ctx := context.Background()
|
||||
repo := NewDocumentRepository(testDB.DB)
|
||||
@@ -386,3 +343,51 @@ func TestDocumentRepository_List(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDocumentRepository_FindByReference_Integration(t *testing.T) {
|
||||
testDB := SetupTestDB(t)
|
||||
|
||||
ctx := context.Background()
|
||||
repo := NewDocumentRepository(testDB.DB)
|
||||
|
||||
// Create a document first
|
||||
input := models.DocumentInput{
|
||||
Title: "Test Doc",
|
||||
URL: "https://example.com/doc.pdf",
|
||||
}
|
||||
|
||||
created, err := repo.Create(ctx, "test-doc-123", input, "admin@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create document: %v", err)
|
||||
}
|
||||
|
||||
// Test finding by URL reference
|
||||
found, err := repo.FindByReference(ctx, created.URL, "url")
|
||||
if err != nil {
|
||||
t.Fatalf("FindByReference failed: %v", err)
|
||||
}
|
||||
|
||||
if found == nil {
|
||||
t.Fatal("Expected to find document, got nil")
|
||||
}
|
||||
|
||||
if found.DocID != created.DocID {
|
||||
t.Errorf("Expected DocID %s, got %s", created.DocID, found.DocID)
|
||||
}
|
||||
|
||||
// Test finding by reference type (doc_id)
|
||||
foundByRef, err := repo.FindByReference(ctx, "test-doc-123", "reference")
|
||||
if err != nil {
|
||||
t.Fatalf("FindByReference by ref failed: %v", err)
|
||||
}
|
||||
|
||||
if foundByRef == nil {
|
||||
t.Fatal("Expected to find document by reference, got nil")
|
||||
}
|
||||
|
||||
// Test not found case
|
||||
notFound, err := repo.FindByReference(ctx, "non-existent-url", "url")
|
||||
if err == nil && notFound == nil {
|
||||
// This is expected - not found
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,485 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/logger"
|
||||
)
|
||||
|
||||
// EmailQueueRepository handles database operations for the email queue
|
||||
type EmailQueueRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewEmailQueueRepository creates a new email queue repository
|
||||
func NewEmailQueueRepository(db *sql.DB) *EmailQueueRepository {
|
||||
return &EmailQueueRepository{db: db}
|
||||
}
|
||||
|
||||
// Enqueue adds a new email to the queue
|
||||
func (r *EmailQueueRepository) Enqueue(ctx context.Context, input models.EmailQueueInput) (*models.EmailQueueItem, error) {
|
||||
// Prepare data as JSON
|
||||
dataJSON, err := json.Marshal(input.Data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal email data: %w", err)
|
||||
}
|
||||
|
||||
var headersJSON []byte
|
||||
if input.Headers != nil {
|
||||
headersJSON, err = json.Marshal(input.Headers)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal email headers: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Use empty JSON object instead of nil for PostgreSQL JSONB compatibility
|
||||
headersJSON = []byte("{}")
|
||||
}
|
||||
|
||||
// Default values
|
||||
maxRetries := input.MaxRetries
|
||||
if maxRetries == 0 {
|
||||
maxRetries = 3
|
||||
}
|
||||
|
||||
scheduledFor := time.Now()
|
||||
if input.ScheduledFor != nil {
|
||||
scheduledFor = *input.ScheduledFor
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO email_queue (
|
||||
to_addresses, cc_addresses, bcc_addresses,
|
||||
subject, template, locale, data, headers,
|
||||
priority, scheduled_for, max_retries,
|
||||
reference_type, reference_id, created_by
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14
|
||||
) RETURNING
|
||||
id, status, retry_count, created_at, processed_at,
|
||||
next_retry_at, last_error, error_details
|
||||
`
|
||||
|
||||
item := &models.EmailQueueItem{
|
||||
ToAddresses: input.ToAddresses,
|
||||
CcAddresses: input.CcAddresses,
|
||||
BccAddresses: input.BccAddresses,
|
||||
Subject: input.Subject,
|
||||
Template: input.Template,
|
||||
Locale: input.Locale,
|
||||
Data: dataJSON,
|
||||
Headers: models.NullRawMessage{RawMessage: headersJSON, Valid: input.Headers != nil},
|
||||
Priority: input.Priority,
|
||||
ScheduledFor: scheduledFor,
|
||||
MaxRetries: maxRetries,
|
||||
ReferenceType: input.ReferenceType,
|
||||
ReferenceID: input.ReferenceID,
|
||||
CreatedBy: input.CreatedBy,
|
||||
}
|
||||
|
||||
err = r.db.QueryRowContext(
|
||||
ctx,
|
||||
query,
|
||||
pq.Array(input.ToAddresses),
|
||||
pq.Array(input.CcAddresses),
|
||||
pq.Array(input.BccAddresses),
|
||||
input.Subject,
|
||||
input.Template,
|
||||
input.Locale,
|
||||
dataJSON,
|
||||
headersJSON,
|
||||
input.Priority,
|
||||
scheduledFor,
|
||||
maxRetries,
|
||||
input.ReferenceType,
|
||||
input.ReferenceID,
|
||||
input.CreatedBy,
|
||||
).Scan(
|
||||
&item.ID,
|
||||
&item.Status,
|
||||
&item.RetryCount,
|
||||
&item.CreatedAt,
|
||||
&item.ProcessedAt,
|
||||
&item.NextRetryAt,
|
||||
&item.LastError,
|
||||
&item.ErrorDetails,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
logger.Logger.Error("Failed to enqueue email",
|
||||
"error", err.Error(),
|
||||
"template", input.Template)
|
||||
return nil, fmt.Errorf("failed to enqueue email: %w", err)
|
||||
}
|
||||
|
||||
logger.Logger.Info("Email enqueued successfully",
|
||||
"id", item.ID,
|
||||
"template", input.Template,
|
||||
"priority", input.Priority)
|
||||
|
||||
return item, nil
|
||||
}
|
||||
|
||||
// GetNextToProcess fetches the next email(s) to process from the queue
|
||||
func (r *EmailQueueRepository) GetNextToProcess(ctx context.Context, limit int) ([]*models.EmailQueueItem, error) {
|
||||
query := `
|
||||
UPDATE email_queue
|
||||
SET status = 'processing'
|
||||
WHERE id IN (
|
||||
SELECT id FROM email_queue
|
||||
WHERE status = 'pending'
|
||||
AND scheduled_for <= $1
|
||||
ORDER BY priority DESC, scheduled_for ASC
|
||||
LIMIT $2
|
||||
FOR UPDATE SKIP LOCKED
|
||||
)
|
||||
RETURNING
|
||||
id, to_addresses, cc_addresses, bcc_addresses,
|
||||
subject, template, locale, data, headers,
|
||||
status, priority, retry_count, max_retries,
|
||||
created_at, scheduled_for, processed_at, next_retry_at,
|
||||
last_error, error_details, reference_type, reference_id, created_by
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, time.Now(), limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get next emails to process: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var items []*models.EmailQueueItem
|
||||
for rows.Next() {
|
||||
item := &models.EmailQueueItem{}
|
||||
err := rows.Scan(
|
||||
&item.ID,
|
||||
pq.Array(&item.ToAddresses),
|
||||
pq.Array(&item.CcAddresses),
|
||||
pq.Array(&item.BccAddresses),
|
||||
&item.Subject,
|
||||
&item.Template,
|
||||
&item.Locale,
|
||||
&item.Data,
|
||||
&item.Headers,
|
||||
&item.Status,
|
||||
&item.Priority,
|
||||
&item.RetryCount,
|
||||
&item.MaxRetries,
|
||||
&item.CreatedAt,
|
||||
&item.ScheduledFor,
|
||||
&item.ProcessedAt,
|
||||
&item.NextRetryAt,
|
||||
&item.LastError,
|
||||
&item.ErrorDetails,
|
||||
&item.ReferenceType,
|
||||
&item.ReferenceID,
|
||||
&item.CreatedBy,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan email queue item: %w", err)
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
return items, nil
|
||||
}
|
||||
|
||||
// MarkAsSent marks an email as successfully sent
|
||||
func (r *EmailQueueRepository) MarkAsSent(ctx context.Context, id int64) error {
|
||||
query := `
|
||||
UPDATE email_queue
|
||||
SET status = 'sent',
|
||||
processed_at = $1
|
||||
WHERE id = $2
|
||||
`
|
||||
|
||||
result, err := r.db.ExecContext(ctx, query, time.Now(), id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to mark email as sent: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return fmt.Errorf("email not found: %d", id)
|
||||
}
|
||||
|
||||
logger.Logger.Debug("Email marked as sent", "id", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkAsFailed marks an email as failed with error details
|
||||
func (r *EmailQueueRepository) MarkAsFailed(ctx context.Context, id int64, err error, shouldRetry bool) error {
|
||||
errorMsg := err.Error()
|
||||
|
||||
errorDetails := map[string]interface{}{
|
||||
"error": errorMsg,
|
||||
"timestamp": time.Now().Format(time.RFC3339),
|
||||
"should_retry": shouldRetry,
|
||||
}
|
||||
|
||||
errorDetailsJSON, _ := json.Marshal(errorDetails)
|
||||
|
||||
var query string
|
||||
var args []interface{}
|
||||
|
||||
if shouldRetry {
|
||||
// If retrying, increment retry count and calculate next retry time
|
||||
query = `
|
||||
UPDATE email_queue
|
||||
SET status = 'pending',
|
||||
retry_count = retry_count + 1,
|
||||
last_error = $1,
|
||||
error_details = $2,
|
||||
scheduled_for = calculate_next_retry_time(retry_count + 1)
|
||||
WHERE id = $3 AND retry_count < max_retries
|
||||
`
|
||||
args = []interface{}{errorMsg, errorDetailsJSON, id}
|
||||
} else {
|
||||
// If not retrying, mark as failed
|
||||
query = `
|
||||
UPDATE email_queue
|
||||
SET status = 'failed',
|
||||
processed_at = $1,
|
||||
last_error = $2,
|
||||
error_details = $3
|
||||
WHERE id = $4
|
||||
`
|
||||
args = []interface{}{time.Now(), errorMsg, errorDetailsJSON, id}
|
||||
}
|
||||
|
||||
result, err := r.db.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to mark email as failed: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 && shouldRetry {
|
||||
// Max retries reached, mark as permanently failed
|
||||
query = `
|
||||
UPDATE email_queue
|
||||
SET status = 'failed',
|
||||
processed_at = $1,
|
||||
last_error = $2,
|
||||
error_details = $3
|
||||
WHERE id = $4
|
||||
`
|
||||
_, err = r.db.ExecContext(ctx, query, time.Now(), errorMsg, errorDetailsJSON, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to mark email as permanently failed: %w", err)
|
||||
}
|
||||
logger.Logger.Warn("Email max retries reached, marked as failed", "id", id)
|
||||
}
|
||||
|
||||
logger.Logger.Debug("Email marked as failed", "id", id, "should_retry", shouldRetry)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRetryableEmails fetches emails that should be retried
|
||||
func (r *EmailQueueRepository) GetRetryableEmails(ctx context.Context, limit int) ([]*models.EmailQueueItem, error) {
|
||||
query := `
|
||||
UPDATE email_queue
|
||||
SET status = 'processing'
|
||||
WHERE id IN (
|
||||
SELECT id FROM email_queue
|
||||
WHERE status = 'pending'
|
||||
AND retry_count > 0
|
||||
AND retry_count < max_retries
|
||||
AND scheduled_for <= $1
|
||||
ORDER BY priority DESC, scheduled_for ASC
|
||||
LIMIT $2
|
||||
FOR UPDATE SKIP LOCKED
|
||||
)
|
||||
RETURNING
|
||||
id, to_addresses, cc_addresses, bcc_addresses,
|
||||
subject, template, locale, data, headers,
|
||||
status, priority, retry_count, max_retries,
|
||||
created_at, scheduled_for, processed_at, next_retry_at,
|
||||
last_error, error_details, reference_type, reference_id, created_by
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, time.Now(), limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get retryable emails: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var items []*models.EmailQueueItem
|
||||
for rows.Next() {
|
||||
item := &models.EmailQueueItem{}
|
||||
err := rows.Scan(
|
||||
&item.ID,
|
||||
pq.Array(&item.ToAddresses),
|
||||
pq.Array(&item.CcAddresses),
|
||||
pq.Array(&item.BccAddresses),
|
||||
&item.Subject,
|
||||
&item.Template,
|
||||
&item.Locale,
|
||||
&item.Data,
|
||||
&item.Headers,
|
||||
&item.Status,
|
||||
&item.Priority,
|
||||
&item.RetryCount,
|
||||
&item.MaxRetries,
|
||||
&item.CreatedAt,
|
||||
&item.ScheduledFor,
|
||||
&item.ProcessedAt,
|
||||
&item.NextRetryAt,
|
||||
&item.LastError,
|
||||
&item.ErrorDetails,
|
||||
&item.ReferenceType,
|
||||
&item.ReferenceID,
|
||||
&item.CreatedBy,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan email queue item: %w", err)
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
return items, nil
|
||||
}
|
||||
|
||||
// GetQueueStats returns statistics about the email queue
|
||||
func (r *EmailQueueRepository) GetQueueStats(ctx context.Context) (*models.EmailQueueStats, error) {
|
||||
stats := &models.EmailQueueStats{
|
||||
ByStatus: make(map[string]int),
|
||||
ByPriority: make(map[string]int),
|
||||
}
|
||||
|
||||
// Get counts by status
|
||||
statusQuery := `
|
||||
SELECT status, COUNT(*)
|
||||
FROM email_queue
|
||||
GROUP BY status
|
||||
`
|
||||
rows, err := r.db.QueryContext(ctx, statusQuery)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get status counts: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var status string
|
||||
var count int
|
||||
if err := rows.Scan(&status, &count); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan status count: %w", err)
|
||||
}
|
||||
stats.ByStatus[status] = count
|
||||
|
||||
switch models.EmailQueueStatus(status) {
|
||||
case models.EmailStatusPending:
|
||||
stats.TotalPending = count
|
||||
case models.EmailStatusProcessing:
|
||||
stats.TotalProcessing = count
|
||||
case models.EmailStatusSent:
|
||||
stats.TotalSent = count
|
||||
case models.EmailStatusFailed:
|
||||
stats.TotalFailed = count
|
||||
}
|
||||
}
|
||||
|
||||
// Get oldest pending email
|
||||
var oldestPending sql.NullTime
|
||||
err = r.db.QueryRowContext(ctx, `
|
||||
SELECT MIN(created_at)
|
||||
FROM email_queue
|
||||
WHERE status = 'pending'
|
||||
`).Scan(&oldestPending)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("failed to get oldest pending: %w", err)
|
||||
}
|
||||
if oldestPending.Valid {
|
||||
stats.OldestPending = &oldestPending.Time
|
||||
}
|
||||
|
||||
// Get average retry count
|
||||
err = r.db.QueryRowContext(ctx, `
|
||||
SELECT AVG(retry_count)::float
|
||||
FROM email_queue
|
||||
WHERE status IN ('sent', 'failed')
|
||||
`).Scan(&stats.AverageRetries)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("failed to get average retries: %w", err)
|
||||
}
|
||||
|
||||
// Get last 24 hours stats
|
||||
err = r.db.QueryRowContext(ctx, `
|
||||
SELECT
|
||||
COUNT(*) FILTER (WHERE status = 'sent' AND processed_at >= NOW() - INTERVAL '24 hours') as sent,
|
||||
COUNT(*) FILTER (WHERE status = 'failed' AND processed_at >= NOW() - INTERVAL '24 hours') as failed,
|
||||
COUNT(*) FILTER (WHERE created_at >= NOW() - INTERVAL '24 hours') as queued
|
||||
FROM email_queue
|
||||
`).Scan(&stats.Last24Hours.Sent, &stats.Last24Hours.Failed, &stats.Last24Hours.Queued)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get 24h stats: %w", err)
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// CancelEmail cancels a pending email
|
||||
func (r *EmailQueueRepository) CancelEmail(ctx context.Context, id int64) error {
|
||||
query := `
|
||||
UPDATE email_queue
|
||||
SET status = 'cancelled',
|
||||
processed_at = $1
|
||||
WHERE id = $2 AND status IN ('pending', 'processing')
|
||||
`
|
||||
|
||||
result, err := r.db.ExecContext(ctx, query, time.Now(), id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to cancel email: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return fmt.Errorf("email not found or already processed: %d", id)
|
||||
}
|
||||
|
||||
logger.Logger.Info("Email cancelled", "id", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupOldEmails removes old processed emails from the queue
|
||||
func (r *EmailQueueRepository) CleanupOldEmails(ctx context.Context, olderThan time.Duration) (int64, error) {
|
||||
query := `
|
||||
DELETE FROM email_queue
|
||||
WHERE status IN ('sent', 'failed', 'cancelled')
|
||||
AND processed_at < $1
|
||||
`
|
||||
|
||||
cutoff := time.Now().Add(-olderThan)
|
||||
result, err := r.db.ExecContext(ctx, query, cutoff)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to cleanup old emails: %w", err)
|
||||
}
|
||||
|
||||
deleted, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get deleted count: %w", err)
|
||||
}
|
||||
|
||||
if deleted > 0 {
|
||||
logger.Logger.Info("Old emails cleaned up", "count", deleted, "older_than", olderThan)
|
||||
}
|
||||
|
||||
return deleted, nil
|
||||
}
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/btouchard/ackify-ce/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/pkg/logger"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/logger"
|
||||
)
|
||||
|
||||
// ExpectedSignerRepository handles database operations for expected signers
|
||||
@@ -21,7 +21,7 @@ func NewExpectedSignerRepository(db *sql.DB) *ExpectedSignerRepository {
|
||||
return &ExpectedSignerRepository{db: db}
|
||||
}
|
||||
|
||||
// AddExpected adds multiple expected signers for a document (batch insert with conflict handling)
|
||||
// AddExpected batch-inserts multiple expected signers with conflict-safe deduplication on (doc_id, email)
|
||||
func (r *ExpectedSignerRepository) AddExpected(ctx context.Context, docID string, contacts []models.ContactInfo, addedBy string) error {
|
||||
if len(contacts) == 0 {
|
||||
return nil
|
||||
@@ -50,7 +50,7 @@ func (r *ExpectedSignerRepository) AddExpected(ctx context.Context, docID string
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListByDocID returns all expected signers for a document
|
||||
// ListByDocID retrieves all expected signers for a document, ordered chronologically by when they were added
|
||||
func (r *ExpectedSignerRepository) ListByDocID(ctx context.Context, docID string) ([]*models.ExpectedSigner, error) {
|
||||
query := `
|
||||
SELECT id, doc_id, email, name, added_at, added_by, notes
|
||||
@@ -91,7 +91,7 @@ func (r *ExpectedSignerRepository) ListByDocID(ctx context.Context, docID string
|
||||
return signers, nil
|
||||
}
|
||||
|
||||
// ListWithStatusByDocID returns expected signers with their signature status
|
||||
// ListWithStatusByDocID enriches signer data with signature completion status and reminder tracking metrics
|
||||
func (r *ExpectedSignerRepository) ListWithStatusByDocID(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) {
|
||||
query := `
|
||||
SELECT
|
||||
@@ -169,7 +169,7 @@ func (r *ExpectedSignerRepository) ListWithStatusByDocID(ctx context.Context, do
|
||||
return signers, nil
|
||||
}
|
||||
|
||||
// Remove removes an expected signer from a document
|
||||
// Remove deletes a specific expected signer by document ID and email address
|
||||
func (r *ExpectedSignerRepository) Remove(ctx context.Context, docID, email string) error {
|
||||
query := `
|
||||
DELETE FROM expected_signers
|
||||
@@ -193,7 +193,7 @@ func (r *ExpectedSignerRepository) Remove(ctx context.Context, docID, email stri
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveAllForDoc removes all expected signers for a document
|
||||
// RemoveAllForDoc purges all expected signers associated with a document in a single operation
|
||||
func (r *ExpectedSignerRepository) RemoveAllForDoc(ctx context.Context, docID string) error {
|
||||
query := `
|
||||
DELETE FROM expected_signers
|
||||
@@ -208,7 +208,7 @@ func (r *ExpectedSignerRepository) RemoveAllForDoc(ctx context.Context, docID st
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsExpected checks if an email is expected for a document
|
||||
// IsExpected efficiently verifies if an email address is in the expected signer list for a document
|
||||
func (r *ExpectedSignerRepository) IsExpected(ctx context.Context, docID, email string) (bool, error) {
|
||||
query := `
|
||||
SELECT EXISTS(
|
||||
@@ -226,7 +226,7 @@ func (r *ExpectedSignerRepository) IsExpected(ctx context.Context, docID, email
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// GetStats returns completion statistics for a document
|
||||
// GetStats calculates signature completion metrics including percentage progress for a document
|
||||
func (r *ExpectedSignerRepository) GetStats(ctx context.Context, docID string) (*models.DocCompletionStats, error) {
|
||||
query := `
|
||||
SELECT
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/btouchard/ackify-ce/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
)
|
||||
|
||||
func TestExpectedSignerRepository_AddExpected(t *testing.T) {
|
||||
@@ -6,8 +6,8 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/btouchard/ackify-ce/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/pkg/logger"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/logger"
|
||||
)
|
||||
|
||||
// ReminderRepository handles database operations for reminder logs
|
||||
@@ -20,7 +20,7 @@ func NewReminderRepository(db *sql.DB) *ReminderRepository {
|
||||
return &ReminderRepository{db: db}
|
||||
}
|
||||
|
||||
// LogReminder logs a reminder email attempt
|
||||
// LogReminder records an email reminder event with delivery status for audit tracking
|
||||
func (r *ReminderRepository) LogReminder(ctx context.Context, log *models.ReminderLog) error {
|
||||
query := `
|
||||
INSERT INTO reminder_logs
|
||||
@@ -46,7 +46,7 @@ func (r *ReminderRepository) LogReminder(ctx context.Context, log *models.Remind
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetReminderHistory returns all reminder logs for a document
|
||||
// GetReminderHistory retrieves complete reminder audit trail for a document, ordered by send time descending
|
||||
func (r *ReminderRepository) GetReminderHistory(ctx context.Context, docID string) ([]*models.ReminderLog, error) {
|
||||
query := `
|
||||
SELECT id, doc_id, recipient_email, sent_at, sent_by, template_used, status, error_message
|
||||
@@ -88,7 +88,7 @@ func (r *ReminderRepository) GetReminderHistory(ctx context.Context, docID strin
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
// GetLastReminderByEmail returns the last reminder sent to an email for a document
|
||||
// GetLastReminderByEmail retrieves the most recent reminder sent to a specific recipient for throttling logic
|
||||
func (r *ReminderRepository) GetLastReminderByEmail(ctx context.Context, docID, email string) (*models.ReminderLog, error) {
|
||||
query := `
|
||||
SELECT id, doc_id, recipient_email, sent_at, sent_by, template_used, status, error_message
|
||||
@@ -121,7 +121,7 @@ func (r *ReminderRepository) GetLastReminderByEmail(ctx context.Context, docID,
|
||||
return log, nil
|
||||
}
|
||||
|
||||
// GetReminderCount returns the count of successfully sent reminders for an email
|
||||
// GetReminderCount tallies successfully delivered reminders to a recipient for rate limiting
|
||||
func (r *ReminderRepository) GetReminderCount(ctx context.Context, docID, email string) (int, error) {
|
||||
query := `
|
||||
SELECT COUNT(*)
|
||||
@@ -138,7 +138,7 @@ func (r *ReminderRepository) GetReminderCount(ctx context.Context, docID, email
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// GetReminderStats returns reminder statistics for a document
|
||||
// GetReminderStats aggregates reminder metrics including pending signers and last send timestamp
|
||||
func (r *ReminderRepository) GetReminderStats(ctx context.Context, docID string) (*models.ReminderStats, error) {
|
||||
query := `
|
||||
SELECT
|
||||
@@ -0,0 +1,92 @@
|
||||
//go:build integration
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
)
|
||||
|
||||
func TestReminderRepository_Basic_Integration(t *testing.T) {
|
||||
testDB := SetupTestDB(t)
|
||||
|
||||
// We need documents and expected_signers tables
|
||||
_, err := testDB.DB.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS documents (
|
||||
doc_id TEXT PRIMARY KEY,
|
||||
title TEXT NOT NULL DEFAULT '',
|
||||
url TEXT NOT NULL DEFAULT '',
|
||||
checksum TEXT NOT NULL DEFAULT '',
|
||||
checksum_algorithm TEXT NOT NULL DEFAULT 'SHA-256',
|
||||
description TEXT NOT NULL DEFAULT '',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
created_by TEXT NOT NULL DEFAULT ''
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS expected_signers (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
doc_id TEXT NOT NULL,
|
||||
email TEXT NOT NULL,
|
||||
name TEXT NOT NULL DEFAULT '',
|
||||
added_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
added_by TEXT NOT NULL,
|
||||
notes TEXT,
|
||||
UNIQUE (doc_id, email),
|
||||
FOREIGN KEY (doc_id) REFERENCES documents(doc_id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS reminder_logs (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
doc_id TEXT NOT NULL,
|
||||
recipient_email TEXT NOT NULL,
|
||||
sent_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
sent_by TEXT NOT NULL,
|
||||
template_used TEXT NOT NULL,
|
||||
status TEXT NOT NULL CHECK (status IN ('sent', 'failed', 'bounced')),
|
||||
error_message TEXT,
|
||||
FOREIGN KEY (doc_id, recipient_email) REFERENCES expected_signers(doc_id, email) ON DELETE CASCADE
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create tables: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
repo := NewReminderRepository(testDB.DB)
|
||||
|
||||
// Create a document and expected signer
|
||||
_, err = testDB.DB.Exec(`
|
||||
INSERT INTO documents (doc_id, title, created_by) VALUES ('doc1', 'Test', 'admin@test.com');
|
||||
INSERT INTO expected_signers (doc_id, email, added_by) VALUES ('doc1', 'user@test.com', 'admin@test.com');
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert test data: %v", err)
|
||||
}
|
||||
|
||||
// Test logging a reminder
|
||||
log := &models.ReminderLog{
|
||||
DocID: "doc1",
|
||||
RecipientEmail: "user@test.com",
|
||||
SentBy: "admin@test.com",
|
||||
TemplateUsed: "test_template",
|
||||
Status: "sent",
|
||||
}
|
||||
|
||||
err = repo.LogReminder(ctx, log)
|
||||
if err != nil {
|
||||
t.Fatalf("LogReminder failed: %v", err)
|
||||
}
|
||||
|
||||
// Test getting reminder history
|
||||
history, err := repo.GetReminderHistory(ctx, "doc1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetReminderHistory failed: %v", err)
|
||||
}
|
||||
|
||||
if len(history) != 1 {
|
||||
t.Errorf("Expected 1 reminder in history, got %d", len(history))
|
||||
}
|
||||
}
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/btouchard/ackify-ce/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
)
|
||||
|
||||
func TestRepository_Concurrency_Integration(t *testing.T) {
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/btouchard/ackify-ce/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
)
|
||||
|
||||
func TestRepository_DatabaseConstraints_Integration(t *testing.T) {
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/btouchard/ackify-ce/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
)
|
||||
|
||||
func TestRepository_Create_Integration(t *testing.T) {
|
||||
@@ -7,22 +7,28 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/btouchard/ackify-ce/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
)
|
||||
|
||||
// SignatureRepository handles PostgreSQL persistence for cryptographic signatures
|
||||
type SignatureRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewSignatureRepository initializes a signature repository with the given database connection
|
||||
func NewSignatureRepository(db *sql.DB) *SignatureRepository {
|
||||
return &SignatureRepository{db: db}
|
||||
}
|
||||
|
||||
// scanSignature scans a row into a Signature, handling NULL values properly
|
||||
func scanSignature(scanner interface {
|
||||
Scan(dest ...interface{}) error
|
||||
}, signature *models.Signature) error {
|
||||
var userName sql.NullString
|
||||
var docChecksum sql.NullString
|
||||
var hashVersion sql.NullInt64
|
||||
var docDeletedAt sql.NullTime
|
||||
var docTitle sql.NullString
|
||||
var docURL sql.NullString
|
||||
err := scanner.Scan(
|
||||
&signature.ID,
|
||||
&signature.DocID,
|
||||
@@ -30,38 +36,66 @@ func scanSignature(scanner interface {
|
||||
&signature.UserEmail,
|
||||
&userName,
|
||||
&signature.SignedAtUTC,
|
||||
&docChecksum,
|
||||
&signature.PayloadHash,
|
||||
&signature.Signature,
|
||||
&signature.Nonce,
|
||||
&signature.CreatedAt,
|
||||
&signature.Referer,
|
||||
&signature.PrevHash,
|
||||
&hashVersion,
|
||||
&docDeletedAt,
|
||||
&docTitle,
|
||||
&docURL,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Convert sql.NullString to string (empty string if NULL)
|
||||
if userName.Valid {
|
||||
signature.UserName = userName.String
|
||||
} else {
|
||||
signature.UserName = ""
|
||||
}
|
||||
if docChecksum.Valid {
|
||||
signature.DocChecksum = docChecksum.String
|
||||
} else {
|
||||
signature.DocChecksum = ""
|
||||
}
|
||||
if hashVersion.Valid {
|
||||
signature.HashVersion = int(hashVersion.Int64)
|
||||
} else {
|
||||
signature.HashVersion = 1 // Default to version 1
|
||||
}
|
||||
if docDeletedAt.Valid {
|
||||
signature.DocDeletedAt = &docDeletedAt.Time
|
||||
}
|
||||
if docTitle.Valid {
|
||||
signature.DocTitle = docTitle.String
|
||||
}
|
||||
if docURL.Valid {
|
||||
signature.DocURL = docURL.String
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create persists a new signature record to PostgreSQL with UNIQUE constraint enforcement on (doc_id, user_sub)
|
||||
func (r *SignatureRepository) Create(ctx context.Context, signature *models.Signature) error {
|
||||
query := `
|
||||
INSERT INTO signatures (doc_id, user_sub, user_email, user_name, signed_at, payload_hash, signature, nonce, referer, prev_hash)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||
INSERT INTO signatures (doc_id, user_sub, user_email, user_name, signed_at, doc_checksum, payload_hash, signature, nonce, referer, prev_hash)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
RETURNING id, created_at
|
||||
`
|
||||
|
||||
// Convert empty string to NULL for user_name
|
||||
var userName sql.NullString
|
||||
if signature.UserName != "" {
|
||||
userName = sql.NullString{String: signature.UserName, Valid: true}
|
||||
}
|
||||
|
||||
var docChecksum sql.NullString
|
||||
if signature.DocChecksum != "" {
|
||||
docChecksum = sql.NullString{String: signature.DocChecksum, Valid: true}
|
||||
}
|
||||
|
||||
err := r.db.QueryRowContext(
|
||||
ctx, query,
|
||||
signature.DocID,
|
||||
@@ -69,6 +103,7 @@ func (r *SignatureRepository) Create(ctx context.Context, signature *models.Sign
|
||||
signature.UserEmail,
|
||||
userName,
|
||||
signature.SignedAtUTC,
|
||||
docChecksum,
|
||||
signature.PayloadHash,
|
||||
signature.Signature,
|
||||
signature.Nonce,
|
||||
@@ -83,11 +118,15 @@ func (r *SignatureRepository) Create(ctx context.Context, signature *models.Sign
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByDocAndUser retrieves a specific signature by document ID and user OAuth subject identifier
|
||||
func (r *SignatureRepository) GetByDocAndUser(ctx context.Context, docID, userSub string) (*models.Signature, error) {
|
||||
query := `
|
||||
SELECT id, doc_id, user_sub, user_email, user_name, signed_at, payload_hash, signature, nonce, created_at, referer, prev_hash
|
||||
FROM signatures
|
||||
WHERE doc_id = $1 AND user_sub = $2
|
||||
SELECT s.id, s.doc_id, s.user_sub, s.user_email, s.user_name, s.signed_at, s.doc_checksum,
|
||||
s.payload_hash, s.signature, s.nonce, s.created_at, s.referer, s.prev_hash,
|
||||
s.hash_version, s.doc_deleted_at, d.title, d.url
|
||||
FROM signatures s
|
||||
LEFT JOIN documents d ON s.doc_id = d.doc_id
|
||||
WHERE s.doc_id = $1 AND s.user_sub = $2
|
||||
`
|
||||
|
||||
signature := &models.Signature{}
|
||||
@@ -103,12 +142,16 @@ func (r *SignatureRepository) GetByDocAndUser(ctx context.Context, docID, userSu
|
||||
return signature, nil
|
||||
}
|
||||
|
||||
// GetByDoc retrieves all signatures for a specific document, ordered by creation timestamp descending
|
||||
func (r *SignatureRepository) GetByDoc(ctx context.Context, docID string) ([]*models.Signature, error) {
|
||||
query := `
|
||||
SELECT id, doc_id, user_sub, user_email, user_name, signed_at, payload_hash, signature, nonce, created_at, referer, prev_hash
|
||||
FROM signatures
|
||||
WHERE doc_id = $1
|
||||
ORDER BY created_at DESC
|
||||
SELECT s.id, s.doc_id, s.user_sub, s.user_email, s.user_name, s.signed_at, s.doc_checksum,
|
||||
s.payload_hash, s.signature, s.nonce, s.created_at, s.referer, s.prev_hash,
|
||||
s.hash_version, s.doc_deleted_at, d.title, d.url
|
||||
FROM signatures s
|
||||
LEFT JOIN documents d ON s.doc_id = d.doc_id
|
||||
WHERE s.doc_id = $1
|
||||
ORDER BY s.created_at DESC
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, docID)
|
||||
@@ -131,12 +174,16 @@ func (r *SignatureRepository) GetByDoc(ctx context.Context, docID string) ([]*mo
|
||||
return signatures, nil
|
||||
}
|
||||
|
||||
// GetByUser retrieves all signatures created by a specific user, ordered by creation timestamp descending
|
||||
func (r *SignatureRepository) GetByUser(ctx context.Context, userSub string) ([]*models.Signature, error) {
|
||||
query := `
|
||||
SELECT id, doc_id, user_sub, user_email, user_name, signed_at, payload_hash, signature, nonce, created_at, referer, prev_hash
|
||||
FROM signatures
|
||||
WHERE user_sub = $1
|
||||
ORDER BY created_at DESC
|
||||
SELECT s.id, s.doc_id, s.user_sub, s.user_email, s.user_name, s.signed_at, s.doc_checksum,
|
||||
s.payload_hash, s.signature, s.nonce, s.created_at, s.referer, s.prev_hash,
|
||||
s.hash_version, s.doc_deleted_at, d.title, d.url
|
||||
FROM signatures s
|
||||
LEFT JOIN documents d ON s.doc_id = d.doc_id
|
||||
WHERE s.user_sub = $1
|
||||
ORDER BY s.created_at DESC
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, userSub)
|
||||
@@ -159,6 +206,7 @@ func (r *SignatureRepository) GetByUser(ctx context.Context, userSub string) ([]
|
||||
return signatures, nil
|
||||
}
|
||||
|
||||
// ExistsByDocAndUser efficiently checks if a signature already exists without retrieving full record data
|
||||
func (r *SignatureRepository) ExistsByDocAndUser(ctx context.Context, docID, userSub string) (bool, error) {
|
||||
query := `SELECT EXISTS(SELECT 1 FROM signatures WHERE doc_id = $1 AND user_sub = $2)`
|
||||
|
||||
@@ -171,6 +219,7 @@ func (r *SignatureRepository) ExistsByDocAndUser(ctx context.Context, docID, use
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// CheckUserSignatureStatus verifies if a user has signed, accepting either OAuth subject or email as identifier
|
||||
func (r *SignatureRepository) CheckUserSignatureStatus(ctx context.Context, docID, userIdentifier string) (bool, error) {
|
||||
query := `
|
||||
SELECT EXISTS(
|
||||
@@ -188,12 +237,16 @@ func (r *SignatureRepository) CheckUserSignatureStatus(ctx context.Context, docI
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// GetLastSignature retrieves the most recent signature for hash chain linking (returns nil if no signatures exist)
|
||||
func (r *SignatureRepository) GetLastSignature(ctx context.Context, docID string) (*models.Signature, error) {
|
||||
query := `
|
||||
SELECT id, doc_id, user_sub, user_email, user_name, signed_at, payload_hash, signature, nonce, created_at, referer, prev_hash
|
||||
FROM signatures
|
||||
WHERE doc_id = $1
|
||||
ORDER BY id DESC
|
||||
SELECT s.id, s.doc_id, s.user_sub, s.user_email, s.user_name, s.signed_at, s.doc_checksum,
|
||||
s.payload_hash, s.signature, s.nonce, s.created_at, s.referer, s.prev_hash,
|
||||
s.hash_version, s.doc_deleted_at, d.title, d.url
|
||||
FROM signatures s
|
||||
LEFT JOIN documents d ON s.doc_id = d.doc_id
|
||||
WHERE s.doc_id = $1
|
||||
ORDER BY s.id DESC
|
||||
LIMIT 1
|
||||
`
|
||||
|
||||
@@ -210,11 +263,15 @@ func (r *SignatureRepository) GetLastSignature(ctx context.Context, docID string
|
||||
return signature, nil
|
||||
}
|
||||
|
||||
// GetAllSignaturesOrdered retrieves all signatures in chronological order for chain integrity verification
|
||||
func (r *SignatureRepository) GetAllSignaturesOrdered(ctx context.Context) ([]*models.Signature, error) {
|
||||
query := `
|
||||
SELECT id, doc_id, user_sub, user_email, user_name, signed_at, payload_hash, signature, nonce, created_at, referer, prev_hash
|
||||
FROM signatures
|
||||
ORDER BY id ASC`
|
||||
SELECT s.id, s.doc_id, s.user_sub, s.user_email, s.user_name, s.signed_at, s.doc_checksum,
|
||||
s.payload_hash, s.signature, s.nonce, s.created_at, s.referer, s.prev_hash,
|
||||
s.hash_version, s.doc_deleted_at, d.title, d.url
|
||||
FROM signatures s
|
||||
LEFT JOIN documents d ON s.doc_id = d.doc_id
|
||||
ORDER BY s.id ASC`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
@@ -236,6 +293,7 @@ func (r *SignatureRepository) GetAllSignaturesOrdered(ctx context.Context) ([]*m
|
||||
return signatures, nil
|
||||
}
|
||||
|
||||
// UpdatePrevHash modifies the previous hash pointer for chain reconstruction operations
|
||||
func (r *SignatureRepository) UpdatePrevHash(ctx context.Context, id int64, prevHash *string) error {
|
||||
query := `UPDATE signatures SET prev_hash = $2 WHERE id = $1`
|
||||
if _, err := r.db.ExecContext(ctx, query, id, prevHash); err != nil {
|
||||
@@ -7,11 +7,14 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/btouchard/ackify-ce/internal/domain/models"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database/postgres"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
@@ -60,45 +63,90 @@ func SetupTestDB(t *testing.T) *TestDB {
|
||||
}
|
||||
|
||||
func (tdb *TestDB) createSchema() error {
|
||||
schema := `
|
||||
-- Drop table if exists (for cleanup)
|
||||
DROP TABLE IF EXISTS signatures;
|
||||
// Find migrations directory
|
||||
migrationsPath := os.Getenv("MIGRATIONS_PATH")
|
||||
if migrationsPath == "" {
|
||||
// Try to find migrations directory by walking up from current directory
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get working directory: %w", err)
|
||||
}
|
||||
|
||||
-- Create signatures table
|
||||
CREATE TABLE signatures (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
doc_id TEXT NOT NULL,
|
||||
user_sub TEXT NOT NULL,
|
||||
user_email TEXT NOT NULL,
|
||||
user_name TEXT,
|
||||
signed_at TIMESTAMPTZ NOT NULL,
|
||||
payload_hash TEXT NOT NULL,
|
||||
signature TEXT NOT NULL,
|
||||
nonce TEXT NOT NULL,
|
||||
referer TEXT,
|
||||
prev_hash TEXT,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
|
||||
-- Constraints
|
||||
UNIQUE (doc_id, user_sub)
|
||||
);
|
||||
// Walk up the directory tree looking for backend/migrations
|
||||
found := false
|
||||
searchDir := wd
|
||||
for i := 0; i < 10; i++ {
|
||||
testPath := filepath.Join(searchDir, "backend", "migrations")
|
||||
if stat, err := os.Stat(testPath); err == nil && stat.IsDir() {
|
||||
migrationsPath = testPath
|
||||
found = true
|
||||
break
|
||||
}
|
||||
|
||||
-- Create indexes for performance
|
||||
CREATE INDEX idx_signatures_doc_id ON signatures(doc_id);
|
||||
CREATE INDEX idx_signatures_user_sub ON signatures(user_sub);
|
||||
CREATE INDEX idx_signatures_user_email ON signatures(user_email);
|
||||
CREATE INDEX idx_signatures_created_at ON signatures(created_at);
|
||||
CREATE INDEX idx_signatures_id_asc ON signatures(id ASC);
|
||||
`
|
||||
// Also try just "migrations" directory
|
||||
testPath = filepath.Join(searchDir, "migrations")
|
||||
if stat, err := os.Stat(testPath); err == nil && stat.IsDir() {
|
||||
migrationsPath = testPath
|
||||
found = true
|
||||
break
|
||||
}
|
||||
|
||||
_, err := tdb.DB.Exec(schema)
|
||||
return err
|
||||
parent := filepath.Dir(searchDir)
|
||||
if parent == searchDir {
|
||||
break // Reached root
|
||||
}
|
||||
searchDir = parent
|
||||
}
|
||||
|
||||
if !found {
|
||||
return fmt.Errorf("migrations directory not found (searched from %s)", wd)
|
||||
}
|
||||
}
|
||||
|
||||
// Get absolute path
|
||||
absPath, err := filepath.Abs(migrationsPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get absolute path for migrations: %w", err)
|
||||
}
|
||||
|
||||
// Create postgres driver instance
|
||||
driver, err := postgres.WithInstance(tdb.DB, &postgres.Config{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create postgres driver: %w", err)
|
||||
}
|
||||
|
||||
// Create migrator
|
||||
m, err := migrate.NewWithDatabaseInstance(
|
||||
fmt.Sprintf("file://%s", absPath),
|
||||
"postgres",
|
||||
driver,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create migrator: %w", err)
|
||||
}
|
||||
|
||||
// Apply all migrations
|
||||
if err := m.Up(); err != nil && err != migrate.ErrNoChange {
|
||||
return fmt.Errorf("failed to apply migrations: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tdb *TestDB) Cleanup() {
|
||||
if tdb.DB != nil {
|
||||
// Drop all tables for cleanup
|
||||
_, _ = tdb.DB.Exec("DROP TABLE IF EXISTS signatures")
|
||||
// Drop all tables to ensure clean state
|
||||
// This is more reliable than running migrations down
|
||||
_, _ = tdb.DB.Exec(`
|
||||
DROP TABLE IF EXISTS signatures CASCADE;
|
||||
DROP TABLE IF EXISTS documents CASCADE;
|
||||
DROP TABLE IF EXISTS expected_signers CASCADE;
|
||||
DROP TABLE IF EXISTS reminder_logs CASCADE;
|
||||
DROP TABLE IF EXISTS checksum_verifications CASCADE;
|
||||
DROP TABLE IF EXISTS email_queue CASCADE;
|
||||
DROP TABLE IF EXISTS schema_migrations CASCADE;
|
||||
`)
|
||||
|
||||
_ = tdb.DB.Close()
|
||||
}
|
||||
}
|
||||
549
backend/internal/infrastructure/email/email_test.go
Normal file
549
backend/internal/infrastructure/email/email_test.go
Normal file
@@ -0,0 +1,549 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package email
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/config"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/i18n"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ============================================================================
|
||||
// TEST FIXTURES
|
||||
// ============================================================================
|
||||
|
||||
const (
|
||||
testBaseURL = "https://example.com"
|
||||
testOrganisation = "Test Org"
|
||||
testFromName = "Test Sender"
|
||||
testFromEmail = "noreply@example.com"
|
||||
)
|
||||
|
||||
func createTestI18n(t *testing.T, tmpDir string) *i18n.I18n {
|
||||
t.Helper()
|
||||
|
||||
// Create simple test translations for all supported languages
|
||||
translations := map[string]map[string]string{
|
||||
"en": {
|
||||
"test.title": "Test Template",
|
||||
"test.message": "Message: {{.message}}",
|
||||
},
|
||||
"fr": {
|
||||
"test.title": "Modèle de Test",
|
||||
"test.message": "Message: {{.message}}",
|
||||
},
|
||||
"de": {
|
||||
"test.title": "Test Vorlage",
|
||||
"test.message": "Nachricht: {{.message}}",
|
||||
},
|
||||
"es": {
|
||||
"test.title": "Plantilla de Prueba",
|
||||
"test.message": "Mensaje: {{.message}}",
|
||||
},
|
||||
"it": {
|
||||
"test.title": "Modello di Test",
|
||||
"test.message": "Messaggio: {{.message}}",
|
||||
},
|
||||
}
|
||||
|
||||
for lang, trans := range translations {
|
||||
// Write locale files
|
||||
content := "{"
|
||||
first := true
|
||||
for key, value := range trans {
|
||||
if !first {
|
||||
content += ","
|
||||
}
|
||||
content += `"` + key + `":"` + value + `"`
|
||||
first = false
|
||||
}
|
||||
content += "}"
|
||||
|
||||
err := os.WriteFile(filepath.Join(tmpDir, lang+".json"), []byte(content), 0644)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
i18nService, err := i18n.NewI18n(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
return i18nService
|
||||
}
|
||||
|
||||
func createTestRenderer(t *testing.T) (*Renderer, string) {
|
||||
t.Helper()
|
||||
|
||||
// Create temporary template directory
|
||||
tmpDir := t.TempDir()
|
||||
localesDir := filepath.Join(tmpDir, "locales")
|
||||
err := os.MkdirAll(localesDir, 0755)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create i18n service
|
||||
i18nService := createTestI18n(t, localesDir)
|
||||
|
||||
// Create base templates
|
||||
baseHTML := `{{define "base"}}<!DOCTYPE html>
|
||||
<html>
|
||||
<head><title>{{.Organisation}}</title></head>
|
||||
<body>
|
||||
{{template "content" .}}
|
||||
<p>From: {{.FromName}} ({{.FromMail}})</p>
|
||||
<p>Base URL: {{.BaseURL}}</p>
|
||||
</body>
|
||||
</html>{{end}}`
|
||||
|
||||
baseTxt := `{{define "base"}}{{template "content" .}}
|
||||
|
||||
From: {{.FromName}} ({{.FromMail}})
|
||||
Base URL: {{.BaseURL}}
|
||||
Organisation: {{.Organisation}}{{end}}`
|
||||
|
||||
err = os.WriteFile(filepath.Join(tmpDir, "base.html.tmpl"), []byte(baseHTML), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(filepath.Join(tmpDir, "base.txt.tmpl"), []byte(baseTxt), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create unified test templates using i18n
|
||||
testHTML := `{{define "content"}}<h1>{{T "test.title"}}</h1><p>{{T "test.message" (dict "message" .Data.message)}}</p>{{end}}`
|
||||
testTxt := `{{define "content"}}{{T "test.title"}}
|
||||
{{T "test.message" (dict "message" .Data.message)}}{{end}}`
|
||||
|
||||
err = os.WriteFile(filepath.Join(tmpDir, "test.html.tmpl"), []byte(testHTML), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(filepath.Join(tmpDir, "test.txt.tmpl"), []byte(testTxt), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
renderer := NewRenderer(tmpDir, testBaseURL, testOrganisation, testFromName, testFromEmail, "en", i18nService)
|
||||
|
||||
return renderer, tmpDir
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - NewRenderer
|
||||
// ============================================================================
|
||||
|
||||
func TestNewRenderer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
localesDir := filepath.Join(tmpDir, "locales")
|
||||
os.MkdirAll(localesDir, 0755)
|
||||
i18nService := createTestI18n(t, localesDir)
|
||||
|
||||
renderer := NewRenderer("/tmp/templates", testBaseURL, testOrganisation, testFromName, testFromEmail, "en", i18nService)
|
||||
|
||||
require.NotNil(t, renderer)
|
||||
assert.Equal(t, "/tmp/templates", renderer.templateDir)
|
||||
assert.Equal(t, testBaseURL, renderer.baseURL)
|
||||
assert.Equal(t, testOrganisation, renderer.organisation)
|
||||
assert.Equal(t, testFromName, renderer.fromName)
|
||||
assert.Equal(t, testFromEmail, renderer.fromMail)
|
||||
assert.Equal(t, "en", renderer.defaultLocale)
|
||||
assert.NotNil(t, renderer.i18n)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - Renderer.Render
|
||||
// ============================================================================
|
||||
|
||||
func TestRenderer_Render_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
renderer, _ := createTestRenderer(t)
|
||||
|
||||
data := map[string]any{
|
||||
"message": "Hello World",
|
||||
}
|
||||
|
||||
htmlBody, textBody, err := renderer.Render("test", "en", data)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, htmlBody, "Test Template")
|
||||
assert.Contains(t, htmlBody, "Hello World")
|
||||
assert.Contains(t, htmlBody, testOrganisation)
|
||||
assert.Contains(t, htmlBody, testBaseURL)
|
||||
assert.Contains(t, htmlBody, testFromName)
|
||||
|
||||
assert.Contains(t, textBody, "Test Template")
|
||||
assert.Contains(t, textBody, "Hello World")
|
||||
assert.Contains(t, textBody, testOrganisation)
|
||||
}
|
||||
|
||||
func TestRenderer_Render_FrenchLocale(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
renderer, _ := createTestRenderer(t)
|
||||
|
||||
data := map[string]any{
|
||||
"message": "Bonjour le monde",
|
||||
}
|
||||
|
||||
htmlBody, textBody, err := renderer.Render("test", "fr", data)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, htmlBody, "Modèle de Test")
|
||||
assert.Contains(t, htmlBody, "Bonjour le monde")
|
||||
|
||||
assert.Contains(t, textBody, "Modèle de Test")
|
||||
assert.Contains(t, textBody, "Bonjour le monde")
|
||||
}
|
||||
|
||||
func TestRenderer_Render_DefaultLocale(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
renderer, _ := createTestRenderer(t)
|
||||
|
||||
data := map[string]any{
|
||||
"message": "Default locale test",
|
||||
}
|
||||
|
||||
// Empty locale should use default (en)
|
||||
htmlBody, textBody, err := renderer.Render("test", "", data)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, htmlBody, "Test Template")
|
||||
assert.Contains(t, textBody, "Default locale test")
|
||||
}
|
||||
|
||||
func TestRenderer_Render_TemplateNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
renderer, _ := createTestRenderer(t)
|
||||
|
||||
data := map[string]any{
|
||||
"message": "test",
|
||||
}
|
||||
|
||||
_, _, err := renderer.Render("nonexistent", "en", data)
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "template not found")
|
||||
}
|
||||
|
||||
func TestRenderer_Render_InvalidTemplateDir(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
localesDir := filepath.Join(tmpDir, "locales")
|
||||
os.MkdirAll(localesDir, 0755)
|
||||
i18nService := createTestI18n(t, localesDir)
|
||||
|
||||
renderer := NewRenderer("/nonexistent/dir", testBaseURL, testOrganisation, testFromName, testFromEmail, "en", i18nService)
|
||||
|
||||
data := map[string]any{
|
||||
"message": "test",
|
||||
}
|
||||
|
||||
_, _, err := renderer.Render("test", "en", data)
|
||||
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - NewSMTPSender
|
||||
// ============================================================================
|
||||
|
||||
func TestNewSMTPSender(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
renderer, _ := createTestRenderer(t)
|
||||
|
||||
cfg := config.MailConfig{
|
||||
Host: "smtp.example.com",
|
||||
Port: 587,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
From: testFromEmail,
|
||||
FromName: testFromName,
|
||||
}
|
||||
|
||||
sender := NewSMTPSender(cfg, renderer)
|
||||
|
||||
require.NotNil(t, sender)
|
||||
assert.NotNil(t, sender.config)
|
||||
assert.NotNil(t, sender.renderer)
|
||||
assert.Equal(t, "smtp.example.com", sender.config.Host)
|
||||
assert.Equal(t, 587, sender.config.Port)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - SMTPSender.Send
|
||||
// ============================================================================
|
||||
|
||||
func TestSMTPSender_Send_SMTPNotConfigured(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
renderer, _ := createTestRenderer(t)
|
||||
|
||||
// Empty host = SMTP not configured
|
||||
cfg := config.MailConfig{
|
||||
Host: "",
|
||||
}
|
||||
|
||||
sender := NewSMTPSender(cfg, renderer)
|
||||
|
||||
msg := Message{
|
||||
To: []string{"recipient@example.com"},
|
||||
Subject: "Test",
|
||||
Template: "test",
|
||||
Locale: "en",
|
||||
Data: map[string]any{
|
||||
"message": "test",
|
||||
},
|
||||
}
|
||||
|
||||
// Should not return error when SMTP not configured
|
||||
err := sender.Send(context.Background(), msg)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSMTPSender_Send_NoFrom(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
renderer, _ := createTestRenderer(t)
|
||||
|
||||
cfg := config.MailConfig{
|
||||
Host: "smtp.example.com",
|
||||
Port: 587,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
From: "", // No from address
|
||||
FromName: testFromName,
|
||||
}
|
||||
|
||||
sender := NewSMTPSender(cfg, renderer)
|
||||
|
||||
msg := Message{
|
||||
To: []string{"recipient@example.com"},
|
||||
Subject: "Test",
|
||||
Template: "test",
|
||||
Locale: "en",
|
||||
Data: map[string]any{
|
||||
"message": "test",
|
||||
},
|
||||
}
|
||||
|
||||
err := sender.Send(context.Background(), msg)
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "ACKIFY_MAIL_FROM not set")
|
||||
}
|
||||
|
||||
func TestSMTPSender_Send_NoRecipients(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
renderer, _ := createTestRenderer(t)
|
||||
|
||||
cfg := config.MailConfig{
|
||||
Host: "smtp.example.com",
|
||||
Port: 587,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
From: testFromEmail,
|
||||
FromName: testFromName,
|
||||
}
|
||||
|
||||
sender := NewSMTPSender(cfg, renderer)
|
||||
|
||||
msg := Message{
|
||||
To: []string{}, // No recipients
|
||||
Subject: "Test",
|
||||
Template: "test",
|
||||
Locale: "en",
|
||||
Data: map[string]any{
|
||||
"message": "test",
|
||||
},
|
||||
}
|
||||
|
||||
err := sender.Send(context.Background(), msg)
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no recipients specified")
|
||||
}
|
||||
|
||||
func TestSMTPSender_Send_InvalidTemplate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
renderer, _ := createTestRenderer(t)
|
||||
|
||||
cfg := config.MailConfig{
|
||||
Host: "smtp.example.com",
|
||||
Port: 587,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
From: testFromEmail,
|
||||
FromName: testFromName,
|
||||
}
|
||||
|
||||
sender := NewSMTPSender(cfg, renderer)
|
||||
|
||||
msg := Message{
|
||||
To: []string{"recipient@example.com"},
|
||||
Subject: "Test",
|
||||
Template: "nonexistent",
|
||||
Locale: "en",
|
||||
Data: map[string]any{},
|
||||
}
|
||||
|
||||
err := sender.Send(context.Background(), msg)
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to render email template")
|
||||
}
|
||||
|
||||
func TestSMTPSender_Send_SubjectPrefix(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
renderer, _ := createTestRenderer(t)
|
||||
|
||||
cfg := config.MailConfig{
|
||||
Host: "smtp.example.com",
|
||||
Port: 587,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
From: testFromEmail,
|
||||
FromName: testFromName,
|
||||
SubjectPrefix: "[TEST] ",
|
||||
}
|
||||
|
||||
sender := NewSMTPSender(cfg, renderer)
|
||||
|
||||
// We can't actually send email in tests, but we can verify the config is used
|
||||
assert.Equal(t, "[TEST] ", sender.config.SubjectPrefix)
|
||||
}
|
||||
|
||||
func TestMessage_Structure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
msg := Message{
|
||||
To: []string{"to@example.com"},
|
||||
Cc: []string{"cc@example.com"},
|
||||
Bcc: []string{"bcc@example.com"},
|
||||
Subject: "Test Subject",
|
||||
Template: "test",
|
||||
Locale: "en",
|
||||
Data: map[string]any{
|
||||
"key": "value",
|
||||
},
|
||||
Headers: map[string]string{
|
||||
"X-Custom": "value",
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, []string{"to@example.com"}, msg.To)
|
||||
assert.Equal(t, []string{"cc@example.com"}, msg.Cc)
|
||||
assert.Equal(t, []string{"bcc@example.com"}, msg.Bcc)
|
||||
assert.Equal(t, "Test Subject", msg.Subject)
|
||||
assert.Equal(t, "test", msg.Template)
|
||||
assert.Equal(t, "en", msg.Locale)
|
||||
assert.Equal(t, "value", msg.Data["key"])
|
||||
assert.Equal(t, "value", msg.Headers["X-Custom"])
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - TemplateData Structure
|
||||
// ============================================================================
|
||||
|
||||
func TestTemplateData_Structure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
data := TemplateData{
|
||||
Organisation: "Test Org",
|
||||
BaseURL: "https://example.com",
|
||||
FromName: "Test Sender",
|
||||
FromMail: "test@example.com",
|
||||
Data: map[string]any{
|
||||
"key1": "value1",
|
||||
"key2": 123,
|
||||
},
|
||||
T: func(key string, args ...map[string]any) string {
|
||||
return key
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, "Test Org", data.Organisation)
|
||||
assert.Equal(t, "https://example.com", data.BaseURL)
|
||||
assert.Equal(t, "Test Sender", data.FromName)
|
||||
assert.Equal(t, "test@example.com", data.FromMail)
|
||||
assert.Equal(t, "value1", data.Data["key1"])
|
||||
assert.Equal(t, 123, data.Data["key2"])
|
||||
assert.NotNil(t, data.T)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - Concurrency
|
||||
// ============================================================================
|
||||
|
||||
func TestRenderer_Render_Concurrent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
renderer, _ := createTestRenderer(t)
|
||||
|
||||
const numGoroutines = 50
|
||||
|
||||
done := make(chan bool, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer func() { done <- true }()
|
||||
|
||||
data := map[string]any{
|
||||
"message": "Concurrent test",
|
||||
}
|
||||
|
||||
locale := "en"
|
||||
if id%2 == 0 {
|
||||
locale = "fr"
|
||||
}
|
||||
|
||||
htmlBody, textBody, err := renderer.Render("test", locale, data)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, htmlBody)
|
||||
assert.NotEmpty(t, textBody)
|
||||
}(i)
|
||||
}
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// BENCHMARKS
|
||||
// ============================================================================
|
||||
|
||||
func BenchmarkRenderer_Render(b *testing.B) {
|
||||
renderer, _ := createTestRenderer(&testing.T{})
|
||||
|
||||
data := map[string]any{
|
||||
"message": "Benchmark test",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _ = renderer.Render("test", "en", data)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRenderer_Render_Parallel(b *testing.B) {
|
||||
renderer, _ := createTestRenderer(&testing.T{})
|
||||
|
||||
data := map[string]any{
|
||||
"message": "Benchmark test",
|
||||
}
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
_, _, _ = renderer.Render("test", "en", data)
|
||||
}
|
||||
})
|
||||
}
|
||||
265
backend/internal/infrastructure/email/helpers_test.go
Normal file
265
backend/internal/infrastructure/email/helpers_test.go
Normal file
@@ -0,0 +1,265 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package email
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Mock sender for testing
|
||||
type mockSender struct {
|
||||
sendFunc func(ctx context.Context, msg Message) error
|
||||
lastMsg *Message
|
||||
}
|
||||
|
||||
func (m *mockSender) Send(ctx context.Context, msg Message) error {
|
||||
m.lastMsg = &msg
|
||||
if m.sendFunc != nil {
|
||||
return m.sendFunc(ctx, msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSendEmail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
template string
|
||||
to []string
|
||||
locale string
|
||||
subject string
|
||||
data map[string]any
|
||||
sendError error
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Send email successfully",
|
||||
template: "test_template",
|
||||
to: []string{"user@example.com"},
|
||||
locale: "en",
|
||||
subject: "Test Subject",
|
||||
data: map[string]any{
|
||||
"name": "John",
|
||||
},
|
||||
sendError: nil,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Send email with multiple recipients",
|
||||
template: "welcome",
|
||||
to: []string{"user1@example.com", "user2@example.com"},
|
||||
locale: "fr",
|
||||
subject: "Bienvenue",
|
||||
data: map[string]any{
|
||||
"company": "Acme Corp",
|
||||
},
|
||||
sendError: nil,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Send email with error",
|
||||
template: "error_template",
|
||||
to: []string{"user@example.com"},
|
||||
locale: "en",
|
||||
subject: "Error Test",
|
||||
data: nil,
|
||||
sendError: errors.New("SMTP connection failed"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Send email with empty data",
|
||||
template: "simple_template",
|
||||
to: []string{"test@example.com"},
|
||||
locale: "en",
|
||||
subject: "Simple Email",
|
||||
data: map[string]any{},
|
||||
sendError: nil,
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
mock := &mockSender{
|
||||
sendFunc: func(ctx context.Context, msg Message) error {
|
||||
return tt.sendError
|
||||
},
|
||||
}
|
||||
|
||||
err := SendEmail(ctx, mock, tt.template, tt.to, tt.locale, tt.subject, tt.data)
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error but got nil")
|
||||
}
|
||||
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Verify message was constructed correctly
|
||||
if mock.lastMsg == nil {
|
||||
t.Fatal("Expected message to be captured")
|
||||
}
|
||||
|
||||
if mock.lastMsg.Template != tt.template {
|
||||
t.Errorf("Expected template '%s', got '%s'", tt.template, mock.lastMsg.Template)
|
||||
}
|
||||
|
||||
if mock.lastMsg.Subject != tt.subject {
|
||||
t.Errorf("Expected subject '%s', got '%s'", tt.subject, mock.lastMsg.Subject)
|
||||
}
|
||||
|
||||
if mock.lastMsg.Locale != tt.locale {
|
||||
t.Errorf("Expected locale '%s', got '%s'", tt.locale, mock.lastMsg.Locale)
|
||||
}
|
||||
|
||||
if len(mock.lastMsg.To) != len(tt.to) {
|
||||
t.Errorf("Expected %d recipients, got %d", len(tt.to), len(mock.lastMsg.To))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendSignatureReminderEmail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
to []string
|
||||
locale string
|
||||
docID string
|
||||
docURL string
|
||||
signURL string
|
||||
recipientName string
|
||||
expectedSubject string
|
||||
sendError error
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Send reminder in English",
|
||||
to: []string{"user@example.com"},
|
||||
locale: "en",
|
||||
docID: "doc123",
|
||||
docURL: "https://example.com/doc.pdf",
|
||||
signURL: "https://example.com/sign?doc=doc123",
|
||||
recipientName: "John Doe",
|
||||
expectedSubject: "Reminder: Document reading confirmation required",
|
||||
sendError: nil,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Send reminder in French",
|
||||
to: []string{"utilisateur@exemple.fr"},
|
||||
locale: "fr",
|
||||
docID: "doc456",
|
||||
docURL: "https://exemple.fr/document.pdf",
|
||||
signURL: "https://exemple.fr/sign?doc=doc456",
|
||||
recipientName: "Marie Dupont",
|
||||
expectedSubject: "Rappel : Confirmation de lecture de document requise",
|
||||
sendError: nil,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Send reminder with unknown locale defaults to English",
|
||||
to: []string{"user@example.com"},
|
||||
locale: "es",
|
||||
docID: "doc789",
|
||||
docURL: "https://example.com/doc.pdf",
|
||||
signURL: "https://example.com/sign?doc=doc789",
|
||||
recipientName: "Juan Garcia",
|
||||
expectedSubject: "Reminder: Document reading confirmation required",
|
||||
sendError: nil,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Send reminder with error",
|
||||
to: []string{"user@example.com"},
|
||||
locale: "en",
|
||||
docID: "doc999",
|
||||
docURL: "https://example.com/doc.pdf",
|
||||
signURL: "https://example.com/sign?doc=doc999",
|
||||
recipientName: "Test User",
|
||||
expectedSubject: "Reminder: Document reading confirmation required",
|
||||
sendError: errors.New("email server unavailable"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Send reminder with empty recipient name",
|
||||
to: []string{"user@example.com"},
|
||||
locale: "en",
|
||||
docID: "doc000",
|
||||
docURL: "https://example.com/doc.pdf",
|
||||
signURL: "https://example.com/sign?doc=doc000",
|
||||
recipientName: "",
|
||||
expectedSubject: "Reminder: Document reading confirmation required",
|
||||
sendError: nil,
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
mock := &mockSender{
|
||||
sendFunc: func(ctx context.Context, msg Message) error {
|
||||
return tt.sendError
|
||||
},
|
||||
}
|
||||
|
||||
err := SendSignatureReminderEmail(ctx, mock, tt.to, tt.locale, tt.docID, tt.docURL, tt.signURL, tt.recipientName)
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error but got nil")
|
||||
}
|
||||
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Verify message construction
|
||||
if mock.lastMsg == nil {
|
||||
t.Fatal("Expected message to be captured")
|
||||
}
|
||||
|
||||
if mock.lastMsg.Template != "signature_reminder" {
|
||||
t.Errorf("Expected template 'signature_reminder', got '%s'", mock.lastMsg.Template)
|
||||
}
|
||||
|
||||
if mock.lastMsg.Subject != tt.expectedSubject {
|
||||
t.Errorf("Expected subject '%s', got '%s'", tt.expectedSubject, mock.lastMsg.Subject)
|
||||
}
|
||||
|
||||
if mock.lastMsg.Locale != tt.locale {
|
||||
t.Errorf("Expected locale '%s', got '%s'", tt.locale, mock.lastMsg.Locale)
|
||||
}
|
||||
|
||||
// Verify data fields
|
||||
if mock.lastMsg.Data == nil {
|
||||
t.Fatal("Expected data to be present")
|
||||
}
|
||||
|
||||
if docID, ok := mock.lastMsg.Data["DocID"].(string); !ok || docID != tt.docID {
|
||||
t.Errorf("Expected DocID '%s', got '%v'", tt.docID, mock.lastMsg.Data["DocID"])
|
||||
}
|
||||
|
||||
if docURL, ok := mock.lastMsg.Data["DocURL"].(string); !ok || docURL != tt.docURL {
|
||||
t.Errorf("Expected DocURL '%s', got '%v'", tt.docURL, mock.lastMsg.Data["DocURL"])
|
||||
}
|
||||
|
||||
if signURL, ok := mock.lastMsg.Data["SignURL"].(string); !ok || signURL != tt.signURL {
|
||||
t.Errorf("Expected SignURL '%s', got '%v'", tt.signURL, mock.lastMsg.Data["SignURL"])
|
||||
}
|
||||
|
||||
if recipientName, ok := mock.lastMsg.Data["RecipientName"].(string); !ok || recipientName != tt.recipientName {
|
||||
t.Errorf("Expected RecipientName '%s', got '%v'", tt.recipientName, mock.lastMsg.Data["RecipientName"])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
105
backend/internal/infrastructure/email/queue_helpers.go
Normal file
105
backend/internal/infrastructure/email/queue_helpers.go
Normal file
@@ -0,0 +1,105 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package email
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
)
|
||||
|
||||
// QueueSender implements the Sender interface by queuing emails instead of sending them directly
|
||||
type QueueSender struct {
|
||||
queueRepo QueueRepository
|
||||
baseURL string
|
||||
}
|
||||
|
||||
// NewQueueSender creates a new queue-based email sender
|
||||
func NewQueueSender(queueRepo QueueRepository, baseURL string) *QueueSender {
|
||||
return &QueueSender{
|
||||
queueRepo: queueRepo,
|
||||
baseURL: baseURL,
|
||||
}
|
||||
}
|
||||
|
||||
// Send queues an email for asynchronous processing
|
||||
func (q *QueueSender) Send(ctx context.Context, msg Message) error {
|
||||
// Convert message data to proper format
|
||||
data := msg.Data
|
||||
if data == nil {
|
||||
data = make(map[string]interface{})
|
||||
}
|
||||
|
||||
input := models.EmailQueueInput{
|
||||
ToAddresses: msg.To,
|
||||
CcAddresses: msg.Cc,
|
||||
BccAddresses: msg.Bcc,
|
||||
Subject: msg.Subject,
|
||||
Template: msg.Template,
|
||||
Locale: msg.Locale,
|
||||
Data: data,
|
||||
Headers: msg.Headers,
|
||||
Priority: models.EmailPriorityNormal,
|
||||
}
|
||||
|
||||
// Set priority based on template type
|
||||
switch msg.Template {
|
||||
case "signature_reminder":
|
||||
input.Priority = models.EmailPriorityHigh
|
||||
case "welcome", "notification":
|
||||
input.Priority = models.EmailPriorityNormal
|
||||
default:
|
||||
input.Priority = models.EmailPriorityNormal
|
||||
}
|
||||
|
||||
// Queue the email
|
||||
_, err := q.queueRepo.Enqueue(ctx, input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to queue email: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueueSignatureReminderEmail queues a signature reminder email
|
||||
func QueueSignatureReminderEmail(
|
||||
ctx context.Context,
|
||||
queueRepo QueueRepository,
|
||||
recipients []string,
|
||||
locale string,
|
||||
docID string,
|
||||
docURL string,
|
||||
signURL string,
|
||||
recipientName string,
|
||||
sentBy string,
|
||||
) error {
|
||||
data := map[string]interface{}{
|
||||
"doc_id": docID,
|
||||
"doc_url": docURL,
|
||||
"sign_url": signURL,
|
||||
"recipient_name": recipientName,
|
||||
"locale": locale,
|
||||
}
|
||||
|
||||
// Create a reference for tracking
|
||||
refType := "signature_reminder"
|
||||
|
||||
input := models.EmailQueueInput{
|
||||
ToAddresses: recipients,
|
||||
Subject: "Reminder: Document signature required",
|
||||
Template: "signature_reminder",
|
||||
Locale: locale,
|
||||
Data: data,
|
||||
Priority: models.EmailPriorityHigh,
|
||||
ReferenceType: &refType,
|
||||
ReferenceID: &docID,
|
||||
CreatedBy: &sentBy,
|
||||
}
|
||||
|
||||
_, err := queueRepo.Enqueue(ctx, input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to queue signature reminder: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -7,7 +7,10 @@ import (
|
||||
htmlTemplate "html/template"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
txtTemplate "text/template"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/i18n"
|
||||
)
|
||||
|
||||
type Renderer struct {
|
||||
@@ -17,6 +20,7 @@ type Renderer struct {
|
||||
fromName string
|
||||
fromMail string
|
||||
defaultLocale string
|
||||
i18n *i18n.I18n
|
||||
}
|
||||
|
||||
type TemplateData struct {
|
||||
@@ -25,9 +29,10 @@ type TemplateData struct {
|
||||
FromName string
|
||||
FromMail string
|
||||
Data map[string]any
|
||||
T func(key string, args ...map[string]any) string
|
||||
}
|
||||
|
||||
func NewRenderer(templateDir, baseURL, organisation, fromName, fromMail, defaultLocale string) *Renderer {
|
||||
func NewRenderer(templateDir, baseURL, organisation, fromName, fromMail, defaultLocale string, i18nBundle *i18n.I18n) *Renderer {
|
||||
return &Renderer{
|
||||
templateDir: templateDir,
|
||||
baseURL: baseURL,
|
||||
@@ -35,6 +40,7 @@ func NewRenderer(templateDir, baseURL, organisation, fromName, fromMail, default
|
||||
fromName: fromName,
|
||||
fromMail: fromMail,
|
||||
defaultLocale: defaultLocale,
|
||||
i18n: i18nBundle,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,12 +49,28 @@ func (r *Renderer) Render(templateName, locale string, data map[string]any) (htm
|
||||
locale = r.defaultLocale
|
||||
}
|
||||
|
||||
// Create translation function with template variable interpolation
|
||||
tFunc := func(key string, args ...map[string]any) string {
|
||||
translated := r.i18n.T(locale, key)
|
||||
|
||||
// If args provided, interpolate {{.VarName}} placeholders
|
||||
if len(args) > 0 && args[0] != nil {
|
||||
for k, v := range args[0] {
|
||||
placeholder := fmt.Sprintf("{{.%s}}", k)
|
||||
translated = strings.ReplaceAll(translated, placeholder, fmt.Sprintf("%v", v))
|
||||
}
|
||||
}
|
||||
|
||||
return translated
|
||||
}
|
||||
|
||||
templateData := TemplateData{
|
||||
Organisation: r.organisation,
|
||||
BaseURL: r.baseURL,
|
||||
FromName: r.fromName,
|
||||
FromMail: r.fromMail,
|
||||
Data: data,
|
||||
T: tFunc,
|
||||
}
|
||||
|
||||
htmlBody, err = r.renderHTML(templateName, locale, templateData)
|
||||
@@ -66,13 +88,32 @@ func (r *Renderer) Render(templateName, locale string, data map[string]any) (htm
|
||||
|
||||
func (r *Renderer) renderHTML(templateName, locale string, data TemplateData) (string, error) {
|
||||
baseTemplatePath := filepath.Join(r.templateDir, "base.html.tmpl")
|
||||
templatePath := r.resolveTemplatePath(templateName, locale, "html.tmpl")
|
||||
templatePath := filepath.Join(r.templateDir, fmt.Sprintf("%s.html.tmpl", templateName))
|
||||
|
||||
if templatePath == "" {
|
||||
return "", fmt.Errorf("template not found: %s (locale: %s)", templateName, locale)
|
||||
if _, err := os.Stat(templatePath); err != nil {
|
||||
return "", fmt.Errorf("template not found: %s", templatePath)
|
||||
}
|
||||
|
||||
tmpl, err := htmlTemplate.ParseFiles(baseTemplatePath, templatePath)
|
||||
// Create template with helper functions
|
||||
tmpl := htmlTemplate.New("base").Funcs(htmlTemplate.FuncMap{
|
||||
"dict": func(args ...interface{}) map[string]any {
|
||||
if len(args)%2 != 0 {
|
||||
return nil
|
||||
}
|
||||
dict := make(map[string]any)
|
||||
for i := 0; i < len(args); i += 2 {
|
||||
key, ok := args[i].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
dict[key] = args[i+1]
|
||||
}
|
||||
return dict
|
||||
},
|
||||
"T": data.T, // Expose T function to template
|
||||
})
|
||||
|
||||
tmpl, err := tmpl.ParseFiles(baseTemplatePath, templatePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse template: %w", err)
|
||||
}
|
||||
@@ -87,13 +128,32 @@ func (r *Renderer) renderHTML(templateName, locale string, data TemplateData) (s
|
||||
|
||||
func (r *Renderer) renderText(templateName, locale string, data TemplateData) (string, error) {
|
||||
baseTemplatePath := filepath.Join(r.templateDir, "base.txt.tmpl")
|
||||
templatePath := r.resolveTemplatePath(templateName, locale, "txt.tmpl")
|
||||
templatePath := filepath.Join(r.templateDir, fmt.Sprintf("%s.txt.tmpl", templateName))
|
||||
|
||||
if templatePath == "" {
|
||||
return "", fmt.Errorf("template not found: %s (locale: %s)", templateName, locale)
|
||||
if _, err := os.Stat(templatePath); err != nil {
|
||||
return "", fmt.Errorf("template not found: %s", templatePath)
|
||||
}
|
||||
|
||||
tmpl, err := txtTemplate.ParseFiles(baseTemplatePath, templatePath)
|
||||
// Create template with helper functions
|
||||
tmpl := txtTemplate.New("base").Funcs(txtTemplate.FuncMap{
|
||||
"dict": func(args ...interface{}) map[string]any {
|
||||
if len(args)%2 != 0 {
|
||||
return nil
|
||||
}
|
||||
dict := make(map[string]any)
|
||||
for i := 0; i < len(args); i += 2 {
|
||||
key, ok := args[i].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
dict[key] = args[i+1]
|
||||
}
|
||||
return dict
|
||||
},
|
||||
"T": data.T, // Expose T function to template
|
||||
})
|
||||
|
||||
tmpl, err := tmpl.ParseFiles(baseTemplatePath, templatePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse template: %w", err)
|
||||
}
|
||||
@@ -105,17 +165,3 @@ func (r *Renderer) renderText(templateName, locale string, data TemplateData) (s
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func (r *Renderer) resolveTemplatePath(templateName, locale, extension string) string {
|
||||
localizedPath := filepath.Join(r.templateDir, fmt.Sprintf("%s.%s.%s", templateName, locale, extension))
|
||||
if _, err := os.Stat(localizedPath); err == nil {
|
||||
return localizedPath
|
||||
}
|
||||
|
||||
fallbackPath := filepath.Join(r.templateDir, fmt.Sprintf("%s.en.%s", templateName, extension))
|
||||
if _, err := os.Stat(fallbackPath); err == nil {
|
||||
return fallbackPath
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
|
||||
mail "github.com/go-mail/mail/v2"
|
||||
|
||||
"github.com/btouchard/ackify-ce/internal/infrastructure/config"
|
||||
"github.com/btouchard/ackify-ce/pkg/logger"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/config"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/logger"
|
||||
)
|
||||
|
||||
type Sender interface {
|
||||
373
backend/internal/infrastructure/email/worker.go
Normal file
373
backend/internal/infrastructure/email/worker.go
Normal file
@@ -0,0 +1,373 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package email
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/logger"
|
||||
)
|
||||
|
||||
// QueueRepository defines the interface for email queue operations
|
||||
type QueueRepository interface {
|
||||
Enqueue(ctx context.Context, input models.EmailQueueInput) (*models.EmailQueueItem, error)
|
||||
GetNextToProcess(ctx context.Context, limit int) ([]*models.EmailQueueItem, error)
|
||||
MarkAsSent(ctx context.Context, id int64) error
|
||||
MarkAsFailed(ctx context.Context, id int64, err error, shouldRetry bool) error
|
||||
GetRetryableEmails(ctx context.Context, limit int) ([]*models.EmailQueueItem, error)
|
||||
CleanupOldEmails(ctx context.Context, olderThan time.Duration) (int64, error)
|
||||
}
|
||||
|
||||
// Worker processes emails from the queue asynchronously
|
||||
type Worker struct {
|
||||
queueRepo QueueRepository
|
||||
sender Sender
|
||||
renderer *Renderer
|
||||
|
||||
// Worker configuration
|
||||
batchSize int
|
||||
pollInterval time.Duration
|
||||
cleanupInterval time.Duration
|
||||
cleanupAge time.Duration
|
||||
maxConcurrent int
|
||||
|
||||
// Control
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
stopChan chan struct{}
|
||||
started bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// WorkerConfig contains configuration for the email worker
|
||||
type WorkerConfig struct {
|
||||
BatchSize int // Number of emails to process in each batch (default: 10)
|
||||
PollInterval time.Duration // How often to check for new emails (default: 5s)
|
||||
CleanupInterval time.Duration // How often to cleanup old emails (default: 1 hour)
|
||||
CleanupAge time.Duration // Age of emails to cleanup (default: 7 days)
|
||||
MaxConcurrent int // Maximum concurrent email sends (default: 5)
|
||||
}
|
||||
|
||||
// DefaultWorkerConfig returns default worker configuration
|
||||
func DefaultWorkerConfig() WorkerConfig {
|
||||
return WorkerConfig{
|
||||
BatchSize: 10,
|
||||
PollInterval: 5 * time.Second,
|
||||
CleanupInterval: 1 * time.Hour,
|
||||
CleanupAge: 7 * 24 * time.Hour, // 7 days
|
||||
MaxConcurrent: 5,
|
||||
}
|
||||
}
|
||||
|
||||
// NewWorker creates a new email worker
|
||||
func NewWorker(queueRepo QueueRepository, sender Sender, renderer *Renderer, config WorkerConfig) *Worker {
|
||||
// Apply defaults
|
||||
if config.BatchSize <= 0 {
|
||||
config.BatchSize = 10
|
||||
}
|
||||
if config.PollInterval <= 0 {
|
||||
config.PollInterval = 5 * time.Second
|
||||
}
|
||||
if config.CleanupInterval <= 0 {
|
||||
config.CleanupInterval = 1 * time.Hour
|
||||
}
|
||||
if config.CleanupAge <= 0 {
|
||||
config.CleanupAge = 7 * 24 * time.Hour
|
||||
}
|
||||
if config.MaxConcurrent <= 0 {
|
||||
config.MaxConcurrent = 5
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return &Worker{
|
||||
queueRepo: queueRepo,
|
||||
sender: sender,
|
||||
renderer: renderer,
|
||||
batchSize: config.BatchSize,
|
||||
pollInterval: config.PollInterval,
|
||||
cleanupInterval: config.CleanupInterval,
|
||||
cleanupAge: config.CleanupAge,
|
||||
maxConcurrent: config.MaxConcurrent,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins processing emails from the queue
|
||||
func (w *Worker) Start() error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
if w.started {
|
||||
return fmt.Errorf("worker already started")
|
||||
}
|
||||
|
||||
logger.Logger.Info("Starting email worker",
|
||||
"batch_size", w.batchSize,
|
||||
"poll_interval", w.pollInterval,
|
||||
"max_concurrent", w.maxConcurrent)
|
||||
|
||||
w.started = true
|
||||
|
||||
// Start the main processing loop
|
||||
w.wg.Add(1)
|
||||
go w.processLoop()
|
||||
|
||||
// Start the cleanup loop
|
||||
w.wg.Add(1)
|
||||
go w.cleanupLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully stops the worker
|
||||
func (w *Worker) Stop() error {
|
||||
w.mu.Lock()
|
||||
if !w.started {
|
||||
w.mu.Unlock()
|
||||
return fmt.Errorf("worker not started")
|
||||
}
|
||||
w.mu.Unlock()
|
||||
|
||||
logger.Logger.Info("Stopping email worker...")
|
||||
|
||||
// Signal shutdown
|
||||
w.cancel()
|
||||
close(w.stopChan)
|
||||
|
||||
// Wait for goroutines to finish with timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
w.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
logger.Logger.Info("Email worker stopped gracefully")
|
||||
case <-time.After(30 * time.Second):
|
||||
logger.Logger.Warn("Email worker stop timeout, some operations may not have completed")
|
||||
}
|
||||
|
||||
w.mu.Lock()
|
||||
w.started = false
|
||||
w.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// processLoop is the main processing loop
|
||||
func (w *Worker) processLoop() {
|
||||
defer w.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(w.pollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Immediate first check
|
||||
w.processBatch()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-w.ctx.Done():
|
||||
return
|
||||
case <-w.stopChan:
|
||||
return
|
||||
case <-ticker.C:
|
||||
w.processBatch()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processBatch processes a batch of emails
|
||||
func (w *Worker) processBatch() {
|
||||
ctx, cancel := context.WithTimeout(w.ctx, 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// Get next batch of emails
|
||||
emails, err := w.queueRepo.GetNextToProcess(ctx, w.batchSize)
|
||||
if err != nil {
|
||||
logger.Logger.Error("Failed to get emails to process", "error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(emails) == 0 {
|
||||
// Also check for retryable emails
|
||||
emails, err = w.queueRepo.GetRetryableEmails(ctx, w.batchSize)
|
||||
if err != nil {
|
||||
logger.Logger.Error("Failed to get retryable emails", "error", err.Error())
|
||||
return
|
||||
}
|
||||
if len(emails) == 0 {
|
||||
return // Nothing to process
|
||||
}
|
||||
}
|
||||
|
||||
logger.Logger.Debug("Processing email batch", "count", len(emails))
|
||||
|
||||
// Process emails concurrently with limited concurrency
|
||||
sem := make(chan struct{}, w.maxConcurrent)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for _, email := range emails {
|
||||
wg.Add(1)
|
||||
sem <- struct{}{} // Acquire semaphore
|
||||
|
||||
go func(item *models.EmailQueueItem) {
|
||||
defer wg.Done()
|
||||
defer func() { <-sem }() // Release semaphore
|
||||
|
||||
w.processEmail(ctx, item)
|
||||
}(email)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// processEmail processes a single email
|
||||
func (w *Worker) processEmail(ctx context.Context, item *models.EmailQueueItem) {
|
||||
logger.Logger.Debug("Processing email",
|
||||
"id", item.ID,
|
||||
"template", item.Template,
|
||||
"retry_count", item.RetryCount)
|
||||
|
||||
// Convert data from JSON to map
|
||||
var data map[string]interface{}
|
||||
if len(item.Data) > 0 {
|
||||
if err := json.Unmarshal(item.Data, &data); err != nil {
|
||||
logger.Logger.Error("Failed to unmarshal email data",
|
||||
"id", item.ID,
|
||||
"error", err.Error())
|
||||
// Mark as failed without retry (data corruption)
|
||||
w.queueRepo.MarkAsFailed(ctx, item.ID, err, false)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Convert headers from JSON to map
|
||||
var headers map[string]string
|
||||
if item.Headers.Valid && len(item.Headers.RawMessage) > 0 {
|
||||
if err := json.Unmarshal(item.Headers.RawMessage, &headers); err != nil {
|
||||
logger.Logger.Error("Failed to unmarshal email headers",
|
||||
"id", item.ID,
|
||||
"error", err.Error())
|
||||
// Continue without headers
|
||||
headers = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Create message
|
||||
msg := Message{
|
||||
To: item.ToAddresses,
|
||||
Cc: item.CcAddresses,
|
||||
Bcc: item.BccAddresses,
|
||||
Subject: item.Subject,
|
||||
Template: item.Template,
|
||||
Locale: item.Locale,
|
||||
Data: data,
|
||||
Headers: headers,
|
||||
}
|
||||
|
||||
// Send email
|
||||
err := w.sender.Send(ctx, msg)
|
||||
if err != nil {
|
||||
logger.Logger.Warn("Failed to send email",
|
||||
"id", item.ID,
|
||||
"template", item.Template,
|
||||
"error", err.Error(),
|
||||
"retry_count", item.RetryCount)
|
||||
|
||||
// Determine if we should retry
|
||||
shouldRetry := item.RetryCount < item.MaxRetries && isRetryableError(err)
|
||||
|
||||
// Mark as failed (with or without retry)
|
||||
if markErr := w.queueRepo.MarkAsFailed(ctx, item.ID, err, shouldRetry); markErr != nil {
|
||||
logger.Logger.Error("Failed to mark email as failed",
|
||||
"id", item.ID,
|
||||
"error", markErr.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Mark as sent
|
||||
if err := w.queueRepo.MarkAsSent(ctx, item.ID); err != nil {
|
||||
logger.Logger.Error("Failed to mark email as sent",
|
||||
"id", item.ID,
|
||||
"error", err.Error())
|
||||
// Email was sent but we failed to update the database
|
||||
// This is not critical, the email won't be resent
|
||||
}
|
||||
|
||||
logger.Logger.Info("Email sent successfully",
|
||||
"id", item.ID,
|
||||
"template", item.Template,
|
||||
"to", item.ToAddresses)
|
||||
}
|
||||
|
||||
// cleanupLoop periodically cleans up old emails
|
||||
func (w *Worker) cleanupLoop() {
|
||||
defer w.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(w.cleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-w.ctx.Done():
|
||||
return
|
||||
case <-w.stopChan:
|
||||
return
|
||||
case <-ticker.C:
|
||||
w.performCleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// performCleanup removes old processed emails
|
||||
func (w *Worker) performCleanup() {
|
||||
ctx, cancel := context.WithTimeout(w.ctx, 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
deleted, err := w.queueRepo.CleanupOldEmails(ctx, w.cleanupAge)
|
||||
if err != nil {
|
||||
logger.Logger.Error("Failed to cleanup old emails", "error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if deleted > 0 {
|
||||
logger.Logger.Info("Cleaned up old emails", "count", deleted)
|
||||
}
|
||||
}
|
||||
|
||||
// isRetryableError determines if an error is retryable
|
||||
func isRetryableError(err error) bool {
|
||||
// TODO: Implement more sophisticated error detection
|
||||
// For now, retry all errors except explicit data/template errors
|
||||
errStr := err.Error()
|
||||
|
||||
// Don't retry template or data errors
|
||||
if contains(errStr, "template") || contains(errStr, "unmarshal") || contains(errStr, "invalid") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Retry network and timeout errors
|
||||
if contains(errStr, "timeout") || contains(errStr, "connection") || contains(errStr, "refused") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Default to retry
|
||||
return true
|
||||
}
|
||||
|
||||
// contains checks if a string contains a substring (case-insensitive)
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) > 0 && len(substr) > 0 &&
|
||||
(s == substr || len(s) > len(substr) &&
|
||||
(s[:len(substr)] == substr || s[len(s)-len(substr):] == substr))
|
||||
}
|
||||
@@ -21,6 +21,9 @@ var (
|
||||
SupportedLangs = []language.Tag{
|
||||
language.English,
|
||||
language.French,
|
||||
language.Italian,
|
||||
language.German,
|
||||
language.Spanish,
|
||||
}
|
||||
matcher = language.NewMatcher(SupportedLangs)
|
||||
)
|
||||
@@ -34,14 +37,13 @@ func NewI18n(localesDir string) (*I18n, error) {
|
||||
translations: make(map[string]map[string]string),
|
||||
}
|
||||
|
||||
// Load English translations
|
||||
if err := i18n.loadTranslations(filepath.Join(localesDir, "en.json"), "en"); err != nil {
|
||||
return nil, fmt.Errorf("failed to load English translations: %w", err)
|
||||
}
|
||||
|
||||
// Load French translations
|
||||
if err := i18n.loadTranslations(filepath.Join(localesDir, "fr.json"), "fr"); err != nil {
|
||||
return nil, fmt.Errorf("failed to load French translations: %w", err)
|
||||
// Load all supported language translations
|
||||
languages := []string{"en", "fr", "it", "de", "es"}
|
||||
for _, lang := range languages {
|
||||
filePath := filepath.Join(localesDir, lang+".json")
|
||||
if err := i18n.loadTranslations(filePath, lang); err != nil {
|
||||
return nil, fmt.Errorf("failed to load %s translations: %w", lang, err)
|
||||
}
|
||||
}
|
||||
|
||||
return i18n, nil
|
||||
@@ -129,14 +131,12 @@ func SetLangCookie(w http.ResponseWriter, lang string, secureCookies bool) {
|
||||
http.SetCookie(w, cookie)
|
||||
}
|
||||
|
||||
// normalizeLang normalizes language codes (en-US -> en, fr-FR -> fr)
|
||||
// normalizeLang normalizes language codes (en-US -> en, fr-FR -> fr, it-IT -> it, etc.)
|
||||
func normalizeLang(lang string) string {
|
||||
lang = strings.ToLower(lang)
|
||||
if strings.HasPrefix(lang, "en") {
|
||||
return "en"
|
||||
}
|
||||
if strings.HasPrefix(lang, "fr") {
|
||||
return "fr"
|
||||
// Extract base language code (before - or _)
|
||||
if idx := strings.IndexAny(lang, "-_"); idx > 0 {
|
||||
return lang[:idx]
|
||||
}
|
||||
return lang
|
||||
}
|
||||
@@ -144,7 +144,13 @@ func normalizeLang(lang string) string {
|
||||
// isSupported checks if a language is supported
|
||||
func isSupported(lang string) bool {
|
||||
lang = normalizeLang(lang)
|
||||
return lang == "en" || lang == "fr"
|
||||
supportedLangs := []string{"en", "fr", "it", "de", "es"}
|
||||
for _, supported := range supportedLangs {
|
||||
if lang == supported {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetTranslations returns all translations for a given language
|
||||
586
backend/internal/infrastructure/i18n/i18n_test.go
Normal file
586
backend/internal/infrastructure/i18n/i18n_test.go
Normal file
@@ -0,0 +1,586 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package i18n
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ============================================================================
|
||||
// TEST FIXTURES
|
||||
// ============================================================================
|
||||
|
||||
var testLocalesDir = filepath.Join("..", "..", "..", "locales")
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - NewI18n
|
||||
// ============================================================================
|
||||
|
||||
func TestNewI18n_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
i18n, err := NewI18n(testLocalesDir)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, i18n)
|
||||
assert.NotEmpty(t, i18n.translations)
|
||||
assert.Contains(t, i18n.translations, "en")
|
||||
assert.Contains(t, i18n.translations, "fr")
|
||||
}
|
||||
|
||||
func TestNewI18n_InvalidDirectory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
i18n, err := NewI18n("/nonexistent/directory")
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, i18n)
|
||||
assert.Contains(t, err.Error(), "failed to load English translations")
|
||||
}
|
||||
|
||||
func TestNewI18n_MissingEnglishFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create temporary directory without en.json
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
i18n, err := NewI18n(tmpDir)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, i18n)
|
||||
}
|
||||
|
||||
func TestNewI18n_InvalidJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create temporary directory with invalid JSON
|
||||
tmpDir := t.TempDir()
|
||||
err := os.WriteFile(filepath.Join(tmpDir, "en.json"), []byte("invalid json"), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
i18n, err := NewI18n(tmpDir)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, i18n)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - T (Translation)
|
||||
// ============================================================================
|
||||
|
||||
func TestI18n_T_EnglishTranslation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
i18n, err := NewI18n(testLocalesDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test a known key from en.json
|
||||
result := i18n.T("en", "site.title")
|
||||
assert.NotEmpty(t, result)
|
||||
assert.NotEqual(t, "site.title", result, "Should return translation, not key")
|
||||
assert.Contains(t, result, "Ackify", "Should contain 'Ackify'")
|
||||
}
|
||||
|
||||
func TestI18n_T_FrenchTranslation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
i18n, err := NewI18n(testLocalesDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test a known key from fr.json
|
||||
result := i18n.T("fr", "site.title")
|
||||
assert.NotEmpty(t, result)
|
||||
assert.NotEqual(t, "site.title", result, "Should return translation, not key")
|
||||
assert.Contains(t, result, "Ackify", "Should contain 'Ackify'")
|
||||
}
|
||||
|
||||
func TestI18n_T_FallbackToEnglish(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
i18n, err := NewI18n(testLocalesDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Request French translation for a key - should work for existing keys
|
||||
result := i18n.T("fr", "site.title")
|
||||
assert.NotEmpty(t, result)
|
||||
}
|
||||
|
||||
func TestI18n_T_UnknownKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
i18n, err := NewI18n(testLocalesDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
unknownKey := "unknown.key.that.does.not.exist"
|
||||
result := i18n.T("en", unknownKey)
|
||||
|
||||
assert.Equal(t, unknownKey, result, "Should return key itself when translation not found")
|
||||
}
|
||||
|
||||
func TestI18n_T_UnknownLanguage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
i18n, err := NewI18n(testLocalesDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test with unsupported language, should fallback to English
|
||||
result := i18n.T("de", "site.title")
|
||||
assert.NotEmpty(t, result)
|
||||
assert.Contains(t, result, "Ackify", "Should fallback to English translation")
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - GetLangFromRequest
|
||||
// ============================================================================
|
||||
|
||||
func TestGetLangFromRequest_FromCookie(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cookieValue string
|
||||
expectedLang string
|
||||
}{
|
||||
{
|
||||
name: "English cookie",
|
||||
cookieValue: "en",
|
||||
expectedLang: "en",
|
||||
},
|
||||
{
|
||||
name: "French cookie",
|
||||
cookieValue: "fr",
|
||||
expectedLang: "fr",
|
||||
},
|
||||
{
|
||||
name: "English with region",
|
||||
cookieValue: "en-US",
|
||||
expectedLang: "en",
|
||||
},
|
||||
{
|
||||
name: "French with region",
|
||||
cookieValue: "fr-FR",
|
||||
expectedLang: "fr",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: LangCookieName,
|
||||
Value: tt.cookieValue,
|
||||
})
|
||||
|
||||
lang := GetLangFromRequest(req)
|
||||
assert.Equal(t, tt.expectedLang, lang)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLangFromRequest_FromAcceptLanguageHeader(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
acceptLang string
|
||||
expectedLang string
|
||||
}{
|
||||
{
|
||||
name: "English",
|
||||
acceptLang: "en",
|
||||
expectedLang: "en",
|
||||
},
|
||||
{
|
||||
name: "French",
|
||||
acceptLang: "fr",
|
||||
expectedLang: "fr",
|
||||
},
|
||||
{
|
||||
name: "English with quality",
|
||||
acceptLang: "en-US,en;q=0.9",
|
||||
expectedLang: "en",
|
||||
},
|
||||
{
|
||||
name: "French with quality",
|
||||
acceptLang: "fr-FR,fr;q=0.9,en;q=0.8",
|
||||
expectedLang: "fr",
|
||||
},
|
||||
{
|
||||
name: "Unsupported language defaults to English",
|
||||
acceptLang: "de,es",
|
||||
expectedLang: "en",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Accept-Language", tt.acceptLang)
|
||||
|
||||
lang := GetLangFromRequest(req)
|
||||
assert.Equal(t, tt.expectedLang, lang)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLangFromRequest_DefaultToEnglish(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
// No cookie, no Accept-Language header
|
||||
|
||||
lang := GetLangFromRequest(req)
|
||||
assert.Equal(t, DefaultLang, lang)
|
||||
}
|
||||
|
||||
func TestGetLangFromRequest_CookieTakesPrecedence(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: LangCookieName,
|
||||
Value: "fr",
|
||||
})
|
||||
req.Header.Set("Accept-Language", "en")
|
||||
|
||||
lang := GetLangFromRequest(req)
|
||||
assert.Equal(t, "fr", lang, "Cookie should take precedence over Accept-Language header")
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - SetLangCookie
|
||||
// ============================================================================
|
||||
|
||||
func TestSetLangCookie_ValidLanguages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
lang string
|
||||
secureCookies bool
|
||||
expectedLang string
|
||||
expectedSecure bool
|
||||
}{
|
||||
{
|
||||
name: "English",
|
||||
lang: "en",
|
||||
secureCookies: false,
|
||||
expectedLang: "en",
|
||||
expectedSecure: false,
|
||||
},
|
||||
{
|
||||
name: "French",
|
||||
lang: "fr",
|
||||
secureCookies: false,
|
||||
expectedLang: "fr",
|
||||
expectedSecure: false,
|
||||
},
|
||||
{
|
||||
name: "English with secure cookies",
|
||||
lang: "en",
|
||||
secureCookies: true,
|
||||
expectedLang: "en",
|
||||
expectedSecure: true,
|
||||
},
|
||||
{
|
||||
name: "English with region",
|
||||
lang: "en-US",
|
||||
secureCookies: false,
|
||||
expectedLang: "en",
|
||||
expectedSecure: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
SetLangCookie(rec, tt.lang, tt.secureCookies)
|
||||
|
||||
cookies := rec.Result().Cookies()
|
||||
require.Len(t, cookies, 1, "Should set exactly one cookie")
|
||||
|
||||
cookie := cookies[0]
|
||||
assert.Equal(t, LangCookieName, cookie.Name)
|
||||
assert.Equal(t, tt.expectedLang, cookie.Value)
|
||||
assert.Equal(t, "/", cookie.Path)
|
||||
assert.Equal(t, 365*24*60*60, cookie.MaxAge)
|
||||
assert.True(t, cookie.HttpOnly)
|
||||
assert.Equal(t, tt.expectedSecure, cookie.Secure)
|
||||
assert.Equal(t, http.SameSiteLaxMode, cookie.SameSite)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetLangCookie_UnsupportedLanguage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
SetLangCookie(rec, "de", false)
|
||||
|
||||
cookies := rec.Result().Cookies()
|
||||
require.Len(t, cookies, 1)
|
||||
|
||||
cookie := cookies[0]
|
||||
assert.Equal(t, DefaultLang, cookie.Value, "Unsupported language should default to English")
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - normalizeLang
|
||||
// ============================================================================
|
||||
|
||||
func Test_normalizeLang(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "English",
|
||||
input: "en",
|
||||
expected: "en",
|
||||
},
|
||||
{
|
||||
name: "French",
|
||||
input: "fr",
|
||||
expected: "fr",
|
||||
},
|
||||
{
|
||||
name: "English with region",
|
||||
input: "en-US",
|
||||
expected: "en",
|
||||
},
|
||||
{
|
||||
name: "French with region",
|
||||
input: "fr-FR",
|
||||
expected: "fr",
|
||||
},
|
||||
{
|
||||
name: "English uppercase",
|
||||
input: "EN",
|
||||
expected: "en",
|
||||
},
|
||||
{
|
||||
name: "English mixed case",
|
||||
input: "En-Us",
|
||||
expected: "en",
|
||||
},
|
||||
{
|
||||
name: "Other language",
|
||||
input: "de",
|
||||
expected: "de",
|
||||
},
|
||||
{
|
||||
name: "Other language with region",
|
||||
input: "de-DE",
|
||||
expected: "de-de",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result := normalizeLang(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - isSupported
|
||||
// ============================================================================
|
||||
|
||||
func Test_isSupported(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
lang string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "English",
|
||||
lang: "en",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "French",
|
||||
lang: "fr",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "English with region",
|
||||
lang: "en-US",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "French with region",
|
||||
lang: "fr-FR",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "German",
|
||||
lang: "de",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Spanish",
|
||||
lang: "es",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result := isSupported(tt.lang)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - GetTranslations
|
||||
// ============================================================================
|
||||
|
||||
func TestI18n_GetTranslations_English(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
i18n, err := NewI18n(testLocalesDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
translations := i18n.GetTranslations("en")
|
||||
assert.NotEmpty(t, translations)
|
||||
}
|
||||
|
||||
func TestI18n_GetTranslations_French(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
i18n, err := NewI18n(testLocalesDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
translations := i18n.GetTranslations("fr")
|
||||
assert.NotEmpty(t, translations)
|
||||
}
|
||||
|
||||
func TestI18n_GetTranslations_UnsupportedLanguage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
i18n, err := NewI18n(testLocalesDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should fallback to default language (English)
|
||||
translations := i18n.GetTranslations("de")
|
||||
assert.NotEmpty(t, translations)
|
||||
assert.Equal(t, i18n.translations[DefaultLang], translations)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - Concurrency
|
||||
// ============================================================================
|
||||
|
||||
func TestI18n_T_Concurrent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
i18n, err := NewI18n(testLocalesDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
const numGoroutines = 100
|
||||
done := make(chan bool, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer func() { done <- true }()
|
||||
|
||||
lang := "en"
|
||||
if id%2 == 0 {
|
||||
lang = "fr"
|
||||
}
|
||||
|
||||
result := i18n.T(lang, "site.title")
|
||||
assert.NotEmpty(t, result)
|
||||
}(i)
|
||||
}
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// BENCHMARKS
|
||||
// ============================================================================
|
||||
|
||||
func BenchmarkI18n_T(b *testing.B) {
|
||||
i18n, err := NewI18n(testLocalesDir)
|
||||
require.NoError(b, err)
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
i18n.T("en", "site.title")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkI18n_T_Parallel(b *testing.B) {
|
||||
i18n, err := NewI18n(testLocalesDir)
|
||||
require.NoError(b, err)
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
i18n.T("en", "site.title")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkGetLangFromRequest(b *testing.B) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: LangCookieName,
|
||||
Value: "fr",
|
||||
})
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
GetLangFromRequest(req)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetLangFromRequest_Parallel(b *testing.B) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: LangCookieName,
|
||||
Value: "fr",
|
||||
})
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
GetLangFromRequest(req)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkSetLangCookie(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
rec := httptest.NewRecorder()
|
||||
SetLangCookie(rec, "fr", false)
|
||||
}
|
||||
}
|
||||
638
backend/internal/presentation/api/admin/handler.go
Normal file
638
backend/internal/presentation/api/admin/handler.go
Normal file
@@ -0,0 +1,638 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/i18n"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/presentation/api/shared"
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
// documentRepository defines the interface for document operations
|
||||
type documentRepository interface {
|
||||
GetByDocID(ctx context.Context, docID string) (*models.Document, error)
|
||||
List(ctx context.Context, limit, offset int) ([]*models.Document, error)
|
||||
CreateOrUpdate(ctx context.Context, docID string, input models.DocumentInput, createdBy string) (*models.Document, error)
|
||||
Delete(ctx context.Context, docID string) error
|
||||
}
|
||||
|
||||
// expectedSignerRepository defines the interface for expected signer operations
|
||||
type expectedSignerRepository interface {
|
||||
ListByDocID(ctx context.Context, docID string) ([]*models.ExpectedSigner, error)
|
||||
ListWithStatusByDocID(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error)
|
||||
AddExpected(ctx context.Context, docID string, contacts []models.ContactInfo, addedBy string) error
|
||||
Remove(ctx context.Context, docID, email string) error
|
||||
GetStats(ctx context.Context, docID string) (*models.DocCompletionStats, error)
|
||||
}
|
||||
|
||||
// reminderService defines the interface for reminder operations
|
||||
type reminderService interface {
|
||||
SendReminders(ctx context.Context, docID, sentBy string, specificEmails []string, docURL string, locale string) (*models.ReminderSendResult, error)
|
||||
GetReminderHistory(ctx context.Context, docID string) ([]*models.ReminderLog, error)
|
||||
GetReminderStats(ctx context.Context, docID string) (*models.ReminderStats, error)
|
||||
}
|
||||
|
||||
// signatureService defines the interface for signature operations
|
||||
type signatureService interface {
|
||||
GetDocumentSignatures(ctx context.Context, docID string) ([]*models.Signature, error)
|
||||
}
|
||||
|
||||
// Handler handles admin API requests
|
||||
type Handler struct {
|
||||
documentRepo documentRepository
|
||||
expectedSignerRepo expectedSignerRepository
|
||||
reminderService reminderService
|
||||
signatureService signatureService
|
||||
baseURL string
|
||||
}
|
||||
|
||||
// NewHandler creates a new admin handler
|
||||
func NewHandler(documentRepo documentRepository, expectedSignerRepo expectedSignerRepository, reminderService reminderService, signatureService signatureService, baseURL string) *Handler {
|
||||
return &Handler{
|
||||
documentRepo: documentRepo,
|
||||
expectedSignerRepo: expectedSignerRepo,
|
||||
reminderService: reminderService,
|
||||
signatureService: signatureService,
|
||||
baseURL: baseURL,
|
||||
}
|
||||
}
|
||||
|
||||
// DocumentResponse represents a document in API responses
|
||||
type DocumentResponse struct {
|
||||
DocID string `json:"docId"`
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Checksum string `json:"checksum,omitempty"`
|
||||
ChecksumAlgorithm string `json:"checksumAlgorithm,omitempty"`
|
||||
Description string `json:"description"`
|
||||
CreatedAt string `json:"createdAt"`
|
||||
UpdatedAt string `json:"updatedAt"`
|
||||
CreatedBy string `json:"createdBy"`
|
||||
}
|
||||
|
||||
// ExpectedSignerResponse represents an expected signer in API responses
|
||||
type ExpectedSignerResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
DocID string `json:"docId"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
AddedAt string `json:"addedAt"`
|
||||
AddedBy string `json:"addedBy"`
|
||||
Notes *string `json:"notes,omitempty"`
|
||||
HasSigned bool `json:"hasSigned"`
|
||||
SignedAt *string `json:"signedAt,omitempty"`
|
||||
UserName *string `json:"userName,omitempty"`
|
||||
LastReminderSent *string `json:"lastReminderSent,omitempty"`
|
||||
ReminderCount int `json:"reminderCount"`
|
||||
DaysSinceAdded int `json:"daysSinceAdded"`
|
||||
DaysSinceLastReminder *int `json:"daysSinceLastReminder,omitempty"`
|
||||
}
|
||||
|
||||
// DocumentStatsResponse represents document statistics
|
||||
type DocumentStatsResponse struct {
|
||||
DocID string `json:"docId"`
|
||||
ExpectedCount int `json:"expectedCount"`
|
||||
SignedCount int `json:"signedCount"`
|
||||
PendingCount int `json:"pendingCount"`
|
||||
CompletionRate float64 `json:"completionRate"`
|
||||
}
|
||||
|
||||
// UnexpectedSignatureResponse represents an unexpected signature
|
||||
type UnexpectedSignatureResponse struct {
|
||||
UserEmail string `json:"userEmail"`
|
||||
UserName *string `json:"userName,omitempty"`
|
||||
SignedAtUTC string `json:"signedAtUTC"`
|
||||
}
|
||||
|
||||
// HandleListDocuments handles GET /api/v1/admin/documents
|
||||
func (h *Handler) HandleListDocuments(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// TODO: Add pagination parameters
|
||||
limit := 100
|
||||
offset := 0
|
||||
|
||||
documents, err := h.documentRepo.List(ctx, limit, offset)
|
||||
if err != nil {
|
||||
shared.WriteError(w, http.StatusInternalServerError, shared.ErrCodeInternal, "Failed to list documents", nil)
|
||||
return
|
||||
}
|
||||
|
||||
response := make([]*DocumentResponse, 0, len(documents))
|
||||
for _, doc := range documents {
|
||||
response = append(response, toDocumentResponse(doc))
|
||||
}
|
||||
|
||||
meta := map[string]interface{}{
|
||||
"total": len(documents), // For now, just return count of results
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
|
||||
shared.WriteJSONWithMeta(w, http.StatusOK, response, meta)
|
||||
}
|
||||
|
||||
// HandleGetDocument handles GET /api/v1/admin/documents/{docId}
|
||||
func (h *Handler) HandleGetDocument(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
docID := chi.URLParam(r, "docId")
|
||||
|
||||
if docID == "" {
|
||||
shared.WriteError(w, http.StatusBadRequest, shared.ErrCodeBadRequest, "Document ID is required", nil)
|
||||
return
|
||||
}
|
||||
|
||||
document, err := h.documentRepo.GetByDocID(ctx, docID)
|
||||
if err != nil {
|
||||
shared.WriteError(w, http.StatusNotFound, shared.ErrCodeNotFound, "Document not found", nil)
|
||||
return
|
||||
}
|
||||
|
||||
shared.WriteJSON(w, http.StatusOK, toDocumentResponse(document))
|
||||
}
|
||||
|
||||
// HandleGetDocumentWithSigners handles GET /api/v1/admin/documents/{docId}/signers
|
||||
func (h *Handler) HandleGetDocumentWithSigners(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
docID := chi.URLParam(r, "docId")
|
||||
|
||||
if docID == "" {
|
||||
shared.WriteError(w, http.StatusBadRequest, shared.ErrCodeBadRequest, "Document ID is required", nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Get document
|
||||
document, err := h.documentRepo.GetByDocID(ctx, docID)
|
||||
if err != nil {
|
||||
shared.WriteError(w, http.StatusNotFound, shared.ErrCodeNotFound, "Document not found", nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Get expected signers with status
|
||||
signers, err := h.expectedSignerRepo.ListWithStatusByDocID(ctx, docID)
|
||||
if err != nil {
|
||||
shared.WriteError(w, http.StatusInternalServerError, shared.ErrCodeInternal, "Failed to get signers", nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Get completion stats
|
||||
stats, err := h.expectedSignerRepo.GetStats(ctx, docID)
|
||||
if err != nil {
|
||||
shared.WriteError(w, http.StatusInternalServerError, shared.ErrCodeInternal, "Failed to get stats", nil)
|
||||
return
|
||||
}
|
||||
|
||||
signersResponse := make([]*ExpectedSignerResponse, 0, len(signers))
|
||||
for _, signer := range signers {
|
||||
signersResponse = append(signersResponse, toExpectedSignerResponse(signer))
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"document": toDocumentResponse(document),
|
||||
"signers": signersResponse,
|
||||
"stats": toStatsResponse(stats),
|
||||
}
|
||||
|
||||
shared.WriteJSON(w, http.StatusOK, response)
|
||||
}
|
||||
|
||||
// AddExpectedSignerRequest represents the request body for adding an expected signer
|
||||
type AddExpectedSignerRequest struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes,omitempty"`
|
||||
}
|
||||
|
||||
// HandleAddExpectedSigner handles POST /api/v1/admin/documents/{docId}/signers
|
||||
func (h *Handler) HandleAddExpectedSigner(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
docID := chi.URLParam(r, "docId")
|
||||
|
||||
if docID == "" {
|
||||
shared.WriteError(w, http.StatusBadRequest, shared.ErrCodeBadRequest, "Document ID is required", nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Get user from context
|
||||
user, ok := shared.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
shared.WriteUnauthorized(w, "")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request body
|
||||
var req AddExpectedSignerRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
shared.WriteError(w, http.StatusBadRequest, shared.ErrCodeBadRequest, "Invalid request body", nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate
|
||||
if req.Email == "" {
|
||||
shared.WriteError(w, http.StatusBadRequest, shared.ErrCodeBadRequest, "Email is required", nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Add expected signer
|
||||
contacts := []models.ContactInfo{{Email: req.Email, Name: req.Name}}
|
||||
err := h.expectedSignerRepo.AddExpected(ctx, docID, contacts, user.Email)
|
||||
if err != nil {
|
||||
shared.WriteError(w, http.StatusInternalServerError, shared.ErrCodeInternal, "Failed to add expected signer", nil)
|
||||
return
|
||||
}
|
||||
|
||||
shared.WriteJSON(w, http.StatusCreated, map[string]interface{}{
|
||||
"message": "Expected signer added successfully",
|
||||
"email": req.Email,
|
||||
})
|
||||
}
|
||||
|
||||
// HandleRemoveExpectedSigner handles DELETE /api/v1/admin/documents/{docId}/signers/{email}
|
||||
func (h *Handler) HandleRemoveExpectedSigner(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
docID := chi.URLParam(r, "docId")
|
||||
email := chi.URLParam(r, "email")
|
||||
|
||||
if docID == "" || email == "" {
|
||||
shared.WriteError(w, http.StatusBadRequest, shared.ErrCodeBadRequest, "Document ID and email are required", nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Remove expected signer
|
||||
err := h.expectedSignerRepo.Remove(ctx, docID, email)
|
||||
if err != nil {
|
||||
shared.WriteError(w, http.StatusInternalServerError, shared.ErrCodeInternal, "Failed to remove expected signer", nil)
|
||||
return
|
||||
}
|
||||
|
||||
shared.WriteJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"message": "Expected signer removed successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// Helper functions to convert models to API responses
|
||||
func toDocumentResponse(doc *models.Document) *DocumentResponse {
|
||||
return &DocumentResponse{
|
||||
DocID: doc.DocID,
|
||||
Title: doc.Title,
|
||||
URL: doc.URL,
|
||||
Checksum: doc.Checksum,
|
||||
ChecksumAlgorithm: doc.ChecksumAlgorithm,
|
||||
Description: doc.Description,
|
||||
CreatedAt: doc.CreatedAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||
UpdatedAt: doc.UpdatedAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||
CreatedBy: doc.CreatedBy,
|
||||
}
|
||||
}
|
||||
|
||||
func toExpectedSignerResponse(signer *models.ExpectedSignerWithStatus) *ExpectedSignerResponse {
|
||||
response := &ExpectedSignerResponse{
|
||||
ID: signer.ID,
|
||||
DocID: signer.DocID,
|
||||
Email: signer.Email,
|
||||
Name: signer.Name,
|
||||
AddedAt: signer.AddedAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||
AddedBy: signer.AddedBy,
|
||||
Notes: signer.Notes,
|
||||
HasSigned: signer.HasSigned,
|
||||
UserName: signer.UserName,
|
||||
ReminderCount: signer.ReminderCount,
|
||||
DaysSinceAdded: signer.DaysSinceAdded,
|
||||
DaysSinceLastReminder: signer.DaysSinceLastReminder,
|
||||
}
|
||||
|
||||
if signer.SignedAt != nil {
|
||||
signedAt := signer.SignedAt.Format("2006-01-02T15:04:05Z07:00")
|
||||
response.SignedAt = &signedAt
|
||||
}
|
||||
|
||||
if signer.LastReminderSent != nil {
|
||||
lastReminder := signer.LastReminderSent.Format("2006-01-02T15:04:05Z07:00")
|
||||
response.LastReminderSent = &lastReminder
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
func toStatsResponse(stats *models.DocCompletionStats) *DocumentStatsResponse {
|
||||
return &DocumentStatsResponse{
|
||||
DocID: stats.DocID,
|
||||
ExpectedCount: stats.ExpectedCount,
|
||||
SignedCount: stats.SignedCount,
|
||||
PendingCount: stats.PendingCount,
|
||||
CompletionRate: stats.CompletionRate,
|
||||
}
|
||||
}
|
||||
|
||||
// SendRemindersRequest represents the request body for sending reminders
|
||||
type SendRemindersRequest struct {
|
||||
Emails []string `json:"emails,omitempty"` // If empty, send to all pending signers
|
||||
}
|
||||
|
||||
// HandleSendReminders handles POST /api/v1/admin/documents/{docId}/reminders
|
||||
func (h *Handler) HandleSendReminders(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
docID := chi.URLParam(r, "docId")
|
||||
|
||||
if docID == "" {
|
||||
shared.WriteError(w, http.StatusBadRequest, shared.ErrCodeBadRequest, "Document ID is required", nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if reminder service is available
|
||||
if h.reminderService == nil {
|
||||
shared.WriteError(w, http.StatusServiceUnavailable, shared.ErrCodeInternal, "Reminder service not configured", nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Get user from context
|
||||
user, ok := shared.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
shared.WriteUnauthorized(w, "")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request body
|
||||
var req SendRemindersRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
shared.WriteError(w, http.StatusBadRequest, shared.ErrCodeBadRequest, "Invalid request body", nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Get document URL from metadata
|
||||
var docURL string
|
||||
if doc, err := h.documentRepo.GetByDocID(ctx, docID); err == nil && doc != nil && doc.URL != "" {
|
||||
docURL = doc.URL
|
||||
}
|
||||
|
||||
// Get locale from request using i18n helper
|
||||
locale := i18n.GetLangFromRequest(r)
|
||||
|
||||
// Send reminders
|
||||
result, err := h.reminderService.SendReminders(ctx, docID, user.Email, req.Emails, docURL, locale)
|
||||
if err != nil {
|
||||
shared.WriteError(w, http.StatusInternalServerError, shared.ErrCodeInternal, "Failed to send reminders", nil)
|
||||
return
|
||||
}
|
||||
|
||||
shared.WriteJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"message": "Reminders sent",
|
||||
"result": result,
|
||||
})
|
||||
}
|
||||
|
||||
// ReminderLogResponse represents a reminder log entry in API responses
|
||||
type ReminderLogResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
DocID string `json:"docId"`
|
||||
RecipientEmail string `json:"recipientEmail"`
|
||||
SentAt string `json:"sentAt"`
|
||||
SentBy string `json:"sentBy"`
|
||||
TemplateUsed string `json:"templateUsed"`
|
||||
Status string `json:"status"`
|
||||
ErrorMessage *string `json:"errorMessage,omitempty"`
|
||||
}
|
||||
|
||||
// HandleGetReminderHistory handles GET /api/v1/admin/documents/{docId}/reminders
|
||||
func (h *Handler) HandleGetReminderHistory(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
docID := chi.URLParam(r, "docId")
|
||||
|
||||
if docID == "" {
|
||||
shared.WriteError(w, http.StatusBadRequest, shared.ErrCodeBadRequest, "Document ID is required", nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if reminder service is available
|
||||
if h.reminderService == nil {
|
||||
shared.WriteError(w, http.StatusServiceUnavailable, shared.ErrCodeInternal, "Reminder service not configured", nil)
|
||||
return
|
||||
}
|
||||
|
||||
history, err := h.reminderService.GetReminderHistory(ctx, docID)
|
||||
if err != nil {
|
||||
shared.WriteError(w, http.StatusInternalServerError, shared.ErrCodeInternal, "Failed to get reminder history", nil)
|
||||
return
|
||||
}
|
||||
|
||||
response := make([]*ReminderLogResponse, 0, len(history))
|
||||
for _, log := range history {
|
||||
response = append(response, &ReminderLogResponse{
|
||||
ID: log.ID,
|
||||
DocID: log.DocID,
|
||||
RecipientEmail: log.RecipientEmail,
|
||||
SentAt: log.SentAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||
SentBy: log.SentBy,
|
||||
TemplateUsed: log.TemplateUsed,
|
||||
Status: log.Status,
|
||||
ErrorMessage: log.ErrorMessage,
|
||||
})
|
||||
}
|
||||
|
||||
shared.WriteJSON(w, http.StatusOK, response)
|
||||
}
|
||||
|
||||
// UpdateDocumentMetadataRequest represents the request body for updating document metadata
|
||||
type UpdateDocumentMetadataRequest struct {
|
||||
Title *string `json:"title,omitempty"`
|
||||
URL *string `json:"url,omitempty"`
|
||||
Checksum *string `json:"checksum,omitempty"`
|
||||
ChecksumAlgorithm *string `json:"checksumAlgorithm,omitempty"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
}
|
||||
|
||||
// HandleUpdateDocumentMetadata handles PUT /api/v1/admin/documents/{docId}/metadata
|
||||
func (h *Handler) HandleUpdateDocumentMetadata(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
docID := chi.URLParam(r, "docId")
|
||||
|
||||
if docID == "" {
|
||||
shared.WriteError(w, http.StatusBadRequest, shared.ErrCodeBadRequest, "Document ID is required", nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Get user from context
|
||||
user, ok := shared.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
shared.WriteUnauthorized(w, "")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request body
|
||||
var req UpdateDocumentMetadataRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
shared.WriteError(w, http.StatusBadRequest, shared.ErrCodeBadRequest, "Invalid request body", nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Get existing document or create new one
|
||||
doc, err := h.documentRepo.GetByDocID(ctx, docID)
|
||||
if err != nil || doc == nil {
|
||||
// Document doesn't exist, create a new one
|
||||
doc = &models.Document{
|
||||
DocID: docID,
|
||||
CreatedBy: user.Email,
|
||||
}
|
||||
}
|
||||
|
||||
// Update fields if provided
|
||||
if req.Title != nil {
|
||||
doc.Title = *req.Title
|
||||
}
|
||||
if req.URL != nil {
|
||||
doc.URL = *req.URL
|
||||
}
|
||||
if req.Checksum != nil {
|
||||
doc.Checksum = *req.Checksum
|
||||
}
|
||||
if req.ChecksumAlgorithm != nil {
|
||||
doc.ChecksumAlgorithm = *req.ChecksumAlgorithm
|
||||
}
|
||||
if req.Description != nil {
|
||||
doc.Description = *req.Description
|
||||
}
|
||||
|
||||
// Save document using CreateOrUpdate
|
||||
input := models.DocumentInput{
|
||||
Title: doc.Title,
|
||||
URL: doc.URL,
|
||||
Checksum: doc.Checksum,
|
||||
ChecksumAlgorithm: doc.ChecksumAlgorithm,
|
||||
Description: doc.Description,
|
||||
}
|
||||
doc, err = h.documentRepo.CreateOrUpdate(ctx, docID, input, user.Email)
|
||||
if err != nil {
|
||||
shared.WriteError(w, http.StatusInternalServerError, shared.ErrCodeInternal, "Failed to update document metadata", nil)
|
||||
return
|
||||
}
|
||||
|
||||
shared.WriteJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"message": "Document metadata updated successfully",
|
||||
"document": toDocumentResponse(doc),
|
||||
})
|
||||
}
|
||||
|
||||
// DocumentStatusResponse represents complete document status including everything
|
||||
type DocumentStatusResponse struct {
|
||||
DocID string `json:"docId"`
|
||||
Document *DocumentResponse `json:"document,omitempty"`
|
||||
ExpectedSigners []*ExpectedSignerResponse `json:"expectedSigners"`
|
||||
UnexpectedSignatures []*UnexpectedSignatureResponse `json:"unexpectedSignatures"`
|
||||
Stats *DocumentStatsResponse `json:"stats"`
|
||||
ReminderStats *ReminderStatsResponse `json:"reminderStats,omitempty"`
|
||||
ShareLink string `json:"shareLink"`
|
||||
}
|
||||
|
||||
// ReminderStatsResponse represents reminder statistics
|
||||
type ReminderStatsResponse struct {
|
||||
TotalSent int `json:"totalSent"`
|
||||
PendingCount int `json:"pendingCount"`
|
||||
LastSentAt *string `json:"lastSentAt,omitempty"`
|
||||
}
|
||||
|
||||
// HandleGetDocumentStatus handles GET /api/v1/admin/documents/{docId}/status
|
||||
func (h *Handler) HandleGetDocumentStatus(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
docID := chi.URLParam(r, "docId")
|
||||
|
||||
if docID == "" {
|
||||
shared.WriteError(w, http.StatusBadRequest, shared.ErrCodeBadRequest, "Document ID is required", nil)
|
||||
return
|
||||
}
|
||||
|
||||
response := &DocumentStatusResponse{
|
||||
DocID: docID,
|
||||
ExpectedSigners: []*ExpectedSignerResponse{},
|
||||
UnexpectedSignatures: []*UnexpectedSignatureResponse{},
|
||||
ShareLink: h.baseURL + "/?doc=" + docID,
|
||||
}
|
||||
|
||||
// Get document (optional)
|
||||
if doc, err := h.documentRepo.GetByDocID(ctx, docID); err == nil && doc != nil {
|
||||
response.Document = toDocumentResponse(doc)
|
||||
}
|
||||
|
||||
// Get expected signers with status
|
||||
expectedEmails := make(map[string]bool)
|
||||
if signers, err := h.expectedSignerRepo.ListWithStatusByDocID(ctx, docID); err == nil {
|
||||
for _, signer := range signers {
|
||||
response.ExpectedSigners = append(response.ExpectedSigners, toExpectedSignerResponse(signer))
|
||||
expectedEmails[signer.Email] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Get all signatures for this document and find unexpected ones
|
||||
if h.signatureService != nil {
|
||||
if signatures, err := h.signatureService.GetDocumentSignatures(ctx, docID); err == nil {
|
||||
for _, sig := range signatures {
|
||||
// If this signature's email is not in the expected list, it's unexpected
|
||||
if !expectedEmails[sig.UserEmail] {
|
||||
userName := sig.UserName
|
||||
response.UnexpectedSignatures = append(response.UnexpectedSignatures, &UnexpectedSignatureResponse{
|
||||
UserEmail: sig.UserEmail,
|
||||
UserName: &userName,
|
||||
SignedAtUTC: sig.SignedAtUTC.Format("2006-01-02T15:04:05Z07:00"),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get completion stats
|
||||
if stats, err := h.expectedSignerRepo.GetStats(ctx, docID); err == nil {
|
||||
response.Stats = toStatsResponse(stats)
|
||||
} else {
|
||||
// Default stats if no expected signers
|
||||
response.Stats = &DocumentStatsResponse{
|
||||
DocID: docID,
|
||||
ExpectedCount: 0,
|
||||
SignedCount: 0,
|
||||
PendingCount: 0,
|
||||
CompletionRate: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// Get reminder stats if service available
|
||||
if h.reminderService != nil {
|
||||
if reminderStats, err := h.reminderService.GetReminderStats(ctx, docID); err == nil {
|
||||
var lastSentAt *string
|
||||
if reminderStats.LastSentAt != nil {
|
||||
formatted := reminderStats.LastSentAt.Format("2006-01-02T15:04:05Z07:00")
|
||||
lastSentAt = &formatted
|
||||
}
|
||||
response.ReminderStats = &ReminderStatsResponse{
|
||||
TotalSent: reminderStats.TotalSent,
|
||||
PendingCount: reminderStats.PendingCount,
|
||||
LastSentAt: lastSentAt,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
shared.WriteJSON(w, http.StatusOK, response)
|
||||
}
|
||||
|
||||
// HandleDeleteDocument handles DELETE /api/v1/admin/documents/{docId}
|
||||
func (h *Handler) HandleDeleteDocument(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
docID := chi.URLParam(r, "docId")
|
||||
|
||||
if docID == "" {
|
||||
shared.WriteError(w, http.StatusBadRequest, shared.ErrCodeBadRequest, "Document ID is required", nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Delete document (this will cascade delete signatures and expected signers due to DB constraints)
|
||||
err := h.documentRepo.Delete(ctx, docID)
|
||||
if err != nil {
|
||||
shared.WriteError(w, http.StatusInternalServerError, shared.ErrCodeInternal, "Failed to delete document", nil)
|
||||
return
|
||||
}
|
||||
|
||||
shared.WriteJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"message": "Document deleted successfully",
|
||||
})
|
||||
}
|
||||
318
backend/internal/presentation/api/admin/handler_test.go
Normal file
318
backend/internal/presentation/api/admin/handler_test.go
Normal file
@@ -0,0 +1,318 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package admin_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/application/services"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/database"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/presentation/api/admin"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/crypto"
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
func setupTestDB(t *testing.T) *database.TestDB {
|
||||
testDB := database.SetupTestDB(t)
|
||||
|
||||
// Create tables
|
||||
schema := `
|
||||
DROP TABLE IF EXISTS reminder_logs CASCADE;
|
||||
DROP TABLE IF EXISTS expected_signers CASCADE;
|
||||
DROP TABLE IF EXISTS signatures CASCADE;
|
||||
DROP TABLE IF EXISTS documents CASCADE;
|
||||
|
||||
CREATE TABLE documents (
|
||||
doc_id TEXT PRIMARY KEY,
|
||||
title TEXT NOT NULL DEFAULT '',
|
||||
url TEXT NOT NULL DEFAULT '',
|
||||
checksum TEXT NOT NULL DEFAULT '',
|
||||
checksum_algorithm TEXT NOT NULL DEFAULT 'SHA-256',
|
||||
description TEXT NOT NULL DEFAULT '',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
created_by TEXT NOT NULL DEFAULT ''
|
||||
);
|
||||
|
||||
CREATE TABLE signatures (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
doc_id TEXT NOT NULL,
|
||||
user_sub TEXT NOT NULL,
|
||||
user_email TEXT NOT NULL,
|
||||
user_name TEXT NOT NULL DEFAULT '',
|
||||
signed_at TIMESTAMPTZ NOT NULL,
|
||||
payload_hash TEXT NOT NULL,
|
||||
signature TEXT NOT NULL,
|
||||
nonce TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT now(),
|
||||
referer TEXT,
|
||||
prev_hash TEXT,
|
||||
UNIQUE (doc_id, user_sub)
|
||||
);
|
||||
|
||||
CREATE TABLE expected_signers (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
doc_id TEXT NOT NULL,
|
||||
email TEXT NOT NULL,
|
||||
name TEXT NOT NULL DEFAULT '',
|
||||
added_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
added_by TEXT NOT NULL,
|
||||
notes TEXT,
|
||||
UNIQUE (doc_id, email)
|
||||
);
|
||||
|
||||
CREATE TABLE reminder_logs (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
doc_id TEXT NOT NULL,
|
||||
recipient_email TEXT NOT NULL,
|
||||
sent_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
sent_by TEXT NOT NULL,
|
||||
template_used TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
error_message TEXT
|
||||
);
|
||||
`
|
||||
|
||||
_, err := testDB.DB.Exec(schema)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test schema: %v", err)
|
||||
}
|
||||
|
||||
return testDB
|
||||
}
|
||||
|
||||
func TestAdminHandler_GetDocumentStatus_WithUnexpectedSignatures(t *testing.T) {
|
||||
testDB := setupTestDB(t)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Setup repositories and services
|
||||
docRepo := database.NewDocumentRepository(testDB.DB)
|
||||
sigRepo := database.NewSignatureRepository(testDB.DB)
|
||||
expectedSignerRepo := database.NewExpectedSignerRepository(testDB.DB)
|
||||
signer, _ := crypto.NewEd25519Signer()
|
||||
sigService := services.NewSignatureService(sigRepo, docRepo, signer)
|
||||
|
||||
// Create test document
|
||||
docID := "test-doc-001"
|
||||
_, err := docRepo.CreateOrUpdate(ctx, docID, models.DocumentInput{
|
||||
Title: "Test Document",
|
||||
URL: "https://example.com/doc.pdf",
|
||||
Checksum: "abc123",
|
||||
ChecksumAlgorithm: "SHA-256",
|
||||
Description: "Test description",
|
||||
}, "admin@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create document: %v", err)
|
||||
}
|
||||
|
||||
// Add expected signer
|
||||
err = expectedSignerRepo.AddExpected(ctx, docID, []models.ContactInfo{
|
||||
{Email: "expected@example.com", Name: "Expected User"},
|
||||
}, "admin@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add expected signer: %v", err)
|
||||
}
|
||||
|
||||
// Create signature from expected user
|
||||
expectedUser := &models.User{
|
||||
Sub: "expected-sub",
|
||||
Email: "expected@example.com",
|
||||
Name: "Expected User",
|
||||
}
|
||||
err = sigService.CreateSignature(ctx, &models.SignatureRequest{
|
||||
DocID: docID,
|
||||
User: expectedUser,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create expected signature: %v", err)
|
||||
}
|
||||
|
||||
// Create signature from unexpected user
|
||||
unexpectedUser := &models.User{
|
||||
Sub: "unexpected-sub",
|
||||
Email: "unexpected@example.com",
|
||||
Name: "Unexpected User",
|
||||
}
|
||||
err = sigService.CreateSignature(ctx, &models.SignatureRequest{
|
||||
DocID: docID,
|
||||
User: unexpectedUser,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create unexpected signature: %v", err)
|
||||
}
|
||||
|
||||
// Create admin handler
|
||||
handler := admin.NewHandler(docRepo, expectedSignerRepo, nil, sigService, "https://example.com")
|
||||
|
||||
// Create HTTP request
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/documents/"+docID+"/status", nil)
|
||||
rctx := chi.NewRouteContext()
|
||||
rctx.URLParams.Add("docId", docID)
|
||||
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
|
||||
|
||||
// Create response recorder
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Call handler
|
||||
handler.HandleGetDocumentStatus(w, req)
|
||||
|
||||
// Check response
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var response struct {
|
||||
DocID string `json:"docId"`
|
||||
ExpectedSigners []struct {
|
||||
Email string `json:"email"`
|
||||
HasSigned bool `json:"hasSigned"`
|
||||
} `json:"expectedSigners"`
|
||||
UnexpectedSignatures []struct {
|
||||
UserEmail string `json:"userEmail"`
|
||||
UserName *string `json:"userName,omitempty"`
|
||||
SignedAtUTC string `json:"signedAtUTC"`
|
||||
} `json:"unexpectedSignatures"`
|
||||
Stats struct {
|
||||
ExpectedCount int `json:"expectedCount"`
|
||||
SignedCount int `json:"signedCount"`
|
||||
PendingCount int `json:"pendingCount"`
|
||||
CompletionRate float64 `json:"completionRate"`
|
||||
} `json:"stats"`
|
||||
ShareLink string `json:"shareLink"`
|
||||
}
|
||||
|
||||
err = json.NewDecoder(w.Body).Decode(&response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
// Verify response
|
||||
if response.DocID != docID {
|
||||
t.Errorf("Expected docId %s, got %s", docID, response.DocID)
|
||||
}
|
||||
|
||||
// Check expected signers
|
||||
if len(response.ExpectedSigners) != 1 {
|
||||
t.Errorf("Expected 1 expected signer, got %d", len(response.ExpectedSigners))
|
||||
} else {
|
||||
if response.ExpectedSigners[0].Email != "expected@example.com" {
|
||||
t.Errorf("Expected email 'expected@example.com', got '%s'", response.ExpectedSigners[0].Email)
|
||||
}
|
||||
if !response.ExpectedSigners[0].HasSigned {
|
||||
t.Error("Expected signer should have signed")
|
||||
}
|
||||
}
|
||||
|
||||
// Check unexpected signatures
|
||||
if len(response.UnexpectedSignatures) != 1 {
|
||||
t.Fatalf("Expected 1 unexpected signature, got %d", len(response.UnexpectedSignatures))
|
||||
}
|
||||
if response.UnexpectedSignatures[0].UserEmail != "unexpected@example.com" {
|
||||
t.Errorf("Expected unexpected email 'unexpected@example.com', got '%s'", response.UnexpectedSignatures[0].UserEmail)
|
||||
}
|
||||
if response.UnexpectedSignatures[0].UserName == nil || *response.UnexpectedSignatures[0].UserName != "Unexpected User" {
|
||||
t.Error("Expected unexpected userName to be 'Unexpected User'")
|
||||
}
|
||||
|
||||
// Check stats
|
||||
if response.Stats.ExpectedCount != 1 {
|
||||
t.Errorf("Expected expectedCount 1, got %d", response.Stats.ExpectedCount)
|
||||
}
|
||||
if response.Stats.SignedCount != 1 {
|
||||
t.Errorf("Expected signedCount 1, got %d", response.Stats.SignedCount)
|
||||
}
|
||||
if response.Stats.CompletionRate != 100.0 {
|
||||
t.Errorf("Expected completionRate 100.0, got %f", response.Stats.CompletionRate)
|
||||
}
|
||||
|
||||
// Check share link
|
||||
expectedShareLink := "https://example.com/?doc=" + docID
|
||||
if response.ShareLink != expectedShareLink {
|
||||
t.Errorf("Expected shareLink '%s', got '%s'", expectedShareLink, response.ShareLink)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminHandler_GetDocumentStatus_NoExpectedSigners(t *testing.T) {
|
||||
testDB := setupTestDB(t)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Setup repositories and services
|
||||
docRepo := database.NewDocumentRepository(testDB.DB)
|
||||
sigRepo := database.NewSignatureRepository(testDB.DB)
|
||||
expectedSignerRepo := database.NewExpectedSignerRepository(testDB.DB)
|
||||
signer, _ := crypto.NewEd25519Signer()
|
||||
sigService := services.NewSignatureService(sigRepo, docRepo, signer)
|
||||
|
||||
// Create test document
|
||||
docID := "test-doc-002"
|
||||
|
||||
// Create signature from user (no expected signers)
|
||||
user := &models.User{
|
||||
Sub: "user-sub",
|
||||
Email: "user@example.com",
|
||||
Name: "Test User",
|
||||
}
|
||||
err := sigService.CreateSignature(ctx, &models.SignatureRequest{
|
||||
DocID: docID,
|
||||
User: user,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create signature: %v", err)
|
||||
}
|
||||
|
||||
// Create admin handler
|
||||
handler := admin.NewHandler(docRepo, expectedSignerRepo, nil, sigService, "https://example.com")
|
||||
|
||||
// Create HTTP request
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/documents/"+docID+"/status", nil)
|
||||
rctx := chi.NewRouteContext()
|
||||
rctx.URLParams.Add("docId", docID)
|
||||
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
|
||||
|
||||
// Create response recorder
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Call handler
|
||||
handler.HandleGetDocumentStatus(w, req)
|
||||
|
||||
// Check response
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var response struct {
|
||||
ExpectedSigners []interface{} `json:"expectedSigners"`
|
||||
UnexpectedSignatures []struct {
|
||||
UserEmail string `json:"userEmail"`
|
||||
} `json:"unexpectedSignatures"`
|
||||
}
|
||||
|
||||
err = json.NewDecoder(w.Body).Decode(&response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
// Verify response
|
||||
if len(response.ExpectedSigners) != 0 {
|
||||
t.Errorf("Expected 0 expected signers, got %d", len(response.ExpectedSigners))
|
||||
}
|
||||
|
||||
// All signatures should be unexpected since there are no expected signers
|
||||
if len(response.UnexpectedSignatures) != 1 {
|
||||
t.Fatalf("Expected 1 unexpected signature, got %d", len(response.UnexpectedSignatures))
|
||||
}
|
||||
if response.UnexpectedSignatures[0].UserEmail != "user@example.com" {
|
||||
t.Errorf("Expected email 'user@example.com', got '%s'", response.UnexpectedSignatures[0].UserEmail)
|
||||
}
|
||||
}
|
||||
1341
backend/internal/presentation/api/admin/handler_unit_test.go
Normal file
1341
backend/internal/presentation/api/admin/handler_unit_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package handlers
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
@@ -8,80 +8,80 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/btouchard/ackify-ce/pkg/logger"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/auth"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/presentation/api/shared"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/presentation/handlers"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/logger"
|
||||
)
|
||||
|
||||
type AuthHandlers struct {
|
||||
authService authService
|
||||
// Handler handles authentication API requests
|
||||
type Handler struct {
|
||||
authService *auth.OauthService
|
||||
middleware *shared.Middleware
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func NewAuthHandlers(authService authService, baseURL string) *AuthHandlers {
|
||||
return &AuthHandlers{
|
||||
// NewHandler creates a new auth handler
|
||||
func NewHandler(authService *auth.OauthService, middleware *shared.Middleware, baseURL string) *Handler {
|
||||
return &Handler{
|
||||
authService: authService,
|
||||
middleware: middleware,
|
||||
baseURL: baseURL,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *AuthHandlers) HandleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
next := r.URL.Query().Get("next")
|
||||
if next == "" {
|
||||
next = h.baseURL + "/"
|
||||
}
|
||||
|
||||
logger.Logger.Debug("HandleLogin: starting OAuth flow",
|
||||
"next_url", next,
|
||||
"query_params", r.URL.Query().Encode())
|
||||
|
||||
// Persist CSRF state in session when generating auth URL
|
||||
authURL := h.authService.CreateAuthURL(w, r, next)
|
||||
|
||||
logger.Logger.Debug("HandleLogin: redirecting to OAuth provider",
|
||||
"auth_url", authURL)
|
||||
|
||||
http.Redirect(w, r, authURL, http.StatusFound)
|
||||
}
|
||||
|
||||
func (h *AuthHandlers) HandleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
h.authService.Logout(w, r)
|
||||
|
||||
// Redirect to SSO logout if configured, otherwise redirect to home
|
||||
ssoLogoutURL := h.authService.GetLogoutURL()
|
||||
if ssoLogoutURL != "" {
|
||||
http.Redirect(w, r, ssoLogoutURL, http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
}
|
||||
|
||||
func (h *AuthHandlers) HandleAuthCheck(w http.ResponseWriter, r *http.Request) {
|
||||
user, err := h.authService.GetUser(r)
|
||||
// HandleGetCSRFToken handles GET /api/v1/csrf
|
||||
func (h *Handler) HandleGetCSRFToken(w http.ResponseWriter, r *http.Request) {
|
||||
token, err := h.middleware.GenerateCSRFToken()
|
||||
if err != nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"authenticated":false}`))
|
||||
shared.WriteInternalError(w)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
response := map[string]interface{}{
|
||||
"authenticated": true,
|
||||
"user": map[string]string{
|
||||
"email": user.Email,
|
||||
"name": user.Name,
|
||||
},
|
||||
}
|
||||
// Set cookie for the token
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: shared.CSRFTokenCookie,
|
||||
Value: token,
|
||||
Path: "/",
|
||||
HttpOnly: false, // Allow JS to read it
|
||||
Secure: r.TLS != nil,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: 86400, // 24 hours
|
||||
})
|
||||
|
||||
if jsonBytes, err := json.Marshal(response); err == nil {
|
||||
w.Write(jsonBytes)
|
||||
} else {
|
||||
w.Write([]byte(`{"authenticated":false}`))
|
||||
}
|
||||
shared.WriteJSON(w, http.StatusOK, map[string]string{
|
||||
"token": token,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AuthHandlers) HandleOAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||
// HandleStartOAuth handles POST /api/v1/auth/start
|
||||
func (h *Handler) HandleStartOAuth(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
RedirectTo string `json:"redirectTo"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
// If no body, that's fine, use default redirect
|
||||
req.RedirectTo = "/"
|
||||
}
|
||||
|
||||
// Default to home if no redirect specified
|
||||
if req.RedirectTo == "" {
|
||||
req.RedirectTo = "/"
|
||||
}
|
||||
|
||||
// Generate OAuth URL and save state in session
|
||||
// This is critical - CreateAuthURL saves the state token in session
|
||||
// which will be validated when Google redirects to /api/v1/auth/callback
|
||||
authURL := h.authService.CreateAuthURL(w, r, req.RedirectTo)
|
||||
|
||||
// Return redirect URL for SPA to handle
|
||||
shared.WriteJSON(w, http.StatusOK, map[string]string{
|
||||
"redirectUrl": authURL,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) HandleOAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
oauthError := r.URL.Query().Get("error")
|
||||
@@ -149,7 +149,7 @@ func (h *AuthHandlers) HandleOAuthCallback(w http.ResponseWriter, r *http.Reques
|
||||
user, nextURL, err := h.authService.HandleCallback(ctx, code, state)
|
||||
if err != nil {
|
||||
logger.Logger.Error("OAuth callback failed", "error", err.Error())
|
||||
HandleError(w, err)
|
||||
handlers.HandleError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -180,3 +180,45 @@ func (h *AuthHandlers) HandleOAuthCallback(w http.ResponseWriter, r *http.Reques
|
||||
|
||||
http.Redirect(w, r, nextURL, http.StatusFound)
|
||||
}
|
||||
|
||||
// HandleLogout handles GET /api/v1/auth/logout
|
||||
func (h *Handler) HandleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
// Clear session
|
||||
h.authService.Logout(w, r)
|
||||
|
||||
// Check if SSO logout is configured
|
||||
logoutURL := h.authService.GetLogoutURL()
|
||||
if logoutURL != "" {
|
||||
returnURL := h.baseURL + "/"
|
||||
fullLogoutURL := logoutURL + "?post_logout_redirect_uri=" + url.QueryEscape(returnURL)
|
||||
|
||||
shared.WriteJSON(w, http.StatusOK, map[string]string{
|
||||
"message": "Successfully logged out",
|
||||
"redirectUrl": fullLogoutURL,
|
||||
})
|
||||
} else {
|
||||
shared.WriteJSON(w, http.StatusOK, map[string]string{
|
||||
"message": "Successfully logged out",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// HandleAuthCheck handles GET /api/v1/auth/check
|
||||
func (h *Handler) HandleAuthCheck(w http.ResponseWriter, r *http.Request) {
|
||||
user, err := h.authService.GetUser(r)
|
||||
if err != nil || user == nil {
|
||||
shared.WriteJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"authenticated": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
shared.WriteJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"authenticated": true,
|
||||
"user": map[string]interface{}{
|
||||
"id": user.Sub,
|
||||
"email": user.Email,
|
||||
"name": user.Name,
|
||||
},
|
||||
})
|
||||
}
|
||||
906
backend/internal/presentation/api/auth/handler_test.go
Normal file
906
backend/internal/presentation/api/auth/handler_test.go
Normal file
@@ -0,0 +1,906 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/securecookie"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/auth"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/presentation/api/shared"
|
||||
)
|
||||
|
||||
// ============================================================================
|
||||
// TEST FIXTURES
|
||||
// ============================================================================
|
||||
|
||||
const (
|
||||
testBaseURL = "https://example.com"
|
||||
testClientID = "test-client-id"
|
||||
testClientSecret = "test-client-secret"
|
||||
testAuthURL = "https://oauth.example.com/authorize"
|
||||
testTokenURL = "https://oauth.example.com/token"
|
||||
testUserInfoURL = "https://oauth.example.com/userinfo"
|
||||
testLogoutURL = "https://oauth.example.com/logout"
|
||||
)
|
||||
|
||||
var (
|
||||
testCookieSecret = securecookie.GenerateRandomKey(32)
|
||||
|
||||
testUser = &models.User{
|
||||
Sub: "oauth2|123456789",
|
||||
Email: "user@example.com",
|
||||
Name: "Test User",
|
||||
}
|
||||
)
|
||||
|
||||
// ============================================================================
|
||||
// HELPER FUNCTIONS
|
||||
// ============================================================================
|
||||
|
||||
func createTestAuthService() *auth.OauthService {
|
||||
return auth.NewOAuthService(auth.Config{
|
||||
BaseURL: testBaseURL,
|
||||
ClientID: testClientID,
|
||||
ClientSecret: testClientSecret,
|
||||
AuthURL: testAuthURL,
|
||||
TokenURL: testTokenURL,
|
||||
UserInfoURL: testUserInfoURL,
|
||||
LogoutURL: testLogoutURL,
|
||||
Scopes: []string{"openid", "email", "profile"},
|
||||
AllowedDomain: "",
|
||||
CookieSecret: testCookieSecret,
|
||||
SecureCookies: false, // false for testing (no HTTPS)
|
||||
})
|
||||
}
|
||||
|
||||
func createTestMiddleware() *shared.Middleware {
|
||||
authService := createTestAuthService()
|
||||
return shared.NewMiddleware(authService, testBaseURL, []string{})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - Constructor
|
||||
// ============================================================================
|
||||
|
||||
func TestNewHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
authService *auth.OauthService
|
||||
middleware *shared.Middleware
|
||||
baseURL string
|
||||
}{
|
||||
{
|
||||
name: "with valid dependencies",
|
||||
authService: createTestAuthService(),
|
||||
middleware: createTestMiddleware(),
|
||||
baseURL: testBaseURL,
|
||||
},
|
||||
{
|
||||
name: "with empty baseURL",
|
||||
authService: createTestAuthService(),
|
||||
middleware: createTestMiddleware(),
|
||||
baseURL: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := NewHandler(tt.authService, tt.middleware, tt.baseURL)
|
||||
|
||||
assert.NotNil(t, handler)
|
||||
assert.NotNil(t, handler.authService)
|
||||
assert.NotNil(t, handler.middleware)
|
||||
assert.Equal(t, tt.baseURL, handler.baseURL)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - HandleAuthCheck
|
||||
// ============================================================================
|
||||
|
||||
func TestHandler_HandleAuthCheck_Authenticated(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
authService := createTestAuthService()
|
||||
handler := NewHandler(authService, createTestMiddleware(), testBaseURL)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/check", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Set user in session
|
||||
err := authService.SetUser(rec, req, testUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get the session cookie from the recorder
|
||||
cookies := rec.Result().Cookies()
|
||||
require.NotEmpty(t, cookies, "Session cookie should be set")
|
||||
|
||||
// Create a new request with the session cookie
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/api/v1/auth/check", nil)
|
||||
for _, cookie := range cookies {
|
||||
req2.AddCookie(cookie)
|
||||
}
|
||||
rec2 := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
handler.HandleAuthCheck(rec2, req2)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec2.Code)
|
||||
assert.Equal(t, "application/json", rec2.Header().Get("Content-Type"))
|
||||
|
||||
// Parse response
|
||||
var wrapper struct {
|
||||
Data struct {
|
||||
Authenticated bool `json:"authenticated"`
|
||||
User map[string]interface{} `json:"user"`
|
||||
} `json:"data"`
|
||||
}
|
||||
err = json.Unmarshal(rec2.Body.Bytes(), &wrapper)
|
||||
require.NoError(t, err, "Response should be valid JSON")
|
||||
|
||||
// Validate fields
|
||||
assert.True(t, wrapper.Data.Authenticated)
|
||||
assert.NotNil(t, wrapper.Data.User)
|
||||
assert.Equal(t, testUser.Sub, wrapper.Data.User["id"])
|
||||
assert.Equal(t, testUser.Email, wrapper.Data.User["email"])
|
||||
assert.Equal(t, testUser.Name, wrapper.Data.User["name"])
|
||||
}
|
||||
|
||||
func TestHandler_HandleAuthCheck_NotAuthenticated(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFunc func(*http.Request) *http.Request
|
||||
}{
|
||||
{
|
||||
name: "no session cookie",
|
||||
setupFunc: func(req *http.Request) *http.Request {
|
||||
return req // No modifications
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid session cookie",
|
||||
setupFunc: func(req *http.Request) *http.Request {
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "ackapp_session",
|
||||
Value: "invalid-cookie-value",
|
||||
})
|
||||
return req
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/check", nil)
|
||||
req = tt.setupFunc(req)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
handler.HandleAuthCheck(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// Parse response
|
||||
var wrapper struct {
|
||||
Data struct {
|
||||
Authenticated bool `json:"authenticated"`
|
||||
} `json:"data"`
|
||||
}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &wrapper)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, wrapper.Data.Authenticated)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_HandleAuthCheck_ResponseFormat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/check", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleAuthCheck(rec, req)
|
||||
|
||||
// Check Content-Type
|
||||
assert.Equal(t, "application/json", rec.Header().Get("Content-Type"))
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// Validate JSON structure
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check wrapper structure
|
||||
assert.Contains(t, response, "data")
|
||||
|
||||
// Get data object
|
||||
data, ok := response["data"].(map[string]interface{})
|
||||
require.True(t, ok, "data should be an object")
|
||||
|
||||
// Check required field
|
||||
assert.Contains(t, data, "authenticated")
|
||||
|
||||
// Validate field type
|
||||
_, ok = data["authenticated"].(bool)
|
||||
assert.True(t, ok, "authenticated should be a boolean")
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - HandleLogout
|
||||
// ============================================================================
|
||||
|
||||
func TestHandler_HandleLogout_WithSSO(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
authService := createTestAuthService()
|
||||
handler := NewHandler(authService, createTestMiddleware(), testBaseURL)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/logout", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Set user in session first
|
||||
err := authService.SetUser(rec, req, testUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get the session cookie
|
||||
cookies := rec.Result().Cookies()
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/api/v1/auth/logout", nil)
|
||||
for _, cookie := range cookies {
|
||||
req2.AddCookie(cookie)
|
||||
}
|
||||
rec2 := httptest.NewRecorder()
|
||||
|
||||
// Execute logout
|
||||
handler.HandleLogout(rec2, req2)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec2.Code)
|
||||
assert.Equal(t, "application/json", rec2.Header().Get("Content-Type"))
|
||||
|
||||
// Parse response
|
||||
var wrapper struct {
|
||||
Data struct {
|
||||
Message string `json:"message"`
|
||||
RedirectURL string `json:"redirectUrl"`
|
||||
} `json:"data"`
|
||||
}
|
||||
err = json.Unmarshal(rec2.Body.Bytes(), &wrapper)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "Successfully logged out", wrapper.Data.Message)
|
||||
assert.Contains(t, wrapper.Data.RedirectURL, testLogoutURL)
|
||||
assert.Contains(t, wrapper.Data.RedirectURL, "post_logout_redirect_uri")
|
||||
assert.Contains(t, wrapper.Data.RedirectURL, testBaseURL)
|
||||
}
|
||||
|
||||
func TestHandler_HandleLogout_WithoutSSO(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create auth service without logout URL
|
||||
authService := auth.NewOAuthService(auth.Config{
|
||||
BaseURL: testBaseURL,
|
||||
ClientID: testClientID,
|
||||
ClientSecret: testClientSecret,
|
||||
AuthURL: testAuthURL,
|
||||
TokenURL: testTokenURL,
|
||||
UserInfoURL: testUserInfoURL,
|
||||
LogoutURL: "", // No SSO logout
|
||||
Scopes: []string{"openid", "email", "profile"},
|
||||
CookieSecret: testCookieSecret,
|
||||
SecureCookies: false,
|
||||
})
|
||||
|
||||
handler := NewHandler(authService, createTestMiddleware(), testBaseURL)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/logout", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
handler.HandleLogout(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// Parse response
|
||||
var wrapper struct {
|
||||
Data struct {
|
||||
Message string `json:"message"`
|
||||
RedirectURL string `json:"redirectUrl,omitempty"`
|
||||
} `json:"data"`
|
||||
}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &wrapper)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "Successfully logged out", wrapper.Data.Message)
|
||||
assert.Empty(t, wrapper.Data.RedirectURL)
|
||||
}
|
||||
|
||||
func TestHandler_HandleLogout_ClearsSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
authService := createTestAuthService()
|
||||
handler := NewHandler(authService, createTestMiddleware(), testBaseURL)
|
||||
|
||||
// Set user in session
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/logout", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
err := authService.SetUser(rec, req, testUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get the session cookie
|
||||
cookies := rec.Result().Cookies()
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/api/v1/auth/logout", nil)
|
||||
for _, cookie := range cookies {
|
||||
req2.AddCookie(cookie)
|
||||
}
|
||||
rec2 := httptest.NewRecorder()
|
||||
|
||||
// Execute logout
|
||||
handler.HandleLogout(rec2, req2)
|
||||
|
||||
// Verify session is cleared by checking the Set-Cookie header
|
||||
setCookieHeaders := rec2.Header().Values("Set-Cookie")
|
||||
assert.NotEmpty(t, setCookieHeaders, "Should set cookie to clear session")
|
||||
|
||||
// Check that MaxAge is negative (cookie deletion)
|
||||
foundMaxAge := false
|
||||
for _, setCookie := range setCookieHeaders {
|
||||
if strings.Contains(setCookie, "Max-Age") && strings.Contains(setCookie, "ackapp_session") {
|
||||
foundMaxAge = true
|
||||
// Should contain negative Max-Age or Max-Age=0
|
||||
assert.True(t, strings.Contains(setCookie, "Max-Age=-1") || strings.Contains(setCookie, "Max-Age=0"),
|
||||
"Cookie should be deleted with negative Max-Age")
|
||||
}
|
||||
}
|
||||
assert.True(t, foundMaxAge, "Should set Max-Age for session cookie")
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - HandleStartOAuth
|
||||
// ============================================================================
|
||||
|
||||
func TestHandler_HandleStartOAuth_WithRedirect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
requestBody map[string]string
|
||||
expectedURL string
|
||||
}{
|
||||
{
|
||||
name: "with custom redirect path",
|
||||
requestBody: map[string]string{"redirectTo": "/dashboard"},
|
||||
expectedURL: "/dashboard",
|
||||
},
|
||||
{
|
||||
name: "with root redirect",
|
||||
requestBody: map[string]string{"redirectTo": "/"},
|
||||
expectedURL: "/",
|
||||
},
|
||||
{
|
||||
name: "with empty redirect",
|
||||
requestBody: map[string]string{"redirectTo": ""},
|
||||
expectedURL: "/",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
|
||||
|
||||
body, err := json.Marshal(tt.requestBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/start", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
handler.HandleStartOAuth(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "application/json", rec.Header().Get("Content-Type"))
|
||||
|
||||
// Parse response
|
||||
var wrapper struct {
|
||||
Data struct {
|
||||
RedirectURL string `json:"redirectUrl"`
|
||||
} `json:"data"`
|
||||
}
|
||||
err = json.Unmarshal(rec.Body.Bytes(), &wrapper)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validate redirect URL contains OAuth provider URL
|
||||
assert.NotEmpty(t, wrapper.Data.RedirectURL)
|
||||
assert.Contains(t, wrapper.Data.RedirectURL, testAuthURL)
|
||||
assert.Contains(t, wrapper.Data.RedirectURL, "client_id="+testClientID)
|
||||
assert.Contains(t, wrapper.Data.RedirectURL, "redirect_uri=")
|
||||
assert.Contains(t, wrapper.Data.RedirectURL, "state=")
|
||||
|
||||
// Check that session cookie was set (for state verification)
|
||||
cookies := rec.Result().Cookies()
|
||||
assert.NotEmpty(t, cookies, "Session cookie should be set for OAuth state")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_HandleStartOAuth_NoBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/start", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
handler.HandleStartOAuth(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// Parse response
|
||||
var wrapper struct {
|
||||
Data struct {
|
||||
RedirectURL string `json:"redirectUrl"`
|
||||
} `json:"data"`
|
||||
}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &wrapper)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should default to root redirect
|
||||
assert.NotEmpty(t, wrapper.Data.RedirectURL)
|
||||
assert.Contains(t, wrapper.Data.RedirectURL, testAuthURL)
|
||||
}
|
||||
|
||||
func TestHandler_HandleStartOAuth_InvalidJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/start", bytes.NewReader([]byte("invalid-json")))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
handler.HandleStartOAuth(rec, req)
|
||||
|
||||
// Assert - should still succeed and default to "/"
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// Parse response
|
||||
var wrapper struct {
|
||||
Data struct {
|
||||
RedirectURL string `json:"redirectUrl"`
|
||||
} `json:"data"`
|
||||
}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &wrapper)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEmpty(t, wrapper.Data.RedirectURL)
|
||||
}
|
||||
|
||||
func TestHandler_HandleStartOAuth_ResponseFormat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/start", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleStartOAuth(rec, req)
|
||||
|
||||
// Check Content-Type
|
||||
assert.Equal(t, "application/json", rec.Header().Get("Content-Type"))
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// Validate JSON structure
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check wrapper structure
|
||||
assert.Contains(t, response, "data")
|
||||
|
||||
// Get data object
|
||||
data, ok := response["data"].(map[string]interface{})
|
||||
require.True(t, ok, "data should be an object")
|
||||
|
||||
// Check required field
|
||||
assert.Contains(t, data, "redirectUrl")
|
||||
|
||||
// Validate field type
|
||||
redirectURL, ok := data["redirectUrl"].(string)
|
||||
assert.True(t, ok, "redirectUrl should be a string")
|
||||
assert.NotEmpty(t, redirectURL)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - HandleGetCSRFToken
|
||||
// ============================================================================
|
||||
|
||||
func TestHandler_HandleGetCSRFToken_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/csrf", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
handler.HandleGetCSRFToken(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "application/json", rec.Header().Get("Content-Type"))
|
||||
|
||||
// Parse response
|
||||
var wrapper struct {
|
||||
Data struct {
|
||||
Token string `json:"token"`
|
||||
} `json:"data"`
|
||||
}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &wrapper)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validate token
|
||||
assert.NotEmpty(t, wrapper.Data.Token)
|
||||
assert.Greater(t, len(wrapper.Data.Token), 20, "CSRF token should be sufficiently long")
|
||||
|
||||
// Check cookie was set
|
||||
cookies := rec.Result().Cookies()
|
||||
var csrfCookie *http.Cookie
|
||||
for _, cookie := range cookies {
|
||||
if cookie.Name == shared.CSRFTokenCookie {
|
||||
csrfCookie = cookie
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, csrfCookie, "CSRF cookie should be set")
|
||||
assert.Equal(t, wrapper.Data.Token, csrfCookie.Value)
|
||||
assert.Equal(t, "/", csrfCookie.Path)
|
||||
assert.False(t, csrfCookie.HttpOnly, "CSRF cookie should be readable by JS")
|
||||
assert.Equal(t, http.SameSiteLaxMode, csrfCookie.SameSite)
|
||||
assert.Equal(t, 86400, csrfCookie.MaxAge, "CSRF token should have 24h lifetime")
|
||||
}
|
||||
|
||||
func TestHandler_HandleGetCSRFToken_ResponseFormat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/csrf", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleGetCSRFToken(rec, req)
|
||||
|
||||
// Check Content-Type
|
||||
assert.Equal(t, "application/json", rec.Header().Get("Content-Type"))
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// Validate JSON structure
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check wrapper structure
|
||||
assert.Contains(t, response, "data")
|
||||
|
||||
// Get data object
|
||||
data, ok := response["data"].(map[string]interface{})
|
||||
require.True(t, ok, "data should be an object")
|
||||
|
||||
// Check required field
|
||||
assert.Contains(t, data, "token")
|
||||
|
||||
// Validate field type
|
||||
token, ok := data["token"].(string)
|
||||
assert.True(t, ok, "token should be a string")
|
||||
assert.NotEmpty(t, token)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - Concurrency
|
||||
// ============================================================================
|
||||
|
||||
func TestHandler_HandleAuthCheck_Concurrent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
authService := createTestAuthService()
|
||||
handler := NewHandler(authService, createTestMiddleware(), testBaseURL)
|
||||
|
||||
const numRequests = 100
|
||||
done := make(chan bool, numRequests)
|
||||
errors := make(chan error, numRequests)
|
||||
|
||||
// Spawn concurrent requests
|
||||
for i := 0; i < numRequests; i++ {
|
||||
go func(id int) {
|
||||
defer func() { done <- true }()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/check", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Half with session, half without
|
||||
if id%2 == 0 {
|
||||
err := authService.SetUser(rec, req, testUser)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
cookies := rec.Result().Cookies()
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/api/v1/auth/check", nil)
|
||||
for _, cookie := range cookies {
|
||||
req2.AddCookie(cookie)
|
||||
}
|
||||
rec2 := httptest.NewRecorder()
|
||||
handler.HandleAuthCheck(rec2, req2)
|
||||
|
||||
if rec2.Code != http.StatusOK {
|
||||
errors <- assert.AnError
|
||||
}
|
||||
} else {
|
||||
handler.HandleAuthCheck(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
errors <- assert.AnError
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all requests
|
||||
for i := 0; i < numRequests; i++ {
|
||||
<-done
|
||||
}
|
||||
close(errors)
|
||||
|
||||
// Check for errors
|
||||
var errCount int
|
||||
for err := range errors {
|
||||
t.Logf("Concurrent request error: %v", err)
|
||||
errCount++
|
||||
}
|
||||
|
||||
assert.Equal(t, 0, errCount, "All concurrent requests should succeed")
|
||||
}
|
||||
|
||||
func TestHandler_HandleLogout_Concurrent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
|
||||
|
||||
const numRequests = 100
|
||||
done := make(chan bool, numRequests)
|
||||
errors := make(chan error, numRequests)
|
||||
|
||||
// Spawn concurrent logout requests
|
||||
for i := 0; i < numRequests; i++ {
|
||||
go func() {
|
||||
defer func() { done <- true }()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/logout", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleLogout(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
errors <- assert.AnError
|
||||
}
|
||||
|
||||
var wrapper struct {
|
||||
Data struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &wrapper); err != nil {
|
||||
errors <- err
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all requests
|
||||
for i := 0; i < numRequests; i++ {
|
||||
<-done
|
||||
}
|
||||
close(errors)
|
||||
|
||||
// Check for errors
|
||||
var errCount int
|
||||
for err := range errors {
|
||||
t.Logf("Concurrent request error: %v", err)
|
||||
errCount++
|
||||
}
|
||||
|
||||
assert.Equal(t, 0, errCount, "All concurrent requests should succeed")
|
||||
}
|
||||
|
||||
func TestHandler_HandleStartOAuth_Concurrent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
|
||||
|
||||
const numRequests = 100
|
||||
done := make(chan bool, numRequests)
|
||||
errors := make(chan error, numRequests)
|
||||
|
||||
// Spawn concurrent OAuth start requests
|
||||
for i := 0; i < numRequests; i++ {
|
||||
go func() {
|
||||
defer func() { done <- true }()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/start", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleStartOAuth(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
errors <- assert.AnError
|
||||
}
|
||||
|
||||
var wrapper struct {
|
||||
Data struct {
|
||||
RedirectURL string `json:"redirectUrl"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &wrapper); err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
|
||||
if wrapper.Data.RedirectURL == "" {
|
||||
errors <- assert.AnError
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all requests
|
||||
for i := 0; i < numRequests; i++ {
|
||||
<-done
|
||||
}
|
||||
close(errors)
|
||||
|
||||
// Check for errors
|
||||
var errCount int
|
||||
for err := range errors {
|
||||
t.Logf("Concurrent request error: %v", err)
|
||||
errCount++
|
||||
}
|
||||
|
||||
assert.Equal(t, 0, errCount, "All concurrent requests should succeed")
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// BENCHMARKS
|
||||
// ============================================================================
|
||||
|
||||
func BenchmarkHandler_HandleAuthCheck(b *testing.B) {
|
||||
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/check", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleAuthCheck(rec, req)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHandler_HandleAuthCheck_Parallel(b *testing.B) {
|
||||
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/check", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleAuthCheck(rec, req)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkHandler_HandleLogout(b *testing.B) {
|
||||
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/logout", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleLogout(rec, req)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHandler_HandleLogout_Parallel(b *testing.B) {
|
||||
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/logout", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleLogout(rec, req)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkHandler_HandleStartOAuth(b *testing.B) {
|
||||
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/start", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleStartOAuth(rec, req)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHandler_HandleStartOAuth_Parallel(b *testing.B) {
|
||||
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/start", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleStartOAuth(rec, req)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkHandler_HandleGetCSRFToken(b *testing.B) {
|
||||
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/csrf", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleGetCSRFToken(rec, req)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHandler_HandleGetCSRFToken_Parallel(b *testing.B) {
|
||||
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/csrf", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleGetCSRFToken(rec, req)
|
||||
}
|
||||
})
|
||||
}
|
||||
372
backend/internal/presentation/api/documents/handler.go
Normal file
372
backend/internal/presentation/api/documents/handler.go
Normal file
@@ -0,0 +1,372 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package documents
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/application/services"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/presentation/api/shared"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/logger"
|
||||
)
|
||||
|
||||
// documentService defines the interface for document operations
|
||||
type documentService interface {
|
||||
CreateDocument(ctx context.Context, req services.CreateDocumentRequest) (*models.Document, error)
|
||||
FindOrCreateDocument(ctx context.Context, ref string) (*models.Document, bool, error)
|
||||
FindByReference(ctx context.Context, ref string, refType string) (*models.Document, error)
|
||||
}
|
||||
|
||||
// Handler handles document API requests
|
||||
type Handler struct {
|
||||
signatureService *services.SignatureService
|
||||
documentService documentService
|
||||
}
|
||||
|
||||
// NewHandler creates a new documents handler
|
||||
func NewHandler(signatureService *services.SignatureService, documentService documentService) *Handler {
|
||||
return &Handler{
|
||||
signatureService: signatureService,
|
||||
documentService: documentService,
|
||||
}
|
||||
}
|
||||
|
||||
// DocumentDTO represents a document data transfer object
|
||||
type DocumentDTO struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
CreatedAt string `json:"createdAt,omitempty"`
|
||||
UpdatedAt string `json:"updatedAt,omitempty"`
|
||||
SignatureCount int `json:"signatureCount"`
|
||||
ExpectedSignerCount int `json:"expectedSignerCount"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// SignatureDTO represents a signature data transfer object
|
||||
type SignatureDTO struct {
|
||||
ID string `json:"id"`
|
||||
DocID string `json:"docId"`
|
||||
UserEmail string `json:"userEmail"`
|
||||
UserName string `json:"userName,omitempty"`
|
||||
SignedAt string `json:"signedAt"`
|
||||
Signature string `json:"signature"`
|
||||
PayloadHash string `json:"payloadHash"`
|
||||
Nonce string `json:"nonce"`
|
||||
PrevHash string `json:"prevHash,omitempty"`
|
||||
}
|
||||
|
||||
// CreateDocumentRequest represents the request body for creating a document
|
||||
type CreateDocumentRequest struct {
|
||||
Reference string `json:"reference"`
|
||||
Title string `json:"title,omitempty"`
|
||||
}
|
||||
|
||||
// CreateDocumentResponse represents the response for creating a document
|
||||
type CreateDocumentResponse struct {
|
||||
DocID string `json:"docId"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Title string `json:"title"`
|
||||
CreatedAt string `json:"createdAt"`
|
||||
}
|
||||
|
||||
// HandleCreateDocument handles POST /api/v1/documents
|
||||
func (h *Handler) HandleCreateDocument(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Parse request body
|
||||
var req CreateDocumentRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
logger.Logger.Warn("Invalid document creation request body",
|
||||
"error", err.Error(),
|
||||
"remote_addr", r.RemoteAddr)
|
||||
shared.WriteError(w, http.StatusBadRequest, shared.ErrCodeBadRequest, "Invalid request body", map[string]interface{}{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate reference field
|
||||
if req.Reference == "" {
|
||||
logger.Logger.Warn("Document creation request missing reference field",
|
||||
"remote_addr", r.RemoteAddr)
|
||||
shared.WriteError(w, http.StatusBadRequest, shared.ErrCodeBadRequest, "Reference is required", nil)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Logger.Info("Document creation request received",
|
||||
"reference", req.Reference,
|
||||
"has_title", req.Title != "",
|
||||
"remote_addr", r.RemoteAddr)
|
||||
|
||||
// Create document request
|
||||
docRequest := services.CreateDocumentRequest{
|
||||
Reference: req.Reference,
|
||||
Title: req.Title,
|
||||
}
|
||||
|
||||
// Create document
|
||||
doc, err := h.documentService.CreateDocument(ctx, docRequest)
|
||||
if err != nil {
|
||||
logger.Logger.Error("Document creation failed in handler",
|
||||
"reference", req.Reference,
|
||||
"error", err.Error())
|
||||
shared.WriteError(w, http.StatusInternalServerError, shared.ErrCodeInternal, "Failed to create document", map[string]interface{}{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
logger.Logger.Info("Document creation succeeded",
|
||||
"doc_id", doc.DocID,
|
||||
"title", doc.Title,
|
||||
"has_url", doc.URL != "")
|
||||
|
||||
// Return the created document
|
||||
response := CreateDocumentResponse{
|
||||
DocID: doc.DocID,
|
||||
URL: doc.URL,
|
||||
Title: doc.Title,
|
||||
CreatedAt: doc.CreatedAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||
}
|
||||
|
||||
shared.WriteJSON(w, http.StatusCreated, response)
|
||||
}
|
||||
|
||||
// HandleListDocuments handles GET /api/v1/documents
|
||||
func (h *Handler) HandleListDocuments(w http.ResponseWriter, r *http.Request) {
|
||||
// Parse query parameters
|
||||
page := 1
|
||||
limit := 20
|
||||
_ = r.URL.Query().Get("search") // TODO: implement search
|
||||
|
||||
if p := r.URL.Query().Get("page"); p != "" {
|
||||
if parsed, err := strconv.Atoi(p); err == nil && parsed > 0 {
|
||||
page = parsed
|
||||
}
|
||||
}
|
||||
|
||||
if l := r.URL.Query().Get("limit"); l != "" {
|
||||
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 && parsed <= 100 {
|
||||
limit = parsed
|
||||
}
|
||||
}
|
||||
|
||||
// For now, return empty list (we'll implement document listing later)
|
||||
documents := []DocumentDTO{}
|
||||
|
||||
// TODO: Implement actual document listing from database
|
||||
// This would require adding a document repository and service
|
||||
|
||||
total := 0
|
||||
shared.WritePaginatedJSON(w, documents, page, limit, total)
|
||||
}
|
||||
|
||||
// HandleGetDocument handles GET /api/v1/documents/{docId}
|
||||
func (h *Handler) HandleGetDocument(w http.ResponseWriter, r *http.Request) {
|
||||
docID := chi.URLParam(r, "docId")
|
||||
if docID == "" {
|
||||
shared.WriteValidationError(w, "Document ID is required", nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Get signatures for the document
|
||||
signatures, err := h.signatureService.GetDocumentSignatures(r.Context(), docID)
|
||||
if err != nil {
|
||||
shared.WriteInternalError(w)
|
||||
return
|
||||
}
|
||||
|
||||
// Build document response
|
||||
// TODO: Get actual document metadata from database
|
||||
document := DocumentDTO{
|
||||
ID: docID,
|
||||
Title: "Document " + docID, // Placeholder
|
||||
Description: "",
|
||||
SignatureCount: len(signatures),
|
||||
// ExpectedSignerCount will be populated when we have the expected signers repository
|
||||
}
|
||||
|
||||
shared.WriteJSON(w, http.StatusOK, document)
|
||||
}
|
||||
|
||||
// HandleGetDocumentSignatures handles GET /api/v1/documents/{docId}/signatures
|
||||
func (h *Handler) HandleGetDocumentSignatures(w http.ResponseWriter, r *http.Request) {
|
||||
docID := chi.URLParam(r, "docId")
|
||||
if docID == "" {
|
||||
shared.WriteValidationError(w, "Document ID is required", nil)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
signatures, err := h.signatureService.GetDocumentSignatures(ctx, docID)
|
||||
if err != nil {
|
||||
logger.Logger.Error("Failed to get signatures",
|
||||
"doc_id", docID,
|
||||
"error", err.Error())
|
||||
shared.WriteInternalError(w)
|
||||
return
|
||||
}
|
||||
|
||||
// Convert to DTOs
|
||||
dtos := make([]SignatureDTO, len(signatures))
|
||||
for i := range signatures {
|
||||
dtos[i] = signatureToDTO(signatures[i])
|
||||
}
|
||||
|
||||
shared.WriteJSON(w, http.StatusOK, dtos)
|
||||
}
|
||||
|
||||
// HandleGetExpectedSigners handles GET /api/v1/documents/{docId}/expected-signers
|
||||
func (h *Handler) HandleGetExpectedSigners(w http.ResponseWriter, r *http.Request) {
|
||||
docID := chi.URLParam(r, "docId")
|
||||
if docID == "" {
|
||||
shared.WriteValidationError(w, "Document ID is required", nil)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: Implement with expected signers repository
|
||||
expectedSigners := []interface{}{}
|
||||
|
||||
shared.WriteJSON(w, http.StatusOK, expectedSigners)
|
||||
}
|
||||
|
||||
// Helper function to convert signature model to DTO
|
||||
func signatureToDTO(sig *models.Signature) SignatureDTO {
|
||||
dto := SignatureDTO{
|
||||
ID: strconv.FormatInt(sig.ID, 10),
|
||||
DocID: sig.DocID,
|
||||
UserEmail: sig.UserEmail,
|
||||
UserName: sig.UserName,
|
||||
SignedAt: sig.SignedAtUTC.Format("2006-01-02T15:04:05Z07:00"),
|
||||
Signature: sig.Signature,
|
||||
PayloadHash: sig.PayloadHash,
|
||||
Nonce: sig.Nonce,
|
||||
}
|
||||
|
||||
if sig.PrevHash != nil && *sig.PrevHash != "" {
|
||||
dto.PrevHash = *sig.PrevHash
|
||||
}
|
||||
|
||||
return dto
|
||||
}
|
||||
|
||||
// FindOrCreateDocumentResponse represents the response for finding or creating a document
|
||||
type FindOrCreateDocumentResponse struct {
|
||||
DocID string `json:"docId"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Title string `json:"title"`
|
||||
Checksum string `json:"checksum,omitempty"`
|
||||
ChecksumAlgorithm string `json:"checksumAlgorithm,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
CreatedAt string `json:"createdAt"`
|
||||
IsNew bool `json:"isNew"`
|
||||
}
|
||||
|
||||
// HandleFindOrCreateDocument handles GET /api/v1/documents/find-or-create?ref={reference}
|
||||
func (h *Handler) HandleFindOrCreateDocument(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Get reference from query parameter
|
||||
ref := r.URL.Query().Get("ref")
|
||||
if ref == "" {
|
||||
logger.Logger.Warn("Find or create request missing ref parameter",
|
||||
"remote_addr", r.RemoteAddr)
|
||||
shared.WriteError(w, http.StatusBadRequest, shared.ErrCodeBadRequest, "ref parameter is required", nil)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Logger.Info("Find or create document request",
|
||||
"reference", ref,
|
||||
"remote_addr", r.RemoteAddr)
|
||||
|
||||
// Check if user is authenticated
|
||||
user, isAuthenticated := shared.GetUserFromContext(ctx)
|
||||
|
||||
// First, try to find the document (without creating)
|
||||
refType := detectReferenceType(ref)
|
||||
existingDoc, err := h.documentService.FindByReference(ctx, ref, string(refType))
|
||||
if err != nil {
|
||||
logger.Logger.Error("Failed to search for document",
|
||||
"reference", ref,
|
||||
"error", err.Error())
|
||||
shared.WriteError(w, http.StatusInternalServerError, shared.ErrCodeInternal, "Failed to search for document", map[string]interface{}{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// If document exists, return it
|
||||
if existingDoc != nil {
|
||||
logger.Logger.Info("Document found",
|
||||
"doc_id", existingDoc.DocID,
|
||||
"reference", ref)
|
||||
|
||||
response := FindOrCreateDocumentResponse{
|
||||
DocID: existingDoc.DocID,
|
||||
URL: existingDoc.URL,
|
||||
Title: existingDoc.Title,
|
||||
Checksum: existingDoc.Checksum,
|
||||
ChecksumAlgorithm: existingDoc.ChecksumAlgorithm,
|
||||
Description: existingDoc.Description,
|
||||
CreatedAt: existingDoc.CreatedAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||
IsNew: false,
|
||||
}
|
||||
|
||||
shared.WriteJSON(w, http.StatusOK, response)
|
||||
return
|
||||
}
|
||||
|
||||
// Document doesn't exist - check authentication before creating
|
||||
if !isAuthenticated {
|
||||
logger.Logger.Warn("Unauthenticated user attempted to create document",
|
||||
"reference", ref,
|
||||
"remote_addr", r.RemoteAddr)
|
||||
shared.WriteError(w, http.StatusUnauthorized, shared.ErrCodeUnauthorized, "Authentication required to create document", nil)
|
||||
return
|
||||
}
|
||||
|
||||
// User is authenticated, create the document
|
||||
doc, isNew, err := h.documentService.FindOrCreateDocument(ctx, ref)
|
||||
if err != nil {
|
||||
logger.Logger.Error("Failed to create document",
|
||||
"reference", ref,
|
||||
"error", err.Error())
|
||||
shared.WriteError(w, http.StatusInternalServerError, shared.ErrCodeInternal, "Failed to create document", map[string]interface{}{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
logger.Logger.Info("Document created",
|
||||
"doc_id", doc.DocID,
|
||||
"reference", ref,
|
||||
"user_email", user.Email)
|
||||
|
||||
// Build response
|
||||
response := FindOrCreateDocumentResponse{
|
||||
DocID: doc.DocID,
|
||||
URL: doc.URL,
|
||||
Title: doc.Title,
|
||||
Checksum: doc.Checksum,
|
||||
ChecksumAlgorithm: doc.ChecksumAlgorithm,
|
||||
Description: doc.Description,
|
||||
CreatedAt: doc.CreatedAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||
IsNew: isNew,
|
||||
}
|
||||
|
||||
shared.WriteJSON(w, http.StatusOK, response)
|
||||
}
|
||||
|
||||
func detectReferenceType(ref string) ReferenceType {
|
||||
if strings.HasPrefix(ref, "http://") || strings.HasPrefix(ref, "https://") {
|
||||
return "url"
|
||||
}
|
||||
|
||||
if strings.Contains(ref, "/") || strings.Contains(ref, "\\") {
|
||||
return "path"
|
||||
}
|
||||
|
||||
return "reference"
|
||||
}
|
||||
|
||||
type ReferenceType string
|
||||
761
backend/internal/presentation/api/documents/handler_test.go
Normal file
761
backend/internal/presentation/api/documents/handler_test.go
Normal file
@@ -0,0 +1,761 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package documents
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/application/services"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/presentation/api/shared"
|
||||
)
|
||||
|
||||
// ============================================================================
|
||||
// TEST FIXTURES & MOCKS
|
||||
// ============================================================================
|
||||
|
||||
var (
|
||||
testDoc = &models.Document{
|
||||
DocID: "test-doc-123",
|
||||
Title: "Test Document",
|
||||
URL: "https://example.com/doc.pdf",
|
||||
Description: "Test description",
|
||||
Checksum: "abc123",
|
||||
ChecksumAlgorithm: "SHA-256",
|
||||
CreatedAt: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC),
|
||||
UpdatedAt: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC),
|
||||
CreatedBy: "user@example.com",
|
||||
}
|
||||
|
||||
testSignature = &models.Signature{
|
||||
ID: 1,
|
||||
DocID: "test-doc-123",
|
||||
UserSub: "oauth2|123",
|
||||
UserEmail: "user@example.com",
|
||||
UserName: "Test User",
|
||||
SignedAtUTC: time.Date(2024, 1, 1, 12, 30, 0, 0, time.UTC),
|
||||
PayloadHash: "payload-hash-123",
|
||||
Signature: "signature-123",
|
||||
Nonce: "nonce-123",
|
||||
CreatedAt: time.Date(2024, 1, 1, 12, 30, 0, 0, time.UTC),
|
||||
PrevHash: stringPtr("prev-hash-123"),
|
||||
Referer: stringPtr("https://example.com"),
|
||||
}
|
||||
|
||||
testUser = &models.User{
|
||||
Sub: "oauth2|123",
|
||||
Email: "user@example.com",
|
||||
Name: "Test User",
|
||||
}
|
||||
)
|
||||
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
// Mock document service
|
||||
type mockDocumentService struct {
|
||||
createDocFunc func(ctx context.Context, req services.CreateDocumentRequest) (*models.Document, error)
|
||||
findOrCreateDocFunc func(ctx context.Context, ref string) (*models.Document, bool, error)
|
||||
findByReferenceFunc func(ctx context.Context, ref string, refType string) (*models.Document, error)
|
||||
}
|
||||
|
||||
func (m *mockDocumentService) CreateDocument(ctx context.Context, req services.CreateDocumentRequest) (*models.Document, error) {
|
||||
if m.createDocFunc != nil {
|
||||
return m.createDocFunc(ctx, req)
|
||||
}
|
||||
return testDoc, nil
|
||||
}
|
||||
|
||||
func (m *mockDocumentService) FindOrCreateDocument(ctx context.Context, ref string) (*models.Document, bool, error) {
|
||||
if m.findOrCreateDocFunc != nil {
|
||||
return m.findOrCreateDocFunc(ctx, ref)
|
||||
}
|
||||
return testDoc, true, nil
|
||||
}
|
||||
|
||||
func (m *mockDocumentService) FindByReference(ctx context.Context, ref string, refType string) (*models.Document, error) {
|
||||
if m.findByReferenceFunc != nil {
|
||||
return m.findByReferenceFunc(ctx, ref, refType)
|
||||
}
|
||||
return nil, fmt.Errorf("document not found")
|
||||
}
|
||||
|
||||
// Mock signature service
|
||||
type mockSignatureService struct {
|
||||
getDocumentSignaturesFunc func(ctx context.Context, docID string) ([]*models.Signature, error)
|
||||
}
|
||||
|
||||
func (m *mockSignatureService) GetDocumentSignatures(ctx context.Context, docID string) ([]*models.Signature, error) {
|
||||
if m.getDocumentSignaturesFunc != nil {
|
||||
return m.getDocumentSignaturesFunc(ctx, docID)
|
||||
}
|
||||
return []*models.Signature{testSignature}, nil
|
||||
}
|
||||
|
||||
func createTestHandler() *Handler {
|
||||
return &Handler{
|
||||
signatureService: &services.SignatureService{}, // Not used in these tests
|
||||
documentService: &mockDocumentService{},
|
||||
}
|
||||
}
|
||||
|
||||
func addUserToContext(ctx context.Context, user *models.User) context.Context {
|
||||
return context.WithValue(ctx, shared.ContextKeyUser, user)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - Constructor
|
||||
// ============================================================================
|
||||
|
||||
func TestNewHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sigService := &services.SignatureService{}
|
||||
docService := &mockDocumentService{}
|
||||
|
||||
handler := NewHandler(sigService, docService)
|
||||
|
||||
assert.NotNil(t, handler)
|
||||
assert.Equal(t, sigService, handler.signatureService)
|
||||
assert.Equal(t, docService, handler.documentService)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - HandleCreateDocument
|
||||
// ============================================================================
|
||||
|
||||
func TestHandler_HandleCreateDocument_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
reference string
|
||||
title string
|
||||
}{
|
||||
{
|
||||
name: "with title",
|
||||
reference: "https://example.com/doc.pdf",
|
||||
title: "My Document",
|
||||
},
|
||||
{
|
||||
name: "without title",
|
||||
reference: "https://example.com/doc.pdf",
|
||||
title: "",
|
||||
},
|
||||
{
|
||||
name: "with file path reference",
|
||||
reference: "/path/to/document.pdf",
|
||||
title: "Local Document",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockDocService := &mockDocumentService{
|
||||
createDocFunc: func(ctx context.Context, req services.CreateDocumentRequest) (*models.Document, error) {
|
||||
assert.Equal(t, tt.reference, req.Reference)
|
||||
assert.Equal(t, tt.title, req.Title)
|
||||
return testDoc, nil
|
||||
},
|
||||
}
|
||||
|
||||
handler := &Handler{
|
||||
documentService: mockDocService,
|
||||
}
|
||||
|
||||
reqBody := CreateDocumentRequest{
|
||||
Reference: tt.reference,
|
||||
Title: tt.title,
|
||||
}
|
||||
body, err := json.Marshal(reqBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/documents", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCreateDocument(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusCreated, rec.Code)
|
||||
assert.Equal(t, "application/json", rec.Header().Get("Content-Type"))
|
||||
|
||||
var wrapper struct {
|
||||
Data CreateDocumentResponse `json:"data"`
|
||||
}
|
||||
err = json.Unmarshal(rec.Body.Bytes(), &wrapper)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, testDoc.DocID, wrapper.Data.DocID)
|
||||
assert.Equal(t, testDoc.Title, wrapper.Data.Title)
|
||||
assert.Equal(t, testDoc.URL, wrapper.Data.URL)
|
||||
assert.NotEmpty(t, wrapper.Data.CreatedAt)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_HandleCreateDocument_ValidationErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
requestBody interface{}
|
||||
expectedStatus int
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "empty reference",
|
||||
requestBody: CreateDocumentRequest{Reference: "", Title: "Title"},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Reference is required",
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
requestBody: "invalid json",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Invalid request body",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := createTestHandler()
|
||||
|
||||
var body []byte
|
||||
var err error
|
||||
if str, ok := tt.requestBody.(string); ok {
|
||||
body = []byte(str)
|
||||
} else {
|
||||
body, err = json.Marshal(tt.requestBody)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/documents", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCreateDocument(rec, req)
|
||||
|
||||
assert.Equal(t, tt.expectedStatus, rec.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Contains(t, response, "error")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_HandleCreateDocument_ServiceError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockDocService := &mockDocumentService{
|
||||
createDocFunc: func(ctx context.Context, req services.CreateDocumentRequest) (*models.Document, error) {
|
||||
return nil, fmt.Errorf("database error")
|
||||
},
|
||||
}
|
||||
|
||||
handler := &Handler{
|
||||
documentService: mockDocService,
|
||||
}
|
||||
|
||||
reqBody := CreateDocumentRequest{
|
||||
Reference: "https://example.com/doc.pdf",
|
||||
Title: "Test",
|
||||
}
|
||||
body, err := json.Marshal(reqBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/documents", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCreateDocument(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Contains(t, response, "error")
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - HandleListDocuments
|
||||
// ============================================================================
|
||||
|
||||
func TestHandler_HandleListDocuments_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
queryParams string
|
||||
expectedPage int
|
||||
expectedLimit int
|
||||
}{
|
||||
{
|
||||
name: "default pagination",
|
||||
queryParams: "",
|
||||
expectedPage: 1,
|
||||
expectedLimit: 20,
|
||||
},
|
||||
{
|
||||
name: "custom page and limit",
|
||||
queryParams: "?page=2&limit=50",
|
||||
expectedPage: 2,
|
||||
expectedLimit: 50,
|
||||
},
|
||||
{
|
||||
name: "limit max capped at 100",
|
||||
queryParams: "?limit=200",
|
||||
expectedPage: 1,
|
||||
expectedLimit: 20, // Will use default since > 100
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := createTestHandler()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/documents"+tt.queryParams, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleListDocuments(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "application/json", rec.Header().Get("Content-Type"))
|
||||
|
||||
var wrapper struct {
|
||||
Data interface{} `json:"data"`
|
||||
Meta struct {
|
||||
Page int `json:"page"`
|
||||
Limit int `json:"limit"`
|
||||
Total int `json:"total"`
|
||||
} `json:"meta"`
|
||||
}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &wrapper)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Currently returns empty list
|
||||
assert.NotNil(t, wrapper.Data)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - HandleGetDocument
|
||||
// ============================================================================
|
||||
|
||||
// TestHandler_HandleGetDocument_Success is skipped because SignatureService
|
||||
// cannot be mocked without significant refactoring. The service requires
|
||||
// a repository interface that we cannot inject in tests.
|
||||
// TODO: Refactor to use interface for signature service
|
||||
func TestHandler_HandleGetDocument_Success(t *testing.T) {
|
||||
t.Skip("SignatureService is not mockable - needs refactoring")
|
||||
}
|
||||
|
||||
func TestHandler_HandleGetDocument_MissingDocID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := createTestHandler()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/documents/", nil)
|
||||
|
||||
// Empty docId parameter
|
||||
rctx := chi.NewRouteContext()
|
||||
rctx.URLParams.Add("docId", "")
|
||||
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleGetDocument(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - HandleGetDocumentSignatures
|
||||
// ============================================================================
|
||||
|
||||
func TestHandler_HandleGetDocumentSignatures_MissingDocID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := createTestHandler()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/documents//signatures", nil)
|
||||
|
||||
rctx := chi.NewRouteContext()
|
||||
rctx.URLParams.Add("docId", "")
|
||||
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleGetDocumentSignatures(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - HandleFindOrCreateDocument
|
||||
// ============================================================================
|
||||
|
||||
func TestHandler_HandleFindOrCreateDocument_FindExisting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockDocService := &mockDocumentService{
|
||||
findByReferenceFunc: func(ctx context.Context, ref string, refType string) (*models.Document, error) {
|
||||
assert.Equal(t, "https://example.com/doc.pdf", ref)
|
||||
assert.Equal(t, "url", refType)
|
||||
return testDoc, nil
|
||||
},
|
||||
}
|
||||
|
||||
handler := &Handler{
|
||||
documentService: mockDocService,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/documents/find-or-create?ref=https://example.com/doc.pdf", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleFindOrCreateDocument(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var wrapper struct {
|
||||
Data FindOrCreateDocumentResponse `json:"data"`
|
||||
}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &wrapper)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, testDoc.DocID, wrapper.Data.DocID)
|
||||
assert.False(t, wrapper.Data.IsNew, "Should not be new since document was found")
|
||||
}
|
||||
|
||||
func TestHandler_HandleFindOrCreateDocument_CreateNew(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockDocService := &mockDocumentService{
|
||||
findByReferenceFunc: func(ctx context.Context, ref string, refType string) (*models.Document, error) {
|
||||
// Document not found - return nil, nil (not an error)
|
||||
return nil, nil
|
||||
},
|
||||
findOrCreateDocFunc: func(ctx context.Context, ref string) (*models.Document, bool, error) {
|
||||
assert.Equal(t, "https://example.com/new-doc.pdf", ref)
|
||||
return testDoc, true, nil
|
||||
},
|
||||
}
|
||||
|
||||
handler := &Handler{
|
||||
documentService: mockDocService,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/documents/find-or-create?ref=https://example.com/new-doc.pdf", nil)
|
||||
|
||||
// Add authenticated user to context
|
||||
ctx := addUserToContext(req.Context(), testUser)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleFindOrCreateDocument(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var wrapper struct {
|
||||
Data FindOrCreateDocumentResponse `json:"data"`
|
||||
}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &wrapper)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, testDoc.DocID, wrapper.Data.DocID)
|
||||
assert.True(t, wrapper.Data.IsNew, "Should be new since document was created")
|
||||
}
|
||||
|
||||
func TestHandler_HandleFindOrCreateDocument_UnauthenticatedCreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockDocService := &mockDocumentService{
|
||||
findByReferenceFunc: func(ctx context.Context, ref string, refType string) (*models.Document, error) {
|
||||
// Document not found - return nil, nil (not an error)
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
handler := &Handler{
|
||||
documentService: mockDocService,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/documents/find-or-create?ref=https://example.com/new-doc.pdf", nil)
|
||||
// No user in context
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleFindOrCreateDocument(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Contains(t, response, "error")
|
||||
}
|
||||
|
||||
func TestHandler_HandleFindOrCreateDocument_MissingRef(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := createTestHandler()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/documents/find-or-create", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleFindOrCreateDocument(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Contains(t, response, "error")
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - detectReferenceType
|
||||
// ============================================================================
|
||||
|
||||
func Test_detectReferenceType(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ref string
|
||||
expected ReferenceType
|
||||
}{
|
||||
{
|
||||
name: "HTTP URL",
|
||||
ref: "http://example.com/doc.pdf",
|
||||
expected: "url",
|
||||
},
|
||||
{
|
||||
name: "HTTPS URL",
|
||||
ref: "https://example.com/doc.pdf",
|
||||
expected: "url",
|
||||
},
|
||||
{
|
||||
name: "Unix file path",
|
||||
ref: "/path/to/document.pdf",
|
||||
expected: "path",
|
||||
},
|
||||
{
|
||||
name: "Windows file path",
|
||||
ref: "C:\\path\\to\\document.pdf",
|
||||
expected: "path",
|
||||
},
|
||||
{
|
||||
name: "Simple reference",
|
||||
ref: "doc-12345",
|
||||
expected: "reference",
|
||||
},
|
||||
{
|
||||
name: "Hash reference",
|
||||
ref: "abc123def456",
|
||||
expected: "reference",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result := detectReferenceType(tt.ref)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - signatureToDTO
|
||||
// ============================================================================
|
||||
|
||||
func Test_signatureToDTO(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sig *models.Signature
|
||||
checkDTO func(t *testing.T, dto SignatureDTO)
|
||||
}{
|
||||
{
|
||||
name: "with prevHash",
|
||||
sig: testSignature,
|
||||
checkDTO: func(t *testing.T, dto SignatureDTO) {
|
||||
assert.Equal(t, "1", dto.ID)
|
||||
assert.Equal(t, testSignature.DocID, dto.DocID)
|
||||
assert.Equal(t, testSignature.UserEmail, dto.UserEmail)
|
||||
assert.Equal(t, testSignature.UserName, dto.UserName)
|
||||
assert.Equal(t, testSignature.Signature, dto.Signature)
|
||||
assert.Equal(t, testSignature.PayloadHash, dto.PayloadHash)
|
||||
assert.Equal(t, testSignature.Nonce, dto.Nonce)
|
||||
assert.Equal(t, *testSignature.PrevHash, dto.PrevHash)
|
||||
assert.NotEmpty(t, dto.SignedAt)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "without prevHash",
|
||||
sig: &models.Signature{
|
||||
ID: 2,
|
||||
DocID: "doc-456",
|
||||
UserSub: "oauth2|456",
|
||||
UserEmail: "user2@example.com",
|
||||
UserName: "User 2",
|
||||
SignedAtUTC: time.Date(2024, 1, 2, 10, 0, 0, 0, time.UTC),
|
||||
PayloadHash: "hash-456",
|
||||
Signature: "sig-456",
|
||||
Nonce: "nonce-456",
|
||||
PrevHash: nil,
|
||||
},
|
||||
checkDTO: func(t *testing.T, dto SignatureDTO) {
|
||||
assert.Equal(t, "2", dto.ID)
|
||||
assert.Empty(t, dto.PrevHash)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dto := signatureToDTO(tt.sig)
|
||||
tt.checkDTO(t, dto)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - Concurrency
|
||||
// ============================================================================
|
||||
|
||||
func TestHandler_HandleCreateDocument_Concurrent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := createTestHandler()
|
||||
|
||||
const numRequests = 50
|
||||
done := make(chan bool, numRequests)
|
||||
errors := make(chan error, numRequests)
|
||||
|
||||
for i := 0; i < numRequests; i++ {
|
||||
go func(id int) {
|
||||
defer func() { done <- true }()
|
||||
|
||||
reqBody := CreateDocumentRequest{
|
||||
Reference: fmt.Sprintf("https://example.com/doc-%d.pdf", id),
|
||||
Title: fmt.Sprintf("Document %d", id),
|
||||
}
|
||||
body, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/documents", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCreateDocument(rec, req)
|
||||
|
||||
if rec.Code != http.StatusCreated {
|
||||
errors <- fmt.Errorf("unexpected status: %d", rec.Code)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
for i := 0; i < numRequests; i++ {
|
||||
<-done
|
||||
}
|
||||
close(errors)
|
||||
|
||||
var errCount int
|
||||
for err := range errors {
|
||||
t.Logf("Concurrent request error: %v", err)
|
||||
errCount++
|
||||
}
|
||||
|
||||
assert.Equal(t, 0, errCount, "All concurrent requests should succeed")
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// BENCHMARKS
|
||||
// ============================================================================
|
||||
|
||||
func BenchmarkHandler_HandleCreateDocument(b *testing.B) {
|
||||
handler := createTestHandler()
|
||||
|
||||
reqBody := CreateDocumentRequest{
|
||||
Reference: "https://example.com/doc.pdf",
|
||||
Title: "Test Document",
|
||||
}
|
||||
body, _ := json.Marshal(reqBody)
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/documents", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCreateDocument(rec, req)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHandler_HandleCreateDocument_Parallel(b *testing.B) {
|
||||
handler := createTestHandler()
|
||||
|
||||
reqBody := CreateDocumentRequest{
|
||||
Reference: "https://example.com/doc.pdf",
|
||||
Title: "Test Document",
|
||||
}
|
||||
body, _ := json.Marshal(reqBody)
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/documents", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCreateDocument(rec, req)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Benchmark_detectReferenceType(b *testing.B) {
|
||||
refs := []string{
|
||||
"https://example.com/doc.pdf",
|
||||
"/path/to/file.pdf",
|
||||
"simple-reference",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
detectReferenceType(refs[i%len(refs)])
|
||||
}
|
||||
}
|
||||
33
backend/internal/presentation/api/health/handler.go
Normal file
33
backend/internal/presentation/api/health/handler.go
Normal file
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package health
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/presentation/api/shared"
|
||||
)
|
||||
|
||||
// Handler handles health check requests
|
||||
type Handler struct{}
|
||||
|
||||
// NewHandler creates a new health handler
|
||||
func NewHandler() *Handler {
|
||||
return &Handler{}
|
||||
}
|
||||
|
||||
// HealthResponse represents the health check response
|
||||
type HealthResponse struct {
|
||||
Status string `json:"status"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// HandleHealth handles GET /api/v1/health
|
||||
func (h *Handler) HandleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
response := HealthResponse{
|
||||
Status: "ok",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
shared.WriteJSON(w, http.StatusOK, response)
|
||||
}
|
||||
234
backend/internal/presentation/api/health/handler_test.go
Normal file
234
backend/internal/presentation/api/health/handler_test.go
Normal file
@@ -0,0 +1,234 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package health
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := NewHandler()
|
||||
|
||||
assert.NotNil(t, handler)
|
||||
}
|
||||
|
||||
func TestHandler_HandleHealth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "GET returns 200 OK",
|
||||
method: http.MethodGet,
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "POST also works (health check should be method-agnostic)",
|
||||
method: http.MethodPost,
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "HEAD also works",
|
||||
method: http.MethodHead,
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Setup
|
||||
handler := NewHandler()
|
||||
req := httptest.NewRequest(tt.method, "/api/v1/health", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
handler.HandleHealth(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, tt.expectedStatus, rec.Code)
|
||||
|
||||
// Validate response body for non-HEAD requests
|
||||
if tt.method != http.MethodHead {
|
||||
// Response is wrapped in {"data": {...}}
|
||||
var wrapper struct {
|
||||
Data HealthResponse `json:"data"`
|
||||
}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &wrapper)
|
||||
require.NoError(t, err, "Response should be valid JSON")
|
||||
|
||||
assert.Equal(t, "ok", wrapper.Data.Status)
|
||||
assert.NotZero(t, wrapper.Data.Timestamp)
|
||||
|
||||
// Timestamp should be recent (within last 5 seconds)
|
||||
now := time.Now()
|
||||
assert.WithinDuration(t, now, wrapper.Data.Timestamp, 5*time.Second)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_HandleHealth_ResponseFormat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := NewHandler()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/health", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleHealth(rec, req)
|
||||
|
||||
// Check Content-Type
|
||||
assert.Equal(t, "application/json", rec.Header().Get("Content-Type"))
|
||||
|
||||
// Validate JSON structure
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check wrapper structure
|
||||
assert.Contains(t, response, "data")
|
||||
|
||||
// Get data object
|
||||
data, ok := response["data"].(map[string]interface{})
|
||||
require.True(t, ok, "data should be an object")
|
||||
|
||||
// Check required fields in data
|
||||
assert.Contains(t, data, "status")
|
||||
assert.Contains(t, data, "timestamp")
|
||||
|
||||
// Validate status value
|
||||
status, ok := data["status"].(string)
|
||||
require.True(t, ok, "status should be a string")
|
||||
assert.Equal(t, "ok", status)
|
||||
|
||||
// Validate timestamp format (RFC3339)
|
||||
timestampStr, ok := data["timestamp"].(string)
|
||||
require.True(t, ok, "timestamp should be a string")
|
||||
|
||||
_, err = time.Parse(time.RFC3339, timestampStr)
|
||||
assert.NoError(t, err, "timestamp should be in RFC3339 format")
|
||||
}
|
||||
|
||||
func TestHandler_HandleHealth_Concurrent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := NewHandler()
|
||||
|
||||
const numRequests = 100
|
||||
done := make(chan bool, numRequests)
|
||||
errors := make(chan error, numRequests)
|
||||
|
||||
// Spawn concurrent requests
|
||||
for i := 0; i < numRequests; i++ {
|
||||
go func() {
|
||||
defer func() { done <- true }()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/health", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleHealth(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
errors <- assert.AnError
|
||||
}
|
||||
|
||||
var wrapper struct {
|
||||
Data HealthResponse `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &wrapper); err != nil {
|
||||
errors <- err
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all requests
|
||||
for i := 0; i < numRequests; i++ {
|
||||
<-done
|
||||
}
|
||||
close(errors)
|
||||
|
||||
// Check for errors
|
||||
var errCount int
|
||||
for err := range errors {
|
||||
t.Logf("Concurrent request error: %v", err)
|
||||
errCount++
|
||||
}
|
||||
|
||||
assert.Equal(t, 0, errCount, "All concurrent health checks should succeed")
|
||||
}
|
||||
|
||||
func TestHandler_HandleHealth_Idempotency(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := NewHandler()
|
||||
|
||||
// First request
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/api/v1/health", nil)
|
||||
rec1 := httptest.NewRecorder()
|
||||
handler.HandleHealth(rec1, req1)
|
||||
|
||||
var wrapper1 struct {
|
||||
Data HealthResponse `json:"data"`
|
||||
}
|
||||
err := json.Unmarshal(rec1.Body.Bytes(), &wrapper1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Small delay
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Second request
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/api/v1/health", nil)
|
||||
rec2 := httptest.NewRecorder()
|
||||
handler.HandleHealth(rec2, req2)
|
||||
|
||||
var wrapper2 struct {
|
||||
Data HealthResponse `json:"data"`
|
||||
}
|
||||
err = json.Unmarshal(rec2.Body.Bytes(), &wrapper2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Status should be same
|
||||
assert.Equal(t, wrapper1.Data.Status, wrapper2.Data.Status)
|
||||
|
||||
// Timestamps should be different (but close)
|
||||
assert.NotEqual(t, wrapper1.Data.Timestamp, wrapper2.Data.Timestamp)
|
||||
assert.WithinDuration(t, wrapper1.Data.Timestamp, wrapper2.Data.Timestamp, 1*time.Second)
|
||||
}
|
||||
|
||||
func BenchmarkHandler_HandleHealth(b *testing.B) {
|
||||
handler := NewHandler()
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/health", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleHealth(rec, req)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHandler_HandleHealth_Parallel(b *testing.B) {
|
||||
handler := NewHandler()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/health", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleHealth(rec, req)
|
||||
}
|
||||
})
|
||||
}
|
||||
175
backend/internal/presentation/api/router.go
Normal file
175
backend/internal/presentation/api/router.go
Normal file
@@ -0,0 +1,175 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/application/services"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/auth"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/database"
|
||||
apiAdmin "github.com/btouchard/ackify-ce/backend/internal/presentation/api/admin"
|
||||
apiAuth "github.com/btouchard/ackify-ce/backend/internal/presentation/api/auth"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/presentation/api/documents"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/presentation/api/health"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/presentation/api/shared"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/presentation/api/signatures"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/presentation/api/users"
|
||||
)
|
||||
|
||||
// RouterConfig holds configuration for the API router
|
||||
type RouterConfig struct {
|
||||
AuthService *auth.OauthService
|
||||
SignatureService *services.SignatureService
|
||||
DocumentService *services.DocumentService
|
||||
DocumentRepository *database.DocumentRepository
|
||||
ExpectedSignerRepository *database.ExpectedSignerRepository
|
||||
ReminderService *services.ReminderAsyncService // Now using async service
|
||||
BaseURL string
|
||||
AdminEmails []string
|
||||
AutoLogin bool
|
||||
}
|
||||
|
||||
// NewRouter creates and configures the API v1 router
|
||||
func NewRouter(cfg RouterConfig) *chi.Mux {
|
||||
r := chi.NewRouter()
|
||||
|
||||
// Initialize middleware
|
||||
apiMiddleware := shared.NewMiddleware(cfg.AuthService, cfg.BaseURL, cfg.AdminEmails)
|
||||
|
||||
// Rate limiters
|
||||
authRateLimit := shared.NewRateLimit(5, time.Minute) // 5 attempts per minute for auth
|
||||
documentRateLimit := shared.NewRateLimit(10, time.Minute) // 10 documents per minute
|
||||
generalRateLimit := shared.NewRateLimit(100, time.Minute) // 100 requests per minute general
|
||||
|
||||
// Global middleware
|
||||
r.Use(middleware.RequestID)
|
||||
r.Use(shared.AddRequestIDToContext)
|
||||
r.Use(middleware.RealIP)
|
||||
r.Use(shared.RequestLogger)
|
||||
r.Use(middleware.Recoverer)
|
||||
r.Use(shared.SecurityHeaders)
|
||||
r.Use(apiMiddleware.CORS)
|
||||
r.Use(generalRateLimit.Middleware)
|
||||
|
||||
// Initialize handlers
|
||||
healthHandler := health.NewHandler()
|
||||
authHandler := apiAuth.NewHandler(cfg.AuthService, apiMiddleware, cfg.BaseURL)
|
||||
usersHandler := users.NewHandler(cfg.AdminEmails)
|
||||
documentsHandler := documents.NewHandler(cfg.SignatureService, cfg.DocumentService)
|
||||
signaturesHandler := signatures.NewHandler(cfg.SignatureService)
|
||||
|
||||
// Public routes
|
||||
r.Group(func(r chi.Router) {
|
||||
// Health check
|
||||
r.Get("/health", healthHandler.HandleHealth)
|
||||
|
||||
// CSRF token
|
||||
r.Get("/csrf", authHandler.HandleGetCSRFToken)
|
||||
|
||||
// Auth endpoints
|
||||
r.Route("/auth", func(r chi.Router) {
|
||||
r.Use(authRateLimit.Middleware)
|
||||
|
||||
r.Post("/start", authHandler.HandleStartOAuth)
|
||||
r.Get("/callback", authHandler.HandleOAuthCallback)
|
||||
r.Get("/logout", authHandler.HandleLogout)
|
||||
|
||||
if cfg.AutoLogin {
|
||||
r.Get("/check", authHandler.HandleAuthCheck)
|
||||
}
|
||||
})
|
||||
|
||||
// Public document endpoints
|
||||
r.Route("/documents", func(r chi.Router) {
|
||||
// Document creation (with CSRF and stricter rate limiting)
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(apiMiddleware.CSRFProtect)
|
||||
r.Use(documentRateLimit.Middleware)
|
||||
r.Post("/", documentsHandler.HandleCreateDocument)
|
||||
})
|
||||
|
||||
// Read-only document endpoints
|
||||
r.Get("/", documentsHandler.HandleListDocuments)
|
||||
r.Get("/{docId}", documentsHandler.HandleGetDocument)
|
||||
r.Get("/{docId}/signatures", documentsHandler.HandleGetDocumentSignatures)
|
||||
r.Get("/{docId}/expected-signers", documentsHandler.HandleGetExpectedSigners)
|
||||
|
||||
// Find or create document by reference (public for embed support, but with optional auth)
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(apiMiddleware.OptionalAuth)
|
||||
r.Get("/find-or-create", documentsHandler.HandleFindOrCreateDocument)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// Authenticated routes
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(apiMiddleware.RequireAuth)
|
||||
r.Use(apiMiddleware.CSRFProtect)
|
||||
|
||||
// User endpoints
|
||||
r.Route("/users", func(r chi.Router) {
|
||||
r.Get("/me", usersHandler.HandleGetCurrentUser)
|
||||
})
|
||||
|
||||
// Signature endpoints
|
||||
r.Route("/signatures", func(r chi.Router) {
|
||||
r.Get("/", signaturesHandler.HandleGetUserSignatures)
|
||||
r.Post("/", signaturesHandler.HandleCreateSignature)
|
||||
})
|
||||
|
||||
// Document signature status (authenticated)
|
||||
r.Get("/documents/{docId}/signatures/status", signaturesHandler.HandleGetSignatureStatus)
|
||||
})
|
||||
|
||||
// Admin routes
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(apiMiddleware.RequireAdmin)
|
||||
r.Use(apiMiddleware.CSRFProtect)
|
||||
|
||||
// Initialize admin handler
|
||||
adminHandler := apiAdmin.NewHandler(cfg.DocumentRepository, cfg.ExpectedSignerRepository, cfg.ReminderService, cfg.SignatureService, cfg.BaseURL)
|
||||
|
||||
r.Route("/admin", func(r chi.Router) {
|
||||
// Document management
|
||||
r.Route("/documents", func(r chi.Router) {
|
||||
r.Get("/", adminHandler.HandleListDocuments)
|
||||
r.Get("/{docId}", adminHandler.HandleGetDocument)
|
||||
r.Get("/{docId}/signers", adminHandler.HandleGetDocumentWithSigners)
|
||||
r.Get("/{docId}/status", adminHandler.HandleGetDocumentStatus)
|
||||
|
||||
// Document metadata
|
||||
r.Put("/{docId}/metadata", adminHandler.HandleUpdateDocumentMetadata)
|
||||
|
||||
// Document deletion
|
||||
r.Delete("/{docId}", adminHandler.HandleDeleteDocument)
|
||||
|
||||
// Expected signers management
|
||||
r.Post("/{docId}/signers", adminHandler.HandleAddExpectedSigner)
|
||||
r.Delete("/{docId}/signers/{email}", adminHandler.HandleRemoveExpectedSigner)
|
||||
|
||||
// Reminder management
|
||||
r.Post("/{docId}/reminders", adminHandler.HandleSendReminders)
|
||||
r.Get("/{docId}/reminders", adminHandler.HandleGetReminderHistory)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// Serve OpenAPI spec
|
||||
r.Get("/openapi.json", serveOpenAPISpec)
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// serveOpenAPISpec serves the OpenAPI specification
|
||||
func serveOpenAPISpec(w http.ResponseWriter, r *http.Request) {
|
||||
// TODO: Read and serve the OpenAPI YAML file as JSON
|
||||
// For now, return a simple response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"info":{"title":"Ackify API","version":"1.0.0"}}`))
|
||||
}
|
||||
101
backend/internal/presentation/api/shared/errors.go
Normal file
101
backend/internal/presentation/api/shared/errors.go
Normal file
@@ -0,0 +1,101 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package shared
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// ErrorCode represents standardized API error codes
|
||||
type ErrorCode string
|
||||
|
||||
const (
|
||||
// Client errors
|
||||
ErrCodeValidation ErrorCode = "VALIDATION_ERROR"
|
||||
ErrCodeBadRequest ErrorCode = "BAD_REQUEST"
|
||||
ErrCodeUnauthorized ErrorCode = "UNAUTHORIZED"
|
||||
ErrCodeForbidden ErrorCode = "FORBIDDEN"
|
||||
ErrCodeNotFound ErrorCode = "NOT_FOUND"
|
||||
ErrCodeConflict ErrorCode = "CONFLICT"
|
||||
ErrCodeRateLimited ErrorCode = "RATE_LIMITED"
|
||||
ErrCodeCSRFInvalid ErrorCode = "CSRF_INVALID"
|
||||
|
||||
// Server errors
|
||||
ErrCodeInternal ErrorCode = "INTERNAL_ERROR"
|
||||
ErrCodeServiceUnavailable ErrorCode = "SERVICE_UNAVAILABLE"
|
||||
)
|
||||
|
||||
// ErrorResponse represents a standardized error response
|
||||
type ErrorResponse struct {
|
||||
Error ErrorDetail `json:"error"`
|
||||
}
|
||||
|
||||
// ErrorDetail contains error details
|
||||
type ErrorDetail struct {
|
||||
Code ErrorCode `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Details map[string]interface{} `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
// WriteError writes a standardized error response
|
||||
func WriteError(w http.ResponseWriter, statusCode int, code ErrorCode, message string, details map[string]interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
response := ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Code: code,
|
||||
Message: message,
|
||||
Details: details,
|
||||
},
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// WriteValidationError writes a validation error response
|
||||
func WriteValidationError(w http.ResponseWriter, message string, fieldErrors map[string]string) {
|
||||
details := make(map[string]interface{})
|
||||
if fieldErrors != nil {
|
||||
details["fields"] = fieldErrors
|
||||
}
|
||||
WriteError(w, http.StatusBadRequest, ErrCodeValidation, message, details)
|
||||
}
|
||||
|
||||
// WriteUnauthorized writes an unauthorized error response
|
||||
func WriteUnauthorized(w http.ResponseWriter, message string) {
|
||||
if message == "" {
|
||||
message = "Authentication required"
|
||||
}
|
||||
WriteError(w, http.StatusUnauthorized, ErrCodeUnauthorized, message, nil)
|
||||
}
|
||||
|
||||
// WriteForbidden writes a forbidden error response
|
||||
func WriteForbidden(w http.ResponseWriter, message string) {
|
||||
if message == "" {
|
||||
message = "Access denied"
|
||||
}
|
||||
WriteError(w, http.StatusForbidden, ErrCodeForbidden, message, nil)
|
||||
}
|
||||
|
||||
// WriteNotFound writes a not found error response
|
||||
func WriteNotFound(w http.ResponseWriter, resource string) {
|
||||
message := "Resource not found"
|
||||
if resource != "" {
|
||||
message = resource + " not found"
|
||||
}
|
||||
WriteError(w, http.StatusNotFound, ErrCodeNotFound, message, nil)
|
||||
}
|
||||
|
||||
// WriteConflict writes a conflict error response
|
||||
func WriteConflict(w http.ResponseWriter, message string) {
|
||||
if message == "" {
|
||||
message = "Resource conflict"
|
||||
}
|
||||
WriteError(w, http.StatusConflict, ErrCodeConflict, message, nil)
|
||||
}
|
||||
|
||||
// WriteInternalError writes an internal server error response
|
||||
func WriteInternalError(w http.ResponseWriter) {
|
||||
WriteError(w, http.StatusInternalServerError, ErrCodeInternal, "An internal error occurred", nil)
|
||||
}
|
||||
188
backend/internal/presentation/api/shared/errors_test.go
Normal file
188
backend/internal/presentation/api/shared/errors_test.go
Normal file
@@ -0,0 +1,188 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package shared
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWriteValidationError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
message string
|
||||
fieldErrors map[string]string
|
||||
}{
|
||||
{
|
||||
name: "Validation error with field errors",
|
||||
message: "Invalid input",
|
||||
fieldErrors: map[string]string{
|
||||
"email": "Invalid email format",
|
||||
"age": "Must be positive",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Validation error without field errors",
|
||||
message: "Invalid request",
|
||||
fieldErrors: nil,
|
||||
},
|
||||
{
|
||||
name: "Validation error with empty field errors",
|
||||
message: "Validation failed",
|
||||
fieldErrors: map[string]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
WriteValidationError(w, tt.message, tt.fieldErrors)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status code %d, got %d", http.StatusBadRequest, w.Code)
|
||||
}
|
||||
|
||||
var response ErrorResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("Failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if response.Error.Message != tt.message {
|
||||
t.Errorf("Expected message '%s', got '%s'", tt.message, response.Error.Message)
|
||||
}
|
||||
|
||||
if response.Error.Code != ErrCodeValidation {
|
||||
t.Errorf("Expected code '%s', got '%s'", ErrCodeValidation, response.Error.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
resource string
|
||||
expectedMessage string
|
||||
}{
|
||||
{
|
||||
name: "Not found with resource name",
|
||||
resource: "User",
|
||||
expectedMessage: "User not found",
|
||||
},
|
||||
{
|
||||
name: "Not found without resource name",
|
||||
resource: "",
|
||||
expectedMessage: "Resource not found",
|
||||
},
|
||||
{
|
||||
name: "Not found with document resource",
|
||||
resource: "Document",
|
||||
expectedMessage: "Document not found",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
WriteNotFound(w, tt.resource)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("Expected status code %d, got %d", http.StatusNotFound, w.Code)
|
||||
}
|
||||
|
||||
var response ErrorResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("Failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if response.Error.Message != tt.expectedMessage {
|
||||
t.Errorf("Expected message '%s', got '%s'", tt.expectedMessage, response.Error.Message)
|
||||
}
|
||||
|
||||
if response.Error.Code != ErrCodeNotFound {
|
||||
t.Errorf("Expected code '%s', got '%s'", ErrCodeNotFound, response.Error.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteConflict(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
message string
|
||||
expectedMessage string
|
||||
}{
|
||||
{
|
||||
name: "Conflict with custom message",
|
||||
message: "Email already exists",
|
||||
expectedMessage: "Email already exists",
|
||||
},
|
||||
{
|
||||
name: "Conflict with empty message",
|
||||
message: "",
|
||||
expectedMessage: "Resource conflict",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
WriteConflict(w, tt.message)
|
||||
|
||||
if w.Code != http.StatusConflict {
|
||||
t.Errorf("Expected status code %d, got %d", http.StatusConflict, w.Code)
|
||||
}
|
||||
|
||||
var response ErrorResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("Failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if response.Error.Message != tt.expectedMessage {
|
||||
t.Errorf("Expected message '%s', got '%s'", tt.expectedMessage, response.Error.Message)
|
||||
}
|
||||
|
||||
if response.Error.Code != ErrCodeConflict {
|
||||
t.Errorf("Expected code '%s', got '%s'", ErrCodeConflict, response.Error.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteInternalError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
WriteInternalError(w)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status code %d, got %d", http.StatusInternalServerError, w.Code)
|
||||
}
|
||||
|
||||
var response ErrorResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("Failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if response.Error.Message != "An internal error occurred" {
|
||||
t.Errorf("Expected message 'An internal error occurred', got '%s'", response.Error.Message)
|
||||
}
|
||||
|
||||
if response.Error.Code != ErrCodeInternal {
|
||||
t.Errorf("Expected code '%s', got '%s'", ErrCodeInternal, response.Error.Code)
|
||||
}
|
||||
}
|
||||
110
backend/internal/presentation/api/shared/logging.go
Normal file
110
backend/internal/presentation/api/shared/logging.go
Normal file
@@ -0,0 +1,110 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package shared
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/logger"
|
||||
)
|
||||
|
||||
// responseWriter is a wrapper around http.ResponseWriter that captures the status code
|
||||
type responseWriter struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
wroteHeader bool
|
||||
}
|
||||
|
||||
func wrapResponseWriter(w http.ResponseWriter) *responseWriter {
|
||||
return &responseWriter{ResponseWriter: w}
|
||||
}
|
||||
|
||||
func (rw *responseWriter) Status() int {
|
||||
return rw.status
|
||||
}
|
||||
|
||||
func (rw *responseWriter) WriteHeader(code int) {
|
||||
if rw.wroteHeader {
|
||||
return
|
||||
}
|
||||
|
||||
rw.status = code
|
||||
rw.ResponseWriter.WriteHeader(code)
|
||||
rw.wroteHeader = true
|
||||
}
|
||||
|
||||
// RequestLogger middleware logs all API requests with structured logging
|
||||
func RequestLogger(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
requestID := getRequestID(r.Context())
|
||||
|
||||
// Log request start in DEBUG
|
||||
logger.Logger.Debug("api_request_start",
|
||||
"request_id", requestID,
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
"remote_addr", r.RemoteAddr,
|
||||
"user_agent", r.UserAgent())
|
||||
|
||||
wrapped := wrapResponseWriter(w)
|
||||
next.ServeHTTP(wrapped, r)
|
||||
|
||||
// Log request completion
|
||||
duration := time.Since(start)
|
||||
status := wrapped.status
|
||||
if status == 0 {
|
||||
status = 200
|
||||
}
|
||||
|
||||
fields := []interface{}{
|
||||
"request_id", requestID,
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
"status", status,
|
||||
"duration_ms", duration.Milliseconds(),
|
||||
}
|
||||
|
||||
// Add user email if available
|
||||
if user, ok := GetUserFromContext(r.Context()); ok {
|
||||
fields = append(fields, "user_email", user.Email)
|
||||
}
|
||||
|
||||
// Log at appropriate level based on status
|
||||
if status >= 500 {
|
||||
logger.Logger.Error("api_request_error", fields...)
|
||||
} else if status >= 400 {
|
||||
logger.Logger.Warn("api_request_client_error", fields...)
|
||||
} else {
|
||||
logger.Logger.Info("api_request_complete", fields...)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func getRequestID(ctx context.Context) string {
|
||||
if requestID, ok := ctx.Value(ContextKeyRequestID).(string); ok {
|
||||
return requestID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func errToString(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
// AddRequestIDToContext middleware adds the request ID from chi middleware to our context
|
||||
func AddRequestIDToContext(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestID := middleware.GetReqID(r.Context())
|
||||
ctx := context.WithValue(r.Context(), ContextKeyRequestID, requestID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
321
backend/internal/presentation/api/shared/middleware.go
Normal file
321
backend/internal/presentation/api/shared/middleware.go
Normal file
@@ -0,0 +1,321 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package shared
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/auth"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/logger"
|
||||
)
|
||||
|
||||
// ContextKey represents a context key type
|
||||
type ContextKey string
|
||||
|
||||
const (
|
||||
// ContextKeyUser is the context key for the authenticated user
|
||||
ContextKeyUser ContextKey = "user"
|
||||
// ContextKeyRequestID is the context key for the request ID
|
||||
ContextKeyRequestID ContextKey = "request_id"
|
||||
// CSRFTokenHeader is the header name for CSRF token
|
||||
CSRFTokenHeader = "X-CSRF-Token"
|
||||
// CSRFTokenCookie is the cookie name for CSRF token
|
||||
CSRFTokenCookie = "csrf_token"
|
||||
)
|
||||
|
||||
// Middleware represents API middleware
|
||||
type Middleware struct {
|
||||
authService *auth.OauthService
|
||||
csrfTokens *sync.Map
|
||||
baseURL string
|
||||
adminEmails []string
|
||||
}
|
||||
|
||||
// NewMiddleware creates a new middleware instance
|
||||
func NewMiddleware(authService *auth.OauthService, baseURL string, adminEmails []string) *Middleware {
|
||||
return &Middleware{
|
||||
authService: authService,
|
||||
csrfTokens: &sync.Map{},
|
||||
baseURL: baseURL,
|
||||
adminEmails: adminEmails,
|
||||
}
|
||||
}
|
||||
|
||||
// CORS middleware for handling cross-origin requests
|
||||
func (m *Middleware) CORS(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
origin := r.Header.Get("Origin")
|
||||
|
||||
// In development, allow localhost:5173 (Vite dev server)
|
||||
if origin == "http://localhost:5173" {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, Authorization, X-CSRF-Token")
|
||||
w.Header().Set("Access-Control-Expose-Headers", "X-CSRF-Token")
|
||||
}
|
||||
|
||||
// Handle preflight requests
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// RequireAuth middleware ensures user is authenticated
|
||||
func (m *Middleware) RequireAuth(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestID := getRequestID(r.Context())
|
||||
|
||||
user, err := m.authService.GetUser(r)
|
||||
if err != nil || user == nil {
|
||||
logger.Logger.Debug("authentication_required",
|
||||
"request_id", requestID,
|
||||
"path", r.URL.Path,
|
||||
"method", r.Method,
|
||||
"error", errToString(err))
|
||||
WriteUnauthorized(w, "Authentication required")
|
||||
return
|
||||
}
|
||||
|
||||
logger.Logger.Debug("authentication_success",
|
||||
"request_id", requestID,
|
||||
"user_email", user.Email,
|
||||
"path", r.URL.Path)
|
||||
|
||||
// Add user to context
|
||||
ctx := context.WithValue(r.Context(), ContextKeyUser, user)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// OptionalAuth middleware adds user to context if authenticated, but doesn't block if not
|
||||
func (m *Middleware) OptionalAuth(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestID := getRequestID(r.Context())
|
||||
|
||||
user, err := m.authService.GetUser(r)
|
||||
if err == nil && user != nil {
|
||||
// User is authenticated, add to context
|
||||
logger.Logger.Debug("optional_auth_success",
|
||||
"request_id", requestID,
|
||||
"user_email", user.Email,
|
||||
"path", r.URL.Path)
|
||||
ctx := context.WithValue(r.Context(), ContextKeyUser, user)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
} else {
|
||||
// User not authenticated, continue without user in context
|
||||
logger.Logger.Debug("optional_auth_none",
|
||||
"request_id", requestID,
|
||||
"path", r.URL.Path)
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// RequireAdmin middleware ensures user is an admin
|
||||
func (m *Middleware) RequireAdmin(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestID := getRequestID(r.Context())
|
||||
|
||||
user, err := m.authService.GetUser(r)
|
||||
if err != nil || user == nil {
|
||||
logger.Logger.Debug("admin_authentication_required",
|
||||
"request_id", requestID,
|
||||
"path", r.URL.Path,
|
||||
"error", errToString(err))
|
||||
WriteUnauthorized(w, "Authentication required")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if user is admin
|
||||
isAdmin := false
|
||||
for _, adminEmail := range m.adminEmails {
|
||||
if strings.EqualFold(user.Email, adminEmail) {
|
||||
isAdmin = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !isAdmin {
|
||||
logger.Logger.Warn("admin_access_denied",
|
||||
"request_id", requestID,
|
||||
"user_email", user.Email,
|
||||
"path", r.URL.Path)
|
||||
WriteForbidden(w, "Admin access required")
|
||||
return
|
||||
}
|
||||
|
||||
logger.Logger.Debug("admin_access_granted",
|
||||
"request_id", requestID,
|
||||
"user_email", user.Email,
|
||||
"path", r.URL.Path)
|
||||
|
||||
// Add user to context
|
||||
ctx := context.WithValue(r.Context(), ContextKeyUser, user)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateCSRFToken generates a new CSRF token
|
||||
func (m *Middleware) GenerateCSRFToken() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
token := base64.URLEncoding.EncodeToString(b)
|
||||
|
||||
// Store token with expiration
|
||||
m.csrfTokens.Store(token, time.Now().Add(24*time.Hour))
|
||||
|
||||
// Clean up expired tokens periodically
|
||||
go m.cleanExpiredTokens()
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// ValidateCSRFToken validates a CSRF token
|
||||
func (m *Middleware) ValidateCSRFToken(token string) bool {
|
||||
if token == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
if val, ok := m.csrfTokens.Load(token); ok {
|
||||
expiry := val.(time.Time)
|
||||
if time.Now().Before(expiry) {
|
||||
return true
|
||||
}
|
||||
// Token expired, remove it
|
||||
m.csrfTokens.Delete(token)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// CSRFProtect middleware for CSRF protection
|
||||
func (m *Middleware) CSRFProtect(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Skip CSRF check for safe methods
|
||||
if r.Method == "GET" || r.Method == "HEAD" || r.Method == "OPTIONS" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Get token from header
|
||||
token := r.Header.Get(CSRFTokenHeader)
|
||||
if token == "" {
|
||||
// Try cookie as fallback
|
||||
if cookie, err := r.Cookie(CSRFTokenCookie); err == nil {
|
||||
token = cookie.Value
|
||||
}
|
||||
}
|
||||
|
||||
if !m.ValidateCSRFToken(token) {
|
||||
WriteError(w, http.StatusForbidden, ErrCodeCSRFInvalid, "Invalid or missing CSRF token", nil)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// cleanExpiredTokens removes expired CSRF tokens
|
||||
func (m *Middleware) cleanExpiredTokens() {
|
||||
m.csrfTokens.Range(func(key, value interface{}) bool {
|
||||
expiry := value.(time.Time)
|
||||
if time.Now().After(expiry) {
|
||||
m.csrfTokens.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// GetUserFromContext retrieves the user from the request context
|
||||
func GetUserFromContext(ctx context.Context) (*models.User, bool) {
|
||||
user, ok := ctx.Value(ContextKeyUser).(*models.User)
|
||||
return user, ok
|
||||
}
|
||||
|
||||
// SecurityHeaders middleware adds security headers
|
||||
func SecurityHeaders(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Security headers
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.Header().Set("X-Frame-Options", "DENY")
|
||||
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
||||
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
w.Header().Set("Permissions-Policy", "geolocation=(), microphone=(), camera=()")
|
||||
|
||||
// CSP for API endpoints
|
||||
w.Header().Set("Content-Security-Policy", "default-src 'none'; frame-ancestors 'none';")
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// RateLimit represents a simple rate limiter
|
||||
type RateLimit struct {
|
||||
attempts *sync.Map
|
||||
limit int
|
||||
window time.Duration
|
||||
}
|
||||
|
||||
// NewRateLimit creates a new rate limiter
|
||||
func NewRateLimit(limit int, window time.Duration) *RateLimit {
|
||||
return &RateLimit{
|
||||
attempts: &sync.Map{},
|
||||
limit: limit,
|
||||
window: window,
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimitMiddleware creates a rate limiting middleware
|
||||
func (rl *RateLimit) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Use IP address as identifier
|
||||
ip := r.RemoteAddr
|
||||
if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" {
|
||||
ip = strings.Split(forwarded, ",")[0]
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Check current attempts
|
||||
if val, ok := rl.attempts.Load(ip); ok {
|
||||
attempts := val.([]time.Time)
|
||||
|
||||
// Filter out old attempts
|
||||
var valid []time.Time
|
||||
for _, t := range attempts {
|
||||
if now.Sub(t) < rl.window {
|
||||
valid = append(valid, t)
|
||||
}
|
||||
}
|
||||
|
||||
if len(valid) >= rl.limit {
|
||||
WriteError(w, http.StatusTooManyRequests, ErrCodeRateLimited, "Rate limit exceeded", map[string]interface{}{
|
||||
"retryAfter": rl.window.Seconds(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
valid = append(valid, now)
|
||||
rl.attempts.Store(ip, valid)
|
||||
} else {
|
||||
rl.attempts.Store(ip, []time.Time{now})
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
1050
backend/internal/presentation/api/shared/middleware_test.go
Normal file
1050
backend/internal/presentation/api/shared/middleware_test.go
Normal file
File diff suppressed because it is too large
Load Diff
68
backend/internal/presentation/api/shared/response.go
Normal file
68
backend/internal/presentation/api/shared/response.go
Normal file
@@ -0,0 +1,68 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package shared
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Response represents a standardized API response
|
||||
type Response struct {
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
Meta map[string]interface{} `json:"meta,omitempty"`
|
||||
}
|
||||
|
||||
// PaginationMeta represents pagination metadata
|
||||
type PaginationMeta struct {
|
||||
Page int `json:"page"`
|
||||
Limit int `json:"limit"`
|
||||
Total int `json:"total"`
|
||||
TotalPages int `json:"totalPages"`
|
||||
}
|
||||
|
||||
// WriteJSON writes a JSON response
|
||||
func WriteJSON(w http.ResponseWriter, statusCode int, data interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
response := Response{
|
||||
Data: data,
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// WriteJSONWithMeta writes a JSON response with metadata
|
||||
func WriteJSONWithMeta(w http.ResponseWriter, statusCode int, data interface{}, meta map[string]interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
response := Response{
|
||||
Data: data,
|
||||
Meta: meta,
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// WritePaginatedJSON writes a paginated JSON response
|
||||
func WritePaginatedJSON(w http.ResponseWriter, data interface{}, page, limit, total int) {
|
||||
totalPages := (total + limit - 1) / limit
|
||||
if totalPages < 1 {
|
||||
totalPages = 1
|
||||
}
|
||||
|
||||
meta := map[string]interface{}{
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"totalPages": totalPages,
|
||||
}
|
||||
|
||||
WriteJSONWithMeta(w, http.StatusOK, data, meta)
|
||||
}
|
||||
|
||||
// WriteNoContent writes a 204 No Content response
|
||||
func WriteNoContent(w http.ResponseWriter) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
240
backend/internal/presentation/api/shared/response_test.go
Normal file
240
backend/internal/presentation/api/shared/response_test.go
Normal file
@@ -0,0 +1,240 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package shared
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWriteJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
data interface{}
|
||||
}{
|
||||
{
|
||||
name: "Write simple string data",
|
||||
statusCode: http.StatusOK,
|
||||
data: "test data",
|
||||
},
|
||||
{
|
||||
name: "Write struct data",
|
||||
statusCode: http.StatusCreated,
|
||||
data: map[string]string{
|
||||
"message": "created successfully",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Write nil data",
|
||||
statusCode: http.StatusOK,
|
||||
data: nil,
|
||||
},
|
||||
{
|
||||
name: "Write error status",
|
||||
statusCode: http.StatusBadRequest,
|
||||
data: map[string]string{"error": "bad request"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
WriteJSON(w, tt.statusCode, tt.data)
|
||||
|
||||
if w.Code != tt.statusCode {
|
||||
t.Errorf("Expected status code %d, got %d", tt.statusCode, w.Code)
|
||||
}
|
||||
|
||||
if contentType := w.Header().Get("Content-Type"); contentType != "application/json" {
|
||||
t.Errorf("Expected Content-Type application/json, got %s", contentType)
|
||||
}
|
||||
|
||||
var response Response
|
||||
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("Failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
// Meta should not be present in simple WriteJSON
|
||||
if response.Meta != nil {
|
||||
t.Error("Expected Meta to be nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteJSONWithMeta(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
data interface{}
|
||||
meta map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "Write with metadata",
|
||||
statusCode: http.StatusOK,
|
||||
data: []string{"item1", "item2"},
|
||||
meta: map[string]interface{}{
|
||||
"count": 2,
|
||||
"page": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Write with empty meta",
|
||||
statusCode: http.StatusOK,
|
||||
data: "test",
|
||||
meta: map[string]interface{}{},
|
||||
},
|
||||
{
|
||||
name: "Write with nil meta",
|
||||
statusCode: http.StatusOK,
|
||||
data: "test",
|
||||
meta: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
WriteJSONWithMeta(w, tt.statusCode, tt.data, tt.meta)
|
||||
|
||||
if w.Code != tt.statusCode {
|
||||
t.Errorf("Expected status code %d, got %d", tt.statusCode, w.Code)
|
||||
}
|
||||
|
||||
if contentType := w.Header().Get("Content-Type"); contentType != "application/json" {
|
||||
t.Errorf("Expected Content-Type application/json, got %s", contentType)
|
||||
}
|
||||
|
||||
var response Response
|
||||
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("Failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
// Check meta is present when provided
|
||||
if tt.meta != nil && len(tt.meta) > 0 {
|
||||
if response.Meta == nil {
|
||||
t.Error("Expected Meta to be present")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWritePaginatedJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data interface{}
|
||||
page int
|
||||
limit int
|
||||
total int
|
||||
expectedTotalPages int
|
||||
}{
|
||||
{
|
||||
name: "Standard pagination",
|
||||
data: []string{"item1", "item2", "item3"},
|
||||
page: 1,
|
||||
limit: 10,
|
||||
total: 25,
|
||||
expectedTotalPages: 3,
|
||||
},
|
||||
{
|
||||
name: "Exact division",
|
||||
data: []string{"item1"},
|
||||
page: 2,
|
||||
limit: 5,
|
||||
total: 10,
|
||||
expectedTotalPages: 2,
|
||||
},
|
||||
{
|
||||
name: "Zero total",
|
||||
data: []string{},
|
||||
page: 1,
|
||||
limit: 10,
|
||||
total: 0,
|
||||
expectedTotalPages: 1, // Minimum 1 page
|
||||
},
|
||||
{
|
||||
name: "Single item",
|
||||
data: []string{"item1"},
|
||||
page: 1,
|
||||
limit: 10,
|
||||
total: 1,
|
||||
expectedTotalPages: 1,
|
||||
},
|
||||
{
|
||||
name: "Large dataset",
|
||||
data: []string{"item1"},
|
||||
page: 5,
|
||||
limit: 50,
|
||||
total: 500,
|
||||
expectedTotalPages: 10,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
WritePaginatedJSON(w, tt.data, tt.page, tt.limit, tt.total)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
var response Response
|
||||
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("Failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if response.Meta == nil {
|
||||
t.Fatal("Expected Meta to be present in paginated response")
|
||||
}
|
||||
|
||||
// Check pagination metadata
|
||||
if page, ok := response.Meta["page"].(float64); !ok || int(page) != tt.page {
|
||||
t.Errorf("Expected page %d, got %v", tt.page, response.Meta["page"])
|
||||
}
|
||||
|
||||
if limit, ok := response.Meta["limit"].(float64); !ok || int(limit) != tt.limit {
|
||||
t.Errorf("Expected limit %d, got %v", tt.limit, response.Meta["limit"])
|
||||
}
|
||||
|
||||
if total, ok := response.Meta["total"].(float64); !ok || int(total) != tt.total {
|
||||
t.Errorf("Expected total %d, got %v", tt.total, response.Meta["total"])
|
||||
}
|
||||
|
||||
if totalPages, ok := response.Meta["totalPages"].(float64); !ok || int(totalPages) != tt.expectedTotalPages {
|
||||
t.Errorf("Expected totalPages %d, got %v", tt.expectedTotalPages, response.Meta["totalPages"])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteNoContent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
WriteNoContent(w)
|
||||
|
||||
if w.Code != http.StatusNoContent {
|
||||
t.Errorf("Expected status code %d, got %d", http.StatusNoContent, w.Code)
|
||||
}
|
||||
|
||||
if w.Body.Len() != 0 {
|
||||
t.Errorf("Expected empty body, got %d bytes", w.Body.Len())
|
||||
}
|
||||
}
|
||||
284
backend/internal/presentation/api/signatures/handler.go
Normal file
284
backend/internal/presentation/api/signatures/handler.go
Normal file
@@ -0,0 +1,284 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package signatures
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/presentation/api/shared"
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
// signatureService defines the interface for signature operations
|
||||
type signatureService interface {
|
||||
CreateSignature(ctx context.Context, request *models.SignatureRequest) error
|
||||
GetSignatureStatus(ctx context.Context, docID string, user *models.User) (*models.SignatureStatus, error)
|
||||
GetSignatureByDocAndUser(ctx context.Context, docID string, user *models.User) (*models.Signature, error)
|
||||
GetDocumentSignatures(ctx context.Context, docID string) ([]*models.Signature, error)
|
||||
GetUserSignatures(ctx context.Context, user *models.User) ([]*models.Signature, error)
|
||||
}
|
||||
|
||||
// Handler handles signature-related requests
|
||||
type Handler struct {
|
||||
signatureService signatureService
|
||||
}
|
||||
|
||||
// NewHandler creates a new signature handler
|
||||
func NewHandler(signatureService signatureService) *Handler {
|
||||
return &Handler{
|
||||
signatureService: signatureService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateSignatureRequest represents the request body for creating a signature
|
||||
type CreateSignatureRequest struct {
|
||||
DocID string `json:"docId"`
|
||||
Referer *string `json:"referer,omitempty"`
|
||||
}
|
||||
|
||||
// SignatureResponse represents a signature in API responses
|
||||
type SignatureResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
DocID string `json:"docId"`
|
||||
UserSub string `json:"userSub"`
|
||||
UserEmail string `json:"userEmail"`
|
||||
UserName string `json:"userName,omitempty"`
|
||||
SignedAt string `json:"signedAt"`
|
||||
PayloadHash string `json:"payloadHash"`
|
||||
Signature string `json:"signature"`
|
||||
Nonce string `json:"nonce"`
|
||||
CreatedAt string `json:"createdAt"`
|
||||
Referer *string `json:"referer,omitempty"`
|
||||
PrevHash *string `json:"prevHash,omitempty"`
|
||||
ServiceInfo *ServiceInfoResult `json:"serviceInfo,omitempty"`
|
||||
DocDeletedAt *string `json:"docDeletedAt,omitempty"`
|
||||
// Document metadata
|
||||
DocTitle *string `json:"docTitle,omitempty"`
|
||||
DocUrl *string `json:"docUrl,omitempty"`
|
||||
}
|
||||
|
||||
// ServiceInfoResult represents service detection information
|
||||
type ServiceInfoResult struct {
|
||||
Name string `json:"name"`
|
||||
Icon string `json:"icon"`
|
||||
Type string `json:"type"`
|
||||
Referrer string `json:"referrer"`
|
||||
}
|
||||
|
||||
// SignatureStatusResponse represents the signature status for a document
|
||||
type SignatureStatusResponse struct {
|
||||
DocID string `json:"docId"`
|
||||
UserEmail string `json:"userEmail"`
|
||||
IsSigned bool `json:"isSigned"`
|
||||
SignedAt *string `json:"signedAt,omitempty"`
|
||||
}
|
||||
|
||||
// HandleCreateSignature handles POST /api/v1/signatures
|
||||
func (h *Handler) HandleCreateSignature(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Get user from context (set by RequireAuth middleware)
|
||||
user, ok := shared.GetUserFromContext(ctx)
|
||||
if !ok || user == nil {
|
||||
shared.WriteUnauthorized(w, "Authentication required")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request body
|
||||
var req CreateSignatureRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
shared.WriteError(w, http.StatusBadRequest, shared.ErrCodeBadRequest, "Invalid request body", map[string]interface{}{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate document ID
|
||||
if req.DocID == "" {
|
||||
shared.WriteError(w, http.StatusBadRequest, shared.ErrCodeBadRequest, "Document ID is required", nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Create signature request
|
||||
sigRequest := &models.SignatureRequest{
|
||||
DocID: req.DocID,
|
||||
User: user,
|
||||
Referer: req.Referer,
|
||||
}
|
||||
|
||||
// Create signature
|
||||
err := h.signatureService.CreateSignature(ctx, sigRequest)
|
||||
if err != nil {
|
||||
if err == models.ErrSignatureAlreadyExists {
|
||||
shared.WriteConflict(w, "You have already signed this document")
|
||||
return
|
||||
}
|
||||
|
||||
if err == models.ErrInvalidDocument {
|
||||
shared.WriteError(w, http.StatusBadRequest, shared.ErrCodeBadRequest, "Invalid document", nil)
|
||||
return
|
||||
}
|
||||
|
||||
if err == models.ErrDocumentModified {
|
||||
shared.WriteError(w, http.StatusConflict, "DOCUMENT_MODIFIED", "The document has been modified since it was created. Please verify the current version before signing.", map[string]interface{}{
|
||||
"docId": req.DocID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
shared.WriteError(w, http.StatusInternalServerError, shared.ErrCodeInternal, "Failed to create signature", map[string]interface{}{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Get the created signature to return it
|
||||
signature, err := h.signatureService.GetSignatureByDocAndUser(ctx, req.DocID, user)
|
||||
if err != nil {
|
||||
// Signature was created but we couldn't retrieve it
|
||||
shared.WriteJSON(w, http.StatusCreated, map[string]interface{}{
|
||||
"message": "Signature created successfully",
|
||||
"docId": req.DocID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Return the created signature
|
||||
shared.WriteJSON(w, http.StatusCreated, h.toSignatureResponse(ctx, signature))
|
||||
}
|
||||
|
||||
// HandleGetUserSignatures handles GET /api/v1/signatures
|
||||
func (h *Handler) HandleGetUserSignatures(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Get user from context
|
||||
user, ok := shared.GetUserFromContext(ctx)
|
||||
if !ok || user == nil {
|
||||
shared.WriteUnauthorized(w, "Authentication required")
|
||||
return
|
||||
}
|
||||
|
||||
// Get user's signatures
|
||||
signatures, err := h.signatureService.GetUserSignatures(ctx, user)
|
||||
if err != nil {
|
||||
shared.WriteError(w, http.StatusInternalServerError, shared.ErrCodeInternal, "Failed to fetch signatures", map[string]interface{}{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Convert to response format
|
||||
response := make([]*SignatureResponse, 0, len(signatures))
|
||||
for _, sig := range signatures {
|
||||
response = append(response, h.toSignatureResponse(ctx, sig))
|
||||
}
|
||||
|
||||
shared.WriteJSON(w, http.StatusOK, response)
|
||||
}
|
||||
|
||||
// HandleGetDocumentSignatures handles GET /api/v1/documents/{docId}/signatures
|
||||
func (h *Handler) HandleGetDocumentSignatures(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Get document ID from URL
|
||||
docID := chi.URLParam(r, "docId")
|
||||
if docID == "" {
|
||||
shared.WriteError(w, http.StatusBadRequest, shared.ErrCodeBadRequest, "Document ID is required", nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Get document signatures
|
||||
signatures, err := h.signatureService.GetDocumentSignatures(ctx, docID)
|
||||
if err != nil {
|
||||
shared.WriteError(w, http.StatusInternalServerError, shared.ErrCodeInternal, "Failed to fetch signatures", map[string]interface{}{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Convert to response format
|
||||
response := make([]*SignatureResponse, 0, len(signatures))
|
||||
for _, sig := range signatures {
|
||||
response = append(response, h.toSignatureResponse(ctx, sig))
|
||||
}
|
||||
|
||||
shared.WriteJSON(w, http.StatusOK, response)
|
||||
}
|
||||
|
||||
// HandleGetSignatureStatus handles GET /api/v1/documents/{docId}/signatures/status
|
||||
func (h *Handler) HandleGetSignatureStatus(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Get user from context
|
||||
user, ok := shared.GetUserFromContext(ctx)
|
||||
if !ok || user == nil {
|
||||
shared.WriteUnauthorized(w, "Authentication required")
|
||||
return
|
||||
}
|
||||
|
||||
// Get document ID from URL
|
||||
docID := chi.URLParam(r, "docId")
|
||||
if docID == "" {
|
||||
shared.WriteError(w, http.StatusBadRequest, shared.ErrCodeBadRequest, "Document ID is required", nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Get signature status
|
||||
status, err := h.signatureService.GetSignatureStatus(ctx, docID, user)
|
||||
if err != nil {
|
||||
shared.WriteError(w, http.StatusInternalServerError, shared.ErrCodeInternal, "Failed to fetch signature status", map[string]interface{}{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Convert to response format
|
||||
response := SignatureStatusResponse{
|
||||
DocID: status.DocID,
|
||||
UserEmail: status.UserEmail,
|
||||
IsSigned: status.IsSigned,
|
||||
}
|
||||
|
||||
if status.SignedAt != nil {
|
||||
signedAt := status.SignedAt.Format("2006-01-02T15:04:05Z07:00")
|
||||
response.SignedAt = &signedAt
|
||||
}
|
||||
|
||||
shared.WriteJSON(w, http.StatusOK, response)
|
||||
}
|
||||
|
||||
// toSignatureResponse converts a domain signature to API response format
|
||||
func (h *Handler) toSignatureResponse(ctx context.Context, sig *models.Signature) *SignatureResponse {
|
||||
response := &SignatureResponse{
|
||||
ID: sig.ID,
|
||||
DocID: sig.DocID,
|
||||
UserSub: sig.UserSub,
|
||||
UserEmail: sig.UserEmail,
|
||||
UserName: sig.UserName,
|
||||
SignedAt: sig.SignedAtUTC.Format("2006-01-02T15:04:05Z07:00"),
|
||||
PayloadHash: sig.PayloadHash,
|
||||
Signature: sig.Signature,
|
||||
Nonce: sig.Nonce,
|
||||
CreatedAt: sig.CreatedAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||
Referer: sig.Referer,
|
||||
PrevHash: sig.PrevHash,
|
||||
}
|
||||
|
||||
// Add doc_deleted_at if document was deleted
|
||||
if sig.DocDeletedAt != nil {
|
||||
deletedAt := sig.DocDeletedAt.Format("2006-01-02T15:04:05Z07:00")
|
||||
response.DocDeletedAt = &deletedAt
|
||||
}
|
||||
|
||||
// Add service info if available
|
||||
if serviceInfo := sig.GetServiceInfo(); serviceInfo != nil {
|
||||
response.ServiceInfo = &ServiceInfoResult{
|
||||
Name: serviceInfo.Name,
|
||||
Icon: serviceInfo.Icon,
|
||||
Type: serviceInfo.Type,
|
||||
Referrer: serviceInfo.Referrer,
|
||||
}
|
||||
}
|
||||
|
||||
// Document metadata is enriched from LEFT JOIN in repository
|
||||
if sig.DocTitle != "" {
|
||||
response.DocTitle = &sig.DocTitle
|
||||
}
|
||||
if sig.DocURL != "" {
|
||||
response.DocUrl = &sig.DocURL
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
899
backend/internal/presentation/api/signatures/handler_test.go
Normal file
899
backend/internal/presentation/api/signatures/handler_test.go
Normal file
@@ -0,0 +1,899 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package signatures
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/presentation/api/shared"
|
||||
)
|
||||
|
||||
// ============================================================================
|
||||
// TEST FIXTURES & MOCKS
|
||||
// ============================================================================
|
||||
|
||||
var (
|
||||
testUser = &models.User{
|
||||
Sub: "oauth2|123",
|
||||
Email: "user@example.com",
|
||||
Name: "Test User",
|
||||
}
|
||||
|
||||
testDoc = &models.Document{
|
||||
DocID: "test-doc-123",
|
||||
Title: "Test Document",
|
||||
URL: "https://example.com/doc.pdf",
|
||||
}
|
||||
|
||||
testSignature = &models.Signature{
|
||||
ID: 1,
|
||||
DocID: "test-doc-123",
|
||||
UserSub: "oauth2|123",
|
||||
UserEmail: "user@example.com",
|
||||
UserName: "Test User",
|
||||
SignedAtUTC: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC),
|
||||
DocChecksum: "checksum-123",
|
||||
PayloadHash: "hash-123",
|
||||
Signature: "sig-123",
|
||||
Nonce: "nonce-123",
|
||||
CreatedAt: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC),
|
||||
Referer: stringPtr("https://github.com/owner/repo"),
|
||||
PrevHash: stringPtr("prev-hash-123"),
|
||||
HashVersion: 2,
|
||||
DocTitle: "Test Document",
|
||||
DocURL: "https://example.com/doc.pdf",
|
||||
}
|
||||
|
||||
testSignatureStatus = &models.SignatureStatus{
|
||||
DocID: "test-doc-123",
|
||||
UserEmail: "user@example.com",
|
||||
IsSigned: true,
|
||||
SignedAt: timePtr(time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)),
|
||||
}
|
||||
)
|
||||
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
func timePtr(t time.Time) *time.Time {
|
||||
return &t
|
||||
}
|
||||
|
||||
// Mock signature service
|
||||
type mockSignatureService struct {
|
||||
createSignatureFunc func(ctx context.Context, request *models.SignatureRequest) error
|
||||
getSignatureStatusFunc func(ctx context.Context, docID string, user *models.User) (*models.SignatureStatus, error)
|
||||
getSignatureByDocAndUserFunc func(ctx context.Context, docID string, user *models.User) (*models.Signature, error)
|
||||
getDocumentSignaturesFunc func(ctx context.Context, docID string) ([]*models.Signature, error)
|
||||
getUserSignaturesFunc func(ctx context.Context, user *models.User) ([]*models.Signature, error)
|
||||
}
|
||||
|
||||
func (m *mockSignatureService) CreateSignature(ctx context.Context, request *models.SignatureRequest) error {
|
||||
if m.createSignatureFunc != nil {
|
||||
return m.createSignatureFunc(ctx, request)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSignatureService) GetSignatureStatus(ctx context.Context, docID string, user *models.User) (*models.SignatureStatus, error) {
|
||||
if m.getSignatureStatusFunc != nil {
|
||||
return m.getSignatureStatusFunc(ctx, docID, user)
|
||||
}
|
||||
return testSignatureStatus, nil
|
||||
}
|
||||
|
||||
func (m *mockSignatureService) GetSignatureByDocAndUser(ctx context.Context, docID string, user *models.User) (*models.Signature, error) {
|
||||
if m.getSignatureByDocAndUserFunc != nil {
|
||||
return m.getSignatureByDocAndUserFunc(ctx, docID, user)
|
||||
}
|
||||
return testSignature, nil
|
||||
}
|
||||
|
||||
func (m *mockSignatureService) GetDocumentSignatures(ctx context.Context, docID string) ([]*models.Signature, error) {
|
||||
if m.getDocumentSignaturesFunc != nil {
|
||||
return m.getDocumentSignaturesFunc(ctx, docID)
|
||||
}
|
||||
return []*models.Signature{testSignature}, nil
|
||||
}
|
||||
|
||||
func (m *mockSignatureService) GetUserSignatures(ctx context.Context, user *models.User) ([]*models.Signature, error) {
|
||||
if m.getUserSignaturesFunc != nil {
|
||||
return m.getUserSignaturesFunc(ctx, user)
|
||||
}
|
||||
return []*models.Signature{testSignature}, nil
|
||||
}
|
||||
|
||||
func createTestHandler() *Handler {
|
||||
return &Handler{
|
||||
signatureService: &mockSignatureService{},
|
||||
}
|
||||
}
|
||||
|
||||
func addUserToContext(ctx context.Context, user *models.User) context.Context {
|
||||
return context.WithValue(ctx, shared.ContextKeyUser, user)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - Constructor
|
||||
// ============================================================================
|
||||
|
||||
func TestNewHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sigService := &mockSignatureService{}
|
||||
|
||||
handler := NewHandler(sigService)
|
||||
|
||||
assert.NotNil(t, handler)
|
||||
assert.Equal(t, sigService, handler.signatureService)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - HandleCreateSignature
|
||||
// ============================================================================
|
||||
|
||||
func TestHandler_HandleCreateSignature_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
docID string
|
||||
referer *string
|
||||
checkReq func(t *testing.T, req *models.SignatureRequest)
|
||||
}{
|
||||
{
|
||||
name: "with referer",
|
||||
docID: "test-doc-123",
|
||||
referer: stringPtr("https://github.com/owner/repo"),
|
||||
checkReq: func(t *testing.T, req *models.SignatureRequest) {
|
||||
assert.Equal(t, "test-doc-123", req.DocID)
|
||||
assert.NotNil(t, req.Referer)
|
||||
assert.Equal(t, "https://github.com/owner/repo", *req.Referer)
|
||||
assert.Equal(t, testUser.Email, req.User.Email)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "without referer",
|
||||
docID: "test-doc-456",
|
||||
referer: nil,
|
||||
checkReq: func(t *testing.T, req *models.SignatureRequest) {
|
||||
assert.Equal(t, "test-doc-456", req.DocID)
|
||||
assert.Nil(t, req.Referer)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockSigService := &mockSignatureService{
|
||||
createSignatureFunc: func(ctx context.Context, request *models.SignatureRequest) error {
|
||||
tt.checkReq(t, request)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
handler := &Handler{
|
||||
signatureService: mockSigService,
|
||||
}
|
||||
|
||||
reqBody := CreateSignatureRequest{
|
||||
DocID: tt.docID,
|
||||
Referer: tt.referer,
|
||||
}
|
||||
body, err := json.Marshal(reqBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/signatures", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
ctx := addUserToContext(req.Context(), testUser)
|
||||
req = req.WithContext(ctx)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCreateSignature(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusCreated, rec.Code)
|
||||
assert.Equal(t, "application/json", rec.Header().Get("Content-Type"))
|
||||
|
||||
var wrapper struct {
|
||||
Data SignatureResponse `json:"data"`
|
||||
}
|
||||
err = json.Unmarshal(rec.Body.Bytes(), &wrapper)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, testSignature.ID, wrapper.Data.ID)
|
||||
assert.Equal(t, testSignature.DocID, wrapper.Data.DocID)
|
||||
assert.Equal(t, testSignature.UserEmail, wrapper.Data.UserEmail)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_HandleCreateSignature_Unauthorized(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := createTestHandler()
|
||||
|
||||
reqBody := CreateSignatureRequest{
|
||||
DocID: "test-doc-123",
|
||||
}
|
||||
body, err := json.Marshal(reqBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/signatures", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
// No user in context
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCreateSignature(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
}
|
||||
|
||||
func TestHandler_HandleCreateSignature_ValidationErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
requestBody interface{}
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "empty docID",
|
||||
requestBody: CreateSignatureRequest{DocID: ""},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
requestBody: "invalid json",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := createTestHandler()
|
||||
|
||||
var body []byte
|
||||
var err error
|
||||
if str, ok := tt.requestBody.(string); ok {
|
||||
body = []byte(str)
|
||||
} else {
|
||||
body, err = json.Marshal(tt.requestBody)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/signatures", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
ctx := addUserToContext(req.Context(), testUser)
|
||||
req = req.WithContext(ctx)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCreateSignature(rec, req)
|
||||
|
||||
assert.Equal(t, tt.expectedStatus, rec.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_HandleCreateSignature_ServiceErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
serviceError error
|
||||
expectedStatus int
|
||||
expectedMsg string
|
||||
}{
|
||||
{
|
||||
name: "signature already exists",
|
||||
serviceError: models.ErrSignatureAlreadyExists,
|
||||
expectedStatus: http.StatusConflict,
|
||||
expectedMsg: "You have already signed this document",
|
||||
},
|
||||
{
|
||||
name: "invalid document",
|
||||
serviceError: models.ErrInvalidDocument,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedMsg: "Invalid document",
|
||||
},
|
||||
{
|
||||
name: "document modified",
|
||||
serviceError: models.ErrDocumentModified,
|
||||
expectedStatus: http.StatusConflict,
|
||||
expectedMsg: "The document has been modified since it was created",
|
||||
},
|
||||
{
|
||||
name: "generic error",
|
||||
serviceError: fmt.Errorf("database error"),
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedMsg: "Failed to create signature",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockSigService := &mockSignatureService{
|
||||
createSignatureFunc: func(ctx context.Context, request *models.SignatureRequest) error {
|
||||
return tt.serviceError
|
||||
},
|
||||
}
|
||||
|
||||
handler := &Handler{
|
||||
signatureService: mockSigService,
|
||||
}
|
||||
|
||||
reqBody := CreateSignatureRequest{
|
||||
DocID: "test-doc-123",
|
||||
}
|
||||
body, err := json.Marshal(reqBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/signatures", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
ctx := addUserToContext(req.Context(), testUser)
|
||||
req = req.WithContext(ctx)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCreateSignature(rec, req)
|
||||
|
||||
assert.Equal(t, tt.expectedStatus, rec.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Contains(t, response, "error")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - HandleGetUserSignatures
|
||||
// ============================================================================
|
||||
|
||||
func TestHandler_HandleGetUserSignatures_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockSigService := &mockSignatureService{
|
||||
getUserSignaturesFunc: func(ctx context.Context, user *models.User) ([]*models.Signature, error) {
|
||||
assert.Equal(t, testUser.Email, user.Email)
|
||||
return []*models.Signature{testSignature}, nil
|
||||
},
|
||||
}
|
||||
|
||||
handler := &Handler{
|
||||
signatureService: mockSigService,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/signatures", nil)
|
||||
ctx := addUserToContext(req.Context(), testUser)
|
||||
req = req.WithContext(ctx)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleGetUserSignatures(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var wrapper struct {
|
||||
Data []*SignatureResponse `json:"data"`
|
||||
}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &wrapper)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, wrapper.Data, 1)
|
||||
assert.Equal(t, testSignature.ID, wrapper.Data[0].ID)
|
||||
assert.Equal(t, testSignature.DocID, wrapper.Data[0].DocID)
|
||||
}
|
||||
|
||||
func TestHandler_HandleGetUserSignatures_Unauthorized(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := createTestHandler()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/signatures", nil)
|
||||
// No user in context
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleGetUserSignatures(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
}
|
||||
|
||||
func TestHandler_HandleGetUserSignatures_ServiceError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockSigService := &mockSignatureService{
|
||||
getUserSignaturesFunc: func(ctx context.Context, user *models.User) ([]*models.Signature, error) {
|
||||
return nil, fmt.Errorf("database error")
|
||||
},
|
||||
}
|
||||
|
||||
handler := &Handler{
|
||||
signatureService: mockSigService,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/signatures", nil)
|
||||
ctx := addUserToContext(req.Context(), testUser)
|
||||
req = req.WithContext(ctx)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleGetUserSignatures(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - HandleGetDocumentSignatures
|
||||
// ============================================================================
|
||||
|
||||
func TestHandler_HandleGetDocumentSignatures_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockSigService := &mockSignatureService{
|
||||
getDocumentSignaturesFunc: func(ctx context.Context, docID string) ([]*models.Signature, error) {
|
||||
assert.Equal(t, "test-doc-123", docID)
|
||||
return []*models.Signature{testSignature}, nil
|
||||
},
|
||||
}
|
||||
|
||||
handler := &Handler{
|
||||
signatureService: mockSigService,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/documents/test-doc-123/signatures", nil)
|
||||
|
||||
// Add chi context with URL param
|
||||
rctx := chi.NewRouteContext()
|
||||
rctx.URLParams.Add("docId", "test-doc-123")
|
||||
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleGetDocumentSignatures(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var wrapper struct {
|
||||
Data []*SignatureResponse `json:"data"`
|
||||
}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &wrapper)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, wrapper.Data, 1)
|
||||
assert.Equal(t, testSignature.DocID, wrapper.Data[0].DocID)
|
||||
}
|
||||
|
||||
func TestHandler_HandleGetDocumentSignatures_MissingDocID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := createTestHandler()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/documents//signatures", nil)
|
||||
|
||||
rctx := chi.NewRouteContext()
|
||||
rctx.URLParams.Add("docId", "")
|
||||
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleGetDocumentSignatures(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestHandler_HandleGetDocumentSignatures_ServiceError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockSigService := &mockSignatureService{
|
||||
getDocumentSignaturesFunc: func(ctx context.Context, docID string) ([]*models.Signature, error) {
|
||||
return nil, fmt.Errorf("database error")
|
||||
},
|
||||
}
|
||||
|
||||
handler := &Handler{
|
||||
signatureService: mockSigService,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/documents/test-doc-123/signatures", nil)
|
||||
|
||||
rctx := chi.NewRouteContext()
|
||||
rctx.URLParams.Add("docId", "test-doc-123")
|
||||
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleGetDocumentSignatures(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - HandleGetSignatureStatus
|
||||
// ============================================================================
|
||||
|
||||
func TestHandler_HandleGetSignatureStatus_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
status *models.SignatureStatus
|
||||
expectSigned bool
|
||||
}{
|
||||
{
|
||||
name: "signed document",
|
||||
status: &models.SignatureStatus{
|
||||
DocID: "test-doc-123",
|
||||
UserEmail: "user@example.com",
|
||||
IsSigned: true,
|
||||
SignedAt: timePtr(time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)),
|
||||
},
|
||||
expectSigned: true,
|
||||
},
|
||||
{
|
||||
name: "unsigned document",
|
||||
status: &models.SignatureStatus{
|
||||
DocID: "test-doc-456",
|
||||
UserEmail: "user@example.com",
|
||||
IsSigned: false,
|
||||
SignedAt: nil,
|
||||
},
|
||||
expectSigned: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockSigService := &mockSignatureService{
|
||||
getSignatureStatusFunc: func(ctx context.Context, docID string, user *models.User) (*models.SignatureStatus, error) {
|
||||
return tt.status, nil
|
||||
},
|
||||
}
|
||||
|
||||
handler := &Handler{
|
||||
signatureService: mockSigService,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/documents/"+tt.status.DocID+"/signatures/status", nil)
|
||||
|
||||
rctx := chi.NewRouteContext()
|
||||
rctx.URLParams.Add("docId", tt.status.DocID)
|
||||
ctx := context.WithValue(req.Context(), chi.RouteCtxKey, rctx)
|
||||
ctx = addUserToContext(ctx, testUser)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleGetSignatureStatus(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var wrapper struct {
|
||||
Data SignatureStatusResponse `json:"data"`
|
||||
}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &wrapper)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tt.status.DocID, wrapper.Data.DocID)
|
||||
assert.Equal(t, tt.status.UserEmail, wrapper.Data.UserEmail)
|
||||
assert.Equal(t, tt.expectSigned, wrapper.Data.IsSigned)
|
||||
|
||||
if tt.expectSigned {
|
||||
assert.NotNil(t, wrapper.Data.SignedAt)
|
||||
} else {
|
||||
assert.Nil(t, wrapper.Data.SignedAt)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_HandleGetSignatureStatus_Unauthorized(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := createTestHandler()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/documents/test-doc-123/signatures/status", nil)
|
||||
|
||||
rctx := chi.NewRouteContext()
|
||||
rctx.URLParams.Add("docId", "test-doc-123")
|
||||
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
|
||||
// No user in context
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleGetSignatureStatus(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
}
|
||||
|
||||
func TestHandler_HandleGetSignatureStatus_MissingDocID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := createTestHandler()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/documents//signatures/status", nil)
|
||||
|
||||
rctx := chi.NewRouteContext()
|
||||
rctx.URLParams.Add("docId", "")
|
||||
ctx := addUserToContext(req.Context(), testUser)
|
||||
req = req.WithContext(context.WithValue(ctx, chi.RouteCtxKey, rctx))
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleGetSignatureStatus(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - toSignatureResponse
|
||||
// ============================================================================
|
||||
|
||||
func Test_toSignatureResponse(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sig *models.Signature
|
||||
checkDTO func(t *testing.T, resp *SignatureResponse)
|
||||
}{
|
||||
{
|
||||
name: "with all fields",
|
||||
sig: testSignature,
|
||||
checkDTO: func(t *testing.T, resp *SignatureResponse) {
|
||||
assert.Equal(t, testSignature.ID, resp.ID)
|
||||
assert.Equal(t, testSignature.DocID, resp.DocID)
|
||||
assert.Equal(t, testSignature.UserEmail, resp.UserEmail)
|
||||
assert.NotNil(t, resp.Referer)
|
||||
assert.NotNil(t, resp.PrevHash)
|
||||
assert.NotNil(t, resp.DocTitle)
|
||||
assert.NotNil(t, resp.DocUrl)
|
||||
// Service info may be populated depending on referer URL detection
|
||||
// We just verify the field structure exists
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "without referer",
|
||||
sig: &models.Signature{
|
||||
ID: 2,
|
||||
DocID: "doc-456",
|
||||
UserSub: "oauth2|456",
|
||||
UserEmail: "user2@example.com",
|
||||
UserName: "User 2",
|
||||
SignedAtUTC: time.Date(2024, 1, 2, 10, 0, 0, 0, time.UTC),
|
||||
PayloadHash: "hash-456",
|
||||
Signature: "sig-456",
|
||||
Nonce: "nonce-456",
|
||||
CreatedAt: time.Date(2024, 1, 2, 10, 0, 0, 0, time.UTC),
|
||||
Referer: nil,
|
||||
PrevHash: nil,
|
||||
},
|
||||
checkDTO: func(t *testing.T, resp *SignatureResponse) {
|
||||
assert.Equal(t, int64(2), resp.ID)
|
||||
assert.Nil(t, resp.Referer)
|
||||
assert.Nil(t, resp.PrevHash)
|
||||
assert.Nil(t, resp.ServiceInfo)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := createTestHandler()
|
||||
|
||||
resp := handler.toSignatureResponse(context.Background(), tt.sig)
|
||||
tt.checkDTO(t, resp)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_toSignatureResponse_ServiceInfo(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
referer string
|
||||
}{
|
||||
{
|
||||
name: "GitHub URL",
|
||||
referer: "https://github.com/owner/repo",
|
||||
},
|
||||
{
|
||||
name: "GitLab URL",
|
||||
referer: "https://gitlab.com/owner/repo",
|
||||
},
|
||||
{
|
||||
name: "Generic URL",
|
||||
referer: "https://example.com/path",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sig := &models.Signature{
|
||||
ID: 1,
|
||||
DocID: "test-doc",
|
||||
UserSub: "oauth2|123",
|
||||
UserEmail: "user@example.com",
|
||||
SignedAtUTC: time.Now(),
|
||||
PayloadHash: "hash",
|
||||
Signature: "sig",
|
||||
Nonce: "nonce",
|
||||
CreatedAt: time.Now(),
|
||||
Referer: &tt.referer,
|
||||
}
|
||||
|
||||
handler := createTestHandler()
|
||||
resp := handler.toSignatureResponse(context.Background(), sig)
|
||||
|
||||
// Just verify the response is created correctly
|
||||
// Service info detection is tested in the services package
|
||||
assert.Equal(t, sig.ID, resp.ID)
|
||||
assert.Equal(t, sig.UserEmail, resp.UserEmail)
|
||||
assert.NotNil(t, resp.Referer)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - Concurrency
|
||||
// ============================================================================
|
||||
|
||||
func TestHandler_HandleCreateSignature_Concurrent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := createTestHandler()
|
||||
|
||||
const numRequests = 50
|
||||
done := make(chan bool, numRequests)
|
||||
errors := make(chan error, numRequests)
|
||||
|
||||
for i := 0; i < numRequests; i++ {
|
||||
go func(id int) {
|
||||
defer func() { done <- true }()
|
||||
|
||||
reqBody := CreateSignatureRequest{
|
||||
DocID: fmt.Sprintf("doc-%d", id),
|
||||
}
|
||||
body, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/signatures", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
ctx := addUserToContext(req.Context(), testUser)
|
||||
req = req.WithContext(ctx)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCreateSignature(rec, req)
|
||||
|
||||
if rec.Code != http.StatusCreated {
|
||||
errors <- fmt.Errorf("unexpected status: %d", rec.Code)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
for i := 0; i < numRequests; i++ {
|
||||
<-done
|
||||
}
|
||||
close(errors)
|
||||
|
||||
var errCount int
|
||||
for err := range errors {
|
||||
t.Logf("Concurrent request error: %v", err)
|
||||
errCount++
|
||||
}
|
||||
|
||||
assert.Equal(t, 0, errCount, "All concurrent requests should succeed")
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// BENCHMARKS
|
||||
// ============================================================================
|
||||
|
||||
func BenchmarkHandler_HandleCreateSignature(b *testing.B) {
|
||||
handler := createTestHandler()
|
||||
|
||||
reqBody := CreateSignatureRequest{
|
||||
DocID: "test-doc-123",
|
||||
}
|
||||
body, _ := json.Marshal(reqBody)
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/signatures", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
ctx := addUserToContext(req.Context(), testUser)
|
||||
req = req.WithContext(ctx)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCreateSignature(rec, req)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHandler_HandleCreateSignature_Parallel(b *testing.B) {
|
||||
handler := createTestHandler()
|
||||
|
||||
reqBody := CreateSignatureRequest{
|
||||
DocID: "test-doc-123",
|
||||
}
|
||||
body, _ := json.Marshal(reqBody)
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/signatures", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
ctx := addUserToContext(req.Context(), testUser)
|
||||
req = req.WithContext(ctx)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCreateSignature(rec, req)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkHandler_HandleGetUserSignatures(b *testing.B) {
|
||||
handler := createTestHandler()
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/signatures", nil)
|
||||
ctx := addUserToContext(req.Context(), testUser)
|
||||
req = req.WithContext(ctx)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleGetUserSignatures(rec, req)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHandler_HandleGetUserSignatures_Parallel(b *testing.B) {
|
||||
handler := createTestHandler()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/signatures", nil)
|
||||
ctx := addUserToContext(req.Context(), testUser)
|
||||
req = req.WithContext(ctx)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleGetUserSignatures(rec, req)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Benchmark_toSignatureResponse(b *testing.B) {
|
||||
handler := createTestHandler()
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
handler.toSignatureResponse(ctx, testSignature)
|
||||
}
|
||||
}
|
||||
56
backend/internal/presentation/api/users/handler.go
Normal file
56
backend/internal/presentation/api/users/handler.go
Normal file
@@ -0,0 +1,56 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package users
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/presentation/api/shared"
|
||||
)
|
||||
|
||||
// Handler handles user API requests
|
||||
type Handler struct {
|
||||
adminEmails []string
|
||||
}
|
||||
|
||||
// NewHandler creates a new users handler
|
||||
func NewHandler(adminEmails []string) *Handler {
|
||||
return &Handler{
|
||||
adminEmails: adminEmails,
|
||||
}
|
||||
}
|
||||
|
||||
// UserDTO represents a user data transfer object
|
||||
type UserDTO struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
IsAdmin bool `json:"isAdmin"`
|
||||
}
|
||||
|
||||
// HandleGetCurrentUser handles GET /api/v1/users/me
|
||||
func (h *Handler) HandleGetCurrentUser(w http.ResponseWriter, r *http.Request) {
|
||||
user, ok := shared.GetUserFromContext(r.Context())
|
||||
if !ok {
|
||||
shared.WriteUnauthorized(w, "")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if user is admin
|
||||
isAdmin := false
|
||||
for _, adminEmail := range h.adminEmails {
|
||||
if strings.EqualFold(user.Email, adminEmail) {
|
||||
isAdmin = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
userDTO := UserDTO{
|
||||
ID: user.Sub,
|
||||
Email: user.Email,
|
||||
Name: user.Name,
|
||||
IsAdmin: isAdmin,
|
||||
}
|
||||
|
||||
shared.WriteJSON(w, http.StatusOK, userDTO)
|
||||
}
|
||||
511
backend/internal/presentation/api/users/handler_test.go
Normal file
511
backend/internal/presentation/api/users/handler_test.go
Normal file
@@ -0,0 +1,511 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package users
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/presentation/api/shared"
|
||||
)
|
||||
|
||||
// ============================================================================
|
||||
// TEST FIXTURES
|
||||
// ============================================================================
|
||||
|
||||
var (
|
||||
testUserRegular = &models.User{
|
||||
Sub: "google-oauth2|123456789",
|
||||
Email: "user@example.com",
|
||||
Name: "Regular User",
|
||||
}
|
||||
|
||||
testUserAdmin = &models.User{
|
||||
Sub: "google-oauth2|987654321",
|
||||
Email: "admin@example.com",
|
||||
Name: "Admin User",
|
||||
}
|
||||
|
||||
testUserAdminUpperCase = &models.User{
|
||||
Sub: "google-oauth2|111111111",
|
||||
Email: "ADMIN@example.com", // Uppercase to test case-insensitive matching
|
||||
Name: "Admin Uppercase",
|
||||
}
|
||||
|
||||
testAdminEmails = []string{"admin@example.com", "admin2@example.com"}
|
||||
)
|
||||
|
||||
// ============================================================================
|
||||
// HELPER FUNCTIONS
|
||||
// ============================================================================
|
||||
|
||||
func addUserToContext(ctx context.Context, user *models.User) context.Context {
|
||||
return context.WithValue(ctx, shared.ContextKeyUser, user)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS
|
||||
// ============================================================================
|
||||
|
||||
func TestNewHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
adminEmails []string
|
||||
}{
|
||||
{
|
||||
name: "with admin emails",
|
||||
adminEmails: []string{"admin@example.com"},
|
||||
},
|
||||
{
|
||||
name: "with multiple admin emails",
|
||||
adminEmails: []string{"admin1@example.com", "admin2@example.com", "admin3@example.com"},
|
||||
},
|
||||
{
|
||||
name: "with empty admin emails",
|
||||
adminEmails: []string{},
|
||||
},
|
||||
{
|
||||
name: "with nil admin emails",
|
||||
adminEmails: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := NewHandler(tt.adminEmails)
|
||||
|
||||
assert.NotNil(t, handler)
|
||||
if tt.adminEmails != nil {
|
||||
assert.Equal(t, len(tt.adminEmails), len(handler.adminEmails))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_HandleGetCurrentUser_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
user *models.User
|
||||
adminEmails []string
|
||||
expectedIsAdmin bool
|
||||
expectedID string
|
||||
expectedEmail string
|
||||
expectedName string
|
||||
}{
|
||||
{
|
||||
name: "regular user - not admin",
|
||||
user: testUserRegular,
|
||||
adminEmails: testAdminEmails,
|
||||
expectedIsAdmin: false,
|
||||
expectedID: "google-oauth2|123456789",
|
||||
expectedEmail: "user@example.com",
|
||||
expectedName: "Regular User",
|
||||
},
|
||||
{
|
||||
name: "admin user - is admin",
|
||||
user: testUserAdmin,
|
||||
adminEmails: testAdminEmails,
|
||||
expectedIsAdmin: true,
|
||||
expectedID: "google-oauth2|987654321",
|
||||
expectedEmail: "admin@example.com",
|
||||
expectedName: "Admin User",
|
||||
},
|
||||
{
|
||||
name: "admin with uppercase email - case insensitive match",
|
||||
user: testUserAdminUpperCase,
|
||||
adminEmails: testAdminEmails,
|
||||
expectedIsAdmin: true,
|
||||
expectedID: "google-oauth2|111111111",
|
||||
expectedEmail: "ADMIN@example.com",
|
||||
expectedName: "Admin Uppercase",
|
||||
},
|
||||
{
|
||||
name: "user with no admin emails configured",
|
||||
user: testUserRegular,
|
||||
adminEmails: []string{},
|
||||
expectedIsAdmin: false,
|
||||
expectedID: "google-oauth2|123456789",
|
||||
expectedEmail: "user@example.com",
|
||||
expectedName: "Regular User",
|
||||
},
|
||||
{
|
||||
name: "user with different admin email",
|
||||
user: testUserRegular,
|
||||
adminEmails: []string{"different@example.com"},
|
||||
expectedIsAdmin: false,
|
||||
expectedID: "google-oauth2|123456789",
|
||||
expectedEmail: "user@example.com",
|
||||
expectedName: "Regular User",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Setup
|
||||
handler := NewHandler(tt.adminEmails)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/users/me", nil)
|
||||
ctx := addUserToContext(req.Context(), tt.user)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
handler.HandleGetCurrentUser(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "application/json", rec.Header().Get("Content-Type"))
|
||||
|
||||
// Parse response
|
||||
var wrapper struct {
|
||||
Data UserDTO `json:"data"`
|
||||
}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &wrapper)
|
||||
require.NoError(t, err, "Response should be valid JSON")
|
||||
|
||||
// Validate fields
|
||||
assert.Equal(t, tt.expectedID, wrapper.Data.ID)
|
||||
assert.Equal(t, tt.expectedEmail, wrapper.Data.Email)
|
||||
assert.Equal(t, tt.expectedName, wrapper.Data.Name)
|
||||
assert.Equal(t, tt.expectedIsAdmin, wrapper.Data.IsAdmin)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_HandleGetCurrentUser_Unauthorized(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupCtx func(context.Context) context.Context
|
||||
expectedMsg string
|
||||
}{
|
||||
{
|
||||
name: "no user in context",
|
||||
setupCtx: func(ctx context.Context) context.Context {
|
||||
return ctx // No user added
|
||||
},
|
||||
expectedMsg: "", // Empty unauthorized message
|
||||
},
|
||||
{
|
||||
name: "nil user in context",
|
||||
setupCtx: func(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, shared.ContextKeyUser, nil)
|
||||
},
|
||||
expectedMsg: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Setup
|
||||
handler := NewHandler(testAdminEmails)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/users/me", nil)
|
||||
ctx := tt.setupCtx(req.Context())
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
handler.HandleGetCurrentUser(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
|
||||
// Parse error response
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should have error structure
|
||||
assert.Contains(t, response, "error")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_HandleGetCurrentUser_ResponseFormat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := NewHandler(testAdminEmails)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/users/me", nil)
|
||||
ctx := addUserToContext(req.Context(), testUserRegular)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleGetCurrentUser(rec, req)
|
||||
|
||||
// Check Content-Type
|
||||
assert.Equal(t, "application/json", rec.Header().Get("Content-Type"))
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// Validate JSON structure
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check wrapper structure
|
||||
assert.Contains(t, response, "data")
|
||||
|
||||
// Get data object
|
||||
data, ok := response["data"].(map[string]interface{})
|
||||
require.True(t, ok, "data should be an object")
|
||||
|
||||
// Check required fields
|
||||
assert.Contains(t, data, "id")
|
||||
assert.Contains(t, data, "email")
|
||||
assert.Contains(t, data, "name")
|
||||
assert.Contains(t, data, "isAdmin")
|
||||
|
||||
// Validate field types
|
||||
_, ok = data["id"].(string)
|
||||
assert.True(t, ok, "id should be a string")
|
||||
|
||||
_, ok = data["email"].(string)
|
||||
assert.True(t, ok, "email should be a string")
|
||||
|
||||
_, ok = data["name"].(string)
|
||||
assert.True(t, ok, "name should be a string")
|
||||
|
||||
_, ok = data["isAdmin"].(bool)
|
||||
assert.True(t, ok, "isAdmin should be a boolean")
|
||||
}
|
||||
|
||||
func TestHandler_HandleGetCurrentUser_AdminEmailCaseInsensitive(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
adminEmails []string
|
||||
userEmail string
|
||||
expectedAdmin bool
|
||||
}{
|
||||
{
|
||||
name: "exact match lowercase",
|
||||
adminEmails: []string{"admin@example.com"},
|
||||
userEmail: "admin@example.com",
|
||||
expectedAdmin: true,
|
||||
},
|
||||
{
|
||||
name: "user uppercase, admin lowercase",
|
||||
adminEmails: []string{"admin@example.com"},
|
||||
userEmail: "ADMIN@EXAMPLE.COM",
|
||||
expectedAdmin: true,
|
||||
},
|
||||
{
|
||||
name: "user lowercase, admin uppercase",
|
||||
adminEmails: []string{"ADMIN@EXAMPLE.COM"},
|
||||
userEmail: "admin@example.com",
|
||||
expectedAdmin: true,
|
||||
},
|
||||
{
|
||||
name: "mixed case both",
|
||||
adminEmails: []string{"Admin@Example.COM"},
|
||||
userEmail: "aDmIn@eXaMpLe.CoM",
|
||||
expectedAdmin: true,
|
||||
},
|
||||
{
|
||||
name: "different email",
|
||||
adminEmails: []string{"admin@example.com"},
|
||||
userEmail: "user@example.com",
|
||||
expectedAdmin: false,
|
||||
},
|
||||
{
|
||||
name: "multiple admins, user matches second",
|
||||
adminEmails: []string{"admin1@example.com", "admin2@example.com"},
|
||||
userEmail: "ADMIN2@EXAMPLE.COM",
|
||||
expectedAdmin: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := NewHandler(tt.adminEmails)
|
||||
|
||||
user := &models.User{
|
||||
Sub: "test-sub",
|
||||
Email: tt.userEmail,
|
||||
Name: "Test User",
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/users/me", nil)
|
||||
ctx := addUserToContext(req.Context(), user)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleGetCurrentUser(rec, req)
|
||||
|
||||
var wrapper struct {
|
||||
Data UserDTO `json:"data"`
|
||||
}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &wrapper)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tt.expectedAdmin, wrapper.Data.IsAdmin, "Admin status mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_HandleGetCurrentUser_Concurrent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := NewHandler(testAdminEmails)
|
||||
|
||||
const numRequests = 100
|
||||
done := make(chan bool, numRequests)
|
||||
errors := make(chan error, numRequests)
|
||||
|
||||
// Spawn concurrent requests
|
||||
for i := 0; i < numRequests; i++ {
|
||||
go func(id int) {
|
||||
defer func() { done <- true }()
|
||||
|
||||
var user *models.User
|
||||
if id%2 == 0 {
|
||||
user = testUserRegular
|
||||
} else {
|
||||
user = testUserAdmin
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/users/me", nil)
|
||||
ctx := addUserToContext(req.Context(), user)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleGetCurrentUser(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
errors <- assert.AnError
|
||||
}
|
||||
|
||||
var wrapper struct {
|
||||
Data UserDTO `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &wrapper); err != nil {
|
||||
errors <- err
|
||||
}
|
||||
|
||||
// Validate admin status
|
||||
if id%2 == 0 && wrapper.Data.IsAdmin {
|
||||
errors <- assert.AnError
|
||||
}
|
||||
if id%2 != 0 && !wrapper.Data.IsAdmin {
|
||||
errors <- assert.AnError
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all requests
|
||||
for i := 0; i < numRequests; i++ {
|
||||
<-done
|
||||
}
|
||||
close(errors)
|
||||
|
||||
// Check for errors
|
||||
var errCount int
|
||||
for err := range errors {
|
||||
t.Logf("Concurrent request error: %v", err)
|
||||
errCount++
|
||||
}
|
||||
|
||||
assert.Equal(t, 0, errCount, "All concurrent requests should succeed")
|
||||
}
|
||||
|
||||
func TestHandler_HandleGetCurrentUser_DifferentHTTPMethods(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "GET method (correct)",
|
||||
method: http.MethodGet,
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "POST method (works but not RESTful)",
|
||||
method: http.MethodPost,
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "PUT method",
|
||||
method: http.MethodPut,
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := NewHandler(testAdminEmails)
|
||||
|
||||
req := httptest.NewRequest(tt.method, "/api/v1/users/me", nil)
|
||||
ctx := addUserToContext(req.Context(), testUserRegular)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleGetCurrentUser(rec, req)
|
||||
|
||||
assert.Equal(t, tt.expectedStatus, rec.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHandler_HandleGetCurrentUser(b *testing.B) {
|
||||
handler := NewHandler(testAdminEmails)
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/users/me", nil)
|
||||
ctx := addUserToContext(req.Context(), testUserRegular)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleGetCurrentUser(rec, req)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHandler_HandleGetCurrentUser_Parallel(b *testing.B) {
|
||||
handler := NewHandler(testAdminEmails)
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/users/me", nil)
|
||||
ctx := addUserToContext(req.Context(), testUserRegular)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.HandleGetCurrentUser(rec, req)
|
||||
}
|
||||
})
|
||||
}
|
||||
39
backend/internal/presentation/handlers/errors.go
Normal file
39
backend/internal/presentation/handlers/errors.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/logger"
|
||||
)
|
||||
|
||||
// HandleError handles different types of errors and returns appropriate HTTP responses
|
||||
func HandleError(w http.ResponseWriter, err error) {
|
||||
switch {
|
||||
case errors.Is(err, models.ErrUnauthorized):
|
||||
logger.Logger.Warn("Unauthorized access attempt", "error", err.Error())
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
case errors.Is(err, models.ErrSignatureNotFound):
|
||||
logger.Logger.Debug("Signature not found", "error", err.Error())
|
||||
http.Error(w, "Signature not found", http.StatusNotFound)
|
||||
case errors.Is(err, models.ErrSignatureAlreadyExists):
|
||||
logger.Logger.Debug("Duplicate signature attempt", "error", err.Error())
|
||||
http.Error(w, "Signature already exists", http.StatusConflict)
|
||||
case errors.Is(err, models.ErrInvalidUser):
|
||||
logger.Logger.Warn("Invalid user data", "error", err.Error())
|
||||
http.Error(w, "Invalid user", http.StatusBadRequest)
|
||||
case errors.Is(err, models.ErrInvalidDocument):
|
||||
logger.Logger.Warn("Invalid document ID", "error", err.Error())
|
||||
http.Error(w, "Invalid document ID", http.StatusBadRequest)
|
||||
case errors.Is(err, models.ErrDomainNotAllowed):
|
||||
logger.Logger.Warn("Domain not allowed", "error", err.Error())
|
||||
http.Error(w, "Domain not allowed", http.StatusForbidden)
|
||||
case errors.Is(err, models.ErrDatabaseConnection):
|
||||
logger.Logger.Error("Database connection error", "error", err.Error())
|
||||
http.Error(w, "Database error", http.StatusInternalServerError)
|
||||
default:
|
||||
logger.Logger.Error("Unhandled error", "error", err.Error())
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
628
backend/internal/presentation/handlers/handlers_test.go
Normal file
628
backend/internal/presentation/handlers/handlers_test.go
Normal file
@@ -0,0 +1,628 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type fakeAuthService struct {
|
||||
shouldFailSetUser bool
|
||||
shouldFailCallback bool
|
||||
shouldFailGetUser bool
|
||||
setUserError error
|
||||
getUserError error
|
||||
callbackUser *models.User
|
||||
callbackNextURL string
|
||||
callbackError error
|
||||
authURL string
|
||||
logoutURL string
|
||||
logoutCalled bool
|
||||
|
||||
verifyStateResult bool
|
||||
lastVerifyToken string
|
||||
currentUser *models.User
|
||||
}
|
||||
|
||||
func newFakeAuthService() *fakeAuthService {
|
||||
return &fakeAuthService{
|
||||
authURL: "https://oauth.example.com/auth",
|
||||
callbackUser: &models.User{Sub: "test-user", Email: "test@example.com", Name: "Test User"},
|
||||
callbackNextURL: "/",
|
||||
verifyStateResult: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeAuthService) GetUser(_ *http.Request) (*models.User, error) {
|
||||
if f.shouldFailGetUser {
|
||||
return nil, f.getUserError
|
||||
}
|
||||
return f.currentUser, nil
|
||||
}
|
||||
|
||||
func (f *fakeAuthService) SetUser(_ http.ResponseWriter, _ *http.Request, user *models.User) error {
|
||||
if f.shouldFailSetUser {
|
||||
return f.setUserError
|
||||
}
|
||||
f.currentUser = user
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeAuthService) Logout(_ http.ResponseWriter, _ *http.Request) {
|
||||
f.logoutCalled = true
|
||||
f.currentUser = nil
|
||||
}
|
||||
|
||||
func (f *fakeAuthService) GetLogoutURL() string {
|
||||
return f.logoutURL
|
||||
}
|
||||
|
||||
func (f *fakeAuthService) GetAuthURL(nextURL string) string {
|
||||
return f.authURL + "?next=" + url.QueryEscape(nextURL)
|
||||
}
|
||||
|
||||
func (f *fakeAuthService) CreateAuthURL(_ http.ResponseWriter, _ *http.Request, nextURL string) string {
|
||||
return f.GetAuthURL(nextURL)
|
||||
}
|
||||
|
||||
func (f *fakeAuthService) VerifyState(_ http.ResponseWriter, _ *http.Request, token string) bool {
|
||||
f.lastVerifyToken = token
|
||||
return f.verifyStateResult
|
||||
}
|
||||
|
||||
func (f *fakeAuthService) HandleCallback(_ context.Context, _, _ string) (*models.User, string, error) {
|
||||
if f.shouldFailCallback {
|
||||
return nil, "", f.callbackError
|
||||
}
|
||||
return f.callbackUser, f.callbackNextURL, nil
|
||||
}
|
||||
|
||||
type fakeUserService struct {
|
||||
user *models.User
|
||||
shouldFail bool
|
||||
getUserError error
|
||||
}
|
||||
|
||||
func newFakeUserService() *fakeUserService {
|
||||
return &fakeUserService{
|
||||
user: &models.User{Sub: "test-user", Email: "test@example.com", Name: "Test User"},
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeUserService) GetUser(_ *http.Request) (*models.User, error) {
|
||||
if f.shouldFail {
|
||||
return nil, f.getUserError
|
||||
}
|
||||
return f.user, nil
|
||||
}
|
||||
|
||||
func TestHandleOEmbed_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
baseURL := "https://example.com"
|
||||
handler := HandleOEmbed(baseURL)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
docID string
|
||||
referrer string
|
||||
}{
|
||||
{"simple doc", "doc123", ""},
|
||||
{"with referrer", "doc456", "github"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
reqURL := baseURL + "/?doc=" + tt.docID
|
||||
if tt.referrer != "" {
|
||||
reqURL += "&referrer=" + tt.referrer
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/oembed?url="+url.QueryEscape(reqURL), nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", rec.Code)
|
||||
}
|
||||
|
||||
var response OEmbedResponse
|
||||
if err := json.NewDecoder(rec.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("Failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if response.Type != "rich" {
|
||||
t.Errorf("Expected type 'rich', got %s", response.Type)
|
||||
}
|
||||
if response.Version != "1.0" {
|
||||
t.Errorf("Expected version '1.0', got %s", response.Version)
|
||||
}
|
||||
if response.ProviderName != "Ackify" {
|
||||
t.Errorf("Expected provider 'Ackify', got %s", response.ProviderName)
|
||||
}
|
||||
if response.Height != 200 {
|
||||
t.Errorf("Expected height 200, got %d", response.Height)
|
||||
}
|
||||
if !strings.Contains(response.HTML, "iframe") {
|
||||
t.Error("Expected HTML to contain iframe")
|
||||
}
|
||||
if !strings.Contains(response.HTML, tt.docID) {
|
||||
t.Errorf("Expected HTML to contain doc ID %s", tt.docID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleOEmbed_MissingURLParam(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := HandleOEmbed("https://example.com")
|
||||
req := httptest.NewRequest(http.MethodGet, "/oembed", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status 400, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleOEmbed_InvalidURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := HandleOEmbed("https://example.com")
|
||||
req := httptest.NewRequest(http.MethodGet, "/oembed?url=:::invalid", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status 400, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleOEmbed_MissingDocParam(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := HandleOEmbed("https://example.com")
|
||||
req := httptest.NewRequest(http.MethodGet, "/oembed?url="+url.QueryEscape("https://example.com/"), nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status 400, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateOEmbedURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
urlStr string
|
||||
baseURL string
|
||||
expected bool
|
||||
}{
|
||||
{"valid same host", "https://example.com/?doc=123", "https://example.com", true},
|
||||
{"valid with port", "https://example.com:443/?doc=123", "https://example.com", true},
|
||||
{"different host", "https://other.com/?doc=123", "https://example.com", false},
|
||||
{"localhost variations", "http://localhost:8080/?doc=123", "http://127.0.0.1:8080", true},
|
||||
{"localhost to 127.0.0.1", "http://127.0.0.1/?doc=123", "http://localhost", true},
|
||||
{"invalid URL", ":::invalid", "https://example.com", false},
|
||||
{"invalid base URL", "https://example.com/?doc=123", ":::invalid", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result := ValidateOEmbedURL(tt.urlStr, tt.baseURL)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// BENCHMARKS
|
||||
// ============================================================================
|
||||
|
||||
func BenchmarkHandleOEmbed(b *testing.B) {
|
||||
handler := HandleOEmbed("https://example.com")
|
||||
reqURL := url.QueryEscape("https://example.com/?doc=test123")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, "/oembed?url="+reqURL, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler(rec, req)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkValidateOEmbedURL(b *testing.B) {
|
||||
urlStr := "https://example.com/?doc=test123"
|
||||
baseURL := "https://example.com"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = ValidateOEmbedURL(urlStr, baseURL)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - Middleware: SecureHeaders
|
||||
// ============================================================================
|
||||
|
||||
func TestSecureHeaders_NonEmbedRoute(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := SecureHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "nosniff", rec.Header().Get("X-Content-Type-Options"))
|
||||
assert.Equal(t, "no-referrer", rec.Header().Get("Referrer-Policy"))
|
||||
assert.Equal(t, "DENY", rec.Header().Get("X-Frame-Options"))
|
||||
assert.Contains(t, rec.Header().Get("Content-Security-Policy"), "frame-ancestors 'self'")
|
||||
}
|
||||
|
||||
func TestSecureHeaders_EmbedRoute(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := SecureHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/embed/doc123", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "nosniff", rec.Header().Get("X-Content-Type-Options"))
|
||||
assert.Equal(t, "no-referrer", rec.Header().Get("Referrer-Policy"))
|
||||
assert.Empty(t, rec.Header().Get("X-Frame-Options"), "Embed routes should not have X-Frame-Options")
|
||||
assert.Contains(t, rec.Header().Get("Content-Security-Policy"), "frame-ancestors *")
|
||||
}
|
||||
|
||||
func TestSecureHeaders_EmbedRootRoute(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := SecureHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/embed", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Empty(t, rec.Header().Get("X-Frame-Options"))
|
||||
assert.Contains(t, rec.Header().Get("Content-Security-Policy"), "frame-ancestors *")
|
||||
}
|
||||
|
||||
func TestSecureHeaders_CSPContent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := SecureHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
csp := rec.Header().Get("Content-Security-Policy")
|
||||
assert.Contains(t, csp, "default-src 'self'")
|
||||
assert.Contains(t, csp, "script-src 'self'")
|
||||
assert.Contains(t, csp, "style-src 'self'")
|
||||
assert.Contains(t, csp, "https://cdn.tailwindcss.com")
|
||||
assert.Contains(t, csp, "https://cdn.simpleicons.org")
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - Middleware: RequestLogger
|
||||
// ============================================================================
|
||||
|
||||
func TestRequestLogger_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := RequestLogger(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("success"))
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "success", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestRequestLogger_WithError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := RequestLogger(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte("error"))
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/fail", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
assert.Equal(t, "error", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestRequestLogger_StatusRecorder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := RequestLogger(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
// Verify the status recorder is working by checking the wrapper
|
||||
if sr, ok := w.(*statusRecorder); ok {
|
||||
assert.Equal(t, http.StatusCreated, sr.status)
|
||||
}
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusCreated, rec.Code)
|
||||
}
|
||||
|
||||
func TestRequestLogger_DifferentMethods(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
methods := []string{http.MethodGet, http.MethodPost, http.MethodPut, http.MethodDelete, http.MethodPatch}
|
||||
|
||||
for _, method := range methods {
|
||||
t.Run(method, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := RequestLogger(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(method, "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - HandleError
|
||||
// ============================================================================
|
||||
|
||||
func TestHandleError_Unauthorized(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
HandleError(rec, models.ErrUnauthorized)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Unauthorized")
|
||||
}
|
||||
|
||||
func TestHandleError_SignatureNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
HandleError(rec, models.ErrSignatureNotFound)
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Signature not found")
|
||||
}
|
||||
|
||||
func TestHandleError_SignatureAlreadyExists(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
HandleError(rec, models.ErrSignatureAlreadyExists)
|
||||
|
||||
assert.Equal(t, http.StatusConflict, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Signature already exists")
|
||||
}
|
||||
|
||||
func TestHandleError_InvalidUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
HandleError(rec, models.ErrInvalidUser)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Invalid user")
|
||||
}
|
||||
|
||||
func TestHandleError_InvalidDocument(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
HandleError(rec, models.ErrInvalidDocument)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Invalid document ID")
|
||||
}
|
||||
|
||||
func TestHandleError_DomainNotAllowed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
HandleError(rec, models.ErrDomainNotAllowed)
|
||||
|
||||
assert.Equal(t, http.StatusForbidden, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Domain not allowed")
|
||||
}
|
||||
|
||||
func TestHandleError_DatabaseConnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
HandleError(rec, models.ErrDatabaseConnection)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Database error")
|
||||
}
|
||||
|
||||
func TestHandleError_UnknownError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
HandleError(rec, errors.New("unknown error"))
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Internal server error")
|
||||
}
|
||||
|
||||
func TestHandleError_WrappedErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expectedStatus int
|
||||
expectedMsg string
|
||||
}{
|
||||
{
|
||||
"wrapped unauthorized",
|
||||
fmt.Errorf("auth failed: %w", models.ErrUnauthorized),
|
||||
http.StatusUnauthorized,
|
||||
"Unauthorized",
|
||||
},
|
||||
{
|
||||
"wrapped domain error",
|
||||
fmt.Errorf("validation failed: %w", models.ErrDomainNotAllowed),
|
||||
http.StatusForbidden,
|
||||
"Domain not allowed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
HandleError(rec, tt.err)
|
||||
|
||||
assert.Equal(t, tt.expectedStatus, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), tt.expectedMsg)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - statusRecorder
|
||||
// ============================================================================
|
||||
|
||||
func TestStatusRecorder_WriteHeader(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
sr := &statusRecorder{ResponseWriter: rec, status: http.StatusOK}
|
||||
|
||||
sr.WriteHeader(http.StatusCreated)
|
||||
|
||||
assert.Equal(t, http.StatusCreated, sr.status)
|
||||
assert.Equal(t, http.StatusCreated, rec.Code)
|
||||
}
|
||||
|
||||
func TestStatusRecorder_DefaultStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
sr := &statusRecorder{ResponseWriter: rec, status: http.StatusOK}
|
||||
|
||||
// Don't call WriteHeader, should keep default
|
||||
assert.Equal(t, http.StatusOK, sr.status)
|
||||
}
|
||||
|
||||
func TestStatusRecorder_MultipleWriteHeader(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
sr := &statusRecorder{ResponseWriter: rec, status: http.StatusOK}
|
||||
|
||||
// First call
|
||||
sr.WriteHeader(http.StatusCreated)
|
||||
assert.Equal(t, http.StatusCreated, sr.status)
|
||||
|
||||
// Second call (should be ignored by http.ResponseWriter)
|
||||
sr.WriteHeader(http.StatusInternalServerError)
|
||||
// Status recorder updates but ResponseWriter doesn't change
|
||||
assert.Equal(t, http.StatusInternalServerError, sr.status)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// BENCHMARKS
|
||||
// ============================================================================
|
||||
|
||||
func BenchmarkSecureHeaders(b *testing.B) {
|
||||
handler := SecureHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRequestLogger(b *testing.B) {
|
||||
handler := RequestLogger(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHandleError(b *testing.B) {
|
||||
err := models.ErrUnauthorized
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
rec := httptest.NewRecorder()
|
||||
HandleError(rec, err)
|
||||
}
|
||||
}
|
||||
88
backend/internal/presentation/handlers/middleware.go
Normal file
88
backend/internal/presentation/handlers/middleware.go
Normal file
@@ -0,0 +1,88 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/logger"
|
||||
)
|
||||
|
||||
type userService interface {
|
||||
GetUser(r *http.Request) (*models.User, error)
|
||||
}
|
||||
|
||||
type AuthMiddleware struct {
|
||||
userService userService
|
||||
baseURL string
|
||||
}
|
||||
|
||||
// SecureHeaders Enforce baseline security headers (CSP, XFO, etc.) to mitigate clickjacking, MIME sniffing, and unsafe embedding by default.
|
||||
func SecureHeaders(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.Header().Set("Referrer-Policy", "no-referrer")
|
||||
|
||||
// Check if this is an embed route - allow iframe embedding
|
||||
isEmbedRoute := strings.HasPrefix(r.URL.Path, "/embed/") || strings.HasPrefix(r.URL.Path, "/embed")
|
||||
|
||||
if isEmbedRoute {
|
||||
// Allow embedding from any origin for embed pages
|
||||
// Do not set X-Frame-Options to allow iframe embedding
|
||||
w.Header().Set("Content-Security-Policy",
|
||||
"default-src 'self'; "+
|
||||
"style-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com https://fonts.googleapis.com; "+
|
||||
"font-src 'self' https://fonts.gstatic.com; "+
|
||||
"script-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com; "+
|
||||
"img-src 'self' data: https://cdn.simpleicons.org; "+
|
||||
"connect-src 'self'; "+
|
||||
"frame-ancestors *") // Allow embedding from any origin
|
||||
} else {
|
||||
// Strict headers for non-embed routes
|
||||
w.Header().Set("X-Frame-Options", "DENY")
|
||||
w.Header().Set("Content-Security-Policy",
|
||||
"default-src 'self'; "+
|
||||
"style-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com https://fonts.googleapis.com; "+
|
||||
"font-src 'self' https://fonts.gstatic.com; "+
|
||||
"script-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com; "+
|
||||
"img-src 'self' data: https://cdn.simpleicons.org; "+
|
||||
"connect-src 'self'; "+
|
||||
"frame-ancestors 'self'")
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// RequestLogger Minimal structured logging without PII; record latency and status for ops visibility.
|
||||
func RequestLogger(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sr := &statusRecorder{ResponseWriter: w, status: http.StatusOK}
|
||||
start := time.Now()
|
||||
next.ServeHTTP(sr, r)
|
||||
duration := time.Since(start)
|
||||
// Minimal structured log to avoid PII
|
||||
logger.Logger.Info("http_request",
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
"status", sr.status,
|
||||
"duration_ms", duration.Milliseconds())
|
||||
})
|
||||
}
|
||||
|
||||
type statusRecorder struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
}
|
||||
|
||||
func (sr *statusRecorder) WriteHeader(code int) {
|
||||
sr.status = code
|
||||
sr.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
128
backend/internal/presentation/handlers/oembed.go
Normal file
128
backend/internal/presentation/handlers/oembed.go
Normal file
@@ -0,0 +1,128 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/logger"
|
||||
)
|
||||
|
||||
// OEmbedResponse represents the oEmbed JSON response format
|
||||
// Specification: https://oembed.com/
|
||||
type OEmbedResponse struct {
|
||||
Type string `json:"type"` // Must be "rich" for iframe embeds
|
||||
Version string `json:"version"` // oEmbed version (always "1.0")
|
||||
Title string `json:"title"` // Document title
|
||||
ProviderName string `json:"provider_name"` // Service name
|
||||
ProviderURL string `json:"provider_url"` // Service homepage URL
|
||||
HTML string `json:"html"` // HTML embed code (iframe)
|
||||
Width int `json:"width,omitempty"` // Recommended width (optional)
|
||||
Height int `json:"height"` // Recommended height
|
||||
}
|
||||
|
||||
// HandleOEmbed handles GET /oembed?url=<document_url>
|
||||
// Returns oEmbed JSON for embedding Ackify signature widgets in external platforms
|
||||
func HandleOEmbed(baseURL string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get the URL parameter
|
||||
urlParam := r.URL.Query().Get("url")
|
||||
if urlParam == "" {
|
||||
logger.Logger.Warn("oEmbed request missing url parameter",
|
||||
"remote_addr", r.RemoteAddr)
|
||||
http.Error(w, "Missing 'url' parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse the URL to extract doc parameter
|
||||
parsedURL, err := url.Parse(urlParam)
|
||||
if err != nil {
|
||||
logger.Logger.Warn("oEmbed request with invalid url",
|
||||
"url", urlParam,
|
||||
"error", err.Error(),
|
||||
"remote_addr", r.RemoteAddr)
|
||||
http.Error(w, "Invalid 'url' parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract doc ID from query parameters
|
||||
docID := parsedURL.Query().Get("doc")
|
||||
if docID == "" {
|
||||
logger.Logger.Warn("oEmbed request missing doc parameter in url",
|
||||
"url", urlParam,
|
||||
"remote_addr", r.RemoteAddr)
|
||||
http.Error(w, "URL must contain 'doc' parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Build embed URL (points to the SPA embed view)
|
||||
embedURL := baseURL + "/embed?doc=" + url.QueryEscape(docID)
|
||||
|
||||
// Check if referrer is provided (for tracking which platform is embedding)
|
||||
referrer := parsedURL.Query().Get("referrer")
|
||||
if referrer != "" {
|
||||
embedURL += "&referrer=" + url.QueryEscape(referrer)
|
||||
}
|
||||
|
||||
// Build iframe HTML
|
||||
iframeHTML := `<iframe src="` + embedURL + `" width="100%" height="200" frameborder="0" style="border: 1px solid #ddd; border-radius: 6px;" allowtransparency="true"></iframe>`
|
||||
|
||||
// Create oEmbed response
|
||||
response := OEmbedResponse{
|
||||
Type: "rich",
|
||||
Version: "1.0",
|
||||
Title: "Document " + docID + " - Confirmations de lecture",
|
||||
ProviderName: "Ackify",
|
||||
ProviderURL: baseURL,
|
||||
HTML: iframeHTML,
|
||||
Height: 200,
|
||||
}
|
||||
|
||||
// Set response headers
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // Allow cross-origin requests for oEmbed
|
||||
|
||||
// Encode and send response
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
logger.Logger.Error("Failed to encode oEmbed response",
|
||||
"doc_id", docID,
|
||||
"error", err.Error())
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Logger.Info("oEmbed response served",
|
||||
"doc_id", docID,
|
||||
"url", urlParam,
|
||||
"remote_addr", r.RemoteAddr)
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateOEmbedURL checks if the provided URL is a valid Ackify document URL
|
||||
func ValidateOEmbedURL(urlStr string, baseURL string) bool {
|
||||
parsedURL, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if the URL belongs to this Ackify instance
|
||||
baseURLParsed, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Normalize hosts for comparison (remove ports if present)
|
||||
urlHost := strings.Split(parsedURL.Host, ":")[0]
|
||||
baseHost := strings.Split(baseURLParsed.Host, ":")[0]
|
||||
|
||||
// Allow localhost variations
|
||||
if urlHost == "localhost" || urlHost == "127.0.0.1" {
|
||||
if baseHost == "localhost" || baseHost == "127.0.0.1" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return urlHost == baseHost
|
||||
}
|
||||
17
backend/locales/de.json
Normal file
17
backend/locales/de.json
Normal file
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"email.reminder.subject": "Erinnerung zur Bestätigung des Dokumentenlesens",
|
||||
"email.reminder.title": "Erinnerung zur Bestätigung des Dokumentenlesens",
|
||||
"email.reminder.greeting_with_name": "Hallo {{.RecipientName}},",
|
||||
"email.reminder.greeting": "Hallo,",
|
||||
"email.reminder.intro": "Dies ist eine Erinnerung, dass das folgende Dokument Ihre Lesebestätigung erfordert:",
|
||||
"email.reminder.doc_id_label": "Dokument-ID:",
|
||||
"email.reminder.doc_location_label": "Standort:",
|
||||
"email.reminder.instructions": "Um dieses Dokument anzusehen und das Lesen zu bestätigen, folgen Sie bitte diesen Schritten:",
|
||||
"email.reminder.step_view_doc": "Dokument ansehen unter:",
|
||||
"email.reminder.step_sign": "Ihr Lesen bestätigen unter:",
|
||||
"email.reminder.cta_button": "Lesen jetzt bestätigen",
|
||||
"email.reminder.explanation": "Ihre kryptographische Bestätigung liefert einen überprüfbaren Nachweis, dass Sie dieses Dokument gelesen und zur Kenntnis genommen haben.",
|
||||
"email.reminder.contact": "Bei Fragen wenden Sie sich bitte an Ihren Administrator.",
|
||||
"email.reminder.regards": "Mit freundlichen Grüßen,",
|
||||
"email.reminder.team": "Das {{.Organisation}}-Team"
|
||||
}
|
||||
17
backend/locales/en.json
Normal file
17
backend/locales/en.json
Normal file
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"email.reminder.subject": "Document Reading Confirmation Reminder",
|
||||
"email.reminder.title": "Document Reading Confirmation Reminder",
|
||||
"email.reminder.greeting_with_name": "Hello {{.RecipientName}},",
|
||||
"email.reminder.greeting": "Hello,",
|
||||
"email.reminder.intro": "This is a reminder that the following document requires your reading confirmation:",
|
||||
"email.reminder.doc_id_label": "Document ID:",
|
||||
"email.reminder.doc_location_label": "Location:",
|
||||
"email.reminder.instructions": "To review and confirm reading of this document, please follow these steps:",
|
||||
"email.reminder.step_view_doc": "View the document at:",
|
||||
"email.reminder.step_sign": "Confirm your reading at:",
|
||||
"email.reminder.cta_button": "Confirm reading now",
|
||||
"email.reminder.explanation": "Your cryptographic confirmation will provide verifiable proof that you have read and acknowledged this document.",
|
||||
"email.reminder.contact": "If you have any questions, please contact your administrator.",
|
||||
"email.reminder.regards": "Best regards,",
|
||||
"email.reminder.team": "The {{.Organisation}} team"
|
||||
}
|
||||
17
backend/locales/es.json
Normal file
17
backend/locales/es.json
Normal file
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"email.reminder.subject": "Recordatorio de confirmación de lectura de documento",
|
||||
"email.reminder.title": "Recordatorio de confirmación de lectura de documento",
|
||||
"email.reminder.greeting_with_name": "Hola {{.RecipientName}},",
|
||||
"email.reminder.greeting": "Hola,",
|
||||
"email.reminder.intro": "Este es un recordatorio de que el siguiente documento requiere su confirmación de lectura:",
|
||||
"email.reminder.doc_id_label": "ID del documento:",
|
||||
"email.reminder.doc_location_label": "Ubicación:",
|
||||
"email.reminder.instructions": "Para revisar y confirmar la lectura de este documento, siga estos pasos:",
|
||||
"email.reminder.step_view_doc": "Ver el documento en:",
|
||||
"email.reminder.step_sign": "Confirmar su lectura en:",
|
||||
"email.reminder.cta_button": "Confirmar lectura ahora",
|
||||
"email.reminder.explanation": "Su confirmación criptográfica proporcionará una prueba verificable de que ha leído y reconocido este documento.",
|
||||
"email.reminder.contact": "Si tiene alguna pregunta, póngase en contacto con su administrador.",
|
||||
"email.reminder.regards": "Saludos cordiales,",
|
||||
"email.reminder.team": "El equipo de {{.Organisation}}"
|
||||
}
|
||||
17
backend/locales/fr.json
Normal file
17
backend/locales/fr.json
Normal file
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"email.reminder.subject": "Rappel de confirmation de lecture de document",
|
||||
"email.reminder.title": "Rappel de confirmation de lecture de document",
|
||||
"email.reminder.greeting_with_name": "Bonjour {{.RecipientName}},",
|
||||
"email.reminder.greeting": "Bonjour,",
|
||||
"email.reminder.intro": "Ceci est un rappel que le document suivant nécessite votre confirmation de lecture :",
|
||||
"email.reminder.doc_id_label": "ID du document :",
|
||||
"email.reminder.doc_location_label": "Emplacement :",
|
||||
"email.reminder.instructions": "Pour consulter et confirmer la lecture de ce document, veuillez suivre ces étapes :",
|
||||
"email.reminder.step_view_doc": "Consulter le document à :",
|
||||
"email.reminder.step_sign": "Confirmer votre lecture à :",
|
||||
"email.reminder.cta_button": "Confirmer la lecture maintenant",
|
||||
"email.reminder.explanation": "Votre confirmation cryptographique fournira une preuve vérifiable que vous avez lu et pris connaissance de ce document.",
|
||||
"email.reminder.contact": "Si vous avez des questions, veuillez contacter votre administrateur.",
|
||||
"email.reminder.regards": "Cordialement,",
|
||||
"email.reminder.team": "L'équipe {{.Organisation}}"
|
||||
}
|
||||
17
backend/locales/it.json
Normal file
17
backend/locales/it.json
Normal file
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"email.reminder.subject": "Promemoria conferma lettura documento",
|
||||
"email.reminder.title": "Promemoria conferma lettura documento",
|
||||
"email.reminder.greeting_with_name": "Ciao {{.RecipientName}},",
|
||||
"email.reminder.greeting": "Ciao,",
|
||||
"email.reminder.intro": "Questo è un promemoria che il seguente documento richiede la tua conferma di lettura:",
|
||||
"email.reminder.doc_id_label": "ID documento:",
|
||||
"email.reminder.doc_location_label": "Posizione:",
|
||||
"email.reminder.instructions": "Per visualizzare e confermare la lettura di questo documento, si prega di seguire questi passaggi:",
|
||||
"email.reminder.step_view_doc": "Visualizza il documento a:",
|
||||
"email.reminder.step_sign": "Conferma la tua lettura a:",
|
||||
"email.reminder.cta_button": "Conferma lettura ora",
|
||||
"email.reminder.explanation": "La tua conferma crittografica fornirà una prova verificabile che hai letto e preso atto di questo documento.",
|
||||
"email.reminder.contact": "Se hai domande, contatta il tuo amministratore.",
|
||||
"email.reminder.regards": "Cordiali saluti,",
|
||||
"email.reminder.team": "Il team {{.Organisation}}"
|
||||
}
|
||||
@@ -15,8 +15,9 @@ CREATE TABLE signatures (
|
||||
UNIQUE (doc_id, user_sub)
|
||||
);
|
||||
|
||||
-- Create index for efficient queries
|
||||
-- Create indexes for efficient queries
|
||||
CREATE INDEX idx_signatures_user ON signatures(user_sub);
|
||||
CREATE INDEX idx_signatures_doc_id ON signatures(doc_id);
|
||||
|
||||
-- Create trigger to prevent modification of created_at
|
||||
CREATE OR REPLACE FUNCTION prevent_created_at_update()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user