Source code for hidet.cuda.stream

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=no-name-in-module, c-extension-no-member
from __future__ import annotations
from typing import Union, Optional, Dict
from cuda import cudart
from cuda.cudart import cudaStream_t
from hidet.utils import exiting
from .event import Event
from .device import CudaDeviceContext, current_device


def _get_device_id(device) -> int:
    """
    Get the device ID from a device.

    Parameters
    ----------
    device: Device or int, optional
        The device.

    Returns
    -------
    device_id: int
        The device ID.
    """
    if device is None:
        return current_device()
    elif isinstance(device, int):
        return device
    else:
        from hidet.runtime.device import Device

        if isinstance(device, Device):
            assert device.is_cuda(), "device must be a CUDA device"
            if device.id is not None:
                return device.id
            else:
                return current_device()
        else:
            raise TypeError(f"device must be a hidet.Device or int, not {type(device)}")


[docs]class Stream: """ A CUDA stream. Parameters ---------- device: int or hidet.Device, optional The device on which to create the stream. If None, the current device will be used. blocking: bool Whether to enable the implicit synchronization between this stream and the default stream. When enabled, any operation enqueued in the stream will wait for all previous operations in the default stream to complete before beginning execution. priority: int The priority of the stream. The priority is a hint to the CUDA driver that it can use to reorder operations in the stream relative to other streams. The priority can be 0 (default priority) and -1 (high priority). By default, all streams are created with priority 0. """ def __init__(self, device=None, blocking: bool = False, priority: int = 0, **kwargs): from hidet import cuda self._device_id: int self._handle: cudaStream_t self._external: bool if 'handle' in kwargs: self._device_id = _get_device_id(device) self._handle = kwargs['handle'] self._external = True else: self._device_id = _get_device_id(device) with cuda.device(self._device_id): flags = cudart.cudaStreamNonBlocking if not blocking else cudart.cudaStreamDefault err, handle = cudart.cudaStreamCreateWithPriority(flags, priority) assert err == 0, err self._handle = handle self._external = False def __int__(self): return int(self._handle) def __hash__(self): return hash(self._handle) def __eq__(self, other: Stream): if not isinstance(other, Stream): raise TypeError(f"cannot compare Stream with {type(other)}") return self._device_id == other._device_id and self._handle == other._handle def __del__(self, is_exiting=exiting.is_exiting): if is_exiting(): return if not self._external: (err,) = cudart.cudaStreamDestroy(self._handle) assert err == 0, err
[docs] def device_id(self) -> int: """ Get the device ID of the stream. Returns ------- device_id: int The device ID of the stream. """ return self._device_id
[docs] def handle(self) -> cudaStream_t: """ Get the handle of the stream. Returns ------- handle: cudaStream_t The handle of the stream. """ return self._handle
[docs] def synchronize(self) -> None: """ Block the current host thread until the stream completes all operations. """ (err,) = cudart.cudaStreamSynchronize(self._handle) if err != 0: raise RuntimeError("cudaStreamSynchronize failed with error: {}".format(err.name))
[docs] def wait_event(self, event: Event) -> None: """ Let the subsequent operations in the stream wait for the event to complete. The event might be recorded in another stream. The host thread will not be blocked. Parameters ---------- event: Event The event to wait for. """ (err,) = cudart.cudaStreamWaitEvent(self._handle, event.handle(), 0) assert err == 0, err
[docs]class ExternalStream(Stream): """ An external CUDA stream created from a handle. Parameters ---------- handle: int or cudaStream_t The handle of the stream. device_id: int, optional The device ID of the stream. If None, the current device will be used. """ def __init__(self, handle: Union[cudaStream_t, int], device_id: Optional[int] = None): super().__init__(handle=handle, device=device_id)
_current_streams: Dict[int, Stream] = {} class StreamContext: def __init__(self, stream: Stream): # pylint: disable=redefined-outer-name self.device_context = CudaDeviceContext(stream.device_id()) self.device: int = stream.device_id() self.stream: Stream = stream self.prev_stream: Optional[Stream] = None def __enter__(self): from hidet.ffi import runtime_api current_streams = _current_streams self.prev_stream = current_streams[self.device] current_streams[self.device] = self.stream self.device_context.__enter__() runtime_api.set_current_stream(self.stream) def __exit__(self, exc_type, exc_val, exc_tb): from hidet.ffi import runtime_api current_streams = _current_streams current_streams[self.device] = self.prev_stream runtime_api.set_current_stream(self.prev_stream) self.device_context.__exit__(exc_type, exc_val, exc_tb)
[docs]def current_stream(device=None) -> Stream: """ Get the current stream. Parameters ---------- device: int or hidet.Device, optional The device on which to get the current stream. If None, the current device will be used. Returns ------- stream: Stream The current stream. """ device_id = _get_device_id(device) if device_id not in _current_streams: _current_streams[device_id] = ExternalStream(handle=0, device_id=device_id) return _current_streams[_get_device_id(device)]
[docs]def default_stream(device=None) -> Stream: """ Get the default stream. Parameters ---------- device: int or hidet.Device, optional The device on which to get the default stream. If None, the current device will be used. Returns ------- stream: Stream The default stream. """ return ExternalStream(handle=0, device_id=_get_device_id(device))
[docs]def stream(s: Stream) -> StreamContext: """ Set the current stream. Parameters ---------- s: Stream The stream to set. Examples -------- >>> import hidet >>> stream = hidet.cuda.Stream() >>> with hidet.cuda.stream(stream): >>> ... # all hidet cuda kernels will be executed in the stream """ return StreamContext(s)