# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Dict, Any, Callable
import zipfile
import os
import json
from dataclasses import dataclass
import tempfile
from tabulate import tabulate
import numpy
import hidet
from hidet.ffi.utils import ctypes_func_pointer
from hidet.ffi.array import Array
from hidet.ir.type import void_p, data_type
from hidet.ir.dtypes import i32, i64
from hidet.runtime.device import Device
from hidet.runtime.compiled_module import CompiledModule
from hidet.runtime.compiled_task import CompiledTask, TensorSignature, _check_inputs
from hidet.runtime.storage import Storage
from hidet.ffi import runtime_api
from hidet.utils.py import prod, median
from hidet.utils.trace_utils import TraceEventEmitter
ModelExecutionHook = Callable[[int, List['Tensor'], List['Tensor']], None]
global_cuda_workspace: Optional[Storage] = None
class ExternalStorage(Storage):
def __init__(self, device: str, addr: int, num_bytes: int):
super().__init__(Device(device), addr, num_bytes, lambda x: x)
@dataclass
class GraphMetaData:
inputs: List[TensorSignature]
outputs: List[TensorSignature]
hidet_version: str
num_kernels: int
graph_hash: str
share_map: Dict[int, int]
@dataclass
class GraphExecutionInstruction:
task_idx: int
inputs: List[int]
outputs: List[int]
free: List[int]
@dataclass
class GraphExecution:
weights_index: List[int]
inputs_index: List[int]
instructions: List[GraphExecutionInstruction]
outputs_index: List[int]
tensor_device: List[str]
[docs]class CompiledGraph:
"""
A compiled graph that can be directly called in Python.
This class should not be instantiated directly. Instead, use :func:`load_compiled_graph` to load a compiled graph
from disk, or build a compiled graph from :class:`FlowGraph` using :func:`hidet.drivers.build_flow_graph`.
Parameters
----------
meta: GraphMetaData
The meta-data of the graph.
graph_module: CompiledModule
The graph compiled module that contains execution logic of the computation graph.
weights: List[hidet.Tensor]
The weights of the graph.
compiled_tasks: List[CompiledTask]
The compiled tasks of the graph that correspond to the operators in the computation graph.
graph_execution: GraphExecution
The execution plan of the graph (the order and connections of the compiled tasks).
graph_string: str
The string representation of the computation graph.
"""
def __init__(
self,
meta: GraphMetaData,
graph_module: CompiledModule,
weights,
compiled_tasks: List[CompiledTask],
graph_execution: GraphExecution,
graph_string: str,
):
import torch
from hidet.graph.tensor import Tensor
# graph module functions
self._init = graph_module['init']
self._get_output_shape = graph_module['get_output_shape']
self._set_workspace = graph_module['set_workspace']
self._get_workspace_size = graph_module['get_workspace_size']
self._launch = graph_module['launch']
# graph assets
self.meta: GraphMetaData = meta
self.graph_module: CompiledModule = graph_module
self.weights: List[Tensor] = weights
self.weights_torch: List[torch.Tensor] = [w.torch() for w in weights]
self.compiled_tasks: List[CompiledTask] = compiled_tasks
self.graph_execution: GraphExecution = graph_execution
self.graph_string: str = graph_string
# derived properties
self.dynamic_dims: List[Tuple[str, Tuple[int, int]]] = [] # [(name, (tensor_index, dim_index))]
self.is_dynamic: bool = False
self._init_dynamic_dims()
self.cpu_space_size, self.cuda_space_size = self._init_space_sizes()
# runtime state
self.working_dir: str = hidet.utils.cache_file('graphs', self.meta.graph_hash)
self.dispatch_table_path = hidet.utils.cache_file('graphs', self.meta.graph_hash, 'dispatch_table.txt')
self.dispatch_table: Dict[Tuple[int, ...], Array] = {}
self.cpu_workspace: Optional[Storage] = None
self.cuda_workspace: Optional[Storage] = None
self.hip_workspace: Optional[Storage] = None
if len(self.weights) == len(graph_execution.weights_index):
# the weights are already loaded, initialize the graph directly
self._init_compiled_graph()
def __getstate__(self):
# Create a temporary file and save the CompiledGraph zip in it
with tempfile.NamedTemporaryFile() as temp_file:
self.save(temp_file.name, save_dispatch_table=True)
with open(temp_file.name, 'rb') as f:
state = f.read()
return state
def __setstate__(self, state):
# Load the CompiledGraph
with tempfile.NamedTemporaryFile() as temp_file:
with open(temp_file.name, 'wb') as f:
f.write(state)
self.__dict__.update(load_compiled_graph(temp_file.name).__dict__)
def __str__(self):
"""
Get the basic information of this compiled graph.
Returns
-------
ret: str
The human readable basic information.
"""
rows = []
for i, sig in enumerate(self.meta.inputs):
dtype = data_type(sig.dtype)
if i == 0:
head = 'input'
else:
head = ''
rows.append([head, dtype.short_name + str(sig.shape)])
for i, sig in enumerate(self.meta.outputs):
dtype = data_type(sig.dtype)
if i == 0:
head = 'output'
else:
head = ''
rows.append([head, dtype.short_name + str(sig.shape)])
weight_size = sum(w.nbytes for w in self.weights)
rows.append(['weights', '{:.3f} GiB'.format(weight_size / 1024 / 1024 / 1024)])
rows.append(['parameters', '{}'.format(sum(prod(x.shape) for x in self.weights))])
return tabulate(rows, colalign=('right', 'left'), tablefmt='simple')
def __call__(self, *args):
"""
Run the model asynchronously with the given inputs.
Parameters
----------
args: Sequence[hidet.Tensor]
The input tensors.
Returns
-------
ret: Union[hidet.Tensor, List[hidet.Tensor]]
The output tensor(s).
"""
outs = self.run_async(args)
if len(outs) == 1:
return outs[0]
else:
return outs
def _init_dynamic_dims(self):
# initialize the derived properties
for tensor_index, sig in enumerate(self.meta.inputs):
for dim_index, dim in enumerate(sig.shape):
if isinstance(dim, str) and dim not in [v for v, _ in self.dynamic_dims]:
self.dynamic_dims.append((dim, (tensor_index, dim_index)))
if len(self.dynamic_dims) > 0 or any(isinstance(dim, str) for sig in self.meta.outputs for dim in sig.shape):
self.is_dynamic = True
else:
self.is_dynamic = False
def _init_compiled_graph(self):
# initialize weights
weights_buffer = Array(void_p, len(self.weights))
for i in range(len(self.weights)):
weights_buffer[i] = self.weights[i].storage.addr
self._init(len(self.weights), weights_buffer)
# load the dispatch table
if os.path.exists(self.dispatch_table_path):
with open(self.dispatch_table_path, 'r') as f:
lines = f.readlines()
for idx, line in enumerate(lines):
if idx == 0:
continue # skip the header line
items = line.split()
if len(items) == 0:
continue # skip empty lines
if len(items) != len(self.dynamic_dims) + len(self.compiled_tasks):
raise RuntimeError('Invalid dispatch table')
items = [int(item) for item in items]
symbol_dims = items[: len(self.dynamic_dims)]
schedule_indices = items[len(self.dynamic_dims) :]
kernel_array = Array(void_p, len(self.compiled_tasks))
for task_idx, (compiled_task, sch_idx) in enumerate(zip(self.compiled_tasks, schedule_indices)):
if not 0 <= sch_idx < len(compiled_task.candidates):
raise RuntimeError(
'Invalid schedule index {} for compiled task at {}'.format(
sch_idx, compiled_task.task_dir
)
)
kernel_array[task_idx] = ctypes_func_pointer(compiled_task.candidates[sch_idx].ctypes_func)
self.dispatch_table[tuple(symbol_dims)] = kernel_array
def _init_space_sizes(self):
if self.is_dynamic:
return (None, None)
buffer = Array(i64, 2)
self._get_workspace_size(buffer)
return list(buffer)
def _update_symbol_table(self, symbol_dims: Tuple[int, ...], best_candidates: List[int]):
kernel_array = Array(void_p, len(self.compiled_tasks))
for task_idx, best_candidate in enumerate(best_candidates):
kernel_array[task_idx] = ctypes_func_pointer(
self.compiled_tasks[task_idx].candidates[best_candidate].ctypes_func
)
self.dispatch_table[symbol_dims] = kernel_array
if not os.path.exists(self.dispatch_table_path):
with open(self.dispatch_table_path, 'w') as f:
symbol_names = [name for name, _ in self.dynamic_dims]
f.write(' '.join(symbol_names))
f.write('\n')
with open(self.dispatch_table_path, 'a') as f:
f.write(' '.join(str(x) for x in symbol_dims))
f.write(' ')
f.write(' '.join(str(x) for x in best_candidates))
f.write('\n')
def _update_symbol_dims(self, inputs) -> Tuple[int, ...]:
symbol_dims = []
for name, (tensor_index, dim_index) in self.dynamic_dims:
symbol_dims.append(inputs[tensor_index].shape[dim_index])
runtime_api.set_symbol_value(name, symbol_dims[-1])
return tuple(symbol_dims)
def _create_outputs(self, inputs, output_to_torch_tensor):
from torch import empty as torch_empty
from torch import device as torch_device
from torch import Tensor as TorchTensor
from hidet.graph.tensor import empty
from hidet.graph.tensor import Tensor as HidetTensor
from hidet.graph.frontend.torch.utils import dtype_to_torch
outputs = []
exec_idx_to_output_idx: Dict[int, int] = {}
for output_index, (exec_idx, sig) in enumerate(zip(self.graph_execution.outputs_index, self.meta.outputs)):
if exec_idx in self.graph_execution.inputs_index:
# the graph directly returns an input tensor
outputs.append(inputs[self.graph_execution.inputs_index.index(exec_idx)])
elif exec_idx in self.graph_execution.weights_index:
# the graph directly returns a weight tensor
if output_to_torch_tensor:
outputs.append(self.weights_torch[self.graph_execution.weights_index.index(exec_idx)])
else:
outputs.append(self.weights[self.graph_execution.weights_index.index(exec_idx)])
elif exec_idx in exec_idx_to_output_idx:
# the graph returns the same tensor multiple times
outputs.append(outputs[exec_idx_to_output_idx[exec_idx]])
else:
# get the shape of output tensor
if self.is_dynamic:
shape_buffer = Array(i32, len(sig.shape))
self._get_output_shape(output_index, shape_buffer)
shape = list(shape_buffer)
else:
shape = sig.shape
if output_index not in self.meta.share_map:
# create the output tensor
if output_to_torch_tensor:
torch_dtype = dtype_to_torch(data_type(sig.dtype))
torch_dev = torch_device(sig.device)
outputs.append(torch_empty(size=shape, dtype=torch_dtype, device=torch_dev))
else:
outputs.append(empty(shape=shape, dtype=sig.dtype, device=sig.device))
else:
# this output tensor shares the storage with one input tensor, reuse the storage
if output_to_torch_tensor:
input_tensor: TorchTensor = inputs[self.meta.share_map[output_index]]
assert isinstance(input_tensor, TorchTensor)
outputs.append(input_tensor.view(shape))
else:
input_tensor: HidetTensor = inputs[self.meta.share_map[output_index]]
outputs.append(
HidetTensor(shape=shape, dtype=sig.dtype, device=sig.device, storage=input_tensor.storage)
)
# record the exec_idx of this output tensor, in case the graph returns the same tensor multiple times
exec_idx_to_output_idx[exec_idx] = output_index
return outputs
def _prepare_workspace(self):
if self.is_dynamic:
buffer = Array(i64, 3)
self._get_workspace_size(buffer)
required_cpu_workspace, required_cuda_workspace, required_hip_workspace = list(buffer)
else:
required_cpu_workspace = self.cpu_space_size
required_cuda_workspace = self.cuda_space_size
if self.cpu_workspace is None or self.cpu_workspace.num_bytes < required_cpu_workspace:
self.cpu_workspace = Storage.new('cpu', required_cpu_workspace)
self._set_workspace(0, self.cpu_workspace.addr)
global global_cuda_workspace
if global_cuda_workspace is not None and global_cuda_workspace.num_bytes < required_cuda_workspace:
global_cuda_workspace.__del__()
global_cuda_workspace = None
if global_cuda_workspace is None:
global_cuda_workspace = Storage.new('cuda', required_cuda_workspace)
self._set_workspace(1, global_cuda_workspace.addr)
if hidet.hip.available() and (
self.hip_workspace is None or self.hip_workspace.num_bytes < required_hip_workspace
):
self.hip_workspace = Storage.new('hip', required_hip_workspace)
self._set_workspace(2, self.hip_workspace.addr)
def _run_fast_path(self, inputs, symbol_dims: Tuple[int, ...], output_to_torch_tensor):
# create output tensors
outputs = self._create_outputs(inputs, output_to_torch_tensor)
# prepare workspace
self._prepare_workspace()
# run the kernels
kernel_array = self.dispatch_table[symbol_dims]
self._launch(*inputs, *outputs, kernel_array)
return outputs
def _run_slow_path(self, inputs, symbol_dims: Tuple[int, ...]):
"""Interpret the graph execution"""
from hidet.graph.tensor import Tensor
index2tensor: Dict[int, Tensor] = {}
exe = self.graph_execution
for i in range(len(inputs)):
index2tensor[exe.inputs_index[i]] = inputs[i]
for i in range(len(self.weights)):
index2tensor[exe.weights_index[i]] = self.weights[i]
best_candidates = [-1 for _ in range(len(self.compiled_tasks))]
trace_emitter = TraceEventEmitter({'graph': self.graph_string})
for inst in exe.instructions:
# prepare inputs and kernel
node_inputs = [index2tensor[i] for i in inst.inputs]
node_kernel: CompiledTask = self.compiled_tasks[inst.task_idx]
# run the kernel
node_outputs = node_kernel.run_async(node_inputs)
# record outputs
for i, output_index in enumerate(inst.outputs):
index2tensor[output_index] = node_outputs[i]
# record best candidate for this kernel
best_candidates[inst.task_idx] = node_kernel.pick_best_candidate(node_inputs, node_outputs)
# record trace events
trace_emitter.append(
name=node_kernel.meta_data.name,
duration_us=int(median(node_kernel.profile(*node_inputs, *node_outputs)) * 1000),
args={
'name': node_kernel.meta_data.name,
'inputs': ['{}{}'.format(x.dtype, x.shape) for x in node_kernel.meta_data.inputs],
'outputs': ['{}{}'.format(x.dtype, x.shape) for x in node_kernel.meta_data.outputs],
},
)
# free tensors that are no longer needed
for idx in inst.free:
del index2tensor[idx]
outputs = [index2tensor[i] for i in exe.outputs_index]
# update the dispatch table
self._update_symbol_table(symbol_dims, best_candidates)
# save the trace
trace_filename = 'trace{}.json'.format('_'.join(str(x) for x in symbol_dims))
with open(os.path.join(self.working_dir, trace_filename), 'w') as f:
trace_emitter.save(f)
return outputs
[docs] def set_weights(self, weights):
"""
Set the weights of the model.
When the weights exist in the model file, the user does not need to set the weights manually.
However, when the weights are not saved in the model file, the user needs to set the weights manually before
running the model.
Parameters
----------
weights: List[hidet.Tensor]
The weights to set.
"""
from hidet.runtime.device import instantiate_device
if len(self.weights) == len(self.graph_execution.weights_index):
raise RuntimeError('The weights are already set.')
if len(weights) != len(self.graph_execution.weights_index):
raise ValueError('Expect {} weights, got {}.'.format(len(self.graph_execution.weights_index), len(weights)))
if any(not isinstance(w, hidet.Tensor) for w in weights):
raise ValueError('Expect all weights to be hidet.Tensor, got {}'.format([type(w) for w in weights]))
for idx, weight in enumerate(weights):
expected_device = instantiate_device(
self.graph_execution.tensor_device[self.graph_execution.weights_index[idx]]
)
if expected_device != weight.device:
raise ValueError(
'Expect weight {} to be on device {}, got {}.'.format(idx, expected_device, weight.device)
)
self.weights = weights
self._init_compiled_graph()
[docs] def run_async(self, inputs, output_to_torch_tensor=False):
"""
Run the model asynchronously.
Parameters
----------
inputs: Sequence[hidet.Tensor]
The input tensors.
Returns
-------
ret: List[hidet.Tensor]
The output tensors.
"""
if hidet.option.get_runtime_check():
_check_inputs(self.meta.inputs, inputs)
if len(self.weights) != len(self.graph_execution.weights_index):
raise RuntimeError('Please set the weights before running the model with compiled_graph.set_weights(...).')
symbol_dims = self._update_symbol_dims(inputs)
if symbol_dims in self.dispatch_table:
return self._run_fast_path(inputs, symbol_dims, output_to_torch_tensor)
else:
res = self._run_slow_path(inputs, symbol_dims)
if output_to_torch_tensor:
res = [tensor.torch() if isinstance(tensor, hidet.Tensor) else tensor for tensor in res]
return res
[docs] def cuda_graph(self, *args):
"""
Create a CUDA graph for this compiled graph.
Parameters
----------
args: Sequence[hidet.Tensor]
The input tensors. Weight tensors are excluded from args.
Returns
-------
cuda_graph: hidet.cuda.graph.CudaGraph
The CUDA graph.
"""
import torch
from hidet.cuda.graph import CudaGraph, CudaGraphCreationError
from hidet.graph.tensor import Tensor
for x in self.meta.inputs + self.meta.outputs:
if x.device == 'cpu':
raise CudaGraphCreationError(f'Cannot create CUDA graph for a model with CPU inputs:\n {x}')
for d in x.shape:
if not isinstance(d, int):
raise CudaGraphCreationError(f'Cannot create CUDA graph for a model with dynamic inputs:\n {x}')
if any(device == 'cpu' for device in self.graph_execution.tensor_device):
raise CudaGraphCreationError('Cannot create CUDA graph for a model with CPU tensors.')
for ctask in self.compiled_tasks:
if len(ctask.meta_data.symbols) > 0:
raise CudaGraphCreationError('Cannot create CUDA graph for a model with dynamic symbols.')
def f_create_inputs() -> List[Tensor]:
with hidet.option.context():
hidet.option.execution_mode('compilation')
inputs = []
for arg in args:
arg = hidet.from_torch(arg) if isinstance(arg, torch.Tensor) else arg
inputs.append(hidet.randn_like(arg))
return inputs
def f_run(inputs: List[Tensor]) -> List[Tensor]:
return self.run_async(inputs)
global global_cuda_workspace
# clear the workspace to avoid the storage being captured by the CUDA graph.
global_cuda_workspace = None
return CudaGraph(f_create_inputs, f_run, ref_objs=[self])
[docs] def save(self, path: str, save_dispatch_table: bool = False):
"""
Save the compiled graph to disk.
See Also
--------
load_compiled_graph
Parameters
----------
path: str
The path to save the compiled graph. By convention, the path should end with '.hidet'.
save_dispatch_table:
Whether to save the dispatch table to disk. See `save_compiled_graph` for details.
"""
save_compiled_graph(self, path, save_dispatch_table)
[docs]def save_compiled_graph(model: CompiledGraph, file: str, save_dispatch_table: bool = False, save_weights: bool = True):
"""
Save the compiled graph to disk.
Parameters
----------
model: CompiledGraph
The compiled graph to save.
file: str
The path to save the compiled graph. By convention, the path should end with '.hidet'.
save_dispatch_table:
Whether to save the dispatch table to disk.
When we run the model that contains alternative kernels for the same operator, we will pick the best kernel
by benchmarking all the alternatives. The dispatch table is used to record the best kernel for the given
input shapes. If the dispatch table is not saved, we will benchmark all the alternatives again when we load
the model next time.
Default: False
save_weights:
Whether to save the weights to disk. If False, the weights will not be saved, and the users can save the
weights separately. This is useful when we want to save the weights separately.
Default: True
"""
from hidet.utils.dataclass import asdict
dirname = os.path.dirname(file)
os.makedirs(dirname, exist_ok=True)
with tempfile.NamedTemporaryFile(dir=dirname, delete=False) as temp_file:
temp_path = temp_file.name
with zipfile.ZipFile(temp_path, 'w') as zf:
def _save_under(dir_path: str, dir_in_zip: str, exclude: Optional[List[str]] = None):
for root, _, files in os.walk(dir_path):
for file in files:
file_path = os.path.join(root, file)
file_in_zip = os.path.join(dir_in_zip, os.path.relpath(file_path, dir_path))
with zf.open(file_in_zip, 'w') as f1:
if exclude and file in exclude:
continue
with open(file_path, 'rb') as f2:
f1.write(f2.read())
# meta info
with zf.open('meta.json', 'w') as f:
meta_bytes = json.dumps(asdict(model.meta), indent=4).encode('utf-8')
f.write(meta_bytes)
# save the modules
_save_under(model.graph_module.module_dir, 'graph_module/')
# save weights
if save_weights:
# zip.open(..., force_zip64=True) is required for >4GB weights
with zf.open('weights.npz', 'w', force_zip64=True) as f:
numpy.savez(f, *[weight.cpu().numpy() for weight in model.weights])
# save the kernels (i.e., compiled tasks)
for i, compiled_task in enumerate(model.compiled_tasks):
_save_under(compiled_task.task_dir, 'kernels/{}/'.format(i))
# save graph execution
with zf.open('graph_execution.json', 'w') as f:
ge_bytes = json.dumps(asdict(model.graph_execution), indent=4).encode('utf-8')
f.write(ge_bytes)
# save dispatch table file
if save_dispatch_table and os.path.exists(model.dispatch_table_path):
with zf.open('dispatch_table.txt', 'w') as f:
with open(model.dispatch_table_path, 'rb') as f2:
f.write(f2.read())
# save graph string
with zf.open('graph_string.txt', 'w') as f:
f.write(model.graph_string.encode('utf-8'))
os.rename(temp_path, file)
[docs]def load_compiled_graph(path: str) -> CompiledGraph:
"""
Load a compiled graph from disk.
The compiled graph is saved with zip format. The path can be either a single file to the zip file, or a directory
that contains the contents of the zip file.
Parameters
----------
path: str
The path to load the compiled graph (can be either a single file or a directory).
Returns
-------
ret: CompiledGraph
The loaded compiled graph.
"""
from hidet.utils.dataclass import from_dict
if os.path.isfile(path):
with zipfile.ZipFile(path, 'r') as zf:
# load meta data
with zf.open('meta.json', 'r') as f:
meta_data: GraphMetaData = from_dict(GraphMetaData, json.load(f))
# extract all files except weights
files_to_extract: List[str] = zf.namelist()
if 'weights.npz' in files_to_extract:
files_to_extract.remove('weights.npz')
cache_dir = hidet.utils.cache_dir('graphs', meta_data.graph_hash)
if not os.path.exists(os.path.join(cache_dir, 'graph_string.txt')):
# only extract files if the graph_string.txt is not in the cache
# here 'graph_string.txt' is just the last file we usually save to disk, we use it as a flag
# to indicate whether the graph is already in the cache
zf.extractall(cache_dir, files_to_extract)
graph_path = cache_dir
else:
graph_path = path
# load meta data
with open(os.path.join(graph_path, 'meta.json'), 'r') as f:
meta_data: GraphMetaData = from_dict(GraphMetaData, json.load(f))
# load graph execution
with open(os.path.join(graph_path, 'graph_execution.json'), 'r') as f:
graph_execution: GraphExecution = from_dict(GraphExecution, json.load(f))
# load weights if it exists
weights = []
def load_weights_from_npz(npz: zipfile.ZipFile):
for weight_idx, name in enumerate(npz.namelist()):
with npz.open(name, 'r') as npy_file:
npy_file: Any # used to suppress type checker warning
device = graph_execution.tensor_device[graph_execution.weights_index[weight_idx]]
weights.append(hidet.asarray(numpy.load(npy_file), device=device))
if os.path.exists(os.path.join(graph_path, 'weights.npz')):
with zipfile.ZipFile(os.path.join(graph_path, 'weights.npz'), 'r') as npz:
load_weights_from_npz(npz)
elif os.path.isfile(path):
with zipfile.ZipFile(path, 'r') as zf:
if 'weights.npz' in zf.namelist():
# weights are loaded directly from the zip file to memory
# avoid extracting the weights to disk and then loading them from disk
with zf.open('weights.npz', 'r') as f:
with zipfile.ZipFile(f, 'r') as npz:
load_weights_from_npz(npz)
# load kernels (i.e., compiled tasks)
num_kernels = meta_data.num_kernels
compiled_tasks = [CompiledTask(task_dir=os.path.join(graph_path, 'kernels', str(i))) for i in range(num_kernels)]
# load graph module
graph_module = CompiledModule(module_dir=os.path.join(graph_path, 'graph_module'))
# load graph string
with open(os.path.join(graph_path, 'graph_string.txt'), 'r') as f:
graph_string = f.read()
# construct the compiled graph
ret = CompiledGraph(meta_data, graph_module, weights, compiled_tasks, graph_execution, graph_string)
return ret