Source code for hidet.drivers.build_module

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Sequence, Dict, Union
import logging
import os
import pickle
import random
from tqdm import tqdm

import hidet.cuda
from hidet.backend import codegen, compile_source
from hidet.drivers.utils import lazy_initialize_cuda
from hidet.ir.module import IRModule
from hidet.ir.type import FuncType
from hidet.ir.target import Target
from hidet.transforms import lower, PassContext, SaveIRInstrument, ProfileInstrument
from hidet.utils.multiprocess import parallel_imap_2ndlevel, get_parallel_num_workers
from hidet.utils.stack_limit import set_stack_limit
from hidet.utils.folder_lock import FolderLock

logger = logging.Logger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler())


def can_remote_build(ir_module: IRModule) -> bool:
    def can_remote_single_build(ir_module: IRModule) -> bool:
        return not (
            len(ir_module.object_files) > 0 or len(ir_module.linking_dirs) > 0 or len(ir_module.include_dirs) > 0
        )

    if isinstance(ir_module, IRModule):
        return can_remote_single_build(ir_module)
    else:
        return all(can_remote_single_build(m) for m in ir_module)


def create_instruments(output_dir: str, ir_module: IRModule):
    instruments = []
    if hidet.option.get_save_lower_ir():
        ir_candidate_dir = os.path.join(output_dir, 'ir', ir_module.namespace)
        instruments.extend(
            [
                SaveIRInstrument(out_dir=ir_candidate_dir),
                ProfileInstrument(log_file=os.path.join(ir_candidate_dir, 'lower_time.txt')),
            ]
        )
    return instruments


def configure_target(target):
    if target.name == 'cuda':
        if 'arch' in target.attrs:
            hidet.option.cuda.arch(target.attrs['arch'])
        if 'cpu_arch' in target.attrs:
            hidet.option.cpu.arch(target.attrs['cpu_arch'])
    elif target.name == 'cpu' and 'arch' in target.attrs:
        hidet.option.cpu.arch(target.attrs['arch'])


[docs]def build_ir_module( ir_module: Union[IRModule, Sequence[IRModule]], output_dir: str, target: str, output_kind: str = '.so', force: bool = False, ): """ Build an IR module to a shared library or object file. This driver function performs the following steps to build an IR module: 1. Lower and optimize the IR module with a sequence of pre-defined passes. 2. Generate source code from the lowered IR module. 3. Call the underlying compiler (e.g., gcc or nvcc) to compile the generated source code into a shared library (when `output_kind == '.so'`) or an object file (when `output_kind == '.o'`). To ensure safe parallel execution in a multiprocessing environment, a file-based lock (`.lock` file in the `output_dir`) is used. This guarantees that only one process can build the IR module for a given `output_dir` at any given time. Parameters ---------- ir_module: Union[IRModule, Sequence[IRModule]] The IR module to be built. This can be a single IRModule or a sequence of IRModules. output_dir: str The directory to save the generated source code and the compiled library. target: str The target to build the IR module. Supported targets are `cpu` and `cuda`. Attributes (e.g., 'cuda --arch=sm_70') can also be specified. output_kind: str The output kind. Supported kinds are `'.so'` and `'.o'`. - `'.so'`: Compile the IR module to a shared library. - `'.o'`: Compile the IR module to an object file. force: bool Whether to force re-build the IR module. By default, the IR module will not be re-built if the library already exists in the specified output directory. Notes ----- - **File Locking:** A `.lock` file is created in the `output_dir` to synchronize access. If another process tries to build the same IR module concurrently, it will wait until the lock is released. - **Parallel Safety:** The file lock ensures that only one process builds the IR module for a specific `output_dir`. """ lib_name = get_library_name(output_kind) lib_path = os.path.join(output_dir, lib_name) # Acquire file lock for this output directory with FolderLock(output_dir): # Locks on .lock file in the output directory if should_skip_build(lib_path, output_kind, output_dir, force): return if hidet.option.compile_server.enabled() and can_remote_build(ir_module): from hidet.apps.compile_server import remote_build remote_build(ir_module, output_dir, target=target, output_kind=output_kind) return target = Target.from_string(target) if isinstance(target, str) else target src_path = get_source_path(output_dir, target) # Set the recursion limit for lowering set_stack_limit() # Lower the IR module ir_module = lower_ir_module(ir_module, output_dir, target) # Generate source code codegen(ir_module, src_out_path=src_path, target=target) # Collect dependencies for compilation include_dir, linking_dir, linking_lib, object_file = collect_dependencies(ir_module) # Compile source code compile_source( src_path, output_library_file=lib_path, target=target, include_dirs=include_dir, linking_dirs=linking_dir, linking_libraries=linking_lib, object_files=object_file, ) # Write function types for shared libraries if output_kind == '.so': write_function_types(ir_module, output_dir)
[docs]def build_ir_module_batch( ir_modules: Sequence[IRModule], output_dirs: Sequence[str], output_kind: str, target: str, force: bool = False ): """ Build a batch of IR modules. Parameters ---------- ir_modules: Sequence[IRModule] A sequence of IR modules to build. output_dirs: Sequence[str] Directories for compilation artifacts. output_kind: str The output kind of the compiled library. Can be `'.so'` or `'.o'`. target: str The target of the compilation. Can be 'cuda' or 'cpu'. force: bool Whether to force re-build the IR module. By default, the IR module will not be re-built if the library already exists in the specified output directory. """ def build_job(args): ir_module, output_dir = args build_ir_module(ir_module, output_dir, output_kind=output_kind, target=target, force=force) def regroup_modules(modules): """ Regroup IR modules for parallel processing. """ from hidet.utils import cdiv MAX_CANDIDATES_PER_JOB = 32 num_workers = get_parallel_num_workers(is_remote_allowed=True) len_modules = len(modules) if len_modules <= num_workers: return modules num_new_jobs = cdiv(len_modules, num_workers * MAX_CANDIDATES_PER_JOB) * num_workers job_per_worker = len_modules // num_new_jobs num_modules_for_1st_pass = job_per_worker * num_new_jobs grouped_modules = [modules[i : i + job_per_worker] for i in range(0, num_modules_for_1st_pass, job_per_worker)] remainder = modules[num_modules_for_1st_pass:] for i, module in enumerate(remainder): grouped_modules[i % len(grouped_modules)].append(module) assert sum(len(group) for group in grouped_modules) == len(modules) return grouped_modules def check_function_singular(module_list): """ Ensure no duplicate function names exist after regrouping. """ if not module_list or isinstance(module_list[0], IRModule): return True name_set = set() for modules in module_list: for module in modules: namespace = module.namespace for func_name in module.extern_functions.keys() | module.functions.keys(): func_str = f"{namespace}::{func_name}" if func_str in name_set: return False name_set.add(func_str) return True # Shuffle modules for balanced workloads random.seed(42) random.shuffle(ir_modules) random.seed() lazy_initialize_cuda() ir_modules_list = regroup_modules(ir_modules) assert check_function_singular(ir_modules_list), "Duplicate function names detected in regrouped modules." jobs = [(group, output_dir) for group, output_dir in zip(ir_modules_list, output_dirs[: len(ir_modules_list)])] for _ in tqdm( parallel_imap_2ndlevel(build_job, jobs, is_remote_allowed=True), desc="Compiling", total=len(jobs), ncols=80 ): pass return output_dirs[: len(ir_modules_list)]
def get_library_name(output_kind): if output_kind == '.so': return 'lib.so' elif output_kind == '.o': return 'lib.o' else: raise ValueError(f"Invalid output kind: {output_kind}") def should_skip_build(lib_path, output_kind, output_dir, force): '''lib_path always contains .lock file''' return ( os.path.exists(lib_path) and os.path.getsize(lib_path) > 1 and (output_kind != '.so' or os.path.exists(os.path.join(output_dir, 'func_types.pickle'))) and not force ) def get_source_path(output_dir, target): if target.name == 'cuda': src_path = os.path.join(output_dir, 'source.cu') elif target.name == 'hip': src_path = os.path.join(output_dir, 'source.hip.cpp') elif target.name == 'cpu': src_path = os.path.join(output_dir, 'source.cc') else: raise ValueError(f'Invalid target: {target}') return src_path def lower_ir_module(ir_module, output_dir, target): configure_target(target) if isinstance(ir_module, Sequence): for i in range(len(ir_module)): instruments = create_instruments(output_dir, ir_module[i]) with PassContext(instruments=instruments): ir_module[i] = lower(ir_module[i]) else: instruments = create_instruments(output_dir, ir_module) with PassContext(instruments=instruments): ir_module = lower(ir_module) return ir_module def collect_dependencies(ir_module): include_dir, linking_dir, linking_lib, object_file = [], [], [], [] if isinstance(ir_module, Sequence): for im in ir_module: include_dir.extend(im.include_dirs) linking_dir.extend(im.linking_dirs) linking_lib.extend(im.linking_libs) object_file.extend(im.object_files) else: include_dir.extend(ir_module.include_dirs) linking_dir.extend(ir_module.linking_dirs) linking_lib.extend(ir_module.linking_libs) object_file.extend(ir_module.object_files) return include_dir, linking_dir, linking_lib, object_file def write_function_types(ir_module, output_dir): """ Write function types for public functions in the IR module. """ func_types: Dict[str, FuncType] = { func.name: FuncType.from_func(func) for func in ir_module.functions.values() if func.kind == 'public' } with open(os.path.join(output_dir, 'func_types.pickle'), 'wb') as f: pickle.dump(func_types, f)