mirror of
https://github.com/DRYTRIX/TimeTracker.git
synced 2026-05-12 07:19:49 -05:00
tests
This commit is contained in:
@@ -257,8 +257,14 @@ class User(UserMixin, db.Model):
|
||||
from .project import Project
|
||||
from .invoice import Invoice
|
||||
from .time_entry import TimeEntry
|
||||
from .client import Client
|
||||
|
||||
# Get client - try relationship first, then query by ID if needed
|
||||
client = self.client
|
||||
if not client and self.client_id:
|
||||
# Relationship might not be loaded, query directly
|
||||
client = Client.query.get(self.client_id)
|
||||
|
||||
if not client:
|
||||
return None
|
||||
|
||||
|
||||
+108
-23
@@ -6,7 +6,7 @@ invoices, and time entries. Uses separate authentication from regular users.
|
||||
from flask import Blueprint, render_template, request, redirect, url_for, flash, abort, session
|
||||
from flask_babel import gettext as _
|
||||
from app import db
|
||||
from app.models import Client, Project, Invoice, TimeEntry
|
||||
from app.models import Client, Project, Invoice, TimeEntry, User
|
||||
from app.utils.db import safe_commit
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy import func
|
||||
@@ -16,11 +16,20 @@ client_portal_bp = Blueprint('client_portal', __name__)
|
||||
|
||||
|
||||
def get_current_client():
|
||||
"""Get the currently logged-in client from session"""
|
||||
"""Get the currently logged-in client from session (either Client or User portal access)"""
|
||||
# Check for Client portal authentication
|
||||
client_id = session.get('client_portal_id')
|
||||
if not client_id:
|
||||
return None
|
||||
return Client.query.get(client_id)
|
||||
if client_id:
|
||||
return Client.query.get(client_id)
|
||||
|
||||
# Check for User portal authentication
|
||||
user_id = session.get('_user_id')
|
||||
if user_id:
|
||||
user = User.query.get(user_id)
|
||||
if user and user.is_client_portal_user:
|
||||
return user.client # Return the Client object linked to the user
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Make get_current_client available to templates
|
||||
@@ -31,23 +40,99 @@ def inject_get_current_client():
|
||||
|
||||
|
||||
def check_client_portal_access():
|
||||
"""Helper function to check if client has portal access - redirects to login if not authenticated"""
|
||||
client = get_current_client()
|
||||
if not client:
|
||||
flash(_('Please log in to access the client portal.'), 'error')
|
||||
return redirect(url_for('client_portal.login', next=request.url))
|
||||
"""Helper function to check if client has portal access - returns 403 for users without access, redirects to login if not authenticated
|
||||
|
||||
if not client.has_portal_access:
|
||||
flash(_('Client portal access is not enabled for your account.'), 'error')
|
||||
session.pop('client_portal_id', None) # Clear invalid session
|
||||
return redirect(url_for('client_portal.login'))
|
||||
Returns:
|
||||
Client: The Client object if access is granted
|
||||
Response: A redirect response if authentication is needed
|
||||
None: If 403 is raised (abort is called)
|
||||
"""
|
||||
# Check for Client portal authentication
|
||||
client_id = session.get('client_portal_id')
|
||||
if client_id:
|
||||
client = Client.query.get(client_id)
|
||||
if not client:
|
||||
flash(_('Please log in to access the client portal.'), 'error')
|
||||
return redirect(url_for('client_portal.login', next=request.url))
|
||||
|
||||
if not client.has_portal_access:
|
||||
flash(_('Client portal access is not enabled for your account.'), 'error')
|
||||
session.pop('client_portal_id', None) # Clear invalid session
|
||||
return redirect(url_for('client_portal.login'))
|
||||
|
||||
if not client.is_active:
|
||||
flash(_('Your client account is inactive.'), 'error')
|
||||
session.pop('client_portal_id', None) # Clear invalid session
|
||||
return redirect(url_for('client_portal.login'))
|
||||
|
||||
return client
|
||||
|
||||
if not client.is_active:
|
||||
flash(_('Your client account is inactive.'), 'error')
|
||||
session.pop('client_portal_id', None) # Clear invalid session
|
||||
return redirect(url_for('client_portal.login'))
|
||||
# Check for User portal authentication
|
||||
user_id = session.get('_user_id')
|
||||
if user_id:
|
||||
try:
|
||||
# Convert to int if it's a string (session stores it as string)
|
||||
if isinstance(user_id, str):
|
||||
user_id = int(user_id)
|
||||
# Query with options to ensure we get fresh data and load relationships
|
||||
from sqlalchemy.orm import joinedload
|
||||
user = User.query.options(joinedload(User.client)).get(user_id)
|
||||
except (ValueError, TypeError):
|
||||
# Invalid user_id format
|
||||
flash(_('Please log in to access the client portal.'), 'error')
|
||||
return redirect(url_for('client_portal.login', next=request.url))
|
||||
except Exception:
|
||||
# If there's a session error, try to rollback and retry
|
||||
try:
|
||||
db.session.rollback()
|
||||
user = User.query.options(joinedload(User.client)).get(user_id)
|
||||
except Exception:
|
||||
db.session.rollback()
|
||||
flash(_('Please log in to access the client portal.'), 'error')
|
||||
return redirect(url_for('client_portal.login', next=request.url))
|
||||
|
||||
if not user:
|
||||
flash(_('Please log in to access the client portal.'), 'error')
|
||||
return redirect(url_for('client_portal.login', next=request.url))
|
||||
|
||||
# Check portal access directly to ensure we have the latest values
|
||||
if not (user.client_portal_enabled and user.client_id is not None):
|
||||
# User is logged in but doesn't have portal access - return 403
|
||||
abort(403)
|
||||
|
||||
if not user.is_active:
|
||||
abort(403)
|
||||
|
||||
if not user.client:
|
||||
abort(403)
|
||||
|
||||
return user.client
|
||||
|
||||
return client
|
||||
# No authentication at all - redirect to login
|
||||
flash(_('Please log in to access the client portal.'), 'error')
|
||||
return redirect(url_for('client_portal.login', next=request.url))
|
||||
|
||||
|
||||
def get_portal_data(client):
|
||||
"""Get portal data for a client, handling both Client and User authentication"""
|
||||
# Check if this is a User accessing via client portal
|
||||
user_id = session.get('_user_id')
|
||||
if user_id:
|
||||
try:
|
||||
# Convert to int if it's a string
|
||||
if isinstance(user_id, str):
|
||||
user_id = int(user_id)
|
||||
db.session.rollback()
|
||||
user = User.query.get(user_id)
|
||||
if user and user.is_client_portal_user and user.client_id == client.id:
|
||||
# Use User's get_client_portal_data method
|
||||
return user.get_client_portal_data()
|
||||
except Exception:
|
||||
db.session.rollback()
|
||||
# Fall through to Client method
|
||||
|
||||
# Otherwise use Client's get_portal_data method
|
||||
return client.get_portal_data()
|
||||
|
||||
|
||||
@client_portal_bp.route('/client-portal/login', methods=['GET', 'POST'])
|
||||
@@ -151,7 +236,7 @@ def dashboard():
|
||||
if not isinstance(result, Client): # It's a redirect response
|
||||
return result
|
||||
client = result
|
||||
portal_data = client.get_portal_data()
|
||||
portal_data = get_portal_data(client)
|
||||
|
||||
if not portal_data:
|
||||
flash(_('Unable to load client portal data.'), 'error')
|
||||
@@ -221,7 +306,7 @@ def projects():
|
||||
if not isinstance(result, Client):
|
||||
return result
|
||||
client = result
|
||||
portal_data = client.get_portal_data()
|
||||
portal_data = get_portal_data(client)
|
||||
|
||||
if not portal_data:
|
||||
flash(_('Unable to load client portal data.'), 'error')
|
||||
@@ -256,7 +341,7 @@ def invoices():
|
||||
if not isinstance(result, Client):
|
||||
return result
|
||||
client = result
|
||||
portal_data = client.get_portal_data()
|
||||
portal_data = get_portal_data(client)
|
||||
|
||||
if not portal_data:
|
||||
flash(_('Unable to load client portal data.'), 'error')
|
||||
@@ -312,7 +397,7 @@ def time_entries():
|
||||
if not isinstance(result, Client):
|
||||
return result
|
||||
client = result
|
||||
portal_data = client.get_portal_data()
|
||||
portal_data = get_portal_data(client)
|
||||
|
||||
if not portal_data:
|
||||
flash(_('Unable to load client portal data.'), 'error')
|
||||
|
||||
@@ -66,7 +66,7 @@ class TestClientPortalUserModel:
|
||||
assert 'projects' in data
|
||||
assert 'invoices' in data
|
||||
assert 'time_entries' in data
|
||||
assert data['client'] == test_client
|
||||
assert data['client'].id == test_client.id
|
||||
|
||||
def test_get_client_portal_data_with_projects(self, app, user, test_client):
|
||||
"""Test get_client_portal_data includes projects"""
|
||||
@@ -92,15 +92,23 @@ class TestClientPortalUserModel:
|
||||
with app.app_context():
|
||||
user.client_portal_enabled = True
|
||||
user.client_id = test_client.id
|
||||
db.session.commit()
|
||||
|
||||
# Handle potential session issues from audit logging
|
||||
try:
|
||||
db.session.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
project = Project(name="Test Project", client_id=test_client.id)
|
||||
db.session.add(project)
|
||||
db.session.commit()
|
||||
db.session.flush() # Flush to get project.id without committing
|
||||
project_id = project.id
|
||||
|
||||
# Create invoices
|
||||
invoice1 = Invoice(
|
||||
invoice_number="INV-001",
|
||||
project_id=project.id,
|
||||
project_id=project_id,
|
||||
client_name=test_client.name,
|
||||
client_id=test_client.id,
|
||||
due_date=datetime.utcnow().date() + timedelta(days=30),
|
||||
@@ -109,7 +117,7 @@ class TestClientPortalUserModel:
|
||||
)
|
||||
invoice2 = Invoice(
|
||||
invoice_number="INV-002",
|
||||
project_id=project.id,
|
||||
project_id=project_id,
|
||||
client_name=test_client.name,
|
||||
client_id=test_client.id,
|
||||
due_date=datetime.utcnow().date() + timedelta(days=30),
|
||||
|
||||
Reference in New Issue
Block a user