forked from tensorflow/tfjs
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtypes_utils.ts
More file actions
81 lines (75 loc) · 2.28 KB
/
types_utils.ts
File metadata and controls
81 lines (75 loc) · 2.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/* Original source: utils/generic_utils.py */
import {Tensor} from '@tensorflow/tfjs-core';
import {ValueError} from '../errors';
import {Shape} from '../keras_format/common';
// tslint:enable
/**
* Determine whether the input is an Array of Shapes.
*/
export function isArrayOfShapes(x: Shape|Shape[]): boolean {
return Array.isArray(x) && Array.isArray(x[0]);
}
/**
* Special case of normalizing shapes to lists.
*
* @param x A shape or list of shapes to normalize into a list of Shapes.
* @return A list of Shapes.
*/
export function normalizeShapeList(x: Shape|Shape[]): Shape[] {
if (x.length === 0) {
return [];
}
if (!Array.isArray(x[0])) {
return [x] as Shape[];
}
return x as Shape[];
}
/**
* Helper function to obtain exactly one Tensor.
* @param xs: A single `tf.Tensor` or an `Array` of `tf.Tensor`s.
* @return A single `tf.Tensor`. If `xs` is an `Array`, return the first one.
* @throws ValueError: If `xs` is an `Array` and its length is not 1.
*/
export function getExactlyOneTensor(xs: Tensor|Tensor[]): Tensor {
let x: Tensor;
if (Array.isArray(xs)) {
if (xs.length !== 1) {
throw new ValueError(`Expected Tensor length to be 1; got ${xs.length}`);
}
x = xs[0];
} else {
x = xs;
}
return x;
}
/**
* Helper function to obtain exactly on instance of Shape.
*
* @param shapes Input single `Shape` or Array of `Shape`s.
* @returns If input is a single `Shape`, return it unchanged. If the input is
* an `Array` containing exactly one instance of `Shape`, return the instance.
* Otherwise, throw a `ValueError`.
* @throws ValueError: If input is an `Array` of `Shape`s, and its length is not
* 1.
*/
export function getExactlyOneShape(shapes: Shape|Shape[]): Shape {
if (Array.isArray(shapes) && Array.isArray(shapes[0])) {
if (shapes.length === 1) {
shapes = shapes as Shape[];
return shapes[0];
} else {
throw new ValueError(`Expected exactly 1 Shape; got ${shapes.length}`);
}
} else {
return shapes as Shape;
}
}