From ce12604870721b439c09e1778fe4337b82816b5e Mon Sep 17 00:00:00 2001 From: "Dr. Patrick Urbanke" Date: Wed, 14 May 2025 00:52:11 +0200 Subject: [PATCH] Added support for transactions --- docs/README.md | 1 + docs/delete_from.md | 6 +- docs/drop.md | 6 +- docs/transactions.md | 61 +++++++++++++ docs/update.md | 6 +- docs/writing.md | 2 +- include/sqlgen.hpp | 3 + include/sqlgen/Connection.hpp | 8 +- include/sqlgen/begin_transaction.hpp | 38 ++++++++ include/sqlgen/commit.hpp | 35 ++++++++ include/sqlgen/delete_from.hpp | 14 +-- include/sqlgen/drop.hpp | 13 +-- include/sqlgen/postgres/Connection.hpp | 22 ++++- include/sqlgen/rollback.hpp | 35 ++++++++ include/sqlgen/sqlite/Connection.hpp | 21 ++++- include/sqlgen/update.hpp | 16 ++-- include/sqlgen/write.hpp | 7 +- src/sqlgen/postgres/Connection.cpp | 52 +++++++++++ src/sqlgen/sqlite/Connection.cpp | 116 +++++++++++++++++++------ tests/sqlite/test_transaction.cpp | 54 ++++++++++++ 20 files changed, 452 insertions(+), 64 deletions(-) create mode 100644 docs/transactions.md create mode 100644 include/sqlgen/begin_transaction.hpp create mode 100644 include/sqlgen/commit.hpp create mode 100644 include/sqlgen/rollback.hpp create mode 100644 tests/sqlite/test_transaction.cpp diff --git a/docs/README.md b/docs/README.md index 4a677ad..3914024 100644 --- a/docs/README.md +++ b/docs/README.md @@ -17,6 +17,7 @@ Welcome to the sqlgen documentation. This guide provides detailed information ab - [sqlgen::update](update.md) - How to update data in a table - [sqlgen::delete_from](delete_from.md) - How to delete data from a table - [sqlgen::drop](drop.md) - How to drop a table +- [Transactions](transactions.md) - How to use transactions for atomic operations ## Data Types and Validation diff --git a/docs/delete_from.md b/docs/delete_from.md index 942f79c..7617d8f 100644 --- a/docs/delete_from.md +++ b/docs/delete_from.md @@ -24,7 +24,7 @@ Note that `conn` is actually a connection wrapped into an `sqlgen::Result<...>`. This means you can use monadic error handling and fit this into a single line: ```cpp -// sqlgen::Result +// sqlgen::Result> const auto result = sqlgen::sqlite::connect("database.db").and_then( sqlgen::delete_from); ``` @@ -62,7 +62,7 @@ using namespace sqlgen; const auto query = delete_from | where("first_name"_c == "Hugo"); -// sqlgen::Result +// sqlgen::Result> const auto result = sqlite::connect("database.db").and_then(query); ``` @@ -101,6 +101,6 @@ const auto result = query(conn); ## Notes - The `where` clause is optional - if omitted, all records will be deleted -- The `Result` type provides error handling; use `.value()` to extract the result (will throw an exception if there's an error) or handle errors as needed or refer to the documentation on `sqlgen::Result<...>` for other forms of error handling. +- The `Result>` type provides error handling; use `.value()` to extract the result (will throw an exception if there's an error) or handle errors as needed or refer to the documentation on `sqlgen::Result<...>` for other forms of error handling. - `"..."_c` refers to the name of the column diff --git a/docs/drop.md b/docs/drop.md index 6d93d3e..01f89e9 100644 --- a/docs/drop.md +++ b/docs/drop.md @@ -24,7 +24,7 @@ Note that `conn` is actually a connection wrapped into an `sqlgen::Result<...>`. This means you can use monadic error handling and fit this into a single line: ```cpp -// sqlgen::Result +// sqlgen::Result> const auto result = sqlgen::sqlite::connect("database.db").and_then( sqlgen::drop); ``` @@ -56,7 +56,7 @@ using namespace sqlgen; const auto query = drop | if_exists; -// sqlgen::Result +// sqlgen::Result> const auto result = sqlite::connect("database.db").and_then(query); ``` @@ -92,6 +92,6 @@ const auto result = query(conn); ## Notes - The `if_exists` clause is optional - if omitted, the query will fail if the table doesn't exist -- The `Result` type provides error handling; use `.value()` to extract the result (will throw an exception if there's an error) or handle errors as needed or refer to the documentation on `sqlgen::Result<...>` for other forms of error handling. +- The `Result>` type provides error handling; use `.value()` to extract the result (will throw an exception if there's an error) or handle errors as needed or refer to the documentation on `sqlgen::Result<...>` for other forms of error handling. - The table name is derived from the struct name (e.g., `Person` becomes `"Person"` in SQL) diff --git a/docs/transactions.md b/docs/transactions.md new file mode 100644 index 0000000..b4ddd03 --- /dev/null +++ b/docs/transactions.md @@ -0,0 +1,61 @@ +# Transactions + +sqlgen provides a simple and safe way to handle database transactions. Transactions ensure that a series of database operations are executed atomically - either all operations succeed, or none of them do. + +## Basic Usage + +Transactions in sqlgen are managed through three main functions: + +- `sqlgen::begin_transaction` - Starts a new transaction +- `sqlgen::commit` - Commits the current transaction +- `sqlgen::rollback` - Rolls back the current transaction + +Here's a basic example of how to use transactions: + +```cpp +using namespace sqlgen; + +// Start a transaction and chain operations +// Note that all of these operations return +// the connection if they are successful. +auto conn = sqlite::connect("database.db") + .and_then(begin_transaction) + .and_then(delete_from | where("first_name"_c == "Hugo")) + .and_then(update("age"_c.set(46)) | where("first_name"_c == "Homer")) + .and_then(commit); +``` + +## Automatic Rollback + +sqlgen provides automatic rollback protection through RAII (Resource Acquisition Is Initialization). If a transaction is not explicitly committed, it will be automatically rolled back when the connection object goes out of scope. This helps ensure database consistency. + +## Error Handling + +All transaction operations return a `sqlgen::Result>`, which allows for safe error handling and operation chaining. If any operation in the transaction chain fails, the error will be propagated through the Result type, and the transaction will be rolled back. + +Example with error handling: + +```cpp +using namespace sqlgen; + +auto conn = sqlite::connect("database.db") + .and_then(begin_transaction) + .and_then(delete_from | where("first_name"_c == "Hugo")) + .and_then(update("age"_c.set(46)) | where("first_name"_c == "Homer")) + .and_then(commit); + +if (!conn) { + // Handle error + std::cerr << "Transaction failed: " << result.error() << std::endl; + return; +} + +// Transaction was successful. conn contains the active connection. +``` + +## Best Practices + +1. Chain operations using `.and_then()` to maintain transaction context +2. Let the automatic rollback handle cleanup in case of errors +3. Check the Result type after transaction operations to handle any errors +4. Keep transactions as short as possible to minimize lock contention diff --git a/docs/update.md b/docs/update.md index d8f4908..e3b79b2 100644 --- a/docs/update.md +++ b/docs/update.md @@ -29,7 +29,7 @@ Note that `conn` is actually a connection wrapped into an `sqlgen::Result<...>`. This means you can use monadic error handling and fit this into a single line: ```cpp -// sqlgen::Result +// sqlgen::Result> const auto result = sqlgen::sqlite::connect("database.db").and_then( update("age"_c.set(100), "first_name"_c.set("New Name"))); ``` @@ -68,7 +68,7 @@ using namespace sqlgen; const auto query = update("age"_c.set(100)) | where("first_name"_c == "Hugo"); -// sqlgen::Result +// sqlgen::Result> const auto result = sqlite::connect("database.db").and_then(query); ``` @@ -136,7 +136,7 @@ const auto result = query(conn); - You must specify at least one column to update - The `where` clause is optional - if omitted, all records will be updated -- The `Result` type provides error handling; use `.value()` to extract the result (will throw an exception if there's an error) or handle errors as needed or refer to the documentation on `sqlgen::Result<...>` for other forms of error handling +- The `Result>` type provides error handling; use `.value()` to extract the result (will throw an exception if there's an error) or handle errors as needed or refer to the documentation on `sqlgen::Result<...>` for other forms of error handling - `"..."_c` refers to the name of the column - You can set columns to either literal values or other column values - The update operation is atomic - either all specified columns are updated or none are diff --git a/docs/writing.md b/docs/writing.md index 5282d99..381181d 100644 --- a/docs/writing.md +++ b/docs/writing.md @@ -81,7 +81,7 @@ The `write` function performs the following operations in sequence: - The function automatically creates the table, if it doesn't exist - Data is written in batches for better performance -- The `Result` type provides error handling; use `.value()` to extract the result (will throw an exception if there's an error) or handle errors as needed +- The `Result>` type provides error handling; use `.value()` to extract the result (will throw an exception if there's an error) or handle errors as needed - The function has three overloads: 1. Takes a connection reference and iterators 2. Takes a `Result>` and iterators diff --git a/include/sqlgen.hpp b/include/sqlgen.hpp index e540be7..0d852b3 100644 --- a/include/sqlgen.hpp +++ b/include/sqlgen.hpp @@ -15,7 +15,9 @@ #include "sqlgen/Result.hpp" #include "sqlgen/Timestamp.hpp" #include "sqlgen/Varchar.hpp" +#include "sqlgen/begin_transaction.hpp" #include "sqlgen/col.hpp" +#include "sqlgen/commit.hpp" #include "sqlgen/delete_from.hpp" #include "sqlgen/drop.hpp" #include "sqlgen/if_exists.hpp" @@ -23,6 +25,7 @@ #include "sqlgen/order_by.hpp" #include "sqlgen/patterns.hpp" #include "sqlgen/read.hpp" +#include "sqlgen/rollback.hpp" #include "sqlgen/update.hpp" #include "sqlgen/where.hpp" #include "sqlgen/write.hpp" diff --git a/include/sqlgen/Connection.hpp b/include/sqlgen/Connection.hpp index b6a0038..ffe0b45 100644 --- a/include/sqlgen/Connection.hpp +++ b/include/sqlgen/Connection.hpp @@ -19,7 +19,10 @@ namespace sqlgen { struct Connection { virtual ~Connection() = default; - /// Commits a statement, + /// Begins a transaction. + virtual Result begin_transaction() = 0; + + /// Commits a transaction. virtual Result commit() = 0; /// Executes a statement. Note that in order for the statement to take effect, @@ -29,6 +32,9 @@ struct Connection { /// Reads the results of a SelectFrom statement. virtual Result> read(const dynamic::SelectFrom& _query) = 0; + /// Rolls a transaction back. + virtual Result rollback() = 0; + /// Transpiles a statement to a particular SQL dialect. virtual std::string to_sql(const dynamic::Statement& _stmt) = 0; diff --git a/include/sqlgen/begin_transaction.hpp b/include/sqlgen/begin_transaction.hpp new file mode 100644 index 0000000..46a37e4 --- /dev/null +++ b/include/sqlgen/begin_transaction.hpp @@ -0,0 +1,38 @@ +#ifndef SQLGEN_BEGIN_TRANSACTION_HPP_ +#define SQLGEN_BEGIN_TRANSACTION_HPP_ + +#include + +#include "Connection.hpp" +#include "Ref.hpp" +#include "Result.hpp" + +namespace sqlgen { + +inline Result> begin_transaction_impl( + const Ref& _conn) { + return _conn->begin_transaction().transform( + [&](const auto&) { return _conn; }); +} + +inline Result> begin_transaction_impl( + const Result>& _res) { + return _res.and_then( + [&](const auto& _conn) { return begin_transaction_impl(_conn); }); +} + +struct BeginTransaction { + Result> operator()(const auto& _conn) const noexcept { + try { + return begin_transaction_impl(_conn); + } catch (std::exception& e) { + return error(e.what()); + } + } +}; + +inline const auto begin_transaction = BeginTransaction{}; + +} // namespace sqlgen + +#endif diff --git a/include/sqlgen/commit.hpp b/include/sqlgen/commit.hpp new file mode 100644 index 0000000..e9db029 --- /dev/null +++ b/include/sqlgen/commit.hpp @@ -0,0 +1,35 @@ +#ifndef SQLGEN_COMMIT_HPP_ +#define SQLGEN_COMMIT_HPP_ + +#include + +#include "Connection.hpp" +#include "Ref.hpp" +#include "Result.hpp" + +namespace sqlgen { + +inline Result> commit_impl(const Ref& _conn) { + return _conn->commit().transform([&](const auto&) { return _conn; }); +} + +inline Result> commit_impl( + const Result>& _res) { + return _res.and_then([&](const auto& _conn) { return commit_impl(_conn); }); +} + +struct Commit { + Result> operator()(const auto& _conn) const noexcept { + try { + return commit_impl(_conn); + } catch (std::exception& e) { + return error(e.what()); + } + } +}; + +inline const auto commit = Commit{}; + +} // namespace sqlgen + +#endif diff --git a/include/sqlgen/delete_from.hpp b/include/sqlgen/delete_from.hpp index c61c439..8e30958 100644 --- a/include/sqlgen/delete_from.hpp +++ b/include/sqlgen/delete_from.hpp @@ -11,16 +11,18 @@ namespace sqlgen { template -Result delete_from_impl(const Ref& _conn, - const WhereType& _where) { +Result> delete_from_impl(const Ref& _conn, + const WhereType& _where) { const auto query = transpilation::to_delete_from(_where); - return _conn->execute(_conn->to_sql(query)); + return _conn->execute(_conn->to_sql(query)).transform([&](const auto&) { + return _conn; + }); } template -Result delete_from_impl(const Result>& _res, - const WhereType& _where) { +Result> delete_from_impl(const Result>& _res, + const WhereType& _where) { return _res.and_then([&](const auto& _conn) { return delete_from_impl(_conn, _where); }); @@ -28,7 +30,7 @@ Result delete_from_impl(const Result>& _res, template struct DeleteFrom { - Result operator()(const auto& _conn) const noexcept { + Result> operator()(const auto& _conn) const noexcept { try { return delete_from_impl(_conn, where_); } catch (std::exception& e) { diff --git a/include/sqlgen/drop.hpp b/include/sqlgen/drop.hpp index c1231e1..653eabd 100644 --- a/include/sqlgen/drop.hpp +++ b/include/sqlgen/drop.hpp @@ -11,14 +11,17 @@ namespace sqlgen { template -Result drop_impl(const Ref& _conn, const bool _if_exists) { +Result> drop_impl(const Ref& _conn, + const bool _if_exists) { const auto query = transpilation::to_drop(_if_exists); - return _conn->execute(_conn->to_sql(query)); + return _conn->execute(_conn->to_sql(query)).transform([&](const auto&) { + return _conn; + }); } template -Result drop_impl(const Result>& _res, - const bool _if_exists) { +Result> drop_impl(const Result>& _res, + const bool _if_exists) { return _res.and_then([&](const auto& _conn) { return drop_impl(_conn, _if_exists); }); @@ -26,7 +29,7 @@ Result drop_impl(const Result>& _res, template struct Drop { - Result operator()(const auto& _conn) const noexcept { + Result> operator()(const auto& _conn) const noexcept { try { return drop_impl(_conn, if_exists_); } catch (std::exception& e) { diff --git a/include/sqlgen/postgres/Connection.hpp b/include/sqlgen/postgres/Connection.hpp index 46f76b8..6e5fe25 100644 --- a/include/sqlgen/postgres/Connection.hpp +++ b/include/sqlgen/postgres/Connection.hpp @@ -25,21 +25,35 @@ class Connection : public sqlgen::Connection { public: Connection(const Credentials& _credentials) - : conn_(make_conn(_credentials.to_str())), credentials_(_credentials) {} + : conn_(make_conn(_credentials.to_str())), + credentials_(_credentials), + transaction_started_(false) {} static rfl::Result> make( const Credentials& _credentials) noexcept; - ~Connection() = default; + Connection(const Connection& _other) = delete; - Result commit() final { return execute("COMMIT;"); } + Connection(Connection&& _other) noexcept; + + ~Connection(); + + Result begin_transaction() noexcept final; + + Result commit() noexcept final; Result execute(const std::string& _sql) noexcept final { return exec(conn_, _sql).transform([](auto&&) { return Nothing{}; }); } + Connection& operator=(const Connection& _other) = delete; + + Connection& operator=(Connection&& _other) noexcept; + Result> read(const dynamic::SelectFrom& _query) final; + Result rollback() noexcept final; + std::string to_sql(const dynamic::Statement& _stmt) noexcept final { return postgres::to_sql_impl(_stmt); } @@ -63,6 +77,8 @@ class Connection : public sqlgen::Connection { ConnPtr conn_; Credentials credentials_; + + bool transaction_started_; }; } // namespace sqlgen::postgres diff --git a/include/sqlgen/rollback.hpp b/include/sqlgen/rollback.hpp new file mode 100644 index 0000000..f37f143 --- /dev/null +++ b/include/sqlgen/rollback.hpp @@ -0,0 +1,35 @@ +#ifndef SQLGEN_ROLLBACK_HPP_ +#define SQLGEN_ROLLBACK_HPP_ + +#include + +#include "Connection.hpp" +#include "Ref.hpp" +#include "Result.hpp" + +namespace sqlgen { + +inline Result> rollback_impl(const Ref& _conn) { + return _conn->rollback().transform([&](const auto&) { return _conn; }); +} + +inline Result> rollback_impl( + const Result>& _res) { + return _res.and_then([&](const auto& _conn) { return rollback_impl(_conn); }); +} + +struct Rollback { + Result> operator()(const auto& _conn) const noexcept { + try { + return rollback_impl(_conn); + } catch (std::exception& e) { + return error(e.what()); + } + } +}; + +inline const auto rollback = Rollback{}; + +} // namespace sqlgen + +#endif diff --git a/include/sqlgen/sqlite/Connection.hpp b/include/sqlgen/sqlite/Connection.hpp index 0ba0851..ae620d6 100644 --- a/include/sqlgen/sqlite/Connection.hpp +++ b/include/sqlgen/sqlite/Connection.hpp @@ -23,19 +23,31 @@ class Connection : public sqlgen::Connection { public: Connection(const std::string& _fname) - : stmt_(nullptr), conn_(make_conn(_fname)) {} + : stmt_(nullptr), conn_(make_conn(_fname)), transaction_started_(false) {} static rfl::Result> make( const std::string& _fname) noexcept; - ~Connection() = default; + Connection(const Connection& _other) = delete; - Result commit() final { return execute("COMMIT;"); } + Connection(Connection&& _other) noexcept; + + ~Connection(); + + Result begin_transaction() noexcept final; + + Result commit() noexcept final; Result execute(const std::string& _sql) noexcept final; + Connection& operator=(const Connection& _other) = delete; + + Connection& operator=(Connection&& _other) noexcept; + Result> read(const dynamic::SelectFrom& _query) final; + Result rollback() noexcept final; + std::string to_sql(const dynamic::Statement& _stmt) noexcept final { return sqlite::to_sql_impl(_stmt); } @@ -58,6 +70,9 @@ class Connection : public sqlgen::Connection { /// The underlying sqlite3 connection. ConnPtr conn_; + + /// Whether a transaction has been started. + bool transaction_started_; }; } // namespace sqlgen::sqlite diff --git a/include/sqlgen/update.hpp b/include/sqlgen/update.hpp index 5944a27..a90083d 100644 --- a/include/sqlgen/update.hpp +++ b/include/sqlgen/update.hpp @@ -12,16 +12,20 @@ namespace sqlgen { template -Result update_impl(const Ref& _conn, const SetsType& _sets, - const WhereType& _where) { +Result> update_impl(const Ref& _conn, + const SetsType& _sets, + const WhereType& _where) { const auto query = transpilation::to_update(_sets, _where); - return _conn->execute(_conn->to_sql(query)); + return _conn->execute(_conn->to_sql(query)).transform([&](const auto&) { + return _conn; + }); } template -Result update_impl(const Result>& _res, - const SetsType& _sets, const WhereType& _where) { +Result> update_impl(const Result>& _res, + const SetsType& _sets, + const WhereType& _where) { return _res.and_then([&](const auto& _conn) { return update_impl(_conn, _sets, _where); }); @@ -29,7 +33,7 @@ Result update_impl(const Result>& _res, template struct Update { - Result operator()(const auto& _conn) const noexcept { + Result> operator()(const auto& _conn) const noexcept { try { return update_impl(_conn, sets_, where_); } catch (std::exception& e) { diff --git a/include/sqlgen/write.hpp b/include/sqlgen/write.hpp index 52c5136..8874411 100644 --- a/include/sqlgen/write.hpp +++ b/include/sqlgen/write.hpp @@ -20,8 +20,8 @@ namespace sqlgen { template -Result write(const Ref& _conn, ItBegin _begin, - ItEnd _end) noexcept { +Result> write(const Ref& _conn, ItBegin _begin, + ItEnd _end) noexcept { using T = std::remove_cvref_t::value_type>; @@ -59,7 +59,8 @@ Result write(const Ref& _conn, ItBegin _begin, return _conn->execute(_conn->to_sql(create_table_stmt)) .and_then(start_write) .and_then(write) - .and_then(end_write); + .and_then(end_write) + .transform([&](const auto&) { return _conn; }); } template diff --git a/src/sqlgen/postgres/Connection.cpp b/src/sqlgen/postgres/Connection.cpp index bfa741b..aa51ae1 100644 --- a/src/sqlgen/postgres/Connection.cpp +++ b/src/sqlgen/postgres/Connection.cpp @@ -11,6 +11,36 @@ namespace sqlgen::postgres { +Connection::Connection(Connection&& _other) noexcept + : conn_(std::move(_other.conn_)), + credentials_(std::move(_other.credentials_)), + transaction_started_(_other.transaction_started_) { + _other.transaction_started_ = false; +} + +Connection::~Connection() { + if (transaction_started_) { + rollback(); + } +} + +Result Connection::begin_transaction() noexcept { + if (transaction_started_) { + return error( + "Cannot BEGIN TRANSACTION - another transaction has been started."); + } + transaction_started_ = true; + return execute("BEGIN TRANSACTION;"); +} + +Result Connection::commit() noexcept { + if (!transaction_started_) { + return error("Cannot COMMIT - no transaction has been started."); + } + transaction_started_ = false; + return execute("COMMIT;"); +} + Result Connection::end_write() { if (PQputCopyEnd(conn_.get(), NULL) == -1) { return error(PQerrorMessage(conn_.get())); @@ -45,6 +75,20 @@ typename Connection::ConnPtr Connection::make_conn( return ConnPtr::make(std::shared_ptr(raw_ptr, &PQfinish)).value(); } +Connection& Connection::operator=(Connection&& _other) noexcept { + if (this == &_other) { + return *this; + } + if (transaction_started_) { + rollback(); + } + conn_ = std::move(_other.conn_); + credentials_ = std::move(_other.credentials_); + transaction_started_ = _other.transaction_started_; + _other.transaction_started_ = false; + return *this; +} + Result> Connection::read(const dynamic::SelectFrom& _query) { const auto sql = postgres::to_sql_impl(_query); try { @@ -54,6 +98,14 @@ Result> Connection::read(const dynamic::SelectFrom& _query) { } } +Result Connection::rollback() noexcept { + if (!transaction_started_) { + return error("Cannot ROLLBACK - no transaction has been started."); + } + transaction_started_ = false; + return execute("ROLLBACK;"); +} + std::string Connection::to_buffer( const std::vector>& _line) const noexcept { using namespace std::ranges::views; diff --git a/src/sqlgen/sqlite/Connection.cpp b/src/sqlgen/sqlite/Connection.cpp index b6ac28d..c1a5b4d 100644 --- a/src/sqlgen/sqlite/Connection.cpp +++ b/src/sqlgen/sqlite/Connection.cpp @@ -11,6 +11,36 @@ namespace sqlgen::sqlite { +Connection::Connection(Connection&& _other) noexcept + : stmt_(std::move(_other.stmt_)), + conn_(std::move(_other.conn_)), + transaction_started_(_other.transaction_started_) { + _other.transaction_started_ = false; +} + +Connection::~Connection() { + if (transaction_started_) { + rollback(); + } +} + +Result Connection::begin_transaction() noexcept { + if (transaction_started_) { + return error( + "Cannot BEGIN TRANSACTION - another transaction has been started."); + } + transaction_started_ = true; + return execute("BEGIN TRANSACTION;"); +} + +Result Connection::commit() noexcept { + if (!transaction_started_) { + return error("Cannot COMMIT - no transaction has been started."); + } + transaction_started_ = false; + return execute("COMMIT;"); +} + rfl::Result> Connection::make( const std::string& _fname) noexcept { try { @@ -41,6 +71,20 @@ typename Connection::ConnPtr Connection::make_conn(const std::string& _fname) { return ConnPtr::make(std::shared_ptr(conn, &sqlite3_close)).value(); } +Connection& Connection::operator=(Connection&& _other) noexcept { + if (this == &_other) { + return *this; + } + if (transaction_started_) { + rollback(); + } + stmt_ = std::move(_other.stmt_); + conn_ = std::move(_other.conn_); + transaction_started_ = _other.transaction_started_; + _other.transaction_started_ = false; + return *this; +} + Result> Connection::read(const dynamic::SelectFrom& _query) { const auto sql = to_sql_impl(_query); @@ -63,6 +107,14 @@ Result> Connection::read(const dynamic::SelectFrom& _query) { }); } +Result Connection::rollback() noexcept { + if (!transaction_started_) { + return error("Cannot ROLLBACK - no transaction has been started."); + } + transaction_started_ = false; + return execute("ROLLBACK;"); +} + Result Connection::start_write(const dynamic::Insert& _stmt) { if (stmt_) { return error( @@ -98,43 +150,53 @@ Result Connection::write( ".write(...)."); } - for (const auto& row : _data) { - const auto num_cols = static_cast(row.size()); + const auto write = [&](const auto&) -> Result { + for (const auto& row : _data) { + const auto num_cols = static_cast(row.size()); - for (int i = 0; i < num_cols; ++i) { - if (row[i]) { - const auto res = - sqlite3_bind_text(stmt_.get(), i + 1, row[i]->c_str(), - static_cast(row[i]->size()), SQLITE_STATIC); - if (res != SQLITE_OK) { - return error(sqlite3_errmsg(conn_.get())); - } - } else { - const auto res = sqlite3_bind_null(stmt_.get(), i + 1); - if (res != SQLITE_OK) { - return error(sqlite3_errmsg(conn_.get())); + for (int i = 0; i < num_cols; ++i) { + if (row[i]) { + const auto res = sqlite3_bind_text( + stmt_.get(), i + 1, row[i]->c_str(), + static_cast(row[i]->size()), SQLITE_STATIC); + if (res != SQLITE_OK) { + return error(sqlite3_errmsg(conn_.get())); + } + } else { + const auto res = sqlite3_bind_null(stmt_.get(), i + 1); + if (res != SQLITE_OK) { + return error(sqlite3_errmsg(conn_.get())); + } } } + + auto res = sqlite3_step(stmt_.get()); + if (res != SQLITE_OK && res != SQLITE_ROW && res != SQLITE_DONE) { + return error(sqlite3_errmsg(conn_.get())); + } + + res = sqlite3_reset(stmt_.get()); + if (res != SQLITE_OK) { + return error(sqlite3_errmsg(conn_.get())); + } } - auto res = sqlite3_step(stmt_.get()); - if (res != SQLITE_OK && res != SQLITE_ROW && res != SQLITE_DONE) { - return error(sqlite3_errmsg(conn_.get())); - } - - res = sqlite3_reset(stmt_.get()); + // We need to reset the statement to avoid segfaults. + const auto res = sqlite3_clear_bindings(stmt_.get()); if (res != SQLITE_OK) { return error(sqlite3_errmsg(conn_.get())); } - } - // We need to reset the statement to avoid segfaults. - const auto res = sqlite3_clear_bindings(stmt_.get()); - if (res != SQLITE_OK) { - return error(sqlite3_errmsg(conn_.get())); - } + return Nothing{}; + }; - return Nothing{}; + return begin_transaction() + .and_then(write) + .and_then([&](const auto&) { return commit(); }) + .or_else([&](const auto& err) -> Result { + rollback(); + return error(err.what()); + }); } Result Connection::end_write() { diff --git a/tests/sqlite/test_transaction.cpp b/tests/sqlite/test_transaction.cpp new file mode 100644 index 0000000..dc027f2 --- /dev/null +++ b/tests/sqlite/test_transaction.cpp @@ -0,0 +1,54 @@ +#include + +#include +#include +#include +#include +#include + +namespace test_update { + +struct Person { + sqlgen::PrimaryKey id; + std::string first_name; + std::string last_name; + int age; +}; + +TEST(sqlite, test_transaction) { + const auto people1 = std::vector( + {Person{ + .id = 0, .first_name = "Homer", .last_name = "Simpson", .age = 45}, + Person{.id = 1, .first_name = "Bart", .last_name = "Simpson", .age = 10}, + Person{.id = 2, .first_name = "Lisa", .last_name = "Simpson", .age = 8}, + Person{ + .id = 3, .first_name = "Maggie", .last_name = "Simpson", .age = 0}, + Person{ + .id = 4, .first_name = "Hugo", .last_name = "Simpson", .age = 10}}); + + auto conn = sqlgen::sqlite::connect(); + + sqlgen::write(conn, people1); + + using namespace sqlgen; + + const auto delete_hugo = + delete_from | where("first_name"_c == "Hugo"); + + const auto update_homers_age = + update("age"_c.set(46)) | where("first_name"_c == "Homer"); + + conn = sqlgen::begin_transaction(conn) + .and_then(delete_hugo) + .and_then(update_homers_age) + .and_then(sqlgen::commit); + + const auto people2 = sqlgen::read>(conn).value(); + + const std::string expected = + R"([{"id":0,"first_name":"Homer","last_name":"Simpson","age":46},{"id":1,"first_name":"Bart","last_name":"Simpson","age":10},{"id":2,"first_name":"Lisa","last_name":"Simpson","age":8},{"id":3,"first_name":"Maggie","last_name":"Simpson","age":0}])"; + + EXPECT_EQ(rfl::json::write(people2), expected); +} + +} // namespace test_update