Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pow fixup instruction? #51

Open
TellowKrinkle opened this issue Oct 11, 2023 · 7 comments
Open

pow fixup instruction? #51

TellowKrinkle opened this issue Oct 11, 2023 · 7 comments

Comments

@TellowKrinkle
Copy link
Contributor

Found some interesting things in the pow function

echo "kernel void test(uint pos [[thread_position_in_grid]], device float* out, const device float2* in) { out[pos] = metal::pow(in[pos].x, in[pos].y); }" | python3 compiler_explorer.py -
   0: 72091004             get_sr           r2, sr80 (thread_position_in_grid.x)
   4: 0501440e00c43200     device_load      0, i32, xy, r0_r1, u2_u3, r2, unsigned, lsl 1
   c: 3800                 wait             0
   e: 8a0d80c6             log2             r3.cache, r0.cache.abs
  12: 9a8dc6222800         fmul32           r3.cache, r3.discard, r1.cache
  18: 8a0dc6d2             exp2             r3.cache, r3.discard
  1c: 3a81c0222cc61200     fmadd32          r0, r0.discard, r1.discard, r3.discard
  24: 4501400e00c01200     device_store     0, i32, x, r0, u0_u1, r2, unsigned, 0
  2c: 8800                 stop             

The only difference from powr (which doesn't handle negative x) is the .abs on the input and the fmadd32 at the end
I don't think adding the product of the inputs will magically fix up the result for negative numbers (and I ran the function to make sure it actually does calculate pow)
The fmadd32 has bit 52 set (which is currently unused in our decoder), so maybe that's what makes it special?

I'm not currently running an OS supported by hwtestbed, so I'll leave actually testing this to someone who is

@willcmartin
Copy link

I did some experimenting and it's clear that bit 52 modifies the fmadd32 op.

Using the fmadd32 example above:

# with bit 52 set
$python3 hwtestbed.py 3a81c0222cc61200 -r-2.0,5.0,0.0,32.0 -b

Thread 0: r0: c2000000 (-32       ), input: -2.0,5.0,0.0,32.0
# without bit 52 set
$ python3 hwtestbed.py 3a81c0222cc60200 -r-2.0,5.0,0.0,32.0 -b

Thread 0: r0: 41b00000 (22        ), input: -2.0,5.0,0.0,32.0

I played around and the logic of fmadd32 d, a, b, c with bit 52 set seems to be something like this:

(For context, if we are following pow function above d = x^y when the op has completed with a = x, b = y, c = abs(x)^y)

if a == 1 or b == 0:
    d = 1
elif b == 1:
   d = a
elif a < 0
    if b is not an int:
        d = 2143289344 # imaginary result, so error code
    elif b is odd:
        d = -c
else:
    d = c

This seems pretty convenient to fix up the pow operation for negatives. I'm unsure if there is a simpler way to express this logic or if the op is used for anything other than the pow function.

I'd be happy to look into this more and update the documentation. Let me know if anyone has any suggestions.

@TellowKrinkle
Copy link
Contributor Author

Thanks for looking at it
I tried testing some more values, and here's a few more edge cases:

  • a == -0 is considered less than zero for negation purposes but not for NaN-if-b-is-not-integer purposes
  • if (a == -1 && isinf(b)) { d = 1 }
  • If a is infinite and b is not an integer, d = c, not NaN

The rest of the funniness seems to be them not caring about what happens if a or b is NaN but c is not, it acts all goofy in those situations, but you wouldn't encounter it when using the instruction for pow fixup.
(The instruction pays attention to the sign bit of the NaN, which is a bit unusual)

Testing code

(After the first run, it'll generate a metallib in /tmp/. Open it up in a hex editor and modify the fmadd32 to add bit 52, then run it again.)

let shader = """
#include <metal_stdlib>
using namespace metal;

struct alignas(16) Info {
	packed_uint3 base;
	uint counter_size;
	packed_uint3 stride;
};

uint flushDenormals(uint val) {
	return val & 0x7f800000 ? val : val & 0x80000000;
}

kernel void test(uint3 pos [[thread_position_in_grid]], constant Info& info [[buffer(0)]], device atomic_uint& counter [[buffer(1)]], device uint4* output [[buffer(2)]]) {
	const uint3 in = pos * info.stride + info.base;
	float3 fin = as_type<float3>(in);
	uint resA = as_type<uint>(fma(fin.x, fin.y, fin.z));
	uint resB = flushDenormals(in.z);
	if (fin.x == 1.f || fin.y == 0.f || (fin.x == -1.f && isinf(fin.y))) {
		resB = as_type<uint>(1.f);
	} else if (fin.y == 1.f) {
		resB = flushDenormals(isnan(fin.x) ? in.z : in.x);
	} else if (in.x & 0x80000000) {
		float iy = floor(fin.y);
		if (fin.y != iy && !isnan(fin.y)) {
			if (fin.x != 0.f && !isinf(fin.x))
				resB = 0x7fc00000;
		} else if (floor(iy / 2) != iy / 2 && !isnan(fin.x) && !isnan(fin.y)) {
			resB = isnan(fin.z) ? 0x7fc00000 : flushDenormals(in.z ^ 0x80000000);
		}
	}
	if (resA != resB) {
		uint idx = atomic_fetch_add_explicit(&counter, 1, memory_order_relaxed);
		if (idx < info.counter_size) {
			output[idx] = uint4(in, resA);
		}
	}
}
"""

import Metal
import simd

typealias V3UInt = SIMD3<UInt32>
typealias V4UInt = SIMD4<UInt32>

func floatBits(_ x: Float) -> UInt32 { return x.bitPattern }

let tests: [(base: V3UInt, stride: V3UInt, size: V3UInt)] = [
	(V3UInt(0, 0, 0),                     V3UInt(65536, 65536, 0), V3UInt(65536, 65536, 1)),
	(V3UInt(0, 0, 0x80000000),            V3UInt(65536, 65536, 0), V3UInt(65536, 65536, 1)),
	(V3UInt(0, 0, 1),                     V3UInt(65536, 65536, 0), V3UInt(65536, 65536, 1)),
	(V3UInt(0, 0, 0x80000001),            V3UInt(65536, 65536, 0), V3UInt(65536, 65536, 1)),
	(V3UInt(0, 0, floatBits( 16.5)),      V3UInt(65536, 65536, 0), V3UInt(65536, 65536, 1)),
	(V3UInt(0, 0, floatBits(-16.5)),      V3UInt(65536, 65536, 0), V3UInt(65536, 65536, 1)),
	(V3UInt(0, 0, floatBits( .infinity)), V3UInt(65536, 65536, 0), V3UInt(65536, 65536, 1)),
	(V3UInt(0, 0, floatBits(-.infinity)), V3UInt(65536, 65536, 0), V3UInt(65536, 65536, 1)),
	(V3UInt(0, 0, floatBits(.nan)),       V3UInt(65536, 65536, 0), V3UInt(65536, 65536, 1)),
	(V3UInt(0, 0, 0xffffffff),            V3UInt(65536, 65536, 0), V3UInt(65536, 65536, 1)),
	(V3UInt(0, 0,                     floatBits(16)), V3UInt(16, 0, 0), V3UInt(1 << 28, 1, 1)),
	(V3UInt(0, 0x80000000,            floatBits(16)), V3UInt(16, 0, 0), V3UInt(1 << 28, 1, 1)),
	(V3UInt(0, floatBits( 1),         floatBits(16)), V3UInt(16, 0, 0), V3UInt(1 << 28, 1, 1)),
	(V3UInt(0, floatBits( 2),         floatBits(16)), V3UInt(16, 0, 0), V3UInt(1 << 28, 1, 1)),
	(V3UInt(0, floatBits(-1),         floatBits(16)), V3UInt(16, 0, 0), V3UInt(1 << 28, 1, 1)),
	(V3UInt(0, floatBits(-2),         floatBits(16)), V3UInt(16, 0, 0), V3UInt(1 << 28, 1, 1)),
	(V3UInt(0, floatBits( .infinity), floatBits(16)), V3UInt(16, 0, 0), V3UInt(1 << 28, 1, 1)),
	(V3UInt(0, floatBits(-.infinity), floatBits(16)), V3UInt(16, 0, 0), V3UInt(1 << 28, 1, 1)),
	(V3UInt(0, floatBits(.nan),       floatBits(16)), V3UInt(16, 0, 0), V3UInt(1 << 28, 1, 1)),
	(V3UInt(0, 0xffffffff,            floatBits(16)), V3UInt(16, 0, 0), V3UInt(1 << 28, 1, 1)),
	(V3UInt(0,                     0, floatBits(16)), V3UInt(0, 16, 0), V3UInt(1, 1 << 28, 1)),
	(V3UInt(0x80000000,            0, floatBits(16)), V3UInt(0, 16, 0), V3UInt(1, 1 << 28, 1)),
	(V3UInt(floatBits( 1),         0, floatBits(16)), V3UInt(0, 16, 0), V3UInt(1, 1 << 28, 1)),
	(V3UInt(floatBits( 2),         0, floatBits(16)), V3UInt(0, 16, 0), V3UInt(1, 1 << 28, 1)),
	(V3UInt(floatBits(-1),         0, floatBits(16)), V3UInt(0, 16, 0), V3UInt(1, 1 << 28, 1)),
	(V3UInt(floatBits(-2),         0, floatBits(16)), V3UInt(0, 16, 0), V3UInt(1, 1 << 28, 1)),
	(V3UInt(floatBits( .infinity), 0, floatBits(16)), V3UInt(0, 16, 0), V3UInt(1, 1 << 28, 1)),
	(V3UInt(floatBits(-.infinity), 0, floatBits(16)), V3UInt(0, 16, 0), V3UInt(1, 1 << 28, 1)),
	(V3UInt(floatBits(.nan),       0, floatBits(16)), V3UInt(0, 16, 0), V3UInt(1, 1 << 28, 1)),
	(V3UInt(0xffffffff,            0, floatBits(16)), V3UInt(0, 16, 0), V3UInt(1, 1 << 28, 1)),
]

let path = URL(fileURLWithPath: "/tmp/PowTest.metallib")

func writeShaderAndExit(lib: MTLLibrary) throws -> Never {
	let cdesc = MTLComputePipelineDescriptor()
	cdesc.computeFunction = lib.makeFunction(name: "test")!
	try? FileManager.default.removeItem(at: path)
	let archive = try lib.device.makeBinaryArchive(descriptor: .init())
	try archive.addComputePipelineFunctions(descriptor: cdesc)
	try archive.serialize(to: path)
	print("Wrote shader to \(path.path).  Please modify it for testing and rerun.")
	exit(EXIT_FAILURE)
}

if let gpu = MTLCopyAllDevices().first(where: { $0.supportsFamily(.apple1) }) {
	let options = MTLCompileOptions()
	options.fastMathEnabled = false
	let lib = try gpu.makeLibrary(source: shader, options: options)

	guard FileManager.default.fileExists(atPath: path.path) else {
		try writeShaderAndExit(lib: lib)
	}

	let pipe: MTLComputePipelineState
	do {
		let cdesc = MTLComputePipelineDescriptor()
		cdesc.computeFunction = lib.makeFunction(name: "test")!
		let archiveDesc = MTLBinaryArchiveDescriptor()
		archiveDesc.url = path
		let archive = try gpu.makeBinaryArchive(descriptor: archiveDesc)
		cdesc.binaryArchives = [archive]
		pipe = try gpu.makeComputePipelineState(descriptor: cdesc, options: .failOnBinaryArchiveMiss).0
	} catch {
		print("Failed to load shader: \(error.localizedDescription)")
		print("Generating new metallib...")
		try writeShaderAndExit(lib: lib)
	}

	let queue = gpu.makeCommandQueue()!
	let outputSize = 128
	let counterBuf = gpu.makeBuffer(length: 8)!
	let outputBuf = gpu.makeBuffer(length: outputSize * MemoryLayout<SIMD4<UInt32>>.size)!
	let optr = outputBuf.contents().bindMemory(to: SIMD4<UInt32>.self, capacity: outputSize)
	for test in tests {
		counterBuf.contents().storeBytes(of: 0, as: UInt32.self)
		let cb = queue.makeCommandBuffer()!
		let enc = cb.makeComputeCommandEncoder()!
		withUnsafeBytes(of: (V4UInt(test.base, UInt32(outputSize)), V4UInt(test.stride, 0))) { ptr in
			enc.setBytes(ptr.baseAddress!, length: ptr.count, index: 0)
		}
		enc.setBuffer(counterBuf, offset: 0, index: 1)
		enc.setBuffer(outputBuf, offset: 0, index: 2)
		enc.setComputePipelineState(pipe)
		var gridsize = MTLSizeMake(Int(test.size.x), Int(test.size.y), Int(test.size.z))
		let total = gridsize.width * gridsize.height * gridsize.depth
		var tgsize = MTLSizeMake(1, 1, 1)
		for _ in 0..<5 {
			if gridsize.width % 2 == 0 {
				tgsize.width *= 2
				gridsize.width /= 2
			} else if gridsize.height % 2 == 0 {
				tgsize.height *= 2
				gridsize.height /= 2
			} else if gridsize.depth % 2 == 0 {
				tgsize.depth *= 2
				gridsize.depth /= 2
			}
		}
		enc.dispatchThreadgroups(gridsize, threadsPerThreadgroup: tgsize)
		enc.endEncoding()
		cb.commit()
		let end = test.base &+ (test.size &- 1) &* test.stride
		print(String(format: "Dispatched %zd threads from (%08x, %08x, %08x) through (%08x, %08x, %08x) by (%d, %d, %d)...", total, test.base.x, test.base.y, test.base.z, end.x, end.y, end.z, test.stride.x, test.stride.y, test.stride.z), terminator: "")
		fflush(stdout)
		cb.waitUntilCompleted()
		let fails = Int(counterBuf.contents().load(as: UInt32.self))
		if fails == 0 {
			print(" OK")
		} else {
			print()
			for i in 0..<min(outputSize, fails) {
				let fail = optr[i]
				let ffail = unsafeBitCast(fail, to: SIMD4<Float>.self)
				print(String(format: "\t(%08x, %08x, %08x) -> %08x | (%f, %f, %f) -> %f", fail.x, fail.y, fail.z, fail.w, ffail.x, ffail.y, ffail.z, ffail.w))
			}
			break
		}
	}

} else {
	print("No Apple GPU found")
}

@willcmartin
Copy link

Thanks, I was able to replicate your results.

Seems like a couple things need to be done to include the fixup instruction

  1. Add a instruction description class to applegpu.py
  2. Add a test to hwtest.py

Anything I'm missing?

@TellowKrinkle
Copy link
Contributor Author

Yeah, that should be it
Would you like to add it?

@willcmartin
Copy link

Sure. I'll put together a PR

@willcmartin
Copy link

It seems that the pow_fixup d, a, b, c operation does not respect 16-bit register flags for register c and instead pretends the register is 32-bit.

For example:

# using 16-bit register
# pow_fixup r0, r1, r2, r3l

$ python3 hwtestbed.py 3A81424224461000 -r-0.0,2.0,5.0,32.0 -b
Thread 0: r0: 42000000 (32        ), input: -0.0,2.0,5.0,32.0

# correct output should be 0
# using 32-bit register
# pow_fixup r0, r1, r2, r3

$ python3 hwtestbed.py 3A81424224461200 -r-0.0,2.0,5.0,32.0 -b
Thread 0: r0: 42000000 (32        ), input: -0.0,2.0,5.0,32.0

As far as I can tell, register a, b, and d all act properly if used as 16-bit registers.

I'm not sure if this practically matters. Looks like register c is not used as a 16-bit register when operating on halfs.

$echo "kernel void test(ushort pos [[thread_position_in_grid]], device float* out, const device half2* in) { out[pos] = metal::pow(in[pos].x, in[pos].y); }" | python3 compiler_explorer.py -

compute shader:
   0: 72041004             get_sr           r1l, sr80 (thread_position_in_grid.x)
   4: 62060000             mov_imm          r1h, 0
   8: 8500240c00c43000     device_load      0, i16, xy, r0l_r0h, u2_u3, r1, signed, lsl 1
  10: 3800                 wait             0
  12: 8a0980c4             log2             r2.cache, r0l.cache.abs
  16: 9a89c4120800         fmul32           r2.cache, r2.discard, r0h.cache
  1c: 8a09c4d2             exp2             r2.cache, r2.discard
  20: ba80c0100cc41200     pow_fixup        r0l.cache, r0l.discard, r0h.discard, r2.discard
  28: 2a81c0000002         fadd32           r0, r0l.discard, -0.0
  2e: 4501200c00c01200     device_store     0, i32, x, r0, u0_u1, r1, signed, 0
  36: 8800                 stop             

Any advice on how to proceed? I could easily just not included test cases with register c as 16-bit. Or I could do some logic in applegpu.py to correct for flags setting register c to 16-bit.

@TellowKrinkle
Copy link
Contributor Author

I think either option would be fine. If you choose not to handle it, include a comment on the instruction definition mentioning it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants