Skip to content

Commit

Permalink
[buddy-mlir] add new constructor for Memref class
Browse files Browse the repository at this point in the history
Signed-off-by: Avimitin <[email protected]>
  • Loading branch information
Avimitin committed Aug 24, 2024
1 parent ca18bb5 commit b36543e
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions tests/pytorch/memref.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
template <typename T, size_t N> class MemRef {
public:
constexpr MemRef(T *data, const int32_t sizes[N]);
constexpr MemRef(T *data, T init, const int32_t sizes[N]);

protected:
inline void setStrides();
Expand All @@ -19,8 +20,8 @@ template <typename T, size_t N> class MemRef {
int32_t strides[N];
};

template <typename T, std::size_t N> constexpr
MemRef<T, N>::MemRef(T *data, const int32_t sizes[N]) {
template <typename T, std::size_t N>
constexpr MemRef<T, N>::MemRef(T *data, const int32_t sizes[N]) {
for (size_t i = 0; i < N; i++) {
this->sizes[i] = sizes[i];
}
Expand All @@ -31,6 +32,20 @@ MemRef<T, N>::MemRef(T *data, const int32_t sizes[N]) {
aligned = data;
}

template <typename T, std::size_t N>
constexpr MemRef<T, N>::MemRef(T *data, T init, const int32_t sizes[N])
: MemRef(data, sizes) {

int32_t total_size = 0;
for (size_t i = 0; i < N; i++) {
total_size += sizes[i];
}

for (int32_t i = 0; i < total_size; i++) {
aligned[i] = init;
}
}

template <typename T, std::size_t N> inline void MemRef<T, N>::setStrides() {
strides[N - 1] = 1;
if (N < 2)
Expand Down

0 comments on commit b36543e

Please sign in to comment.