# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import List, Optional, Dict, Any
import logging
import hidet.option
from hidet.graph.flow_graph import FlowGraph
from hidet.graph.transforms.graph_patterns.base import SubgraphRewriteRule
from .instruments import GraphPassInstrument
logger = logging.Logger(name='hidet.graph.transforms', level=logging.INFO)
logger.addHandler(logging.StreamHandler())
[docs]class PassContext:
"""Graph-level pass context.
Use the pass context to control the behavior of optimization passes. Normally, we can
optimize a flow graph by directly calling :func:`hidet.graph.optimize`:
.. code-block:: python
graph_opt = hidet.graph.optimize(graph)
This will optimize the given flow graph in a default context.
To customize the optimizations, run the :func:`~hidet.graph.optimize` function with in
a custom :class:`hidet.graph.PassContext`:
.. code-block:: python
with hidet.graph.PassContext() as ctx:
# config the contexts
ctx.profile_pass_instrument(print_stdout=True) # print elapsed time for each pass
ctx.save_graph_instrument(out_dir='./outs') # save the output of each pass as text
ctx.set_precision(dtype='float16') # use float16 as the data type
ctx.set_reduce_precision(dtype='float32') # use float32 for reduction accumulation
ctx.set_mma('mma') # use TensorCore in NVIDIA GPUs to accelerate matmul and conv2d
... # other configs
# call optimize function
graph_opt = hidet.graph.optimize(graph)
Please refer to the member functions of this class for the available configs and their usage.
Attributes
----------
instruments: List[GraphPassInstrument]
The graph pass instruments that will be applied before and after each pass. The instruments
will be applied in order. See :class:`hidet.graph.GraphPassInstrument` on how to add custom
instrument.
configs: Dict[str, Any]
The current configs of the pass context.
"""
_stack: List['PassContext'] = []
def __init__(self):
self.instruments: List[GraphPassInstrument] = []
self.configs: Dict[str, Any] = {
# target precision:
# [None, 'int8', 'float16', 'bfloat16', 'float32']
'precision': None,
# selectively quantize the given graph patterns
'quantize_patterns': [],
# target reduce precision:
# [None, 'float16', 'float32']
'reduce_precision': None,
# use attention or not
# [True, False]
'use_attention': False,
# mma primitive:
# ['simt', 'mma']
'mma': 'simt',
# parallel k
# ['default', 'disabled', 'search', 2, 4, ...]
'parallel_k': 'default',
# print lower details
'verbose': False,
}
def __enter__(self) -> PassContext:
self._stack.append(self)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
from ..transforms.graph_patterns.attn_patterns import deregister_attn_patterns
deregister_attn_patterns()
popped = self._stack.pop()
assert popped == self
[docs] @classmethod
def current(cls):
"""
Get the current pass context.
Returns
-------
ret: PassContext
The current pass context.
"""
if len(cls._stack) == 0:
cls._stack.append(PassContext())
return cls._stack[-1]
[docs] def set_precision(self, dtype: Optional[str] = None) -> PassContext:
"""
Set the target precision to use as the output of most operators. To retain the accuracy,
some operators will still use the original data type.
Parameters
----------
dtype: Optional[str]
The target dtype to mix the precision of the model. Candidates:
- None
Do not mix the precision.
- 'int8'
Converts the model into float16 data type, then selectively quantize subgraphs
using default quantize_patterns.
For greater flexibility and control of quantization, use self.add_quantize_pattern(),
to selectively quantize subgraphs using custom quantize_patterns.
- 'float16'
Convert the model into float16 data type.
- 'bfloat16'
Convert the model into bfloat16 data type.
- 'float32'
Convert the model into float32 data type.
"""
if dtype == 'int8':
self.add_quantize_rules(hidet.graph.quant.default_patterns())
self.configs['precision'] = 'float16'
else:
self.configs['precision'] = dtype
return self
[docs] def add_quantize_rules(self, patterns: List[SubgraphRewriteRule]) -> PassContext:
"""
Adds selective quantization rules to the pass context.
Parameters
----------
pattern: Optional[List[SubgraphRewriteRule]]
The pattern to selectively quantize.
- List[SubgraphRewriteRule]
Adds new rules on top of what is already there. The new rules will be applied
after the existing ones.
"""
if patterns is not None:
for pat in patterns:
if isinstance(pat, SubgraphRewriteRule):
self.configs['quantize_patterns'].append(pat)
elif issubclass(pat, SubgraphRewriteRule):
self.configs['quantize_patterns'].append(pat())
else:
self.configs['quantize_patterns'] = []
return self
[docs] def set_reduce_precision(self, dtype: Optional[str] = None) -> PassContext:
"""
Set the target precision used for accumulation results. Operators like reduce_mean, reduce_avg,
matrix multiplication and convolution will reduce along some dimensions. We might want to use a
data type with more precision to accumulate the results for more accuracy.
Parameters
----------
dtype: Optional[str]
The target dtype to use for accumulation.
- None
Use the same as inputs of operators.
- 'float16'
Use 'float16' to accumulate. Only valid when set_precision('float16') has been used.
- 'float32'
Use 'float32' to accumulate.
"""
self.configs['reduce_precision'] = dtype
return self
[docs] def set_use_attention(self, flag=False) -> PassContext:
"""
Set to use fused attention schedule
"""
# fmha requires sm75+
cc = hidet.option.cuda.get_arch_pair()
if cc < (7, 5):
return self
from ..transforms.graph_patterns.attn_patterns import register_attn_patterns, deregister_attn_patterns
self.configs['use_attention'] = flag
if flag:
register_attn_patterns()
else:
deregister_attn_patterns()
return self
[docs] def set_verbose(self) -> PassContext:
"""
Allow each graph level passes to print detailed information related to its lowering and optimization.
"""
self.configs['verbose'] = True
return self
[docs] def set_mma(self, mma: str) -> PassContext:
"""
Specify the matrix-multiply-accumulate (mma) computation primitives used in matrix multiplication and
convolution.
Parameters
----------
mma: str
The mma computation primitive to use. Candidates:
- 'simt'
Use cuda cores.
- 'mma'
Use mma instructions.
"""
self.configs['mma'] = mma
return self
[docs] def set_parallel_k(self, disabled=False, default=False, search=False, nparts: Optional[int] = None):
"""
Set the strategy to parallel on reduction dimension for matrix multiplication and convolution.
Only one of the three parameters should be specified.
Parameters
----------
disabled: bool
Disable the parallelization on reduction dimension.
default: bool
Allow hidet to figure our the parallel factor.
search: bool
Whether to search the k.
nparts: Optional[int]
Use a fixed factor.
"""
if sum([disabled, default, search, nparts is not None]) > 1:
raise ValueError('Only one of parameters should be set.')
if disabled:
self.configs['parallel_k'] = 'disabled'
if default:
self.configs['parallel_k'] = 'default'
if search:
self.configs['parallel_k'] = 'search'
if nparts is not None:
self.configs['parallel_k'] = nparts
[docs] def save_graph_instrument(self, out_dir) -> PassContext:
"""
Save the computation graph after each pass to given output directory.
Parameters
----------
out_dir: str
The directory to save graph.
"""
from .instruments.save_graph_instrument import SaveGraphInstrument # pylint: disable=import-outside-toplevel
self.instruments.append(SaveGraphInstrument(out_dir))
return self
[docs] def profile_pass_instrument(self, log_file: Optional[str] = None, print_stdout: bool = False) -> PassContext:
"""
Profile the time of each pass.
Parameters
----------
log_file: Optional[str]
When given, write the elapsed time for each pass to this file.
print_stdout: bool
Whether to print the elapsed time for each pass to standard output.
"""
from .instruments.profile_instrument import ProfileInstrument # pylint: disable=import-outside-toplevel
self.instruments.append(ProfileInstrument(log_file, print_stdout))
return self
[docs] def reduce_cuda_compile_mem(self, enable: Optional[bool] = None):
"""
Reduce CUDA memory used during compilation by using vcuda tensors, might incur compile time cost
Parameters
----------
enable: Optional[bool]
When given, will always enable or disable this instrument.
If no argument is given, the compiler will decide to enable this with some heuristics
"""
from .instruments import ConvertGraphToVCuda # pylint: disable=import-outside-toplevel
self.instruments.append(ConvertGraphToVCuda(enable))
class GraphPass:
def __init__(self):
self.name = self.__class__.__name__
def __call__(self, graph: FlowGraph) -> FlowGraph:
ctx = PassContext.current()
for inst in ctx.instruments:
inst.before_pass(self.name, graph)
graph = self.process_graph(graph)
for inst in reversed(ctx.instruments):
inst.after_pass(self.name, graph)
return graph
@staticmethod
def current_context() -> PassContext:
return PassContext.current()
def process_graph(self, graph: FlowGraph) -> FlowGraph:
raise NotImplementedError()