From 9d6cac05026bbca0790c90c674ffad8a3ec274c3 Mon Sep 17 00:00:00 2001 From: "Dr. Patrick Urbanke" Date: Sun, 11 May 2025 05:15:04 +0200 Subject: [PATCH] Added update --- include/sqlgen.hpp | 1 + include/sqlgen/col.hpp | 12 ++++ include/sqlgen/dynamic/Statement.hpp | 5 +- include/sqlgen/dynamic/Update.hpp | 27 +++++++++ include/sqlgen/transpilation/Set.hpp | 15 +++++ include/sqlgen/transpilation/to_sets.hpp | 69 ++++++++++++++++++++++ include/sqlgen/transpilation/to_sql.hpp | 9 +++ include/sqlgen/transpilation/to_update.hpp | 31 ++++++++++ include/sqlgen/update.hpp | 54 +++++++++++++++++ include/sqlgen/where.hpp | 11 ++++ src/sqlgen/postgres/to_sql.cpp | 37 ++++++++++++ src/sqlgen/sqlite/to_sql.cpp | 37 ++++++++++++ tests/postgres/test_update_dry.cpp | 27 +++++++++ tests/sqlite/test_update.cpp | 51 ++++++++++++++++ 14 files changed, 384 insertions(+), 2 deletions(-) create mode 100644 include/sqlgen/dynamic/Update.hpp create mode 100644 include/sqlgen/transpilation/Set.hpp create mode 100644 include/sqlgen/transpilation/to_sets.hpp create mode 100644 include/sqlgen/transpilation/to_update.hpp create mode 100644 include/sqlgen/update.hpp create mode 100644 tests/postgres/test_update_dry.cpp create mode 100644 tests/sqlite/test_update.cpp diff --git a/include/sqlgen.hpp b/include/sqlgen.hpp index 923ed6f..eb71aca 100644 --- a/include/sqlgen.hpp +++ b/include/sqlgen.hpp @@ -22,6 +22,7 @@ #include "sqlgen/order_by.hpp" #include "sqlgen/patterns.hpp" #include "sqlgen/read.hpp" +#include "sqlgen/update.hpp" #include "sqlgen/where.hpp" #include "sqlgen/write.hpp" diff --git a/include/sqlgen/col.hpp b/include/sqlgen/col.hpp index 9fc996c..7d3c307 100644 --- a/include/sqlgen/col.hpp +++ b/include/sqlgen/col.hpp @@ -6,6 +6,7 @@ #include "transpilation/Condition.hpp" #include "transpilation/Desc.hpp" +#include "transpilation/Set.hpp" #include "transpilation/Value.hpp" #include "transpilation/conditions.hpp" @@ -20,6 +21,17 @@ struct Col { /// Returns the column name. std::string name() const noexcept { return Name().str(); } + + /// Defines a SET clause in an UPDATE statement. + template + auto set(const T& _to) const noexcept { + return transpilation::Set, std::remove_cvref_t>{.to = _to}; + } + + /// Defines a SET clause in an UPDATE statement. + auto set(const char* _to) const noexcept { + return transpilation::Set, std::string>{.to = _to}; + } }; template diff --git a/include/sqlgen/dynamic/Statement.hpp b/include/sqlgen/dynamic/Statement.hpp index 4ada452..290fa62 100644 --- a/include/sqlgen/dynamic/Statement.hpp +++ b/include/sqlgen/dynamic/Statement.hpp @@ -8,11 +8,12 @@ #include "Drop.hpp" #include "Insert.hpp" #include "SelectFrom.hpp" +#include "Update.hpp" namespace sqlgen::dynamic { -using Statement = - rfl::TaggedUnion<"stmt", CreateTable, DeleteFrom, Drop, Insert, SelectFrom>; +using Statement = rfl::TaggedUnion<"stmt", CreateTable, DeleteFrom, Drop, + Insert, SelectFrom, Update>; } // namespace sqlgen::dynamic diff --git a/include/sqlgen/dynamic/Update.hpp b/include/sqlgen/dynamic/Update.hpp new file mode 100644 index 0000000..33cf782 --- /dev/null +++ b/include/sqlgen/dynamic/Update.hpp @@ -0,0 +1,27 @@ +#ifndef SQLGEN_DYNAMIC_UPDATE_HPP_ +#define SQLGEN_DYNAMIC_UPDATE_HPP_ + +#include +#include + +#include "Column.hpp" +#include "ColumnOrValue.hpp" +#include "Condition.hpp" +#include "Table.hpp" + +namespace sqlgen::dynamic { + +struct Update { + struct Set { + Column col; + ColumnOrValue to; + }; + + Table table; + std::vector sets; + std::optional where; +}; + +} // namespace sqlgen::dynamic + +#endif diff --git a/include/sqlgen/transpilation/Set.hpp b/include/sqlgen/transpilation/Set.hpp new file mode 100644 index 0000000..5e59cb6 --- /dev/null +++ b/include/sqlgen/transpilation/Set.hpp @@ -0,0 +1,15 @@ +#ifndef SQLGEN_TRANSPILATION_SET_HPP_ +#define SQLGEN_TRANSPILATION_SET_HPP_ + +namespace sqlgen::transpilation { + +/// Defines the SET clause in an UPDATE statement. +template +struct Set { + using ColType = _ColType; + T to; +}; + +} // namespace sqlgen::transpilation + +#endif diff --git a/include/sqlgen/transpilation/to_sets.hpp b/include/sqlgen/transpilation/to_sets.hpp new file mode 100644 index 0000000..4461e88 --- /dev/null +++ b/include/sqlgen/transpilation/to_sets.hpp @@ -0,0 +1,69 @@ +#ifndef SQLGEN_TRANSPILATION_TO_SETS_HPP_ +#define SQLGEN_TRANSPILATION_TO_SETS_HPP_ + +#include +#include +#include +#include +#include + +#include "../Result.hpp" +#include "../col.hpp" +#include "../dynamic/Table.hpp" +#include "../dynamic/Update.hpp" +#include "Set.hpp" +#include "all_columns_exist.hpp" +#include "get_schema.hpp" +#include "get_tablename.hpp" +#include "to_condition.hpp" +#include "to_sets.hpp" +#include "to_value.hpp" + +namespace sqlgen::transpilation { + +template +struct ToSet; + +template +struct ToSet, ToType>> { + static_assert(all_columns_exist>(), "All columns must exist."); + + dynamic::Update::Set operator()(const auto& _set) const { + return dynamic::Update::Set{ + .col = dynamic::Column{.name = _name.str()}, + .to = to_value(_set.to), + }; + } +}; + +template +struct ToSet, Col<_name2>>> { + static_assert(all_columns_exist>(), "All columns must exist."); + static_assert(all_columns_exist>(), "All columns must exist."); + + dynamic::Update::Set operator()(const auto& _set) const { + return dynamic::Update::Set{ + .col = dynamic::Column{.name = _name1.str()}, + .to = dynamic::Column{.name = _name2.str()}, + }; + } +}; + +template +dynamic::Update::Set to_set(const SetType& _set) { + return ToSet, std::remove_cvref_t>{}(_set); +} + +template +std::vector to_sets(const SetsType& _sets) { + return rfl::apply( + [](const auto&... _s) { + return std::vector({to_set(_s)...}); + }, + _sets); +} + +} // namespace sqlgen::transpilation + +#endif diff --git a/include/sqlgen/transpilation/to_sql.hpp b/include/sqlgen/transpilation/to_sql.hpp index 6855a55..80b0c78 100644 --- a/include/sqlgen/transpilation/to_sql.hpp +++ b/include/sqlgen/transpilation/to_sql.hpp @@ -9,11 +9,13 @@ #include "../drop.hpp" #include "../dynamic/Statement.hpp" #include "../read.hpp" +#include "../update.hpp" #include "to_create_table.hpp" #include "to_delete_from.hpp" #include "to_drop.hpp" #include "to_insert.hpp" #include "to_select_from.hpp" +#include "to_update.hpp" #include "value_t.hpp" namespace sqlgen::transpilation { @@ -56,6 +58,13 @@ struct ToSQL> { } }; +template +struct ToSQL> { + dynamic::Statement operator()(const auto& _update) const { + return to_update(_update.sets_, _update.where_); + } +}; + template dynamic::Statement to_sql(const T& _t) { return ToSQL>{}(_t); diff --git a/include/sqlgen/transpilation/to_update.hpp b/include/sqlgen/transpilation/to_update.hpp new file mode 100644 index 0000000..9ecceb5 --- /dev/null +++ b/include/sqlgen/transpilation/to_update.hpp @@ -0,0 +1,31 @@ +#ifndef SQLGEN_TRANSPILATION_TO_UPDATE_HPP_ +#define SQLGEN_TRANSPILATION_TO_UPDATE_HPP_ + +#include +#include +#include +#include + +#include "../Result.hpp" +#include "../dynamic/Table.hpp" +#include "../dynamic/Update.hpp" +#include "get_schema.hpp" +#include "get_tablename.hpp" +#include "to_condition.hpp" +#include "to_sets.hpp" + +namespace sqlgen::transpilation { + +template + requires std::is_class_v> && + std::is_aggregate_v> +dynamic::Update to_update(const SetsType& _sets, const WhereType& _where) { + return dynamic::Update{.table = dynamic::Table{.name = get_tablename(), + .schema = get_schema()}, + .sets = to_sets(_sets), + .where = to_condition>(_where)}; +} + +} // namespace sqlgen::transpilation + +#endif diff --git a/include/sqlgen/update.hpp b/include/sqlgen/update.hpp new file mode 100644 index 0000000..5944a27 --- /dev/null +++ b/include/sqlgen/update.hpp @@ -0,0 +1,54 @@ +#ifndef SQLGEN_UPDATE_HPP_ +#define SQLGEN_UPDATE_HPP_ + +#include +#include + +#include "Connection.hpp" +#include "Ref.hpp" +#include "Result.hpp" +#include "transpilation/to_update.hpp" + +namespace sqlgen { + +template +Result update_impl(const Ref& _conn, const SetsType& _sets, + const WhereType& _where) { + const auto query = + transpilation::to_update(_sets, _where); + return _conn->execute(_conn->to_sql(query)); +} + +template +Result update_impl(const Result>& _res, + const SetsType& _sets, const WhereType& _where) { + return _res.and_then([&](const auto& _conn) { + return update_impl(_conn, _sets, _where); + }); +} + +template +struct Update { + Result operator()(const auto& _conn) const noexcept { + try { + return update_impl(_conn, sets_, where_); + } catch (std::exception& e) { + return error(e.what()); + } + } + + SetsType sets_; + + WhereType where_; +}; + +template +inline auto update(const SetsType&... _sets) { + static_assert(sizeof...(_sets) > 0, "You must update at least one column."); + using TupleType = rfl::Tuple...>; + return Update{.sets_ = TupleType(_sets...)}; +}; + +} // namespace sqlgen + +#endif diff --git a/include/sqlgen/where.hpp b/include/sqlgen/where.hpp index 0cf36d0..1c64a67 100644 --- a/include/sqlgen/where.hpp +++ b/include/sqlgen/where.hpp @@ -8,6 +8,7 @@ #include "read.hpp" #include "transpilation/Limit.hpp" #include "transpilation/value_t.hpp" +#include "update.hpp" namespace sqlgen { @@ -40,6 +41,16 @@ auto operator|(const Read& _r, .where_ = _where.condition}; } +template +auto operator|(const Update& _u, + const Where& _where) { + static_assert(std::is_same_v, + "You cannot call where(...) twice (but you can apply more " + "than one condition by combining them with && or ||)."); + return Update{.sets_ = _u.sets_, + .where_ = _where.condition}; +} + template inline auto where(const ConditionType& _cond) { return Where>{.condition = _cond}; diff --git a/src/sqlgen/postgres/to_sql.cpp b/src/sqlgen/postgres/to_sql.cpp index ecbab73..bf6a74b 100644 --- a/src/sqlgen/postgres/to_sql.cpp +++ b/src/sqlgen/postgres/to_sql.cpp @@ -37,6 +37,8 @@ std::string select_from_to_sql(const dynamic::SelectFrom& _stmt) noexcept; std::string type_to_sql(const dynamic::Type& _type) noexcept; +std::string update_to_sql(const dynamic::Update& _stmt) noexcept; + // ---------------------------------------------------------------------------- inline std::string get_name(const dynamic::Column& _col) { return _col.name; } @@ -268,6 +270,7 @@ std::string select_from_to_sql(const dynamic::SelectFrom& _stmt) noexcept { std::string to_sql_impl(const dynamic::Statement& _stmt) noexcept { return _stmt.visit([&](const auto& _s) -> std::string { using S = std::remove_cvref_t; + if constexpr (std::is_same_v) { return create_table_to_sql(_s); @@ -283,6 +286,9 @@ std::string to_sql_impl(const dynamic::Statement& _stmt) noexcept { } else if constexpr (std::is_same_v) { return select_from_to_sql(_s); + } else if constexpr (std::is_same_v) { + return update_to_sql(_s); + } else { static_assert(rfl::always_false_v, "Unsupported type."); } @@ -325,4 +331,35 @@ std::string type_to_sql(const dynamic::Type& _type) noexcept { }); } +std::string update_to_sql(const dynamic::Update& _stmt) noexcept { + using namespace std::ranges::views; + + const auto to_str = [](const auto& _set) -> std::string { + return wrap_in_quotes(_set.col.name) + " = " + + column_or_value_to_sql(_set.to); + }; + + std::stringstream stream; + + stream << "UPDATE "; + + if (_stmt.table.schema) { + stream << wrap_in_quotes(*_stmt.table.schema) << "."; + } + stream << wrap_in_quotes(_stmt.table.name); + + stream << " SET "; + + stream << internal::strings::join( + ", ", internal::collect::vector(_stmt.sets | transform(to_str))); + + if (_stmt.where) { + stream << " WHERE " << condition_to_sql(*_stmt.where); + } + + stream << ";"; + + return stream.str(); +} + } // namespace sqlgen::postgres diff --git a/src/sqlgen/sqlite/to_sql.cpp b/src/sqlgen/sqlite/to_sql.cpp index cc4eed6..aaae05d 100644 --- a/src/sqlgen/sqlite/to_sql.cpp +++ b/src/sqlgen/sqlite/to_sql.cpp @@ -32,6 +32,10 @@ std::string select_from_to_sql(const dynamic::SelectFrom& _stmt) noexcept; std::string type_to_sql(const dynamic::Type& _type) noexcept; +std::string update_to_sql(const dynamic::Update& _stmt) noexcept; + +// ---------------------------------------------------------------------------- + std::string column_or_value_to_sql( const dynamic::ColumnOrValue& _col) noexcept { const auto handle_value = [](const auto& _v) -> std::string { @@ -270,6 +274,9 @@ std::string to_sql_impl(const dynamic::Statement& _stmt) noexcept { } else if constexpr (std::is_same_v) { return select_from_to_sql(_s); + } else if constexpr (std::is_same_v) { + return update_to_sql(_s); + } else { static_assert(rfl::always_false_v, "Unsupported type."); } @@ -304,4 +311,34 @@ std::string type_to_sql(const dynamic::Type& _type) noexcept { }); } +std::string update_to_sql(const dynamic::Update& _stmt) noexcept { + using namespace std::ranges::views; + + const auto to_str = [](const auto& _set) -> std::string { + return "\"" + _set.col.name + "\" = " + column_or_value_to_sql(_set.to); + }; + + std::stringstream stream; + + stream << "UPDATE "; + + if (_stmt.table.schema) { + stream << "\"" << *_stmt.table.schema << "\"."; + } + stream << "\"" << _stmt.table.name << "\""; + + stream << " SET "; + + stream << internal::strings::join( + ", ", internal::collect::vector(_stmt.sets | transform(to_str))); + + if (_stmt.where) { + stream << " WHERE " << condition_to_sql(*_stmt.where); + } + + stream << ";"; + + return stream.str(); +} + } // namespace sqlgen::sqlite diff --git a/tests/postgres/test_update_dry.cpp b/tests/postgres/test_update_dry.cpp new file mode 100644 index 0000000..1266282 --- /dev/null +++ b/tests/postgres/test_update_dry.cpp @@ -0,0 +1,27 @@ +#include + +#include +#include + +namespace test_update_dry { + +struct TestTable { + std::string field1; + int32_t field2; + sqlgen::PrimaryKey id; + std::optional nullable; +}; + +TEST(postgres, test_update_dry) { + using namespace sqlgen; + + const auto query = + update("field1"_c.set("Hello"), "nullable"_c.set("field1"_c)) | + where("field2"_c > 0); + + const auto expected = + R"(UPDATE "TestTable" SET "field1" = 'Hello', "nullable" = "field1" WHERE "field2" > 0;)"; + + EXPECT_EQ(sqlgen::postgres::to_sql(query), expected); +} +} // namespace test_update_dry diff --git a/tests/sqlite/test_update.cpp b/tests/sqlite/test_update.cpp new file mode 100644 index 0000000..b41e360 --- /dev/null +++ b/tests/sqlite/test_update.cpp @@ -0,0 +1,51 @@ +#include + +#include +#include +#include +#include +#include + +namespace test_update { + +struct Person { + sqlgen::PrimaryKey id; + std::string first_name; + std::string last_name; + int age; +}; + +TEST(sqlite, test_update) { + const auto people1 = std::vector( + {Person{ + .id = 0, .first_name = "Homer", .last_name = "Simpson", .age = 45}, + Person{.id = 1, .first_name = "Bart", .last_name = "Simpson", .age = 10}, + Person{.id = 2, .first_name = "Lisa", .last_name = "Simpson", .age = 8}, + Person{ + .id = 3, .first_name = "Maggie", .last_name = "Simpson", .age = 0}, + Person{ + .id = 4, .first_name = "Hugo", .last_name = "Simpson", .age = 10}}); + + const auto conn = sqlgen::sqlite::connect(); + + sqlgen::write(conn, people1); + + using namespace sqlgen; + + const auto query = + update("first_name"_c.set("last_name"_c), "age"_c.set(100)) | + where("first_name"_c == "Hugo"); + + query(conn).value(); + + const auto people2 = sqlgen::read>(conn).value(); + + std::cout << rfl::json::write(people2) << std::endl; + + const std::string expected = + R"([{"id":0,"first_name":"Homer","last_name":"Simpson","age":45},{"id":1,"first_name":"Bart","last_name":"Simpson","age":10},{"id":2,"first_name":"Lisa","last_name":"Simpson","age":8},{"id":3,"first_name":"Maggie","last_name":"Simpson","age":0},{"id":4,"first_name":"Simpson","last_name":"Simpson","age":100}])"; + + EXPECT_EQ(rfl::json::write(people2), expected); +} + +} // namespace test_update