-
Notifications
You must be signed in to change notification settings - Fork 0
/
util.h
103 lines (90 loc) · 2.64 KB
/
util.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
//
// Created by Daniel Kerbel on 12/14/2019.
//
#ifndef UTIL_H
#define UTIL_H
#include "catch.hpp"
#include "Matrix.h"
#include <sstream>
/** Helper function for creating a matrix using brace initialization syntax, 'rows' is basically a list of lists
*/
static Matrix mkMatrix(const std::initializer_list<std::initializer_list<float>>& rows)
{
int numRows = rows.size();
int numCols = rows.begin()->size();
Matrix matrix(numRows, numCols);
int inserted = 0;
for (auto row = rows.begin(); row != rows.end(); ++row)
{
auto col = *row;
assert((int)col.size() == numCols);
for (float elm: col)
{
matrix[inserted++] = elm;
}
}
return matrix;
}
namespace Catch {
template<>
struct StringMaker<Matrix> {
/** Used for printing a matrix */
static std::string convert(const Matrix& matrix)
{
std::stringstream ss;
ss << "[" << matrix.getRows() << "x" << matrix.getCols() << "] = {" << std::endl;
for (int i=0; i < matrix.getRows(); ++i)
{
ss << "{";
for (int j=0; j < matrix.getCols(); ++j)
{
ss << matrix(i,j) << " ";
}
ss << "}" << std::endl;
}
ss << "}";
return ss.str();
}
};
}
/** Used for comparing matrices for equality */
class MatrixMatcher: public Catch::MatcherBase<Matrix>
{
private:
Matrix _b;
public:
MatrixMatcher(const Matrix &b)
: _b(b)
{}
bool match(const Matrix &a) const override
{
if (a.getRows() != _b.getRows() || a.getCols() != _b.getCols())
{
return false;
}
for (int ix = 0; ix < a.getRows() * a.getCols(); ++ix)
{
// like checking that a[ix] == _b[ix], but uses approximation since we're dealing with floats
// see https://github.com/catchorg/Catch2/blob/master/docs/assertions.md for more information
if (a[ix] != Approx(_b[ix]))
{
std::cerr << "While checking 2 matrices for equality, at index [" << ix << "], the values " << a[ix]
<< " and " << _b[ix] << " aren't equal(approximately)" << std::endl;
return false;
}
}
return true;
}
std::string describe() const override
{
std::stringstream ss;
ss << "equals the matrix ";
ss << Catch::StringMaker<Matrix>::convert(_b);
return ss.str();
};
};
inline MatrixMatcher MatrixEquals(const Matrix& b)
{
return MatrixMatcher(b);
}
#endif //UTIL_H