Improve get_last_insert_id() for SQLite

Return the ID used for the given table instead of the last ID inserted
in any database table, which could be very different.

This relies on using internal sqlite_sequence table that is used by
AUTOINCREMENT.

Co-authored-by: Vadim Zeitlin <vz-soci@zeitlins.org>

Closes #971.
This commit is contained in:
Cosmin Cremarenco
2022-06-15 22:05:28 +02:00
committed by Vadim Zeitlin
parent ae3ff29216
commit 9507d084cd
3 changed files with 248 additions and 7 deletions

View File

@@ -329,6 +329,10 @@ struct sqlite3_session_backend : details::session_backend
}
sqlite_api::sqlite3 *conn_;
// This flag is set to true if the internal sqlite_sequence table exists in
// the database.
bool sequence_table_exists_;
};
struct sqlite3_backend_factory : backend_factory

View File

@@ -10,6 +10,8 @@
#include "soci/connection-parameters.h"
#include "soci-cstrtoi.h"
#include <sstream>
#include <string>
@@ -24,11 +26,14 @@ using namespace sqlite_api;
namespace // anonymous
{
// helper function for hardcoded queries
void execude_hardcoded(sqlite_api::sqlite3* conn, char const* const query, char const* const errMsg)
// helper function for hardcoded queries: this is a simple wrapper for
// sqlite3_exec() which throws an exception on error.
void execude_hardcoded(sqlite_api::sqlite3* conn, char const* const query, char const* const errMsg,
int (*callback)(void*, int, char**, char**) = NULL,
void* callback_arg = NULL)
{
char *zErrMsg = 0;
int const res = sqlite3_exec(conn, query, 0, 0, &zErrMsg);
int const res = sqlite3_exec(conn, query, callback, callback_arg, &zErrMsg);
if (res != SQLITE_OK)
{
std::ostringstream ss;
@@ -52,9 +57,31 @@ void check_sqlite_err(sqlite_api::sqlite3* conn, int res, char const* const errM
} // namespace anonymous
static int sequence_table_exists_callback(void* ctxt, int result_columns, char**, char**)
{
bool* const flag = static_cast<bool*>(ctxt);
*flag = result_columns > 0;
return 0;
}
static bool check_if_sequence_table_exists(sqlite_api::sqlite3* conn)
{
bool sequence_table_exists = false;
execude_hardcoded
(
conn,
"select name from sqlite_master where type='table' and name='sqlite_sequence'",
"Failed checking if the sqlite_sequence table exists",
&sequence_table_exists_callback,
&sequence_table_exists
);
return sequence_table_exists;
}
sqlite3_session_backend::sqlite3_session_backend(
connection_parameters const & parameters)
: sequence_table_exists_(false)
{
int timeout = 0;
int connection_flags = SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE;
@@ -139,7 +166,6 @@ sqlite3_session_backend::sqlite3_session_backend(
res = sqlite3_busy_timeout(conn_, timeout * 1000);
check_sqlite_err(conn_, res, "Failed to set busy timeout for connection. ");
}
sqlite3_session_backend::~sqlite3_session_backend()
@@ -162,10 +188,80 @@ void sqlite3_session_backend::rollback()
execude_hardcoded(conn_, "ROLLBACK", "Cannot rollback transaction.");
}
bool sqlite3_session_backend::get_last_insert_id(
session & /* s */, std::string const & /* table */, long long & value)
// Argument passed to store_single_value_callback(), which is used to retrieve
// a single numeric value from a hardcoded query.
struct single_value_callback_ctx
{
value = static_cast<long long>(sqlite3_last_insert_rowid(conn_));
single_value_callback_ctx() : valid_(false) {}
long long value_;
bool valid_;
};
static int store_single_value_callback(void* ctx, int result_columns, char** values, char**)
{
single_value_callback_ctx* arg = static_cast<single_value_callback_ctx*>(ctx);
if (result_columns == 1 && values[0])
{
if (cstring_to_integer(arg->value_, values[0]))
arg->valid_ = true;
}
return 0;
}
static std::string sanitize_table_name(std::string const& table)
{
std::string ret;
ret.reserve(table.length());
for (std::string::size_type pos = 0; pos < table.size(); ++pos)
{
if (isspace(table[pos]))
throw sqlite3_soci_error("Table name must not contain whitespace", 0);
const char c = table[pos];
ret += c;
if (c == '\'')
ret += '\'';
else if (c == '\"')
ret += '\"';
}
return ret;
}
bool sqlite3_session_backend::get_last_insert_id(
session &, std::string const & table, long long & value)
{
single_value_callback_ctx ctx;
if (sequence_table_exists_ || check_if_sequence_table_exists(conn_))
{
// Once the sqlite_sequence table is created (because of a column marked AUTOINCREMENT)
// it can never be dropped, so don't search for it again.
sequence_table_exists_ = true;
std::string const query =
"select seq from sqlite_sequence where name ='" + sanitize_table_name(table) + "'";
execude_hardcoded(conn_, query.c_str(), "Unable to get value in sqlite_sequence",
&store_single_value_callback, &ctx);
// The value will not be filled if the callback was never called.
// It may mean either that nothing was inserted yet into the table
// that has an AUTOINCREMENT column, or that the table does not have an AUTOINCREMENT
// column.
if (ctx.valid_)
{
value = ctx.value_;
return true;
}
}
// Fall-back just get the maximum rowid of what was already inserted in the
// table. This has the disadvantage that if rows were deleted, then ids may be re-used.
// But, if one cares about that, AUTOINCREMENT should be used anyway.
std::string const max_rowid_query = "select max(rowid) from \"" + sanitize_table_name(table) + "\"";
execude_hardcoded(conn_, max_rowid_query.c_str(), "Unable to get max rowid",
&store_single_value_callback, &ctx);
value = ctx.valid_ ? ctx.value_ : 0;
return true;
}

View File

@@ -114,6 +114,147 @@ TEST_CASE("SQLite foreign keys are enabled by foreign_keys option", "[sqlite][fo
"\"delete from parent where id = 1\".");
}
class SetupAutoIncrementTable
{
public:
SetupAutoIncrementTable(soci::session& sql)
: m_sql(sql)
{
m_sql <<
"create table t("
" id integer primary key autoincrement,"
" name text"
")";
}
~SetupAutoIncrementTable()
{
m_sql << "drop table t";
}
private:
SetupAutoIncrementTable(const SetupAutoIncrementTable&);
SetupAutoIncrementTable& operator=(const SetupAutoIncrementTable&);
soci::session& m_sql;
};
TEST_CASE("SQLite get_last_insert_id works with AUTOINCREMENT",
"[sqlite][rowid]")
{
soci::session sql(backEnd, connectString);
SetupAutoIncrementTable createTable(sql);
sql << "insert into t(name) values('x')";
sql << "insert into t(name) values('y')";
long long val;
sql.get_last_insert_id("t", val);
CHECK(val == 2);
}
TEST_CASE("SQLite get_last_insert_id with AUTOINCREMENT does not reuse IDs when rows deleted",
"[sqlite][rowid]")
{
soci::session sql(backEnd, connectString);
SetupAutoIncrementTable createTable(sql);
sql << "insert into t(name) values('x')";
sql << "insert into t(name) values('y')";
sql << "delete from t where id = 2";
long long val;
sql.get_last_insert_id("t", val);
CHECK(val == 2);
}
class SetupNoAutoIncrementTable
{
public:
SetupNoAutoIncrementTable(soci::session& sql)
: m_sql(sql)
{
m_sql <<
"create table t("
" id integer primary key,"
" name text"
")";
}
~SetupNoAutoIncrementTable()
{
m_sql << "drop table t";
}
private:
SetupNoAutoIncrementTable(const SetupNoAutoIncrementTable&);
SetupNoAutoIncrementTable& operator=(const SetupNoAutoIncrementTable&);
soci::session& m_sql;
};
TEST_CASE("SQLite get_last_insert_id without AUTOINCREMENT reuses IDs when rows deleted",
"[sqlite][rowid]")
{
soci::session sql(backEnd, connectString);
SetupNoAutoIncrementTable createTable(sql);
sql << "insert into t(name) values('x')";
sql << "insert into t(name) values('y')";
sql << "delete from t where id = 2";
long long val;
sql.get_last_insert_id("t", val);
CHECK(val == 1);
}
TEST_CASE("SQLite get_last_insert_id throws if table not found",
"[sqlite][rowid]")
{
soci::session sql(backEnd, connectString);
long long val;
CHECK_THROWS(sql.get_last_insert_id("notexisting", val));
}
class SetupTableWithDoubleQuoteInName
{
public:
SetupTableWithDoubleQuoteInName(soci::session& sql)
: m_sql(sql)
{
m_sql <<
"create table \"t\"\"fff\"("
" id integer primary key,"
" name text"
")";
}
~SetupTableWithDoubleQuoteInName()
{
m_sql << "drop table \"t\"\"fff\"";
}
private:
SetupTableWithDoubleQuoteInName(const SetupTableWithDoubleQuoteInName&);
SetupTableWithDoubleQuoteInName& operator=(const SetupTableWithDoubleQuoteInName&);
soci::session& m_sql;
};
TEST_CASE("SQLite get_last_insert_id escapes table name",
"[sqlite][rowid]")
{
soci::session sql(backEnd, connectString);
SetupTableWithDoubleQuoteInName table(sql);
long long val;
sql.get_last_insert_id("t\"fff", val);
CHECK(val == 0);
}
// BLOB test
struct blob_table_creator : public table_creator_base
{