# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Dict, Any, List, Optional, Callable, Iterable, Tuple, Union
import os
import tomlkit
class OptionRegistry:
registered_options: Dict[str, OptionRegistry] = {}
def __init__(
self,
name: str,
type_hint: str,
description: str,
default_value: Any,
normalizer: Optional[Callable[[Any], Any]] = None,
choices: Optional[Iterable[Any]] = None,
checker: Optional[Callable[[Any], bool]] = None,
):
self.name = name
self.type_hint = type_hint
self.description = description
self.default_value = default_value
self.normalizer = normalizer
self.choices = choices
self.checker = checker
def create_toml_doc() -> tomlkit.TOMLDocument:
def nest_flattened_dict(d: Dict[str, Any]) -> Dict[str, Any]:
new_dict = {}
for k, v in d.items():
if '.' in k:
prefix, suffix = k.split('.', 1)
if prefix not in new_dict:
new_dict[prefix] = {suffix: v}
else:
new_dict[prefix][suffix] = v
else:
new_dict[k] = v
for k, v in new_dict.items():
if isinstance(v, dict):
new_dict[k] = nest_flattened_dict(v)
return new_dict
def gen_doc(d: Dict[str, Any], toml_doc: tomlkit.TOMLDocument):
for k, v in d.items():
if isinstance(v, dict):
table = tomlkit.table()
gen_doc(v, table)
toml_doc.add(k, table)
elif isinstance(v, OptionRegistry):
toml_doc.add(tomlkit.comment(v.description))
if v.choices is not None:
toml_doc.add(tomlkit.comment(f' choices: {v.choices}'))
if isinstance(v.default_value, (bool, int, float, str)):
toml_doc.add(k, v.default_value)
elif isinstance(v.default_value, Tuple):
# represent tuples are toml arrays, do not allow python lists are default values to avoid ambiguity
val = list(v.default_value)
arr = tomlkit.array()
arr.extend(val)
toml_doc.add(k, arr)
else:
raise ValueError(f'Invalid type of default value for option {k}: {type(v.default_value)}')
toml_doc.add(tomlkit.nl())
else:
raise ValueError(f'Invalid type of default value for option {k}: {type(v)}')
fd = nest_flattened_dict(OptionRegistry.registered_options)
doc = tomlkit.document()
gen_doc(fd, doc)
return doc
def _load_config(config_file_path: str):
def collapse_nested_dict(d: Dict[str, Any]) -> Dict[str, Union[str, int, float, bool, Tuple]]:
# {"cuda": {"arch": "hopper", "cc": [9, 0]}} -> {"cuda.arch": 90, "cuda.cc": (9, 0)}
ret = {}
for k, v in d.items():
if isinstance(v, dict):
v = collapse_nested_dict(v)
for k1, v1 in v.items():
ret[f'{k}.{k1}'] = v1
continue
if isinstance(v, list):
v = tuple(v)
ret[k] = v
return ret
with open(config_file_path, 'r') as f:
config_doc = tomlkit.parse(f.read())
for k, v in collapse_nested_dict(config_doc).items():
if k not in OptionRegistry.registered_options:
raise KeyError(f'Option {k} found in config file {config_file_path} is not registered.')
OptionRegistry.registered_options[k].default_value = v
def _write_default_config(config_file_path: str, config_doc: tomlkit.TOMLDocument):
with open(config_file_path, 'w') as f:
tomlkit.dump(config_doc, f)
def register_option(
name: str,
type_hint: str,
description: str,
default_value: Any,
normalizer: Optional[Callable[[Any], Any]] = None,
choices: Optional[Iterable[Any]] = None,
checker: Optional[Callable[[Any], bool]] = None,
):
registered_options = OptionRegistry.registered_options
if name in registered_options:
raise KeyError(f'Option {name} has already been registered.')
registered_options[name] = OptionRegistry(name, type_hint, description, default_value, normalizer, choices, checker)
def register_hidet_options():
from hidet.utils import git_utils
register_option(
name='bench_config',
type_hint='Tuple[int, int, int]',
description='The (warmup, number, repeat) parameters for benchmarking. '
'The benchmarking will run warmup + number * repeat times.',
default_value=(3, 10, 3),
)
register_option(
name='search_space', #
type_hint='int',
description='The search space level.',
default_value=0,
choices=[0, 1, 2],
)
register_option(
name='cache_operator',
type_hint='bool',
description='Whether to enable operator cache on disk.',
default_value=True,
choices=[True, False],
)
register_option(
name='cache_dir',
type_hint='path',
description='The directory to store the cache.',
default_value=os.path.abspath(
os.path.join(git_utils.git_repo_root(), '.hidet_cache') # developer mode
if git_utils.in_git_repo()
else os.path.join(os.path.expanduser('~'), '.cache', 'hidet') # user mode
),
normalizer=os.path.abspath,
)
register_option(
name='parallel_build',
type_hint='bool',
default_value=True,
description='Whether to build operators in parallel.',
choices=[True, False],
)
register_option(
name='parallel_tune',
type_hint='int, float',
default_value=(-1, 1.5),
description='The pair (max_parallel_jobs, mem_gb_per_job) that describe '
'the maximum number of parallel jobs and memory reserved for each job',
)
register_option(
name='save_lower_ir',
type_hint='bool',
default_value=False,
description='Whether to save the IR when lower an IRModule to the operator cache.',
choices=[True, False],
)
register_option(
name='debug_cache_tuning',
type_hint='bool',
default_value=False,
description='Whether to cache the generated kernels during tuning.',
choices=[True, False],
)
register_option(
name='debug_show_var_id',
type_hint='bool',
default_value=False,
description='Whether to show the variable id in the IR.',
choices=[True, False],
)
register_option(
name='runtime_check',
type_hint='bool',
default_value=True,
description='Whether to check shapes of compiled graph and tasks during execution.',
choices=[True, False],
)
register_option(
name='debug_show_verbose_flow_graph',
type_hint='bool',
default_value=False,
description='Whether to show the verbose flow graph.',
choices=[True, False],
)
register_option(
name='compile_server.addr',
type_hint='str',
default_value='localhost',
description='The address of the compile server. Can be an IP address or a domain name.',
)
register_option(
name='compile_server.port', type_hint='int', default_value=8329, description='The port of the compile server.'
)
register_option(
name='compile_server.enabled',
type_hint='bool',
default_value=False,
description='Whether to enable the compile server.',
choices=[True, False],
)
register_option(
name='compile_server.username',
type_hint='str',
default_value='admin',
description='The user name to access the compile server.',
)
register_option(
name='compile_server.password',
type_hint='str',
default_value='admin_password',
description='The password to access the compile server.',
)
register_option(
name='compile_server.repo_url',
type_hint='str',
default_value='https://github.com/hidet-org/hidet',
description='The URL of the repository that the remote server will use.',
)
register_option(
name='compile_server.repo_version',
type_hint='str',
default_value='main',
description='The version (e.g., branch, commit, or tag) that the remote server will use.',
)
register_option(
name='cuda.arch',
type_hint='str',
default_value='auto',
description='The CUDA architecture to compile the kernels for (e.g., "sm_70"). "auto" for auto-detect.',
)
config_file_path = os.path.join(os.path.expanduser('~'), '.config', 'hidet')
if not os.path.exists(config_file_path):
os.makedirs(config_file_path)
config_file_path = os.path.join(config_file_path, 'hidet.toml')
if not os.path.exists(config_file_path):
_write_default_config(config_file_path, create_toml_doc())
else:
_load_config(config_file_path)
register_hidet_options()
[docs]class OptionContext:
"""
The option context.
"""
stack: List[OptionContext] = []
def __init__(self):
self.options: Dict[str, Any] = {}
def __str__(self):
pass
def __enter__(self):
"""
Enter the option context.
Returns
-------
ret: OptionContext
The option context itself.
"""
OptionContext.stack.append(self)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""
Exit the option context.
"""
OptionContext.stack.pop()
@staticmethod
def current() -> OptionContext:
return OptionContext.stack[-1]
def load_from_file(self, config_path: str):
import configparser
config = configparser.ConfigParser()
config.read(config_path)
for section in config.sections():
for option in config.options(section):
value = config.get(section, option)
entry_name = '{}.{}'.format(section, option)
if entry_name not in OptionRegistry.registered_options:
raise KeyError(
'Option {} found in config file {} is not registered.'.format(entry_name, config_path)
)
self.set_option(entry_name, value)
def set_option(self, name: str, value: Any):
if name not in OptionRegistry.registered_options:
raise KeyError(f'Option {name} has not been registered.')
registry = OptionRegistry.registered_options[name]
if registry.normalizer is not None:
value = registry.normalizer(value)
if registry.checker is not None:
if not registry.checker(value):
raise ValueError(f'Invalid value for option {name}: {value}')
if registry.choices is not None:
if value not in registry.choices:
raise ValueError(f'Invalid value for option {name}: {value}, choices {registry.choices}')
self.options[name] = value
def get_option(self, name: str) -> Any:
for ctx in reversed(OptionContext.stack):
if name in ctx.options:
return ctx.options[name]
if name not in OptionRegistry.registered_options:
raise KeyError(f'Option {name} has not been registered.')
registry = OptionRegistry.registered_options[name]
return registry.default_value
OptionContext.stack.append(OptionContext())
[docs]def dump_options() -> Dict[str, Any]:
"""
Dump the options in option context stack.
Returns
-------
ret: Dict[str, Any]
The dumped options.
"""
return {'option_context_stack': OptionContext.stack, 'registered_options': OptionRegistry.registered_options}
[docs]def restore_options(dumped_options: Dict[str, Any]):
"""
Restore the options from dumped options.
Parameters
----------
dumped_options: Dict[str, Any]
The dumped options.
"""
OptionContext.stack = dumped_options['option_context_stack']
OptionRegistry.registered_options = dumped_options['registered_options']
[docs]def current_context() -> OptionContext:
"""
Get the current option context.
To get the value of an option in the current context:
.. code-block:: python
ctx = hidet.option.current_context()
cache_dir: str = ctx.get_option('cache_dir')
cache_operator: bool = ctx.get_option('cache_operator')
...
Returns
-------
ctx: OptionContext
The current option context.
"""
return OptionContext.current()
[docs]def context() -> OptionContext:
"""
Create a new option context.
To set options in the new context, use the ``with`` statement:
.. code-block:: python
with hidet.option.context() as ctx:
hidet.option.cache_dir('./new_cache_dir') # set predefined option
hidet.option.set_option('other_option', 'other_value') # set a custom option
...
Returns
-------
ctx: OptionContext
The new option context.
"""
return OptionContext()
[docs]def set_option(name: str, value: Any):
"""
Set the value of an option in current option context.
The option must be registered before setting via :py:func:`hidet.option.register_option`.
Parameters
----------
name: str
The name of the option.
value: Any
The value of the option.
"""
OptionContext.current().set_option(name, value)
[docs]def get_option(name: str) -> Any:
"""
Get the value of an option in current option context.
Parameters
----------
name: str
The name of the option.
Returns
-------
ret: Any
The value of the option.
"""
return OptionContext.current().get_option(name)
[docs]def bench_config(warmup: int = 1, number: int = 5, repeat: int = 5):
"""
Set the benchmark config of operator tuning.
To profile a schedule, hidet will run the following code:
.. code-block:: python
for i in range(warmup):
run()
latency = []
for i in range(repeat):
synchronize device
t1 = time()
for j in range(number):
run()
synchronize device
t2 = time()
latency.append((t2 - t1) / number)
return median of latency
Thus, there will be total ``warmup + number * repeat`` times of execution.
Parameters
----------
warmup: int
The number of warmup runs.
number: int
The number of runs in a repeat.
repeat: int
The number of repeats.
"""
OptionContext.current().set_option('bench_config', (warmup, number, repeat))
[docs]def get_bench_config() -> Tuple[int, int, int]:
"""
Get the benchmark config of operator tuning.
Returns
-------
ret: Tuple[int, int, int]
The benchmark config.
"""
return OptionContext.current().get_option('bench_config')
[docs]def search_space(space: int):
"""
Set the schedule search space of tunable operator.
Some operators can be tuned in hidet to achieve the best performance, such as matrix multiplication.
During tuning, different operator schedules will be tried and profiled to get the best one.
We call the space of the tried operator schedule `schedule space`. There is a trade-off between the
tuning time and the operator execution time. If we try more schedules, the tuning process would take
longer time, and we are likely to find better schedule.
This function allows user to set the space level that controls the search space we tried.
By convention, we have space level
- 0 for schedule space contains only a single schedule.
- 1 for schedule space contains tens of schedules so that the tuning time will be less than 1 minute.
- 2 for arbitrary large space.
Usage
.. code-block:: python
hidet.search_space(2)
After calling above function, all subsequent compilation would use space level 2, until we call this
function again with another space level.
Parameters
----------
space: int
The space level to use. Candidates: 0, 1, and 2.
"""
OptionContext.current().set_option('search_space', space)
[docs]def get_search_space() -> int:
"""
Get the schedule search space of tunable operator.
Returns
-------
ret: int
The schedule space level.
"""
return OptionContext.current().get_option('search_space')
[docs]def cache_operator(enabled: bool = True):
"""
Whether to cache compiled operator on disk.
By default, hidet would cache all compiled operator and reuse whenever possible.
If user wants to disable the cache, run
.. code-block:: python
hidet.option.cache_operator(False)
Parameters
----------
enabled: bool
Whether to cache the compiled operator.
"""
OptionContext.current().set_option('cache_operator', enabled)
[docs]def get_cache_operator() -> bool:
"""
Get the option value of whether to cache compiled operator on disk.
Returns
-------
ret: bool
Whether to cache the compiled operator.
"""
return OptionContext.current().get_option('cache_operator')
[docs]def cache_dir(new_dir: str):
"""
Set the directory to store the cache.
The default cache directory:
- If the hidet code is in a git repo, the cache will be stored in the repo root:
``hidet-repo/.hidet_cache``.
- Otherwise, the cache will be stored in the user home directory: ``~/.hidet/cache``.
Parameters
----------
new_dir: str
The new directory to store the cache.
"""
OptionContext.current().set_option('cache_dir', new_dir)
[docs]def get_cache_dir() -> str:
"""
Get the directory to store the cache.
Returns
-------
ret: str
The directory to store the cache.
"""
return OptionContext.current().get_option('cache_dir')
[docs]def parallel_build(enabled: bool = True):
"""
Whether to build operators in parallel.
Parameters
----------
enabled: bool
Whether to build operators in parallel.
"""
OptionContext.current().set_option('parallel_build', enabled)
[docs]def get_parallel_build() -> bool:
"""
Get the option value of whether to build operators in parallel.
Returns
-------
ret: bool
Whether to build operators in parallel.
"""
return OptionContext.current().get_option('parallel_build')
[docs]def parallel_tune(max_parallel_jobs: int = -1, mem_gb_per_job: float = 1.5):
"""
Specify the maximum number of parallel compilation jobs to do,
and the number of GiB preserved for each job.
Parameters
----------
max_parallel_jobs: int
The maximum number of parallel jobs allowed, default -1
(the number of available vcpu returned by `os.cpu_count()`).
mem_gb_per_job: float
The minimum amount of memory (in GiB) reserved for each tuning job, default 1.5GiB.
"""
OptionContext.current().set_option('parallel_tune', (max_parallel_jobs, mem_gb_per_job))
[docs]def get_parallel_tune() -> Tuple[int, float]:
"""
Get the option value of whether to build operators in parallel.
Returns
-------
ret: Tuple[int, float]
Get the maximum number of jobs and minumum amount of memory reserved for tuning.
"""
return OptionContext.current().get_option('parallel_tune')
[docs]def save_lower_ir(enabled: bool = True):
"""
Whether to save the lower IR.
Parameters
----------
enabled: bool
Whether to save the lower IR.
"""
OptionContext.current().set_option('save_lower_ir', enabled)
[docs]def get_save_lower_ir() -> bool:
"""
Get the option value of whether to save the lower IR.
"""
return OptionContext.current().get_option('save_lower_ir')
[docs]def debug_cache_tuning(enabled: bool = True):
"""
Whether to cache the generated kernels during tuning.
.. note::
This option is only used for debugging purpose. It will generate a lot of files in the cache directory
and take a lot of disk space.
Parameters
----------
enabled: bool
Whether to debug cache tuning.
"""
OptionContext.current().set_option('debug_cache_tuning', enabled)
[docs]def debug_show_var_id(enable: bool = True):
"""
Whether to show the var id in the IR.
When this option is enabled, the IR will show the var id with the format `var@id`, like `x@1` and `d_1@1732`.
Variable (i.e., hidet.ir.Var) a and b is the same var if and only if `a is b` evaluates to True in Python).
Parameters
----------
enable: bool
Whether to show the var id in the IR.
"""
OptionContext.current().set_option('debug_show_var_id', enable)
[docs]def runtime_check(enable: bool = True):
"""
Whether to check shapes and dtypes of all input arguments to compiled Graphs or Tasks.
Parameters
----------
enable: bool
Whether to check shapes and dtypes of all input arguments to compiled Graphs or Tasks.
"""
OptionContext.current().set_option('runtime_check', enable)
[docs]def get_runtime_check() -> bool:
"""
Get whether to check shapes and dtypes of all input arguments to compiled Graphs or Tasks.
Returns
-------
ret: bool
Get whether to check shapes and dtypes of all input arguments to compiled Graphs or Tasks.
"""
return OptionContext.current().get_option('runtime_check')
[docs]def debug_show_verbose_flow_graph(enable: bool = True):
"""Whether to show verbose information (like task) when we convert flow graph in to human-readable text.
Parameters
----------
enable: bool
Whether to show verbose information when we convert flow graph in to human-readable text.
"""
OptionContext.current().set_option('debug_show_verbose_flow_graph', enable)
class cuda:
@staticmethod
def arch(arch: str = 'auto'):
"""
Set the CUDA architecture to use when building CUDA kernels.
Parameters
----------
arch: Optional[str]
The CUDA architecture, e.g., 'sm_35', 'sm_70', 'sm_80', etc. "auto" means
using the architecture of the first CUDA GPU on the current machine. Default "auto".
"""
OptionContext.current().set_option('cuda.arch', arch)
@staticmethod
def get_arch() -> str:
"""
Get the CUDA architecture to use when building CUDA kernels.
Returns
-------
ret: str
The CUDA architecture, e.g., 'sm_35', 'sm_70', 'sm_80', etc.
"""
arch: Optional[str] = OptionContext.current().get_option('cuda.arch')
if arch == "auto":
import hidet.cuda
# get the architecture of the first CUDA GPU
properties = hidet.cuda.properties(0)
arch = 'sm_{}{}'.format(properties.major, properties.minor)
return arch
@staticmethod
def get_arch_pair() -> Tuple[int, int]:
"""
Get the CUDA architecture to use when building CUDA kernels, with major and minor version as a tuple.
Returns
-------
ret: Tuple[int, int]
The CUDA architecture, e.g., (3, 5), (7, 0), (8, 0), etc.
"""
arch = cuda.get_arch()
return int(arch[3]), int(arch[4])
class compile_server:
@staticmethod
def addr(addr: str):
OptionContext.current().set_option('compile_server.addr', addr)
@staticmethod
def port(port: int):
OptionContext.current().set_option('compile_server.port', port)
@staticmethod
def enable(flag: bool = True):
OptionContext.current().set_option('compile_server.enabled', flag)
@staticmethod
def enabled() -> bool:
return OptionContext.current().get_option('compile_server.enabled')
@staticmethod
def username(username: str):
OptionContext.current().set_option('compile_server.username', username)
@staticmethod
def password(password: str):
OptionContext.current().set_option('compile_server.password', password)
@staticmethod
def repo(repo_url: str, version: str = 'main'):
OptionContext.current().set_option('compile_server.repo_url', repo_url)
OptionContext.current().set_option('compile_server.repo_version', version)
# load the options from config file (e.g., ~/.config/hidet.config) if exists
_config_path = os.path.join(os.path.expanduser('~'), '.config', 'hidet.config')
if os.path.exists(_config_path):
OptionContext.current().load_from_file(_config_path)