-
Notifications
You must be signed in to change notification settings - Fork 75.3k
Expand file tree
/
Copy pathslicing.py
More file actions
184 lines (161 loc) · 7.36 KB
/
slicing.py
File metadata and controls
184 lines (161 loc) · 7.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for slicing in to a `LinearOperator`."""
import collections
import functools
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.util import nest
__all__ = ['batch_slice']
def _prefer_static_where(condition, x, y):
args = [condition, x, y]
constant_args = [tensor_util.constant_value(a) for a in args]
# Do this statically.
if all(arg is not None for arg in constant_args):
condition_, x_, y_ = constant_args
return np.where(condition_, x_, y_)
return array_ops.where(condition, x, y)
def _broadcast_parameter_with_batch_shape(
param, param_ndims_to_matrix_ndims, batch_shape):
"""Broadcasts `param` with the given batch shape, recursively."""
if hasattr(param, 'batch_shape_tensor'):
# Recursively broadcast every parameter inside the operator.
override_dict = {}
for name, ndims in param._experimental_parameter_ndims_to_matrix_ndims.items(): # pylint:disable=protected-access,line-too-long
sub_param = getattr(param, name)
override_dict[name] = nest.map_structure_up_to(
sub_param, functools.partial(
_broadcast_parameter_with_batch_shape,
batch_shape=batch_shape), sub_param, ndims)
parameters = dict(param.parameters, **override_dict)
return type(param)(**parameters)
base_shape = array_ops.concat(
[batch_shape, array_ops.ones(
[param_ndims_to_matrix_ndims], dtype=dtypes.int32)], axis=0)
return array_ops.broadcast_to(
param,
array_ops.broadcast_dynamic_shape(base_shape, array_ops.shape(param)))
def _sanitize_slices(slices, intended_shape, deficient_shape):
"""Restricts slices to avoid overflowing size-1 (broadcast) dimensions.
Args:
slices: iterable of slices received by `__getitem__`.
intended_shape: int `Tensor` shape for which the slices were intended.
deficient_shape: int `Tensor` shape to which the slices will be applied.
Must have the same rank as `intended_shape`.
Returns:
sanitized_slices: Python `list` of slice objects.
"""
sanitized_slices = []
idx = 0
for slc in slices:
if slc is Ellipsis: # Switch over to negative indexing.
if idx < 0:
raise ValueError('Found multiple `...` in slices {}'.format(slices))
num_remaining_non_newaxis_slices = sum(
s is not array_ops.newaxis for s in slices[
slices.index(Ellipsis) + 1:])
idx = -num_remaining_non_newaxis_slices
elif slc is array_ops.newaxis:
pass
else:
is_broadcast = intended_shape[idx] > deficient_shape[idx]
if isinstance(slc, slice):
# Slices are denoted by start:stop:step.
start, stop, step = slc.start, slc.stop, slc.step
if start is not None:
start = _prefer_static_where(is_broadcast, 0, start)
if stop is not None:
stop = _prefer_static_where(is_broadcast, 1, stop)
if step is not None:
step = _prefer_static_where(is_broadcast, 1, step)
slc = slice(start, stop, step)
else: # int, or int Tensor, e.g. d[d.batch_shape_tensor()[0] // 2]
slc = _prefer_static_where(is_broadcast, 0, slc)
idx += 1
sanitized_slices.append(slc)
return sanitized_slices
def _slice_single_param(
param, param_ndims_to_matrix_ndims, slices, batch_shape):
"""Slices into the batch shape of a single parameter.
Args:
param: The original parameter to slice; either a `Tensor` or an object
with batch shape (LinearOperator).
param_ndims_to_matrix_ndims: `int` number of right-most dimensions used for
inferring matrix shape of the `LinearOperator`. For non-Tensor
parameters, this is the number of this param's batch dimensions used by
the matrix shape of the parent object.
slices: iterable of slices received by `__getitem__`.
batch_shape: The parameterized object's batch shape `Tensor`.
Returns:
new_param: Instance of the same type as `param`, batch-sliced according to
`slices`.
"""
# Broadcast the parammeter to have full batch rank.
param = _broadcast_parameter_with_batch_shape(
param, param_ndims_to_matrix_ndims, array_ops.ones_like(batch_shape))
if hasattr(param, 'batch_shape_tensor'):
param_batch_shape = param.batch_shape_tensor()
else:
param_batch_shape = array_ops.shape(param)
# Truncate by param_ndims_to_matrix_ndims
param_batch_rank = array_ops.size(param_batch_shape)
param_batch_shape = param_batch_shape[
:(param_batch_rank - param_ndims_to_matrix_ndims)]
# At this point the param should have full batch rank, *unless* it's an
# atomic object like `tfb.Identity()` incapable of having any batch rank.
if (tensor_util.constant_value(array_ops.size(batch_shape)) != 0 and
tensor_util.constant_value(array_ops.size(param_batch_shape)) == 0):
return param
param_slices = _sanitize_slices(
slices, intended_shape=batch_shape, deficient_shape=param_batch_shape)
# Extend `param_slices` (which represents slicing into the
# parameter's batch shape) with the parameter's event ndims. For example, if
# `params_ndims == 1`, then `[i, ..., j]` would become `[i, ..., j, :]`.
if param_ndims_to_matrix_ndims > 0:
if Ellipsis not in [
slc for slc in slices if not tensor_util.is_tensor(slc)]:
param_slices.append(Ellipsis)
param_slices += [slice(None)] * param_ndims_to_matrix_ndims
return param.__getitem__(tuple(param_slices))
def batch_slice(linop, params_overrides, slices):
"""Slices `linop` along its batch dimensions.
Args:
linop: A `LinearOperator` instance.
params_overrides: A `dict` of parameter overrides.
slices: A `slice` or `int` or `int` `Tensor` or `tf.newaxis` or `tuple`
thereof. (e.g. the argument of a `__getitem__` method).
Returns:
new_linop: A batch-sliced `LinearOperator`.
"""
if not isinstance(slices, collections.abc.Sequence):
slices = (slices,)
if len(slices) == 1 and slices[0] is Ellipsis:
override_dict = {}
else:
batch_shape = linop.batch_shape_tensor()
override_dict = {}
for param_name, param_ndims_to_matrix_ndims in linop._experimental_parameter_ndims_to_matrix_ndims.items(): # pylint:disable=protected-access,line-too-long
param = getattr(linop, param_name)
# These represent optional `Tensor` parameters.
if param is not None:
override_dict[param_name] = nest.map_structure_up_to(
param, functools.partial(
_slice_single_param, slices=slices, batch_shape=batch_shape),
param, param_ndims_to_matrix_ndims)
override_dict.update(params_overrides)
parameters = dict(linop.parameters, **override_dict)
return type(linop)(**parameters)