-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
36 lines (34 loc) · 1.15 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from code import BiasMeasure, FindSeed, PromptSyntheticGenerator, DPPromptSyntheticGenerator
import pandas as pd
with open(".secrets", "r") as f:
openai_api_key = f.read()
print("Read API Key")
df = pd.read_csv("./data/biased/tabular/adult.csv")
df = df.sample(frac=1, random_state=42).reset_index(drop=True)
bm = BiasMeasure(openai_api_key)#,"Maritial Status", "Single", "Married")
regex_queries = bm.make_query(df)
print()
print("RegEx Done")
scores_df, measure_df = bm.evaluate_df(df, regex_queries)
print()
print("Score DF Generated")
print(measure_df)
fs = FindSeed("Maritial Status")#,"Single", "Married")
print()
print("Model Loaded")
seed_rows, example_counterfactuals = fs.find_seeds(scores_df)
print("Examples existing in the dataset:\n", example_counterfactuals)
print()
print("Seed Found")
psg = PromptSyntheticGenerator(openai_api_key) #, "Single", "Married")
for majority_sample in seed_rows:
print("Majority Sample")
print(majority_sample)
print()
print("Minority Sample")
print(psg.generate_synthetic(majority_sample))
break
print()
print("DP Synthetic Data Generation")
dppsg = DPPromptSyntheticGenerator(openai_api_key, df)
dppsg.generate_new_point()