Source code for hidet.cuda.device
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=no-name-in-module, c-extension-no-member
from typing import Tuple, Optional
from functools import lru_cache
from cuda import cudart
from cuda.cudart import cudaDeviceProp
class CudaDeviceContext:
def __init__(self, device_id: int):
self.device_id: int = device_id
self.prev_device_id: Optional[int] = None
def __enter__(self):
self.prev_device_id = current_device()
set_device(self.device_id)
def __exit__(self, exc_type, exc_val, exc_tb):
set_device(self.prev_device_id)
[docs]@lru_cache(maxsize=None)
def available() -> bool:
"""
Returns True if CUDA is available, False otherwise.
Use ctypes to check if libcuda.so is available instead of calling cudart directly.
Returns
-------
ret: bool
Whether CUDA is available.
"""
import ctypes.util
if ctypes.util.find_library('cuda'):
return True
return False
[docs]@lru_cache(maxsize=None)
def device_count() -> int:
"""
Get the number of available CUDA devices.
Returns
-------
count: int
The number of available CUDA devices.
"""
err, count = cudart.cudaGetDeviceCount()
assert err == 0, err
return count
[docs]@lru_cache(maxsize=None)
def properties(device_id: int = 0) -> cudaDeviceProp:
"""
Get the properties of a CUDA device.
Parameters
----------
device_id: int
The ID of the device.
Returns
-------
prop: cudaDeviceProp
The properties of the device.
"""
err, prop = cudart.cudaGetDeviceProperties(device_id)
assert err == 0, err
return prop
[docs]def set_device(device_id: int):
"""
Set the current cuda device.
Parameters
----------
device_id: int
The ID of the cuda device.
"""
(err,) = cudart.cudaSetDevice(device_id)
assert err == 0, err
[docs]def current_device() -> int:
"""
Get the current cuda device.
Returns
-------
device_id: int
The ID of the cuda device.
"""
err, device_id = cudart.cudaGetDevice()
assert err == 0, err
return device_id
def device(device_id: int):
"""
Context manager to set the current cuda device.
Parameters
----------
device_id: int
The ID of the cuda device.
Examples
--------
>>> import hidet
>>> with hidet.cuda.device(0):
>>> # do something on device 0
"""
return CudaDeviceContext(device_id)
[docs]@lru_cache(maxsize=None)
def compute_capability(device_id: int = 0) -> Tuple[int, int]:
"""
Get the compute capability of a CUDA device.
Parameters
----------
device_id: int
The ID of the device to query.
Returns
-------
(major, minor): Tuple[int, int]
The compute capability of the device.
"""
prop = properties(device_id)
return prop.major, prop.minor
[docs]def synchronize():
"""
Synchronize the host thread with the device.
This function blocks until the device has completed all preceding requested tasks.
"""
(err,) = cudart.cudaDeviceSynchronize()
if err != 0:
raise RuntimeError("cudaDeviceSynchronize failed with error: {}".format(err.name))
[docs]def profiler_start():
"""
Mark the start of a profiling range.
"""
(err,) = cudart.cudaProfilerStart()
assert err == 0, err
[docs]def profiler_stop():
"""
Mark the end of a profiling range.
"""
(err,) = cudart.cudaProfilerStop()
assert err == 0, err