# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Callable, Tuple, Any, Dict, Union
import time
from dataclasses import dataclass
import numpy as np
# copied from: https://github.com/openai/triton/blob/main/python/triton/testing.py
def do_bench(fn, warmup=25, rep=100, percentiles=(0.2, 0.5, 0.8)):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
:param fn: Function to benchmark
:type fn: Callable
:param warmup: Warmup time (in ms)
:type warmup: int
:param rep: Repetition time (in ms)
:type rep: int
:param percentiles: Performance percentile to return in addition to the median.
:type percentiles: list[float]
"""
# Estimate the runtime of the function
import torch
fn()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(5):
fn()
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5
# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))
# We maintain a buffer of 256 MB that we clear
# before each kernel call to make sure that the L2
# doesn't contain any input data before the run
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
# Warm-up
for _ in range(n_warmup):
fn()
# Benchmark
for i in range(n_repeat):
# we clear the L2 cache before each run
cache.zero_()
# record time of `fn`
start_event[i].record()
fn()
end_event[i].record()
# Record clocks
torch.cuda.synchronize()
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)])
if percentiles:
percentiles = torch.quantile(times, torch.tensor(percentiles)).tolist()
return tuple(percentiles)
else:
return torch.mean(times).item()
[docs]def benchmark_func(run_func, warmup=1, number=5, repeat=5, median=True) -> Union[List[float], float]:
"""Benchmark given function.
The given function ``run_func`` will be executed :math:`warmup + repeat * number` times. Each :math:`number` times
of execution will be grouped and conducted together.
Parameters
----------
run_func: Callable[[], Any]
Any callable function to be benchmarked.
warmup: int
The number of warm-up executions.
number: int
The number of executions to be grouped for measurement.
repeat: int
The number of repeat times of the group measurement.
median: bool
Whether the median latency is returned, instead of the latency.
Returns
-------
ret: Union[float, List[float]]
- When median == True, a single latency number is returned.
- When median == False, the latency of each repeat is returned, as a list of floats.
"""
import nvtx
import hidet.cuda
results = []
with nvtx.annotate('warmup'):
for _ in range(warmup):
run_func()
hidet.cuda.synchronize()
for i in range(repeat):
with nvtx.annotate(f'repeat {i}'):
hidet.cuda.synchronize()
start_time = time.time()
for _ in range(number):
run_func()
hidet.cuda.synchronize()
end_time = time.time()
results.append((end_time - start_time) * 1000 / number)
if median:
return float(np.median(results))
else:
return results
@dataclass
class BenchData:
x_vals: List[Any]
x_name: str
y_name: str
kwargs: Dict[str, Any]
data: Dict[str, Tuple[List[float], List[float], List[float]]] # [t_min, t_avg, t_max]
def show_plot(self, show=True, save_path=None, figsize=None, title=None):
from matplotlib import pyplot as plt
if all(isinstance(x, (float, int)) for x in self.x_vals):
x_vals = self.x_vals
else:
x_vals = range(1, len(self.x_vals) + 1)
plt.figure(figsize=figsize)
ax = plt.subplot()
for name, (t_min, t_avg, t_max) in self.data.items():
p = ax.plot(x_vals, t_avg, label=name)
color = p[0].get_color()
ax.fill_between(x_vals, t_min, t_max, alpha=0.15, color=color)
ax.legend()
ax.set_xlabel(self.x_name)
ax.set_ylabel(self.y_name)
if title is not None:
ax.set_title(title)
ax.set_xticks(ticks=x_vals, labels=[str(x) for x in self.x_vals])
if show:
plt.show()
if save_path is not None:
plt.savefig(save_path)
return self
def to_dataframe(self):
import pandas as pd
columns = list(self.data.keys())
df = pd.DataFrame(columns=columns, index=self.x_vals)
for n in columns:
df[n] = self.data[n][1] # get t_avg
return df
def print_data(self):
print(self.to_dataframe())
class Bench:
def __init__(self, x_vals: List[Any], x_name: str, **kwargs):
self.x_vals = x_vals
self.x_name = x_name
self.y_name = 'ms'
self.byte_fn = None
self.kwargs: Dict[str, Any] = kwargs
self.bench_fns: List[Tuple[str, Callable]] = []
self.bench_data: Dict[str, Tuple[List[float], List[float], List[float]]] = {}
def measure_flops(self, byte_fn: Callable[[Any], int]):
"""
set a function that takes in the config, and the current x_val and returns the number of bytes
"""
self.byte_fn = byte_fn
self.y_name = 'TFLOP/s'
def bench(self, fn: Callable[[Any], Callable[[], Any]], name: Optional[str] = None):
"""
add a function that takes in the config and int and returns a function to be benchmarked
to the list of functions to be benchmarked.
If the name argument is None, the the name for this particular line is fn.__name__
"""
if name is None:
if hasattr(fn, '__name__'):
name = fn.__name__
else:
raise ValueError("cannot get name of function")
self.bench_fns.append((name, fn))
return self
def run(self):
"""
run all the functions that needs to be benchmarked, returning BenchData representing
the collected results
"""
for i in self.x_vals:
for name, fn in self.bench_fns:
if name not in self.bench_data:
self.bench_data[name] = ([], [], [])
t_min, t_avg, t_max = self.bench_data[name]
bench_fn = fn(i, **self.kwargs)
lo, avg, hi = do_bench(bench_fn)
if self.byte_fn is not None:
lo = self.byte_fn(i, **self.kwargs) * 1e-12 / (lo * 1e-3)
avg = self.byte_fn(i, **self.kwargs) * 1e-12 / (avg * 1e-3)
hi = self.byte_fn(i, **self.kwargs) * 1e-12 / (hi * 1e-3)
t_min.append(lo)
t_avg.append(avg)
t_max.append(hi)
return BenchData(self.x_vals, self.x_name, self.y_name, self.kwargs, self.bench_data)