Skip to content

Commit

Permalink
Make distance estimator param name more descriptive
Browse files Browse the repository at this point in the history
  • Loading branch information
zschira committed Nov 2, 2023
1 parent a57810a commit 688b577
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions src/pudl/analysis/record_linkage/classify_plants_ferc1.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, plants_steam_df, metric="euclidean", penalty=100):
metric: Distance metric to use in computation.
penalty: Penalty to apply to records with the same report year.
"""
self.df = plants_steam_df
self.plants_steam_df = plants_steam_df
self.metric = metric
self.penalty = penalty

Expand All @@ -48,11 +48,14 @@ def fit(self, X, y=None, **fit_params): # noqa: N803
def transform(self, X, y=None, **fit_params): # noqa: N803
"""Compute distance between records then add penalty to records from same year."""
dist_matrix = pairwise_distances(X, metric=self.metric)
report_years = range(self.df.report_year.min(), self.df.report_year.max() + 1)
report_years = range(
self.plants_steam_df.report_year.min(),
self.plants_steam_df.report_year.max() + 1,
)
penalty_matrix = np.full(dist_matrix.shape, 0)
for yr in report_years:
# get the indices of all the record pairs that have matching report years
yr_idx = self.df[self.df.report_year == yr].index
yr_idx = self.plants_steam_df[self.plants_steam_df.report_year == yr].index
yr_match_pairs_idx = np.array(np.meshgrid(yr_idx, yr_idx)).T.reshape(-1, 2)
idx_x = yr_match_pairs_idx[:, 0]
idx_y = yr_match_pairs_idx[:, 1]
Expand Down

0 comments on commit 688b577

Please sign in to comment.