From e8829abb81c0274029c460fd093215b4e86724e4 Mon Sep 17 00:00:00 2001 From: Jakob Pinterits Date: Tue, 12 Nov 2024 16:51:52 +0100 Subject: [PATCH] improvements to arequests --- rio/arequests.py | 138 ++++++++++++++++++++++++++++++++-------- tests/test_arequests.py | 46 ++++++++++++++ 2 files changed, 156 insertions(+), 28 deletions(-) create mode 100644 tests/test_arequests.py diff --git a/rio/arequests.py b/rio/arequests.py index 75ea31e6..c4493f52 100644 --- a/rio/arequests.py +++ b/rio/arequests.py @@ -10,9 +10,43 @@ import json as json_module import ssl import typing as t import urllib.error -import urllib.parse import urllib.request -import urllib.response + +# Re-export JSONDecodeError, since it can be raised by this module +from json import JSONDecodeError as JSONDecodeError + +__all__ = [ + "HttpError", + "HttpResponse", + "JSONDecodeError", + "request", + "request_sync", +] + + +HttpMethod = t.Literal[ + "get", + "GET", + "head", + "HEAD", + "options", + "OPTIONS", + "trace", + "TRACE", + "put", + "PUT", + "delete", + "DELETE", + "post", + "POST", + "patch", + "PATCH", + "connect", + "CONNECT", +] + + +HTTP_METHOD_VALUES = t.get_args(HttpMethod) class HttpError(Exception): @@ -25,6 +59,10 @@ class HttpError(Exception): ## Attributes + `method`: The HTTP method used in the failed request + + `url`: The target URL of the failed request + `message`: A human-readable error message `status_code`: The HTTP status code, if applicable. This will be `None` if @@ -34,19 +72,35 @@ class HttpError(Exception): def __init__( self, - message: str, + *, + method: HttpMethod, + url: str, status_code: int | None, + message: str, ) -> None: - super().__init__(message, status_code) + super().__init__( + method, + url, + status_code, + message, + ) @property - def message(self) -> str: + def method(self) -> str: return self.args[0] @property - def status_code(self) -> int | None: + def url(self) -> str: return self.args[1] + @property + def status_code(self) -> int | None: + return self.args[2] + + @property + def message(self) -> str: + return self.args[3] + class HttpResponse: """ @@ -80,8 +134,15 @@ class HttpResponse: def json(self) -> t.Any: """ - Returns the response body as a JSON object. Raises a - `json.JSONDecodeError` if the response body is not valid JSON. + Returns the response body as a JSON object. + + + ## Raises + + `json.JSONDecodeError`: If the response body is not valid JSON. To + simplify your error handling, this will also consider invalid UTF-8 + to be invalid JSON. (Which is in line with the JSON spec, which + requires all JSON documents to be encoded in UTF-8.) """ try: @@ -95,7 +156,7 @@ class HttpResponse: def request_sync( - method: t.Literal["get", "post"], + method: HttpMethod, url: str, *, content: str | bytes | None = None, @@ -118,7 +179,9 @@ def request_sync( `None` to make a request without a body. `json`: Convenience parameter to specify the body of the request as a JSON - object. If this is specified, `content` must be `None`. + object. This will also change the `content-type` header to + `"application/json"`. If this parameter is specified, `content` must be + `None`. `headers`: Additional headers to include in the request. Any headers provided here will override the default headers. @@ -130,43 +193,55 @@ def request_sync( `ValueError`: If the `method` is not one of `"get"` or `"post"` - `HttpError`: If the request fails for any reason + `HttpError`: If the request fails for any reason. This includes external + errors such as network issues, as well as any non-200 status codes. """ # Verify the method - if method not in ("get", "post"): + if method not in HTTP_METHOD_VALUES: raise ValueError("Invalid method") - # Prepare the full headers + # Prepare a set of default headers all_headers = { "user-agent": "rio.arequests/0.1", + "content-type": "application/octet-stream", } - if headers: - for header, value in headers.items(): - all_headers[header.lower()] = value - # Prepare the request req = urllib.request.Request( url, method=method.upper(), ) - for key, value in all_headers.items(): - req.add_header(key, value) - + # If a JSON object was provided, use it as content and also set the + # content-type header if json: if content is not None: raise ValueError("Cannot specify both `content` and `json`") content = json_module.dumps(json) + all_headers["content-type"] = "application/json" + # If content was provided add it to the request if content: if isinstance(content, str): content = content.encode("utf-8") req.data = content + # Add the headers + # + # User-provided headers override the default headers. Take care to do this + # late so that the logic above (e.g. JSON) can modify the default headers. + if headers: + for header, value in headers.items(): + all_headers[header.lower()] = value + + for key, value in all_headers.items(): + req.add_header(key, value) + # Prepare the SSL context + # + # The default context verifies SSL. If that is not desired, override it. if verify_ssl: ssl_context = None else: @@ -180,8 +255,10 @@ def request_sync( # Check the status code if response.status >= 300: raise HttpError( - response.reason, - response.status, + method=method, + url=url, + status_code=response.status, + message=response.reason, ) # Epic success! @@ -195,19 +272,23 @@ def request_sync( except urllib.error.HTTPError as e: raise HttpError( - e.reason, - e.code, + method=method, + url=url, + status_code=e.code, + message=e.reason, ) from None except urllib.error.URLError as e: raise HttpError( - str(e.reason), - None, + method=method, + url=url, + status_code=None, + message=str(e.reason), ) from None async def request( - method: t.Literal["get", "post"], + method: HttpMethod, url: str, *, content: bytes | None = None, @@ -242,7 +323,8 @@ async def request( `ValueError`: If the `method` is not one of `"get"` or `"post"` - `HttpError`: If the request fails for any reason + `HttpError`: If the request fails for any reason. This includes external + errors such as network issues, as well as any non-200 status codes. """ return await asyncio.to_thread( diff --git a/tests/test_arequests.py b/tests/test_arequests.py new file mode 100644 index 00000000..5217c9aa --- /dev/null +++ b/tests/test_arequests.py @@ -0,0 +1,46 @@ +import json + +import pytest + +import rio.arequests as arequests + + +def test_http_response_invalid_json() -> None: + response = arequests.HttpResponse( + status_code=200, + headers={}, + content=b"invalid json", + ) + + with pytest.raises(json.JSONDecodeError, match="Expecting value"): + response.json() + + +def test_http_response_invalid_utf8() -> None: + response = arequests.HttpResponse( + status_code=200, + headers={}, + content=b"\xff", + ) + + with pytest.raises(json.JSONDecodeError, match="UTF-8"): + response.json() + + +def test_request() -> None: + response = arequests.request_sync( + "get", + "https://postman-echo.com/get", + json={"foo": "bar"}, + ) + + assert response.status_code == 200 + assert response.headers["content-type"] == "application/json; charset=utf-8" + + response_json = response.json() + + assert response_json["headers"]["user-agent"] == "rio.arequests/0.1" + assert response_json["headers"]["host"] == "postman-echo.com" + assert response_json["headers"]["content-type"] == "application/json" + assert response_json["headers"]["content-length"] == "14" + assert response_json["args"] == {"foo": "bar"}