Source code for hidet.runtime.compiled_task

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Dict, Tuple, Union, Optional, Iterable
from dataclasses import dataclass
import os
import json
from collections import namedtuple
import tabulate
from hidet.runtime.compiled_module import CompiledModule, CompiledFunction, load_compiled_module
from hidet.ir.dtypes import i32
from hidet.ffi import runtime_api
from hidet.ffi.array import Array


@dataclass
class TensorSignature:
    device: str
    dtype: str
    shape: List[Union[str, int]]


@dataclass
class TaskMetaData:
    name: str
    symbols: List[str]
    inputs: List[TensorSignature]
    outputs: List[TensorSignature]
    share_map: Dict[int, int]
    target: str
    num_candidates: int
    hidet_version: str


[docs]class CompiledTask: """ A compiled task is a special kind of compiled module that implements a computation task. A compiled task is a compiled module with the following conventions: 1. The compiled module contains functions named `launch_0`, `launch_1`, ..., `launch_N-1`, where N is the number of candidates for the task. 2. There are two shape-related functions `get_input_shape` and `get_output_shape` that return the shape of inputs and outputs respectively. When a compiled task is called, the input arguments should be consistent with the input signature of the task. The compiled task will pick the best candidate based on the input shapes and dispatch the computation to the corresponding candidate. The output tensors will be created and passed to the candidate function as arguments. When the candidate finishes the execution, the output tensors will be returned. This class is not intended to be instantiated by users directly. Instead, users should use the :func:`load_compiled_task` function to load a compiled task from the given directory, or use :func:`hidet.drivers.build_task` to build a compiled task from a task definition. Parameters ---------- task_dir: str The directory of the compiled task. """ def __init__(self, task_dir: str): self.task_dir: str = task_dir self.meta_data: TaskMetaData = self._load_meta_data() self.task_module: CompiledModule = load_compiled_module(task_dir) self.candidates: List[CompiledFunction] = [ self.task_module['launch_{}'.format(i)] for i in range(self.meta_data.num_candidates) ] self.dispatch_table: Dict[Tuple[int, ...], int] = self._load_dispatch_table() self._get_input_shape = self.task_module['get_input_shape'] self._get_output_shape = self.task_module['get_output_shape'] def __call__(self, *args): """ Run the compiled task with the given arguments. Parameters ---------- args: a sequence of input tensors or scalars The input arguments. They should be consistent with the input signature of the task. Returns ------- A sequence of output tensors: The output tensors. They are created by the task and passed to the candidate function as arguments. When the candidate finishes the execution, the output tensors will be returned. """ outs = self.run_async(args) if len(outs) == 1: return outs[0] else: return outs def _load_meta_data(self) -> TaskMetaData: from hidet.utils.dataclass import from_dict meta_data_path = os.path.join(self.task_dir, 'meta.json') with open(meta_data_path, 'r') as f: return from_dict(TaskMetaData, json.load(f)) def _load_compiled_modules(self) -> List[CompiledModule]: compiled_modules = [] candidates_dir = os.path.join(self.task_dir, 'candidates') if not os.path.exists(candidates_dir) or not os.path.isdir(candidates_dir): raise RuntimeError(f'Cannot find candidates dir: {candidates_dir}') for module_dir in os.listdir(candidates_dir): if not os.path.isdir(module_dir): continue compiled_modules.append(CompiledModule(module_dir)) if len(compiled_modules) == 0: raise RuntimeError(f'No compiled module found in {candidates_dir}') return compiled_modules def _load_dispatch_table(self): dispatch_table_path = os.path.join(self.task_dir, 'dispatch_table.txt') if not os.path.exists(dispatch_table_path): return {} dispatch_table = {} with open(dispatch_table_path, 'r') as f: for i, line in enumerate(f.readlines()): if i == 0: continue items = line.split() if len(items) == 0: continue if len(items) != len(self.meta_data.symbols) + 1: os.remove(dispatch_table_path) raise RuntimeError(f'Invalid dispatch table: {dispatch_table_path}') key = tuple(int(item) for item in items[:-1]) value = int(items[-1]) dispatch_table[key] = value return dispatch_table def _get_symbol_values(self) -> Tuple[int, ...]: return tuple(runtime_api.get_symbol_value(symbol) for symbol in self.meta_data.symbols) def create_outputs(self, inputs): import hidet outputs = [] for idx, sig in enumerate(self.meta_data.outputs): shape_buffer = Array(i32, len(sig.shape)) self._get_output_shape(idx, shape_buffer) shape: List[int] = list(shape_buffer) if idx not in self.meta_data.share_map: outputs.append(hidet.empty(shape, sig.dtype, sig.device)) else: shared_tensor = inputs[self.meta_data.share_map[idx]] if not isinstance(shared_tensor, hidet.Tensor): import torch assert isinstance(shared_tensor, torch.Tensor), "Unknown tensor type" tensor_dtype = getattr(torch, sig.dtype) # we need to turn the tensor into a view with the graph output's shape & dtype input_tensor = shared_tensor.view(*shape).view(tensor_dtype) else: input_tensor = hidet.Tensor( shape=shape, dtype=sig.dtype, device=sig.device, storage=shared_tensor.storage ) outputs.append(input_tensor) return outputs def pick_best_candidate(self, inputs, outputs) -> int: from hidet.utils.benchmark.bench import find_best_candidate key = self._get_symbol_values() if key not in self.dispatch_table: if len(self.candidates) > 1: best_idx, latencies = find_best_candidate(self.candidates, self.meta_data.name, *inputs, *outputs) self.dispatch_table[key] = best_idx # write a benchmark report report_name = '_'.join('{}_{}'.format(a, b) for a, b in zip(self.meta_data.symbols, key)) os.makedirs(os.path.join(self.task_dir, 'reports'), exist_ok=True) report_path = os.path.join(self.task_dir, 'reports', report_name + '.txt') with open(os.path.join(self.task_dir, 'candidates.json'), 'r') as f: candidates_json = json.load(f) headers: List[str] = candidates_json['headers'] candidate_lines: List[List[str]] = candidates_json['candidates'] headers.extend(['latency', 'rank']) sorted_indices = sorted(range(len(latencies)), key=lambda i: latencies[i]) for idx, line in enumerate(candidate_lines): line.extend(['{:.6f} ms'.format(latencies[idx]), sorted_indices.index(idx)]) candidate_lines.sort(key=lambda l: l[-1]) with open(report_path, 'w') as f: f.write(tabulate.tabulate(candidate_lines, headers=headers, tablefmt='plain')) else: assert len(self.candidates) == 1 self.dispatch_table[key] = 0 # write the best candidate to dispatch table dispatch_table_path = os.path.join(self.task_dir, 'dispatch_table.txt') if not os.path.exists(dispatch_table_path): with open(dispatch_table_path, 'w') as f: f.write(' '.join(self.meta_data.symbols) + '\n') with open(dispatch_table_path, 'a') as f: f.write(' '.join([str(v) for v in key]) + ' ' + str(self.dispatch_table[key]) + '\n') candidate_index = self.dispatch_table[key] if candidate_index >= len(self.candidates): raise RuntimeError(f'Invalid candidate index: {candidate_index}') return candidate_index
[docs] def run_async(self, inputs): """ Run the compiled task with the given arguments. Parameters ---------- inputs: a sequence of input tensors or scalars The input arguments. They should be consistent with the input signature of the task. Returns ------- A sequence of output tensors: The output tensors. They are created by the task and passed to the candidate function as arguments. When the candidate finishes the execution, the output tensors will be returned. """ from hidet import option if option.get_runtime_check(): _check_inputs(self.meta_data.inputs, inputs) outputs = self.create_outputs(inputs) candidate = self.candidates[self.pick_best_candidate(inputs, outputs)] candidate(*inputs, *outputs) return outputs
[docs] def profile(self, *args, warmup=1, number=2, repeat=10): """ Run the compiled task with the given arguments and profile the execution time. Parameters ---------- args: a sequence of input tensors or scalars The input arguments. They should be consistent with the input signature of the task. warmup: int The number of warmup runs. number: int The number of runs for each measurement. repeat: int The number of measurements. Returns ------- latency: List[float] The measured latency in milliseconds. The length of the list is equal to `repeat`. """ num_inputs = len(self.meta_data.inputs) inputs = args[:num_inputs] outputs = args[num_inputs:] # For operators like scatter_add_, if we run it multiple times on the same input & output tensors, # the input and output tensors will be wrong as they will be wrongly updated multiple times. # to avoid this, make a clone of the output tensors if they share the memory with some input tensors. if len(self.meta_data.share_map) > 0: from hidet import Tensor outputs = list(outputs) inputs = list(inputs) for output_idx in self.meta_data.share_map: original_output = outputs[output_idx] if isinstance(original_output, Tensor): outputs[output_idx] = original_output.copy() else: outputs[output_idx] = original_output.clone() args = inputs + outputs candidate = self.candidates[self.pick_best_candidate(inputs, outputs)] return candidate.profile(*args, warmup=warmup, number=number, repeat=repeat)
[docs]def load_compiled_task(compiled_task_dir: str) -> CompiledTask: """ Load a compiled task from the given directory. Parameters ---------- compiled_task_dir: str The directory of the compiled task. Returns ------- ret: CompiledTask The loaded compiled task. """ return CompiledTask(compiled_task_dir)
CompiledTaskKey = namedtuple('CompiledTaskKey', ['device', 'space', 'task_str']) class CompiledTaskCache: def __init__(self): self.cached: Dict[Tuple[str, int, str], CompiledTask] = {} def contains(self, device_type: str, space: int, task_str: str) -> bool: key = CompiledTaskKey(device_type, space, task_str) return key in self.cached def get(self, device_type: str, space: int, task_str: str) -> Optional[CompiledTask]: key = CompiledTaskKey(device_type, space, task_str) return self.cached.get(key) if key in self.cached else None def add(self, device_type: str, space: int, task_str: str, compiled_task: CompiledTask): key = CompiledTaskKey(device_type, space, task_str) self.cached[key] = compiled_task compiled_task_cache = CompiledTaskCache() def _check_inputs(traced_inputs: Iterable[TensorSignature], inputs): from hidet.ir import data_type from hidet.graph.frontend.torch.utils import dtype_to_torch from torch import Tensor as TorchTensor symbol_map = {} for i, (traced, new) in enumerate(zip(traced_inputs, inputs)): if isinstance(new, TorchTensor): traced_dev_kind = traced.device.partition(':')[0] new_device_target = 'cuda' if new.device.type in ['cuda', 'vcuda'] else 'cpu' if traced_dev_kind != new_device_target: raise RuntimeError( f"device mismatch at arg {i} between original: {traced.device} and new: {new.device.kind}" ) if dtype_to_torch(data_type(traced.dtype)) != new.dtype: raise RuntimeError(f"dtype mismatch at arg {i} between original: {traced.dtype} and new: {new.dtype}") traced_shape = traced.shape concrete_shape = new.shape if len(traced_shape) != len(concrete_shape): raise RuntimeError( f"Rank of input {i} not equal to original. ({len(concrete_shape)} vs. {len(traced_shape)})" ) for j, (orig_shape, new_shape) in enumerate(zip(traced_shape, concrete_shape)): if isinstance(orig_shape, int) and orig_shape != new_shape: raise RuntimeError( f'shape mismatch at dimension {j}, original: \ {orig_shape} vs. new: {new_shape}' ) elif orig_shape not in symbol_map: symbol_map[orig_shape] = new_shape elif symbol_map[orig_shape] != new_shape: raise RuntimeError( f"There exists multiple instances of the same symbol {orig_shape}\ with different values in inputs (ex: {symbol_map[orig_shape]} and {new_shape})" ) else: traced_dev_kind = traced.device.partition(':')[0] if traced_dev_kind != new.device.target: raise RuntimeError( f"device mismatch at arg {i} between original: {traced.device} and new: {new.device.kind}" ) if data_type(traced.dtype) != new.dtype: raise RuntimeError(f"dtype mismatch at arg {i} between original: {traced.dtype} and new: {new.dtype}") traced_shape = traced.shape concrete_shape = new.shape if len(traced_shape) != len(concrete_shape): raise RuntimeError( f"Rank of input {i} not equal to original. ({len(concrete_shape)} vs. {len(traced_shape)})" ) for j, (orig_shape, new_shape) in enumerate(zip(traced_shape, concrete_shape)): if isinstance(orig_shape, int) and orig_shape != new_shape: raise RuntimeError( f'shape mismatch at dimension {j}, original: \ {orig_shape} vs. new: {new_shape}' ) elif orig_shape not in symbol_map: symbol_map[orig_shape] = new_shape elif symbol_map[orig_shape] != new_shape: raise RuntimeError( f"There exists multiple instances of the same symbol {orig_shape}\ with different values in inputs (ex: {symbol_map[orig_shape]} and {new_shape})" )