diff --git a/rio/app.py b/rio/app.py index 52e6b5f8..f332f48a 100644 --- a/rio/app.py +++ b/rio/app.py @@ -1036,33 +1036,53 @@ pixels_per_rem; # Register the extension self._ids_to_extensions[id(extension)] = extension - # Register the extension's event handlers + # Gather all of the the extension's event handlers. This will put them + # in a dictionary, grouped by their event tag. handlers = rio.extension_event._collect_tagged_methods_recursive( - extension.__class__ + extension ) - self._extension_on_app_start_handlers.extend( - handlers.get(rio.extension_event.ExtensionEventTag.ON_APP_START, []) + # The values in the dictionary above aren't just the callables - they + # allow for an optional argument that was passed to the decorator. Since + # most events don't use that, use a helper function to strip it again. + def extend_with_first_in_tuples(target: list, tuples) -> None: + for tup in tuples: + assert len(tup) == 2 + assert tup[1] is None + + target.append(tup[0]) + + extend_with_first_in_tuples( + self._extension_on_app_start_handlers, + handlers.get( + rio.extension_event.ExtensionEventTag.ON_APP_START, [] + ), ) - self._extension_on_app_close_handlers.extend( - handlers.get(rio.extension_event.ExtensionEventTag.ON_APP_CLOSE, []) + extend_with_first_in_tuples( + self._extension_on_app_close_handlers, + handlers.get( + rio.extension_event.ExtensionEventTag.ON_APP_CLOSE, [] + ), ) - self._extension_on_session_start_handlers.extend( + extend_with_first_in_tuples( + self._extension_on_session_start_handlers, handlers.get( rio.extension_event.ExtensionEventTag.ON_SESSION_START, [] - ) + ), ) - self._extension_on_session_close_handlers.extend( + extend_with_first_in_tuples( + self._extension_on_session_close_handlers, handlers.get( rio.extension_event.ExtensionEventTag.ON_SESSION_CLOSE, [] - ) + ), ) - self._extension_on_page_change_handlers.extend( + extend_with_first_in_tuples( + self._extension_on_page_change_handlers, handlers.get( rio.extension_event.ExtensionEventTag.ON_PAGE_CHANGE, [] - ) + ), ) def add_default_attachment(self, attachment: t.Any) -> None: diff --git a/rio/extension.py b/rio/extension.py index 232927ed..2090ebcd 100644 --- a/rio/extension.py +++ b/rio/extension.py @@ -1,7 +1,7 @@ import abc import typing as t -from . import extension_event +import rio __all__ = [ "Extension", @@ -10,19 +10,19 @@ __all__ = [ class Extension(abc.ABC): _rio_on_app_start_event_handlers_: t.ClassVar[ - list[t.Callable[[extension_event.ExtensionAppStartEvent], None]] + list[t.Callable[[rio.extension_event.ExtensionAppStartEvent], None]] ] _rio_on_app_close_event_handlers_: t.ClassVar[ - list[t.Callable[[extension_event.ExtensionAppCloseEvent], None]] + list[t.Callable[[rio.extension_event.ExtensionAppCloseEvent], None]] ] _rio_on_session_start_event_handlers_: t.ClassVar[ - list[t.Callable[[extension_event.ExtensionSessionStartEvent], None]] + list[t.Callable[[rio.extension_event.ExtensionSessionStartEvent], None]] ] _rio_on_session_close_event_handlers_: t.ClassVar[ - list[t.Callable[[extension_event.ExtensionSessionCloseEvent], None]] + list[t.Callable[[rio.extension_event.ExtensionSessionCloseEvent], None]] ] def __init__(self) -> None: diff --git a/rio/extension_event.py b/rio/extension_event.py index 3f03c33f..8a2bedfb 100644 --- a/rio/extension_event.py +++ b/rio/extension_event.py @@ -35,27 +35,27 @@ MethodWithNoParametersVar = t.TypeVar( MethodWithAppStartEventParameterVar = t.TypeVar( "MethodWithAppStartEventParameterVar", - bound=t.Callable[["ExtensionAppStartEvent"], t.Any], + bound=t.Callable[[t.Any, "ExtensionAppStartEvent"], t.Any], ) MethodWithAppCloseEventParameterVar = t.TypeVar( "MethodWithAppCloseEventParameterVar", - bound=t.Callable[["ExtensionAppCloseEvent"], t.Any], + bound=t.Callable[[t.Any, "ExtensionAppCloseEvent"], t.Any], ) MethodWithSessionStartEventParameterVar = t.TypeVar( "MethodWithSessionStartEventParameterVar", - bound=t.Callable[["ExtensionSessionStartEvent"], t.Any], + bound=t.Callable[[t.Any, "ExtensionSessionStartEvent"], t.Any], ) MethodWithSessionCloseEventParameterVar = t.TypeVar( "MethodWithSessionCloseEventParameterVar", - bound=t.Callable[["ExtensionSessionCloseEvent"], t.Any], + bound=t.Callable[[t.Any, "ExtensionSessionCloseEvent"], t.Any], ) MethodWithPageChangeEventParameterVar = t.TypeVar( "MethodWithPageChangeEventParameterVar", - bound=t.Callable[["ExtensionPageChangeEvent"], t.Any], + bound=t.Callable[[t.Any, "ExtensionPageChangeEvent"], t.Any], ) @@ -136,7 +136,7 @@ class ExtensionPageChangeEvent: def _tag_as_event_handler( function: t.Callable, tag: ExtensionEventTag, - args: t.Any, + arg: t.Any, ) -> None: """ Registers the function as an event handler for the given tag. This simply @@ -146,13 +146,14 @@ def _tag_as_event_handler( all_events: dict[ExtensionEventTag, list[t.Any]] = vars( function ).setdefault("_rio_extension_events_", {}) + events_like_this = all_events.setdefault(tag, []) - events_like_this.append(args) + events_like_this.append((function, arg)) def _collect_tagged_methods_recursive( - cls: t.Type, -) -> dict[ExtensionEventTag, list[t.Callable]]: + ext: rio.Extension, +) -> dict[ExtensionEventTag, list[tuple[t.Callable, t.Any]]]: """ Walks a class and its parents, gathering all methods that have been tagged as event handlers. @@ -162,10 +163,10 @@ def _collect_tagged_methods_recursive( handlers for a particular event, the result may have no entry for this tag at all, or contain an empty list. """ - result: dict[ExtensionEventTag, list[t.Callable]] = {} + result: dict[ExtensionEventTag, list[tuple[t.Callable, t.Any]]] = {} # The MRO conveniently includes all classes that need to be searched - for base in cls.__mro__: + for base in type(ext).__mro__: # Walk all methods in the class for _, method in vars(base).items(): # Skip untagged members. This also conveniently filters out any @@ -177,7 +178,17 @@ def _collect_tagged_methods_recursive( # Which events is this method a handler for? for tag, handlers in method._rio_extension_events_.items(): - result.setdefault(tag, []).extend(handlers) + # Because the method was retrieved from the class instead of an + # instance, it's not bound to anything. Fix that. + result.setdefault(tag, []).extend( + [ + ( + handler.__get__(ext), + arg, + ) + for handler, arg in handlers + ] + ) return result diff --git a/tests/test_app_build.py b/tests/test_app_build.py index d7672d33..108e8a98 100644 --- a/tests/test_app_build.py +++ b/tests/test_app_build.py @@ -1,8 +1,8 @@ import rio.testing -async def test_fundamental_container_as_root(): - def build(): +async def test_fundamental_container_as_root() -> None: + def build() -> rio.Component: return rio.Row(rio.Text("Hello")) async with rio.testing.TestClient(build) as test_client: diff --git a/tests/test_attribute_bindings.py b/tests/test_attribute_bindings.py index 9f28481b..7cb6d249 100644 --- a/tests/test_attribute_bindings.py +++ b/tests/test_attribute_bindings.py @@ -7,24 +7,24 @@ from rio.state_properties import PleaseTurnThisIntoAnAttributeBinding class Parent(rio.Component): text: str = "" - def build(self): + def build(self) -> rio.Component: return rio.Text(self.bind().text) class Grandparent(rio.Component): text: str = "" - def build(self): + def build(self) -> rio.Component: return Parent(self.bind().text) -async def test_bindings_arent_created_too_early(): +async def test_bindings_arent_created_too_early() -> None: # There was a time when attribute bindings were created in `Component.__init__`, # thus skipping any properties that were only assigned later. class IHaveACustomInit(rio.Component): text: str - def __init__(self, *args, text: str, **kwargs): + def __init__(self, *args, text: str, **kwargs) -> None: super().__init__(*args, **kwargs) # `Component.__init__`` has already run, but we haven't assigned @@ -51,7 +51,7 @@ async def test_bindings_arent_created_too_early(): assert child_component.text == "bye" -async def test_init_receives_attribute_bindings_as_input(): +async def test_init_receives_attribute_bindings_as_input() -> None: # For a while we considered initializing attribute bindings before calling a # component's `__init__` and passing the values of the bindings as arguments # into `__init__`. But ultimately we decided against it, because some @@ -61,7 +61,7 @@ async def test_init_receives_attribute_bindings_as_input(): size_value = None class Square(rio.Component): - def __init__(self, size: float): + def __init__(self, size: float) -> None: nonlocal size_value size_value = size @@ -82,7 +82,7 @@ async def test_init_receives_attribute_bindings_as_input(): assert isinstance(size_value, PleaseTurnThisIntoAnAttributeBinding) -async def test_binding_assignment_on_child(): +async def test_binding_assignment_on_child() -> None: async with rio.testing.TestClient(Parent) as test_client: root_component = test_client.get_component(Parent) text_component = test_client._get_build_output(root_component, rio.Text) @@ -99,7 +99,7 @@ async def test_binding_assignment_on_child(): assert text_component.text == "Hello" -async def test_binding_assignment_on_parent(): +async def test_binding_assignment_on_parent() -> None: async with rio.testing.TestClient(Parent) as test_client: root_component = test_client.get_component(Parent) text_component = test_client._get_build_output(root_component) @@ -116,7 +116,7 @@ async def test_binding_assignment_on_parent(): assert text_component.text == "Hello" -async def test_binding_assignment_on_sibling(): +async def test_binding_assignment_on_sibling() -> None: class Root(rio.Component): text: str = "" @@ -147,7 +147,7 @@ async def test_binding_assignment_on_sibling(): assert text2.text == "Hello" -async def test_binding_assignment_on_grandchild(): +async def test_binding_assignment_on_grandchild() -> None: async with rio.testing.TestClient(Grandparent) as test_client: root_component = test_client.get_component(Grandparent) parent = t.cast(Parent, test_client._get_build_output(root_component)) @@ -167,7 +167,7 @@ async def test_binding_assignment_on_grandchild(): assert text_component.text == "Hello" -async def test_binding_assignment_on_middle(): +async def test_binding_assignment_on_middle() -> None: async with rio.testing.TestClient(Grandparent) as test_client: root_component = test_client.get_component(Grandparent) parent: Parent = test_client._get_build_output(root_component) @@ -187,7 +187,7 @@ async def test_binding_assignment_on_middle(): assert text_component.text == "Hello" -async def test_binding_assignment_on_child_after_reconciliation(): +async def test_binding_assignment_on_child_after_reconciliation() -> None: async with rio.testing.TestClient(Parent) as test_client: root_component = test_client.get_component(Parent) text_component: rio.Text = test_client._get_build_output(root_component) @@ -207,7 +207,7 @@ async def test_binding_assignment_on_child_after_reconciliation(): assert text_component.text == "Hello" -async def test_binding_assignment_on_parent_after_reconciliation(): +async def test_binding_assignment_on_parent_after_reconciliation() -> None: async with rio.testing.TestClient(Parent) as test_client: root_component = test_client.get_component(Parent) text_component: rio.Text = test_client._get_build_output(root_component) @@ -227,7 +227,7 @@ async def test_binding_assignment_on_parent_after_reconciliation(): assert text_component.text == "Hello" -async def test_binding_assignment_on_sibling_after_reconciliation(): +async def test_binding_assignment_on_sibling_after_reconciliation() -> None: class Root(rio.Component): text: str = "" @@ -258,7 +258,7 @@ async def test_binding_assignment_on_sibling_after_reconciliation(): assert text2.text == "Hello" -async def test_binding_assignment_on_grandchild_after_reconciliation(): +async def test_binding_assignment_on_grandchild_after_reconciliation() -> None: async with rio.testing.TestClient(Grandparent) as test_client: root_component = test_client.get_component(Grandparent) parent: Parent = test_client._get_build_output(root_component) @@ -281,7 +281,7 @@ async def test_binding_assignment_on_grandchild_after_reconciliation(): assert text_component.text == "Hello" -async def test_binding_assignment_on_middle_after_reconciliation(): +async def test_binding_assignment_on_middle_after_reconciliation() -> None: async with rio.testing.TestClient(Grandparent) as test_client: root_component = test_client.get_component(Grandparent) parent: Parent = test_client._get_build_output(root_component) diff --git a/tests/test_extensions.py b/tests/test_extensions.py new file mode 100644 index 00000000..b69ee42b --- /dev/null +++ b/tests/test_extensions.py @@ -0,0 +1,138 @@ +import asyncio +import inspect +import typing as t + +import rio.testing + + +class TracingExtensionParent(rio.Extension): + def __init__(self) -> None: + # Keeps track of called functions + self.function_call_log: list[str] = [] + + def verify_and_clear_log(self, *expected: str) -> None: + """ + Verifies that the recorded function calls match the expected ones. This + disregards order! + """ + + assert set(self.function_call_log) == set(expected) + self.function_call_log.clear() + + def _record_function_call(self) -> None: + # Get the caller's name + caller_name = inspect.stack()[1].function + + # Record it + self.function_call_log.append(caller_name) + + # This function should be inherited and called + @rio.extension_event.on_session_start + def on_session_start_parent( + self, + event: rio.ExtensionSessionStartEvent, + ) -> None: + self._record_function_call() + + +class TracingExtensionChild(TracingExtensionParent): + @rio.extension_event.on_app_start + def on_app_start( + self, + event: rio.ExtensionAppStartEvent, + ) -> None: + self._record_function_call() + + @rio.extension_event.on_app_close + def on_app_close( + self, + event: rio.ExtensionAppCloseEvent, + ) -> None: + self._record_function_call() + + # This function isn't registered at all and should not be called + def on_session_start( + self, + event: rio.ExtensionSessionStartEvent, + ) -> None: + self._record_function_call() + + # This function is asynchronous, testing that the extension system awaits + # functions as needed. + @rio.extension_event.on_session_start + def on_session_start_async( + self, + event: rio.ExtensionSessionStartEvent, + ) -> None: + self._record_function_call() + + @rio.extension_event.on_session_close + def on_session_close( + self, + event: rio.ExtensionSessionCloseEvent, + ) -> None: + self._record_function_call() + + @rio.extension_event.on_page_change + def on_page_change( + self, + event: rio.ExtensionPageChangeEvent, + ) -> None: + self._record_function_call() + + # This function is registered for multiple events and should be called for + # each of them + @rio.extension_event.on_session_start + @rio.extension_event.on_session_close + @rio.extension_event.on_page_change + def on_multiple_events( + self, + event: t.Any, + ) -> None: + self._record_function_call() + + +async def test_extension_events() -> None: + # Create an app + app = rio.App( + build=lambda: rio.Text("Hello"), + ) + + # Add an extension which records called events + extension_instance = TracingExtensionChild() + app._add_extension(extension_instance) + + assert len(extension_instance.function_call_log) == 0 + + async with rio.testing.TestClient(app) as test_client: + # The test client doesn't support Rio's full feature set and so doesn't + # actually call the app start/close events right now. Skip them + + # Expect: Session start events + extension_instance.verify_and_clear_log( + "on_session_start_async", + "on_session_start_parent", + "on_multiple_events", + ) + + # Navigate to another page + test_client.session.navigate_to("/foobar") + + # Expect: Page change event + # + # Rio dispatches this event in an asyncio task. Give it some time to + # run. + await asyncio.sleep(0.1) + + extension_instance.verify_and_clear_log( + "on_page_change", + "on_multiple_events", + ) + + # Expect: Session close events + extension_instance.verify_and_clear_log( + "on_session_close", + "on_multiple_events", + ) + + # The app close event is missing here for the same reason as above diff --git a/tests/test_refresh.py b/tests/test_refresh.py index 9e7c2ffb..f40dee38 100644 --- a/tests/test_refresh.py +++ b/tests/test_refresh.py @@ -1,8 +1,8 @@ import rio.testing -async def test_refresh_with_nothing_to_do(): - def build(): +async def test_refresh_with_nothing_to_do() -> None: + def build() -> rio.Component: return rio.Text("Hello") async with rio.testing.TestClient(build) as test_client: @@ -13,8 +13,8 @@ async def test_refresh_with_nothing_to_do(): assert not test_client._last_updated_components -async def test_refresh_with_clean_root_component(): - def build(): +async def test_refresh_with_clean_root_component() -> None: + def build() -> rio.Component: text_component = rio.Text("Hello") return rio.Container(text_component) @@ -27,7 +27,7 @@ async def test_refresh_with_clean_root_component(): assert test_client._last_updated_components == {text_component} -async def test_rebuild_component_with_dead_parent(): +async def test_rebuild_component_with_dead_parent() -> None: class RootComponent(rio.Component): content: rio.Component @@ -64,7 +64,7 @@ async def test_rebuild_component_with_dead_parent(): assert test_client._last_updated_components == {root_component} -async def test_unmount_and_remount(): +async def test_unmount_and_remount() -> None: class DemoComponent(rio.Component): content: rio.Component show_child: bool diff --git a/tests/test_routing.py b/tests/test_routing.py index 106d8e66..22c79d28 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -124,6 +124,10 @@ def test_redirects( relative_url_before_redirects: str, relative_url_after_redirects_should: str, ) -> None: + """ + Simulate navigation to URLs, run any guards, and make sure the final, + resulting URL is correct. + """ # Create a fake session. It contains everything used by the routing system. # fake_session = t.cast(rio.Session, FakeSession()) #