diff --git a/include/sqlgen/postgres/Iterator.hpp b/include/sqlgen/postgres/Iterator.hpp index 9034b44..00bcb93 100644 --- a/include/sqlgen/postgres/Iterator.hpp +++ b/include/sqlgen/postgres/Iterator.hpp @@ -41,6 +41,15 @@ class SQLGEN_API Iterator { Iterator& operator=(Iterator&& _other) noexcept; + static rfl::Result> make(const std::string& _sql, + const Conn& _conn) noexcept { + try { + return Ref::make(_sql, _conn); + } catch (const std::exception& e) { + return error(e.what()); + } + } + private: static std::string make_cursor_name() { // TODO: Create unique cursor names. diff --git a/include/sqlgen/postgres/PostgresV2Connection.hpp b/include/sqlgen/postgres/PostgresV2Connection.hpp index df1ac82..016a1c9 100644 --- a/include/sqlgen/postgres/PostgresV2Connection.hpp +++ b/include/sqlgen/postgres/PostgresV2Connection.hpp @@ -24,6 +24,15 @@ class SQLGEN_API PostgresV2Connection { static rfl::Result make( const std::string& _conn_str) noexcept; + static rfl::Result make(PGconn* _ptr) noexcept { + try { + return PostgresV2Connection(_ptr); + } catch (const std::exception& e) { + return rfl::error("Failed to connect to postgres: " + + std::string(e.what())); + } + } + PGconn* ptr() const { return ptr_.get(); } private: diff --git a/include/sqlgen/postgres/PostgresV2Result.hpp b/include/sqlgen/postgres/PostgresV2Result.hpp index ed3dfa8..f9895c4 100644 --- a/include/sqlgen/postgres/PostgresV2Result.hpp +++ b/include/sqlgen/postgres/PostgresV2Result.hpp @@ -25,6 +25,15 @@ class SQLGEN_API PostgresV2Result { static rfl::Result make( const std::string& _query, const PostgresV2Connection& _conn) noexcept; + static rfl::Result make(PGresult* _ptr) noexcept { + try { + return PostgresV2Result(_ptr); + } catch (const std::exception& e) { + return rfl::error("Failed to retrieve result from postgres: " + + std::string(e.what())); + } + } + PGresult* ptr() const { return ptr_.get(); } private: diff --git a/src/sqlgen/postgres/Connection.cpp b/src/sqlgen/postgres/Connection.cpp index 41e2fe1..ee74c5b 100644 --- a/src/sqlgen/postgres/Connection.cpp +++ b/src/sqlgen/postgres/Connection.cpp @@ -34,11 +34,13 @@ Result Connection::end_write() { if (PQputCopyEnd(conn_.ptr(), NULL) == -1) { return error(PQerrorMessage(conn_.ptr())); } - const auto res = PostgresV2Result(PQgetResult(conn_.ptr())); - if (PQresultStatus(res.ptr()) != PGRES_COMMAND_OK) { - return error(PQerrorMessage(conn_.ptr())); - } - return Nothing{}; + return PostgresV2Result::make(PQgetResult(conn_.ptr())) + .and_then([&](auto&& res) -> Result { + if (PQresultStatus(res.ptr()) != PGRES_COMMAND_OK) { + return error(PQerrorMessage(conn_.ptr())); + } + return Nothing{}; + }); } std::list Connection::get_notifications() noexcept { @@ -46,18 +48,17 @@ std::list Connection::get_notifications() noexcept { // Safe to call even if no data — just returns true if (!PQconsumeInput(conn_.ptr())) { - // Note: In pure wait/consume pattern, this should rarely happen if socket is healthy - // But we don't error here — just skip + // Note: In pure wait/consume pattern, this should rarely happen if socket + // is healthy But we don't error here — just skip return notices; } PGnotify* notify; while ((notify = PQnotifies(conn_.ptr())) != nullptr) { - notices.push_back({ - .channel = std::string(notify->relname), - .payload = notify->extra[0] ? std::string(notify->extra) : "", - .backend_pid = notify->be_pid - }); + notices.push_back( + {.channel = std::string(notify->relname), + .payload = notify->extra[0] ? std::string(notify->extra) : "", + .backend_pid = notify->be_pid}); PQfreemem(notify); } @@ -83,16 +84,19 @@ rfl::Result Connection::unlisten(const std::string& channel) noexcept { return execute(sql); } -rfl::Result Connection::notify(const std::string& channel, const std::string& payload) noexcept { +rfl::Result Connection::notify(const std::string& channel, + const std::string& payload) noexcept { if (!is_valid_channel_name(channel)) { return error("Invalid channel name"); } - auto* escaped_payload = PQescapeLiteral(conn_.ptr(), payload.c_str(), payload.size()); + auto* escaped_payload = + PQescapeLiteral(conn_.ptr(), payload.c_str(), payload.size()); if (!escaped_payload) { return error("Failed to escape NOTIFY payload"); } - const std::string sql = "NOTIFY " + channel + ", " + std::string(escaped_payload); + const std::string sql = + "NOTIFY " + channel + ", " + std::string(escaped_payload); PQfreemem(escaped_payload); auto result = execute(sql); @@ -116,55 +120,64 @@ Result Connection::insert_impl( const auto sql = to_sql_impl(_stmt); - const auto res = PostgresV2Result(PQprepare( - conn_.ptr(), name.c_str(), sql.c_str(), _data.at(0).size(), nullptr)); + return PostgresV2Result::make(PQprepare(conn_.ptr(), name.c_str(), + sql.c_str(), _data.at(0).size(), + nullptr)) + .and_then([&](auto&& res) -> Result { + const auto status = PQresultStatus(res.ptr()); - const auto status = PQresultStatus(res.ptr()); + if (status != PGRES_COMMAND_OK) { + return error("Generating prepared statement for '" + sql + + "' failed: " + PQresultErrorMessage(res.ptr())); + } - if (status != PGRES_COMMAND_OK) { - return error("Generating prepared statement for '" + sql + - "' failed: " + PQresultErrorMessage(res.ptr())); - } + std::vector current_row(_data[0].size()); - std::vector current_row(_data[0].size()); + const int n_params = static_cast(current_row.size()); - const int n_params = static_cast(current_row.size()); + for (size_t i = 0; i < _data.size(); ++i) { + const auto& d = _data[i]; - for (size_t i = 0; i < _data.size(); ++i) { - const auto& d = _data[i]; + if (d.size() != current_row.size()) { + execute("DEALLOCATE " + name + ";"); + return error("Error in entry " + std::to_string(i) + ": Expected " + + std::to_string(current_row.size()) + " entries, got " + + std::to_string(d.size())); + } - if (d.size() != current_row.size()) { - execute("DEALLOCATE " + name + ";"); - return error("Error in entry " + std::to_string(i) + ": Expected " + - std::to_string(current_row.size()) + " entries, got " + - std::to_string(d.size())); - } + for (size_t j = 0; j < d.size(); ++j) { + current_row[j] = d[j] ? d[j]->c_str() : nullptr; + } - for (size_t j = 0; j < d.size(); ++j) { - current_row[j] = d[j] ? d[j]->c_str() : nullptr; - } + try { + const auto res = PostgresV2Result(PQexecPrepared( + conn_.ptr(), // conn + name.c_str(), // stmtName + n_params, // nParams + current_row.data(), // paramValues + nullptr, // paramLengths + nullptr, // paramFormats + 0 // resultFormat + )); - const auto res = - PostgresV2Result(PQexecPrepared(conn_.ptr(), // conn - name.c_str(), // stmtName - n_params, // nParams - current_row.data(), // paramValues - nullptr, // paramLengths - nullptr, // paramFormats - 0 // resultFormat - )); + const auto status = PQresultStatus(res.ptr()); - const auto status = PQresultStatus(res.ptr()); + if (status != PGRES_COMMAND_OK) { + const auto err = error(std::string("Executing INSERT failed: ") + + PQresultErrorMessage(res.ptr())); + execute("DEALLOCATE " + name + ";"); + return err; + } + } catch (const std::exception& e) { + const auto err = + error(std::string("Executing INSERT failed: ") + e.what()); + execute("DEALLOCATE " + name + ";"); + return err; + } + } - if (status != PGRES_COMMAND_OK) { - const auto err = error(std::string("Executing INSERT failed: ") + - PQresultErrorMessage(res.ptr())); - execute("DEALLOCATE " + name + ";"); - return err; - } - } - - return execute("DEALLOCATE " + name + ";"); + return execute("DEALLOCATE " + name + ";"); + }); } rfl::Result> Connection::make( @@ -176,7 +189,7 @@ rfl::Result> Connection::make( Result> Connection::read_impl( const rfl::Variant& _query) { const auto sql = _query.visit([](const auto& _q) { return to_sql_impl(_q); }); - return Ref::make(sql, conn_); + return Iterator::make(sql, conn_); } Result Connection::rollback() noexcept { return execute("ROLLBACK;"); } diff --git a/tests/duckdb/test_error_handling.cpp b/tests/duckdb/test_error_handling.cpp new file mode 100644 index 0000000..5e4bc49 --- /dev/null +++ b/tests/duckdb/test_error_handling.cpp @@ -0,0 +1,46 @@ + +#include + +#include +#include +#include +#include +#include + +namespace test_error_handling { + +struct Person { + sqlgen::PrimaryKey id; + std::string first_name; + std::string last_name; + int age; +}; + +TEST(duckdb, test_error_handling) { + const auto people1 = std::vector( + {Person{ + .id = 0, .first_name = "Homer", .last_name = "Simpson", .age = 45}, + Person{.id = 1, .first_name = "Bart", .last_name = "Simpson", .age = 10}, + Person{.id = 2, .first_name = "Lisa", .last_name = "Simpson", .age = 8}, + Person{ + .id = 3, .first_name = "Maggie", .last_name = "Simpson", .age = 0}, + Person{ + .id = 4, .first_name = "Hugo", .last_name = "Simpson", .age = 10}}); + + using namespace sqlgen; + using namespace sqlgen::literals; + + const auto people2 = + duckdb::connect() + .and_then(write(std::ref(people1))) + .and_then(sqlgen::read> | + where("first_name"_c.in(std::vector()))) + .value_or(std::vector({})); + + const std::string expected1 = R"([])"; + + EXPECT_EQ(rfl::json::write(people2), expected1); +} + +} // namespace test_error_handling + diff --git a/tests/mysql/test_error_handling.cpp b/tests/mysql/test_error_handling.cpp new file mode 100644 index 0000000..94410cc --- /dev/null +++ b/tests/mysql/test_error_handling.cpp @@ -0,0 +1,55 @@ +#ifndef SQLGEN_BUILD_DRY_TESTS_ONLY + +#include + +#include +#include +#include +#include +#include + +namespace test_error_handling { + +struct Person { + sqlgen::PrimaryKey id; + std::string first_name; + std::string last_name; + int age; +}; + +TEST(mysql, test_error_handling) { + const auto people1 = std::vector( + {Person{ + .id = 0, .first_name = "Homer", .last_name = "Simpson", .age = 45}, + Person{.id = 1, .first_name = "Bart", .last_name = "Simpson", .age = 10}, + Person{.id = 2, .first_name = "Lisa", .last_name = "Simpson", .age = 8}, + Person{ + .id = 3, .first_name = "Maggie", .last_name = "Simpson", .age = 0}, + Person{ + .id = 4, .first_name = "Hugo", .last_name = "Simpson", .age = 10}}); + + const auto credentials = sqlgen::mysql::Credentials{.host = "localhost", + .user = "sqlgen", + .password = "password", + .dbname = "mysql"}; + + using namespace sqlgen; + using namespace sqlgen::literals; + + const auto people2 = + mysql::connect(credentials) + .and_then(drop | if_exists) + .and_then(write(std::ref(people1))) + .and_then(sqlgen::read> | + where("first_name"_c.in(std::vector())) | + order_by("age"_c)) + .value_or(std::vector({})); + + const std::string expected1 = R"([])"; + + EXPECT_EQ(rfl::json::write(people2), expected1); +} + +} // namespace test_error_handling + +#endif diff --git a/tests/postgres/test_error_handling.cpp b/tests/postgres/test_error_handling.cpp new file mode 100644 index 0000000..3f45324 --- /dev/null +++ b/tests/postgres/test_error_handling.cpp @@ -0,0 +1,55 @@ +#ifndef SQLGEN_BUILD_DRY_TESTS_ONLY + +#include + +#include +#include +#include +#include +#include + +namespace test_error_handling { + +struct Person { + sqlgen::PrimaryKey id; + std::string first_name; + std::string last_name; + int age; +}; + +TEST(postgres, test_error_handling) { + const auto people1 = std::vector( + {Person{ + .id = 0, .first_name = "Homer", .last_name = "Simpson", .age = 45}, + Person{.id = 1, .first_name = "Bart", .last_name = "Simpson", .age = 10}, + Person{.id = 2, .first_name = "Lisa", .last_name = "Simpson", .age = 8}, + Person{ + .id = 3, .first_name = "Maggie", .last_name = "Simpson", .age = 0}, + Person{ + .id = 4, .first_name = "Hugo", .last_name = "Simpson", .age = 10}}); + + const auto credentials = sqlgen::postgres::Credentials{.user = "postgres", + .password = "password", + .host = "localhost", + .dbname = "postgres"}; + + using namespace sqlgen; + using namespace sqlgen::literals; + + /// Intentionally passing an empty vector to test error handling + const auto people2 = + postgres::connect(credentials) + .and_then(drop | if_exists) + .and_then(write(std::ref(people1))) + .and_then(sqlgen::read> | + where("first_name"_c.in(std::vector()))) + .value_or(std::vector({})); + + const std::string expected1 = R"([])"; + + EXPECT_EQ(rfl::json::write(people2), expected1); +} + +} // namespace test_error_handling + +#endif diff --git a/tests/sqlite/test_error_handling.cpp b/tests/sqlite/test_error_handling.cpp new file mode 100644 index 0000000..6690220 --- /dev/null +++ b/tests/sqlite/test_error_handling.cpp @@ -0,0 +1,46 @@ + +#include + +#include +#include +#include +#include +#include + +namespace test_error_handling { + +struct Person { + sqlgen::PrimaryKey id; + std::string first_name; + std::string last_name; + int age; +}; + +TEST(sqlite, test_error_handling) { + const auto people1 = std::vector( + {Person{ + .id = 0, .first_name = "Homer", .last_name = "Simpson", .age = 45}, + Person{.id = 1, .first_name = "Bart", .last_name = "Simpson", .age = 10}, + Person{.id = 2, .first_name = "Lisa", .last_name = "Simpson", .age = 8}, + Person{ + .id = 3, .first_name = "Maggie", .last_name = "Simpson", .age = 0}, + Person{ + .id = 4, .first_name = "Hugo", .last_name = "Simpson", .age = 10}}); + + using namespace sqlgen; + using namespace sqlgen::literals; + + const auto people2 = + sqlite::connect() + .and_then(write(std::ref(people1))) + .and_then(sqlgen::read> | + where("first_name"_c.in(std::vector()))) + .value_or(std::vector({})); + + const std::string expected1 = R"([])"; + + EXPECT_EQ(rfl::json::write(people2), expected1); +} + +} // namespace test_error_handling +