Skip to content

Commit

Permalink
variable and type names
Browse files Browse the repository at this point in the history
  • Loading branch information
chenfucn committed Feb 23, 2024
1 parent 18bf463 commit 7d5d5ca
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ class QuantBMmaBase {
typename Operator::IteratorB warp_tile_iterator_B_;

/// Iterator to load a warp-scoped tile of quant scales from shared memory
typename Operator::IteratorQScale warp_tile_iterator_QScale_;
typename Operator::IteratorQMeta warp_tile_iterator_QScale_;

public:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ class QuantBMetaMmaTile{
// 16b gemm (kNumBsPerCoreTileFragement == 2)
// 2 B operand tiles per mma (kBTilesPerMma == 2)
// (1,n) quantization blocking
// The weight and offset tensor is prepacked to reduce load instructions.
// The scale and offset tensors are prepacked to reduce the number of load instructions.
return make_Coord((lane_id % CoreTile::kContiguous) * 4,
lane_id / CoreTile::kContiguous);
} else {
Expand Down Expand Up @@ -356,7 +356,7 @@ class QuantBMetaMmaTensorOpTileIterator<WarpShapeB_, BlockingShape_,
// 16b gemm (kNumBsPerCoreTileFragement == 2)
// 2 B operand tiles per mma (kBTilesPerMma == 2)
// (1,n) quantization blocking
// The weight and offset tensor is prepacked to reduce load instructions.
// The scale and offset tensors are prepacked to reduce the number of load instructions needed
const int row = lane_position_.row();
const int column = lane_position_.column() / BlockingShape::kColumn;

Expand Down Expand Up @@ -550,14 +550,6 @@ class QuantBMetaMmaTensorOpTileIterator<WarpShapeB_, BlockingShape_,
return *this;
}

/// Advances the pointer
CUTLASS_HOST_DEVICE
QuantBMetaMmaTensorOpTileIterator &operator--() {
// This is for operand B, so advance on the K dimension
lane_position_ += make_Coord(MetaTile::TileShapeB::kRow, 0);
return *this;
}

CUTLASS_DEVICE
QuantBMetaMmaTensorOpTileIterator &add_tile_offset(
TensorCoord const &tile_offset) {
Expand Down Expand Up @@ -761,14 +753,6 @@ class QuantBMetaMmaTensorOpTileIterator<WarpShapeB_, BlockingShape_,
return *this;
}

/// Advances the pointer
CUTLASS_HOST_DEVICE
QuantBMetaMmaTensorOpTileIterator &operator--() {
// This is for operand B, so advance on the K dimension
lane_position_ += make_Coord(MetaTile::TileShapeB::kRow, 0);
return *this;
}

CUTLASS_DEVICE
QuantBMetaMmaTensorOpTileIterator &add_tile_offset(
TensorCoord const &tile_offset) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,81 +68,6 @@ namespace warp {

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace internal {

template <typename T, typename S, int N, FloatRoundStyle Round>
struct ConvertAndPack {

using Converter = NumericArrayConverter<T, S, N, Round>;

CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<S, N> const &source) {
Converter converter;

return converter(source);
}
};

template <typename T, int N, FloatRoundStyle Round>
struct ConvertAndPack<T, T, N, Round> {

CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &source) {
return source;
}
};

template <int N, FloatRoundStyle Round>
struct ConvertAndPack<bfloat16_t, float, N, Round> {

using Converter = NumericArrayConverter<bfloat16_t, float, N, Round>;

CUTLASS_HOST_DEVICE
Array<bfloat16_t, N> operator()(Array<float, N> const &source) {
Converter converter;

Array<float, N> tmp;

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
int idx = (((i << 1) & 2) | ((i >> 1) & 1) | (i & 0xfffffffc));
tmp[i] = source[idx];
}

return converter(tmp);
}
};

template <int N, FloatRoundStyle Round>
struct ConvertAndPack<half_t, float, N, Round> {

using Converter = NumericArrayConverter<half_t, float, N, Round>;

CUTLASS_HOST_DEVICE
Array<half_t, N> operator()(Array<float, N> const &source) {
Converter converter;

Array<float, N> tmp;

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
int idx = (((i << 1) & 2) | ((i >> 1) & 1) | (i & 0xfffffffc));
tmp[i] = source[idx];
}

return converter(tmp);
}
};

/////////////////////////////////////////////////////////////////////////////////////////////////


/////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace internal

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
Expand Down Expand Up @@ -292,13 +217,13 @@ class QuantBMmaTensorOp {

// TODO This is an expanding iterator, it needs to replicate the quantization parameters
// to all threads in the warp.
using IteratorQScale = QuantBMetaMmaTensorOpTileIterator<
using IteratorQMeta = QuantBMetaMmaTensorOpTileIterator<
MatrixShape<Shape::kK, Shape::kN>, QuantBlocking, ElementQScale, SmemLayoutQScale,
ElementQOffset, SmemLayoutQOffset,
ArchMmaOperator, kThreadCount, kPartitionsK>;

using FragmentQScale = typename IteratorQScale::FragmentScale;
using FragmentQOffset = typename IteratorQScale::FragmentOffset;
using FragmentQScale = typename IteratorQMeta::FragmentScale;
using FragmentQOffset = typename IteratorQMeta::FragmentOffset;

/// Number of mma operations performed
using MmaIterations = MatrixShape<
Expand Down Expand Up @@ -419,7 +344,7 @@ class QuantBMmaTensorOp {

Array<uint8_t, FragmentB::kElements * 2> const *ptr_B =
reinterpret_cast<Array<uint8_t, FragmentB::kElements * 2> const *>(&B);
IteratorQScale::dequant(scales, offsets, *ptr_B, dst_B);
IteratorQMeta::dequant(scales, offsets, *ptr_B, dst_B);
}
};

Expand Down

0 comments on commit 7d5d5ca

Please sign in to comment.