extension fixes

This commit is contained in:
Jakob Pinterits
2025-01-01 22:18:19 +01:00
parent c43b408329
commit b74bdaacb0
8 changed files with 226 additions and 53 deletions

View File

@@ -1036,33 +1036,53 @@ pixels_per_rem;
# Register the extension
self._ids_to_extensions[id(extension)] = extension
# Register the extension's event handlers
# Gather all of the the extension's event handlers. This will put them
# in a dictionary, grouped by their event tag.
handlers = rio.extension_event._collect_tagged_methods_recursive(
extension.__class__
extension
)
self._extension_on_app_start_handlers.extend(
handlers.get(rio.extension_event.ExtensionEventTag.ON_APP_START, [])
# The values in the dictionary above aren't just the callables - they
# allow for an optional argument that was passed to the decorator. Since
# most events don't use that, use a helper function to strip it again.
def extend_with_first_in_tuples(target: list, tuples) -> None:
for tup in tuples:
assert len(tup) == 2
assert tup[1] is None
target.append(tup[0])
extend_with_first_in_tuples(
self._extension_on_app_start_handlers,
handlers.get(
rio.extension_event.ExtensionEventTag.ON_APP_START, []
),
)
self._extension_on_app_close_handlers.extend(
handlers.get(rio.extension_event.ExtensionEventTag.ON_APP_CLOSE, [])
extend_with_first_in_tuples(
self._extension_on_app_close_handlers,
handlers.get(
rio.extension_event.ExtensionEventTag.ON_APP_CLOSE, []
),
)
self._extension_on_session_start_handlers.extend(
extend_with_first_in_tuples(
self._extension_on_session_start_handlers,
handlers.get(
rio.extension_event.ExtensionEventTag.ON_SESSION_START, []
)
),
)
self._extension_on_session_close_handlers.extend(
extend_with_first_in_tuples(
self._extension_on_session_close_handlers,
handlers.get(
rio.extension_event.ExtensionEventTag.ON_SESSION_CLOSE, []
)
),
)
self._extension_on_page_change_handlers.extend(
extend_with_first_in_tuples(
self._extension_on_page_change_handlers,
handlers.get(
rio.extension_event.ExtensionEventTag.ON_PAGE_CHANGE, []
)
),
)
def add_default_attachment(self, attachment: t.Any) -> None:

View File

@@ -1,7 +1,7 @@
import abc
import typing as t
from . import extension_event
import rio
__all__ = [
"Extension",
@@ -10,19 +10,19 @@ __all__ = [
class Extension(abc.ABC):
_rio_on_app_start_event_handlers_: t.ClassVar[
list[t.Callable[[extension_event.ExtensionAppStartEvent], None]]
list[t.Callable[[rio.extension_event.ExtensionAppStartEvent], None]]
]
_rio_on_app_close_event_handlers_: t.ClassVar[
list[t.Callable[[extension_event.ExtensionAppCloseEvent], None]]
list[t.Callable[[rio.extension_event.ExtensionAppCloseEvent], None]]
]
_rio_on_session_start_event_handlers_: t.ClassVar[
list[t.Callable[[extension_event.ExtensionSessionStartEvent], None]]
list[t.Callable[[rio.extension_event.ExtensionSessionStartEvent], None]]
]
_rio_on_session_close_event_handlers_: t.ClassVar[
list[t.Callable[[extension_event.ExtensionSessionCloseEvent], None]]
list[t.Callable[[rio.extension_event.ExtensionSessionCloseEvent], None]]
]
def __init__(self) -> None:

View File

@@ -35,27 +35,27 @@ MethodWithNoParametersVar = t.TypeVar(
MethodWithAppStartEventParameterVar = t.TypeVar(
"MethodWithAppStartEventParameterVar",
bound=t.Callable[["ExtensionAppStartEvent"], t.Any],
bound=t.Callable[[t.Any, "ExtensionAppStartEvent"], t.Any],
)
MethodWithAppCloseEventParameterVar = t.TypeVar(
"MethodWithAppCloseEventParameterVar",
bound=t.Callable[["ExtensionAppCloseEvent"], t.Any],
bound=t.Callable[[t.Any, "ExtensionAppCloseEvent"], t.Any],
)
MethodWithSessionStartEventParameterVar = t.TypeVar(
"MethodWithSessionStartEventParameterVar",
bound=t.Callable[["ExtensionSessionStartEvent"], t.Any],
bound=t.Callable[[t.Any, "ExtensionSessionStartEvent"], t.Any],
)
MethodWithSessionCloseEventParameterVar = t.TypeVar(
"MethodWithSessionCloseEventParameterVar",
bound=t.Callable[["ExtensionSessionCloseEvent"], t.Any],
bound=t.Callable[[t.Any, "ExtensionSessionCloseEvent"], t.Any],
)
MethodWithPageChangeEventParameterVar = t.TypeVar(
"MethodWithPageChangeEventParameterVar",
bound=t.Callable[["ExtensionPageChangeEvent"], t.Any],
bound=t.Callable[[t.Any, "ExtensionPageChangeEvent"], t.Any],
)
@@ -136,7 +136,7 @@ class ExtensionPageChangeEvent:
def _tag_as_event_handler(
function: t.Callable,
tag: ExtensionEventTag,
args: t.Any,
arg: t.Any,
) -> None:
"""
Registers the function as an event handler for the given tag. This simply
@@ -146,13 +146,14 @@ def _tag_as_event_handler(
all_events: dict[ExtensionEventTag, list[t.Any]] = vars(
function
).setdefault("_rio_extension_events_", {})
events_like_this = all_events.setdefault(tag, [])
events_like_this.append(args)
events_like_this.append((function, arg))
def _collect_tagged_methods_recursive(
cls: t.Type,
) -> dict[ExtensionEventTag, list[t.Callable]]:
ext: rio.Extension,
) -> dict[ExtensionEventTag, list[tuple[t.Callable, t.Any]]]:
"""
Walks a class and its parents, gathering all methods that have been tagged
as event handlers.
@@ -162,10 +163,10 @@ def _collect_tagged_methods_recursive(
handlers for a particular event, the result may have no entry for this tag
at all, or contain an empty list.
"""
result: dict[ExtensionEventTag, list[t.Callable]] = {}
result: dict[ExtensionEventTag, list[tuple[t.Callable, t.Any]]] = {}
# The MRO conveniently includes all classes that need to be searched
for base in cls.__mro__:
for base in type(ext).__mro__:
# Walk all methods in the class
for _, method in vars(base).items():
# Skip untagged members. This also conveniently filters out any
@@ -177,7 +178,17 @@ def _collect_tagged_methods_recursive(
# Which events is this method a handler for?
for tag, handlers in method._rio_extension_events_.items():
result.setdefault(tag, []).extend(handlers)
# Because the method was retrieved from the class instead of an
# instance, it's not bound to anything. Fix that.
result.setdefault(tag, []).extend(
[
(
handler.__get__(ext),
arg,
)
for handler, arg in handlers
]
)
return result

View File

@@ -1,8 +1,8 @@
import rio.testing
async def test_fundamental_container_as_root():
def build():
async def test_fundamental_container_as_root() -> None:
def build() -> rio.Component:
return rio.Row(rio.Text("Hello"))
async with rio.testing.TestClient(build) as test_client:

View File

@@ -7,24 +7,24 @@ from rio.state_properties import PleaseTurnThisIntoAnAttributeBinding
class Parent(rio.Component):
text: str = ""
def build(self):
def build(self) -> rio.Component:
return rio.Text(self.bind().text)
class Grandparent(rio.Component):
text: str = ""
def build(self):
def build(self) -> rio.Component:
return Parent(self.bind().text)
async def test_bindings_arent_created_too_early():
async def test_bindings_arent_created_too_early() -> None:
# There was a time when attribute bindings were created in `Component.__init__`,
# thus skipping any properties that were only assigned later.
class IHaveACustomInit(rio.Component):
text: str
def __init__(self, *args, text: str, **kwargs):
def __init__(self, *args, text: str, **kwargs) -> None:
super().__init__(*args, **kwargs)
# `Component.__init__`` has already run, but we haven't assigned
@@ -51,7 +51,7 @@ async def test_bindings_arent_created_too_early():
assert child_component.text == "bye"
async def test_init_receives_attribute_bindings_as_input():
async def test_init_receives_attribute_bindings_as_input() -> None:
# For a while we considered initializing attribute bindings before calling a
# component's `__init__` and passing the values of the bindings as arguments
# into `__init__`. But ultimately we decided against it, because some
@@ -61,7 +61,7 @@ async def test_init_receives_attribute_bindings_as_input():
size_value = None
class Square(rio.Component):
def __init__(self, size: float):
def __init__(self, size: float) -> None:
nonlocal size_value
size_value = size
@@ -82,7 +82,7 @@ async def test_init_receives_attribute_bindings_as_input():
assert isinstance(size_value, PleaseTurnThisIntoAnAttributeBinding)
async def test_binding_assignment_on_child():
async def test_binding_assignment_on_child() -> None:
async with rio.testing.TestClient(Parent) as test_client:
root_component = test_client.get_component(Parent)
text_component = test_client._get_build_output(root_component, rio.Text)
@@ -99,7 +99,7 @@ async def test_binding_assignment_on_child():
assert text_component.text == "Hello"
async def test_binding_assignment_on_parent():
async def test_binding_assignment_on_parent() -> None:
async with rio.testing.TestClient(Parent) as test_client:
root_component = test_client.get_component(Parent)
text_component = test_client._get_build_output(root_component)
@@ -116,7 +116,7 @@ async def test_binding_assignment_on_parent():
assert text_component.text == "Hello"
async def test_binding_assignment_on_sibling():
async def test_binding_assignment_on_sibling() -> None:
class Root(rio.Component):
text: str = ""
@@ -147,7 +147,7 @@ async def test_binding_assignment_on_sibling():
assert text2.text == "Hello"
async def test_binding_assignment_on_grandchild():
async def test_binding_assignment_on_grandchild() -> None:
async with rio.testing.TestClient(Grandparent) as test_client:
root_component = test_client.get_component(Grandparent)
parent = t.cast(Parent, test_client._get_build_output(root_component))
@@ -167,7 +167,7 @@ async def test_binding_assignment_on_grandchild():
assert text_component.text == "Hello"
async def test_binding_assignment_on_middle():
async def test_binding_assignment_on_middle() -> None:
async with rio.testing.TestClient(Grandparent) as test_client:
root_component = test_client.get_component(Grandparent)
parent: Parent = test_client._get_build_output(root_component)
@@ -187,7 +187,7 @@ async def test_binding_assignment_on_middle():
assert text_component.text == "Hello"
async def test_binding_assignment_on_child_after_reconciliation():
async def test_binding_assignment_on_child_after_reconciliation() -> None:
async with rio.testing.TestClient(Parent) as test_client:
root_component = test_client.get_component(Parent)
text_component: rio.Text = test_client._get_build_output(root_component)
@@ -207,7 +207,7 @@ async def test_binding_assignment_on_child_after_reconciliation():
assert text_component.text == "Hello"
async def test_binding_assignment_on_parent_after_reconciliation():
async def test_binding_assignment_on_parent_after_reconciliation() -> None:
async with rio.testing.TestClient(Parent) as test_client:
root_component = test_client.get_component(Parent)
text_component: rio.Text = test_client._get_build_output(root_component)
@@ -227,7 +227,7 @@ async def test_binding_assignment_on_parent_after_reconciliation():
assert text_component.text == "Hello"
async def test_binding_assignment_on_sibling_after_reconciliation():
async def test_binding_assignment_on_sibling_after_reconciliation() -> None:
class Root(rio.Component):
text: str = ""
@@ -258,7 +258,7 @@ async def test_binding_assignment_on_sibling_after_reconciliation():
assert text2.text == "Hello"
async def test_binding_assignment_on_grandchild_after_reconciliation():
async def test_binding_assignment_on_grandchild_after_reconciliation() -> None:
async with rio.testing.TestClient(Grandparent) as test_client:
root_component = test_client.get_component(Grandparent)
parent: Parent = test_client._get_build_output(root_component)
@@ -281,7 +281,7 @@ async def test_binding_assignment_on_grandchild_after_reconciliation():
assert text_component.text == "Hello"
async def test_binding_assignment_on_middle_after_reconciliation():
async def test_binding_assignment_on_middle_after_reconciliation() -> None:
async with rio.testing.TestClient(Grandparent) as test_client:
root_component = test_client.get_component(Grandparent)
parent: Parent = test_client._get_build_output(root_component)

138
tests/test_extensions.py Normal file
View File

@@ -0,0 +1,138 @@
import asyncio
import inspect
import typing as t
import rio.testing
class TracingExtensionParent(rio.Extension):
def __init__(self) -> None:
# Keeps track of called functions
self.function_call_log: list[str] = []
def verify_and_clear_log(self, *expected: str) -> None:
"""
Verifies that the recorded function calls match the expected ones. This
disregards order!
"""
assert set(self.function_call_log) == set(expected)
self.function_call_log.clear()
def _record_function_call(self) -> None:
# Get the caller's name
caller_name = inspect.stack()[1].function
# Record it
self.function_call_log.append(caller_name)
# This function should be inherited and called
@rio.extension_event.on_session_start
def on_session_start_parent(
self,
event: rio.ExtensionSessionStartEvent,
) -> None:
self._record_function_call()
class TracingExtensionChild(TracingExtensionParent):
@rio.extension_event.on_app_start
def on_app_start(
self,
event: rio.ExtensionAppStartEvent,
) -> None:
self._record_function_call()
@rio.extension_event.on_app_close
def on_app_close(
self,
event: rio.ExtensionAppCloseEvent,
) -> None:
self._record_function_call()
# This function isn't registered at all and should not be called
def on_session_start(
self,
event: rio.ExtensionSessionStartEvent,
) -> None:
self._record_function_call()
# This function is asynchronous, testing that the extension system awaits
# functions as needed.
@rio.extension_event.on_session_start
def on_session_start_async(
self,
event: rio.ExtensionSessionStartEvent,
) -> None:
self._record_function_call()
@rio.extension_event.on_session_close
def on_session_close(
self,
event: rio.ExtensionSessionCloseEvent,
) -> None:
self._record_function_call()
@rio.extension_event.on_page_change
def on_page_change(
self,
event: rio.ExtensionPageChangeEvent,
) -> None:
self._record_function_call()
# This function is registered for multiple events and should be called for
# each of them
@rio.extension_event.on_session_start
@rio.extension_event.on_session_close
@rio.extension_event.on_page_change
def on_multiple_events(
self,
event: t.Any,
) -> None:
self._record_function_call()
async def test_extension_events() -> None:
# Create an app
app = rio.App(
build=lambda: rio.Text("Hello"),
)
# Add an extension which records called events
extension_instance = TracingExtensionChild()
app._add_extension(extension_instance)
assert len(extension_instance.function_call_log) == 0
async with rio.testing.TestClient(app) as test_client:
# The test client doesn't support Rio's full feature set and so doesn't
# actually call the app start/close events right now. Skip them
# Expect: Session start events
extension_instance.verify_and_clear_log(
"on_session_start_async",
"on_session_start_parent",
"on_multiple_events",
)
# Navigate to another page
test_client.session.navigate_to("/foobar")
# Expect: Page change event
#
# Rio dispatches this event in an asyncio task. Give it some time to
# run.
await asyncio.sleep(0.1)
extension_instance.verify_and_clear_log(
"on_page_change",
"on_multiple_events",
)
# Expect: Session close events
extension_instance.verify_and_clear_log(
"on_session_close",
"on_multiple_events",
)
# The app close event is missing here for the same reason as above

View File

@@ -1,8 +1,8 @@
import rio.testing
async def test_refresh_with_nothing_to_do():
def build():
async def test_refresh_with_nothing_to_do() -> None:
def build() -> rio.Component:
return rio.Text("Hello")
async with rio.testing.TestClient(build) as test_client:
@@ -13,8 +13,8 @@ async def test_refresh_with_nothing_to_do():
assert not test_client._last_updated_components
async def test_refresh_with_clean_root_component():
def build():
async def test_refresh_with_clean_root_component() -> None:
def build() -> rio.Component:
text_component = rio.Text("Hello")
return rio.Container(text_component)
@@ -27,7 +27,7 @@ async def test_refresh_with_clean_root_component():
assert test_client._last_updated_components == {text_component}
async def test_rebuild_component_with_dead_parent():
async def test_rebuild_component_with_dead_parent() -> None:
class RootComponent(rio.Component):
content: rio.Component
@@ -64,7 +64,7 @@ async def test_rebuild_component_with_dead_parent():
assert test_client._last_updated_components == {root_component}
async def test_unmount_and_remount():
async def test_unmount_and_remount() -> None:
class DemoComponent(rio.Component):
content: rio.Component
show_child: bool

View File

@@ -124,6 +124,10 @@ def test_redirects(
relative_url_before_redirects: str,
relative_url_after_redirects_should: str,
) -> None:
"""
Simulate navigation to URLs, run any guards, and make sure the final,
resulting URL is correct.
"""
# Create a fake session. It contains everything used by the routing system.
# fake_session = t.cast(rio.Session, FakeSession())
#