mirror of
https://github.com/eitchtee/WYGIWYH.git
synced 2025-12-21 13:00:12 -06:00
Merge pull request #425
feat(api): add endpoints for importing files and getting account balance
This commit is contained in:
33
app/apps/accounts/services.py
Normal file
33
app/apps/accounts/services.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
from apps.accounts.models import Account
|
||||||
|
from apps.transactions.models import Transaction
|
||||||
|
|
||||||
|
|
||||||
|
def get_account_balance(account: Account, paid_only: bool = True) -> Decimal:
|
||||||
|
"""
|
||||||
|
Calculate account balance (income - expense).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
account: Account instance to calculate balance for.
|
||||||
|
paid_only: If True, only count paid transactions (current balance).
|
||||||
|
If False, count all transactions (projected balance).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decimal: The calculated balance (income - expense).
|
||||||
|
"""
|
||||||
|
filters = {"account": account}
|
||||||
|
if paid_only:
|
||||||
|
filters["is_paid"] = True
|
||||||
|
|
||||||
|
income = Transaction.objects.filter(
|
||||||
|
type=Transaction.Type.INCOME, **filters
|
||||||
|
).aggregate(total=models.Sum("amount"))["total"] or Decimal("0")
|
||||||
|
|
||||||
|
expense = Transaction.objects.filter(
|
||||||
|
type=Transaction.Type.EXPENSE, **filters
|
||||||
|
).aggregate(total=models.Sum("amount"))["total"] or Decimal("0")
|
||||||
|
|
||||||
|
return income - expense
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from datetime import date
|
||||||
|
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
from apps.accounts.models import Account, AccountGroup
|
from apps.accounts.models import Account, AccountGroup
|
||||||
@@ -39,3 +41,135 @@ class AccountTests(TestCase):
|
|||||||
exchange_currency=self.exchange_currency,
|
exchange_currency=self.exchange_currency,
|
||||||
)
|
)
|
||||||
self.assertEqual(account.exchange_currency, self.exchange_currency)
|
self.assertEqual(account.exchange_currency, self.exchange_currency)
|
||||||
|
|
||||||
|
|
||||||
|
class GetAccountBalanceServiceTests(TestCase):
|
||||||
|
"""Tests for the get_account_balance service function"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test data"""
|
||||||
|
from apps.transactions.models import Transaction
|
||||||
|
self.Transaction = Transaction
|
||||||
|
|
||||||
|
self.currency = Currency.objects.create(
|
||||||
|
code="BRL", name="Brazilian Real", decimal_places=2, prefix="R$ "
|
||||||
|
)
|
||||||
|
self.account_group = AccountGroup.objects.create(name="Service Test Group")
|
||||||
|
self.account = Account.objects.create(
|
||||||
|
name="Service Test Account", group=self.account_group, currency=self.currency
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_balance_with_no_transactions(self):
|
||||||
|
"""Test balance is 0 when no transactions exist"""
|
||||||
|
from apps.accounts.services import get_account_balance
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
balance = get_account_balance(self.account, paid_only=True)
|
||||||
|
self.assertEqual(balance, Decimal("0"))
|
||||||
|
|
||||||
|
def test_current_balance_only_counts_paid(self):
|
||||||
|
"""Test current balance only counts paid transactions"""
|
||||||
|
from apps.accounts.services import get_account_balance
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
# Paid income
|
||||||
|
self.Transaction.objects.create(
|
||||||
|
account=self.account,
|
||||||
|
type=self.Transaction.Type.INCOME,
|
||||||
|
amount=Decimal("100.00"),
|
||||||
|
is_paid=True,
|
||||||
|
date=date(2025, 1, 1),
|
||||||
|
description="Paid income",
|
||||||
|
)
|
||||||
|
# Unpaid income (should not count)
|
||||||
|
self.Transaction.objects.create(
|
||||||
|
account=self.account,
|
||||||
|
type=self.Transaction.Type.INCOME,
|
||||||
|
amount=Decimal("50.00"),
|
||||||
|
is_paid=False,
|
||||||
|
date=date(2025, 1, 1),
|
||||||
|
description="Unpaid income",
|
||||||
|
)
|
||||||
|
# Paid expense
|
||||||
|
self.Transaction.objects.create(
|
||||||
|
account=self.account,
|
||||||
|
type=self.Transaction.Type.EXPENSE,
|
||||||
|
amount=Decimal("30.00"),
|
||||||
|
is_paid=True,
|
||||||
|
date=date(2025, 1, 1),
|
||||||
|
description="Paid expense",
|
||||||
|
)
|
||||||
|
|
||||||
|
balance = get_account_balance(self.account, paid_only=True)
|
||||||
|
self.assertEqual(balance, Decimal("70.00")) # 100 - 30
|
||||||
|
|
||||||
|
def test_projected_balance_counts_all(self):
|
||||||
|
"""Test projected balance counts all transactions"""
|
||||||
|
from apps.accounts.services import get_account_balance
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
# Paid income
|
||||||
|
self.Transaction.objects.create(
|
||||||
|
account=self.account,
|
||||||
|
type=self.Transaction.Type.INCOME,
|
||||||
|
amount=Decimal("100.00"),
|
||||||
|
is_paid=True,
|
||||||
|
date=date(2025, 1, 1),
|
||||||
|
description="Paid income",
|
||||||
|
)
|
||||||
|
# Unpaid income
|
||||||
|
self.Transaction.objects.create(
|
||||||
|
account=self.account,
|
||||||
|
type=self.Transaction.Type.INCOME,
|
||||||
|
amount=Decimal("50.00"),
|
||||||
|
is_paid=False,
|
||||||
|
date=date(2025, 1, 1),
|
||||||
|
description="Unpaid income",
|
||||||
|
)
|
||||||
|
# Paid expense
|
||||||
|
self.Transaction.objects.create(
|
||||||
|
account=self.account,
|
||||||
|
type=self.Transaction.Type.EXPENSE,
|
||||||
|
amount=Decimal("30.00"),
|
||||||
|
is_paid=True,
|
||||||
|
date=date(2025, 1, 1),
|
||||||
|
description="Paid expense",
|
||||||
|
)
|
||||||
|
# Unpaid expense
|
||||||
|
self.Transaction.objects.create(
|
||||||
|
account=self.account,
|
||||||
|
type=self.Transaction.Type.EXPENSE,
|
||||||
|
amount=Decimal("20.00"),
|
||||||
|
is_paid=False,
|
||||||
|
date=date(2025, 1, 1),
|
||||||
|
description="Unpaid expense",
|
||||||
|
)
|
||||||
|
|
||||||
|
balance = get_account_balance(self.account, paid_only=False)
|
||||||
|
self.assertEqual(balance, Decimal("100.00")) # (100 + 50) - (30 + 20)
|
||||||
|
|
||||||
|
def test_balance_defaults_to_paid_only(self):
|
||||||
|
"""Test that paid_only defaults to True"""
|
||||||
|
from apps.accounts.services import get_account_balance
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
self.Transaction.objects.create(
|
||||||
|
account=self.account,
|
||||||
|
type=self.Transaction.Type.INCOME,
|
||||||
|
amount=Decimal("100.00"),
|
||||||
|
is_paid=True,
|
||||||
|
date=date(2025, 1, 1),
|
||||||
|
description="Paid",
|
||||||
|
)
|
||||||
|
self.Transaction.objects.create(
|
||||||
|
account=self.account,
|
||||||
|
type=self.Transaction.Type.INCOME,
|
||||||
|
amount=Decimal("50.00"),
|
||||||
|
is_paid=False,
|
||||||
|
date=date(2025, 1, 1),
|
||||||
|
description="Unpaid",
|
||||||
|
)
|
||||||
|
|
||||||
|
balance = get_account_balance(self.account) # defaults to paid_only=True
|
||||||
|
self.assertEqual(balance, Decimal("100.00"))
|
||||||
|
|
||||||
|
|||||||
@@ -11,23 +11,13 @@ from django.utils.translation import gettext_lazy as _
|
|||||||
|
|
||||||
from apps.accounts.forms import AccountBalanceFormSet
|
from apps.accounts.forms import AccountBalanceFormSet
|
||||||
from apps.accounts.models import Account, Transaction
|
from apps.accounts.models import Account, Transaction
|
||||||
|
from apps.accounts.services import get_account_balance
|
||||||
from apps.common.decorators.htmx import only_htmx
|
from apps.common.decorators.htmx import only_htmx
|
||||||
|
|
||||||
|
|
||||||
@only_htmx
|
@only_htmx
|
||||||
@login_required
|
@login_required
|
||||||
def account_reconciliation(request):
|
def account_reconciliation(request):
|
||||||
def get_account_balance(account):
|
|
||||||
income = Transaction.objects.filter(
|
|
||||||
account=account, type=Transaction.Type.INCOME, is_paid=True
|
|
||||||
).aggregate(total=models.Sum("amount"))["total"] or Decimal("0")
|
|
||||||
|
|
||||||
expense = Transaction.objects.filter(
|
|
||||||
account=account, type=Transaction.Type.EXPENSE, is_paid=True
|
|
||||||
).aggregate(total=models.Sum("amount"))["total"] or Decimal("0")
|
|
||||||
|
|
||||||
return income - expense
|
|
||||||
|
|
||||||
initial_data = [
|
initial_data = [
|
||||||
{
|
{
|
||||||
"account_id": account.id,
|
"account_id": account.id,
|
||||||
|
|||||||
@@ -2,3 +2,5 @@ from .transactions import *
|
|||||||
from .accounts import *
|
from .accounts import *
|
||||||
from .currencies import *
|
from .currencies import *
|
||||||
from .dca import *
|
from .dca import *
|
||||||
|
from .imports import *
|
||||||
|
|
||||||
|
|||||||
@@ -67,3 +67,12 @@ class AccountSerializer(serializers.ModelSerializer):
|
|||||||
setattr(instance, attr, value)
|
setattr(instance, attr, value)
|
||||||
instance.save()
|
instance.save()
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
|
|
||||||
|
class AccountBalanceSerializer(serializers.Serializer):
|
||||||
|
"""Serializer for account balance response."""
|
||||||
|
|
||||||
|
current_balance = serializers.DecimalField(max_digits=20, decimal_places=10)
|
||||||
|
projected_balance = serializers.DecimalField(max_digits=20, decimal_places=10)
|
||||||
|
currency = CurrencySerializer()
|
||||||
|
|
||||||
|
|||||||
41
app/apps/api/serializers/imports.py
Normal file
41
app/apps/api/serializers/imports.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
from rest_framework import serializers
|
||||||
|
|
||||||
|
from apps.import_app.models import ImportProfile, ImportRun
|
||||||
|
|
||||||
|
|
||||||
|
class ImportProfileSerializer(serializers.ModelSerializer):
|
||||||
|
"""Serializer for listing import profiles."""
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
model = ImportProfile
|
||||||
|
fields = ["id", "name", "version", "yaml_config"]
|
||||||
|
|
||||||
|
|
||||||
|
class ImportRunSerializer(serializers.ModelSerializer):
|
||||||
|
"""Serializer for listing import runs."""
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
model = ImportRun
|
||||||
|
fields = [
|
||||||
|
"id",
|
||||||
|
"status",
|
||||||
|
"profile",
|
||||||
|
"file_name",
|
||||||
|
"logs",
|
||||||
|
"processed_rows",
|
||||||
|
"total_rows",
|
||||||
|
"successful_rows",
|
||||||
|
"skipped_rows",
|
||||||
|
"failed_rows",
|
||||||
|
"started_at",
|
||||||
|
"finished_at",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ImportFileSerializer(serializers.Serializer):
|
||||||
|
"""Serializer for uploading a file to import using an existing profile."""
|
||||||
|
|
||||||
|
profile_id = serializers.PrimaryKeyRelatedField(
|
||||||
|
queryset=ImportProfile.objects.all(), source="profile"
|
||||||
|
)
|
||||||
|
file = serializers.FileField()
|
||||||
4
app/apps/api/tests/__init__.py
Normal file
4
app/apps/api/tests/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
# Import all test classes for Django test discovery
|
||||||
|
from .test_imports import *
|
||||||
|
from .test_accounts import *
|
||||||
|
|
||||||
99
app/apps/api/tests/test_accounts.py
Normal file
99
app/apps/api/tests/test_accounts.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
from datetime import date
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
from django.contrib.auth import get_user_model
|
||||||
|
from django.test import TestCase, override_settings
|
||||||
|
from rest_framework import status
|
||||||
|
from rest_framework.test import APIClient
|
||||||
|
|
||||||
|
from apps.accounts.models import Account, AccountGroup
|
||||||
|
from apps.currencies.models import Currency
|
||||||
|
from apps.transactions.models import Transaction
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(
|
||||||
|
STORAGES={
|
||||||
|
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||||
|
"staticfiles": {
|
||||||
|
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
WHITENOISE_AUTOREFRESH=True,
|
||||||
|
)
|
||||||
|
class AccountBalanceAPITests(TestCase):
|
||||||
|
"""Tests for the Account Balance API endpoint"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test data"""
|
||||||
|
User = get_user_model()
|
||||||
|
self.user = User.objects.create_user(
|
||||||
|
email="testuser@test.com", password="testpass123"
|
||||||
|
)
|
||||||
|
self.client = APIClient()
|
||||||
|
self.client.force_authenticate(user=self.user)
|
||||||
|
|
||||||
|
self.currency = Currency.objects.create(
|
||||||
|
code="USD", name="US Dollar", decimal_places=2, prefix="$ "
|
||||||
|
)
|
||||||
|
self.account_group = AccountGroup.objects.create(name="Test Group")
|
||||||
|
self.account = Account.objects.create(
|
||||||
|
name="Test Account", group=self.account_group, currency=self.currency
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create some transactions
|
||||||
|
Transaction.objects.create(
|
||||||
|
account=self.account,
|
||||||
|
type=Transaction.Type.INCOME,
|
||||||
|
amount=Decimal("500.00"),
|
||||||
|
is_paid=True,
|
||||||
|
date=date(2025, 1, 1),
|
||||||
|
description="Paid income",
|
||||||
|
)
|
||||||
|
Transaction.objects.create(
|
||||||
|
account=self.account,
|
||||||
|
type=Transaction.Type.INCOME,
|
||||||
|
amount=Decimal("200.00"),
|
||||||
|
is_paid=False,
|
||||||
|
date=date(2025, 1, 15),
|
||||||
|
description="Unpaid income",
|
||||||
|
)
|
||||||
|
Transaction.objects.create(
|
||||||
|
account=self.account,
|
||||||
|
type=Transaction.Type.EXPENSE,
|
||||||
|
amount=Decimal("100.00"),
|
||||||
|
is_paid=True,
|
||||||
|
date=date(2025, 1, 10),
|
||||||
|
description="Paid expense",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_get_balance_success(self):
|
||||||
|
"""Test successful balance retrieval"""
|
||||||
|
response = self.client.get(f"/api/accounts/{self.account.id}/balance/")
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertIn("current_balance", response.data)
|
||||||
|
self.assertIn("projected_balance", response.data)
|
||||||
|
self.assertIn("currency", response.data)
|
||||||
|
|
||||||
|
# Current: 500 - 100 = 400
|
||||||
|
self.assertEqual(Decimal(response.data["current_balance"]), Decimal("400.00"))
|
||||||
|
# Projected: (500 + 200) - 100 = 600
|
||||||
|
self.assertEqual(Decimal(response.data["projected_balance"]), Decimal("600.00"))
|
||||||
|
|
||||||
|
# Check currency data
|
||||||
|
self.assertEqual(response.data["currency"]["code"], "USD")
|
||||||
|
|
||||||
|
def test_get_balance_nonexistent_account(self):
|
||||||
|
"""Test balance for non-existent account returns 404"""
|
||||||
|
response = self.client.get("/api/accounts/99999/balance/")
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||||
|
|
||||||
|
def test_get_balance_unauthenticated(self):
|
||||||
|
"""Test unauthenticated request returns 403"""
|
||||||
|
unauthenticated_client = APIClient()
|
||||||
|
response = unauthenticated_client.get(
|
||||||
|
f"/api/accounts/{self.account.id}/balance/"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||||
404
app/apps/api/tests/test_imports.py
Normal file
404
app/apps/api/tests/test_imports.py
Normal file
@@ -0,0 +1,404 @@
|
|||||||
|
from io import BytesIO
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from django.contrib.auth import get_user_model
|
||||||
|
from django.core.files.uploadedfile import SimpleUploadedFile
|
||||||
|
from django.test import TestCase, override_settings
|
||||||
|
from rest_framework import status
|
||||||
|
from rest_framework.test import APIClient
|
||||||
|
|
||||||
|
from apps.import_app.models import ImportProfile, ImportRun
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(
|
||||||
|
STORAGES={
|
||||||
|
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||||
|
"staticfiles": {
|
||||||
|
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
WHITENOISE_AUTOREFRESH=True,
|
||||||
|
)
|
||||||
|
class ImportAPITests(TestCase):
|
||||||
|
"""Tests for the Import API endpoint"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test data"""
|
||||||
|
User = get_user_model()
|
||||||
|
self.user = User.objects.create_user(
|
||||||
|
email="testuser@test.com", password="testpass123"
|
||||||
|
)
|
||||||
|
self.client = APIClient()
|
||||||
|
self.client.force_authenticate(user=self.user)
|
||||||
|
|
||||||
|
# Create a basic import profile with minimal valid YAML config
|
||||||
|
self.profile = ImportProfile.objects.create(
|
||||||
|
name="Test Profile",
|
||||||
|
version=ImportProfile.Versions.VERSION_1,
|
||||||
|
yaml_config="""
|
||||||
|
file_type: csv
|
||||||
|
date_format: "%Y-%m-%d"
|
||||||
|
column_mapping:
|
||||||
|
date:
|
||||||
|
source: date
|
||||||
|
description:
|
||||||
|
source: description
|
||||||
|
amount:
|
||||||
|
source: amount
|
||||||
|
transaction_type:
|
||||||
|
detection_method: always_expense
|
||||||
|
is_paid:
|
||||||
|
detection_method: always_paid
|
||||||
|
account:
|
||||||
|
source: account
|
||||||
|
match_field: name
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("apps.import_app.tasks.process_import.defer")
|
||||||
|
@patch("django.core.files.storage.FileSystemStorage.save")
|
||||||
|
@patch("django.core.files.storage.FileSystemStorage.path")
|
||||||
|
def test_create_import_success(self, mock_path, mock_save, mock_defer):
|
||||||
|
"""Test successful file upload creates ImportRun and queues task"""
|
||||||
|
mock_save.return_value = "test_file.csv"
|
||||||
|
mock_path.return_value = "/usr/src/app/temp/test_file.csv"
|
||||||
|
|
||||||
|
csv_content = b"date,description,amount,account\n2025-01-01,Test,100,Main"
|
||||||
|
file = SimpleUploadedFile(
|
||||||
|
"test_file.csv", csv_content, content_type="text/csv"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = self.client.post(
|
||||||
|
"/api/import/import/",
|
||||||
|
{"profile_id": self.profile.id, "file": file},
|
||||||
|
format="multipart",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED)
|
||||||
|
self.assertIn("import_run_id", response.data)
|
||||||
|
self.assertEqual(response.data["status"], "queued")
|
||||||
|
|
||||||
|
# Verify ImportRun was created
|
||||||
|
import_run = ImportRun.objects.get(id=response.data["import_run_id"])
|
||||||
|
self.assertEqual(import_run.profile, self.profile)
|
||||||
|
self.assertEqual(import_run.file_name, "test_file.csv")
|
||||||
|
|
||||||
|
# Verify task was deferred
|
||||||
|
mock_defer.assert_called_once_with(
|
||||||
|
import_run_id=import_run.id,
|
||||||
|
file_path="/usr/src/app/temp/test_file.csv",
|
||||||
|
user_id=self.user.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_create_import_missing_profile(self):
|
||||||
|
"""Test request without profile_id returns 400"""
|
||||||
|
csv_content = b"date,description,amount\n2025-01-01,Test,100"
|
||||||
|
file = SimpleUploadedFile(
|
||||||
|
"test_file.csv", csv_content, content_type="text/csv"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = self.client.post(
|
||||||
|
"/api/import/import/",
|
||||||
|
{"file": file},
|
||||||
|
format="multipart",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||||
|
self.assertIn("profile_id", response.data)
|
||||||
|
|
||||||
|
def test_create_import_missing_file(self):
|
||||||
|
"""Test request without file returns 400"""
|
||||||
|
response = self.client.post(
|
||||||
|
"/api/import/import/",
|
||||||
|
{"profile_id": self.profile.id},
|
||||||
|
format="multipart",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||||
|
self.assertIn("file", response.data)
|
||||||
|
|
||||||
|
def test_create_import_invalid_profile(self):
|
||||||
|
"""Test request with non-existent profile returns 400"""
|
||||||
|
csv_content = b"date,description,amount\n2025-01-01,Test,100"
|
||||||
|
file = SimpleUploadedFile(
|
||||||
|
"test_file.csv", csv_content, content_type="text/csv"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = self.client.post(
|
||||||
|
"/api/import/import/",
|
||||||
|
{"profile_id": 99999, "file": file},
|
||||||
|
format="multipart",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||||
|
self.assertIn("profile_id", response.data)
|
||||||
|
|
||||||
|
@patch("apps.import_app.tasks.process_import.defer")
|
||||||
|
@patch("django.core.files.storage.FileSystemStorage.save")
|
||||||
|
@patch("django.core.files.storage.FileSystemStorage.path")
|
||||||
|
def test_create_import_xlsx(self, mock_path, mock_save, mock_defer):
|
||||||
|
"""Test successful XLSX file upload"""
|
||||||
|
mock_save.return_value = "test_file.xlsx"
|
||||||
|
mock_path.return_value = "/usr/src/app/temp/test_file.xlsx"
|
||||||
|
|
||||||
|
# Create a simple XLSX-like content (just for the upload test)
|
||||||
|
xlsx_content = BytesIO(b"PK\x03\x04") # XLSX files start with PK header
|
||||||
|
file = SimpleUploadedFile(
|
||||||
|
"test_file.xlsx",
|
||||||
|
xlsx_content.getvalue(),
|
||||||
|
content_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = self.client.post(
|
||||||
|
"/api/import/import/",
|
||||||
|
{"profile_id": self.profile.id, "file": file},
|
||||||
|
format="multipart",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED)
|
||||||
|
self.assertIn("import_run_id", response.data)
|
||||||
|
|
||||||
|
def test_unauthenticated_request(self):
|
||||||
|
"""Test unauthenticated request returns 403"""
|
||||||
|
unauthenticated_client = APIClient()
|
||||||
|
|
||||||
|
csv_content = b"date,description,amount\n2025-01-01,Test,100"
|
||||||
|
file = SimpleUploadedFile(
|
||||||
|
"test_file.csv", csv_content, content_type="text/csv"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = unauthenticated_client.post(
|
||||||
|
"/api/import/import/",
|
||||||
|
{"profile_id": self.profile.id, "file": file},
|
||||||
|
format="multipart",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(
|
||||||
|
STORAGES={
|
||||||
|
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||||
|
"staticfiles": {
|
||||||
|
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
WHITENOISE_AUTOREFRESH=True,
|
||||||
|
)
|
||||||
|
class ImportProfileAPITests(TestCase):
|
||||||
|
"""Tests for the Import Profile API endpoints"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test data"""
|
||||||
|
User = get_user_model()
|
||||||
|
self.user = User.objects.create_user(
|
||||||
|
email="testuser@test.com", password="testpass123"
|
||||||
|
)
|
||||||
|
self.client = APIClient()
|
||||||
|
self.client.force_authenticate(user=self.user)
|
||||||
|
|
||||||
|
self.profile1 = ImportProfile.objects.create(
|
||||||
|
name="Profile 1",
|
||||||
|
version=ImportProfile.Versions.VERSION_1,
|
||||||
|
yaml_config="""
|
||||||
|
file_type: csv
|
||||||
|
date_format: "%Y-%m-%d"
|
||||||
|
column_mapping:
|
||||||
|
date:
|
||||||
|
source: date
|
||||||
|
description:
|
||||||
|
source: description
|
||||||
|
amount:
|
||||||
|
source: amount
|
||||||
|
transaction_type:
|
||||||
|
detection_method: always_expense
|
||||||
|
is_paid:
|
||||||
|
detection_method: always_paid
|
||||||
|
account:
|
||||||
|
source: account
|
||||||
|
match_field: name
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
self.profile2 = ImportProfile.objects.create(
|
||||||
|
name="Profile 2",
|
||||||
|
version=ImportProfile.Versions.VERSION_1,
|
||||||
|
yaml_config="""
|
||||||
|
file_type: csv
|
||||||
|
date_format: "%Y-%m-%d"
|
||||||
|
column_mapping:
|
||||||
|
date:
|
||||||
|
source: date
|
||||||
|
description:
|
||||||
|
source: description
|
||||||
|
amount:
|
||||||
|
source: amount
|
||||||
|
transaction_type:
|
||||||
|
detection_method: always_income
|
||||||
|
is_paid:
|
||||||
|
detection_method: always_unpaid
|
||||||
|
account:
|
||||||
|
source: account
|
||||||
|
match_field: name
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_list_profiles(self):
|
||||||
|
"""Test listing all profiles"""
|
||||||
|
response = self.client.get("/api/import/profiles/")
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertEqual(response.data["count"], 2)
|
||||||
|
self.assertEqual(len(response.data["results"]), 2)
|
||||||
|
|
||||||
|
def test_retrieve_profile(self):
|
||||||
|
"""Test retrieving a specific profile"""
|
||||||
|
response = self.client.get(f"/api/import/profiles/{self.profile1.id}/")
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertEqual(response.data["id"], self.profile1.id)
|
||||||
|
self.assertEqual(response.data["name"], "Profile 1")
|
||||||
|
self.assertIn("yaml_config", response.data)
|
||||||
|
|
||||||
|
def test_retrieve_nonexistent_profile(self):
|
||||||
|
"""Test retrieving a non-existent profile returns 404"""
|
||||||
|
response = self.client.get("/api/import/profiles/99999/")
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||||
|
|
||||||
|
def test_profiles_unauthenticated(self):
|
||||||
|
"""Test unauthenticated request returns 403"""
|
||||||
|
unauthenticated_client = APIClient()
|
||||||
|
response = unauthenticated_client.get("/api/import/profiles/")
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(
|
||||||
|
STORAGES={
|
||||||
|
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||||
|
"staticfiles": {
|
||||||
|
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
WHITENOISE_AUTOREFRESH=True,
|
||||||
|
)
|
||||||
|
class ImportRunAPITests(TestCase):
|
||||||
|
"""Tests for the Import Run API endpoints"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test data"""
|
||||||
|
User = get_user_model()
|
||||||
|
self.user = User.objects.create_user(
|
||||||
|
email="testuser@test.com", password="testpass123"
|
||||||
|
)
|
||||||
|
self.client = APIClient()
|
||||||
|
self.client.force_authenticate(user=self.user)
|
||||||
|
|
||||||
|
self.profile1 = ImportProfile.objects.create(
|
||||||
|
name="Profile 1",
|
||||||
|
version=ImportProfile.Versions.VERSION_1,
|
||||||
|
yaml_config="""
|
||||||
|
file_type: csv
|
||||||
|
date_format: "%Y-%m-%d"
|
||||||
|
column_mapping:
|
||||||
|
date:
|
||||||
|
source: date
|
||||||
|
description:
|
||||||
|
source: description
|
||||||
|
amount:
|
||||||
|
source: amount
|
||||||
|
transaction_type:
|
||||||
|
detection_method: always_expense
|
||||||
|
is_paid:
|
||||||
|
detection_method: always_paid
|
||||||
|
account:
|
||||||
|
source: account
|
||||||
|
match_field: name
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
self.profile2 = ImportProfile.objects.create(
|
||||||
|
name="Profile 2",
|
||||||
|
version=ImportProfile.Versions.VERSION_1,
|
||||||
|
yaml_config="""
|
||||||
|
file_type: csv
|
||||||
|
date_format: "%Y-%m-%d"
|
||||||
|
column_mapping:
|
||||||
|
date:
|
||||||
|
source: date
|
||||||
|
description:
|
||||||
|
source: description
|
||||||
|
amount:
|
||||||
|
source: amount
|
||||||
|
transaction_type:
|
||||||
|
detection_method: always_income
|
||||||
|
is_paid:
|
||||||
|
detection_method: always_unpaid
|
||||||
|
account:
|
||||||
|
source: account
|
||||||
|
match_field: name
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create import runs
|
||||||
|
self.run1 = ImportRun.objects.create(
|
||||||
|
profile=self.profile1,
|
||||||
|
file_name="file1.csv",
|
||||||
|
status=ImportRun.Status.FINISHED,
|
||||||
|
)
|
||||||
|
self.run2 = ImportRun.objects.create(
|
||||||
|
profile=self.profile1,
|
||||||
|
file_name="file2.csv",
|
||||||
|
status=ImportRun.Status.QUEUED,
|
||||||
|
)
|
||||||
|
self.run3 = ImportRun.objects.create(
|
||||||
|
profile=self.profile2,
|
||||||
|
file_name="file3.csv",
|
||||||
|
status=ImportRun.Status.FINISHED,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_list_all_runs(self):
|
||||||
|
"""Test listing all runs"""
|
||||||
|
response = self.client.get("/api/import/runs/")
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertEqual(response.data["count"], 3)
|
||||||
|
self.assertEqual(len(response.data["results"]), 3)
|
||||||
|
|
||||||
|
def test_list_runs_by_profile(self):
|
||||||
|
"""Test filtering runs by profile_id"""
|
||||||
|
response = self.client.get(f"/api/import/runs/?profile_id={self.profile1.id}")
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertEqual(response.data["count"], 2)
|
||||||
|
for run in response.data["results"]:
|
||||||
|
self.assertEqual(run["profile"], self.profile1.id)
|
||||||
|
|
||||||
|
def test_list_runs_by_other_profile(self):
|
||||||
|
"""Test filtering runs by another profile_id"""
|
||||||
|
response = self.client.get(f"/api/import/runs/?profile_id={self.profile2.id}")
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertEqual(response.data["count"], 1)
|
||||||
|
self.assertEqual(response.data["results"][0]["profile"], self.profile2.id)
|
||||||
|
|
||||||
|
def test_retrieve_run(self):
|
||||||
|
"""Test retrieving a specific run"""
|
||||||
|
response = self.client.get(f"/api/import/runs/{self.run1.id}/")
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertEqual(response.data["id"], self.run1.id)
|
||||||
|
self.assertEqual(response.data["file_name"], "file1.csv")
|
||||||
|
self.assertEqual(response.data["status"], "FINISHED")
|
||||||
|
|
||||||
|
def test_retrieve_nonexistent_run(self):
|
||||||
|
"""Test retrieving a non-existent run returns 404"""
|
||||||
|
response = self.client.get("/api/import/runs/99999/")
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||||
|
|
||||||
|
def test_runs_unauthenticated(self):
|
||||||
|
"""Test unauthenticated request returns 403"""
|
||||||
|
unauthenticated_client = APIClient()
|
||||||
|
response = unauthenticated_client.get("/api/import/runs/")
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||||
@@ -16,7 +16,11 @@ router.register(r"currencies", views.CurrencyViewSet)
|
|||||||
router.register(r"exchange-rates", views.ExchangeRateViewSet)
|
router.register(r"exchange-rates", views.ExchangeRateViewSet)
|
||||||
router.register(r"dca/strategies", views.DCAStrategyViewSet)
|
router.register(r"dca/strategies", views.DCAStrategyViewSet)
|
||||||
router.register(r"dca/entries", views.DCAEntryViewSet)
|
router.register(r"dca/entries", views.DCAEntryViewSet)
|
||||||
|
router.register(r"import/profiles", views.ImportProfileViewSet, basename="import-profiles")
|
||||||
|
router.register(r"import/runs", views.ImportRunViewSet, basename="import-runs")
|
||||||
|
router.register(r"import/import", views.ImportViewSet, basename="import-import")
|
||||||
|
|
||||||
urlpatterns = [
|
urlpatterns = [
|
||||||
path("", include(router.urls)),
|
path("", include(router.urls)),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -2,3 +2,5 @@ from .transactions import *
|
|||||||
from .accounts import *
|
from .accounts import *
|
||||||
from .currencies import *
|
from .currencies import *
|
||||||
from .dca import *
|
from .dca import *
|
||||||
|
from .imports import *
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,18 @@
|
|||||||
|
from drf_spectacular.utils import extend_schema, extend_schema_view
|
||||||
from rest_framework import viewsets
|
from rest_framework import viewsets
|
||||||
|
from rest_framework.decorators import action
|
||||||
|
from rest_framework.permissions import IsAuthenticated
|
||||||
|
from rest_framework.response import Response
|
||||||
|
|
||||||
from apps.api.custom.pagination import CustomPageNumberPagination
|
|
||||||
from apps.accounts.models import AccountGroup, Account
|
from apps.accounts.models import AccountGroup, Account
|
||||||
from apps.api.serializers import AccountGroupSerializer, AccountSerializer
|
from apps.accounts.services import get_account_balance
|
||||||
|
from apps.api.custom.pagination import CustomPageNumberPagination
|
||||||
|
from apps.api.serializers import AccountGroupSerializer, AccountSerializer, AccountBalanceSerializer
|
||||||
|
|
||||||
|
|
||||||
class AccountGroupViewSet(viewsets.ModelViewSet):
|
class AccountGroupViewSet(viewsets.ModelViewSet):
|
||||||
|
"""ViewSet for managing account groups."""
|
||||||
|
|
||||||
queryset = AccountGroup.objects.all()
|
queryset = AccountGroup.objects.all()
|
||||||
serializer_class = AccountGroupSerializer
|
serializer_class = AccountGroupSerializer
|
||||||
pagination_class = CustomPageNumberPagination
|
pagination_class = CustomPageNumberPagination
|
||||||
@@ -14,7 +21,16 @@ class AccountGroupViewSet(viewsets.ModelViewSet):
|
|||||||
return AccountGroup.objects.all().order_by("id")
|
return AccountGroup.objects.all().order_by("id")
|
||||||
|
|
||||||
|
|
||||||
|
@extend_schema_view(
|
||||||
|
balance=extend_schema(
|
||||||
|
summary="Get account balance",
|
||||||
|
description="Returns the current and projected balance for the account, along with currency data.",
|
||||||
|
responses={200: AccountBalanceSerializer},
|
||||||
|
),
|
||||||
|
)
|
||||||
class AccountViewSet(viewsets.ModelViewSet):
|
class AccountViewSet(viewsets.ModelViewSet):
|
||||||
|
"""ViewSet for managing accounts."""
|
||||||
|
|
||||||
queryset = Account.objects.all()
|
queryset = Account.objects.all()
|
||||||
serializer_class = AccountSerializer
|
serializer_class = AccountSerializer
|
||||||
pagination_class = CustomPageNumberPagination
|
pagination_class = CustomPageNumberPagination
|
||||||
@@ -25,3 +41,20 @@ class AccountViewSet(viewsets.ModelViewSet):
|
|||||||
.order_by("id")
|
.order_by("id")
|
||||||
.select_related("group", "currency", "exchange_currency")
|
.select_related("group", "currency", "exchange_currency")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@action(detail=True, methods=["get"], permission_classes=[IsAuthenticated])
|
||||||
|
def balance(self, request, pk=None):
|
||||||
|
"""Get current and projected balance for an account."""
|
||||||
|
account = self.get_object()
|
||||||
|
|
||||||
|
current_balance = get_account_balance(account, paid_only=True)
|
||||||
|
projected_balance = get_account_balance(account, paid_only=False)
|
||||||
|
|
||||||
|
serializer = AccountBalanceSerializer({
|
||||||
|
"current_balance": current_balance,
|
||||||
|
"projected_balance": projected_balance,
|
||||||
|
"currency": account.currency,
|
||||||
|
})
|
||||||
|
|
||||||
|
return Response(serializer.data)
|
||||||
|
|
||||||
|
|||||||
123
app/apps/api/views/imports.py
Normal file
123
app/apps/api/views/imports.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
from django.core.files.storage import FileSystemStorage
|
||||||
|
from drf_spectacular.types import OpenApiTypes
|
||||||
|
from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view, inline_serializer
|
||||||
|
from rest_framework import serializers as drf_serializers
|
||||||
|
from rest_framework import status, viewsets
|
||||||
|
from rest_framework.parsers import MultiPartParser
|
||||||
|
from rest_framework.permissions import IsAuthenticated
|
||||||
|
from rest_framework.response import Response
|
||||||
|
|
||||||
|
from apps.api.serializers import ImportFileSerializer, ImportProfileSerializer, ImportRunSerializer
|
||||||
|
from apps.import_app.models import ImportProfile, ImportRun
|
||||||
|
from apps.import_app.tasks import process_import
|
||||||
|
|
||||||
|
|
||||||
|
@extend_schema_view(
|
||||||
|
list=extend_schema(
|
||||||
|
summary="List import profiles",
|
||||||
|
description="Returns a paginated list of all available import profiles.",
|
||||||
|
),
|
||||||
|
retrieve=extend_schema(
|
||||||
|
summary="Get import profile",
|
||||||
|
description="Returns the details of a specific import profile by ID.",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
class ImportProfileViewSet(viewsets.ReadOnlyModelViewSet):
|
||||||
|
"""ViewSet for listing and retrieving import profiles."""
|
||||||
|
|
||||||
|
queryset = ImportProfile.objects.all()
|
||||||
|
serializer_class = ImportProfileSerializer
|
||||||
|
permission_classes = [IsAuthenticated]
|
||||||
|
|
||||||
|
|
||||||
|
@extend_schema_view(
|
||||||
|
list=extend_schema(
|
||||||
|
summary="List import runs",
|
||||||
|
description="Returns a paginated list of import runs. Optionally filter by profile_id.",
|
||||||
|
parameters=[
|
||||||
|
OpenApiParameter(
|
||||||
|
name="profile_id",
|
||||||
|
type=int,
|
||||||
|
location=OpenApiParameter.QUERY,
|
||||||
|
description="Filter runs by profile ID",
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
retrieve=extend_schema(
|
||||||
|
summary="Get import run",
|
||||||
|
description="Returns the details of a specific import run by ID, including status and logs.",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
class ImportRunViewSet(viewsets.ReadOnlyModelViewSet):
|
||||||
|
"""ViewSet for listing and retrieving import runs."""
|
||||||
|
|
||||||
|
queryset = ImportRun.objects.all().order_by("-id")
|
||||||
|
serializer_class = ImportRunSerializer
|
||||||
|
permission_classes = [IsAuthenticated]
|
||||||
|
|
||||||
|
def get_queryset(self):
|
||||||
|
queryset = super().get_queryset()
|
||||||
|
profile_id = self.request.query_params.get("profile_id")
|
||||||
|
if profile_id:
|
||||||
|
queryset = queryset.filter(profile_id=profile_id)
|
||||||
|
return queryset
|
||||||
|
|
||||||
|
|
||||||
|
@extend_schema_view(
|
||||||
|
create=extend_schema(
|
||||||
|
summary="Import file",
|
||||||
|
description="Upload a CSV or XLSX file to import using an existing import profile. The import is queued and processed asynchronously.",
|
||||||
|
request={
|
||||||
|
"multipart/form-data": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"profile_id": {"type": "integer", "description": "ID of the ImportProfile to use"},
|
||||||
|
"file": {"type": "string", "format": "binary", "description": "CSV or XLSX file to import"},
|
||||||
|
},
|
||||||
|
"required": ["profile_id", "file"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
responses={
|
||||||
|
202: inline_serializer(
|
||||||
|
name="ImportResponse",
|
||||||
|
fields={
|
||||||
|
"import_run_id": drf_serializers.IntegerField(),
|
||||||
|
"status": drf_serializers.CharField(),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
class ImportViewSet(viewsets.ViewSet):
|
||||||
|
"""ViewSet for importing data via file upload."""
|
||||||
|
|
||||||
|
permission_classes = [IsAuthenticated]
|
||||||
|
parser_classes = [MultiPartParser]
|
||||||
|
|
||||||
|
def create(self, request):
|
||||||
|
serializer = ImportFileSerializer(data=request.data)
|
||||||
|
serializer.is_valid(raise_exception=True)
|
||||||
|
|
||||||
|
profile = serializer.validated_data["profile"]
|
||||||
|
uploaded_file = serializer.validated_data["file"]
|
||||||
|
|
||||||
|
# Save file to temp location
|
||||||
|
fs = FileSystemStorage(location="/usr/src/app/temp")
|
||||||
|
filename = fs.save(uploaded_file.name, uploaded_file)
|
||||||
|
file_path = fs.path(filename)
|
||||||
|
|
||||||
|
# Create ImportRun record
|
||||||
|
import_run = ImportRun.objects.create(profile=profile, file_name=filename)
|
||||||
|
|
||||||
|
# Queue import task
|
||||||
|
process_import.defer(
|
||||||
|
import_run_id=import_run.id,
|
||||||
|
file_path=file_path,
|
||||||
|
user_id=request.user.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
{"import_run_id": import_run.id, "status": "queued"},
|
||||||
|
status=status.HTTP_202_ACCEPTED,
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user