Note
Go to the end to download the full example code.
Optimize PyTorch Model¶
Hidet provides a backend to pytorch dynamo to optimize PyTorch models. To use this backend, you need to specify ‘hidet’
as the backend when calling torch.compile()
such as
# optimize the model with hidet provided backend 'hidet'
model_hidet = torch.compile(model, backend='hidet')
Note
Currently, all the operators in hidet are generated by hidet itself and there is no dependency on kernel libraries such as cuDNN or cuBLAS. In the future, we might support to lower some operators to these libraries if they perform better.
Under the hood, hidet will convert the PyTorch model to hidet’s graph representation and optimize the computation graph (such as sub-graph rewrite and fusion, constant folding, etc.). After that, each operator will be lowered to hidet’s scheduling system to generate the final kernel.
Hidet provides some configurations to control the hidet backend of torch dynamo.
Search in a larger search space¶
There are some operators that are compute-intensive and their scheduling is critical to the performance. We usually need
to search in a schedule space to find the best schedule for them to achieve the best performance on given input shapes.
However, searching in a larger schedule space usually takes longer time to optimize the model. By default, hidet will
use their default schedule to generate the kernel for all input shapes. To search in a larger schedule space to get
better performance, you can configure the search space via search_space()
:
# There are three search spaces:
# 0 - use default schedule, no search [Default]
# 1 - search in a small schedule space (usually 1~30 schedules)
# 2 - search in a large schedule space (usually more than 30 schedules)
hidet.torch.dynamo_config.search_space(2)
# After configure the search space, you can optimize the model
model_opt = torch.compile(model, backend='hidet')
# The actual searching happens when you first run the model to know the input shapes
outputs = model_opt(inputs)
Please note that the search space we set through set_search_space()
will be read and
used when we first run the model, instead of when we call torch.compile()
.
Check the correctness¶
It is important to make sure the optimized model is correct. Hidet provides a configuration to print the numerical
difference between the hidet generated operator and the original pytorch operator. You can configure it via
correctness_report()
:
# enable the correctness checking
hidet.torch.dynamo_config.correctness_report()
After enabling the correctness report, every time a new graph is received to compile, hidet will print the numerical difference using the dummy inputs (for now, torch dynamo does not expose the actual inputs to backends, thus we can not use the actual inputs). Let’s take the resnet18 model as an example:
import torch.backends.cudnn
import hidet
x = torch.randn(1, 3, 224, 224, dtype=torch.float16).cuda()
model = torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=True, verbose=False)
model = model.to(torch.float16).cuda().eval()
with torch.no_grad():
hidet.torch.dynamo_config.correctness_report()
model_opt = torch.compile(model, backend='hidet', mode='max-autotune')
model_opt(x)
kind operator dtype error attention
--- ------------- --------------------------------------- ------- ------- -----------
0 placeholder float16 0.0e+00
1 placeholder float16 0.0e+00
2 placeholder float16 0.0e+00
3 placeholder float16 0.0e+00
4 placeholder float16 0.0e+00
5 placeholder float16 0.0e+00
6 placeholder float16 0.0e+00
7 placeholder float16 0.0e+00
8 placeholder float16 0.0e+00
9 placeholder float16 0.0e+00
10 placeholder float16 0.0e+00
11 placeholder float16 0.0e+00
12 placeholder float16 0.0e+00
13 placeholder float16 0.0e+00
14 placeholder float16 0.0e+00
15 placeholder float16 0.0e+00
16 placeholder float16 0.0e+00
17 placeholder float16 0.0e+00
18 placeholder float16 0.0e+00
19 placeholder float16 0.0e+00
20 placeholder float16 0.0e+00
21 placeholder float16 0.0e+00
22 placeholder float16 0.0e+00
23 placeholder float16 0.0e+00
24 placeholder float16 0.0e+00
25 placeholder float16 0.0e+00
26 placeholder float16 0.0e+00
27 placeholder float16 0.0e+00
28 placeholder float16 0.0e+00
29 placeholder float16 0.0e+00
30 placeholder float16 0.0e+00
31 placeholder float16 0.0e+00
32 placeholder float16 0.0e+00
33 placeholder float16 0.0e+00
34 placeholder float16 0.0e+00
35 placeholder float16 0.0e+00
36 placeholder float16 0.0e+00
37 placeholder float16 0.0e+00
38 placeholder float16 0.0e+00
39 placeholder float16 0.0e+00
40 placeholder float16 0.0e+00
41 placeholder float16 0.0e+00
42 placeholder float16 0.0e+00
43 placeholder float16 0.0e+00
44 placeholder float16 0.0e+00
45 placeholder float16 0.0e+00
46 placeholder float16 0.0e+00
47 placeholder float16 0.0e+00
48 placeholder float16 0.0e+00
49 placeholder float16 0.0e+00
50 placeholder float16 0.0e+00
51 placeholder float16 0.0e+00
52 placeholder float16 0.0e+00
53 placeholder float16 0.0e+00
54 placeholder float16 0.0e+00
55 placeholder float16 0.0e+00
56 placeholder float16 0.0e+00
57 placeholder float16 0.0e+00
58 placeholder float16 0.0e+00
59 placeholder float16 0.0e+00
60 placeholder float16 0.0e+00
61 placeholder float16 0.0e+00
62 placeholder float16 0.0e+00
63 placeholder float16 0.0e+00
64 placeholder float16 0.0e+00
65 placeholder float16 0.0e+00
66 placeholder float16 0.0e+00
67 placeholder float16 0.0e+00
68 placeholder float16 0.0e+00
69 placeholder float16 0.0e+00
70 placeholder float16 0.0e+00
71 placeholder float16 0.0e+00
72 placeholder float16 0.0e+00
73 placeholder float16 0.0e+00
74 placeholder float16 0.0e+00
75 placeholder float16 0.0e+00
76 placeholder float16 0.0e+00
77 placeholder float16 0.0e+00
78 placeholder float16 0.0e+00
79 placeholder float16 0.0e+00
80 placeholder float16 0.0e+00
81 placeholder float16 0.0e+00
82 placeholder float16 0.0e+00
83 placeholder float16 0.0e+00
84 placeholder float16 0.0e+00
85 placeholder float16 0.0e+00
86 placeholder float16 0.0e+00
87 placeholder float16 0.0e+00
88 placeholder float16 0.0e+00
89 placeholder float16 0.0e+00
90 placeholder float16 0.0e+00
91 placeholder float16 0.0e+00
92 placeholder float16 0.0e+00
93 placeholder float16 0.0e+00
94 placeholder float16 0.0e+00
95 placeholder float16 0.0e+00
96 placeholder float16 0.0e+00
97 placeholder float16 0.0e+00
98 placeholder float16 0.0e+00
99 placeholder float16 0.0e+00
100 placeholder float16 0.0e+00
101 placeholder float16 0.0e+00
102 placeholder float16 0.0e+00
103 call_function torch.nn.functional.conv2d float16 1.4e-02
104 call_function torch.nn.functional.batch_norm float16 5.8e-03
105 call_function torch.nn.functional.relu float16 4.7e-03
106 call_function torch.nn.functional.max_pool2d float16 4.4e-03
107 call_function torch.nn.functional.conv2d float16 1.7e-02
108 call_function torch.nn.functional.batch_norm float16 1.6e-02
109 call_function torch.nn.functional.relu float16 1.3e-02
110 call_function torch.nn.functional.conv2d float16 7.3e-03
111 call_function torch.nn.functional.batch_norm float16 1.6e-02
112 call_function operator.iadd float16 1.7e-02
113 call_function torch.nn.functional.relu float16 1.6e-02
114 call_function torch.nn.functional.conv2d float16 1.8e-02
115 call_function torch.nn.functional.batch_norm float16 1.4e-02
116 call_function torch.nn.functional.relu float16 1.3e-02
117 call_function torch.nn.functional.conv2d float16 7.4e-03
118 call_function torch.nn.functional.batch_norm float16 2.0e-02
119 call_function operator.iadd float16 2.6e-02
120 call_function torch.nn.functional.relu float16 2.6e-02
121 call_function torch.nn.functional.conv2d float16 2.9e-02
122 call_function torch.nn.functional.batch_norm float16 1.2e-02
123 call_function torch.nn.functional.relu float16 1.2e-02
124 call_function torch.nn.functional.conv2d float16 1.2e-02
125 call_function torch.nn.functional.batch_norm float16 1.4e-02
126 call_function torch.nn.functional.conv2d float16 9.6e-03
127 call_function torch.nn.functional.batch_norm float16 1.2e-02
128 call_function operator.iadd float16 1.7e-02
129 call_function torch.nn.functional.relu float16 1.3e-02
130 call_function torch.nn.functional.conv2d float16 1.3e-02
131 call_function torch.nn.functional.batch_norm float16 1.2e-02
132 call_function torch.nn.functional.relu float16 1.2e-02
133 call_function torch.nn.functional.conv2d float16 7.1e-03
134 call_function torch.nn.functional.batch_norm float16 1.6e-02
135 call_function operator.iadd float16 1.9e-02
136 call_function torch.nn.functional.relu float16 1.9e-02
137 call_function torch.nn.functional.conv2d float16 1.5e-02
138 call_function torch.nn.functional.batch_norm float16 1.1e-02
139 call_function torch.nn.functional.relu float16 1.1e-02
140 call_function torch.nn.functional.conv2d float16 7.2e-03
141 call_function torch.nn.functional.batch_norm float16 1.1e-02
142 call_function torch.nn.functional.conv2d float16 4.2e-03
143 call_function torch.nn.functional.batch_norm float16 5.0e-03
144 call_function operator.iadd float16 1.1e-02
145 call_function torch.nn.functional.relu float16 1.1e-02
146 call_function torch.nn.functional.conv2d float16 8.9e-03
147 call_function torch.nn.functional.batch_norm float16 1.1e-02
148 call_function torch.nn.functional.relu float16 9.1e-03
149 call_function torch.nn.functional.conv2d float16 4.5e-03
150 call_function torch.nn.functional.batch_norm float16 1.3e-02
151 call_function operator.iadd float16 1.4e-02
152 call_function torch.nn.functional.relu float16 1.3e-02
153 call_function torch.nn.functional.conv2d float16 1.3e-02
154 call_function torch.nn.functional.batch_norm float16 1.0e-02
155 call_function torch.nn.functional.relu float16 6.6e-03
156 call_function torch.nn.functional.conv2d float16 3.7e-03
157 call_function torch.nn.functional.batch_norm float16 1.0e-02
158 call_function torch.nn.functional.conv2d float16 4.4e-03
159 call_function torch.nn.functional.batch_norm float16 8.7e-03
160 call_function operator.iadd float16 1.2e-02
161 call_function torch.nn.functional.relu float16 1.1e-02
162 call_function torch.nn.functional.conv2d float16 8.2e-03
163 call_function torch.nn.functional.batch_norm float16 8.6e-03
164 call_function torch.nn.functional.relu float16 7.8e-03
165 call_function torch.nn.functional.conv2d float16 3.4e-03
166 call_function torch.nn.functional.batch_norm float16 4.5e-02
167 call_function operator.iadd float16 4.5e-02
168 call_function torch.nn.functional.relu float16 4.4e-02
169 call_function torch.nn.functional.adaptive_avg_pool2d float16 7.7e-03
170 call_function torch.flatten float16 7.7e-03
171 call_function torch.nn.functional.linear float16 3.5e+00 <------
172 output float16 3.5e+00 <------
Tip
Usually, we can expect:
for float32: \(e_h \leq 10^{-5}\), and
for float16: \(e_h \leq 10^{-2}\).
The correctness report will print the harmonic mean of the absolute error and relative error for each operator:
where \(actual\), \(expected\) are the actual and expected results of the operator, respectively. The \(e_a\) and \(e_r\) are the absolute error and relative error, respectively. The harmonic mean error is printed for each operator.
Operator configurations¶
Use CUDA Graph to dispatch kernels¶
Hidet provides a configuration to use CUDA Graph to dispatch kernels. CUDA Graph is a new feature in CUDA 11.0
that allows us to record the kernel dispatches and replay them later. This feature is useful when we want to
dispatch the same kernels multiple times. Hidet will enable CUDA Graph by default. You can disable it via
use_cuda_graph()
:
# disable CUDA Graph
hidet.torch.dynamo_config.use_cuda_graph(False)
in case you want to use PyTorch’s CUDA Graph feature.
Print the input graph¶
If you are interested in the graph that PyTorch dynamo dispatches to hidet backend, you can configure hidet to
print the graph via print_input_graph()
:
# print the input graph
hidet.torch.dynamo_config.print_input_graph(True)
Because ResNet18 is a neat model without control flow, we can print the input graph to see how PyTorch dynamo dispatches the model to hidet backend:
with torch.no_grad():
hidet.torch.dynamo_config.print_input_graph(True)
model_opt = torch.compile(model, backend='hidet', mode='max-autotune')
model_opt(x)
class GraphModule(torch.nn.Module):
def forward(self, L_self_modules_conv1_parameters_weight_: "f16[64, 3, 7, 7]", L_x_: "f16[1, 3, 224, 224]", L_self_modules_bn1_buffers_running_mean_: "f16[64]", L_self_modules_bn1_buffers_running_var_: "f16[64]", L_self_modules_bn1_parameters_weight_: "f16[64]", L_self_modules_bn1_parameters_bias_: "f16[64]", L_self_modules_layer1_modules_0_modules_conv1_parameters_weight_: "f16[64, 64, 3, 3]", L_self_modules_layer1_modules_0_modules_bn1_buffers_running_mean_: "f16[64]", L_self_modules_layer1_modules_0_modules_bn1_buffers_running_var_: "f16[64]", L_self_modules_layer1_modules_0_modules_bn1_parameters_weight_: "f16[64]", L_self_modules_layer1_modules_0_modules_bn1_parameters_bias_: "f16[64]", L_self_modules_layer1_modules_0_modules_conv2_parameters_weight_: "f16[64, 64, 3, 3]", L_self_modules_layer1_modules_0_modules_bn2_buffers_running_mean_: "f16[64]", L_self_modules_layer1_modules_0_modules_bn2_buffers_running_var_: "f16[64]", L_self_modules_layer1_modules_0_modules_bn2_parameters_weight_: "f16[64]", L_self_modules_layer1_modules_0_modules_bn2_parameters_bias_: "f16[64]", L_self_modules_layer1_modules_1_modules_conv1_parameters_weight_: "f16[64, 64, 3, 3]", L_self_modules_layer1_modules_1_modules_bn1_buffers_running_mean_: "f16[64]", L_self_modules_layer1_modules_1_modules_bn1_buffers_running_var_: "f16[64]", L_self_modules_layer1_modules_1_modules_bn1_parameters_weight_: "f16[64]", L_self_modules_layer1_modules_1_modules_bn1_parameters_bias_: "f16[64]", L_self_modules_layer1_modules_1_modules_conv2_parameters_weight_: "f16[64, 64, 3, 3]", L_self_modules_layer1_modules_1_modules_bn2_buffers_running_mean_: "f16[64]", L_self_modules_layer1_modules_1_modules_bn2_buffers_running_var_: "f16[64]", L_self_modules_layer1_modules_1_modules_bn2_parameters_weight_: "f16[64]", L_self_modules_layer1_modules_1_modules_bn2_parameters_bias_: "f16[64]", L_self_modules_layer2_modules_0_modules_conv1_parameters_weight_: "f16[128, 64, 3, 3]", L_self_modules_layer2_modules_0_modules_bn1_buffers_running_mean_: "f16[128]", L_self_modules_layer2_modules_0_modules_bn1_buffers_running_var_: "f16[128]", L_self_modules_layer2_modules_0_modules_bn1_parameters_weight_: "f16[128]", L_self_modules_layer2_modules_0_modules_bn1_parameters_bias_: "f16[128]", L_self_modules_layer2_modules_0_modules_conv2_parameters_weight_: "f16[128, 128, 3, 3]", L_self_modules_layer2_modules_0_modules_bn2_buffers_running_mean_: "f16[128]", L_self_modules_layer2_modules_0_modules_bn2_buffers_running_var_: "f16[128]", L_self_modules_layer2_modules_0_modules_bn2_parameters_weight_: "f16[128]", L_self_modules_layer2_modules_0_modules_bn2_parameters_bias_: "f16[128]", L_self_modules_layer2_modules_0_modules_downsample_modules_0_parameters_weight_: "f16[128, 64, 1, 1]", L_self_modules_layer2_modules_0_modules_downsample_modules_1_buffers_running_mean_: "f16[128]", L_self_modules_layer2_modules_0_modules_downsample_modules_1_buffers_running_var_: "f16[128]", L_self_modules_layer2_modules_0_modules_downsample_modules_1_parameters_weight_: "f16[128]", L_self_modules_layer2_modules_0_modules_downsample_modules_1_parameters_bias_: "f16[128]", L_self_modules_layer2_modules_1_modules_conv1_parameters_weight_: "f16[128, 128, 3, 3]", L_self_modules_layer2_modules_1_modules_bn1_buffers_running_mean_: "f16[128]", L_self_modules_layer2_modules_1_modules_bn1_buffers_running_var_: "f16[128]", L_self_modules_layer2_modules_1_modules_bn1_parameters_weight_: "f16[128]", L_self_modules_layer2_modules_1_modules_bn1_parameters_bias_: "f16[128]", L_self_modules_layer2_modules_1_modules_conv2_parameters_weight_: "f16[128, 128, 3, 3]", L_self_modules_layer2_modules_1_modules_bn2_buffers_running_mean_: "f16[128]", L_self_modules_layer2_modules_1_modules_bn2_buffers_running_var_: "f16[128]", L_self_modules_layer2_modules_1_modules_bn2_parameters_weight_: "f16[128]", L_self_modules_layer2_modules_1_modules_bn2_parameters_bias_: "f16[128]", L_self_modules_layer3_modules_0_modules_conv1_parameters_weight_: "f16[256, 128, 3, 3]", L_self_modules_layer3_modules_0_modules_bn1_buffers_running_mean_: "f16[256]", L_self_modules_layer3_modules_0_modules_bn1_buffers_running_var_: "f16[256]", L_self_modules_layer3_modules_0_modules_bn1_parameters_weight_: "f16[256]", L_self_modules_layer3_modules_0_modules_bn1_parameters_bias_: "f16[256]", L_self_modules_layer3_modules_0_modules_conv2_parameters_weight_: "f16[256, 256, 3, 3]", L_self_modules_layer3_modules_0_modules_bn2_buffers_running_mean_: "f16[256]", L_self_modules_layer3_modules_0_modules_bn2_buffers_running_var_: "f16[256]", L_self_modules_layer3_modules_0_modules_bn2_parameters_weight_: "f16[256]", L_self_modules_layer3_modules_0_modules_bn2_parameters_bias_: "f16[256]", L_self_modules_layer3_modules_0_modules_downsample_modules_0_parameters_weight_: "f16[256, 128, 1, 1]", L_self_modules_layer3_modules_0_modules_downsample_modules_1_buffers_running_mean_: "f16[256]", L_self_modules_layer3_modules_0_modules_downsample_modules_1_buffers_running_var_: "f16[256]", L_self_modules_layer3_modules_0_modules_downsample_modules_1_parameters_weight_: "f16[256]", L_self_modules_layer3_modules_0_modules_downsample_modules_1_parameters_bias_: "f16[256]", L_self_modules_layer3_modules_1_modules_conv1_parameters_weight_: "f16[256, 256, 3, 3]", L_self_modules_layer3_modules_1_modules_bn1_buffers_running_mean_: "f16[256]", L_self_modules_layer3_modules_1_modules_bn1_buffers_running_var_: "f16[256]", L_self_modules_layer3_modules_1_modules_bn1_parameters_weight_: "f16[256]", L_self_modules_layer3_modules_1_modules_bn1_parameters_bias_: "f16[256]", L_self_modules_layer3_modules_1_modules_conv2_parameters_weight_: "f16[256, 256, 3, 3]", L_self_modules_layer3_modules_1_modules_bn2_buffers_running_mean_: "f16[256]", L_self_modules_layer3_modules_1_modules_bn2_buffers_running_var_: "f16[256]", L_self_modules_layer3_modules_1_modules_bn2_parameters_weight_: "f16[256]", L_self_modules_layer3_modules_1_modules_bn2_parameters_bias_: "f16[256]", L_self_modules_layer4_modules_0_modules_conv1_parameters_weight_: "f16[512, 256, 3, 3]", L_self_modules_layer4_modules_0_modules_bn1_buffers_running_mean_: "f16[512]", L_self_modules_layer4_modules_0_modules_bn1_buffers_running_var_: "f16[512]", L_self_modules_layer4_modules_0_modules_bn1_parameters_weight_: "f16[512]", L_self_modules_layer4_modules_0_modules_bn1_parameters_bias_: "f16[512]", L_self_modules_layer4_modules_0_modules_conv2_parameters_weight_: "f16[512, 512, 3, 3]", L_self_modules_layer4_modules_0_modules_bn2_buffers_running_mean_: "f16[512]", L_self_modules_layer4_modules_0_modules_bn2_buffers_running_var_: "f16[512]", L_self_modules_layer4_modules_0_modules_bn2_parameters_weight_: "f16[512]", L_self_modules_layer4_modules_0_modules_bn2_parameters_bias_: "f16[512]", L_self_modules_layer4_modules_0_modules_downsample_modules_0_parameters_weight_: "f16[512, 256, 1, 1]", L_self_modules_layer4_modules_0_modules_downsample_modules_1_buffers_running_mean_: "f16[512]", L_self_modules_layer4_modules_0_modules_downsample_modules_1_buffers_running_var_: "f16[512]", L_self_modules_layer4_modules_0_modules_downsample_modules_1_parameters_weight_: "f16[512]", L_self_modules_layer4_modules_0_modules_downsample_modules_1_parameters_bias_: "f16[512]", L_self_modules_layer4_modules_1_modules_conv1_parameters_weight_: "f16[512, 512, 3, 3]", L_self_modules_layer4_modules_1_modules_bn1_buffers_running_mean_: "f16[512]", L_self_modules_layer4_modules_1_modules_bn1_buffers_running_var_: "f16[512]", L_self_modules_layer4_modules_1_modules_bn1_parameters_weight_: "f16[512]", L_self_modules_layer4_modules_1_modules_bn1_parameters_bias_: "f16[512]", L_self_modules_layer4_modules_1_modules_conv2_parameters_weight_: "f16[512, 512, 3, 3]", L_self_modules_layer4_modules_1_modules_bn2_buffers_running_mean_: "f16[512]", L_self_modules_layer4_modules_1_modules_bn2_buffers_running_var_: "f16[512]", L_self_modules_layer4_modules_1_modules_bn2_parameters_weight_: "f16[512]", L_self_modules_layer4_modules_1_modules_bn2_parameters_bias_: "f16[512]", L_self_modules_fc_parameters_weight_: "f16[1000, 512]", L_self_modules_fc_parameters_bias_: "f16[1000]"):
l_self_modules_conv1_parameters_weight_ = L_self_modules_conv1_parameters_weight_
l_x_ = L_x_
l_self_modules_bn1_buffers_running_mean_ = L_self_modules_bn1_buffers_running_mean_
l_self_modules_bn1_buffers_running_var_ = L_self_modules_bn1_buffers_running_var_
l_self_modules_bn1_parameters_weight_ = L_self_modules_bn1_parameters_weight_
l_self_modules_bn1_parameters_bias_ = L_self_modules_bn1_parameters_bias_
l_self_modules_layer1_modules_0_modules_conv1_parameters_weight_ = L_self_modules_layer1_modules_0_modules_conv1_parameters_weight_
l_self_modules_layer1_modules_0_modules_bn1_buffers_running_mean_ = L_self_modules_layer1_modules_0_modules_bn1_buffers_running_mean_
l_self_modules_layer1_modules_0_modules_bn1_buffers_running_var_ = L_self_modules_layer1_modules_0_modules_bn1_buffers_running_var_
l_self_modules_layer1_modules_0_modules_bn1_parameters_weight_ = L_self_modules_layer1_modules_0_modules_bn1_parameters_weight_
l_self_modules_layer1_modules_0_modules_bn1_parameters_bias_ = L_self_modules_layer1_modules_0_modules_bn1_parameters_bias_
l_self_modules_layer1_modules_0_modules_conv2_parameters_weight_ = L_self_modules_layer1_modules_0_modules_conv2_parameters_weight_
l_self_modules_layer1_modules_0_modules_bn2_buffers_running_mean_ = L_self_modules_layer1_modules_0_modules_bn2_buffers_running_mean_
l_self_modules_layer1_modules_0_modules_bn2_buffers_running_var_ = L_self_modules_layer1_modules_0_modules_bn2_buffers_running_var_
l_self_modules_layer1_modules_0_modules_bn2_parameters_weight_ = L_self_modules_layer1_modules_0_modules_bn2_parameters_weight_
l_self_modules_layer1_modules_0_modules_bn2_parameters_bias_ = L_self_modules_layer1_modules_0_modules_bn2_parameters_bias_
l_self_modules_layer1_modules_1_modules_conv1_parameters_weight_ = L_self_modules_layer1_modules_1_modules_conv1_parameters_weight_
l_self_modules_layer1_modules_1_modules_bn1_buffers_running_mean_ = L_self_modules_layer1_modules_1_modules_bn1_buffers_running_mean_
l_self_modules_layer1_modules_1_modules_bn1_buffers_running_var_ = L_self_modules_layer1_modules_1_modules_bn1_buffers_running_var_
l_self_modules_layer1_modules_1_modules_bn1_parameters_weight_ = L_self_modules_layer1_modules_1_modules_bn1_parameters_weight_
l_self_modules_layer1_modules_1_modules_bn1_parameters_bias_ = L_self_modules_layer1_modules_1_modules_bn1_parameters_bias_
l_self_modules_layer1_modules_1_modules_conv2_parameters_weight_ = L_self_modules_layer1_modules_1_modules_conv2_parameters_weight_
l_self_modules_layer1_modules_1_modules_bn2_buffers_running_mean_ = L_self_modules_layer1_modules_1_modules_bn2_buffers_running_mean_
l_self_modules_layer1_modules_1_modules_bn2_buffers_running_var_ = L_self_modules_layer1_modules_1_modules_bn2_buffers_running_var_
l_self_modules_layer1_modules_1_modules_bn2_parameters_weight_ = L_self_modules_layer1_modules_1_modules_bn2_parameters_weight_
l_self_modules_layer1_modules_1_modules_bn2_parameters_bias_ = L_self_modules_layer1_modules_1_modules_bn2_parameters_bias_
l_self_modules_layer2_modules_0_modules_conv1_parameters_weight_ = L_self_modules_layer2_modules_0_modules_conv1_parameters_weight_
l_self_modules_layer2_modules_0_modules_bn1_buffers_running_mean_ = L_self_modules_layer2_modules_0_modules_bn1_buffers_running_mean_
l_self_modules_layer2_modules_0_modules_bn1_buffers_running_var_ = L_self_modules_layer2_modules_0_modules_bn1_buffers_running_var_
l_self_modules_layer2_modules_0_modules_bn1_parameters_weight_ = L_self_modules_layer2_modules_0_modules_bn1_parameters_weight_
l_self_modules_layer2_modules_0_modules_bn1_parameters_bias_ = L_self_modules_layer2_modules_0_modules_bn1_parameters_bias_
l_self_modules_layer2_modules_0_modules_conv2_parameters_weight_ = L_self_modules_layer2_modules_0_modules_conv2_parameters_weight_
l_self_modules_layer2_modules_0_modules_bn2_buffers_running_mean_ = L_self_modules_layer2_modules_0_modules_bn2_buffers_running_mean_
l_self_modules_layer2_modules_0_modules_bn2_buffers_running_var_ = L_self_modules_layer2_modules_0_modules_bn2_buffers_running_var_
l_self_modules_layer2_modules_0_modules_bn2_parameters_weight_ = L_self_modules_layer2_modules_0_modules_bn2_parameters_weight_
l_self_modules_layer2_modules_0_modules_bn2_parameters_bias_ = L_self_modules_layer2_modules_0_modules_bn2_parameters_bias_
l_self_modules_layer2_modules_0_modules_downsample_modules_0_parameters_weight_ = L_self_modules_layer2_modules_0_modules_downsample_modules_0_parameters_weight_
l_self_modules_layer2_modules_0_modules_downsample_modules_1_buffers_running_mean_ = L_self_modules_layer2_modules_0_modules_downsample_modules_1_buffers_running_mean_
l_self_modules_layer2_modules_0_modules_downsample_modules_1_buffers_running_var_ = L_self_modules_layer2_modules_0_modules_downsample_modules_1_buffers_running_var_
l_self_modules_layer2_modules_0_modules_downsample_modules_1_parameters_weight_ = L_self_modules_layer2_modules_0_modules_downsample_modules_1_parameters_weight_
l_self_modules_layer2_modules_0_modules_downsample_modules_1_parameters_bias_ = L_self_modules_layer2_modules_0_modules_downsample_modules_1_parameters_bias_
l_self_modules_layer2_modules_1_modules_conv1_parameters_weight_ = L_self_modules_layer2_modules_1_modules_conv1_parameters_weight_
l_self_modules_layer2_modules_1_modules_bn1_buffers_running_mean_ = L_self_modules_layer2_modules_1_modules_bn1_buffers_running_mean_
l_self_modules_layer2_modules_1_modules_bn1_buffers_running_var_ = L_self_modules_layer2_modules_1_modules_bn1_buffers_running_var_
l_self_modules_layer2_modules_1_modules_bn1_parameters_weight_ = L_self_modules_layer2_modules_1_modules_bn1_parameters_weight_
l_self_modules_layer2_modules_1_modules_bn1_parameters_bias_ = L_self_modules_layer2_modules_1_modules_bn1_parameters_bias_
l_self_modules_layer2_modules_1_modules_conv2_parameters_weight_ = L_self_modules_layer2_modules_1_modules_conv2_parameters_weight_
l_self_modules_layer2_modules_1_modules_bn2_buffers_running_mean_ = L_self_modules_layer2_modules_1_modules_bn2_buffers_running_mean_
l_self_modules_layer2_modules_1_modules_bn2_buffers_running_var_ = L_self_modules_layer2_modules_1_modules_bn2_buffers_running_var_
l_self_modules_layer2_modules_1_modules_bn2_parameters_weight_ = L_self_modules_layer2_modules_1_modules_bn2_parameters_weight_
l_self_modules_layer2_modules_1_modules_bn2_parameters_bias_ = L_self_modules_layer2_modules_1_modules_bn2_parameters_bias_
l_self_modules_layer3_modules_0_modules_conv1_parameters_weight_ = L_self_modules_layer3_modules_0_modules_conv1_parameters_weight_
l_self_modules_layer3_modules_0_modules_bn1_buffers_running_mean_ = L_self_modules_layer3_modules_0_modules_bn1_buffers_running_mean_
l_self_modules_layer3_modules_0_modules_bn1_buffers_running_var_ = L_self_modules_layer3_modules_0_modules_bn1_buffers_running_var_
l_self_modules_layer3_modules_0_modules_bn1_parameters_weight_ = L_self_modules_layer3_modules_0_modules_bn1_parameters_weight_
l_self_modules_layer3_modules_0_modules_bn1_parameters_bias_ = L_self_modules_layer3_modules_0_modules_bn1_parameters_bias_
l_self_modules_layer3_modules_0_modules_conv2_parameters_weight_ = L_self_modules_layer3_modules_0_modules_conv2_parameters_weight_
l_self_modules_layer3_modules_0_modules_bn2_buffers_running_mean_ = L_self_modules_layer3_modules_0_modules_bn2_buffers_running_mean_
l_self_modules_layer3_modules_0_modules_bn2_buffers_running_var_ = L_self_modules_layer3_modules_0_modules_bn2_buffers_running_var_
l_self_modules_layer3_modules_0_modules_bn2_parameters_weight_ = L_self_modules_layer3_modules_0_modules_bn2_parameters_weight_
l_self_modules_layer3_modules_0_modules_bn2_parameters_bias_ = L_self_modules_layer3_modules_0_modules_bn2_parameters_bias_
l_self_modules_layer3_modules_0_modules_downsample_modules_0_parameters_weight_ = L_self_modules_layer3_modules_0_modules_downsample_modules_0_parameters_weight_
l_self_modules_layer3_modules_0_modules_downsample_modules_1_buffers_running_mean_ = L_self_modules_layer3_modules_0_modules_downsample_modules_1_buffers_running_mean_
l_self_modules_layer3_modules_0_modules_downsample_modules_1_buffers_running_var_ = L_self_modules_layer3_modules_0_modules_downsample_modules_1_buffers_running_var_
l_self_modules_layer3_modules_0_modules_downsample_modules_1_parameters_weight_ = L_self_modules_layer3_modules_0_modules_downsample_modules_1_parameters_weight_
l_self_modules_layer3_modules_0_modules_downsample_modules_1_parameters_bias_ = L_self_modules_layer3_modules_0_modules_downsample_modules_1_parameters_bias_
l_self_modules_layer3_modules_1_modules_conv1_parameters_weight_ = L_self_modules_layer3_modules_1_modules_conv1_parameters_weight_
l_self_modules_layer3_modules_1_modules_bn1_buffers_running_mean_ = L_self_modules_layer3_modules_1_modules_bn1_buffers_running_mean_
l_self_modules_layer3_modules_1_modules_bn1_buffers_running_var_ = L_self_modules_layer3_modules_1_modules_bn1_buffers_running_var_
l_self_modules_layer3_modules_1_modules_bn1_parameters_weight_ = L_self_modules_layer3_modules_1_modules_bn1_parameters_weight_
l_self_modules_layer3_modules_1_modules_bn1_parameters_bias_ = L_self_modules_layer3_modules_1_modules_bn1_parameters_bias_
l_self_modules_layer3_modules_1_modules_conv2_parameters_weight_ = L_self_modules_layer3_modules_1_modules_conv2_parameters_weight_
l_self_modules_layer3_modules_1_modules_bn2_buffers_running_mean_ = L_self_modules_layer3_modules_1_modules_bn2_buffers_running_mean_
l_self_modules_layer3_modules_1_modules_bn2_buffers_running_var_ = L_self_modules_layer3_modules_1_modules_bn2_buffers_running_var_
l_self_modules_layer3_modules_1_modules_bn2_parameters_weight_ = L_self_modules_layer3_modules_1_modules_bn2_parameters_weight_
l_self_modules_layer3_modules_1_modules_bn2_parameters_bias_ = L_self_modules_layer3_modules_1_modules_bn2_parameters_bias_
l_self_modules_layer4_modules_0_modules_conv1_parameters_weight_ = L_self_modules_layer4_modules_0_modules_conv1_parameters_weight_
l_self_modules_layer4_modules_0_modules_bn1_buffers_running_mean_ = L_self_modules_layer4_modules_0_modules_bn1_buffers_running_mean_
l_self_modules_layer4_modules_0_modules_bn1_buffers_running_var_ = L_self_modules_layer4_modules_0_modules_bn1_buffers_running_var_
l_self_modules_layer4_modules_0_modules_bn1_parameters_weight_ = L_self_modules_layer4_modules_0_modules_bn1_parameters_weight_
l_self_modules_layer4_modules_0_modules_bn1_parameters_bias_ = L_self_modules_layer4_modules_0_modules_bn1_parameters_bias_
l_self_modules_layer4_modules_0_modules_conv2_parameters_weight_ = L_self_modules_layer4_modules_0_modules_conv2_parameters_weight_
l_self_modules_layer4_modules_0_modules_bn2_buffers_running_mean_ = L_self_modules_layer4_modules_0_modules_bn2_buffers_running_mean_
l_self_modules_layer4_modules_0_modules_bn2_buffers_running_var_ = L_self_modules_layer4_modules_0_modules_bn2_buffers_running_var_
l_self_modules_layer4_modules_0_modules_bn2_parameters_weight_ = L_self_modules_layer4_modules_0_modules_bn2_parameters_weight_
l_self_modules_layer4_modules_0_modules_bn2_parameters_bias_ = L_self_modules_layer4_modules_0_modules_bn2_parameters_bias_
l_self_modules_layer4_modules_0_modules_downsample_modules_0_parameters_weight_ = L_self_modules_layer4_modules_0_modules_downsample_modules_0_parameters_weight_
l_self_modules_layer4_modules_0_modules_downsample_modules_1_buffers_running_mean_ = L_self_modules_layer4_modules_0_modules_downsample_modules_1_buffers_running_mean_
l_self_modules_layer4_modules_0_modules_downsample_modules_1_buffers_running_var_ = L_self_modules_layer4_modules_0_modules_downsample_modules_1_buffers_running_var_
l_self_modules_layer4_modules_0_modules_downsample_modules_1_parameters_weight_ = L_self_modules_layer4_modules_0_modules_downsample_modules_1_parameters_weight_
l_self_modules_layer4_modules_0_modules_downsample_modules_1_parameters_bias_ = L_self_modules_layer4_modules_0_modules_downsample_modules_1_parameters_bias_
l_self_modules_layer4_modules_1_modules_conv1_parameters_weight_ = L_self_modules_layer4_modules_1_modules_conv1_parameters_weight_
l_self_modules_layer4_modules_1_modules_bn1_buffers_running_mean_ = L_self_modules_layer4_modules_1_modules_bn1_buffers_running_mean_
l_self_modules_layer4_modules_1_modules_bn1_buffers_running_var_ = L_self_modules_layer4_modules_1_modules_bn1_buffers_running_var_
l_self_modules_layer4_modules_1_modules_bn1_parameters_weight_ = L_self_modules_layer4_modules_1_modules_bn1_parameters_weight_
l_self_modules_layer4_modules_1_modules_bn1_parameters_bias_ = L_self_modules_layer4_modules_1_modules_bn1_parameters_bias_
l_self_modules_layer4_modules_1_modules_conv2_parameters_weight_ = L_self_modules_layer4_modules_1_modules_conv2_parameters_weight_
l_self_modules_layer4_modules_1_modules_bn2_buffers_running_mean_ = L_self_modules_layer4_modules_1_modules_bn2_buffers_running_mean_
l_self_modules_layer4_modules_1_modules_bn2_buffers_running_var_ = L_self_modules_layer4_modules_1_modules_bn2_buffers_running_var_
l_self_modules_layer4_modules_1_modules_bn2_parameters_weight_ = L_self_modules_layer4_modules_1_modules_bn2_parameters_weight_
l_self_modules_layer4_modules_1_modules_bn2_parameters_bias_ = L_self_modules_layer4_modules_1_modules_bn2_parameters_bias_
l_self_modules_fc_parameters_weight_ = L_self_modules_fc_parameters_weight_
l_self_modules_fc_parameters_bias_ = L_self_modules_fc_parameters_bias_
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:232 in _forward_impl, code: x = self.conv1(x)
x: "f16[1, 64, 112, 112]" = torch.conv2d(l_x_, l_self_modules_conv1_parameters_weight_, None, (2, 2), (3, 3), (1, 1), 1); l_x_ = l_self_modules_conv1_parameters_weight_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:233 in _forward_impl, code: x = self.bn1(x)
x_1: "f16[1, 64, 112, 112]" = torch.nn.functional.batch_norm(x, l_self_modules_bn1_buffers_running_mean_, l_self_modules_bn1_buffers_running_var_, l_self_modules_bn1_parameters_weight_, l_self_modules_bn1_parameters_bias_, False, 0.1, 1e-05); x = l_self_modules_bn1_buffers_running_mean_ = l_self_modules_bn1_buffers_running_var_ = l_self_modules_bn1_parameters_weight_ = l_self_modules_bn1_parameters_bias_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:234 in _forward_impl, code: x = self.relu(x)
x_2: "f16[1, 64, 112, 112]" = torch.nn.functional.relu(x_1, inplace = True); x_1 = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:235 in _forward_impl, code: x = self.maxpool(x)
x_3: "f16[1, 64, 56, 56]" = torch.nn.functional.max_pool2d(x_2, 3, 2, 1, 1, ceil_mode = False, return_indices = False); x_2 = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:70 in forward, code: out = self.conv1(x)
out: "f16[1, 64, 56, 56]" = torch.conv2d(x_3, l_self_modules_layer1_modules_0_modules_conv1_parameters_weight_, None, (1, 1), (1, 1), (1, 1), 1); l_self_modules_layer1_modules_0_modules_conv1_parameters_weight_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:71 in forward, code: out = self.bn1(out)
out_1: "f16[1, 64, 56, 56]" = torch.nn.functional.batch_norm(out, l_self_modules_layer1_modules_0_modules_bn1_buffers_running_mean_, l_self_modules_layer1_modules_0_modules_bn1_buffers_running_var_, l_self_modules_layer1_modules_0_modules_bn1_parameters_weight_, l_self_modules_layer1_modules_0_modules_bn1_parameters_bias_, False, 0.1, 1e-05); out = l_self_modules_layer1_modules_0_modules_bn1_buffers_running_mean_ = l_self_modules_layer1_modules_0_modules_bn1_buffers_running_var_ = l_self_modules_layer1_modules_0_modules_bn1_parameters_weight_ = l_self_modules_layer1_modules_0_modules_bn1_parameters_bias_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:72 in forward, code: out = self.relu(out)
out_2: "f16[1, 64, 56, 56]" = torch.nn.functional.relu(out_1, inplace = True); out_1 = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:74 in forward, code: out = self.conv2(out)
out_3: "f16[1, 64, 56, 56]" = torch.conv2d(out_2, l_self_modules_layer1_modules_0_modules_conv2_parameters_weight_, None, (1, 1), (1, 1), (1, 1), 1); out_2 = l_self_modules_layer1_modules_0_modules_conv2_parameters_weight_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:75 in forward, code: out = self.bn2(out)
out_4: "f16[1, 64, 56, 56]" = torch.nn.functional.batch_norm(out_3, l_self_modules_layer1_modules_0_modules_bn2_buffers_running_mean_, l_self_modules_layer1_modules_0_modules_bn2_buffers_running_var_, l_self_modules_layer1_modules_0_modules_bn2_parameters_weight_, l_self_modules_layer1_modules_0_modules_bn2_parameters_bias_, False, 0.1, 1e-05); out_3 = l_self_modules_layer1_modules_0_modules_bn2_buffers_running_mean_ = l_self_modules_layer1_modules_0_modules_bn2_buffers_running_var_ = l_self_modules_layer1_modules_0_modules_bn2_parameters_weight_ = l_self_modules_layer1_modules_0_modules_bn2_parameters_bias_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:80 in forward, code: out += identity
out_4 += x_3; out_5: "f16[1, 64, 56, 56]" = out_4; out_4 = x_3 = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:81 in forward, code: out = self.relu(out)
out_6: "f16[1, 64, 56, 56]" = torch.nn.functional.relu(out_5, inplace = True); out_5 = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:70 in forward, code: out = self.conv1(x)
out_7: "f16[1, 64, 56, 56]" = torch.conv2d(out_6, l_self_modules_layer1_modules_1_modules_conv1_parameters_weight_, None, (1, 1), (1, 1), (1, 1), 1); l_self_modules_layer1_modules_1_modules_conv1_parameters_weight_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:71 in forward, code: out = self.bn1(out)
out_8: "f16[1, 64, 56, 56]" = torch.nn.functional.batch_norm(out_7, l_self_modules_layer1_modules_1_modules_bn1_buffers_running_mean_, l_self_modules_layer1_modules_1_modules_bn1_buffers_running_var_, l_self_modules_layer1_modules_1_modules_bn1_parameters_weight_, l_self_modules_layer1_modules_1_modules_bn1_parameters_bias_, False, 0.1, 1e-05); out_7 = l_self_modules_layer1_modules_1_modules_bn1_buffers_running_mean_ = l_self_modules_layer1_modules_1_modules_bn1_buffers_running_var_ = l_self_modules_layer1_modules_1_modules_bn1_parameters_weight_ = l_self_modules_layer1_modules_1_modules_bn1_parameters_bias_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:72 in forward, code: out = self.relu(out)
out_9: "f16[1, 64, 56, 56]" = torch.nn.functional.relu(out_8, inplace = True); out_8 = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:74 in forward, code: out = self.conv2(out)
out_10: "f16[1, 64, 56, 56]" = torch.conv2d(out_9, l_self_modules_layer1_modules_1_modules_conv2_parameters_weight_, None, (1, 1), (1, 1), (1, 1), 1); out_9 = l_self_modules_layer1_modules_1_modules_conv2_parameters_weight_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:75 in forward, code: out = self.bn2(out)
out_11: "f16[1, 64, 56, 56]" = torch.nn.functional.batch_norm(out_10, l_self_modules_layer1_modules_1_modules_bn2_buffers_running_mean_, l_self_modules_layer1_modules_1_modules_bn2_buffers_running_var_, l_self_modules_layer1_modules_1_modules_bn2_parameters_weight_, l_self_modules_layer1_modules_1_modules_bn2_parameters_bias_, False, 0.1, 1e-05); out_10 = l_self_modules_layer1_modules_1_modules_bn2_buffers_running_mean_ = l_self_modules_layer1_modules_1_modules_bn2_buffers_running_var_ = l_self_modules_layer1_modules_1_modules_bn2_parameters_weight_ = l_self_modules_layer1_modules_1_modules_bn2_parameters_bias_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:80 in forward, code: out += identity
out_11 += out_6; out_12: "f16[1, 64, 56, 56]" = out_11; out_11 = out_6 = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:81 in forward, code: out = self.relu(out)
out_13: "f16[1, 64, 56, 56]" = torch.nn.functional.relu(out_12, inplace = True); out_12 = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:70 in forward, code: out = self.conv1(x)
out_14: "f16[1, 128, 28, 28]" = torch.conv2d(out_13, l_self_modules_layer2_modules_0_modules_conv1_parameters_weight_, None, (2, 2), (1, 1), (1, 1), 1); l_self_modules_layer2_modules_0_modules_conv1_parameters_weight_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:71 in forward, code: out = self.bn1(out)
out_15: "f16[1, 128, 28, 28]" = torch.nn.functional.batch_norm(out_14, l_self_modules_layer2_modules_0_modules_bn1_buffers_running_mean_, l_self_modules_layer2_modules_0_modules_bn1_buffers_running_var_, l_self_modules_layer2_modules_0_modules_bn1_parameters_weight_, l_self_modules_layer2_modules_0_modules_bn1_parameters_bias_, False, 0.1, 1e-05); out_14 = l_self_modules_layer2_modules_0_modules_bn1_buffers_running_mean_ = l_self_modules_layer2_modules_0_modules_bn1_buffers_running_var_ = l_self_modules_layer2_modules_0_modules_bn1_parameters_weight_ = l_self_modules_layer2_modules_0_modules_bn1_parameters_bias_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:72 in forward, code: out = self.relu(out)
out_16: "f16[1, 128, 28, 28]" = torch.nn.functional.relu(out_15, inplace = True); out_15 = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:74 in forward, code: out = self.conv2(out)
out_17: "f16[1, 128, 28, 28]" = torch.conv2d(out_16, l_self_modules_layer2_modules_0_modules_conv2_parameters_weight_, None, (1, 1), (1, 1), (1, 1), 1); out_16 = l_self_modules_layer2_modules_0_modules_conv2_parameters_weight_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:75 in forward, code: out = self.bn2(out)
out_18: "f16[1, 128, 28, 28]" = torch.nn.functional.batch_norm(out_17, l_self_modules_layer2_modules_0_modules_bn2_buffers_running_mean_, l_self_modules_layer2_modules_0_modules_bn2_buffers_running_var_, l_self_modules_layer2_modules_0_modules_bn2_parameters_weight_, l_self_modules_layer2_modules_0_modules_bn2_parameters_bias_, False, 0.1, 1e-05); out_17 = l_self_modules_layer2_modules_0_modules_bn2_buffers_running_mean_ = l_self_modules_layer2_modules_0_modules_bn2_buffers_running_var_ = l_self_modules_layer2_modules_0_modules_bn2_parameters_weight_ = l_self_modules_layer2_modules_0_modules_bn2_parameters_bias_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:78 in forward, code: identity = self.downsample(x)
input_1: "f16[1, 128, 28, 28]" = torch.conv2d(out_13, l_self_modules_layer2_modules_0_modules_downsample_modules_0_parameters_weight_, None, (2, 2), (0, 0), (1, 1), 1); out_13 = l_self_modules_layer2_modules_0_modules_downsample_modules_0_parameters_weight_ = None
input_2: "f16[1, 128, 28, 28]" = torch.nn.functional.batch_norm(input_1, l_self_modules_layer2_modules_0_modules_downsample_modules_1_buffers_running_mean_, l_self_modules_layer2_modules_0_modules_downsample_modules_1_buffers_running_var_, l_self_modules_layer2_modules_0_modules_downsample_modules_1_parameters_weight_, l_self_modules_layer2_modules_0_modules_downsample_modules_1_parameters_bias_, False, 0.1, 1e-05); input_1 = l_self_modules_layer2_modules_0_modules_downsample_modules_1_buffers_running_mean_ = l_self_modules_layer2_modules_0_modules_downsample_modules_1_buffers_running_var_ = l_self_modules_layer2_modules_0_modules_downsample_modules_1_parameters_weight_ = l_self_modules_layer2_modules_0_modules_downsample_modules_1_parameters_bias_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:80 in forward, code: out += identity
out_18 += input_2; out_19: "f16[1, 128, 28, 28]" = out_18; out_18 = input_2 = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:81 in forward, code: out = self.relu(out)
out_20: "f16[1, 128, 28, 28]" = torch.nn.functional.relu(out_19, inplace = True); out_19 = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:70 in forward, code: out = self.conv1(x)
out_21: "f16[1, 128, 28, 28]" = torch.conv2d(out_20, l_self_modules_layer2_modules_1_modules_conv1_parameters_weight_, None, (1, 1), (1, 1), (1, 1), 1); l_self_modules_layer2_modules_1_modules_conv1_parameters_weight_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:71 in forward, code: out = self.bn1(out)
out_22: "f16[1, 128, 28, 28]" = torch.nn.functional.batch_norm(out_21, l_self_modules_layer2_modules_1_modules_bn1_buffers_running_mean_, l_self_modules_layer2_modules_1_modules_bn1_buffers_running_var_, l_self_modules_layer2_modules_1_modules_bn1_parameters_weight_, l_self_modules_layer2_modules_1_modules_bn1_parameters_bias_, False, 0.1, 1e-05); out_21 = l_self_modules_layer2_modules_1_modules_bn1_buffers_running_mean_ = l_self_modules_layer2_modules_1_modules_bn1_buffers_running_var_ = l_self_modules_layer2_modules_1_modules_bn1_parameters_weight_ = l_self_modules_layer2_modules_1_modules_bn1_parameters_bias_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:72 in forward, code: out = self.relu(out)
out_23: "f16[1, 128, 28, 28]" = torch.nn.functional.relu(out_22, inplace = True); out_22 = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:74 in forward, code: out = self.conv2(out)
out_24: "f16[1, 128, 28, 28]" = torch.conv2d(out_23, l_self_modules_layer2_modules_1_modules_conv2_parameters_weight_, None, (1, 1), (1, 1), (1, 1), 1); out_23 = l_self_modules_layer2_modules_1_modules_conv2_parameters_weight_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:75 in forward, code: out = self.bn2(out)
out_25: "f16[1, 128, 28, 28]" = torch.nn.functional.batch_norm(out_24, l_self_modules_layer2_modules_1_modules_bn2_buffers_running_mean_, l_self_modules_layer2_modules_1_modules_bn2_buffers_running_var_, l_self_modules_layer2_modules_1_modules_bn2_parameters_weight_, l_self_modules_layer2_modules_1_modules_bn2_parameters_bias_, False, 0.1, 1e-05); out_24 = l_self_modules_layer2_modules_1_modules_bn2_buffers_running_mean_ = l_self_modules_layer2_modules_1_modules_bn2_buffers_running_var_ = l_self_modules_layer2_modules_1_modules_bn2_parameters_weight_ = l_self_modules_layer2_modules_1_modules_bn2_parameters_bias_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:80 in forward, code: out += identity
out_25 += out_20; out_26: "f16[1, 128, 28, 28]" = out_25; out_25 = out_20 = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:81 in forward, code: out = self.relu(out)
out_27: "f16[1, 128, 28, 28]" = torch.nn.functional.relu(out_26, inplace = True); out_26 = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:70 in forward, code: out = self.conv1(x)
out_28: "f16[1, 256, 14, 14]" = torch.conv2d(out_27, l_self_modules_layer3_modules_0_modules_conv1_parameters_weight_, None, (2, 2), (1, 1), (1, 1), 1); l_self_modules_layer3_modules_0_modules_conv1_parameters_weight_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:71 in forward, code: out = self.bn1(out)
out_29: "f16[1, 256, 14, 14]" = torch.nn.functional.batch_norm(out_28, l_self_modules_layer3_modules_0_modules_bn1_buffers_running_mean_, l_self_modules_layer3_modules_0_modules_bn1_buffers_running_var_, l_self_modules_layer3_modules_0_modules_bn1_parameters_weight_, l_self_modules_layer3_modules_0_modules_bn1_parameters_bias_, False, 0.1, 1e-05); out_28 = l_self_modules_layer3_modules_0_modules_bn1_buffers_running_mean_ = l_self_modules_layer3_modules_0_modules_bn1_buffers_running_var_ = l_self_modules_layer3_modules_0_modules_bn1_parameters_weight_ = l_self_modules_layer3_modules_0_modules_bn1_parameters_bias_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:72 in forward, code: out = self.relu(out)
out_30: "f16[1, 256, 14, 14]" = torch.nn.functional.relu(out_29, inplace = True); out_29 = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:74 in forward, code: out = self.conv2(out)
out_31: "f16[1, 256, 14, 14]" = torch.conv2d(out_30, l_self_modules_layer3_modules_0_modules_conv2_parameters_weight_, None, (1, 1), (1, 1), (1, 1), 1); out_30 = l_self_modules_layer3_modules_0_modules_conv2_parameters_weight_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:75 in forward, code: out = self.bn2(out)
out_32: "f16[1, 256, 14, 14]" = torch.nn.functional.batch_norm(out_31, l_self_modules_layer3_modules_0_modules_bn2_buffers_running_mean_, l_self_modules_layer3_modules_0_modules_bn2_buffers_running_var_, l_self_modules_layer3_modules_0_modules_bn2_parameters_weight_, l_self_modules_layer3_modules_0_modules_bn2_parameters_bias_, False, 0.1, 1e-05); out_31 = l_self_modules_layer3_modules_0_modules_bn2_buffers_running_mean_ = l_self_modules_layer3_modules_0_modules_bn2_buffers_running_var_ = l_self_modules_layer3_modules_0_modules_bn2_parameters_weight_ = l_self_modules_layer3_modules_0_modules_bn2_parameters_bias_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:78 in forward, code: identity = self.downsample(x)
input_3: "f16[1, 256, 14, 14]" = torch.conv2d(out_27, l_self_modules_layer3_modules_0_modules_downsample_modules_0_parameters_weight_, None, (2, 2), (0, 0), (1, 1), 1); out_27 = l_self_modules_layer3_modules_0_modules_downsample_modules_0_parameters_weight_ = None
input_4: "f16[1, 256, 14, 14]" = torch.nn.functional.batch_norm(input_3, l_self_modules_layer3_modules_0_modules_downsample_modules_1_buffers_running_mean_, l_self_modules_layer3_modules_0_modules_downsample_modules_1_buffers_running_var_, l_self_modules_layer3_modules_0_modules_downsample_modules_1_parameters_weight_, l_self_modules_layer3_modules_0_modules_downsample_modules_1_parameters_bias_, False, 0.1, 1e-05); input_3 = l_self_modules_layer3_modules_0_modules_downsample_modules_1_buffers_running_mean_ = l_self_modules_layer3_modules_0_modules_downsample_modules_1_buffers_running_var_ = l_self_modules_layer3_modules_0_modules_downsample_modules_1_parameters_weight_ = l_self_modules_layer3_modules_0_modules_downsample_modules_1_parameters_bias_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:80 in forward, code: out += identity
out_32 += input_4; out_33: "f16[1, 256, 14, 14]" = out_32; out_32 = input_4 = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:81 in forward, code: out = self.relu(out)
out_34: "f16[1, 256, 14, 14]" = torch.nn.functional.relu(out_33, inplace = True); out_33 = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:70 in forward, code: out = self.conv1(x)
out_35: "f16[1, 256, 14, 14]" = torch.conv2d(out_34, l_self_modules_layer3_modules_1_modules_conv1_parameters_weight_, None, (1, 1), (1, 1), (1, 1), 1); l_self_modules_layer3_modules_1_modules_conv1_parameters_weight_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:71 in forward, code: out = self.bn1(out)
out_36: "f16[1, 256, 14, 14]" = torch.nn.functional.batch_norm(out_35, l_self_modules_layer3_modules_1_modules_bn1_buffers_running_mean_, l_self_modules_layer3_modules_1_modules_bn1_buffers_running_var_, l_self_modules_layer3_modules_1_modules_bn1_parameters_weight_, l_self_modules_layer3_modules_1_modules_bn1_parameters_bias_, False, 0.1, 1e-05); out_35 = l_self_modules_layer3_modules_1_modules_bn1_buffers_running_mean_ = l_self_modules_layer3_modules_1_modules_bn1_buffers_running_var_ = l_self_modules_layer3_modules_1_modules_bn1_parameters_weight_ = l_self_modules_layer3_modules_1_modules_bn1_parameters_bias_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:72 in forward, code: out = self.relu(out)
out_37: "f16[1, 256, 14, 14]" = torch.nn.functional.relu(out_36, inplace = True); out_36 = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:74 in forward, code: out = self.conv2(out)
out_38: "f16[1, 256, 14, 14]" = torch.conv2d(out_37, l_self_modules_layer3_modules_1_modules_conv2_parameters_weight_, None, (1, 1), (1, 1), (1, 1), 1); out_37 = l_self_modules_layer3_modules_1_modules_conv2_parameters_weight_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:75 in forward, code: out = self.bn2(out)
out_39: "f16[1, 256, 14, 14]" = torch.nn.functional.batch_norm(out_38, l_self_modules_layer3_modules_1_modules_bn2_buffers_running_mean_, l_self_modules_layer3_modules_1_modules_bn2_buffers_running_var_, l_self_modules_layer3_modules_1_modules_bn2_parameters_weight_, l_self_modules_layer3_modules_1_modules_bn2_parameters_bias_, False, 0.1, 1e-05); out_38 = l_self_modules_layer3_modules_1_modules_bn2_buffers_running_mean_ = l_self_modules_layer3_modules_1_modules_bn2_buffers_running_var_ = l_self_modules_layer3_modules_1_modules_bn2_parameters_weight_ = l_self_modules_layer3_modules_1_modules_bn2_parameters_bias_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:80 in forward, code: out += identity
out_39 += out_34; out_40: "f16[1, 256, 14, 14]" = out_39; out_39 = out_34 = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:81 in forward, code: out = self.relu(out)
out_41: "f16[1, 256, 14, 14]" = torch.nn.functional.relu(out_40, inplace = True); out_40 = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:70 in forward, code: out = self.conv1(x)
out_42: "f16[1, 512, 7, 7]" = torch.conv2d(out_41, l_self_modules_layer4_modules_0_modules_conv1_parameters_weight_, None, (2, 2), (1, 1), (1, 1), 1); l_self_modules_layer4_modules_0_modules_conv1_parameters_weight_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:71 in forward, code: out = self.bn1(out)
out_43: "f16[1, 512, 7, 7]" = torch.nn.functional.batch_norm(out_42, l_self_modules_layer4_modules_0_modules_bn1_buffers_running_mean_, l_self_modules_layer4_modules_0_modules_bn1_buffers_running_var_, l_self_modules_layer4_modules_0_modules_bn1_parameters_weight_, l_self_modules_layer4_modules_0_modules_bn1_parameters_bias_, False, 0.1, 1e-05); out_42 = l_self_modules_layer4_modules_0_modules_bn1_buffers_running_mean_ = l_self_modules_layer4_modules_0_modules_bn1_buffers_running_var_ = l_self_modules_layer4_modules_0_modules_bn1_parameters_weight_ = l_self_modules_layer4_modules_0_modules_bn1_parameters_bias_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:72 in forward, code: out = self.relu(out)
out_44: "f16[1, 512, 7, 7]" = torch.nn.functional.relu(out_43, inplace = True); out_43 = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:74 in forward, code: out = self.conv2(out)
out_45: "f16[1, 512, 7, 7]" = torch.conv2d(out_44, l_self_modules_layer4_modules_0_modules_conv2_parameters_weight_, None, (1, 1), (1, 1), (1, 1), 1); out_44 = l_self_modules_layer4_modules_0_modules_conv2_parameters_weight_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:75 in forward, code: out = self.bn2(out)
out_46: "f16[1, 512, 7, 7]" = torch.nn.functional.batch_norm(out_45, l_self_modules_layer4_modules_0_modules_bn2_buffers_running_mean_, l_self_modules_layer4_modules_0_modules_bn2_buffers_running_var_, l_self_modules_layer4_modules_0_modules_bn2_parameters_weight_, l_self_modules_layer4_modules_0_modules_bn2_parameters_bias_, False, 0.1, 1e-05); out_45 = l_self_modules_layer4_modules_0_modules_bn2_buffers_running_mean_ = l_self_modules_layer4_modules_0_modules_bn2_buffers_running_var_ = l_self_modules_layer4_modules_0_modules_bn2_parameters_weight_ = l_self_modules_layer4_modules_0_modules_bn2_parameters_bias_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:78 in forward, code: identity = self.downsample(x)
input_5: "f16[1, 512, 7, 7]" = torch.conv2d(out_41, l_self_modules_layer4_modules_0_modules_downsample_modules_0_parameters_weight_, None, (2, 2), (0, 0), (1, 1), 1); out_41 = l_self_modules_layer4_modules_0_modules_downsample_modules_0_parameters_weight_ = None
input_6: "f16[1, 512, 7, 7]" = torch.nn.functional.batch_norm(input_5, l_self_modules_layer4_modules_0_modules_downsample_modules_1_buffers_running_mean_, l_self_modules_layer4_modules_0_modules_downsample_modules_1_buffers_running_var_, l_self_modules_layer4_modules_0_modules_downsample_modules_1_parameters_weight_, l_self_modules_layer4_modules_0_modules_downsample_modules_1_parameters_bias_, False, 0.1, 1e-05); input_5 = l_self_modules_layer4_modules_0_modules_downsample_modules_1_buffers_running_mean_ = l_self_modules_layer4_modules_0_modules_downsample_modules_1_buffers_running_var_ = l_self_modules_layer4_modules_0_modules_downsample_modules_1_parameters_weight_ = l_self_modules_layer4_modules_0_modules_downsample_modules_1_parameters_bias_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:80 in forward, code: out += identity
out_46 += input_6; out_47: "f16[1, 512, 7, 7]" = out_46; out_46 = input_6 = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:81 in forward, code: out = self.relu(out)
out_48: "f16[1, 512, 7, 7]" = torch.nn.functional.relu(out_47, inplace = True); out_47 = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:70 in forward, code: out = self.conv1(x)
out_49: "f16[1, 512, 7, 7]" = torch.conv2d(out_48, l_self_modules_layer4_modules_1_modules_conv1_parameters_weight_, None, (1, 1), (1, 1), (1, 1), 1); l_self_modules_layer4_modules_1_modules_conv1_parameters_weight_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:71 in forward, code: out = self.bn1(out)
out_50: "f16[1, 512, 7, 7]" = torch.nn.functional.batch_norm(out_49, l_self_modules_layer4_modules_1_modules_bn1_buffers_running_mean_, l_self_modules_layer4_modules_1_modules_bn1_buffers_running_var_, l_self_modules_layer4_modules_1_modules_bn1_parameters_weight_, l_self_modules_layer4_modules_1_modules_bn1_parameters_bias_, False, 0.1, 1e-05); out_49 = l_self_modules_layer4_modules_1_modules_bn1_buffers_running_mean_ = l_self_modules_layer4_modules_1_modules_bn1_buffers_running_var_ = l_self_modules_layer4_modules_1_modules_bn1_parameters_weight_ = l_self_modules_layer4_modules_1_modules_bn1_parameters_bias_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:72 in forward, code: out = self.relu(out)
out_51: "f16[1, 512, 7, 7]" = torch.nn.functional.relu(out_50, inplace = True); out_50 = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:74 in forward, code: out = self.conv2(out)
out_52: "f16[1, 512, 7, 7]" = torch.conv2d(out_51, l_self_modules_layer4_modules_1_modules_conv2_parameters_weight_, None, (1, 1), (1, 1), (1, 1), 1); out_51 = l_self_modules_layer4_modules_1_modules_conv2_parameters_weight_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:75 in forward, code: out = self.bn2(out)
out_53: "f16[1, 512, 7, 7]" = torch.nn.functional.batch_norm(out_52, l_self_modules_layer4_modules_1_modules_bn2_buffers_running_mean_, l_self_modules_layer4_modules_1_modules_bn2_buffers_running_var_, l_self_modules_layer4_modules_1_modules_bn2_parameters_weight_, l_self_modules_layer4_modules_1_modules_bn2_parameters_bias_, False, 0.1, 1e-05); out_52 = l_self_modules_layer4_modules_1_modules_bn2_buffers_running_mean_ = l_self_modules_layer4_modules_1_modules_bn2_buffers_running_var_ = l_self_modules_layer4_modules_1_modules_bn2_parameters_weight_ = l_self_modules_layer4_modules_1_modules_bn2_parameters_bias_ = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:80 in forward, code: out += identity
out_53 += out_48; out_54: "f16[1, 512, 7, 7]" = out_53; out_53 = out_48 = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:81 in forward, code: out = self.relu(out)
out_55: "f16[1, 512, 7, 7]" = torch.nn.functional.relu(out_54, inplace = True); out_54 = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:242 in _forward_impl, code: x = self.avgpool(x)
x_4: "f16[1, 512, 1, 1]" = torch.nn.functional.adaptive_avg_pool2d(out_55, (1, 1)); out_55 = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:243 in _forward_impl, code: x = torch.flatten(x, 1)
x_5: "f16[1, 512]" = torch.flatten(x_4, 1); x_4 = None
# File: /home/ryan/.cache/torch/hub/pytorch_vision_v0.9.0/torchvision/models/resnet.py:244 in _forward_impl, code: x = self.fc(x)
x_6: "f16[1, 1000]" = torch._C._nn.linear(x_5, l_self_modules_fc_parameters_weight_, l_self_modules_fc_parameters_bias_); x_5 = l_self_modules_fc_parameters_weight_ = l_self_modules_fc_parameters_bias_ = None
return (x_6,)
---
opcode name target args kwargs
------------- ---------------------------------------------------------------------------------- ---------------------------------------------------------------------------------- ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------
placeholder l_self_modules_conv1_parameters_weight_ L_self_modules_conv1_parameters_weight_ () {}
placeholder l_x_ L_x_ () {}
placeholder l_self_modules_bn1_buffers_running_mean_ L_self_modules_bn1_buffers_running_mean_ () {}
placeholder l_self_modules_bn1_buffers_running_var_ L_self_modules_bn1_buffers_running_var_ () {}
placeholder l_self_modules_bn1_parameters_weight_ L_self_modules_bn1_parameters_weight_ () {}
placeholder l_self_modules_bn1_parameters_bias_ L_self_modules_bn1_parameters_bias_ () {}
placeholder l_self_modules_layer1_modules_0_modules_conv1_parameters_weight_ L_self_modules_layer1_modules_0_modules_conv1_parameters_weight_ () {}
placeholder l_self_modules_layer1_modules_0_modules_bn1_buffers_running_mean_ L_self_modules_layer1_modules_0_modules_bn1_buffers_running_mean_ () {}
placeholder l_self_modules_layer1_modules_0_modules_bn1_buffers_running_var_ L_self_modules_layer1_modules_0_modules_bn1_buffers_running_var_ () {}
placeholder l_self_modules_layer1_modules_0_modules_bn1_parameters_weight_ L_self_modules_layer1_modules_0_modules_bn1_parameters_weight_ () {}
placeholder l_self_modules_layer1_modules_0_modules_bn1_parameters_bias_ L_self_modules_layer1_modules_0_modules_bn1_parameters_bias_ () {}
placeholder l_self_modules_layer1_modules_0_modules_conv2_parameters_weight_ L_self_modules_layer1_modules_0_modules_conv2_parameters_weight_ () {}
placeholder l_self_modules_layer1_modules_0_modules_bn2_buffers_running_mean_ L_self_modules_layer1_modules_0_modules_bn2_buffers_running_mean_ () {}
placeholder l_self_modules_layer1_modules_0_modules_bn2_buffers_running_var_ L_self_modules_layer1_modules_0_modules_bn2_buffers_running_var_ () {}
placeholder l_self_modules_layer1_modules_0_modules_bn2_parameters_weight_ L_self_modules_layer1_modules_0_modules_bn2_parameters_weight_ () {}
placeholder l_self_modules_layer1_modules_0_modules_bn2_parameters_bias_ L_self_modules_layer1_modules_0_modules_bn2_parameters_bias_ () {}
placeholder l_self_modules_layer1_modules_1_modules_conv1_parameters_weight_ L_self_modules_layer1_modules_1_modules_conv1_parameters_weight_ () {}
placeholder l_self_modules_layer1_modules_1_modules_bn1_buffers_running_mean_ L_self_modules_layer1_modules_1_modules_bn1_buffers_running_mean_ () {}
placeholder l_self_modules_layer1_modules_1_modules_bn1_buffers_running_var_ L_self_modules_layer1_modules_1_modules_bn1_buffers_running_var_ () {}
placeholder l_self_modules_layer1_modules_1_modules_bn1_parameters_weight_ L_self_modules_layer1_modules_1_modules_bn1_parameters_weight_ () {}
placeholder l_self_modules_layer1_modules_1_modules_bn1_parameters_bias_ L_self_modules_layer1_modules_1_modules_bn1_parameters_bias_ () {}
placeholder l_self_modules_layer1_modules_1_modules_conv2_parameters_weight_ L_self_modules_layer1_modules_1_modules_conv2_parameters_weight_ () {}
placeholder l_self_modules_layer1_modules_1_modules_bn2_buffers_running_mean_ L_self_modules_layer1_modules_1_modules_bn2_buffers_running_mean_ () {}
placeholder l_self_modules_layer1_modules_1_modules_bn2_buffers_running_var_ L_self_modules_layer1_modules_1_modules_bn2_buffers_running_var_ () {}
placeholder l_self_modules_layer1_modules_1_modules_bn2_parameters_weight_ L_self_modules_layer1_modules_1_modules_bn2_parameters_weight_ () {}
placeholder l_self_modules_layer1_modules_1_modules_bn2_parameters_bias_ L_self_modules_layer1_modules_1_modules_bn2_parameters_bias_ () {}
placeholder l_self_modules_layer2_modules_0_modules_conv1_parameters_weight_ L_self_modules_layer2_modules_0_modules_conv1_parameters_weight_ () {}
placeholder l_self_modules_layer2_modules_0_modules_bn1_buffers_running_mean_ L_self_modules_layer2_modules_0_modules_bn1_buffers_running_mean_ () {}
placeholder l_self_modules_layer2_modules_0_modules_bn1_buffers_running_var_ L_self_modules_layer2_modules_0_modules_bn1_buffers_running_var_ () {}
placeholder l_self_modules_layer2_modules_0_modules_bn1_parameters_weight_ L_self_modules_layer2_modules_0_modules_bn1_parameters_weight_ () {}
placeholder l_self_modules_layer2_modules_0_modules_bn1_parameters_bias_ L_self_modules_layer2_modules_0_modules_bn1_parameters_bias_ () {}
placeholder l_self_modules_layer2_modules_0_modules_conv2_parameters_weight_ L_self_modules_layer2_modules_0_modules_conv2_parameters_weight_ () {}
placeholder l_self_modules_layer2_modules_0_modules_bn2_buffers_running_mean_ L_self_modules_layer2_modules_0_modules_bn2_buffers_running_mean_ () {}
placeholder l_self_modules_layer2_modules_0_modules_bn2_buffers_running_var_ L_self_modules_layer2_modules_0_modules_bn2_buffers_running_var_ () {}
placeholder l_self_modules_layer2_modules_0_modules_bn2_parameters_weight_ L_self_modules_layer2_modules_0_modules_bn2_parameters_weight_ () {}
placeholder l_self_modules_layer2_modules_0_modules_bn2_parameters_bias_ L_self_modules_layer2_modules_0_modules_bn2_parameters_bias_ () {}
placeholder l_self_modules_layer2_modules_0_modules_downsample_modules_0_parameters_weight_ L_self_modules_layer2_modules_0_modules_downsample_modules_0_parameters_weight_ () {}
placeholder l_self_modules_layer2_modules_0_modules_downsample_modules_1_buffers_running_mean_ L_self_modules_layer2_modules_0_modules_downsample_modules_1_buffers_running_mean_ () {}
placeholder l_self_modules_layer2_modules_0_modules_downsample_modules_1_buffers_running_var_ L_self_modules_layer2_modules_0_modules_downsample_modules_1_buffers_running_var_ () {}
placeholder l_self_modules_layer2_modules_0_modules_downsample_modules_1_parameters_weight_ L_self_modules_layer2_modules_0_modules_downsample_modules_1_parameters_weight_ () {}
placeholder l_self_modules_layer2_modules_0_modules_downsample_modules_1_parameters_bias_ L_self_modules_layer2_modules_0_modules_downsample_modules_1_parameters_bias_ () {}
placeholder l_self_modules_layer2_modules_1_modules_conv1_parameters_weight_ L_self_modules_layer2_modules_1_modules_conv1_parameters_weight_ () {}
placeholder l_self_modules_layer2_modules_1_modules_bn1_buffers_running_mean_ L_self_modules_layer2_modules_1_modules_bn1_buffers_running_mean_ () {}
placeholder l_self_modules_layer2_modules_1_modules_bn1_buffers_running_var_ L_self_modules_layer2_modules_1_modules_bn1_buffers_running_var_ () {}
placeholder l_self_modules_layer2_modules_1_modules_bn1_parameters_weight_ L_self_modules_layer2_modules_1_modules_bn1_parameters_weight_ () {}
placeholder l_self_modules_layer2_modules_1_modules_bn1_parameters_bias_ L_self_modules_layer2_modules_1_modules_bn1_parameters_bias_ () {}
placeholder l_self_modules_layer2_modules_1_modules_conv2_parameters_weight_ L_self_modules_layer2_modules_1_modules_conv2_parameters_weight_ () {}
placeholder l_self_modules_layer2_modules_1_modules_bn2_buffers_running_mean_ L_self_modules_layer2_modules_1_modules_bn2_buffers_running_mean_ () {}
placeholder l_self_modules_layer2_modules_1_modules_bn2_buffers_running_var_ L_self_modules_layer2_modules_1_modules_bn2_buffers_running_var_ () {}
placeholder l_self_modules_layer2_modules_1_modules_bn2_parameters_weight_ L_self_modules_layer2_modules_1_modules_bn2_parameters_weight_ () {}
placeholder l_self_modules_layer2_modules_1_modules_bn2_parameters_bias_ L_self_modules_layer2_modules_1_modules_bn2_parameters_bias_ () {}
placeholder l_self_modules_layer3_modules_0_modules_conv1_parameters_weight_ L_self_modules_layer3_modules_0_modules_conv1_parameters_weight_ () {}
placeholder l_self_modules_layer3_modules_0_modules_bn1_buffers_running_mean_ L_self_modules_layer3_modules_0_modules_bn1_buffers_running_mean_ () {}
placeholder l_self_modules_layer3_modules_0_modules_bn1_buffers_running_var_ L_self_modules_layer3_modules_0_modules_bn1_buffers_running_var_ () {}
placeholder l_self_modules_layer3_modules_0_modules_bn1_parameters_weight_ L_self_modules_layer3_modules_0_modules_bn1_parameters_weight_ () {}
placeholder l_self_modules_layer3_modules_0_modules_bn1_parameters_bias_ L_self_modules_layer3_modules_0_modules_bn1_parameters_bias_ () {}
placeholder l_self_modules_layer3_modules_0_modules_conv2_parameters_weight_ L_self_modules_layer3_modules_0_modules_conv2_parameters_weight_ () {}
placeholder l_self_modules_layer3_modules_0_modules_bn2_buffers_running_mean_ L_self_modules_layer3_modules_0_modules_bn2_buffers_running_mean_ () {}
placeholder l_self_modules_layer3_modules_0_modules_bn2_buffers_running_var_ L_self_modules_layer3_modules_0_modules_bn2_buffers_running_var_ () {}
placeholder l_self_modules_layer3_modules_0_modules_bn2_parameters_weight_ L_self_modules_layer3_modules_0_modules_bn2_parameters_weight_ () {}
placeholder l_self_modules_layer3_modules_0_modules_bn2_parameters_bias_ L_self_modules_layer3_modules_0_modules_bn2_parameters_bias_ () {}
placeholder l_self_modules_layer3_modules_0_modules_downsample_modules_0_parameters_weight_ L_self_modules_layer3_modules_0_modules_downsample_modules_0_parameters_weight_ () {}
placeholder l_self_modules_layer3_modules_0_modules_downsample_modules_1_buffers_running_mean_ L_self_modules_layer3_modules_0_modules_downsample_modules_1_buffers_running_mean_ () {}
placeholder l_self_modules_layer3_modules_0_modules_downsample_modules_1_buffers_running_var_ L_self_modules_layer3_modules_0_modules_downsample_modules_1_buffers_running_var_ () {}
placeholder l_self_modules_layer3_modules_0_modules_downsample_modules_1_parameters_weight_ L_self_modules_layer3_modules_0_modules_downsample_modules_1_parameters_weight_ () {}
placeholder l_self_modules_layer3_modules_0_modules_downsample_modules_1_parameters_bias_ L_self_modules_layer3_modules_0_modules_downsample_modules_1_parameters_bias_ () {}
placeholder l_self_modules_layer3_modules_1_modules_conv1_parameters_weight_ L_self_modules_layer3_modules_1_modules_conv1_parameters_weight_ () {}
placeholder l_self_modules_layer3_modules_1_modules_bn1_buffers_running_mean_ L_self_modules_layer3_modules_1_modules_bn1_buffers_running_mean_ () {}
placeholder l_self_modules_layer3_modules_1_modules_bn1_buffers_running_var_ L_self_modules_layer3_modules_1_modules_bn1_buffers_running_var_ () {}
placeholder l_self_modules_layer3_modules_1_modules_bn1_parameters_weight_ L_self_modules_layer3_modules_1_modules_bn1_parameters_weight_ () {}
placeholder l_self_modules_layer3_modules_1_modules_bn1_parameters_bias_ L_self_modules_layer3_modules_1_modules_bn1_parameters_bias_ () {}
placeholder l_self_modules_layer3_modules_1_modules_conv2_parameters_weight_ L_self_modules_layer3_modules_1_modules_conv2_parameters_weight_ () {}
placeholder l_self_modules_layer3_modules_1_modules_bn2_buffers_running_mean_ L_self_modules_layer3_modules_1_modules_bn2_buffers_running_mean_ () {}
placeholder l_self_modules_layer3_modules_1_modules_bn2_buffers_running_var_ L_self_modules_layer3_modules_1_modules_bn2_buffers_running_var_ () {}
placeholder l_self_modules_layer3_modules_1_modules_bn2_parameters_weight_ L_self_modules_layer3_modules_1_modules_bn2_parameters_weight_ () {}
placeholder l_self_modules_layer3_modules_1_modules_bn2_parameters_bias_ L_self_modules_layer3_modules_1_modules_bn2_parameters_bias_ () {}
placeholder l_self_modules_layer4_modules_0_modules_conv1_parameters_weight_ L_self_modules_layer4_modules_0_modules_conv1_parameters_weight_ () {}
placeholder l_self_modules_layer4_modules_0_modules_bn1_buffers_running_mean_ L_self_modules_layer4_modules_0_modules_bn1_buffers_running_mean_ () {}
placeholder l_self_modules_layer4_modules_0_modules_bn1_buffers_running_var_ L_self_modules_layer4_modules_0_modules_bn1_buffers_running_var_ () {}
placeholder l_self_modules_layer4_modules_0_modules_bn1_parameters_weight_ L_self_modules_layer4_modules_0_modules_bn1_parameters_weight_ () {}
placeholder l_self_modules_layer4_modules_0_modules_bn1_parameters_bias_ L_self_modules_layer4_modules_0_modules_bn1_parameters_bias_ () {}
placeholder l_self_modules_layer4_modules_0_modules_conv2_parameters_weight_ L_self_modules_layer4_modules_0_modules_conv2_parameters_weight_ () {}
placeholder l_self_modules_layer4_modules_0_modules_bn2_buffers_running_mean_ L_self_modules_layer4_modules_0_modules_bn2_buffers_running_mean_ () {}
placeholder l_self_modules_layer4_modules_0_modules_bn2_buffers_running_var_ L_self_modules_layer4_modules_0_modules_bn2_buffers_running_var_ () {}
placeholder l_self_modules_layer4_modules_0_modules_bn2_parameters_weight_ L_self_modules_layer4_modules_0_modules_bn2_parameters_weight_ () {}
placeholder l_self_modules_layer4_modules_0_modules_bn2_parameters_bias_ L_self_modules_layer4_modules_0_modules_bn2_parameters_bias_ () {}
placeholder l_self_modules_layer4_modules_0_modules_downsample_modules_0_parameters_weight_ L_self_modules_layer4_modules_0_modules_downsample_modules_0_parameters_weight_ () {}
placeholder l_self_modules_layer4_modules_0_modules_downsample_modules_1_buffers_running_mean_ L_self_modules_layer4_modules_0_modules_downsample_modules_1_buffers_running_mean_ () {}
placeholder l_self_modules_layer4_modules_0_modules_downsample_modules_1_buffers_running_var_ L_self_modules_layer4_modules_0_modules_downsample_modules_1_buffers_running_var_ () {}
placeholder l_self_modules_layer4_modules_0_modules_downsample_modules_1_parameters_weight_ L_self_modules_layer4_modules_0_modules_downsample_modules_1_parameters_weight_ () {}
placeholder l_self_modules_layer4_modules_0_modules_downsample_modules_1_parameters_bias_ L_self_modules_layer4_modules_0_modules_downsample_modules_1_parameters_bias_ () {}
placeholder l_self_modules_layer4_modules_1_modules_conv1_parameters_weight_ L_self_modules_layer4_modules_1_modules_conv1_parameters_weight_ () {}
placeholder l_self_modules_layer4_modules_1_modules_bn1_buffers_running_mean_ L_self_modules_layer4_modules_1_modules_bn1_buffers_running_mean_ () {}
placeholder l_self_modules_layer4_modules_1_modules_bn1_buffers_running_var_ L_self_modules_layer4_modules_1_modules_bn1_buffers_running_var_ () {}
placeholder l_self_modules_layer4_modules_1_modules_bn1_parameters_weight_ L_self_modules_layer4_modules_1_modules_bn1_parameters_weight_ () {}
placeholder l_self_modules_layer4_modules_1_modules_bn1_parameters_bias_ L_self_modules_layer4_modules_1_modules_bn1_parameters_bias_ () {}
placeholder l_self_modules_layer4_modules_1_modules_conv2_parameters_weight_ L_self_modules_layer4_modules_1_modules_conv2_parameters_weight_ () {}
placeholder l_self_modules_layer4_modules_1_modules_bn2_buffers_running_mean_ L_self_modules_layer4_modules_1_modules_bn2_buffers_running_mean_ () {}
placeholder l_self_modules_layer4_modules_1_modules_bn2_buffers_running_var_ L_self_modules_layer4_modules_1_modules_bn2_buffers_running_var_ () {}
placeholder l_self_modules_layer4_modules_1_modules_bn2_parameters_weight_ L_self_modules_layer4_modules_1_modules_bn2_parameters_weight_ () {}
placeholder l_self_modules_layer4_modules_1_modules_bn2_parameters_bias_ L_self_modules_layer4_modules_1_modules_bn2_parameters_bias_ () {}
placeholder l_self_modules_fc_parameters_weight_ L_self_modules_fc_parameters_weight_ () {}
placeholder l_self_modules_fc_parameters_bias_ L_self_modules_fc_parameters_bias_ () {}
call_function x <built-in method conv2d of type object at 0x75886d0bf1c0> (l_x_, l_self_modules_conv1_parameters_weight_, None, (2, 2), (3, 3), (1, 1), 1) {}
call_function x_1 <function batch_norm at 0x758881d7a4d0> (x, l_self_modules_bn1_buffers_running_mean_, l_self_modules_bn1_buffers_running_var_, l_self_modules_bn1_parameters_weight_, l_self_modules_bn1_parameters_bias_, False, 0.1, 1e-05) {}
call_function x_2 <function relu at 0x758881d79510> (x_1,) {'inplace': True}
call_function x_3 <function boolean_dispatch.<locals>.fn at 0x758881d783a0> (x_2, 3, 2, 1, 1) {'ceil_mode': False, 'return_indices': False}
call_function out <built-in method conv2d of type object at 0x75886d0bf1c0> (x_3, l_self_modules_layer1_modules_0_modules_conv1_parameters_weight_, None, (1, 1), (1, 1), (1, 1), 1) {}
call_function out_1 <function batch_norm at 0x758881d7a4d0> (out, l_self_modules_layer1_modules_0_modules_bn1_buffers_running_mean_, l_self_modules_layer1_modules_0_modules_bn1_buffers_running_var_, l_self_modules_layer1_modules_0_modules_bn1_parameters_weight_, l_self_modules_layer1_modules_0_modules_bn1_parameters_bias_, False, 0.1, 1e-05) {}
call_function out_2 <function relu at 0x758881d79510> (out_1,) {'inplace': True}
call_function out_3 <built-in method conv2d of type object at 0x75886d0bf1c0> (out_2, l_self_modules_layer1_modules_0_modules_conv2_parameters_weight_, None, (1, 1), (1, 1), (1, 1), 1) {}
call_function out_4 <function batch_norm at 0x758881d7a4d0> (out_3, l_self_modules_layer1_modules_0_modules_bn2_buffers_running_mean_, l_self_modules_layer1_modules_0_modules_bn2_buffers_running_var_, l_self_modules_layer1_modules_0_modules_bn2_parameters_weight_, l_self_modules_layer1_modules_0_modules_bn2_parameters_bias_, False, 0.1, 1e-05) {}
call_function out_5 <built-in function iadd> (out_4, x_3) {}
call_function out_6 <function relu at 0x758881d79510> (out_5,) {'inplace': True}
call_function out_7 <built-in method conv2d of type object at 0x75886d0bf1c0> (out_6, l_self_modules_layer1_modules_1_modules_conv1_parameters_weight_, None, (1, 1), (1, 1), (1, 1), 1) {}
call_function out_8 <function batch_norm at 0x758881d7a4d0> (out_7, l_self_modules_layer1_modules_1_modules_bn1_buffers_running_mean_, l_self_modules_layer1_modules_1_modules_bn1_buffers_running_var_, l_self_modules_layer1_modules_1_modules_bn1_parameters_weight_, l_self_modules_layer1_modules_1_modules_bn1_parameters_bias_, False, 0.1, 1e-05) {}
call_function out_9 <function relu at 0x758881d79510> (out_8,) {'inplace': True}
call_function out_10 <built-in method conv2d of type object at 0x75886d0bf1c0> (out_9, l_self_modules_layer1_modules_1_modules_conv2_parameters_weight_, None, (1, 1), (1, 1), (1, 1), 1) {}
call_function out_11 <function batch_norm at 0x758881d7a4d0> (out_10, l_self_modules_layer1_modules_1_modules_bn2_buffers_running_mean_, l_self_modules_layer1_modules_1_modules_bn2_buffers_running_var_, l_self_modules_layer1_modules_1_modules_bn2_parameters_weight_, l_self_modules_layer1_modules_1_modules_bn2_parameters_bias_, False, 0.1, 1e-05) {}
call_function out_12 <built-in function iadd> (out_11, out_6) {}
call_function out_13 <function relu at 0x758881d79510> (out_12,) {'inplace': True}
call_function out_14 <built-in method conv2d of type object at 0x75886d0bf1c0> (out_13, l_self_modules_layer2_modules_0_modules_conv1_parameters_weight_, None, (2, 2), (1, 1), (1, 1), 1) {}
call_function out_15 <function batch_norm at 0x758881d7a4d0> (out_14, l_self_modules_layer2_modules_0_modules_bn1_buffers_running_mean_, l_self_modules_layer2_modules_0_modules_bn1_buffers_running_var_, l_self_modules_layer2_modules_0_modules_bn1_parameters_weight_, l_self_modules_layer2_modules_0_modules_bn1_parameters_bias_, False, 0.1, 1e-05) {}
call_function out_16 <function relu at 0x758881d79510> (out_15,) {'inplace': True}
call_function out_17 <built-in method conv2d of type object at 0x75886d0bf1c0> (out_16, l_self_modules_layer2_modules_0_modules_conv2_parameters_weight_, None, (1, 1), (1, 1), (1, 1), 1) {}
call_function out_18 <function batch_norm at 0x758881d7a4d0> (out_17, l_self_modules_layer2_modules_0_modules_bn2_buffers_running_mean_, l_self_modules_layer2_modules_0_modules_bn2_buffers_running_var_, l_self_modules_layer2_modules_0_modules_bn2_parameters_weight_, l_self_modules_layer2_modules_0_modules_bn2_parameters_bias_, False, 0.1, 1e-05) {}
call_function input_1 <built-in method conv2d of type object at 0x75886d0bf1c0> (out_13, l_self_modules_layer2_modules_0_modules_downsample_modules_0_parameters_weight_, None, (2, 2), (0, 0), (1, 1), 1) {}
call_function input_2 <function batch_norm at 0x758881d7a4d0> (input_1, l_self_modules_layer2_modules_0_modules_downsample_modules_1_buffers_running_mean_, l_self_modules_layer2_modules_0_modules_downsample_modules_1_buffers_running_var_, l_self_modules_layer2_modules_0_modules_downsample_modules_1_parameters_weight_, l_self_modules_layer2_modules_0_modules_downsample_modules_1_parameters_bias_, False, 0.1, 1e-05) {}
call_function out_19 <built-in function iadd> (out_18, input_2) {}
call_function out_20 <function relu at 0x758881d79510> (out_19,) {'inplace': True}
call_function out_21 <built-in method conv2d of type object at 0x75886d0bf1c0> (out_20, l_self_modules_layer2_modules_1_modules_conv1_parameters_weight_, None, (1, 1), (1, 1), (1, 1), 1) {}
call_function out_22 <function batch_norm at 0x758881d7a4d0> (out_21, l_self_modules_layer2_modules_1_modules_bn1_buffers_running_mean_, l_self_modules_layer2_modules_1_modules_bn1_buffers_running_var_, l_self_modules_layer2_modules_1_modules_bn1_parameters_weight_, l_self_modules_layer2_modules_1_modules_bn1_parameters_bias_, False, 0.1, 1e-05) {}
call_function out_23 <function relu at 0x758881d79510> (out_22,) {'inplace': True}
call_function out_24 <built-in method conv2d of type object at 0x75886d0bf1c0> (out_23, l_self_modules_layer2_modules_1_modules_conv2_parameters_weight_, None, (1, 1), (1, 1), (1, 1), 1) {}
call_function out_25 <function batch_norm at 0x758881d7a4d0> (out_24, l_self_modules_layer2_modules_1_modules_bn2_buffers_running_mean_, l_self_modules_layer2_modules_1_modules_bn2_buffers_running_var_, l_self_modules_layer2_modules_1_modules_bn2_parameters_weight_, l_self_modules_layer2_modules_1_modules_bn2_parameters_bias_, False, 0.1, 1e-05) {}
call_function out_26 <built-in function iadd> (out_25, out_20) {}
call_function out_27 <function relu at 0x758881d79510> (out_26,) {'inplace': True}
call_function out_28 <built-in method conv2d of type object at 0x75886d0bf1c0> (out_27, l_self_modules_layer3_modules_0_modules_conv1_parameters_weight_, None, (2, 2), (1, 1), (1, 1), 1) {}
call_function out_29 <function batch_norm at 0x758881d7a4d0> (out_28, l_self_modules_layer3_modules_0_modules_bn1_buffers_running_mean_, l_self_modules_layer3_modules_0_modules_bn1_buffers_running_var_, l_self_modules_layer3_modules_0_modules_bn1_parameters_weight_, l_self_modules_layer3_modules_0_modules_bn1_parameters_bias_, False, 0.1, 1e-05) {}
call_function out_30 <function relu at 0x758881d79510> (out_29,) {'inplace': True}
call_function out_31 <built-in method conv2d of type object at 0x75886d0bf1c0> (out_30, l_self_modules_layer3_modules_0_modules_conv2_parameters_weight_, None, (1, 1), (1, 1), (1, 1), 1) {}
call_function out_32 <function batch_norm at 0x758881d7a4d0> (out_31, l_self_modules_layer3_modules_0_modules_bn2_buffers_running_mean_, l_self_modules_layer3_modules_0_modules_bn2_buffers_running_var_, l_self_modules_layer3_modules_0_modules_bn2_parameters_weight_, l_self_modules_layer3_modules_0_modules_bn2_parameters_bias_, False, 0.1, 1e-05) {}
call_function input_3 <built-in method conv2d of type object at 0x75886d0bf1c0> (out_27, l_self_modules_layer3_modules_0_modules_downsample_modules_0_parameters_weight_, None, (2, 2), (0, 0), (1, 1), 1) {}
call_function input_4 <function batch_norm at 0x758881d7a4d0> (input_3, l_self_modules_layer3_modules_0_modules_downsample_modules_1_buffers_running_mean_, l_self_modules_layer3_modules_0_modules_downsample_modules_1_buffers_running_var_, l_self_modules_layer3_modules_0_modules_downsample_modules_1_parameters_weight_, l_self_modules_layer3_modules_0_modules_downsample_modules_1_parameters_bias_, False, 0.1, 1e-05) {}
call_function out_33 <built-in function iadd> (out_32, input_4) {}
call_function out_34 <function relu at 0x758881d79510> (out_33,) {'inplace': True}
call_function out_35 <built-in method conv2d of type object at 0x75886d0bf1c0> (out_34, l_self_modules_layer3_modules_1_modules_conv1_parameters_weight_, None, (1, 1), (1, 1), (1, 1), 1) {}
call_function out_36 <function batch_norm at 0x758881d7a4d0> (out_35, l_self_modules_layer3_modules_1_modules_bn1_buffers_running_mean_, l_self_modules_layer3_modules_1_modules_bn1_buffers_running_var_, l_self_modules_layer3_modules_1_modules_bn1_parameters_weight_, l_self_modules_layer3_modules_1_modules_bn1_parameters_bias_, False, 0.1, 1e-05) {}
call_function out_37 <function relu at 0x758881d79510> (out_36,) {'inplace': True}
call_function out_38 <built-in method conv2d of type object at 0x75886d0bf1c0> (out_37, l_self_modules_layer3_modules_1_modules_conv2_parameters_weight_, None, (1, 1), (1, 1), (1, 1), 1) {}
call_function out_39 <function batch_norm at 0x758881d7a4d0> (out_38, l_self_modules_layer3_modules_1_modules_bn2_buffers_running_mean_, l_self_modules_layer3_modules_1_modules_bn2_buffers_running_var_, l_self_modules_layer3_modules_1_modules_bn2_parameters_weight_, l_self_modules_layer3_modules_1_modules_bn2_parameters_bias_, False, 0.1, 1e-05) {}
call_function out_40 <built-in function iadd> (out_39, out_34) {}
call_function out_41 <function relu at 0x758881d79510> (out_40,) {'inplace': True}
call_function out_42 <built-in method conv2d of type object at 0x75886d0bf1c0> (out_41, l_self_modules_layer4_modules_0_modules_conv1_parameters_weight_, None, (2, 2), (1, 1), (1, 1), 1) {}
call_function out_43 <function batch_norm at 0x758881d7a4d0> (out_42, l_self_modules_layer4_modules_0_modules_bn1_buffers_running_mean_, l_self_modules_layer4_modules_0_modules_bn1_buffers_running_var_, l_self_modules_layer4_modules_0_modules_bn1_parameters_weight_, l_self_modules_layer4_modules_0_modules_bn1_parameters_bias_, False, 0.1, 1e-05) {}
call_function out_44 <function relu at 0x758881d79510> (out_43,) {'inplace': True}
call_function out_45 <built-in method conv2d of type object at 0x75886d0bf1c0> (out_44, l_self_modules_layer4_modules_0_modules_conv2_parameters_weight_, None, (1, 1), (1, 1), (1, 1), 1) {}
call_function out_46 <function batch_norm at 0x758881d7a4d0> (out_45, l_self_modules_layer4_modules_0_modules_bn2_buffers_running_mean_, l_self_modules_layer4_modules_0_modules_bn2_buffers_running_var_, l_self_modules_layer4_modules_0_modules_bn2_parameters_weight_, l_self_modules_layer4_modules_0_modules_bn2_parameters_bias_, False, 0.1, 1e-05) {}
call_function input_5 <built-in method conv2d of type object at 0x75886d0bf1c0> (out_41, l_self_modules_layer4_modules_0_modules_downsample_modules_0_parameters_weight_, None, (2, 2), (0, 0), (1, 1), 1) {}
call_function input_6 <function batch_norm at 0x758881d7a4d0> (input_5, l_self_modules_layer4_modules_0_modules_downsample_modules_1_buffers_running_mean_, l_self_modules_layer4_modules_0_modules_downsample_modules_1_buffers_running_var_, l_self_modules_layer4_modules_0_modules_downsample_modules_1_parameters_weight_, l_self_modules_layer4_modules_0_modules_downsample_modules_1_parameters_bias_, False, 0.1, 1e-05) {}
call_function out_47 <built-in function iadd> (out_46, input_6) {}
call_function out_48 <function relu at 0x758881d79510> (out_47,) {'inplace': True}
call_function out_49 <built-in method conv2d of type object at 0x75886d0bf1c0> (out_48, l_self_modules_layer4_modules_1_modules_conv1_parameters_weight_, None, (1, 1), (1, 1), (1, 1), 1) {}
call_function out_50 <function batch_norm at 0x758881d7a4d0> (out_49, l_self_modules_layer4_modules_1_modules_bn1_buffers_running_mean_, l_self_modules_layer4_modules_1_modules_bn1_buffers_running_var_, l_self_modules_layer4_modules_1_modules_bn1_parameters_weight_, l_self_modules_layer4_modules_1_modules_bn1_parameters_bias_, False, 0.1, 1e-05) {}
call_function out_51 <function relu at 0x758881d79510> (out_50,) {'inplace': True}
call_function out_52 <built-in method conv2d of type object at 0x75886d0bf1c0> (out_51, l_self_modules_layer4_modules_1_modules_conv2_parameters_weight_, None, (1, 1), (1, 1), (1, 1), 1) {}
call_function out_53 <function batch_norm at 0x758881d7a4d0> (out_52, l_self_modules_layer4_modules_1_modules_bn2_buffers_running_mean_, l_self_modules_layer4_modules_1_modules_bn2_buffers_running_var_, l_self_modules_layer4_modules_1_modules_bn2_parameters_weight_, l_self_modules_layer4_modules_1_modules_bn2_parameters_bias_, False, 0.1, 1e-05) {}
call_function out_54 <built-in function iadd> (out_53, out_48) {}
call_function out_55 <function relu at 0x758881d79510> (out_54,) {'inplace': True}
call_function x_4 <function adaptive_avg_pool2d at 0x758881d79000> (out_55, (1, 1)) {}
call_function x_5 <built-in method flatten of type object at 0x75886d0bf1c0> (x_4, 1) {}
call_function x_6 <built-in function linear> (x_5, l_self_modules_fc_parameters_weight_, l_self_modules_fc_parameters_bias_) {}
output output output ((x_6,),) {}
Parallel build: 0%| | 0/38 [00:00<?, ?it/s]
Parallel build: 3%|▋ | 1/38 [00:02<01:37, 2.64s/it]
Parallel build: 13%|███▋ | 5/38 [00:02<00:14, 2.34it/s]
Parallel build: 24%|██████▋ | 9/38 [00:02<00:06, 4.77it/s]
Parallel build: 32%|████████▌ | 12/38 [00:03<00:05, 4.44it/s]
Parallel build: 39%|██████████▋ | 15/38 [00:20<00:49, 2.14s/it]
Parallel build: 42%|███████████▎ | 16/38 [00:21<00:41, 1.88s/it]
Parallel build: 47%|████████████▊ | 18/38 [00:24<00:37, 1.87s/it]
Parallel build: 50%|█████████████▌ | 19/38 [00:35<01:04, 3.38s/it]
Parallel build: 53%|██████████████▏ | 20/38 [00:36<00:51, 2.88s/it]
Parallel build: 55%|██████████████▉ | 21/38 [00:37<00:40, 2.40s/it]
Parallel build: 58%|███████████████▋ | 22/38 [00:38<00:32, 2.05s/it]
Parallel build: 61%|████████████████▎ | 23/38 [00:38<00:24, 1.64s/it]
Parallel build: 63%|█████████████████ | 24/38 [00:40<00:25, 1.80s/it]
Parallel build: 66%|█████████████████▊ | 25/38 [00:44<00:29, 2.24s/it]
Parallel build: 68%|██████████████████▍ | 26/38 [00:46<00:25, 2.14s/it]
Parallel build: 71%|███████████████████▏ | 27/38 [00:46<00:18, 1.71s/it]
Parallel build: 74%|███████████████████▉ | 28/38 [00:59<00:48, 4.85s/it]
Parallel build: 76%|████████████████████▌ | 29/38 [00:59<00:31, 3.52s/it]
Parallel build: 79%|█████████████████████▎ | 30/38 [01:12<00:50, 6.27s/it]
Parallel build: 84%|██████████████████████▋ | 32/38 [01:24<00:36, 6.15s/it]
Parallel build: 87%|███████████████████████▍ | 33/38 [01:28<00:28, 5.61s/it]
Parallel build: 89%|████████████████████████▏ | 34/38 [01:30<00:18, 4.69s/it]
Parallel build: 92%|████████████████████████▊ | 35/38 [01:31<00:11, 3.72s/it]
Parallel build: 95%|█████████████████████████▌ | 36/38 [01:31<00:05, 2.79s/it]
Parallel build: 97%|██████████████████████████▎| 37/38 [01:44<00:05, 5.59s/it]
Parallel build: 100%|███████████████████████████| 38/38 [01:46<00:00, 4.61s/it]
Parallel build: 100%|███████████████████████████| 38/38 [01:46<00:00, 2.81s/it]
Total running time of the script: (2 minutes 11.220 seconds)