mirror of
https://github.com/eitchtee/WYGIWYH.git
synced 2025-12-16 18:26:10 -06:00
Merge pull request #425
feat(api): add endpoints for importing files and getting account balance
This commit is contained in:
33
app/apps/accounts/services.py
Normal file
33
app/apps/accounts/services.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from decimal import Decimal
|
||||
|
||||
from django.db import models
|
||||
|
||||
from apps.accounts.models import Account
|
||||
from apps.transactions.models import Transaction
|
||||
|
||||
|
||||
def get_account_balance(account: Account, paid_only: bool = True) -> Decimal:
|
||||
"""
|
||||
Calculate account balance (income - expense).
|
||||
|
||||
Args:
|
||||
account: Account instance to calculate balance for.
|
||||
paid_only: If True, only count paid transactions (current balance).
|
||||
If False, count all transactions (projected balance).
|
||||
|
||||
Returns:
|
||||
Decimal: The calculated balance (income - expense).
|
||||
"""
|
||||
filters = {"account": account}
|
||||
if paid_only:
|
||||
filters["is_paid"] = True
|
||||
|
||||
income = Transaction.objects.filter(
|
||||
type=Transaction.Type.INCOME, **filters
|
||||
).aggregate(total=models.Sum("amount"))["total"] or Decimal("0")
|
||||
|
||||
expense = Transaction.objects.filter(
|
||||
type=Transaction.Type.EXPENSE, **filters
|
||||
).aggregate(total=models.Sum("amount"))["total"] or Decimal("0")
|
||||
|
||||
return income - expense
|
||||
@@ -1,3 +1,5 @@
|
||||
from datetime import date
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
from apps.accounts.models import Account, AccountGroup
|
||||
@@ -39,3 +41,135 @@ class AccountTests(TestCase):
|
||||
exchange_currency=self.exchange_currency,
|
||||
)
|
||||
self.assertEqual(account.exchange_currency, self.exchange_currency)
|
||||
|
||||
|
||||
class GetAccountBalanceServiceTests(TestCase):
|
||||
"""Tests for the get_account_balance service function"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data"""
|
||||
from apps.transactions.models import Transaction
|
||||
self.Transaction = Transaction
|
||||
|
||||
self.currency = Currency.objects.create(
|
||||
code="BRL", name="Brazilian Real", decimal_places=2, prefix="R$ "
|
||||
)
|
||||
self.account_group = AccountGroup.objects.create(name="Service Test Group")
|
||||
self.account = Account.objects.create(
|
||||
name="Service Test Account", group=self.account_group, currency=self.currency
|
||||
)
|
||||
|
||||
def test_balance_with_no_transactions(self):
|
||||
"""Test balance is 0 when no transactions exist"""
|
||||
from apps.accounts.services import get_account_balance
|
||||
from decimal import Decimal
|
||||
|
||||
balance = get_account_balance(self.account, paid_only=True)
|
||||
self.assertEqual(balance, Decimal("0"))
|
||||
|
||||
def test_current_balance_only_counts_paid(self):
|
||||
"""Test current balance only counts paid transactions"""
|
||||
from apps.accounts.services import get_account_balance
|
||||
from decimal import Decimal
|
||||
|
||||
# Paid income
|
||||
self.Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=self.Transaction.Type.INCOME,
|
||||
amount=Decimal("100.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 1),
|
||||
description="Paid income",
|
||||
)
|
||||
# Unpaid income (should not count)
|
||||
self.Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=self.Transaction.Type.INCOME,
|
||||
amount=Decimal("50.00"),
|
||||
is_paid=False,
|
||||
date=date(2025, 1, 1),
|
||||
description="Unpaid income",
|
||||
)
|
||||
# Paid expense
|
||||
self.Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=self.Transaction.Type.EXPENSE,
|
||||
amount=Decimal("30.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 1),
|
||||
description="Paid expense",
|
||||
)
|
||||
|
||||
balance = get_account_balance(self.account, paid_only=True)
|
||||
self.assertEqual(balance, Decimal("70.00")) # 100 - 30
|
||||
|
||||
def test_projected_balance_counts_all(self):
|
||||
"""Test projected balance counts all transactions"""
|
||||
from apps.accounts.services import get_account_balance
|
||||
from decimal import Decimal
|
||||
|
||||
# Paid income
|
||||
self.Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=self.Transaction.Type.INCOME,
|
||||
amount=Decimal("100.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 1),
|
||||
description="Paid income",
|
||||
)
|
||||
# Unpaid income
|
||||
self.Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=self.Transaction.Type.INCOME,
|
||||
amount=Decimal("50.00"),
|
||||
is_paid=False,
|
||||
date=date(2025, 1, 1),
|
||||
description="Unpaid income",
|
||||
)
|
||||
# Paid expense
|
||||
self.Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=self.Transaction.Type.EXPENSE,
|
||||
amount=Decimal("30.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 1),
|
||||
description="Paid expense",
|
||||
)
|
||||
# Unpaid expense
|
||||
self.Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=self.Transaction.Type.EXPENSE,
|
||||
amount=Decimal("20.00"),
|
||||
is_paid=False,
|
||||
date=date(2025, 1, 1),
|
||||
description="Unpaid expense",
|
||||
)
|
||||
|
||||
balance = get_account_balance(self.account, paid_only=False)
|
||||
self.assertEqual(balance, Decimal("100.00")) # (100 + 50) - (30 + 20)
|
||||
|
||||
def test_balance_defaults_to_paid_only(self):
|
||||
"""Test that paid_only defaults to True"""
|
||||
from apps.accounts.services import get_account_balance
|
||||
from decimal import Decimal
|
||||
|
||||
self.Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=self.Transaction.Type.INCOME,
|
||||
amount=Decimal("100.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 1),
|
||||
description="Paid",
|
||||
)
|
||||
self.Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=self.Transaction.Type.INCOME,
|
||||
amount=Decimal("50.00"),
|
||||
is_paid=False,
|
||||
date=date(2025, 1, 1),
|
||||
description="Unpaid",
|
||||
)
|
||||
|
||||
balance = get_account_balance(self.account) # defaults to paid_only=True
|
||||
self.assertEqual(balance, Decimal("100.00"))
|
||||
|
||||
|
||||
@@ -11,23 +11,13 @@ from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.accounts.forms import AccountBalanceFormSet
|
||||
from apps.accounts.models import Account, Transaction
|
||||
from apps.accounts.services import get_account_balance
|
||||
from apps.common.decorators.htmx import only_htmx
|
||||
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
def account_reconciliation(request):
|
||||
def get_account_balance(account):
|
||||
income = Transaction.objects.filter(
|
||||
account=account, type=Transaction.Type.INCOME, is_paid=True
|
||||
).aggregate(total=models.Sum("amount"))["total"] or Decimal("0")
|
||||
|
||||
expense = Transaction.objects.filter(
|
||||
account=account, type=Transaction.Type.EXPENSE, is_paid=True
|
||||
).aggregate(total=models.Sum("amount"))["total"] or Decimal("0")
|
||||
|
||||
return income - expense
|
||||
|
||||
initial_data = [
|
||||
{
|
||||
"account_id": account.id,
|
||||
|
||||
@@ -2,3 +2,5 @@ from .transactions import *
|
||||
from .accounts import *
|
||||
from .currencies import *
|
||||
from .dca import *
|
||||
from .imports import *
|
||||
|
||||
|
||||
@@ -67,3 +67,12 @@ class AccountSerializer(serializers.ModelSerializer):
|
||||
setattr(instance, attr, value)
|
||||
instance.save()
|
||||
return instance
|
||||
|
||||
|
||||
class AccountBalanceSerializer(serializers.Serializer):
|
||||
"""Serializer for account balance response."""
|
||||
|
||||
current_balance = serializers.DecimalField(max_digits=20, decimal_places=10)
|
||||
projected_balance = serializers.DecimalField(max_digits=20, decimal_places=10)
|
||||
currency = CurrencySerializer()
|
||||
|
||||
|
||||
41
app/apps/api/serializers/imports.py
Normal file
41
app/apps/api/serializers/imports.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from rest_framework import serializers
|
||||
|
||||
from apps.import_app.models import ImportProfile, ImportRun
|
||||
|
||||
|
||||
class ImportProfileSerializer(serializers.ModelSerializer):
|
||||
"""Serializer for listing import profiles."""
|
||||
|
||||
class Meta:
|
||||
model = ImportProfile
|
||||
fields = ["id", "name", "version", "yaml_config"]
|
||||
|
||||
|
||||
class ImportRunSerializer(serializers.ModelSerializer):
|
||||
"""Serializer for listing import runs."""
|
||||
|
||||
class Meta:
|
||||
model = ImportRun
|
||||
fields = [
|
||||
"id",
|
||||
"status",
|
||||
"profile",
|
||||
"file_name",
|
||||
"logs",
|
||||
"processed_rows",
|
||||
"total_rows",
|
||||
"successful_rows",
|
||||
"skipped_rows",
|
||||
"failed_rows",
|
||||
"started_at",
|
||||
"finished_at",
|
||||
]
|
||||
|
||||
|
||||
class ImportFileSerializer(serializers.Serializer):
|
||||
"""Serializer for uploading a file to import using an existing profile."""
|
||||
|
||||
profile_id = serializers.PrimaryKeyRelatedField(
|
||||
queryset=ImportProfile.objects.all(), source="profile"
|
||||
)
|
||||
file = serializers.FileField()
|
||||
4
app/apps/api/tests/__init__.py
Normal file
4
app/apps/api/tests/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Import all test classes for Django test discovery
|
||||
from .test_imports import *
|
||||
from .test_accounts import *
|
||||
|
||||
99
app/apps/api/tests/test_accounts.py
Normal file
99
app/apps/api/tests/test_accounts.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.test import TestCase, override_settings
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from apps.accounts.models import Account, AccountGroup
|
||||
from apps.currencies.models import Currency
|
||||
from apps.transactions.models import Transaction
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class AccountBalanceAPITests(TestCase):
|
||||
"""Tests for the Account Balance API endpoint"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data"""
|
||||
User = get_user_model()
|
||||
self.user = User.objects.create_user(
|
||||
email="testuser@test.com", password="testpass123"
|
||||
)
|
||||
self.client = APIClient()
|
||||
self.client.force_authenticate(user=self.user)
|
||||
|
||||
self.currency = Currency.objects.create(
|
||||
code="USD", name="US Dollar", decimal_places=2, prefix="$ "
|
||||
)
|
||||
self.account_group = AccountGroup.objects.create(name="Test Group")
|
||||
self.account = Account.objects.create(
|
||||
name="Test Account", group=self.account_group, currency=self.currency
|
||||
)
|
||||
|
||||
# Create some transactions
|
||||
Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=Transaction.Type.INCOME,
|
||||
amount=Decimal("500.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 1),
|
||||
description="Paid income",
|
||||
)
|
||||
Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=Transaction.Type.INCOME,
|
||||
amount=Decimal("200.00"),
|
||||
is_paid=False,
|
||||
date=date(2025, 1, 15),
|
||||
description="Unpaid income",
|
||||
)
|
||||
Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=Transaction.Type.EXPENSE,
|
||||
amount=Decimal("100.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 10),
|
||||
description="Paid expense",
|
||||
)
|
||||
|
||||
def test_get_balance_success(self):
|
||||
"""Test successful balance retrieval"""
|
||||
response = self.client.get(f"/api/accounts/{self.account.id}/balance/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertIn("current_balance", response.data)
|
||||
self.assertIn("projected_balance", response.data)
|
||||
self.assertIn("currency", response.data)
|
||||
|
||||
# Current: 500 - 100 = 400
|
||||
self.assertEqual(Decimal(response.data["current_balance"]), Decimal("400.00"))
|
||||
# Projected: (500 + 200) - 100 = 600
|
||||
self.assertEqual(Decimal(response.data["projected_balance"]), Decimal("600.00"))
|
||||
|
||||
# Check currency data
|
||||
self.assertEqual(response.data["currency"]["code"], "USD")
|
||||
|
||||
def test_get_balance_nonexistent_account(self):
|
||||
"""Test balance for non-existent account returns 404"""
|
||||
response = self.client.get("/api/accounts/99999/balance/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||
|
||||
def test_get_balance_unauthenticated(self):
|
||||
"""Test unauthenticated request returns 403"""
|
||||
unauthenticated_client = APIClient()
|
||||
response = unauthenticated_client.get(
|
||||
f"/api/accounts/{self.account.id}/balance/"
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||
404
app/apps/api/tests/test_imports.py
Normal file
404
app/apps/api/tests/test_imports.py
Normal file
@@ -0,0 +1,404 @@
|
||||
from io import BytesIO
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.core.files.uploadedfile import SimpleUploadedFile
|
||||
from django.test import TestCase, override_settings
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from apps.import_app.models import ImportProfile, ImportRun
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class ImportAPITests(TestCase):
|
||||
"""Tests for the Import API endpoint"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data"""
|
||||
User = get_user_model()
|
||||
self.user = User.objects.create_user(
|
||||
email="testuser@test.com", password="testpass123"
|
||||
)
|
||||
self.client = APIClient()
|
||||
self.client.force_authenticate(user=self.user)
|
||||
|
||||
# Create a basic import profile with minimal valid YAML config
|
||||
self.profile = ImportProfile.objects.create(
|
||||
name="Test Profile",
|
||||
version=ImportProfile.Versions.VERSION_1,
|
||||
yaml_config="""
|
||||
file_type: csv
|
||||
date_format: "%Y-%m-%d"
|
||||
column_mapping:
|
||||
date:
|
||||
source: date
|
||||
description:
|
||||
source: description
|
||||
amount:
|
||||
source: amount
|
||||
transaction_type:
|
||||
detection_method: always_expense
|
||||
is_paid:
|
||||
detection_method: always_paid
|
||||
account:
|
||||
source: account
|
||||
match_field: name
|
||||
""",
|
||||
)
|
||||
|
||||
@patch("apps.import_app.tasks.process_import.defer")
|
||||
@patch("django.core.files.storage.FileSystemStorage.save")
|
||||
@patch("django.core.files.storage.FileSystemStorage.path")
|
||||
def test_create_import_success(self, mock_path, mock_save, mock_defer):
|
||||
"""Test successful file upload creates ImportRun and queues task"""
|
||||
mock_save.return_value = "test_file.csv"
|
||||
mock_path.return_value = "/usr/src/app/temp/test_file.csv"
|
||||
|
||||
csv_content = b"date,description,amount,account\n2025-01-01,Test,100,Main"
|
||||
file = SimpleUploadedFile(
|
||||
"test_file.csv", csv_content, content_type="text/csv"
|
||||
)
|
||||
|
||||
response = self.client.post(
|
||||
"/api/import/import/",
|
||||
{"profile_id": self.profile.id, "file": file},
|
||||
format="multipart",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED)
|
||||
self.assertIn("import_run_id", response.data)
|
||||
self.assertEqual(response.data["status"], "queued")
|
||||
|
||||
# Verify ImportRun was created
|
||||
import_run = ImportRun.objects.get(id=response.data["import_run_id"])
|
||||
self.assertEqual(import_run.profile, self.profile)
|
||||
self.assertEqual(import_run.file_name, "test_file.csv")
|
||||
|
||||
# Verify task was deferred
|
||||
mock_defer.assert_called_once_with(
|
||||
import_run_id=import_run.id,
|
||||
file_path="/usr/src/app/temp/test_file.csv",
|
||||
user_id=self.user.id,
|
||||
)
|
||||
|
||||
def test_create_import_missing_profile(self):
|
||||
"""Test request without profile_id returns 400"""
|
||||
csv_content = b"date,description,amount\n2025-01-01,Test,100"
|
||||
file = SimpleUploadedFile(
|
||||
"test_file.csv", csv_content, content_type="text/csv"
|
||||
)
|
||||
|
||||
response = self.client.post(
|
||||
"/api/import/import/",
|
||||
{"file": file},
|
||||
format="multipart",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertIn("profile_id", response.data)
|
||||
|
||||
def test_create_import_missing_file(self):
|
||||
"""Test request without file returns 400"""
|
||||
response = self.client.post(
|
||||
"/api/import/import/",
|
||||
{"profile_id": self.profile.id},
|
||||
format="multipart",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertIn("file", response.data)
|
||||
|
||||
def test_create_import_invalid_profile(self):
|
||||
"""Test request with non-existent profile returns 400"""
|
||||
csv_content = b"date,description,amount\n2025-01-01,Test,100"
|
||||
file = SimpleUploadedFile(
|
||||
"test_file.csv", csv_content, content_type="text/csv"
|
||||
)
|
||||
|
||||
response = self.client.post(
|
||||
"/api/import/import/",
|
||||
{"profile_id": 99999, "file": file},
|
||||
format="multipart",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertIn("profile_id", response.data)
|
||||
|
||||
@patch("apps.import_app.tasks.process_import.defer")
|
||||
@patch("django.core.files.storage.FileSystemStorage.save")
|
||||
@patch("django.core.files.storage.FileSystemStorage.path")
|
||||
def test_create_import_xlsx(self, mock_path, mock_save, mock_defer):
|
||||
"""Test successful XLSX file upload"""
|
||||
mock_save.return_value = "test_file.xlsx"
|
||||
mock_path.return_value = "/usr/src/app/temp/test_file.xlsx"
|
||||
|
||||
# Create a simple XLSX-like content (just for the upload test)
|
||||
xlsx_content = BytesIO(b"PK\x03\x04") # XLSX files start with PK header
|
||||
file = SimpleUploadedFile(
|
||||
"test_file.xlsx",
|
||||
xlsx_content.getvalue(),
|
||||
content_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
)
|
||||
|
||||
response = self.client.post(
|
||||
"/api/import/import/",
|
||||
{"profile_id": self.profile.id, "file": file},
|
||||
format="multipart",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED)
|
||||
self.assertIn("import_run_id", response.data)
|
||||
|
||||
def test_unauthenticated_request(self):
|
||||
"""Test unauthenticated request returns 403"""
|
||||
unauthenticated_client = APIClient()
|
||||
|
||||
csv_content = b"date,description,amount\n2025-01-01,Test,100"
|
||||
file = SimpleUploadedFile(
|
||||
"test_file.csv", csv_content, content_type="text/csv"
|
||||
)
|
||||
|
||||
response = unauthenticated_client.post(
|
||||
"/api/import/import/",
|
||||
{"profile_id": self.profile.id, "file": file},
|
||||
format="multipart",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class ImportProfileAPITests(TestCase):
|
||||
"""Tests for the Import Profile API endpoints"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data"""
|
||||
User = get_user_model()
|
||||
self.user = User.objects.create_user(
|
||||
email="testuser@test.com", password="testpass123"
|
||||
)
|
||||
self.client = APIClient()
|
||||
self.client.force_authenticate(user=self.user)
|
||||
|
||||
self.profile1 = ImportProfile.objects.create(
|
||||
name="Profile 1",
|
||||
version=ImportProfile.Versions.VERSION_1,
|
||||
yaml_config="""
|
||||
file_type: csv
|
||||
date_format: "%Y-%m-%d"
|
||||
column_mapping:
|
||||
date:
|
||||
source: date
|
||||
description:
|
||||
source: description
|
||||
amount:
|
||||
source: amount
|
||||
transaction_type:
|
||||
detection_method: always_expense
|
||||
is_paid:
|
||||
detection_method: always_paid
|
||||
account:
|
||||
source: account
|
||||
match_field: name
|
||||
""",
|
||||
)
|
||||
self.profile2 = ImportProfile.objects.create(
|
||||
name="Profile 2",
|
||||
version=ImportProfile.Versions.VERSION_1,
|
||||
yaml_config="""
|
||||
file_type: csv
|
||||
date_format: "%Y-%m-%d"
|
||||
column_mapping:
|
||||
date:
|
||||
source: date
|
||||
description:
|
||||
source: description
|
||||
amount:
|
||||
source: amount
|
||||
transaction_type:
|
||||
detection_method: always_income
|
||||
is_paid:
|
||||
detection_method: always_unpaid
|
||||
account:
|
||||
source: account
|
||||
match_field: name
|
||||
""",
|
||||
)
|
||||
|
||||
def test_list_profiles(self):
|
||||
"""Test listing all profiles"""
|
||||
response = self.client.get("/api/import/profiles/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["count"], 2)
|
||||
self.assertEqual(len(response.data["results"]), 2)
|
||||
|
||||
def test_retrieve_profile(self):
|
||||
"""Test retrieving a specific profile"""
|
||||
response = self.client.get(f"/api/import/profiles/{self.profile1.id}/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["id"], self.profile1.id)
|
||||
self.assertEqual(response.data["name"], "Profile 1")
|
||||
self.assertIn("yaml_config", response.data)
|
||||
|
||||
def test_retrieve_nonexistent_profile(self):
|
||||
"""Test retrieving a non-existent profile returns 404"""
|
||||
response = self.client.get("/api/import/profiles/99999/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||
|
||||
def test_profiles_unauthenticated(self):
|
||||
"""Test unauthenticated request returns 403"""
|
||||
unauthenticated_client = APIClient()
|
||||
response = unauthenticated_client.get("/api/import/profiles/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class ImportRunAPITests(TestCase):
|
||||
"""Tests for the Import Run API endpoints"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data"""
|
||||
User = get_user_model()
|
||||
self.user = User.objects.create_user(
|
||||
email="testuser@test.com", password="testpass123"
|
||||
)
|
||||
self.client = APIClient()
|
||||
self.client.force_authenticate(user=self.user)
|
||||
|
||||
self.profile1 = ImportProfile.objects.create(
|
||||
name="Profile 1",
|
||||
version=ImportProfile.Versions.VERSION_1,
|
||||
yaml_config="""
|
||||
file_type: csv
|
||||
date_format: "%Y-%m-%d"
|
||||
column_mapping:
|
||||
date:
|
||||
source: date
|
||||
description:
|
||||
source: description
|
||||
amount:
|
||||
source: amount
|
||||
transaction_type:
|
||||
detection_method: always_expense
|
||||
is_paid:
|
||||
detection_method: always_paid
|
||||
account:
|
||||
source: account
|
||||
match_field: name
|
||||
""",
|
||||
)
|
||||
self.profile2 = ImportProfile.objects.create(
|
||||
name="Profile 2",
|
||||
version=ImportProfile.Versions.VERSION_1,
|
||||
yaml_config="""
|
||||
file_type: csv
|
||||
date_format: "%Y-%m-%d"
|
||||
column_mapping:
|
||||
date:
|
||||
source: date
|
||||
description:
|
||||
source: description
|
||||
amount:
|
||||
source: amount
|
||||
transaction_type:
|
||||
detection_method: always_income
|
||||
is_paid:
|
||||
detection_method: always_unpaid
|
||||
account:
|
||||
source: account
|
||||
match_field: name
|
||||
""",
|
||||
)
|
||||
|
||||
# Create import runs
|
||||
self.run1 = ImportRun.objects.create(
|
||||
profile=self.profile1,
|
||||
file_name="file1.csv",
|
||||
status=ImportRun.Status.FINISHED,
|
||||
)
|
||||
self.run2 = ImportRun.objects.create(
|
||||
profile=self.profile1,
|
||||
file_name="file2.csv",
|
||||
status=ImportRun.Status.QUEUED,
|
||||
)
|
||||
self.run3 = ImportRun.objects.create(
|
||||
profile=self.profile2,
|
||||
file_name="file3.csv",
|
||||
status=ImportRun.Status.FINISHED,
|
||||
)
|
||||
|
||||
def test_list_all_runs(self):
|
||||
"""Test listing all runs"""
|
||||
response = self.client.get("/api/import/runs/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["count"], 3)
|
||||
self.assertEqual(len(response.data["results"]), 3)
|
||||
|
||||
def test_list_runs_by_profile(self):
|
||||
"""Test filtering runs by profile_id"""
|
||||
response = self.client.get(f"/api/import/runs/?profile_id={self.profile1.id}")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["count"], 2)
|
||||
for run in response.data["results"]:
|
||||
self.assertEqual(run["profile"], self.profile1.id)
|
||||
|
||||
def test_list_runs_by_other_profile(self):
|
||||
"""Test filtering runs by another profile_id"""
|
||||
response = self.client.get(f"/api/import/runs/?profile_id={self.profile2.id}")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["count"], 1)
|
||||
self.assertEqual(response.data["results"][0]["profile"], self.profile2.id)
|
||||
|
||||
def test_retrieve_run(self):
|
||||
"""Test retrieving a specific run"""
|
||||
response = self.client.get(f"/api/import/runs/{self.run1.id}/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["id"], self.run1.id)
|
||||
self.assertEqual(response.data["file_name"], "file1.csv")
|
||||
self.assertEqual(response.data["status"], "FINISHED")
|
||||
|
||||
def test_retrieve_nonexistent_run(self):
|
||||
"""Test retrieving a non-existent run returns 404"""
|
||||
response = self.client.get("/api/import/runs/99999/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||
|
||||
def test_runs_unauthenticated(self):
|
||||
"""Test unauthenticated request returns 403"""
|
||||
unauthenticated_client = APIClient()
|
||||
response = unauthenticated_client.get("/api/import/runs/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||
@@ -16,7 +16,11 @@ router.register(r"currencies", views.CurrencyViewSet)
|
||||
router.register(r"exchange-rates", views.ExchangeRateViewSet)
|
||||
router.register(r"dca/strategies", views.DCAStrategyViewSet)
|
||||
router.register(r"dca/entries", views.DCAEntryViewSet)
|
||||
router.register(r"import/profiles", views.ImportProfileViewSet, basename="import-profiles")
|
||||
router.register(r"import/runs", views.ImportRunViewSet, basename="import-runs")
|
||||
router.register(r"import/import", views.ImportViewSet, basename="import-import")
|
||||
|
||||
urlpatterns = [
|
||||
path("", include(router.urls)),
|
||||
]
|
||||
|
||||
|
||||
@@ -2,3 +2,5 @@ from .transactions import *
|
||||
from .accounts import *
|
||||
from .currencies import *
|
||||
from .dca import *
|
||||
from .imports import *
|
||||
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
from drf_spectacular.utils import extend_schema, extend_schema_view
|
||||
from rest_framework import viewsets
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
|
||||
from apps.api.custom.pagination import CustomPageNumberPagination
|
||||
from apps.accounts.models import AccountGroup, Account
|
||||
from apps.api.serializers import AccountGroupSerializer, AccountSerializer
|
||||
from apps.accounts.services import get_account_balance
|
||||
from apps.api.custom.pagination import CustomPageNumberPagination
|
||||
from apps.api.serializers import AccountGroupSerializer, AccountSerializer, AccountBalanceSerializer
|
||||
|
||||
|
||||
class AccountGroupViewSet(viewsets.ModelViewSet):
|
||||
"""ViewSet for managing account groups."""
|
||||
|
||||
queryset = AccountGroup.objects.all()
|
||||
serializer_class = AccountGroupSerializer
|
||||
pagination_class = CustomPageNumberPagination
|
||||
@@ -14,7 +21,16 @@ class AccountGroupViewSet(viewsets.ModelViewSet):
|
||||
return AccountGroup.objects.all().order_by("id")
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
balance=extend_schema(
|
||||
summary="Get account balance",
|
||||
description="Returns the current and projected balance for the account, along with currency data.",
|
||||
responses={200: AccountBalanceSerializer},
|
||||
),
|
||||
)
|
||||
class AccountViewSet(viewsets.ModelViewSet):
|
||||
"""ViewSet for managing accounts."""
|
||||
|
||||
queryset = Account.objects.all()
|
||||
serializer_class = AccountSerializer
|
||||
pagination_class = CustomPageNumberPagination
|
||||
@@ -25,3 +41,20 @@ class AccountViewSet(viewsets.ModelViewSet):
|
||||
.order_by("id")
|
||||
.select_related("group", "currency", "exchange_currency")
|
||||
)
|
||||
|
||||
@action(detail=True, methods=["get"], permission_classes=[IsAuthenticated])
|
||||
def balance(self, request, pk=None):
|
||||
"""Get current and projected balance for an account."""
|
||||
account = self.get_object()
|
||||
|
||||
current_balance = get_account_balance(account, paid_only=True)
|
||||
projected_balance = get_account_balance(account, paid_only=False)
|
||||
|
||||
serializer = AccountBalanceSerializer({
|
||||
"current_balance": current_balance,
|
||||
"projected_balance": projected_balance,
|
||||
"currency": account.currency,
|
||||
})
|
||||
|
||||
return Response(serializer.data)
|
||||
|
||||
|
||||
123
app/apps/api/views/imports.py
Normal file
123
app/apps/api/views/imports.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from django.core.files.storage import FileSystemStorage
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view, inline_serializer
|
||||
from rest_framework import serializers as drf_serializers
|
||||
from rest_framework import status, viewsets
|
||||
from rest_framework.parsers import MultiPartParser
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
|
||||
from apps.api.serializers import ImportFileSerializer, ImportProfileSerializer, ImportRunSerializer
|
||||
from apps.import_app.models import ImportProfile, ImportRun
|
||||
from apps.import_app.tasks import process_import
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
list=extend_schema(
|
||||
summary="List import profiles",
|
||||
description="Returns a paginated list of all available import profiles.",
|
||||
),
|
||||
retrieve=extend_schema(
|
||||
summary="Get import profile",
|
||||
description="Returns the details of a specific import profile by ID.",
|
||||
),
|
||||
)
|
||||
class ImportProfileViewSet(viewsets.ReadOnlyModelViewSet):
|
||||
"""ViewSet for listing and retrieving import profiles."""
|
||||
|
||||
queryset = ImportProfile.objects.all()
|
||||
serializer_class = ImportProfileSerializer
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
list=extend_schema(
|
||||
summary="List import runs",
|
||||
description="Returns a paginated list of import runs. Optionally filter by profile_id.",
|
||||
parameters=[
|
||||
OpenApiParameter(
|
||||
name="profile_id",
|
||||
type=int,
|
||||
location=OpenApiParameter.QUERY,
|
||||
description="Filter runs by profile ID",
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
),
|
||||
retrieve=extend_schema(
|
||||
summary="Get import run",
|
||||
description="Returns the details of a specific import run by ID, including status and logs.",
|
||||
),
|
||||
)
|
||||
class ImportRunViewSet(viewsets.ReadOnlyModelViewSet):
|
||||
"""ViewSet for listing and retrieving import runs."""
|
||||
|
||||
queryset = ImportRun.objects.all().order_by("-id")
|
||||
serializer_class = ImportRunSerializer
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
def get_queryset(self):
|
||||
queryset = super().get_queryset()
|
||||
profile_id = self.request.query_params.get("profile_id")
|
||||
if profile_id:
|
||||
queryset = queryset.filter(profile_id=profile_id)
|
||||
return queryset
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
create=extend_schema(
|
||||
summary="Import file",
|
||||
description="Upload a CSV or XLSX file to import using an existing import profile. The import is queued and processed asynchronously.",
|
||||
request={
|
||||
"multipart/form-data": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"profile_id": {"type": "integer", "description": "ID of the ImportProfile to use"},
|
||||
"file": {"type": "string", "format": "binary", "description": "CSV or XLSX file to import"},
|
||||
},
|
||||
"required": ["profile_id", "file"],
|
||||
},
|
||||
},
|
||||
responses={
|
||||
202: inline_serializer(
|
||||
name="ImportResponse",
|
||||
fields={
|
||||
"import_run_id": drf_serializers.IntegerField(),
|
||||
"status": drf_serializers.CharField(),
|
||||
},
|
||||
),
|
||||
},
|
||||
),
|
||||
)
|
||||
class ImportViewSet(viewsets.ViewSet):
|
||||
"""ViewSet for importing data via file upload."""
|
||||
|
||||
permission_classes = [IsAuthenticated]
|
||||
parser_classes = [MultiPartParser]
|
||||
|
||||
def create(self, request):
|
||||
serializer = ImportFileSerializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
|
||||
profile = serializer.validated_data["profile"]
|
||||
uploaded_file = serializer.validated_data["file"]
|
||||
|
||||
# Save file to temp location
|
||||
fs = FileSystemStorage(location="/usr/src/app/temp")
|
||||
filename = fs.save(uploaded_file.name, uploaded_file)
|
||||
file_path = fs.path(filename)
|
||||
|
||||
# Create ImportRun record
|
||||
import_run = ImportRun.objects.create(profile=profile, file_name=filename)
|
||||
|
||||
# Queue import task
|
||||
process_import.defer(
|
||||
import_run_id=import_run.id,
|
||||
file_path=file_path,
|
||||
user_id=request.user.id,
|
||||
)
|
||||
|
||||
return Response(
|
||||
{"import_run_id": import_run.id, "status": "queued"},
|
||||
status=status.HTTP_202_ACCEPTED,
|
||||
)
|
||||
Reference in New Issue
Block a user