add static type annotations

This commit is contained in:
David Lord
2024-04-30 10:46:55 -07:00
parent 691acc186b
commit cdcf917044
27 changed files with 463 additions and 253 deletions

View File

@@ -37,3 +37,19 @@ jobs:
cache-dependency-path: requirements*/*.txt
- run: pip install tox
- run: tox run -e ${{ matrix.tox || format('py{0}', matrix.python) }}
typing:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@0ad4b8fadaa221de15dcec353f45205ec38ea70b # v4.1.4
- uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # v5.1.0
with:
python-version: '3.x'
cache: pip
cache-dependency-path: requirements*/*.txt
- name: cache mypy
uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2
with:
path: ./.mypy_cache
key: mypy|${{ hashFiles('pyproject.toml') }}
- run: pip install tox
- run: tox run -e typing

View File

@@ -11,6 +11,7 @@ classifiers = [
"Framework :: Flask",
"License :: OSI Approved :: BSD License",
"Programming Language :: Python",
"Typing :: Typed",
]
requires-python = ">=3.8"
dependencies = [
@@ -50,6 +51,12 @@ show_error_codes = true
pretty = true
strict = true
[[tool.mypy.overrides]]
module = [
"sqlparse.*"
]
ignore_missing_imports = true
[tool.pyright]
pythonVersion = "3.8"
include = ["src/flask_debugtoolbar", "tests"]

View File

@@ -15,6 +15,7 @@ babel==2.14.0
blinker==1.8.1
# via
# -r tests.txt
# -r typing.txt
# flask
cachetools==5.3.3
# via tox
@@ -33,6 +34,7 @@ charset-normalizer==3.3.2
click==8.1.7
# via
# -r tests.txt
# -r typing.txt
# flask
colorama==0.4.6
# via tox
@@ -54,9 +56,12 @@ filelock==3.14.0
flask==3.0.3
# via
# -r tests.txt
# -r typing.txt
# flask-sqlalchemy
flask-sqlalchemy==3.1.1
# via -r tests.txt
# via
# -r tests.txt
# -r typing.txt
identify==2.5.36
# via pre-commit
idna==3.7
@@ -71,6 +76,7 @@ importlib-metadata==7.1.0
# via
# -r docs.txt
# -r tests.txt
# -r typing.txt
# flask
# sphinx
iniconfig==2.0.0
@@ -81,17 +87,20 @@ iniconfig==2.0.0
itsdangerous==2.2.0
# via
# -r tests.txt
# -r typing.txt
# flask
jinja2==3.1.3
# via
# -r docs.txt
# -r tests.txt
# -r typing.txt
# flask
# sphinx
markupsafe==2.1.5
# via
# -r docs.txt
# -r tests.txt
# -r typing.txt
# jinja2
# werkzeug
mypy==1.10.0
@@ -190,6 +199,7 @@ sphinxcontrib-serializinghtml==1.1.5
sqlalchemy==2.0.29
# via
# -r tests.txt
# -r typing.txt
# flask-sqlalchemy
tomli==2.0.1
# via
@@ -201,6 +211,16 @@ tomli==2.0.1
# tox
tox==4.15.0
# via -r dev.in
types-docutils==0.21.0.20240423
# via
# -r typing.txt
# types-pygments
types-pygments==2.17.0.20240310
# via -r typing.txt
types-setuptools==69.5.0.20240423
# via
# -r typing.txt
# types-pygments
typing-extensions==4.11.0
# via
# -r tests.txt
@@ -218,11 +238,13 @@ virtualenv==20.26.1
werkzeug==3.0.2
# via
# -r tests.txt
# -r typing.txt
# flask
zipp==3.18.1
# via
# -r docs.txt
# -r tests.txt
# -r typing.txt
# importlib-metadata
# The following packages are considered to be unsafe in a requirements file:

View File

@@ -1,3 +1,5 @@
mypy
pyright
pytest
types-pygments
flask-sqlalchemy

View File

@@ -4,10 +4,28 @@
#
# pip-compile typing.in
#
blinker==1.8.1
# via flask
click==8.1.7
# via flask
exceptiongroup==1.2.1
# via pytest
flask==3.0.3
# via flask-sqlalchemy
flask-sqlalchemy==3.1.1
# via -r typing.in
importlib-metadata==7.1.0
# via flask
iniconfig==2.0.0
# via pytest
itsdangerous==2.2.0
# via flask
jinja2==3.1.3
# via flask
markupsafe==2.1.5
# via
# jinja2
# werkzeug
mypy==1.10.0
# via -r typing.in
mypy-extensions==1.0.0
@@ -22,12 +40,26 @@ pyright==1.1.360
# via -r typing.in
pytest==8.2.0
# via -r typing.in
sqlalchemy==2.0.29
# via flask-sqlalchemy
tomli==2.0.1
# via
# mypy
# pytest
types-docutils==0.21.0.20240423
# via types-pygments
types-pygments==2.17.0.20240310
# via -r typing.in
types-setuptools==69.5.0.20240423
# via types-pygments
typing-extensions==4.11.0
# via mypy
# via
# mypy
# sqlalchemy
werkzeug==3.0.2
# via flask
zipp==3.18.1
# via importlib-metadata
# The following packages are considered to be unsafe in a requirements file:
# setuptools

View File

@@ -1,19 +1,26 @@
import contextvars
from __future__ import annotations
import collections.abc as c
import importlib.metadata
import os
import typing as t
import urllib.parse
import warnings
from contextvars import ContextVar
from flask import Blueprint
from flask import current_app
from flask import Flask
from flask import g
from flask import request
from flask import send_from_directory
from flask import url_for
from flask.globals import request_ctx
from jinja2 import __version__ as __jinja_version__
from jinja2 import Environment
from jinja2 import PackageLoader
from werkzeug import Request
from werkzeug import Response
from werkzeug.routing import Rule
from .toolbar import DebugToolbar
from .utils import decode_text
@@ -21,11 +28,12 @@ from .utils import gzip_compress
from .utils import gzip_decompress
__version__ = importlib.metadata.version("flask-debugtoolbar")
_jinja_version = importlib.metadata.version("jinja2")
module = Blueprint("debugtoolbar", __name__)
module: Blueprint = Blueprint("debugtoolbar", __name__)
def replace_insensitive(string, target, replacement):
def replace_insensitive(string: str, target: str, replacement: str) -> str:
"""Similar to string.replace() but is case insensitive
Code borrowed from:
http://forums.devshed.com/python-programming-11/case-insensitive-string-replace-490921.html
@@ -39,7 +47,7 @@ def replace_insensitive(string, target, replacement):
return string
def _printable(value):
def _printable(value: object) -> str:
try:
return decode_text(repr(value))
except Exception as e:
@@ -52,19 +60,21 @@ class DebugToolbarExtension:
_toolbar_codes = [200, 201, 400, 401, 403, 404, 405, 500, 501, 502, 503, 504]
_redirect_codes = [301, 302, 303, 304]
def __init__(self, app=None):
def __init__(self, app: Flask | None = None) -> None:
self.app = app
# Support threads running `flask.copy_current_request_context` without
# poping toolbar during `teardown_request`
self.debug_toolbars_var = contextvars.ContextVar("debug_toolbars")
self.debug_toolbars_var: ContextVar[dict[Request, DebugToolbar]] = ContextVar(
"debug_toolbars"
)
jinja_extensions = ["jinja2.ext.i18n"]
if __jinja_version__[0] == "2":
if _jinja_version[0] == "2":
jinja_extensions.append("jinja2.ext.with_")
# Configure jinja for the internal templates and add url rules
# for static data
self.jinja_env = Environment(
self.jinja_env: Environment = Environment(
autoescape=True,
extensions=jinja_extensions,
loader=PackageLoader(__name__, "templates"),
@@ -76,7 +86,7 @@ class DebugToolbarExtension:
if app is not None:
self.init_app(app)
def init_app(self, app):
def init_app(self, app: Flask) -> None:
for k, v in self._default_config(app).items():
app.config.setdefault(k, v)
@@ -96,7 +106,7 @@ class DebugToolbarExtension:
app.teardown_request(self.teardown_request)
# Monkey-patch the Flask.dispatch_request method
app.dispatch_request = self.dispatch_request
app.dispatch_request = self.dispatch_request # type: ignore[method-assign]
app.add_url_rule(
"/_debug_toolbar/static/<path:filename>",
@@ -106,7 +116,7 @@ class DebugToolbarExtension:
app.register_blueprint(module, url_prefix="/_debug_toolbar/views")
def _default_config(self, app):
def _default_config(self, app: Flask) -> dict[str, t.Any]:
return {
"DEBUG_TB_ENABLED": app.debug,
"DEBUG_TB_HOSTS": (),
@@ -127,30 +137,32 @@ class DebugToolbarExtension:
"SQLALCHEMY_RECORD_QUERIES": app.debug,
}
def dispatch_request(self):
"""Modified version of Flask.dispatch_request to call process_view."""
def dispatch_request(self) -> t.Any:
"""Modified version of ``Flask.dispatch_request`` to call
:meth:`process_view`.
"""
# self references this extension, use current_app to call app methods.
app = current_app._get_current_object() # type: ignore[attr-defined]
req = request_ctx.request
app = current_app
if req.routing_exception is not None:
app.raise_routing_exception(req)
rule = req.url_rule
rule: Rule = req.url_rule # type: ignore[assignment]
# if we provide automatic options for this URL and the
# request came with the OPTIONS method, reply automatically
if (
getattr(rule, "provide_automatic_options", False)
and req.method == "OPTIONS"
):
return app.make_default_options_response()
# otherwise dispatch to the handler for that endpoint
view_func = app.view_functions[rule.endpoint]
view_func = self.process_view(app, view_func, req.view_args)
return view_func(**req.view_args)
view_args: dict[str, t.Any] = req.view_args # type: ignore[assignment]
# allow each toolbar to process the view and args
view_func = self.process_view(app, view_func, view_args)
return view_func(**view_args)
def _show_toolbar(self):
def _show_toolbar(self) -> bool:
"""Return a boolean to indicate if we need to show the toolbar."""
if request.blueprint == "debugtoolbar":
return False
@@ -162,17 +174,17 @@ class DebugToolbarExtension:
return True
def send_static_file(self, filename):
def send_static_file(self, filename: str) -> Response:
"""Send a static file from the flask-debugtoolbar static directory."""
return send_from_directory(self._static_dir, filename)
def process_request(self):
def process_request(self) -> None:
g.debug_toolbar = self
if not self._show_toolbar():
return
real_request = request._get_current_object()
real_request = request._get_current_object() # type: ignore[attr-defined]
self.debug_toolbars_var.set({})
self.debug_toolbars_var.get()[real_request] = DebugToolbar(
real_request, self.jinja_env
@@ -181,11 +193,16 @@ class DebugToolbarExtension:
for panel in self.debug_toolbars_var.get()[real_request].panels:
panel.process_request(real_request)
def process_view(self, app, view_func, view_kwargs):
def process_view(
self,
app: Flask,
view_func: c.Callable[..., t.Any],
view_kwargs: dict[str, t.Any],
) -> c.Callable[..., t.Any]:
"""This method is called just before the flask view is called.
This is done by the dispatch_request method.
"""
real_request = request._get_current_object()
real_request = request._get_current_object() # type: ignore[attr-defined]
try:
toolbar = self.debug_toolbars_var.get({})[real_request]
@@ -200,8 +217,8 @@ class DebugToolbarExtension:
return view_func
def process_response(self, response):
real_request = request._get_current_object()
def process_response(self, response: Response) -> Response:
real_request = request._get_current_object() # type: ignore[attr-defined]
if real_request not in self.debug_toolbars_var.get({}):
return response
@@ -219,7 +236,7 @@ class DebugToolbarExtension:
{"redirect_to": redirect_to, "redirect_code": redirect_code},
)
response.content_length = len(content)
response.location = None
del response.location
response.response = [content]
response.status_code = 200
@@ -263,20 +280,21 @@ class DebugToolbarExtension:
toolbar_html = toolbar.render_toolbar()
content = "".join((before, toolbar_html, after))
content = content.encode("utf-8")
content_bytes = content.encode("utf-8")
if content_encoding and "gzip" in content_encoding:
content = gzip_compress(content)
content_bytes = gzip_compress(content_bytes)
response.response = [content]
response.content_length = len(content)
response.response = [content_bytes]
response.content_length = len(content_bytes)
return response
def teardown_request(self, exc):
def teardown_request(self, exc: BaseException | None) -> None:
# debug_toolbars_var won't be set under `flask.copy_current_request_context`
self.debug_toolbars_var.get({}).pop(request._get_current_object(), None)
real_request = request._get_current_object() # type: ignore[attr-defined]
self.debug_toolbars_var.get({}).pop(real_request, None)
def render(self, template_name, context):
def render(self, template_name: str, context: dict[str, t.Any]) -> str:
template = self.jinja_env.get_template(template_name)
return template.render(**context)

View File

@@ -1,7 +1,18 @@
from __future__ import annotations
import collections.abc as c
import typing as t
from flask import Flask
from jinja2 import Environment
from werkzeug import Request
from werkzeug import Response
class DebugPanel:
"""Base class for debug panels."""
# name = Base
name: str
# If content returns something, set to true in subclass
has_content = False
@@ -11,10 +22,12 @@ class DebugPanel:
# We'll maintain a local context instance so we can expose our template
# context variables to panels which need them:
context = {}
context: dict[str, t.Any] = {}
# Panel methods
def __init__(self, jinja_env, context=None):
def __init__(
self, jinja_env: Environment, context: dict[str, t.Any] | None = None
) -> None:
if context is not None:
self.context.update(context)
@@ -23,7 +36,7 @@ class DebugPanel:
self.is_active = False
@classmethod
def init_app(cls, app):
def init_app(cls, app: Flask) -> None:
"""Method that can be overridden by child classes.
Can be used for setting up additional URL-rules/routes.
@@ -45,37 +58,42 @@ class DebugPanel:
"""
pass
def render(self, template_name, context):
def render(self, template_name: str, context: dict[str, t.Any]) -> str:
template = self.jinja_env.get_template(template_name)
return template.render(**context)
def dom_id(self):
def dom_id(self) -> str:
return f"flDebug{self.name.replace(' ', '')}Panel"
def nav_title(self):
def nav_title(self) -> str:
"""Title showing in toolbar"""
raise NotImplementedError
def nav_subtitle(self):
def nav_subtitle(self) -> str:
"""Subtitle showing until title in toolbar"""
return ""
def title(self):
def title(self) -> str:
"""Title showing in panel"""
raise NotImplementedError
def url(self):
def url(self) -> str:
raise NotImplementedError
def content(self):
def content(self) -> str:
raise NotImplementedError
# Standard middleware methods
def process_request(self, request):
def process_request(self, request: Request) -> None:
pass
def process_view(self, request, view_func, view_kwargs):
def process_view(
self,
request: Request,
view_func: c.Callable[..., t.Any],
view_kwargs: dict[str, t.Any],
) -> c.Callable[..., t.Any] | None:
pass
def process_response(self, request, response):
def process_response(self, request: Request, response: Response) -> None:
pass

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
from flask import current_app
from . import DebugPanel
@@ -9,16 +11,16 @@ class ConfigVarsDebugPanel(DebugPanel):
name = "ConfigVars"
has_content = True
def nav_title(self):
def nav_title(self) -> str:
return "Config"
def title(self):
def title(self) -> str:
return "Config"
def url(self):
def url(self) -> str:
return ""
def content(self):
def content(self) -> str:
context = self.context.copy()
context.update(
{

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
from flask import g
from . import DebugPanel
@@ -9,16 +11,16 @@ class GDebugPanel(DebugPanel):
name = "g"
has_content = True
def nav_title(self):
def nav_title(self) -> str:
return "flask.g"
def title(self):
def title(self) -> str:
return "flask.g content"
def url(self):
def url(self) -> str:
return ""
def content(self):
def content(self) -> str:
context = self.context.copy()
context.update({"g_content": g.__dict__})
return self.render("panels/g.html", context)

View File

@@ -1,3 +1,9 @@
from __future__ import annotations
import typing as t
from werkzeug import Request
from . import DebugPanel
@@ -7,7 +13,7 @@ class HeaderDebugPanel(DebugPanel):
name = "Header"
has_content = True
# List of headers we want to display
header_filter = (
header_filter: tuple[str, ...] = (
"CONTENT_TYPE",
"HTTP_ACCEPT",
"HTTP_ACCEPT_CHARSET",
@@ -30,25 +36,21 @@ class HeaderDebugPanel(DebugPanel):
"SERVER_SOFTWARE",
)
def nav_title(self):
def nav_title(self) -> str:
return "HTTP Headers"
def title(self):
def title(self) -> str:
return "HTTP Headers"
def url(self):
def url(self) -> str:
return ""
def process_request(self, request):
self.headers = dict(
[
(k, request.environ[k])
for k in self.header_filter
if k in request.environ
]
)
def process_request(self, request: Request) -> None:
self.headers: dict[str, t.Any] = {
k: request.environ[k] for k in self.header_filter if k in request.environ
}
def content(self):
def content(self) -> str:
context = self.context.copy()
context.update({"headers": self.headers})
return self.render("panels/headers.html", context)

View File

@@ -1,20 +1,27 @@
from __future__ import annotations
import datetime
import logging
import threading
from werkzeug import Request
from ..utils import format_fname
from . import DebugPanel
class ThreadTrackingHandler(logging.Handler):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.records = {} # a dictionary that maps threads to log records
# a dictionary that maps threads to log records
self.records: dict[threading.Thread, list[logging.LogRecord]] = {}
def emit(self, record):
def emit(self, record: logging.LogRecord) -> None:
self.get_records().append(record)
def get_records(self, thread=None):
def get_records(
self, thread: threading.Thread | None = None
) -> list[logging.LogRecord]:
"""
Returns a list of records for the provided thread, of if none is
provided, returns a list for the current thread.
@@ -27,7 +34,7 @@ class ThreadTrackingHandler(logging.Handler):
return self.records[thread]
def clear_records(self, thread=None):
def clear_records(self, thread: threading.Thread | None = None) -> None:
if thread is None:
thread = threading.current_thread()
@@ -35,11 +42,11 @@ class ThreadTrackingHandler(logging.Handler):
del self.records[thread]
handler = None
handler: ThreadTrackingHandler = None # type: ignore[assignment]
_init_lock = threading.Lock()
def _init_once():
def _init_once() -> None:
global handler
if handler is not None:
@@ -65,30 +72,30 @@ class LoggingPanel(DebugPanel):
name = "Logging"
has_content = True
def process_request(self, request):
def process_request(self, request: Request) -> None:
_init_once()
handler.clear_records()
def get_and_delete(self):
def get_and_delete(self) -> list[logging.LogRecord]:
records = handler.get_records()
handler.clear_records()
return records
def nav_title(self):
def nav_title(self) -> str:
return "Logging"
def nav_subtitle(self):
def nav_subtitle(self) -> str:
num_records = len(handler.get_records())
plural = "message" if num_records == 1 else "messages"
return f"{num_records} {plural}"
def title(self):
def title(self) -> str:
return "Log Messages"
def url(self):
def url(self) -> str:
return ""
def content(self):
def content(self) -> str:
records = []
for record in self.get_and_delete():

View File

@@ -1,7 +1,14 @@
from __future__ import annotations
import collections.abc as c
import functools
import pstats
import typing as t
from flask import current_app
from jinja2 import Environment
from werkzeug import Request
from werkzeug import Response
from ..utils import format_fname
from . import DebugPanel
@@ -9,7 +16,7 @@ from . import DebugPanel
try:
import cProfile as profile
except ImportError:
import profile
import profile # type: ignore[no-redef]
class ProfilerDebugPanel(DebugPanel):
@@ -18,7 +25,15 @@ class ProfilerDebugPanel(DebugPanel):
name = "Profiler"
user_activate = True
def __init__(self, jinja_env, context=None):
is_active: bool = False
dump_filename: str | None = None
profiler: profile.Profile
stats: pstats.Stats | None = None
function_calls: list[dict[str, t.Any]]
def __init__(
self, jinja_env: Environment, context: dict[str, t.Any] | None = None
) -> None:
super().__init__(jinja_env, context=context)
if current_app.config.get("DEBUG_TB_PROFILER_ENABLED"):
@@ -27,40 +42,48 @@ class ProfilerDebugPanel(DebugPanel):
"DEBUG_TB_PROFILER_DUMP_FILENAME"
)
def has_content(self):
@property
def has_content(self) -> bool: # type: ignore[override]
return bool(self.profiler)
def process_request(self, request):
def process_request(self, request: Request) -> None:
if not self.is_active:
return
self.profiler = profile.Profile()
self.profiler = profile.Profile() # pyright: ignore
self.stats = None
def process_view(self, request, view_func, view_kwargs):
def process_view(
self,
request: Request,
view_func: c.Callable[..., t.Any],
view_kwargs: dict[str, t.Any],
) -> c.Callable[..., t.Any] | None:
if self.is_active:
func = functools.partial(self.profiler.runcall, view_func)
functools.update_wrapper(func, view_func)
return func
def process_response(self, request, response):
return None
def process_response(self, request: Request, response: Response) -> None:
if not self.is_active:
return False
return
if self.profiler is not None:
self.profiler.disable()
self.profiler.disable() # pyright: ignore
try:
stats = pstats.Stats(self.profiler)
except TypeError:
self.is_active = False
return False
return
function_calls = []
function_calls: list[dict[str, t.Any]] = []
for func in stats.sort_stats(1).fcn_list:
current = {}
info = stats.stats[func]
for func in stats.sort_stats(1).fcn_list: # type: ignore[attr-defined]
current: dict[str, t.Any] = {}
info = stats.stats[func] # type: ignore[attr-defined]
# Number of calls
if info[0] != info[1]:
@@ -88,7 +111,7 @@ class ProfilerDebugPanel(DebugPanel):
current["percall_cum"] = 0
# Filename
filename = pstats.func_std_string(func)
filename = pstats.func_std_string(func) # type: ignore[attr-defined]
current["filename_long"] = filename
current["filename"] = format_fname(filename)
function_calls.append(current)
@@ -104,27 +127,25 @@ class ProfilerDebugPanel(DebugPanel):
self.profiler.dump_stats(filename)
return response
def title(self):
def title(self) -> str:
if not self.is_active:
return "Profiler not active"
return f"View: {float(self.stats.total_tt) * 1000:.2f}ms"
return f"View: {float(self.stats.total_tt) * 1000:.2f}ms" # type: ignore[union-attr]
def nav_title(self):
def nav_title(self) -> str:
return "Profiler"
def nav_subtitle(self):
def nav_subtitle(self) -> str:
if not self.is_active:
return "in-active"
return f"View: {float(self.stats.total_tt) * 1000:.2f}ms"
return f"View: {float(self.stats.total_tt) * 1000:.2f}ms" # type: ignore[union-attr]
def url(self):
def url(self) -> str:
return ""
def content(self):
def content(self) -> str:
if not self.is_active:
return "The profiler is not activated, activate it to use it"

View File

@@ -1,4 +1,10 @@
from __future__ import annotations
import collections.abc as c
import typing as t
from flask import session
from werkzeug import Request
from . import DebugPanel
@@ -9,27 +15,31 @@ class RequestVarsDebugPanel(DebugPanel):
name = "RequestVars"
has_content = True
def nav_title(self):
def nav_title(self) -> str:
return "Request Vars"
def title(self):
def title(self) -> str:
return "Request Vars"
def url(self):
def url(self) -> str:
return ""
def process_request(self, request):
def process_request(self, request: Request) -> None:
self.request = request
self.session = session
self.view_func = None
self.view_args = []
self.view_kwargs = {}
self.view_func: c.Callable[..., t.Any] | None = None
self.view_kwargs: dict[str, t.Any] = {}
def process_view(self, request, view_func, view_kwargs):
def process_view(
self,
request: Request,
view_func: c.Callable[..., t.Any],
view_kwargs: dict[str, t.Any],
) -> None:
self.view_func = view_func
self.view_kwargs = view_kwargs
def content(self):
def content(self) -> str:
context = self.context.copy()
context.update(
{
@@ -41,7 +51,6 @@ class RequestVarsDebugPanel(DebugPanel):
if self.view_func
else "[unknown]"
),
"view_args": self.view_args,
"view_kwargs": self.view_kwargs or {},
"session": self.session.items(),
}

View File

@@ -1,4 +1,8 @@
from __future__ import annotations
from flask import current_app
from werkzeug import Request
from werkzeug.routing import Rule
from . import DebugPanel
@@ -8,30 +12,30 @@ class RouteListDebugPanel(DebugPanel):
name = "RouteList"
has_content = True
routes = []
routes: list[Rule] = []
def nav_title(self):
def nav_title(self) -> str:
return "Route List"
def title(self):
def title(self) -> str:
return "Route List"
def url(self):
def url(self) -> str:
return ""
def nav_subtitle(self):
def nav_subtitle(self) -> str:
count = len(self.routes)
plural = "route" if count == 1 else "routes"
return f"{count} {plural}"
def process_request(self, request):
def process_request(self, request: Request) -> None:
self.routes = [
rule
for rule in current_app.url_map.iter_rules()
if not rule.rule.startswith("/_debug_toolbar")
]
def content(self):
def content(self) -> str:
return self.render(
"panels/route_list.html",
{

View File

@@ -1,3 +1,7 @@
from __future__ import annotations
import typing as t
import itsdangerous
from flask import abort
from flask import current_app
@@ -12,40 +16,48 @@ from . import DebugPanel
try:
from flask_sqlalchemy import SQLAlchemy
except ImportError:
sqlalchemy_available = False
get_recorded_queries = SQLAlchemy = None
debug_enables_record_queries = False
sqlalchemy_available: bool = False
get_recorded_queries = SQLAlchemy = None # type: ignore[misc, assignment]
debug_enables_record_queries: bool = False
else:
try:
from flask_sqlalchemy.record_queries import get_recorded_queries
from flask_sqlalchemy.record_queries import ( # type: ignore[assignment]
get_recorded_queries,
)
debug_enables_record_queries = False
except ImportError:
# For flask_sqlalchemy < 3.0.0
from flask_sqlalchemy import get_debug_queries as get_recorded_queries
from flask_sqlalchemy import ( # type: ignore[no-redef]
get_debug_queries as get_recorded_queries,
)
# flask_sqlalchemy < 3.0.0 automatically enabled
# SQLALCHEMY_RECORD_QUERIES in debug or test mode
debug_enables_record_queries = True
location_property = "context"
location_property: str = "context"
else:
location_property = "location"
sqlalchemy_available = True
def query_signer():
def query_signer() -> itsdangerous.URLSafeSerializer:
return itsdangerous.URLSafeSerializer(
current_app.config["SECRET_KEY"], salt="fdt-sql-query"
)
def is_select(statement):
prefix = b"select" if isinstance(statement, bytes) else "select"
return statement.lower().strip().startswith(prefix)
def is_select(statement: str | bytes) -> bool:
statement = statement.lower().strip()
if isinstance(statement, bytes):
return statement.startswith(b"select")
return statement.startswith("select") # pyright: ignore
def dump_query(statement, params):
def dump_query(statement: str, params: t.Any) -> str | None:
if not params or not is_select(statement):
return None
@@ -55,9 +67,9 @@ def dump_query(statement, params):
return None
def load_query(data):
def load_query(data: str) -> tuple[str, t.Any]:
try:
statement, params = query_signer().loads(request.args["query"])
statement, params = query_signer().loads(data)
except (itsdangerous.BadSignature, TypeError):
abort(406)
@@ -68,21 +80,21 @@ def load_query(data):
return statement, params
def extension_used():
def extension_used() -> bool:
return "sqlalchemy" in current_app.extensions
def recording_enabled():
def recording_enabled() -> bool:
return (
debug_enables_record_queries and current_app.debug
) or current_app.config.get("SQLALCHEMY_RECORD_QUERIES")
) or current_app.config.get("SQLALCHEMY_RECORD_QUERIES", False)
def is_available():
def is_available() -> bool:
return sqlalchemy_available and extension_used() and recording_enabled()
def get_queries():
def get_queries() -> list[t.Any]:
if get_recorded_queries:
return get_recorded_queries()
else:
@@ -95,19 +107,13 @@ class SQLAlchemyDebugPanel(DebugPanel):
name = "SQLAlchemy"
@property
def has_content(self):
def has_content(self) -> bool: # type: ignore[override]
return bool(get_queries()) or not is_available()
def process_request(self, request):
pass
def process_response(self, request, response):
pass
def nav_title(self):
def nav_title(self) -> str:
return "SQLAlchemy"
def nav_subtitle(self):
def nav_subtitle(self) -> str:
count = len(get_queries())
if not count and not is_available():
@@ -116,13 +122,13 @@ class SQLAlchemyDebugPanel(DebugPanel):
plural = "query" if count == 1 else "queries"
return f"{count} {plural}"
def title(self):
def title(self) -> str:
return "SQLAlchemy queries"
def url(self):
def url(self) -> str:
return ""
def content(self):
def content(self) -> str:
queries = get_queries()
if not queries and not is_available():
@@ -158,9 +164,9 @@ class SQLAlchemyDebugPanel(DebugPanel):
@module.route(
"/sqlalchemy/sql_explain", methods=["GET", "POST"], defaults=dict(explain=True)
)
def sql_select(explain=False):
def sql_select(explain: bool = False) -> str:
statement, params = load_query(request.args["query"])
engine = SQLAlchemy().get_engine(current_app)
engine = current_app.extensions["sqlalchemy"].engine
if explain:
if engine.driver == "pysqlite":
@@ -169,7 +175,7 @@ def sql_select(explain=False):
statement = f"EXPLAIN\n{statement}"
result = engine.execute(statement, params)
return g.debug_toolbar.render(
return g.debug_toolbar.render( # type: ignore[no-any-return]
"panels/sqlalchemy_select.html",
{
"result": result.fetchall(),

View File

@@ -1,7 +1,10 @@
import collections
from __future__ import annotations
import json
import sys
import typing as t
import uuid
from collections import deque
from flask import abort
from flask import current_app
@@ -10,6 +13,7 @@ from flask import request
from flask import Response
from flask import template_rendered
from flask import url_for
from jinja2 import Template
from .. import module
from . import DebugPanel
@@ -22,23 +26,23 @@ class TemplateDebugPanel(DebugPanel):
has_content = True
# save the context for the 5 most recent requests
template_cache = collections.deque(maxlen=5)
template_cache: deque[tuple[str, list[dict[str, t.Any]]]] = deque(maxlen=5)
@classmethod
def get_cache_for_key(self, key):
for cache_key, value in self.template_cache:
def get_cache_for_key(cls, key: str) -> list[dict[str, t.Any]]:
for cache_key, value in cls.template_cache:
if key == cache_key:
return value
raise KeyError(key)
def __init__(self, *args, **kwargs):
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs)
self.key = str(uuid.uuid4())
self.templates = []
self.key: str = str(uuid.uuid4())
self.templates: list[dict[str, t.Any]] = []
template_rendered.connect(self._store_template_info)
def _store_template_info(self, sender, **kwargs):
def _store_template_info(self, sender: t.Any, **kwargs: t.Any) -> None:
# only record in the cache if the editor is enabled and there is
# actually a template for this request
if not self.templates and is_editor_enabled():
@@ -46,25 +50,19 @@ class TemplateDebugPanel(DebugPanel):
self.templates.append(kwargs)
def process_request(self, request):
pass
def process_response(self, request, response):
pass
def nav_title(self):
def nav_title(self) -> str:
return "Templates"
def nav_subtitle(self):
def nav_subtitle(self) -> str:
return f"{len(self.templates)} rendered"
def title(self):
def title(self) -> str:
return "Templates"
def url(self):
def url(self) -> str:
return ""
def content(self):
def content(self) -> str:
return self.render(
"panels/template.html",
{
@@ -75,33 +73,36 @@ class TemplateDebugPanel(DebugPanel):
)
def is_editor_enabled():
return current_app.config.get("DEBUG_TB_TEMPLATE_EDITOR_ENABLED")
def is_editor_enabled() -> bool:
return current_app.config.get("DEBUG_TB_TEMPLATE_EDITOR_ENABLED", False) # type: ignore
def require_enabled():
def require_enabled() -> None:
if not is_editor_enabled():
abort(403)
def _get_source(template):
def _get_source(template: Template) -> str:
if template.filename is None:
return ""
with open(template.filename, "rb") as fp:
source = fp.read()
return source.decode(_template_encoding())
def _template_encoding():
def _template_encoding() -> str:
return getattr(current_app.jinja_loader, "encoding", "utf-8")
@module.route("/template/<key>")
def template_editor(key):
def template_editor(key: str) -> str:
require_enabled()
# TODO set up special loader that caches templates it loads
# and can override template contents
templates = [t["template"] for t in TemplateDebugPanel.get_cache_for_key(key)]
return g.debug_toolbar.render(
return g.debug_toolbar.render( # type: ignore[no-any-return]
"panels/template_editor.html",
{
"static_path": url_for("_debug_toolbar.static", filename=""),
@@ -114,7 +115,7 @@ def template_editor(key):
@module.route("/template/<key>/save", methods=["POST"])
def save_template(key):
def save_template(key: str) -> str:
require_enabled()
template = TemplateDebugPanel.get_cache_for_key(key)[0]["template"]
content = request.form["content"].encode(_template_encoding())
@@ -126,7 +127,7 @@ def save_template(key):
@module.route("/template/<key>", methods=["POST"])
def template_preview(key):
def template_preview(key: str) -> str | Response:
require_enabled()
context = TemplateDebugPanel.get_cache_for_key(key)[0]["context"]
content = request.form["content"]
@@ -139,10 +140,10 @@ def template_preview(key):
tb = sys.exc_info()[2]
try:
while tb.tb_next:
tb = tb.tb_next
while tb.tb_next: # type: ignore[union-attr]
tb = tb.tb_next # type: ignore[union-attr]
msg = {"lineno": tb.tb_lineno, "error": str(e)}
msg = {"lineno": tb.tb_lineno, "error": str(e)} # type: ignore[union-attr]
return Response(json.dumps(msg), status=400, mimetype="application/json")
finally:
del tb

View File

@@ -1,5 +1,10 @@
from __future__ import annotations
import time
from werkzeug import Request
from werkzeug import Response
from . import DebugPanel
try:
@@ -16,22 +21,22 @@ class TimerDebugPanel(DebugPanel):
name = "Timer"
has_content = HAVE_RESOURCE
def process_request(self, request):
def process_request(self, request: Request) -> None:
self._start_time = time.time()
if HAVE_RESOURCE:
self._start_rusage = resource.getrusage(resource.RUSAGE_SELF)
def process_response(self, request, response):
self.total_time = (time.time() - self._start_time) * 1000
def process_response(self, request: Request, response: Response) -> None:
self.total_time: float = (time.time() - self._start_time) * 1000
if HAVE_RESOURCE:
self._end_rusage = resource.getrusage(resource.RUSAGE_SELF)
def nav_title(self):
def nav_title(self) -> str:
return "Time"
def nav_subtitle(self):
def nav_subtitle(self) -> str:
if not HAVE_RESOURCE:
return f"TOTAL: {self.total_time:0.2f}ms"
@@ -39,16 +44,16 @@ class TimerDebugPanel(DebugPanel):
stime = self._end_rusage.ru_stime - self._start_rusage.ru_stime
return f"CPU: {(utime + stime) * 1000.0:0.2f}ms ({self.total_time:0.2f}ms)"
def title(self):
def title(self) -> str:
return "Resource Usage"
def url(self):
def url(self) -> str:
return ""
def _elapsed_ru(self, name):
return getattr(self._end_rusage, name) - getattr(self._start_rusage, name)
def _elapsed_ru(self, name: str) -> float:
return getattr(self._end_rusage, name) - getattr(self._start_rusage, name) # type: ignore[no-any-return]
def content(self):
def content(self) -> str:
utime = 1000 * self._elapsed_ru("ru_utime")
stime = 1000 * self._elapsed_ru("ru_stime")
vcsw = self._elapsed_ru("ru_nvcsw")

View File

@@ -1,10 +1,12 @@
from __future__ import annotations
import importlib.metadata
import os
from sysconfig import get_path
from . import DebugPanel
flask_version = importlib.metadata.version("flask")
flask_version: str = importlib.metadata.version("flask")
class VersionDebugPanel(DebugPanel):
@@ -13,19 +15,19 @@ class VersionDebugPanel(DebugPanel):
name = "Version"
has_content = True
def nav_title(self):
def nav_title(self) -> str:
return "Versions"
def nav_subtitle(self):
def nav_subtitle(self) -> str:
return f"Flask {flask_version}"
def url(self):
def url(self) -> str:
return ""
def title(self):
def title(self) -> str:
return "Versions"
def content(self):
def content(self) -> str:
packages_metadata = [p.metadata for p in importlib.metadata.distributions()]
packages = sorted(packages_metadata, key=lambda p: p["Name"].lower())
return self.render(

View File

View File

@@ -4,14 +4,12 @@
<thead>
<tr>
<th>View Function</th>
<th>args</th>
<th>kwargs</th>
</tr>
</thead>
<tbody>
<tr>
<td>{{ view_func }}</td>
<td>{{ view_args|default("None") }}</td>
<td>
{% if view_kwargs.items() %}
{% for k, v in view_kwargs.items() %}

View File

@@ -1,28 +1,35 @@
from __future__ import annotations
import collections.abc as c
import typing as t
from urllib.parse import unquote
from flask import current_app
from flask import Flask
from flask import url_for
from jinja2 import Environment
from werkzeug import Request
from werkzeug.utils import import_string
from .panels import DebugPanel
class DebugToolbar:
_cached_panel_classes = {}
_cached_panel_classes: t.ClassVar[dict[str, type[DebugPanel] | None]] = {}
def __init__(self, request, jinja_env):
def __init__(self, request: Request, jinja_env: Environment) -> None:
self.jinja_env = jinja_env
self.request = request
self.panels = []
self.template_context = {
self.panels: list[DebugPanel] = []
self.template_context: dict[str, t.Any] = {
"static_path": url_for("_debug_toolbar.static", filename="")
}
self.create_panels()
def create_panels(self):
def create_panels(self) -> None:
"""Populate debug panels"""
activated = self.request.cookies.get("fldt_active", "")
activated = unquote(activated).split(";")
activated_str = self.request.cookies.get("fldt_active", "")
activated = unquote(activated_str).split(";")
for panel_class in self._iter_panels(current_app):
panel_instance = panel_class(
@@ -34,21 +41,20 @@ class DebugToolbar:
self.panels.append(panel_instance)
def render_toolbar(self):
def render_toolbar(self) -> str:
context = self.template_context.copy()
context.update({"panels": self.panels})
template = self.jinja_env.get_template("base.html")
return template.render(**context)
@classmethod
def load_panels(cls, app):
def load_panels(cls, app: Flask) -> None:
for panel_class in cls._iter_panels(app):
# Call `.init_app()` on panels
panel_class.init_app(app)
@classmethod
def _iter_panels(cls, app):
def _iter_panels(cls, app: Flask) -> c.Iterator[type[DebugPanel]]:
for panel_path in app.config["DEBUG_TB_PANELS"]:
panel_class = cls._import_panel(app, panel_path)
@@ -56,7 +62,7 @@ class DebugToolbar:
yield panel_class
@classmethod
def _import_panel(cls, app, path):
def _import_panel(cls, app: Flask, path: str) -> type[DebugPanel] | None:
cache = cls._cached_panel_classes
try:
@@ -65,7 +71,7 @@ class DebugToolbar:
pass
try:
panel_class = import_string(path)
panel_class: type[DebugPanel] | None = import_string(path)
except ImportError as e:
app.logger.warning("Disabled %s due to ImportError: %s", path, e)
panel_class = None

View File

@@ -1,8 +1,12 @@
from __future__ import annotations
import collections.abc as c
import gzip
import io
import itertools
import os.path
import sys
from types import ModuleType
from flask import current_app
from markupsafe import Markup
@@ -19,14 +23,14 @@ except ImportError:
HAVE_PYGMENTS = False
try:
import sqlparse
import sqlparse # pyright: ignore
HAVE_SQLPARSE = True
except ImportError:
HAVE_SQLPARSE = False
def format_fname(value):
def format_fname(value: str) -> str:
# If the value has a builtin prefix, return it unchanged
if value.startswith(("{", "<")):
return value
@@ -46,12 +50,16 @@ def format_fname(value):
return f"<{_shortest_relative_path(value, sys.path, os.path)}>"
def _shortest_relative_path(value, paths, path_module):
def _shortest_relative_path(
value: str, paths: list[str], path_module: ModuleType
) -> str:
relpaths = _relative_paths(value, paths, path_module)
return min(itertools.chain(relpaths, [value]), key=len)
def _relative_paths(value, paths, path_module):
def _relative_paths(
value: str, paths: list[str], path_module: ModuleType
) -> c.Iterator[str]:
for path in paths:
try:
relval = path_module.relpath(value, path)
@@ -64,7 +72,7 @@ def _relative_paths(value, paths, path_module):
yield relval
def decode_text(value):
def decode_text(value: str | bytes) -> str:
"""
Decode a text-like value for display.
@@ -73,11 +81,11 @@ def decode_text(value):
"""
if isinstance(value, bytes):
return value.decode("ascii", "replace")
else:
return value
return value # pyright: ignore
def format_sql(query, args):
def format_sql(query: str | bytes, args: object) -> str:
if HAVE_SQLPARSE:
query = sqlparse.format(query, reindent=True, keyword_case="upper")
@@ -89,7 +97,7 @@ def format_sql(query, args):
)
def gzip_compress(data, compresslevel=6):
def gzip_compress(data: bytes, compresslevel: int = 6) -> bytes:
buff = io.BytesIO()
with gzip.GzipFile(fileobj=buff, mode="wb", compresslevel=compresslevel) as f:
@@ -98,6 +106,6 @@ def gzip_compress(data, compresslevel=6):
return buff.getvalue()
def gzip_decompress(data):
def gzip_decompress(data: bytes) -> bytes:
with gzip.GzipFile(fileobj=io.BytesIO(data), mode="rb") as f:
return f.read()

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
from flask import Flask
from flask import render_template
from flask_sqlalchemy import SQLAlchemy
@@ -21,13 +23,13 @@ toolbar = DebugToolbarExtension(app)
db = SQLAlchemy(app)
class Foo(db.Model):
class Foo(db.Model): # type: ignore[name-defined, misc]
__tablename__ = "foo"
id = db.Column(db.Integer, primary_key=True)
@app.route("/")
def index():
def index() -> str:
Foo.query.filter_by(id=1).all()
return render_template("basic_app.html")

View File

@@ -1,6 +1,8 @@
from __future__ import annotations
import pytest
@pytest.fixture(autouse=True)
def mock_env_development(monkeypatch):
def mock_env_development(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("FLASK_ENV", "development")

View File

@@ -1,10 +1,16 @@
def load_app(name):
app = __import__(name).app
from __future__ import annotations
from flask import Flask
from flask.testing import FlaskClient
def load_app(name: str) -> FlaskClient:
app: Flask = __import__(name).app
app.config["TESTING"] = True
return app.test_client()
def test_basic_app():
def test_basic_app() -> None:
app = load_app("basic_app")
index = app.get("/")
assert index.status_code == 200

View File

@@ -1,5 +1,8 @@
from __future__ import annotations
import ntpath
import posixpath
from types import ModuleType
import pytest
from markupsafe import escape
@@ -34,7 +37,9 @@ from flask_debugtoolbar.utils import HAVE_PYGMENTS
("c:\\Foo\\Bar", ["c:\\foo"], ["Bar"], ntpath),
],
)
def test_relative_paths(value, paths, expected, path_module):
def test_relative_paths(
value: str, paths: list[str], expected: list[str], path_module: ModuleType
) -> None:
assert list(_relative_paths(value, paths, path_module)) == expected
@@ -52,22 +57,24 @@ def test_relative_paths(value, paths, expected, path_module):
("c:\\foo\\bar\\baz", ["c:\\foo", "c:\\foo\\bar"], "baz", ntpath),
],
)
def test_shortest_relative_path(value, paths, expected, path_module):
def test_shortest_relative_path(
value: str, paths: list[str], expected: str, path_module: ModuleType
) -> None:
assert _shortest_relative_path(value, paths, path_module) == expected
def test_decode_text_unicode():
def test_decode_text_unicode() -> None:
value = "\uffff"
decoded = decode_text(value)
assert decoded == value
def test_decode_text_ascii():
def test_decode_text_ascii() -> None:
value = "abc"
assert decode_text(value.encode("ascii")) == value
def test_decode_text_non_ascii():
def test_decode_text_non_ascii() -> None:
value = b"abc \xff xyz"
assert isinstance(value, bytes)
@@ -79,22 +86,25 @@ def test_decode_text_non_ascii():
@pytest.fixture()
def no_pygments(monkeypatch):
def no_pygments(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr("flask_debugtoolbar.utils.HAVE_PYGMENTS", False)
def test_format_sql_no_pygments(no_pygments):
@pytest.mark.usefixtures("no_pygments")
def test_format_sql_no_pygments() -> None:
sql = "select 1"
assert format_sql(sql, {}) == sql
def test_format_sql_no_pygments_non_ascii(no_pygments):
@pytest.mark.usefixtures("no_pygments")
def test_format_sql_no_pygments_non_ascii() -> None:
sql = b"select '\xff'"
formatted = format_sql(sql, {})
assert formatted.startswith("select '")
def test_format_sql_no_pygments_escape_html(no_pygments):
@pytest.mark.usefixtures("no_pygments")
def test_format_sql_no_pygments_escape_html() -> None:
sql = "select x < 1"
formatted = format_sql(sql, {})
assert not isinstance(formatted, Markup)
@@ -102,7 +112,7 @@ def test_format_sql_no_pygments_escape_html(no_pygments):
@pytest.mark.skipif(not HAVE_PYGMENTS, reason='test requires the "Pygments" library')
def test_format_sql_pygments():
def test_format_sql_pygments() -> None:
sql = "select 1"
html = format_sql(sql, {})
assert isinstance(html, Markup)
@@ -112,7 +122,7 @@ def test_format_sql_pygments():
@pytest.mark.skipif(not HAVE_PYGMENTS, reason='test requires the "Pygments" library')
def test_format_sql_pygments_non_ascii():
def test_format_sql_pygments_non_ascii() -> None:
sql = b"select 'abc \xff xyz'"
html = format_sql(sql, {})
assert isinstance(html, Markup)

View File

@@ -3,6 +3,7 @@ envlist =
py3{12,11,10,9,8}
minimal
style
typing
docs
skip_missing_interpreters = true
@@ -41,6 +42,7 @@ skip_install = true
commands = pre-commit autoupdate -j4
[testenv:update-requirements]
base_python = 3.8
labels = update
deps = pip-tools
skip_install = true