Skip to content

Commit

Permalink
feat: add db flagging & checking
Browse files Browse the repository at this point in the history
  • Loading branch information
mookums committed Nov 24, 2023
1 parent 7fadedf commit c1c0fd4
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 38 deletions.
56 changes: 34 additions & 22 deletions src/databases.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ const int32_t PairDistanceKVectorDatabase::kMagicValue = 0x2536f009;
struct KVectorPair {
int16_t index1;
int16_t index2;
float distance;
decimal distance;
};

bool CompareKVectorPairs(const KVectorPair &p1, const KVectorPair &p2) {
Expand All @@ -33,8 +33,8 @@ bool CompareKVectorPairs(const KVectorPair &p1, const KVectorPair &p2) {
| size | name | description |
|---------------+------------+-------------------------------------------------------------|
| 4 | numEntries | |
| sizeof float | min | minimum value contained in the database |
| sizeof float | max | max value contained in index |
| sizeof decimal | min | minimum value contained in the database |
| sizeof decimal | max | max value contained in index |
| 4 | numBins | |
| 4*(numBins+1) | bins | The `i'th bin (starting from zero) stores how many pairs of |
| | | stars have a distance lesst han or equal to: |
Expand All @@ -57,9 +57,9 @@ bool CompareKVectorPairs(const KVectorPair &p1, const KVectorPair &p2) {
* @param numBins the number of "bins" the KVector should use. A higher number makes query results "tighter" but takes up more disk space. Usually should be set somewhat smaller than (max-min) divided by the "width" of the typical query.
* @param buffer[out] index is written here.
*/
void SerializeKVectorIndex(SerializeContext *ser, const std::vector<float> &values, float min, float max, long numBins) {
void SerializeKVectorIndex(SerializeContext *ser, const std::vector<decimal> &values, decimal min, decimal max, long numBins) {
std::vector<int32_t> kVector(numBins+1); // We store sums before and after each bin
float binWidth = (max - min) / numBins;
decimal binWidth = (max - min) / numBins;

// generate the k-vector part
// Idea: When we find the first star that's across any bin boundary, we want to update all the newly sealed bins
Expand Down Expand Up @@ -92,8 +92,8 @@ void SerializeKVectorIndex(SerializeContext *ser, const std::vector<float> &valu

// metadata fields
SerializePrimitive<int32_t>(ser, values.size());
SerializePrimitive<float>(ser, min);
SerializePrimitive<float>(ser, max);
SerializePrimitive<decimal>(ser, min);
SerializePrimitive<decimal>(ser, max);
SerializePrimitive<int32_t>(ser, numBins);

// kvector index field
Expand All @@ -106,8 +106,8 @@ void SerializeKVectorIndex(SerializeContext *ser, const std::vector<float> &valu
KVectorIndex::KVectorIndex(DeserializeContext *des) {

numValues = DeserializePrimitive<int32_t>(des);
min = DeserializePrimitive<float>(des);
max = DeserializePrimitive<float>(des);
min = DeserializePrimitive<decimal>(des);
max = DeserializePrimitive<decimal>(des);
numBins = DeserializePrimitive<int32_t>(des);

assert(min >= 0.0f);
Expand All @@ -122,7 +122,7 @@ KVectorIndex::KVectorIndex(DeserializeContext *des) {
* @param upperIndex[out] Is set to the index of the last returned value +1.
* @return the index (starting from zero) of the first value matching the query
*/
long KVectorIndex::QueryLiberal(float minQueryDistance, float maxQueryDistance, long *upperIndex) const {
long KVectorIndex::QueryLiberal(decimal minQueryDistance, decimal maxQueryDistance, long *upperIndex) const {
assert(maxQueryDistance > minQueryDistance);
if (maxQueryDistance >= max) {
maxQueryDistance = max - 0.00001; // TODO: better way to avoid hitting the bottom bin
Expand Down Expand Up @@ -152,7 +152,7 @@ long KVectorIndex::QueryLiberal(float minQueryDistance, float maxQueryDistance,
}

/// return the lowest-indexed bin that contains the number of pairs with distance <= dist
long KVectorIndex::BinFor(float query) const {
long KVectorIndex::BinFor(decimal query) const {
long result = (long)ceil((query - min) / binWidth);
assert(result >= 0);
assert(result <= numBins);
Expand All @@ -168,7 +168,7 @@ long KVectorIndex::BinFor(float query) const {
| sizeof kvectorIndex | kVectorIndex | Serialized KVector index |
| 2*sizeof(int16)*numPairs | pairs | Bulk pair data |
*/
std::vector<KVectorPair> CatalogToPairDistances(const Catalog &catalog, float minDistance, float maxDistance) {
std::vector<KVectorPair> CatalogToPairDistances(const Catalog &catalog, decimal minDistance, decimal maxDistance) {
std::vector<KVectorPair> result;
for (int16_t i = 0; i < (int16_t)catalog.size(); i++) {
for (int16_t k = i+1; k < (int16_t)catalog.size(); k++) {
Expand All @@ -191,13 +191,13 @@ std::vector<KVectorPair> CatalogToPairDistances(const Catalog &catalog, float mi
* Serialize a pair-distance KVector into buffer.
* Use SerializeLengthPairDistanceKVector to determine how large the buffer needs to be. See command line documentation for other options.
*/
void SerializePairDistanceKVector(SerializeContext *ser, const Catalog &catalog, float minDistance, float maxDistance, long numBins) {
void SerializePairDistanceKVector(SerializeContext *ser, const Catalog &catalog, decimal minDistance, decimal maxDistance, long numBins) {
std::vector<int32_t> kVector(numBins+1); // numBins = length, all elements zero
std::vector<KVectorPair> pairs = CatalogToPairDistances(catalog, minDistance, maxDistance);

// sort pairs in increasing order.
std::sort(pairs.begin(), pairs.end(), CompareKVectorPairs);
std::vector<float> distances;
std::vector<decimal> distances;

for (const KVectorPair &pair : pairs) {
distances.push_back(pair.distance);
Expand All @@ -221,7 +221,7 @@ PairDistanceKVectorDatabase::PairDistanceKVectorDatabase(DeserializeContext *des
}

/// Return the value in the range [low,high] which is closest to num
float Clamp(float num, float low, float high) {
decimal Clamp(decimal num, decimal low, decimal high) {
return num < low ? low : num > high ? high : num;
}

Expand All @@ -231,7 +231,7 @@ float Clamp(float num, float low, float high) {
* @return A pointer to the start of the matched pairs. Each pair is stored as simply two 16-bit integers, each of which is a catalog index. (you must increment the pointer twice to get to the next pair).
*/
const int16_t *PairDistanceKVectorDatabase::FindPairsLiberal(
float minQueryDistance, float maxQueryDistance, const int16_t **end) const {
decimal minQueryDistance, decimal maxQueryDistance, const int16_t **end) const {

assert(maxQueryDistance <= M_PI);

Expand All @@ -242,16 +242,16 @@ const int16_t *PairDistanceKVectorDatabase::FindPairsLiberal(
}

const int16_t *PairDistanceKVectorDatabase::FindPairsExact(const Catalog &catalog,
float minQueryDistance, float maxQueryDistance, const int16_t **end) const {
decimal minQueryDistance, decimal maxQueryDistance, const int16_t **end) const {

// Instead of computing the angle for every pair in the database, we pre-compute the /cosines/
// of the min and max query distances so that we can compare against dot products directly! As
// angle increases, cosine decreases, up to M_PI (and queries larger than that don't really make
// sense anyway)
assert(maxQueryDistance <= M_PI);

float maxQueryCos = cos(minQueryDistance);
float minQueryCos = cos(maxQueryDistance);
decimal maxQueryCos = cos(minQueryDistance);
decimal minQueryCos = cos(maxQueryDistance);

long liberalUpperIndex;
long liberalLowerIndex = index.QueryLiberal(minQueryDistance, maxQueryDistance, &liberalUpperIndex);
Expand Down Expand Up @@ -280,8 +280,8 @@ long PairDistanceKVectorDatabase::NumPairs() const {
}

/// Return the distances from the given star to each star it's paired with in the database (for debugging).
std::vector<float> PairDistanceKVectorDatabase::StarDistances(int16_t star, const Catalog &catalog) const {
std::vector<float> result;
std::vector<decimal> PairDistanceKVectorDatabase::StarDistances(int16_t star, const Catalog &catalog) const {
std::vector<decimal> result;
for (int i = 0; i < NumPairs(); i++) {
if (pairs[i*2] == star || pairs[i*2+1] == star) {
result.push_back(AngleUnit(catalog[pairs[i*2]].spatial, catalog[pairs[i*2+1]].spatial));
Expand All @@ -296,6 +296,7 @@ std::vector<float> PairDistanceKVectorDatabase::StarDistances(int16_t star, cons
| size | name | description |
|------+----------------+---------------------------------------------|
| 4 | magicValue | unique database identifier |
| 4 | flags | [X, X, X, isDouble?] |
| 4 | databaseLength | length in bytes (32-bit unsigned) |
| n | database | the entire database. 8-byte aligned |
| ... | ... | More databases (each has value, length, db) |
Expand All @@ -318,6 +319,15 @@ const unsigned char *MultiDatabase::SubDatabasePointer(int32_t magicValue) const
if (curMagicValue == 0) {
return nullptr;
}
uint32_t dbFlags = DeserializePrimitive<uint32_t>(des);

// Ensure that our database is using the same type as the runtime.
if(dbFlags & MULTI_DB_IS_DOUBLE) {
assert(typeid(decimal) == typeid(double));
} else {
assert(typeid(decimal) == typeid(float));
}

uint32_t dbLength = DeserializePrimitive<uint32_t>(des);
assert(dbLength > 0);
DeserializePadding<uint64_t>(des); // align to an 8-byte boundary
Expand All @@ -331,9 +341,11 @@ const unsigned char *MultiDatabase::SubDatabasePointer(int32_t magicValue) const
}

void SerializeMultiDatabase(SerializeContext *ser,
const MultiDatabaseDescriptor &dbs) {
const MultiDatabaseDescriptor &dbs,
uint32_t flags) {
for (const MultiDatabaseEntry &multiDbEntry : dbs) {
SerializePrimitive<int32_t>(ser, multiDbEntry.magicValue);
SerializePrimitive<uint32_t>(ser, flags);
SerializePrimitive<uint32_t>(ser, multiDbEntry.bytes.size());
SerializePadding<uint64_t>(ser);
std::copy(multiDbEntry.bytes.cbegin(), multiDbEntry.bytes.cend(), std::back_inserter(ser->buffer));
Expand Down
35 changes: 20 additions & 15 deletions src/databases.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,27 @@ class KVectorIndex {
public:
explicit KVectorIndex(DeserializeContext *des);

long QueryLiberal(float minQueryDistance, float maxQueryDistance, long *upperIndex) const;
long QueryLiberal(decimal minQueryDistance, decimal maxQueryDistance, long *upperIndex) const;

/// The number of data points in the data referred to by the kvector
long NumValues() const { return numValues; };
long NumBins() const { return numBins; };
/// Upper bound on elements
float Max() const { return max; };
decimal Max() const { return max; };
// Lower bound on elements
float Min() const { return min; };
decimal Min() const { return min; };
private:
long BinFor(float dist) const;
long BinFor(decimal dist) const;

long numValues;
float min;
float max;
float binWidth;
decimal min;
decimal max;
decimal binWidth;
long numBins;
const int32_t *bins;
};

void SerializePairDistanceKVector(SerializeContext *, const Catalog &, float minDistance, float maxDistance, long numBins);
void SerializePairDistanceKVector(SerializeContext *, const Catalog &, decimal minDistance, decimal maxDistance, long numBins);

/**
* A database storing distances between pairs of stars.
Expand All @@ -53,14 +53,14 @@ class PairDistanceKVectorDatabase {
public:
explicit PairDistanceKVectorDatabase(DeserializeContext *des);

const int16_t *FindPairsLiberal(float min, float max, const int16_t **end) const;
const int16_t *FindPairsExact(const Catalog &, float min, float max, const int16_t **end) const;
std::vector<float> StarDistances(int16_t star, const Catalog &) const;
const int16_t *FindPairsLiberal(decimal min, decimal max, const int16_t **end) const;
const int16_t *FindPairsExact(const Catalog &, decimal min, decimal max, const int16_t **end) const;
std::vector<decimal> StarDistances(int16_t star, const Catalog &) const;

/// Upper bound on stored star pair distances
float MaxDistance() const { return index.Max(); };
decimal MaxDistance() const { return index.Max(); };
/// Lower bound on stored star pair distances
float MinDistance() const { return index.Min(); };
decimal MinDistance() const { return index.Min(); };
/// Exact number of stored pairs
long NumPairs() const;

Expand All @@ -85,7 +85,7 @@ class PairDistanceKVectorDatabase {
// public:
// explicit TripleInnerKVectorDatabase(const unsigned char *databaseBytes);

// void FindTriplesLiberal(float min, float max, long **begin, long **end) const;
// void FindTriplesLiberal(decimal min, decimal max, long **begin, long **end) const;
// private:
// KVectorIndex index;
// int16_t *triples;
Expand All @@ -96,6 +96,10 @@ class PairDistanceKVectorDatabase {
* This is almost always the database that is actually passed to star-id algorithms in the real world, since you'll want to store at least the catalog plus one specific database.
* Multi-databases are essentially a map from "magic values" to database buffers.
*/

#define MULTI_DB_IS_DOUBLE 0x0001
#define MULTI_DB_IS_FLOAT 0x0000

class MultiDatabase {
public:
/// Create a multidatabase from a serialized multidatabase.
Expand All @@ -111,12 +115,13 @@ class MultiDatabaseEntry {
: magicValue(magicValue), bytes(bytes) { }

int32_t magicValue;
uint32_t flags;
std::vector<unsigned char> bytes;
};

typedef std::vector<MultiDatabaseEntry> MultiDatabaseDescriptor;

void SerializeMultiDatabase(SerializeContext *, const MultiDatabaseDescriptor &dbs);
void SerializeMultiDatabase(SerializeContext *, const MultiDatabaseDescriptor &dbs, uint32_t flags);

}

Expand Down
7 changes: 6 additions & 1 deletion src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "databases.hpp"
#include "centroiders.hpp"
#include "decimal.hpp"
#include "io.hpp"
#include "man-database.h"
#include "man-pipeline.h"
Expand All @@ -30,9 +31,13 @@ static void DatabaseBuild(const DatabaseOptions &values) {

MultiDatabaseDescriptor dbEntries = GenerateDatabases(narrowedCatalog, values);
SerializeContext ser = serFromDbValues(values);
SerializeMultiDatabase(&ser, dbEntries);

// Inject flags into the Serialized Database.
uint32_t dbFlags = typeid(decimal) == typeid(double) ? MULTI_DB_IS_DOUBLE : MULTI_DB_IS_FLOAT;
SerializeMultiDatabase(&ser, dbEntries, dbFlags);

std::cerr << "Generated database with " << ser.buffer.size() << " bytes" << std::endl;
std::cerr << "Database flagged with " << dbFlags << std::endl;

UserSpecifiedOutputStream pos = UserSpecifiedOutputStream(values.outputPath, true);
pos.Stream().write((char *) ser.buffer.data(), ser.buffer.size());
Expand Down

0 comments on commit c1c0fd4

Please sign in to comment.