forked from hunter-packages/arrayfire
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbagging.cpp
More file actions
151 lines (115 loc) · 4.19 KB
/
bagging.cpp
File metadata and controls
151 lines (115 loc) · 4.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
/*******************************************************
* Copyright (c) 2014, ArrayFire
* All rights reserved.
*
* This file is distributed under 3-clause BSD license.
* The complete license agreement can be obtained at:
* http://arrayfire.com/licenses/BSD-3-Clause
********************************************************/
#include <arrayfire.h>
#include <stdio.h>
#include <vector>
#include <string>
#include <af/util.h>
#include <math.h>
#include "mnist_common.h"
using namespace af;
// Get accuracy of the predicted results
float accuracy(const array& predicted, const array& target)
{
return 100 * count<float>(predicted == target) / target.elements();
}
// Calculate all the distances from testing set to training set
array distance(array train, array test)
{
const int feat_len = train.dims(1);
const int num_train = train.dims(0);
const int num_test = test.dims(0);
array dist = constant(0, num_train, num_test);
// Iterate over each attribute
for (int ii = 0; ii < feat_len; ii++) {
// Get a attribute vectors
array train_i = train(span, ii);
array test_i = test (span, ii).T();
// Tile the vectors to generate matrices
array train_tiled = tile(train_i, 1, num_test);
array test_tiled = tile( test_i, num_train, 1 );
// Add the distance for this attribute
dist = dist + abs(train_tiled - test_tiled);
dist.eval(); // Necessary to free up train_i, test_i
}
return dist;
}
array knn(array &train_feats, array &test_feats, array &train_labels)
{
// Find distances between training and testing sets
array dist = distance(train_feats, test_feats);
// Find the neighbor producing the minimum distance
array val, idx;
min(val, idx, dist);
// Return the labels
return train_labels(idx);
}
array bagging(array &train_feats, array &test_feats, array &train_labels,
int num_classes, int num_models, int sample_size)
{
int num_train = train_feats.dims(0);
int num_test = test_feats.dims(0);
array idx = floor(randu(sample_size, num_models) * num_train);
array labels_all = constant(0, num_test, num_classes);
array off = seq(num_test);
for (int i = 0; i < num_models; i++) {
array ii = idx(span, i);
array train_feats_ii = lookup(train_feats, ii, 0);
array train_labels_ii = train_labels(ii);
// Get the predicted results
array labels_ii = knn(train_feats_ii, test_feats, train_labels_ii);
array lidx = labels_ii * num_test + off;
labels_all(lidx) = labels_all(lidx) + 1;
}
array val, labels;
max(val, labels, labels_all, 1);
return labels;
}
void bagging_demo(bool console, int perc)
{
array train_images, train_labels;
array test_images, test_labels;
int num_train, num_test, num_classes;
// Load mnist data
float frac = (float)(perc) / 100.0;
setup_mnist<false>(&num_classes, &num_train, &num_test,
train_images, test_images,
train_labels, test_labels, frac);
int feature_length = train_images.elements() / num_train;
array train_feats = moddims(train_images, feature_length, num_train).T();
array test_feats = moddims(test_images , feature_length, num_test ).T();
int num_models = 10;
int sample_size = 1000;
timer::start();
// Get the predicted results
array res_labels = bagging(train_feats, test_feats, train_labels,
num_classes, num_models, sample_size);
double test_time = timer::stop();
// Results
printf("Accuracy on testing data: %2.2f\n",
accuracy(res_labels , test_labels));
printf("Prediction time: %4.4f\n", test_time);
if (false && !console) {
display_results<false>(test_images, res_labels, test_labels.T(), 20);
}
}
int main(int argc, char** argv)
{
int device = argc > 1 ? atoi(argv[1]) : 0;
bool console = argc > 2 ? argv[2][0] == '-' : false;
int perc = argc > 3 ? atoi(argv[3]) : 60;
try {
setDevice(device);
af::info();
bagging_demo(console, perc);
} catch (af::exception &ae) {
std::cerr << ae.what() << std::endl;
}
return 0;
}