From 9f4ab335fe37b17332df01708b2ac1ac2d17ff9c Mon Sep 17 00:00:00 2001 From: Dries Peeters Date: Fri, 14 Nov 2025 21:57:59 +0100 Subject: [PATCH] tests --- app/models/user.py | 6 ++ app/routes/client_portal.py | 131 +++++++++++++++++++++++++++++------- tests/test_client_portal.py | 16 +++-- 3 files changed, 126 insertions(+), 27 deletions(-) diff --git a/app/models/user.py b/app/models/user.py index d38fd7f..123a27e 100644 --- a/app/models/user.py +++ b/app/models/user.py @@ -257,8 +257,14 @@ class User(UserMixin, db.Model): from .project import Project from .invoice import Invoice from .time_entry import TimeEntry + from .client import Client + # Get client - try relationship first, then query by ID if needed client = self.client + if not client and self.client_id: + # Relationship might not be loaded, query directly + client = Client.query.get(self.client_id) + if not client: return None diff --git a/app/routes/client_portal.py b/app/routes/client_portal.py index 62bf9e0..3462165 100644 --- a/app/routes/client_portal.py +++ b/app/routes/client_portal.py @@ -6,7 +6,7 @@ invoices, and time entries. Uses separate authentication from regular users. from flask import Blueprint, render_template, request, redirect, url_for, flash, abort, session from flask_babel import gettext as _ from app import db -from app.models import Client, Project, Invoice, TimeEntry +from app.models import Client, Project, Invoice, TimeEntry, User from app.utils.db import safe_commit from datetime import datetime, timedelta from sqlalchemy import func @@ -16,11 +16,20 @@ client_portal_bp = Blueprint('client_portal', __name__) def get_current_client(): - """Get the currently logged-in client from session""" + """Get the currently logged-in client from session (either Client or User portal access)""" + # Check for Client portal authentication client_id = session.get('client_portal_id') - if not client_id: - return None - return Client.query.get(client_id) + if client_id: + return Client.query.get(client_id) + + # Check for User portal authentication + user_id = session.get('_user_id') + if user_id: + user = User.query.get(user_id) + if user and user.is_client_portal_user: + return user.client # Return the Client object linked to the user + + return None # Make get_current_client available to templates @@ -31,23 +40,99 @@ def inject_get_current_client(): def check_client_portal_access(): - """Helper function to check if client has portal access - redirects to login if not authenticated""" - client = get_current_client() - if not client: - flash(_('Please log in to access the client portal.'), 'error') - return redirect(url_for('client_portal.login', next=request.url)) + """Helper function to check if client has portal access - returns 403 for users without access, redirects to login if not authenticated - if not client.has_portal_access: - flash(_('Client portal access is not enabled for your account.'), 'error') - session.pop('client_portal_id', None) # Clear invalid session - return redirect(url_for('client_portal.login')) + Returns: + Client: The Client object if access is granted + Response: A redirect response if authentication is needed + None: If 403 is raised (abort is called) + """ + # Check for Client portal authentication + client_id = session.get('client_portal_id') + if client_id: + client = Client.query.get(client_id) + if not client: + flash(_('Please log in to access the client portal.'), 'error') + return redirect(url_for('client_portal.login', next=request.url)) + + if not client.has_portal_access: + flash(_('Client portal access is not enabled for your account.'), 'error') + session.pop('client_portal_id', None) # Clear invalid session + return redirect(url_for('client_portal.login')) + + if not client.is_active: + flash(_('Your client account is inactive.'), 'error') + session.pop('client_portal_id', None) # Clear invalid session + return redirect(url_for('client_portal.login')) + + return client - if not client.is_active: - flash(_('Your client account is inactive.'), 'error') - session.pop('client_portal_id', None) # Clear invalid session - return redirect(url_for('client_portal.login')) + # Check for User portal authentication + user_id = session.get('_user_id') + if user_id: + try: + # Convert to int if it's a string (session stores it as string) + if isinstance(user_id, str): + user_id = int(user_id) + # Query with options to ensure we get fresh data and load relationships + from sqlalchemy.orm import joinedload + user = User.query.options(joinedload(User.client)).get(user_id) + except (ValueError, TypeError): + # Invalid user_id format + flash(_('Please log in to access the client portal.'), 'error') + return redirect(url_for('client_portal.login', next=request.url)) + except Exception: + # If there's a session error, try to rollback and retry + try: + db.session.rollback() + user = User.query.options(joinedload(User.client)).get(user_id) + except Exception: + db.session.rollback() + flash(_('Please log in to access the client portal.'), 'error') + return redirect(url_for('client_portal.login', next=request.url)) + + if not user: + flash(_('Please log in to access the client portal.'), 'error') + return redirect(url_for('client_portal.login', next=request.url)) + + # Check portal access directly to ensure we have the latest values + if not (user.client_portal_enabled and user.client_id is not None): + # User is logged in but doesn't have portal access - return 403 + abort(403) + + if not user.is_active: + abort(403) + + if not user.client: + abort(403) + + return user.client - return client + # No authentication at all - redirect to login + flash(_('Please log in to access the client portal.'), 'error') + return redirect(url_for('client_portal.login', next=request.url)) + + +def get_portal_data(client): + """Get portal data for a client, handling both Client and User authentication""" + # Check if this is a User accessing via client portal + user_id = session.get('_user_id') + if user_id: + try: + # Convert to int if it's a string + if isinstance(user_id, str): + user_id = int(user_id) + db.session.rollback() + user = User.query.get(user_id) + if user and user.is_client_portal_user and user.client_id == client.id: + # Use User's get_client_portal_data method + return user.get_client_portal_data() + except Exception: + db.session.rollback() + # Fall through to Client method + + # Otherwise use Client's get_portal_data method + return client.get_portal_data() @client_portal_bp.route('/client-portal/login', methods=['GET', 'POST']) @@ -151,7 +236,7 @@ def dashboard(): if not isinstance(result, Client): # It's a redirect response return result client = result - portal_data = client.get_portal_data() + portal_data = get_portal_data(client) if not portal_data: flash(_('Unable to load client portal data.'), 'error') @@ -221,7 +306,7 @@ def projects(): if not isinstance(result, Client): return result client = result - portal_data = client.get_portal_data() + portal_data = get_portal_data(client) if not portal_data: flash(_('Unable to load client portal data.'), 'error') @@ -256,7 +341,7 @@ def invoices(): if not isinstance(result, Client): return result client = result - portal_data = client.get_portal_data() + portal_data = get_portal_data(client) if not portal_data: flash(_('Unable to load client portal data.'), 'error') @@ -312,7 +397,7 @@ def time_entries(): if not isinstance(result, Client): return result client = result - portal_data = client.get_portal_data() + portal_data = get_portal_data(client) if not portal_data: flash(_('Unable to load client portal data.'), 'error') diff --git a/tests/test_client_portal.py b/tests/test_client_portal.py index 378004f..e314c0f 100644 --- a/tests/test_client_portal.py +++ b/tests/test_client_portal.py @@ -66,7 +66,7 @@ class TestClientPortalUserModel: assert 'projects' in data assert 'invoices' in data assert 'time_entries' in data - assert data['client'] == test_client + assert data['client'].id == test_client.id def test_get_client_portal_data_with_projects(self, app, user, test_client): """Test get_client_portal_data includes projects""" @@ -92,15 +92,23 @@ class TestClientPortalUserModel: with app.app_context(): user.client_portal_enabled = True user.client_id = test_client.id + db.session.commit() + + # Handle potential session issues from audit logging + try: + db.session.rollback() + except Exception: + pass project = Project(name="Test Project", client_id=test_client.id) db.session.add(project) - db.session.commit() + db.session.flush() # Flush to get project.id without committing + project_id = project.id # Create invoices invoice1 = Invoice( invoice_number="INV-001", - project_id=project.id, + project_id=project_id, client_name=test_client.name, client_id=test_client.id, due_date=datetime.utcnow().date() + timedelta(days=30), @@ -109,7 +117,7 @@ class TestClientPortalUserModel: ) invoice2 = Invoice( invoice_number="INV-002", - project_id=project.id, + project_id=project_id, client_name=test_client.name, client_id=test_client.id, due_date=datetime.utcnow().date() + timedelta(days=30),