From 66b601ea22cf022abbbe881ce4d2fe736173d5f4 Mon Sep 17 00:00:00 2001 From: Oleg S Date: Fri, 14 Oct 2022 17:32:10 +0300 Subject: [PATCH] Fix Keep-Alive TCP connections (#18) * Fix using SO_REUSEPORT flag (Win32 not supported) * Fix memleak on close peer connection * Fix reinit Request struct into on_message_begin Fix reinit wsgi_input object * Fix define global objects (transfer to server.c) * Integrate Request struct with client_t struct * Add support Keep-Alive TCP connections * Fix build with local libuv sources * Improve function set_header (allow zero length argument) * Fix write response buffer to tcp stream * Add function shutdown_connection --- Makefile | 2 +- fastwsgi/constants.h | 7 ++- fastwsgi/filewrapper.h | 7 ++- fastwsgi/request.c | 105 ++++++++++++++++++++++-------------- fastwsgi/request.h | 19 ++----- fastwsgi/server.c | 119 ++++++++++++++++++++++++----------------- fastwsgi/server.h | 39 +++++++++----- 7 files changed, 178 insertions(+), 120 deletions(-) diff --git a/Makefile b/Makefile index 55caabd..962e34b 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,3 @@ server: fastwsgi/server.c mkdir -p bin - gcc -Illhttp/include llhttp/src/*.c fastwsgi/request.c fastwsgi/server.c fastwsgi/constants.c -o bin/server -luv -I/usr/include/python3.8 -lpython3.8 -O3 \ No newline at end of file + gcc -Illhttp/include -Ilibuv/include -Ilibuv/src llhttp/src/*.c fastwsgi/request.c fastwsgi/server.c fastwsgi/constants.c -o bin/server -luv -I/usr/include/python3.8 -lpython3.8 -O3 \ No newline at end of file diff --git a/fastwsgi/constants.h b/fastwsgi/constants.h index 8711a37..04b529e 100644 --- a/fastwsgi/constants.h +++ b/fastwsgi/constants.h @@ -1,3 +1,6 @@ +#ifndef FASTWSGI_CONSTANTS_H_ +#define FASTWSGI_CONSTANTS_H_ + #include PyObject* REQUEST_METHOD, * SCRIPT_NAME, * SERVER_NAME, * SERVER_PORT, * SERVER_PROTOCOL, * QUERY_STRING; @@ -6,4 +9,6 @@ PyObject* http_scheme, * HTTP_1_1, * HTTP_1_0; PyObject* server_host, * server_port, * empty_string; PyObject* HTTP_, * PATH_INFO, * wsgi_input; -void init_constants(); \ No newline at end of file +void init_constants(); + +#endif diff --git a/fastwsgi/filewrapper.h b/fastwsgi/filewrapper.h index fb9e138..e52b760 100644 --- a/fastwsgi/filewrapper.h +++ b/fastwsgi/filewrapper.h @@ -1,3 +1,6 @@ +#ifndef FASTWSGI_FILEWRAPPER_H_ +#define FASTWSGI_FILEWRAPPER_H_ + #include typedef struct { @@ -10,4 +13,6 @@ PyTypeObject FileWrapper_Type; #define FileWrapper_CheckExact(object) ((object)->ob_type == &FileWrapper_Type) -void FileWrapper_Init(void); \ No newline at end of file +void FileWrapper_Init(void); + +#endif diff --git a/fastwsgi/request.c b/fastwsgi/request.c index a1d9586..ebabf79 100644 --- a/fastwsgi/request.c +++ b/fastwsgi/request.c @@ -17,7 +17,8 @@ static void reprint(PyObject* obj) { static void set_header(PyObject* headers, const char* key, const char* value, size_t length) { logger("setting header"); - PyObject* item = PyUnicode_FromStringAndSize(value, length); + int vlen = (length > 0) ? (int)length : (int)strlen(value); + PyObject* item = PyUnicode_FromStringAndSize(value, vlen); PyObject* existing_item = PyDict_GetItemString(headers, key); if (existing_item) { @@ -38,15 +39,41 @@ static void set_header(PyObject* headers, const char* key, const char* value, si int on_message_begin(llhttp_t* parser) { logger("on message begin"); - Request* request = (Request*)parser->data; - request->headers = PyDict_Copy(base_dict); - request->response_buffer.len = 0; + client_t * client = (client_t *)parser->data; + client->request.state.keep_alive = 0; + client->request.state.error = 0; + if (client->response.buffer.base) + free(client->response.buffer.base); + client->response.buffer.base = NULL; + client->response.buffer.len = 0; + if (client->request.headers == NULL) { + PyObject* headers = PyDict_Copy(base_dict); + // Sets up base request dict for new incoming requests + // https://www.python.org/dev/peps/pep-3333/#specification-details + PyObject* io = PyImport_ImportModule("io"); + PyObject* BytesIO = PyUnicode_FromString("BytesIO"); + PyObject* io_BytesIO = PyObject_CallMethodObjArgs(io, BytesIO, NULL); + PyDict_SetItem(headers, wsgi_input, io_BytesIO); + client->request.headers = headers; + Py_DECREF(BytesIO); + Py_DECREF(io); + } else { + PyObject* input = PyDict_GetItem(client->request.headers, wsgi_input); + PyObject* truncate = PyUnicode_FromString("truncate"); + PyObject* result1 = PyObject_CallMethodObjArgs(input, truncate, PyLong_FromLong(0L), NULL); + Py_DECREF(truncate); + Py_DECREF(result1); + PyObject* seek = PyUnicode_FromString("seek"); + PyObject* result2 = PyObject_CallMethodObjArgs(input, seek, PyLong_FromLong(0L), NULL); + Py_DECREF(seek); + Py_DECREF(result2); + } return 0; }; int on_url(llhttp_t* parser, const char* data, size_t length) { logger("on url"); - Request* request = (Request*)parser->data; + client_t * client = (client_t *)parser->data; char* url = malloc(length + 1); strncpy(url, data, length); @@ -55,9 +82,9 @@ int on_url(llhttp_t* parser, const char* data, size_t length) { char* query_string = strchr(url, '?'); if (query_string) { *query_string = 0; - set_header(request->headers, "QUERY_STRING", query_string + 1, strlen(query_string + 1)); + set_header(client->request.headers, "QUERY_STRING", query_string + 1, strlen(query_string + 1)); } - set_header(request->headers, "PATH_INFO", url, strlen(url)); + set_header(client->request.headers, "PATH_INFO", url, strlen(url)); free(url); return 0; @@ -65,9 +92,9 @@ int on_url(llhttp_t* parser, const char* data, size_t length) { int on_body(llhttp_t* parser, const char* body, size_t length) { logger("on body"); - Request* request = (Request*)parser->data; + client_t * client = (client_t *)parser->data; - PyObject* input = PyDict_GetItem(request->headers, wsgi_input); + PyObject* input = PyDict_GetItem(client->request.headers, wsgi_input); PyObject* write = PyUnicode_FromString("write"); PyObject* body_content = PyBytes_FromStringAndSize(body, length); @@ -81,12 +108,13 @@ int on_body(llhttp_t* parser, const char* body, size_t length) { int on_header_field(llhttp_t* parser, const char* header, size_t length) { logger("on header field"); + client_t * client = (client_t *)parser->data; char* upperHeader = malloc(length + 1); for (size_t i = 0; i < length; i++) { char current = header[i]; if (current == '_') { - current_header = NULL; // CVE-2015-0219 + client->request.current_header = NULL; // CVE-2015-0219 return 0; } if (current == '-') { @@ -97,14 +125,14 @@ int on_header_field(llhttp_t* parser, const char* header, size_t length) { } } upperHeader[length] = 0; - char* old_header = current_header; + char* old_header = client->request.current_header; if ((strcmp(upperHeader, "CONTENT_LENGTH") == 0) || (strcmp(upperHeader, "CONTENT_TYPE") == 0)) { - current_header = upperHeader; + client->request.current_header = upperHeader; } else { - current_header = malloc(strlen(upperHeader) + 5); - sprintf(current_header, "HTTP_%s", upperHeader); + client->request.current_header = malloc(strlen(upperHeader) + 5); + sprintf(client->request.current_header, "HTTP_%s", upperHeader); } if (old_header) @@ -115,9 +143,9 @@ int on_header_field(llhttp_t* parser, const char* header, size_t length) { int on_header_value(llhttp_t* parser, const char* value, size_t length) { logger("on header value"); - if (current_header != NULL) { - Request* request = (Request*)parser->data; - set_header(request->headers, current_header, value, length); + client_t * client = (client_t *)parser->data; + if (client->request.current_header != NULL) { + set_header(client->request.headers, client->request.current_header, value, length); } return 0; }; @@ -170,10 +198,11 @@ PyObject* extract_response(PyObject* wsgi_response) { int on_message_complete(llhttp_t* parser) { logger("on message complete"); - Request* request = (Request*)parser->data; + client_t * client = (client_t *)parser->data; + PyObject * headers = client->request.headers; // Sets the input byte stream position back to 0 - PyObject* body = PyDict_GetItem(request->headers, wsgi_input); + PyObject* body = PyDict_GetItem(headers, wsgi_input); PyObject* seek = PyUnicode_FromString("seek"); PyObject* res = PyObject_CallMethodObjArgs(body, seek, PyLong_FromLong(0L), NULL); Py_DECREF(res); @@ -187,16 +216,16 @@ int on_message_complete(llhttp_t* parser) { logger("calling wsgi application"); PyObject* wsgi_response; wsgi_response = PyObject_CallFunctionObjArgs( - wsgi_app, request->headers, start_response, NULL + wsgi_app, headers, start_response, NULL ); logger("called wsgi application"); if (PyErr_Occurred()) { - request->state.error = 1; + client->request.state.error = 1; PyErr_Print(); } - if (request->state.error == 0) { + if (client->request.state.error == 0) { PyObject* response_body = extract_response(wsgi_response); if (response_body != NULL) { build_response(response_body, start_response, parser); @@ -206,7 +235,7 @@ int on_message_complete(llhttp_t* parser) { // FIXME: Try to not repeat this block in this method if (PyErr_Occurred()) { - request->state.error = 1; + client->request.state.error = 1; PyErr_Print(); } @@ -216,7 +245,7 @@ int on_message_complete(llhttp_t* parser) { Py_CLEAR(start_response); Py_CLEAR(wsgi_response); - Py_CLEAR(request->headers); + Py_CLEAR(client->request.headers); return 0; }; @@ -224,7 +253,7 @@ void build_response(PyObject* response_body, StartResponse* response, llhttp_t* // This function needs a clean up logger("building response"); - Request* request = (Request*)parser->data; + client_t * client = (client_t *)parser->data; int response_has_no_content = 0; @@ -241,7 +270,7 @@ void build_response(PyObject* response_body, StartResponse* response, llhttp_t* char* connection_header = "\r\nConnection: close"; if (llhttp_should_keep_alive(parser)) { connection_header = "\r\nConnection: Keep-Alive"; - request->state.keep_alive = 1; + client->request.state.keep_alive = 1; } char* old_buf = buf; buf = malloc(strlen(old_buf) + strlen(connection_header)); @@ -296,36 +325,30 @@ void build_response(PyObject* response_body, StartResponse* response, llhttp_t* } logger(buf); - request->response_buffer.base = buf; - request->response_buffer.len = strlen(buf); + client->response.buffer.base = buf; + client->response.buffer.len = strlen(buf); } void build_wsgi_environ(llhttp_t* parser) { logger("building wsgi environ"); - Request* request = (Request*)parser->data; + client_t * client = (client_t *)parser->data; + PyObject * headers = client->request.headers; const char* method = llhttp_method_name(parser->method); + set_header(headers, "REQUEST_METHOD", method, 0); const char* protocol = parser->http_minor == 1 ? "HTTP/1.1" : "HTTP/1.0"; - const char* remote_addr = request->remote_addr; - - set_header(request->headers, "REQUEST_METHOD", method, strlen(method)); - set_header(request->headers, "SERVER_PROTOCOL", protocol, strlen(protocol)); - set_header(request->headers, "REMOTE_ADDR", remote_addr, strlen(remote_addr)); + set_header(headers, "SERVER_PROTOCOL", protocol, 0); + set_header(headers, "REMOTE_ADDR", client->remote_addr, 0); } void init_request_dict() { - // Sets up base request dict for new incoming requests - // https://www.python.org/dev/peps/pep-3333/#specification-details - PyObject* io = PyImport_ImportModule("io"); - PyObject* BytesIO = PyUnicode_FromString("BytesIO"); - PyObject* io_BytesIO = PyObject_CallMethodObjArgs(io, BytesIO, NULL); - + // only constant values!!! base_dict = PyDict_New(); PyDict_SetItem(base_dict, SCRIPT_NAME, empty_string); PyDict_SetItem(base_dict, SERVER_NAME, server_host); PyDict_SetItem(base_dict, SERVER_PORT, server_port); - PyDict_SetItem(base_dict, wsgi_input, io_BytesIO); + //PyDict_SetItem(base_dict, wsgi_input, io_BytesIO); // not const!!! PyDict_SetItem(base_dict, wsgi_version, version); PyDict_SetItem(base_dict, wsgi_url_scheme, http_scheme); PyDict_SetItem(base_dict, wsgi_errors, PySys_GetObject("stderr")); diff --git a/fastwsgi/request.h b/fastwsgi/request.h index 1334b14..3b247a2 100644 --- a/fastwsgi/request.h +++ b/fastwsgi/request.h @@ -1,3 +1,6 @@ +#ifndef FASTWSGI_REQUEST_H_ +#define FASTWSGI_REQUEST_H_ + #ifdef _MSC_VER // strncasecmp is not available on Windows #define strncasecmp _strnicmp @@ -6,25 +9,13 @@ #include "start_response.h" -typedef struct { - int error; - int keep_alive; -} RequestState; - -typedef struct { - PyObject* headers; - char remote_addr[17]; - llhttp_t parser; - uv_buf_t response_buffer; - RequestState state; -} Request; PyObject* base_dict; void init_request_dict(); void build_wsgi_environ(llhttp_t* parser); void build_response(PyObject* wsgi_response, StartResponse* response, llhttp_t* parser); -char* current_header; - llhttp_settings_t parser_settings; void configure_parser_settings(); + +#endif diff --git a/fastwsgi/server.c b/fastwsgi/server.c index 5bf5dc0..3d442a1 100644 --- a/fastwsgi/server.c +++ b/fastwsgi/server.c @@ -3,12 +3,24 @@ #include #include "uv.h" +#include "uv-common.h" #include "llhttp.h" #include "server.h" #include "request.h" #include "constants.h" +PyObject* wsgi_app; +char* host; +int port; +int backlog; + +uv_tcp_t server; +uv_loop_t* loop; +uv_os_fd_t file_descriptor; + +struct sockaddr_in addr; + static const char* BAD_REQUEST = "HTTP/1.1 400 Bad Request\r\n\r\n"; static const char* INTERNAL_ERROR = "HTTP/1.1 500 Internal Server Error\r\n\r\n"; @@ -19,17 +31,24 @@ void logger(char* message) { void close_cb(uv_handle_t* handle) { logger("disconnected"); - free(handle); -} - -void shutdown_cb(uv_shutdown_t* req, int status) { - uv_handle_t* handle = (uv_handle_t*)req->handle; - if (!uv_is_closing(handle)) - uv_close(handle, close_cb); - free(req); + client_t * client = (client_t *)handle->data; + Py_XDECREF(client->request.headers); + if (client->response.buffer.base) + free(client->response.buffer.base); + free(client); } void close_connection(uv_stream_t* handle) { + if (!uv_is_closing((uv_handle_t*)handle)) + uv_close((uv_handle_t*)handle, close_cb); +} + +void shutdown_cb(uv_shutdown_t* req, int status) { + close_connection(req->handle); + free(req); +} + +void shutdown_connection(uv_stream_t* handle) { uv_shutdown_t* shutdown = malloc(sizeof(uv_shutdown_t)); uv_shutdown(shutdown, handle, shutdown_cb); } @@ -38,67 +57,66 @@ void write_cb(uv_write_t* req, int status) { if (status) { fprintf(stderr, "Write error %s\n", uv_strerror(status)); } - write_req_t* write_req = (write_req_t*)req; - free(write_req->buf.base); - free(write_req); + //write_req_t* write_req = (write_req_t*)req; + free(req); } -void send_error(write_req_t* req, uv_stream_t* handle, const char* error_string) { - char* error = malloc(strlen(error_string) + 1); - strcpy(error, error_string); - req->buf = uv_buf_init(error, strlen(error)); +void stream_write(uv_stream_t* handle, const void* data, size_t size) { + if (!data || size == 0) + return; + size_t req_size = _Py_SIZE_ROUND_UP(sizeof(write_req_t), 16); + write_req_t* req = (write_req_t*)malloc(req_size + size); + req->buf.base = (char *)req + req_size; + req->buf.len = size; + memcpy(req->buf.base, data, size); uv_write((uv_write_t*)req, handle, &req->buf, 1, write_cb); - close_connection(handle); } -void send_response(write_req_t* req, uv_stream_t* handle, Request* request) { - uv_buf_t response = request->response_buffer; - req->buf = uv_buf_init(response.base, response.len); - uv_write((uv_write_t*)req, handle, &req->buf, 1, write_cb); - if (!request->state.keep_alive) - close_connection(handle); +void send_error(uv_stream_t* handle, const char* error_string) { + stream_write(handle, error_string, strlen(error_string)); + shutdown_connection(handle); // fixme: maybe check keep_alive??? +} + +void send_response(uv_stream_t* handle, client_t* client) { + uv_buf_t * resbuf = &client->response.buffer; + stream_write(handle, resbuf->base, resbuf->len); + if (!client->request.state.keep_alive) + shutdown_connection(handle); } void read_cb(uv_stream_t* handle, ssize_t nread, const uv_buf_t* buf) { + int continue_read = 0; client_t* client = (client_t*)handle->data; - - Request* request = malloc(sizeof(Request)); - request->state.keep_alive = 0; - request->state.error = 0; - strcpy(request->remote_addr, client->remote_addr); - client->parser.data = request; - write_req_t* req = (write_req_t*)malloc(sizeof(write_req_t)); + llhttp_t * parser = &client->request.parser; if (nread > 0) { - enum llhttp_errno err = llhttp_execute(&client->parser, buf->base, nread); + enum llhttp_errno err = llhttp_execute(parser, buf->base, nread); if (err == HPE_OK) { logger("Successfully parsed"); - if (request->response_buffer.len > 0) - send_response(req, handle, request); - else if (request->state.error) - send_error(req, handle, INTERNAL_ERROR); + if (client->response.buffer.len > 0) + send_response(handle, client); + else if (client->request.state.error) + send_error(handle, INTERNAL_ERROR); else - send_error(req, handle, BAD_REQUEST); + continue_read = 1; } else { - fprintf(stderr, "Parse error: %s %s\n", llhttp_errno_name(err), client->parser.reason); - send_error(req, handle, BAD_REQUEST); + fprintf(stderr, "Parse error: %s %s\n", llhttp_errno_name(err), client->request.parser.reason); + send_error(handle, BAD_REQUEST); } } if (nread < 0) { uv_read_stop(handle); - if (nread == UV_ECONNRESET) { - close_connection(handle); - } - else if (nread != UV_EOF) { - fprintf(stderr, "Read error %s\n", uv_err_name(nread)); + if (nread == UV_EOF) { // remote peer disconnected close_connection(handle); + } else { + if (nread != UV_ECONNRESET) + fprintf(stderr, "Read error %s\n", uv_err_name(nread)); + shutdown_connection(handle); } } - free(request); - llhttp_reset(&client->parser); if (buf->base) free(buf->base); @@ -119,7 +137,7 @@ void connection_cb(uv_stream_t* server, int status) { return; } - client_t* client = malloc(sizeof(client_t)); + client_t* client = calloc(1, sizeof(client_t)); uv_tcp_init(loop, &client->handle); uv_tcp_nodelay(&client->handle, 0); @@ -135,7 +153,8 @@ void connection_cb(uv_stream_t* server, int status) { client->handle.data = client; if (uv_accept(server, (uv_stream_t*)&client->handle) == 0) { - llhttp_init(&client->parser, HTTP_REQUEST, &parser_settings); + llhttp_init(&client->request.parser, HTTP_REQUEST, &parser_settings); + client->request.parser.data = client; uv_read_start((uv_stream_t*)&client->handle, alloc_cb, read_cb); } else { @@ -165,8 +184,12 @@ int main() { uv_fileno((const uv_handle_t*)&server, &file_descriptor); int enabled = 1; - int so_reuseport = 15; - setsockopt(file_descriptor, SOL_SOCKET, so_reuseport, &enabled, sizeof(&enabled)); +#ifdef _WIN32 + //uv__socket_sockopt((uv_handle_t*)&server, SO_REUSEADDR, &enabled); +#else + int so_reuseport = 15; // SO_REUSEPORT + uv__socket_sockopt((uv_handle_t*)&server, so_reuseport, &enabled); +#endif int err = uv_tcp_bind(&server, (const struct sockaddr*)&addr, 0); if (err) { diff --git a/fastwsgi/server.h b/fastwsgi/server.h index 501f276..4de7b7f 100644 --- a/fastwsgi/server.h +++ b/fastwsgi/server.h @@ -1,18 +1,14 @@ +#ifndef FASTWSGI_SERVER_H_ +#define FASTWSGI_SERVER_H_ + #include #include "uv.h" +#include "uv-common.h" #include "llhttp.h" +#include "request.h" -PyObject* wsgi_app; -char* host; -int port; -int backlog; +extern PyObject* wsgi_app; -uv_tcp_t server; -uv_buf_t response_buf; -uv_loop_t* loop; -uv_os_fd_t file_descriptor; - -struct sockaddr_in addr; typedef struct { uv_write_t req; @@ -20,12 +16,27 @@ typedef struct { } write_req_t; typedef struct { - uv_tcp_t handle; - llhttp_t parser; - char remote_addr[17]; + int error; + int keep_alive; +} RequestState; + +typedef struct { + uv_tcp_t handle; // peer connection + char remote_addr[24]; + struct { + PyObject* headers; + char* current_header; + llhttp_t parser; + RequestState state; + } request; + struct { + uv_buf_t buffer; + } response; } client_t; PyObject* run_server(PyObject* self, PyObject* args); int LOGGING_ENABLED; -void logger(char* message); \ No newline at end of file +void logger(char* message); + +#endif