Skip to content

Commit

Permalink
finish refactoring of database lookup into slurm-free backend; fix so…
Browse files Browse the repository at this point in the history
…me corner cases in image selection
  • Loading branch information
bcumming committed May 17, 2024
1 parent e6fb5c5 commit e8faeb2
Show file tree
Hide file tree
Showing 9 changed files with 249 additions and 190 deletions.
29 changes: 29 additions & 0 deletions TODO.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# overview

## current

`mount`
- `is_valid_mount_point`
- `do_mount`

`util/helper`
- split
- is_sha
- is_file

`sqlite`:
The sqlite/sqlite.hpp header exposes an api around sqlite

`parse_args`:
- the `find_repo_image` function that returs the path of the image should be moved to the database

## proposed


`sqlite`:
- refactor so that it exposes an API around the database - we shouldn't leak underlying database implementation to the calling scope
- rename `db`

## sqlite


3 changes: 2 additions & 1 deletion meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ shared_module('slurm-uenv-mount',
sources: ['src/plugin.cpp',
'src/mount.cpp',
'src/parse_args.cpp',
'src/sqlite/sqlite.cpp',
'src/database/database.cpp',
'src/database/sqlite.cpp',
'src/util/strings.cpp',
'src/util/filesystem.cpp'],
dependencies: [libmount_dep, sqlite3_dep],
Expand Down
125 changes: 125 additions & 0 deletions src/database/database.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#include <algorithm>
#include <sstream>
#include <string>
#include <vector>

#include "database.hpp"
#include "sqlite.hpp"

#include "../util/filesystem.hpp"

namespace db {

struct uenv_result {
std::string name;
std::string version;
std::string tag;
std::string sha;

uenv_result() = delete;

uenv_result(std::string name, std::string version, std::string tag,
std::string sha);
uenv_result(SQLiteStatement &result)
: uenv_result(result.getColumn(result.getColumnIndex("name")),
result.getColumn(result.getColumnIndex("sha256")),
result.getColumn(result.getColumnIndex("tag")),
result.getColumn(result.getColumnIndex("version"))) {}
friend bool operator<(const uenv_result &, const uenv_result &);
};

util::expected<std::string, std::string>
find_image(const uenv_desc &desc, const std::string &repo_path,
std::optional<std::string> uenv_arch) {
try {
std::string dbpath = repo_path + "/index.db";
// check if dbpath exists.
if (!util::is_file(dbpath)) {
return util::unexpected("Can't open uenv repo. " + dbpath +
" is not a file.");
}
SQLiteDB db(dbpath, sqlite_open::readonly);

// get all results
std::vector<uenv_result> results;
if (desc.sha) {
if (desc.sha.value().size() < 64) {
SQLiteStatement query(db, "SELECT * FROM records WHERE id = :id");
query.bind(":id", desc.sha.value());
while (query.execute()) {
results.emplace_back(query);
}
} else {
SQLiteStatement query(db, "SELECT * FROM records WHERE sha256 = :sha");
query.bind(":sha", desc.sha.value());
while (query.execute()) {
results.emplace_back(query);
}
}
} else {
std::string query_str = "SELECT * FROM records WHERE ";
std::vector<std::string> filter;
if (uenv_arch) {
filter.push_back("uarch");
}
if (desc.name) {
filter.push_back("name");
}
if (desc.version) {
filter.push_back("version");
}
if (desc.tag) {
filter.push_back("tag");
}
for (size_t i = 0; i < filter.size(); ++i) {
if (i > 0) {
query_str += " AND ";
}
query_str += filter[i] + " = " + ":" + filter[i];
}
SQLiteStatement query(db, query_str);
if (uenv_arch) {
query.bind(":uarch", uenv_arch.value());
}
if (desc.name) {
query.bind(":name", desc.name.value());
}
if (desc.version) {
query.bind(":version", desc.version.value());
}
if (desc.tag) {
query.bind(":tag", desc.tag.value());
}
while (query.execute()) {
results.emplace_back(query);
}
}

// sort the results by sha, and build a list of unique sha
std::sort(results.begin(), results.end(),
[](auto &lhs, auto &rhs) { return lhs.sha < rhs.sha; });
std::set<std::string> shas;
for (const auto &r : results) {
shas.insert(r.sha);
}
if (shas.size() > 1) {
std::stringstream ss;
ss << "More than one uenv matches.\n";
for (auto &d : results) {
ss << d.name << "/" << d.version << ":" << d.tag << "\t" << d.sha
<< "\n";
}
return util::unexpected(ss.str());
}
if (results.empty()) {
return util::unexpected("No uenv matches the request. Run `uenv image "
"ls` to list available images.");
}
return repo_path + "/images/" + results[0].sha + "/store.squashfs";
} catch (const SQLiteError &e) {
return util::unexpected(std::string("internal database error: ") +
e.what());
}
}

} // namespace db
24 changes: 24 additions & 0 deletions src/database/database.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#pragma once

#include <optional>
#include <set>
#include <string>
#include <vector>

#include "../util/expected.hpp"

namespace db {

struct uenv_desc {
using entry_t = std::optional<std::string>;
entry_t name;
entry_t version;
entry_t tag;
entry_t sha;
};

util::expected<std::string, std::string>
find_image(const uenv_desc &desc, const std::string &repo_path,
std::optional<std::string> uenv_arch);

} // namespace db
31 changes: 12 additions & 19 deletions src/sqlite/sqlite.cpp → src/database/sqlite.cpp
Original file line number Diff line number Diff line change
@@ -1,20 +1,11 @@
#include "sqlite.hpp"
#include <exception>

#include <map>
#include <sqlite3.h>

std::map<sqlite_open, int> sqlite_oflag = {
{sqlite_open::readonly, SQLITE_OPEN_READONLY}};

class SQLiteError : public std::exception {
public:
SQLiteError(const std::string &msg) : msg(msg) {}
const char *what() const noexcept override { return msg.c_str(); }

private:
std::string msg;
};

SQLiteDB::SQLiteDB(const std::string &fname, sqlite_open flag) {
int rc =
sqlite3_open_v2(fname.c_str(), &this->db, sqlite_oflag.at(flag), NULL);
Expand All @@ -26,10 +17,10 @@ SQLiteDB::SQLiteDB(const std::string &fname, sqlite_open flag) {
SQLiteDB::~SQLiteDB() { sqlite3_close(this->db); }

/// SQLiteColumn
SQLiteColumn::SQLiteColumn(SQLiteStatement &statement, int index)
SQLiteColumn::SQLiteColumn(const SQLiteStatement &statement, int index)
: statement(statement), index(index) {}

std::string SQLiteColumn::getColumnName() const {
std::string SQLiteColumn::name() const {
return sqlite3_column_name(this->statement.stmt, this->index);
}

Expand Down Expand Up @@ -58,18 +49,20 @@ SQLiteStatement::SQLiteStatement(SQLiteDB &db, const std::string &query)
column_count = sqlite3_column_count(this->stmt);
}

int SQLiteStatement::getColumnIndex(const std::string &name) {
int SQLiteStatement::getColumnIndex(const std::string &name) const {
for (int i = 0; i < this->column_count; ++i) {
if (this->getColumn(i).getColumnName() == name)
if (this->getColumn(i).name() == name)
return i;
}
return -1;
}

void SQLiteStatement::bind(const std::string& name, const std::string& value) {
void SQLiteStatement::bind(const std::string &name, const std::string &value) {
int i = sqlite3_bind_parameter_index(this->stmt, name.c_str());
if(sqlite3_bind_text(this->stmt, i, value.c_str(), -1, SQLITE_STATIC) != SQLITE_OK) {
throw SQLiteError(std::string("Failed to bind parameter: ") + sqlite3_errmsg(this->db.get()));
if (sqlite3_bind_text(this->stmt, i, value.c_str(), -1, SQLITE_STATIC) !=
SQLITE_OK) {
throw SQLiteError(std::string("Failed to bind parameter: ") +
sqlite3_errmsg(this->db.get()));
}
}

Expand All @@ -79,7 +72,7 @@ void SQLiteStatement::checkIndex(int i) const {
}
}

std::string SQLiteStatement::getColumnType(int i) {
std::string SQLiteStatement::getColumnType(int i) const {
checkIndex(i);
const char *result = sqlite3_column_decltype(this->stmt, i);
if (!result) {
Expand All @@ -89,7 +82,7 @@ std::string SQLiteStatement::getColumnType(int i) {
}
}

SQLiteColumn SQLiteStatement::getColumn(int i) {
SQLiteColumn SQLiteStatement::getColumn(int i) const {
checkIndex(i);
if (this->rc != SQLITE_ROW) {
throw SQLiteError("Statement invalid");
Expand Down
28 changes: 19 additions & 9 deletions src/sqlite/sqlite.hpp → src/database/sqlite.hpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
#pragma once

#include <exception>
#include <string>

struct sqlite3_stmt;
struct sqlite3;

enum class sqlite_open : int { readonly };

class SQLiteError;
class SQLiteError : public std::exception {
public:
SQLiteError(const std::string &msg) : msg(msg) {}
const char *what() const noexcept override { return msg.c_str(); }

private:
std::string msg;
};

class SQLiteStatement;

Expand All @@ -32,10 +42,10 @@ class SQLiteStatement {
SQLiteStatement(SQLiteDB &db, const std::string &query);
SQLiteStatement(const SQLiteStatement &) = delete;
SQLiteStatement operator=(const SQLiteStatement &) = delete;
std::string getColumnType(int i);
SQLiteColumn getColumn(int i);
int getColumnIndex(const std::string &name);
void bind(const std::string& name, const std::string& value);
std::string getColumnType(int i) const;
SQLiteColumn getColumn(int i) const;
int getColumnIndex(const std::string &name) const;
void bind(const std::string &name, const std::string &value);
bool execute();

virtual ~SQLiteStatement();
Expand All @@ -53,12 +63,12 @@ class SQLiteStatement {

class SQLiteColumn {
public:
SQLiteColumn(SQLiteStatement &statement, int index);
std::string getColumnName() const;
SQLiteColumn(const SQLiteStatement &statement, int index);
std::string name() const;
operator int() const;
operator std::string() const;

private:
SQLiteStatement &statement;
int index;
const SQLiteStatement &statement;
const int index;
};
Loading

0 comments on commit e8faeb2

Please sign in to comment.