Skip to content

Commit

Permalink
Continue transfer to xsimd
Browse files Browse the repository at this point in the history
  • Loading branch information
Pencilcaseman committed Aug 6, 2023
1 parent cda9f72 commit 96d28f2
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 214 deletions.
5 changes: 4 additions & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,7 @@
url = https://github.com/mreineck/pocketfft
[submodule "librapid/vendor/CLBlast"]
path = librapid/vendor/CLBlast
url = https://github.com/CNugteren/CLBlast.git
url = https://github.com/CNugteren/CLBlast.git
[submodule "librapid/vendor/xsimd"]
path = librapid/vendor/xsimd
url = https://github.com/LibRapid/xsimd.git
92 changes: 47 additions & 45 deletions librapid/include/librapid/autodiff/dual.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ namespace librapid {
T derivative;

#if defined(LIBRAPID_IN_JITIFY)
using Scalar = T;
using Scalar = T;
using Packet = T;
static constexpr uint64_t packetWidth = 1;
#else
using Scalar = typename typetraits::TypeInfo<T>::Scalar;
using Scalar = typename typetraits::TypeInfo<T>::Scalar;
using Packet = typename typetraits::TypeInfo<T>::Packet;
static constexpr uint64_t packetWidth = typetraits::TypeInfo<T>::packetWidth;
#endif

using Packet = typename typetraits::TypeInfo<T>::Packet;
static constexpr uint64_t packetWidth = typetraits::TypeInfo<T>::packetWidth;

Dual() = default;
explicit Dual(T value) : value(value), derivative(T()) {}
Dual(T value, T derivative) : value(value), derivative(derivative) {}
Expand Down Expand Up @@ -53,40 +54,40 @@ namespace librapid {

static constexpr size_t size() { return typetraits::TypeInfo<Dual>::packetWidth; }

template<typename P>
LIBRAPID_ALWAYS_INLINE void store(P *ptr) const {
// Load the data into batches.
auto casted = reinterpret_cast<const Scalar *>(ptr);

// Compute interleaved values.
std::array<Scalar, 2 * packetWidth> interleaved;
for (std::size_t i = 0; i < packetWidth; ++i) {
interleaved[2 * i] = value.get(i);
interleaved[2 * i + 1] = derivative.get(i);
}

// Store the interleaved values back to memory.
std::copy(interleaved.begin(), interleaved.end(), casted);
}

template<typename P>
LIBRAPID_ALWAYS_INLINE void load(const P *ptr) {
// auto casted = reinterpret_cast<const Scalar *>(ptr);
// Vc::deinterleave(&value, &derivative, casted, Vc::Aligned);

// Load the data into batches.
auto casted = reinterpret_cast<const Scalar *>(ptr);

// Compute interleaved values.
std::array<Scalar, 2 * packetWidth> interleaved;
std::copy(casted, casted + 2 * packetWidth, interleaved.begin());

// Store the interleaved values back to memory.
for (std::size_t i = 0; i < packetWidth; ++i) {
value.set(i, interleaved[2 * i]);
derivative.set(i, interleaved[2 * i + 1]);
}
}
// template<typename P>
// LIBRAPID_ALWAYS_INLINE void store(P *ptr) const {
// // Load the data into batches.
// auto casted = reinterpret_cast<const Scalar *>(ptr);
//
// // Compute interleaved values.
// std::array<Scalar, 2 * packetWidth> interleaved;
// for (std::size_t i = 0; i < packetWidth; ++i) {
// interleaved[2 * i] = value.get(i);
// interleaved[2 * i + 1] = derivative.get(i);
// }
//
// // Store the interleaved values back to memory.
// std::copy(interleaved.begin(), interleaved.end(), casted);
// }

// template<typename P>
// LIBRAPID_ALWAYS_INLINE void load(const P *ptr) {
// // auto casted = reinterpret_cast<const Scalar *>(ptr);
// // Vc::deinterleave(&value, &derivative, casted, Vc::Aligned);
//
// // Load the data into batches.
// auto casted = reinterpret_cast<const Scalar *>(ptr);
//
// // Compute interleaved values.
// std::array<Scalar, 2 * packetWidth> interleaved;
// std::copy(casted, casted + 2 * packetWidth, interleaved.begin());
//
// // Store the interleaved values back to memory.
// for (std::size_t i = 0; i < packetWidth; ++i) {
// value.set(i, interleaved[2 * i]);
// derivative.set(i, interleaved[2 * i + 1]);
// }
// }

LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator+=(const Dual &other) {
value += other.value;
Expand Down Expand Up @@ -387,17 +388,17 @@ namespace librapid {
struct TypeInfo<Dual<T>> {
static constexpr detail::LibRapidType type = detail::LibRapidType::Dual;
using Scalar = T;
using Packet = Dual<typename TypeInfo<T>::Packet>;
using Packet = std::false_type; // Dual<typename TypeInfo<T>::Packet>;
static constexpr int64_t packetWidth =
TypeInfo<typename TypeInfo<T>::Scalar>::packetWidth;
0; // TypeInfo<typename TypeInfo<T>::Scalar>::packetWidth;
using Backend = backend::CPU;

static constexpr char name[] = "Dual_T";

static constexpr bool supportsArithmetic = TypeInfo<T>::supportsArithmetic;
static constexpr bool supportsLogical = TypeInfo<T>::supportsLogical;
static constexpr bool supportsBinary = TypeInfo<T>::supportsBinary;
static constexpr bool allowVectorisation = TypeInfo<T>::allowVectorisation;
static constexpr bool allowVectorisation = false; // TypeInfo<T>::allowVectorisation;

# if defined(LIBRAPID_HAS_CUDA)
static constexpr cudaDataType_t CudaType = TypeInfo<T>::CudaType;
Expand All @@ -421,17 +422,18 @@ namespace librapid {
struct TypeInfo<Dual<float>> {
static constexpr detail::LibRapidType type = detail::LibRapidType::Dual;
using Scalar = float;
using Packet = Dual<typename TypeInfo<float>::Packet>;
using Packet = std::false_type; // Dual<typename TypeInfo<float>::Packet>;
static constexpr int64_t packetWidth =
TypeInfo<typename TypeInfo<float>::Scalar>::packetWidth;
0; // TypeInfo<typename TypeInfo<float>::Scalar>::packetWidth;
using Backend = backend::CPU;

static constexpr char name[] = "Dual_float";

static constexpr bool supportsArithmetic = TypeInfo<float>::supportsArithmetic;
static constexpr bool supportsLogical = TypeInfo<float>::supportsLogical;
static constexpr bool supportsBinary = TypeInfo<float>::supportsBinary;
static constexpr bool allowVectorisation = TypeInfo<float>::allowVectorisation;
static constexpr bool allowVectorisation =
false; // TypeInfo<float>::allowVectorisation;

# if defined(LIBRAPID_HAS_CUDA)
static constexpr cudaDataType_t CudaType = TypeInfo<float>::CudaType;
Expand Down
32 changes: 29 additions & 3 deletions librapid/include/librapid/core/traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -736,25 +736,51 @@ namespace librapid {
};
#endif

template<typename BatchType>
struct TypeInfo<xsimd::batch_element_reference<BatchType>> {
static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar;
using Scalar = typename xsimd::batch_element_reference<BatchType>::Scalar;
using Packet = std::false_type;
using Backend = backend::CPU;
static constexpr int64_t packetWidth = 1;
static constexpr char name[] = "xsimd::batch_element_reference";
static constexpr bool supportsArithmetic = true;
static constexpr bool supportsLogical = false;
static constexpr bool supportsBinary = false;
static constexpr bool allowVectorisation = false;

static constexpr bool canAlign = true;
static constexpr bool canMemcpy = false;

LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); }
LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); }
LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); }
LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); }
LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); }
LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); }
LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); }
LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); }
};

template<>
struct TypeInfo<backend::CPU> {
static constexpr char name[] = "CPU";
using Backend = backend::CPU;
using Backend = backend::CPU;
};

#if defined(LIBRAPID_HAS_OPENCL)
template<>
struct TypeInfo<backend::OpenCL> {
static constexpr char name[] = "OpenCL";
using Backend = backend::OpenCL;
using Backend = backend::OpenCL;
};
#endif

#if defined(LIBRAPID_HAS_CUDA)
template<>
struct TypeInfo<backend::CUDA> {
static constexpr char name[] = "CUDA";
using Backend = backend::CUDA;
using Backend = backend::CUDA;
};
#endif

Expand Down
34 changes: 17 additions & 17 deletions librapid/include/librapid/math/complex.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,19 +480,19 @@ namespace librapid {
return *this;
}

template<typename P>
LIBRAPID_ALWAYS_INLINE void store(P *ptr) const {
auto casted = reinterpret_cast<Scalar *>(ptr);
auto ret = Vc::interleave(m_val[RE], m_val[IM]);
ret.first.store(casted);
ret.second.store(casted + size());
}

template<typename P>
LIBRAPID_ALWAYS_INLINE void load(const P *ptr) {
auto casted = reinterpret_cast<const Scalar *>(ptr);
Vc::deinterleave(&m_val[RE], &m_val[IM], casted, Vc::Aligned);
}
// template<typename P>
// LIBRAPID_ALWAYS_INLINE void store(P *ptr) const {
// auto casted = reinterpret_cast<Scalar *>(ptr);
// auto ret = Vc::interleave(m_val[RE], m_val[IM]);
// ret.first.store(casted);
// ret.second.store(casted + size());
// }

// template<typename P>
// LIBRAPID_ALWAYS_INLINE void load(const P *ptr) {
// auto casted = reinterpret_cast<const Scalar *>(ptr);
// Vc::deinterleave(&m_val[RE], &m_val[IM], casted, Vc::Aligned);
// }

/// \brief Assign to the real component
///
Expand Down Expand Up @@ -2047,11 +2047,11 @@ namespace librapid {
struct TypeInfo<Complex<T>> {
static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar;
using Scalar = Complex<T>;
using Packet =
typename std::conditional_t<(TypeInfo<T>::packetWidth > 1),
Complex<typename TypeInfo<T>::Packet>, std::false_type>;
using Packet = std::false_type;
// typename std::conditional_t<(TypeInfo<T>::packetWidth > 1),
// Complex<typename TypeInfo<T>::Packet>, std::false_type>;
static constexpr int64_t packetWidth =
TypeInfo<typename TypeInfo<T>::Scalar>::packetWidth;
0; // TypeInfo<typename TypeInfo<T>::Scalar>::packetWidth;
static constexpr char name[] = "Complex";
static constexpr bool supportsArithmetic = true;
static constexpr bool supportsLogical = true;
Expand Down
Loading

0 comments on commit 96d28f2

Please sign in to comment.