-
Notifications
You must be signed in to change notification settings - Fork 0
/
tf_utils.py
79 lines (62 loc) · 2.85 KB
/
tf_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union, Optional
import numpy as np
import tensorflow as tf
import logging
logger = logging.getLogger(__name__)
def set_tensor_by_indices_to_value(
tensor: tf.Tensor, indices: tf.Tensor, value: Union[tf.Tensor, int, float]
):
# create value_tensor since tensor value assignment is not possible in TF
return tf.where(indices, value, tensor)
def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> List[int]:
"""
Deal with dynamic shape in tensorflow cleanly.
Args:
tensor (`tf.Tensor` or `np.ndarray`): The tensor we want the shape of.
Returns:
`List[int]`: The shape of the tensor as a list.
"""
if isinstance(tensor, np.ndarray):
return list(tensor.shape)
dynamic = tf.shape(tensor)
if tensor.shape == tf.TensorShape(None):
return dynamic
static = tensor.shape.as_list()
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
def stable_softmax(
logits: tf.Tensor, axis: Optional[int] = None, name: Optional[str] = None
) -> tf.Tensor:
"""
Stable wrapper that returns the same output as `tf.nn.softmax`, but that works reliably with XLA on CPU. It is
meant as a workaround for the [following issue](https://github.com/tensorflow/tensorflow/issues/55682), and will be
removed after it gets fixed. The arguments and outputs are the same as `tf.nn.softmax`, and relies on the fact that
`softmax(x) = softmax(x + c)` (see https://ogunlao.github.io/2020/04/26/you_dont_really_know_softmax.html).
Args:
logits (`tf.Tensor`):
Must be one of the following types: half, float32, float64.
axis (`int`, *optional*):
The dimension softmax would be performed on. The default is -1 which indicates the last dimension.
name (`str`, *optional*):
A name for the operation.
Returns:
`tf.Tensor`:
A Tensor. Has the same type and shape as logits.
"""
# TODO: When the issue linked above gets sorted, add a check on TF version here and use the original function if
# it has the fix. After we drop the support for unfixed versions, remove this function.
return tf.nn.softmax(
logits=logits + tf.constant(1e-9, dtype=logits.dtype), axis=axis, name=name
)