Fix Keep-Alive TCP connections (#18)

* Fix using SO_REUSEPORT flag (Win32 not supported)

* Fix memleak on close peer connection

* Fix reinit Request struct into on_message_begin

Fix reinit wsgi_input object

* Fix define global objects (transfer to server.c)

* Integrate Request struct with client_t struct

* Add support Keep-Alive TCP connections

* Fix build with local libuv sources

* Improve function set_header (allow zero length argument)

* Fix write response buffer to tcp stream

* Add function shutdown_connection
This commit is contained in:
Oleg S
2022-10-14 17:32:10 +03:00
committed by GitHub
parent c2775c06cd
commit 66b601ea22
7 changed files with 178 additions and 120 deletions

View File

@@ -1,3 +1,3 @@
server: fastwsgi/server.c
mkdir -p bin
gcc -Illhttp/include llhttp/src/*.c fastwsgi/request.c fastwsgi/server.c fastwsgi/constants.c -o bin/server -luv -I/usr/include/python3.8 -lpython3.8 -O3
gcc -Illhttp/include -Ilibuv/include -Ilibuv/src llhttp/src/*.c fastwsgi/request.c fastwsgi/server.c fastwsgi/constants.c -o bin/server -luv -I/usr/include/python3.8 -lpython3.8 -O3

View File

@@ -1,3 +1,6 @@
#ifndef FASTWSGI_CONSTANTS_H_
#define FASTWSGI_CONSTANTS_H_
#include <Python.h>
PyObject* REQUEST_METHOD, * SCRIPT_NAME, * SERVER_NAME, * SERVER_PORT, * SERVER_PROTOCOL, * QUERY_STRING;
@@ -6,4 +9,6 @@ PyObject* http_scheme, * HTTP_1_1, * HTTP_1_0;
PyObject* server_host, * server_port, * empty_string;
PyObject* HTTP_, * PATH_INFO, * wsgi_input;
void init_constants();
void init_constants();
#endif

View File

@@ -1,3 +1,6 @@
#ifndef FASTWSGI_FILEWRAPPER_H_
#define FASTWSGI_FILEWRAPPER_H_
#include <Python.h>
typedef struct {
@@ -10,4 +13,6 @@ PyTypeObject FileWrapper_Type;
#define FileWrapper_CheckExact(object) ((object)->ob_type == &FileWrapper_Type)
void FileWrapper_Init(void);
void FileWrapper_Init(void);
#endif

View File

@@ -17,7 +17,8 @@ static void reprint(PyObject* obj) {
static void set_header(PyObject* headers, const char* key, const char* value, size_t length) {
logger("setting header");
PyObject* item = PyUnicode_FromStringAndSize(value, length);
int vlen = (length > 0) ? (int)length : (int)strlen(value);
PyObject* item = PyUnicode_FromStringAndSize(value, vlen);
PyObject* existing_item = PyDict_GetItemString(headers, key);
if (existing_item) {
@@ -38,15 +39,41 @@ static void set_header(PyObject* headers, const char* key, const char* value, si
int on_message_begin(llhttp_t* parser) {
logger("on message begin");
Request* request = (Request*)parser->data;
request->headers = PyDict_Copy(base_dict);
request->response_buffer.len = 0;
client_t * client = (client_t *)parser->data;
client->request.state.keep_alive = 0;
client->request.state.error = 0;
if (client->response.buffer.base)
free(client->response.buffer.base);
client->response.buffer.base = NULL;
client->response.buffer.len = 0;
if (client->request.headers == NULL) {
PyObject* headers = PyDict_Copy(base_dict);
// Sets up base request dict for new incoming requests
// https://www.python.org/dev/peps/pep-3333/#specification-details
PyObject* io = PyImport_ImportModule("io");
PyObject* BytesIO = PyUnicode_FromString("BytesIO");
PyObject* io_BytesIO = PyObject_CallMethodObjArgs(io, BytesIO, NULL);
PyDict_SetItem(headers, wsgi_input, io_BytesIO);
client->request.headers = headers;
Py_DECREF(BytesIO);
Py_DECREF(io);
} else {
PyObject* input = PyDict_GetItem(client->request.headers, wsgi_input);
PyObject* truncate = PyUnicode_FromString("truncate");
PyObject* result1 = PyObject_CallMethodObjArgs(input, truncate, PyLong_FromLong(0L), NULL);
Py_DECREF(truncate);
Py_DECREF(result1);
PyObject* seek = PyUnicode_FromString("seek");
PyObject* result2 = PyObject_CallMethodObjArgs(input, seek, PyLong_FromLong(0L), NULL);
Py_DECREF(seek);
Py_DECREF(result2);
}
return 0;
};
int on_url(llhttp_t* parser, const char* data, size_t length) {
logger("on url");
Request* request = (Request*)parser->data;
client_t * client = (client_t *)parser->data;
char* url = malloc(length + 1);
strncpy(url, data, length);
@@ -55,9 +82,9 @@ int on_url(llhttp_t* parser, const char* data, size_t length) {
char* query_string = strchr(url, '?');
if (query_string) {
*query_string = 0;
set_header(request->headers, "QUERY_STRING", query_string + 1, strlen(query_string + 1));
set_header(client->request.headers, "QUERY_STRING", query_string + 1, strlen(query_string + 1));
}
set_header(request->headers, "PATH_INFO", url, strlen(url));
set_header(client->request.headers, "PATH_INFO", url, strlen(url));
free(url);
return 0;
@@ -65,9 +92,9 @@ int on_url(llhttp_t* parser, const char* data, size_t length) {
int on_body(llhttp_t* parser, const char* body, size_t length) {
logger("on body");
Request* request = (Request*)parser->data;
client_t * client = (client_t *)parser->data;
PyObject* input = PyDict_GetItem(request->headers, wsgi_input);
PyObject* input = PyDict_GetItem(client->request.headers, wsgi_input);
PyObject* write = PyUnicode_FromString("write");
PyObject* body_content = PyBytes_FromStringAndSize(body, length);
@@ -81,12 +108,13 @@ int on_body(llhttp_t* parser, const char* body, size_t length) {
int on_header_field(llhttp_t* parser, const char* header, size_t length) {
logger("on header field");
client_t * client = (client_t *)parser->data;
char* upperHeader = malloc(length + 1);
for (size_t i = 0; i < length; i++) {
char current = header[i];
if (current == '_') {
current_header = NULL; // CVE-2015-0219
client->request.current_header = NULL; // CVE-2015-0219
return 0;
}
if (current == '-') {
@@ -97,14 +125,14 @@ int on_header_field(llhttp_t* parser, const char* header, size_t length) {
}
}
upperHeader[length] = 0;
char* old_header = current_header;
char* old_header = client->request.current_header;
if ((strcmp(upperHeader, "CONTENT_LENGTH") == 0) || (strcmp(upperHeader, "CONTENT_TYPE") == 0)) {
current_header = upperHeader;
client->request.current_header = upperHeader;
}
else {
current_header = malloc(strlen(upperHeader) + 5);
sprintf(current_header, "HTTP_%s", upperHeader);
client->request.current_header = malloc(strlen(upperHeader) + 5);
sprintf(client->request.current_header, "HTTP_%s", upperHeader);
}
if (old_header)
@@ -115,9 +143,9 @@ int on_header_field(llhttp_t* parser, const char* header, size_t length) {
int on_header_value(llhttp_t* parser, const char* value, size_t length) {
logger("on header value");
if (current_header != NULL) {
Request* request = (Request*)parser->data;
set_header(request->headers, current_header, value, length);
client_t * client = (client_t *)parser->data;
if (client->request.current_header != NULL) {
set_header(client->request.headers, client->request.current_header, value, length);
}
return 0;
};
@@ -170,10 +198,11 @@ PyObject* extract_response(PyObject* wsgi_response) {
int on_message_complete(llhttp_t* parser) {
logger("on message complete");
Request* request = (Request*)parser->data;
client_t * client = (client_t *)parser->data;
PyObject * headers = client->request.headers;
// Sets the input byte stream position back to 0
PyObject* body = PyDict_GetItem(request->headers, wsgi_input);
PyObject* body = PyDict_GetItem(headers, wsgi_input);
PyObject* seek = PyUnicode_FromString("seek");
PyObject* res = PyObject_CallMethodObjArgs(body, seek, PyLong_FromLong(0L), NULL);
Py_DECREF(res);
@@ -187,16 +216,16 @@ int on_message_complete(llhttp_t* parser) {
logger("calling wsgi application");
PyObject* wsgi_response;
wsgi_response = PyObject_CallFunctionObjArgs(
wsgi_app, request->headers, start_response, NULL
wsgi_app, headers, start_response, NULL
);
logger("called wsgi application");
if (PyErr_Occurred()) {
request->state.error = 1;
client->request.state.error = 1;
PyErr_Print();
}
if (request->state.error == 0) {
if (client->request.state.error == 0) {
PyObject* response_body = extract_response(wsgi_response);
if (response_body != NULL) {
build_response(response_body, start_response, parser);
@@ -206,7 +235,7 @@ int on_message_complete(llhttp_t* parser) {
// FIXME: Try to not repeat this block in this method
if (PyErr_Occurred()) {
request->state.error = 1;
client->request.state.error = 1;
PyErr_Print();
}
@@ -216,7 +245,7 @@ int on_message_complete(llhttp_t* parser) {
Py_CLEAR(start_response);
Py_CLEAR(wsgi_response);
Py_CLEAR(request->headers);
Py_CLEAR(client->request.headers);
return 0;
};
@@ -224,7 +253,7 @@ void build_response(PyObject* response_body, StartResponse* response, llhttp_t*
// This function needs a clean up
logger("building response");
Request* request = (Request*)parser->data;
client_t * client = (client_t *)parser->data;
int response_has_no_content = 0;
@@ -241,7 +270,7 @@ void build_response(PyObject* response_body, StartResponse* response, llhttp_t*
char* connection_header = "\r\nConnection: close";
if (llhttp_should_keep_alive(parser)) {
connection_header = "\r\nConnection: Keep-Alive";
request->state.keep_alive = 1;
client->request.state.keep_alive = 1;
}
char* old_buf = buf;
buf = malloc(strlen(old_buf) + strlen(connection_header));
@@ -296,36 +325,30 @@ void build_response(PyObject* response_body, StartResponse* response, llhttp_t*
}
logger(buf);
request->response_buffer.base = buf;
request->response_buffer.len = strlen(buf);
client->response.buffer.base = buf;
client->response.buffer.len = strlen(buf);
}
void build_wsgi_environ(llhttp_t* parser) {
logger("building wsgi environ");
Request* request = (Request*)parser->data;
client_t * client = (client_t *)parser->data;
PyObject * headers = client->request.headers;
const char* method = llhttp_method_name(parser->method);
set_header(headers, "REQUEST_METHOD", method, 0);
const char* protocol = parser->http_minor == 1 ? "HTTP/1.1" : "HTTP/1.0";
const char* remote_addr = request->remote_addr;
set_header(request->headers, "REQUEST_METHOD", method, strlen(method));
set_header(request->headers, "SERVER_PROTOCOL", protocol, strlen(protocol));
set_header(request->headers, "REMOTE_ADDR", remote_addr, strlen(remote_addr));
set_header(headers, "SERVER_PROTOCOL", protocol, 0);
set_header(headers, "REMOTE_ADDR", client->remote_addr, 0);
}
void init_request_dict() {
// Sets up base request dict for new incoming requests
// https://www.python.org/dev/peps/pep-3333/#specification-details
PyObject* io = PyImport_ImportModule("io");
PyObject* BytesIO = PyUnicode_FromString("BytesIO");
PyObject* io_BytesIO = PyObject_CallMethodObjArgs(io, BytesIO, NULL);
// only constant values!!!
base_dict = PyDict_New();
PyDict_SetItem(base_dict, SCRIPT_NAME, empty_string);
PyDict_SetItem(base_dict, SERVER_NAME, server_host);
PyDict_SetItem(base_dict, SERVER_PORT, server_port);
PyDict_SetItem(base_dict, wsgi_input, io_BytesIO);
//PyDict_SetItem(base_dict, wsgi_input, io_BytesIO); // not const!!!
PyDict_SetItem(base_dict, wsgi_version, version);
PyDict_SetItem(base_dict, wsgi_url_scheme, http_scheme);
PyDict_SetItem(base_dict, wsgi_errors, PySys_GetObject("stderr"));

View File

@@ -1,3 +1,6 @@
#ifndef FASTWSGI_REQUEST_H_
#define FASTWSGI_REQUEST_H_
#ifdef _MSC_VER
// strncasecmp is not available on Windows
#define strncasecmp _strnicmp
@@ -6,25 +9,13 @@
#include "start_response.h"
typedef struct {
int error;
int keep_alive;
} RequestState;
typedef struct {
PyObject* headers;
char remote_addr[17];
llhttp_t parser;
uv_buf_t response_buffer;
RequestState state;
} Request;
PyObject* base_dict;
void init_request_dict();
void build_wsgi_environ(llhttp_t* parser);
void build_response(PyObject* wsgi_response, StartResponse* response, llhttp_t* parser);
char* current_header;
llhttp_settings_t parser_settings;
void configure_parser_settings();
#endif

View File

@@ -3,12 +3,24 @@
#include <string.h>
#include "uv.h"
#include "uv-common.h"
#include "llhttp.h"
#include "server.h"
#include "request.h"
#include "constants.h"
PyObject* wsgi_app;
char* host;
int port;
int backlog;
uv_tcp_t server;
uv_loop_t* loop;
uv_os_fd_t file_descriptor;
struct sockaddr_in addr;
static const char* BAD_REQUEST = "HTTP/1.1 400 Bad Request\r\n\r\n";
static const char* INTERNAL_ERROR = "HTTP/1.1 500 Internal Server Error\r\n\r\n";
@@ -19,17 +31,24 @@ void logger(char* message) {
void close_cb(uv_handle_t* handle) {
logger("disconnected");
free(handle);
}
void shutdown_cb(uv_shutdown_t* req, int status) {
uv_handle_t* handle = (uv_handle_t*)req->handle;
if (!uv_is_closing(handle))
uv_close(handle, close_cb);
free(req);
client_t * client = (client_t *)handle->data;
Py_XDECREF(client->request.headers);
if (client->response.buffer.base)
free(client->response.buffer.base);
free(client);
}
void close_connection(uv_stream_t* handle) {
if (!uv_is_closing((uv_handle_t*)handle))
uv_close((uv_handle_t*)handle, close_cb);
}
void shutdown_cb(uv_shutdown_t* req, int status) {
close_connection(req->handle);
free(req);
}
void shutdown_connection(uv_stream_t* handle) {
uv_shutdown_t* shutdown = malloc(sizeof(uv_shutdown_t));
uv_shutdown(shutdown, handle, shutdown_cb);
}
@@ -38,67 +57,66 @@ void write_cb(uv_write_t* req, int status) {
if (status) {
fprintf(stderr, "Write error %s\n", uv_strerror(status));
}
write_req_t* write_req = (write_req_t*)req;
free(write_req->buf.base);
free(write_req);
//write_req_t* write_req = (write_req_t*)req;
free(req);
}
void send_error(write_req_t* req, uv_stream_t* handle, const char* error_string) {
char* error = malloc(strlen(error_string) + 1);
strcpy(error, error_string);
req->buf = uv_buf_init(error, strlen(error));
void stream_write(uv_stream_t* handle, const void* data, size_t size) {
if (!data || size == 0)
return;
size_t req_size = _Py_SIZE_ROUND_UP(sizeof(write_req_t), 16);
write_req_t* req = (write_req_t*)malloc(req_size + size);
req->buf.base = (char *)req + req_size;
req->buf.len = size;
memcpy(req->buf.base, data, size);
uv_write((uv_write_t*)req, handle, &req->buf, 1, write_cb);
close_connection(handle);
}
void send_response(write_req_t* req, uv_stream_t* handle, Request* request) {
uv_buf_t response = request->response_buffer;
req->buf = uv_buf_init(response.base, response.len);
uv_write((uv_write_t*)req, handle, &req->buf, 1, write_cb);
if (!request->state.keep_alive)
close_connection(handle);
void send_error(uv_stream_t* handle, const char* error_string) {
stream_write(handle, error_string, strlen(error_string));
shutdown_connection(handle); // fixme: maybe check keep_alive???
}
void send_response(uv_stream_t* handle, client_t* client) {
uv_buf_t * resbuf = &client->response.buffer;
stream_write(handle, resbuf->base, resbuf->len);
if (!client->request.state.keep_alive)
shutdown_connection(handle);
}
void read_cb(uv_stream_t* handle, ssize_t nread, const uv_buf_t* buf) {
int continue_read = 0;
client_t* client = (client_t*)handle->data;
Request* request = malloc(sizeof(Request));
request->state.keep_alive = 0;
request->state.error = 0;
strcpy(request->remote_addr, client->remote_addr);
client->parser.data = request;
write_req_t* req = (write_req_t*)malloc(sizeof(write_req_t));
llhttp_t * parser = &client->request.parser;
if (nread > 0) {
enum llhttp_errno err = llhttp_execute(&client->parser, buf->base, nread);
enum llhttp_errno err = llhttp_execute(parser, buf->base, nread);
if (err == HPE_OK) {
logger("Successfully parsed");
if (request->response_buffer.len > 0)
send_response(req, handle, request);
else if (request->state.error)
send_error(req, handle, INTERNAL_ERROR);
if (client->response.buffer.len > 0)
send_response(handle, client);
else if (client->request.state.error)
send_error(handle, INTERNAL_ERROR);
else
send_error(req, handle, BAD_REQUEST);
continue_read = 1;
}
else {
fprintf(stderr, "Parse error: %s %s\n", llhttp_errno_name(err), client->parser.reason);
send_error(req, handle, BAD_REQUEST);
fprintf(stderr, "Parse error: %s %s\n", llhttp_errno_name(err), client->request.parser.reason);
send_error(handle, BAD_REQUEST);
}
}
if (nread < 0) {
uv_read_stop(handle);
if (nread == UV_ECONNRESET) {
close_connection(handle);
}
else if (nread != UV_EOF) {
fprintf(stderr, "Read error %s\n", uv_err_name(nread));
if (nread == UV_EOF) { // remote peer disconnected
close_connection(handle);
} else {
if (nread != UV_ECONNRESET)
fprintf(stderr, "Read error %s\n", uv_err_name(nread));
shutdown_connection(handle);
}
}
free(request);
llhttp_reset(&client->parser);
if (buf->base)
free(buf->base);
@@ -119,7 +137,7 @@ void connection_cb(uv_stream_t* server, int status) {
return;
}
client_t* client = malloc(sizeof(client_t));
client_t* client = calloc(1, sizeof(client_t));
uv_tcp_init(loop, &client->handle);
uv_tcp_nodelay(&client->handle, 0);
@@ -135,7 +153,8 @@ void connection_cb(uv_stream_t* server, int status) {
client->handle.data = client;
if (uv_accept(server, (uv_stream_t*)&client->handle) == 0) {
llhttp_init(&client->parser, HTTP_REQUEST, &parser_settings);
llhttp_init(&client->request.parser, HTTP_REQUEST, &parser_settings);
client->request.parser.data = client;
uv_read_start((uv_stream_t*)&client->handle, alloc_cb, read_cb);
}
else {
@@ -165,8 +184,12 @@ int main() {
uv_fileno((const uv_handle_t*)&server, &file_descriptor);
int enabled = 1;
int so_reuseport = 15;
setsockopt(file_descriptor, SOL_SOCKET, so_reuseport, &enabled, sizeof(&enabled));
#ifdef _WIN32
//uv__socket_sockopt((uv_handle_t*)&server, SO_REUSEADDR, &enabled);
#else
int so_reuseport = 15; // SO_REUSEPORT
uv__socket_sockopt((uv_handle_t*)&server, so_reuseport, &enabled);
#endif
int err = uv_tcp_bind(&server, (const struct sockaddr*)&addr, 0);
if (err) {

View File

@@ -1,18 +1,14 @@
#ifndef FASTWSGI_SERVER_H_
#define FASTWSGI_SERVER_H_
#include <Python.h>
#include "uv.h"
#include "uv-common.h"
#include "llhttp.h"
#include "request.h"
PyObject* wsgi_app;
char* host;
int port;
int backlog;
extern PyObject* wsgi_app;
uv_tcp_t server;
uv_buf_t response_buf;
uv_loop_t* loop;
uv_os_fd_t file_descriptor;
struct sockaddr_in addr;
typedef struct {
uv_write_t req;
@@ -20,12 +16,27 @@ typedef struct {
} write_req_t;
typedef struct {
uv_tcp_t handle;
llhttp_t parser;
char remote_addr[17];
int error;
int keep_alive;
} RequestState;
typedef struct {
uv_tcp_t handle; // peer connection
char remote_addr[24];
struct {
PyObject* headers;
char* current_header;
llhttp_t parser;
RequestState state;
} request;
struct {
uv_buf_t buffer;
} response;
} client_t;
PyObject* run_server(PyObject* self, PyObject* args);
int LOGGING_ENABLED;
void logger(char* message);
void logger(char* message);
#endif