Note
Go to the end to download the full example code.
Writing Dynamic kernel¶
Todo
More details about hidet script and how to write dynamic kernel are coming soon.
import numpy.testing
import hidet
def matmul_simt_kernel():
    from hidet.lang import attrs
    from hidet.lang import float32, int32
    from hidet.lang import as_tensor_pointer, tensor, register_tensor, shared_tensor
    from hidet.lang.cuda import threadIdx, blockIdx, syncthreads
    from hidet.lang.mapping import repeat, spatial, auto_map
    from hidet.lang.layout import row_major, local_layout
    warps_m, warps_n = 4, 2  # we use 4x2 warps
    warp_m, warp_n = 2, 2  # each warp repeats 2x2 times
    warp_map_m, warp_map_n = 2, 16  # each warp has 2x16 threads
    thread_m, thread_n = 4, 4  # each thread repeats 4x4 times
    # block_size = (64, 256, 8)
    block_m_size, block_n_size = (
        warps_m * warp_m * warp_map_m * thread_m,
        warps_n * warp_n * warp_map_n * thread_n,
    )
    block_k_size = 8
    num_warps = warps_m * warps_n  # 8
    num_threads = num_warps * 32  # 256
    with hidet.lang.script_module() as script_module:
        @hidet.lang.script
        def matmul_kernel(
            a_ptr: ~float32,  # ~ means "pointer to", similar to "*" in C
            b_ptr: ~float32,
            c_ptr: ~float32,
            m_size: int32,
            n_size: int32,
            k_size: int32,
        ):
            attrs.func_name = 'matmul_kernel'
            attrs.cuda.block_dim = num_threads
            attrs.cuda.grid_dim = (
                (m_size + block_m_size - 1) // block_m_size,
                (n_size + block_n_size - 1) // block_n_size,
            )
            a = as_tensor_pointer(a_ptr, float32, [m_size, k_size])
            b = as_tensor_pointer(b_ptr, float32, [k_size, n_size])
            c = as_tensor_pointer(c_ptr, float32, [m_size, n_size])
            smem_a = shared_tensor(float32, shape=[block_m_size, block_k_size])
            smem_b = shared_tensor(float32, shape=[block_k_size, block_n_size])
            regs_c = register_tensor(
                dtype=float32,
                # shape will be inferred from the layout automatically,
                # in this case, the shape is [64, 256]
                layout=(
                    local_layout(warps_m, warps_n)
                    * row_major(warp_m, warp_n)
                    * local_layout(warp_map_m, warp_map_n)
                    * row_major(thread_m, thread_n)
                ),
            )
            # initialize the registers
            mma_mapping = (
                spatial(warps_m, warps_n)
                .repeat(warp_m, warp_n)
                .spatial(warp_map_m, warp_map_n)
                .repeat(thread_m, thread_n)
            )
            for i, j in mma_mapping.on(threadIdx.x):
                regs_c[i, j] = 0.0
            # iterate over the k tiles
            num_k_tiles = (k_size + block_k_size - 1) // block_k_size
            for k_tile in range(num_k_tiles):
                # load smem_a [block_m_size, block_k_size] from global memory
                for i, k in auto_map(block_m_size, block_k_size, workers=num_threads).on(
                    threadIdx.x
                ):
                    global_i, global_k = (i + blockIdx.x * block_m_size, k + k_tile * block_k_size)
                    smem_a[i, k] = (
                        a[global_i, global_k] if global_i < m_size and global_k < k_size else 0.0
                    )
                # load smem_b [block_k_size, block_n_size] from global memory
                for k, j in auto_map(block_k_size, block_n_size, workers=num_threads).on(
                    threadIdx.x
                ):
                    global_k, global_j = (k + k_tile * block_k_size, j + blockIdx.y * block_n_size)
                    smem_b[k, j] = (
                        b[global_k, global_j] if global_k < k_size and global_j < n_size else 0.0
                    )
                # synchronize all threads in the block
                syncthreads()
                # simt matrix multiply accumulate (mma): regs_c = regs_c + smem_a @ smem_b
                for i, j in mma_mapping.on(threadIdx.x):
                    for k in range(block_k_size):
                        regs_c[i, j] += smem_a[i, k] * smem_b[k, j]
                # synchronize all threads in the block
                syncthreads()
            # store regs_c back to global memory
            for i, j in mma_mapping.on(threadIdx.x):
                global_i = i + blockIdx.x * block_m_size
                global_j = j + blockIdx.y * block_n_size
                if global_i < m_size and global_j < n_size:
                    c[global_i, global_j] = regs_c[i, j]
    assert isinstance(matmul_kernel, hidet.ir.Function)  # matmul is a hidet.ir.Function
    return script_module.build()
def main():
    from hidet.utils.benchmark import benchmark_func
    func = matmul_simt_kernel()
    for m, n, k in [(1024, 1024, 1024), (333, 444, 555), (1, 12, 13)]:
        a = hidet.randn([m, k], dtype='float32').cuda()
        b = hidet.randn([k, n], dtype='float32').cuda()
        c = hidet.zeros([m, n]).cuda()
        func(a, b, c, m, n, k)
        numpy.testing.assert_allclose(
            actual=c.cpu().numpy(), desired=a.cpu().numpy() @ b.cpu().numpy(), rtol=1e-4, atol=1e-4
        )
        hidet_latency = benchmark_func(lambda: func(a, b, c, m, n, k), repeat=50)
        print(f'{m}x{k}x{n}: hidet takes {hidet_latency:.2f} ms')
main()
1024x1024x1024: hidet takes 0.22 ms
333x555x444: hidet takes 0.11 ms
1x13x12: hidet takes 0.02 ms
Total running time of the script: (0 minutes 1.939 seconds)