This commit is contained in:
Dries Peeters
2025-11-14 21:57:59 +01:00
parent 07d9d13240
commit 9f4ab335fe
3 changed files with 126 additions and 27 deletions
+6
View File
@@ -257,8 +257,14 @@ class User(UserMixin, db.Model):
from .project import Project
from .invoice import Invoice
from .time_entry import TimeEntry
from .client import Client
# Get client - try relationship first, then query by ID if needed
client = self.client
if not client and self.client_id:
# Relationship might not be loaded, query directly
client = Client.query.get(self.client_id)
if not client:
return None
+108 -23
View File
@@ -6,7 +6,7 @@ invoices, and time entries. Uses separate authentication from regular users.
from flask import Blueprint, render_template, request, redirect, url_for, flash, abort, session
from flask_babel import gettext as _
from app import db
from app.models import Client, Project, Invoice, TimeEntry
from app.models import Client, Project, Invoice, TimeEntry, User
from app.utils.db import safe_commit
from datetime import datetime, timedelta
from sqlalchemy import func
@@ -16,11 +16,20 @@ client_portal_bp = Blueprint('client_portal', __name__)
def get_current_client():
"""Get the currently logged-in client from session"""
"""Get the currently logged-in client from session (either Client or User portal access)"""
# Check for Client portal authentication
client_id = session.get('client_portal_id')
if not client_id:
return None
return Client.query.get(client_id)
if client_id:
return Client.query.get(client_id)
# Check for User portal authentication
user_id = session.get('_user_id')
if user_id:
user = User.query.get(user_id)
if user and user.is_client_portal_user:
return user.client # Return the Client object linked to the user
return None
# Make get_current_client available to templates
@@ -31,23 +40,99 @@ def inject_get_current_client():
def check_client_portal_access():
"""Helper function to check if client has portal access - redirects to login if not authenticated"""
client = get_current_client()
if not client:
flash(_('Please log in to access the client portal.'), 'error')
return redirect(url_for('client_portal.login', next=request.url))
"""Helper function to check if client has portal access - returns 403 for users without access, redirects to login if not authenticated
if not client.has_portal_access:
flash(_('Client portal access is not enabled for your account.'), 'error')
session.pop('client_portal_id', None) # Clear invalid session
return redirect(url_for('client_portal.login'))
Returns:
Client: The Client object if access is granted
Response: A redirect response if authentication is needed
None: If 403 is raised (abort is called)
"""
# Check for Client portal authentication
client_id = session.get('client_portal_id')
if client_id:
client = Client.query.get(client_id)
if not client:
flash(_('Please log in to access the client portal.'), 'error')
return redirect(url_for('client_portal.login', next=request.url))
if not client.has_portal_access:
flash(_('Client portal access is not enabled for your account.'), 'error')
session.pop('client_portal_id', None) # Clear invalid session
return redirect(url_for('client_portal.login'))
if not client.is_active:
flash(_('Your client account is inactive.'), 'error')
session.pop('client_portal_id', None) # Clear invalid session
return redirect(url_for('client_portal.login'))
return client
if not client.is_active:
flash(_('Your client account is inactive.'), 'error')
session.pop('client_portal_id', None) # Clear invalid session
return redirect(url_for('client_portal.login'))
# Check for User portal authentication
user_id = session.get('_user_id')
if user_id:
try:
# Convert to int if it's a string (session stores it as string)
if isinstance(user_id, str):
user_id = int(user_id)
# Query with options to ensure we get fresh data and load relationships
from sqlalchemy.orm import joinedload
user = User.query.options(joinedload(User.client)).get(user_id)
except (ValueError, TypeError):
# Invalid user_id format
flash(_('Please log in to access the client portal.'), 'error')
return redirect(url_for('client_portal.login', next=request.url))
except Exception:
# If there's a session error, try to rollback and retry
try:
db.session.rollback()
user = User.query.options(joinedload(User.client)).get(user_id)
except Exception:
db.session.rollback()
flash(_('Please log in to access the client portal.'), 'error')
return redirect(url_for('client_portal.login', next=request.url))
if not user:
flash(_('Please log in to access the client portal.'), 'error')
return redirect(url_for('client_portal.login', next=request.url))
# Check portal access directly to ensure we have the latest values
if not (user.client_portal_enabled and user.client_id is not None):
# User is logged in but doesn't have portal access - return 403
abort(403)
if not user.is_active:
abort(403)
if not user.client:
abort(403)
return user.client
return client
# No authentication at all - redirect to login
flash(_('Please log in to access the client portal.'), 'error')
return redirect(url_for('client_portal.login', next=request.url))
def get_portal_data(client):
"""Get portal data for a client, handling both Client and User authentication"""
# Check if this is a User accessing via client portal
user_id = session.get('_user_id')
if user_id:
try:
# Convert to int if it's a string
if isinstance(user_id, str):
user_id = int(user_id)
db.session.rollback()
user = User.query.get(user_id)
if user and user.is_client_portal_user and user.client_id == client.id:
# Use User's get_client_portal_data method
return user.get_client_portal_data()
except Exception:
db.session.rollback()
# Fall through to Client method
# Otherwise use Client's get_portal_data method
return client.get_portal_data()
@client_portal_bp.route('/client-portal/login', methods=['GET', 'POST'])
@@ -151,7 +236,7 @@ def dashboard():
if not isinstance(result, Client): # It's a redirect response
return result
client = result
portal_data = client.get_portal_data()
portal_data = get_portal_data(client)
if not portal_data:
flash(_('Unable to load client portal data.'), 'error')
@@ -221,7 +306,7 @@ def projects():
if not isinstance(result, Client):
return result
client = result
portal_data = client.get_portal_data()
portal_data = get_portal_data(client)
if not portal_data:
flash(_('Unable to load client portal data.'), 'error')
@@ -256,7 +341,7 @@ def invoices():
if not isinstance(result, Client):
return result
client = result
portal_data = client.get_portal_data()
portal_data = get_portal_data(client)
if not portal_data:
flash(_('Unable to load client portal data.'), 'error')
@@ -312,7 +397,7 @@ def time_entries():
if not isinstance(result, Client):
return result
client = result
portal_data = client.get_portal_data()
portal_data = get_portal_data(client)
if not portal_data:
flash(_('Unable to load client portal data.'), 'error')
+12 -4
View File
@@ -66,7 +66,7 @@ class TestClientPortalUserModel:
assert 'projects' in data
assert 'invoices' in data
assert 'time_entries' in data
assert data['client'] == test_client
assert data['client'].id == test_client.id
def test_get_client_portal_data_with_projects(self, app, user, test_client):
"""Test get_client_portal_data includes projects"""
@@ -92,15 +92,23 @@ class TestClientPortalUserModel:
with app.app_context():
user.client_portal_enabled = True
user.client_id = test_client.id
db.session.commit()
# Handle potential session issues from audit logging
try:
db.session.rollback()
except Exception:
pass
project = Project(name="Test Project", client_id=test_client.id)
db.session.add(project)
db.session.commit()
db.session.flush() # Flush to get project.id without committing
project_id = project.id
# Create invoices
invoice1 = Invoice(
invoice_number="INV-001",
project_id=project.id,
project_id=project_id,
client_name=test_client.name,
client_id=test_client.id,
due_date=datetime.utcnow().date() + timedelta(days=30),
@@ -109,7 +117,7 @@ class TestClientPortalUserModel:
)
invoice2 = Invoice(
invoice_number="INV-002",
project_id=project.id,
project_id=project_id,
client_name=test_client.name,
client_id=test_client.id,
due_date=datetime.utcnow().date() + timedelta(days=30),