forked from serizba/cppflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdatatype.h
More file actions
119 lines (110 loc) · 3.25 KB
/
datatype.h
File metadata and controls
119 lines (110 loc) · 3.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
//
// Created by serizba on 12/7/20.
//
#ifndef CPPFLOW2_DATATYPE_H
#define CPPFLOW2_DATATYPE_H
#include <type_traits>
#include <string>
#include <typeinfo>
#include <ostream>
#include <stdexcept>
namespace cppflow {
using datatype = TF_DataType;
/**
* @return A string representing dt
*
*/
inline std::string to_string(datatype dt) {
switch (dt) {
case TF_FLOAT:
return "TF_FLOAT";
case TF_DOUBLE:
return "TF_DOUBLE";
case TF_INT32:
return "TF_INT32";
case TF_UINT8:
return "TF_UINT8";
case TF_INT16:
return "TF_INT16";
case TF_INT8:
return "TF_INT8";
case TF_STRING:
return "TF_STRING";
case TF_COMPLEX64:
return "TF_COMPLEX64";
case TF_INT64:
return "TF_INT64";
case TF_BOOL:
return "TF_BOOL";
case TF_QINT8:
return "TF_QINT8";
case TF_QUINT8:
return "TF_QUINT8";
case TF_QINT32:
return "TF_QINT32";
case TF_BFLOAT16:
return "TF_BFLOAT16";
case TF_QINT16:
return "TF_QINT16";
case TF_QUINT16:
return "TF_QUINT16";
case TF_UINT16:
return "TF_UINT16";
case TF_COMPLEX128:
return "TF_COMPLEX128";
case TF_HALF:
return "TF_HALF";
case TF_RESOURCE:
return "TF_RESOURCE";
case TF_VARIANT:
return "TF_VARIANT";
case TF_UINT32:
return "TF_UINT32";
case TF_UINT64:
return "TF_UINT64";
default:
return "DATATYPE_NOT_KNOWN";
}
}
/**
*
* @tparam T
* @return The TensorFlow type of T
*/
template<typename T>
TF_DataType deduce_tf_type() {
if (std::is_same<T, float>::value)
return TF_FLOAT;
if (std::is_same<T, double>::value)
return TF_DOUBLE;
if (std::is_same<T, int32_t >::value)
return TF_INT32;
if (std::is_same<T, uint8_t>::value)
return TF_UINT8;
if (std::is_same<T, int16_t>::value)
return TF_INT16;
if (std::is_same<T, int8_t>::value)
return TF_INT8;
if (std::is_same<T, int64_t>::value)
return TF_INT64;
if (std::is_same<T, unsigned char>::value)
return TF_BOOL;
if (std::is_same<T, uint16_t>::value)
return TF_UINT16;
if (std::is_same<T, uint32_t>::value)
return TF_UINT32;
if (std::is_same<T, uint64_t>::value)
return TF_UINT64;
// decode with `c++filt --type $output` for gcc
throw std::runtime_error{"Could not deduce type! type_name: " + std::string(typeid(T).name())};
}
/**
* @return The stream os after inserting the string representation of dt
*
*/
inline std::ostream& operator<<(std::ostream& os, datatype dt) {
os << to_string(dt);
return os;
}
}
#endif //CPPFLOW2_DATATYPE_H