Skip to content

Commit

Permalink
MNT: avoid code interpolation code repetition.
Browse files Browse the repository at this point in the history
  • Loading branch information
phmbressan committed Dec 19, 2023
1 parent 34ed384 commit 4ed16be
Showing 1 changed file with 43 additions and 50 deletions.
93 changes: 43 additions & 50 deletions rocketpy/mathutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,31 +484,7 @@ def get_value_opt(x):
elif self.__interpolation__ == "shepard":
# change the function's name to avoid mypy's error
def get_value_opt_multiple(*args):
x_data = self.source[:, 0:-1] # Support for N-Dimensions
y_data = self.source[:, -1]

arg_stack = np.column_stack(args)
arg_qty, arg_dim = arg_stack.shape
result = np.zeros(arg_qty)

# Reshape to vectorize calculations
x = arg_stack.reshape(arg_qty, 1, arg_dim)

sub_matrix = x_data - x
distances_squared = np.sum(sub_matrix**2, axis=2)

# Remove zero distances from further calculations
zero_distances = np.where(distances_squared == 0)
valid_indexes = np.ones(arg_qty, dtype=bool)
valid_indexes[zero_distances[0]] = False

weights = distances_squared[valid_indexes] ** (-1.5)
numerator_sum = np.sum(y_data * weights, axis=1)
denominator_sum = np.sum(weights, axis=1)
result[valid_indexes] = numerator_sum / denominator_sum
result[~valid_indexes] = y_data[zero_distances[1]]

return result if len(result) > 1 else result[0]
return self.__interpolate_shepard__(args)

get_value_opt = get_value_opt_multiple

Expand Down Expand Up @@ -880,31 +856,7 @@ def get_value(self, *args):

# Returns value for shepard interpolation
elif self.__interpolation__ == "shepard":
x_data = self.source[:, 0:-1] # Support for N-Dimensions
y_data = self.source[:, -1]

arg_stack = np.column_stack(args)
arg_qty, arg_dim = arg_stack.shape
result = np.zeros(arg_qty)

# Reshape to vectorize calculations
x = arg_stack.reshape(arg_qty, 1, arg_dim)

sub_matrix = x_data - x
distances_squared = np.sum(sub_matrix**2, axis=2)

# Remove zero distances from further calculations
zero_distances = np.where(distances_squared == 0)
valid_indexes = np.ones(arg_qty, dtype=bool)
valid_indexes[zero_distances[0]] = False

weights = distances_squared[valid_indexes] ** (-1.5)
numerator_sum = np.sum(y_data * weights, axis=1)
denominator_sum = np.sum(weights, axis=1)
result[valid_indexes] = numerator_sum / denominator_sum
result[~valid_indexes] = y_data[zero_distances[1]]

return result if len(result) > 1 else result[0]
return self.__interpolate_shepard__(args)

# Returns value for polynomial interpolation function type
elif self.__interpolation__ == "polynomial":
Expand Down Expand Up @@ -1656,6 +1608,47 @@ def __interpolate_akima__(self):
coeffs[4 * i : 4 * i + 4] = np.linalg.solve(matrix, result)
self.__akima_coefficients__ = coeffs

def __interpolate_shepard__(self, args):
"""Calculates the shepard interpolation from the given arguments.
The shepard interpolation is computed by a inverse distance weighting
in a vectorized manner.
Parameters
----------
args : scalar, list
Values where the Function is to be evaluated.
Returns
-------
result : scalar, list
The result of the interpolation.
"""
x_data = self.source[:, 0:-1] # Support for N-Dimensions
y_data = self.source[:, -1]

arg_stack = np.column_stack(args)
arg_qty, arg_dim = arg_stack.shape
result = np.zeros(arg_qty)

# Reshape to vectorize calculations
x = arg_stack.reshape(arg_qty, 1, arg_dim)

sub_matrix = x_data - x
distances_squared = np.sum(sub_matrix**2, axis=2)

# Remove zero distances from further calculations
zero_distances = np.where(distances_squared == 0)
valid_indexes = np.ones(arg_qty, dtype=bool)
valid_indexes[zero_distances[0]] = False

weights = distances_squared[valid_indexes] ** (-1.5)
numerator_sum = np.sum(y_data * weights, axis=1)
denominator_sum = np.sum(weights, axis=1)
result[valid_indexes] = numerator_sum / denominator_sum
result[~valid_indexes] = y_data[zero_distances[1]]

return result if len(result) > 1 else result[0]

def __neg__(self):
"""Negates the Function object. The result has the same effect as
multiplying the Function by -1.
Expand Down

0 comments on commit 4ed16be

Please sign in to comment.