From e8faeb2510cab5e1696082899211ead96515921e Mon Sep 17 00:00:00 2001 From: bcumming Date: Fri, 17 May 2024 07:33:26 +0200 Subject: [PATCH] finish refactoring of database lookup into slurm-free backend; fix some corner cases in image selection --- TODO.md | 29 +++++ meson.build | 3 +- src/database/database.cpp | 125 +++++++++++++++++++++ src/database/database.hpp | 24 ++++ src/{sqlite => database}/sqlite.cpp | 31 ++---- src/{sqlite => database}/sqlite.hpp | 28 +++-- src/parse_args.cpp | 165 ++-------------------------- src/util/strings.cpp | 6 +- src/util/strings.hpp | 28 +++++ 9 files changed, 249 insertions(+), 190 deletions(-) create mode 100644 TODO.md create mode 100644 src/database/database.cpp create mode 100644 src/database/database.hpp rename src/{sqlite => database}/sqlite.cpp (74%) rename src/{sqlite => database}/sqlite.hpp (62%) diff --git a/TODO.md b/TODO.md new file mode 100644 index 0000000..6571303 --- /dev/null +++ b/TODO.md @@ -0,0 +1,29 @@ +# overview + +## current + +`mount` +- `is_valid_mount_point` +- `do_mount` + +`util/helper` +- split +- is_sha +- is_file + +`sqlite`: +The sqlite/sqlite.hpp header exposes an api around sqlite + +`parse_args`: +- the `find_repo_image` function that returs the path of the image should be moved to the database + +## proposed + + +`sqlite`: +- refactor so that it exposes an API around the database - we shouldn't leak underlying database implementation to the calling scope +- rename `db` + +## sqlite + + diff --git a/meson.build b/meson.build index 6da7213..542bc61 100644 --- a/meson.build +++ b/meson.build @@ -19,7 +19,8 @@ shared_module('slurm-uenv-mount', sources: ['src/plugin.cpp', 'src/mount.cpp', 'src/parse_args.cpp', - 'src/sqlite/sqlite.cpp', + 'src/database/database.cpp', + 'src/database/sqlite.cpp', 'src/util/strings.cpp', 'src/util/filesystem.cpp'], dependencies: [libmount_dep, sqlite3_dep], diff --git a/src/database/database.cpp b/src/database/database.cpp new file mode 100644 index 0000000..bda7525 --- /dev/null +++ b/src/database/database.cpp @@ -0,0 +1,125 @@ +#include +#include +#include +#include + +#include "database.hpp" +#include "sqlite.hpp" + +#include "../util/filesystem.hpp" + +namespace db { + +struct uenv_result { + std::string name; + std::string version; + std::string tag; + std::string sha; + + uenv_result() = delete; + + uenv_result(std::string name, std::string version, std::string tag, + std::string sha); + uenv_result(SQLiteStatement &result) + : uenv_result(result.getColumn(result.getColumnIndex("name")), + result.getColumn(result.getColumnIndex("sha256")), + result.getColumn(result.getColumnIndex("tag")), + result.getColumn(result.getColumnIndex("version"))) {} + friend bool operator<(const uenv_result &, const uenv_result &); +}; + +util::expected +find_image(const uenv_desc &desc, const std::string &repo_path, + std::optional uenv_arch) { + try { + std::string dbpath = repo_path + "/index.db"; + // check if dbpath exists. + if (!util::is_file(dbpath)) { + return util::unexpected("Can't open uenv repo. " + dbpath + + " is not a file."); + } + SQLiteDB db(dbpath, sqlite_open::readonly); + + // get all results + std::vector results; + if (desc.sha) { + if (desc.sha.value().size() < 64) { + SQLiteStatement query(db, "SELECT * FROM records WHERE id = :id"); + query.bind(":id", desc.sha.value()); + while (query.execute()) { + results.emplace_back(query); + } + } else { + SQLiteStatement query(db, "SELECT * FROM records WHERE sha256 = :sha"); + query.bind(":sha", desc.sha.value()); + while (query.execute()) { + results.emplace_back(query); + } + } + } else { + std::string query_str = "SELECT * FROM records WHERE "; + std::vector filter; + if (uenv_arch) { + filter.push_back("uarch"); + } + if (desc.name) { + filter.push_back("name"); + } + if (desc.version) { + filter.push_back("version"); + } + if (desc.tag) { + filter.push_back("tag"); + } + for (size_t i = 0; i < filter.size(); ++i) { + if (i > 0) { + query_str += " AND "; + } + query_str += filter[i] + " = " + ":" + filter[i]; + } + SQLiteStatement query(db, query_str); + if (uenv_arch) { + query.bind(":uarch", uenv_arch.value()); + } + if (desc.name) { + query.bind(":name", desc.name.value()); + } + if (desc.version) { + query.bind(":version", desc.version.value()); + } + if (desc.tag) { + query.bind(":tag", desc.tag.value()); + } + while (query.execute()) { + results.emplace_back(query); + } + } + + // sort the results by sha, and build a list of unique sha + std::sort(results.begin(), results.end(), + [](auto &lhs, auto &rhs) { return lhs.sha < rhs.sha; }); + std::set shas; + for (const auto &r : results) { + shas.insert(r.sha); + } + if (shas.size() > 1) { + std::stringstream ss; + ss << "More than one uenv matches.\n"; + for (auto &d : results) { + ss << d.name << "/" << d.version << ":" << d.tag << "\t" << d.sha + << "\n"; + } + return util::unexpected(ss.str()); + } + if (results.empty()) { + return util::unexpected("No uenv matches the request. Run `uenv image " + "ls` to list available images."); + } + return repo_path + "/images/" + results[0].sha + "/store.squashfs"; + } catch (const SQLiteError &e) { + return util::unexpected(std::string("internal database error: ") + + e.what()); + } +} + +} // namespace db diff --git a/src/database/database.hpp b/src/database/database.hpp new file mode 100644 index 0000000..be21f33 --- /dev/null +++ b/src/database/database.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include +#include +#include +#include + +#include "../util/expected.hpp" + +namespace db { + +struct uenv_desc { + using entry_t = std::optional; + entry_t name; + entry_t version; + entry_t tag; + entry_t sha; +}; + +util::expected +find_image(const uenv_desc &desc, const std::string &repo_path, + std::optional uenv_arch); + +} // namespace db diff --git a/src/sqlite/sqlite.cpp b/src/database/sqlite.cpp similarity index 74% rename from src/sqlite/sqlite.cpp rename to src/database/sqlite.cpp index 538bdce..27d6915 100644 --- a/src/sqlite/sqlite.cpp +++ b/src/database/sqlite.cpp @@ -1,20 +1,11 @@ #include "sqlite.hpp" -#include + #include #include std::map sqlite_oflag = { {sqlite_open::readonly, SQLITE_OPEN_READONLY}}; -class SQLiteError : public std::exception { -public: - SQLiteError(const std::string &msg) : msg(msg) {} - const char *what() const noexcept override { return msg.c_str(); } - -private: - std::string msg; -}; - SQLiteDB::SQLiteDB(const std::string &fname, sqlite_open flag) { int rc = sqlite3_open_v2(fname.c_str(), &this->db, sqlite_oflag.at(flag), NULL); @@ -26,10 +17,10 @@ SQLiteDB::SQLiteDB(const std::string &fname, sqlite_open flag) { SQLiteDB::~SQLiteDB() { sqlite3_close(this->db); } /// SQLiteColumn -SQLiteColumn::SQLiteColumn(SQLiteStatement &statement, int index) +SQLiteColumn::SQLiteColumn(const SQLiteStatement &statement, int index) : statement(statement), index(index) {} -std::string SQLiteColumn::getColumnName() const { +std::string SQLiteColumn::name() const { return sqlite3_column_name(this->statement.stmt, this->index); } @@ -58,18 +49,20 @@ SQLiteStatement::SQLiteStatement(SQLiteDB &db, const std::string &query) column_count = sqlite3_column_count(this->stmt); } -int SQLiteStatement::getColumnIndex(const std::string &name) { +int SQLiteStatement::getColumnIndex(const std::string &name) const { for (int i = 0; i < this->column_count; ++i) { - if (this->getColumn(i).getColumnName() == name) + if (this->getColumn(i).name() == name) return i; } return -1; } -void SQLiteStatement::bind(const std::string& name, const std::string& value) { +void SQLiteStatement::bind(const std::string &name, const std::string &value) { int i = sqlite3_bind_parameter_index(this->stmt, name.c_str()); - if(sqlite3_bind_text(this->stmt, i, value.c_str(), -1, SQLITE_STATIC) != SQLITE_OK) { - throw SQLiteError(std::string("Failed to bind parameter: ") + sqlite3_errmsg(this->db.get())); + if (sqlite3_bind_text(this->stmt, i, value.c_str(), -1, SQLITE_STATIC) != + SQLITE_OK) { + throw SQLiteError(std::string("Failed to bind parameter: ") + + sqlite3_errmsg(this->db.get())); } } @@ -79,7 +72,7 @@ void SQLiteStatement::checkIndex(int i) const { } } -std::string SQLiteStatement::getColumnType(int i) { +std::string SQLiteStatement::getColumnType(int i) const { checkIndex(i); const char *result = sqlite3_column_decltype(this->stmt, i); if (!result) { @@ -89,7 +82,7 @@ std::string SQLiteStatement::getColumnType(int i) { } } -SQLiteColumn SQLiteStatement::getColumn(int i) { +SQLiteColumn SQLiteStatement::getColumn(int i) const { checkIndex(i); if (this->rc != SQLITE_ROW) { throw SQLiteError("Statement invalid"); diff --git a/src/sqlite/sqlite.hpp b/src/database/sqlite.hpp similarity index 62% rename from src/sqlite/sqlite.hpp rename to src/database/sqlite.hpp index ccb8b09..f57cfbb 100644 --- a/src/sqlite/sqlite.hpp +++ b/src/database/sqlite.hpp @@ -1,3 +1,6 @@ +#pragma once + +#include #include struct sqlite3_stmt; @@ -5,7 +8,14 @@ struct sqlite3; enum class sqlite_open : int { readonly }; -class SQLiteError; +class SQLiteError : public std::exception { +public: + SQLiteError(const std::string &msg) : msg(msg) {} + const char *what() const noexcept override { return msg.c_str(); } + +private: + std::string msg; +}; class SQLiteStatement; @@ -32,10 +42,10 @@ class SQLiteStatement { SQLiteStatement(SQLiteDB &db, const std::string &query); SQLiteStatement(const SQLiteStatement &) = delete; SQLiteStatement operator=(const SQLiteStatement &) = delete; - std::string getColumnType(int i); - SQLiteColumn getColumn(int i); - int getColumnIndex(const std::string &name); - void bind(const std::string& name, const std::string& value); + std::string getColumnType(int i) const; + SQLiteColumn getColumn(int i) const; + int getColumnIndex(const std::string &name) const; + void bind(const std::string &name, const std::string &value); bool execute(); virtual ~SQLiteStatement(); @@ -53,12 +63,12 @@ class SQLiteStatement { class SQLiteColumn { public: - SQLiteColumn(SQLiteStatement &statement, int index); - std::string getColumnName() const; + SQLiteColumn(const SQLiteStatement &statement, int index); + std::string name() const; operator int() const; operator std::string() const; private: - SQLiteStatement &statement; - int index; + const SQLiteStatement &statement; + const int index; }; diff --git a/src/parse_args.cpp b/src/parse_args.cpp index deefafc..587bc5b 100644 --- a/src/parse_args.cpp +++ b/src/parse_args.cpp @@ -6,8 +6,8 @@ #include #include "config.hpp" +#include "database/database.hpp" #include "parse_args.hpp" -#include "sqlite/sqlite.hpp" #include "util/expected.hpp" #include "util/filesystem.hpp" #include "util/strings.hpp" @@ -18,14 +18,6 @@ namespace impl { -struct uenv_desc { - using entry_t = std::optional; - entry_t name; - entry_t version; - entry_t tag; - entry_t sha; -}; - const std::regex default_pattern("(" LINUX_ABS_FPATH ")" "(:" LINUX_ABS_FPATH ")?", std::regex::ECMAScript); @@ -38,52 +30,6 @@ const std::regex repo_pattern("(" JFROG_IMAGE ")" "(:" LINUX_ABS_FPATH ")?", std::regex::ECMAScript); -// split a string on a character delimiter -// -// if drop_empty==false (default) -// -// "" -> [""] -// "," -> ["", ""] -// ",," -> ["", "", ""] -// ",a" -> ["", "a"] -// "a," -> ["a", ""] -// "a" -> ["a"] -// "a,b" -> ["a", "b"] -// "a,b,c" -> ["a", "b", "c"] -// "a,b,,c" -> ["a", "b", "", "c"] -// -// if drop_empty==true -// -// "" -> [] -// "," -> [] -// ",," -> [] -// ",a" -> ["a"] -// "a," -> ["a"] -// "a" -> ["a"] -// "a,b" -> ["a", "b"] -// "a,b,c" -> ["a", "b", "c"] -// "a,b,,c" -> ["a", "b", "c"] -std::vector split(const std::string &s, const char delim, - const bool drop_empty = false) { - std::vector results; - - auto pos = s.cbegin(); - auto end = s.cend(); - auto next = std::find(pos, end, delim); - while (next != end) { - if (!drop_empty || pos != next) { - results.emplace_back(pos, next); - } - pos = next + 1; - next = std::find(pos, end, delim); - } - if (!drop_empty || pos != next) { - results.emplace_back(pos, next); - } - - return results; -} - /* return dictionary{"name", "version", "tag", "sha" } from a uenv description string @@ -94,10 +40,10 @@ std::vector split(const std::string &s, const char delim, prgenv_gnu:v2 ->("prgenv_gnu", None, "v2", None) 3313739553fe6553 ->(None, None, None, "3313739553fe6553") */ -uenv_desc parse_uenv_string(const std::string &entry) { - uenv_desc res; +db::uenv_desc parse_uenv_string(const std::string &entry) { + db::uenv_desc res; - if (is_sha(entry)) { + if (util::is_sha(entry)) { res.sha = entry; return res; } @@ -122,106 +68,10 @@ uenv_desc parse_uenv_string(const std::string &entry) { return res; } -uenv_desc to_desc(SQLiteStatement &stmt) { - uenv_desc desc; - desc.name = stmt.getColumn(stmt.getColumnIndex("name")); - desc.sha = stmt.getColumn(stmt.getColumnIndex("sha256")); - desc.tag = stmt.getColumn(stmt.getColumnIndex("tag")); - desc.version = stmt.getColumn(stmt.getColumnIndex("version")); - return desc; -} - -struct cmp { - bool operator()(const uenv_desc &d1, const uenv_desc &d2) const { - return d1.sha < d2.sha; - } -}; - -util::expected -find_repo_image(const uenv_desc &desc, const std::string &repo_path, - std::optional uenv_arch) { - std::string dbpath = repo_path + "/index.db"; - // check if dbpath exists. - if (!util::is_file(dbpath)) { - return util::unexpected("Can't open uenv repo. " + dbpath + - " is not a file."); - } - SQLiteDB db(dbpath, sqlite_open::readonly); - - // get all results - std::set shas; - if (desc.sha) { - if (desc.sha.value().size() < 64) { - SQLiteStatement query(db, "SELECT * FROM records WHERE id = :id"); - query.bind(":id", desc.sha.value()); - while (query.execute()) { - shas.insert(to_desc(query)); - } - } else { - SQLiteStatement query(db, "SELECT * FROM records WHERE sha256 = :sha"); - query.bind(":sha", desc.sha.value()); - while (query.execute()) { - shas.insert(to_desc(query)); - } - } - } else { - std::string query_str = "SELECT * FROM records WHERE "; - std::vector filter; - if (uenv_arch) { - filter.push_back("uarch"); - } - if (desc.name) { - filter.push_back("name"); - } - if (desc.version) { - filter.push_back("version"); - } - if (desc.tag) { - filter.push_back("tag"); - } - for (size_t i = 0; i < filter.size(); ++i) { - if (i > 0) { - query_str += " AND "; - } - query_str += filter[i] + " = " + ":" + filter[i]; - } - SQLiteStatement query(db, query_str); - if (uenv_arch.has_value()) { - query.bind(":uarch", uenv_arch.value()); - } - if (desc.name) { - query.bind(":name", desc.name.value()); - } - if (desc.version) { - query.bind(":version", desc.version.value()); - } - if (desc.tag) { - query.bind(":tag", desc.tag.value()); - } - while (query.execute()) { - shas.insert(to_desc(query)); - } - } - if (shas.size() > 1) { - std::stringstream ss; - ss << "More than one uenv matches.\n"; - for (auto &d : shas) { - ss << d.name.value() << "/" << d.version.value() << ":" << d.tag.value() - << "\t" << d.sha.value() << "\n"; - } - return util::unexpected(ss.str()); - } - if (shas.empty()) { - return util::unexpected( - "No images found. Run `uenv image ls` to list available images."); - } - return repo_path + "/images/" + shas.begin()->sha.value() + "/store.squashfs"; -} - util::expected, std::string> parse_arg(const std::string &arg, std::optional uenv_repo_path, std::optional uenv_arch) { - std::vector arguments = split(arg, ',', true); + std::vector arguments = util::split(arg, ',', true); if (arguments.empty()) { return util::unexpected("No mountpoints given."); @@ -242,14 +92,13 @@ parse_arg(const std::string &arg, std::optional uenv_repo_path, mount_entries.emplace_back(mount_entry{image_path, mount_point}); } else if (std::smatch match; std::regex_match(entry, match, repo_pattern)) { - uenv_desc desc = parse_uenv_string(entry); + auto desc = parse_uenv_string(entry); if (!uenv_repo_path) { return util::unexpected("Attempting to open from uenv repository. But " "either $" UENV_REPO_PATH_VARNAME " or $SCRATCH is not set."); } - auto image_path = - find_repo_image(desc, uenv_repo_path.value(), uenv_arch); + auto image_path = db::find_image(desc, uenv_repo_path.value(), uenv_arch); if (!image_path.has_value()) { return util::unexpected(image_path.error()); } diff --git a/src/util/strings.cpp b/src/util/strings.cpp index af1987a..88e2c03 100644 --- a/src/util/strings.cpp +++ b/src/util/strings.cpp @@ -4,9 +4,7 @@ #include "strings.hpp" -extern "C" { -#include -} +namespace util { std::vector split(const std::string &s, char delim) { std::vector elems; @@ -45,3 +43,5 @@ bool is_sha(const std::string &str) { } return false; } + +} // namespace util diff --git a/src/util/strings.hpp b/src/util/strings.hpp index 8a79d89..85100bd 100644 --- a/src/util/strings.hpp +++ b/src/util/strings.hpp @@ -1,7 +1,35 @@ #include #include +namespace util { + +// split a string on a character delimiter +// +// if drop_empty==false (default) +// +// "" -> [""] +// "," -> ["", ""] +// ",," -> ["", "", ""] +// ",a" -> ["", "a"] +// "a," -> ["a", ""] +// "a" -> ["a"] +// "a,b" -> ["a", "b"] +// "a,b,c" -> ["a", "b", "c"] +// "a,b,,c" -> ["a", "b", "", "c"] +// +// if drop_empty==true +// +// "" -> [] +// "," -> [] +// ",," -> [] +// ",a" -> ["a"] +// "a," -> ["a"] +// "a" -> ["a"] +// "a,b" -> ["a", "b"] +// "a,b,c" -> ["a", "b", "c"] +// "a,b,,c" -> ["a", "b", "c"] std::vector split(const std::string &s, const char delim, const bool drop_empty = false); bool is_sha(const std::string &str); +} // namespace util