improvements to arequests

This commit is contained in:
Jakob Pinterits
2024-11-12 16:51:52 +01:00
parent 595eff08d0
commit e8829abb81
2 changed files with 156 additions and 28 deletions

View File

@@ -10,9 +10,43 @@ import json as json_module
import ssl
import typing as t
import urllib.error
import urllib.parse
import urllib.request
import urllib.response
# Re-export JSONDecodeError, since it can be raised by this module
from json import JSONDecodeError as JSONDecodeError
__all__ = [
"HttpError",
"HttpResponse",
"JSONDecodeError",
"request",
"request_sync",
]
HttpMethod = t.Literal[
"get",
"GET",
"head",
"HEAD",
"options",
"OPTIONS",
"trace",
"TRACE",
"put",
"PUT",
"delete",
"DELETE",
"post",
"POST",
"patch",
"PATCH",
"connect",
"CONNECT",
]
HTTP_METHOD_VALUES = t.get_args(HttpMethod)
class HttpError(Exception):
@@ -25,6 +59,10 @@ class HttpError(Exception):
## Attributes
`method`: The HTTP method used in the failed request
`url`: The target URL of the failed request
`message`: A human-readable error message
`status_code`: The HTTP status code, if applicable. This will be `None` if
@@ -34,19 +72,35 @@ class HttpError(Exception):
def __init__(
self,
message: str,
*,
method: HttpMethod,
url: str,
status_code: int | None,
message: str,
) -> None:
super().__init__(message, status_code)
super().__init__(
method,
url,
status_code,
message,
)
@property
def message(self) -> str:
def method(self) -> str:
return self.args[0]
@property
def status_code(self) -> int | None:
def url(self) -> str:
return self.args[1]
@property
def status_code(self) -> int | None:
return self.args[2]
@property
def message(self) -> str:
return self.args[3]
class HttpResponse:
"""
@@ -80,8 +134,15 @@ class HttpResponse:
def json(self) -> t.Any:
"""
Returns the response body as a JSON object. Raises a
`json.JSONDecodeError` if the response body is not valid JSON.
Returns the response body as a JSON object.
## Raises
`json.JSONDecodeError`: If the response body is not valid JSON. To
simplify your error handling, this will also consider invalid UTF-8
to be invalid JSON. (Which is in line with the JSON spec, which
requires all JSON documents to be encoded in UTF-8.)
"""
try:
@@ -95,7 +156,7 @@ class HttpResponse:
def request_sync(
method: t.Literal["get", "post"],
method: HttpMethod,
url: str,
*,
content: str | bytes | None = None,
@@ -118,7 +179,9 @@ def request_sync(
`None` to make a request without a body.
`json`: Convenience parameter to specify the body of the request as a JSON
object. If this is specified, `content` must be `None`.
object. This will also change the `content-type` header to
`"application/json"`. If this parameter is specified, `content` must be
`None`.
`headers`: Additional headers to include in the request. Any headers
provided here will override the default headers.
@@ -130,43 +193,55 @@ def request_sync(
`ValueError`: If the `method` is not one of `"get"` or `"post"`
`HttpError`: If the request fails for any reason
`HttpError`: If the request fails for any reason. This includes external
errors such as network issues, as well as any non-200 status codes.
"""
# Verify the method
if method not in ("get", "post"):
if method not in HTTP_METHOD_VALUES:
raise ValueError("Invalid method")
# Prepare the full headers
# Prepare a set of default headers
all_headers = {
"user-agent": "rio.arequests/0.1",
"content-type": "application/octet-stream",
}
if headers:
for header, value in headers.items():
all_headers[header.lower()] = value
# Prepare the request
req = urllib.request.Request(
url,
method=method.upper(),
)
for key, value in all_headers.items():
req.add_header(key, value)
# If a JSON object was provided, use it as content and also set the
# content-type header
if json:
if content is not None:
raise ValueError("Cannot specify both `content` and `json`")
content = json_module.dumps(json)
all_headers["content-type"] = "application/json"
# If content was provided add it to the request
if content:
if isinstance(content, str):
content = content.encode("utf-8")
req.data = content
# Add the headers
#
# User-provided headers override the default headers. Take care to do this
# late so that the logic above (e.g. JSON) can modify the default headers.
if headers:
for header, value in headers.items():
all_headers[header.lower()] = value
for key, value in all_headers.items():
req.add_header(key, value)
# Prepare the SSL context
#
# The default context verifies SSL. If that is not desired, override it.
if verify_ssl:
ssl_context = None
else:
@@ -180,8 +255,10 @@ def request_sync(
# Check the status code
if response.status >= 300:
raise HttpError(
response.reason,
response.status,
method=method,
url=url,
status_code=response.status,
message=response.reason,
)
# Epic success!
@@ -195,19 +272,23 @@ def request_sync(
except urllib.error.HTTPError as e:
raise HttpError(
e.reason,
e.code,
method=method,
url=url,
status_code=e.code,
message=e.reason,
) from None
except urllib.error.URLError as e:
raise HttpError(
str(e.reason),
None,
method=method,
url=url,
status_code=None,
message=str(e.reason),
) from None
async def request(
method: t.Literal["get", "post"],
method: HttpMethod,
url: str,
*,
content: bytes | None = None,
@@ -242,7 +323,8 @@ async def request(
`ValueError`: If the `method` is not one of `"get"` or `"post"`
`HttpError`: If the request fails for any reason
`HttpError`: If the request fails for any reason. This includes external
errors such as network issues, as well as any non-200 status codes.
"""
return await asyncio.to_thread(

46
tests/test_arequests.py Normal file
View File

@@ -0,0 +1,46 @@
import json
import pytest
import rio.arequests as arequests
def test_http_response_invalid_json() -> None:
response = arequests.HttpResponse(
status_code=200,
headers={},
content=b"invalid json",
)
with pytest.raises(json.JSONDecodeError, match="Expecting value"):
response.json()
def test_http_response_invalid_utf8() -> None:
response = arequests.HttpResponse(
status_code=200,
headers={},
content=b"\xff",
)
with pytest.raises(json.JSONDecodeError, match="UTF-8"):
response.json()
def test_request() -> None:
response = arequests.request_sync(
"get",
"https://postman-echo.com/get",
json={"foo": "bar"},
)
assert response.status_code == 200
assert response.headers["content-type"] == "application/json; charset=utf-8"
response_json = response.json()
assert response_json["headers"]["user-agent"] == "rio.arequests/0.1"
assert response_json["headers"]["host"] == "postman-echo.com"
assert response_json["headers"]["content-type"] == "application/json"
assert response_json["headers"]["content-length"] == "14"
assert response_json["args"] == {"foo": "bar"}