1414
1515#include < ATen/ATen.h>
1616#include < ATen/cuda/CUDAContext.h>
17-
18- #include < THC/THC.h>
19- #include < THC/THCDeviceUtils.cuh>
17+ #include < ATen/ceil_div.h>
18+ #include < c10/cuda/CUDACachingAllocator.h>
2019
2120#include < vector>
2221#include < iostream>
@@ -74,7 +73,7 @@ __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh,
7473 t |= 1ULL << i;
7574 }
7675 }
77- const int col_blocks = THCCeilDiv (n_boxes, threadsPerBlock);
76+ const int col_blocks = at::ceil_div (n_boxes, threadsPerBlock);
7877 dev_mask[cur_box_idx * col_blocks + col_start] = t;
7978 }
8079}
@@ -89,28 +88,28 @@ at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) {
8988
9089 int boxes_num = boxes.size (0 );
9190
92- const int col_blocks = THCCeilDiv (boxes_num, threadsPerBlock);
91+ const int col_blocks = at::ceil_div (boxes_num, threadsPerBlock);
9392
9493 scalar_t * boxes_dev = boxes_sorted.data_ptr <scalar_t >();
9594
96- THCState *state = at::globalContext ().lazyInitCUDA (); // TODO replace with getTHCState
95+ at::globalContext ().lazyInitCUDA (); // TODO replace with getTHCState
9796
9897 unsigned long long * mask_dev = NULL ;
9998 // THCudaCheck(THCudaMalloc(state, (void**) &mask_dev,
10099 // boxes_num * col_blocks * sizeof(unsigned long long)));
101100
102- mask_dev = (unsigned long long *) THCudaMalloc (state, boxes_num * col_blocks * sizeof (unsigned long long ));
101+ mask_dev = (unsigned long long *) c10::cuda::CUDACachingAllocator::raw_alloc ( boxes_num * col_blocks * sizeof (unsigned long long ));
103102
104- dim3 blocks (THCCeilDiv (boxes_num, threadsPerBlock),
105- THCCeilDiv (boxes_num, threadsPerBlock));
103+ dim3 blocks (at::ceil_div (boxes_num, threadsPerBlock),
104+ at::ceil_div (boxes_num, threadsPerBlock));
106105 dim3 threads (threadsPerBlock);
107106 nms_kernel<<<blocks, threads>>> (boxes_num,
108107 nms_overlap_thresh,
109108 boxes_dev,
110109 mask_dev);
111110
112111 std::vector<unsigned long long > mask_host (boxes_num * col_blocks);
113- THCudaCheck (cudaMemcpy (&mask_host[0 ],
112+ C10_CUDA_CHECK (cudaMemcpy (&mask_host[0 ],
114113 mask_dev,
115114 sizeof (unsigned long long ) * boxes_num * col_blocks,
116115 cudaMemcpyDeviceToHost));
@@ -135,7 +134,7 @@ at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) {
135134 }
136135 }
137136
138- THCudaFree (state, mask_dev);
137+ c10::cuda::CUDACachingAllocator::raw_delete ( mask_dev);
139138 // TODO improve this part
140139 return std::get<0 >(order_t .index ({
141140 keep.narrow (/* dim=*/ 0 , /* start=*/ 0 , /* length=*/ num_to_keep).to (
0 commit comments