Skip to content

Commit

Permalink
updated notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreFCruz committed Aug 28, 2024
1 parent b377c43 commit 6f8e709
Show file tree
Hide file tree
Showing 7 changed files with 585 additions and 221 deletions.
2 changes: 1 addition & 1 deletion folktexts/classifier/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from .._utils import hash_dict, hash_function

DEFAULT_CONTEXT_SIZE = 500
DEFAULT_CONTEXT_SIZE = 600
DEFAULT_BATCH_SIZE = 16

SCORE_COL_NAME = "risk_score"
Expand Down
2 changes: 1 addition & 1 deletion folktexts/cli/eval_feature_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

DEFAULT_TASK_NAME = "ACSIncome"

DEFAULT_CONTEXT_SIZE = 500
DEFAULT_CONTEXT_SIZE = 600
DEFAULT_BATCH_SIZE = 30
DEFAULT_SEED = 42

Expand Down
2 changes: 1 addition & 1 deletion folktexts/cli/run_acs_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
DEFAULT_ACS_TASK = "ACSIncome"

DEFAULT_BATCH_SIZE = 30
DEFAULT_CONTEXT_SIZE = 500
DEFAULT_CONTEXT_SIZE = 600
DEFAULT_SEED = 42


Expand Down
27 changes: 1 addition & 26 deletions folktexts/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def show_or_save(fig, fig_name: str):

# If the group is too small of a fraction, skip (curve will be too erratic)
if len(group_indices) / len(sensitive_attribute) < group_size_threshold:
logging.warning(f"Skipping group {group_value_map(s_value)} plot as it's too small.")
logging.info(f"Skipping group {group_value_map(s_value)} plot as it's too small.")
continue

# Plot global calibration curve
Expand All @@ -257,29 +257,4 @@ def show_or_save(fig, fig_name: str):
plt.title("Calibration curve per sub-group" + model_str)
show_or_save(fig, "calibration_curve_per_subgroup")

# ###
# Plot scores distribution per group
# ###
# TODO: make a decent score-distribution plot... # TODO: try score CDFs!
# hist_bin_edges = np.histogram_bin_edges(y_pred_scores, bins=10)
# for idx, s_value in enumerate(np.unique(sensitive_attribute)):
# group_indices = np.argwhere(sensitive_attribute == s_value).flatten()
# group_y_pred_scores = y_pred_scores[group_indices]
# is_first_group = (idx == 0)
# if is_first_group:
# fig, ax = plt.subplots()
# sns.histplot(
# group_y_pred_scores,
# bins=hist_bin_edges,
# stat="density",
# kde=False,
# color=group_colors[idx],
# label=group_value_map(s_value),
# ax=ax,
# )

# plt.legend()
# plt.title("Score distribution per sub-group" + model_str)
# results["score_distribution_per_subgroup_path"] = save_fig(fig, "score_distribution_per_subgroup", imgs_dir)

return results
262 changes: 136 additions & 126 deletions notebooks/detailed-example.ipynb

Large diffs are not rendered by default.

135 changes: 69 additions & 66 deletions notebooks/minimal-example.ipynb

Large diffs are not rendered by default.

376 changes: 376 additions & 0 deletions notebooks/minimal-example_web-API-model.ipynb

Large diffs are not rendered by default.

0 comments on commit 6f8e709

Please sign in to comment.