Simplify tests setup, one test file for multiple source variants (#1407)

This commit is contained in:
Sebastián Ramírez
2025-06-19 16:29:32 +02:00
committed by GitHub
parent af52d6573f
commit f1c9d15525
3 changed files with 43 additions and 193 deletions

View File

@@ -1,8 +1,10 @@
import shutil
import subprocess
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Dict, List, Union
from typing import Any, Callable, Dict, Generator, List, Union
from unittest.mock import patch
import pytest
from pydantic import BaseModel
@@ -26,7 +28,7 @@ def clear_sqlmodel() -> Any:
@pytest.fixture()
def cov_tmp_path(tmp_path: Path):
def cov_tmp_path(tmp_path: Path) -> Generator[Path, None, None]:
yield tmp_path
for coverage_path in tmp_path.glob(".coverage*"):
coverage_destiny_path = top_level_path / coverage_path.name
@@ -53,8 +55,8 @@ def coverage_run(*, module: str, cwd: Union[str, Path]) -> subprocess.CompletedP
def get_testing_print_function(
calls: List[List[Union[str, Dict[str, Any]]]],
) -> Callable[..., Any]:
def new_print(*args):
data = []
def new_print(*args: Any) -> None:
data: List[Any] = []
for arg in args:
if isinstance(arg, BaseModel):
data.append(arg.model_dump())
@@ -71,6 +73,19 @@ def get_testing_print_function(
return new_print
@dataclass
class PrintMock:
calls: List[Any] = field(default_factory=list)
@pytest.fixture(name="print_mock")
def print_mock_fixture() -> Generator[PrintMock, None, None]:
print_mock = PrintMock()
new_print = get_testing_print_function(print_mock.calls)
with patch("builtins.print", new=new_print):
yield print_mock
needs_pydanticv2 = pytest.mark.skipif(not IS_PYDANTIC_V2, reason="requires Pydantic v2")
needs_pydanticv1 = pytest.mark.skipif(IS_PYDANTIC_V2, reason="requires Pydantic v1")

View File

@@ -1,163 +0,0 @@
from typing import Any, Dict, List, Union
from unittest.mock import patch
from sqlmodel import create_engine
from tests.conftest import get_testing_print_function, needs_py310
def check_calls(calls: List[List[Union[str, Dict[str, Any]]]]):
assert calls[0] == ["Before interacting with the database"]
assert calls[1] == [
"Hero 1:",
{
"id": None,
"name": "Deadpond",
"secret_name": "Dive Wilson",
"age": None,
},
]
assert calls[2] == [
"Hero 2:",
{
"id": None,
"name": "Spider-Boy",
"secret_name": "Pedro Parqueador",
"age": None,
},
]
assert calls[3] == [
"Hero 3:",
{
"id": None,
"name": "Rusty-Man",
"secret_name": "Tommy Sharp",
"age": 48,
},
]
assert calls[4] == ["After adding to the session"]
assert calls[5] == [
"Hero 1:",
{
"id": None,
"name": "Deadpond",
"secret_name": "Dive Wilson",
"age": None,
},
]
assert calls[6] == [
"Hero 2:",
{
"id": None,
"name": "Spider-Boy",
"secret_name": "Pedro Parqueador",
"age": None,
},
]
assert calls[7] == [
"Hero 3:",
{
"id": None,
"name": "Rusty-Man",
"secret_name": "Tommy Sharp",
"age": 48,
},
]
assert calls[8] == ["After committing the session"]
assert calls[9] == ["Hero 1:", {}]
assert calls[10] == ["Hero 2:", {}]
assert calls[11] == ["Hero 3:", {}]
assert calls[12] == ["After committing the session, show IDs"]
assert calls[13] == ["Hero 1 ID:", 1]
assert calls[14] == ["Hero 2 ID:", 2]
assert calls[15] == ["Hero 3 ID:", 3]
assert calls[16] == ["After committing the session, show names"]
assert calls[17] == ["Hero 1 name:", "Deadpond"]
assert calls[18] == ["Hero 2 name:", "Spider-Boy"]
assert calls[19] == ["Hero 3 name:", "Rusty-Man"]
assert calls[20] == ["After refreshing the heroes"]
assert calls[21] == [
"Hero 1:",
{
"id": 1,
"name": "Deadpond",
"secret_name": "Dive Wilson",
"age": None,
},
]
assert calls[22] == [
"Hero 2:",
{
"id": 2,
"name": "Spider-Boy",
"secret_name": "Pedro Parqueador",
"age": None,
},
]
assert calls[23] == [
"Hero 3:",
{
"id": 3,
"name": "Rusty-Man",
"secret_name": "Tommy Sharp",
"age": 48,
},
]
assert calls[24] == ["After the session closes"]
assert calls[21] == [
"Hero 1:",
{
"id": 1,
"name": "Deadpond",
"secret_name": "Dive Wilson",
"age": None,
},
]
assert calls[22] == [
"Hero 2:",
{
"id": 2,
"name": "Spider-Boy",
"secret_name": "Pedro Parqueador",
"age": None,
},
]
assert calls[23] == [
"Hero 3:",
{
"id": 3,
"name": "Rusty-Man",
"secret_name": "Tommy Sharp",
"age": 48,
},
]
@needs_py310
def test_tutorial_001(clear_sqlmodel):
from docs_src.tutorial.automatic_id_none_refresh import tutorial001_py310 as mod
mod.sqlite_url = "sqlite://"
mod.engine = create_engine(mod.sqlite_url)
calls = []
new_print = get_testing_print_function(calls)
with patch("builtins.print", new=new_print):
mod.main()
check_calls(calls)
@needs_py310
def test_tutorial_002(clear_sqlmodel):
from docs_src.tutorial.automatic_id_none_refresh import tutorial002_py310 as mod
mod.sqlite_url = "sqlite://"
mod.engine = create_engine(mod.sqlite_url)
calls = []
new_print = get_testing_print_function(calls)
with patch("builtins.print", new=new_print):
mod.main()
check_calls(calls)

View File

@@ -1,12 +1,14 @@
import importlib
from types import ModuleType
from typing import Any, Dict, List, Union
from unittest.mock import patch
import pytest
from sqlmodel import create_engine
from tests.conftest import get_testing_print_function
from tests.conftest import PrintMock, needs_py310
def check_calls(calls: List[List[Union[str, Dict[str, Any]]]]):
def check_calls(calls: List[List[Union[str, Dict[str, Any]]]]) -> None:
assert calls[0] == ["Before interacting with the database"]
assert calls[1] == [
"Hero 1:",
@@ -133,29 +135,25 @@ def check_calls(calls: List[List[Union[str, Dict[str, Any]]]]):
]
def test_tutorial_001():
from docs_src.tutorial.automatic_id_none_refresh import tutorial001 as mod
@pytest.fixture(
name="module",
params=[
"tutorial001",
"tutorial002",
pytest.param("tutorial001_py310", marks=needs_py310),
pytest.param("tutorial002_py310", marks=needs_py310),
],
)
def get_module(request: pytest.FixtureRequest) -> ModuleType:
module = importlib.import_module(
f"docs_src.tutorial.automatic_id_none_refresh.{request.param}"
)
module.sqlite_url = "sqlite://"
module.engine = create_engine(module.sqlite_url)
mod.sqlite_url = "sqlite://"
mod.engine = create_engine(mod.sqlite_url)
calls = []
new_print = get_testing_print_function(calls)
with patch("builtins.print", new=new_print):
mod.main()
check_calls(calls)
return module
def test_tutorial_002():
from docs_src.tutorial.automatic_id_none_refresh import tutorial002 as mod
mod.sqlite_url = "sqlite://"
mod.engine = create_engine(mod.sqlite_url)
calls = []
new_print = get_testing_print_function(calls)
with patch("builtins.print", new=new_print):
mod.main()
check_calls(calls)
def test_tutorial_001_tutorial_002(print_mock: PrintMock, module: ModuleType) -> None:
module.main()
check_calls(print_mock.calls)