Note
Go to the end to download the full example code
Kernel Functions¶
Besides the public
function, there are other function kinds in hidet script. Currently, we support the following
function kinds:
public
: a public function. The public functions in a script module will be exposed to the outside and can be invoked by the outside (in our case, we can call them in python).cpu_kernel
: a kernel function on cpu.cpu_internal
: an internal function on cpu.cuda_kernel
: a kernel function on cuda.cuda_internal
: an internal function on cuda.
Tip
The cuda_kernel
and cuda_internal
correspond to the __global__
and __device__
functions in CUDA.
Usually, we use the cpu_kernel
and cuda_kernel
to define the kernel functions. The cpu_internal
and
cuda_internal
are used to define the internal functions that are only used by the kernel functions.
When there is only one kernel function in a script module and there is no function named launch
, a default
launch
function will be generated to launch the kernel function.
CPU kernel function¶
import hidet
from hidet.lang import attrs
from hidet.lang.types import f32
hidet.option.cache_dir('./outs/cache')
with hidet.script_module() as script_module:
@hidet.script
def matmul(a: f32[16, 16], b: f32[16, 16], c: f32[16, 16]):
# specify the function kind as 'cpu_kernel'
attrs.func_kind = 'cpu_kernel'
for i in range(16):
for j in range(16):
c[i, j] = 0.0
for k in range(16):
c[i, j] += a[i, k] * b[k, j]
module = script_module.build()
a = hidet.randn([16, 16])
b = hidet.randn([16, 16])
c = hidet.empty([16, 16])
module(a, b, c)
We can check the generated source code to see that the launch
function is generated automatically.
print(module.source())
#include <stdint.h>
#include <math.h>
#include <hidet/runtime/symbols.h>
#include <hidet/runtime/memory_planner.h>
#include <hidet/runtime/cpu/context.h>
#include <hidet/runtime/cpu/float32.h>
#include <hidet/runtime/logging.h>
static void hidet_matmul(float * __restrict__ a, float * __restrict__ b, float * __restrict__ c) {
for (int32_t i = 0; (i < 16); i = (i + 1)) {
for (int32_t j = 0; (j < 16); j = (j + 1)) {
c[((i * 16) + j)] = 0.0f;
for (int32_t k = 0; (k < 16); k = (k + 1)) {
c[((i * 16) + j)] = (c[((i * 16) + j)] + (a[((i * 16) + k)] * b[((k * 16) + j)]));
}
}
}
}
DLL void hidet_launch(float * __restrict__ a, float * __restrict__ b, float * __restrict__ c) {
hidet_matmul(a, b, c);
}
CUDA kernel function¶
We can also define a kernel function on CUDA. The following example defines a kernel function on cuda.
We can access cuda primitive variables and functions in the hidet.lang.cuda
module.
from hidet.lang.cuda import blockIdx, threadIdx, blockDim
# workload size
m_size = 1024
n_size = 1024
k_size = 1024
with hidet.script_module() as script_module:
@hidet.script
def matmul(a: f32[m_size, k_size], b: f32[k_size, n_size], c: f32[m_size, n_size]):
# specify the function kind as 'cuda_kernel'
attrs.func_kind = 'cuda_kernel'
# specify the grid dimension and block dimension
attrs.cuda.grid_dim = (m_size + 15) // 16, (n_size + 15) // 16
attrs.cuda.block_dim = 16, 16
# the coordinate of the c matrix that this thread is responsible for
i = blockIdx.x * blockDim.x + threadIdx.x
j = blockIdx.y * blockDim.y + threadIdx.y
if i < m_size and j < n_size:
c[i, j] = 0.0
for k in range(k_size):
c[i, j] += a[i, k] * b[k, j]
module = script_module.build()
a = hidet.randn([m_size, k_size], device='cuda')
b = hidet.randn([k_size, n_size], device='cuda')
c = hidet.empty([m_size, n_size], device='cuda')
module(a, b, c)
# compare the result with torch.matmul
hidet.utils.assert_close(c, a.torch() @ b.torch(), atol=1e-4, rtol=1e-4)
We can check the generated source code:
Tip
You can find that there is no boundary checking in the kernel function. This is because hidet infers the value range for each index variable and finds that the if condition is always true, so it simplifies the if-statement.
print(module.source())
#include <stdint.h>
#include <hidet/runtime/symbols.h>
#include <hidet/runtime/memory_planner.h>
#include <hidet/runtime/cpu/context.h>
#include <hidet/runtime/cuda/complex.h>
#include <hidet/runtime/cuda/context.h>
#include <hidet/runtime/logging.h>
static __global__ void __launch_bounds__(256) hidet_matmul(float * __restrict__ a, float * __restrict__ b, float * __restrict__ c) {
int32_t i = (((int)blockIdx.x * 16) + (int)threadIdx.x);
int32_t j = (((int)blockIdx.y * 16) + (int)threadIdx.y);
c[((i * 1024) + j)] = 0.0f;
for (int32_t k = 0; (k < 1024); k = (k + 1)) {
c[((i * 1024) + j)] = (c[((i * 1024) + j)] + (a[((i * 1024) + k)] * b[((k * 1024) + j)]));
}
}
DLL void hidet_launch(float * __restrict__ a, float * __restrict__ b, float * __restrict__ c) {
hidet_matmul<<<dim3(64, 64, 1), dim3(16, 16, 1), 0, (cudaStream_t)get_cuda_stream()>>>(a, b, c);
{cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) LOG(ERROR) << "CUDA error: " << cudaGetErrorString(err) << "\n";}
}
Total running time of the script: (0 minutes 0.037 seconds)