Optimize PyTorch Model

Hidet provides a backend to pytorch dynamo to optimize PyTorch models. To use this backend, you need to specify ‘hidet’ as the backend when calling torch.compile() such as

# optimize the model with hidet provided backend 'hidet'
model_hidet = torch.compile(model, backend='hidet')

Note

Currently, all the operators in hidet are generated by hidet itself and there is no dependency on kernel libraries such as cuDNN or cuBLAS. In the future, we might support to lower some operators to these libraries if they perform better.

Under the hood, hidet will convert the PyTorch model to hidet’s graph representation and optimize the computation graph (such as sub-graph rewrite and fusion, constant folding, etc.). After that, each operator will be lowered to hidet’s scheduling system to generate the final kernel.

Hidet provides some configurations to control the hidet backend of torch dynamo.

Search in a larger search space

There are some operators that are compute-intensive and their scheduling is critical to the performance. We usually need to search in a schedule space to find the best schedule for them to achieve the best performance on given input shapes. However, searching in a larger schedule space usually takes longer time to optimize the model. By default, hidet will use their default schedule to generate the kernel for all input shapes. To search in a larger schedule space to get better performance, you can configure the search space via search_space() :

# There are three search spaces:
# 0 - use default schedule, no search [Default]
# 1 - search in a small schedule space (usually 1~30 schedules)
# 2 - search in a large schedule space (usually more than 30 schedules)
hidet.torch.dynamo_config.search_space(2)

# After configure the search space, you can optimize the model
model_opt = torch.compile(model, backend='hidet')

# The actual searching happens when you first run the model to know the input shapes
outputs = model_opt(inputs)

Please note that the search space we set through set_search_space() will be read and used when we first run the model, instead of when we call torch.compile().

Check the correctness

It is important to make sure the optimized model is correct. Hidet provides a configuration to print the numerical difference between the hidet generated operator and the original pytorch operator. You can configure it via correctness_report():

# enable the correctness checking
hidet.torch.dynamo_config.correctness_report()

After enabling the correctness report, every time a new graph is received to compile, hidet will print the numerical difference using the dummy inputs (for now, torch dynamo does not expose the actual inputs to backends, thus we can not use the actual inputs). Let’s take the resnet18 model as an example:

import torch.backends.cudnn
import hidet

x = torch.randn(1, 3, 224, 224, dtype=torch.float16).cuda()
model = torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=True, verbose=False)
model = model.to(torch.float16).cuda().eval()

with torch.no_grad():
    hidet.torch.dynamo_config.correctness_report()
    model_opt = torch.compile(model, backend='hidet', mode='max-autotune')
    model_opt(x)
     kind           operator                                 dtype    error    attention
---  -------------  ---------------------------------------  -------  -------  -----------
0    placeholder                                             float16  0.0e+00
1    placeholder                                             float16  0.0e+00
2    placeholder                                             float16  0.0e+00
3    placeholder                                             float16  0.0e+00
4    placeholder                                             float16  0.0e+00
5    placeholder                                             float16  0.0e+00
6    placeholder                                             float16  0.0e+00
7    placeholder                                             float16  0.0e+00
8    placeholder                                             float16  0.0e+00
9    placeholder                                             float16  0.0e+00
10   placeholder                                             float16  0.0e+00
11   placeholder                                             float16  0.0e+00
12   placeholder                                             float16  0.0e+00
13   placeholder                                             float16  0.0e+00
14   placeholder                                             float16  0.0e+00
15   placeholder                                             float16  0.0e+00
16   placeholder                                             float16  0.0e+00
17   placeholder                                             float16  0.0e+00
18   placeholder                                             float16  0.0e+00
19   placeholder                                             float16  0.0e+00
20   placeholder                                             float16  0.0e+00
21   placeholder                                             float16  0.0e+00
22   placeholder                                             float16  0.0e+00
23   placeholder                                             float16  0.0e+00
24   placeholder                                             float16  0.0e+00
25   placeholder                                             float16  0.0e+00
26   placeholder                                             float16  0.0e+00
27   placeholder                                             float16  0.0e+00
28   placeholder                                             float16  0.0e+00
29   placeholder                                             float16  0.0e+00
30   placeholder                                             float16  0.0e+00
31   placeholder                                             float16  0.0e+00
32   placeholder                                             float16  0.0e+00
33   placeholder                                             float16  0.0e+00
34   placeholder                                             float16  0.0e+00
35   placeholder                                             float16  0.0e+00
36   placeholder                                             float16  0.0e+00
37   placeholder                                             float16  0.0e+00
38   placeholder                                             float16  0.0e+00
39   placeholder                                             float16  0.0e+00
40   placeholder                                             float16  0.0e+00
41   placeholder                                             float16  0.0e+00
42   placeholder                                             float16  0.0e+00
43   placeholder                                             float16  0.0e+00
44   placeholder                                             float16  0.0e+00
45   placeholder                                             float16  0.0e+00
46   placeholder                                             float16  0.0e+00
47   placeholder                                             float16  0.0e+00
48   placeholder                                             float16  0.0e+00
49   placeholder                                             float16  0.0e+00
50   placeholder                                             float16  0.0e+00
51   placeholder                                             float16  0.0e+00
52   placeholder                                             float16  0.0e+00
53   placeholder                                             float16  0.0e+00
54   placeholder                                             float16  0.0e+00
55   placeholder                                             float16  0.0e+00
56   placeholder                                             float16  0.0e+00
57   placeholder                                             float16  0.0e+00
58   placeholder                                             float16  0.0e+00
59   placeholder                                             float16  0.0e+00
60   placeholder                                             float16  0.0e+00
61   placeholder                                             float16  0.0e+00
62   placeholder                                             float16  0.0e+00
63   placeholder                                             float16  0.0e+00
64   placeholder                                             float16  0.0e+00
65   placeholder                                             float16  0.0e+00
66   placeholder                                             float16  0.0e+00
67   placeholder                                             float16  0.0e+00
68   placeholder                                             float16  0.0e+00
69   placeholder                                             float16  0.0e+00
70   placeholder                                             float16  0.0e+00
71   placeholder                                             float16  0.0e+00
72   placeholder                                             float16  0.0e+00
73   placeholder                                             float16  0.0e+00
74   placeholder                                             float16  0.0e+00
75   placeholder                                             float16  0.0e+00
76   placeholder                                             float16  0.0e+00
77   placeholder                                             float16  0.0e+00
78   placeholder                                             float16  0.0e+00
79   placeholder                                             float16  0.0e+00
80   placeholder                                             float16  0.0e+00
81   placeholder                                             float16  0.0e+00
82   placeholder                                             float16  0.0e+00
83   placeholder                                             float16  0.0e+00
84   placeholder                                             float16  0.0e+00
85   placeholder                                             float16  0.0e+00
86   placeholder                                             float16  0.0e+00
87   placeholder                                             float16  0.0e+00
88   placeholder                                             float16  0.0e+00
89   placeholder                                             float16  0.0e+00
90   placeholder                                             float16  0.0e+00
91   placeholder                                             float16  0.0e+00
92   placeholder                                             float16  0.0e+00
93   placeholder                                             float16  0.0e+00
94   placeholder                                             float16  0.0e+00
95   placeholder                                             float16  0.0e+00
96   placeholder                                             float16  0.0e+00
97   placeholder                                             float16  0.0e+00
98   placeholder                                             float16  0.0e+00
99   placeholder                                             float16  0.0e+00
100  placeholder                                             float16  0.0e+00
101  placeholder                                             float16  0.0e+00
102  placeholder                                             float16  0.0e+00
103  call_function  torch.nn.functional.conv2d               float16  1.4e-02
104  call_function  torch.nn.functional.batch_norm           float16  5.8e-03
105  call_function  torch.nn.functional.relu                 float16  4.7e-03
106  call_function  torch.nn.functional.max_pool2d           float16  4.4e-03
107  call_function  torch.nn.functional.conv2d               float16  1.7e-02
108  call_function  torch.nn.functional.batch_norm           float16  1.6e-02
109  call_function  torch.nn.functional.relu                 float16  1.3e-02
110  call_function  torch.nn.functional.conv2d               float16  7.3e-03
111  call_function  torch.nn.functional.batch_norm           float16  1.6e-02
112  call_function  operator.iadd                            float16  1.7e-02
113  call_function  torch.nn.functional.relu                 float16  1.6e-02
114  call_function  torch.nn.functional.conv2d               float16  1.8e-02
115  call_function  torch.nn.functional.batch_norm           float16  1.4e-02
116  call_function  torch.nn.functional.relu                 float16  1.3e-02
117  call_function  torch.nn.functional.conv2d               float16  7.4e-03
118  call_function  torch.nn.functional.batch_norm           float16  2.0e-02
119  call_function  operator.iadd                            float16  2.6e-02
120  call_function  torch.nn.functional.relu                 float16  2.6e-02
121  call_function  torch.nn.functional.conv2d               float16  2.9e-02
122  call_function  torch.nn.functional.batch_norm           float16  1.2e-02
123  call_function  torch.nn.functional.relu                 float16  1.2e-02
124  call_function  torch.nn.functional.conv2d               float16  1.2e-02
125  call_function  torch.nn.functional.batch_norm           float16  1.4e-02
126  call_function  torch.nn.functional.conv2d               float16  9.6e-03
127  call_function  torch.nn.functional.batch_norm           float16  1.2e-02
128  call_function  operator.iadd                            float16  1.7e-02
129  call_function  torch.nn.functional.relu                 float16  1.3e-02
130  call_function  torch.nn.functional.conv2d               float16  1.3e-02
131  call_function  torch.nn.functional.batch_norm           float16  1.2e-02
132  call_function  torch.nn.functional.relu                 float16  1.2e-02
133  call_function  torch.nn.functional.conv2d               float16  7.1e-03
134  call_function  torch.nn.functional.batch_norm           float16  1.6e-02
135  call_function  operator.iadd                            float16  1.9e-02
136  call_function  torch.nn.functional.relu                 float16  1.9e-02
137  call_function  torch.nn.functional.conv2d               float16  1.5e-02
138  call_function  torch.nn.functional.batch_norm           float16  1.1e-02
139  call_function  torch.nn.functional.relu                 float16  1.1e-02
140  call_function  torch.nn.functional.conv2d               float16  7.2e-03
141  call_function  torch.nn.functional.batch_norm           float16  1.1e-02
142  call_function  torch.nn.functional.conv2d               float16  4.2e-03
143  call_function  torch.nn.functional.batch_norm           float16  5.0e-03
144  call_function  operator.iadd                            float16  1.1e-02
145  call_function  torch.nn.functional.relu                 float16  1.1e-02
146  call_function  torch.nn.functional.conv2d               float16  8.9e-03
147  call_function  torch.nn.functional.batch_norm           float16  1.1e-02
148  call_function  torch.nn.functional.relu                 float16  9.1e-03
149  call_function  torch.nn.functional.conv2d               float16  4.5e-03
150  call_function  torch.nn.functional.batch_norm           float16  1.3e-02
151  call_function  operator.iadd                            float16  1.4e-02
152  call_function  torch.nn.functional.relu                 float16  1.3e-02
153  call_function  torch.nn.functional.conv2d               float16  1.3e-02
154  call_function  torch.nn.functional.batch_norm           float16  1.0e-02
155  call_function  torch.nn.functional.relu                 float16  6.6e-03
156  call_function  torch.nn.functional.conv2d               float16  3.7e-03
157  call_function  torch.nn.functional.batch_norm           float16  1.0e-02
158  call_function  torch.nn.functional.conv2d               float16  4.4e-03
159  call_function  torch.nn.functional.batch_norm           float16  8.7e-03
160  call_function  operator.iadd                            float16  1.2e-02
161  call_function  torch.nn.functional.relu                 float16  1.1e-02
162  call_function  torch.nn.functional.conv2d               float16  8.2e-03
163  call_function  torch.nn.functional.batch_norm           float16  8.6e-03
164  call_function  torch.nn.functional.relu                 float16  7.8e-03
165  call_function  torch.nn.functional.conv2d               float16  3.4e-03
166  call_function  torch.nn.functional.batch_norm           float16  4.5e-02
167  call_function  operator.iadd                            float16  4.5e-02
168  call_function  torch.nn.functional.relu                 float16  4.4e-02
169  call_function  torch.nn.functional.adaptive_avg_pool2d  float16  7.7e-03
170  call_function  torch.flatten                            float16  7.7e-03
171  call_function  torch.nn.functional.linear               float16  3.5e+00  <------
172  output                                                  float16  3.5e+00  <------

Tip

Usually, we can expect:

  • for float32: \(e_h \leq 10^{-5}\), and

  • for float16: \(e_h \leq 10^{-2}\).

The correctness report will print the harmonic mean of the absolute error and relative error for each operator:

\[e_h = \frac{|actual - expected|}{|expected| + 1} \quad (\frac{1}{e_h} = \frac{1}{e_a} + \frac{1}{e_r})\]

where \(actual\), \(expected\) are the actual and expected results of the operator, respectively. The \(e_a\) and \(e_r\) are the absolute error and relative error, respectively. The harmonic mean error is printed for each operator.

Operator configurations

Use CUDA Graph to dispatch kernels

Hidet provides a configuration to use CUDA Graph to dispatch kernels. CUDA Graph is a new feature in CUDA 11.0 that allows us to record the kernel dispatches and replay them later. This feature is useful when we want to dispatch the same kernels multiple times. Hidet will enable CUDA Graph by default. You can disable it via use_cuda_graph():

# disable CUDA Graph
hidet.torch.dynamo_config.use_cuda_graph(False)

in case you want to use PyTorch’s CUDA Graph feature.