Writing Dynamic kernel

Todo

More details about hidet script and how to write dynamic kernel are coming soon.

import numpy.testing
import hidet


def matmul_simt_kernel():
    from hidet.lang import attrs
    from hidet.lang import float32, int32
    from hidet.lang import as_tensor_pointer, tensor, register_tensor, shared_tensor
    from hidet.lang.cuda import threadIdx, blockIdx, syncthreads
    from hidet.lang.mapping import repeat, spatial, auto_map
    from hidet.lang.layout import row_major, local_layout

    warps_m, warps_n = 4, 2  # we use 4x2 warps
    warp_m, warp_n = 2, 2  # each warp repeats 2x2 times
    warp_map_m, warp_map_n = 2, 16  # each warp has 2x16 threads
    thread_m, thread_n = 4, 4  # each thread repeats 4x4 times

    # block_size = (64, 256, 8)
    block_m_size, block_n_size = (
        warps_m * warp_m * warp_map_m * thread_m,
        warps_n * warp_n * warp_map_n * thread_n,
    )
    block_k_size = 8
    num_warps = warps_m * warps_n  # 8
    num_threads = num_warps * 32  # 256

    with hidet.lang.script_module() as script_module:

        @hidet.lang.script
        def matmul_kernel(
            a_ptr: ~float32,  # ~ means "pointer to", similar to "*" in C
            b_ptr: ~float32,
            c_ptr: ~float32,
            m_size: int32,
            n_size: int32,
            k_size: int32,
        ):
            attrs.func_name = 'matmul_kernel'
            attrs.cuda.block_dim = num_threads
            attrs.cuda.grid_dim = (
                (m_size + block_m_size - 1) // block_m_size,
                (n_size + block_n_size - 1) // block_n_size,
            )

            a = as_tensor_pointer(a_ptr, float32, [m_size, k_size])
            b = as_tensor_pointer(b_ptr, float32, [k_size, n_size])
            c = as_tensor_pointer(c_ptr, float32, [m_size, n_size])

            smem_a = shared_tensor(float32, shape=[block_m_size, block_k_size])
            smem_b = shared_tensor(float32, shape=[block_k_size, block_n_size])
            regs_c = register_tensor(
                dtype=float32,
                # shape will be inferred from the layout automatically,
                # in this case, the shape is [64, 256]
                layout=(
                    local_layout(warps_m, warps_n)
                    * row_major(warp_m, warp_n)
                    * local_layout(warp_map_m, warp_map_n)
                    * row_major(thread_m, thread_n)
                ),
            )

            # initialize the registers
            mma_mapping = (
                spatial(warps_m, warps_n)
                .repeat(warp_m, warp_n)
                .spatial(warp_map_m, warp_map_n)
                .repeat(thread_m, thread_n)
            )
            for i, j in mma_mapping.on(threadIdx.x):
                regs_c[i, j] = 0.0

            # iterate over the k tiles
            num_k_tiles = (k_size + block_k_size - 1) // block_k_size
            for k_tile in range(num_k_tiles):
                # load smem_a [block_m_size, block_k_size] from global memory
                for i, k in auto_map(block_m_size, block_k_size, workers=num_threads).on(
                    threadIdx.x
                ):
                    global_i, global_k = (i + blockIdx.x * block_m_size, k + k_tile * block_k_size)
                    smem_a[i, k] = (
                        a[global_i, global_k] if global_i < m_size and global_k < k_size else 0.0
                    )

                # load smem_b [block_k_size, block_n_size] from global memory
                for k, j in auto_map(block_k_size, block_n_size, workers=num_threads).on(
                    threadIdx.x
                ):
                    global_k, global_j = (k + k_tile * block_k_size, j + blockIdx.y * block_n_size)
                    smem_b[k, j] = (
                        b[global_k, global_j] if global_k < k_size and global_j < n_size else 0.0
                    )

                # synchronize all threads in the block
                syncthreads()

                # simt matrix multiply accumulate (mma): regs_c = regs_c + smem_a @ smem_b
                for i, j in mma_mapping.on(threadIdx.x):
                    for k in range(block_k_size):
                        regs_c[i, j] += smem_a[i, k] * smem_b[k, j]

                # synchronize all threads in the block
                syncthreads()

            # store regs_c back to global memory
            for i, j in mma_mapping.on(threadIdx.x):
                global_i = i + blockIdx.x * block_m_size
                global_j = j + blockIdx.y * block_n_size
                if global_i < m_size and global_j < n_size:
                    c[global_i, global_j] = regs_c[i, j]

    assert isinstance(matmul_kernel, hidet.ir.Function)  # matmul is a hidet.ir.Function

    return script_module.build()


def main():
    func = matmul_simt_kernel()

    for m, n, k in [(1024, 1024, 1024), (333, 444, 555), (1, 12, 13)]:
        a = hidet.randn([m, k], dtype='float32').cuda()
        b = hidet.randn([k, n], dtype='float32').cuda()
        c = hidet.zeros([m, n]).cuda()
        func(a, b, c, m, n, k)
        numpy.testing.assert_allclose(
            actual=c.cpu().numpy(), desired=a.cpu().numpy() @ b.cpu().numpy(), rtol=1e-4, atol=1e-4
        )

        hidet_latency = hidet.utils.benchmark_func(lambda: func(a, b, c, m, n, k), repeat=50)
        print(f'{m}x{k}x{n}: hidet takes {hidet_latency:.2f} ms')
main()
1024x1024x1024: hidet takes 0.13 ms
333x555x444: hidet takes 0.07 ms
1x13x12: hidet takes 0.01 ms

Total running time of the script: (0 minutes 0.801 seconds)

Gallery generated by Sphinx-Gallery