forked from gidariss/LocNet
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnms_mex.cpp
More file actions
121 lines (98 loc) · 3.07 KB
/
nms_mex.cpp
File metadata and controls
121 lines (98 loc) · 3.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#include "mex.h"
#ifdef WIN32
#include <windows.h>
#include <tchar.h>
#else
#include <algorithm>
#endif
#include <vector>
#include <map>
using namespace std;
struct score {
double s;
int idx;
bool operator() (score i, score j) { return (i.idx < j.idx);}
} score;
template <typename T>
void nms(const mxArray *input_boxes, double overlap, vector<int> &vPick, int &nPick)
{
int nSample = (int)mxGetM(input_boxes);
int nDim_boxes = (int)mxGetN(input_boxes);
T *pBoxes = (T*)mxGetData(input_boxes);
vector<double> vArea(nSample);
for (int i = 0; i < nSample; ++i)
{
vArea[i] = double(pBoxes[2*nSample + i] - pBoxes[0*nSample + i] + 1)
* (pBoxes[3*nSample + i] - pBoxes[1*nSample + i] + 1);
if (vArea[i] < 0)
mexErrMsgTxt("Boxes area must >= 0");
}
std::multimap<T, int> scores;
for (int i = 0; i < nSample; ++i)
scores.insert(std::pair<T,int>(pBoxes[4*nSample + i], i));
nPick = 0;
do
{
int last = scores.rbegin()->second;
vPick[nPick] = last;
nPick += 1;
for (typename std::multimap<T, int>::iterator it = scores.begin(); it != scores.end();)
{
int it_idx = it->second;
T xx1 = max(pBoxes[0*nSample + last], pBoxes[0*nSample + it_idx]);
T yy1 = max(pBoxes[1*nSample + last], pBoxes[1*nSample + it_idx]);
T xx2 = min(pBoxes[2*nSample + last], pBoxes[2*nSample + it_idx]);
T yy2 = min(pBoxes[3*nSample + last], pBoxes[3*nSample + it_idx]);
double w = std::max( (T)0.0, xx2-xx1+1), h = std::max((T)0.0, yy2-yy1+1);
double ov = w*h / (vArea[last] + vArea[it_idx] - w*h);
if (ov > overlap)
{
#ifdef WIN32
it = scores.erase(it);
#else
typename std::multimap<T, int>::iterator save=it; ++save;
scores.erase(it);
it=save;
#endif
}
else
{
it++;
}
}
} while (scores.size() != 0);
}
//void mexFunction(int nlhs, mxArray *plhs[], int nrhs, mxArray *prhs[])
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
if (nrhs != 2)
mexErrMsgTxt("Wrong number of inputs");
if (nlhs != 1)
mexErrMsgTxt("One output");
const mxArray *input_boxes = prhs[0];
if (mxGetClassID(input_boxes) != mxDOUBLE_CLASS && mxGetClassID(input_boxes) != mxSINGLE_CLASS)
mexErrMsgTxt("Input boxes must be Double or Single");
const mxArray *input_overlap = prhs[1];
if (mxGetClassID(input_overlap) != mxDOUBLE_CLASS )
mexErrMsgTxt("Input overlap must be Double");
double overlap = mxGetScalar(input_overlap);
int nSample = (int)mxGetM(input_boxes);
int nDim_boxes = (int)mxGetN(input_boxes);
if (nSample * nDim_boxes == 0)
{
plhs[0] = mxCreateNumericMatrix(0, 0, mxDOUBLE_CLASS, mxREAL);
return;
}
if (nDim_boxes != 5)
mexErrMsgTxt("nms_mex boxes must has 5 columns");
int nPick = 0;
vector<int> vPick(nSample);
if(mxGetClassID(input_boxes) == mxDOUBLE_CLASS)
nms<double>(input_boxes, overlap, vPick, nPick);
else
nms<float>(input_boxes, overlap, vPick, nPick);
plhs[0] = mxCreateNumericMatrix(nPick, 1, mxDOUBLE_CLASS, mxREAL);
double *pRst = mxGetPr(plhs[0]);
for (int i = 0; i < nPick; ++i)
pRst[i] = vPick[i] + 1;
}