Skip to content

Commit 54e79dd

Browse files
committedDec 11, 2020
perf(mgb/cuda): do not call cudaGetDeviceProperties to avoid io traffic
GitOrigin-RevId: 6aa35928c8ec737d244fdb3ca9639ae49b03b284
·
v1.13.4v1.2.0
1 parent 5f17129 commit 54e79dd

File tree

3 files changed

+28
-43
lines changed

3 files changed

+28
-43
lines changed
 

‎dnn/src/cuda/topk/opr_impl.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,25 @@ template <typename ctype>
2222
void TopKImpl::dispatch_with_ctype(int k, size_t m, size_t n, ptrdiff_t lda,
2323
const ctype* data, ctype* values,
2424
int* indices, void* workspace) {
25-
auto stream = concrete_handle(handle())->stream();
25+
auto _handle = concrete_handle(handle());
26+
auto stream = _handle->stream();
27+
size_t grid_dim_y_limit = _handle->device_prop().maxGridSize[1];
2628
switch (param().mode) {
2729
case Param::Mode::KTH_ONLY:
2830
cuda_check(topk::find_kth_radix<ctype>(data, values, workspace, m,
29-
n, lda, k, stream));
31+
n, lda, k, grid_dim_y_limit,
32+
stream));
3033
return;
3134
case Param::Mode::VALUE_IDX_NOSORT: {
3235
WorkspaceBundle wk_bundle{workspace, {m * sizeof(ctype), 1}};
3336
auto thresh = static_cast<ctype*>(wk_bundle.get(0));
3437
auto real_wk = wk_bundle.get(1);
3538
cuda_check(topk::find_kth_radix<ctype>(data, thresh, real_wk, m, n,
36-
lda, k, stream));
39+
lda, k, grid_dim_y_limit,
40+
stream));
3741
cuda_check(topk::topk_select<ctype>(data, thresh, values, indices,
38-
real_wk, m, n, lda, k, stream));
42+
real_wk, m, n, lda, k,
43+
grid_dim_y_limit, stream));
3944
return;
4045
}
4146
case Param::Mode::VALUE_IDX_SORTED: {
@@ -48,10 +53,11 @@ void TopKImpl::dispatch_with_ctype(int k, size_t m, size_t n, ptrdiff_t lda,
4853
auto nosort_idx = static_cast<int32_t*>(wk_bundle.get(2));
4954
auto real_wk = wk_bundle.get(3);
5055
cuda_check(topk::find_kth_radix<ctype>(data, thresh, real_wk, m, n,
51-
lda, k, stream));
56+
lda, k, grid_dim_y_limit,
57+
stream));
5258
cuda_check(topk::topk_select<ctype>(data, thresh, nosort_values,
5359
nosort_idx, real_wk, m, n, lda,
54-
k, stream));
60+
k, grid_dim_y_limit, stream));
5561
argsort::forward(nosort_values, values, indices, real_wk, m,
5662
std::abs(k), k > 0, stream, nosort_idx);
5763
return;
@@ -89,9 +95,11 @@ size_t TopKImpl::get_workspace_in_bytes(int k, const TensorLayout& data,
8995
MEGDNN_MARK_USED_VAR(indices);
9096
size_t m = data[0], n = data[1];
9197
size_t kabs = std::abs(k);
98+
size_t grid_dim_y_limit =
99+
concrete_handle(handle())->device_prop().maxGridSize[1];
92100
megdnn_assert(std::max(m, n) <=
93101
static_cast<size_t>(std::numeric_limits<int>::max()));
94-
size_t kth = topk::find_kth_radix_workspace(m, n),
102+
size_t kth = topk::find_kth_radix_workspace(m, n, grid_dim_y_limit),
95103
sel = topk::topk_select_workspace(m, n);
96104
auto ctsize = data.dtype.size();
97105
switch (param().mode) {

‎dnn/src/cuda/topk/topk_radix.cu

Lines changed: 7 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -468,17 +468,9 @@ static size_t get_scan_workspace(uint32_t size) {
468468
} // namespace select
469469
} // namespace cuda_topk_impl
470470

471-
uint32_t topk::find_kth_radix_workspace(uint32_t batch, uint32_t length) {
471+
uint32_t topk::find_kth_radix_workspace(uint32_t batch, uint32_t length,
472+
uint32_t grid_dim_y_limit) {
472473
using namespace cuda_topk_impl::kth;
473-
int device_id;
474-
if (cudaGetDevice(&device_id) != cudaSuccess) {
475-
megdnn_trap();
476-
}
477-
cudaDeviceProp prop;
478-
if (cudaGetDeviceProperties(&prop, device_id) != cudaSuccess) {
479-
megdnn_trap();
480-
}
481-
uint32_t grid_dim_y_limit = prop.maxGridSize[1];
482474
uint32_t limit = batch > grid_dim_y_limit ? grid_dim_y_limit : batch;
483475
return (limit * get_grid_dim_x(length) * NR_BUCKET + limit * 2) *
484476
sizeof(uint32_t);
@@ -488,6 +480,7 @@ template <typename ctype>
488480
cudaError_t topk::find_kth_radix(const ctype* input, ctype* output,
489481
void* workspace, uint32_t batch,
490482
uint32_t length, int32_t lda, int32_t k,
483+
uint32_t grid_dim_y_limit,
491484
cudaStream_t stream) {
492485
using namespace cuda_topk_impl::kth;
493486
if (!k) {
@@ -502,16 +495,6 @@ cudaError_t topk::find_kth_radix(const ctype* input, ctype* output,
502495
megdnn_trap();
503496
}
504497

505-
int device_id;
506-
if (cudaGetDevice(&device_id) != cudaSuccess) {
507-
megdnn_trap();
508-
}
509-
cudaDeviceProp prop;
510-
if (cudaGetDeviceProperties(&prop, device_id) != cudaSuccess) {
511-
megdnn_trap();
512-
}
513-
uint32_t grid_dim_y_limit = prop.maxGridSize[1];
514-
515498
uint32_t batch_idx = 0;
516499
uint32_t grid_dim_x = get_grid_dim_x(length);
517500
uint32_t grid_dim_y = 1;
@@ -567,20 +550,11 @@ template <typename ctype>
567550
cudaError_t topk::topk_select(const ctype* input, const ctype* thresh,
568551
ctype* output_value, int32_t* output_idx,
569552
void* workspace, uint32_t batch, uint32_t length,
570-
int32_t lda, int32_t k, cudaStream_t stream) {
553+
int32_t lda, int32_t k,
554+
uint32_t batch_upper_limit, cudaStream_t stream) {
571555
using namespace cuda_topk_impl;
572556
using namespace cuda_topk_impl::select;
573557

574-
int device_id;
575-
if (cudaGetDevice(&device_id) != cudaSuccess) {
576-
megdnn_trap();
577-
}
578-
cudaDeviceProp prop;
579-
if (cudaGetDeviceProperties(&prop, device_id) != cudaSuccess) {
580-
megdnn_trap();
581-
}
582-
uint32_t batch_upper_limit = prop.maxGridSize[1];
583-
584558
uint32_t length_split = DIVUP(length, REDUCE_SIZE);
585559

586560
void (*kptr_reduce_block_cnt)(const ctype*, const ctype*, uint32_t, int32_t,
@@ -688,10 +662,10 @@ namespace topk {
688662
#define INST(t) \
689663
template cudaError_t find_kth_radix<t>(const t*, t*, void*, uint32_t, \
690664
uint32_t, int32_t, int32_t, \
691-
cudaStream_t); \
665+
uint32_t, cudaStream_t); \
692666
template cudaError_t topk_select<t>(const t*, const t*, t*, int32_t*, \
693667
void*, uint32_t, uint32_t, int32_t, \
694-
int32_t, cudaStream_t)
668+
int32_t, uint32_t, cudaStream_t)
695669
INST(float);
696670
INST(int32_t);
697671
#undef INST

‎dnn/src/cuda/topk/topk_radix.cuh

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,12 @@ struct RadixConverter<int32_t> {
7676
template <typename ctype>
7777
cudaError_t find_kth_radix(const ctype* input, ctype* output, void* workspace,
7878
uint32_t batch, uint32_t length, int32_t lda,
79-
int32_t k, cudaStream_t stream);
79+
int32_t k, uint32_t grid_dim_y_limit,
80+
cudaStream_t stream);
8081

8182
//! get workspace in bytes
82-
uint32_t find_kth_radix_workspace(uint32_t batch, uint32_t length);
83+
uint32_t find_kth_radix_workspace(uint32_t batch, uint32_t length,
84+
uint32_t grid_dim_y_limit);
8385

8486
/*!
8587
* \brief select values from rows of input that compare to thresh as specified
@@ -90,7 +92,8 @@ template <typename ctype>
9092
cudaError_t topk_select(const ctype* input, const ctype* thresh,
9193
ctype* output_value, int32_t* output_idx,
9294
void* workspace, uint32_t batch, uint32_t length,
93-
int32_t lda, int32_t k, cudaStream_t stream);
95+
int32_t lda, int32_t k, uint32_t batch_upper_limit,
96+
cudaStream_t stream);
9497

9598
uint32_t topk_select_workspace(uint32_t batch, uint32_t length);
9699

0 commit comments

Comments
 (0)
Please sign in to comment.