Skip to content

Commit aa647a8

Browse files
committed
refactor activation type
1 parent 97d9de7 commit aa647a8

8 files changed

Lines changed: 210 additions & 370 deletions

File tree

include/ActivationFunc.h

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,34 @@
11
#pragma once
2-
#include <memory>
3-
#include <armadillo>
2+
#include "common.h"
3+
44

55
namespace NeuralNet{
66
enum ActivationType {softmax, sigmoid, linear, tanh, ReLU};
77

8+
inline ActivationType GetActivationType(DeepLearning::NeuralNetParameter_ActivationType type){
9+
switch (type) {
10+
case DeepLearning::NeuralNetParameter_ActivationType_sigmoid:
11+
return sigmoid;
12+
break;
13+
case DeepLearning::NeuralNetParameter_ActivationType_linear:
14+
return linear;
15+
break;
16+
case DeepLearning::NeuralNetParameter_ActivationType_tanh:
17+
return tanh;
18+
break;
19+
case DeepLearning::NeuralNetParameter_ActivationType_softmax:
20+
return softmax;
21+
break;
22+
case DeepLearning::NeuralNetParameter_ActivationType_ReLU:
23+
return ReLU;
24+
break;
25+
default:
26+
std::cerr << "invalid activation type" << std::endl;
27+
exit(1);
28+
break;
29+
}
30+
}
31+
832
inline void ApplyActivation(std::shared_ptr<arma::mat> output, ActivationType actType){
933
std::shared_ptr<arma::mat> &p=output;
1034
arma::mat maxVal = arma::max(*p,0);

include/DeepLearning.pb.h

Lines changed: 63 additions & 117 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

include/MultiLayerPerceptron.h

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,34 +27,25 @@ namespace NeuralNet {
2727
void vectoriseWeight(arma::vec &x);
2828
void calLoss(std::shared_ptr<arma::mat> delta);
2929
virtual void forward();
30-
virtual void setTrainingSamples(std::shared_ptr<arma::mat> X, std::shared_ptr<arma::mat> Y);
31-
virtual void applyUpdates(std::vector<std::shared_ptr<arma::mat>>);
30+
virtual void applyUpdates(std::vector<std::shared_ptr<arma::mat>>);
3231
virtual void calGradient();
33-
virtual std::vector<std::shared_ptr<arma::mat>> netGradients();
3432
virtual double getLoss();
3533
virtual void save(std::string filename);
3634
virtual void load(std::string filename);
3735
virtual std::shared_ptr<arma::mat> netOutput() {
3836
return netOutput_;
3937
}
4038
private:
41-
DeepLearning::NeuralNetParameter neuralNetPara;
4239
int numLayers;
4340
int numInstance;
4441
bool testGrad;
4542
double error;
4643
/**the collection of Base layers*/
4744
std::vector<BaseLayer> layers;
48-
/**training data, input and label*/
49-
std::shared_ptr<arma::mat> trainingX;
50-
std::shared_ptr<arma::mat> trainingY;
5145
/* dimension parameters for each layer*/
5246
std::vector<int> dimensions;
5347
/* network output*/
5448
std::shared_ptr<arma::mat> netOutput_;
55-
/* network gradients*/
56-
std::vector<std::shared_ptr<arma::mat>> netGradVector;
57-
5849
int totalDim;
5950

6051
};

0 commit comments

Comments
 (0)