forked from tensorflow/tfjs
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmath_utils.ts
More file actions
146 lines (132 loc) · 3.79 KB
/
math_utils.ts
File metadata and controls
146 lines (132 loc) · 3.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/**
* Math utility functions.
*
* This file contains some frequently used math function that operates on
* number[] or Float32Array and return a number. Many of these functions are
* not-so-thick wrappers around TF.js Core functions. But they offer the
* convenience of
* 1) not having to convert the inputs into Tensors,
* 2) not having to convert the returned Tensors to numbers.
*/
import * as tfc from '@tensorflow/tfjs-core';
import {scalar, Tensor1D, tensor1d} from '@tensorflow/tfjs-core';
import {ValueError} from '../errors';
export type ArrayTypes = Uint8Array | Int32Array | Float32Array;
/**
* Determine if a number is an integer.
*/
export function isInteger(x: number): boolean {
return x === parseInt(x.toString(), 10);
}
/**
* Calculate the product of an array of numbers.
* @param array The array to calculate the product over.
* @param begin Beginning index, inclusive.
* @param end Ending index, exclusive.
* @return The product.
*/
export function arrayProd(
array: number[] | ArrayTypes, begin?: number, end?: number): number {
if (begin == null) {
begin = 0;
}
if (end == null) {
end = array.length;
}
let prod = 1;
for (let i = begin; i < end; ++i) {
prod *= array[i];
}
return prod;
}
/**
* A helper function transforms the two input types to an instance of Tensor1D,
* so the return value can be fed directly into various TF.js Core functions.
* @param array
*/
function toArray1D(array: number[] | Float32Array): Tensor1D {
array = Array.isArray(array) ? new Float32Array(array) : array;
return tensor1d(array);
}
/**
* Compute minimum value.
* @param array
* @return minimum value.
*/
export function min(array: number[] | Float32Array): number {
return tfc.min(toArray1D(array)).dataSync()[0];
}
/**
* Compute maximum value.
* @param array
* @return maximum value
*/
export function max(array: number[] | Float32Array): number {
return tfc.max(toArray1D(array)).dataSync()[0];
}
/**
* Compute sum of array.
* @param array
* @return The sum.
*/
export function sum(array: number[] | Float32Array): number {
return tfc.sum(toArray1D(array)).dataSync()[0];
}
/**
* Compute mean of array.
* @param array
* @return The mean.
*/
export function mean(array: number[] | Float32Array): number {
return sum(array) / array.length;
}
/**
* Compute variance of array.
* @param array
* @return The variance.
*/
export function variance(array: number[] | Float32Array): number {
const demeaned = tfc.sub(toArray1D(array), scalar(mean(array)));
const sumSquare = tfc.sum(tfc.mulStrict(demeaned, demeaned)).dataSync()[0];
return sumSquare / array.length;
}
/**
* Compute median of array.
* @param array
* @return The median value.
*/
export function median(array: number[] | Float32Array): number {
const arraySorted = array.slice().sort((a, b) => a - b);
const lowIdx = Math.floor((arraySorted.length - 1) / 2);
const highIdx = Math.ceil((arraySorted.length - 1) / 2);
if (lowIdx === highIdx) {
return arraySorted[lowIdx];
}
return (arraySorted[lowIdx] + arraySorted[highIdx]) / 2;
}
/**
* Generate an array of integers in [begin, end).
* @param begin Beginning integer, inclusive.
* @param end Ending integer, exclusive.
* @returns Range array.
* @throws ValueError, iff `end` < `begin`.
*/
export function range(begin: number, end: number): number[] {
if (end < begin) {
throw new ValueError(`end (${end}) < begin (${begin}) is forbidden.`);
}
const out: number[] = [];
for (let i = begin; i < end; ++i) {
out.push(i);
}
return out;
}