-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathdml_operator_helper.cc
212 lines (177 loc) · 6.19 KB
/
dml_operator_helper.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
/* Copyright (c) Microsoft Corporation.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tfdml/core/dml_operator_helper.h"
#include "tfdml/runtime_adapter/bcast.h"
#include "tfdml/runtime_adapter/macros.h"
#include "tfdml/runtime_adapter/op_kernel_context.h"
#include "tfdml/runtime_adapter/status.h"
namespace tfdml
{
static TensorShape BroadcastTensorShape(
const TensorShape& input_shape_0,
const TensorShape& input_shape_1)
{
if (input_shape_0 == input_shape_1)
{
return input_shape_0;
}
const auto output_rank =
std::max(input_shape_0.dims(), input_shape_1.dims());
TensorShape output_shape;
for (int i = 0; i < output_rank; ++i)
{
output_shape.AddDim(0);
}
// Walk backwards through both input shapes and broadcast each dimension
int in_dim_0_idx = input_shape_0.dims() - 1;
int in_dim_1_idx = input_shape_1.dims() - 1;
for (int out_dim_idx = output_rank - 1; out_dim_idx >= 0; --out_dim_idx)
{
int64_t in_dim_0 = 1;
if (in_dim_0_idx >= 0)
{
in_dim_0 = input_shape_0.dim_size(in_dim_0_idx);
--in_dim_0_idx;
}
int64_t in_dim_1 = 1;
if (in_dim_1_idx >= 0)
{
in_dim_1 = input_shape_1.dim_size(in_dim_1_idx);
--in_dim_1_idx;
}
CHECK((in_dim_0 == in_dim_1) || (in_dim_0 == 1) || (in_dim_1 == 1));
int64_t broadcasted_dim = std::max(in_dim_0, in_dim_1);
CHECK(broadcasted_dim >= 0);
// Special case - you can't broadcast a zero dimension (the dimension
// stays zero)
if (in_dim_0 == 0 || in_dim_1 == 0)
{
broadcasted_dim = 0;
}
output_shape.set_dim(out_dim_idx, broadcasted_dim);
}
return output_shape;
}
TensorShape BroadcastTensorShapes(absl::Span<const TensorShape> shapes)
{
CHECK(!shapes.empty());
TensorShape accumulated_shape = shapes[0];
for (const TensorShape& shape : shapes)
{
accumulated_shape = BroadcastTensorShape(accumulated_shape, shape);
}
return accumulated_shape;
}
BroadcastedOutputShapeInitHelper::BroadcastedOutputShapeInitHelper(
OpKernelContext* ctx,
std::shared_ptr<const Attributes> attr)
{
constexpr bool fewer_dims_optimization = false;
for (int i = 0; i < ctx->num_inputs(); ++i)
{
TensorShape input_shape = ctx->input(i).shape();
BCast bcast_helper(
BCast::FromShape(broadcasted_shape_),
BCast::FromShape(input_shape),
fewer_dims_optimization);
OP_REQUIRES(
ctx,
bcast_helper.IsValid(),
errors::InvalidArgument(
"Incompatible shapes: ",
broadcasted_shape_.DebugString(),
" vs. ",
input_shape.DebugString()));
broadcasted_shape_ = BCast::ToShape(bcast_helper.output_shape());
}
}
std::vector<TensorShape> GetBroadcastedOutputShapeHelper::GetOutputShapes(
OpKernelContext* ctx,
const InitializationHelper* initialization_helper) const
{
auto init_helper = static_cast<const InitHelper*>(initialization_helper);
return {init_helper->GetBroadcastedShape()};
}
std::vector<TensorShape> BatchNormShapeHelper::GetOutputShapes(
OpKernelContext* ctx,
const InitializationHelper* initialization_helper) const
{
// _FusedBatchNormEx can have 6 inputs
CHECK(ctx->num_inputs() == 5 || ctx->num_inputs() == 6);
CHECK(ctx->num_outputs() == 5 || ctx->num_outputs() == 6);
if (ctx->num_outputs() == 5)
{
// FusedBatchNorm/FusedBatchNormV2 case
// The shape of the normalized output matches the input tensor, and the
// computed/saved mean/variance tensors match the shape of the scale
// tensor (which is 1D, and the same size as the input tensor's C
// dimension)
return {
ctx->input(0).shape(),
ctx->input(1).shape(),
ctx->input(1).shape(),
ctx->input(1).shape(),
ctx->input(1).shape(),
};
}
else
{
// FusedBatchNormV3 has an additional output tensor (which we don't
// actually use, so give it an empty shape)
return {
ctx->input(0).shape(),
ctx->input(1).shape(),
ctx->input(1).shape(),
ctx->input(1).shape(),
ctx->input(1).shape(),
TensorShape(),
};
}
}
std::vector<TensorShape> BatchNormGradShapeHelper::GetOutputShapes(
OpKernelContext* ctx,
const InitializationHelper* initialization_helper) const
{
CHECK(ctx->num_inputs() == 5 || ctx->num_inputs() == 6);
CHECK(ctx->num_outputs() == 5);
const Tensor x = ctx->input(0);
const Tensor scale = ctx->input(2);
const TensorShape& x_shape = x.shape();
const TensorShape& scale_shape = scale.shape();
// x_backprop, scale_backprop, offset_backprop, unused, unused
// scale_backprop and offset_backprop are both 1D and have the same shape.
return {
x_shape,
scale_shape,
scale_shape,
TensorShape(),
TensorShape(),
};
}
TensorShape ComputeFlatOuterDims(const TensorShape& orig, int64_t num_out_dims)
{
TensorShape out_dims;
for (int64_t out_dim = 0; out_dim < num_out_dims - 1; ++out_dim)
{
int64_t new_dim_size =
out_dim >= orig.dims() ? 1 : orig.dim_size(out_dim);
out_dims.AddDim(new_dim_size);
}
int64_t last_dim_size = 1;
for (int64_t in_dim = num_out_dims - 1; in_dim < orig.dims(); ++in_dim)
{
last_dim_size *= orig.dim_size(in_dim);
}
out_dims.AddDim(last_dim_size);
return out_dims;
}
} // namespace tfdml