Skip to content

Commit

Permalink
Clean code (#4)
Browse files Browse the repository at this point in the history
* Remove unused code

* Clean tests

* Clean src

---------

Co-authored-by: Juliette-Gerbaux <[email protected]>
  • Loading branch information
Juliette-Gerbaux and Juliette-Gerbaux authored Dec 10, 2024
1 parent 0e63c2b commit 85aecd3
Show file tree
Hide file tree
Showing 17 changed files with 44 additions and 290 deletions.
4 changes: 0 additions & 4 deletions src/calculate_reward_and_bellman_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,6 @@ def get_disc(
n_pts_below + n_pts_in + n_pts_above
) # Make sure total adds up
if method == "lines":
# in_curve_pts = np.linspace(lbs, ubs, xNsteps-2) # Disc-2 * R
# all_pts = np.insert(in_curve_pts, 0, empty, axis=0) # (Disc-1) * R
# all_pts = np.insert(all_pts, all_pts.shape[0], full, axis=0) # Disc * R
above_curve_pts = [
np.linspace(ubs[r], full[r], n_pts_above[r], endpoint=True)
for r in range(n_reservoirs)
Expand All @@ -183,7 +180,6 @@ def get_disc(
for r in range(n_reservoirs)
]
).T
# all_pts = np.concatenate((below_curve_pts, in_curve_pts, above_curve_pts), axis=0) # Disc * R
diffs_to_ref = all_pts[:, None] - reference_pt[None, :] # Disc * R
diffs_to_ref = (
diffs_to_ref[:, :, None] * np.eye(n_reservoirs)[None, :, :]
Expand Down
16 changes: 4 additions & 12 deletions src/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ def draw_usage_values(
closest_level
]

# z = np.maximum(np.zeros(reinterpolated_usage_values[area].T.shape), np.minimum(ub*np.ones(reinterpolated_usage_values[area].T.shape), reinterpolated_usage_values[area].T))
# z = np.maximum(np.zeros(reinterpolated_usage_values[area].T.shape) - ub, np.minimum(ub*np.ones(reinterpolated_usage_values[area].T.shape), reinterpolated_usage_values[area].T))
usage_values_plot = go.Figure(
data=[
go.Heatmap(
Expand Down Expand Up @@ -64,9 +62,7 @@ def draw_usage_values(
marker=dict(symbol="circle"),
showlegend=True,
)
for i, (area, mng) in enumerate(
multi_stock_management.dict_reservoirs.items()
)
for i, (_, _) in enumerate(multi_stock_management.dict_reservoirs.items())
]
+ [
go.Scatter(
Expand All @@ -78,9 +74,7 @@ def draw_usage_values(
line=dict(dash="dash"),
showlegend=True,
)
for i, (area, mng) in enumerate(
multi_stock_management.dict_reservoirs.items()
)
for i, (_, mng) in enumerate(multi_stock_management.dict_reservoirs.items())
]
+ [
go.Scatter(
Expand All @@ -92,9 +86,7 @@ def draw_usage_values(
line=dict(dash="dash"),
showlegend=True,
)
for i, (area, mng) in enumerate(
multi_stock_management.dict_reservoirs.items()
)
for i, (_, mng) in enumerate(multi_stock_management.dict_reservoirs.items())
],
layout=dict(title=f"Usage Values"),
)
Expand Down Expand Up @@ -168,7 +160,7 @@ def draw_uvs_sddp(
showlegend=True,
visible=(r == 0),
)
for r, res in enumerate(reservoirs)
for r, _ in enumerate(reservoirs)
]
+ [
go.Scatter(
Expand Down
11 changes: 1 addition & 10 deletions src/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,6 @@ def __init__(
costs: np.ndarray,
duals: np.ndarray,
correlations: Optional[np.ndarray] = None,
interp_mode: bool = False,
) -> None:
"""
Instanciates a LinearCostEstimator
Expand All @@ -324,9 +323,6 @@ def __init__(
costs:np.ndarray: Cost for every input,
duals:np.ndarray: Duals for every input first dimension should be the same as inputs,
"""
# self.true_controls=controls
# self.true_costs=costs
# self.true_duals=duals
self.estimators = np.array(
[
[
Expand All @@ -335,7 +331,6 @@ def __init__(
costs=costs[week, scenario],
duals=duals[week, scenario],
correlations=correlations,
# interp_mode=interp_mode,
)
for scenario in range(param.len_scenario)
]
Expand Down Expand Up @@ -389,7 +384,6 @@ def update(
inputs: np.ndarray,
costs: np.ndarray,
duals: np.ndarray,
interp_mode: bool = False,
) -> None:
"""
Updates the parameters of the Linear Interpolators
Expand All @@ -409,10 +403,9 @@ def update(
inputs=inputs,
costs=costs,
duals=duals,
# interp_mode=interp_mode,
)

def enrich_estimator(self, n_splits: int = 3) -> None:
def enrich_estimator(self) -> None:
"""
Adds 'mid_cuts' to our cost estimator to smoothen the curves and (hopefully) accelerate convergence
Expand Down Expand Up @@ -441,8 +434,6 @@ def cleanup_approximations(
for week_estimators in self.estimators:
for estimator in week_estimators:
estimator.remove_incoherence()
# controls=true_controls[week, scenario],
# real_costs=true_costs[week, scenario])

def remove_redundants(
self,
Expand Down
3 changes: 1 addition & 2 deletions src/functions_iterative.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def itr_control(
) = init_iterative_calculation(param, reservoir_management, output_path, X, solver)
i = 0

while (gap >= tol_gap and gap >= 0) and i < N: # and (i<3):
while (gap >= tol_gap and gap >= 0) and i < N:
debut = time()

initial_x, controls = compute_x_multi_scenario(
Expand Down Expand Up @@ -391,7 +391,6 @@ def init_iterative_calculation(
stock_discretization=X,
)

i = 0
gap = 1e3
fin = time()
tot_t.append(fin - debut)
Expand Down
7 changes: 0 additions & 7 deletions src/hyperplane_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,11 @@ def BezierAv(
if last_axify:
p[:-1] = (p0[:-1] + p1[:-1]) / 2
p[-1] = p0[-1] * diff_rate + p1[-1] * (1 - diff_rate)
# v_sum = np.sum(v[:-1])
# v[-1] /= v_sum
v = v0 * (t) + v1 * (1 - t)
v[-1] = v0[-1] * (t * (1 - partage) + partage * diff_rate) + v1[-1] * (
(1 - t) * (1 - partage) + partage * (1 - diff_rate)
)
elif norm(v) > 1e-12:
# v = v / sum(v[:-1])
v = v / norm(v) * (n0 + n1) / 2
return np.array([p, v])

Expand Down Expand Up @@ -125,10 +122,6 @@ def interpolate_between(
return mlr, data


mlrs: list = []
datas: list = []


def enrich_by_interpolation(
controls_init: np.ndarray, costs: np.ndarray, slopes: np.ndarray, n_splits: int = 3
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
Expand Down
4 changes: 2 additions & 2 deletions src/launch_calculation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def calculate_bellman_values(

elif method == "precalculated":
# or with precalulated reward
vb, G = calculate_bellman_value_with_precalculated_reward(
vb, _ = calculate_bellman_value_with_precalculated_reward(
len_controls=len_controls,
param=param,
reservoir_management=reservoir_management,
Expand All @@ -61,7 +61,7 @@ def calculate_bellman_values(

elif method == "iterative":
# or with iterative algorithm
vb, G, _, _, controls_upper, traj = itr_control(
vb, _, _, _, _, _ = itr_control(
param=param,
reservoir_management=reservoir_management,
output_path=output_path,
Expand Down
Loading

0 comments on commit 85aecd3

Please sign in to comment.