Skip to content
This repository has been archived by the owner on Mar 12, 2021. It is now read-only.

Commit

Permalink
Merge pull request #544 from JuliaGPU/tb/threads
Browse files Browse the repository at this point in the history
Rework library handles for multithreading.
  • Loading branch information
maleadt authored Dec 20, 2019
2 parents 50c1a3d + 8635a2f commit 16f3bef
Show file tree
Hide file tree
Showing 20 changed files with 372 additions and 157 deletions.
12 changes: 9 additions & 3 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,25 @@ version = "0.2.0"

[[CUDAapi]]
deps = ["Libdl", "Logging"]
git-tree-sha1 = "6eee47385c81ed3b3f716b745697869c712c2df3"
git-tree-sha1 = "ca1c7f639c5f6326919ee2834fa0dffb5002ff60"
repo-rev = "master"
repo-url = "https://github.com/JuliaGPU/CUDAapi.jl.git"
uuid = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
version = "2.0.0"

[[CUDAdrv]]
deps = ["CEnum", "CUDAapi", "Printf"]
git-tree-sha1 = "0f39fddace3324707469ace7fbcbc7b28d5cf921"
git-tree-sha1 = "5c2cf00a78503e1f71409cecf3d64508fb33f17f"
repo-rev = "master"
repo-url = "https://github.com/JuliaGPU/CUDAdrv.jl.git"
uuid = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
version = "4.0.4"

[[CUDAnative]]
deps = ["Adapt", "CEnum", "CUDAapi", "CUDAdrv", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "Printf", "TimerOutputs"]
git-tree-sha1 = "a67b38619d1fa131027bac1c4a81f0012254d1fd"
git-tree-sha1 = "8b1a585344fee94bdb95ac44653fd057d74e32e6"
repo-rev = "master"
repo-url = "https://github.com/JuliaGPU/CUDAnative.jl.git"
uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
version = "2.6.0"

Expand Down
26 changes: 0 additions & 26 deletions src/CuArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@ include("linalg.jl")

include("gpuarray_interface.jl")

# many libraries need to be initialized per-device (per-context, really, but we assume users
# of CuArrays and/or CUDAnative only use a single context), so keep track of the active one.
const active_context = Ref{CuContext}()

include("blas/CUBLAS.jl")
include("sparse/CUSPARSE.jl")
include("solver/CUSOLVER.jl")
Expand Down Expand Up @@ -112,28 +108,6 @@ function __init__()
# package integrations
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" include("forwarddiff.jl")

# update the active context when we switch devices
callback = (::CuDevice, ctx::CuContext) -> begin
active_context[] = ctx

# wipe the active handles
CUBLAS._handle[] = C_NULL
CUBLAS._xt_handle[] = C_NULL
CUSOLVER._dense_handle[] = C_NULL
CUSOLVER._sparse_handle[] = C_NULL
CUSPARSE._handle[] = C_NULL
CURAND._generator[] = nothing
CUDNN._handle[] = C_NULL
CUTENSOR._handle[] = nothing
end
push!(CUDAnative.device!_listeners, callback)

# a device might be active already
existing_ctx = CUDAdrv.CuCurrentContext()
if existing_ctx !== nothing
active_context[] = existing_ctx
end

__init_memory__()

__initialized__[] = true
Expand Down
57 changes: 37 additions & 20 deletions src/blas/CUBLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ using CUDAapi
using CUDAdrv
using CUDAdrv: CUstream

import CUDAnative
using CUDAnative

using ..CuArrays
using ..CuArrays: active_context, unsafe_free!
using ..CuArrays: unsafe_free!
using LinearAlgebra

using CEnum
Expand All @@ -27,45 +27,62 @@ include("wrappers.jl")
# high-level integrations
include("linalg.jl")

const _handles = Dict{CuContext,cublasHandle_t}()
const _xt_handles = Dict{CuContext,cublasXtHandle_t}()
const _handle = Ref{cublasHandle_t}(C_NULL)
const _xt_handle = Ref{cublasXtHandle_t}(C_NULL)
const created_handles = IdDict{CuContext,cublasHandle_t}()
const created_xt_handles = IdDict{CuContext,cublasXtHandle_t}()
const active_handles = Vector{Union{Nothing,cublasHandle_t}}()
const active_xt_handles = Vector{Union{Nothing,cublasXtHandle_t}}()

function handle()
if _handle[] == C_NULL
CUDAnative.maybe_initialize("CUBLAS")
_handle[] = get!(_handles, active_context[]) do
context = active_context[]
tid = Threads.threadid()
if @inbounds active_handles[tid] === nothing
ctx = context()
active_handles[tid] = get!(created_handles, ctx) do
handle = cublasCreate_v2()
atexit(()->CUDAdrv.isvalid(ctx) && cublasDestroy_v2(handle))

# enable tensor math mode if our device supports it, and fast math is enabled
dev = CUDAdrv.device(context)
dev = CUDAdrv.device()
if Base.JLOptions().fast_math == 1 && CUDAdrv.capability(dev) >= v"7.0" && version() >= v"9"
cublasSetMathMode(CUBLAS_TENSOR_OP_MATH, handle)
end

atexit(()->CUDAdrv.isvalid(context) && cublasDestroy_v2(handle))
handle
end
end

return _handle[]
@inbounds active_handles[tid]
end

function xt_handle()
if _xt_handle[] == C_NULL
@assert isassigned(active_context) # some other call should have initialized CUDA
_xt_handle[] = get!(_xt_handles, active_context[]) do
context = active_context[]
tid = Threads.threadid()
if @inbounds active_xt_handles[tid] === nothing
ctx = context()
active_xt_handles[tid] = get!(created_xt_handles, ctx) do
handle = cublasXtCreate()
atexit(()->CUDAdrv.isvalid(ctx) && cublasXtDestroy(handle))

# select the devices
# TODO: this is weird, since we typically use a single device per thread/context
devs = convert.(Cint, CUDAdrv.devices())
cublasXtDeviceSelect(handle, length(devs), devs)
atexit(()->CUDAdrv.isvalid(context) && cublasXtDestroy(handle))

handle
end
end
return _xt_handle[]
@inbounds active_xt_handles[tid]
end

function __init__()
resize!(active_handles, Threads.nthreads())
fill!(active_handles, nothing)

resize!(active_xt_handles, Threads.nthreads())
fill!(active_xt_handles, nothing)

CUDAnative.atcontextswitch() do tid, ctx
# we don't eagerly initialize handles, but do so lazily when requested
active_handles[tid] = nothing
active_xt_handles[tid] = nothing
end
end

end
35 changes: 28 additions & 7 deletions src/blas/error.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,34 @@ function status_message(status)
end
end

macro check(blas_func)

## API call wrapper

# API calls that are allowed without a functional context
const preinit_apicalls = Set{Symbol}([
:cublasGetVersion,
:cublasGetProperty,
:cublasGetCudartVersion
])

# outlined functionality to avoid GC frame allocation
@noinline function throw_api_error(res)
throw(CuError(res))
end

macro check(ex)
fun = Symbol(decode_ccall_function(ex))
init = if !in(fun, preinit_apicalls)
:(CUDAnative.maybe_initialize())
end
quote
local err::cublasStatus_t
err = $(esc(blas_func::Expr))
if err != CUBLAS_STATUS_SUCCESS
throw(CUBLASError(err))
$init

res = $(esc(ex))
if res != CUBLAS_STATUS_SUCCESS
throw_api_error(res)
end
err

return
end
end
end
29 changes: 19 additions & 10 deletions src/dnn/CUDNN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ using CUDAapi: libraryPropertyType
using CUDAdrv
using CUDAdrv: CUstream

import CUDAnative
using CUDAnative

using CEnum

using ..CuArrays
using ..CuArrays: active_context, @argout, @workspace
using ..CuArrays: @argout, @workspace
import ..CuArrays.unsafe_free!

import NNlib
Expand Down Expand Up @@ -41,21 +41,30 @@ include("nnlib.jl")

include("compat.jl")

const _handles = Dict{CuContext,cudnnHandle_t}()
const _handle = Ref{cudnnHandle_t}(C_NULL)
const created_handles = IdDict{CuContext,cudnnHandle_t}()
const active_handles = Vector{Union{Nothing,cudnnHandle_t}}()

function handle()
if _handle[] == C_NULL
CUDAnative.maybe_initialize("CUDNN")
_handle[] = get!(_handles, active_context[]) do
context = active_context[]
tid = Threads.threadid()
if @inbounds active_handles[tid] === nothing
ctx = context()
active_handles[tid] = get!(created_handles, ctx) do
handle = cudnnCreate()
atexit(()->CUDAdrv.isvalid(context) && cudnnDestroy(handle))
atexit(()->CUDAdrv.isvalid(ctx) && cudnnDestroy(handle))
handle
end
end
@inbounds active_handles[tid]
end

function __init__()
resize!(active_handles, Threads.nthreads())
fill!(active_handles, nothing)

return _handle[]
CUDAnative.atcontextswitch() do tid, ctx
# we don't eagerly initialize handles, but do so lazily when requested
active_handles[tid] = nothing
end
end

end
34 changes: 28 additions & 6 deletions src/dnn/error.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,35 @@ function CUDNNError(status::cudnnStatus_t)
return CUDNNError(status, msg)
end

macro check(dnn_func)

## API call wrapper

# API calls that are allowed without a functional context
const preinit_apicalls = Set{Symbol}([
:cudnnGetVersion,
:cudnnGetProperty,
:cudnnGetCudartVersion,
:cudnnGetErrorString,
])

# outlined functionality to avoid GC frame allocation
@noinline function throw_api_error(res)
throw(CUDNNError(res))
end

macro check(ex)
fun = Symbol(decode_ccall_function(ex))
init = if !in(fun, preinit_apicalls)
:(CUDAnative.maybe_initialize())
end
quote
local err::cudnnStatus_t
err = $(esc(dnn_func))
if err != CUDNN_STATUS_SUCCESS
throw(CUDNNError(err))
$init

res = $(esc(ex))
if res != CUDNN_STATUS_SUCCESS
throw_api_error(res)
end
err

return
end
end
2 changes: 1 addition & 1 deletion src/dnn/filter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Base.unsafe_convert(::Type{cudnnFilterDescriptor_t}, fd::FilterDesc) = fd.ptr

function createFilterDesc()
d = Ref{cudnnFilterDescriptor_t}()
@check cudnnCreateFilterDescriptor(d)
cudnnCreateFilterDescriptor(d)
return d[]
end

Expand Down
2 changes: 2 additions & 0 deletions src/fft/CUFFT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import ..CuArrays: unsafe_free!
using CUDAdrv
using CUDAdrv: CUstream

using CUDAnative

using CEnum

const libcufft = Ref("libcufft")
Expand Down
32 changes: 26 additions & 6 deletions src/fft/error.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,33 @@ function status_message(status)
end
end

macro check(fft_func)

## API call wrapper

# API calls that are allowed without a functional context
const preinit_apicalls = Set{Symbol}([
:cufftGetVersion,
:cufftGetProperty,
])

# outlined functionality to avoid GC frame allocation
@noinline function throw_api_error(res)
throw(CUFFTError(res))
end

macro check(ex)
fun = Symbol(decode_ccall_function(ex))
init = if !in(fun, preinit_apicalls)
:(CUDAnative.maybe_initialize())
end
quote
local err::cufftResult
err = $(esc(fft_func::Expr))
if err != CUFFT_SUCCESS
throw(CUFFTError(err))
$init

res = $(esc(ex))
if res != CUFFT_SUCCESS
throw_api_error(res)
end
err

return
end
end
4 changes: 1 addition & 3 deletions src/memory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,7 @@ synchronized right before and after executing `ex` to exclude any external effec
macro time(ex)
quote
# @time might surround an application, so be sure to initialize CUDA before that
# FIXME: this should be done in CUDAdrv (`synchronize(ctx=CuCurrentOrNewContext()`)
# but the CUDA initialization mechanics are part of CUDAnative.jl
CUDAnative.maybe_initialize("@time")
CUDAnative.maybe_initialize()

# coarse synchronization to exclude effects from previously-executed code
CUDAdrv.synchronize()
Expand Down
Loading

0 comments on commit 16f3bef

Please sign in to comment.