Skip to content

Commit

Permalink
Remove __CUDACC__
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Jul 27, 2023
1 parent 4865b20 commit 0e59a59
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 95 deletions.
5 changes: 2 additions & 3 deletions include/mscclpp/concurrency.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <stdint.h>

#include <mscclpp/core.hpp>
#include <mscclpp/poll.hpp>

namespace mscclpp {
Expand All @@ -19,11 +20,10 @@ struct DeviceSyncer {
/// Destroy the DeviceSyncer object.
~DeviceSyncer() = default;

#ifdef __CUDACC__
/// Synchronize all threads inside a kernel. Guarantee that all previous work of all threads in cooperating blocks is
/// finished.
/// @param blockNum The number of blocks that will synchronize.
__forceinline__ __device__ void sync(int blockNum) {
MSCCLPP_DEVICE void sync(int blockNum) {
int maxOldCnt = blockNum - 1;
__syncthreads();
if (blockNum == 1) return;
Expand All @@ -48,7 +48,6 @@ struct DeviceSyncer {
// the flag is flipped.
__syncthreads();
}
#endif

private:
/// The flag to indicate whether the barrier is reached by the latest thread.
Expand Down
11 changes: 11 additions & 0 deletions include/mscclpp/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@
#define MSCCLPP_PATCH 0
#define MSCCLPP_VERSION (MSCCLPP_MAJOR * 10000 + MSCCLPP_MINOR * 100 + MSCCLPP_PATCH)

#if defined(__CUDA_ARCH__)
#define MSCCLPP_DEVICE __forceinline__ __device__
#else
#define MSCCLPP_DEVICE inline
#define atomicAdd(...) 0
#define atomicSub(...) 0
#define __syncthreads(...)
#define __threadfence(...)
#define __threadfence_system(...)
#endif

#include <array>
#include <bitset>
#include <future>
Expand Down
7 changes: 3 additions & 4 deletions include/mscclpp/fifo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <cstdint>
#include <functional>
#include <memory>
#include <mscclpp/core.hpp>
#include <mscclpp/poll.hpp>

#define MSCCLPP_PROXY_FIFO_SIZE 128
Expand Down Expand Up @@ -35,12 +36,11 @@ struct alignas(16) ProxyTrigger {
/// tail as there is usually enough space for device threads to push their work into.
///
struct DeviceProxyFifo {
#ifdef __CUDACC__
/// Push a trigger to the FIFO.
///
/// @param trigger The trigger to push.
/// @return The new head of the FIFO.
__forceinline__ __device__ uint64_t push(ProxyTrigger trigger) {
MSCCLPP_DEVICE uint64_t push(ProxyTrigger trigger) {
uint64_t curFifoHead = atomicAdd((unsigned long long int*)this->head, 1);

// Only one of two conditions need to be met to proceed. Either the tail has advanced enough or where we need to
Expand All @@ -62,13 +62,12 @@ struct DeviceProxyFifo {
/// Wait until there is a place in the FIFO to push a trigger.
///
/// @param curFifoHead The current head of the FIFO.
__forceinline__ __device__ void sync(uint64_t curFifoHead) {
MSCCLPP_DEVICE void sync(uint64_t curFifoHead) {
// Same as push but in this case checking the fist condition is probably faster since for tail to be pushed we need
// to wait for cudaMemcpy to be done.
OR_POLL_MAYBE_JAILBREAK(*(volatile uint64_t*)&(this->triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE]) != 0,
*(volatile uint64_t*)(this->tailReplica) <= curFifoHead, 1000000);
}
#endif // __CUDACC__

/// The FIFO buffer that is allocated on the host via `cudaHostAlloc()`.
ProxyTrigger* triggers;
Expand Down
24 changes: 11 additions & 13 deletions include/mscclpp/packet.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#ifndef MSCCLPP_PACKET_HPP_
#define MSCCLPP_PACKET_HPP_

#include <mscclpp/core.hpp>

namespace mscclpp {

/// LL (low latency) protocol packet.
Expand All @@ -23,29 +25,28 @@ union LLPacket {

uint64_t v[2];

#ifdef __CUDACC__
__forceinline__ __device__ LLPacket() {}
MSCCLPP_DEVICE LLPacket() {}

/// Write 8 bytes of data to the packet.
/// @param val1 The first 4-byte data to write.
/// @param val2 The second 4-byte data to write.
/// @param flag The flag to write.
__forceinline__ __device__ void write(uint32_t val1, uint32_t val2, uint32_t flag) {
MSCCLPP_DEVICE void write(uint32_t val1, uint32_t val2, uint32_t flag) {
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" ::"l"(v), "r"(val1), "r"(flag), "r"(val2), "r"(flag));
}

/// Write 8 bytes of data to the packet.
/// @param val The 8-byte data to write.
/// @param flag The flag to write.
__forceinline__ __device__ void write(uint64_t val, uint32_t flag) {
MSCCLPP_DEVICE void write(uint64_t val, uint32_t flag) {
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" ::"l"(v), "r"((uint32_t)val), "r"(flag),
"r"((uint32_t)(val >> 32)), "r"(flag));
}

/// Read 8 bytes of data from the packet.
/// @param flag The flag to read.
/// @return The 8-byte data read.
__forceinline__ __device__ uint2 read(uint32_t flag) {
MSCCLPP_DEVICE uint2 read(uint32_t flag) {
uint2 data;
uint32_t flag1, flag2;
do {
Expand All @@ -57,16 +58,14 @@ union LLPacket {
}

/// Clear the packet.
__forceinline__ __device__ void clear() {
MSCCLPP_DEVICE void clear() {
vec.x = 0;
vec.y = 0;
}
#endif // __CUDACC__
};

#ifdef __CUDACC__
__forceinline__ __device__ void putPackets(void* dst, uint64_t dstOffset, void* src, uint64_t srcOffset,
uint64_t srcBytes, uint32_t threadId, uint32_t numThreads, uint32_t flag) {
MSCCLPP_DEVICE void putPackets(void* dst, uint64_t dstOffset, void* src, uint64_t srcOffset, uint64_t srcBytes,
uint32_t threadId, uint32_t numThreads, uint32_t flag) {
// Offsets should be aligned to 8 bytes & size should be a multiple of 8 bytes
uint32_t* srcBase = (uint32_t*)((char*)src + srcOffset);
LLPacket* dstBase = (LLPacket*)((char*)dst + dstOffset);
Expand All @@ -77,8 +76,8 @@ __forceinline__ __device__ void putPackets(void* dst, uint64_t dstOffset, void*
}
}

__forceinline__ __device__ void getPackets(void* dst, uint64_t dstOffset, void* src, uint64_t srcOffset,
uint64_t dstBytes, uint32_t threadId, uint32_t numThreads, uint32_t flag) {
MSCCLPP_DEVICE void getPackets(void* dst, uint64_t dstOffset, void* src, uint64_t srcOffset, uint64_t dstBytes,
uint32_t threadId, uint32_t numThreads, uint32_t flag) {
// Offsets should be aligned to 8 bytes & size should be a multiple of 8 bytes
LLPacket* srcBase = (LLPacket*)((char*)src + srcOffset);
uint2* dstBase = (uint2*)((char*)dst + dstOffset);
Expand All @@ -88,7 +87,6 @@ __forceinline__ __device__ void getPackets(void* dst, uint64_t dstOffset, void*
dstBase[i] = pkt->read(flag);
}
}
#endif // __CUDACC__

}; // namespace mscclpp

Expand Down
4 changes: 0 additions & 4 deletions include/mscclpp/poll.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
#ifndef MSCCLPP_POLL_HPP_
#define MSCCLPP_POLL_HPP_

#ifdef __CUDACC__

#ifndef NDEBUG
// TODO(chhwang): https://github.com/microsoft/mscclpp/issues/99
#define POLL_PRINT_ON_STUCK(__cond)
Expand Down Expand Up @@ -57,6 +55,4 @@
} \
} while (0);

#endif // __CUDACC__

#endif // MSCCLPP_POLL_HPP_
58 changes: 23 additions & 35 deletions include/mscclpp/proxy_channel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,11 @@ union ChannelTrigger {
MSCCLPP_BITS_TYPE); // ensure 64-bit alignment
} fields;

#ifdef __CUDACC__
/// Default constructor.
__device__ ChannelTrigger() {}
MSCCLPP_DEVICE ChannelTrigger() {}

/// Copy constructor.
__device__ ChannelTrigger(ProxyTrigger value) : value(value) {}
MSCCLPP_DEVICE ChannelTrigger(ProxyTrigger value) : value(value) {}

/// Constructor.
/// @param type The type of the trigger.
Expand All @@ -118,16 +117,15 @@ union ChannelTrigger {
/// @param srcOffset The offset into the source memory region.
/// @param bytes The bytes of the transfer.
/// @param semaphoreId The ID of the semaphore.
__device__ ChannelTrigger(TriggerType type, MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
uint64_t bytes, int semaphoreId) {
MSCCLPP_DEVICE ChannelTrigger(TriggerType type, MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
uint64_t bytes, int semaphoreId) {
value.fst = ((srcOffset << MSCCLPP_BITS_SIZE) + bytes);
value.snd = ((((((((semaphoreId << MSCCLPP_BITS_TYPE) + (uint64_t)type) << MSCCLPP_BITS_REGMEM_HANDLE) + dst)
<< MSCCLPP_BITS_REGMEM_HANDLE) +
src)
<< MSCCLPP_BITS_OFFSET) +
dstOffset);
}
#endif // __CUDACC__
};

/// Proxy channel.
Expand All @@ -143,15 +141,13 @@ struct ProxyChannel {

ProxyChannel& operator=(ProxyChannel& other) = default;

#ifdef __CUDACC__
/// Push a @ref TriggerData to the FIFO.
/// @param dst The destination memory region.
/// @param dstOffset The offset into the destination memory region.
/// @param src The source memory region.
/// @param srcOffset The offset into the source memory region.
/// @param size The size of the transfer.
__forceinline__ __device__ void put(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
uint64_t size) {
MSCCLPP_DEVICE void put(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset, uint64_t size) {
fifo_.push(ChannelTrigger(TriggerData, dst, dstOffset, src, srcOffset, size, semaphoreId_).value);
}

Expand All @@ -160,23 +156,20 @@ struct ProxyChannel {
/// @param src The source memory region.
/// @param offset The common offset into the destination and source memory regions.
/// @param size The size of the transfer.
__forceinline__ __device__ void put(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) {
MSCCLPP_DEVICE void put(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) {
put(dst, offset, src, offset, size);
}

/// Push a @ref TriggerFlag to the FIFO.
__forceinline__ __device__ void signal() {
fifo_.push(ChannelTrigger(TriggerFlag, 0, 0, 0, 0, 1, semaphoreId_).value);
}
MSCCLPP_DEVICE void signal() { fifo_.push(ChannelTrigger(TriggerFlag, 0, 0, 0, 0, 1, semaphoreId_).value); }

/// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO.
/// @param dst The destination memory region.
/// @param dstOffset The offset into the destination memory region.
/// @param src The source memory region.
/// @param srcOffset The offset into the source memory region.
/// @param size The size of the transfer.
__forceinline__ __device__ void putWithSignal(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
uint64_t size) {
MSCCLPP_DEVICE void putWithSignal(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset, uint64_t size) {
fifo_.push(ChannelTrigger(TriggerData | TriggerFlag, dst, dstOffset, src, srcOffset, size, semaphoreId_).value);
}

Expand All @@ -185,7 +178,7 @@ struct ProxyChannel {
/// @param src The source memory region.
/// @param offset The common offset into the destination and source memory regions.
/// @param size The size of the transfer.
__forceinline__ __device__ void putWithSignal(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) {
MSCCLPP_DEVICE void putWithSignal(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) {
putWithSignal(dst, offset, src, offset, size);
}

Expand All @@ -195,8 +188,8 @@ struct ProxyChannel {
/// @param src The source memory region.
/// @param srcOffset The offset into the source memory region.
/// @param size The size of the transfer.
__forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, uint64_t dstOffset, MemoryId src,
uint64_t srcOffset, uint64_t size) {
MSCCLPP_DEVICE void putWithSignalAndFlush(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset,
uint64_t size) {
uint64_t curFifoHead = fifo_.push(
ChannelTrigger(TriggerData | TriggerFlag | TriggerSync, dst, dstOffset, src, srcOffset, size, semaphoreId_)
.value);
Expand All @@ -208,20 +201,18 @@ struct ProxyChannel {
/// @param src The source memory region.
/// @param offset The common offset into the destination and source memory regions.
/// @param size The size of the transfer.
__forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) {
MSCCLPP_DEVICE void putWithSignalAndFlush(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) {
putWithSignalAndFlush(dst, offset, src, offset, size);
}

/// Push a @ref TriggerSync to the FIFO.
__forceinline__ __device__ void flush() {
MSCCLPP_DEVICE void flush() {
uint64_t curFifoHead = fifo_.push(ChannelTrigger(TriggerSync, 0, 0, 0, 0, 1, semaphoreId_).value);
fifo_.sync(curFifoHead);
}

/// Wait for the proxy channel to be signaled.
__forceinline__ __device__ void wait() { semaphore_.wait(); }

#endif // __CUDACC__
MSCCLPP_DEVICE void wait() { semaphore_.wait(); }

SemaphoreId semaphoreId_;

Expand Down Expand Up @@ -256,58 +247,55 @@ struct SimpleProxyChannel {
/// Assignment operator.
SimpleProxyChannel& operator=(SimpleProxyChannel& other) = default;

#ifdef __CUDACC__
/// Push a @ref TriggerData to the FIFO.
/// @param dstOffset The offset into the destination memory region.
/// @param srcOffset The offset into the source memory region.
/// @param size The size of the transfer.
__forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) {
MSCCLPP_DEVICE void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) {
proxyChan_.put(dst_, dstOffset, src_, srcOffset, size);
}

/// Push a @ref TriggerData to the FIFO.
/// @param offset The common offset into the destination and source memory regions.
/// @param size The size of the transfer.
__forceinline__ __device__ void put(uint64_t offset, uint64_t size) { put(offset, offset, size); }
MSCCLPP_DEVICE void put(uint64_t offset, uint64_t size) { put(offset, offset, size); }

/// Push a @ref TriggerFlag to the FIFO.
__forceinline__ __device__ void signal() { proxyChan_.signal(); }
MSCCLPP_DEVICE void signal() { proxyChan_.signal(); }

/// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO.
/// @param dstOffset The offset into the destination memory region.
/// @param srcOffset The offset into the source memory region.
/// @param size The size of the transfer.
__forceinline__ __device__ void putWithSignal(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) {
MSCCLPP_DEVICE void putWithSignal(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) {
proxyChan_.putWithSignal(dst_, dstOffset, src_, srcOffset, size);
}

/// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO.
/// @param offset The common offset into the destination and source memory regions.
/// @param size The size of the transfer.
__forceinline__ __device__ void putWithSignal(uint64_t offset, uint64_t size) { putWithSignal(offset, offset, size); }
MSCCLPP_DEVICE void putWithSignal(uint64_t offset, uint64_t size) { putWithSignal(offset, offset, size); }

/// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO.
/// @param dstOffset The offset into the destination memory region.
/// @param srcOffset The offset into the source memory region.
/// @param size The size of the transfer.
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) {
MSCCLPP_DEVICE void putWithSignalAndFlush(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) {
proxyChan_.putWithSignalAndFlush(dst_, dstOffset, src_, srcOffset, size);
}

/// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO.
/// @param offset The common offset into the destination and source memory regions.
/// @param size The size of the transfer.
__forceinline__ __device__ void putWithSignalAndFlush(uint64_t offset, uint64_t size) {
MSCCLPP_DEVICE void putWithSignalAndFlush(uint64_t offset, uint64_t size) {
putWithSignalAndFlush(offset, offset, size);
}

/// Push a @ref TriggerSync to the FIFO.
__forceinline__ __device__ void flush() { proxyChan_.flush(); }
MSCCLPP_DEVICE void flush() { proxyChan_.flush(); }

/// Wait for the proxy channel to be signaled.
__forceinline__ __device__ void wait() { proxyChan_.wait(); }

#endif // __CUDACC__
MSCCLPP_DEVICE void wait() { proxyChan_.wait(); }

ProxyChannel proxyChan_;
MemoryId dst_;
Expand Down
Loading

0 comments on commit 0e59a59

Please sign in to comment.