mirror of
https://github.com/rio-labs/rio.git
synced 2026-01-07 13:49:51 -06:00
wip add oauth extension
This commit is contained in:
244
extensions/oauth_extension.py
Normal file
244
extensions/oauth_extension.py
Normal file
@@ -0,0 +1,244 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import logging
|
||||
import secrets
|
||||
import typing as t
|
||||
from base64 import urlsafe_b64encode
|
||||
from datetime import timedelta
|
||||
from hashlib import sha256
|
||||
|
||||
import fastapi
|
||||
import timer_dict
|
||||
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
||||
|
||||
import rio
|
||||
|
||||
# Define constants for OAuth2 client configuration
|
||||
CLIENT_ID = "Iv23liJIcSo4rnvwwN57"
|
||||
CLIENT_SECRET = "383d3f560f1904880e101440210168f14b1e51b9"
|
||||
SCOPE = "user:email"
|
||||
REDIRECT_URI = "http://localhost:8001/callback"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class OAuth:
|
||||
_extension: OAuthExtension
|
||||
_session: rio.Session
|
||||
|
||||
_state: t.Literal["logged-out", "logging-in", "logged-in"]
|
||||
|
||||
@property
|
||||
def is_logged_in(self) -> bool:
|
||||
"""
|
||||
Returns whether the client is currently authenticated.
|
||||
|
||||
Returns `True` if the user has successfully been authenticated and
|
||||
`False` otherwise. This property is read-only. To log in, use the
|
||||
`login` method.
|
||||
|
||||
Note that this will return `False` even if a login is currently in
|
||||
progress. You can use `is_logging_in` to check if that's the case.
|
||||
"""
|
||||
return self._state == "logged-in"
|
||||
|
||||
@property
|
||||
def is_logging_in(self) -> bool:
|
||||
"""
|
||||
Returns whether the client is currently logging in.
|
||||
|
||||
Returns `True` if the client is currently in the process of logging in
|
||||
and `False` otherwise.
|
||||
"""
|
||||
return self._state == "logging-in"
|
||||
|
||||
async def login(self) -> None:
|
||||
# Can't log in twice
|
||||
if self._state == "logging-in":
|
||||
raise RuntimeError("Already logging in")
|
||||
|
||||
if self._state == "logged-in":
|
||||
raise RuntimeError("Already logged in")
|
||||
|
||||
# Logging in now
|
||||
self._state = "logging-in"
|
||||
|
||||
# Generate a code verifier and challenge for PKCE
|
||||
code_verifier = secrets.token_urlsafe(100)
|
||||
code_challenge = sha256(code_verifier.encode("utf-8")).digest()
|
||||
code_challenge = (
|
||||
urlsafe_b64encode(code_challenge).decode("utf-8").rstrip("=")
|
||||
)
|
||||
|
||||
# Create an authorization URL and state using the OAuth2 client
|
||||
async with AsyncOAuth2Client(CLIENT_ID, CLIENT_SECRET) as oauth:
|
||||
auth_url, state = oauth.create_authorization_url(
|
||||
"https://github.com/login/oauth/authorize",
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method="S256",
|
||||
redirect_uri=REDIRECT_URI,
|
||||
)
|
||||
|
||||
# Have the user log in by redirecting to the authorization URL
|
||||
self._session.navigate_to(auth_url)
|
||||
|
||||
# Wait for the callback to be received
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self._extension._pending_requests[state] = (self, future)
|
||||
code = await future
|
||||
|
||||
# Create an OAuth2 client and fetch the access token
|
||||
oauth = AsyncOAuth2Client(CLIENT_ID, CLIENT_SECRET)
|
||||
await oauth.fetch_token(
|
||||
url="https://github.com/login/oauth/access_token",
|
||||
authorization_response=code,
|
||||
code_verifier=code_verifier,
|
||||
)
|
||||
|
||||
# Fetch the user's email from GitHub
|
||||
user = await oauth.get("https://api.github.com/user")
|
||||
print(f"Logged in as {user}")
|
||||
|
||||
async def logout(self) -> None:
|
||||
raise NotImplementedError("TODO")
|
||||
|
||||
|
||||
class OAuthExtension(rio.Extension):
|
||||
"""
|
||||
Easily integrate OAuth2 with Rio.
|
||||
|
||||
This extension provides a simple way to authenticate users using OAuth2
|
||||
in Rio applications. It handles the OAuth2 flow, including generating the
|
||||
authorization URL, handling the callback, and fetching the access token.
|
||||
|
||||
```python
|
||||
TODO: Example usage
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# This stores currently pending authentication requests. The key is
|
||||
# the state, and the value is a tuple containing
|
||||
#
|
||||
# - The OAuth object that is waiting for the callback
|
||||
# - The future that will be set when the callback is received
|
||||
#
|
||||
# The future's parameters are the code and state received from the
|
||||
# callback.
|
||||
self._pending_requests: timer_dict.TimerDict[
|
||||
str, tuple[OAuth, asyncio.Future[str]]
|
||||
] = timer_dict.TimerDict(
|
||||
default_duration=timedelta(minutes=5),
|
||||
)
|
||||
|
||||
async def _callback_route(
|
||||
self,
|
||||
request: fastapi.Request,
|
||||
) -> fastapi.responses.Response:
|
||||
# TODO: Come up with better responses
|
||||
|
||||
# Get the code & state from the query parameters
|
||||
try:
|
||||
code = request.query_params["code"]
|
||||
except KeyError:
|
||||
logging.error(
|
||||
f"Received invalid OAuth callback. `{request.url}` did not contain a `code` query parameter."
|
||||
)
|
||||
return fastapi.responses.Response(
|
||||
content='TODO: Missing "code" query parameter',
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
try:
|
||||
state = request.query_params["state"]
|
||||
except KeyError:
|
||||
logging.error(
|
||||
f"Received invalid OAuth callback. `{request.url}` did not contain a `state` query parameter."
|
||||
)
|
||||
return fastapi.responses.Response(
|
||||
content='TODO: Missing "state" query parameter',
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# Is an OAuth object waiting for this callback?
|
||||
try:
|
||||
oauth, future = self._pending_requests.pop(state)
|
||||
except KeyError:
|
||||
logging.error(
|
||||
f"Received invalid OAuth callback for state `{state}`. There is no pending authentication request with that state."
|
||||
)
|
||||
return fastapi.responses.Response(
|
||||
content="TODO: Invalid state",
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# Update that object
|
||||
future.set_result(code)
|
||||
|
||||
# Done
|
||||
return fastapi.responses.Response(
|
||||
content="TODO: You can close this tab now",
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
@rio.extension_event.on_as_fastapi
|
||||
def on_as_fastapi(self, event: rio.ExtensionAsFastapiEvent) -> None:
|
||||
# Add a route to handle the OAuth2 callback
|
||||
event.fastapi_app.add_api_route(
|
||||
"/callback",
|
||||
self._callback_route,
|
||||
methods=["GET"],
|
||||
)
|
||||
|
||||
@rio.extension_event.on_session_start
|
||||
def on_session_start(self, event: rio.ExtensionSessionStartEvent) -> None:
|
||||
# Attach an object to the session for handling authentication
|
||||
event.session.attach(
|
||||
OAuth(
|
||||
_extension=self,
|
||||
_session=event.session,
|
||||
_state="logged-out",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Define a custom component class for the root of the application
|
||||
class MyRoot(rio.Component):
|
||||
logged_in_email: str | None = None
|
||||
|
||||
# Asynchronous method to handle the login process
|
||||
async def on_login(self) -> None:
|
||||
# Get the OAuth object from the session
|
||||
oauth = self.session[OAuth]
|
||||
|
||||
# Authenticate
|
||||
await oauth.login()
|
||||
|
||||
# Method to build the UI component
|
||||
def build(self) -> rio.Component:
|
||||
if self.logged_in_email is None:
|
||||
# Display a login button if the user is not logged in
|
||||
return rio.Button(
|
||||
"Login with GitHub",
|
||||
icon="brand/github",
|
||||
on_press=self.on_login,
|
||||
align_x=0.5,
|
||||
align_y=0.5,
|
||||
)
|
||||
# Display a welcome message if the user is logged in
|
||||
return rio.Text(
|
||||
f"Hello, {self.logged_in_email}!",
|
||||
justify="center",
|
||||
align_x=0.5,
|
||||
align_y=0.5,
|
||||
)
|
||||
|
||||
|
||||
# Create a Rio application with the custom root component
|
||||
app = rio.App(
|
||||
build=MyRoot,
|
||||
)
|
||||
|
||||
|
||||
app._add_extension(OAuthExtension())
|
||||
2
extensions/rio.toml
Normal file
2
extensions/rio.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
[app]
|
||||
main-module = "oauth_extension"
|
||||
Reference in New Issue
Block a user