Go to the end to download the full example code.
Quick Start¶
This guide walks through the key functionality of Hidet for tensor computation.
Optimize PyTorch model with Hidet¶
requires PyTorch 2.3+.
The easiest way to use Hidet is to use the torch.compile()
function with hidet
as the backend, such as
model_opt = torch.compile(model, backend='hidet')
Next, we use resnet18 model as an example to show how to optimize a PyTorch model with Hidet.
Because tf32 is enabled by default for torch’s cudnn backend, the torch’s precision is slightly low. You could disable the tf32 (See also PyTorch TF32).
import hidet
import torch
# take resnet18 as an example
x = torch.randn(1, 3, 224, 224, dtype=torch.float16).cuda()
model = torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=True, verbose=False)
model = model.cuda().eval().to(torch.float16)
# optimize the model with 'hidet' backend
model_opt = torch.compile(model, backend='hidet', mode='max-autotune')
# run the optimized model
y1 = model_opt(x)
y2 = model(x)
# check the correctness
torch.testing.assert_close(actual=y1, expected=y2, rtol=2e-2, atol=2e-2)
# benchmark the performance
for name, model in [('eager', model), ('hidet', model_opt)]:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
for _ in range(100):
y = model(x)
print('{:>10}: {:.3f} ms'.format(name, start_event.elapsed_time(end_event) / 100.0))
Parallel build: 0%| | 0/38 [00:00<?, ?it/s]
Parallel build: 3%|▋ | 1/38 [00:02<01:39, 2.69s/it]
Parallel build: 13%|███▋ | 5/38 [00:02<00:13, 2.37it/s]
Parallel build: 21%|█████▉ | 8/38 [00:03<00:07, 3.92it/s]
Parallel build: 26%|███████ | 10/38 [00:03<00:05, 5.20it/s]
Parallel build: 32%|████████▌ | 12/38 [00:03<00:06, 4.32it/s]
Parallel build: 32%|████████▌ | 12/38 [00:19<00:06, 4.32it/s]
Parallel build: 39%|██████████▋ | 15/38 [00:20<00:53, 2.34s/it]
Parallel build: 42%|███████████▎ | 16/38 [00:20<00:44, 2.03s/it]
Parallel build: 45%|████████████ | 17/38 [00:21<00:39, 1.86s/it]
Parallel build: 47%|████████████▊ | 18/38 [00:23<00:34, 1.73s/it]
Parallel build: 50%|█████████████▌ | 19/38 [00:35<01:20, 4.21s/it]
Parallel build: 53%|██████████████▏ | 20/38 [00:36<01:01, 3.39s/it]
Parallel build: 58%|███████████████▋ | 22/38 [00:37<00:36, 2.27s/it]
Parallel build: 61%|████████████████▎ | 23/38 [00:38<00:29, 1.96s/it]
Parallel build: 63%|█████████████████ | 24/38 [00:39<00:23, 1.67s/it]
Parallel build: 66%|█████████████████▊ | 25/38 [00:39<00:17, 1.33s/it]
Parallel build: 68%|██████████████████▍ | 26/38 [00:41<00:16, 1.34s/it]
Parallel build: 71%|███████████████████▏ | 27/38 [00:52<00:44, 4.02s/it]
Parallel build: 74%|███████████████████▉ | 28/38 [00:53<00:32, 3.20s/it]
Parallel build: 76%|████████████████████▌ | 29/38 [01:03<00:47, 5.24s/it]
Parallel build: 79%|█████████████████████▎ | 30/38 [01:04<00:32, 4.05s/it]
Parallel build: 82%|██████████████████████ | 31/38 [01:16<00:44, 6.41s/it]
Parallel build: 84%|██████████████████████▋ | 32/38 [01:18<00:29, 4.90s/it]
Parallel build: 87%|███████████████████████▍ | 33/38 [01:29<00:33, 6.68s/it]
Parallel build: 89%|████████████████████████▏ | 34/38 [01:30<00:20, 5.04s/it]
Parallel build: 92%|████████████████████████▊ | 35/38 [01:30<00:10, 3.56s/it]
Parallel build: 95%|█████████████████████████▌ | 36/38 [01:31<00:05, 2.85s/it]
Parallel build: 97%|██████████████████████████▎| 37/38 [01:45<00:06, 6.28s/it]
Parallel build: 100%|███████████████████████████| 38/38 [01:46<00:00, 4.60s/it]
Parallel build: 100%|███████████████████████████| 38/38 [01:46<00:00, 2.81s/it]
Finding the best candidates for conv_gemm_fp16_pk (1, 224, 224, 3) (392, 64) (3, 1, 112, 112, 64): 0it [00:00, ?it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 224, 224, 3) (392, 64) (3, 1, 112, 112, 64): 48it [00:00, 473.90it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 224, 224, 3) (392, 64) (3, 1, 112, 112, 64): 105it [00:00, 526.33it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 224, 224, 3) (392, 64) (3, 1, 112, 112, 64): 164it [00:00, 554.91it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 224, 224, 3) (392, 64) (3, 1, 112, 112, 64): 179it [00:00, 401.83it/s]
Finding the best candidates for reduce_sum (3, 1, 112, 112, 64) (1, 112, 112, 64): 0it [00:00, ?it/s]
Finding the best candidates for reduce_sum (3, 1, 112, 112, 64) (1, 112, 112, 64): 10it [00:00, 88.79it/s]
Finding the best candidates for reduce_sum (3, 1, 112, 112, 64) (1, 112, 112, 64): 10it [00:00, 81.72it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 56, 56, 64) (576, 64) (9, 1, 56, 56, 64): 0it [00:00, ?it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 56, 56, 64) (576, 64) (9, 1, 56, 56, 64): 45it [00:00, 449.47it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 56, 56, 64) (576, 64) (9, 1, 56, 56, 64): 90it [00:00, 448.72it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 56, 56, 64) (576, 64) (9, 1, 56, 56, 64): 151it [00:00, 519.16it/s]/home/ryan/.local/lib/python3.10/site-packages/scipy/stats/ RuntimeWarning: Precision loss occurred in moment calculation due to catastrophic cancellation. This occurs when the data are nearly identical. Results may be unreliable.
res = hypotest_fun_out(*samples, **kwds)
Finding the best candidates for conv_gemm_fp16_pk (1, 56, 56, 64) (576, 64) (9, 1, 56, 56, 64): 177it [00:00, 389.80it/s]
Finding the best candidates for reduce_sum (9, 1, 56, 56, 64) (1, 56, 56, 64): 0it [00:00, ?it/s]
Finding the best candidates for reduce_sum (9, 1, 56, 56, 64) (1, 56, 56, 64): 10it [00:00, 87.05it/s]
Finding the best candidates for reduce_sum (9, 1, 56, 56, 64) (1, 56, 56, 64): 12it [00:00, 70.85it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 56, 56, 64) (576, 128) (9, 1, 28, 28, 128): 0it [00:00, ?it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 56, 56, 64) (576, 128) (9, 1, 28, 28, 128): 49it [00:00, 490.00it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 56, 56, 64) (576, 128) (9, 1, 28, 28, 128): 98it [00:00, 464.20it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 56, 56, 64) (576, 128) (9, 1, 28, 28, 128): 145it [00:00, 457.96it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 56, 56, 64) (576, 128) (9, 1, 28, 28, 128): 181it [00:00, 339.12it/s]
Finding the best candidates for reduce_sum (9, 1, 28, 28, 128) (1, 28, 28, 128): 0it [00:00, ?it/s]
Finding the best candidates for reduce_sum (9, 1, 28, 28, 128) (1, 28, 28, 128): 9it [00:00, 89.05it/s]
Finding the best candidates for reduce_sum (9, 1, 28, 28, 128) (1, 28, 28, 128): 12it [00:00, 71.60it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 28, 28, 128) (1152, 128) (18, 1, 28, 28, 128): 0it [00:00, ?it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 28, 28, 128) (1152, 128) (18, 1, 28, 28, 128): 62it [00:00, 617.96it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 28, 28, 128) (1152, 128) (18, 1, 28, 28, 128): 124it [00:00, 599.20it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 28, 28, 128) (1152, 128) (18, 1, 28, 28, 128): 181it [00:00, 396.66it/s]
Finding the best candidates for reduce_sum (18, 1, 28, 28, 128) (1, 28, 28, 128): 0it [00:00, ?it/s]
Finding the best candidates for reduce_sum (18, 1, 28, 28, 128) (1, 28, 28, 128): 10it [00:00, 96.66it/s]
Finding the best candidates for reduce_sum (18, 1, 28, 28, 128) (1, 28, 28, 128): 10it [00:00, 88.33it/s]
Finding the best candidates for fused_conv_gemm_fp16_pk_rearrange_subtract_mul_mul_add_add_relu (1, 56, 56, 64) (64, 128) (1, 1, 1, 128) (1, 1, 1, 128) (1, 1, 1, 128) (1, 1, 1, 128) (1, 28, 28, 128) (1, 28, 28, 128): 0it [00:00, ?it/s]
Finding the best candidates for fused_conv_gemm_fp16_pk_rearrange_subtract_mul_mul_add_add_relu (1, 56, 56, 64) (64, 128) (1, 1, 1, 128) (1, 1, 1, 128) (1, 1, 1, 128) (1, 1, 1, 128) (1, 28, 28, 128) (1, 28, 28, 128): 42it [00:00, 418.00it/s]
Finding the best candidates for fused_conv_gemm_fp16_pk_rearrange_subtract_mul_mul_add_add_relu (1, 56, 56, 64) (64, 128) (1, 1, 1, 128) (1, 1, 1, 128) (1, 1, 1, 128) (1, 1, 1, 128) (1, 28, 28, 128) (1, 28, 28, 128): 89it [00:00, 444.61it/s]
Finding the best candidates for fused_conv_gemm_fp16_pk_rearrange_subtract_mul_mul_add_add_relu (1, 56, 56, 64) (64, 128) (1, 1, 1, 128) (1, 1, 1, 128) (1, 1, 1, 128) (1, 1, 1, 128) (1, 28, 28, 128) (1, 28, 28, 128): 150it [00:00, 517.68it/s]
Finding the best candidates for fused_conv_gemm_fp16_pk_rearrange_subtract_mul_mul_add_add_relu (1, 56, 56, 64) (64, 128) (1, 1, 1, 128) (1, 1, 1, 128) (1, 1, 1, 128) (1, 1, 1, 128) (1, 28, 28, 128) (1, 28, 28, 128): 183it [00:00, 341.42it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 28, 28, 128) (1152, 256) (18, 1, 14, 14, 256): 0it [00:00, ?it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 28, 28, 128) (1152, 256) (18, 1, 14, 14, 256): 48it [00:00, 477.16it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 28, 28, 128) (1152, 256) (18, 1, 14, 14, 256): 98it [00:00, 489.07it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 28, 28, 128) (1152, 256) (18, 1, 14, 14, 256): 151it [00:00, 507.51it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 28, 28, 128) (1152, 256) (18, 1, 14, 14, 256): 179it [00:00, 372.54it/s]
Finding the best candidates for reduce_sum (18, 1, 14, 14, 256) (1, 14, 14, 256): 0it [00:00, ?it/s]
Finding the best candidates for reduce_sum (18, 1, 14, 14, 256) (1, 14, 14, 256): 10it [00:00, 92.06it/s]
Finding the best candidates for reduce_sum (18, 1, 14, 14, 256) (1, 14, 14, 256): 10it [00:00, 89.45it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 14, 14, 256) (2304, 256) (34, 1, 14, 14, 256): 0it [00:00, ?it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 14, 14, 256) (2304, 256) (34, 1, 14, 14, 256): 64it [00:00, 635.60it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 14, 14, 256) (2304, 256) (34, 1, 14, 14, 256): 128it [00:00, 637.64it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 14, 14, 256) (2304, 256) (34, 1, 14, 14, 256): 177it [00:00, 461.07it/s]
Finding the best candidates for reduce_sum (34, 1, 14, 14, 256) (1, 14, 14, 256): 0it [00:00, ?it/s]
Finding the best candidates for reduce_sum (34, 1, 14, 14, 256) (1, 14, 14, 256): 10it [00:00, 99.84it/s]
Finding the best candidates for reduce_sum (34, 1, 14, 14, 256) (1, 14, 14, 256): 10it [00:00, 96.33it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 28, 28, 128) (128, 256) (2, 1, 14, 14, 256): 0it [00:00, ?it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 28, 28, 128) (128, 256) (2, 1, 14, 14, 256): 36it [00:00, 352.45it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 28, 28, 128) (128, 256) (2, 1, 14, 14, 256): 72it [00:00, 349.10it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 28, 28, 128) (128, 256) (2, 1, 14, 14, 256): 108it [00:00, 353.98it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 28, 28, 128) (128, 256) (2, 1, 14, 14, 256): 150it [00:00, 378.12it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 28, 28, 128) (128, 256) (2, 1, 14, 14, 256): 188it [00:01, 80.59it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 28, 28, 128) (128, 256) (2, 1, 14, 14, 256): 214it [00:01, 82.44it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 28, 28, 128) (128, 256) (2, 1, 14, 14, 256): 214it [00:01, 114.88it/s]
Finding the best candidates for reduce_sum (2, 1, 14, 14, 256) (1, 14, 14, 256): 0it [00:00, ?it/s]
Finding the best candidates for reduce_sum (2, 1, 14, 14, 256) (1, 14, 14, 256): 8it [00:00, 61.01it/s]
Finding the best candidates for reduce_sum (2, 1, 14, 14, 256) (1, 14, 14, 256): 12it [00:00, 44.09it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 14, 14, 256) (2304, 512) (36, 1, 7, 7, 512): 0it [00:00, ?it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 14, 14, 256) (2304, 512) (36, 1, 7, 7, 512): 52it [00:00, 519.76it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 14, 14, 256) (2304, 512) (36, 1, 7, 7, 512): 104it [00:00, 477.72it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 14, 14, 256) (2304, 512) (36, 1, 7, 7, 512): 153it [00:00, 481.28it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 14, 14, 256) (2304, 512) (36, 1, 7, 7, 512): 182it [00:00, 351.09it/s]
Finding the best candidates for reduce_sum (36, 1, 7, 7, 512) (1, 7, 7, 512): 0it [00:00, ?it/s]
Finding the best candidates for reduce_sum (36, 1, 7, 7, 512) (1, 7, 7, 512): 8it [00:00, 138.96it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 7, 7, 512) (4608, 512) (50, 1, 7, 7, 512): 0it [00:00, ?it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 7, 7, 512) (4608, 512) (50, 1, 7, 7, 512): 61it [00:00, 606.14it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 7, 7, 512) (4608, 512) (50, 1, 7, 7, 512): 122it [00:00, 580.52it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 7, 7, 512) (4608, 512) (50, 1, 7, 7, 512): 177it [00:00, 430.95it/s]
Finding the best candidates for reduce_sum (50, 1, 7, 7, 512) (1, 7, 7, 512): 0it [00:00, ?it/s]
Finding the best candidates for reduce_sum (50, 1, 7, 7, 512) (1, 7, 7, 512): 8it [00:00, 142.91it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 14, 14, 256) (256, 512) (4, 1, 7, 7, 512): 0it [00:00, ?it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 14, 14, 256) (256, 512) (4, 1, 7, 7, 512): 30it [00:00, 295.68it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 14, 14, 256) (256, 512) (4, 1, 7, 7, 512): 67it [00:00, 336.27it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 14, 14, 256) (256, 512) (4, 1, 7, 7, 512): 103it [00:00, 343.54it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 14, 14, 256) (256, 512) (4, 1, 7, 7, 512): 143it [00:00, 365.39it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 14, 14, 256) (256, 512) (4, 1, 7, 7, 512): 180it [00:00, 210.47it/s]
Finding the best candidates for conv_gemm_fp16_pk (1, 14, 14, 256) (256, 512) (4, 1, 7, 7, 512): 188it [00:00, 218.90it/s]
Finding the best candidates for reduce_sum (4, 1, 7, 7, 512) (1, 7, 7, 512): 0it [00:00, ?it/s]
Finding the best candidates for reduce_sum (4, 1, 7, 7, 512) (1, 7, 7, 512): 8it [00:00, 65.69it/s]
Finding the best candidates for reduce_sum (4, 1, 7, 7, 512) (1, 7, 7, 512): 12it [00:00, 53.33it/s]
Finding the best candidates for reduce_avg (1, 7, 7, 512) (1, 7, 1, 512): 0it [00:00, ?it/s]
Finding the best candidates for reduce_avg (1, 7, 7, 512) (1, 7, 1, 512): 8it [00:00, 78.40it/s]
Finding the best candidates for reduce_avg (1, 7, 7, 512) (1, 7, 1, 512): 8it [00:00, 76.61it/s]
Finding the best candidates for reduce_avg (1, 7, 1, 512) (1, 1, 1, 512): 0it [00:00, ?it/s]
Finding the best candidates for reduce_avg (1, 7, 1, 512) (1, 1, 1, 512): 7it [00:00, 58.63it/s]
Finding the best candidates for reduce_avg (1, 7, 1, 512) (1, 1, 1, 512): 11it [00:00, 41.37it/s]
Finding the best candidates for fused_matmul_f16_pk_cute_transpose_b_True_add (1, 512) (1000, 512) (1000,) (1, 1000): 0it [00:00, ?it/s]
Finding the best candidates for fused_matmul_f16_pk_cute_transpose_b_True_add (1, 512) (1000, 512) (1000,) (1, 1000): 25it [00:00, 247.04it/s]
Finding the best candidates for fused_matmul_f16_pk_cute_transpose_b_True_add (1, 512) (1000, 512) (1000,) (1, 1000): 71it [00:00, 366.14it/s]
Finding the best candidates for fused_matmul_f16_pk_cute_transpose_b_True_add (1, 512) (1000, 512) (1000,) (1, 1000): 108it [00:00, 198.70it/s]
Finding the best candidates for fused_matmul_f16_pk_cute_transpose_b_True_add (1, 512) (1000, 512) (1000,) (1, 1000): 109it [00:00, 209.90it/s]
eager: 0.781 ms
hidet: 0.163 ms
One operator can have multiple equivalent implementations (i.e., kernel programs) with different performance. We usually need to try different implementations for each concrete input shape to find the best one for the specific input shape. This process is called kernel tuning. To enable kernel tuning, we can use the following config in hidet:
# 0 - no tuning, default kernel will be used
# 1 - tuning in a small search space
# 2 - tuning in a large search space, will take longer time and achieves better performance
When kernel tuning is enabled, hidet can achieve the following performance on NVIDIA RTX 4090:
eager: 1.176 ms
hidet: 0.286 ms
Hidet provides some configurations to control the optimization of hidet backend. such as
Search Space: you can choose the search space of operator kernel tuning. A larger schedule space usually achieves the better performance, but takes longer time to optimize.
Correctness Checking: print the correctness checking report. You can know the numerical difference between the hidet generated operator and the original pytorch operator.
Other Configurations: you can also configure the other optimizations of hidet backend, such as using a lower precision of data type automatically (e.g., float16), or control the behavior of parallelization of the reduction dimension of the matrix multiplication and convolution operators.
See also
You can learn more about the configuration of hidet as a backend in torch dynamo in the tutorial Optimize PyTorch Model.
In the remaining parts, we will show you the key components of Hidet.
Define tensors¶
Besides randn()
, we can also use zeros()
, ones()
, full()
to create tensors with different initialized values. We can use from_torch()
convert a PyTorch tensor to Hidet tensor that shares the same memory. We can also use asarray()
convert python list or numpy ndarray to Hidet tensor.
A tensor is a n-dimension array. As other machine learning framework,
Hidet takes Tensor
as the core object to compute and manipulate.
The following code defines a tensor with randomly initialized tensor with hidet.randn()
a = hidet.randn([2, 3], device='cuda')
Tensor(shape=(2, 3), dtype='float32', device='cuda:0')
[[-0.89 0.02 -0.7 ]
[ 0.31 -0.16 -0.97]]
Each Tensor
has dtype
to define the type of each tensor element,
and device
to tell which device this tensor resides on, and
to indicate the size of each dimension. The example defines a float32
tensor on
device with shape [2, 3]
Run operators¶
Hidet provides a bunch of operators
(e.g., matmul()
) to compute and manipulate tensors. We can do a matrix multiplication as follows:
b = hidet.randn([3, 2], device='cuda')
c = hidet.randn([2], device='cuda')
d = hidet.ops.matmul(a, b)
d = d + c # 'd + c' is equivalent to 'hidet.ops.add(d, c)'
Tensor(shape=(2, 2), dtype='float32', device='cuda:0')
[[ 2.13 0.16]
[-0.08 2.13]]
In this example, the operator is executed on the device at the time we call it, thus it is in an imperative style of execution. Imperative execution is intuitive and easy to debug. But it prevents some graph-level optimization opportunities and suffers from higher kernel dispatch latency.
In the next section, we would introduce another way to execute operators.
Symbolic tensor and flow graph¶
In hidet, each tensor has an optional storage
attribute that represents a block of
memory that stores the contents of the tensor. If the storage attribute is None, the tensor is a symbolic tensor.
We could use hidet.symbol_like()
or hidet.symbol()
to create a symbolic tensor. Symbolic tensors are
returned if any input tensor of an operator is symbolic. We could know how the symbolic tensor is computed via the
attribute. It is a tuple (op, idx)
where op
is the operator produces this
tensor and idx
is the index of this tensor in the operator’s outputs.
def linear_bias(x, b, c):
return hidet.ops.matmul(x, b) + c
x = hidet.symbol_like(a)
y = linear_bias(x, b, c)
assert x.trace is None
assert y.trace is not None
print('x:', x)
print('y:', y)
x: Tensor(shape=(2, 3), dtype='float32', device='cuda:0')
y: Tensor(shape=(2, 2), dtype='float32', device='cuda:0')
from (<hidet.graph.ops.arithmetic.AddOp object at 0x7589399a7730>, 0)
We can use trace attribute to construct the computation graph, starting from the symbolic output tensor(s).
This is what function hidet.trace_from()
does. In hidet, we use hidet.graph.FlowGraph
represent the data flow graph (a.k.a, computation graph).
graph: hidet.FlowGraph = hidet.trace_from(y)
Graph(x: float32[2, 3][cuda]){
c = Constant(float32[3, 2][cuda])
c_1 = Constant(float32[2][cuda])
x_1: float32[2, 2][cuda] = Matmul(x, c, require_prologue=False, transpose_b=False)
x_2: float32[2, 2][cuda] = Add(x_1, c_1)
return x_2
Optimize flow graph¶
We may config optimizations with PassContext
Potential configs:
Whether to use tensor core.
Whether to use low-precision data type (e.g.,
Flow graph is the basic unit of graph-level optimizations in hidet. We can optimize a flow graph with
. This function applies the predefined passes to optimize given flow graph.
In this example, we fused the matrix multiplication and element-wise addition into a single operator.
opt_graph: hidet.FlowGraph = hidet.graph.optimize(graph)
Graph(x: float32[2, 3][cuda]){
c = Constant(float32[2][cuda])
c_1 = Constant(float32[3, 2][cuda])
x_1: float32[2, 2][cuda] = FusedCudaBatchMatmul(c, x, c_1, fused_graph=FlowGraph(Broadcast, Broadcast, CudaBatchMatmul, Reshape, Add), anchor=2)
return x_1
Run flow graph¶
We can directly call the flow graph to run it:
y1 = opt_graph(a)
Generating Hidet IR: 0%| | 0/1 [00:00<?, ?it/s]
Generating Hidet IR: 100%|███████████████████████| 1/1 [00:00<00:00, 127.39it/s]
Appling fusing: 0%| | 0/1 [00:00<?, ?it/s]
Appling fusing: 100%|█████████████████████████████| 1/1 [00:00<00:00, 33.98it/s]
Tensor(shape=(2, 2), dtype='float32', device='cuda:0')
[[ 2.13 0.16]
[-0.08 2.13]]
For CUDA device, a more efficient way is to create a cuda graph to dispatch the kernels in a flow graph to the NVIDIA GPU.
cuda_graph = opt_graph.cuda_graph()
outputs =[a])
y2 = outputs[0]
Tensor(shape=(2, 2), dtype='float32', device='cuda:0')
[[ 2.13 0.16]
[-0.08 2.13]]
In this quick start guide, we walk through several important functionalities of hidet:
Define tensors.
Run operators imperatively.
Use symbolic tensor to create computation graph (e.g., flow graph).
Optimize and run flow graph.
Next Step¶
It is time to learn how to use hidet in your project. A good start is to Optimize PyTorch Model and Optimize ONNX Model with Hidet.
Total running time of the script: (2 minutes 1.228 seconds)