Use monadic error handling in postgres (#112)

This commit is contained in:
Dr. Patrick Urbanke (劉自成)
2025-12-31 00:56:04 +01:00
committed by GitHub
parent c8f4ad8d84
commit 9d0f0a29ab
8 changed files with 297 additions and 55 deletions

View File

@@ -41,6 +41,15 @@ class SQLGEN_API Iterator {
Iterator& operator=(Iterator&& _other) noexcept;
static rfl::Result<Ref<Iterator>> make(const std::string& _sql,
const Conn& _conn) noexcept {
try {
return Ref<Iterator>::make(_sql, _conn);
} catch (const std::exception& e) {
return error(e.what());
}
}
private:
static std::string make_cursor_name() {
// TODO: Create unique cursor names.

View File

@@ -24,6 +24,15 @@ class SQLGEN_API PostgresV2Connection {
static rfl::Result<PostgresV2Connection> make(
const std::string& _conn_str) noexcept;
static rfl::Result<PostgresV2Connection> make(PGconn* _ptr) noexcept {
try {
return PostgresV2Connection(_ptr);
} catch (const std::exception& e) {
return rfl::error("Failed to connect to postgres: " +
std::string(e.what()));
}
}
PGconn* ptr() const { return ptr_.get(); }
private:

View File

@@ -25,6 +25,15 @@ class SQLGEN_API PostgresV2Result {
static rfl::Result<PostgresV2Result> make(
const std::string& _query, const PostgresV2Connection& _conn) noexcept;
static rfl::Result<PostgresV2Result> make(PGresult* _ptr) noexcept {
try {
return PostgresV2Result(_ptr);
} catch (const std::exception& e) {
return rfl::error("Failed to retrieve result from postgres: " +
std::string(e.what()));
}
}
PGresult* ptr() const { return ptr_.get(); }
private:

View File

@@ -34,11 +34,13 @@ Result<Nothing> Connection::end_write() {
if (PQputCopyEnd(conn_.ptr(), NULL) == -1) {
return error(PQerrorMessage(conn_.ptr()));
}
const auto res = PostgresV2Result(PQgetResult(conn_.ptr()));
if (PQresultStatus(res.ptr()) != PGRES_COMMAND_OK) {
return error(PQerrorMessage(conn_.ptr()));
}
return Nothing{};
return PostgresV2Result::make(PQgetResult(conn_.ptr()))
.and_then([&](auto&& res) -> Result<Nothing> {
if (PQresultStatus(res.ptr()) != PGRES_COMMAND_OK) {
return error(PQerrorMessage(conn_.ptr()));
}
return Nothing{};
});
}
std::list<Notification> Connection::get_notifications() noexcept {
@@ -46,18 +48,17 @@ std::list<Notification> Connection::get_notifications() noexcept {
// Safe to call even if no data — just returns true
if (!PQconsumeInput(conn_.ptr())) {
// Note: In pure wait/consume pattern, this should rarely happen if socket is healthy
// But we don't error here — just skip
// Note: In pure wait/consume pattern, this should rarely happen if socket
// is healthy But we don't error here — just skip
return notices;
}
PGnotify* notify;
while ((notify = PQnotifies(conn_.ptr())) != nullptr) {
notices.push_back({
.channel = std::string(notify->relname),
.payload = notify->extra[0] ? std::string(notify->extra) : "",
.backend_pid = notify->be_pid
});
notices.push_back(
{.channel = std::string(notify->relname),
.payload = notify->extra[0] ? std::string(notify->extra) : "",
.backend_pid = notify->be_pid});
PQfreemem(notify);
}
@@ -83,16 +84,19 @@ rfl::Result<Nothing> Connection::unlisten(const std::string& channel) noexcept {
return execute(sql);
}
rfl::Result<Nothing> Connection::notify(const std::string& channel, const std::string& payload) noexcept {
rfl::Result<Nothing> Connection::notify(const std::string& channel,
const std::string& payload) noexcept {
if (!is_valid_channel_name(channel)) {
return error("Invalid channel name");
}
auto* escaped_payload = PQescapeLiteral(conn_.ptr(), payload.c_str(), payload.size());
auto* escaped_payload =
PQescapeLiteral(conn_.ptr(), payload.c_str(), payload.size());
if (!escaped_payload) {
return error("Failed to escape NOTIFY payload");
}
const std::string sql = "NOTIFY " + channel + ", " + std::string(escaped_payload);
const std::string sql =
"NOTIFY " + channel + ", " + std::string(escaped_payload);
PQfreemem(escaped_payload);
auto result = execute(sql);
@@ -116,55 +120,64 @@ Result<Nothing> Connection::insert_impl(
const auto sql = to_sql_impl(_stmt);
const auto res = PostgresV2Result(PQprepare(
conn_.ptr(), name.c_str(), sql.c_str(), _data.at(0).size(), nullptr));
return PostgresV2Result::make(PQprepare(conn_.ptr(), name.c_str(),
sql.c_str(), _data.at(0).size(),
nullptr))
.and_then([&](auto&& res) -> Result<Nothing> {
const auto status = PQresultStatus(res.ptr());
const auto status = PQresultStatus(res.ptr());
if (status != PGRES_COMMAND_OK) {
return error("Generating prepared statement for '" + sql +
"' failed: " + PQresultErrorMessage(res.ptr()));
}
if (status != PGRES_COMMAND_OK) {
return error("Generating prepared statement for '" + sql +
"' failed: " + PQresultErrorMessage(res.ptr()));
}
std::vector<const char*> current_row(_data[0].size());
std::vector<const char*> current_row(_data[0].size());
const int n_params = static_cast<int>(current_row.size());
const int n_params = static_cast<int>(current_row.size());
for (size_t i = 0; i < _data.size(); ++i) {
const auto& d = _data[i];
for (size_t i = 0; i < _data.size(); ++i) {
const auto& d = _data[i];
if (d.size() != current_row.size()) {
execute("DEALLOCATE " + name + ";");
return error("Error in entry " + std::to_string(i) + ": Expected " +
std::to_string(current_row.size()) + " entries, got " +
std::to_string(d.size()));
}
if (d.size() != current_row.size()) {
execute("DEALLOCATE " + name + ";");
return error("Error in entry " + std::to_string(i) + ": Expected " +
std::to_string(current_row.size()) + " entries, got " +
std::to_string(d.size()));
}
for (size_t j = 0; j < d.size(); ++j) {
current_row[j] = d[j] ? d[j]->c_str() : nullptr;
}
for (size_t j = 0; j < d.size(); ++j) {
current_row[j] = d[j] ? d[j]->c_str() : nullptr;
}
try {
const auto res = PostgresV2Result(PQexecPrepared(
conn_.ptr(), // conn
name.c_str(), // stmtName
n_params, // nParams
current_row.data(), // paramValues
nullptr, // paramLengths
nullptr, // paramFormats
0 // resultFormat
));
const auto res =
PostgresV2Result(PQexecPrepared(conn_.ptr(), // conn
name.c_str(), // stmtName
n_params, // nParams
current_row.data(), // paramValues
nullptr, // paramLengths
nullptr, // paramFormats
0 // resultFormat
));
const auto status = PQresultStatus(res.ptr());
const auto status = PQresultStatus(res.ptr());
if (status != PGRES_COMMAND_OK) {
const auto err = error(std::string("Executing INSERT failed: ") +
PQresultErrorMessage(res.ptr()));
execute("DEALLOCATE " + name + ";");
return err;
}
} catch (const std::exception& e) {
const auto err =
error(std::string("Executing INSERT failed: ") + e.what());
execute("DEALLOCATE " + name + ";");
return err;
}
}
if (status != PGRES_COMMAND_OK) {
const auto err = error(std::string("Executing INSERT failed: ") +
PQresultErrorMessage(res.ptr()));
execute("DEALLOCATE " + name + ";");
return err;
}
}
return execute("DEALLOCATE " + name + ";");
return execute("DEALLOCATE " + name + ";");
});
}
rfl::Result<Ref<Connection>> Connection::make(
@@ -176,7 +189,7 @@ rfl::Result<Ref<Connection>> Connection::make(
Result<Ref<Iterator>> Connection::read_impl(
const rfl::Variant<dynamic::SelectFrom, dynamic::Union>& _query) {
const auto sql = _query.visit([](const auto& _q) { return to_sql_impl(_q); });
return Ref<Iterator>::make(sql, conn_);
return Iterator::make(sql, conn_);
}
Result<Nothing> Connection::rollback() noexcept { return execute("ROLLBACK;"); }

View File

@@ -0,0 +1,46 @@
#include <gtest/gtest.h>
#include <rfl.hpp>
#include <rfl/json.hpp>
#include <sqlgen.hpp>
#include <sqlgen/duckdb.hpp>
#include <vector>
namespace test_error_handling {
struct Person {
sqlgen::PrimaryKey<uint32_t> id;
std::string first_name;
std::string last_name;
int age;
};
TEST(duckdb, test_error_handling) {
const auto people1 = std::vector<Person>(
{Person{
.id = 0, .first_name = "Homer", .last_name = "Simpson", .age = 45},
Person{.id = 1, .first_name = "Bart", .last_name = "Simpson", .age = 10},
Person{.id = 2, .first_name = "Lisa", .last_name = "Simpson", .age = 8},
Person{
.id = 3, .first_name = "Maggie", .last_name = "Simpson", .age = 0},
Person{
.id = 4, .first_name = "Hugo", .last_name = "Simpson", .age = 10}});
using namespace sqlgen;
using namespace sqlgen::literals;
const auto people2 =
duckdb::connect()
.and_then(write(std::ref(people1)))
.and_then(sqlgen::read<std::vector<Person>> |
where("first_name"_c.in(std::vector<std::string>())))
.value_or(std::vector<Person>({}));
const std::string expected1 = R"([])";
EXPECT_EQ(rfl::json::write(people2), expected1);
}
} // namespace test_error_handling

View File

@@ -0,0 +1,55 @@
#ifndef SQLGEN_BUILD_DRY_TESTS_ONLY
#include <gtest/gtest.h>
#include <rfl.hpp>
#include <rfl/json.hpp>
#include <sqlgen.hpp>
#include <sqlgen/mysql.hpp>
#include <vector>
namespace test_error_handling {
struct Person {
sqlgen::PrimaryKey<uint32_t> id;
std::string first_name;
std::string last_name;
int age;
};
TEST(mysql, test_error_handling) {
const auto people1 = std::vector<Person>(
{Person{
.id = 0, .first_name = "Homer", .last_name = "Simpson", .age = 45},
Person{.id = 1, .first_name = "Bart", .last_name = "Simpson", .age = 10},
Person{.id = 2, .first_name = "Lisa", .last_name = "Simpson", .age = 8},
Person{
.id = 3, .first_name = "Maggie", .last_name = "Simpson", .age = 0},
Person{
.id = 4, .first_name = "Hugo", .last_name = "Simpson", .age = 10}});
const auto credentials = sqlgen::mysql::Credentials{.host = "localhost",
.user = "sqlgen",
.password = "password",
.dbname = "mysql"};
using namespace sqlgen;
using namespace sqlgen::literals;
const auto people2 =
mysql::connect(credentials)
.and_then(drop<Person> | if_exists)
.and_then(write(std::ref(people1)))
.and_then(sqlgen::read<std::vector<Person>> |
where("first_name"_c.in(std::vector<std::string>())) |
order_by("age"_c))
.value_or(std::vector<Person>({}));
const std::string expected1 = R"([])";
EXPECT_EQ(rfl::json::write(people2), expected1);
}
} // namespace test_error_handling
#endif

View File

@@ -0,0 +1,55 @@
#ifndef SQLGEN_BUILD_DRY_TESTS_ONLY
#include <gtest/gtest.h>
#include <rfl.hpp>
#include <rfl/json.hpp>
#include <sqlgen.hpp>
#include <sqlgen/postgres.hpp>
#include <vector>
namespace test_error_handling {
struct Person {
sqlgen::PrimaryKey<uint32_t> id;
std::string first_name;
std::string last_name;
int age;
};
TEST(postgres, test_error_handling) {
const auto people1 = std::vector<Person>(
{Person{
.id = 0, .first_name = "Homer", .last_name = "Simpson", .age = 45},
Person{.id = 1, .first_name = "Bart", .last_name = "Simpson", .age = 10},
Person{.id = 2, .first_name = "Lisa", .last_name = "Simpson", .age = 8},
Person{
.id = 3, .first_name = "Maggie", .last_name = "Simpson", .age = 0},
Person{
.id = 4, .first_name = "Hugo", .last_name = "Simpson", .age = 10}});
const auto credentials = sqlgen::postgres::Credentials{.user = "postgres",
.password = "password",
.host = "localhost",
.dbname = "postgres"};
using namespace sqlgen;
using namespace sqlgen::literals;
/// Intentionally passing an empty vector to test error handling
const auto people2 =
postgres::connect(credentials)
.and_then(drop<Person> | if_exists)
.and_then(write(std::ref(people1)))
.and_then(sqlgen::read<std::vector<Person>> |
where("first_name"_c.in(std::vector<std::string>())))
.value_or(std::vector<Person>({}));
const std::string expected1 = R"([])";
EXPECT_EQ(rfl::json::write(people2), expected1);
}
} // namespace test_error_handling
#endif

View File

@@ -0,0 +1,46 @@
#include <gtest/gtest.h>
#include <rfl.hpp>
#include <rfl/json.hpp>
#include <sqlgen.hpp>
#include <sqlgen/sqlite.hpp>
#include <vector>
namespace test_error_handling {
struct Person {
sqlgen::PrimaryKey<uint32_t> id;
std::string first_name;
std::string last_name;
int age;
};
TEST(sqlite, test_error_handling) {
const auto people1 = std::vector<Person>(
{Person{
.id = 0, .first_name = "Homer", .last_name = "Simpson", .age = 45},
Person{.id = 1, .first_name = "Bart", .last_name = "Simpson", .age = 10},
Person{.id = 2, .first_name = "Lisa", .last_name = "Simpson", .age = 8},
Person{
.id = 3, .first_name = "Maggie", .last_name = "Simpson", .age = 0},
Person{
.id = 4, .first_name = "Hugo", .last_name = "Simpson", .age = 10}});
using namespace sqlgen;
using namespace sqlgen::literals;
const auto people2 =
sqlite::connect()
.and_then(write(std::ref(people1)))
.and_then(sqlgen::read<std::vector<Person>> |
where("first_name"_c.in(std::vector<std::string>())))
.value_or(std::vector<Person>({}));
const std::string expected1 = R"([])";
EXPECT_EQ(rfl::json::write(people2), expected1);
}
} // namespace test_error_handling