Note
Go to the end to download the full example code
Visualize Flow Graph¶
Visualization is a key component of a machine learning tool to allow us have a better understanding of the model.
We customized the popular Netron viewer to visualize the flow graph of a hidet model. The customized Netron viewer can be found at here, you can also find a link on the bottom of the documentation side bar.
In this tutorial, we will show you how to visualize the flow graph of a model.
Define model¶
We first define a model with a self-attention layer.
import math
import hidet
from hidet import Tensor
from hidet.graph import nn, ops
class SelfAttention(nn.Module):
def __init__(self, hidden_size=768, num_attention_heads=12):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_size = hidden_size // num_attention_heads
self.query_layer = nn.Linear(hidden_size, hidden_size)
self.key_layer = nn.Linear(hidden_size, hidden_size)
self.value_layer = nn.Linear(hidden_size, hidden_size)
def transpose_for_scores(self, x: Tensor) -> Tensor:
batch_size, seq_length, hidden_size = x.shape
x = x.reshape([batch_size, seq_length, self.num_attention_heads, self.attention_head_size])
x = x.rearrange([[0, 2], [1], [3]])
return x # [batch_size * num_attention_heads, seq_length, attention_head_size]
def forward(self, hidden_states: Tensor, attention_mask: Tensor):
batch_size, seq_length, _ = hidden_states.shape
query = self.transpose_for_scores(self.query_layer(hidden_states))
key = self.transpose_for_scores(self.key_layer(hidden_states))
value = self.transpose_for_scores(self.value_layer(hidden_states))
attention_scores = ops.matmul(query, ops.transpose(key, [-1, -2])) / math.sqrt(
self.attention_head_size
)
attention_scores = attention_scores + attention_mask
attention_probs = ops.softmax(attention_scores, axis=-1)
context = ops.matmul(attention_probs, value)
context = context.reshape(
[batch_size, self.num_attention_heads, seq_length, self.attention_head_size]
)
context = context.rearrange([[0], [2], [1, 3]])
return context
model = SelfAttention()
print(model)
SelfAttention(
(query_layer): Linear(in_features=768, out_features=768)
(key_layer): Linear(in_features=768, out_features=768)
(value_layer): Linear(in_features=768, out_features=768)
)
Generate flow graph¶
Then we generate the flow graph of the model.
graph = model.flow_graph_for(
inputs=[hidet.randn([1, 128, 768]), hidet.ones([1, 128], dtype='int32')]
)
print(graph)
Graph(x: float32[1, 128, 768][cpu], x_1: int32[1, 128][cpu]){
c = Constant(float32[768, 768][cpu])
c_1 = Constant(float32[768][cpu])
c_2 = Constant(float32[768, 768][cpu])
c_3 = Constant(float32[768][cpu])
c_4 = Constant(float32[768, 768][cpu])
c_5 = Constant(float32[768][cpu])
x_2: float32[1, 128, 768][cpu] = Matmul(x, c, require_prologue=False)
x_3: float32[1, 128, 768][cpu] = Add(x_2, c_1)
x_4: float32[1, 128, 12, 64][cpu] = Reshape(x_3, shape=[1, 128, 12, 64])
x_5: float32[12, 128, 64][cpu] = Rearrange(x_4, plan=[[0, 2], [1], [3]])
x_6: float32[1, 128, 768][cpu] = Matmul(x, c_2, require_prologue=False)
x_7: float32[1, 128, 768][cpu] = Add(x_6, c_3)
x_8: float32[1, 128, 12, 64][cpu] = Reshape(x_7, shape=[1, 128, 12, 64])
x_9: float32[12, 128, 64][cpu] = Rearrange(x_8, plan=[[0, 2], [1], [3]])
x_10: float32[12, 64, 128][cpu] = PermuteDims(x_9, axes=[0, 2, 1])
x_11: float32[12, 128, 128][cpu] = Matmul(x_5, x_10, require_prologue=False)
x_12: float32[12, 128, 128][cpu] = DivideScalar(x_11, scalar=8.0f)
x_13: float32[12, 128, 128][cpu] = Add(x_12, x_1)
x_14: float32[12, 128, 128][cpu] = Softmax(x_13, axis=2)
x_15: float32[1, 128, 768][cpu] = Matmul(x, c_4, require_prologue=False)
x_16: float32[1, 128, 768][cpu] = Add(x_15, c_5)
x_17: float32[1, 128, 12, 64][cpu] = Reshape(x_16, shape=[1, 128, 12, 64])
x_18: float32[12, 128, 64][cpu] = Rearrange(x_17, plan=[[0, 2], [1], [3]])
x_19: float32[12, 128, 64][cpu] = Matmul(x_14, x_18, require_prologue=False)
x_20: float32[1, 12, 128, 64][cpu] = Reshape(x_19, shape=[1, 12, 128, 64])
x_21: float32[1, 128, 768][cpu] = Rearrange(x_20, plan=[[0], [2], [1, 3]])
return x_21
}
Dump netron graph¶
To visualize the flow graph, we need to dump the graph structure to a json file using
hidet.utils.netron.dump()
function.
from hidet.utils import netron
with open('attention-graph.json', 'w') as f:
netron.dump(graph, f)
Above code will generate a json file named attention-graph.json
.
You can download the generated json file
attention-graph.json
and open it with the customized Netron viewer.
Visualize optimization intermediate graphs¶
Hidet also provides a way to visualize the intermediate graphs of the optimization passes.
To get the json files for the intermediate graphs, we need to add an instrument that dumps the graph in the
pass context before optimize it. We can use
PassContext.save_graph_instrument()
method to do that.
with hidet.graph.PassContext() as ctx:
# print the time cost of each pass
ctx.profile_pass_instrument(print_stdout=True)
# save the intermediate graph of each pass to './outs' directory
ctx.save_graph_instrument(out_dir='./outs')
# run the optimization passes
graph_opt = hidet.graph.optimize(graph)
ConvChannelLastPass started...
ConvChannelLastPass 0.007 seconds
SubgraphRewritePass started...
SubgraphRewritePass 0.015 seconds
AutoMixPrecisionPass started...
AutoMixPrecisionPass 0.007 seconds
SelectiveQuantizePass started...
SubgraphRewritePass started...
SubgraphRewritePass 0.015 seconds
SelectiveQuantizePass 0.022 seconds
ResolveVariantPass started...
ResolveVariantPass 0.008 seconds
FuseOperatorPass started...
FuseOperatorPass 0.023 seconds
EliminateBarrierPass started...
EliminateBarrierPass 0.005 seconds
Above code will generate a directory named outs
that contains the json files for the intermediate graphs.
The optimized graph:
print(graph_opt)
Graph(x: float32[1, 128, 768][cpu], x_1: int32[1, 128][cpu]){
c = Constant(float32[768, 768][cpu])
c_1 = Constant(float32[768][cpu])
c_2 = Constant(float32[768, 768][cpu])
c_3 = Constant(float32[768][cpu])
c_4 = Constant(float32[768, 768][cpu])
c_5 = Constant(float32[768][cpu])
x_2: float32[12, 128, 64][cpu] = FusedMatmul(x, c, c_1, fused_graph=FlowGraph(Matmul, Add, Reshape, Rearrange), anchor=0)
x_3: float32[12, 64, 128][cpu] = FusedMatmul(x, c_2, c_3, fused_graph=FlowGraph(Matmul, Add, Reshape, Rearrange, PermuteDims), anchor=0)
x_4: float32[12, 128, 128][cpu] = FusedMatmul(x_2, x_3, x_1, fused_graph=FlowGraph(Matmul, DivideScalar, Add), anchor=0)
x_5: float32[12, 128, 128][cpu] = Softmax(x_4, axis=2)
x_6: float32[12, 128, 64][cpu] = FusedMatmul(x, c_4, c_5, fused_graph=FlowGraph(Matmul, Add, Reshape, Rearrange), anchor=0)
x_7: float32[1, 128, 768][cpu] = FusedMatmul(x_5, x_6, fused_graph=FlowGraph(Matmul, Reshape, Rearrange), anchor=0)
return x_7
}
Summary¶
This tutorial shows how to visualize the flow graph of a model and the intermediate graphs of the optimization passes.
Total running time of the script: (0 minutes 0.123 seconds)