-
Notifications
You must be signed in to change notification settings - Fork 628
Expand file tree
/
Copy pathnearest_neighbours.h
More file actions
82 lines (66 loc) · 2.15 KB
/
nearest_neighbours.h
File metadata and controls
82 lines (66 loc) · 2.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
// Copyright 2017 Ben Frederickson
#ifndef IMPLICIT_NEAREST_NEIGHBOURS_H_
#define IMPLICIT_NEAREST_NEIGHBOURS_H_
#include <algorithm>
#include <functional>
#include <utility>
#include <vector>
namespace implicit {
/** Functor that stores the Top K (Value/Index) pairs
passed to it in and stores in its results member
*/
template <typename Index, typename Value> struct TopK {
explicit TopK(size_t K) : K(K) {}
void operator()(Index index, Value score) {
if ((results.size() < K) || (score > results[0].first)) {
if (results.size() >= K) {
std::pop_heap(results.begin(), results.end(), heap_order);
results.pop_back();
}
results.push_back(std::make_pair(score, index));
std::push_heap(results.begin(), results.end(), heap_order);
}
}
size_t K;
std::vector<std::pair<Value, Index>> results;
std::greater<std::pair<Value, Index>> heap_order;
};
/** A utility class to multiply rows of a sparse matrix
Implements the sparse matrix multiplication algorithm
described in the paper 'Sparse Matrix Multiplication Package (SMMP)'
http://www.i2m.univ-amu.fr/~bradji/multp_sparse.pdf
*/
template <typename Index, typename Value> class SparseMatrixMultiplier {
public:
explicit SparseMatrixMultiplier(Index item_count)
: sums(item_count, 0), nonzeros(item_count, -1), head(-2), length(0) {}
/** Adds value to the item at index */
void add(Index index, Value value) {
sums[index] += value;
if (nonzeros[index] == -1) {
nonzeros[index] = head;
head = index;
length += 1;
}
}
/** Calls a function once per non-zero entry, also clears state for next run*/
template <typename Function> void foreach (Function &f) { // NOLINT(*)
for (int i = 0; i < length; ++i) {
Index index = head;
f(index, sums[index]);
// clear up memory and advance linked list
head = nonzeros[head];
sums[index] = 0;
nonzeros[index] = -1;
}
length = 0;
head = -2;
}
Index nnz() const { return length; }
std::vector<Value> sums;
protected:
std::vector<Index> nonzeros;
Index head, length;
};
} // namespace implicit
#endif // IMPLICIT_NEAREST_NEIGHBOURS_H_