# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional, Callable
import os
import pickle
import time
import warnings
import ctypes
from import FuncType, PointerType, DataType, BaseType, VoidType, TensorPointerType
from hidet.ffi.shared_lib import SharedLibrary
from hidet.ffi.utils import c_pointer_compatible

class CompiledModuleLoadError(Exception):

[docs]class CompiledFunction: """ A compiled function that can be directly called. """ def __init__(self, name, func_type: FuncType, ctypes_func): str = name self.func_type: FuncType = func_type self.ctypes_func: Callable = ctypes_func self._update_func_signature() def __call__(self, *args): from hidet.ffi.ffi import BackendException, get_last_error ret = self.ctypes_func(*args) status = get_last_error() if status is not None: msg = 'Calling {} with arguments {} failed. error:\n{}'.format(, args, status) raise BackendException(msg) return ret def _parse_type(self, hidet_type: BaseType): if isinstance(hidet_type, DataType): from import dtypes mapping = { dtypes.int8: ctypes.c_int8, dtypes.int16: ctypes.c_int16, dtypes.int32: ctypes.c_int32, dtypes.int64: ctypes.c_int64, dtypes.uint8: ctypes.c_uint8, dtypes.uint16: ctypes.c_uint16, dtypes.uint32: ctypes.c_uint32, dtypes.uint64: ctypes.c_uint64, # dtypes.float16: sadly, there is no float16 in ctypes for now, we might need to create a custom type dtypes.float32: ctypes.c_float, dtypes.float64: ctypes.c_double, dtypes.boolean: ctypes.c_bool, # dtypes.complex64: # dtypes.complex128: } if hidet_type not in mapping: raise NotImplementedError('Unsupported type {}'.format(hidet_type)) return mapping[hidet_type] elif isinstance(hidet_type, VoidType): return None elif isinstance(hidet_type, (PointerType, TensorPointerType)): return c_pointer_compatible else: raise NotImplementedError('Unsupported type {}'.format(hidet_type)) def _update_func_signature(self): self.ctypes_func.argtypes = [self._parse_type(hidet_type) for hidet_type in self.func_type.param_types] self.ctypes_func.restype = self._parse_type(self.func_type.ret_type) def profile(self, *args, warmup=1, number=2, repeat=10): from hidet.cuda import current_stream for _ in range(warmup): self.ctypes_func(*args) results = [] for _ in range(repeat): current_stream().synchronize() start = time.time() for _ in range(number): self.ctypes_func(*args) current_stream().synchronize() end = time.time() results.append((end - start) / number * 1000) return results
class CompiledModule: def __init__(self, module_dir: str): self.module_dir: str = module_dir self.shared_library: SharedLibrary = self._load_shared_library() self.functions: Dict[str, CompiledFunction] = self._load_functions() def __call__(self, *args): if 'launch' not in self.functions: raise RuntimeError('Launch function not found.') return self.functions['launch'](*args) def __getitem__(self, item: str) -> CompiledFunction: return self.functions[item] def _load_shared_library(self): lib_path = os.path.join(self.module_dir, '') if not os.path.exists(lib_path): raise CompiledModuleLoadError('Shared library {} does not exist.'.format(lib_path)) return SharedLibrary(lib_path) def _load_functions(self): func_types_path = os.path.join(self.module_dir, 'func_types.pickle') if not os.path.exists(func_types_path): raise CompiledModuleLoadError('Function types {} does not exist.'.format(func_types_path)) with open(func_types_path, 'rb') as f: func_types: Dict[str, FuncType] = pickle.load(f) functions: Dict[str, CompiledFunction] = {} for name, func_type in func_types.items(): functions[name] = CompiledFunction(name, func_type, self.shared_library['hidet_' + name]) return functions def source(self, color=False) -> Optional[str]: if os.path.exists(os.path.join(self.module_dir, '')): src_path = os.path.join(self.module_dir, '') elif os.path.exists(os.path.join(self.module_dir, '')): src_path = os.path.join(self.module_dir, '') else: src_path = None if src_path is None: return None with open(src_path, 'r') as f: src_code = if color: import importlib.util if importlib.util.find_spec('pygments'): from pygments import highlight from pygments.lexers import CudaLexer from pygments.formatters import Terminal256Formatter return highlight(src_code, CudaLexer(), Terminal256Formatter(style='autumn')) else: warnings.warn('pygments is not installed, please install it to enable colorized source code.') return src_code def profile(self, *args, warmup=1, number=2, repeat=10): return self['launch'].profile(*args, warmup=warmup, number=number, repeat=repeat) def load_compiled_module(module_dir: str) -> CompiledModule: return CompiledModule(module_dir) def compiled_module_exists(module_dir: str) -> bool: required_files = ['', 'func_types.pickle'] for file in required_files: if not os.path.exists(os.path.join(module_dir, file)): return False return True