mirror of
https://github.com/btouchard/ackify-ce.git
synced 2026-02-28 10:48:47 -06:00
feat: initial project setup
Add complete Go application for cryptographic document signature validation with OAuth2 authentication, Ed25519 signatures, and PostgreSQL storage following clean architecture principles.
This commit is contained in:
@@ -0,0 +1,52 @@
|
||||
# Git
|
||||
.git
|
||||
.gitignore
|
||||
.idea
|
||||
|
||||
# Documentation
|
||||
README.md
|
||||
CLAUDE.md
|
||||
*SETUP.md
|
||||
docs/
|
||||
*.md
|
||||
LICENSE
|
||||
|
||||
# Development
|
||||
.env
|
||||
.env.local
|
||||
.env.example
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Build artifacts
|
||||
*.exe
|
||||
coverage.out
|
||||
*.test
|
||||
|
||||
# Temporary files
|
||||
tmp/
|
||||
temp/
|
||||
*.tmp
|
||||
*.log
|
||||
|
||||
# GitHub Actions (not needed in container)
|
||||
.github/
|
||||
|
||||
# Docker
|
||||
Dockerfile*
|
||||
docker-compose*
|
||||
.dockerignore
|
||||
|
||||
# Node.js (if any frontend assets)
|
||||
node_modules/
|
||||
npm-debug.log
|
||||
yarn-error.log
|
||||
@@ -0,0 +1,40 @@
|
||||
# Application Configuration
|
||||
APP_NAME=ackify
|
||||
APP_DNS=your-domain.com
|
||||
APP_BASE_URL=https://your-domain.com
|
||||
APP_ORGANISATION="Your Organization Name"
|
||||
|
||||
# Database Configuration
|
||||
POSTGRES_USER=ackifyr
|
||||
POSTGRES_PASSWORD=your_secure_password
|
||||
POSTGRES_DB=ackify
|
||||
DB_DSN=postgres://user:pass@db:5432/ack?sslmode=disable
|
||||
|
||||
# OAuth2 Configuration - Generic Provider
|
||||
OAUTH_CLIENT_ID=your_oauth_client_id
|
||||
OAUTH_CLIENT_SECRET=your_oauth_client_secret
|
||||
OAUTH_ALLOWED_DOMAIN=your-organization.com
|
||||
|
||||
# OAuth2 Provider Configuration
|
||||
# Use OAUTH_PROVIDER to configure popular providers automatically:
|
||||
# - "google" for Google OAuth2
|
||||
# - "github" for GitHub OAuth2
|
||||
# - "gitlab" for GitLab OAuth2 (set OAUTH_GITLAB_URL if self-hosted)
|
||||
# - Leave empty for custom provider (requires manual URL configuration)
|
||||
OAUTH_PROVIDER=google
|
||||
|
||||
# Custom OAuth2 Provider URLs (only needed if OAUTH_PROVIDER is empty)
|
||||
# OAUTH_AUTH_URL=https://your-provider.com/oauth/authorize
|
||||
# OAUTH_TOKEN_URL=https://your-provider.com/oauth/token
|
||||
# OAUTH_USERINFO_URL=https://your-provider.com/api/user
|
||||
# OAUTH_SCOPES=openid,email
|
||||
|
||||
# GitLab specific (if using gitlab as provider and self-hosted)
|
||||
# OAUTH_GITLAB_URL=https://gitlab.your-company.com
|
||||
|
||||
# Security Configuration
|
||||
OAUTH_COOKIE_SECRET=your_base64_encoded_secret_key
|
||||
ED25519_PRIVATE_KEY_B64=your_base64_encoded_ed25519_private_key
|
||||
|
||||
# Server Configuration
|
||||
LISTEN_ADDR=:8080
|
||||
@@ -0,0 +1,184 @@
|
||||
name: CI/CD Pipeline
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, develop ]
|
||||
pull_request:
|
||||
branches: [ main, develop ]
|
||||
release:
|
||||
types: [ published ]
|
||||
|
||||
env:
|
||||
REGISTRY: docker.io
|
||||
IMAGE_NAME: btouchard/ackify
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Run Tests
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:15-alpine
|
||||
env:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: testpassword
|
||||
POSTGRES_DB: ackify_test
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
ports:
|
||||
- 5432:5432
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '1.24.5'
|
||||
cache: true
|
||||
|
||||
- name: Download dependencies
|
||||
run: go mod download
|
||||
|
||||
- name: Run go fmt check
|
||||
run: |
|
||||
if [ "$(gofmt -s -l . | wc -l)" -gt 0 ]; then
|
||||
echo "The following files need to be formatted:"
|
||||
gofmt -s -l .
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Run go vet
|
||||
run: go vet ./...
|
||||
|
||||
- name: Run unit tests
|
||||
env:
|
||||
APP_BASE_URL: "http://localhost:8080"
|
||||
APP_ORGANISATION: "Test Org"
|
||||
OAUTH_CLIENT_ID: "test-client-id"
|
||||
OAUTH_CLIENT_SECRET: "test-client-secret"
|
||||
OAUTH_COOKIE_SECRET: "dGVzdC1jb29raWUtc2VjcmV0LXRlc3QtY29va2llLXNlY3JldA=="
|
||||
run: go test -v -race -short ./...
|
||||
|
||||
- name: Run integrations tests
|
||||
env:
|
||||
DB_DSN: "postgres://postgres:testpassword@localhost:5432/ackify_test?sslmode=disable"
|
||||
INTEGRATION_TESTS: "true"
|
||||
run: go test -v -race -tags=integrations ./internal/infrastructure/database/...
|
||||
|
||||
- name: Generate coverage report
|
||||
env:
|
||||
DB_DSN: "postgres://postgres:testpassword@localhost:5432/ackify_test?sslmode=disable"
|
||||
INTEGRATION_TESTS: "true"
|
||||
APP_BASE_URL: "http://localhost:8080"
|
||||
APP_ORGANISATION: "Test Org"
|
||||
OAUTH_CLIENT_ID: "test-client-id"
|
||||
OAUTH_CLIENT_SECRET: "test-client-secret"
|
||||
OAUTH_COOKIE_SECRET: "dGVzdC1jb29raWUtc2VjcmV0LXRlc3QtY29va2llLXNlY3JldA=="
|
||||
run: go test -v -race -tags=integrations -coverprofile=coverage.out ./...
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
if: success()
|
||||
uses: codecov/codecov-action@v3
|
||||
with:
|
||||
file: ./coverage.out
|
||||
flags: unittests,integrations
|
||||
name: codecov-umbrella
|
||||
|
||||
build:
|
||||
name: Build and Push Docker Image
|
||||
runs-on: ubuntu-latest
|
||||
needs: test
|
||||
if: github.event_name != 'pull_request'
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
|
||||
- name: Extract metadata
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=ref,event=pr
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=semver,pattern={{major}}
|
||||
type=sha,prefix={{branch}}-
|
||||
type=raw,value=latest,enable={{is_default_branch}}
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
build-args: |
|
||||
VERSION=${{ github.ref_name }}
|
||||
COMMIT=${{ github.sha }}
|
||||
BUILD_DATE=${{ github.event.head_commit.timestamp }}
|
||||
|
||||
security:
|
||||
name: Security Scan
|
||||
runs-on: ubuntu-latest
|
||||
needs: build
|
||||
if: github.event_name != 'pull_request'
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
with:
|
||||
image-ref: '${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ github.ref_name }}'
|
||||
format: 'sarif'
|
||||
output: 'trivy-results.sarif'
|
||||
|
||||
# - name: Upload Trivy scan results to GitHub Security tab
|
||||
# uses: github/codeql-action/upload-sarif@v2
|
||||
# if: always()
|
||||
# with:
|
||||
# sarif_file: 'trivy-results.sarif'
|
||||
|
||||
notify:
|
||||
name: Notify
|
||||
runs-on: ubuntu-latest
|
||||
needs: [test, build, security]
|
||||
if: always() && github.event_name != 'pull_request'
|
||||
|
||||
steps:
|
||||
- name: Notify success
|
||||
if: needs.test.result == 'success' && needs.build.result == 'success'
|
||||
run: |
|
||||
echo "✅ CI/CD Pipeline completed successfully!"
|
||||
echo "🚀 Image pushed: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ github.ref_name }}"
|
||||
|
||||
- name: Notify failure
|
||||
if: needs.test.result == 'failure' || needs.build.result == 'failure'
|
||||
run: |
|
||||
echo "❌ CI/CD Pipeline failed!"
|
||||
echo "Please check the logs above for details."
|
||||
exit 1
|
||||
@@ -0,0 +1,8 @@
|
||||
CLAUDE.md
|
||||
*SETUP.md
|
||||
.claude
|
||||
.idea
|
||||
.env
|
||||
|
||||
docker-compose.local.yml
|
||||
|
||||
+62
@@ -0,0 +1,62 @@
|
||||
# ---- Build stage ----
|
||||
FROM golang:alpine AS builder
|
||||
|
||||
# Install security updates and ca-certificates
|
||||
RUN apk update && apk add --no-cache ca-certificates git && rm -rf /var/cache/apk/*
|
||||
|
||||
# Create non-root user for build
|
||||
RUN adduser -D -g '' ackuser
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy go mod files first for better layer caching
|
||||
COPY go.mod go.sum ./
|
||||
|
||||
# Set GOTOOLCHAIN to auto to allow Go toolchain updates
|
||||
ENV GOTOOLCHAIN=auto
|
||||
|
||||
RUN go mod download && go mod verify
|
||||
|
||||
# Copy source code
|
||||
COPY . .
|
||||
|
||||
# Build arguments for metadata
|
||||
ARG VERSION="dev"
|
||||
ARG COMMIT="unknown"
|
||||
ARG BUILD_DATE="unknown"
|
||||
|
||||
# Build the application with optimizations and metadata
|
||||
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \
|
||||
-a -installsuffix cgo \
|
||||
-ldflags="-w -s -X main.Version=${VERSION} -X main.Commit=${COMMIT} -X main.BuildDate=${BUILD_DATE}" \
|
||||
-o ackify ./cmd/ackify
|
||||
|
||||
# ---- Runtime stage ----
|
||||
FROM gcr.io/distroless/static-debian12:nonroot
|
||||
|
||||
# Re-declare ARG for runtime stage
|
||||
ARG VERSION="dev"
|
||||
|
||||
# Add metadata labels
|
||||
LABEL maintainer="Benjamin TOUCHARD"
|
||||
LABEL version="${VERSION}"
|
||||
LABEL description="Ackify - Document signature validation platform"
|
||||
LABEL org.opencontainers.image.source="https://github.com/btouchard/ackify"
|
||||
LABEL org.opencontainers.image.description="Professional solution for validating and tracking document reading"
|
||||
LABEL org.opencontainers.image.licenses="SSPL"
|
||||
|
||||
# Copy certificates for HTTPS requests
|
||||
COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/
|
||||
|
||||
# Set working directory and copy application files
|
||||
WORKDIR /app
|
||||
COPY --from=builder /app/ackify /app/ackify
|
||||
COPY --from=builder /app/web /app/web
|
||||
|
||||
# Use non-root user (already set in distroless image)
|
||||
# USER 65532:65532
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
|
||||
ENTRYPOINT ["/app/ackify"]
|
||||
@@ -0,0 +1,103 @@
|
||||
Server Side Public License
|
||||
Version 1, October 16, 2018
|
||||
|
||||
Copyright (c) 2024 Benjamin TOUCHARD
|
||||
|
||||
The Server Side Public License (the "License") applies to the use, reproduction,
|
||||
modification and distribution of the Work and any derivatives thereof.
|
||||
|
||||
The Work is (c) 2024 Benjamin TOUCHARD
|
||||
|
||||
Parameters
|
||||
|
||||
Licensor: Benjamin TOUCHARD
|
||||
Licensed Work: Ackify
|
||||
The Licensed Work is (c) 2024 Benjamin TOUCHARD
|
||||
|
||||
Additional Use Grant: You may make use of the Licensed Work, provided that you may not use
|
||||
the Licensed Work for a Service.
|
||||
|
||||
A "Service" is a commercial offering, product, hosted, or managed
|
||||
service, that allows third parties (other than your own employees
|
||||
and contractors acting on your behalf) to access and/or use the
|
||||
Licensed Work or a substantial set of the features or functionality
|
||||
of the Licensed Work to third parties as a software-as-a-service,
|
||||
platform-as-a-service, infrastructure-as-a-service or other similar
|
||||
services that compete with Licensor products or services.
|
||||
|
||||
Change Date: The earlier of the date specified in a Change License, or four
|
||||
years from the date the Licensed Work is published.
|
||||
|
||||
Change License: Apache License, Version 2.0
|
||||
|
||||
For information about alternative licensing arrangements for the Licensed Work,
|
||||
please visit: https://github.com/btouchard/ackify
|
||||
|
||||
Notice
|
||||
|
||||
The Business Source License (this document) is not an Open Source license.
|
||||
However, the Licensed Work will eventually be made available under an Open Source
|
||||
License, as stated in this License.
|
||||
|
||||
License Text
|
||||
|
||||
Terms and Conditions
|
||||
|
||||
1. License Grant. Subject to the terms and conditions of this License, Licensor
|
||||
hereby grants to you a non-exclusive, royalty-free, worldwide, non-transferable
|
||||
license during the term of this License to:
|
||||
|
||||
(a) use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
|
||||
of the Licensed Work; and
|
||||
(b) permit persons to whom the Licensed Work is furnished to do so, subject to
|
||||
the following conditions.
|
||||
|
||||
2. Limitations. You may not use the Licensed Work:
|
||||
|
||||
(a) for a Service, unless you have a separate agreement with Licensor
|
||||
permitting such use; or
|
||||
(b) to provide a Service to third parties.
|
||||
|
||||
3. Conditions. Your exercise of the rights granted under this License is subject to
|
||||
the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Licensed Work a copy of this License;
|
||||
(b) You must cause any modified files to carry prominent notices stating that you
|
||||
changed the files;
|
||||
(c) You must retain, in the source form of any derivative works that you distribute,
|
||||
all copyright, patent, trademark, and attribution notices from the source form
|
||||
of the Licensed Work, excluding those notices that do not pertain to any part
|
||||
of the derivative works; and
|
||||
(d) If the Licensed Work includes a "NOTICE" text file as part of its distribution,
|
||||
then any derivative works that you distribute must include a readable copy of
|
||||
the attribution notices contained within such NOTICE file.
|
||||
|
||||
4. Termination. This License will terminate automatically upon any breach by you of
|
||||
the terms of this License. Upon termination, you must stop all use of the Licensed
|
||||
Work and destroy all copies of the Licensed Work in your possession or control.
|
||||
|
||||
5. Disclaimer of Warranty. THE LICENSED WORK IS PROVIDED "AS IS" WITHOUT WARRANTY
|
||||
OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO
|
||||
EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES
|
||||
OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE LICENSED WORK OR THE USE OR OTHER DEALINGS
|
||||
IN THE LICENSED WORK.
|
||||
|
||||
6. Limitation of Liability. IN NO EVENT SHALL THE LICENSOR BE LIABLE FOR ANY DIRECT,
|
||||
INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
|
||||
OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THE LICENSED WORK, EVEN IF
|
||||
ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
7. General. This License represents the complete agreement concerning the subject
|
||||
matter hereof. If any provision of this License is held to be unenforceable,
|
||||
such provision shall be reformed only to the extent necessary to make it
|
||||
enforceable. Any use of the Licensed Work in violation of this License will
|
||||
automatically terminate your rights under this License for the current and all
|
||||
other versions of the Licensed Work.
|
||||
|
||||
For more information on the Server Side Public License, please see:
|
||||
https://www.mongodb.com/licensing/server-side-public-license
|
||||
@@ -0,0 +1,125 @@
|
||||
# Makefile for ackify project
|
||||
|
||||
.PHONY: build test test-unit test-integration test-short coverage lint fmt vet clean help
|
||||
|
||||
# Variables
|
||||
BINARY_NAME=ackapp
|
||||
BUILD_DIR=./cmd/ackapp
|
||||
COVERAGE_DIR=coverage
|
||||
|
||||
# Default target
|
||||
help: ## Display this help message
|
||||
@echo "Available targets:"
|
||||
@awk 'BEGIN {FS = ":.*?## "} /^[a-zA-Z_-]+:.*?## / {printf " \033[36m%-15s\033[0m %s\n", $$1, $$2}' $(MAKEFILE_LIST)
|
||||
|
||||
# Build targets
|
||||
build: ## Build the application
|
||||
@echo "Building $(BINARY_NAME)..."
|
||||
go build -o $(BINARY_NAME) $(BUILD_DIR)
|
||||
|
||||
# Test targets
|
||||
test: test-unit test-integration ## Run all tests
|
||||
|
||||
test-unit: ## Run unit tests
|
||||
@echo "Running unit tests with race detection..."
|
||||
CGO_ENABLED=1 go test -short -race -v ./internal/... ./pkg/... ./cmd/...
|
||||
|
||||
test-integration: ## Run integration tests (requires PostgreSQL)
|
||||
@echo "Running integration tests with race detection..."
|
||||
@if [ -z "$(DB_DSN)" ]; then \
|
||||
export DB_DSN="postgres://postgres:testpassword@localhost:5432/ackify_test?sslmode=disable"; \
|
||||
fi; \
|
||||
export INTEGRATION_TESTS=true; \
|
||||
CGO_ENABLED=1 go test -v -race -tags=integration ./internal/infrastructure/database/...
|
||||
|
||||
test-integration-setup: ## Setup test database for integration tests
|
||||
@echo "Setting up test database..."
|
||||
@psql "postgres://postgres:testpassword@localhost:5432/postgres?sslmode=disable" -c "DROP DATABASE IF EXISTS ackify_test;" || true
|
||||
@psql "postgres://postgres:testpassword@localhost:5432/postgres?sslmode=disable" -c "CREATE DATABASE ackify_test;"
|
||||
@echo "Test database ready!"
|
||||
|
||||
test-short: ## Run only quick tests
|
||||
@echo "Running short tests..."
|
||||
go test -short ./...
|
||||
|
||||
|
||||
# Coverage targets
|
||||
coverage: ## Generate test coverage report
|
||||
@echo "Generating coverage report..."
|
||||
@mkdir -p $(COVERAGE_DIR)
|
||||
go test -coverprofile=$(COVERAGE_DIR)/coverage.out ./...
|
||||
go tool cover -html=$(COVERAGE_DIR)/coverage.out -o $(COVERAGE_DIR)/coverage.html
|
||||
@echo "Coverage report generated: $(COVERAGE_DIR)/coverage.html"
|
||||
|
||||
coverage-integration: ## Generate integration test coverage report
|
||||
@echo "Generating integration coverage report..."
|
||||
@mkdir -p $(COVERAGE_DIR)
|
||||
@export DB_DSN="postgres://postgres:testpassword@localhost:5432/ackify_test?sslmode=disable"; \
|
||||
export INTEGRATION_TESTS=true; \
|
||||
go test -v -race -tags=integration -coverprofile=$(COVERAGE_DIR)/coverage-integration.out ./internal/infrastructure/database/...
|
||||
go tool cover -html=$(COVERAGE_DIR)/coverage-integration.out -o $(COVERAGE_DIR)/coverage-integration.html
|
||||
@echo "Integration coverage report generated: $(COVERAGE_DIR)/coverage-integration.html"
|
||||
|
||||
coverage-all: ## Generate full coverage report (unit + integration)
|
||||
@echo "Generating full coverage report..."
|
||||
@mkdir -p $(COVERAGE_DIR)
|
||||
@export DB_DSN="postgres://postgres:testpassword@localhost:5432/ackify_test?sslmode=disable"; \
|
||||
export INTEGRATION_TESTS=true; \
|
||||
go test -v -race -tags=integration -coverprofile=$(COVERAGE_DIR)/coverage-all.out ./...
|
||||
go tool cover -html=$(COVERAGE_DIR)/coverage-all.out -o $(COVERAGE_DIR)/coverage-all.html
|
||||
go tool cover -func=$(COVERAGE_DIR)/coverage-all.out
|
||||
@echo "Full coverage report generated: $(COVERAGE_DIR)/coverage-all.html"
|
||||
|
||||
coverage-func: ## Show function-level coverage
|
||||
go test -coverprofile=$(COVERAGE_DIR)/coverage.out ./...
|
||||
go tool cover -func=$(COVERAGE_DIR)/coverage.out
|
||||
|
||||
# Code quality targets
|
||||
fmt: ## Format Go code
|
||||
@echo "Formatting code..."
|
||||
go fmt ./...
|
||||
|
||||
vet: ## Run go vet
|
||||
@echo "Running go vet..."
|
||||
go vet ./...
|
||||
|
||||
lint: fmt vet ## Run all linting tools
|
||||
|
||||
# Development targets
|
||||
clean: ## Clean build artifacts and test coverage
|
||||
@echo "Cleaning..."
|
||||
rm -f $(BINARY_NAME)
|
||||
rm -rf $(COVERAGE_DIR)
|
||||
go clean ./...
|
||||
|
||||
deps: ## Download and tidy dependencies
|
||||
@echo "Downloading dependencies..."
|
||||
go mod download
|
||||
go mod tidy
|
||||
|
||||
# Mock generation
|
||||
generate-mocks: ## Generate mocks for interfaces
|
||||
@echo "Generating mocks..."
|
||||
@command -v mockgen >/dev/null 2>&1 || { echo "Installing mockgen..."; go install go.uber.org/mock/mockgen@latest; }
|
||||
@mkdir -p test/mocks
|
||||
mockgen -source=internal/presentation/handlers/signature_handlers.go -destination=test/mocks/mock_signature_service.go -package=mocks SignatureService
|
||||
mockgen -source=internal/presentation/handlers/auth_handlers.go -destination=test/mocks/mock_auth_service.go -package=mocks AuthService
|
||||
mockgen -source=internal/domain/repositories/signature_repository.go -destination=test/mocks/mock_signature_repository.go -package=mocks SignatureRepository
|
||||
|
||||
# Docker targets
|
||||
docker-build: ## Build Docker image
|
||||
docker build -t ackify:latest .
|
||||
|
||||
docker-test: ## Run tests in Docker environment
|
||||
docker compose -f docker-compose.local.yml up -d postgres
|
||||
@sleep 5
|
||||
$(MAKE) test
|
||||
docker compose -f docker-compose.local.yml down
|
||||
|
||||
# CI targets
|
||||
ci: deps lint test coverage ## Run all CI checks
|
||||
|
||||
# Install dev tools
|
||||
dev-tools: ## Install development tools
|
||||
@echo "Installing development tools..."
|
||||
go install go.uber.org/mock/mockgen@latest
|
||||
@@ -0,0 +1,292 @@
|
||||
# 🔐 Ackify
|
||||
|
||||
> **Proof of Read. Compliance made simple.**
|
||||
|
||||
Service sécurisé de validation de lecture avec traçabilité cryptographique et preuves incontestables.
|
||||
|
||||
[](https://github.com/btouchard/ackify)
|
||||
[](https://en.wikipedia.org/wiki/EdDSA)
|
||||
[](https://golang.org/)
|
||||
[](LICENSE)
|
||||
|
||||
> 🌍 [English version available here](README_EN.md)
|
||||
|
||||
## 🎯 Pourquoi Ackify ?
|
||||
|
||||
**Problème** : Comment prouver qu'un collaborateur a bien lu et compris un document important ?
|
||||
|
||||
**Solution** : Signatures cryptographiques Ed25519 avec horodatage immutable et traçabilité complète.
|
||||
|
||||
### Cas d'usage concrets
|
||||
- ✅ Validation de politiques de sécurité
|
||||
- ✅ Attestations de formation obligatoire
|
||||
- ✅ Prise de connaissance RGPD
|
||||
- ✅ Accusés de réception contractuels
|
||||
- ✅ Procédures qualité et compliance
|
||||
|
||||
---
|
||||
|
||||
## ⚡ Démarrage Rapide
|
||||
|
||||
### Avec Docker (recommandé)
|
||||
```bash
|
||||
git clone https://github.com/btouchard/ackify.git
|
||||
cd ackify
|
||||
|
||||
# Configuration minimale
|
||||
cp .env.example .env
|
||||
# Éditez .env avec vos paramètres OAuth2
|
||||
|
||||
# Démarrage
|
||||
docker compose up -d
|
||||
|
||||
# Test
|
||||
curl http://localhost:8080/healthz
|
||||
```
|
||||
|
||||
### Variables obligatoires
|
||||
```bash
|
||||
APP_BASE_URL="https://votre-domaine.com"
|
||||
OAUTH_CLIENT_ID="your-oauth-client-id" # Google/GitHub/GitLab
|
||||
OAUTH_CLIENT_SECRET="your-oauth-client-secret"
|
||||
DB_DSN="postgres://user:password@localhost/ackify?sslmode=disable"
|
||||
OAUTH_COOKIE_SECRET="$(openssl rand -base64 32)"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Utilisation Simple
|
||||
|
||||
### 1. Demander une signature
|
||||
```
|
||||
https://votre-domaine.com/sign?doc=procedure_securite_2024
|
||||
```
|
||||
→ L'utilisateur s'authentifie via OAuth2 et valide sa lecture
|
||||
|
||||
### 2. Vérifier les signatures
|
||||
```bash
|
||||
# API JSON - Liste complète
|
||||
curl "https://votre-domaine.com/status?doc=procedure_securite_2024"
|
||||
|
||||
# Badge PNG - Statut individuel
|
||||
curl "https://votre-domaine.com/status.png?doc=procedure_securite_2024&user=jean.dupont@entreprise.com"
|
||||
```
|
||||
|
||||
### 3. Intégrer dans vos pages
|
||||
```html
|
||||
<!-- Widget intégrable -->
|
||||
<iframe src="https://votre-domaine.com/embed?doc=procedure_securite_2024"
|
||||
width="500" height="300"></iframe>
|
||||
|
||||
<!-- Via oEmbed -->
|
||||
<script>
|
||||
fetch('/oembed?url=https://votre-domaine.com/embed?doc=procedure_securite_2024')
|
||||
.then(r => r.json())
|
||||
.then(data => document.getElementById('signatures').innerHTML = data.html);
|
||||
</script>
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🔧 Configuration OAuth2
|
||||
|
||||
### Providers supportés
|
||||
|
||||
| Provider | Configuration |
|
||||
|----------|---------------|
|
||||
| **Google** | `OAUTH_PROVIDER=google` |
|
||||
| **GitHub** | `OAUTH_PROVIDER=github` |
|
||||
| **GitLab** | `OAUTH_PROVIDER=gitlab` + `OAUTH_GITLAB_URL` |
|
||||
| **Custom** | Endpoints personnalisés |
|
||||
|
||||
### Provider personnalisé
|
||||
```bash
|
||||
# Laissez OAUTH_PROVIDER vide
|
||||
OAUTH_AUTH_URL="https://auth.company.com/oauth/authorize"
|
||||
OAUTH_TOKEN_URL="https://auth.company.com/oauth/token"
|
||||
OAUTH_USERINFO_URL="https://auth.company.com/api/user"
|
||||
OAUTH_SCOPES="read:user,user:email"
|
||||
```
|
||||
|
||||
### Restriction par domaine
|
||||
```bash
|
||||
OAUTH_ALLOWED_DOMAIN="@entreprise.com" # Seuls les emails @entreprise.com
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🛡️ Sécurité & Architecture
|
||||
|
||||
### Sécurité cryptographique
|
||||
- **Ed25519** : Signatures numériques de pointe
|
||||
- **SHA-256** : Hachage des payloads contre le tampering
|
||||
- **Horodatage immutable** : Triggers PostgreSQL
|
||||
- **Sessions chiffrées** : Cookies sécurisés
|
||||
- **CSP headers** : Protection XSS
|
||||
|
||||
### Architecture Go
|
||||
```
|
||||
cmd/ackapp/ # Point d'entrée
|
||||
internal/
|
||||
domain/ # Logique métier
|
||||
models/ # Entités
|
||||
repositories/ # Interfaces persistance
|
||||
application/ # Use cases
|
||||
services/ # Implémentations métier
|
||||
infrastructure/ # Adaptateurs
|
||||
auth/ # OAuth2
|
||||
database/ # PostgreSQL
|
||||
config/ # Configuration
|
||||
presentation/ # HTTP
|
||||
handlers/ # Contrôleurs + interfaces
|
||||
templates/ # Vues HTML
|
||||
pkg/ # Utilitaires partagés
|
||||
```
|
||||
|
||||
### Stack technique
|
||||
- **Go 1.24.5** : Performance et simplicité
|
||||
- **PostgreSQL** : Contraintes d'intégrité
|
||||
- **OAuth2** : Multi-providers
|
||||
- **Docker** : Déploiement simplifié
|
||||
- **Traefik** : Reverse proxy HTTPS
|
||||
|
||||
---
|
||||
|
||||
## 📊 Base de Données
|
||||
|
||||
```sql
|
||||
CREATE TABLE signatures (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
doc_id TEXT NOT NULL, -- ID document
|
||||
user_sub TEXT NOT NULL, -- ID OAuth utilisateur
|
||||
user_email TEXT NOT NULL, -- Email utilisateur
|
||||
signed_at TIMESTAMPTZ NOT NULL, -- Timestamp signature
|
||||
payload_hash TEXT NOT NULL, -- Hash cryptographique
|
||||
signature TEXT NOT NULL, -- Signature Ed25519
|
||||
nonce TEXT NOT NULL, -- Anti-replay
|
||||
created_at TIMESTAMPTZ DEFAULT now(), -- Immutable
|
||||
referer TEXT, -- Source (optionnel)
|
||||
prev_hash TEXT,
|
||||
UNIQUE (doc_id, user_sub) -- Une signature par user/doc
|
||||
);
|
||||
```
|
||||
|
||||
**Garanties** :
|
||||
- ✅ **Unicité** : Un utilisateur = une signature par document
|
||||
- ✅ **Immutabilité** : `created_at` protégé par trigger
|
||||
- ✅ **Intégrité** : Hachage SHA-256 pour détecter modifications
|
||||
- ✅ **Non-répudiation** : Signature Ed25519 cryptographiquement prouvable
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Déploiement Production
|
||||
|
||||
### docker-compose.yml
|
||||
```yaml
|
||||
version: '3.8'
|
||||
services:
|
||||
ackapp:
|
||||
image: btouchard/ackify:latest
|
||||
environment:
|
||||
APP_BASE_URL: https://ackify.company.com
|
||||
DB_DSN: postgres://user:pass@postgres:5432/ackdb?sslmode=require
|
||||
OAUTH_CLIENT_ID: ${OAUTH_CLIENT_ID}
|
||||
OAUTH_CLIENT_SECRET: ${OAUTH_CLIENT_SECRET}
|
||||
OAUTH_COOKIE_SECRET: ${OAUTH_COOKIE_SECRET}
|
||||
labels:
|
||||
- "traefik.enable=true"
|
||||
- "traefik.http.routers.ackify.rule=Host(`ackify.company.com`)"
|
||||
- "traefik.http.routers.ackify.tls.certresolver=letsencrypt"
|
||||
|
||||
postgres:
|
||||
image: postgres:15-alpine
|
||||
environment:
|
||||
POSTGRES_DB: ackdb
|
||||
POSTGRES_USER: ackuser
|
||||
POSTGRES_PASSWORD: ${DB_PASSWORD}
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
```
|
||||
|
||||
### Variables production
|
||||
```bash
|
||||
# Sécurité renforcée
|
||||
OAUTH_COOKIE_SECRET="$(openssl rand -base64 64)" # AES-256
|
||||
ED25519_PRIVATE_KEY_B64="$(openssl genpkey -algorithm Ed25519 | base64 -w 0)"
|
||||
|
||||
# HTTPS obligatoire
|
||||
APP_BASE_URL="https://ackify.company.com"
|
||||
|
||||
# PostgreSQL sécurisé
|
||||
DB_DSN="postgres://user:pass@postgres:5432/ackdb?sslmode=require"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📋 API Complète
|
||||
|
||||
### Authentification
|
||||
- `GET /login?next=<url>` - Connexion OAuth2
|
||||
- `GET /logout` - Déconnexion
|
||||
- `GET /oauth2/callback` - Callback OAuth2
|
||||
|
||||
### Signatures
|
||||
- `GET /sign?doc=<id>` - Interface de signature
|
||||
- `POST /sign` - Créer signature
|
||||
- `GET /signatures` - Mes signatures (auth requis)
|
||||
|
||||
### Consultation
|
||||
- `GET /status?doc=<id>` - JSON toutes signatures
|
||||
- `GET /status.png?doc=<id>&user=<email>` - Badge PNG
|
||||
|
||||
### Intégration
|
||||
- `GET /oembed?url=<embed_url>` - Métadonnées oEmbed
|
||||
- `GET /embed?doc=<id>` - Widget HTML
|
||||
|
||||
### Supervision
|
||||
- `GET /healthz` - Health check
|
||||
|
||||
---
|
||||
|
||||
## 🔍 Développement & Tests
|
||||
|
||||
### Build local
|
||||
```bash
|
||||
# Dépendances
|
||||
go mod tidy
|
||||
|
||||
# Build
|
||||
go build ./cmd/ackify
|
||||
|
||||
# Linting
|
||||
go fmt ./...
|
||||
go vet ./...
|
||||
|
||||
# Tests (TODO: ajouter des tests)
|
||||
go test -v ./...
|
||||
```
|
||||
|
||||
### Docker development
|
||||
```bash
|
||||
# Build image
|
||||
docker build -t ackify:dev .
|
||||
|
||||
# Run avec base locale
|
||||
docker run -p 8080:8080 --env-file .env ackify:dev
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🤝 Support
|
||||
|
||||
### Aide & Documentation
|
||||
- 🐛 **Issues** : [GitHub Issues](https://github.com/btouchard/ackify/issues)
|
||||
- 💬 **Discussions** : [GitHub Discussions](https://github.com/btouchard/ackify/discussions)
|
||||
|
||||
### Licence SSPL
|
||||
Usage libre pour projets internes. Restriction pour services commerciaux concurrents.
|
||||
Voir [LICENSE](LICENSE) pour détails complets.
|
||||
|
||||
---
|
||||
|
||||
**Développé avec ❤️ par [Benjamin TOUCHARD](mailto:benjamin@kolapsis.com)**
|
||||
+292
@@ -0,0 +1,292 @@
|
||||
# 🔐 Ackify
|
||||
|
||||
> **Proof of Read. Compliance made simple.**
|
||||
|
||||
Secure document reading validation service with cryptographic traceability and irrefutable proof.
|
||||
|
||||
[](https://github.com/btouchard/ackify)
|
||||
[](https://en.wikipedia.org/wiki/EdDSA)
|
||||
[](https://golang.org/)
|
||||
[](LICENSE)
|
||||
|
||||
> 🇫🇷 [Version française disponible ici](README.md)
|
||||
|
||||
## 🎯 Why Ackify?
|
||||
|
||||
**Problem**: How to prove that a collaborator has actually read and understood an important document?
|
||||
|
||||
**Solution**: Ed25519 cryptographic signatures with immutable timestamps and complete traceability.
|
||||
|
||||
### Real-world use cases
|
||||
- ✅ Security policy validation
|
||||
- ✅ Mandatory training attestations
|
||||
- ✅ GDPR acknowledgment
|
||||
- ✅ Contractual acknowledgments
|
||||
- ✅ Quality and compliance procedures
|
||||
|
||||
---
|
||||
|
||||
## ⚡ Quick Start
|
||||
|
||||
### With Docker (recommended)
|
||||
```bash
|
||||
git clone https://github.com/btouchard/ackify.git
|
||||
cd ackify
|
||||
|
||||
# Minimal configuration
|
||||
cp .env.example .env
|
||||
# Edit .env with your OAuth2 settings
|
||||
|
||||
# Start
|
||||
docker compose up -d
|
||||
|
||||
# Test
|
||||
curl http://localhost:8080/healthz
|
||||
```
|
||||
|
||||
### Required variables
|
||||
```bash
|
||||
APP_BASE_URL="https://your-domain.com"
|
||||
OAUTH_CLIENT_ID="your-oauth-client-id" # Google/GitHub/GitLab
|
||||
OAUTH_CLIENT_SECRET="your-oauth-client-secret"
|
||||
DB_DSN="postgres://user:password@localhost/ackify?sslmode=disable"
|
||||
OAUTH_COOKIE_SECRET="$(openssl rand -base64 32)"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Simple Usage
|
||||
|
||||
### 1. Request a signature
|
||||
```
|
||||
https://your-domain.com/sign?doc=security_procedure_2024
|
||||
```
|
||||
→ User authenticates via OAuth2 and validates their reading
|
||||
|
||||
### 2. Verify signatures
|
||||
```bash
|
||||
# JSON API - Complete list
|
||||
curl "https://your-domain.com/status?doc=security_procedure_2024"
|
||||
|
||||
# PNG Badge - Individual status
|
||||
curl "https://your-domain.com/status.png?doc=security_procedure_2024&user=john.doe@company.com"
|
||||
```
|
||||
|
||||
### 3. Integrate into your pages
|
||||
```html
|
||||
<!-- Embeddable widget -->
|
||||
<iframe src="https://your-domain.com/embed?doc=security_procedure_2024"
|
||||
width="500" height="300"></iframe>
|
||||
|
||||
<!-- Via oEmbed -->
|
||||
<script>
|
||||
fetch('/oembed?url=https://your-domain.com/embed?doc=security_procedure_2024')
|
||||
.then(r => r.json())
|
||||
.then(data => document.getElementById('signatures').innerHTML = data.html);
|
||||
</script>
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🔧 OAuth2 Configuration
|
||||
|
||||
### Supported providers
|
||||
|
||||
| Provider | Configuration |
|
||||
|----------|---------------|
|
||||
| **Google** | `OAUTH_PROVIDER=google` |
|
||||
| **GitHub** | `OAUTH_PROVIDER=github` |
|
||||
| **GitLab** | `OAUTH_PROVIDER=gitlab` + `OAUTH_GITLAB_URL` |
|
||||
| **Custom** | Custom endpoints |
|
||||
|
||||
### Custom provider
|
||||
```bash
|
||||
# Leave OAUTH_PROVIDER empty
|
||||
OAUTH_AUTH_URL="https://auth.company.com/oauth/authorize"
|
||||
OAUTH_TOKEN_URL="https://auth.company.com/oauth/token"
|
||||
OAUTH_USERINFO_URL="https://auth.company.com/api/user"
|
||||
OAUTH_SCOPES="read:user,user:email"
|
||||
```
|
||||
|
||||
### Domain restriction
|
||||
```bash
|
||||
OAUTH_ALLOWED_DOMAIN="@company.com" # Only @company.com emails
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🛡️ Security & Architecture
|
||||
|
||||
### Cryptographic security
|
||||
- **Ed25519**: State-of-the-art digital signatures
|
||||
- **SHA-256**: Payload hashing against tampering
|
||||
- **Immutable timestamps**: PostgreSQL triggers
|
||||
- **Encrypted sessions**: Secure cookies
|
||||
- **CSP headers**: XSS protection
|
||||
|
||||
### Go architecture
|
||||
```
|
||||
cmd/ackapp/ # Entry point
|
||||
internal/
|
||||
domain/ # Business logic
|
||||
models/ # Entities
|
||||
repositories/ # Persistence interfaces
|
||||
application/ # Use cases
|
||||
services/ # Business implementations
|
||||
infrastructure/ # Adapters
|
||||
auth/ # OAuth2
|
||||
database/ # PostgreSQL
|
||||
config/ # Configuration
|
||||
presentation/ # HTTP
|
||||
handlers/ # Controllers + interfaces
|
||||
templates/ # HTML views
|
||||
pkg/ # Shared utilities
|
||||
```
|
||||
|
||||
### Technology stack
|
||||
- **Go 1.24.5**: Performance and simplicity
|
||||
- **PostgreSQL**: Integrity constraints
|
||||
- **OAuth2**: Multi-provider
|
||||
- **Docker**: Simplified deployment
|
||||
- **Traefik**: HTTPS reverse proxy
|
||||
|
||||
---
|
||||
|
||||
## 📊 Database
|
||||
|
||||
```sql
|
||||
CREATE TABLE signatures (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
doc_id TEXT NOT NULL, -- Document ID
|
||||
user_sub TEXT NOT NULL, -- OAuth user ID
|
||||
user_email TEXT NOT NULL, -- User email
|
||||
signed_at TIMESTAMPTZ NOT NULL, -- Signature timestamp
|
||||
payload_hash TEXT NOT NULL, -- Cryptographic hash
|
||||
signature TEXT NOT NULL, -- Ed25519 signature
|
||||
nonce TEXT NOT NULL, -- Anti-replay
|
||||
created_at TIMESTAMPTZ DEFAULT now(), -- Immutable
|
||||
referer TEXT, -- Source (optional)
|
||||
prev_hash TEXT,
|
||||
UNIQUE (doc_id, user_sub) -- One signature per user/doc
|
||||
);
|
||||
```
|
||||
|
||||
**Guarantees**:
|
||||
- ✅ **Uniqueness**: One user = one signature per document
|
||||
- ✅ **Immutability**: `created_at` protected by trigger
|
||||
- ✅ **Integrity**: SHA-256 hash to detect modifications
|
||||
- ✅ **Non-repudiation**: Ed25519 signature cryptographically provable
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Production Deployment
|
||||
|
||||
### docker-compose.yml
|
||||
```yaml
|
||||
version: '3.8'
|
||||
services:
|
||||
ackapp:
|
||||
image: btouchard/ackify:latest
|
||||
environment:
|
||||
APP_BASE_URL: https://ackify.company.com
|
||||
DB_DSN: postgres://user:pass@postgres:5432/ackdb?sslmode=require
|
||||
OAUTH_CLIENT_ID: ${OAUTH_CLIENT_ID}
|
||||
OAUTH_CLIENT_SECRET: ${OAUTH_CLIENT_SECRET}
|
||||
OAUTH_COOKIE_SECRET: ${OAUTH_COOKIE_SECRET}
|
||||
labels:
|
||||
- "traefik.enable=true"
|
||||
- "traefik.http.routers.ackify.rule=Host(`ackify.company.com`)"
|
||||
- "traefik.http.routers.ackify.tls.certresolver=letsencrypt"
|
||||
|
||||
postgres:
|
||||
image: postgres:15-alpine
|
||||
environment:
|
||||
POSTGRES_DB: ackdb
|
||||
POSTGRES_USER: ackuser
|
||||
POSTGRES_PASSWORD: ${DB_PASSWORD}
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
```
|
||||
|
||||
### Production variables
|
||||
```bash
|
||||
# Enhanced security
|
||||
OAUTH_COOKIE_SECRET="$(openssl rand -base64 64)" # AES-256
|
||||
ED25519_PRIVATE_KEY_B64="$(openssl genpkey -algorithm Ed25519 | base64 -w 0)"
|
||||
|
||||
# HTTPS mandatory
|
||||
APP_BASE_URL="https://ackify.company.com"
|
||||
|
||||
# Secure PostgreSQL
|
||||
DB_DSN="postgres://user:pass@postgres:5432/ackdb?sslmode=require"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📋 Complete API
|
||||
|
||||
### Authentication
|
||||
- `GET /login?next=<url>` - OAuth2 login
|
||||
- `GET /logout` - Logout
|
||||
- `GET /oauth2/callback` - OAuth2 callback
|
||||
|
||||
### Signatures
|
||||
- `GET /sign?doc=<id>` - Signature interface
|
||||
- `POST /sign` - Create signature
|
||||
- `GET /signatures` - My signatures (auth required)
|
||||
|
||||
### Consultation
|
||||
- `GET /status?doc=<id>` - JSON all signatures
|
||||
- `GET /status.png?doc=<id>&user=<email>` - PNG badge
|
||||
|
||||
### Integration
|
||||
- `GET /oembed?url=<embed_url>` - oEmbed metadata
|
||||
- `GET /embed?doc=<id>` - HTML widget
|
||||
|
||||
### Monitoring
|
||||
- `GET /healthz` - Health check
|
||||
|
||||
---
|
||||
|
||||
## 🔍 Development & Testing
|
||||
|
||||
### Local build
|
||||
```bash
|
||||
# Dependencies
|
||||
go mod tidy
|
||||
|
||||
# Build
|
||||
go build ./cmd/ackify
|
||||
|
||||
# Linting
|
||||
go fmt ./...
|
||||
go vet ./...
|
||||
|
||||
# Tests (TODO: add tests)
|
||||
go test -v ./...
|
||||
```
|
||||
|
||||
### Docker development
|
||||
```bash
|
||||
# Build image
|
||||
docker build -t ackify:dev .
|
||||
|
||||
# Run with local database
|
||||
docker run -p 8080:8080 --env-file .env ackify:dev
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🤝 Support
|
||||
|
||||
### Help & Documentation
|
||||
- 🐛 **Issues**: [GitHub Issues](https://github.com/btouchard/ackify/issues)
|
||||
- 💬 **Discussions**: [GitHub Discussions](https://github.com/btouchard/ackify/discussions)
|
||||
|
||||
### SSPL License
|
||||
Free usage for internal projects. Restriction for competing commercial services.
|
||||
See [LICENSE](LICENSE) for complete details.
|
||||
|
||||
---
|
||||
|
||||
**Developed with ❤️ by [Benjamin TOUCHARD](mailto:benjamin@kolapsis.com)**
|
||||
@@ -0,0 +1,159 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"ackify/internal/application/services"
|
||||
"ackify/internal/infrastructure/auth"
|
||||
"ackify/internal/infrastructure/config"
|
||||
"ackify/internal/infrastructure/database"
|
||||
"ackify/internal/presentation/handlers"
|
||||
"ackify/internal/presentation/templates"
|
||||
"ackify/pkg/crypto"
|
||||
)
|
||||
|
||||
func main() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Initialize dependencies
|
||||
cfg, db, tmpl, signer, err := initInfrastructure(ctx)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to initialize infrastructure: %v", err)
|
||||
}
|
||||
defer func(db *sql.DB) {
|
||||
_ = db.Close()
|
||||
}(db)
|
||||
|
||||
// Initialize services
|
||||
authService := auth.NewOAuthService(auth.Config{
|
||||
BaseURL: cfg.App.BaseURL,
|
||||
ClientID: cfg.OAuth.ClientID,
|
||||
ClientSecret: cfg.OAuth.ClientSecret,
|
||||
AuthURL: cfg.OAuth.AuthURL,
|
||||
TokenURL: cfg.OAuth.TokenURL,
|
||||
UserInfoURL: cfg.OAuth.UserInfoURL,
|
||||
Scopes: cfg.OAuth.Scopes,
|
||||
AllowedDomain: cfg.OAuth.AllowedDomain,
|
||||
CookieSecret: cfg.OAuth.CookieSecret,
|
||||
SecureCookies: cfg.App.SecureCookies,
|
||||
})
|
||||
|
||||
// Initialize signatures
|
||||
signatureRepo := database.NewSignatureRepository(db)
|
||||
signatureService := services.NewSignatureService(signatureRepo, signer)
|
||||
|
||||
// Initialize handlers
|
||||
authHandlers := handlers.NewAuthHandlers(authService, cfg.App.BaseURL)
|
||||
authMiddleware := handlers.NewAuthMiddleware(authService, cfg.App.BaseURL)
|
||||
signatureHandlers := handlers.NewSignatureHandlers(signatureService, authService, tmpl, cfg.App.BaseURL)
|
||||
badgeHandler := handlers.NewBadgeHandler(signatureService)
|
||||
oembedHandler := handlers.NewOEmbedHandler(signatureService, tmpl, cfg.App.BaseURL, cfg.App.Organisation)
|
||||
healthHandler := handlers.NewHealthHandler()
|
||||
|
||||
// Setup HTTP router
|
||||
router := setupRouter(authHandlers, authMiddleware, signatureHandlers, badgeHandler, oembedHandler, healthHandler)
|
||||
|
||||
// Create HTTP server
|
||||
server := &http.Server{
|
||||
Addr: cfg.Server.ListenAddr,
|
||||
Handler: handlers.SecureHeaders(router),
|
||||
}
|
||||
|
||||
// Start server in a goroutine
|
||||
go func() {
|
||||
log.Printf("Server starting on %s", cfg.Server.ListenAddr)
|
||||
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
log.Fatalf("Server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for interrupt signal for graceful shutdown
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-quit
|
||||
|
||||
log.Println("Shutting down server...")
|
||||
|
||||
// Graceful shutdown with timeout
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := server.Shutdown(shutdownCtx); err != nil {
|
||||
log.Printf("Server forced to shutdown: %v", err)
|
||||
}
|
||||
|
||||
log.Println("Server exited")
|
||||
}
|
||||
|
||||
// initInfrastructure initializes the basic infrastructure components
|
||||
func initInfrastructure(ctx context.Context) (*config.Config, *sql.DB, *template.Template, *crypto.Ed25519Signer, error) {
|
||||
// Load configuration
|
||||
cfg, err := config.Load()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, fmt.Errorf("failed to load config: %w", err)
|
||||
}
|
||||
|
||||
// Initialize database
|
||||
db, err := database.InitDB(ctx, database.Config{
|
||||
DSN: cfg.Database.DSN,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, fmt.Errorf("failed to initialize database: %w", err)
|
||||
}
|
||||
|
||||
// Initialize templates
|
||||
tmpl, err := templates.InitTemplates()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, fmt.Errorf("failed to initialize templates: %w", err)
|
||||
}
|
||||
|
||||
// Initialize cryptographic signer
|
||||
signer, err := crypto.NewEd25519Signer()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, fmt.Errorf("failed to initialize signer: %w", err)
|
||||
}
|
||||
|
||||
return cfg, db, tmpl, signer, nil
|
||||
}
|
||||
|
||||
// setupRouter configures all HTTP routes
|
||||
func setupRouter(
|
||||
authHandlers *handlers.AuthHandlers,
|
||||
authMiddleware *handlers.AuthMiddleware,
|
||||
signatureHandlers *handlers.SignatureHandlers,
|
||||
badgeHandler *handlers.BadgeHandler,
|
||||
oembedHandler *handlers.OEmbedHandler,
|
||||
healthHandler *handlers.HealthHandler,
|
||||
) *httprouter.Router {
|
||||
router := httprouter.New()
|
||||
|
||||
// Public routes
|
||||
router.GET("/", signatureHandlers.HandleIndex)
|
||||
router.GET("/login", authHandlers.HandleLogin)
|
||||
router.GET("/logout", authHandlers.HandleLogout)
|
||||
router.GET("/oauth2/callback", authHandlers.HandleOAuthCallback)
|
||||
router.GET("/status", signatureHandlers.HandleStatusJSON)
|
||||
router.GET("/status.png", badgeHandler.HandleStatusPNG)
|
||||
router.GET("/oembed", oembedHandler.HandleOEmbed)
|
||||
router.GET("/embed", oembedHandler.HandleEmbedView)
|
||||
router.GET("/health", healthHandler.HandleHealth)
|
||||
|
||||
// Protected routes (require authentication)
|
||||
router.GET("/sign", authMiddleware.RequireAuth(signatureHandlers.HandleSignGET))
|
||||
router.POST("/sign", authMiddleware.RequireAuth(signatureHandlers.HandleSignPOST))
|
||||
router.GET("/signatures", authMiddleware.RequireAuth(signatureHandlers.HandleUserSignatures))
|
||||
|
||||
return router
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
name: ackify
|
||||
|
||||
services:
|
||||
ackify:
|
||||
image: btouchard/ackify
|
||||
container_name: ackify
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
APP_BASE_URL: "https://${APP_DNS}"
|
||||
APP_ORGANISATION: "${APP_ORGANISATION}"
|
||||
OAUTH_PROVIDER: "${OAUTH_PROVIDER}"
|
||||
OAUTH_CLIENT_ID: "${OAUTH_CLIENT_ID}"
|
||||
OAUTH_CLIENT_SECRET: "${OAUTH_CLIENT_SECRET}"
|
||||
OAUTH_ALLOWED_DOMAIN: "${OAUTH_ALLOWED_DOMAIN}"
|
||||
OAUTH_COOKIE_SECRET: "${OAUTH_COOKIE_SECRET}"
|
||||
DB_DSN: "postgres://${POSTGRES_USER}:${POSTGRES_PASSWORD}@ackify_db:5432/${POSTGRES_DB}?sslmode=disable"
|
||||
ED25519_PRIVATE_KEY_B64: "${ED25519_PRIVATE_KEY_B64}"
|
||||
LISTEN_ADDR: ":8080"
|
||||
depends_on:
|
||||
ackify_db:
|
||||
condition: service_healthy
|
||||
networks:
|
||||
- web
|
||||
- internal
|
||||
labels:
|
||||
- "traefik.enable=true"
|
||||
- "traefik.http.routers.${APP_NAME}.rule=Host(`${APP_DNS}`)"
|
||||
- "traefik.http.routers.${APP_NAME}.entrypoints=websecure"
|
||||
- "traefik.http.routers.${APP_NAME}.tls.certresolver=letsencrypt"
|
||||
- "traefik.http.services.${APP_NAME}.loadbalancer.server.port=8080"
|
||||
|
||||
ackify_db:
|
||||
image: postgres:16-alpine
|
||||
container_name: ackify_db
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
POSTGRES_USER: ${POSTGRES_USER}
|
||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
|
||||
POSTGRES_DB: ${POSTGRES_DB}
|
||||
volumes:
|
||||
- ackify_data:/var/lib/postgresql/data
|
||||
networks:
|
||||
- internal
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER} -d ${POSTGRES_DB}"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
networks:
|
||||
internal:
|
||||
web:
|
||||
external: true
|
||||
|
||||
volumes:
|
||||
ackify_data:
|
||||
@@ -0,0 +1,21 @@
|
||||
module ackify
|
||||
|
||||
go 1.24.5
|
||||
|
||||
require (
|
||||
github.com/gorilla/securecookie v1.1.2
|
||||
github.com/gorilla/sessions v1.4.0
|
||||
github.com/julienschmidt/httprouter v1.3.0
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/stretchr/testify v1.11.1
|
||||
golang.org/x/oauth2 v0.31.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/kr/pretty v0.3.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rogpeppe/go-internal v1.13.1 // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
@@ -0,0 +1,35 @@
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
|
||||
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
|
||||
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
|
||||
github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ=
|
||||
github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2emc7lT5ik=
|
||||
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
|
||||
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
golang.org/x/oauth2 v0.31.0 h1:8Fq0yVZLh4j4YA47vHKFTa9Ew5XIrCP8LC6UeNZnLxo=
|
||||
golang.org/x/oauth2 v0.31.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
@@ -0,0 +1,294 @@
|
||||
# Intégration Google Docs/Sheets avec Ackify
|
||||
|
||||
Guide d'intégration d'Ackify pour valider la lecture de documents Google via Apps Script.
|
||||
|
||||
## 🎯 Principe
|
||||
|
||||
Permettre aux utilisateurs de signer/valider qu'ils ont lu un document Google Docs ou Google Sheets directement depuis le document, avec validation cryptographique via Ackify.
|
||||
|
||||
## 📋 Prérequis
|
||||
|
||||
1. **Document Google Docs/Sheets** publié ou partagé
|
||||
2. **Instance Ackify** déployée et accessible (ex: `https://ackify.votre-domaine.com`)
|
||||
3. **Accès Google Apps Script** (compte Google)
|
||||
|
||||
## 🚀 Configuration rapide
|
||||
|
||||
### Étape 1 : Obtenir l'ID du document
|
||||
|
||||
L'ID se trouve dans l'URL de votre document :
|
||||
```
|
||||
https://docs.google.com/document/d/[DOCUMENT_ID]/edit
|
||||
https://docs.google.com/spreadsheets/d/[DOCUMENT_ID]/edit
|
||||
```
|
||||
|
||||
### Étape 2 : Créer le script Apps Script
|
||||
|
||||
1. Ouvrir le document Google
|
||||
2. **Extensions** → **Apps Script**
|
||||
3. Remplacer le code par défaut par :
|
||||
|
||||
```javascript
|
||||
// Configuration Ackify
|
||||
const ACKIFY_BASE_URL = 'https://ackify.votre-domaine.com';
|
||||
const DOCUMENT_ID = 'votre-document-id'; // À remplacer
|
||||
|
||||
/**
|
||||
* Ajoute un menu Ackify dans Google Docs/Sheets
|
||||
*/
|
||||
function onOpen() {
|
||||
const ui = DocumentApp.getUi(); // Pour Docs
|
||||
// const ui = SpreadsheetApp.getUi(); // Pour Sheets - décommenter si nécessaire
|
||||
|
||||
ui.createMenu('📝 Ackify')
|
||||
.addItem('✅ Valider ma lecture', 'validateReading')
|
||||
.addItem('📊 Voir les validations', 'viewSignatures')
|
||||
.addSeparator()
|
||||
.addItem('🔗 Intégrer widget', 'showEmbedCode')
|
||||
.addToUi();
|
||||
}
|
||||
|
||||
/**
|
||||
* Redirige vers la page de validation Ackify
|
||||
*/
|
||||
function validateReading() {
|
||||
const signUrl = `${ACKIFY_BASE_URL}/sign?doc=${DOCUMENT_ID}&referrer=${encodeURIComponent(getDocumentUrl())}`;
|
||||
|
||||
const html = `
|
||||
<div style="padding: 20px; font-family: Arial, sans-serif;">
|
||||
<h2>🔒 Validation de lecture</h2>
|
||||
<p>Cliquez pour valider que vous avez lu ce document :</p>
|
||||
<p><a href="${signUrl}" target="_blank" style="
|
||||
display: inline-block;
|
||||
background: #4285f4;
|
||||
color: white;
|
||||
padding: 12px 24px;
|
||||
text-decoration: none;
|
||||
border-radius: 6px;
|
||||
font-weight: bold;
|
||||
">✅ Valider ma lecture</a></p>
|
||||
<p><small>Une signature cryptographique sera générée pour prouver votre lecture.</small></p>
|
||||
</div>
|
||||
`;
|
||||
|
||||
const htmlOutput = HtmlService.createHtmlOutput(html)
|
||||
.setWidth(400)
|
||||
.setHeight(200);
|
||||
|
||||
const ui = DocumentApp.getUi(); // Pour Docs
|
||||
// const ui = SpreadsheetApp.getUi(); // Pour Sheets
|
||||
|
||||
ui.showModalDialog(htmlOutput, 'Validation Ackify');
|
||||
}
|
||||
|
||||
/**
|
||||
* Affiche les validations existantes
|
||||
*/
|
||||
function viewSignatures() {
|
||||
const statusUrl = `${ACKIFY_BASE_URL}/status?doc=${DOCUMENT_ID}`;
|
||||
|
||||
try {
|
||||
const response = UrlFetchApp.fetch(statusUrl);
|
||||
const signatures = JSON.parse(response.getContentText());
|
||||
|
||||
let html = `
|
||||
<div style="padding: 20px; font-family: Arial, sans-serif;">
|
||||
<h2>📊 Validations de lecture</h2>
|
||||
`;
|
||||
|
||||
if (signatures.length === 0) {
|
||||
html += '<p><em>Aucune validation pour ce document.</em></p>';
|
||||
} else {
|
||||
html += `<p><strong>${signatures.length}</strong> validation(s) :</p><ul>`;
|
||||
|
||||
signatures.forEach(sig => {
|
||||
const date = new Date(sig.signed_at_utc).toLocaleDateString('fr-FR', {
|
||||
year: 'numeric',
|
||||
month: 'short',
|
||||
day: 'numeric',
|
||||
hour: '2-digit',
|
||||
minute: '2-digit'
|
||||
});
|
||||
|
||||
const name = sig.user_name || sig.user_email;
|
||||
html += `<li><strong>${name}</strong> - ${date}</li>`;
|
||||
});
|
||||
|
||||
html += '</ul>';
|
||||
}
|
||||
|
||||
html += `
|
||||
<p><a href="${statusUrl}" target="_blank">🔗 Voir les détails</a></p>
|
||||
</div>
|
||||
`;
|
||||
|
||||
const htmlOutput = HtmlService.createHtmlOutput(html)
|
||||
.setWidth(500)
|
||||
.setHeight(400);
|
||||
|
||||
const ui = DocumentApp.getUi(); // Pour Docs
|
||||
// const ui = SpreadsheetApp.getUi(); // Pour Sheets
|
||||
|
||||
ui.showModalDialog(htmlOutput, 'Validations Ackify');
|
||||
|
||||
} catch (error) {
|
||||
const ui = DocumentApp.getUi();
|
||||
ui.alert('Erreur', `Impossible de récupérer les validations : ${error.message}`, ui.ButtonSet.OK);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Affiche le code d'intégration HTML
|
||||
*/
|
||||
function showEmbedCode() {
|
||||
const embedCode = `<!-- Widget Ackify -->
|
||||
<iframe src="${ACKIFY_BASE_URL}/embed?doc=${DOCUMENT_ID}&referrer=${encodeURIComponent(getDocumentUrl())}"
|
||||
width="100%" height="200" frameborder="0"
|
||||
style="border: 1px solid #ddd; border-radius: 6px;">
|
||||
</iframe>`;
|
||||
|
||||
const html = `
|
||||
<div style="padding: 20px; font-family: Arial, sans-serif;">
|
||||
<h2>🔗 Code d'intégration</h2>
|
||||
<p>Copiez ce code HTML pour intégrer le widget Ackify :</p>
|
||||
<textarea readonly style="width: 100%; height: 100px; font-family: monospace; font-size: 12px;">${embedCode}</textarea>
|
||||
<p><small>À intégrer dans une page web, wiki, ou plateforme supportant l'HTML.</small></p>
|
||||
</div>
|
||||
`;
|
||||
|
||||
const htmlOutput = HtmlService.createHtmlOutput(html)
|
||||
.setWidth(600)
|
||||
.setHeight(300);
|
||||
|
||||
const ui = DocumentApp.getUi(); // Pour Docs
|
||||
// const ui = SpreadsheetApp.getUi(); // Pour Sheets
|
||||
|
||||
ui.showModalDialog(htmlOutput, 'Code d\'intégration');
|
||||
}
|
||||
|
||||
/**
|
||||
* Récupère l'URL du document actuel
|
||||
*/
|
||||
function getDocumentUrl() {
|
||||
try {
|
||||
// Pour Google Docs
|
||||
return DocumentApp.getActiveDocument().getUrl();
|
||||
} catch (e) {
|
||||
try {
|
||||
// Pour Google Sheets
|
||||
return SpreadsheetApp.getActiveSpreadsheet().getUrl();
|
||||
} catch (e2) {
|
||||
return `https://docs.google.com/document/d/${DOCUMENT_ID}/edit`;
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Étape 3 : Configuration du script
|
||||
|
||||
1. **Remplacer les variables** :
|
||||
```javascript
|
||||
const ACKIFY_BASE_URL = 'https://votre-ackify.com';
|
||||
const DOCUMENT_ID = 'votre-id-document-google';
|
||||
```
|
||||
|
||||
2. **Pour Google Sheets** : Décommenter les lignes `SpreadsheetApp` et commenter celles de `DocumentApp`
|
||||
|
||||
3. **Sauvegarder** le script (Ctrl+S)
|
||||
|
||||
### Étape 4 : Autoriser les permissions
|
||||
|
||||
1. Cliquer sur **▶️ Exécuter** → `onOpen`
|
||||
2. **Autoriser** l'accès aux APIs Google (première fois)
|
||||
3. Recharger le document Google
|
||||
|
||||
## ✅ Utilisation
|
||||
|
||||
### Menu Ackify
|
||||
|
||||
Un nouveau menu **📝 Ackify** apparaît dans votre document avec :
|
||||
|
||||
- **✅ Valider ma lecture** : Redirige vers Ackify pour signer
|
||||
- **📊 Voir les validations** : Liste des signatures existantes
|
||||
- **🔗 Intégrer widget** : Code HTML pour intégration externe
|
||||
|
||||
### Processus de validation
|
||||
|
||||
1. **Utilisateur** clique sur "Valider ma lecture"
|
||||
2. **Redirection** vers Ackify avec authentification OAuth2
|
||||
3. **Signature cryptographique** générée (Ed25519)
|
||||
4. **Retour** au document avec confirmation
|
||||
|
||||
## 🔧 Personnalisation avancée
|
||||
|
||||
### Notifications automatiques
|
||||
|
||||
Ajouter une fonction de notification lors de nouvelles signatures :
|
||||
|
||||
```javascript
|
||||
/**
|
||||
* Vérifie périodiquement les nouvelles validations
|
||||
*/
|
||||
function checkNewSignatures() {
|
||||
// Logique de vérification et notification
|
||||
// (à implémenter selon vos besoins)
|
||||
}
|
||||
|
||||
/**
|
||||
* Déclenche des vérifications périodiques
|
||||
*/
|
||||
function createTrigger() {
|
||||
ScriptApp.newTrigger('checkNewSignatures')
|
||||
.timeBased()
|
||||
.everyHours(1)
|
||||
.create();
|
||||
}
|
||||
```
|
||||
|
||||
### Badge dans le document
|
||||
|
||||
Intégrer un badge directement dans le document :
|
||||
|
||||
```javascript
|
||||
/**
|
||||
* Insère un badge Ackify dans le document
|
||||
*/
|
||||
function insertBadge() {
|
||||
const doc = DocumentApp.getActiveDocument();
|
||||
const body = doc.getBody();
|
||||
|
||||
const badgeUrl = `${ACKIFY_BASE_URL}/status.png?doc=${DOCUMENT_ID}`;
|
||||
const signUrl = `${ACKIFY_BASE_URL}/sign?doc=${DOCUMENT_ID}`;
|
||||
|
||||
// Insérer image avec lien
|
||||
const paragraph = body.appendParagraph('');
|
||||
const image = paragraph.appendInlineImage(UrlFetchApp.fetch(badgeUrl).getBlob());
|
||||
image.setLinkUrl(signUrl);
|
||||
}
|
||||
```
|
||||
|
||||
## 🛡️ Sécurité
|
||||
|
||||
- **Authentification** : OAuth2 requis pour signer
|
||||
- **Non-répudiation** : Signatures Ed25519 cryptographiquement vérifiables
|
||||
- **Traçabilité** : Horodatage UTC + hash SHA-256
|
||||
- **Intégrité** : Chaînage cryptographique des signatures
|
||||
|
||||
## 🌐 Intégration multi-plateforme
|
||||
|
||||
Le même principe s'applique à d'autres plateformes :
|
||||
|
||||
- **Notion** : Via API et webhooks
|
||||
- **Confluence** : Apps Script ou macros
|
||||
- **SharePoint** : Power Automate + Custom Connector
|
||||
- **Wiki** : Widget HTML intégré
|
||||
|
||||
## 📞 Support
|
||||
|
||||
- **Documentation** : [Ackify Docs](https://docs.ackify.app)
|
||||
- **API** : `GET /status?doc=<id>` et `POST /sign`
|
||||
- **oEmbed** : `GET /oembed?url=<doc_url>`
|
||||
|
||||
---
|
||||
|
||||
**Architecture validée selon CLAUDE.md - Clean Architecture Go 2025** ✨
|
||||
@@ -0,0 +1,97 @@
|
||||
function onOpen() {
|
||||
DocumentApp.getUi()
|
||||
.createMenu("Signatures")
|
||||
// .addItem("Confirmer la lecture de ce document", "openSignature")
|
||||
.addItem("Afficher la barre latérale", "showSignatures")
|
||||
.addToUi();
|
||||
}
|
||||
|
||||
function getSidebarHtml() {
|
||||
var doc = DocumentApp.getActiveDocument();
|
||||
var docId = doc.getId();
|
||||
var url = "https://sign.neodtx.com/embed?doc=" + encodeURIComponent(docId);
|
||||
|
||||
var response = UrlFetchApp.fetch(url, {muteHttpExceptions: true});
|
||||
return response.getContentText();
|
||||
}
|
||||
|
||||
function openSignature() {
|
||||
var doc = DocumentApp.getActiveDocument();
|
||||
var docId = doc.getId();
|
||||
var url = "https://sign.neodtx.com/sign?doc=" + encodeURIComponent(docId);
|
||||
|
||||
var html = '<style>html,body{height:100%;margin:0;padding:0}</style>' +
|
||||
'<iframe src="' + url + '" ' +
|
||||
'style="border:0;width:100%;height:100vh;" ' +
|
||||
'sandbox="allow-scripts allow-popups allow-same-origin allow-forms"></iframe>';
|
||||
|
||||
var output = HtmlService.createHtmlOutput(html);
|
||||
DocumentApp.getUi().showModalDialog(output, 'Confirmer la lecture du document');
|
||||
}
|
||||
|
||||
function showSignatures() {
|
||||
var doc = DocumentApp.getActiveDocument();
|
||||
var docId = doc.getId();
|
||||
var url = "https://sign.neodtx.com/embed?doc=" + encodeURIComponent(docId);
|
||||
|
||||
var response = UrlFetchApp.fetch(url, {muteHttpExceptions: true});
|
||||
var html = response.getContentText();
|
||||
|
||||
var modifiedHtml = html + `
|
||||
<script>
|
||||
document.addEventListener('DOMContentLoaded', function() {
|
||||
// Fonction pour rafraîchir le sidebar
|
||||
function refreshSidebar() {
|
||||
google.script.run.withSuccessHandler(function(newHtml){
|
||||
document.body.innerHTML = newHtml;
|
||||
// Réinjecte les listeners après rafraîchissement
|
||||
addLinkListeners();
|
||||
}).getSidebarHtml();
|
||||
}
|
||||
|
||||
// Ajoute les listeners sur tous les liens
|
||||
function addLinkListeners() {
|
||||
document.querySelectorAll('a[href]').forEach(function(link){
|
||||
link.addEventListener('click', function(e){
|
||||
e.preventDefault();
|
||||
// Ajoute un listener focus sur window
|
||||
function onFocus() {
|
||||
window.removeEventListener('focus', onFocus); // on supprime après déclenchement
|
||||
refreshSidebar();
|
||||
}
|
||||
window.addEventListener('focus', onFocus);
|
||||
window.open(link.href, '_blank');
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
addLinkListeners(); // initial call
|
||||
});
|
||||
</script>
|
||||
`;
|
||||
|
||||
var output = HtmlService.createHtmlOutput(modifiedHtml)
|
||||
.setTitle("Signatures")
|
||||
.setXFrameOptionsMode(HtmlService.XFrameOptionsMode.ALLOWALL)
|
||||
.setSandboxMode(HtmlService.SandboxMode.IFRAME);
|
||||
|
||||
DocumentApp.getUi().showSidebar(output);
|
||||
}
|
||||
|
||||
// function showSignatures() {
|
||||
// var doc = DocumentApp.getActiveDocument();
|
||||
// var docId = doc.getId();
|
||||
// var url = "https://sign.neodtx.com/embed?doc=" + encodeURIComponent(docId);
|
||||
|
||||
// // On insère un iframe pointant sur ton embed
|
||||
// var html = '<style>html,body{height:100%;margin:0;padding:0}</style>' +
|
||||
// '<iframe src="' + url + '" ' +
|
||||
// 'style="border:0;width:100%;height:100vh;" ' +
|
||||
// 'sandbox="allow-scripts allow-popups allow-same-origin allow-forms"></iframe>';
|
||||
|
||||
// var output = HtmlService.createHtmlOutput(html)
|
||||
// .setTitle("Signatures du document")
|
||||
// .setWidth(360); // largeur sidebar (modifiable)
|
||||
|
||||
// DocumentApp.getUi().showSidebar(output);
|
||||
// }
|
||||
@@ -0,0 +1,43 @@
|
||||
{
|
||||
"timeZone": "Europe/Paris",
|
||||
"exceptionLogging": "STACKDRIVER",
|
||||
"runtimeVersion": "V8",
|
||||
"oauthScopes": [
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
"https://www.googleapis.com/auth/drive.file",
|
||||
"https://www.googleapis.com/auth/documents",
|
||||
"https://www.googleapis.com/auth/script.container.ui",
|
||||
"https://www.googleapis.com/auth/script.external_request"
|
||||
],
|
||||
"urlFetchWhitelist": [
|
||||
"https://sign.neodtx.com/"
|
||||
],
|
||||
"addOns": {
|
||||
"common": {
|
||||
"name": "Signature & Certificat",
|
||||
"logoUrl": "https://lh3.googleusercontent.com/-CeJhs3m4l3w/aLwUFVYqUNI/AAAAAAAAABg/qWFdtmoAp9469WfnKIF5-ujRwJj2j7ViACNcBGAsYHQ/s400/icon_128x128.png",
|
||||
"useLocaleFromApp": true,
|
||||
"homepageTrigger": {
|
||||
"runFunction": "showSignatures"
|
||||
},
|
||||
"openLinkUrlPrefixes": [
|
||||
"https://"
|
||||
],
|
||||
"universalActions": [
|
||||
{
|
||||
"label": "Vérifier document",
|
||||
"runFunction": "checkDocument"
|
||||
}
|
||||
]
|
||||
},
|
||||
"docs": {
|
||||
"homepageTrigger": {
|
||||
"runFunction": "showSignatures"
|
||||
},
|
||||
"onFileScopeGrantedTrigger": {
|
||||
"runFunction": "onDocsFileScopeGranted"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,297 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"ackify/internal/domain/models"
|
||||
"ackify/pkg/crypto"
|
||||
"ackify/pkg/logger"
|
||||
)
|
||||
|
||||
type repository interface {
|
||||
Create(ctx context.Context, signature *models.Signature) error
|
||||
GetByDocAndUser(ctx context.Context, docID, userSub string) (*models.Signature, error)
|
||||
GetByDoc(ctx context.Context, docID string) ([]*models.Signature, error)
|
||||
GetByUser(ctx context.Context, userSub string) ([]*models.Signature, error)
|
||||
ExistsByDocAndUser(ctx context.Context, docID, userSub string) (bool, error)
|
||||
CheckUserSignatureStatus(ctx context.Context, docID, userIdentifier string) (bool, error)
|
||||
GetLastSignature(ctx context.Context) (*models.Signature, error)
|
||||
GetAllSignaturesOrdered(ctx context.Context) ([]*models.Signature, error)
|
||||
}
|
||||
|
||||
type cryptoSigner interface {
|
||||
CreateSignature(docID string, user *models.User, timestamp time.Time, nonce string) (string, string, error)
|
||||
}
|
||||
|
||||
type SignatureService struct {
|
||||
repo repository
|
||||
signer cryptoSigner
|
||||
}
|
||||
|
||||
// NewSignatureService creates a new signature service
|
||||
func NewSignatureService(repo repository, signer cryptoSigner) *SignatureService {
|
||||
return &SignatureService{
|
||||
repo: repo,
|
||||
signer: signer,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SignatureService) CreateSignature(ctx context.Context, request *models.SignatureRequest) error {
|
||||
if request.User == nil || !request.User.IsValid() {
|
||||
return models.ErrInvalidUser
|
||||
}
|
||||
|
||||
if request.DocID == "" {
|
||||
return models.ErrInvalidDocument
|
||||
}
|
||||
|
||||
exists, err := s.repo.ExistsByDocAndUser(ctx, request.DocID, request.User.Sub)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check existing signature: %w", err)
|
||||
}
|
||||
|
||||
if exists {
|
||||
return models.ErrSignatureAlreadyExists
|
||||
}
|
||||
|
||||
nonce, err := crypto.GenerateNonce()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
timestamp := time.Now().UTC()
|
||||
payloadHash, signatureB64, err := s.signer.CreateSignature(request.DocID, request.User, timestamp, nonce)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create cryptographic signature: %w", err)
|
||||
}
|
||||
|
||||
lastSignature, err := s.repo.GetLastSignature(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get last signature for chaining: %w", err)
|
||||
}
|
||||
|
||||
var prevHashB64 *string
|
||||
if lastSignature != nil {
|
||||
hash := lastSignature.ComputeRecordHash()
|
||||
prevHashB64 = &hash
|
||||
logger.Logger.Info("Chaining to previous signature",
|
||||
"prevID", lastSignature.ID,
|
||||
"prevHash", hash[:16]+"...")
|
||||
} else {
|
||||
logger.Logger.Info("Creating genesis signature (no previous signature)")
|
||||
}
|
||||
|
||||
var userName *string
|
||||
if request.User.Name != "" {
|
||||
userName = &request.User.Name
|
||||
}
|
||||
|
||||
logger.Logger.Info("Creating signature",
|
||||
"docID", request.DocID,
|
||||
"userSub", request.User.Sub,
|
||||
"userEmail", request.User.NormalizedEmail(),
|
||||
"userName", request.User.Name)
|
||||
|
||||
signature := &models.Signature{
|
||||
DocID: request.DocID,
|
||||
UserSub: request.User.Sub,
|
||||
UserEmail: request.User.NormalizedEmail(),
|
||||
UserName: userName,
|
||||
SignedAtUTC: timestamp,
|
||||
PayloadHashB64: payloadHash,
|
||||
SignatureB64: signatureB64,
|
||||
Nonce: nonce,
|
||||
Referer: request.Referer,
|
||||
PrevHashB64: prevHashB64,
|
||||
}
|
||||
|
||||
if err := s.repo.Create(ctx, signature); err != nil {
|
||||
return fmt.Errorf("failed to save signature: %w", err)
|
||||
}
|
||||
|
||||
logger.Logger.Info("Signature created successfully", "id", signature.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SignatureService) GetSignatureStatus(ctx context.Context, docID string, user *models.User) (*models.SignatureStatus, error) {
|
||||
if user == nil || !user.IsValid() {
|
||||
return nil, models.ErrInvalidUser
|
||||
}
|
||||
|
||||
signature, err := s.repo.GetByDocAndUser(ctx, docID, user.Sub)
|
||||
if err != nil {
|
||||
if errors.Is(err, models.ErrSignatureNotFound) {
|
||||
return &models.SignatureStatus{
|
||||
DocID: docID,
|
||||
UserEmail: user.Email,
|
||||
IsSigned: false,
|
||||
SignedAt: nil,
|
||||
}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get signature: %w", err)
|
||||
}
|
||||
|
||||
return &models.SignatureStatus{
|
||||
DocID: docID,
|
||||
UserEmail: user.Email,
|
||||
IsSigned: true,
|
||||
SignedAt: &signature.SignedAtUTC,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SignatureService) GetDocumentSignatures(ctx context.Context, docID string) ([]*models.Signature, error) {
|
||||
signatures, err := s.repo.GetByDoc(ctx, docID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get document signatures: %w", err)
|
||||
}
|
||||
|
||||
return signatures, nil
|
||||
}
|
||||
|
||||
func (s *SignatureService) GetUserSignatures(ctx context.Context, user *models.User) ([]*models.Signature, error) {
|
||||
if user == nil || !user.IsValid() {
|
||||
return nil, models.ErrInvalidUser
|
||||
}
|
||||
|
||||
signatures, err := s.repo.GetByUser(ctx, user.Sub)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user signatures: %w", err)
|
||||
}
|
||||
|
||||
return signatures, nil
|
||||
}
|
||||
|
||||
func (s *SignatureService) GetSignatureByDocAndUser(ctx context.Context, docID string, user *models.User) (*models.Signature, error) {
|
||||
if user == nil || !user.IsValid() {
|
||||
return nil, models.ErrInvalidUser
|
||||
}
|
||||
|
||||
signature, err := s.repo.GetByDocAndUser(ctx, docID, user.Sub)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get signature: %w", err)
|
||||
}
|
||||
|
||||
return signature, nil
|
||||
}
|
||||
|
||||
func (s *SignatureService) CheckUserSignature(ctx context.Context, docID, userIdentifier string) (bool, error) {
|
||||
exists, err := s.repo.CheckUserSignatureStatus(ctx, docID, userIdentifier)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check user signature: %w", err)
|
||||
}
|
||||
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
type ChainIntegrityResult struct {
|
||||
IsValid bool
|
||||
TotalRecords int
|
||||
BreakAtID *int64
|
||||
Details string
|
||||
}
|
||||
|
||||
func (s *SignatureService) VerifyChainIntegrity(ctx context.Context) (*ChainIntegrityResult, error) {
|
||||
signatures, err := s.repo.GetAllSignaturesOrdered(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get signatures for chain verification: %w", err)
|
||||
}
|
||||
|
||||
result := &ChainIntegrityResult{
|
||||
IsValid: true,
|
||||
TotalRecords: len(signatures),
|
||||
}
|
||||
|
||||
if len(signatures) == 0 {
|
||||
result.Details = "No signatures found"
|
||||
return result, nil
|
||||
}
|
||||
|
||||
if signatures[0].PrevHashB64 != nil {
|
||||
result.IsValid = false
|
||||
result.BreakAtID = &signatures[0].ID
|
||||
result.Details = "Genesis signature has non-null previous hash"
|
||||
return result, nil
|
||||
}
|
||||
|
||||
for i := 1; i < len(signatures); i++ {
|
||||
current := signatures[i]
|
||||
previous := signatures[i-1]
|
||||
|
||||
expectedHash := previous.ComputeRecordHash()
|
||||
|
||||
if current.PrevHashB64 == nil {
|
||||
result.IsValid = false
|
||||
result.BreakAtID = ¤t.ID
|
||||
result.Details = fmt.Sprintf("Signature %d has null previous hash, expected: %s...", current.ID, expectedHash[:16])
|
||||
return result, nil
|
||||
}
|
||||
|
||||
if *current.PrevHashB64 != expectedHash {
|
||||
result.IsValid = false
|
||||
result.BreakAtID = ¤t.ID
|
||||
result.Details = fmt.Sprintf("Hash mismatch at signature %d: expected %s..., got %s...",
|
||||
current.ID, expectedHash[:16], (*current.PrevHashB64)[:16])
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
result.Details = "Chain integrity verified successfully"
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// RebuildChain reconstructs the hash chain for existing signatures
|
||||
// This should be used once after deploying the chain feature to populate prev_hash
|
||||
func (s *SignatureService) RebuildChain(ctx context.Context) error {
|
||||
signatures, err := s.repo.GetAllSignaturesOrdered(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get signatures for chain rebuild: %w", err)
|
||||
}
|
||||
|
||||
if len(signatures) == 0 {
|
||||
logger.Logger.Info("No signatures found, nothing to rebuild")
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Logger.Info("Starting chain rebuild", "totalSignatures", len(signatures))
|
||||
|
||||
// First signature (genesis) should have null prev_hash
|
||||
if signatures[0].PrevHashB64 != nil {
|
||||
// Reset genesis signature
|
||||
signatures[0].PrevHashB64 = nil
|
||||
if err := s.repo.Create(ctx, signatures[0]); err != nil {
|
||||
logger.Logger.Warn("Failed to update genesis signature", "id", signatures[0].ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Process subsequent signatures
|
||||
for i := 1; i < len(signatures); i++ {
|
||||
current := signatures[i]
|
||||
previous := signatures[i-1]
|
||||
|
||||
expectedHash := previous.ComputeRecordHash()
|
||||
|
||||
// Update if hash is missing or incorrect
|
||||
if current.PrevHashB64 == nil || *current.PrevHashB64 != expectedHash {
|
||||
current.PrevHashB64 = &expectedHash
|
||||
|
||||
// Note: This would require an UPDATE method in the repository
|
||||
// For now, we'll log what needs to be updated
|
||||
logger.Logger.Info("Chain rebuild needed for signature",
|
||||
"id", current.ID,
|
||||
"expectedHash", expectedHash[:16]+"...",
|
||||
"currentHash", func() string {
|
||||
if current.PrevHashB64 != nil {
|
||||
return (*current.PrevHashB64)[:16] + "..."
|
||||
}
|
||||
return "null"
|
||||
}())
|
||||
}
|
||||
}
|
||||
|
||||
logger.Logger.Info("Chain rebuild completed", "processedSignatures", len(signatures))
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,911 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"ackify/internal/domain/models"
|
||||
)
|
||||
|
||||
// Mock repository implementation
|
||||
type fakeRepository struct {
|
||||
signatures map[string]*models.Signature // key: docID_userSub
|
||||
allSignatures []*models.Signature
|
||||
shouldFailCreate bool
|
||||
shouldFailGet bool
|
||||
shouldFailExists bool
|
||||
shouldFailGetLast bool
|
||||
shouldFailGetAll bool
|
||||
shouldFailCheck bool
|
||||
}
|
||||
|
||||
func newFakeRepository() *fakeRepository {
|
||||
return &fakeRepository{
|
||||
signatures: make(map[string]*models.Signature),
|
||||
allSignatures: make([]*models.Signature, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeRepository) Create(ctx context.Context, signature *models.Signature) error {
|
||||
if f.shouldFailCreate {
|
||||
return errors.New("repository create failed")
|
||||
}
|
||||
|
||||
signature.ID = int64(len(f.allSignatures) + 1)
|
||||
signature.CreatedAt = time.Now().UTC()
|
||||
|
||||
key := signature.DocID + "_" + signature.UserSub
|
||||
f.signatures[key] = signature
|
||||
f.allSignatures = append(f.allSignatures, signature)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeRepository) GetByDocAndUser(ctx context.Context, docID, userSub string) (*models.Signature, error) {
|
||||
if f.shouldFailGet {
|
||||
return nil, errors.New("repository get failed")
|
||||
}
|
||||
|
||||
key := docID + "_" + userSub
|
||||
signature, exists := f.signatures[key]
|
||||
if !exists {
|
||||
return nil, models.ErrSignatureNotFound
|
||||
}
|
||||
|
||||
return signature, nil
|
||||
}
|
||||
|
||||
func (f *fakeRepository) GetByDoc(ctx context.Context, docID string) ([]*models.Signature, error) {
|
||||
if f.shouldFailGet {
|
||||
return nil, errors.New("repository get failed")
|
||||
}
|
||||
|
||||
var result []*models.Signature
|
||||
for _, sig := range f.signatures {
|
||||
if sig.DocID == docID {
|
||||
result = append(result, sig)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (f *fakeRepository) GetByUser(ctx context.Context, userSub string) ([]*models.Signature, error) {
|
||||
if f.shouldFailGet {
|
||||
return nil, errors.New("repository get failed")
|
||||
}
|
||||
|
||||
var result []*models.Signature
|
||||
for _, sig := range f.signatures {
|
||||
if sig.UserSub == userSub {
|
||||
result = append(result, sig)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (f *fakeRepository) ExistsByDocAndUser(ctx context.Context, docID, userSub string) (bool, error) {
|
||||
if f.shouldFailExists {
|
||||
return false, errors.New("repository exists failed")
|
||||
}
|
||||
|
||||
key := docID + "_" + userSub
|
||||
_, exists := f.signatures[key]
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
func (f *fakeRepository) CheckUserSignatureStatus(ctx context.Context, docID, userIdentifier string) (bool, error) {
|
||||
if f.shouldFailCheck {
|
||||
return false, errors.New("repository check failed")
|
||||
}
|
||||
|
||||
for _, sig := range f.signatures {
|
||||
if sig.DocID == docID && (sig.UserSub == userIdentifier || sig.UserEmail == userIdentifier) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (f *fakeRepository) GetLastSignature(ctx context.Context) (*models.Signature, error) {
|
||||
if f.shouldFailGetLast {
|
||||
return nil, errors.New("repository get last failed")
|
||||
}
|
||||
|
||||
if len(f.allSignatures) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return f.allSignatures[len(f.allSignatures)-1], nil
|
||||
}
|
||||
|
||||
func (f *fakeRepository) GetAllSignaturesOrdered(ctx context.Context) ([]*models.Signature, error) {
|
||||
if f.shouldFailGetAll {
|
||||
return nil, errors.New("repository get all failed")
|
||||
}
|
||||
|
||||
return f.allSignatures, nil
|
||||
}
|
||||
|
||||
// Mock crypto signer implementation
|
||||
type fakeCryptoSigner struct {
|
||||
shouldFail bool
|
||||
}
|
||||
|
||||
func newFakeCryptoSigner() *fakeCryptoSigner {
|
||||
return &fakeCryptoSigner{}
|
||||
}
|
||||
|
||||
func (f *fakeCryptoSigner) CreateSignature(docID string, user *models.User, timestamp time.Time, nonce string) (string, string, error) {
|
||||
if f.shouldFail {
|
||||
return "", "", errors.New("crypto signing failed")
|
||||
}
|
||||
|
||||
payloadHash := "fake-payload-hash-" + docID
|
||||
signature := "fake-signature-" + user.Sub
|
||||
return payloadHash, signature, nil
|
||||
}
|
||||
|
||||
// Test NewSignatureService
|
||||
func TestNewSignatureService(t *testing.T) {
|
||||
repo := newFakeRepository()
|
||||
signer := newFakeCryptoSigner()
|
||||
|
||||
service := NewSignatureService(repo, signer)
|
||||
|
||||
if service == nil {
|
||||
t.Error("NewSignatureService should not return nil")
|
||||
} else if service.repo != repo {
|
||||
t.Error("Service repository not set correctly")
|
||||
} else if service.signer != signer {
|
||||
t.Error("Service signer not set correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignatureService_CreateSignature(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
request *models.SignatureRequest
|
||||
setupRepo func(*fakeRepository)
|
||||
setupSigner func(*fakeCryptoSigner)
|
||||
expectError bool
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "valid signature creation",
|
||||
request: &models.SignatureRequest{
|
||||
DocID: "test-doc-123",
|
||||
User: &models.User{
|
||||
Sub: "user-123",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
},
|
||||
Referer: stringPtr("github"),
|
||||
},
|
||||
setupRepo: func(r *fakeRepository) {},
|
||||
setupSigner: func(s *fakeCryptoSigner) {},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid user - nil",
|
||||
request: &models.SignatureRequest{
|
||||
DocID: "test-doc-123",
|
||||
User: nil,
|
||||
},
|
||||
expectError: true,
|
||||
expectedError: models.ErrInvalidUser,
|
||||
},
|
||||
{
|
||||
name: "invalid user - invalid data",
|
||||
request: &models.SignatureRequest{
|
||||
DocID: "test-doc-123",
|
||||
User: &models.User{
|
||||
Sub: "",
|
||||
Email: "test@example.com",
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
expectedError: models.ErrInvalidUser,
|
||||
},
|
||||
{
|
||||
name: "empty document ID",
|
||||
request: &models.SignatureRequest{
|
||||
DocID: "",
|
||||
User: &models.User{
|
||||
Sub: "user-123",
|
||||
Email: "test@example.com",
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
expectedError: models.ErrInvalidDocument,
|
||||
},
|
||||
{
|
||||
name: "signature already exists",
|
||||
request: &models.SignatureRequest{
|
||||
DocID: "existing-doc",
|
||||
User: &models.User{
|
||||
Sub: "existing-user",
|
||||
Email: "existing@example.com",
|
||||
},
|
||||
},
|
||||
setupRepo: func(r *fakeRepository) {
|
||||
// Pre-populate with existing signature
|
||||
r.signatures["existing-doc_existing-user"] = &models.Signature{
|
||||
ID: 1,
|
||||
DocID: "existing-doc",
|
||||
UserSub: "existing-user",
|
||||
}
|
||||
},
|
||||
setupSigner: func(s *fakeCryptoSigner) {},
|
||||
expectError: true,
|
||||
expectedError: models.ErrSignatureAlreadyExists,
|
||||
},
|
||||
{
|
||||
name: "repository exists check fails",
|
||||
request: &models.SignatureRequest{
|
||||
DocID: "test-doc",
|
||||
User: &models.User{
|
||||
Sub: "user-123",
|
||||
Email: "test@example.com",
|
||||
},
|
||||
},
|
||||
setupRepo: func(r *fakeRepository) {
|
||||
r.shouldFailExists = true
|
||||
},
|
||||
setupSigner: func(s *fakeCryptoSigner) {},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "crypto signing fails",
|
||||
request: &models.SignatureRequest{
|
||||
DocID: "test-doc",
|
||||
User: &models.User{
|
||||
Sub: "user-123",
|
||||
Email: "test@example.com",
|
||||
},
|
||||
},
|
||||
setupRepo: func(r *fakeRepository) {},
|
||||
setupSigner: func(s *fakeCryptoSigner) {
|
||||
s.shouldFail = true
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "repository get last signature fails",
|
||||
request: &models.SignatureRequest{
|
||||
DocID: "test-doc",
|
||||
User: &models.User{
|
||||
Sub: "user-123",
|
||||
Email: "test@example.com",
|
||||
},
|
||||
},
|
||||
setupRepo: func(r *fakeRepository) {
|
||||
r.shouldFailGetLast = true
|
||||
},
|
||||
setupSigner: func(s *fakeCryptoSigner) {},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "repository create fails",
|
||||
request: &models.SignatureRequest{
|
||||
DocID: "test-doc",
|
||||
User: &models.User{
|
||||
Sub: "user-123",
|
||||
Email: "test@example.com",
|
||||
},
|
||||
},
|
||||
setupRepo: func(r *fakeRepository) {
|
||||
r.shouldFailCreate = true
|
||||
},
|
||||
setupSigner: func(s *fakeCryptoSigner) {},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "user without name",
|
||||
request: &models.SignatureRequest{
|
||||
DocID: "test-doc",
|
||||
User: &models.User{
|
||||
Sub: "user-123",
|
||||
Email: "test@example.com",
|
||||
Name: "",
|
||||
},
|
||||
},
|
||||
setupRepo: func(r *fakeRepository) {},
|
||||
setupSigner: func(s *fakeCryptoSigner) {},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := newFakeRepository()
|
||||
signer := newFakeCryptoSigner()
|
||||
|
||||
if tt.setupRepo != nil {
|
||||
tt.setupRepo(repo)
|
||||
}
|
||||
if tt.setupSigner != nil {
|
||||
tt.setupSigner(signer)
|
||||
}
|
||||
|
||||
service := NewSignatureService(repo, signer)
|
||||
|
||||
err := service.CreateSignature(context.Background(), tt.request)
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
return
|
||||
}
|
||||
if tt.expectedError != nil && !errors.Is(err, tt.expectedError) {
|
||||
t.Errorf("Error = %v, expected %v", err, tt.expectedError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify signature was created
|
||||
key := tt.request.DocID + "_" + tt.request.User.Sub
|
||||
signature, exists := repo.signatures[key]
|
||||
if !exists {
|
||||
t.Error("Signature should have been created")
|
||||
return
|
||||
}
|
||||
|
||||
if signature.DocID != tt.request.DocID {
|
||||
t.Errorf("DocID = %v, expected %v", signature.DocID, tt.request.DocID)
|
||||
}
|
||||
if signature.UserSub != tt.request.User.Sub {
|
||||
t.Errorf("UserSub = %v, expected %v", signature.UserSub, tt.request.User.Sub)
|
||||
}
|
||||
if signature.UserEmail != tt.request.User.NormalizedEmail() {
|
||||
t.Errorf("UserEmail = %v, expected %v", signature.UserEmail, tt.request.User.NormalizedEmail())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignatureService_GetSignatureStatus(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
docID string
|
||||
user *models.User
|
||||
setupRepo func(*fakeRepository)
|
||||
expectError bool
|
||||
expectedError error
|
||||
expectedSigned bool
|
||||
}{
|
||||
{
|
||||
name: "user has signed",
|
||||
docID: "test-doc",
|
||||
user: &models.User{
|
||||
Sub: "user-123",
|
||||
Email: "test@example.com",
|
||||
},
|
||||
setupRepo: func(r *fakeRepository) {
|
||||
r.signatures["test-doc_user-123"] = &models.Signature{
|
||||
ID: 1,
|
||||
DocID: "test-doc",
|
||||
UserSub: "user-123",
|
||||
SignedAtUTC: time.Now().UTC(),
|
||||
}
|
||||
},
|
||||
expectError: false,
|
||||
expectedSigned: true,
|
||||
},
|
||||
{
|
||||
name: "user has not signed",
|
||||
docID: "test-doc",
|
||||
user: &models.User{
|
||||
Sub: "user-123",
|
||||
Email: "test@example.com",
|
||||
},
|
||||
setupRepo: func(r *fakeRepository) {},
|
||||
expectError: false,
|
||||
expectedSigned: false,
|
||||
},
|
||||
{
|
||||
name: "invalid user - nil",
|
||||
docID: "test-doc",
|
||||
user: nil,
|
||||
expectError: true,
|
||||
expectedError: models.ErrInvalidUser,
|
||||
},
|
||||
{
|
||||
name: "invalid user - invalid data",
|
||||
docID: "test-doc",
|
||||
user: &models.User{
|
||||
Sub: "",
|
||||
Email: "test@example.com",
|
||||
},
|
||||
expectError: true,
|
||||
expectedError: models.ErrInvalidUser,
|
||||
},
|
||||
{
|
||||
name: "repository get fails",
|
||||
docID: "test-doc",
|
||||
user: &models.User{
|
||||
Sub: "user-123",
|
||||
Email: "test@example.com",
|
||||
},
|
||||
setupRepo: func(r *fakeRepository) {
|
||||
r.shouldFailGet = true
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := newFakeRepository()
|
||||
signer := newFakeCryptoSigner()
|
||||
|
||||
if tt.setupRepo != nil {
|
||||
tt.setupRepo(repo)
|
||||
}
|
||||
|
||||
service := NewSignatureService(repo, signer)
|
||||
|
||||
status, err := service.GetSignatureStatus(context.Background(), tt.docID, tt.user)
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
return
|
||||
}
|
||||
if tt.expectedError != nil && !errors.Is(err, tt.expectedError) {
|
||||
t.Errorf("Error = %v, expected %v", err, tt.expectedError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if status == nil {
|
||||
t.Error("Status should not be nil")
|
||||
return
|
||||
}
|
||||
|
||||
if status.DocID != tt.docID {
|
||||
t.Errorf("Status.DocID = %v, expected %v", status.DocID, tt.docID)
|
||||
}
|
||||
if status.UserEmail != tt.user.Email {
|
||||
t.Errorf("Status.UserEmail = %v, expected %v", status.UserEmail, tt.user.Email)
|
||||
}
|
||||
if status.IsSigned != tt.expectedSigned {
|
||||
t.Errorf("Status.IsSigned = %v, expected %v", status.IsSigned, tt.expectedSigned)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignatureService_GetDocumentSignatures(t *testing.T) {
|
||||
repo := newFakeRepository()
|
||||
signer := newFakeCryptoSigner()
|
||||
service := NewSignatureService(repo, signer)
|
||||
|
||||
// Setup test data
|
||||
sig1 := &models.Signature{ID: 1, DocID: "doc1", UserSub: "user1"}
|
||||
sig2 := &models.Signature{ID: 2, DocID: "doc1", UserSub: "user2"}
|
||||
sig3 := &models.Signature{ID: 3, DocID: "doc2", UserSub: "user1"}
|
||||
|
||||
repo.signatures["doc1_user1"] = sig1
|
||||
repo.signatures["doc1_user2"] = sig2
|
||||
repo.signatures["doc2_user1"] = sig3
|
||||
|
||||
t.Run("get signatures for document", func(t *testing.T) {
|
||||
signatures, err := service.GetDocumentSignatures(context.Background(), "doc1")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(signatures) != 2 {
|
||||
t.Errorf("Expected 2 signatures, got %d", len(signatures))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("repository fails", func(t *testing.T) {
|
||||
repo.shouldFailGet = true
|
||||
_, err := service.GetDocumentSignatures(context.Background(), "doc1")
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSignatureService_GetUserSignatures(t *testing.T) {
|
||||
repo := newFakeRepository()
|
||||
signer := newFakeCryptoSigner()
|
||||
service := NewSignatureService(repo, signer)
|
||||
|
||||
// Setup test data
|
||||
sig1 := &models.Signature{ID: 1, DocID: "doc1", UserSub: "user1"}
|
||||
sig2 := &models.Signature{ID: 2, DocID: "doc2", UserSub: "user1"}
|
||||
sig3 := &models.Signature{ID: 3, DocID: "doc1", UserSub: "user2"}
|
||||
|
||||
repo.signatures["doc1_user1"] = sig1
|
||||
repo.signatures["doc2_user1"] = sig2
|
||||
repo.signatures["doc1_user2"] = sig3
|
||||
|
||||
t.Run("get signatures for user", func(t *testing.T) {
|
||||
user := &models.User{Sub: "user1", Email: "user1@example.com"}
|
||||
signatures, err := service.GetUserSignatures(context.Background(), user)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(signatures) != 2 {
|
||||
t.Errorf("Expected 2 signatures, got %d", len(signatures))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid user", func(t *testing.T) {
|
||||
_, err := service.GetUserSignatures(context.Background(), nil)
|
||||
if err != models.ErrInvalidUser {
|
||||
t.Errorf("Error = %v, expected %v", err, models.ErrInvalidUser)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("repository fails", func(t *testing.T) {
|
||||
user := &models.User{Sub: "user1", Email: "user1@example.com"}
|
||||
repo.shouldFailGet = true
|
||||
_, err := service.GetUserSignatures(context.Background(), user)
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSignatureService_GetSignatureByDocAndUser(t *testing.T) {
|
||||
repo := newFakeRepository()
|
||||
signer := newFakeCryptoSigner()
|
||||
service := NewSignatureService(repo, signer)
|
||||
|
||||
// Setup test data
|
||||
sig := &models.Signature{ID: 1, DocID: "doc1", UserSub: "user1"}
|
||||
repo.signatures["doc1_user1"] = sig
|
||||
|
||||
t.Run("get existing signature", func(t *testing.T) {
|
||||
user := &models.User{Sub: "user1", Email: "user1@example.com"}
|
||||
signature, err := service.GetSignatureByDocAndUser(context.Background(), "doc1", user)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if signature.ID != 1 {
|
||||
t.Errorf("Signature.ID = %v, expected 1", signature.ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid user", func(t *testing.T) {
|
||||
_, err := service.GetSignatureByDocAndUser(context.Background(), "doc1", nil)
|
||||
if err != models.ErrInvalidUser {
|
||||
t.Errorf("Error = %v, expected %v", err, models.ErrInvalidUser)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("repository fails", func(t *testing.T) {
|
||||
user := &models.User{Sub: "user1", Email: "user1@example.com"}
|
||||
repo.shouldFailGet = true
|
||||
_, err := service.GetSignatureByDocAndUser(context.Background(), "doc1", user)
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSignatureService_CheckUserSignature(t *testing.T) {
|
||||
repo := newFakeRepository()
|
||||
signer := newFakeCryptoSigner()
|
||||
service := NewSignatureService(repo, signer)
|
||||
|
||||
// Setup test data
|
||||
sig := &models.Signature{ID: 1, DocID: "doc1", UserSub: "user1", UserEmail: "user1@example.com"}
|
||||
repo.signatures["doc1_user1"] = sig
|
||||
|
||||
t.Run("check by user sub", func(t *testing.T) {
|
||||
exists, err := service.CheckUserSignature(context.Background(), "doc1", "user1")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Error("Should find signature by user sub")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("check by email", func(t *testing.T) {
|
||||
exists, err := service.CheckUserSignature(context.Background(), "doc1", "user1@example.com")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Error("Should find signature by email")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("signature not found", func(t *testing.T) {
|
||||
exists, err := service.CheckUserSignature(context.Background(), "doc1", "nonexistent")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if exists {
|
||||
t.Error("Should not find nonexistent signature")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("repository fails", func(t *testing.T) {
|
||||
repo.shouldFailCheck = true
|
||||
_, err := service.CheckUserSignature(context.Background(), "doc1", "user1")
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSignatureService_VerifyChainIntegrity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupSignatures func(*fakeRepository)
|
||||
expectValid bool
|
||||
expectBreakAtID *int64
|
||||
expectDetails string
|
||||
}{
|
||||
{
|
||||
name: "empty chain",
|
||||
setupSignatures: func(r *fakeRepository) {},
|
||||
expectValid: true,
|
||||
expectDetails: "No signatures found",
|
||||
},
|
||||
{
|
||||
name: "valid chain with single signature",
|
||||
setupSignatures: func(r *fakeRepository) {
|
||||
sig1 := &models.Signature{
|
||||
ID: 1,
|
||||
DocID: "doc1",
|
||||
UserSub: "user1",
|
||||
PrevHashB64: nil, // Genesis
|
||||
}
|
||||
r.allSignatures = []*models.Signature{sig1}
|
||||
},
|
||||
expectValid: true,
|
||||
expectDetails: "Chain integrity verified successfully",
|
||||
},
|
||||
{
|
||||
name: "valid chain with multiple signatures",
|
||||
setupSignatures: func(r *fakeRepository) {
|
||||
sig1 := &models.Signature{
|
||||
ID: 1,
|
||||
DocID: "doc1",
|
||||
UserSub: "user1",
|
||||
PrevHashB64: nil, // Genesis
|
||||
}
|
||||
hash1 := sig1.ComputeRecordHash()
|
||||
sig2 := &models.Signature{
|
||||
ID: 2,
|
||||
DocID: "doc2",
|
||||
UserSub: "user2",
|
||||
PrevHashB64: &hash1,
|
||||
}
|
||||
r.allSignatures = []*models.Signature{sig1, sig2}
|
||||
},
|
||||
expectValid: true,
|
||||
expectDetails: "Chain integrity verified successfully",
|
||||
},
|
||||
{
|
||||
name: "invalid chain - genesis has prev hash",
|
||||
setupSignatures: func(r *fakeRepository) {
|
||||
hash := "invalid-genesis-hash"
|
||||
sig1 := &models.Signature{
|
||||
ID: 1,
|
||||
DocID: "doc1",
|
||||
UserSub: "user1",
|
||||
PrevHashB64: &hash,
|
||||
}
|
||||
r.allSignatures = []*models.Signature{sig1}
|
||||
},
|
||||
expectValid: false,
|
||||
expectBreakAtID: int64Ptr(1),
|
||||
expectDetails: "Genesis signature has non-null previous hash",
|
||||
},
|
||||
{
|
||||
name: "invalid chain - missing prev hash",
|
||||
setupSignatures: func(r *fakeRepository) {
|
||||
sig1 := &models.Signature{
|
||||
ID: 1,
|
||||
DocID: "doc1",
|
||||
UserSub: "user1",
|
||||
PrevHashB64: nil, // Genesis
|
||||
}
|
||||
sig2 := &models.Signature{
|
||||
ID: 2,
|
||||
DocID: "doc2",
|
||||
UserSub: "user2",
|
||||
PrevHashB64: nil, // Should have prev hash
|
||||
}
|
||||
r.allSignatures = []*models.Signature{sig1, sig2}
|
||||
},
|
||||
expectValid: false,
|
||||
expectBreakAtID: int64Ptr(2),
|
||||
},
|
||||
{
|
||||
name: "invalid chain - wrong prev hash",
|
||||
setupSignatures: func(r *fakeRepository) {
|
||||
sig1 := &models.Signature{
|
||||
ID: 1,
|
||||
DocID: "doc1",
|
||||
UserSub: "user1",
|
||||
PrevHashB64: nil, // Genesis
|
||||
}
|
||||
wrongHash := "wrong-hash-that-is-long-enough-for-display"
|
||||
sig2 := &models.Signature{
|
||||
ID: 2,
|
||||
DocID: "doc2",
|
||||
UserSub: "user2",
|
||||
PrevHashB64: &wrongHash,
|
||||
}
|
||||
r.allSignatures = []*models.Signature{sig1, sig2}
|
||||
},
|
||||
expectValid: false,
|
||||
expectBreakAtID: int64Ptr(2),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := newFakeRepository()
|
||||
signer := newFakeCryptoSigner()
|
||||
service := NewSignatureService(repo, signer)
|
||||
|
||||
tt.setupSignatures(repo)
|
||||
|
||||
result, err := service.VerifyChainIntegrity(context.Background())
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result.IsValid != tt.expectValid {
|
||||
t.Errorf("IsValid = %v, expected %v", result.IsValid, tt.expectValid)
|
||||
}
|
||||
|
||||
if tt.expectBreakAtID != nil {
|
||||
if result.BreakAtID == nil {
|
||||
t.Error("Expected BreakAtID to be set")
|
||||
} else if *result.BreakAtID != *tt.expectBreakAtID {
|
||||
t.Errorf("BreakAtID = %v, expected %v", *result.BreakAtID, *tt.expectBreakAtID)
|
||||
}
|
||||
}
|
||||
|
||||
if tt.expectDetails != "" && !contains(result.Details, tt.expectDetails) {
|
||||
t.Errorf("Details should contain %v, got %v", tt.expectDetails, result.Details)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("repository fails", func(t *testing.T) {
|
||||
repo := newFakeRepository()
|
||||
signer := newFakeCryptoSigner()
|
||||
service := NewSignatureService(repo, signer)
|
||||
|
||||
repo.shouldFailGetAll = true
|
||||
|
||||
_, err := service.VerifyChainIntegrity(context.Background())
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSignatureService_RebuildChain(t *testing.T) {
|
||||
t.Run("empty chain", func(t *testing.T) {
|
||||
repo := newFakeRepository()
|
||||
signer := newFakeCryptoSigner()
|
||||
service := NewSignatureService(repo, signer)
|
||||
|
||||
err := service.RebuildChain(context.Background())
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("chain with signatures", func(t *testing.T) {
|
||||
repo := newFakeRepository()
|
||||
signer := newFakeCryptoSigner()
|
||||
service := NewSignatureService(repo, signer)
|
||||
|
||||
// Setup signatures that need rebuilding
|
||||
hash := "wrong-hash"
|
||||
sig1 := &models.Signature{
|
||||
ID: 1,
|
||||
DocID: "doc1",
|
||||
UserSub: "user1",
|
||||
PrevHashB64: &hash, // Should be nil for genesis
|
||||
}
|
||||
sig2 := &models.Signature{
|
||||
ID: 2,
|
||||
DocID: "doc2",
|
||||
UserSub: "user2",
|
||||
PrevHashB64: nil, // Should have correct hash
|
||||
}
|
||||
repo.allSignatures = []*models.Signature{sig1, sig2}
|
||||
|
||||
err := service.RebuildChain(context.Background())
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("repository fails", func(t *testing.T) {
|
||||
repo := newFakeRepository()
|
||||
signer := newFakeCryptoSigner()
|
||||
service := NewSignatureService(repo, signer)
|
||||
|
||||
repo.shouldFailGetAll = true
|
||||
|
||||
err := service.RebuildChain(context.Background())
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestChainIntegrityResult_Structure(t *testing.T) {
|
||||
result := &ChainIntegrityResult{
|
||||
IsValid: true,
|
||||
TotalRecords: 5,
|
||||
BreakAtID: int64Ptr(3),
|
||||
Details: "Test details",
|
||||
}
|
||||
|
||||
if !result.IsValid {
|
||||
t.Error("IsValid should be true")
|
||||
}
|
||||
if result.TotalRecords != 5 {
|
||||
t.Errorf("TotalRecords = %v, expected 5", result.TotalRecords)
|
||||
}
|
||||
if result.BreakAtID == nil || *result.BreakAtID != 3 {
|
||||
t.Errorf("BreakAtID = %v, expected 3", result.BreakAtID)
|
||||
}
|
||||
if result.Details != "Test details" {
|
||||
t.Errorf("Details = %v, expected 'Test details'", result.Details)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
func int64Ptr(i int64) *int64 {
|
||||
return &i
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || containsAt(s, substr, 0))
|
||||
}
|
||||
|
||||
func containsAt(s, substr string, start int) bool {
|
||||
if start+len(substr) > len(s) {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < len(substr); i++ {
|
||||
if s[start+i] != substr[i] {
|
||||
if start+1 < len(s) {
|
||||
return containsAt(s, substr, start+1)
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package models
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrSignatureNotFound = errors.New("signature not found")
|
||||
ErrSignatureAlreadyExists = errors.New("signature already exists")
|
||||
ErrInvalidUser = errors.New("invalid user")
|
||||
ErrInvalidDocument = errors.New("invalid document ID")
|
||||
ErrDatabaseConnection = errors.New("database connection error")
|
||||
ErrUnauthorized = errors.New("unauthorized")
|
||||
ErrDomainNotAllowed = errors.New("domain not allowed")
|
||||
)
|
||||
@@ -0,0 +1,234 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDomainErrors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expectedMsg string
|
||||
shouldNotBeNil bool
|
||||
}{
|
||||
{
|
||||
name: "ErrSignatureNotFound",
|
||||
err: ErrSignatureNotFound,
|
||||
expectedMsg: "signature not found",
|
||||
shouldNotBeNil: true,
|
||||
},
|
||||
{
|
||||
name: "ErrSignatureAlreadyExists",
|
||||
err: ErrSignatureAlreadyExists,
|
||||
expectedMsg: "signature already exists",
|
||||
shouldNotBeNil: true,
|
||||
},
|
||||
{
|
||||
name: "ErrInvalidUser",
|
||||
err: ErrInvalidUser,
|
||||
expectedMsg: "invalid user",
|
||||
shouldNotBeNil: true,
|
||||
},
|
||||
{
|
||||
name: "ErrInvalidDocument",
|
||||
err: ErrInvalidDocument,
|
||||
expectedMsg: "invalid document ID",
|
||||
shouldNotBeNil: true,
|
||||
},
|
||||
{
|
||||
name: "ErrDatabaseConnection",
|
||||
err: ErrDatabaseConnection,
|
||||
expectedMsg: "database connection error",
|
||||
shouldNotBeNil: true,
|
||||
},
|
||||
{
|
||||
name: "ErrUnauthorized",
|
||||
err: ErrUnauthorized,
|
||||
expectedMsg: "unauthorized",
|
||||
shouldNotBeNil: true,
|
||||
},
|
||||
{
|
||||
name: "ErrDomainNotAllowed",
|
||||
err: ErrDomainNotAllowed,
|
||||
expectedMsg: "domain not allowed",
|
||||
shouldNotBeNil: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.shouldNotBeNil && tt.err == nil {
|
||||
t.Errorf("Error should not be nil")
|
||||
return
|
||||
}
|
||||
|
||||
if tt.err.Error() != tt.expectedMsg {
|
||||
t.Errorf("Error message mismatch: got %v, expected %v", tt.err.Error(), tt.expectedMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorComparison(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err1 error
|
||||
err2 error
|
||||
equal bool
|
||||
}{
|
||||
{
|
||||
name: "same error instances are equal",
|
||||
err1: ErrSignatureNotFound,
|
||||
err2: ErrSignatureNotFound,
|
||||
equal: true,
|
||||
},
|
||||
{
|
||||
name: "different error instances are not equal",
|
||||
err1: ErrSignatureNotFound,
|
||||
err2: ErrSignatureAlreadyExists,
|
||||
equal: false,
|
||||
},
|
||||
{
|
||||
name: "wrapped errors can be detected",
|
||||
err1: ErrInvalidUser,
|
||||
err2: errors.New("wrapped: invalid user"),
|
||||
equal: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isEqual := errors.Is(tt.err1, tt.err2)
|
||||
if isEqual != tt.equal {
|
||||
t.Errorf("Error comparison mismatch: got %v, expected %v", isEqual, tt.equal)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorWrapping(t *testing.T) {
|
||||
originalErr := ErrSignatureNotFound
|
||||
wrappedErr := errors.Join(originalErr, errors.New("additional context"))
|
||||
|
||||
// Test that the original error can be detected in the wrapped error
|
||||
if !errors.Is(wrappedErr, originalErr) {
|
||||
t.Error("Original error should be detectable in wrapped error")
|
||||
}
|
||||
|
||||
// Test error message contains both parts
|
||||
wrappedMsg := wrappedErr.Error()
|
||||
if !contains(wrappedMsg, "signature not found") {
|
||||
t.Errorf("Wrapped error should contain original message: %v", wrappedMsg)
|
||||
}
|
||||
if !contains(wrappedMsg, "additional context") {
|
||||
t.Errorf("Wrapped error should contain additional context: %v", wrappedMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorTypes(t *testing.T) {
|
||||
// Test that all errors are of type error interface
|
||||
errors := []error{
|
||||
ErrSignatureNotFound,
|
||||
ErrSignatureAlreadyExists,
|
||||
ErrInvalidUser,
|
||||
ErrInvalidDocument,
|
||||
ErrDatabaseConnection,
|
||||
ErrUnauthorized,
|
||||
ErrDomainNotAllowed,
|
||||
}
|
||||
|
||||
for i, err := range errors {
|
||||
t.Run("error_type_"+string(rune(i+'0')), func(t *testing.T) {
|
||||
if err == nil {
|
||||
t.Error("Error should not be nil")
|
||||
}
|
||||
|
||||
// Test that error implements error interface
|
||||
if _, ok := err.(error); !ok {
|
||||
t.Error("Error should implement error interface")
|
||||
}
|
||||
|
||||
// Test that error message is not empty
|
||||
if err.Error() == "" {
|
||||
t.Error("Error message should not be empty")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorUniqueness(t *testing.T) {
|
||||
// Test that all error messages are unique
|
||||
errors := map[string]error{
|
||||
"signature not found": ErrSignatureNotFound,
|
||||
"signature already exists": ErrSignatureAlreadyExists,
|
||||
"invalid user": ErrInvalidUser,
|
||||
"invalid document ID": ErrInvalidDocument,
|
||||
"database connection error": ErrDatabaseConnection,
|
||||
"unauthorized": ErrUnauthorized,
|
||||
"domain not allowed": ErrDomainNotAllowed,
|
||||
}
|
||||
|
||||
messages := make(map[string]bool)
|
||||
for msg, err := range errors {
|
||||
if messages[msg] {
|
||||
t.Errorf("Duplicate error message found: %v", msg)
|
||||
}
|
||||
messages[msg] = true
|
||||
|
||||
if err.Error() != msg {
|
||||
t.Errorf("Error message mismatch for %v: got %v, expected %v", err, err.Error(), msg)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify we have the expected number of unique errors
|
||||
expectedCount := 7
|
||||
if len(messages) != expectedCount {
|
||||
t.Errorf("Expected %d unique error messages, got %d", expectedCount, len(messages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorSentinelValues(t *testing.T) {
|
||||
// Test that errors are sentinel values (same instance when accessed multiple times)
|
||||
if ErrSignatureNotFound != ErrSignatureNotFound {
|
||||
t.Error("ErrSignatureNotFound should be a sentinel value")
|
||||
}
|
||||
if ErrSignatureAlreadyExists != ErrSignatureAlreadyExists {
|
||||
t.Error("ErrSignatureAlreadyExists should be a sentinel value")
|
||||
}
|
||||
if ErrInvalidUser != ErrInvalidUser {
|
||||
t.Error("ErrInvalidUser should be a sentinel value")
|
||||
}
|
||||
if ErrInvalidDocument != ErrInvalidDocument {
|
||||
t.Error("ErrInvalidDocument should be a sentinel value")
|
||||
}
|
||||
if ErrDatabaseConnection != ErrDatabaseConnection {
|
||||
t.Error("ErrDatabaseConnection should be a sentinel value")
|
||||
}
|
||||
if ErrUnauthorized != ErrUnauthorized {
|
||||
t.Error("ErrUnauthorized should be a sentinel value")
|
||||
}
|
||||
if ErrDomainNotAllowed != ErrDomainNotAllowed {
|
||||
t.Error("ErrDomainNotAllowed should be a sentinel value")
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to check if string contains substring
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || containsAt(s, substr, 0))
|
||||
}
|
||||
|
||||
func containsAt(s, substr string, start int) bool {
|
||||
if start+len(substr) > len(s) {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < len(substr); i++ {
|
||||
if s[start+i] != substr[i] {
|
||||
if start+1 < len(s) {
|
||||
return containsAt(s, substr, start+1)
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"ackify/pkg/services"
|
||||
)
|
||||
|
||||
// Signature represents a document signature record
|
||||
type Signature struct {
|
||||
ID int64 `json:"id" db:"id"`
|
||||
DocID string `json:"doc_id" db:"doc_id"`
|
||||
UserSub string `json:"user_sub" db:"user_sub"`
|
||||
UserEmail string `json:"user_email" db:"user_email"`
|
||||
UserName *string `json:"user_name,omitempty" db:"user_name"`
|
||||
SignedAtUTC time.Time `json:"signed_at_utc" db:"signed_at"`
|
||||
PayloadHashB64 string `json:"payload_hash_b64" db:"payload_hash"`
|
||||
SignatureB64 string `json:"signature_b64" db:"signature"`
|
||||
Nonce string `json:"nonce" db:"nonce"`
|
||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||
Referer *string `json:"referer,omitempty" db:"referer"`
|
||||
PrevHashB64 *string `json:"prev_hash_b64,omitempty" db:"prev_hash"`
|
||||
}
|
||||
|
||||
// GetServiceInfo returns information about the service that originated this signature
|
||||
func (s *Signature) GetServiceInfo() *services.ServiceInfo {
|
||||
if s.Referer == nil {
|
||||
return nil
|
||||
}
|
||||
return services.DetectServiceFromReferrer(*s.Referer)
|
||||
}
|
||||
|
||||
// SignatureRequest represents a request to create a signature
|
||||
type SignatureRequest struct {
|
||||
DocID string
|
||||
User *User
|
||||
Referer *string
|
||||
}
|
||||
|
||||
// SignatureStatus represents the status of a signature for a user
|
||||
type SignatureStatus struct {
|
||||
DocID string
|
||||
UserEmail string
|
||||
IsSigned bool
|
||||
SignedAt *time.Time
|
||||
}
|
||||
|
||||
// ComputeRecordHash computes the SHA-256 hash of a signature record for chaining
|
||||
func (s *Signature) ComputeRecordHash() string {
|
||||
data := fmt.Sprintf("%d|%s|%s|%s|%v|%s|%s|%s|%s|%s|%s",
|
||||
s.ID,
|
||||
s.DocID,
|
||||
s.UserSub,
|
||||
s.UserEmail,
|
||||
s.UserName,
|
||||
s.SignedAtUTC.Format(time.RFC3339Nano),
|
||||
s.PayloadHashB64,
|
||||
s.SignatureB64,
|
||||
s.Nonce,
|
||||
s.CreatedAt.Format(time.RFC3339Nano),
|
||||
func() string {
|
||||
if s.Referer != nil {
|
||||
return *s.Referer
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
)
|
||||
|
||||
hash := sha256.Sum256([]byte(data))
|
||||
return base64.StdEncoding.EncodeToString(hash[:])
|
||||
}
|
||||
@@ -0,0 +1,559 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSignature_JSONSerialization(t *testing.T) {
|
||||
timestamp := time.Date(2024, 1, 15, 10, 30, 45, 123456789, time.UTC)
|
||||
createdAt := time.Date(2024, 1, 15, 10, 30, 46, 0, time.UTC)
|
||||
userName := "Test User"
|
||||
referer := "https://github.com/user/repo"
|
||||
prevHash := "abcd1234efgh5678"
|
||||
|
||||
signature := &Signature{
|
||||
ID: 123,
|
||||
DocID: "test-doc-123",
|
||||
UserSub: "google-oauth2|123456789",
|
||||
UserEmail: "test@example.com",
|
||||
UserName: &userName,
|
||||
SignedAtUTC: timestamp,
|
||||
PayloadHashB64: "SGVsbG8gV29ybGQ=",
|
||||
SignatureB64: "c2lnbmF0dXJlLWRhdGE=",
|
||||
Nonce: "random-nonce-123",
|
||||
CreatedAt: createdAt,
|
||||
Referer: &referer,
|
||||
PrevHashB64: &prevHash,
|
||||
}
|
||||
|
||||
// Test serialization
|
||||
data, err := json.Marshal(signature)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal signature: %v", err)
|
||||
}
|
||||
|
||||
// Test deserialization
|
||||
var unmarshaled Signature
|
||||
err = json.Unmarshal(data, &unmarshaled)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal signature: %v", err)
|
||||
}
|
||||
|
||||
// Verify all fields
|
||||
if unmarshaled.ID != signature.ID {
|
||||
t.Errorf("ID mismatch: got %v, expected %v", unmarshaled.ID, signature.ID)
|
||||
}
|
||||
if unmarshaled.DocID != signature.DocID {
|
||||
t.Errorf("DocID mismatch: got %v, expected %v", unmarshaled.DocID, signature.DocID)
|
||||
}
|
||||
if unmarshaled.UserSub != signature.UserSub {
|
||||
t.Errorf("UserSub mismatch: got %v, expected %v", unmarshaled.UserSub, signature.UserSub)
|
||||
}
|
||||
if unmarshaled.UserEmail != signature.UserEmail {
|
||||
t.Errorf("UserEmail mismatch: got %v, expected %v", unmarshaled.UserEmail, signature.UserEmail)
|
||||
}
|
||||
if (unmarshaled.UserName == nil) != (signature.UserName == nil) {
|
||||
t.Errorf("UserName nil mismatch: got %v, expected %v", unmarshaled.UserName == nil, signature.UserName == nil)
|
||||
}
|
||||
if unmarshaled.UserName != nil && signature.UserName != nil && *unmarshaled.UserName != *signature.UserName {
|
||||
t.Errorf("UserName mismatch: got %v, expected %v", *unmarshaled.UserName, *signature.UserName)
|
||||
}
|
||||
if !unmarshaled.SignedAtUTC.Equal(signature.SignedAtUTC) {
|
||||
t.Errorf("SignedAtUTC mismatch: got %v, expected %v", unmarshaled.SignedAtUTC, signature.SignedAtUTC)
|
||||
}
|
||||
if unmarshaled.PayloadHashB64 != signature.PayloadHashB64 {
|
||||
t.Errorf("PayloadHashB64 mismatch: got %v, expected %v", unmarshaled.PayloadHashB64, signature.PayloadHashB64)
|
||||
}
|
||||
if unmarshaled.SignatureB64 != signature.SignatureB64 {
|
||||
t.Errorf("SignatureB64 mismatch: got %v, expected %v", unmarshaled.SignatureB64, signature.SignatureB64)
|
||||
}
|
||||
if unmarshaled.Nonce != signature.Nonce {
|
||||
t.Errorf("Nonce mismatch: got %v, expected %v", unmarshaled.Nonce, signature.Nonce)
|
||||
}
|
||||
if !unmarshaled.CreatedAt.Equal(signature.CreatedAt) {
|
||||
t.Errorf("CreatedAt mismatch: got %v, expected %v", unmarshaled.CreatedAt, signature.CreatedAt)
|
||||
}
|
||||
if (unmarshaled.Referer == nil) != (signature.Referer == nil) {
|
||||
t.Errorf("Referer nil mismatch: got %v, expected %v", unmarshaled.Referer == nil, signature.Referer == nil)
|
||||
}
|
||||
if unmarshaled.Referer != nil && signature.Referer != nil && *unmarshaled.Referer != *signature.Referer {
|
||||
t.Errorf("Referer mismatch: got %v, expected %v", *unmarshaled.Referer, *signature.Referer)
|
||||
}
|
||||
if (unmarshaled.PrevHashB64 == nil) != (signature.PrevHashB64 == nil) {
|
||||
t.Errorf("PrevHashB64 nil mismatch: got %v, expected %v", unmarshaled.PrevHashB64 == nil, signature.PrevHashB64 == nil)
|
||||
}
|
||||
if unmarshaled.PrevHashB64 != nil && signature.PrevHashB64 != nil && *unmarshaled.PrevHashB64 != *signature.PrevHashB64 {
|
||||
t.Errorf("PrevHashB64 mismatch: got %v, expected %v", *unmarshaled.PrevHashB64, *signature.PrevHashB64)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignature_JSONSerializationWithNilFields(t *testing.T) {
|
||||
timestamp := time.Date(2024, 1, 15, 10, 30, 45, 0, time.UTC)
|
||||
createdAt := time.Date(2024, 1, 15, 10, 30, 46, 0, time.UTC)
|
||||
|
||||
signature := &Signature{
|
||||
ID: 456,
|
||||
DocID: "minimal-doc",
|
||||
UserSub: "github|987654321",
|
||||
UserEmail: "minimal@example.com",
|
||||
UserName: nil,
|
||||
SignedAtUTC: timestamp,
|
||||
PayloadHashB64: "bWluaW1hbA==",
|
||||
SignatureB64: "bWluaW1hbC1zaWc=",
|
||||
Nonce: "minimal-nonce",
|
||||
CreatedAt: createdAt,
|
||||
Referer: nil,
|
||||
PrevHashB64: nil,
|
||||
}
|
||||
|
||||
// Test serialization
|
||||
data, err := json.Marshal(signature)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal signature: %v", err)
|
||||
}
|
||||
|
||||
// Verify nil fields are omitted from JSON
|
||||
jsonStr := string(data)
|
||||
if strings.Contains(jsonStr, "user_name") {
|
||||
t.Error("user_name should be omitted when nil")
|
||||
}
|
||||
if strings.Contains(jsonStr, "referer") {
|
||||
t.Error("referer should be omitted when nil")
|
||||
}
|
||||
if strings.Contains(jsonStr, "prev_hash_b64") {
|
||||
t.Error("prev_hash_b64 should be omitted when nil")
|
||||
}
|
||||
|
||||
// Test deserialization
|
||||
var unmarshaled Signature
|
||||
err = json.Unmarshal(data, &unmarshaled)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal signature: %v", err)
|
||||
}
|
||||
|
||||
// Verify nil fields remain nil
|
||||
if unmarshaled.UserName != nil {
|
||||
t.Errorf("UserName should be nil, got %v", unmarshaled.UserName)
|
||||
}
|
||||
if unmarshaled.Referer != nil {
|
||||
t.Errorf("Referer should be nil, got %v", unmarshaled.Referer)
|
||||
}
|
||||
if unmarshaled.PrevHashB64 != nil {
|
||||
t.Errorf("PrevHashB64 should be nil, got %v", unmarshaled.PrevHashB64)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignature_GetServiceInfo(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
referer *string
|
||||
expectedService *string
|
||||
expectedIcon *string
|
||||
expectedType *string
|
||||
}{
|
||||
{
|
||||
name: "GitHub referer param",
|
||||
referer: stringPtr("github"),
|
||||
expectedService: stringPtr("GitHub"),
|
||||
expectedIcon: stringPtr("https://cdn.simpleicons.org/github"),
|
||||
expectedType: stringPtr("code"),
|
||||
},
|
||||
{
|
||||
name: "GitLab referer param",
|
||||
referer: stringPtr("gitlab"),
|
||||
expectedService: stringPtr("GitLab"),
|
||||
expectedIcon: stringPtr("https://cdn.simpleicons.org/gitlab"),
|
||||
expectedType: stringPtr("code"),
|
||||
},
|
||||
{
|
||||
name: "Google Docs referer param",
|
||||
referer: stringPtr("google-docs"),
|
||||
expectedService: stringPtr("Google Docs"),
|
||||
expectedIcon: stringPtr("https://cdn.simpleicons.org/googledocs"),
|
||||
expectedType: stringPtr("docs"),
|
||||
},
|
||||
{
|
||||
name: "Google Sheets referer param",
|
||||
referer: stringPtr("google-sheets"),
|
||||
expectedService: stringPtr("Google Sheets"),
|
||||
expectedIcon: stringPtr("https://cdn.simpleicons.org/googlesheets"),
|
||||
expectedType: stringPtr("sheets"),
|
||||
},
|
||||
{
|
||||
name: "Notion referer param",
|
||||
referer: stringPtr("notion"),
|
||||
expectedService: stringPtr("Notion"),
|
||||
expectedIcon: stringPtr("https://cdn.simpleicons.org/notion"),
|
||||
expectedType: stringPtr("notes"),
|
||||
},
|
||||
{
|
||||
name: "nil referer",
|
||||
referer: nil,
|
||||
expectedService: nil,
|
||||
expectedIcon: nil,
|
||||
expectedType: nil,
|
||||
},
|
||||
{
|
||||
name: "empty referer",
|
||||
referer: stringPtr(""),
|
||||
expectedService: nil,
|
||||
expectedIcon: nil,
|
||||
expectedType: nil,
|
||||
},
|
||||
{
|
||||
name: "custom referer param",
|
||||
referer: stringPtr("custom-service"),
|
||||
expectedService: stringPtr("custom-service"),
|
||||
expectedIcon: stringPtr("https://cdn.simpleicons.org/link"),
|
||||
expectedType: stringPtr("custom"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
signature := &Signature{
|
||||
Referer: tt.referer,
|
||||
}
|
||||
|
||||
serviceInfo := signature.GetServiceInfo()
|
||||
|
||||
if tt.expectedService == nil {
|
||||
if serviceInfo != nil {
|
||||
t.Errorf("Expected nil service info, got %+v", serviceInfo)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if serviceInfo == nil {
|
||||
t.Errorf("Expected service info, got nil")
|
||||
return
|
||||
}
|
||||
|
||||
if serviceInfo.Name != *tt.expectedService {
|
||||
t.Errorf("Service name mismatch: got %v, expected %v", serviceInfo.Name, *tt.expectedService)
|
||||
}
|
||||
if serviceInfo.Icon != *tt.expectedIcon {
|
||||
t.Errorf("Service icon mismatch: got %v, expected %v", serviceInfo.Icon, *tt.expectedIcon)
|
||||
}
|
||||
if serviceInfo.Type != *tt.expectedType {
|
||||
t.Errorf("Service type mismatch: got %v, expected %v", serviceInfo.Type, *tt.expectedType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignature_ComputeRecordHash(t *testing.T) {
|
||||
timestamp := time.Date(2024, 1, 15, 10, 30, 45, 123456789, time.UTC)
|
||||
createdAt := time.Date(2024, 1, 15, 10, 30, 46, 0, time.UTC)
|
||||
userName := "Test User"
|
||||
referer := "https://github.com/user/repo"
|
||||
|
||||
signature := &Signature{
|
||||
ID: 123,
|
||||
DocID: "test-doc-123",
|
||||
UserSub: "google-oauth2|123456789",
|
||||
UserEmail: "test@example.com",
|
||||
UserName: &userName,
|
||||
SignedAtUTC: timestamp,
|
||||
PayloadHashB64: "SGVsbG8gV29ybGQ=",
|
||||
SignatureB64: "c2lnbmF0dXJlLWRhdGE=",
|
||||
Nonce: "random-nonce-123",
|
||||
CreatedAt: createdAt,
|
||||
Referer: &referer,
|
||||
}
|
||||
|
||||
hash1 := signature.ComputeRecordHash()
|
||||
hash2 := signature.ComputeRecordHash()
|
||||
|
||||
// Hash should be deterministic
|
||||
if hash1 != hash2 {
|
||||
t.Errorf("Hash computation is not deterministic: %v != %v", hash1, hash2)
|
||||
}
|
||||
|
||||
// Hash should not be empty
|
||||
if hash1 == "" {
|
||||
t.Error("Hash should not be empty")
|
||||
}
|
||||
|
||||
// Hash should be base64 encoded
|
||||
if !isValidBase64(hash1) {
|
||||
t.Errorf("Hash is not valid base64: %v", hash1)
|
||||
}
|
||||
|
||||
// Changing any field should change the hash
|
||||
originalID := signature.ID
|
||||
signature.ID = 456
|
||||
hashChanged := signature.ComputeRecordHash()
|
||||
if hashChanged == hash1 {
|
||||
t.Error("Hash should change when ID changes")
|
||||
}
|
||||
signature.ID = originalID
|
||||
|
||||
// Test with nil UserName
|
||||
signature.UserName = nil
|
||||
hashWithNilName := signature.ComputeRecordHash()
|
||||
if hashWithNilName == hash1 {
|
||||
t.Error("Hash should change when UserName becomes nil")
|
||||
}
|
||||
|
||||
// Test with nil Referer
|
||||
signature.UserName = &userName
|
||||
signature.Referer = nil
|
||||
hashWithNilReferer := signature.ComputeRecordHash()
|
||||
if hashWithNilReferer == hash1 {
|
||||
t.Error("Hash should change when Referer becomes nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignature_ComputeRecordHashDeterministic(t *testing.T) {
|
||||
// Test that the same signature data produces the same hash
|
||||
timestamp := time.Date(2024, 1, 15, 10, 30, 45, 123456789, time.UTC)
|
||||
createdAt := time.Date(2024, 1, 15, 10, 30, 46, 0, time.UTC)
|
||||
userName := "Test User"
|
||||
referer := "https://github.com/user/repo"
|
||||
|
||||
sig1 := &Signature{
|
||||
ID: 123,
|
||||
DocID: "test-doc-123",
|
||||
UserSub: "google-oauth2|123456789",
|
||||
UserEmail: "test@example.com",
|
||||
UserName: &userName,
|
||||
SignedAtUTC: timestamp,
|
||||
PayloadHashB64: "SGVsbG8gV29ybGQ=",
|
||||
SignatureB64: "c2lnbmF0dXJlLWRhdGE=",
|
||||
Nonce: "random-nonce-123",
|
||||
CreatedAt: createdAt,
|
||||
Referer: &referer,
|
||||
}
|
||||
|
||||
sig2 := &Signature{
|
||||
ID: 123,
|
||||
DocID: "test-doc-123",
|
||||
UserSub: "google-oauth2|123456789",
|
||||
UserEmail: "test@example.com",
|
||||
UserName: &userName,
|
||||
SignedAtUTC: timestamp,
|
||||
PayloadHashB64: "SGVsbG8gV29ybGQ=",
|
||||
SignatureB64: "c2lnbmF0dXJlLWRhdGE=",
|
||||
Nonce: "random-nonce-123",
|
||||
CreatedAt: createdAt,
|
||||
Referer: &referer,
|
||||
}
|
||||
|
||||
hash1 := sig1.ComputeRecordHash()
|
||||
hash2 := sig2.ComputeRecordHash()
|
||||
|
||||
if hash1 != hash2 {
|
||||
t.Errorf("Identical signatures should produce identical hashes: %v != %v", hash1, hash2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignatureRequest_Validation(t *testing.T) {
|
||||
validUser := &User{
|
||||
Sub: "google-oauth2|123456789",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
request SignatureRequest
|
||||
valid bool
|
||||
}{
|
||||
{
|
||||
name: "valid request",
|
||||
request: SignatureRequest{
|
||||
DocID: "valid-doc-123",
|
||||
User: validUser,
|
||||
Referer: stringPtr("https://github.com/user/repo"),
|
||||
},
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "valid request without referer",
|
||||
request: SignatureRequest{
|
||||
DocID: "valid-doc-123",
|
||||
User: validUser,
|
||||
},
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "invalid request - empty DocID",
|
||||
request: SignatureRequest{
|
||||
DocID: "",
|
||||
User: validUser,
|
||||
},
|
||||
valid: false,
|
||||
},
|
||||
{
|
||||
name: "invalid request - nil user",
|
||||
request: SignatureRequest{
|
||||
DocID: "valid-doc-123",
|
||||
User: nil,
|
||||
},
|
||||
valid: false,
|
||||
},
|
||||
{
|
||||
name: "invalid request - invalid user",
|
||||
request: SignatureRequest{
|
||||
DocID: "valid-doc-123",
|
||||
User: &User{
|
||||
Sub: "",
|
||||
Email: "test@example.com",
|
||||
},
|
||||
},
|
||||
valid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test basic validation logic
|
||||
hasValidDocID := tt.request.DocID != ""
|
||||
hasValidUser := tt.request.User != nil && tt.request.User.IsValid()
|
||||
isValid := hasValidDocID && hasValidUser
|
||||
|
||||
if isValid != tt.valid {
|
||||
t.Errorf("Request validation mismatch: got %v, expected %v for %+v", isValid, tt.valid, tt.request)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignatureStatus_Creation(t *testing.T) {
|
||||
timestamp := time.Now().UTC()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
status SignatureStatus
|
||||
}{
|
||||
{
|
||||
name: "signed status",
|
||||
status: SignatureStatus{
|
||||
DocID: "test-doc-123",
|
||||
UserEmail: "test@example.com",
|
||||
IsSigned: true,
|
||||
SignedAt: ×tamp,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "not signed status",
|
||||
status: SignatureStatus{
|
||||
DocID: "test-doc-123",
|
||||
UserEmail: "test@example.com",
|
||||
IsSigned: false,
|
||||
SignedAt: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test JSON serialization
|
||||
data, err := json.Marshal(tt.status)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal status: %v", err)
|
||||
}
|
||||
|
||||
var unmarshaled SignatureStatus
|
||||
err = json.Unmarshal(data, &unmarshaled)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal status: %v", err)
|
||||
}
|
||||
|
||||
// Verify fields
|
||||
if unmarshaled.DocID != tt.status.DocID {
|
||||
t.Errorf("DocID mismatch: got %v, expected %v", unmarshaled.DocID, tt.status.DocID)
|
||||
}
|
||||
if unmarshaled.UserEmail != tt.status.UserEmail {
|
||||
t.Errorf("UserEmail mismatch: got %v, expected %v", unmarshaled.UserEmail, tt.status.UserEmail)
|
||||
}
|
||||
if unmarshaled.IsSigned != tt.status.IsSigned {
|
||||
t.Errorf("IsSigned mismatch: got %v, expected %v", unmarshaled.IsSigned, tt.status.IsSigned)
|
||||
}
|
||||
if (unmarshaled.SignedAt == nil) != (tt.status.SignedAt == nil) {
|
||||
t.Errorf("SignedAt nil mismatch: got %v, expected %v", unmarshaled.SignedAt == nil, tt.status.SignedAt == nil)
|
||||
}
|
||||
if unmarshaled.SignedAt != nil && tt.status.SignedAt != nil && !unmarshaled.SignedAt.Equal(*tt.status.SignedAt) {
|
||||
t.Errorf("SignedAt mismatch: got %v, expected %v", *unmarshaled.SignedAt, *tt.status.SignedAt)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignature_TimestampValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
signedAt time.Time
|
||||
createdAt time.Time
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "valid timestamps - signedAt before createdAt",
|
||||
signedAt: time.Date(2024, 1, 15, 10, 30, 45, 0, time.UTC),
|
||||
createdAt: time.Date(2024, 1, 15, 10, 30, 46, 0, time.UTC),
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "valid timestamps - same time",
|
||||
signedAt: time.Date(2024, 1, 15, 10, 30, 45, 0, time.UTC),
|
||||
createdAt: time.Date(2024, 1, 15, 10, 30, 45, 0, time.UTC),
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "questionable timestamps - createdAt before signedAt",
|
||||
signedAt: time.Date(2024, 1, 15, 10, 30, 46, 0, time.UTC),
|
||||
createdAt: time.Date(2024, 1, 15, 10, 30, 45, 0, time.UTC),
|
||||
expectValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
signature := &Signature{
|
||||
SignedAtUTC: tt.signedAt,
|
||||
CreatedAt: tt.createdAt,
|
||||
}
|
||||
|
||||
// Business rule: CreatedAt should be after or equal to SignedAtUTC
|
||||
isValid := !signature.CreatedAt.Before(signature.SignedAtUTC)
|
||||
if isValid != tt.expectValid {
|
||||
t.Errorf("Timestamp validation mismatch: got %v, expected %v", isValid, tt.expectValid)
|
||||
}
|
||||
|
||||
// Verify timestamps are UTC
|
||||
if signature.SignedAtUTC.Location() != time.UTC {
|
||||
t.Error("SignedAtUTC should be in UTC timezone")
|
||||
}
|
||||
if signature.CreatedAt.Location() != time.UTC {
|
||||
t.Error("CreatedAt should be in UTC timezone")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
func isValidBase64(s string) bool {
|
||||
// Simple base64 validation - should contain only valid base64 characters
|
||||
validChars := "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/="
|
||||
for _, char := range s {
|
||||
found := false
|
||||
for _, valid := range validChars {
|
||||
if char == valid {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return len(s) > 0
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package models
|
||||
|
||||
import "strings"
|
||||
|
||||
// User represents an authenticated user
|
||||
type User struct {
|
||||
Sub string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// IsValid checks if the user has valid credentials
|
||||
func (u *User) IsValid() bool {
|
||||
return strings.TrimSpace(u.Sub) != "" && strings.TrimSpace(u.Email) != ""
|
||||
}
|
||||
|
||||
// NormalizedEmail returns the email in lowercase
|
||||
func (u *User) NormalizedEmail() string {
|
||||
return strings.ToLower(u.Email)
|
||||
}
|
||||
@@ -0,0 +1,396 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUser_IsValid(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
user User
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "valid user with all fields",
|
||||
user: User{
|
||||
Sub: "google-oauth2|123456789",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "valid user without name",
|
||||
user: User{
|
||||
Sub: "github|987654321",
|
||||
Email: "user@github.com",
|
||||
Name: "",
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "invalid user - missing sub",
|
||||
user: User{
|
||||
Sub: "",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "invalid user - missing email",
|
||||
user: User{
|
||||
Sub: "google-oauth2|123456789",
|
||||
Email: "",
|
||||
Name: "Test User",
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "invalid user - missing both sub and email",
|
||||
user: User{
|
||||
Sub: "",
|
||||
Email: "",
|
||||
Name: "Test User",
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "invalid user - whitespace only sub",
|
||||
user: User{
|
||||
Sub: " ",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "invalid user - whitespace only email",
|
||||
user: User{
|
||||
Sub: "google-oauth2|123456789",
|
||||
Email: " ",
|
||||
Name: "Test User",
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.user.IsValid()
|
||||
if result != tt.expected {
|
||||
t.Errorf("User.IsValid() = %v, expected %v for user %+v", result, tt.expected, tt.user)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUser_NormalizedEmail(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "lowercase email",
|
||||
email: "test@example.com",
|
||||
expected: "test@example.com",
|
||||
},
|
||||
{
|
||||
name: "uppercase email",
|
||||
email: "TEST@EXAMPLE.COM",
|
||||
expected: "test@example.com",
|
||||
},
|
||||
{
|
||||
name: "mixed case email",
|
||||
email: "TeSt@ExAmPlE.CoM",
|
||||
expected: "test@example.com",
|
||||
},
|
||||
{
|
||||
name: "email with mixed domain",
|
||||
email: "user@GitHub.COM",
|
||||
expected: "user@github.com",
|
||||
},
|
||||
{
|
||||
name: "empty email",
|
||||
email: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "email with special characters",
|
||||
email: "User+Tag@DOMAIN.ORG",
|
||||
expected: "user+tag@domain.org",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
user := User{Email: tt.email}
|
||||
result := user.NormalizedEmail()
|
||||
if result != tt.expected {
|
||||
t.Errorf("User.NormalizedEmail() = %v, expected %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUser_JSONSerialization(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
user User
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "complete user",
|
||||
user: User{
|
||||
Sub: "google-oauth2|123456789",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
},
|
||||
expected: `{"sub":"google-oauth2|123456789","email":"test@example.com","name":"Test User"}`,
|
||||
},
|
||||
{
|
||||
name: "user without name",
|
||||
user: User{
|
||||
Sub: "github|987654321",
|
||||
Email: "user@github.com",
|
||||
Name: "",
|
||||
},
|
||||
expected: `{"sub":"github|987654321","email":"user@github.com","name":""}`,
|
||||
},
|
||||
{
|
||||
name: "user with special characters",
|
||||
user: User{
|
||||
Sub: "gitlab|special-chars-123",
|
||||
Email: "user+tag@domain.org",
|
||||
Name: "Nom avec accénts",
|
||||
},
|
||||
expected: `{"sub":"gitlab|special-chars-123","email":"user+tag@domain.org","name":"Nom avec accénts"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test serialization
|
||||
data, err := json.Marshal(tt.user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal user: %v", err)
|
||||
}
|
||||
|
||||
if string(data) != tt.expected {
|
||||
t.Errorf("JSON serialization mismatch:\ngot: %s\nexpected: %s", string(data), tt.expected)
|
||||
}
|
||||
|
||||
// Test deserialization
|
||||
var user User
|
||||
err = json.Unmarshal(data, &user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal user: %v", err)
|
||||
}
|
||||
|
||||
if user.Sub != tt.user.Sub || user.Email != tt.user.Email || user.Name != tt.user.Name {
|
||||
t.Errorf("Deserialized user mismatch:\ngot: %+v\nexpected: %+v", user, tt.user)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUser_JSONDeserialization(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
jsonData string
|
||||
expected User
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid JSON",
|
||||
jsonData: `{"sub":"google-oauth2|123456789","email":"test@example.com","name":"Test User"}`,
|
||||
expected: User{
|
||||
Sub: "google-oauth2|123456789",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "JSON with missing name",
|
||||
jsonData: `{"sub":"github|987654321","email":"user@github.com"}`,
|
||||
expected: User{
|
||||
Sub: "github|987654321",
|
||||
Email: "user@github.com",
|
||||
Name: "",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "JSON with null values",
|
||||
jsonData: `{"sub":"gitlab|123","email":"test@example.com","name":null}`,
|
||||
expected: User{
|
||||
Sub: "gitlab|123",
|
||||
Email: "test@example.com",
|
||||
Name: "",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
jsonData: `{"sub":"invalid"email":"missing-comma"}`,
|
||||
expected: User{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty JSON object",
|
||||
jsonData: `{}`,
|
||||
expected: User{
|
||||
Sub: "",
|
||||
Email: "",
|
||||
Name: "",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var user User
|
||||
err := json.Unmarshal([]byte(tt.jsonData), &user)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if user.Sub != tt.expected.Sub || user.Email != tt.expected.Email || user.Name != tt.expected.Name {
|
||||
t.Errorf("Deserialized user mismatch:\ngot: %+v\nexpected: %+v", user, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUser_EmailValidationRules(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "valid standard email",
|
||||
email: "test@example.com",
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "valid email with subdomain",
|
||||
email: "user@mail.example.com",
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "valid email with plus sign",
|
||||
email: "user+tag@example.com",
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "valid email with dots",
|
||||
email: "first.last@example.com",
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "empty email is invalid",
|
||||
email: "",
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "whitespace email is invalid",
|
||||
email: " ",
|
||||
expectValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
user := User{
|
||||
Sub: "test-sub",
|
||||
Email: tt.email,
|
||||
}
|
||||
|
||||
isValid := user.IsValid()
|
||||
if isValid != tt.expectValid {
|
||||
t.Errorf("User with email '%s' validation = %v, expected %v", tt.email, isValid, tt.expectValid)
|
||||
}
|
||||
|
||||
// Test normalized email
|
||||
normalized := user.NormalizedEmail()
|
||||
if tt.email != "" {
|
||||
expectedNormalized := strings.ToLower(tt.email)
|
||||
if normalized != expectedNormalized {
|
||||
t.Errorf("NormalizedEmail() = %v, expected %v", normalized, expectedNormalized)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUser_SubValidationRules(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sub string
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "valid Google OAuth2 sub",
|
||||
sub: "google-oauth2|123456789012345678901",
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "valid GitHub sub",
|
||||
sub: "github|12345678",
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "valid GitLab sub",
|
||||
sub: "gitlab|987654321",
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "valid custom provider sub",
|
||||
sub: "custom-provider|user-123",
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "empty sub is invalid",
|
||||
sub: "",
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "whitespace sub is invalid",
|
||||
sub: " ",
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "numeric sub is valid",
|
||||
sub: "123456789",
|
||||
expectValid: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
user := User{
|
||||
Sub: tt.sub,
|
||||
Email: "test@example.com",
|
||||
}
|
||||
|
||||
isValid := user.IsValid()
|
||||
if isValid != tt.expectValid {
|
||||
t.Errorf("User with sub '%s' validation = %v, expected %v", tt.sub, isValid, tt.expectValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,232 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/securecookie"
|
||||
"github.com/gorilla/sessions"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"ackify/internal/domain/models"
|
||||
"ackify/pkg/logger"
|
||||
)
|
||||
|
||||
const sessionName = "ackapp_session"
|
||||
|
||||
type OauthService struct {
|
||||
oauthConfig *oauth2.Config
|
||||
sessionStore *sessions.CookieStore
|
||||
userInfoURL string
|
||||
allowedDomain string
|
||||
secureCookies bool
|
||||
}
|
||||
|
||||
// Config holds OAuth service configuration
|
||||
type Config struct {
|
||||
BaseURL string
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
AuthURL string
|
||||
TokenURL string
|
||||
UserInfoURL string
|
||||
Scopes []string
|
||||
AllowedDomain string
|
||||
CookieSecret []byte
|
||||
SecureCookies bool
|
||||
}
|
||||
|
||||
// NewOAuthService creates a new OAuth service
|
||||
func NewOAuthService(config Config) *OauthService {
|
||||
oauthConfig := &oauth2.Config{
|
||||
ClientID: config.ClientID,
|
||||
ClientSecret: config.ClientSecret,
|
||||
RedirectURL: config.BaseURL + "/oauth2/callback",
|
||||
Scopes: config.Scopes,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: config.AuthURL,
|
||||
TokenURL: config.TokenURL,
|
||||
},
|
||||
}
|
||||
|
||||
sessionStore := sessions.NewCookieStore(config.CookieSecret)
|
||||
|
||||
return &OauthService{
|
||||
oauthConfig: oauthConfig,
|
||||
sessionStore: sessionStore,
|
||||
userInfoURL: config.UserInfoURL,
|
||||
allowedDomain: config.AllowedDomain,
|
||||
secureCookies: config.SecureCookies,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OauthService) GetUser(r *http.Request) (*models.User, error) {
|
||||
session, err := s.sessionStore.Get(r, sessionName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get session: %w", err)
|
||||
}
|
||||
|
||||
userJSON, ok := session.Values["user"].(string)
|
||||
if !ok || userJSON == "" {
|
||||
return nil, models.ErrUnauthorized
|
||||
}
|
||||
|
||||
var user models.User
|
||||
if err := json.Unmarshal([]byte(userJSON), &user); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal user: %w", err)
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *OauthService) SetUser(w http.ResponseWriter, r *http.Request, user *models.User) error {
|
||||
session, _ := s.sessionStore.Get(r, sessionName)
|
||||
|
||||
userJSON, err := json.Marshal(user)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal user: %w", err)
|
||||
}
|
||||
|
||||
session.Values["user"] = string(userJSON)
|
||||
session.Options = &sessions.Options{
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: s.secureCookies,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}
|
||||
|
||||
if err := session.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save session: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OauthService) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
session, _ := s.sessionStore.Get(r, sessionName)
|
||||
session.Options.MaxAge = -1
|
||||
_ = session.Save(r, w)
|
||||
}
|
||||
|
||||
func (s *OauthService) GetAuthURL(nextURL string) string {
|
||||
state := base64.RawURLEncoding.EncodeToString(securecookie.GenerateRandomKey(20)) +
|
||||
":" + base64.RawURLEncoding.EncodeToString([]byte(nextURL))
|
||||
|
||||
return s.oauthConfig.AuthCodeURL(state, oauth2.SetAuthURLParam("prompt", "select_account"))
|
||||
}
|
||||
|
||||
func (s *OauthService) HandleCallback(ctx context.Context, code, state string) (*models.User, string, error) {
|
||||
// Parse state to get next URL
|
||||
parts := strings.SplitN(state, ":", 2)
|
||||
nextURL := "/"
|
||||
if len(parts) == 2 {
|
||||
if nb, err := base64.RawURLEncoding.DecodeString(parts[1]); err == nil {
|
||||
nextURL = string(nb)
|
||||
}
|
||||
}
|
||||
|
||||
// Exchange code for token
|
||||
token, err := s.oauthConfig.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return nil, nextURL, fmt.Errorf("oauth exchange failed: %w", err)
|
||||
}
|
||||
|
||||
// Get user info
|
||||
client := s.oauthConfig.Client(ctx, token)
|
||||
resp, err := client.Get(s.userInfoURL)
|
||||
if err != nil || resp.StatusCode != 200 {
|
||||
return nil, nextURL, fmt.Errorf("userinfo request failed: %w", err)
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
_ = Body.Close()
|
||||
}(resp.Body)
|
||||
|
||||
user, err := s.parseUserInfo(resp)
|
||||
if err != nil {
|
||||
return nil, nextURL, fmt.Errorf("failed to parse user info: %w", err)
|
||||
}
|
||||
|
||||
// Check domain restriction
|
||||
if !s.IsAllowedDomain(user.Email) {
|
||||
return nil, nextURL, models.ErrDomainNotAllowed
|
||||
}
|
||||
|
||||
return user, nextURL, nil
|
||||
}
|
||||
|
||||
func (s *OauthService) IsAllowedDomain(email string) bool {
|
||||
if s.allowedDomain == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
return strings.HasSuffix(
|
||||
strings.ToLower(email),
|
||||
"@"+strings.ToLower(s.allowedDomain),
|
||||
)
|
||||
}
|
||||
|
||||
// parseUserInfo extracts user information from different OAuth2 providers
|
||||
func (s *OauthService) parseUserInfo(resp *http.Response) (*models.User, error) {
|
||||
var rawUser map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&rawUser); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode user info: %w", err)
|
||||
}
|
||||
|
||||
logger.Logger.Info("Raw OAuth user info received", "data", rawUser)
|
||||
|
||||
user := &models.User{}
|
||||
|
||||
// Extract user ID (sub field or id field depending on provider)
|
||||
if sub, ok := rawUser["sub"].(string); ok {
|
||||
user.Sub = sub
|
||||
} else if id, ok := rawUser["id"]; ok {
|
||||
user.Sub = fmt.Sprintf("%v", id) // Convert to string regardless of type
|
||||
} else {
|
||||
return nil, fmt.Errorf("missing user ID in response")
|
||||
}
|
||||
|
||||
// Extract email
|
||||
if email, ok := rawUser["email"].(string); ok {
|
||||
user.Email = email
|
||||
} else {
|
||||
// Some providers might have email in a different field or structure
|
||||
return nil, fmt.Errorf("missing email in user info response")
|
||||
}
|
||||
|
||||
// Extract user name with fallback strategy
|
||||
var name string
|
||||
if preferredName, ok := rawUser["preferred_username"].(string); ok && preferredName != "" {
|
||||
name = preferredName
|
||||
} else if firstName, ok := rawUser["given_name"].(string); ok {
|
||||
if lastName, ok := rawUser["family_name"].(string); ok {
|
||||
name = firstName + " " + lastName
|
||||
} else {
|
||||
name = firstName
|
||||
}
|
||||
} else if fullName, ok := rawUser["name"].(string); ok && fullName != "" {
|
||||
name = fullName
|
||||
} else if cn, ok := rawUser["cn"].(string); ok && cn != "" {
|
||||
name = cn
|
||||
} else if displayName, ok := rawUser["display_name"].(string); ok && displayName != "" {
|
||||
name = displayName
|
||||
}
|
||||
|
||||
user.Name = name
|
||||
|
||||
logger.Logger.Info("Extracted user data",
|
||||
"sub", user.Sub,
|
||||
"email", user.Email,
|
||||
"name", user.Name)
|
||||
|
||||
// Validate extracted data
|
||||
if !user.IsValid() {
|
||||
return nil, fmt.Errorf("invalid user data extracted: sub=%s, email=%s", user.Sub, user.Email)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
@@ -0,0 +1,895 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ackify/internal/domain/models"
|
||||
)
|
||||
|
||||
func TestNewOAuthService(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config Config
|
||||
}{
|
||||
{
|
||||
name: "complete config",
|
||||
config: Config{
|
||||
BaseURL: "https://ackify.example.com",
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
AuthURL: "https://provider.com/auth",
|
||||
TokenURL: "https://provider.com/token",
|
||||
UserInfoURL: "https://provider.com/userinfo",
|
||||
Scopes: []string{"openid", "email", "profile"},
|
||||
AllowedDomain: "@example.com",
|
||||
CookieSecret: []byte("32-byte-secret-for-secure-cookies"),
|
||||
SecureCookies: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "minimal config",
|
||||
config: Config{
|
||||
BaseURL: "http://localhost:8080",
|
||||
ClientID: "minimal-client",
|
||||
ClientSecret: "minimal-secret",
|
||||
AuthURL: "https://auth.com/oauth",
|
||||
TokenURL: "https://auth.com/token",
|
||||
UserInfoURL: "https://api.com/user",
|
||||
Scopes: []string{"user"},
|
||||
CookieSecret: []byte("test-secret"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
service := NewOAuthService(tt.config)
|
||||
|
||||
if service == nil {
|
||||
t.Fatal("NewOAuthService() returned nil")
|
||||
}
|
||||
|
||||
// Test that OAuth config is properly initialized
|
||||
if service.oauthConfig == nil {
|
||||
t.Error("OAuth config should not be nil")
|
||||
}
|
||||
if service.oauthConfig.ClientID != tt.config.ClientID {
|
||||
t.Errorf("ClientID = %v, expected %v", service.oauthConfig.ClientID, tt.config.ClientID)
|
||||
}
|
||||
if service.oauthConfig.ClientSecret != tt.config.ClientSecret {
|
||||
t.Errorf("ClientSecret = %v, expected %v", service.oauthConfig.ClientSecret, tt.config.ClientSecret)
|
||||
}
|
||||
|
||||
expectedRedirectURL := tt.config.BaseURL + "/oauth2/callback"
|
||||
if service.oauthConfig.RedirectURL != expectedRedirectURL {
|
||||
t.Errorf("RedirectURL = %v, expected %v", service.oauthConfig.RedirectURL, expectedRedirectURL)
|
||||
}
|
||||
|
||||
if len(service.oauthConfig.Scopes) != len(tt.config.Scopes) {
|
||||
t.Errorf("Scopes length = %v, expected %v", len(service.oauthConfig.Scopes), len(tt.config.Scopes))
|
||||
}
|
||||
|
||||
if service.oauthConfig.Endpoint.AuthURL != tt.config.AuthURL {
|
||||
t.Errorf("AuthURL = %v, expected %v", service.oauthConfig.Endpoint.AuthURL, tt.config.AuthURL)
|
||||
}
|
||||
if service.oauthConfig.Endpoint.TokenURL != tt.config.TokenURL {
|
||||
t.Errorf("TokenURL = %v, expected %v", service.oauthConfig.Endpoint.TokenURL, tt.config.TokenURL)
|
||||
}
|
||||
|
||||
// Test service fields
|
||||
if service.userInfoURL != tt.config.UserInfoURL {
|
||||
t.Errorf("userInfoURL = %v, expected %v", service.userInfoURL, tt.config.UserInfoURL)
|
||||
}
|
||||
if service.allowedDomain != tt.config.AllowedDomain {
|
||||
t.Errorf("allowedDomain = %v, expected %v", service.allowedDomain, tt.config.AllowedDomain)
|
||||
}
|
||||
if service.secureCookies != tt.config.SecureCookies {
|
||||
t.Errorf("secureCookies = %v, expected %v", service.secureCookies, tt.config.SecureCookies)
|
||||
}
|
||||
|
||||
// Test session store
|
||||
if service.sessionStore == nil {
|
||||
t.Error("Session store should not be nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOauthService_GetUser(t *testing.T) {
|
||||
service := createTestService()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupSession func(*httptest.ResponseRecorder, *http.Request)
|
||||
expectError bool
|
||||
expectedError error
|
||||
expectedUser *models.User
|
||||
}{
|
||||
{
|
||||
name: "valid user session",
|
||||
setupSession: func(w *httptest.ResponseRecorder, r *http.Request) {
|
||||
user := &models.User{
|
||||
Sub: "test-sub",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
}
|
||||
err := service.SetUser(w, r, user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set user: %v", err)
|
||||
}
|
||||
},
|
||||
expectError: false,
|
||||
expectedUser: &models.User{
|
||||
Sub: "test-sub",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no session",
|
||||
setupSession: func(w *httptest.ResponseRecorder, r *http.Request) {},
|
||||
expectError: true,
|
||||
expectedError: models.ErrUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON in session",
|
||||
setupSession: func(w *httptest.ResponseRecorder, r *http.Request) {
|
||||
session, _ := service.sessionStore.Get(r, sessionName)
|
||||
session.Values["user"] = "invalid-json"
|
||||
session.Save(r, w)
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "empty user value in session",
|
||||
setupSession: func(w *httptest.ResponseRecorder, r *http.Request) {
|
||||
session, _ := service.sessionStore.Get(r, sessionName)
|
||||
session.Values["user"] = ""
|
||||
session.Save(r, w)
|
||||
},
|
||||
expectError: true,
|
||||
expectedError: models.ErrUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
// Setup session if needed
|
||||
tt.setupSession(w, r)
|
||||
|
||||
// Add cookies from the setup response to the request
|
||||
if len(w.Result().Cookies()) > 0 {
|
||||
for _, cookie := range w.Result().Cookies() {
|
||||
r.AddCookie(cookie)
|
||||
}
|
||||
}
|
||||
|
||||
user, err := service.GetUser(r)
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
return
|
||||
}
|
||||
if tt.expectedError != nil && err != tt.expectedError {
|
||||
t.Errorf("Error = %v, expected %v", err, tt.expectedError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
t.Error("User should not be nil")
|
||||
return
|
||||
}
|
||||
|
||||
if user.Sub != tt.expectedUser.Sub {
|
||||
t.Errorf("User.Sub = %v, expected %v", user.Sub, tt.expectedUser.Sub)
|
||||
}
|
||||
if user.Email != tt.expectedUser.Email {
|
||||
t.Errorf("User.Email = %v, expected %v", user.Email, tt.expectedUser.Email)
|
||||
}
|
||||
if user.Name != tt.expectedUser.Name {
|
||||
t.Errorf("User.Name = %v, expected %v", user.Name, tt.expectedUser.Name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOauthService_SetUser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
service *OauthService
|
||||
user *models.User
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid user with secure cookies",
|
||||
service: createTestServiceWithSecure(true),
|
||||
user: &models.User{
|
||||
Sub: "test-sub",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "valid user without secure cookies",
|
||||
service: createTestServiceWithSecure(false),
|
||||
user: &models.User{
|
||||
Sub: "github|123",
|
||||
Email: "user@github.com",
|
||||
Name: "GitHub User",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "user with special characters",
|
||||
service: createTestService(),
|
||||
user: &models.User{
|
||||
Sub: "google-oauth2|123456789",
|
||||
Email: "user+test@example.com",
|
||||
Name: "Üser Námé",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
err := tt.service.SetUser(w, r, tt.user)
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify that user can be retrieved
|
||||
for _, cookie := range w.Result().Cookies() {
|
||||
r.AddCookie(cookie)
|
||||
}
|
||||
|
||||
retrievedUser, err := tt.service.GetUser(r)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to retrieve user: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if retrievedUser.Sub != tt.user.Sub {
|
||||
t.Errorf("Retrieved user Sub = %v, expected %v", retrievedUser.Sub, tt.user.Sub)
|
||||
}
|
||||
if retrievedUser.Email != tt.user.Email {
|
||||
t.Errorf("Retrieved user Email = %v, expected %v", retrievedUser.Email, tt.user.Email)
|
||||
}
|
||||
if retrievedUser.Name != tt.user.Name {
|
||||
t.Errorf("Retrieved user Name = %v, expected %v", retrievedUser.Name, tt.user.Name)
|
||||
}
|
||||
|
||||
// Verify cookie properties
|
||||
cookies := w.Result().Cookies()
|
||||
if len(cookies) == 0 {
|
||||
t.Error("No cookies set")
|
||||
return
|
||||
}
|
||||
|
||||
sessionCookie := cookies[0]
|
||||
if sessionCookie.HttpOnly != true {
|
||||
t.Error("Cookie should be HttpOnly")
|
||||
}
|
||||
if sessionCookie.Secure != tt.service.secureCookies {
|
||||
t.Errorf("Cookie Secure = %v, expected %v", sessionCookie.Secure, tt.service.secureCookies)
|
||||
}
|
||||
if sessionCookie.SameSite != http.SameSiteLaxMode {
|
||||
t.Errorf("Cookie SameSite = %v, expected %v", sessionCookie.SameSite, http.SameSiteLaxMode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOauthService_Logout(t *testing.T) {
|
||||
service := createTestService()
|
||||
|
||||
// First, set a user
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
user := &models.User{
|
||||
Sub: "test-sub",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
}
|
||||
|
||||
err := service.SetUser(w, r, user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set user: %v", err)
|
||||
}
|
||||
|
||||
// Add cookies to request
|
||||
for _, cookie := range w.Result().Cookies() {
|
||||
r.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Verify user exists
|
||||
retrievedUser, err := service.GetUser(r)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get user before logout: %v", err)
|
||||
}
|
||||
if retrievedUser == nil {
|
||||
t.Fatal("User should exist before logout")
|
||||
}
|
||||
|
||||
// Logout
|
||||
w2 := httptest.NewRecorder()
|
||||
service.Logout(w2, r)
|
||||
|
||||
// Verify logout cookie has MaxAge = -1
|
||||
cookies := w2.Result().Cookies()
|
||||
if len(cookies) == 0 {
|
||||
t.Error("No logout cookies set")
|
||||
return
|
||||
}
|
||||
|
||||
logoutCookie := cookies[0]
|
||||
if logoutCookie.MaxAge != -1 {
|
||||
t.Errorf("Logout cookie MaxAge = %v, expected -1", logoutCookie.MaxAge)
|
||||
}
|
||||
|
||||
// Test that logout doesn't fail even with no session
|
||||
w3 := httptest.NewRecorder()
|
||||
r3 := httptest.NewRequest("GET", "/", nil)
|
||||
service.Logout(w3, r3) // Should not panic
|
||||
}
|
||||
|
||||
func TestOauthService_GetAuthURL(t *testing.T) {
|
||||
service := createTestService()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
nextURL string
|
||||
}{
|
||||
{
|
||||
name: "root next URL",
|
||||
nextURL: "/",
|
||||
},
|
||||
{
|
||||
name: "specific page next URL",
|
||||
nextURL: "/sign?doc=test-doc",
|
||||
},
|
||||
{
|
||||
name: "empty next URL",
|
||||
nextURL: "",
|
||||
},
|
||||
{
|
||||
name: "complex next URL with parameters",
|
||||
nextURL: "/sign?doc=test-doc&referrer=github",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
authURL := service.GetAuthURL(tt.nextURL)
|
||||
|
||||
if authURL == "" {
|
||||
t.Error("Auth URL should not be empty")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse the URL to verify it's valid
|
||||
parsedURL, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Errorf("Invalid auth URL: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify it contains the expected OAuth parameters
|
||||
query := parsedURL.Query()
|
||||
if query.Get("client_id") != "test-client-id" {
|
||||
t.Errorf("client_id = %v, expected test-client-id", query.Get("client_id"))
|
||||
}
|
||||
if query.Get("response_type") != "code" {
|
||||
t.Errorf("response_type = %v, expected code", query.Get("response_type"))
|
||||
}
|
||||
if query.Get("redirect_uri") == "" {
|
||||
t.Error("redirect_uri should not be empty")
|
||||
}
|
||||
if query.Get("scope") == "" {
|
||||
t.Error("scope should not be empty")
|
||||
}
|
||||
if query.Get("state") == "" {
|
||||
t.Error("state should not be empty")
|
||||
}
|
||||
if query.Get("prompt") != "select_account" {
|
||||
t.Errorf("prompt = %v, expected select_account", query.Get("prompt"))
|
||||
}
|
||||
|
||||
// Verify state contains the next URL (basic check)
|
||||
state := query.Get("state")
|
||||
if !strings.Contains(state, ":") {
|
||||
t.Error("State should contain ':' separator")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOauthService_IsAllowedDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
allowedDomain string
|
||||
email string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "no domain restriction",
|
||||
allowedDomain: "",
|
||||
email: "user@anywhere.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "matching domain",
|
||||
allowedDomain: "example.com",
|
||||
email: "user@example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "non-matching domain",
|
||||
allowedDomain: "example.com",
|
||||
email: "user@other.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "case insensitive matching",
|
||||
allowedDomain: "EXAMPLE.COM",
|
||||
email: "user@example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "case insensitive email",
|
||||
allowedDomain: "example.com",
|
||||
email: "USER@EXAMPLE.COM",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "subdomain not allowed",
|
||||
allowedDomain: "example.com",
|
||||
email: "user@sub.example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "partial domain match not allowed",
|
||||
allowedDomain: "example.com",
|
||||
email: "user@notexample.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "domain without @ prefix",
|
||||
allowedDomain: "example.com",
|
||||
email: "user@example.com",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
service := &OauthService{
|
||||
allowedDomain: tt.allowedDomain,
|
||||
}
|
||||
|
||||
result := service.IsAllowedDomain(tt.email)
|
||||
if result != tt.expected {
|
||||
t.Errorf("IsAllowedDomain() = %v, expected %v for email %s with domain %s",
|
||||
result, tt.expected, tt.email, tt.allowedDomain)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOauthService_parseUserInfo(t *testing.T) {
|
||||
service := createTestService()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
responseBody map[string]interface{}
|
||||
expectError bool
|
||||
expectedUser *models.User
|
||||
}{
|
||||
{
|
||||
name: "Google OAuth response",
|
||||
responseBody: map[string]interface{}{
|
||||
"sub": "google-oauth2|123456789",
|
||||
"email": "test@example.com",
|
||||
"name": "Test User",
|
||||
},
|
||||
expectError: false,
|
||||
expectedUser: &models.User{
|
||||
Sub: "google-oauth2|123456789",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "GitHub OAuth response",
|
||||
responseBody: map[string]interface{}{
|
||||
"id": float64(12345), // JSON numbers become float64
|
||||
"email": "user@github.com",
|
||||
"name": "GitHub User",
|
||||
},
|
||||
expectError: false,
|
||||
expectedUser: &models.User{
|
||||
Sub: "12345",
|
||||
Email: "user@github.com",
|
||||
Name: "GitHub User",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "GitLab OAuth response",
|
||||
responseBody: map[string]interface{}{
|
||||
"id": float64(987),
|
||||
"email": "user@gitlab.com",
|
||||
"preferred_username": "gitlabuser",
|
||||
},
|
||||
expectError: false,
|
||||
expectedUser: &models.User{
|
||||
Sub: "987",
|
||||
Email: "user@gitlab.com",
|
||||
Name: "gitlabuser",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "OAuth with first/last names",
|
||||
responseBody: map[string]interface{}{
|
||||
"sub": "oauth2|12345",
|
||||
"email": "user@example.com",
|
||||
"given_name": "John",
|
||||
"family_name": "Doe",
|
||||
},
|
||||
expectError: false,
|
||||
expectedUser: &models.User{
|
||||
Sub: "oauth2|12345",
|
||||
Email: "user@example.com",
|
||||
Name: "John Doe",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "OAuth with only first name",
|
||||
responseBody: map[string]interface{}{
|
||||
"sub": "oauth2|12345",
|
||||
"email": "user@example.com",
|
||||
"given_name": "John",
|
||||
},
|
||||
expectError: false,
|
||||
expectedUser: &models.User{
|
||||
Sub: "oauth2|12345",
|
||||
Email: "user@example.com",
|
||||
Name: "John",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "OAuth with CN field",
|
||||
responseBody: map[string]interface{}{
|
||||
"sub": "ldap|12345",
|
||||
"email": "user@company.com",
|
||||
"cn": "Common Name",
|
||||
},
|
||||
expectError: false,
|
||||
expectedUser: &models.User{
|
||||
Sub: "ldap|12345",
|
||||
Email: "user@company.com",
|
||||
Name: "Common Name",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "OAuth with display_name",
|
||||
responseBody: map[string]interface{}{
|
||||
"sub": "custom|12345",
|
||||
"email": "user@custom.com",
|
||||
"display_name": "Display Name",
|
||||
},
|
||||
expectError: false,
|
||||
expectedUser: &models.User{
|
||||
Sub: "custom|12345",
|
||||
Email: "user@custom.com",
|
||||
Name: "Display Name",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "OAuth without name fields",
|
||||
responseBody: map[string]interface{}{
|
||||
"sub": "minimal|12345",
|
||||
"email": "user@minimal.com",
|
||||
},
|
||||
expectError: false,
|
||||
expectedUser: &models.User{
|
||||
Sub: "minimal|12345",
|
||||
Email: "user@minimal.com",
|
||||
Name: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing sub and id",
|
||||
responseBody: map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
"name": "Test User",
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "missing email",
|
||||
responseBody: map[string]interface{}{
|
||||
"sub": "test|12345",
|
||||
"name": "Test User",
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "string ID",
|
||||
responseBody: map[string]interface{}{
|
||||
"id": "string-id-123",
|
||||
"email": "user@example.com",
|
||||
"name": "String ID User",
|
||||
},
|
||||
expectError: false,
|
||||
expectedUser: &models.User{
|
||||
Sub: "string-id-123",
|
||||
Email: "user@example.com",
|
||||
Name: "String ID User",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create HTTP response
|
||||
jsonBody, _ := json.Marshal(tt.responseBody)
|
||||
resp := &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: io.NopCloser(bytes.NewReader(jsonBody)),
|
||||
}
|
||||
|
||||
user, err := service.parseUserInfo(resp)
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
t.Error("User should not be nil")
|
||||
return
|
||||
}
|
||||
|
||||
if user.Sub != tt.expectedUser.Sub {
|
||||
t.Errorf("User.Sub = %v, expected %v", user.Sub, tt.expectedUser.Sub)
|
||||
}
|
||||
if user.Email != tt.expectedUser.Email {
|
||||
t.Errorf("User.Email = %v, expected %v", user.Email, tt.expectedUser.Email)
|
||||
}
|
||||
if user.Name != tt.expectedUser.Name {
|
||||
t.Errorf("User.Name = %v, expected %v", user.Name, tt.expectedUser.Name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOauthService_parseUserInfo_InvalidJSON(t *testing.T) {
|
||||
service := createTestService()
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: io.NopCloser(strings.NewReader("invalid json")),
|
||||
}
|
||||
|
||||
_, err := service.parseUserInfo(resp)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid JSON")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "failed to decode user info") {
|
||||
t.Errorf("Error should mention decoding failure: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOauthService_HandleCallback_StateDecoding(t *testing.T) {
|
||||
service := createTestService()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
state string
|
||||
expectedURL string
|
||||
}{
|
||||
{
|
||||
name: "valid state with next URL",
|
||||
state: "randomstate:L3NpZ24_ZG9jPXRlc3Q", // base64 for "/sign?doc=test"
|
||||
expectedURL: "/sign?doc=test",
|
||||
},
|
||||
{
|
||||
name: "state without separator",
|
||||
state: "invalidstate",
|
||||
expectedURL: "/",
|
||||
},
|
||||
{
|
||||
name: "state with invalid base64",
|
||||
state: "randomstate:invalid-base64!",
|
||||
expectedURL: "/",
|
||||
},
|
||||
{
|
||||
name: "empty state",
|
||||
state: "",
|
||||
expectedURL: "/",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// We can't easily test the full HandleCallback without mocking OAuth2 exchange
|
||||
// So we test the state parsing logic by calling with invalid code
|
||||
_, nextURL, _ := service.HandleCallback(context.Background(), "invalid-code", tt.state)
|
||||
|
||||
if nextURL != tt.expectedURL {
|
||||
t.Errorf("NextURL = %v, expected %v", nextURL, tt.expectedURL)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func createTestService() *OauthService {
|
||||
return createTestServiceWithSecure(false)
|
||||
}
|
||||
|
||||
func TestOauthService_HandleCallback_DomainRestriction(t *testing.T) {
|
||||
// Create service with domain restriction
|
||||
config := Config{
|
||||
BaseURL: "https://test.example.com",
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
AuthURL: "https://provider.com/auth",
|
||||
TokenURL: "https://provider.com/token",
|
||||
UserInfoURL: "https://provider.com/userinfo",
|
||||
Scopes: []string{"openid", "email", "profile"},
|
||||
AllowedDomain: "example.com",
|
||||
CookieSecret: []byte("test-secret-32-bytes-long-key!"),
|
||||
SecureCookies: false,
|
||||
}
|
||||
service := NewOAuthService(config)
|
||||
|
||||
// Test with disallowed domain - this will fail during OAuth exchange
|
||||
// but we can test the domain check logic by calling IsAllowedDomain directly
|
||||
if service.IsAllowedDomain("user@other.com") {
|
||||
t.Error("Domain restriction should reject other.com emails")
|
||||
}
|
||||
if !service.IsAllowedDomain("user@example.com") {
|
||||
t.Error("Domain restriction should allow example.com emails")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOauthService_GetUser_SessionError(t *testing.T) {
|
||||
// Test with invalid cookie secret to trigger session errors
|
||||
config := Config{
|
||||
BaseURL: "https://test.example.com",
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
AuthURL: "https://provider.com/auth",
|
||||
TokenURL: "https://provider.com/token",
|
||||
UserInfoURL: "https://provider.com/userinfo",
|
||||
Scopes: []string{"openid", "email"},
|
||||
CookieSecret: []byte("short"), // Too short, might cause issues
|
||||
}
|
||||
service := NewOAuthService(config)
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
// Add a malformed cookie to trigger session error
|
||||
r.AddCookie(&http.Cookie{
|
||||
Name: sessionName,
|
||||
Value: "malformed-session-data",
|
||||
})
|
||||
|
||||
_, err := service.GetUser(r)
|
||||
if err == nil {
|
||||
t.Error("Expected error with malformed session")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Structure(t *testing.T) {
|
||||
config := Config{
|
||||
BaseURL: "https://ackify.example.com",
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
AuthURL: "https://auth.provider.com/oauth/authorize",
|
||||
TokenURL: "https://auth.provider.com/oauth/token",
|
||||
UserInfoURL: "https://api.provider.com/user",
|
||||
Scopes: []string{"openid", "email", "profile"},
|
||||
AllowedDomain: "example.com",
|
||||
CookieSecret: []byte("32-byte-secret-for-secure-cookies"),
|
||||
SecureCookies: true,
|
||||
}
|
||||
|
||||
// Test that config fields are accessible and correct
|
||||
if config.BaseURL != "https://ackify.example.com" {
|
||||
t.Errorf("BaseURL = %v, expected https://ackify.example.com", config.BaseURL)
|
||||
}
|
||||
if config.ClientID != "test-client-id" {
|
||||
t.Errorf("ClientID = %v, expected test-client-id", config.ClientID)
|
||||
}
|
||||
if len(config.Scopes) != 3 {
|
||||
t.Errorf("Scopes length = %v, expected 3", len(config.Scopes))
|
||||
}
|
||||
if !config.SecureCookies {
|
||||
t.Error("SecureCookies should be true")
|
||||
}
|
||||
if len(config.CookieSecret) == 0 {
|
||||
t.Error("CookieSecret should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOauthService_SetUser_MarshalError(t *testing.T) {
|
||||
service := createTestService()
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
// Test with a user that would cause JSON marshal issues
|
||||
// In Go, it's hard to make json.Marshal fail with basic types
|
||||
// but we can test with a valid user to ensure the path works
|
||||
user := &models.User{
|
||||
Sub: "test-sub",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
}
|
||||
|
||||
err := service.SetUser(w, r, user)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Verify session was set
|
||||
cookies := w.Result().Cookies()
|
||||
if len(cookies) == 0 {
|
||||
t.Error("No cookies set")
|
||||
}
|
||||
}
|
||||
|
||||
func createTestServiceWithSecure(secure bool) *OauthService {
|
||||
config := Config{
|
||||
BaseURL: "https://test.example.com",
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
AuthURL: "https://provider.com/auth",
|
||||
TokenURL: "https://provider.com/token",
|
||||
UserInfoURL: "https://provider.com/userinfo",
|
||||
Scopes: []string{"openid", "email", "profile"},
|
||||
AllowedDomain: "example.com",
|
||||
CookieSecret: []byte("test-secret-32-bytes-long-key!"),
|
||||
SecureCookies: secure,
|
||||
}
|
||||
return NewOAuthService(config)
|
||||
}
|
||||
@@ -0,0 +1,142 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/securecookie"
|
||||
)
|
||||
|
||||
// Config holds all application configuration
|
||||
type Config struct {
|
||||
App AppConfig
|
||||
Database DatabaseConfig
|
||||
OAuth OAuthConfig
|
||||
Server ServerConfig
|
||||
}
|
||||
|
||||
// AppConfig holds general application settings
|
||||
type AppConfig struct {
|
||||
BaseURL string
|
||||
Organisation string
|
||||
SecureCookies bool
|
||||
}
|
||||
|
||||
// DatabaseConfig holds database connection settings
|
||||
type DatabaseConfig struct {
|
||||
DSN string
|
||||
}
|
||||
|
||||
// OAuthConfig holds OAuth authentication settings
|
||||
type OAuthConfig struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
AuthURL string
|
||||
TokenURL string
|
||||
UserInfoURL string
|
||||
Scopes []string
|
||||
AllowedDomain string
|
||||
CookieSecret []byte
|
||||
}
|
||||
|
||||
// ServerConfig holds server settings
|
||||
type ServerConfig struct {
|
||||
ListenAddr string
|
||||
}
|
||||
|
||||
// Load loads configuration from environment variables
|
||||
func Load() (*Config, error) {
|
||||
config := &Config{}
|
||||
|
||||
// App config
|
||||
baseURL := mustGetEnv("APP_BASE_URL")
|
||||
config.App.BaseURL = baseURL
|
||||
config.App.Organisation = mustGetEnv("APP_ORGANISATION")
|
||||
config.App.SecureCookies = strings.HasPrefix(strings.ToLower(baseURL), "https://")
|
||||
|
||||
// Database config
|
||||
config.Database.DSN = mustGetEnv("DB_DSN")
|
||||
|
||||
// OAuth config
|
||||
config.OAuth.ClientID = mustGetEnv("OAUTH_CLIENT_ID")
|
||||
config.OAuth.ClientSecret = mustGetEnv("OAUTH_CLIENT_SECRET")
|
||||
config.OAuth.AllowedDomain = os.Getenv("OAUTH_ALLOWED_DOMAIN")
|
||||
|
||||
// Configure OAuth endpoints based on provider or use custom URLs
|
||||
provider := strings.ToLower(getEnv("OAUTH_PROVIDER", ""))
|
||||
switch provider {
|
||||
case "google":
|
||||
config.OAuth.AuthURL = "https://accounts.google.com/o/oauth2/auth"
|
||||
config.OAuth.TokenURL = "https://oauth2.googleapis.com/token"
|
||||
config.OAuth.UserInfoURL = "https://openidconnect.googleapis.com/v1/userinfo"
|
||||
config.OAuth.Scopes = []string{"openid", "email", "profile"}
|
||||
case "github":
|
||||
config.OAuth.AuthURL = "https://github.com/login/oauth/authorize"
|
||||
config.OAuth.TokenURL = "https://github.com/login/oauth/access_token"
|
||||
config.OAuth.UserInfoURL = "https://api.github.com/user"
|
||||
config.OAuth.Scopes = []string{"user:email", "read:user"}
|
||||
case "gitlab":
|
||||
gitlabURL := getEnv("OAUTH_GITLAB_URL", "https://gitlab.com")
|
||||
config.OAuth.AuthURL = fmt.Sprintf("%s/oauth/authorize", gitlabURL)
|
||||
config.OAuth.TokenURL = fmt.Sprintf("%s/oauth/token", gitlabURL)
|
||||
config.OAuth.UserInfoURL = fmt.Sprintf("%s/api/v4/user", gitlabURL)
|
||||
config.OAuth.Scopes = []string{"read_user", "profile"}
|
||||
default:
|
||||
// Custom OAuth provider - all URLs must be explicitly set
|
||||
config.OAuth.AuthURL = mustGetEnv("OAUTH_AUTH_URL")
|
||||
config.OAuth.TokenURL = mustGetEnv("OAUTH_TOKEN_URL")
|
||||
config.OAuth.UserInfoURL = mustGetEnv("OAUTH_USERINFO_URL")
|
||||
scopesStr := getEnv("OAUTH_SCOPES", "openid,email,profile")
|
||||
config.OAuth.Scopes = strings.Split(scopesStr, ",")
|
||||
}
|
||||
|
||||
cookieSecret, err := parseCookieSecret()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse cookie secret: %w", err)
|
||||
}
|
||||
config.OAuth.CookieSecret = cookieSecret
|
||||
|
||||
// Server config
|
||||
config.Server.ListenAddr = getEnv("LISTEN_ADDR", ":8080")
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// mustGetEnv gets an environment variable or panics if not found
|
||||
func mustGetEnv(key string) string {
|
||||
value := strings.TrimSpace(os.Getenv(key))
|
||||
if value == "" {
|
||||
panic(fmt.Sprintf("missing required environment variable: %s", key))
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// getEnv gets an environment variable with a default value
|
||||
func getEnv(key, defaultValue string) string {
|
||||
value := strings.TrimSpace(os.Getenv(key))
|
||||
if value == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// parseCookieSecret parses the cookie secret from environment
|
||||
func parseCookieSecret() ([]byte, error) {
|
||||
raw := os.Getenv("OAUTH_COOKIE_SECRET")
|
||||
if raw == "" {
|
||||
// Generate random 32 bytes for development
|
||||
secret := securecookie.GenerateRandomKey(32)
|
||||
fmt.Println("[WARN] OAUTH_COOKIE_SECRET not set, generated volatile secret (sessions reset on restart)")
|
||||
return secret, nil
|
||||
}
|
||||
|
||||
// Try base64 decoding first
|
||||
if decoded, err := base64.StdEncoding.DecodeString(raw); err == nil && (len(decoded) == 32 || len(decoded) == 64) {
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// Fallback to raw bytes
|
||||
return []byte(raw), nil
|
||||
}
|
||||
@@ -0,0 +1,814 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConfig_Structures(t *testing.T) {
|
||||
t.Run("Config structure", func(t *testing.T) {
|
||||
config := &Config{
|
||||
App: AppConfig{
|
||||
BaseURL: "https://example.com",
|
||||
Organisation: "Test Org",
|
||||
SecureCookies: true,
|
||||
},
|
||||
Database: DatabaseConfig{
|
||||
DSN: "postgres://user:pass@localhost/db",
|
||||
},
|
||||
OAuth: OAuthConfig{
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
AuthURL: "https://provider.com/auth",
|
||||
TokenURL: "https://provider.com/token",
|
||||
UserInfoURL: "https://provider.com/userinfo",
|
||||
Scopes: []string{"openid", "email"},
|
||||
AllowedDomain: "@example.com",
|
||||
CookieSecret: []byte("test-secret"),
|
||||
},
|
||||
Server: ServerConfig{
|
||||
ListenAddr: ":8080",
|
||||
},
|
||||
}
|
||||
|
||||
// Test that all fields are accessible
|
||||
if config.App.BaseURL != "https://example.com" {
|
||||
t.Errorf("App.BaseURL mismatch")
|
||||
}
|
||||
if config.Database.DSN != "postgres://user:pass@localhost/db" {
|
||||
t.Errorf("Database.DSN mismatch")
|
||||
}
|
||||
if config.OAuth.ClientID != "test-client-id" {
|
||||
t.Errorf("OAuth.ClientID mismatch")
|
||||
}
|
||||
if config.Server.ListenAddr != ":8080" {
|
||||
t.Errorf("Server.ListenAddr mismatch")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("AppConfig structure", func(t *testing.T) {
|
||||
app := AppConfig{
|
||||
BaseURL: "https://ackify.example.com",
|
||||
Organisation: "My Company",
|
||||
SecureCookies: true,
|
||||
}
|
||||
|
||||
if app.BaseURL == "" {
|
||||
t.Error("BaseURL should not be empty")
|
||||
}
|
||||
if app.Organisation == "" {
|
||||
t.Error("Organisation should not be empty")
|
||||
}
|
||||
if !app.SecureCookies {
|
||||
t.Error("SecureCookies should be true for HTTPS")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OAuthConfig structure", func(t *testing.T) {
|
||||
oauth := OAuthConfig{
|
||||
ClientID: "oauth-client-123",
|
||||
ClientSecret: "oauth-secret-456",
|
||||
AuthURL: "https://oauth.provider.com/auth",
|
||||
TokenURL: "https://oauth.provider.com/token",
|
||||
UserInfoURL: "https://oauth.provider.com/userinfo",
|
||||
Scopes: []string{"openid", "email", "profile"},
|
||||
AllowedDomain: "@company.com",
|
||||
CookieSecret: []byte("32-byte-secret-for-secure-cookies"),
|
||||
}
|
||||
|
||||
// Test required fields
|
||||
if oauth.ClientID == "" {
|
||||
t.Error("ClientID should not be empty")
|
||||
}
|
||||
if oauth.ClientSecret == "" {
|
||||
t.Error("ClientSecret should not be empty")
|
||||
}
|
||||
if len(oauth.Scopes) == 0 {
|
||||
t.Error("Scopes should not be empty")
|
||||
}
|
||||
if len(oauth.CookieSecret) == 0 {
|
||||
t.Error("CookieSecret should not be empty")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetEnv(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
defaultValue string
|
||||
envValue string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "existing environment variable",
|
||||
key: "TEST_ENV_VAR",
|
||||
defaultValue: "default",
|
||||
envValue: "custom_value",
|
||||
expected: "custom_value",
|
||||
},
|
||||
{
|
||||
name: "missing environment variable uses default",
|
||||
key: "MISSING_ENV_VAR",
|
||||
defaultValue: "default_value",
|
||||
envValue: "",
|
||||
expected: "default_value",
|
||||
},
|
||||
{
|
||||
name: "empty environment variable uses default",
|
||||
key: "EMPTY_ENV_VAR",
|
||||
defaultValue: "default_value",
|
||||
envValue: "",
|
||||
expected: "default_value",
|
||||
},
|
||||
{
|
||||
name: "whitespace-only environment variable uses default",
|
||||
key: "WHITESPACE_ENV_VAR",
|
||||
defaultValue: "default_value",
|
||||
envValue: " ",
|
||||
expected: "default_value",
|
||||
},
|
||||
{
|
||||
name: "environment variable with leading/trailing spaces",
|
||||
key: "SPACES_ENV_VAR",
|
||||
defaultValue: "default",
|
||||
envValue: " value_with_spaces ",
|
||||
expected: "value_with_spaces",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Clean up environment variable before test
|
||||
os.Unsetenv(tt.key)
|
||||
|
||||
// Set environment variable if specified
|
||||
if tt.envValue != "" {
|
||||
os.Setenv(tt.key, tt.envValue)
|
||||
defer os.Unsetenv(tt.key)
|
||||
}
|
||||
|
||||
result := getEnv(tt.key, tt.defaultValue)
|
||||
if result != tt.expected {
|
||||
t.Errorf("getEnv() = %v, expected %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMustGetEnv(t *testing.T) {
|
||||
t.Run("existing environment variable", func(t *testing.T) {
|
||||
key := "TEST_MUST_ENV_VAR"
|
||||
expected := "test_value"
|
||||
os.Setenv(key, expected)
|
||||
defer os.Unsetenv(key)
|
||||
|
||||
result := mustGetEnv(key)
|
||||
if result != expected {
|
||||
t.Errorf("mustGetEnv() = %v, expected %v", result, expected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("environment variable with spaces is trimmed", func(t *testing.T) {
|
||||
key := "TEST_MUST_ENV_VAR_SPACES"
|
||||
os.Setenv(key, " trimmed_value ")
|
||||
defer os.Unsetenv(key)
|
||||
|
||||
result := mustGetEnv(key)
|
||||
if result != "trimmed_value" {
|
||||
t.Errorf("mustGetEnv() = %v, expected 'trimmed_value'", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing environment variable panics", func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Error("mustGetEnv() should panic for missing environment variable")
|
||||
}
|
||||
}()
|
||||
|
||||
mustGetEnv("DEFINITELY_MISSING_ENV_VAR")
|
||||
})
|
||||
|
||||
t.Run("empty environment variable panics", func(t *testing.T) {
|
||||
key := "TEST_EMPTY_ENV_VAR"
|
||||
os.Setenv(key, "")
|
||||
defer os.Unsetenv(key)
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Error("mustGetEnv() should panic for empty environment variable")
|
||||
}
|
||||
}()
|
||||
|
||||
mustGetEnv(key)
|
||||
})
|
||||
|
||||
t.Run("whitespace-only environment variable panics", func(t *testing.T) {
|
||||
key := "TEST_WHITESPACE_ENV_VAR"
|
||||
os.Setenv(key, " ")
|
||||
defer os.Unsetenv(key)
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Error("mustGetEnv() should panic for whitespace-only environment variable")
|
||||
}
|
||||
}()
|
||||
|
||||
mustGetEnv(key)
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseCookieSecret(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envValue string
|
||||
expectError bool
|
||||
minLength int
|
||||
}{
|
||||
{
|
||||
name: "missing cookie secret generates random",
|
||||
envValue: "",
|
||||
expectError: false,
|
||||
minLength: 32,
|
||||
},
|
||||
{
|
||||
name: "valid base64 32-byte secret",
|
||||
envValue: base64.StdEncoding.EncodeToString(make([]byte, 32)),
|
||||
expectError: false,
|
||||
minLength: 32,
|
||||
},
|
||||
{
|
||||
name: "valid base64 64-byte secret",
|
||||
envValue: base64.StdEncoding.EncodeToString(make([]byte, 64)),
|
||||
expectError: false,
|
||||
minLength: 64,
|
||||
},
|
||||
{
|
||||
name: "raw string secret",
|
||||
envValue: "this-is-a-raw-string-secret-that-should-work",
|
||||
expectError: false,
|
||||
minLength: 1,
|
||||
},
|
||||
{
|
||||
name: "short raw string secret",
|
||||
envValue: "short",
|
||||
expectError: false,
|
||||
minLength: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Clean up environment variable
|
||||
os.Unsetenv("OAUTH_COOKIE_SECRET")
|
||||
|
||||
// Set environment variable if specified
|
||||
if tt.envValue != "" {
|
||||
os.Setenv("OAUTH_COOKIE_SECRET", tt.envValue)
|
||||
defer os.Unsetenv("OAUTH_COOKIE_SECRET")
|
||||
}
|
||||
|
||||
result, err := parseCookieSecret()
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
return
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(result) < tt.minLength {
|
||||
t.Errorf("Cookie secret length %d is less than minimum %d", len(result), tt.minLength)
|
||||
}
|
||||
|
||||
// For empty env value, should generate a random 32-byte secret
|
||||
if tt.envValue == "" && len(result) != 32 {
|
||||
t.Errorf("Generated secret should be 32 bytes, got %d", len(result))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_GoogleProvider(t *testing.T) {
|
||||
// Set up environment variables for Google OAuth
|
||||
envVars := map[string]string{
|
||||
"APP_BASE_URL": "https://ackify.example.com",
|
||||
"APP_ORGANISATION": "Test Organisation",
|
||||
"DB_DSN": "postgres://user:pass@localhost/test",
|
||||
"OAUTH_CLIENT_ID": "google-client-id",
|
||||
"OAUTH_CLIENT_SECRET": "google-client-secret",
|
||||
"OAUTH_PROVIDER": "google",
|
||||
"OAUTH_ALLOWED_DOMAIN": "@example.com",
|
||||
"OAUTH_COOKIE_SECRET": base64.StdEncoding.EncodeToString(make([]byte, 32)),
|
||||
"LISTEN_ADDR": ":8080",
|
||||
}
|
||||
|
||||
// Set environment variables
|
||||
for key, value := range envVars {
|
||||
os.Setenv(key, value)
|
||||
defer os.Unsetenv(key)
|
||||
}
|
||||
|
||||
config, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
|
||||
// Test App config
|
||||
if config.App.BaseURL != "https://ackify.example.com" {
|
||||
t.Errorf("App.BaseURL = %v, expected https://ackify.example.com", config.App.BaseURL)
|
||||
}
|
||||
if config.App.Organisation != "Test Organisation" {
|
||||
t.Errorf("App.Organisation = %v, expected Test Organisation", config.App.Organisation)
|
||||
}
|
||||
if !config.App.SecureCookies {
|
||||
t.Error("App.SecureCookies should be true for HTTPS base URL")
|
||||
}
|
||||
|
||||
// Test Database config
|
||||
if config.Database.DSN != "postgres://user:pass@localhost/test" {
|
||||
t.Errorf("Database.DSN = %v, expected postgres://user:pass@localhost/test", config.Database.DSN)
|
||||
}
|
||||
|
||||
// Test OAuth config for Google
|
||||
if config.OAuth.ClientID != "google-client-id" {
|
||||
t.Errorf("OAuth.ClientID = %v, expected google-client-id", config.OAuth.ClientID)
|
||||
}
|
||||
if config.OAuth.AuthURL != "https://accounts.google.com/o/oauth2/auth" {
|
||||
t.Errorf("OAuth.AuthURL = %v, expected Google auth URL", config.OAuth.AuthURL)
|
||||
}
|
||||
if config.OAuth.TokenURL != "https://oauth2.googleapis.com/token" {
|
||||
t.Errorf("OAuth.TokenURL = %v, expected Google token URL", config.OAuth.TokenURL)
|
||||
}
|
||||
if config.OAuth.UserInfoURL != "https://openidconnect.googleapis.com/v1/userinfo" {
|
||||
t.Errorf("OAuth.UserInfoURL = %v, expected Google userinfo URL", config.OAuth.UserInfoURL)
|
||||
}
|
||||
expectedScopes := []string{"openid", "email", "profile"}
|
||||
if !equalSlices(config.OAuth.Scopes, expectedScopes) {
|
||||
t.Errorf("OAuth.Scopes = %v, expected %v", config.OAuth.Scopes, expectedScopes)
|
||||
}
|
||||
if config.OAuth.AllowedDomain != "@example.com" {
|
||||
t.Errorf("OAuth.AllowedDomain = %v, expected @example.com", config.OAuth.AllowedDomain)
|
||||
}
|
||||
if len(config.OAuth.CookieSecret) != 32 {
|
||||
t.Errorf("OAuth.CookieSecret length = %d, expected 32", len(config.OAuth.CookieSecret))
|
||||
}
|
||||
|
||||
// Test Server config
|
||||
if config.Server.ListenAddr != ":8080" {
|
||||
t.Errorf("Server.ListenAddr = %v, expected :8080", config.Server.ListenAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_GitHubProvider(t *testing.T) {
|
||||
envVars := map[string]string{
|
||||
"APP_BASE_URL": "http://localhost:8080",
|
||||
"APP_ORGANISATION": "GitHub Test",
|
||||
"DB_DSN": "postgres://user:pass@localhost/github",
|
||||
"OAUTH_CLIENT_ID": "github-client-id",
|
||||
"OAUTH_CLIENT_SECRET": "github-client-secret",
|
||||
"OAUTH_PROVIDER": "github",
|
||||
"OAUTH_COOKIE_SECRET": base64.StdEncoding.EncodeToString(make([]byte, 32)),
|
||||
}
|
||||
|
||||
for key, value := range envVars {
|
||||
os.Setenv(key, value)
|
||||
defer os.Unsetenv(key)
|
||||
}
|
||||
|
||||
config, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
|
||||
// Test GitHub-specific OAuth config
|
||||
if config.OAuth.AuthURL != "https://github.com/login/oauth/authorize" {
|
||||
t.Errorf("OAuth.AuthURL = %v, expected GitHub auth URL", config.OAuth.AuthURL)
|
||||
}
|
||||
if config.OAuth.TokenURL != "https://github.com/login/oauth/access_token" {
|
||||
t.Errorf("OAuth.TokenURL = %v, expected GitHub token URL", config.OAuth.TokenURL)
|
||||
}
|
||||
if config.OAuth.UserInfoURL != "https://api.github.com/user" {
|
||||
t.Errorf("OAuth.UserInfoURL = %v, expected GitHub API user URL", config.OAuth.UserInfoURL)
|
||||
}
|
||||
expectedScopes := []string{"user:email", "read:user"}
|
||||
if !equalSlices(config.OAuth.Scopes, expectedScopes) {
|
||||
t.Errorf("OAuth.Scopes = %v, expected %v", config.OAuth.Scopes, expectedScopes)
|
||||
}
|
||||
|
||||
// Test that SecureCookies is false for HTTP
|
||||
if config.App.SecureCookies {
|
||||
t.Error("App.SecureCookies should be false for HTTP base URL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_GitLabProvider(t *testing.T) {
|
||||
envVars := map[string]string{
|
||||
"APP_BASE_URL": "https://ackify.gitlab.com",
|
||||
"APP_ORGANISATION": "GitLab Test",
|
||||
"DB_DSN": "postgres://user:pass@localhost/gitlab",
|
||||
"OAUTH_CLIENT_ID": "gitlab-client-id",
|
||||
"OAUTH_CLIENT_SECRET": "gitlab-client-secret",
|
||||
"OAUTH_PROVIDER": "gitlab",
|
||||
"OAUTH_GITLAB_URL": "https://gitlab.example.com",
|
||||
"OAUTH_COOKIE_SECRET": base64.StdEncoding.EncodeToString(make([]byte, 32)),
|
||||
}
|
||||
|
||||
for key, value := range envVars {
|
||||
os.Setenv(key, value)
|
||||
defer os.Unsetenv(key)
|
||||
}
|
||||
|
||||
config, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
|
||||
// Test GitLab-specific OAuth config with custom URL
|
||||
if config.OAuth.AuthURL != "https://gitlab.example.com/oauth/authorize" {
|
||||
t.Errorf("OAuth.AuthURL = %v, expected custom GitLab auth URL", config.OAuth.AuthURL)
|
||||
}
|
||||
if config.OAuth.TokenURL != "https://gitlab.example.com/oauth/token" {
|
||||
t.Errorf("OAuth.TokenURL = %v, expected custom GitLab token URL", config.OAuth.TokenURL)
|
||||
}
|
||||
if config.OAuth.UserInfoURL != "https://gitlab.example.com/api/v4/user" {
|
||||
t.Errorf("OAuth.UserInfoURL = %v, expected custom GitLab API user URL", config.OAuth.UserInfoURL)
|
||||
}
|
||||
expectedScopes := []string{"read_user", "profile"}
|
||||
if !equalSlices(config.OAuth.Scopes, expectedScopes) {
|
||||
t.Errorf("OAuth.Scopes = %v, expected %v", config.OAuth.Scopes, expectedScopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_GitLabDefaultURL(t *testing.T) {
|
||||
envVars := map[string]string{
|
||||
"APP_BASE_URL": "https://ackify.gitlab.com",
|
||||
"APP_ORGANISATION": "GitLab Test",
|
||||
"DB_DSN": "postgres://user:pass@localhost/gitlab",
|
||||
"OAUTH_CLIENT_ID": "gitlab-client-id",
|
||||
"OAUTH_CLIENT_SECRET": "gitlab-client-secret",
|
||||
"OAUTH_PROVIDER": "gitlab",
|
||||
"OAUTH_COOKIE_SECRET": base64.StdEncoding.EncodeToString(make([]byte, 32)),
|
||||
}
|
||||
|
||||
for key, value := range envVars {
|
||||
os.Setenv(key, value)
|
||||
defer os.Unsetenv(key)
|
||||
}
|
||||
|
||||
// Ensure OAUTH_GITLAB_URL is not set to test default
|
||||
os.Unsetenv("OAUTH_GITLAB_URL")
|
||||
|
||||
config, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
|
||||
// Test GitLab-specific OAuth config with default URL
|
||||
if config.OAuth.AuthURL != "https://gitlab.com/oauth/authorize" {
|
||||
t.Errorf("OAuth.AuthURL = %v, expected default GitLab auth URL", config.OAuth.AuthURL)
|
||||
}
|
||||
if config.OAuth.TokenURL != "https://gitlab.com/oauth/token" {
|
||||
t.Errorf("OAuth.TokenURL = %v, expected default GitLab token URL", config.OAuth.TokenURL)
|
||||
}
|
||||
if config.OAuth.UserInfoURL != "https://gitlab.com/api/v4/user" {
|
||||
t.Errorf("OAuth.UserInfoURL = %v, expected default GitLab API user URL", config.OAuth.UserInfoURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_CustomProvider(t *testing.T) {
|
||||
envVars := map[string]string{
|
||||
"APP_BASE_URL": "https://ackify.custom.com",
|
||||
"APP_ORGANISATION": "Custom Test",
|
||||
"DB_DSN": "postgres://user:pass@localhost/custom",
|
||||
"OAUTH_CLIENT_ID": "custom-client-id",
|
||||
"OAUTH_CLIENT_SECRET": "custom-client-secret",
|
||||
"OAUTH_AUTH_URL": "https://auth.custom.com/oauth/authorize",
|
||||
"OAUTH_TOKEN_URL": "https://auth.custom.com/oauth/token",
|
||||
"OAUTH_USERINFO_URL": "https://api.custom.com/user",
|
||||
"OAUTH_SCOPES": "read,write,admin",
|
||||
"OAUTH_COOKIE_SECRET": base64.StdEncoding.EncodeToString(make([]byte, 32)),
|
||||
}
|
||||
|
||||
for key, value := range envVars {
|
||||
os.Setenv(key, value)
|
||||
defer os.Unsetenv(key)
|
||||
}
|
||||
|
||||
config, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
|
||||
// Test custom OAuth config
|
||||
if config.OAuth.AuthURL != "https://auth.custom.com/oauth/authorize" {
|
||||
t.Errorf("OAuth.AuthURL = %v, expected custom auth URL", config.OAuth.AuthURL)
|
||||
}
|
||||
if config.OAuth.TokenURL != "https://auth.custom.com/oauth/token" {
|
||||
t.Errorf("OAuth.TokenURL = %v, expected custom token URL", config.OAuth.TokenURL)
|
||||
}
|
||||
if config.OAuth.UserInfoURL != "https://api.custom.com/user" {
|
||||
t.Errorf("OAuth.UserInfoURL = %v, expected custom userinfo URL", config.OAuth.UserInfoURL)
|
||||
}
|
||||
expectedScopes := []string{"read", "write", "admin"}
|
||||
if !equalSlices(config.OAuth.Scopes, expectedScopes) {
|
||||
t.Errorf("OAuth.Scopes = %v, expected %v", config.OAuth.Scopes, expectedScopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_MissingRequiredEnvironmentVariables(t *testing.T) {
|
||||
requiredVars := []string{
|
||||
"APP_BASE_URL",
|
||||
"APP_ORGANISATION",
|
||||
"DB_DSN",
|
||||
"OAUTH_CLIENT_ID",
|
||||
"OAUTH_CLIENT_SECRET",
|
||||
}
|
||||
|
||||
for _, missingVar := range requiredVars {
|
||||
t.Run("missing_"+missingVar, func(t *testing.T) {
|
||||
// Set all required variables except the one being tested
|
||||
envVars := map[string]string{
|
||||
"APP_BASE_URL": "https://ackify.example.com",
|
||||
"APP_ORGANISATION": "Test Organisation",
|
||||
"DB_DSN": "postgres://user:pass@localhost/test",
|
||||
"OAUTH_CLIENT_ID": "test-client-id",
|
||||
"OAUTH_CLIENT_SECRET": "test-client-secret",
|
||||
"OAUTH_PROVIDER": "google",
|
||||
}
|
||||
|
||||
// Remove the variable we're testing
|
||||
delete(envVars, missingVar)
|
||||
|
||||
// Set environment variables
|
||||
for key, value := range envVars {
|
||||
os.Setenv(key, value)
|
||||
defer os.Unsetenv(key)
|
||||
}
|
||||
|
||||
// Ensure the missing variable is not set
|
||||
os.Unsetenv(missingVar)
|
||||
|
||||
// Test that Load() panics
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("Load() should panic when %s is missing", missingVar)
|
||||
}
|
||||
}()
|
||||
|
||||
Load()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_CustomProviderMissingRequiredVars(t *testing.T) {
|
||||
customRequiredVars := []string{
|
||||
"OAUTH_AUTH_URL",
|
||||
"OAUTH_TOKEN_URL",
|
||||
"OAUTH_USERINFO_URL",
|
||||
}
|
||||
|
||||
for _, missingVar := range customRequiredVars {
|
||||
t.Run("custom_missing_"+missingVar, func(t *testing.T) {
|
||||
// Set basic required variables
|
||||
envVars := map[string]string{
|
||||
"APP_BASE_URL": "https://ackify.example.com",
|
||||
"APP_ORGANISATION": "Test Organisation",
|
||||
"DB_DSN": "postgres://user:pass@localhost/test",
|
||||
"OAUTH_CLIENT_ID": "test-client-id",
|
||||
"OAUTH_CLIENT_SECRET": "test-client-secret",
|
||||
"OAUTH_AUTH_URL": "https://auth.custom.com/oauth/authorize",
|
||||
"OAUTH_TOKEN_URL": "https://auth.custom.com/oauth/token",
|
||||
"OAUTH_USERINFO_URL": "https://api.custom.com/user",
|
||||
}
|
||||
|
||||
// Remove the variable we're testing
|
||||
delete(envVars, missingVar)
|
||||
|
||||
// Set environment variables
|
||||
for key, value := range envVars {
|
||||
os.Setenv(key, value)
|
||||
defer os.Unsetenv(key)
|
||||
}
|
||||
|
||||
// Ensure the missing variable is not set
|
||||
os.Unsetenv(missingVar)
|
||||
|
||||
// Test that Load() panics for custom provider missing URLs
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("Load() should panic when %s is missing for custom provider", missingVar)
|
||||
}
|
||||
}()
|
||||
|
||||
Load()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_DefaultValues(t *testing.T) {
|
||||
envVars := map[string]string{
|
||||
"APP_BASE_URL": "https://ackify.example.com",
|
||||
"APP_ORGANISATION": "Test Organisation",
|
||||
"DB_DSN": "postgres://user:pass@localhost/test",
|
||||
"OAUTH_CLIENT_ID": "test-client-id",
|
||||
"OAUTH_CLIENT_SECRET": "test-client-secret",
|
||||
"OAUTH_PROVIDER": "google",
|
||||
}
|
||||
|
||||
for key, value := range envVars {
|
||||
os.Setenv(key, value)
|
||||
defer os.Unsetenv(key)
|
||||
}
|
||||
|
||||
// Ensure optional variables are not set to test defaults
|
||||
os.Unsetenv("OAUTH_ALLOWED_DOMAIN")
|
||||
os.Unsetenv("OAUTH_COOKIE_SECRET")
|
||||
os.Unsetenv("LISTEN_ADDR")
|
||||
|
||||
config, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
|
||||
// Test default values
|
||||
if config.OAuth.AllowedDomain != "" {
|
||||
t.Errorf("OAuth.AllowedDomain = %v, expected empty string", config.OAuth.AllowedDomain)
|
||||
}
|
||||
if len(config.OAuth.CookieSecret) != 32 {
|
||||
t.Errorf("OAuth.CookieSecret should be generated as 32 bytes, got %d", len(config.OAuth.CookieSecret))
|
||||
}
|
||||
if config.Server.ListenAddr != ":8080" {
|
||||
t.Errorf("Server.ListenAddr = %v, expected :8080", config.Server.ListenAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_CustomProviderDefaultScopes(t *testing.T) {
|
||||
envVars := map[string]string{
|
||||
"APP_BASE_URL": "https://ackify.custom.com",
|
||||
"APP_ORGANISATION": "Custom Test",
|
||||
"DB_DSN": "postgres://user:pass@localhost/custom",
|
||||
"OAUTH_CLIENT_ID": "custom-client-id",
|
||||
"OAUTH_CLIENT_SECRET": "custom-client-secret",
|
||||
"OAUTH_AUTH_URL": "https://auth.custom.com/oauth/authorize",
|
||||
"OAUTH_TOKEN_URL": "https://auth.custom.com/oauth/token",
|
||||
"OAUTH_USERINFO_URL": "https://api.custom.com/user",
|
||||
"OAUTH_COOKIE_SECRET": base64.StdEncoding.EncodeToString(make([]byte, 32)),
|
||||
}
|
||||
|
||||
for key, value := range envVars {
|
||||
os.Setenv(key, value)
|
||||
defer os.Unsetenv(key)
|
||||
}
|
||||
|
||||
// Ensure OAUTH_SCOPES is not set to test default
|
||||
os.Unsetenv("OAUTH_SCOPES")
|
||||
|
||||
config, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
|
||||
// Test default scopes for custom provider
|
||||
expectedScopes := []string{"openid", "email", "profile"}
|
||||
if !equalSlices(config.OAuth.Scopes, expectedScopes) {
|
||||
t.Errorf("OAuth.Scopes = %v, expected default %v", config.OAuth.Scopes, expectedScopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCookieSecret_InvalidBase64(t *testing.T) {
|
||||
// Test invalid base64 that falls back to raw string
|
||||
os.Setenv("OAUTH_COOKIE_SECRET", "this-is-not-valid-base64!")
|
||||
defer os.Unsetenv("OAUTH_COOKIE_SECRET")
|
||||
|
||||
result, err := parseCookieSecret()
|
||||
if err != nil {
|
||||
t.Errorf("parseCookieSecret() should not fail for invalid base64: %v", err)
|
||||
}
|
||||
|
||||
expected := "this-is-not-valid-base64!"
|
||||
if string(result) != expected {
|
||||
t.Errorf("parseCookieSecret() = %v, expected %v", string(result), expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCookieSecret_ValidBase64WrongLength(t *testing.T) {
|
||||
// Test valid base64 but wrong length (should fall back to raw string)
|
||||
wrongLength := base64.StdEncoding.EncodeToString(make([]byte, 16)) // 16 bytes instead of 32/64
|
||||
os.Setenv("OAUTH_COOKIE_SECRET", wrongLength)
|
||||
defer os.Unsetenv("OAUTH_COOKIE_SECRET")
|
||||
|
||||
result, err := parseCookieSecret()
|
||||
if err != nil {
|
||||
t.Errorf("parseCookieSecret() should not fail for wrong length: %v", err)
|
||||
}
|
||||
|
||||
// Should fall back to raw string
|
||||
if string(result) != wrongLength {
|
||||
t.Errorf("parseCookieSecret() should fall back to raw string for wrong length")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_ErrorInParseCookieSecret(t *testing.T) {
|
||||
envVars := map[string]string{
|
||||
"APP_BASE_URL": "https://ackify.example.com",
|
||||
"APP_ORGANISATION": "Test Organisation",
|
||||
"DB_DSN": "postgres://user:pass@localhost/test",
|
||||
"OAUTH_CLIENT_ID": "test-client-id",
|
||||
"OAUTH_CLIENT_SECRET": "test-client-secret",
|
||||
"OAUTH_PROVIDER": "google",
|
||||
}
|
||||
|
||||
for key, value := range envVars {
|
||||
os.Setenv(key, value)
|
||||
defer os.Unsetenv(key)
|
||||
}
|
||||
|
||||
// Set a cookie secret that won't cause an error (parseCookieSecret doesn't actually return errors in current implementation)
|
||||
os.Setenv("OAUTH_COOKIE_SECRET", "valid-secret")
|
||||
defer os.Unsetenv("OAUTH_COOKIE_SECRET")
|
||||
|
||||
config, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() should not fail: %v", err)
|
||||
}
|
||||
|
||||
// Verify the config was loaded successfully
|
||||
if config == nil {
|
||||
t.Error("Config should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppConfig_SecureCookiesLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
baseURL string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "HTTPS URL should enable secure cookies",
|
||||
baseURL: "https://ackify.example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "HTTP URL should disable secure cookies",
|
||||
baseURL: "http://ackify.example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Mixed case HTTPS should enable secure cookies",
|
||||
baseURL: "HTTPS://ackify.example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Mixed case HTTP should disable secure cookies",
|
||||
baseURL: "HTTP://ackify.example.com",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
envVars := map[string]string{
|
||||
"APP_BASE_URL": tt.baseURL,
|
||||
"APP_ORGANISATION": "Test Organisation",
|
||||
"DB_DSN": "postgres://user:pass@localhost/test",
|
||||
"OAUTH_CLIENT_ID": "test-client-id",
|
||||
"OAUTH_CLIENT_SECRET": "test-client-secret",
|
||||
"OAUTH_PROVIDER": "google",
|
||||
}
|
||||
|
||||
for key, value := range envVars {
|
||||
os.Setenv(key, value)
|
||||
defer os.Unsetenv(key)
|
||||
}
|
||||
|
||||
config, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
|
||||
if config.App.SecureCookies != tt.expected {
|
||||
t.Errorf("SecureCookies = %v, expected %v for URL %s",
|
||||
config.App.SecureCookies, tt.expected, tt.baseURL)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to compare slices
|
||||
func equalSlices(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
// Config holds database configuration
|
||||
type Config struct {
|
||||
DSN string
|
||||
}
|
||||
|
||||
// InitDB initializes the database connection and runs migrations
|
||||
func InitDB(ctx context.Context, config Config) (*sql.DB, error) {
|
||||
db, err := sql.Open("postgres", config.DSN)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
// Test connection with timeout
|
||||
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := db.PingContext(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to ping database: %w", err)
|
||||
}
|
||||
|
||||
// Run migrations
|
||||
if err := runMigrations(ctx, db); err != nil {
|
||||
return nil, fmt.Errorf("failed to run migrations: %w", err)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// runMigrations creates the necessary tables
|
||||
func runMigrations(ctx context.Context, db *sql.DB) error {
|
||||
query := `
|
||||
CREATE TABLE IF NOT EXISTS signatures (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
doc_id TEXT NOT NULL,
|
||||
user_sub TEXT NOT NULL,
|
||||
user_email TEXT NOT NULL,
|
||||
user_name TEXT,
|
||||
signed_at TIMESTAMPTZ NOT NULL,
|
||||
payload_hash TEXT NOT NULL,
|
||||
signature TEXT NOT NULL,
|
||||
nonce TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
referer TEXT,
|
||||
UNIQUE (doc_id, user_sub)
|
||||
);
|
||||
|
||||
-- Migration: Add prev_hash_b64 column if it doesn't exist
|
||||
ALTER TABLE signatures ADD COLUMN IF NOT EXISTS prev_hash TEXT;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_signatures_user ON signatures(user_sub);
|
||||
|
||||
CREATE OR REPLACE FUNCTION prevent_created_at_update()
|
||||
RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
IF OLD.created_at IS DISTINCT FROM NEW.created_at THEN
|
||||
RAISE EXCEPTION 'Cannot modify created_at timestamp';
|
||||
END IF;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
DROP TRIGGER IF EXISTS trigger_prevent_created_at_update ON signatures;
|
||||
CREATE TRIGGER trigger_prevent_created_at_update
|
||||
BEFORE UPDATE ON signatures
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION prevent_created_at_update();
|
||||
`
|
||||
|
||||
_, err := db.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute migrations: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,267 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"ackify/internal/domain/models"
|
||||
)
|
||||
|
||||
type SignatureRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewSignatureRepository creates a new PostgresSQL signature repository
|
||||
func NewSignatureRepository(db *sql.DB) *SignatureRepository {
|
||||
return &SignatureRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *SignatureRepository) Create(ctx context.Context, signature *models.Signature) error {
|
||||
query := `
|
||||
INSERT INTO signatures (doc_id, user_sub, user_email, user_name, signed_at_utc, payload_hash_b64, signature_b64, nonce, referer, prev_hash_b64)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||
RETURNING id, created_at
|
||||
`
|
||||
|
||||
err := r.db.QueryRowContext(
|
||||
ctx, query,
|
||||
signature.DocID,
|
||||
signature.UserSub,
|
||||
signature.UserEmail,
|
||||
signature.UserName,
|
||||
signature.SignedAtUTC,
|
||||
signature.PayloadHashB64,
|
||||
signature.SignatureB64,
|
||||
signature.Nonce,
|
||||
signature.Referer,
|
||||
signature.PrevHashB64,
|
||||
).Scan(&signature.ID, &signature.CreatedAt)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create signature: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SignatureRepository) GetByDocAndUser(ctx context.Context, docID, userSub string) (*models.Signature, error) {
|
||||
query := `
|
||||
SELECT id, doc_id, user_sub, user_email, user_name, signed_at_utc, payload_hash_b64, signature_b64, nonce, created_at, referer, prev_hash_b64
|
||||
FROM signatures
|
||||
WHERE doc_id = $1 AND user_sub = $2
|
||||
`
|
||||
|
||||
signature := &models.Signature{}
|
||||
err := r.db.QueryRowContext(ctx, query, docID, userSub).Scan(
|
||||
&signature.ID,
|
||||
&signature.DocID,
|
||||
&signature.UserSub,
|
||||
&signature.UserEmail,
|
||||
&signature.UserName,
|
||||
&signature.SignedAtUTC,
|
||||
&signature.PayloadHashB64,
|
||||
&signature.SignatureB64,
|
||||
&signature.Nonce,
|
||||
&signature.CreatedAt,
|
||||
&signature.Referer,
|
||||
&signature.PrevHashB64,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, models.ErrSignatureNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get signature: %w", err)
|
||||
}
|
||||
|
||||
return signature, nil
|
||||
}
|
||||
|
||||
func (r *SignatureRepository) GetByDoc(ctx context.Context, docID string) ([]*models.Signature, error) {
|
||||
query := `
|
||||
SELECT id, doc_id, user_sub, user_email, user_name, signed_at_utc, payload_hash_b64, signature_b64, nonce, created_at, referer, prev_hash_b64
|
||||
FROM signatures
|
||||
WHERE doc_id = $1
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, docID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query signatures: %w", err)
|
||||
}
|
||||
defer func(rows *sql.Rows) {
|
||||
_ = rows.Close()
|
||||
}(rows)
|
||||
|
||||
var signatures []*models.Signature
|
||||
for rows.Next() {
|
||||
signature := &models.Signature{}
|
||||
err := rows.Scan(
|
||||
&signature.ID,
|
||||
&signature.DocID,
|
||||
&signature.UserSub,
|
||||
&signature.UserEmail,
|
||||
&signature.UserName,
|
||||
&signature.SignedAtUTC,
|
||||
&signature.PayloadHashB64,
|
||||
&signature.SignatureB64,
|
||||
&signature.Nonce,
|
||||
&signature.CreatedAt,
|
||||
&signature.Referer,
|
||||
&signature.PrevHashB64,
|
||||
)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
signatures = append(signatures, signature)
|
||||
}
|
||||
|
||||
return signatures, nil
|
||||
}
|
||||
|
||||
func (r *SignatureRepository) GetByUser(ctx context.Context, userSub string) ([]*models.Signature, error) {
|
||||
query := `
|
||||
SELECT id, doc_id, user_sub, user_email, user_name, signed_at_utc, payload_hash_b64, signature_b64, nonce, created_at, referer, prev_hash_b64
|
||||
FROM signatures
|
||||
WHERE user_sub = $1
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, userSub)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query user signatures: %w", err)
|
||||
}
|
||||
defer func(rows *sql.Rows) {
|
||||
_ = rows.Close()
|
||||
}(rows)
|
||||
|
||||
var signatures []*models.Signature
|
||||
for rows.Next() {
|
||||
signature := &models.Signature{}
|
||||
err := rows.Scan(
|
||||
&signature.ID,
|
||||
&signature.DocID,
|
||||
&signature.UserSub,
|
||||
&signature.UserEmail,
|
||||
&signature.UserName,
|
||||
&signature.SignedAtUTC,
|
||||
&signature.PayloadHashB64,
|
||||
&signature.SignatureB64,
|
||||
&signature.Nonce,
|
||||
&signature.CreatedAt,
|
||||
&signature.Referer,
|
||||
&signature.PrevHashB64,
|
||||
)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
signatures = append(signatures, signature)
|
||||
}
|
||||
|
||||
return signatures, nil
|
||||
}
|
||||
|
||||
func (r *SignatureRepository) ExistsByDocAndUser(ctx context.Context, docID, userSub string) (bool, error) {
|
||||
query := `SELECT EXISTS(SELECT 1 FROM signatures WHERE doc_id = $1 AND user_sub = $2)`
|
||||
|
||||
var exists bool
|
||||
err := r.db.QueryRowContext(ctx, query, docID, userSub).Scan(&exists)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check signature existence: %w", err)
|
||||
}
|
||||
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
func (r *SignatureRepository) CheckUserSignatureStatus(ctx context.Context, docID, userIdentifier string) (bool, error) {
|
||||
query := `
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM signatures
|
||||
WHERE doc_id = $1 AND (user_sub = $2 OR LOWER(user_email) = LOWER($2))
|
||||
)
|
||||
`
|
||||
|
||||
var exists bool
|
||||
err := r.db.QueryRowContext(ctx, query, docID, userIdentifier).Scan(&exists)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check user signature status: %w", err)
|
||||
}
|
||||
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
func (r *SignatureRepository) GetLastSignature(ctx context.Context) (*models.Signature, error) {
|
||||
query := `
|
||||
SELECT id, doc_id, user_sub, user_email, user_name, signed_at_utc, payload_hash_b64, signature_b64, nonce, created_at, referer, prev_hash_b64
|
||||
FROM signatures
|
||||
ORDER BY id DESC
|
||||
LIMIT 1
|
||||
`
|
||||
|
||||
signature := &models.Signature{}
|
||||
err := r.db.QueryRowContext(ctx, query).Scan(
|
||||
&signature.ID,
|
||||
&signature.DocID,
|
||||
&signature.UserSub,
|
||||
&signature.UserEmail,
|
||||
&signature.UserName,
|
||||
&signature.SignedAtUTC,
|
||||
&signature.PayloadHashB64,
|
||||
&signature.SignatureB64,
|
||||
&signature.Nonce,
|
||||
&signature.CreatedAt,
|
||||
&signature.Referer,
|
||||
&signature.PrevHashB64,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get last signature: %w", err)
|
||||
}
|
||||
|
||||
return signature, nil
|
||||
}
|
||||
|
||||
func (r *SignatureRepository) GetAllSignaturesOrdered(ctx context.Context) ([]*models.Signature, error) {
|
||||
query := `
|
||||
SELECT id, doc_id, user_sub, user_email, user_name, signed_at_utc, payload_hash_b64, signature_b64, nonce, created_at, referer, prev_hash_b64
|
||||
FROM signatures
|
||||
ORDER BY id ASC`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query all signatures: %w", err)
|
||||
}
|
||||
defer func(rows *sql.Rows) {
|
||||
_ = rows.Close()
|
||||
}(rows)
|
||||
|
||||
var signatures []*models.Signature
|
||||
for rows.Next() {
|
||||
signature := &models.Signature{}
|
||||
err := rows.Scan(
|
||||
&signature.ID,
|
||||
&signature.DocID,
|
||||
&signature.UserSub,
|
||||
&signature.UserEmail,
|
||||
&signature.UserName,
|
||||
&signature.SignedAtUTC,
|
||||
&signature.PayloadHashB64,
|
||||
&signature.SignatureB64,
|
||||
&signature.Nonce,
|
||||
&signature.CreatedAt,
|
||||
&signature.Referer,
|
||||
&signature.PrevHashB64,
|
||||
)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
signatures = append(signatures, signature)
|
||||
}
|
||||
|
||||
return signatures, nil
|
||||
}
|
||||
@@ -0,0 +1,513 @@
|
||||
//go:build integration
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"ackify/internal/domain/models"
|
||||
)
|
||||
|
||||
func TestRepository_Concurrency_Integration(t *testing.T) {
|
||||
testDB := SetupTestDB(t)
|
||||
repo := NewSignatureRepository(testDB.DB)
|
||||
factory := NewSignatureFactory()
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("concurrent creates different docs", func(t *testing.T) {
|
||||
testDB.ClearTable(t)
|
||||
|
||||
const numGoroutines = 50
|
||||
const signaturesPerGoroutine = 10
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, numGoroutines*signaturesPerGoroutine)
|
||||
|
||||
// Launch concurrent goroutines creating signatures
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
|
||||
for j := 0; j < signaturesPerGoroutine; j++ {
|
||||
sig := factory.CreateSignatureWithDocAndUser(
|
||||
fmt.Sprintf("doc-%d-%d", goroutineID, j),
|
||||
fmt.Sprintf("user-%d-%d", goroutineID, j),
|
||||
fmt.Sprintf("user%d%d@example.com", goroutineID, j),
|
||||
)
|
||||
|
||||
if err := repo.Create(ctx, sig); err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
// Check for errors
|
||||
for err := range errors {
|
||||
t.Errorf("Concurrent create error: %v", err)
|
||||
}
|
||||
|
||||
// Verify all signatures were created
|
||||
expectedCount := numGoroutines * signaturesPerGoroutine
|
||||
actualCount := testDB.GetTableCount(t)
|
||||
if actualCount != expectedCount {
|
||||
t.Errorf("Expected %d signatures, got %d", expectedCount, actualCount)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("concurrent creates with duplicate attempts", func(t *testing.T) {
|
||||
testDB.ClearTable(t)
|
||||
|
||||
const numGoroutines = 20
|
||||
docID := "shared-doc"
|
||||
userSub := "shared-user"
|
||||
|
||||
var wg sync.WaitGroup
|
||||
successCount := make(chan int, numGoroutines)
|
||||
errorCount := make(chan int, numGoroutines)
|
||||
|
||||
// Launch concurrent goroutines trying to create the same signature
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
sig := factory.CreateSignatureWithDocAndUser(
|
||||
docID,
|
||||
userSub,
|
||||
"shared@example.com",
|
||||
)
|
||||
|
||||
if err := repo.Create(ctx, sig); err != nil {
|
||||
errorCount <- 1
|
||||
} else {
|
||||
successCount <- 1
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(successCount)
|
||||
close(errorCount)
|
||||
|
||||
// Count results
|
||||
successes := 0
|
||||
failures := 0
|
||||
for range successCount {
|
||||
successes++
|
||||
}
|
||||
for range errorCount {
|
||||
failures++
|
||||
}
|
||||
|
||||
// Only one should succeed due to unique constraint
|
||||
if successes != 1 {
|
||||
t.Errorf("Expected exactly 1 success, got %d", successes)
|
||||
}
|
||||
if failures != numGoroutines-1 {
|
||||
t.Errorf("Expected %d failures, got %d", numGoroutines-1, failures)
|
||||
}
|
||||
|
||||
// Verify only one record exists
|
||||
count := testDB.GetTableCount(t)
|
||||
if count != 1 {
|
||||
t.Errorf("Expected 1 signature after concurrent duplicates, got %d", count)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("concurrent reads during writes", func(t *testing.T) {
|
||||
testDB.ClearTable(t)
|
||||
|
||||
const numWriters = 10
|
||||
const numReaders = 20
|
||||
const numWrites = 5
|
||||
docID := "concurrent-doc"
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Start writers
|
||||
for i := 0; i < numWriters; i++ {
|
||||
wg.Add(1)
|
||||
go func(writerID int) {
|
||||
defer wg.Done()
|
||||
|
||||
for j := 0; j < numWrites; j++ {
|
||||
sig := factory.CreateSignatureWithDocAndUser(
|
||||
docID,
|
||||
fmt.Sprintf("user-%d-%d", writerID, j),
|
||||
fmt.Sprintf("user%d%d@example.com", writerID, j),
|
||||
)
|
||||
|
||||
_ = repo.Create(ctx, sig)
|
||||
time.Sleep(time.Millisecond) // Small delay to spread writes
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Start readers
|
||||
readResults := make(chan int, numReaders*10) // Buffer for all results
|
||||
for i := 0; i < numReaders; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
for j := 0; j < 10; j++ {
|
||||
signatures, err := repo.GetByDoc(ctx, docID)
|
||||
if err != nil {
|
||||
t.Errorf("Concurrent read error: %v", err)
|
||||
return
|
||||
}
|
||||
readResults <- len(signatures)
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(readResults)
|
||||
|
||||
// Verify reads were consistent (no corruption)
|
||||
for count := range readResults {
|
||||
if count < 0 || count > numWriters*numWrites {
|
||||
t.Errorf("Invalid read result: %d (should be 0-%d)", count, numWriters*numWrites)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify final count
|
||||
finalCount := testDB.GetTableCount(t)
|
||||
expectedCount := numWriters * numWrites
|
||||
if finalCount != expectedCount {
|
||||
t.Errorf("Expected %d final signatures, got %d", expectedCount, finalCount)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("concurrent GetLastSignature during creates", func(t *testing.T) {
|
||||
testDB.ClearTable(t)
|
||||
|
||||
const numCreators = 10
|
||||
const numReaders = 5
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Start creators
|
||||
for i := 0; i < numCreators; i++ {
|
||||
wg.Add(1)
|
||||
go func(creatorID int) {
|
||||
defer wg.Done()
|
||||
|
||||
for j := 0; j < 5; j++ {
|
||||
sig := factory.CreateSignatureWithUser(
|
||||
fmt.Sprintf("user-%d-%d", creatorID, j),
|
||||
fmt.Sprintf("user%d%d@example.com", creatorID, j),
|
||||
)
|
||||
|
||||
_ = repo.Create(ctx, sig)
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Start readers calling GetLastSignature
|
||||
lastSigResults := make(chan *models.Signature, numReaders*10)
|
||||
for i := 0; i < numReaders; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
for j := 0; j < 10; j++ {
|
||||
lastSig, err := repo.GetLastSignature(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("GetLastSignature error: %v", err)
|
||||
return
|
||||
}
|
||||
lastSigResults <- lastSig
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(lastSigResults)
|
||||
|
||||
// Verify GetLastSignature results are valid
|
||||
for sig := range lastSigResults {
|
||||
if sig != nil {
|
||||
// Should have valid ID assigned by database
|
||||
if sig.ID <= 0 {
|
||||
t.Error("GetLastSignature returned signature with invalid ID")
|
||||
}
|
||||
// Should have valid required fields
|
||||
if sig.DocID == "" || sig.UserSub == "" {
|
||||
t.Error("GetLastSignature returned signature with empty required fields")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("concurrent GetAllSignaturesOrdered during creates", func(t *testing.T) {
|
||||
testDB.ClearTable(t)
|
||||
|
||||
const numCreators = 5
|
||||
const numReaders = 3
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Start creators
|
||||
for i := 0; i < numCreators; i++ {
|
||||
wg.Add(1)
|
||||
go func(creatorID int) {
|
||||
defer wg.Done()
|
||||
|
||||
for j := 0; j < 10; j++ {
|
||||
sig := factory.CreateSignatureWithUser(
|
||||
fmt.Sprintf("concurrent-user-%d-%d", creatorID, j),
|
||||
fmt.Sprintf("user%d%d@example.com", creatorID, j),
|
||||
)
|
||||
|
||||
_ = repo.Create(ctx, sig)
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Start readers calling GetAllSignaturesOrdered
|
||||
orderingErrors := make(chan error, numReaders*5)
|
||||
for i := 0; i < numReaders; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
for j := 0; j < 5; j++ {
|
||||
signatures, err := repo.GetAllSignaturesOrdered(ctx)
|
||||
if err != nil {
|
||||
orderingErrors <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Verify ordering (ID should be ascending)
|
||||
for k := 1; k < len(signatures); k++ {
|
||||
if signatures[k].ID <= signatures[k-1].ID {
|
||||
orderingErrors <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(orderingErrors)
|
||||
|
||||
// Check for ordering violations
|
||||
for err := range orderingErrors {
|
||||
if err != nil {
|
||||
t.Errorf("Concurrent ordering error: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("stress test with mixed operations", func(t *testing.T) {
|
||||
testDB.ClearTable(t)
|
||||
|
||||
const duration = 2 * time.Second
|
||||
const numWorkers = 20
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, duration)
|
||||
defer cancel()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
operationCounts := make(chan map[string]int, numWorkers)
|
||||
|
||||
// Start workers doing mixed operations
|
||||
for i := 0; i < numWorkers; i++ {
|
||||
wg.Add(1)
|
||||
go func(workerID int) {
|
||||
defer wg.Done()
|
||||
|
||||
counts := map[string]int{
|
||||
"creates": 0,
|
||||
"gets": 0,
|
||||
"exists": 0,
|
||||
"last": 0,
|
||||
"all": 0,
|
||||
"errors": 0,
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
operationCounts <- counts
|
||||
return
|
||||
default:
|
||||
// Randomly choose operation
|
||||
switch workerID % 5 {
|
||||
case 0: // Create
|
||||
sig := factory.CreateSignatureWithUser(
|
||||
fmt.Sprintf("stress-user-%d-%d", workerID, counts["creates"]),
|
||||
fmt.Sprintf("stress%d%d@example.com", workerID, counts["creates"]),
|
||||
)
|
||||
if err := repo.Create(ctx, sig); err != nil {
|
||||
counts["errors"]++
|
||||
} else {
|
||||
counts["creates"]++
|
||||
}
|
||||
|
||||
case 1: // GetByDocAndUser
|
||||
_, err := repo.GetByDocAndUser(ctx, "test-doc-123", "user-123")
|
||||
if err != nil && !strings.Contains(err.Error(), "not found") {
|
||||
counts["errors"]++
|
||||
} else {
|
||||
counts["gets"]++
|
||||
}
|
||||
|
||||
case 2: // ExistsByDocAndUser
|
||||
_, err := repo.ExistsByDocAndUser(ctx, "test-doc-123", "user-123")
|
||||
if err != nil {
|
||||
counts["errors"]++
|
||||
} else {
|
||||
counts["exists"]++
|
||||
}
|
||||
|
||||
case 3: // GetLastSignature
|
||||
_, err := repo.GetLastSignature(ctx)
|
||||
if err != nil {
|
||||
counts["errors"]++
|
||||
} else {
|
||||
counts["last"]++
|
||||
}
|
||||
|
||||
case 4: // GetAllSignaturesOrdered
|
||||
_, err := repo.GetAllSignaturesOrdered(ctx)
|
||||
if err != nil {
|
||||
counts["errors"]++
|
||||
} else {
|
||||
counts["all"]++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(operationCounts)
|
||||
|
||||
// Aggregate results
|
||||
totalOps := 0
|
||||
totalErrors := 0
|
||||
for counts := range operationCounts {
|
||||
for op, count := range counts {
|
||||
if op == "errors" {
|
||||
totalErrors += count
|
||||
} else {
|
||||
totalOps += count
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Stress test completed: %d operations, %d errors", totalOps, totalErrors)
|
||||
|
||||
// Should have completed many operations with minimal errors
|
||||
if totalOps < 100 {
|
||||
t.Errorf("Expected at least 100 operations, got %d", totalOps)
|
||||
}
|
||||
|
||||
// Error rate should be reasonable (< 10%)
|
||||
errorRate := float64(totalErrors) / float64(totalOps+totalErrors) * 100
|
||||
if errorRate > 10 {
|
||||
t.Errorf("Error rate too high: %.2f%% (expected < 10%%)", errorRate)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRepository_DeadlockPrevention_Integration(t *testing.T) {
|
||||
testDB := SetupTestDB(t)
|
||||
factory := NewSignatureFactory()
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("avoid deadlocks with multiple table access patterns", func(t *testing.T) {
|
||||
testDB.ClearTable(t)
|
||||
|
||||
const numWorkers = 10
|
||||
const opsPerWorker = 20
|
||||
|
||||
var wg sync.WaitGroup
|
||||
deadlockErrors := make(chan error, numWorkers)
|
||||
|
||||
// Workers with different access patterns that could cause deadlocks
|
||||
for i := 0; i < numWorkers; i++ {
|
||||
wg.Add(1)
|
||||
go func(workerID int) {
|
||||
defer wg.Done()
|
||||
repo := NewSignatureRepository(testDB.DB)
|
||||
|
||||
for j := 0; j < opsPerWorker; j++ {
|
||||
// Pattern 1: Create then immediately query
|
||||
if workerID%2 == 0 {
|
||||
sig := factory.CreateSignatureWithUser(
|
||||
fmt.Sprintf("pattern1-user-%d-%d", workerID, j),
|
||||
fmt.Sprintf("pattern1-%d%d@example.com", workerID, j),
|
||||
)
|
||||
|
||||
if err := repo.Create(ctx, sig); err == nil {
|
||||
_, _ = repo.GetByDocAndUser(ctx, sig.DocID, sig.UserSub)
|
||||
_, _ = repo.ExistsByDocAndUser(ctx, sig.DocID, sig.UserSub)
|
||||
}
|
||||
} else {
|
||||
// Pattern 2: Query then create
|
||||
testDocID := fmt.Sprintf("pattern2-doc-%d", workerID)
|
||||
testUserSub := fmt.Sprintf("pattern2-user-%d", j)
|
||||
|
||||
_, _ = repo.GetByDoc(ctx, testDocID)
|
||||
_, _ = repo.GetByUser(ctx, testUserSub)
|
||||
|
||||
sig := factory.CreateSignatureWithDocAndUser(
|
||||
testDocID,
|
||||
testUserSub,
|
||||
"pattern2@example.com",
|
||||
)
|
||||
_ = repo.Create(ctx, sig)
|
||||
}
|
||||
|
||||
// Small random delay to increase chance of contention
|
||||
time.Sleep(time.Duration(workerID%3+1) * time.Millisecond)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait with timeout to detect deadlocks
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
wg.Wait()
|
||||
done <- true
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Success - no deadlocks
|
||||
case <-time.After(30 * time.Second):
|
||||
t.Fatal("Test timed out - possible deadlock detected")
|
||||
}
|
||||
|
||||
close(deadlockErrors)
|
||||
|
||||
// Check for deadlock-specific errors
|
||||
for err := range deadlockErrors {
|
||||
if err != nil {
|
||||
t.Errorf("Deadlock-related error: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,504 @@
|
||||
//go:build integration
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"ackify/internal/domain/models"
|
||||
)
|
||||
|
||||
func TestRepository_DatabaseConstraints_Integration(t *testing.T) {
|
||||
testDB := SetupTestDB(t)
|
||||
repo := NewSignatureRepository(testDB.DB)
|
||||
factory := NewSignatureFactory()
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("unique constraint violation", func(t *testing.T) {
|
||||
testDB.ClearTable(t)
|
||||
|
||||
// Create first signature
|
||||
sig1 := factory.CreateSignatureWithDocAndUser("doc1", "user1", "user1@example.com")
|
||||
err := repo.Create(ctx, sig1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create first signature: %v", err)
|
||||
}
|
||||
|
||||
// Try to create duplicate
|
||||
sig2 := factory.CreateSignatureWithDocAndUser("doc1", "user1", "user1@example.com")
|
||||
err = repo.Create(ctx, sig2)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for unique constraint violation")
|
||||
}
|
||||
|
||||
// Verify it's a constraint violation (PostgreSQL specific)
|
||||
if !strings.Contains(err.Error(), "duplicate key") &&
|
||||
!strings.Contains(err.Error(), "unique constraint") {
|
||||
t.Errorf("Expected constraint violation error, got: %v", err)
|
||||
}
|
||||
|
||||
// Verify only one record exists
|
||||
count := testDB.GetTableCount(t)
|
||||
if count != 1 {
|
||||
t.Errorf("Expected 1 record after constraint violation, got %d", count)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("null constraints", func(t *testing.T) {
|
||||
testDB.ClearTable(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
modifyFn func(*models.Signature)
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid signature with nulls",
|
||||
modifyFn: func(s *models.Signature) {
|
||||
s.UserName = nil
|
||||
s.Referer = nil
|
||||
s.PrevHashB64 = nil
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty doc_id is allowed by DB",
|
||||
modifyFn: func(s *models.Signature) { s.DocID = "" },
|
||||
wantErr: false, // Empty string != NULL in PostgreSQL
|
||||
},
|
||||
{
|
||||
name: "empty user_sub is allowed by DB",
|
||||
modifyFn: func(s *models.Signature) { s.UserSub = "" },
|
||||
wantErr: false, // Empty string != NULL in PostgreSQL
|
||||
},
|
||||
{
|
||||
name: "empty user_email is allowed by DB",
|
||||
modifyFn: func(s *models.Signature) { s.UserEmail = "" },
|
||||
wantErr: false, // Empty string != NULL in PostgreSQL
|
||||
},
|
||||
{
|
||||
name: "empty payload_hash_b64 is allowed by DB",
|
||||
modifyFn: func(s *models.Signature) { s.PayloadHashB64 = "" },
|
||||
wantErr: false, // Empty string != NULL in PostgreSQL
|
||||
},
|
||||
{
|
||||
name: "empty signature_b64 is allowed by DB",
|
||||
modifyFn: func(s *models.Signature) { s.SignatureB64 = "" },
|
||||
wantErr: false, // Empty string != NULL in PostgreSQL
|
||||
},
|
||||
{
|
||||
name: "empty nonce is allowed by DB",
|
||||
modifyFn: func(s *models.Signature) { s.Nonce = "" },
|
||||
wantErr: false, // Empty string != NULL in PostgreSQL
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sig := factory.CreateSignatureWithDocAndUser(
|
||||
fmt.Sprintf("test-doc-%d", i),
|
||||
fmt.Sprintf("test-user-%d", i),
|
||||
fmt.Sprintf("test%d@example.com", i),
|
||||
)
|
||||
tt.modifyFn(sig)
|
||||
|
||||
err := repo.Create(ctx, sig)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("Expected error for null constraint violation")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("index performance validation", func(t *testing.T) {
|
||||
testDB.ClearTable(t)
|
||||
|
||||
// Create multiple signatures for performance testing
|
||||
const numSignatures = 1000
|
||||
for i := 0; i < numSignatures; i++ {
|
||||
sig := factory.CreateSignatureWithDocAndUser(
|
||||
"perf-doc",
|
||||
fmt.Sprintf("user-%d", i%100), // Reuse some users
|
||||
fmt.Sprintf("user%d@example.com", i),
|
||||
)
|
||||
_ = repo.Create(ctx, sig)
|
||||
}
|
||||
|
||||
// Test indexed queries performance
|
||||
start := time.Now()
|
||||
_, err := repo.GetByDoc(ctx, "perf-doc")
|
||||
duration := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("GetByDoc failed: %v", err)
|
||||
}
|
||||
|
||||
// Should be fast with index
|
||||
if duration > 100*time.Millisecond {
|
||||
t.Errorf("GetByDoc too slow: %v (expected < 100ms)", duration)
|
||||
}
|
||||
|
||||
t.Logf("GetByDoc for %d signatures took: %v", numSignatures, duration)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRepository_Transactions_Integration(t *testing.T) {
|
||||
testDB := SetupTestDB(t)
|
||||
factory := NewSignatureFactory()
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("transaction rollback on constraint violation", func(t *testing.T) {
|
||||
testDB.ClearTable(t)
|
||||
|
||||
// Start transaction
|
||||
tx, err := testDB.DB.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to begin transaction: %v", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Execute operations within transaction context
|
||||
// Create first signature
|
||||
query := `INSERT INTO signatures (doc_id, user_sub, user_email, signed_at_utc, payload_hash_b64, signature_b64, nonce)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)`
|
||||
|
||||
_, err = tx.ExecContext(ctx, query, "test-doc", "test-user", "test@example.com",
|
||||
time.Now().UTC(), "hash1", "sig1", "nonce1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create first signature: %v", err)
|
||||
}
|
||||
|
||||
// Try to create duplicate - should fail
|
||||
_, err = tx.ExecContext(ctx, query, "test-doc", "test-user", "test@example.com",
|
||||
time.Now().UTC(), "hash2", "sig2", "nonce2")
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected constraint violation error")
|
||||
}
|
||||
|
||||
// Rollback transaction
|
||||
err = tx.Rollback()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to rollback transaction: %v", err)
|
||||
}
|
||||
|
||||
// Verify rollback worked - no signatures should exist
|
||||
count := testDB.GetTableCount(t)
|
||||
if count != 0 {
|
||||
t.Errorf("Expected 0 signatures after rollback, got %d", count)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("transaction commit", func(t *testing.T) {
|
||||
testDB.ClearTable(t)
|
||||
|
||||
// Start transaction
|
||||
tx, err := testDB.DB.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to begin transaction: %v", err)
|
||||
}
|
||||
|
||||
// Execute operations within transaction context
|
||||
query := `INSERT INTO signatures (doc_id, user_sub, user_email, signed_at_utc, payload_hash_b64, signature_b64, nonce)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)`
|
||||
|
||||
_, err = tx.ExecContext(ctx, query, "test-doc", "test-user", "test@example.com",
|
||||
time.Now().UTC(), "hash1", "sig1", "nonce1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create signature in transaction: %v", err)
|
||||
}
|
||||
|
||||
// Commit transaction
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to commit transaction: %v", err)
|
||||
}
|
||||
|
||||
// Verify commit worked - signature should exist
|
||||
count := testDB.GetTableCount(t)
|
||||
if count != 1 {
|
||||
t.Errorf("Expected 1 signature after commit, got %d", count)
|
||||
}
|
||||
|
||||
// Verify using repository
|
||||
repo := NewSignatureRepository(testDB.DB)
|
||||
result, err := repo.GetByDocAndUser(ctx, "test-doc", "test-user")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get signature after commit: %v", err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("Expected signature after commit")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("isolation levels", func(t *testing.T) {
|
||||
testDB.ClearTable(t)
|
||||
|
||||
// Create initial signature
|
||||
sig1 := factory.CreateValidSignature()
|
||||
mainRepo := NewSignatureRepository(testDB.DB)
|
||||
_ = mainRepo.Create(ctx, sig1)
|
||||
|
||||
// Start transaction with READ COMMITTED isolation
|
||||
tx1, err := testDB.DB.BeginTx(ctx, &sql.TxOptions{
|
||||
Isolation: sql.LevelReadCommitted,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to begin transaction 1: %v", err)
|
||||
}
|
||||
defer tx1.Rollback()
|
||||
|
||||
repo1 := NewSignatureRepository(testDB.DB)
|
||||
|
||||
// Start another transaction
|
||||
tx2, err := testDB.DB.BeginTx(ctx, &sql.TxOptions{
|
||||
Isolation: sql.LevelReadCommitted,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to begin transaction 2: %v", err)
|
||||
}
|
||||
defer tx2.Rollback()
|
||||
|
||||
repo2 := NewSignatureRepository(testDB.DB)
|
||||
|
||||
// Both transactions should see the initial signature
|
||||
result1, err := repo1.GetByDocAndUser(ctx, sig1.DocID, sig1.UserSub)
|
||||
if err != nil {
|
||||
t.Fatalf("Transaction 1 failed to get signature: %v", err)
|
||||
}
|
||||
if result1 == nil {
|
||||
t.Fatal("Transaction 1 expected signature")
|
||||
}
|
||||
|
||||
result2, err := repo2.GetByDocAndUser(ctx, sig1.DocID, sig1.UserSub)
|
||||
if err != nil {
|
||||
t.Fatalf("Transaction 2 failed to get signature: %v", err)
|
||||
}
|
||||
if result2 == nil {
|
||||
t.Fatal("Transaction 2 expected signature")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRepository_DataIntegrity_Integration(t *testing.T) {
|
||||
testDB := SetupTestDB(t)
|
||||
repo := NewSignatureRepository(testDB.DB)
|
||||
factory := NewSignatureFactory()
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("timestamp precision", func(t *testing.T) {
|
||||
testDB.ClearTable(t)
|
||||
|
||||
// Create signature with specific timestamp
|
||||
now := time.Now().UTC()
|
||||
sig := factory.CreateValidSignature()
|
||||
sig.SignedAtUTC = now
|
||||
|
||||
err := repo.Create(ctx, sig)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create signature: %v", err)
|
||||
}
|
||||
|
||||
// Retrieve and verify timestamp precision
|
||||
result, err := repo.GetByDocAndUser(ctx, sig.DocID, sig.UserSub)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get signature: %v", err)
|
||||
}
|
||||
|
||||
// Check timestamp is preserved (allowing for some precision loss)
|
||||
timeDiff := result.SignedAtUTC.Sub(now).Abs()
|
||||
if timeDiff > time.Microsecond {
|
||||
t.Errorf("Timestamp precision lost: expected %v, got %v (diff: %v)",
|
||||
now, result.SignedAtUTC, timeDiff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("string encoding preservation", func(t *testing.T) {
|
||||
testDB.ClearTable(t)
|
||||
|
||||
// Test with various string encodings
|
||||
sig := factory.CreateValidSignature()
|
||||
sig.DocID = "test-éñcode-中文-🎯"
|
||||
sig.UserEmail = "tëst@éxample.com"
|
||||
sig.PayloadHashB64 = "SGVsbG8gV29ybGQh" // "Hello World!" in base64
|
||||
sig.Nonce = "nonce-with-special-chars-αβγ"
|
||||
|
||||
referer := "https://example.com/path/with/émojis🚀?param=value"
|
||||
sig.Referer = &referer
|
||||
|
||||
err := repo.Create(ctx, sig)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create signature with special chars: %v", err)
|
||||
}
|
||||
|
||||
// Retrieve and verify encoding preservation
|
||||
result, err := repo.GetByDocAndUser(ctx, sig.DocID, sig.UserSub)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get signature: %v", err)
|
||||
}
|
||||
|
||||
AssertSignatureEqual(t, sig, result)
|
||||
})
|
||||
|
||||
t.Run("large data handling", func(t *testing.T) {
|
||||
testDB.ClearTable(t)
|
||||
|
||||
// Create signature with large data
|
||||
sig := factory.CreateValidSignature()
|
||||
|
||||
// Large base64 strings (simulate large signatures/hashes)
|
||||
largeData := strings.Repeat("SGVsbG8gV29ybGQh", 100) // Repeat base64 string
|
||||
sig.PayloadHashB64 = largeData
|
||||
sig.SignatureB64 = largeData
|
||||
|
||||
longReferer := "https://example.com/very/long/path/" + strings.Repeat("segment/", 50)
|
||||
sig.Referer = &longReferer
|
||||
|
||||
err := repo.Create(ctx, sig)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create signature with large data: %v", err)
|
||||
}
|
||||
|
||||
// Retrieve and verify
|
||||
result, err := repo.GetByDocAndUser(ctx, sig.DocID, sig.UserSub)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get signature: %v", err)
|
||||
}
|
||||
|
||||
if len(result.PayloadHashB64) != len(sig.PayloadHashB64) {
|
||||
t.Errorf("PayloadHashB64 length mismatch: expected %d, got %d",
|
||||
len(sig.PayloadHashB64), len(result.PayloadHashB64))
|
||||
}
|
||||
|
||||
if len(*result.Referer) != len(*sig.Referer) {
|
||||
t.Errorf("Referer length mismatch: expected %d, got %d",
|
||||
len(*sig.Referer), len(*result.Referer))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRepository_EdgeCases_Integration(t *testing.T) {
|
||||
testDB := SetupTestDB(t)
|
||||
repo := NewSignatureRepository(testDB.DB)
|
||||
factory := NewSignatureFactory()
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("empty string vs null handling", func(t *testing.T) {
|
||||
testDB.ClearTable(t)
|
||||
|
||||
// Test with empty strings for nullable fields
|
||||
sig := factory.CreateValidSignature()
|
||||
emptyString := ""
|
||||
sig.UserName = &emptyString
|
||||
sig.Referer = &emptyString
|
||||
sig.PrevHashB64 = &emptyString
|
||||
|
||||
err := repo.Create(ctx, sig)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create signature with empty strings: %v", err)
|
||||
}
|
||||
|
||||
result, err := repo.GetByDocAndUser(ctx, sig.DocID, sig.UserSub)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get signature: %v", err)
|
||||
}
|
||||
|
||||
// Verify empty strings are preserved (not converted to NULL)
|
||||
if result.UserName == nil || *result.UserName != "" {
|
||||
t.Error("Empty string UserName not preserved")
|
||||
}
|
||||
if result.Referer == nil || *result.Referer != "" {
|
||||
t.Error("Empty string Referer not preserved")
|
||||
}
|
||||
if result.PrevHashB64 == nil || *result.PrevHashB64 != "" {
|
||||
t.Error("Empty string PrevHashB64 not preserved")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("boundary values", func(t *testing.T) {
|
||||
testDB.ClearTable(t)
|
||||
|
||||
// Test with boundary timestamp values
|
||||
sig := factory.CreateValidSignature()
|
||||
|
||||
// Use a very old timestamp
|
||||
sig.SignedAtUTC = time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
err := repo.Create(ctx, sig)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create signature with old timestamp: %v", err)
|
||||
}
|
||||
|
||||
result, err := repo.GetByDocAndUser(ctx, sig.DocID, sig.UserSub)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get signature: %v", err)
|
||||
}
|
||||
|
||||
if !result.SignedAtUTC.Equal(sig.SignedAtUTC) {
|
||||
t.Errorf("Timestamp boundary value not preserved: expected %v, got %v",
|
||||
sig.SignedAtUTC, result.SignedAtUTC)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("case sensitivity", func(t *testing.T) {
|
||||
testDB.ClearTable(t)
|
||||
|
||||
// Create signatures with different case variations
|
||||
sig1 := factory.CreateSignatureWithDocAndUser("Doc1", "User1", "USER1@EXAMPLE.COM")
|
||||
sig2 := factory.CreateSignatureWithDocAndUser("doc1", "user1", "user1@example.com")
|
||||
|
||||
err1 := repo.Create(ctx, sig1)
|
||||
err2 := repo.Create(ctx, sig2)
|
||||
|
||||
if err1 != nil {
|
||||
t.Fatalf("Failed to create signature 1: %v", err1)
|
||||
}
|
||||
if err2 != nil {
|
||||
t.Fatalf("Failed to create signature 2: %v", err2)
|
||||
}
|
||||
|
||||
// Both should exist as they have different case for doc_id and user_sub
|
||||
count := testDB.GetTableCount(t)
|
||||
if count != 2 {
|
||||
t.Errorf("Expected 2 signatures with different cases, got %d", count)
|
||||
}
|
||||
|
||||
// Test CheckUserSignatureStatus with case variations
|
||||
exists1, _ := repo.CheckUserSignatureStatus(ctx, "Doc1", "USER1@EXAMPLE.COM")
|
||||
exists2, _ := repo.CheckUserSignatureStatus(ctx, "doc1", "user1@example.com")
|
||||
exists3, _ := repo.CheckUserSignatureStatus(ctx, "Doc1", "user1@example.com") // Cross-case: different doc case but same email case-insensitive
|
||||
exists4, _ := repo.CheckUserSignatureStatus(ctx, "doc1", "USER1@EXAMPLE.COM") // Cross-case: different doc case but same email case-insensitive
|
||||
|
||||
if !exists1 {
|
||||
t.Error("Expected to find signature with exact case match")
|
||||
}
|
||||
if !exists2 {
|
||||
t.Error("Expected to find signature with exact case match")
|
||||
}
|
||||
if !exists3 {
|
||||
t.Error("Expected to find signature with case-insensitive email match for Doc1")
|
||||
}
|
||||
if !exists4 {
|
||||
t.Error("Expected to find signature with case-insensitive email match for doc1")
|
||||
}
|
||||
|
||||
// Test with non-matching doc_id case
|
||||
exists5, _ := repo.CheckUserSignatureStatus(ctx, "DOC1", "user1@example.com") // All caps doc_id should not match
|
||||
if exists5 {
|
||||
t.Error("Should not find signature with different case for doc_id when no exact match")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,585 @@
|
||||
//go:build integration
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"ackify/internal/domain/models"
|
||||
)
|
||||
|
||||
func TestRepository_Create_Integration(t *testing.T) {
|
||||
testDB := SetupTestDB(t)
|
||||
repo := NewSignatureRepository(testDB.DB)
|
||||
factory := NewSignatureFactory()
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
signature *models.Signature
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "create valid signature",
|
||||
signature: factory.CreateValidSignature(),
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "create minimal signature",
|
||||
signature: factory.CreateMinimalSignature(),
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "create signature with previous hash",
|
||||
signature: factory.CreateChainedSignature("cHJldmlvdXMtaGFzaA=="),
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testDB.ClearTable(t)
|
||||
|
||||
err := repo.Create(ctx, tt.signature)
|
||||
|
||||
if tt.wantError {
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Verify signature was created with ID and CreatedAt
|
||||
if tt.signature.ID <= 0 {
|
||||
t.Error("Expected ID to be set after create")
|
||||
}
|
||||
|
||||
if tt.signature.CreatedAt.IsZero() {
|
||||
t.Error("Expected CreatedAt to be set after create")
|
||||
}
|
||||
|
||||
// Verify data in database
|
||||
count := testDB.GetTableCount(t)
|
||||
if count != 1 {
|
||||
t.Errorf("Expected 1 signature in DB, got %d", count)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepository_Create_UniqueConstraint_Integration(t *testing.T) {
|
||||
testDB := SetupTestDB(t)
|
||||
repo := NewSignatureRepository(testDB.DB)
|
||||
factory := NewSignatureFactory()
|
||||
ctx := context.Background()
|
||||
|
||||
// Create first signature
|
||||
sig1 := factory.CreateSignatureWithDocAndUser("doc1", "user1", "user1@example.com")
|
||||
err := repo.Create(ctx, sig1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create first signature: %v", err)
|
||||
}
|
||||
|
||||
// Try to create duplicate (same doc_id and user_sub)
|
||||
sig2 := factory.CreateSignatureWithDocAndUser("doc1", "user1", "user1@example.com")
|
||||
err = repo.Create(ctx, sig2)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error for duplicate signature but got none")
|
||||
}
|
||||
|
||||
// Should still have only 1 signature
|
||||
count := testDB.GetTableCount(t)
|
||||
if count != 1 {
|
||||
t.Errorf("Expected 1 signature in DB after constraint violation, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepository_GetByDocAndUser_Integration(t *testing.T) {
|
||||
testDB := SetupTestDB(t)
|
||||
repo := NewSignatureRepository(testDB.DB)
|
||||
factory := NewSignatureFactory()
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func() *models.Signature
|
||||
docID string
|
||||
userSub string
|
||||
wantError bool
|
||||
wantNil bool
|
||||
}{
|
||||
{
|
||||
name: "get existing signature",
|
||||
setup: func() *models.Signature {
|
||||
sig := factory.CreateSignatureWithDocAndUser("doc1", "user1", "user1@example.com")
|
||||
_ = repo.Create(ctx, sig)
|
||||
return sig
|
||||
},
|
||||
docID: "doc1",
|
||||
userSub: "user1",
|
||||
wantError: false,
|
||||
wantNil: false,
|
||||
},
|
||||
{
|
||||
name: "get non-existent signature",
|
||||
setup: func() *models.Signature {
|
||||
return nil
|
||||
},
|
||||
docID: "non-existent",
|
||||
userSub: "non-existent",
|
||||
wantError: true,
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "get signature wrong user",
|
||||
setup: func() *models.Signature {
|
||||
sig := factory.CreateSignatureWithDocAndUser("doc1", "user1", "user1@example.com")
|
||||
_ = repo.Create(ctx, sig)
|
||||
return sig
|
||||
},
|
||||
docID: "doc1",
|
||||
userSub: "wrong-user",
|
||||
wantError: true,
|
||||
wantNil: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testDB.ClearTable(t)
|
||||
|
||||
var expected *models.Signature
|
||||
if tt.setup != nil {
|
||||
expected = tt.setup()
|
||||
}
|
||||
|
||||
result, err := repo.GetByDocAndUser(ctx, tt.docID, tt.userSub)
|
||||
|
||||
if tt.wantError {
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
if !errors.Is(err, models.ErrSignatureNotFound) && tt.wantNil {
|
||||
t.Errorf("Expected ErrSignatureNotFound, got: %v", err)
|
||||
}
|
||||
if result != nil {
|
||||
t.Error("Expected nil result with error")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Expected signature but got nil")
|
||||
}
|
||||
|
||||
// Compare with expected (excluding ID and CreatedAt which are set by DB)
|
||||
AssertSignatureEqual(t, expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepository_GetByDoc_Integration(t *testing.T) {
|
||||
testDB := SetupTestDB(t)
|
||||
repo := NewSignatureRepository(testDB.DB)
|
||||
factory := NewSignatureFactory()
|
||||
ctx := context.Background()
|
||||
|
||||
// Setup: Create signatures for multiple docs and users
|
||||
sig1 := factory.CreateSignatureWithDocAndUser("doc1", "user1", "user1@example.com")
|
||||
sig2 := factory.CreateSignatureWithDocAndUser("doc1", "user2", "user2@example.com")
|
||||
sig3 := factory.CreateSignatureWithDocAndUser("doc2", "user1", "user1@example.com")
|
||||
|
||||
_ = repo.Create(ctx, sig1)
|
||||
time.Sleep(10 * time.Millisecond) // Ensure different created_at
|
||||
_ = repo.Create(ctx, sig2)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
_ = repo.Create(ctx, sig3)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
docID string
|
||||
expectedCount int
|
||||
expectedUsers []string
|
||||
}{
|
||||
{
|
||||
name: "get signatures for doc with 2 users",
|
||||
docID: "doc1",
|
||||
expectedCount: 2,
|
||||
expectedUsers: []string{"user2", "user1"}, // Should be ordered by created_at DESC
|
||||
},
|
||||
{
|
||||
name: "get signatures for doc with 1 user",
|
||||
docID: "doc2",
|
||||
expectedCount: 1,
|
||||
expectedUsers: []string{"user1"},
|
||||
},
|
||||
{
|
||||
name: "get signatures for non-existent doc",
|
||||
docID: "non-existent",
|
||||
expectedCount: 0,
|
||||
expectedUsers: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := repo.GetByDoc(ctx, tt.docID)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result) != tt.expectedCount {
|
||||
t.Errorf("Expected %d signatures, got %d", tt.expectedCount, len(result))
|
||||
}
|
||||
|
||||
// Verify order (should be by created_at DESC)
|
||||
for i, sig := range result {
|
||||
if i < len(tt.expectedUsers) && sig.UserSub != tt.expectedUsers[i] {
|
||||
t.Errorf("Expected user %s at position %d, got %s", tt.expectedUsers[i], i, sig.UserSub)
|
||||
}
|
||||
|
||||
if sig.DocID != tt.docID {
|
||||
t.Errorf("Expected DocID %s, got %s", tt.docID, sig.DocID)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepository_GetByUser_Integration(t *testing.T) {
|
||||
testDB := SetupTestDB(t)
|
||||
repo := NewSignatureRepository(testDB.DB)
|
||||
factory := NewSignatureFactory()
|
||||
ctx := context.Background()
|
||||
|
||||
// Setup: Create signatures for multiple users and docs
|
||||
sig1 := factory.CreateSignatureWithDocAndUser("doc1", "user1", "user1@example.com")
|
||||
sig2 := factory.CreateSignatureWithDocAndUser("doc2", "user1", "user1@example.com")
|
||||
sig3 := factory.CreateSignatureWithDocAndUser("doc1", "user2", "user2@example.com")
|
||||
|
||||
_ = repo.Create(ctx, sig1)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
_ = repo.Create(ctx, sig2)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
_ = repo.Create(ctx, sig3)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
userSub string
|
||||
expectedCount int
|
||||
expectedDocIDs []string
|
||||
}{
|
||||
{
|
||||
name: "get signatures for user with 2 docs",
|
||||
userSub: "user1",
|
||||
expectedCount: 2,
|
||||
expectedDocIDs: []string{"doc2", "doc1"}, // Should be ordered by created_at DESC
|
||||
},
|
||||
{
|
||||
name: "get signatures for user with 1 doc",
|
||||
userSub: "user2",
|
||||
expectedCount: 1,
|
||||
expectedDocIDs: []string{"doc1"},
|
||||
},
|
||||
{
|
||||
name: "get signatures for non-existent user",
|
||||
userSub: "non-existent",
|
||||
expectedCount: 0,
|
||||
expectedDocIDs: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := repo.GetByUser(ctx, tt.userSub)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result) != tt.expectedCount {
|
||||
t.Errorf("Expected %d signatures, got %d", tt.expectedCount, len(result))
|
||||
}
|
||||
|
||||
// Verify order and data
|
||||
for i, sig := range result {
|
||||
if i < len(tt.expectedDocIDs) && sig.DocID != tt.expectedDocIDs[i] {
|
||||
t.Errorf("Expected DocID %s at position %d, got %s", tt.expectedDocIDs[i], i, sig.DocID)
|
||||
}
|
||||
|
||||
if sig.UserSub != tt.userSub {
|
||||
t.Errorf("Expected UserSub %s, got %s", tt.userSub, sig.UserSub)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepository_ExistsByDocAndUser_Integration(t *testing.T) {
|
||||
testDB := SetupTestDB(t)
|
||||
repo := NewSignatureRepository(testDB.DB)
|
||||
factory := NewSignatureFactory()
|
||||
ctx := context.Background()
|
||||
|
||||
// Setup: Create a signature
|
||||
sig := factory.CreateSignatureWithDocAndUser("doc1", "user1", "user1@example.com")
|
||||
_ = repo.Create(ctx, sig)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
docID string
|
||||
userSub string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "existing signature",
|
||||
docID: "doc1",
|
||||
userSub: "user1",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "non-existent doc",
|
||||
docID: "non-existent",
|
||||
userSub: "user1",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "non-existent user",
|
||||
docID: "doc1",
|
||||
userSub: "non-existent",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "both non-existent",
|
||||
docID: "non-existent",
|
||||
userSub: "non-existent",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := repo.ExistsByDocAndUser(ctx, tt.docID, tt.userSub)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepository_CheckUserSignatureStatus_Integration(t *testing.T) {
|
||||
testDB := SetupTestDB(t)
|
||||
repo := NewSignatureRepository(testDB.DB)
|
||||
factory := NewSignatureFactory()
|
||||
ctx := context.Background()
|
||||
|
||||
// Setup: Create signatures with different users
|
||||
sig1 := factory.CreateSignatureWithDocAndUser("doc1", "user-sub-123", "user@EXAMPLE.COM")
|
||||
sig2 := factory.CreateSignatureWithDocAndUser("doc2", "another-user", "another@example.com")
|
||||
|
||||
_ = repo.Create(ctx, sig1)
|
||||
_ = repo.Create(ctx, sig2)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
docID string
|
||||
userIdentifier string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "check by user_sub",
|
||||
docID: "doc1",
|
||||
userIdentifier: "user-sub-123",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "check by email (case insensitive)",
|
||||
docID: "doc1",
|
||||
userIdentifier: "user@example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "check by email exact case",
|
||||
docID: "doc1",
|
||||
userIdentifier: "USER@EXAMPLE.COM",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "non-existent doc",
|
||||
docID: "non-existent",
|
||||
userIdentifier: "user-sub-123",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "non-existent user",
|
||||
docID: "doc1",
|
||||
userIdentifier: "non-existent",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "wrong doc for user",
|
||||
docID: "doc2",
|
||||
userIdentifier: "user-sub-123",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := repo.CheckUserSignatureStatus(ctx, tt.docID, tt.userIdentifier)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepository_GetLastSignature_Integration(t *testing.T) {
|
||||
testDB := SetupTestDB(t)
|
||||
repo := NewSignatureRepository(testDB.DB)
|
||||
factory := NewSignatureFactory()
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("no signatures", func(t *testing.T) {
|
||||
testDB.ClearTable(t)
|
||||
|
||||
result, err := repo.GetLastSignature(ctx)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
t.Error("Expected nil when no signatures exist")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("single signature", func(t *testing.T) {
|
||||
testDB.ClearTable(t)
|
||||
|
||||
sig := factory.CreateValidSignature()
|
||||
_ = repo.Create(ctx, sig)
|
||||
|
||||
result, err := repo.GetLastSignature(ctx)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Expected signature but got nil")
|
||||
}
|
||||
|
||||
AssertSignatureEqual(t, sig, result)
|
||||
})
|
||||
|
||||
t.Run("multiple signatures", func(t *testing.T) {
|
||||
testDB.ClearTable(t)
|
||||
|
||||
// Create signatures with different content
|
||||
sig1 := factory.CreateSignatureWithUser("user1", "user1@example.com")
|
||||
sig2 := factory.CreateSignatureWithUser("user2", "user2@example.com")
|
||||
sig3 := factory.CreateSignatureWithUser("user3", "user3@example.com")
|
||||
|
||||
_ = repo.Create(ctx, sig1)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
_ = repo.Create(ctx, sig2)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
_ = repo.Create(ctx, sig3)
|
||||
|
||||
result, err := repo.GetLastSignature(ctx)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Expected signature but got nil")
|
||||
}
|
||||
|
||||
// Should return the last created signature (sig3)
|
||||
if result.UserSub != "user3" {
|
||||
t.Errorf("Expected last signature to be user3, got %s", result.UserSub)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRepository_GetAllSignaturesOrdered_Integration(t *testing.T) {
|
||||
testDB := SetupTestDB(t)
|
||||
repo := NewSignatureRepository(testDB.DB)
|
||||
factory := NewSignatureFactory()
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("no signatures", func(t *testing.T) {
|
||||
testDB.ClearTable(t)
|
||||
|
||||
result, err := repo.GetAllSignaturesOrdered(ctx)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result) != 0 {
|
||||
t.Errorf("Expected empty slice, got %d signatures", len(result))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple signatures ordered by ID ASC", func(t *testing.T) {
|
||||
testDB.ClearTable(t)
|
||||
|
||||
// Create signatures
|
||||
sig1 := factory.CreateSignatureWithUser("user1", "user1@example.com")
|
||||
sig2 := factory.CreateSignatureWithUser("user2", "user2@example.com")
|
||||
sig3 := factory.CreateSignatureWithUser("user3", "user3@example.com")
|
||||
|
||||
_ = repo.Create(ctx, sig1)
|
||||
_ = repo.Create(ctx, sig2)
|
||||
_ = repo.Create(ctx, sig3)
|
||||
|
||||
result, err := repo.GetAllSignaturesOrdered(ctx)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result) != 3 {
|
||||
t.Errorf("Expected 3 signatures, got %d", len(result))
|
||||
}
|
||||
|
||||
// Verify order by ID ASC
|
||||
expectedUsers := []string{"user1", "user2", "user3"}
|
||||
for i, sig := range result {
|
||||
if sig.UserSub != expectedUsers[i] {
|
||||
t.Errorf("Expected user %s at position %d, got %s", expectedUsers[i], i, sig.UserSub)
|
||||
}
|
||||
|
||||
// Verify IDs are in ascending order
|
||||
if i > 0 && result[i].ID <= result[i-1].ID {
|
||||
t.Errorf("IDs not in ascending order: %d should be > %d", result[i].ID, result[i-1].ID)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,259 @@
|
||||
//go:build integration
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"ackify/internal/domain/models"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
// TestDB holds test database configuration
|
||||
type TestDB struct {
|
||||
DB *sql.DB
|
||||
DSN string
|
||||
dbName string
|
||||
}
|
||||
|
||||
// SetupTestDB creates a test database connection and runs migrations
|
||||
func SetupTestDB(t *testing.T) *TestDB {
|
||||
t.Helper()
|
||||
|
||||
// Skip if not in integrations test mode
|
||||
if os.Getenv("INTEGRATION_TESTS") == "" {
|
||||
t.Skip("Skipping integrations test (INTEGRATION_TESTS not set)")
|
||||
}
|
||||
|
||||
dsn := os.Getenv("DB_DSN")
|
||||
if dsn == "" {
|
||||
dsn = "postgres://postgres:testpassword@localhost:5432/ackify_test?sslmode=disable"
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to test database: %v", err)
|
||||
}
|
||||
|
||||
// Verify connection
|
||||
if err := db.Ping(); err != nil {
|
||||
t.Fatalf("Failed to ping test database: %v", err)
|
||||
}
|
||||
|
||||
testDB := &TestDB{
|
||||
DB: db,
|
||||
DSN: dsn,
|
||||
dbName: fmt.Sprintf("test_%d_%d", time.Now().UnixNano(), os.Getpid()),
|
||||
}
|
||||
|
||||
// Create test schema
|
||||
if err := testDB.createSchema(); err != nil {
|
||||
t.Fatalf("Failed to create test schema: %v", err)
|
||||
}
|
||||
|
||||
// Clean up on test completion
|
||||
t.Cleanup(func() {
|
||||
testDB.Cleanup()
|
||||
})
|
||||
|
||||
return testDB
|
||||
}
|
||||
|
||||
// createSchema creates the signatures table for testing
|
||||
func (tdb *TestDB) createSchema() error {
|
||||
schema := `
|
||||
-- Drop table if exists (for cleanup)
|
||||
DROP TABLE IF EXISTS signatures;
|
||||
|
||||
-- Create signatures table
|
||||
CREATE TABLE signatures (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
doc_id TEXT NOT NULL,
|
||||
user_sub TEXT NOT NULL,
|
||||
user_email TEXT NOT NULL,
|
||||
user_name TEXT,
|
||||
signed_at_utc TIMESTAMPTZ NOT NULL,
|
||||
payload_hash_b64 TEXT NOT NULL,
|
||||
signature_b64 TEXT NOT NULL,
|
||||
nonce TEXT NOT NULL,
|
||||
referer TEXT,
|
||||
prev_hash_b64 TEXT,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
|
||||
-- Constraints
|
||||
UNIQUE (doc_id, user_sub)
|
||||
);
|
||||
|
||||
-- Create indexes for performance
|
||||
CREATE INDEX idx_signatures_doc_id ON signatures(doc_id);
|
||||
CREATE INDEX idx_signatures_user_sub ON signatures(user_sub);
|
||||
CREATE INDEX idx_signatures_user_email ON signatures(user_email);
|
||||
CREATE INDEX idx_signatures_created_at ON signatures(created_at);
|
||||
CREATE INDEX idx_signatures_id_asc ON signatures(id ASC);
|
||||
`
|
||||
|
||||
_, err := tdb.DB.Exec(schema)
|
||||
return err
|
||||
}
|
||||
|
||||
// Cleanup closes the database connection and cleans up
|
||||
func (tdb *TestDB) Cleanup() {
|
||||
if tdb.DB != nil {
|
||||
// Drop all tables for cleanup
|
||||
_, _ = tdb.DB.Exec("DROP TABLE IF EXISTS signatures")
|
||||
_ = tdb.DB.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// ClearTable removes all data from the signatures table
|
||||
func (tdb *TestDB) ClearTable(t *testing.T) {
|
||||
t.Helper()
|
||||
_, err := tdb.DB.Exec("TRUNCATE TABLE signatures RESTART IDENTITY")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to clear signatures table: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// GetTableCount returns the number of rows in signatures table
|
||||
func (tdb *TestDB) GetTableCount(t *testing.T) int {
|
||||
t.Helper()
|
||||
var count int
|
||||
err := tdb.DB.QueryRow("SELECT COUNT(*) FROM signatures").Scan(&count)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get table count: %v", err)
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// SignatureFactory creates test signature objects
|
||||
type SignatureFactory struct{}
|
||||
|
||||
// CreateValidSignature creates a valid signature for testing
|
||||
func (f *SignatureFactory) CreateValidSignature() *models.Signature {
|
||||
now := time.Now().UTC()
|
||||
userName := "Test User"
|
||||
referer := "https://example.com/doc"
|
||||
|
||||
return &models.Signature{
|
||||
DocID: "test-doc-123",
|
||||
UserSub: "user-123",
|
||||
UserEmail: "test@example.com",
|
||||
UserName: &userName,
|
||||
SignedAtUTC: now,
|
||||
PayloadHashB64: "dGVzdC1wYXlsb2FkLWhhc2g=", // base64("test-payload-hash")
|
||||
SignatureB64: "dGVzdC1zaWduYXR1cmU=", // base64("test-signature")
|
||||
Nonce: "test-nonce-123",
|
||||
Referer: &referer,
|
||||
PrevHashB64: nil, // Will be set for chained signatures
|
||||
}
|
||||
}
|
||||
|
||||
// CreateSignatureWithDoc creates a signature for a specific document
|
||||
func (f *SignatureFactory) CreateSignatureWithDoc(docID string) *models.Signature {
|
||||
sig := f.CreateValidSignature()
|
||||
sig.DocID = docID
|
||||
return sig
|
||||
}
|
||||
|
||||
// CreateSignatureWithUser creates a signature for a specific user
|
||||
func (f *SignatureFactory) CreateSignatureWithUser(userSub, userEmail string) *models.Signature {
|
||||
sig := f.CreateValidSignature()
|
||||
sig.UserSub = userSub
|
||||
sig.UserEmail = userEmail
|
||||
return sig
|
||||
}
|
||||
|
||||
// CreateSignatureWithDocAndUser creates a signature for specific doc and user
|
||||
func (f *SignatureFactory) CreateSignatureWithDocAndUser(docID, userSub, userEmail string) *models.Signature {
|
||||
sig := f.CreateValidSignature()
|
||||
sig.DocID = docID
|
||||
sig.UserSub = userSub
|
||||
sig.UserEmail = userEmail
|
||||
return sig
|
||||
}
|
||||
|
||||
// CreateChainedSignature creates a signature with previous hash for chaining tests
|
||||
func (f *SignatureFactory) CreateChainedSignature(prevHashB64 string) *models.Signature {
|
||||
sig := f.CreateValidSignature()
|
||||
sig.PrevHashB64 = &prevHashB64
|
||||
return sig
|
||||
}
|
||||
|
||||
// CreateMinimalSignature creates signature with only required fields
|
||||
func (f *SignatureFactory) CreateMinimalSignature() *models.Signature {
|
||||
now := time.Now().UTC()
|
||||
|
||||
return &models.Signature{
|
||||
DocID: "minimal-doc",
|
||||
UserSub: "minimal-user",
|
||||
UserEmail: "minimal@example.com",
|
||||
UserName: nil, // NULL
|
||||
SignedAtUTC: now,
|
||||
PayloadHashB64: "bWluaW1hbA==", // base64("minimal")
|
||||
SignatureB64: "bWluaW1hbA==", // base64("minimal")
|
||||
Nonce: "minimal-nonce",
|
||||
Referer: nil, // NULL
|
||||
PrevHashB64: nil, // NULL
|
||||
}
|
||||
}
|
||||
|
||||
// AssertSignatureEqual compares two signatures for testing
|
||||
func AssertSignatureEqual(t *testing.T, expected, actual *models.Signature) {
|
||||
t.Helper()
|
||||
|
||||
if actual.DocID != expected.DocID {
|
||||
t.Errorf("DocID mismatch: got %s, want %s", actual.DocID, expected.DocID)
|
||||
}
|
||||
|
||||
if actual.UserSub != expected.UserSub {
|
||||
t.Errorf("UserSub mismatch: got %s, want %s", actual.UserSub, expected.UserSub)
|
||||
}
|
||||
|
||||
if actual.UserEmail != expected.UserEmail {
|
||||
t.Errorf("UserEmail mismatch: got %s, want %s", actual.UserEmail, expected.UserEmail)
|
||||
}
|
||||
|
||||
if !isStringPtrEqual(actual.UserName, expected.UserName) {
|
||||
t.Errorf("UserName mismatch: got %v, want %v", actual.UserName, expected.UserName)
|
||||
}
|
||||
|
||||
if actual.PayloadHashB64 != expected.PayloadHashB64 {
|
||||
t.Errorf("PayloadHashB64 mismatch: got %s, want %s", actual.PayloadHashB64, expected.PayloadHashB64)
|
||||
}
|
||||
|
||||
if actual.SignatureB64 != expected.SignatureB64 {
|
||||
t.Errorf("SignatureB64 mismatch: got %s, want %s", actual.SignatureB64, expected.SignatureB64)
|
||||
}
|
||||
|
||||
if actual.Nonce != expected.Nonce {
|
||||
t.Errorf("Nonce mismatch: got %s, want %s", actual.Nonce, expected.Nonce)
|
||||
}
|
||||
|
||||
if !isStringPtrEqual(actual.Referer, expected.Referer) {
|
||||
t.Errorf("Referer mismatch: got %v, want %v", actual.Referer, expected.Referer)
|
||||
}
|
||||
|
||||
if !isStringPtrEqual(actual.PrevHashB64, expected.PrevHashB64) {
|
||||
t.Errorf("PrevHashB64 mismatch: got %v, want %v", actual.PrevHashB64, expected.PrevHashB64)
|
||||
}
|
||||
}
|
||||
|
||||
// isStringPtrEqual compares two string pointers
|
||||
func isStringPtrEqual(a, b *string) bool {
|
||||
if a == nil && b == nil {
|
||||
return true
|
||||
}
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
return *a == *b
|
||||
}
|
||||
|
||||
// NewSignatureFactory creates a new signature factory
|
||||
func NewSignatureFactory() *SignatureFactory {
|
||||
return &SignatureFactory{}
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
)
|
||||
|
||||
// AuthHandlers handles authentication-related HTTP requests
|
||||
type AuthHandlers struct {
|
||||
authService authService
|
||||
baseURL string
|
||||
}
|
||||
|
||||
// NewAuthHandlers creates new authentication handlers
|
||||
func NewAuthHandlers(authService authService, baseURL string) *AuthHandlers {
|
||||
return &AuthHandlers{
|
||||
authService: authService,
|
||||
baseURL: baseURL,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleLogin handles login requests
|
||||
func (h *AuthHandlers) HandleLogin(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
next := r.URL.Query().Get("next")
|
||||
if next == "" {
|
||||
next = h.baseURL + "/"
|
||||
}
|
||||
|
||||
authURL := h.authService.GetAuthURL(next)
|
||||
http.Redirect(w, r, authURL, http.StatusFound)
|
||||
}
|
||||
|
||||
// HandleLogout handles logout requests
|
||||
func (h *AuthHandlers) HandleLogout(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
h.authService.Logout(w, r)
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
}
|
||||
|
||||
// HandleOAuthCallback handles OAuth callback from the configured provider
|
||||
func (h *AuthHandlers) HandleOAuthCallback(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
if code == "" {
|
||||
http.Error(w, "Missing authorization code", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
user, nextURL, err := h.authService.HandleCallback(ctx, code, state)
|
||||
if err != nil {
|
||||
HandleError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.authService.SetUser(w, r, user); err != nil {
|
||||
http.Error(w, "Failed to set user session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse and validate next URL
|
||||
if nextURL == "" {
|
||||
nextURL = "/"
|
||||
}
|
||||
|
||||
// Basic URL validation to prevent open redirects
|
||||
if parsedURL, err := url.Parse(nextURL); err != nil ||
|
||||
(parsedURL.Host != "" && parsedURL.Host != r.Host) {
|
||||
nextURL = "/"
|
||||
}
|
||||
|
||||
http.Redirect(w, r, nextURL, http.StatusFound)
|
||||
}
|
||||
@@ -0,0 +1,230 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"image"
|
||||
"image/color"
|
||||
"image/draw"
|
||||
"image/png"
|
||||
"net/http"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"ackify/internal/domain/models"
|
||||
)
|
||||
|
||||
type checkService interface {
|
||||
CheckUserSignature(ctx context.Context, docID, userIdentifier string) (bool, error)
|
||||
}
|
||||
|
||||
// BadgeHandler handles badge generation
|
||||
type BadgeHandler struct {
|
||||
checkService checkService
|
||||
}
|
||||
|
||||
// NewBadgeHandler creates a new badge handler
|
||||
func NewBadgeHandler(checkService checkService) *BadgeHandler {
|
||||
return &BadgeHandler{
|
||||
checkService: checkService,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleStatusPNG generates a PNG badge showing signature status
|
||||
func (h *BadgeHandler) HandleStatusPNG(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
docID, err := validateDocID(r)
|
||||
if err != nil {
|
||||
HandleError(w, models.ErrInvalidDocument)
|
||||
return
|
||||
}
|
||||
|
||||
userIdentifier, err := validateUserIdentifier(r)
|
||||
if err != nil {
|
||||
HandleError(w, models.ErrInvalidUser)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
isSigned, err := h.checkService.CheckUserSignature(ctx, docID, userIdentifier)
|
||||
if err != nil {
|
||||
HandleError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
badge := h.generateBadge(isSigned)
|
||||
|
||||
w.Header().Set("Content-Type", "image/png")
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
_, _ = w.Write(badge)
|
||||
}
|
||||
|
||||
const badgeSize = 64
|
||||
|
||||
// BadgeColors represents the color scheme for badges
|
||||
type BadgeColors struct {
|
||||
Background color.RGBA
|
||||
Icon color.RGBA
|
||||
Border color.RGBA
|
||||
}
|
||||
|
||||
// BadgeThemes contains predefined color schemes
|
||||
var BadgeThemes = struct {
|
||||
Success BadgeColors
|
||||
Error BadgeColors
|
||||
}{
|
||||
Success: BadgeColors{
|
||||
Background: color.RGBA{R: 240, G: 253, B: 244, A: 255}, // success-50
|
||||
Icon: color.RGBA{R: 34, G: 197, B: 94, A: 255}, // success-500
|
||||
Border: color.RGBA{R: 134, G: 239, B: 172, A: 255}, // success-300
|
||||
},
|
||||
Error: BadgeColors{
|
||||
Background: color.RGBA{R: 254, G: 242, B: 242, A: 255}, // red-50
|
||||
Icon: color.RGBA{R: 239, G: 68, B: 68, A: 255}, // red-500
|
||||
Border: color.RGBA{R: 252, G: 165, B: 165, A: 255}, // red-300
|
||||
},
|
||||
}
|
||||
|
||||
// generateBadge creates a PNG badge
|
||||
func (h *BadgeHandler) generateBadge(isSigned bool) []byte {
|
||||
img := image.NewRGBA(image.Rect(0, 0, badgeSize, badgeSize))
|
||||
|
||||
colors := h.getBadgeColors(isSigned)
|
||||
h.drawBackground(img, colors.Background)
|
||||
h.drawBorder(img, colors.Border)
|
||||
h.drawIcon(img, isSigned, colors.Icon)
|
||||
|
||||
return h.encodeToPNG(img)
|
||||
}
|
||||
|
||||
// getBadgeColors returns appropriate colors based on signing status
|
||||
func (h *BadgeHandler) getBadgeColors(isSigned bool) BadgeColors {
|
||||
if isSigned {
|
||||
return BadgeThemes.Success
|
||||
}
|
||||
return BadgeThemes.Error
|
||||
}
|
||||
|
||||
// drawBackground fills the image with background color
|
||||
func (h *BadgeHandler) drawBackground(img *image.RGBA, bgColor color.RGBA) {
|
||||
draw.Draw(img, img.Bounds(), &image.Uniform{C: bgColor}, image.Point{}, draw.Src)
|
||||
}
|
||||
|
||||
// drawBorder draws a circular border around the badge
|
||||
func (h *BadgeHandler) drawBorder(img *image.RGBA, borderColor color.RGBA) {
|
||||
cx, cy, r := badgeSize/2, badgeSize/2, badgeSize/2-3
|
||||
for y := 0; y < badgeSize; y++ {
|
||||
for x := 0; x < badgeSize; x++ {
|
||||
dx, dy := x-cx, y-cy
|
||||
dist := dx*dx + dy*dy
|
||||
if dist >= (r*r) && dist <= ((r+2)*(r+2)) {
|
||||
img.Set(x, y, borderColor)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// drawIcon draws the appropriate icon based on signing status
|
||||
func (h *BadgeHandler) drawIcon(img *image.RGBA, isSigned bool, iconColor color.RGBA) {
|
||||
if isSigned {
|
||||
h.drawCheckmark(img, badgeSize, iconColor)
|
||||
} else {
|
||||
h.drawX(img, badgeSize, iconColor)
|
||||
}
|
||||
}
|
||||
|
||||
// encodeToPNG encodes the image to PNG format
|
||||
func (h *BadgeHandler) encodeToPNG(img *image.RGBA) []byte {
|
||||
buf := bytes.NewBuffer(nil)
|
||||
_ = png.Encode(buf, img)
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// drawCheckmark draws a checkmark icon
|
||||
func (h *BadgeHandler) drawCheckmark(img *image.RGBA, size int, col color.RGBA) {
|
||||
cx, cy := size/2, size/2
|
||||
scale := float64(size) / 64.0
|
||||
|
||||
// Checkmark path points (scaled)
|
||||
points := [][2]int{
|
||||
{int(18 * scale), int(32 * scale)},
|
||||
{int(28 * scale), int(42 * scale)},
|
||||
{int(46 * scale), int(22 * scale)},
|
||||
}
|
||||
|
||||
thickness := int(3 * scale)
|
||||
if thickness < 2 {
|
||||
thickness = 2
|
||||
}
|
||||
|
||||
// Draw first stroke (left part of check)
|
||||
h.drawThickLine(img, cx+points[0][0]-cx, cy+points[0][1]-cy,
|
||||
cx+points[1][0]-cx, cy+points[1][1]-cy, thickness, col)
|
||||
|
||||
// Draw second stroke (right part of check)
|
||||
h.drawThickLine(img, cx+points[1][0]-cx, cy+points[1][1]-cy,
|
||||
cx+points[2][0]-cx, cy+points[2][1]-cy, thickness, col)
|
||||
}
|
||||
|
||||
// drawX draws an X icon
|
||||
func (h *BadgeHandler) drawX(img *image.RGBA, size int, col color.RGBA) {
|
||||
cx, cy := size/2, size/2
|
||||
offset := int(float64(size) * 0.3)
|
||||
thickness := size / 12
|
||||
if thickness < 2 {
|
||||
thickness = 2
|
||||
}
|
||||
|
||||
// Draw diagonal lines for X
|
||||
h.drawThickLine(img, cx-offset, cy-offset, cx+offset, cy+offset, thickness, col)
|
||||
h.drawThickLine(img, cx-offset, cy+offset, cx+offset, cy-offset, thickness, col)
|
||||
}
|
||||
|
||||
// drawThickLine draws a thick line using Bresenham's algorithm
|
||||
func (h *BadgeHandler) drawThickLine(img *image.RGBA, x0, y0, x1, y1, thickness int, col color.RGBA) {
|
||||
dx := abs(x1 - x0)
|
||||
dy := abs(y1 - y0)
|
||||
sx := -1
|
||||
if x0 < x1 {
|
||||
sx = 1
|
||||
}
|
||||
sy := -1
|
||||
if y0 < y1 {
|
||||
sy = 1
|
||||
}
|
||||
err := dx - dy
|
||||
|
||||
x, y := x0, y0
|
||||
for {
|
||||
// Draw thick point
|
||||
for i := -thickness / 2; i <= thickness/2; i++ {
|
||||
for j := -thickness / 2; j <= thickness/2; j++ {
|
||||
px, py := x+i, y+j
|
||||
if px >= 0 && px < img.Bounds().Dx() && py >= 0 && py < img.Bounds().Dy() {
|
||||
img.Set(px, py, col)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if x == x1 && y == y1 {
|
||||
break
|
||||
}
|
||||
|
||||
e2 := 2 * err
|
||||
if e2 > -dy {
|
||||
err -= dy
|
||||
x += sx
|
||||
}
|
||||
if e2 < dx {
|
||||
err += dx
|
||||
y += sy
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// abs returns absolute value
|
||||
func abs(x int) int {
|
||||
if x < 0 {
|
||||
return -x
|
||||
}
|
||||
return x
|
||||
}
|
||||
@@ -0,0 +1,713 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"ackify/internal/domain/models"
|
||||
)
|
||||
|
||||
// Fake implementations for testing
|
||||
|
||||
type fakeAuthService struct {
|
||||
shouldFailSetUser bool
|
||||
shouldFailCallback bool
|
||||
setUserError error
|
||||
callbackUser *models.User
|
||||
callbackNextURL string
|
||||
callbackError error
|
||||
authURL string
|
||||
logoutCalled bool
|
||||
}
|
||||
|
||||
func newFakeAuthService() *fakeAuthService {
|
||||
return &fakeAuthService{
|
||||
authURL: "https://oauth.example.com/auth",
|
||||
callbackUser: &models.User{Sub: "test-user", Email: "test@example.com", Name: "Test User"},
|
||||
callbackNextURL: "/",
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeAuthService) SetUser(_ http.ResponseWriter, _ *http.Request, _ *models.User) error {
|
||||
if f.shouldFailSetUser {
|
||||
return f.setUserError
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeAuthService) Logout(_ http.ResponseWriter, _ *http.Request) {
|
||||
f.logoutCalled = true
|
||||
}
|
||||
|
||||
func (f *fakeAuthService) GetAuthURL(nextURL string) string {
|
||||
return f.authURL + "?next=" + url.QueryEscape(nextURL)
|
||||
}
|
||||
|
||||
func (f *fakeAuthService) HandleCallback(_ context.Context, _, _ string) (*models.User, string, error) {
|
||||
if f.shouldFailCallback {
|
||||
return nil, "", f.callbackError
|
||||
}
|
||||
return f.callbackUser, f.callbackNextURL, nil
|
||||
}
|
||||
|
||||
type fakeUserService struct {
|
||||
user *models.User
|
||||
shouldFail bool
|
||||
getUserError error
|
||||
}
|
||||
|
||||
func newFakeUserService() *fakeUserService {
|
||||
return &fakeUserService{
|
||||
user: &models.User{Sub: "test-user", Email: "test@example.com", Name: "Test User"},
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeUserService) GetUser(_ *http.Request) (*models.User, error) {
|
||||
if f.shouldFail {
|
||||
return nil, f.getUserError
|
||||
}
|
||||
return f.user, nil
|
||||
}
|
||||
|
||||
type fakeSignatureService struct {
|
||||
shouldFailCreate bool
|
||||
shouldFailGetStatus bool
|
||||
shouldFailGetByDocUser bool
|
||||
shouldFailGetDoc bool
|
||||
shouldFailGetUser bool
|
||||
shouldFailCheck bool
|
||||
createError error
|
||||
statusResult *models.SignatureStatus
|
||||
getStatusError error
|
||||
signature *models.Signature
|
||||
getSignatureError error
|
||||
docSignatures []*models.Signature
|
||||
getDocError error
|
||||
userSignatures []*models.Signature
|
||||
getUserError error
|
||||
checkResult bool
|
||||
checkError error
|
||||
}
|
||||
|
||||
func newFakeSignatureService() *fakeSignatureService {
|
||||
return &fakeSignatureService{
|
||||
statusResult: &models.SignatureStatus{
|
||||
DocID: "test-doc",
|
||||
UserEmail: "test@example.com",
|
||||
IsSigned: false,
|
||||
SignedAt: nil,
|
||||
},
|
||||
signature: &models.Signature{
|
||||
ID: 1,
|
||||
DocID: "test-doc",
|
||||
UserSub: "test-user",
|
||||
UserEmail: "test@example.com",
|
||||
SignedAtUTC: time.Now().UTC(),
|
||||
},
|
||||
docSignatures: []*models.Signature{
|
||||
{
|
||||
ID: 1,
|
||||
DocID: "test-doc",
|
||||
UserSub: "test-user",
|
||||
UserEmail: "test@example.com",
|
||||
SignedAtUTC: time.Now().UTC(),
|
||||
},
|
||||
},
|
||||
userSignatures: []*models.Signature{
|
||||
{
|
||||
ID: 1,
|
||||
DocID: "test-doc",
|
||||
UserSub: "test-user",
|
||||
UserEmail: "test@example.com",
|
||||
SignedAtUTC: time.Now().UTC(),
|
||||
},
|
||||
},
|
||||
checkResult: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeSignatureService) CreateSignature(_ context.Context, _ *models.SignatureRequest) error {
|
||||
if f.shouldFailCreate {
|
||||
return f.createError
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeSignatureService) GetSignatureStatus(_ context.Context, _ string, _ *models.User) (*models.SignatureStatus, error) {
|
||||
if f.shouldFailGetStatus {
|
||||
return nil, f.getStatusError
|
||||
}
|
||||
return f.statusResult, nil
|
||||
}
|
||||
|
||||
func (f *fakeSignatureService) GetSignatureByDocAndUser(_ context.Context, _ string, _ *models.User) (*models.Signature, error) {
|
||||
if f.shouldFailGetByDocUser {
|
||||
return nil, f.getSignatureError
|
||||
}
|
||||
return f.signature, nil
|
||||
}
|
||||
|
||||
func (f *fakeSignatureService) GetDocumentSignatures(_ context.Context, _ string) ([]*models.Signature, error) {
|
||||
if f.shouldFailGetDoc {
|
||||
return nil, f.getDocError
|
||||
}
|
||||
return f.docSignatures, nil
|
||||
}
|
||||
|
||||
func (f *fakeSignatureService) GetUserSignatures(_ context.Context, _ *models.User) ([]*models.Signature, error) {
|
||||
if f.shouldFailGetUser {
|
||||
return nil, f.getUserError
|
||||
}
|
||||
return f.userSignatures, nil
|
||||
}
|
||||
|
||||
func (f *fakeSignatureService) CheckUserSignature(_ context.Context, _, _ string) (bool, error) {
|
||||
if f.shouldFailCheck {
|
||||
return false, f.checkError
|
||||
}
|
||||
return f.checkResult, nil
|
||||
}
|
||||
|
||||
// Test helpers
|
||||
|
||||
func createTestTemplate() *template.Template {
|
||||
tmpl := template.New("test")
|
||||
template.Must(tmpl.New("base").Parse(`<html><body>{{.TemplateName}}</body></html>`))
|
||||
return tmpl
|
||||
}
|
||||
|
||||
func TestAuthHandlers_NewAuthHandlers(t *testing.T) {
|
||||
authService := newFakeAuthService()
|
||||
baseURL := "https://example.com"
|
||||
|
||||
handlers := NewAuthHandlers(authService, baseURL)
|
||||
|
||||
if handlers == nil {
|
||||
t.Error("NewAuthHandlers should not return nil")
|
||||
} else if handlers.authService != authService {
|
||||
t.Error("AuthService not set correctly")
|
||||
} else if handlers.baseURL != baseURL {
|
||||
t.Error("BaseURL not set correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandlers_HandleLogin(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
nextParam string
|
||||
expectedURL string
|
||||
}{
|
||||
{
|
||||
name: "login with next parameter",
|
||||
nextParam: "/sign?doc=test",
|
||||
expectedURL: "https://oauth.example.com/auth?next=" + url.QueryEscape("/sign?doc=test"),
|
||||
},
|
||||
{
|
||||
name: "login without next parameter",
|
||||
nextParam: "",
|
||||
expectedURL: "https://oauth.example.com/auth?next=" + url.QueryEscape("https://example.com/"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
authService := newFakeAuthService()
|
||||
handlers := NewAuthHandlers(authService, "https://example.com")
|
||||
|
||||
req := httptest.NewRequest("GET", "/login", nil)
|
||||
if tt.nextParam != "" {
|
||||
q := req.URL.Query()
|
||||
q.Set("next", tt.nextParam)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
handlers.HandleLogin(w, req, nil)
|
||||
|
||||
if w.Code != http.StatusFound {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusFound, w.Code)
|
||||
}
|
||||
|
||||
location := w.Header().Get("Location")
|
||||
if location != tt.expectedURL {
|
||||
t.Errorf("Expected redirect to %s, got %s", tt.expectedURL, location)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandlers_HandleLogout(t *testing.T) {
|
||||
authService := newFakeAuthService()
|
||||
handlers := NewAuthHandlers(authService, "https://example.com")
|
||||
|
||||
req := httptest.NewRequest("GET", "/logout", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handlers.HandleLogout(w, req, nil)
|
||||
|
||||
if w.Code != http.StatusFound {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusFound, w.Code)
|
||||
}
|
||||
|
||||
location := w.Header().Get("Location")
|
||||
if location != "/" {
|
||||
t.Errorf("Expected redirect to /, got %s", location)
|
||||
}
|
||||
|
||||
if !authService.logoutCalled {
|
||||
t.Error("Logout should have been called on auth service")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandlers_HandleOAuthCallback(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
code string
|
||||
state string
|
||||
setupAuth func(*fakeAuthService)
|
||||
expectedStatus int
|
||||
expectedRedirect string
|
||||
}{
|
||||
{
|
||||
name: "successful callback",
|
||||
code: "test-code",
|
||||
state: "test-state",
|
||||
setupAuth: func(a *fakeAuthService) {},
|
||||
expectedStatus: http.StatusFound,
|
||||
expectedRedirect: "/",
|
||||
},
|
||||
{
|
||||
name: "missing code",
|
||||
code: "",
|
||||
state: "test-state",
|
||||
setupAuth: func(a *fakeAuthService) {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "callback fails",
|
||||
code: "test-code",
|
||||
state: "test-state",
|
||||
setupAuth: func(a *fakeAuthService) {
|
||||
a.shouldFailCallback = true
|
||||
a.callbackError = models.ErrDomainNotAllowed
|
||||
},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "set user fails",
|
||||
code: "test-code",
|
||||
state: "test-state",
|
||||
setupAuth: func(a *fakeAuthService) {
|
||||
a.shouldFailSetUser = true
|
||||
a.setUserError = errors.New("session error")
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
authService := newFakeAuthService()
|
||||
tt.setupAuth(authService)
|
||||
handlers := NewAuthHandlers(authService, "https://example.com")
|
||||
|
||||
req := httptest.NewRequest("GET", "/oauth2/callback", nil)
|
||||
q := req.URL.Query()
|
||||
if tt.code != "" {
|
||||
q.Set("code", tt.code)
|
||||
}
|
||||
if tt.state != "" {
|
||||
q.Set("state", tt.state)
|
||||
}
|
||||
req.URL.RawQuery = q.Encode()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
handlers.HandleOAuthCallback(w, req, nil)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
|
||||
}
|
||||
|
||||
if tt.expectedRedirect != "" {
|
||||
location := w.Header().Get("Location")
|
||||
if location != tt.expectedRedirect {
|
||||
t.Errorf("Expected redirect to %s, got %s", tt.expectedRedirect, location)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignatureHandlers_NewSignatureHandlers(t *testing.T) {
|
||||
signatureService := newFakeSignatureService()
|
||||
userService := newFakeUserService()
|
||||
tmpl := createTestTemplate()
|
||||
baseURL := "https://example.com"
|
||||
|
||||
handlers := NewSignatureHandlers(signatureService, userService, tmpl, baseURL)
|
||||
|
||||
if handlers == nil {
|
||||
t.Error("NewSignatureHandlers should not return nil")
|
||||
} else if handlers.signatureService != signatureService {
|
||||
t.Error("SignatureService not set correctly")
|
||||
} else if handlers.userService != userService {
|
||||
t.Error("UserService not set correctly")
|
||||
} else if handlers.template != tmpl {
|
||||
t.Error("Template not set correctly")
|
||||
} else if handlers.baseURL != baseURL {
|
||||
t.Error("BaseURL not set correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignatureHandlers_HandleIndex(t *testing.T) {
|
||||
signatureService := newFakeSignatureService()
|
||||
userService := newFakeUserService()
|
||||
tmpl := createTestTemplate()
|
||||
handlers := NewSignatureHandlers(signatureService, userService, tmpl, "https://example.com")
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handlers.HandleIndex(w, req, nil)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
contentType := w.Header().Get("Content-Type")
|
||||
if !strings.Contains(contentType, "text/html") {
|
||||
t.Errorf("Expected HTML content type, got %s", contentType)
|
||||
}
|
||||
|
||||
body := w.Body.String()
|
||||
if !strings.Contains(body, "index") {
|
||||
t.Error("Response should contain template name 'index'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignatureHandlers_HandleSignGET(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
docParam string
|
||||
setupUser func(*fakeUserService)
|
||||
setupSig func(*fakeSignatureService)
|
||||
expectedStatus int
|
||||
shouldRedirect bool
|
||||
}{
|
||||
{
|
||||
name: "successful sign page load - not signed",
|
||||
docParam: "test-doc",
|
||||
setupUser: func(u *fakeUserService) {},
|
||||
setupSig: func(s *fakeSignatureService) {
|
||||
s.statusResult.IsSigned = false
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "successful sign page load - already signed",
|
||||
docParam: "test-doc",
|
||||
setupUser: func(u *fakeUserService) {},
|
||||
setupSig: func(s *fakeSignatureService) {
|
||||
s.statusResult.IsSigned = true
|
||||
signedAt := time.Now().UTC()
|
||||
s.statusResult.SignedAt = &signedAt
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "user not authenticated",
|
||||
docParam: "test-doc",
|
||||
setupUser: func(u *fakeUserService) {
|
||||
u.shouldFail = true
|
||||
u.getUserError = models.ErrUnauthorized
|
||||
},
|
||||
setupSig: func(s *fakeSignatureService) {},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "missing doc parameter",
|
||||
docParam: "",
|
||||
setupUser: func(u *fakeUserService) {},
|
||||
setupSig: func(s *fakeSignatureService) {},
|
||||
expectedStatus: http.StatusFound,
|
||||
shouldRedirect: true,
|
||||
},
|
||||
{
|
||||
name: "signature service fails",
|
||||
docParam: "test-doc",
|
||||
setupUser: func(u *fakeUserService) {},
|
||||
setupSig: func(s *fakeSignatureService) {
|
||||
s.shouldFailGetStatus = true
|
||||
s.getStatusError = errors.New("service error")
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
signatureService := newFakeSignatureService()
|
||||
userService := newFakeUserService()
|
||||
tt.setupUser(userService)
|
||||
tt.setupSig(signatureService)
|
||||
|
||||
tmpl := createTestTemplate()
|
||||
handlers := NewSignatureHandlers(signatureService, userService, tmpl, "https://example.com")
|
||||
|
||||
req := httptest.NewRequest("GET", "/sign", nil)
|
||||
if tt.docParam != "" {
|
||||
q := req.URL.Query()
|
||||
q.Set("doc", tt.docParam)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
handlers.HandleSignGET(w, req, nil)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
|
||||
}
|
||||
|
||||
if tt.shouldRedirect {
|
||||
location := w.Header().Get("Location")
|
||||
if location == "" {
|
||||
t.Error("Expected redirect but no Location header found")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignatureHandlers_HandleSignPOST(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
formData map[string]string
|
||||
setupUser func(*fakeUserService)
|
||||
setupSig func(*fakeSignatureService)
|
||||
expectedStatus int
|
||||
shouldRedirect bool
|
||||
}{
|
||||
{
|
||||
name: "successful signature creation",
|
||||
formData: map[string]string{
|
||||
"doc": "test-doc",
|
||||
},
|
||||
setupUser: func(u *fakeUserService) {},
|
||||
setupSig: func(s *fakeSignatureService) {},
|
||||
expectedStatus: http.StatusFound,
|
||||
shouldRedirect: true,
|
||||
},
|
||||
{
|
||||
name: "signature already exists",
|
||||
formData: map[string]string{
|
||||
"doc": "test-doc",
|
||||
},
|
||||
setupUser: func(u *fakeUserService) {},
|
||||
setupSig: func(s *fakeSignatureService) {
|
||||
s.shouldFailCreate = true
|
||||
s.createError = models.ErrSignatureAlreadyExists
|
||||
},
|
||||
expectedStatus: http.StatusFound,
|
||||
shouldRedirect: true,
|
||||
},
|
||||
{
|
||||
name: "user not authenticated",
|
||||
formData: map[string]string{
|
||||
"doc": "test-doc",
|
||||
},
|
||||
setupUser: func(u *fakeUserService) {
|
||||
u.shouldFail = true
|
||||
u.getUserError = models.ErrUnauthorized
|
||||
},
|
||||
setupSig: func(s *fakeSignatureService) {},
|
||||
expectedStatus: http.StatusFound,
|
||||
shouldRedirect: true,
|
||||
},
|
||||
{
|
||||
name: "missing doc parameter",
|
||||
formData: map[string]string{},
|
||||
setupUser: func(u *fakeUserService) {},
|
||||
setupSig: func(s *fakeSignatureService) {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "signature service fails",
|
||||
formData: map[string]string{
|
||||
"doc": "test-doc",
|
||||
},
|
||||
setupUser: func(u *fakeUserService) {},
|
||||
setupSig: func(s *fakeSignatureService) {
|
||||
s.shouldFailCreate = true
|
||||
s.createError = errors.New("service error")
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
signatureService := newFakeSignatureService()
|
||||
userService := newFakeUserService()
|
||||
tt.setupUser(userService)
|
||||
tt.setupSig(signatureService)
|
||||
|
||||
tmpl := createTestTemplate()
|
||||
handlers := NewSignatureHandlers(signatureService, userService, tmpl, "https://example.com")
|
||||
|
||||
// Create form data
|
||||
form := url.Values{}
|
||||
for key, value := range tt.formData {
|
||||
form.Set(key, value)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/sign", strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
handlers.HandleSignPOST(w, req, nil)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
|
||||
}
|
||||
|
||||
if tt.shouldRedirect {
|
||||
location := w.Header().Get("Location")
|
||||
if location == "" {
|
||||
t.Error("Expected redirect but no Location header found")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignatureHandlers_HandleStatusJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
docParam string
|
||||
setupSig func(*fakeSignatureService)
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "successful status JSON",
|
||||
docParam: "test-doc",
|
||||
setupSig: func(s *fakeSignatureService) {},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "missing doc parameter",
|
||||
docParam: "",
|
||||
setupSig: func(s *fakeSignatureService) {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "service fails",
|
||||
docParam: "test-doc",
|
||||
setupSig: func(s *fakeSignatureService) {
|
||||
s.shouldFailGetDoc = true
|
||||
s.getDocError = errors.New("service error")
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
signatureService := newFakeSignatureService()
|
||||
userService := newFakeUserService()
|
||||
tt.setupSig(signatureService)
|
||||
|
||||
tmpl := createTestTemplate()
|
||||
handlers := NewSignatureHandlers(signatureService, userService, tmpl, "https://example.com")
|
||||
|
||||
req := httptest.NewRequest("GET", "/status", nil)
|
||||
if tt.docParam != "" {
|
||||
q := req.URL.Query()
|
||||
q.Set("doc", tt.docParam)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
handlers.HandleStatusJSON(w, req, nil)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
|
||||
}
|
||||
|
||||
if tt.expectedStatus == http.StatusOK {
|
||||
contentType := w.Header().Get("Content-Type")
|
||||
if !strings.Contains(contentType, "application/json") {
|
||||
t.Errorf("Expected JSON content type, got %s", contentType)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignatureHandlers_HandleUserSignatures(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupUser func(*fakeUserService)
|
||||
setupSig func(*fakeSignatureService)
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "successful user signatures",
|
||||
setupUser: func(u *fakeUserService) {},
|
||||
setupSig: func(s *fakeSignatureService) {},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "user not authenticated",
|
||||
setupUser: func(u *fakeUserService) {
|
||||
u.shouldFail = true
|
||||
u.getUserError = models.ErrUnauthorized
|
||||
},
|
||||
setupSig: func(s *fakeSignatureService) {},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "service fails",
|
||||
setupUser: func(u *fakeUserService) {},
|
||||
setupSig: func(s *fakeSignatureService) {
|
||||
s.shouldFailGetUser = true
|
||||
s.getUserError = errors.New("service error")
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
signatureService := newFakeSignatureService()
|
||||
userService := newFakeUserService()
|
||||
tt.setupUser(userService)
|
||||
tt.setupSig(signatureService)
|
||||
|
||||
tmpl := createTestTemplate()
|
||||
handlers := NewSignatureHandlers(signatureService, userService, tmpl, "https://example.com")
|
||||
|
||||
req := httptest.NewRequest("GET", "/signatures", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handlers.HandleUserSignatures(w, req, nil)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
|
||||
}
|
||||
|
||||
if tt.expectedStatus == http.StatusOK {
|
||||
contentType := w.Header().Get("Content-Type")
|
||||
if !strings.Contains(contentType, "text/html") {
|
||||
t.Errorf("Expected HTML content type, got %s", contentType)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,718 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"ackify/internal/domain/models"
|
||||
)
|
||||
|
||||
// Badge Handler Tests
|
||||
|
||||
func TestBadgeHandler_NewBadgeHandler(t *testing.T) {
|
||||
checkService := newFakeSignatureService()
|
||||
handler := NewBadgeHandler(checkService)
|
||||
|
||||
if handler == nil {
|
||||
t.Error("NewBadgeHandler should not return nil")
|
||||
} else if handler.checkService != checkService {
|
||||
t.Error("CheckService not set correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBadgeHandler_HandleStatusPNG(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
docParam string
|
||||
userParam string
|
||||
setupService func(*fakeSignatureService)
|
||||
expectedStatus int
|
||||
expectedType string
|
||||
}{
|
||||
{
|
||||
name: "successful badge - signed",
|
||||
docParam: "test-doc",
|
||||
userParam: "test@example.com",
|
||||
setupService: func(s *fakeSignatureService) { s.checkResult = true },
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedType: "image/png",
|
||||
},
|
||||
{
|
||||
name: "successful badge - not signed",
|
||||
docParam: "test-doc",
|
||||
userParam: "test@example.com",
|
||||
setupService: func(s *fakeSignatureService) { s.checkResult = false },
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedType: "image/png",
|
||||
},
|
||||
{
|
||||
name: "missing doc parameter",
|
||||
docParam: "",
|
||||
userParam: "test@example.com",
|
||||
setupService: func(s *fakeSignatureService) {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "missing user parameter",
|
||||
docParam: "test-doc",
|
||||
userParam: "",
|
||||
setupService: func(s *fakeSignatureService) {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "service fails",
|
||||
docParam: "test-doc",
|
||||
userParam: "test@example.com",
|
||||
setupService: func(s *fakeSignatureService) {
|
||||
s.shouldFailCheck = true
|
||||
s.checkError = errors.New("service error")
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
service := newFakeSignatureService()
|
||||
tt.setupService(service)
|
||||
handler := NewBadgeHandler(service)
|
||||
|
||||
req := httptest.NewRequest("GET", "/status.png", nil)
|
||||
q := req.URL.Query()
|
||||
if tt.docParam != "" {
|
||||
q.Set("doc", tt.docParam)
|
||||
}
|
||||
if tt.userParam != "" {
|
||||
q.Set("user", tt.userParam)
|
||||
}
|
||||
req.URL.RawQuery = q.Encode()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
handler.HandleStatusPNG(w, req, nil)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
|
||||
}
|
||||
|
||||
if tt.expectedType != "" {
|
||||
contentType := w.Header().Get("Content-Type")
|
||||
if contentType != tt.expectedType {
|
||||
t.Errorf("Expected content type %s, got %s", tt.expectedType, contentType)
|
||||
}
|
||||
|
||||
cacheControl := w.Header().Get("Cache-Control")
|
||||
if cacheControl != "no-store" {
|
||||
t.Errorf("Expected Cache-Control: no-store, got %s", cacheControl)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Health Handler Tests
|
||||
|
||||
func TestHealthHandler_NewHealthHandler(t *testing.T) {
|
||||
handler := NewHealthHandler()
|
||||
if handler == nil {
|
||||
t.Error("NewHealthHandler should not return nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthHandler_HandleHealth(t *testing.T) {
|
||||
handler := NewHealthHandler()
|
||||
|
||||
req := httptest.NewRequest("GET", "/health", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.HandleHealth(w, req, nil)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
contentType := w.Header().Get("Content-Type")
|
||||
if !strings.Contains(contentType, "application/json") {
|
||||
t.Errorf("Expected JSON content type, got %s", contentType)
|
||||
}
|
||||
|
||||
body := w.Body.String()
|
||||
if !strings.Contains(body, `"ok":true`) {
|
||||
t.Error("Response should contain ok:true")
|
||||
}
|
||||
if !strings.Contains(body, `"time"`) {
|
||||
t.Error("Response should contain time field")
|
||||
}
|
||||
}
|
||||
|
||||
// OEmbed Handler Tests
|
||||
|
||||
func TestOEmbedHandler_NewOEmbedHandler(t *testing.T) {
|
||||
service := newFakeSignatureService()
|
||||
tmpl := createTestTemplate()
|
||||
baseURL := "https://example.com"
|
||||
org := "Test Org"
|
||||
|
||||
handler := NewOEmbedHandler(service, tmpl, baseURL, org)
|
||||
|
||||
if handler == nil {
|
||||
t.Error("NewOEmbedHandler should not return nil")
|
||||
} else if handler.signatureService != service {
|
||||
t.Error("SignatureService not set correctly")
|
||||
} else if handler.template != tmpl {
|
||||
t.Error("Template not set correctly")
|
||||
} else if handler.baseURL != baseURL {
|
||||
t.Error("BaseURL not set correctly")
|
||||
} else if handler.organisation != org {
|
||||
t.Error("Organisation not set correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOEmbedHandler_HandleOEmbed(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
urlParam string
|
||||
formatParam string
|
||||
setupService func(*fakeSignatureService)
|
||||
expectedStatus int
|
||||
expectedType string
|
||||
}{
|
||||
{
|
||||
name: "successful oembed",
|
||||
urlParam: "https://example.com/embed?doc=test-doc",
|
||||
formatParam: "json",
|
||||
setupService: func(s *fakeSignatureService) {},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedType: "application/json",
|
||||
},
|
||||
{
|
||||
name: "default format (json)",
|
||||
urlParam: "https://example.com/embed?doc=test-doc",
|
||||
formatParam: "",
|
||||
setupService: func(s *fakeSignatureService) {},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedType: "application/json",
|
||||
},
|
||||
{
|
||||
name: "unsupported format",
|
||||
urlParam: "https://example.com/embed?doc=test-doc",
|
||||
formatParam: "xml",
|
||||
setupService: func(s *fakeSignatureService) {},
|
||||
expectedStatus: http.StatusNotImplemented,
|
||||
},
|
||||
{
|
||||
name: "missing url parameter",
|
||||
urlParam: "",
|
||||
formatParam: "json",
|
||||
setupService: func(s *fakeSignatureService) {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "invalid url format",
|
||||
urlParam: "https://example.com/embed",
|
||||
formatParam: "json",
|
||||
setupService: func(s *fakeSignatureService) {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "service fails",
|
||||
urlParam: "https://example.com/embed?doc=test-doc",
|
||||
formatParam: "json",
|
||||
setupService: func(s *fakeSignatureService) {
|
||||
s.shouldFailGetDoc = true
|
||||
s.getDocError = errors.New("service error")
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
service := newFakeSignatureService()
|
||||
tt.setupService(service)
|
||||
tmpl := createTestTemplate()
|
||||
// Add embed template for OEmbed tests
|
||||
template.Must(tmpl.New("embed").Parse(`<div>{{.DocID}} - {{.Count}} signatures</div>`))
|
||||
handler := NewOEmbedHandler(service, tmpl, "https://example.com", "Test Org")
|
||||
|
||||
req := httptest.NewRequest("GET", "/oembed", nil)
|
||||
q := req.URL.Query()
|
||||
if tt.urlParam != "" {
|
||||
q.Set("url", tt.urlParam)
|
||||
}
|
||||
if tt.formatParam != "" {
|
||||
q.Set("format", tt.formatParam)
|
||||
}
|
||||
req.URL.RawQuery = q.Encode()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
handler.HandleOEmbed(w, req, nil)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
|
||||
}
|
||||
|
||||
if tt.expectedType != "" {
|
||||
contentType := w.Header().Get("Content-Type")
|
||||
if !strings.Contains(contentType, tt.expectedType) {
|
||||
t.Errorf("Expected content type %s, got %s", tt.expectedType, contentType)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOEmbedHandler_HandleEmbedView(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
docParam string
|
||||
setupService func(*fakeSignatureService)
|
||||
expectedStatus int
|
||||
expectedType string
|
||||
}{
|
||||
{
|
||||
name: "successful embed view",
|
||||
docParam: "test-doc",
|
||||
setupService: func(s *fakeSignatureService) {},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedType: "text/html",
|
||||
},
|
||||
{
|
||||
name: "missing doc parameter",
|
||||
docParam: "",
|
||||
setupService: func(s *fakeSignatureService) {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "service fails",
|
||||
docParam: "test-doc",
|
||||
setupService: func(s *fakeSignatureService) {
|
||||
s.shouldFailGetDoc = true
|
||||
s.getDocError = errors.New("service error")
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
service := newFakeSignatureService()
|
||||
tt.setupService(service)
|
||||
tmpl := createTestTemplate()
|
||||
// Add embed template
|
||||
template.Must(tmpl.New("embed").Parse(`<div>{{.DocID}} - {{.Count}} signatures</div>`))
|
||||
|
||||
handler := NewOEmbedHandler(service, tmpl, "https://example.com", "Test Org")
|
||||
|
||||
req := httptest.NewRequest("GET", "/embed", nil)
|
||||
if tt.docParam != "" {
|
||||
q := req.URL.Query()
|
||||
q.Set("doc", tt.docParam)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
handler.HandleEmbedView(w, req, nil)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
|
||||
}
|
||||
|
||||
if tt.expectedType != "" {
|
||||
contentType := w.Header().Get("Content-Type")
|
||||
if !strings.Contains(contentType, tt.expectedType) {
|
||||
t.Errorf("Expected content type %s, got %s", tt.expectedType, contentType)
|
||||
}
|
||||
|
||||
frameOptions := w.Header().Get("X-Frame-Options")
|
||||
if frameOptions != "ALLOWALL" {
|
||||
t.Errorf("Expected X-Frame-Options: ALLOWALL, got %s", frameOptions)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOEmbedHandler_extractDocIDFromURL(t *testing.T) {
|
||||
handler := &OEmbedHandler{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
expected string
|
||||
shouldErr bool
|
||||
}{
|
||||
{
|
||||
name: "extract from query parameter",
|
||||
url: "https://example.com/embed?doc=test-doc",
|
||||
expected: "test-doc",
|
||||
},
|
||||
{
|
||||
name: "extract from embed path",
|
||||
url: "https://example.com/embed/test-doc",
|
||||
expected: "test-doc",
|
||||
},
|
||||
{
|
||||
name: "extract from status path",
|
||||
url: "https://example.com/status/test-doc",
|
||||
expected: "test-doc",
|
||||
},
|
||||
{
|
||||
name: "extract from sign path",
|
||||
url: "https://example.com/sign/test-doc",
|
||||
expected: "test-doc",
|
||||
},
|
||||
{
|
||||
name: "invalid url",
|
||||
url: "not-a-url",
|
||||
shouldErr: true,
|
||||
},
|
||||
{
|
||||
name: "no doc id found",
|
||||
url: "https://example.com/other",
|
||||
shouldErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := handler.extractDocIDFromURL(tt.url)
|
||||
|
||||
if tt.shouldErr {
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %s, got %s", tt.expected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Middleware Tests
|
||||
|
||||
func TestAuthMiddleware_NewAuthMiddleware(t *testing.T) {
|
||||
userService := newFakeUserService()
|
||||
baseURL := "https://example.com"
|
||||
|
||||
middleware := NewAuthMiddleware(userService, baseURL)
|
||||
|
||||
if middleware == nil {
|
||||
t.Error("NewAuthMiddleware should not return nil")
|
||||
} else if middleware.userService != userService {
|
||||
t.Error("UserService not set correctly")
|
||||
} else if middleware.baseURL != baseURL {
|
||||
t.Error("BaseURL not set correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_RequireAuth(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupUser func(*fakeUserService)
|
||||
expectedStatus int
|
||||
shouldRedirect bool
|
||||
}{
|
||||
{
|
||||
name: "authenticated user",
|
||||
setupUser: func(u *fakeUserService) {},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "unauthenticated user",
|
||||
setupUser: func(u *fakeUserService) {
|
||||
u.shouldFail = true
|
||||
u.getUserError = models.ErrUnauthorized
|
||||
},
|
||||
expectedStatus: http.StatusFound,
|
||||
shouldRedirect: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
userService := newFakeUserService()
|
||||
tt.setupUser(userService)
|
||||
middleware := NewAuthMiddleware(userService, "https://example.com")
|
||||
|
||||
// Create a test handler that returns 200 OK
|
||||
testHandler := func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("OK"))
|
||||
}
|
||||
|
||||
wrappedHandler := middleware.RequireAuth(testHandler)
|
||||
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
wrappedHandler(w, req, nil)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
|
||||
}
|
||||
|
||||
if tt.shouldRedirect {
|
||||
location := w.Header().Get("Location")
|
||||
if location == "" {
|
||||
t.Error("Expected redirect but no Location header found")
|
||||
}
|
||||
if !strings.Contains(location, "/login") {
|
||||
t.Error("Expected redirect to login page")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureHeaders(t *testing.T) {
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("OK"))
|
||||
})
|
||||
|
||||
wrapped := SecureHeaders(nextHandler)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
wrapped.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
// Check security headers
|
||||
headers := map[string]string{
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"X-Frame-Options": "DENY",
|
||||
"Referrer-Policy": "no-referrer",
|
||||
"Content-Security-Policy": "default-src 'self'",
|
||||
}
|
||||
|
||||
for header, expectedValue := range headers {
|
||||
actualValue := w.Header().Get(header)
|
||||
if !strings.Contains(actualValue, expectedValue) {
|
||||
t.Errorf("Expected header %s to contain %s, got %s", header, expectedValue, actualValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expectedStatus int
|
||||
expectedText string
|
||||
}{
|
||||
{
|
||||
name: "unauthorized error",
|
||||
err: models.ErrUnauthorized,
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedText: "Unauthorized",
|
||||
},
|
||||
{
|
||||
name: "signature not found error",
|
||||
err: models.ErrSignatureNotFound,
|
||||
expectedStatus: http.StatusNotFound,
|
||||
expectedText: "Signature not found",
|
||||
},
|
||||
{
|
||||
name: "signature already exists error",
|
||||
err: models.ErrSignatureAlreadyExists,
|
||||
expectedStatus: http.StatusConflict,
|
||||
expectedText: "Signature already exists",
|
||||
},
|
||||
{
|
||||
name: "invalid user error",
|
||||
err: models.ErrInvalidUser,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedText: "Invalid user",
|
||||
},
|
||||
{
|
||||
name: "invalid document error",
|
||||
err: models.ErrInvalidDocument,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedText: "Invalid document ID",
|
||||
},
|
||||
{
|
||||
name: "domain not allowed error",
|
||||
err: models.ErrDomainNotAllowed,
|
||||
expectedStatus: http.StatusForbidden,
|
||||
expectedText: "Domain not allowed",
|
||||
},
|
||||
{
|
||||
name: "database connection error",
|
||||
err: models.ErrDatabaseConnection,
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedText: "Database error",
|
||||
},
|
||||
{
|
||||
name: "unknown error",
|
||||
err: errors.New("unknown error"),
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedText: "Internal server error",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
HandleError(w, tt.err)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
|
||||
}
|
||||
|
||||
body := strings.TrimSpace(w.Body.String())
|
||||
if !strings.Contains(body, tt.expectedText) {
|
||||
t.Errorf("Expected body to contain %s, got %s", tt.expectedText, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Utility function tests
|
||||
|
||||
func TestValidateDocID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupReq func() *http.Request
|
||||
expected string
|
||||
shouldErr bool
|
||||
}{
|
||||
{
|
||||
name: "from query parameter",
|
||||
setupReq: func() *http.Request {
|
||||
req := httptest.NewRequest("GET", "/test?doc=test-doc", nil)
|
||||
return req
|
||||
},
|
||||
expected: "test-doc",
|
||||
},
|
||||
{
|
||||
name: "from form value",
|
||||
setupReq: func() *http.Request {
|
||||
req := httptest.NewRequest("POST", "/test", strings.NewReader("doc=test-doc"))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
return req
|
||||
},
|
||||
expected: "test-doc",
|
||||
},
|
||||
{
|
||||
name: "trimmed whitespace",
|
||||
setupReq: func() *http.Request {
|
||||
req := httptest.NewRequest("GET", "/test?doc=%20test-doc%20", nil)
|
||||
return req
|
||||
},
|
||||
expected: "test-doc",
|
||||
},
|
||||
{
|
||||
name: "missing doc parameter",
|
||||
setupReq: func() *http.Request {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
return req
|
||||
},
|
||||
shouldErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := tt.setupReq()
|
||||
result, err := validateDocID(req)
|
||||
|
||||
if tt.shouldErr {
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %s, got %s", tt.expected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSignURL(t *testing.T) {
|
||||
result := buildSignURL("https://example.com", "test doc")
|
||||
expected := "https://example.com/sign?doc=test+doc"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf("Expected %s, got %s", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildLoginURL(t *testing.T) {
|
||||
result := buildLoginURL("https://example.com/sign?doc=test")
|
||||
expected := "/login?next=" + url.QueryEscape("https://example.com/sign?doc=test")
|
||||
|
||||
if result != expected {
|
||||
t.Errorf("Expected %s, got %s", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateUserIdentifier(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userParam string
|
||||
expected string
|
||||
shouldErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid user identifier",
|
||||
userParam: "test@example.com",
|
||||
expected: "test@example.com",
|
||||
},
|
||||
{
|
||||
name: "trimmed whitespace",
|
||||
userParam: " test@example.com ",
|
||||
expected: "test@example.com",
|
||||
},
|
||||
{
|
||||
name: "missing user parameter",
|
||||
userParam: "",
|
||||
shouldErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
if tt.userParam != "" {
|
||||
q := req.URL.Query()
|
||||
q.Set("user", tt.userParam)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
result, err := validateUserIdentifier(req)
|
||||
|
||||
if tt.shouldErr {
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %s, got %s", tt.expected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
)
|
||||
|
||||
// HealthHandler handles health check requests
|
||||
type HealthHandler struct{}
|
||||
|
||||
// NewHealthHandler creates a new health handler
|
||||
func NewHealthHandler() *HealthHandler {
|
||||
return &HealthHandler{}
|
||||
}
|
||||
|
||||
// HealthResponse represents a health check response
|
||||
type HealthResponse struct {
|
||||
OK bool `json:"ok"`
|
||||
Time time.Time `json:"time"`
|
||||
}
|
||||
|
||||
// HandleHealth returns the application health status
|
||||
func (h *HealthHandler) HandleHealth(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
response := HealthResponse{
|
||||
OK: true,
|
||||
Time: time.Now().UTC(),
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"ackify/internal/domain/models"
|
||||
"context"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type authService interface {
|
||||
SetUser(w http.ResponseWriter, r *http.Request, user *models.User) error
|
||||
Logout(w http.ResponseWriter, r *http.Request)
|
||||
GetAuthURL(nextURL string) string
|
||||
HandleCallback(ctx context.Context, code, state string) (*models.User, string, error)
|
||||
}
|
||||
|
||||
type userService interface {
|
||||
GetUser(r *http.Request) (*models.User, error)
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"ackify/internal/domain/models"
|
||||
)
|
||||
|
||||
// AuthMiddleware provides authentication middleware
|
||||
type AuthMiddleware struct {
|
||||
userService userService
|
||||
baseURL string
|
||||
}
|
||||
|
||||
// NewAuthMiddleware creates a new auth middleware
|
||||
func NewAuthMiddleware(userService userService, baseURL string) *AuthMiddleware {
|
||||
return &AuthMiddleware{
|
||||
userService: userService,
|
||||
baseURL: baseURL,
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAuth wraps a handler to require authentication
|
||||
func (m *AuthMiddleware) RequireAuth(next httprouter.Handle) httprouter.Handle {
|
||||
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
||||
_, err := m.userService.GetUser(r)
|
||||
if err != nil {
|
||||
nextURL := m.baseURL + r.URL.RequestURI()
|
||||
loginURL := buildLoginURL(nextURL)
|
||||
http.Redirect(w, r, loginURL, http.StatusFound)
|
||||
return
|
||||
}
|
||||
next(w, r, ps)
|
||||
}
|
||||
}
|
||||
|
||||
// SecureHeaders middleware adds security headers with default configuration
|
||||
func SecureHeaders(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.Header().Set("X-Frame-Options", "DENY")
|
||||
w.Header().Set("Referrer-Policy", "no-referrer")
|
||||
w.Header().Set("Content-Security-Policy",
|
||||
"default-src 'self'; style-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com; "+
|
||||
"script-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com; "+
|
||||
"img-src 'self' data: https://cdn.simpleicons.org; connect-src 'self'")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// ErrorResponse represents an error response
|
||||
type ErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// HandleError handles different types of errors and returns appropriate HTTP responses
|
||||
func HandleError(w http.ResponseWriter, err error) {
|
||||
switch {
|
||||
case errors.Is(err, models.ErrUnauthorized):
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
case errors.Is(err, models.ErrSignatureNotFound):
|
||||
http.Error(w, "Signature not found", http.StatusNotFound)
|
||||
case errors.Is(err, models.ErrSignatureAlreadyExists):
|
||||
http.Error(w, "Signature already exists", http.StatusConflict)
|
||||
case errors.Is(err, models.ErrInvalidUser):
|
||||
http.Error(w, "Invalid user", http.StatusBadRequest)
|
||||
case errors.Is(err, models.ErrInvalidDocument):
|
||||
http.Error(w, "Invalid document ID", http.StatusBadRequest)
|
||||
case errors.Is(err, models.ErrDomainNotAllowed):
|
||||
http.Error(w, "Domain not allowed", http.StatusForbidden)
|
||||
case errors.Is(err, models.ErrDatabaseConnection):
|
||||
http.Error(w, "Database error", http.StatusInternalServerError)
|
||||
default:
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,252 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"ackify/internal/domain/models"
|
||||
)
|
||||
|
||||
// OEmbedHandler handles oEmbed requests
|
||||
type OEmbedHandler struct {
|
||||
signatureService signatureService
|
||||
template *template.Template
|
||||
baseURL string
|
||||
organisation string
|
||||
}
|
||||
|
||||
// NewOEmbedHandler creates a new oEmbed handler
|
||||
func NewOEmbedHandler(signatureService signatureService, tmpl *template.Template, baseURL, organisation string) *OEmbedHandler {
|
||||
return &OEmbedHandler{
|
||||
signatureService: signatureService,
|
||||
template: tmpl,
|
||||
baseURL: baseURL,
|
||||
organisation: organisation,
|
||||
}
|
||||
}
|
||||
|
||||
// OEmbedResponse represents the oEmbed JSON response format
|
||||
type OEmbedResponse struct {
|
||||
Type string `json:"type"`
|
||||
Version string `json:"version"`
|
||||
Title string `json:"title"`
|
||||
AuthorName string `json:"author_name,omitempty"`
|
||||
AuthorURL string `json:"author_url,omitempty"`
|
||||
ProviderName string `json:"provider_name"`
|
||||
ProviderURL string `json:"provider_url"`
|
||||
CacheAge int `json:"cache_age,omitempty"`
|
||||
HTML string `json:"html"`
|
||||
Width int `json:"width,omitempty"`
|
||||
Height int `json:"height,omitempty"`
|
||||
}
|
||||
|
||||
// SignatoryData represents data for rendering signatories
|
||||
type SignatoryData struct {
|
||||
DocID string
|
||||
Signatures []SignatoryInfo
|
||||
Count int
|
||||
LastSignedAt string
|
||||
EmbedURL string
|
||||
SignURL string
|
||||
}
|
||||
|
||||
// SignatoryInfo represents a signatory's information
|
||||
type SignatoryInfo struct {
|
||||
Name string
|
||||
Email string
|
||||
SignedAt string
|
||||
}
|
||||
|
||||
// HandleOEmbed handles oEmbed requests for signature lists
|
||||
func (h *OEmbedHandler) HandleOEmbed(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
// Parse query parameters
|
||||
targetURL := r.URL.Query().Get("url")
|
||||
format := r.URL.Query().Get("format")
|
||||
maxWidth := r.URL.Query().Get("maxwidth")
|
||||
maxHeight := r.URL.Query().Get("maxheight")
|
||||
|
||||
if targetURL == "" {
|
||||
HandleError(w, models.ErrInvalidDocument)
|
||||
return
|
||||
}
|
||||
|
||||
// Default format is JSON
|
||||
if format == "" {
|
||||
format = "json"
|
||||
}
|
||||
|
||||
// Only support JSON format for now
|
||||
if format != "json" {
|
||||
http.Error(w, "Only JSON format is supported", http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract document ID from URL
|
||||
docID, err := h.extractDocIDFromURL(targetURL)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid URL format", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Get signatures for the document
|
||||
ctx := r.Context()
|
||||
signatures, err := h.signatureService.GetDocumentSignatures(ctx, docID)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to retrieve signatures", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Convert to signatory info
|
||||
signatories := make([]SignatoryInfo, len(signatures))
|
||||
var lastSignedAt string
|
||||
for i, sig := range signatures {
|
||||
name := ""
|
||||
if sig.UserName != nil {
|
||||
name = *sig.UserName
|
||||
}
|
||||
signatories[i] = SignatoryInfo{
|
||||
Name: name,
|
||||
Email: sig.UserEmail,
|
||||
SignedAt: sig.SignedAtUTC.Format("02/01/2006 à 15:04"),
|
||||
}
|
||||
if i == 0 { // First signature (most recent due to ORDER BY in repository)
|
||||
lastSignedAt = signatories[i].SignedAt
|
||||
}
|
||||
}
|
||||
|
||||
// Render embedded HTML
|
||||
embedHTML, err := h.renderEmbeddedHTML(SignatoryData{
|
||||
DocID: docID,
|
||||
Signatures: signatories,
|
||||
Count: len(signatories),
|
||||
LastSignedAt: lastSignedAt,
|
||||
EmbedURL: targetURL,
|
||||
SignURL: fmt.Sprintf("%s/sign?doc=%s", h.baseURL, url.QueryEscape(docID)),
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to render embedded content", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse dimensions
|
||||
width := 480 // Default width
|
||||
height := 320 // Default height
|
||||
|
||||
if maxWidth != "" {
|
||||
if w, err := strconv.Atoi(maxWidth); err == nil && w > 0 && w < 2000 {
|
||||
width = w
|
||||
}
|
||||
}
|
||||
|
||||
if maxHeight != "" {
|
||||
if h, err := strconv.Atoi(maxHeight); err == nil && h > 0 && h < 2000 {
|
||||
height = h
|
||||
}
|
||||
}
|
||||
|
||||
// Create oEmbed response
|
||||
response := OEmbedResponse{
|
||||
Type: "rich",
|
||||
Version: "1.0",
|
||||
Title: fmt.Sprintf("Signataires du document %s", docID),
|
||||
AuthorName: h.organisation,
|
||||
AuthorURL: h.baseURL,
|
||||
ProviderName: "Service de validation de lecture",
|
||||
ProviderURL: h.baseURL,
|
||||
CacheAge: 3600, // Cache for 1 hour
|
||||
HTML: embedHTML,
|
||||
Width: width,
|
||||
Height: height,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// HandleEmbedView handles direct embed view requests
|
||||
func (h *OEmbedHandler) HandleEmbedView(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
docID := strings.TrimSpace(r.URL.Query().Get("doc"))
|
||||
if docID == "" {
|
||||
http.Error(w, "Missing document ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Get signatures for the document
|
||||
ctx := r.Context()
|
||||
signatures, err := h.signatureService.GetDocumentSignatures(ctx, docID)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to retrieve signatures", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Convert to signatory info
|
||||
signatories := make([]SignatoryInfo, len(signatures))
|
||||
var lastSignedAt string
|
||||
for i, sig := range signatures {
|
||||
name := ""
|
||||
if sig.UserName != nil {
|
||||
name = *sig.UserName
|
||||
}
|
||||
signatories[i] = SignatoryInfo{
|
||||
Name: name,
|
||||
Email: sig.UserEmail,
|
||||
SignedAt: sig.SignedAtUTC.Format("02/01/2006 à 15:04"),
|
||||
}
|
||||
if i == 0 {
|
||||
lastSignedAt = signatories[i].SignedAt
|
||||
}
|
||||
}
|
||||
|
||||
data := SignatoryData{
|
||||
DocID: docID,
|
||||
Signatures: signatories,
|
||||
Count: len(signatories),
|
||||
LastSignedAt: lastSignedAt,
|
||||
EmbedURL: fmt.Sprintf("%s/embed?doc=%s", h.baseURL, url.QueryEscape(docID)),
|
||||
SignURL: fmt.Sprintf("%s/sign?doc=%s", h.baseURL, url.QueryEscape(docID)),
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.Header().Set("X-Frame-Options", "ALLOWALL") // Allow embedding in iframes
|
||||
|
||||
if err := h.template.ExecuteTemplate(w, "embed", data); err != nil {
|
||||
http.Error(w, "Failed to render template", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// extractDocIDFromURL extracts document ID from various URL formats
|
||||
func (h *OEmbedHandler) extractDocIDFromURL(targetURL string) (string, error) {
|
||||
parsedURL, err := url.Parse(targetURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Try to extract from query parameter
|
||||
if docID := parsedURL.Query().Get("doc"); docID != "" {
|
||||
return docID, nil
|
||||
}
|
||||
|
||||
// Try to extract from path (e.g., /embed/doc_123 or /status/doc_123)
|
||||
pathParts := strings.Split(strings.Trim(parsedURL.Path, "/"), "/")
|
||||
if len(pathParts) >= 2 && (pathParts[0] == "embed" || pathParts[0] == "status" || pathParts[0] == "sign") {
|
||||
return pathParts[1], nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("could not extract document ID from URL")
|
||||
}
|
||||
|
||||
// renderEmbeddedHTML renders the embedded HTML content
|
||||
func (h *OEmbedHandler) renderEmbeddedHTML(data SignatoryData) (string, error) {
|
||||
var buf strings.Builder
|
||||
if err := h.template.ExecuteTemplate(&buf, "embed", data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
@@ -0,0 +1,298 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
||||
"ackify/internal/domain/models"
|
||||
"ackify/pkg/services"
|
||||
)
|
||||
|
||||
type signatureService interface {
|
||||
CreateSignature(ctx context.Context, request *models.SignatureRequest) error
|
||||
GetSignatureStatus(ctx context.Context, docID string, user *models.User) (*models.SignatureStatus, error)
|
||||
GetSignatureByDocAndUser(ctx context.Context, docID string, user *models.User) (*models.Signature, error)
|
||||
GetDocumentSignatures(ctx context.Context, docID string) ([]*models.Signature, error)
|
||||
GetUserSignatures(ctx context.Context, user *models.User) ([]*models.Signature, error)
|
||||
CheckUserSignature(ctx context.Context, docID, userIdentifier string) (bool, error)
|
||||
}
|
||||
|
||||
// SignatureHandlers handles signature-related HTTP requests
|
||||
type SignatureHandlers struct {
|
||||
signatureService signatureService
|
||||
userService userService
|
||||
template *template.Template
|
||||
baseURL string
|
||||
}
|
||||
|
||||
// NewSignatureHandlers creates new signature handlers
|
||||
func NewSignatureHandlers(signatureService signatureService, userService userService, tmpl *template.Template, baseURL string) *SignatureHandlers {
|
||||
return &SignatureHandlers{
|
||||
signatureService: signatureService,
|
||||
userService: userService,
|
||||
template: tmpl,
|
||||
baseURL: baseURL,
|
||||
}
|
||||
}
|
||||
|
||||
// PageData represents data passed to templates
|
||||
type PageData struct {
|
||||
User *models.User
|
||||
Year int
|
||||
DocID string
|
||||
Already bool
|
||||
SignedAt string
|
||||
TemplateName string
|
||||
BaseURL string
|
||||
Signatures []*models.Signature
|
||||
ServiceInfo *struct {
|
||||
Name string
|
||||
Icon string
|
||||
Type string
|
||||
Referrer string
|
||||
}
|
||||
}
|
||||
|
||||
// HandleIndex serves the main index page
|
||||
func (h *SignatureHandlers) HandleIndex(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
user, _ := h.userService.GetUser(r)
|
||||
h.render(w, r, "index", PageData{User: user})
|
||||
}
|
||||
|
||||
// HandleSignGET displays the signature page
|
||||
func (h *SignatureHandlers) HandleSignGET(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
user, err := h.userService.GetUser(r)
|
||||
if err != nil {
|
||||
HandleError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
docID, err := validateDocID(r)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
status, err := h.signatureService.GetSignatureStatus(ctx, docID, user)
|
||||
if err != nil {
|
||||
HandleError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
signedAt := ""
|
||||
var serviceInfo *struct {
|
||||
Name string
|
||||
Icon string
|
||||
Type string
|
||||
Referrer string
|
||||
}
|
||||
|
||||
// First try to get service info from URL parameter (always present when coming from embed)
|
||||
if referrerParam := r.URL.Query().Get("referrer"); referrerParam != "" {
|
||||
if sigServiceInfo := services.DetectServiceFromReferrer(referrerParam); sigServiceInfo != nil {
|
||||
serviceInfo = &struct {
|
||||
Name string
|
||||
Icon string
|
||||
Type string
|
||||
Referrer string
|
||||
}{
|
||||
Name: sigServiceInfo.Name,
|
||||
Icon: sigServiceInfo.Icon,
|
||||
Type: sigServiceInfo.Type,
|
||||
Referrer: sigServiceInfo.Referrer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if status.IsSigned {
|
||||
// Get full signature to access referer information
|
||||
signature, err := h.signatureService.GetSignatureByDocAndUser(ctx, docID, user)
|
||||
if err == nil && signature != nil {
|
||||
if signature.SignedAtUTC.IsZero() == false {
|
||||
signedAt = signature.SignedAtUTC.Format("02/01/2006 à 15:04:05")
|
||||
}
|
||||
|
||||
// If no service info from URL, try to get it from stored signature
|
||||
if serviceInfo == nil && signature.Referer != nil {
|
||||
if sigServiceInfo := signature.GetServiceInfo(); sigServiceInfo != nil {
|
||||
serviceInfo = &struct {
|
||||
Name string
|
||||
Icon string
|
||||
Type string
|
||||
Referrer string
|
||||
}{
|
||||
Name: sigServiceInfo.Name,
|
||||
Icon: sigServiceInfo.Icon,
|
||||
Type: sigServiceInfo.Type,
|
||||
Referrer: sigServiceInfo.Referrer,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if signedAt == "" && status.SignedAt != nil {
|
||||
signedAt = status.SignedAt.Format("02/01/2006 à 15:04:05")
|
||||
}
|
||||
|
||||
h.render(w, r, "sign", PageData{
|
||||
User: user,
|
||||
DocID: docID,
|
||||
Already: status.IsSigned,
|
||||
SignedAt: signedAt,
|
||||
BaseURL: h.baseURL,
|
||||
ServiceInfo: serviceInfo,
|
||||
})
|
||||
}
|
||||
|
||||
// HandleSignPOST processes signature creation
|
||||
func (h *SignatureHandlers) HandleSignPOST(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
user, err := h.userService.GetUser(r)
|
||||
if err != nil {
|
||||
if docID := r.FormValue("doc"); docID != "" {
|
||||
loginURL := buildLoginURL(buildSignURL(h.baseURL, docID))
|
||||
http.Redirect(w, r, loginURL, http.StatusFound)
|
||||
return
|
||||
}
|
||||
HandleError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
docID, err := validateDocID(r)
|
||||
if err != nil {
|
||||
HandleError(w, models.ErrInvalidDocument)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
var referer *string
|
||||
if referrerParam := r.FormValue("referrer"); referrerParam != "" {
|
||||
referer = &referrerParam
|
||||
} else if referrerParam := r.URL.Query().Get("referrer"); referrerParam != "" {
|
||||
referer = &referrerParam
|
||||
} else {
|
||||
fmt.Printf("DEBUG: No referrer found in form or URL\n")
|
||||
}
|
||||
|
||||
request := &models.SignatureRequest{
|
||||
DocID: docID,
|
||||
User: user,
|
||||
Referer: referer,
|
||||
}
|
||||
|
||||
err = h.signatureService.CreateSignature(ctx, request)
|
||||
if err != nil {
|
||||
if errors.Is(err, models.ErrSignatureAlreadyExists) {
|
||||
// Redirect to view existing signature
|
||||
http.Redirect(w, r, buildSignURL(h.baseURL, docID), http.StatusFound)
|
||||
return
|
||||
}
|
||||
HandleError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Redirect to view the created signature
|
||||
http.Redirect(w, r, buildSignURL(h.baseURL, docID), http.StatusFound)
|
||||
}
|
||||
|
||||
// HandleStatusJSON returns signature status as JSON
|
||||
func (h *SignatureHandlers) HandleStatusJSON(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
docID, err := validateDocID(r)
|
||||
if err != nil {
|
||||
HandleError(w, models.ErrInvalidDocument)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
signatures, err := h.signatureService.GetDocumentSignatures(ctx, docID)
|
||||
if err != nil {
|
||||
HandleError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Convert to JSON response format
|
||||
response := make([]map[string]interface{}, 0, len(signatures))
|
||||
for _, sig := range signatures {
|
||||
sigData := map[string]interface{}{
|
||||
"id": sig.ID,
|
||||
"doc_id": sig.DocID,
|
||||
"user_sub": sig.UserSub,
|
||||
"user_email": sig.UserEmail,
|
||||
"signed_at_utc": sig.SignedAtUTC,
|
||||
}
|
||||
|
||||
// Add username if available
|
||||
if sig.UserName != nil && *sig.UserName != "" {
|
||||
sigData["user_name"] = *sig.UserName
|
||||
}
|
||||
|
||||
// Add service information if available
|
||||
if serviceInfo := sig.GetServiceInfo(); serviceInfo != nil {
|
||||
sigData["service"] = map[string]interface{}{
|
||||
"name": serviceInfo.Name,
|
||||
"icon": serviceInfo.Icon,
|
||||
"type": serviceInfo.Type,
|
||||
}
|
||||
}
|
||||
|
||||
response = append(response, sigData)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// HandleUserSignatures displays the user's signatures page
|
||||
func (h *SignatureHandlers) HandleUserSignatures(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||||
user, err := h.userService.GetUser(r)
|
||||
if err != nil {
|
||||
HandleError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
signatures, err := h.signatureService.GetUserSignatures(ctx, user)
|
||||
if err != nil {
|
||||
HandleError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
h.render(w, r, "signatures", PageData{User: user, BaseURL: h.baseURL, Signatures: signatures})
|
||||
}
|
||||
|
||||
// render executes template with data
|
||||
func (h *SignatureHandlers) render(w http.ResponseWriter, _ *http.Request, templateName string, data PageData) {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
|
||||
if data.Year == 0 {
|
||||
data.Year = time.Now().Year()
|
||||
}
|
||||
if data.TemplateName == "" {
|
||||
data.TemplateName = templateName
|
||||
}
|
||||
|
||||
templateData := map[string]interface{}{
|
||||
"User": data.User,
|
||||
"Year": data.Year,
|
||||
"DocID": data.DocID,
|
||||
"Already": data.Already,
|
||||
"SignedAt": data.SignedAt,
|
||||
"TemplateName": data.TemplateName,
|
||||
"BaseURL": data.BaseURL,
|
||||
"Signatures": data.Signatures,
|
||||
"ServiceInfo": data.ServiceInfo,
|
||||
}
|
||||
|
||||
if err := h.template.ExecuteTemplate(w, "base", templateData); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// validateDocID extracts and validates document ID from request
|
||||
func validateDocID(r *http.Request) (string, error) {
|
||||
var docID string
|
||||
|
||||
// Try query parameter first, then form value
|
||||
docID = strings.TrimSpace(r.URL.Query().Get("doc"))
|
||||
if docID == "" {
|
||||
docID = strings.TrimSpace(r.FormValue("doc"))
|
||||
}
|
||||
|
||||
if docID == "" {
|
||||
return "", fmt.Errorf("missing document ID")
|
||||
}
|
||||
|
||||
return docID, nil
|
||||
}
|
||||
|
||||
// buildSignURL constructs a sign URL with proper escaping
|
||||
func buildSignURL(baseURL, docID string) string {
|
||||
return fmt.Sprintf("%s/sign?doc=%s", baseURL, url.QueryEscape(docID))
|
||||
}
|
||||
|
||||
// buildLoginURL constructs a login URL with next parameter
|
||||
func buildLoginURL(nextURL string) string {
|
||||
return "/login?next=" + url.QueryEscape(nextURL)
|
||||
}
|
||||
|
||||
// validateUserIdentifier extracts and validates user identifier from request
|
||||
func validateUserIdentifier(r *http.Request) (string, error) {
|
||||
userIdentifier := strings.TrimSpace(r.URL.Query().Get("user"))
|
||||
if userIdentifier == "" {
|
||||
return "", fmt.Errorf("missing user parameter")
|
||||
}
|
||||
return userIdentifier, nil
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
package templates
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"html/template"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// InitTemplates initializes the HTML templates from files
|
||||
func InitTemplates() (*template.Template, error) {
|
||||
// Get the templates directory path relative to the binary
|
||||
templatesDir := "web/templates"
|
||||
|
||||
// Parse the base template first
|
||||
tmpl, err := template.New("base").ParseFiles(filepath.Join(templatesDir, "base.html.tpl"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse base template: %w", err)
|
||||
}
|
||||
|
||||
// Parse the additional templates
|
||||
additionalTemplates := []string{"index.html.tpl", "sign.html.tpl", "signatures.html.tpl", "embed.html.tpl"}
|
||||
for _, templateFile := range additionalTemplates {
|
||||
_, err = tmpl.ParseFiles(filepath.Join(templatesDir, templateFile))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse template %s: %w", templateFile, err)
|
||||
}
|
||||
}
|
||||
|
||||
return tmpl, nil
|
||||
}
|
||||
@@ -0,0 +1,390 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"ackify/internal/domain/models"
|
||||
)
|
||||
|
||||
// TestCryptoIntegration tests the integrations between signature generation and nonce generation
|
||||
func TestCryptoIntegration(t *testing.T) {
|
||||
t.Run("signature with generated nonce", func(t *testing.T) {
|
||||
signer, err := NewEd25519Signer()
|
||||
require.NoError(t, err)
|
||||
|
||||
user := testUserAlice
|
||||
docID := "integrations-test-doc"
|
||||
timestamp := time.Now().UTC()
|
||||
|
||||
// Generate a nonce
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create signature with generated nonce
|
||||
hash, sig, err := signer.CreateSignature(docID, user, timestamp, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEmpty(t, hash)
|
||||
assert.NotEmpty(t, sig)
|
||||
|
||||
// Verify hash is SHA-256
|
||||
hashBytes, err := base64.StdEncoding.DecodeString(hash)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, hashBytes, 32, "Hash should be SHA-256 (32 bytes)")
|
||||
|
||||
// Verify signature is Ed25519
|
||||
sigBytes, err := base64.StdEncoding.DecodeString(sig)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, sigBytes, 64, "Signature should be Ed25519 (64 bytes)")
|
||||
})
|
||||
|
||||
t.Run("different nonces produce different signatures", func(t *testing.T) {
|
||||
signer, err := NewEd25519Signer()
|
||||
require.NoError(t, err)
|
||||
|
||||
user := testUserBob
|
||||
docID := "nonce-diff-test"
|
||||
timestamp := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Generate two different nonces
|
||||
nonce1, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
nonce2, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, nonce1, nonce2, "Nonces should be different")
|
||||
|
||||
// Create signatures with different nonces
|
||||
hash1, sig1, err := signer.CreateSignature(docID, user, timestamp, nonce1)
|
||||
require.NoError(t, err)
|
||||
|
||||
hash2, sig2, err := signer.CreateSignature(docID, user, timestamp, nonce2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Different nonces should produce different signatures
|
||||
assert.NotEqual(t, hash1, hash2, "Different nonces should produce different hashes")
|
||||
assert.NotEqual(t, sig1, sig2, "Different nonces should produce different signatures")
|
||||
})
|
||||
|
||||
t.Run("replay attack prevention", func(t *testing.T) {
|
||||
signer, err := NewEd25519Signer()
|
||||
require.NoError(t, err)
|
||||
|
||||
user := testUserCharlie
|
||||
docID := "replay-test-doc"
|
||||
timestamp := time.Now().UTC()
|
||||
|
||||
// Simulate multiple signature attempts for same document
|
||||
signatures := make(map[string]bool)
|
||||
nonces := make(map[string]bool)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
// Generate unique nonce for each attempt
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify nonce is unique
|
||||
assert.False(t, nonces[nonce], "Nonce should be unique for replay protection")
|
||||
nonces[nonce] = true
|
||||
|
||||
// Create signature
|
||||
hash, sig, err := signer.CreateSignature(docID, user, timestamp, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify signature is unique
|
||||
assert.False(t, signatures[sig], "Signature should be unique due to nonce")
|
||||
signatures[sig] = true
|
||||
|
||||
// All should have different hashes due to nonce
|
||||
assert.NotEmpty(t, hash)
|
||||
assert.NotEmpty(t, sig)
|
||||
}
|
||||
|
||||
assert.Len(t, signatures, 10, "All signatures should be unique")
|
||||
assert.Len(t, nonces, 10, "All nonces should be unique")
|
||||
})
|
||||
}
|
||||
|
||||
// TestSHA256Hashing tests SHA-256 hashing functionality indirectly through signature creation
|
||||
func TestSHA256Hashing(t *testing.T) {
|
||||
signer, err := NewEd25519Signer()
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("consistent hashing", func(t *testing.T) {
|
||||
user := testUserAlice
|
||||
docID := "hash-test-doc"
|
||||
timestamp := time.Date(2024, 3, 15, 10, 30, 0, 0, time.UTC)
|
||||
nonce := "consistent-nonce"
|
||||
|
||||
// Create signature multiple times
|
||||
hash1, _, err := signer.CreateSignature(docID, user, timestamp, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
hash2, _, err := signer.CreateSignature(docID, user, timestamp, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, hash1, hash2, "Same input should produce same hash")
|
||||
})
|
||||
|
||||
t.Run("hash changes with input changes", func(t *testing.T) {
|
||||
user := testUserBob
|
||||
baseTimestamp := time.Date(2024, 4, 1, 14, 0, 0, 0, time.UTC)
|
||||
baseNonce := "base-nonce"
|
||||
|
||||
// Base signature
|
||||
baseHash, _, err := signer.CreateSignature("base-doc", user, baseTimestamp, baseNonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test different document ID
|
||||
hash1, _, err := signer.CreateSignature("different-doc", user, baseTimestamp, baseNonce)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, baseHash, hash1, "Different docID should produce different hash")
|
||||
|
||||
// Test different user
|
||||
differentUser := testUserCharlie
|
||||
hash2, _, err := signer.CreateSignature("base-doc", differentUser, baseTimestamp, baseNonce)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, baseHash, hash2, "Different user should produce different hash")
|
||||
|
||||
// Test different timestamp
|
||||
differentTime := baseTimestamp.Add(time.Hour)
|
||||
hash3, _, err := signer.CreateSignature("base-doc", user, differentTime, baseNonce)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, baseHash, hash3, "Different timestamp should produce different hash")
|
||||
|
||||
// Test different nonce
|
||||
hash4, _, err := signer.CreateSignature("base-doc", user, baseTimestamp, "different-nonce")
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, baseHash, hash4, "Different nonce should produce different hash")
|
||||
})
|
||||
|
||||
t.Run("hash properties", func(t *testing.T) {
|
||||
user := testUserAlice
|
||||
docID := "props-test"
|
||||
timestamp := time.Now().UTC()
|
||||
nonce := "props-nonce"
|
||||
|
||||
hashB64, _, err := signer.CreateSignature(docID, user, timestamp, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Decode hash
|
||||
hashBytes, err := base64.StdEncoding.DecodeString(hashB64)
|
||||
require.NoError(t, err)
|
||||
|
||||
// SHA-256 properties
|
||||
assert.Len(t, hashBytes, 32, "SHA-256 hash should be 32 bytes")
|
||||
assert.NotEqual(t, make([]byte, 32), hashBytes, "Hash should not be all zeros")
|
||||
|
||||
// Verify it's actually SHA-256 by recreating manually
|
||||
expectedPayload := "doc_id=" + docID + "\n" +
|
||||
"user_sub=" + user.Sub + "\n" +
|
||||
"user_email=" + user.NormalizedEmail() + "\n" +
|
||||
"signed_at_utc=" + timestamp.UTC().Format(time.RFC3339Nano) + "\n" +
|
||||
"nonce=" + nonce + "\n"
|
||||
|
||||
expectedHash := sha256.Sum256([]byte(expectedPayload))
|
||||
expectedHashB64 := base64.StdEncoding.EncodeToString(expectedHash[:])
|
||||
|
||||
assert.Equal(t, expectedHashB64, hashB64, "Hash should match manual SHA-256 calculation")
|
||||
})
|
||||
|
||||
t.Run("avalanche effect", func(t *testing.T) {
|
||||
// Test that small changes in input produce large changes in hash (avalanche effect)
|
||||
user := testUserAlice
|
||||
timestamp := time.Now().UTC()
|
||||
nonce := "avalanche-test"
|
||||
|
||||
// Base hash
|
||||
baseHash, _, err := signer.CreateSignature("testdoc", user, timestamp, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Change one character in docID
|
||||
modHash, _, err := signer.CreateSignature("testdoC", user, timestamp, nonce) // Changed 'c' to 'C'
|
||||
require.NoError(t, err)
|
||||
|
||||
// Decode both hashes
|
||||
baseBytes, err := base64.StdEncoding.DecodeString(baseHash)
|
||||
require.NoError(t, err)
|
||||
|
||||
modBytes, err := base64.StdEncoding.DecodeString(modHash)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Count different bits (should be approximately 50% for good hash function)
|
||||
differentBits := 0
|
||||
for i := range baseBytes {
|
||||
xor := baseBytes[i] ^ modBytes[i]
|
||||
for xor != 0 {
|
||||
differentBits++
|
||||
xor &= xor - 1 // Clear lowest set bit
|
||||
}
|
||||
}
|
||||
|
||||
// SHA-256 should have good avalanche effect
|
||||
totalBits := len(baseBytes) * 8
|
||||
percentage := float64(differentBits) / float64(totalBits)
|
||||
|
||||
// Should be roughly 50% different bits (allow 30-70% range for single test)
|
||||
assert.Greater(t, percentage, 0.3, "Avalanche effect should change at least 30%% of bits")
|
||||
assert.Less(t, percentage, 0.7, "Avalanche effect should not change more than 70%% of bits")
|
||||
})
|
||||
}
|
||||
|
||||
// TestCorruptionDetection tests that signature corruption is detectable
|
||||
func TestCorruptionDetection(t *testing.T) {
|
||||
signer, err := NewEd25519Signer()
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("hash corruption detection", func(t *testing.T) {
|
||||
user := testUserAlice
|
||||
docID := "corruption-test"
|
||||
timestamp := time.Now().UTC()
|
||||
nonce := "corruption-nonce"
|
||||
|
||||
originalHash, originalSig, err := signer.CreateSignature(docID, user, timestamp, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Corrupt the hash
|
||||
hashBytes, err := base64.StdEncoding.DecodeString(originalHash)
|
||||
require.NoError(t, err)
|
||||
|
||||
hashBytes[0] ^= 0x01 // Flip one bit
|
||||
corruptedHash := base64.StdEncoding.EncodeToString(hashBytes)
|
||||
|
||||
assert.NotEqual(t, originalHash, corruptedHash, "Corrupted hash should be different")
|
||||
|
||||
// Original signature won't match corrupted hash when verified
|
||||
// (This would be caught during verification process)
|
||||
assert.NotEmpty(t, originalSig)
|
||||
})
|
||||
|
||||
t.Run("signature corruption detection", func(t *testing.T) {
|
||||
user := testUserBob
|
||||
docID := "sig-corruption-test"
|
||||
timestamp := time.Now().UTC()
|
||||
nonce := "sig-corruption-nonce"
|
||||
|
||||
originalHash, originalSig, err := signer.CreateSignature(docID, user, timestamp, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Corrupt the signature
|
||||
sigBytes, err := base64.StdEncoding.DecodeString(originalSig)
|
||||
require.NoError(t, err)
|
||||
|
||||
sigBytes[63] ^= 0xFF // Flip bits in last byte
|
||||
corruptedSig := base64.StdEncoding.EncodeToString(sigBytes)
|
||||
|
||||
assert.NotEqual(t, originalSig, corruptedSig, "Corrupted signature should be different")
|
||||
assert.NotEmpty(t, originalHash) // Hash should remain valid
|
||||
})
|
||||
|
||||
t.Run("payload tampering detection", func(t *testing.T) {
|
||||
user := testUserCharlie
|
||||
docID := "tamper-test"
|
||||
timestamp := time.Date(2024, 5, 1, 16, 45, 0, 0, time.UTC)
|
||||
nonce := "tamper-nonce"
|
||||
|
||||
// Original signature
|
||||
originalHash, originalSig, err := signer.CreateSignature(docID, user, timestamp, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create signature for tampered data (different docID)
|
||||
tamperedHash, tamperedSig, err := signer.CreateSignature("tampered-doc", user, timestamp, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Tampered data produces different hash and signature
|
||||
assert.NotEqual(t, originalHash, tamperedHash, "Tampered payload should produce different hash")
|
||||
assert.NotEqual(t, originalSig, tamperedSig, "Tampered payload should produce different signature")
|
||||
})
|
||||
}
|
||||
|
||||
// TestBusinessRuleEnforcement tests that cryptographic functions support business rules
|
||||
func TestBusinessRuleEnforcement(t *testing.T) {
|
||||
t.Run("unique signatures per document-user pair", func(t *testing.T) {
|
||||
signer, err := NewEd25519Signer()
|
||||
require.NoError(t, err)
|
||||
|
||||
user := testUserAlice
|
||||
docID := "business-rule-test"
|
||||
timestamp := time.Now().UTC()
|
||||
|
||||
// Create signatures with different nonces (simulating different attempts)
|
||||
nonce1, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
nonce2, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
hash1, sig1, err := signer.CreateSignature(docID, user, timestamp, nonce1)
|
||||
require.NoError(t, err)
|
||||
|
||||
hash2, sig2, err := signer.CreateSignature(docID, user, timestamp, nonce2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Different nonces create different signatures
|
||||
// This supports business rule that each signing attempt must be unique
|
||||
assert.NotEqual(t, hash1, hash2, "Different nonces should create different hashes")
|
||||
assert.NotEqual(t, sig1, sig2, "Different nonces should create different signatures")
|
||||
})
|
||||
|
||||
t.Run("email normalization consistency", func(t *testing.T) {
|
||||
signer, err := NewEd25519Signer()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create users with same email in different cases
|
||||
user1 := &models.User{
|
||||
Sub: "user-case-test",
|
||||
Email: "Test.User@EXAMPLE.COM",
|
||||
Name: "Test User",
|
||||
}
|
||||
|
||||
user2 := &models.User{
|
||||
Sub: "user-case-test",
|
||||
Email: "test.user@example.com",
|
||||
Name: "Test User",
|
||||
}
|
||||
|
||||
docID := "email-case-test"
|
||||
timestamp := time.Date(2024, 6, 1, 12, 0, 0, 0, time.UTC)
|
||||
nonce := "case-nonce"
|
||||
|
||||
hash1, sig1, err := signer.CreateSignature(docID, user1, timestamp, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
hash2, sig2, err := signer.CreateSignature(docID, user2, timestamp, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should produce same signature due to email normalization
|
||||
assert.Equal(t, hash1, hash2, "Email case should not affect signature due to normalization")
|
||||
assert.Equal(t, sig1, sig2, "Email case should not affect signature due to normalization")
|
||||
})
|
||||
|
||||
t.Run("timestamp precision handling", func(t *testing.T) {
|
||||
signer, err := NewEd25519Signer()
|
||||
require.NoError(t, err)
|
||||
|
||||
user := testUserAlice
|
||||
docID := "timestamp-precision-test"
|
||||
nonce := "precision-nonce"
|
||||
|
||||
// Test that nanosecond precision is maintained in signatures
|
||||
timestamp1 := time.Date(2024, 7, 1, 10, 30, 15, 123456789, time.UTC)
|
||||
timestamp2 := time.Date(2024, 7, 1, 10, 30, 15, 123456790, time.UTC) // 1 nanosecond different
|
||||
|
||||
hash1, sig1, err := signer.CreateSignature(docID, user, timestamp1, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
hash2, sig2, err := signer.CreateSignature(docID, user, timestamp2, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Even 1 nanosecond difference should produce different signatures
|
||||
assert.NotEqual(t, hash1, hash2, "Nanosecond precision should be maintained")
|
||||
assert.NotEqual(t, sig1, sig2, "Nanosecond precision should be maintained")
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ackify/internal/domain/models"
|
||||
)
|
||||
|
||||
// Ed25519Signer handles Ed25519 cryptographic operations
|
||||
type Ed25519Signer struct {
|
||||
privateKey ed25519.PrivateKey
|
||||
publicKey ed25519.PublicKey
|
||||
}
|
||||
|
||||
// NewEd25519Signer creates a new Ed25519 signer
|
||||
func NewEd25519Signer() (*Ed25519Signer, error) {
|
||||
privKey, pubKey, err := loadOrGenerateKeys()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load or generate keys: %w", err)
|
||||
}
|
||||
|
||||
return &Ed25519Signer{
|
||||
privateKey: privKey,
|
||||
publicKey: pubKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateSignature creates a cryptographic signature for a document
|
||||
func (s *Ed25519Signer) CreateSignature(docID string, user *models.User, timestamp time.Time, nonce string) (string, string, error) {
|
||||
payload := canonicalPayload(docID, user, timestamp, nonce)
|
||||
hash := sha256.Sum256(payload)
|
||||
signature := ed25519.Sign(s.privateKey, hash[:])
|
||||
|
||||
return base64.StdEncoding.EncodeToString(hash[:]), base64.StdEncoding.EncodeToString(signature), nil
|
||||
}
|
||||
|
||||
// GetPublicKey returns the base64 encoded public key
|
||||
func (s *Ed25519Signer) GetPublicKey() string {
|
||||
return base64.StdEncoding.EncodeToString(s.publicKey)
|
||||
}
|
||||
|
||||
// canonicalPayload creates a canonical payload for signing
|
||||
func canonicalPayload(docID string, user *models.User, timestamp time.Time, nonce string) []byte {
|
||||
return []byte(fmt.Sprintf(
|
||||
"doc_id=%s\nuser_sub=%s\nuser_email=%s\nsigned_at_utc=%s\nnonce=%s\n",
|
||||
docID,
|
||||
user.Sub,
|
||||
user.NormalizedEmail(),
|
||||
timestamp.UTC().Format(time.RFC3339Nano),
|
||||
nonce,
|
||||
))
|
||||
}
|
||||
|
||||
// loadOrGenerateKeys loads existing keys or generates new ones
|
||||
func loadOrGenerateKeys() (ed25519.PrivateKey, ed25519.PublicKey, error) {
|
||||
b64Key := strings.TrimSpace(os.Getenv("ED25519_PRIVATE_KEY_B64"))
|
||||
|
||||
if b64Key != "" {
|
||||
keyBytes, err := base64.StdEncoding.DecodeString(b64Key)
|
||||
if err != nil || len(keyBytes) != ed25519.PrivateKeySize {
|
||||
return nil, nil, fmt.Errorf("invalid ED25519_PRIVATE_KEY_B64: %v", err)
|
||||
}
|
||||
|
||||
privateKey := ed25519.PrivateKey(keyBytes)
|
||||
publicKey := privateKey.Public().(ed25519.PublicKey)
|
||||
|
||||
return privateKey, publicKey, nil
|
||||
}
|
||||
|
||||
// Generate new keys
|
||||
publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to generate keys: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("[WARN] Generated ephemeral Ed25519 keypair. Set ED25519_PRIVATE_KEY_B64 to persist: %s\n",
|
||||
base64.StdEncoding.EncodeToString(privateKey))
|
||||
|
||||
return privateKey, publicKey, nil
|
||||
}
|
||||
@@ -0,0 +1,437 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"ackify/internal/domain/models"
|
||||
)
|
||||
|
||||
func TestEd25519Signer_NewEd25519Signer(t *testing.T) {
|
||||
t.Run("creates new signer successfully", func(t *testing.T) {
|
||||
// Clear environment variable to force generation
|
||||
originalKey := os.Getenv("ED25519_PRIVATE_KEY_B64")
|
||||
os.Unsetenv("ED25519_PRIVATE_KEY_B64")
|
||||
defer func() {
|
||||
if originalKey != "" {
|
||||
os.Setenv("ED25519_PRIVATE_KEY_B64", originalKey)
|
||||
}
|
||||
}()
|
||||
|
||||
signer, err := NewEd25519Signer()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, signer)
|
||||
|
||||
// Test that public key is accessible
|
||||
pubKey := signer.GetPublicKey()
|
||||
assert.NotEmpty(t, pubKey)
|
||||
|
||||
// Test that public key is valid base64
|
||||
_, err = base64.StdEncoding.DecodeString(pubKey)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("loads signer from environment variable", func(t *testing.T) {
|
||||
// Generate a test key pair
|
||||
pubKey, privKey, err := ed25519.GenerateKey(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set environment variable
|
||||
b64Key := base64.StdEncoding.EncodeToString(privKey)
|
||||
os.Setenv("ED25519_PRIVATE_KEY_B64", b64Key)
|
||||
defer os.Unsetenv("ED25519_PRIVATE_KEY_B64")
|
||||
|
||||
signer, err := NewEd25519Signer()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, signer)
|
||||
|
||||
// Verify the public key matches
|
||||
expectedPubKey := base64.StdEncoding.EncodeToString(pubKey)
|
||||
actualPubKey := signer.GetPublicKey()
|
||||
assert.Equal(t, expectedPubKey, actualPubKey)
|
||||
})
|
||||
|
||||
t.Run("fails with invalid environment variable", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
value string
|
||||
}{
|
||||
{"invalid base64", "invalid!@#$"},
|
||||
{"wrong length", base64.StdEncoding.EncodeToString([]byte("short"))},
|
||||
{"empty string", ""},
|
||||
{"whitespace only", " "},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
os.Setenv("ED25519_PRIVATE_KEY_B64", tc.value)
|
||||
defer os.Unsetenv("ED25519_PRIVATE_KEY_B64")
|
||||
|
||||
if tc.value == "" || tc.value == " " {
|
||||
// Empty or whitespace should generate new keys
|
||||
signer, err := NewEd25519Signer()
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, signer)
|
||||
} else {
|
||||
// Invalid keys should return error
|
||||
signer, err := NewEd25519Signer()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, signer)
|
||||
assert.Contains(t, err.Error(), "invalid ED25519_PRIVATE_KEY_B64")
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEd25519Signer_CreateSignature(t *testing.T) {
|
||||
// Create signer for tests
|
||||
signer, err := NewEd25519Signer()
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("creates valid signature", func(t *testing.T) {
|
||||
user := testUserAlice
|
||||
docID := "test-document"
|
||||
timestamp := time.Date(2024, 1, 15, 12, 30, 0, 0, time.UTC)
|
||||
nonce := "test-nonce-123"
|
||||
|
||||
hashB64, sigB64, err := signer.CreateSignature(docID, user, timestamp, nonce)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, hashB64)
|
||||
assert.NotEmpty(t, sigB64)
|
||||
|
||||
// Verify hash is valid base64
|
||||
hashBytes, err := base64.StdEncoding.DecodeString(hashB64)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, hashBytes, 32) // SHA-256 hash length
|
||||
|
||||
// Verify signature is valid base64
|
||||
sigBytes, err := base64.StdEncoding.DecodeString(sigB64)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, sigBytes, ed25519.SignatureSize) // Ed25519 signature length
|
||||
})
|
||||
|
||||
t.Run("creates consistent signatures", func(t *testing.T) {
|
||||
user := testUserBob
|
||||
docID := "consistent-doc"
|
||||
timestamp := time.Date(2024, 2, 1, 10, 0, 0, 0, time.UTC)
|
||||
nonce := "consistent-nonce"
|
||||
|
||||
// Create signature twice with same parameters
|
||||
hash1, sig1, err1 := signer.CreateSignature(docID, user, timestamp, nonce)
|
||||
require.NoError(t, err1)
|
||||
|
||||
hash2, sig2, err2 := signer.CreateSignature(docID, user, timestamp, nonce)
|
||||
require.NoError(t, err2)
|
||||
|
||||
// Should produce identical results
|
||||
assert.Equal(t, hash1, hash2)
|
||||
assert.Equal(t, sig1, sig2)
|
||||
})
|
||||
|
||||
t.Run("creates different signatures for different inputs", func(t *testing.T) {
|
||||
user := testUserCharlie
|
||||
timestamp := time.Now().UTC()
|
||||
nonce := "test-nonce"
|
||||
|
||||
// Same user, different documents
|
||||
hash1, sig1, err := signer.CreateSignature("doc1", user, timestamp, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
hash2, sig2, err := signer.CreateSignature("doc2", user, timestamp, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, hash1, hash2)
|
||||
assert.NotEqual(t, sig1, sig2)
|
||||
|
||||
// Same document, different users
|
||||
hash3, sig3, err := signer.CreateSignature("doc1", testUserAlice, timestamp, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, hash1, hash3)
|
||||
assert.NotEqual(t, sig1, sig3)
|
||||
|
||||
// Same everything, different nonces
|
||||
hash4, sig4, err := signer.CreateSignature("doc1", user, timestamp, "different-nonce")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, hash1, hash4)
|
||||
assert.NotEqual(t, sig1, sig4)
|
||||
})
|
||||
|
||||
t.Run("handles different timestamp formats", func(t *testing.T) {
|
||||
user := testUserAlice
|
||||
docID := "timestamp-test"
|
||||
nonce := "timestamp-nonce"
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
timestamp time.Time
|
||||
}{
|
||||
{"UTC time", time.Date(2024, 6, 15, 14, 30, 0, 0, time.UTC)},
|
||||
{"Local time", time.Date(2024, 6, 15, 14, 30, 0, 0, time.FixedZone("EST", -5*3600))},
|
||||
{"Nanoseconds", time.Date(2024, 6, 15, 14, 30, 0, 123456789, time.UTC)},
|
||||
{"Zero time", time.Time{}},
|
||||
{"Unix epoch", time.Unix(0, 0)},
|
||||
}
|
||||
|
||||
signatures := make(map[string]string)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
hash, sig, err := signer.CreateSignature(docID, user, tc.timestamp, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Each timestamp should produce unique signature
|
||||
assert.NotContains(t, signatures, sig, "Signature should be unique for different timestamps")
|
||||
signatures[sig] = tc.name
|
||||
|
||||
assert.NotEmpty(t, hash)
|
||||
assert.NotEmpty(t, sig)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles edge case inputs", func(t *testing.T) {
|
||||
timestamp := time.Now().UTC()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
docID string
|
||||
user *models.User
|
||||
nonce string
|
||||
}{
|
||||
{"empty docID", "", testUserAlice, "nonce"},
|
||||
{"empty nonce", "doc", testUserAlice, ""},
|
||||
{"special chars in docID", "doc/with:special#chars", testUserAlice, "nonce"},
|
||||
{"unicode in docID", "文档-测试", testUserAlice, "nonce"},
|
||||
{"long docID", string(make([]byte, 1000)), testUserAlice, "nonce"},
|
||||
{"long nonce", string(make([]byte, 1000)), testUserAlice, "nonce"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Fill long strings with valid data
|
||||
if tc.docID == string(make([]byte, 1000)) {
|
||||
tc.docID = "long-doc-" + string(make([]rune, 990))
|
||||
for i := range tc.docID[9:] {
|
||||
tc.docID = tc.docID[:9+i] + "a" + tc.docID[9+i+1:]
|
||||
}
|
||||
}
|
||||
if tc.nonce == string(make([]byte, 1000)) {
|
||||
tc.nonce = "long-nonce-" + string(make([]rune, 985))
|
||||
for i := range tc.nonce[11:] {
|
||||
tc.nonce = tc.nonce[:11+i] + "b" + tc.nonce[11+i+1:]
|
||||
}
|
||||
}
|
||||
|
||||
hash, sig, err := signer.CreateSignature(tc.docID, testUserAlice, timestamp, tc.nonce)
|
||||
|
||||
// Should not fail on edge case inputs
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, hash)
|
||||
assert.NotEmpty(t, sig)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEd25519Signer_SignatureVerification(t *testing.T) {
|
||||
signer, err := NewEd25519Signer()
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("signature can be verified", func(t *testing.T) {
|
||||
user := testUserAlice
|
||||
docID := "verify-test"
|
||||
timestamp := time.Date(2024, 3, 1, 9, 15, 30, 0, time.UTC)
|
||||
nonce := "verify-nonce"
|
||||
|
||||
hashB64, sigB64, err := signer.CreateSignature(docID, user, timestamp, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Decode signature and hash
|
||||
sigBytes, err := base64.StdEncoding.DecodeString(sigB64)
|
||||
require.NoError(t, err)
|
||||
|
||||
hashBytes, err := base64.StdEncoding.DecodeString(hashB64)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get public key
|
||||
pubKeyB64 := signer.GetPublicKey()
|
||||
pubKeyBytes, err := base64.StdEncoding.DecodeString(pubKeyB64)
|
||||
require.NoError(t, err)
|
||||
|
||||
pubKey := ed25519.PublicKey(pubKeyBytes)
|
||||
|
||||
// Verify signature against hash
|
||||
isValid := ed25519.Verify(pubKey, hashBytes, sigBytes)
|
||||
assert.True(t, isValid, "Generated signature should be valid")
|
||||
})
|
||||
|
||||
t.Run("corrupted signature fails verification", func(t *testing.T) {
|
||||
user := testUserBob
|
||||
docID := "corrupt-test"
|
||||
timestamp := time.Now().UTC()
|
||||
nonce := "corrupt-nonce"
|
||||
|
||||
hashB64, sigB64, err := signer.CreateSignature(docID, user, timestamp, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Corrupt the signature
|
||||
sigBytes, err := base64.StdEncoding.DecodeString(sigB64)
|
||||
require.NoError(t, err)
|
||||
|
||||
sigBytes[0] ^= 0xFF // Flip bits in first byte
|
||||
corruptedSig := base64.StdEncoding.EncodeToString(sigBytes)
|
||||
|
||||
// Try to verify corrupted signature
|
||||
hashBytes, err := base64.StdEncoding.DecodeString(hashB64)
|
||||
require.NoError(t, err)
|
||||
|
||||
pubKeyB64 := signer.GetPublicKey()
|
||||
pubKeyBytes, err := base64.StdEncoding.DecodeString(pubKeyB64)
|
||||
require.NoError(t, err)
|
||||
|
||||
pubKey := ed25519.PublicKey(pubKeyBytes)
|
||||
corruptedSigBytes, err := base64.StdEncoding.DecodeString(corruptedSig)
|
||||
require.NoError(t, err)
|
||||
|
||||
isValid := ed25519.Verify(pubKey, hashBytes, corruptedSigBytes)
|
||||
assert.False(t, isValid, "Corrupted signature should not be valid")
|
||||
})
|
||||
}
|
||||
|
||||
func TestEd25519Signer_PayloadGeneration(t *testing.T) {
|
||||
signer, err := NewEd25519Signer()
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("canonical payload format", func(t *testing.T) {
|
||||
user := testUserAlice
|
||||
docID := "payload-test"
|
||||
timestamp := time.Date(2024, 4, 1, 12, 0, 0, 0, time.UTC)
|
||||
nonce := "payload-nonce"
|
||||
|
||||
hash1, _, err := signer.CreateSignature(docID, user, timestamp, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Manually create expected payload to verify format
|
||||
expectedPayload := []byte("doc_id=payload-test\nuser_sub=user-123-alice\nuser_email=alice@example.com\nsigned_at_utc=2024-04-01T12:00:00Z\nnonce=payload-nonce\n")
|
||||
expectedHash := sha256.Sum256(expectedPayload)
|
||||
expectedHashB64 := base64.StdEncoding.EncodeToString(expectedHash[:])
|
||||
|
||||
assert.Equal(t, expectedHashB64, hash1, "Hash should match canonical payload format")
|
||||
})
|
||||
|
||||
t.Run("email normalization in payload", func(t *testing.T) {
|
||||
// Create user with mixed case email
|
||||
user := &models.User{
|
||||
Sub: "user-email-test",
|
||||
Email: "Test.User@EXAMPLE.COM",
|
||||
Name: "Test User",
|
||||
}
|
||||
|
||||
docID := "email-test"
|
||||
timestamp := time.Date(2024, 5, 1, 10, 0, 0, 0, time.UTC)
|
||||
nonce := "email-nonce"
|
||||
|
||||
hash, _, err := signer.CreateSignature(docID, user, timestamp, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create expected payload with normalized (lowercase) email
|
||||
expectedPayload := []byte("doc_id=email-test\nuser_sub=user-email-test\nuser_email=test.user@example.com\nsigned_at_utc=2024-05-01T10:00:00Z\nnonce=email-nonce\n")
|
||||
expectedHash := sha256.Sum256(expectedPayload)
|
||||
expectedHashB64 := base64.StdEncoding.EncodeToString(expectedHash[:])
|
||||
|
||||
assert.Equal(t, expectedHashB64, hash, "Payload should use normalized lowercase email")
|
||||
})
|
||||
|
||||
t.Run("timestamp format consistency", func(t *testing.T) {
|
||||
user := testUserCharlie
|
||||
docID := "time-format-test"
|
||||
nonce := "time-nonce"
|
||||
|
||||
// Test different timezone inputs but same UTC moment
|
||||
utcTime := time.Date(2024, 6, 1, 15, 30, 45, 123456789, time.UTC)
|
||||
localTime := utcTime.In(time.Local)
|
||||
|
||||
hash1, _, err := signer.CreateSignature(docID, user, utcTime, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
hash2, _, err := signer.CreateSignature(docID, user, localTime, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should produce same hash as both represent same UTC moment
|
||||
assert.Equal(t, hash1, hash2, "Different timezone representations of same moment should produce same hash")
|
||||
})
|
||||
}
|
||||
|
||||
func TestEd25519Signer_GetPublicKey(t *testing.T) {
|
||||
t.Run("returns consistent public key", func(t *testing.T) {
|
||||
signer, err := NewEd25519Signer()
|
||||
require.NoError(t, err)
|
||||
|
||||
pubKey1 := signer.GetPublicKey()
|
||||
pubKey2 := signer.GetPublicKey()
|
||||
|
||||
assert.Equal(t, pubKey1, pubKey2, "Public key should be consistent across calls")
|
||||
assert.NotEmpty(t, pubKey1, "Public key should not be empty")
|
||||
})
|
||||
|
||||
t.Run("public key is valid base64", func(t *testing.T) {
|
||||
signer, err := NewEd25519Signer()
|
||||
require.NoError(t, err)
|
||||
|
||||
pubKeyB64 := signer.GetPublicKey()
|
||||
pubKeyBytes, err := base64.StdEncoding.DecodeString(pubKeyB64)
|
||||
|
||||
require.NoError(t, err, "Public key should be valid base64")
|
||||
assert.Len(t, pubKeyBytes, ed25519.PublicKeySize, "Public key should be correct length")
|
||||
})
|
||||
|
||||
t.Run("different signers have different public keys", func(t *testing.T) {
|
||||
// Clear environment to force generation of different keys
|
||||
originalKey := os.Getenv("ED25519_PRIVATE_KEY_B64")
|
||||
os.Unsetenv("ED25519_PRIVATE_KEY_B64")
|
||||
defer func() {
|
||||
if originalKey != "" {
|
||||
os.Setenv("ED25519_PRIVATE_KEY_B64", originalKey)
|
||||
}
|
||||
}()
|
||||
|
||||
signer1, err := NewEd25519Signer()
|
||||
require.NoError(t, err)
|
||||
|
||||
signer2, err := NewEd25519Signer()
|
||||
require.NoError(t, err)
|
||||
|
||||
pubKey1 := signer1.GetPublicKey()
|
||||
pubKey2 := signer2.GetPublicKey()
|
||||
|
||||
assert.NotEqual(t, pubKey1, pubKey2, "Different signers should have different public keys")
|
||||
})
|
||||
}
|
||||
|
||||
func TestEd25519Signer_InterfaceCompliance(t *testing.T) {
|
||||
t.Run("concrete type methods work", func(t *testing.T) {
|
||||
signer, err := NewEd25519Signer()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test methods are accessible directly on concrete type
|
||||
pubKey := signer.GetPublicKey()
|
||||
assert.NotEmpty(t, pubKey)
|
||||
|
||||
user := testUserAlice
|
||||
hash, sig, err := signer.CreateSignature("test", user, time.Now(), "nonce")
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, hash)
|
||||
assert.NotEmpty(t, sig)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package crypto
|
||||
|
||||
import "ackify/internal/domain/models"
|
||||
|
||||
// Internal test fixtures to avoid external dependencies
|
||||
|
||||
var (
|
||||
testUserAlice = &models.User{
|
||||
Sub: "user-123-alice",
|
||||
Email: "alice@example.com",
|
||||
Name: "Alice Smith",
|
||||
}
|
||||
|
||||
testUserBob = &models.User{
|
||||
Sub: "user-456-bob",
|
||||
Email: "bob@example.com",
|
||||
Name: "Bob Johnson",
|
||||
}
|
||||
|
||||
testUserCharlie = &models.User{
|
||||
Sub: "user-789-charlie",
|
||||
Email: "charlie@example.com",
|
||||
Name: "Charlie Brown",
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,16 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
)
|
||||
|
||||
// GenerateNonce generates a cryptographically secure random nonce
|
||||
func GenerateNonce() (string, error) {
|
||||
nonceBytes := make([]byte, 16)
|
||||
if _, err := rand.Read(nonceBytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return base64.RawURLEncoding.EncodeToString(nonceBytes), nil
|
||||
}
|
||||
@@ -0,0 +1,243 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGenerateNonce(t *testing.T) {
|
||||
t.Run("generates valid nonce", func(t *testing.T) {
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, nonce)
|
||||
})
|
||||
|
||||
t.Run("nonce is valid base64url", func(t *testing.T) {
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should be decodable as base64url
|
||||
decoded, err := base64.RawURLEncoding.DecodeString(nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should be 16 bytes (128 bits) when decoded
|
||||
assert.Len(t, decoded, 16, "Decoded nonce should be 16 bytes")
|
||||
})
|
||||
|
||||
t.Run("generates unique nonces", func(t *testing.T) {
|
||||
const numNonces = 1000
|
||||
nonces := make(map[string]bool)
|
||||
|
||||
for i := 0; i < numNonces; i++ {
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check for duplicates
|
||||
assert.False(t, nonces[nonce], "Nonce %s should be unique", nonce)
|
||||
nonces[nonce] = true
|
||||
}
|
||||
|
||||
assert.Len(t, nonces, numNonces, "All nonces should be unique")
|
||||
})
|
||||
|
||||
t.Run("nonce format consistency", func(t *testing.T) {
|
||||
for i := 0; i < 100; i++ {
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should not be empty
|
||||
assert.NotEmpty(t, nonce)
|
||||
|
||||
// Should not contain padding (RawURLEncoding)
|
||||
assert.NotContains(t, nonce, "=", "Nonce should not contain padding")
|
||||
|
||||
// Should only contain valid base64url characters
|
||||
assert.Regexp(t, `^[A-Za-z0-9_-]+$`, nonce, "Nonce should only contain base64url characters")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nonce length consistency", func(t *testing.T) {
|
||||
var lengths []int
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
lengths = append(lengths, len(nonce))
|
||||
}
|
||||
|
||||
// All nonces should have same length
|
||||
expectedLength := lengths[0]
|
||||
for _, length := range lengths {
|
||||
assert.Equal(t, expectedLength, length, "All nonces should have consistent length")
|
||||
}
|
||||
|
||||
// For 16 bytes (128 bits), base64url without padding should be 22 characters
|
||||
// 16 bytes = 128 bits = 128/6 = 21.33 -> 22 characters (rounded up)
|
||||
assert.Equal(t, 22, expectedLength, "Nonce should be 22 characters long")
|
||||
})
|
||||
|
||||
t.Run("concurrent nonce generation", func(t *testing.T) {
|
||||
const numGoroutines = 100
|
||||
const noncesPerGoroutine = 10
|
||||
|
||||
nonceChan := make(chan string, numGoroutines*noncesPerGoroutine)
|
||||
errorChan := make(chan error, numGoroutines*noncesPerGoroutine)
|
||||
|
||||
// Start multiple goroutines generating nonces
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func() {
|
||||
for j := 0; j < noncesPerGoroutine; j++ {
|
||||
nonce, err := GenerateNonce()
|
||||
if err != nil {
|
||||
errorChan <- err
|
||||
return
|
||||
}
|
||||
nonceChan <- nonce
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Collect results
|
||||
nonces := make(map[string]bool)
|
||||
for i := 0; i < numGoroutines*noncesPerGoroutine; i++ {
|
||||
select {
|
||||
case nonce := <-nonceChan:
|
||||
assert.False(t, nonces[nonce], "Concurrent nonce %s should be unique", nonce)
|
||||
nonces[nonce] = true
|
||||
case err := <-errorChan:
|
||||
t.Fatalf("Concurrent nonce generation failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Len(t, nonces, numGoroutines*noncesPerGoroutine, "All concurrent nonces should be unique")
|
||||
})
|
||||
|
||||
t.Run("nonce entropy validation", func(t *testing.T) {
|
||||
const numNonces = 1000
|
||||
bitCounts := make([]int, 8) // Count bits 0-7 across all bytes
|
||||
|
||||
for i := 0; i < numNonces; i++ {
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
decoded, err := base64.RawURLEncoding.DecodeString(nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Count bit frequency
|
||||
for _, b := range decoded {
|
||||
for bit := 0; bit < 8; bit++ {
|
||||
if (b>>bit)&1 == 1 {
|
||||
bitCounts[bit]++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Each bit should appear roughly 50% of the time (within reasonable variance)
|
||||
expectedCount := numNonces * 16 / 2 // 16 bytes per nonce, expect 50% ones
|
||||
tolerance := expectedCount / 10 // 10% tolerance
|
||||
|
||||
for bit, count := range bitCounts {
|
||||
assert.InDelta(t, expectedCount, count, float64(tolerance),
|
||||
"Bit %d should have balanced distribution (got %d, expected ~%d)",
|
||||
bit, count, expectedCount)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nonce base64url safety", func(t *testing.T) {
|
||||
for i := 0; i < 100; i++ {
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should not contain characters that need URL encoding
|
||||
assert.NotContains(t, nonce, "+", "Nonce should not contain + (use URL-safe base64)")
|
||||
assert.NotContains(t, nonce, "/", "Nonce should not contain / (use URL-safe base64)")
|
||||
assert.NotContains(t, nonce, "=", "Nonce should not contain = (use RawURLEncoding)")
|
||||
|
||||
// Should be safe for use in URLs and forms
|
||||
assert.Regexp(t, `^[A-Za-z0-9_-]+$`, nonce, "Nonce should only contain URL-safe characters")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nonce anti-replay properties", func(t *testing.T) {
|
||||
// Generate a large set of nonces to verify anti-replay properties
|
||||
const numNonces = 10000
|
||||
nonces := make([]string, 0, numNonces)
|
||||
nonceSet := make(map[string]bool)
|
||||
|
||||
for i := 0; i < numNonces; i++ {
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify uniqueness (anti-replay)
|
||||
assert.False(t, nonceSet[nonce], "Nonce should not repeat (anti-replay)")
|
||||
nonceSet[nonce] = true
|
||||
nonces = append(nonces, nonce)
|
||||
}
|
||||
|
||||
// Verify we generated the expected number of unique nonces
|
||||
assert.Len(t, nonces, numNonces)
|
||||
assert.Len(t, nonceSet, numNonces)
|
||||
|
||||
// Verify sufficient entropy - no obvious patterns
|
||||
// Check that first characters are well distributed
|
||||
firstChars := make(map[byte]int)
|
||||
for _, nonce := range nonces {
|
||||
firstChars[nonce[0]]++
|
||||
}
|
||||
|
||||
// Should have reasonable distribution of first characters
|
||||
assert.Greater(t, len(firstChars), 10, "First character should have good distribution")
|
||||
})
|
||||
|
||||
t.Run("nonce cryptographic strength", func(t *testing.T) {
|
||||
// Test that nonces have sufficient randomness
|
||||
nonce1, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
nonce2, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Different nonces should be completely different
|
||||
assert.NotEqual(t, nonce1, nonce2)
|
||||
|
||||
// Decode both nonces
|
||||
decoded1, err := base64.RawURLEncoding.DecodeString(nonce1)
|
||||
require.NoError(t, err)
|
||||
|
||||
decoded2, err := base64.RawURLEncoding.DecodeString(nonce2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should have no common bytes (extremely unlikely with crypto/rand)
|
||||
commonBytes := 0
|
||||
for i := range decoded1 {
|
||||
if decoded1[i] == decoded2[i] {
|
||||
commonBytes++
|
||||
}
|
||||
}
|
||||
|
||||
// With truly random data, expect 0-2 common bytes in 16-byte sequences
|
||||
assert.LessOrEqual(t, commonBytes, 3, "Too many common bytes between random nonces")
|
||||
})
|
||||
|
||||
t.Run("error handling edge cases", func(t *testing.T) {
|
||||
// This test verifies the function handles errors gracefully
|
||||
// In normal conditions, GenerateNonce should not fail
|
||||
// but we test that error handling pattern is correct
|
||||
|
||||
nonce, err := GenerateNonce()
|
||||
|
||||
// In normal cases, should always succeed
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, nonce)
|
||||
|
||||
// If it did error, nonce should be empty string
|
||||
if err != nil {
|
||||
assert.Empty(t, nonce, "On error, nonce should be empty")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
)
|
||||
|
||||
var Logger *slog.Logger
|
||||
|
||||
func init() {
|
||||
Logger = slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
|
||||
Level: slog.LevelInfo,
|
||||
}))
|
||||
}
|
||||
|
||||
func SetLevel(level slog.Level) {
|
||||
Logger = slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
|
||||
Level: level,
|
||||
}))
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package services
|
||||
|
||||
// ServiceInfo contains information about a detected service
|
||||
type ServiceInfo struct {
|
||||
Name string
|
||||
Icon string // Simple Icons CDN URL for SVG icon
|
||||
Type string // "docs", "sheets", "notes", "wiki", etc.
|
||||
Referrer string // Original referrer parameter value
|
||||
}
|
||||
|
||||
// DetectServiceFromReferrer detects the service from a 'referrer' parameter
|
||||
func DetectServiceFromReferrer(referrerParam string) *ServiceInfo {
|
||||
if referrerParam == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Mapping des paramètres referrer vers les services
|
||||
switch referrerParam {
|
||||
// Google services
|
||||
case "google-docs":
|
||||
return &ServiceInfo{Name: "Google Docs", Icon: "https://cdn.simpleicons.org/googledocs", Type: "docs", Referrer: referrerParam}
|
||||
case "google-sheets":
|
||||
return &ServiceInfo{Name: "Google Sheets", Icon: "https://cdn.simpleicons.org/googlesheets", Type: "sheets", Referrer: referrerParam}
|
||||
case "google-slides":
|
||||
return &ServiceInfo{Name: "Google Slides", Icon: "https://cdn.simpleicons.org/googleslides", Type: "presentation", Referrer: referrerParam}
|
||||
case "google-drive":
|
||||
return &ServiceInfo{Name: "Google Drive", Icon: "https://cdn.simpleicons.org/googledrive", Type: "storage", Referrer: referrerParam}
|
||||
case "google":
|
||||
return &ServiceInfo{Name: "Google", Icon: "https://cdn.simpleicons.org/google", Type: "google", Referrer: referrerParam}
|
||||
|
||||
// Notion
|
||||
case "notion":
|
||||
return &ServiceInfo{Name: "Notion", Icon: "https://cdn.simpleicons.org/notion", Type: "notes", Referrer: referrerParam}
|
||||
|
||||
// Confluence
|
||||
case "confluence":
|
||||
return &ServiceInfo{Name: "Confluence", Icon: "https://cdn.simpleicons.org/confluence", Type: "wiki", Referrer: referrerParam}
|
||||
|
||||
// Microsoft
|
||||
case "microsoft":
|
||||
return &ServiceInfo{Name: "Microsoft Office", Icon: "https://cdn.simpleicons.org/microsoft", Type: "office", Referrer: referrerParam}
|
||||
|
||||
// GitHub
|
||||
case "github":
|
||||
return &ServiceInfo{Name: "GitHub", Icon: "https://cdn.simpleicons.org/github", Type: "code", Referrer: referrerParam}
|
||||
|
||||
// GitLab
|
||||
case "gitlab":
|
||||
return &ServiceInfo{Name: "GitLab", Icon: "https://cdn.simpleicons.org/gitlab", Type: "code", Referrer: referrerParam}
|
||||
|
||||
// Outline
|
||||
case "outline":
|
||||
return &ServiceInfo{Name: "Outline", Icon: "https://cdn.simpleicons.org/outline", Type: "wiki", Referrer: referrerParam}
|
||||
|
||||
// Communication
|
||||
case "slack":
|
||||
return &ServiceInfo{Name: "Slack", Icon: "https://cdn.simpleicons.org/slack", Type: "chat", Referrer: referrerParam}
|
||||
case "discord":
|
||||
return &ServiceInfo{Name: "Discord", Icon: "https://cdn.simpleicons.org/discord", Type: "chat", Referrer: referrerParam}
|
||||
|
||||
// Project management
|
||||
case "trello":
|
||||
return &ServiceInfo{Name: "Trello", Icon: "https://cdn.simpleicons.org/trello", Type: "boards", Referrer: referrerParam}
|
||||
case "asana":
|
||||
return &ServiceInfo{Name: "Asana", Icon: "https://cdn.simpleicons.org/asana", Type: "tasks", Referrer: referrerParam}
|
||||
case "monday":
|
||||
return &ServiceInfo{Name: "Monday.com", Icon: "https://cdn.simpleicons.org/monday", Type: "project", Referrer: referrerParam}
|
||||
|
||||
// Design
|
||||
case "figma":
|
||||
return &ServiceInfo{Name: "Figma", Icon: "https://cdn.simpleicons.org/figma", Type: "design", Referrer: referrerParam}
|
||||
case "miro":
|
||||
return &ServiceInfo{Name: "Miro", Icon: "https://cdn.simpleicons.org/miro", Type: "whiteboard", Referrer: referrerParam}
|
||||
|
||||
// Storage
|
||||
case "dropbox":
|
||||
return &ServiceInfo{Name: "Dropbox", Icon: "https://cdn.simpleicons.org/dropbox", Type: "storage", Referrer: referrerParam}
|
||||
|
||||
default:
|
||||
// Paramètre referrer personnalisé - utiliser tel quel
|
||||
return &ServiceInfo{Name: referrerParam, Icon: "https://cdn.simpleicons.org/link", Type: "custom", Referrer: referrerParam}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
{{define "base"}}<!doctype html>
|
||||
<html lang="fr">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<title>Service de validation de lecture</title>
|
||||
{{if .DocID}}
|
||||
<link rel="alternate" type="application/json+oembed" href="/oembed?url={{.BaseURL}}/sign?doc={{.DocID}}&format=json" title="Signataires du document {{.DocID}}" />
|
||||
{{end}}
|
||||
<script src="https://cdn.tailwindcss.com"></script>
|
||||
<script>
|
||||
tailwind.config = {
|
||||
theme: {
|
||||
extend: {
|
||||
colors: {
|
||||
primary: { 50: '#eff6ff', 100: '#dbeafe', 500: '#3b82f6', 600: '#2563eb', 700: '#1d4ed8', 900: '#1e3a8a' },
|
||||
success: { 50: '#f0fdf4', 100: '#dcfce7', 500: '#22c55e', 600: '#16a34a', 700: '#15803d' },
|
||||
warning: { 50: '#fffbeb', 100: '#fef3c7', 500: '#f59e0b', 600: '#d97706' },
|
||||
danger: { 50: '#fef2f2', 100: '#fecaca', 500: '#ef4444', 600: '#dc2626' }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
</script>
|
||||
</head>
|
||||
<body class="min-h-screen bg-gradient-to-br from-slate-50 to-blue-50">
|
||||
<div class="min-h-screen flex flex-col">
|
||||
<header class="bg-white/80 backdrop-blur-sm border-b border-slate-200 sticky top-0 z-10">
|
||||
<div class="max-w-4xl mx-auto px-6 py-4">
|
||||
<div class="flex items-center justify-between">
|
||||
<div class="flex items-center space-x-3">
|
||||
<div class="w-8 h-8 bg-primary-600 rounded-lg flex items-center justify-center">
|
||||
<svg class="w-5 h-5 text-white" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 12l2 2 4-4m6 2a9 9 0 11-18 0 9 9 0 0118 0z"/>
|
||||
</svg>
|
||||
</div>
|
||||
<h1 class="text-xl font-bold text-slate-900">Service de validation de lecture</h1>
|
||||
</div>
|
||||
{{if .User}}
|
||||
<div class="flex items-center space-x-4">
|
||||
<div class="text-sm text-slate-600">
|
||||
<span class="inline-flex items-center space-x-2">
|
||||
<div class="w-6 h-6 bg-primary-100 rounded-full flex items-center justify-center">
|
||||
<svg class="w-3 h-3 text-primary-600" fill="currentColor" viewBox="0 0 20 20">
|
||||
<path d="M10 9a3 3 0 100-6 3 3 0 000 6zm-7 9a7 7 0 1114 0H3z"/>
|
||||
</svg>
|
||||
</div>
|
||||
<span>{{if .User.Name}}{{.User.Name}}{{else}}{{.User.Email}}{{end}}</span>
|
||||
</span>
|
||||
</div>
|
||||
<a href="/logout" class="text-sm text-slate-500 hover:text-slate-700 underline">Déconnexion</a>
|
||||
</div>
|
||||
{{end}}
|
||||
</div>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
<main class="flex-1 py-8">
|
||||
<div class="max-w-4xl mx-auto px-6">
|
||||
{{if eq .TemplateName "sign"}}{{template "sign" .}}{{else if eq .TemplateName "signatures"}}{{template "signatures" .}}{{else}}{{template "index" .}}{{end}}
|
||||
</div>
|
||||
</main>
|
||||
|
||||
<footer class="bg-white/50 backdrop-blur-sm border-t border-slate-200 py-6">
|
||||
<div class="max-w-4xl mx-auto px-6">
|
||||
<div class="text-center space-y-2">
|
||||
<p class="text-xs text-slate-400">
|
||||
Développé par
|
||||
<a href="mailto:benjamin@kolapsis.com" class="text-primary-600 hover:text-primary-700 font-medium">Benjamin Touchard</a>
|
||||
<span class="mx-1">•</span>
|
||||
<a href="mailto:benjamin@kolapsis.com" class="text-slate-500 hover:text-slate-600">benjamin@kolapsis.com</a>
|
||||
<span class="mx-1">•</span>
|
||||
<span class="text-slate-400">@2025</span>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</footer>
|
||||
</div>
|
||||
</body>
|
||||
</html>{{end}}
|
||||
@@ -0,0 +1,594 @@
|
||||
{{define "embed"}}<!DOCTYPE html>
|
||||
<html lang="fr">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>Signataires - Document {{.DocID}}</title>
|
||||
<style>
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
html, body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||
background: #ffffff;
|
||||
color: #334155;
|
||||
line-height: 1.4;
|
||||
padding: 0;
|
||||
margin: 0;
|
||||
height: 100%;
|
||||
overflow-x: hidden;
|
||||
}
|
||||
|
||||
.embed-container {
|
||||
background: white;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.08);
|
||||
border: 1px solid #e2e8f0;
|
||||
overflow: hidden;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
min-width: 280px;
|
||||
max-width: 100%;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.header {
|
||||
background: linear-gradient(135deg, #3b82f6 0%, #1d4ed8 100%);
|
||||
color: white;
|
||||
padding: 10px 16px;
|
||||
border-bottom: 1px solid #e2e8f0;
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.header h3 {
|
||||
font-size: 16px;
|
||||
font-weight: 600;
|
||||
margin: 0;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.header .parent-domain {
|
||||
font-size: 11px;
|
||||
opacity: 0.8;
|
||||
text-align: right;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.header .doc-id {
|
||||
font-family: 'Monaco', 'Menlo', 'Consolas', monospace;
|
||||
background: rgba(255, 255, 255, 0.2);
|
||||
padding: 3px 6px;
|
||||
border-radius: 4px;
|
||||
font-size: 12px;
|
||||
word-break: break-all;
|
||||
max-width: 120px;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.stats {
|
||||
background: #f8fafc;
|
||||
padding: 10px 16px;
|
||||
border-bottom: 1px solid #e2e8f0;
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: flex-start;
|
||||
font-size: 13px;
|
||||
gap: 8px;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.stats .count {
|
||||
font-weight: 600;
|
||||
color: #059669;
|
||||
}
|
||||
|
||||
.stats .last-signed {
|
||||
color: #6b7280;
|
||||
text-align: right;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.signatories {
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
min-height: 0;
|
||||
}
|
||||
|
||||
.signatory {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
padding: 10px 16px;
|
||||
border-bottom: 1px solid #f1f5f9;
|
||||
}
|
||||
|
||||
.signatory:last-child {
|
||||
border-bottom: none;
|
||||
}
|
||||
|
||||
.signatory:hover {
|
||||
background: #f8fafc;
|
||||
}
|
||||
|
||||
.signatory-info {
|
||||
flex: 1;
|
||||
min-width: 0;
|
||||
}
|
||||
|
||||
.signatory-email {
|
||||
font-weight: 500;
|
||||
color: #1e293b;
|
||||
font-size: 13px;
|
||||
word-break: break-word;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.signatory-date {
|
||||
color: #64748b;
|
||||
font-size: 11px;
|
||||
margin-top: 2px;
|
||||
}
|
||||
|
||||
.signature-icon {
|
||||
width: 24px;
|
||||
height: 24px;
|
||||
background: #dcfce7;
|
||||
border-radius: 50%;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
margin-right: 10px;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.signature-icon svg {
|
||||
width: 12px;
|
||||
height: 12px;
|
||||
color: #059669;
|
||||
}
|
||||
|
||||
.empty-state {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
padding: 30px 16px;
|
||||
text-align: center;
|
||||
color: #64748b;
|
||||
min-height: 0;
|
||||
}
|
||||
|
||||
.empty-state svg {
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
color: #cbd5e1;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
.empty-state p {
|
||||
font-size: 14px;
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
|
||||
.footer {
|
||||
background: #f8fafc;
|
||||
padding: 10px 16px;
|
||||
text-align: center;
|
||||
border-top: 1px solid #e2e8f0;
|
||||
}
|
||||
|
||||
.sign-button {
|
||||
background: linear-gradient(135deg, #3b82f6 0%, #1d4ed8 100%);
|
||||
color: white;
|
||||
text-decoration: none;
|
||||
font-size: 13px;
|
||||
font-weight: 600;
|
||||
display: inline-block;
|
||||
padding: 10px 20px;
|
||||
border-radius: 6px;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease;
|
||||
box-shadow: 0 2px 4px rgba(59, 130, 246, 0.2);
|
||||
}
|
||||
|
||||
.sign-button:hover {
|
||||
transform: translateY(-1px);
|
||||
box-shadow: 0 4px 8px rgba(59, 130, 246, 0.3);
|
||||
text-decoration: none;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.sign-button:active {
|
||||
transform: translateY(0px);
|
||||
box-shadow: 0 2px 4px rgba(59, 130, 246, 0.2);
|
||||
}
|
||||
|
||||
/* Scrollbar styling */
|
||||
.signatories::-webkit-scrollbar {
|
||||
width: 3px;
|
||||
}
|
||||
|
||||
.signatories::-webkit-scrollbar-track {
|
||||
background: #f1f5f9;
|
||||
}
|
||||
|
||||
.signatories::-webkit-scrollbar-thumb {
|
||||
background: #cbd5e1;
|
||||
border-radius: 2px;
|
||||
}
|
||||
|
||||
.signatories::-webkit-scrollbar-thumb:hover {
|
||||
background: #94a3b8;
|
||||
}
|
||||
|
||||
/* Responsive design for very narrow screens */
|
||||
@media (max-width: 320px) {
|
||||
.header {
|
||||
padding: 8px 12px;
|
||||
flex-direction: column;
|
||||
align-items: flex-start;
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.header .parent-domain {
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
.header h3 {
|
||||
font-size: 14px;
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.header .doc-id {
|
||||
font-size: 11px;
|
||||
max-width: 100px;
|
||||
}
|
||||
|
||||
.stats {
|
||||
padding: 8px 12px;
|
||||
font-size: 12px;
|
||||
flex-direction: column;
|
||||
align-items: flex-start;
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.signatory {
|
||||
padding: 8px 12px;
|
||||
}
|
||||
|
||||
.signatory-email {
|
||||
font-size: 12px;
|
||||
}
|
||||
|
||||
.signatory-date {
|
||||
font-size: 10px;
|
||||
}
|
||||
|
||||
.signature-icon {
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
margin-right: 8px;
|
||||
}
|
||||
|
||||
.signature-icon svg {
|
||||
width: 10px;
|
||||
height: 10px;
|
||||
}
|
||||
|
||||
.footer {
|
||||
padding: 8px 12px;
|
||||
}
|
||||
|
||||
.sign-button {
|
||||
font-size: 12px;
|
||||
padding: 8px 16px;
|
||||
}
|
||||
|
||||
.empty-state {
|
||||
padding: 20px 12px;
|
||||
}
|
||||
}
|
||||
|
||||
/* Google Drive sidebar specific optimizations */
|
||||
@media (max-width: 400px) {
|
||||
.embed-container {
|
||||
border-radius: 4px;
|
||||
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
.header h3 {
|
||||
font-size: 15px;
|
||||
}
|
||||
|
||||
.stats .last-signed {
|
||||
font-size: 11px;
|
||||
line-height: 1.3;
|
||||
}
|
||||
|
||||
.signatories {
|
||||
flex: 1;
|
||||
min-height: 0;
|
||||
}
|
||||
}
|
||||
|
||||
/* Compact mode for iframe embedding */
|
||||
.compact .signatories {
|
||||
flex: 1;
|
||||
min-height: 0;
|
||||
}
|
||||
|
||||
.compact .empty-state {
|
||||
padding: 20px 16px;
|
||||
flex: 1;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.compact .empty-state svg {
|
||||
width: 32px;
|
||||
height: 32px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="embed-container">
|
||||
<div class="header">
|
||||
<h3>
|
||||
<svg width="20" height="20" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 12l2 2 4-4m6 2a9 9 0 11-18 0 9 9 0 0118 0z"/>
|
||||
</svg>
|
||||
Signataires
|
||||
<span class="doc-id">{{.DocID}}</span>
|
||||
</h3>
|
||||
<div id="parent-domain" class="parent-domain"></div>
|
||||
</div>
|
||||
|
||||
{{if gt .Count 0}}
|
||||
<div class="stats">
|
||||
<span class="count">{{.Count}} signature{{if gt .Count 1}}s{{end}}</span>
|
||||
{{if .LastSignedAt}}
|
||||
<span class="last-signed">Dernière signature le {{.LastSignedAt}}</span>
|
||||
{{end}}
|
||||
</div>
|
||||
|
||||
<div class="signatories">
|
||||
{{range .Signatures}}
|
||||
<div class="signatory">
|
||||
<div class="signature-icon">
|
||||
<svg fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M5 13l4 4L19 7"/>
|
||||
</svg>
|
||||
</div>
|
||||
<div class="signatory-info">
|
||||
<div class="signatory-email">{{if .Name}}{{.Name}} • {{end}}{{.Email}}</div>
|
||||
<div class="signatory-date">{{.SignedAt}}</div>
|
||||
</div>
|
||||
</div>
|
||||
{{end}}
|
||||
</div>
|
||||
{{else}}
|
||||
<div class="empty-state">
|
||||
<svg fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 12h6m-6 4h6m2 5H7a2 2 0 01-2-2V5a2 2 0 012-2h5.586a1 1 0 01.707.293l5.414 5.414a1 1 0 01.293.707V19a2 2 0 01-2 2z"/>
|
||||
</svg>
|
||||
<p><strong>Aucune signature</strong></p>
|
||||
<p>Ce document n'a pas encore été signé.</p>
|
||||
</div>
|
||||
{{end}}
|
||||
|
||||
<div class="footer">
|
||||
<a href="{{$.SignURL}}" target="_blank" class="sign-button">
|
||||
Signer et confirmer la lecture de ce document
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
// Variable globale pour stocker les infos du referrer détecté
|
||||
let detectedReferrer = null;
|
||||
|
||||
// Détecter le domaine parent de l'iframe
|
||||
function detectParentDomain() {
|
||||
const parentDomainEl = document.getElementById('parent-domain');
|
||||
|
||||
try {
|
||||
// Essayer d'accéder au domaine parent
|
||||
let parentHost = '';
|
||||
let parentOrigin = '';
|
||||
|
||||
// Vérifier si on est dans un iframe
|
||||
if (window.parent !== window.self) {
|
||||
try {
|
||||
// Tenter d'accéder à l'URL du parent (peut échouer à cause de CORS)
|
||||
parentHost = window.parent.location.hostname;
|
||||
parentOrigin = window.parent.location.origin;
|
||||
} catch (e) {
|
||||
// Si bloqué par CORS, essayer avec document.referrer
|
||||
if (document.referrer) {
|
||||
try {
|
||||
const referrerUrl = new URL(document.referrer);
|
||||
parentHost = referrerUrl.hostname;
|
||||
parentOrigin = referrerUrl.origin;
|
||||
} catch (err) {
|
||||
console.log('Impossible de parser le referrer:', err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Afficher les informations si disponibles
|
||||
if (parentHost) {
|
||||
// Détecter le service basé sur le domaine
|
||||
let serviceInfo = detectService(parentHost);
|
||||
|
||||
if (serviceInfo) {
|
||||
parentDomainEl.innerHTML = `${serviceInfo.icon} Intégré dans ${serviceInfo.name}`;
|
||||
// Stocker les infos du referrer pour l'URL de signature
|
||||
detectedReferrer = serviceInfo.referrer;
|
||||
} else {
|
||||
parentDomainEl.innerHTML = `🌐 Intégré dans ${parentHost}`;
|
||||
// Utiliser le domaine nettoyé comme referrer
|
||||
detectedReferrer = parentHost.replace(/[^a-z0-9]/g, '-');
|
||||
}
|
||||
|
||||
// Ajouter l'information comme attribut pour debugging
|
||||
parentDomainEl.setAttribute('data-parent-domain', parentHost);
|
||||
parentDomainEl.setAttribute('data-parent-origin', parentOrigin);
|
||||
parentDomainEl.setAttribute('data-referrer', detectedReferrer);
|
||||
} else {
|
||||
parentDomainEl.innerHTML = '📱 Intégré (origine non détectable)';
|
||||
}
|
||||
} else {
|
||||
// Pas dans un iframe
|
||||
parentDomainEl.innerHTML = '🌐 Vue directe';
|
||||
}
|
||||
} catch (e) {
|
||||
console.log('Erreur lors de la détection du domaine parent:', e);
|
||||
parentDomainEl.innerHTML = '🔒 Origine protégée';
|
||||
}
|
||||
}
|
||||
|
||||
// Fonction pour détecter le service basé sur le hostname
|
||||
function detectService(hostname) {
|
||||
const host = hostname.toLowerCase();
|
||||
|
||||
// Google services (including script.googleusercontent.com)
|
||||
if (host.includes('docs.google.com')) {
|
||||
return { name: 'Google Docs', icon: '📝', referrer: 'google-docs' };
|
||||
}
|
||||
if (host.includes('sheets.google.com')) {
|
||||
return { name: 'Google Sheets', icon: '📊', referrer: 'google-sheets' };
|
||||
}
|
||||
if (host.includes('slides.google.com')) {
|
||||
return { name: 'Google Slides', icon: '📊', referrer: 'google-slides' };
|
||||
}
|
||||
if (host.includes('drive.google.com')) {
|
||||
return { name: 'Google Drive', icon: '💾', referrer: 'google-drive' };
|
||||
}
|
||||
if (host.includes('script.googleusercontent.com') || host.includes('googleusercontent.com')) {
|
||||
return { name: 'Google', icon: '🔵', referrer: 'google' };
|
||||
}
|
||||
if (host.includes('google.com')) {
|
||||
return { name: 'Google', icon: '🔵', referrer: 'google' };
|
||||
}
|
||||
|
||||
// Notion
|
||||
if (host.includes('notion.so') || host.includes('notion.com')) {
|
||||
return { name: 'Notion', icon: '📒', referrer: 'notion' };
|
||||
}
|
||||
|
||||
// Confluence
|
||||
if (host.includes('confluence')) {
|
||||
return { name: 'Confluence', icon: '🌊', referrer: 'confluence' };
|
||||
}
|
||||
|
||||
// Microsoft Office
|
||||
if (host.includes('office.com') || host.includes('sharepoint.com')) {
|
||||
return { name: 'Microsoft Office', icon: '🏢', referrer: 'microsoft' };
|
||||
}
|
||||
if (host.includes('live.com') || host.includes('outlook.com')) {
|
||||
return { name: 'Microsoft', icon: '🏢', referrer: 'microsoft' };
|
||||
}
|
||||
|
||||
// GitHub
|
||||
if (host.includes('github.com')) {
|
||||
return { name: 'GitHub', icon: '🐙', referrer: 'github' };
|
||||
}
|
||||
|
||||
// GitLab
|
||||
if (host.includes('gitlab.com')) {
|
||||
return { name: 'GitLab', icon: '🦊', referrer: 'gitlab' };
|
||||
}
|
||||
if (host.includes('gitlab')) {
|
||||
return { name: 'GitLab', icon: '🦊', referrer: 'gitlab' };
|
||||
}
|
||||
|
||||
// Outline
|
||||
if (host.includes('outline')) {
|
||||
return { name: 'Outline', icon: '📖', referrer: 'outline' };
|
||||
}
|
||||
|
||||
// Slack
|
||||
if (host.includes('slack.com')) {
|
||||
return { name: 'Slack', icon: '💬', referrer: 'slack' };
|
||||
}
|
||||
|
||||
// Discord
|
||||
if (host.includes('discord.com')) {
|
||||
return { name: 'Discord', icon: '💬', referrer: 'discord' };
|
||||
}
|
||||
|
||||
// Trello
|
||||
if (host.includes('trello.com')) {
|
||||
return { name: 'Trello', icon: '📋', referrer: 'trello' };
|
||||
}
|
||||
|
||||
// Asana
|
||||
if (host.includes('asana.com')) {
|
||||
return { name: 'Asana', icon: '✅', referrer: 'asana' };
|
||||
}
|
||||
|
||||
// Monday.com
|
||||
if (host.includes('monday.com')) {
|
||||
return { name: 'Monday.com', icon: '📅', referrer: 'monday' };
|
||||
}
|
||||
|
||||
// Figma
|
||||
if (host.includes('figma.com')) {
|
||||
return { name: 'Figma', icon: '🎨', referrer: 'figma' };
|
||||
}
|
||||
|
||||
// Miro
|
||||
if (host.includes('miro.com')) {
|
||||
return { name: 'Miro', icon: '🎨', referrer: 'miro' };
|
||||
}
|
||||
|
||||
// Dropbox
|
||||
if (host.includes('dropbox.com')) {
|
||||
return { name: 'Dropbox', icon: '📦', referrer: 'dropbox' };
|
||||
}
|
||||
|
||||
// Unknown service - use domain as referrer
|
||||
return { name: host, icon: '🌐', referrer: host.replace(/[^a-z0-9]/g, '-') };
|
||||
}
|
||||
|
||||
// Fonction pour mettre à jour l'URL de signature avec le referrer
|
||||
function updateSignatureURL() {
|
||||
const signButton = document.querySelector('.sign-button');
|
||||
if (signButton && detectedReferrer) {
|
||||
const currentUrl = new URL(signButton.href);
|
||||
currentUrl.searchParams.set('referrer', detectedReferrer);
|
||||
signButton.href = currentUrl.toString();
|
||||
}
|
||||
}
|
||||
|
||||
// Détecter le domaine parent au chargement de la page
|
||||
document.addEventListener('DOMContentLoaded', function() {
|
||||
detectParentDomain();
|
||||
// Petite pause pour s'assurer que detectedReferrer est défini
|
||||
setTimeout(updateSignatureURL, 150);
|
||||
});
|
||||
|
||||
// Retry après un court délai au cas où les permissions changeraient
|
||||
setTimeout(function() {
|
||||
detectParentDomain();
|
||||
updateSignatureURL();
|
||||
}, 100);
|
||||
</script>
|
||||
</body>
|
||||
</html>{{end}}
|
||||
@@ -0,0 +1,95 @@
|
||||
{{define "index"}}
|
||||
<div class="space-y-8">
|
||||
<!-- Hero Section -->
|
||||
<div class="bg-white rounded-3xl shadow-xl border border-slate-200 overflow-hidden">
|
||||
<div class="bg-gradient-to-r from-primary-600 to-primary-700 px-8 py-6">
|
||||
<div class="flex items-center space-x-4">
|
||||
<div class="w-12 h-12 bg-white/20 rounded-2xl flex items-center justify-center">
|
||||
<svg class="w-7 h-7 text-white" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 12h6m-6 4h6m2 5H7a2 2 0 01-2-2V5a2 2 0 012-2h5.586a1 1 0 01.707.293l5.414 5.414a1 1 0 01.293.707V19a2 2 0 01-2 2z"/>
|
||||
</svg>
|
||||
</div>
|
||||
<div>
|
||||
<h2 class="text-2xl font-bold text-white">Ackify</h2>
|
||||
<p class="text-primary-100">La solution professionnelle pour valider la lecture de vos documents</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="px-8 py-8">
|
||||
<form method="GET" action="/sign" class="space-y-6">
|
||||
<div>
|
||||
<div class="flex justify-between items-center mb-3">
|
||||
<label for="doc" class="text-sm font-semibold text-slate-700">
|
||||
Identifiant du document
|
||||
</label>
|
||||
{{if .User}}
|
||||
<a href="/signatures" class="text-sm font-medium text-primary-600 hover:text-primary-700 transition-colors flex items-center space-x-1">
|
||||
<svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 12l2 2 4-4m6 2a9 9 0 11-18 0 9 9 0 0118 0z"/>
|
||||
</svg>
|
||||
<span>Mes signatures</span>
|
||||
</a>
|
||||
{{end}}
|
||||
</div>
|
||||
<div class="relative">
|
||||
<div class="absolute inset-y-0 left-0 pl-4 flex items-center pointer-events-none">
|
||||
<svg class="h-5 w-5 text-slate-400" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M7 7h.01M7 3h5c.512 0 1.024.195 1.414.586l7 7a2 2 0 010 2.828l-7 7a.997.997 0 01-.707.293H7a4 4 0 01-4-4V7a4 4 0 014-4z"/>
|
||||
</svg>
|
||||
</div>
|
||||
<input
|
||||
id="doc"
|
||||
name="doc"
|
||||
placeholder="doc_123abc..."
|
||||
class="block w-full pl-12 pr-4 py-4 border border-slate-300 rounded-2xl text-lg placeholder-slate-400 focus:ring-2 focus:ring-primary-500 focus:border-primary-500 transition-colors"
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
<p class="mt-2 text-sm text-slate-500">Apposez à vos documents une preuve de lecture certifiée</p>
|
||||
</div>
|
||||
|
||||
<button type="submit" class="w-full bg-gradient-to-r from-primary-600 to-primary-700 hover:from-primary-700 hover:to-primary-800 text-white font-semibold py-4 px-6 rounded-2xl transition-all duration-200 flex items-center justify-center space-x-3 shadow-lg hover:shadow-xl">
|
||||
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M13 7l5 5m0 0l-5 5m5-5H6"/>
|
||||
</svg>
|
||||
<span>Continuer vers la signature</span>
|
||||
</button>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Info Cards -->
|
||||
<div class="grid md:grid-cols-3 gap-6">
|
||||
<div class="bg-white rounded-2xl p-6 border border-slate-200 hover:shadow-lg transition-shadow">
|
||||
<div class="w-10 h-10 bg-success-100 rounded-xl flex items-center justify-center mb-4">
|
||||
<svg class="w-5 h-5 text-success-600" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 15v2m-6 4h12a2 2 0 002-2v-6a2 2 0 00-2-2H6a2 2 0 00-2 2v6a2 2 0 002 2zm10-10V7a4 4 0 00-8 0v4h8z"/>
|
||||
</svg>
|
||||
</div>
|
||||
<h3 class="font-semibold text-slate-900 mb-2">Sécurisé</h3>
|
||||
<p class="text-sm text-slate-600">Cryptographie Ed25519 et authentification OAuth2 pour une sécurité maximale</p>
|
||||
</div>
|
||||
|
||||
<div class="bg-white rounded-2xl p-6 border border-slate-200 hover:shadow-lg transition-shadow">
|
||||
<div class="w-10 h-10 bg-primary-100 rounded-xl flex items-center justify-center mb-4">
|
||||
<svg class="w-5 h-5 text-primary-600" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M13 10V3L4 14h7v7l9-11h-7z"/>
|
||||
</svg>
|
||||
</div>
|
||||
<h3 class="font-semibold text-slate-900 mb-2">Efficace</h3>
|
||||
<p class="text-sm text-slate-600">Validez vos lectures en 30 secondes, traçabilité garantie</p>
|
||||
</div>
|
||||
|
||||
<div class="bg-white rounded-2xl p-6 border border-slate-200 hover:shadow-lg transition-shadow">
|
||||
<div class="w-10 h-10 bg-warning-100 rounded-xl flex items-center justify-center mb-4">
|
||||
<svg class="w-5 h-5 text-warning-600" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 5H7a2 2 0 00-2 2v10a2 2 0 002 2h8a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2"/>
|
||||
</svg>
|
||||
</div>
|
||||
<h3 class="font-semibold text-slate-900 mb-2">Conforme</h3>
|
||||
<p class="text-sm text-slate-600">Audit trail complet pour vos besoins de conformité réglementaire</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{{end}}
|
||||
@@ -0,0 +1,114 @@
|
||||
{{define "sign"}}
|
||||
<div class="space-y-6">
|
||||
<!-- Document Info Card -->
|
||||
<div class="bg-white rounded-3xl shadow-xl border border-slate-200 overflow-hidden">
|
||||
<div class="bg-gradient-to-r from-slate-100 to-slate-200 px-8 py-6 border-b border-slate-200">
|
||||
<div class="flex items-center space-x-4">
|
||||
<div class="w-10 h-10 bg-success-100 rounded-xl flex items-center justify-center flex-shrink-0">
|
||||
<svg class="w-5 h-5 text-success-600" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M5 13l4 4L19 7"/>
|
||||
</svg>
|
||||
</div>
|
||||
<div>
|
||||
<div class="flex items-center space-x-3">
|
||||
<p class="font-semibold text-slate-900 text-2xl">Document {{.DocID}}</p>
|
||||
{{if .ServiceInfo}}
|
||||
<div class="flex items-center space-x-1 bg-slate-100 px-2 py-1 rounded-md">
|
||||
<img src="{{.ServiceInfo.Icon}}" alt="{{.ServiceInfo.Name}}" class="w-3 h-3">
|
||||
<span class="text-xs text-slate-600">{{.ServiceInfo.Name}}</span>
|
||||
</div>
|
||||
{{end}}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="px-8 py-8">
|
||||
{{if .Already}}
|
||||
<!-- Document Already Signed -->
|
||||
<div class="text-center space-y-6">
|
||||
<div class="mx-auto w-20 h-20 bg-success-100 rounded-full flex items-center justify-center">
|
||||
<svg class="w-10 h-10 text-success-600" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 12l2 2 4-4m6 2a9 9 0 11-18 0 9 9 0 0118 0z"/>
|
||||
</svg>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<h3 class="text-xl font-bold text-success-700 mb-2">Document Déjà Signé</h3>
|
||||
<p class="text-slate-600 mb-4">Vous avez confirmé la lecture de ce document</p>
|
||||
|
||||
<div class="bg-success-50 border border-success-200 rounded-2xl p-6">
|
||||
<div class="flex items-center justify-center space-x-3 text-success-800">
|
||||
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 8v4l3 3m6-3a9 9 0 11-18 0 9 9 0 0118 0z"/>
|
||||
</svg>
|
||||
<span class="font-semibold">Signé le {{.SignedAt}}</span>
|
||||
</div>
|
||||
<p class="text-success-700 text-sm mt-2">Signature cryptographique enregistrée et vérifiable</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{{else}}
|
||||
<!-- Document Not Signed Yet -->
|
||||
<div class="text-center space-y-6">
|
||||
<div class="mx-auto w-20 h-20 bg-warning-100 rounded-full flex items-center justify-center">
|
||||
<svg class="w-10 h-10 text-warning-600" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-2.5L13.732 4c-.77-.833-1.732-.833-2.5 0L3.732 16.5c-.77.833.192 2.5 1.732 2.5z"/>
|
||||
</svg>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<h3 class="text-xl font-bold text-warning-700 mb-2">Document Non Signé</h3>
|
||||
<p class="text-slate-600 mb-6">Vous devez confirmer avoir lu et approuvé ce document</p>
|
||||
|
||||
<div class="bg-warning-50 border border-warning-200 rounded-2xl p-6 mb-6">
|
||||
<div class="flex items-start space-x-3">
|
||||
<svg class="w-5 h-5 text-warning-600 mt-0.5 flex-shrink-0" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"/>
|
||||
</svg>
|
||||
<div class="text-left">
|
||||
<p class="font-semibold text-warning-800 mb-1">Avant de signer</p>
|
||||
<p class="text-warning-700 text-sm">Assurez-vous d'avoir lu et compris l'intégralité du document. La signature est irréversible.</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<form method="POST" action="/sign">
|
||||
<input type="hidden" name="doc" value="{{.DocID}}" />
|
||||
{{if .ServiceInfo}}
|
||||
<input type="hidden" name="referrer" value="{{.ServiceInfo.Referrer}}" />
|
||||
{{end}}
|
||||
<button type="submit" class="w-full bg-gradient-to-r from-success-600 to-success-700 hover:from-success-700 hover:to-success-800 text-white font-bold py-4 px-8 rounded-2xl transition-all duration-200 flex items-center justify-center space-x-3 shadow-lg hover:shadow-xl text-lg">
|
||||
<svg class="w-6 h-6" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M15.232 5.232l3.536 3.536m-2.036-5.036a2.5 2.5 0 113.536 3.536L6.5 21.036H3v-3.572L16.732 3.732z"/>
|
||||
</svg>
|
||||
<span>J'ai lu et j'approuve ce document</span>
|
||||
</button>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
{{end}}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Actions Card -->
|
||||
<div class="bg-white rounded-2xl border border-slate-200 p-6">
|
||||
<h4 class="font-semibold text-slate-900 mb-4">Actions supplémentaires</h4>
|
||||
<div class="grid grid-cols-1 sm:grid-cols-2 gap-3">
|
||||
<a href="/embed?doc={{.DocID}}" target="_blank" class="flex items-center justify-center space-x-2 px-4 py-3 bg-slate-100 hover:bg-slate-200 text-slate-700 rounded-xl transition-colors text-sm font-medium text-center">
|
||||
<svg class="w-4 h-4 flex-shrink-0" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M8 9l3 3-3 3m5 0h3M5 20h14a2 2 0 002-2V6a2 2 0 00-2-2H5a2 2 0 00-2 2v14a2 2 0 002 2z"/>
|
||||
</svg>
|
||||
<span>Widget embarqué</span>
|
||||
</a>
|
||||
|
||||
<a href="/" class="flex items-center justify-center space-x-2 px-4 py-3 bg-primary-100 hover:bg-primary-200 text-primary-700 rounded-xl transition-colors text-sm font-medium text-center">
|
||||
<svg class="w-4 h-4 flex-shrink-0" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M3 9l9-7 9 7v11a2 2 0 01-2 2H5a2 2 0 01-2-2V9z"/>
|
||||
</svg>
|
||||
<span>Retour à l'accueil</span>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{{end}}
|
||||
@@ -0,0 +1,112 @@
|
||||
{{define "signatures"}}
|
||||
<div class="space-y-6">
|
||||
<!-- Header -->
|
||||
<div class="bg-white rounded-3xl shadow-xl border border-slate-200 overflow-hidden">
|
||||
<div class="bg-gradient-to-r from-primary-600 to-primary-700 px-8 py-6">
|
||||
<div class="flex items-center justify-between">
|
||||
<div class="flex items-center space-x-4">
|
||||
<div class="w-12 h-12 bg-white/20 rounded-2xl flex items-center justify-center">
|
||||
<svg class="w-7 h-7 text-white" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 12l2 2 4-4m6 2a9 9 0 11-18 0 9 9 0 0118 0z"/>
|
||||
</svg>
|
||||
</div>
|
||||
<div>
|
||||
<h2 class="text-2xl font-bold text-white">Mes signatures</h2>
|
||||
<p class="text-primary-100">Liste de tous les documents que vous avez signés</p>
|
||||
</div>
|
||||
</div>
|
||||
<a href="/" class="text-primary-100 hover:text-white transition-colors">
|
||||
<svg class="w-6 h-6" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M10 19l-7-7m0 0l7-7m-7 7h18"/>
|
||||
</svg>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Signatures List -->
|
||||
<div class="bg-white rounded-3xl shadow-xl border border-slate-200 overflow-hidden">
|
||||
{{if .Signatures}}
|
||||
<!-- Stats -->
|
||||
<div class="bg-gradient-to-r from-slate-50 to-slate-100 px-8 py-4 border-b border-slate-200">
|
||||
<div class="flex items-center justify-between">
|
||||
<span class="text-slate-600 font-medium">
|
||||
{{len .Signatures}} signature{{if gt (len .Signatures) 1}}s{{end}} au total
|
||||
</span>
|
||||
<span class="text-sm text-slate-500">
|
||||
Trié par date décroissante
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Signatures -->
|
||||
<div class="divide-y divide-slate-100">
|
||||
{{range .Signatures}}
|
||||
<div class="px-8 py-6 hover:bg-slate-50 transition-colors">
|
||||
<div class="flex items-center space-x-4">
|
||||
<div class="w-10 h-10 bg-success-100 rounded-xl flex items-center justify-center flex-shrink-0">
|
||||
<svg class="w-5 h-5 text-success-600" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M5 13l4 4L19 7"/>
|
||||
</svg>
|
||||
</div>
|
||||
<div class="flex-1 min-w-0">
|
||||
<div class="flex items-center justify-between">
|
||||
<div>
|
||||
<div class="flex items-center space-x-3">
|
||||
<p class="font-semibold text-slate-900">Document {{.DocID}}</p>
|
||||
{{if .GetServiceInfo}}
|
||||
<div class="flex items-center space-x-1 bg-slate-100 px-2 py-1 rounded-md">
|
||||
<img src="{{.GetServiceInfo.Icon}}" alt="{{.GetServiceInfo.Name}}" class="w-3 h-3">
|
||||
<span class="text-xs text-slate-600">{{.GetServiceInfo.Name}}</span>
|
||||
</div>
|
||||
{{end}}
|
||||
</div>
|
||||
<p class="text-sm text-slate-500 mt-1">
|
||||
Signé le {{.SignedAtUTC.Format "02/01/2006 à 15:04:05"}}
|
||||
</p>
|
||||
</div>
|
||||
<div class="flex space-x-2">
|
||||
<a href="/sign?doc={{.DocID}}"
|
||||
class="inline-flex items-center px-3 py-2 text-sm font-medium text-primary-700 bg-primary-50 rounded-lg hover:bg-primary-100 transition-colors">
|
||||
<svg class="w-4 h-4 mr-2" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M15 12a3 3 0 11-6 0 3 3 0 016 0z"/>
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M2.458 12C3.732 7.943 7.523 5 12 5c4.478 0 8.268 2.943 9.542 7-1.274 4.057-5.064 7-9.542 7-4.477 0-8.268-2.943-9.542-7z"/>
|
||||
</svg>
|
||||
Voir
|
||||
</a>
|
||||
<a href="/status?doc={{.DocID}}"
|
||||
class="inline-flex items-center px-3 py-2 text-sm font-medium text-slate-600 bg-slate-100 rounded-lg hover:bg-slate-200 transition-colors">
|
||||
<svg class="w-4 h-4 mr-2" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 19v-6a2 2 0 00-2-2H5a2 2 0 00-2 2v6a2 2 0 002 2h2a2 2 0 002-2zm0 0V9a2 2 0 012-2h2a2 2 0 012 2v10m-6 0a2 2 0 002 2h2a2 2 0 002-2m0 0V5a2 2 0 012-2h2a2 2 0 012 2v14a2 2 0 01-2 2h-2a2 2 0 01-2-2z"/>
|
||||
</svg>
|
||||
Statut
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{{end}}
|
||||
</div>
|
||||
{{else}}
|
||||
<!-- Empty State -->
|
||||
<div class="px-8 py-16 text-center">
|
||||
<div class="w-20 h-20 mx-auto bg-slate-100 rounded-2xl flex items-center justify-center mb-6">
|
||||
<svg class="w-10 h-10 text-slate-400" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 12h6m-6 4h6m2 5H7a2 2 0 01-2-2V5a2 2 0 012-2h5.586a1 1 0 01.707.293l5.414 5.414a1 1 0 01.293.707V19a2 2 0 01-2 2z"/>
|
||||
</svg>
|
||||
</div>
|
||||
<h3 class="text-lg font-semibold text-slate-900 mb-2">Aucune signature</h3>
|
||||
<p class="text-slate-500 mb-6">Vous n'avez encore signé aucun document.</p>
|
||||
<a href="/"
|
||||
class="inline-flex items-center px-6 py-3 bg-gradient-to-r from-primary-600 to-primary-700 hover:from-primary-700 hover:to-primary-800 text-white font-semibold rounded-2xl transition-all duration-200 shadow-lg hover:shadow-xl">
|
||||
<svg class="w-5 h-5 mr-2" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 6v6m0 0v6m0-6h6m-6 0H6"/>
|
||||
</svg>
|
||||
Signer un document
|
||||
</a>
|
||||
</div>
|
||||
{{end}}
|
||||
</div>
|
||||
</div>
|
||||
{{end}}
|
||||
Reference in New Issue
Block a user