Skip to content

Commit

Permalink
entropy and cross entropy ready for experiment
Browse files Browse the repository at this point in the history
  • Loading branch information
whoisjones committed Nov 3, 2023
1 parent 83347c2 commit e132131
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 18 deletions.
5 changes: 4 additions & 1 deletion first_experiment/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def load_dataset_from_hub(dataset_name: str) -> DatasetDict:
parser.add_argument("--dataset", type=str, default="imdb", choices=["imdb", "rte", "qnli", "sst2", "snli"])
parser.add_argument("--tam_model", type=str, default="distilbert-base-uncased")
parser.add_argument("--embedding_model", type=str, default=None)
parser.add_argument("--init_strategy", type=str, choices=["random", "closest-to-centeroid", "furthest-to-centeroid", "expected-gradients", "certainty"], default="random")
parser.add_argument("--init_strategy", type=str, choices=["random", "closest-to-centeroid", "furthest-to-centeroid", "expected-gradients", "cross-entropy", "entropy"], default="random")
parser.add_argument("--stopping_criteria", type=str)
parser.add_argument("--dataset_size", type=int, nargs="+", default=[32, 64, 128, 256, 512, 1024, 2048, 4096, 0])
args = parser.parse_args()
Expand All @@ -57,6 +57,9 @@ def load_dataset_from_hub(dataset_name: str) -> DatasetDict:

for dataset_size in args.dataset_size:

if dataset_size == 0 and args.init_strategy != "random":
continue

dataset = deepcopy(full_dataset)

if dataset_size > 0:
Expand Down
54 changes: 37 additions & 17 deletions first_experiment/selection_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,21 @@ def select_fewshots(
dataset_size,
task_keys
)
elif args.init_strategy == "certainty":
elif args.init_strategy == "cross-entropy":
dataset = cross_entropy_selection(
args.tam_model,
full_dataset,
dataset_size,
task_keys,
args.dataset
)
elif args.init_strategy == "entropy":
dataset = entropy_selection(
args.tam_model,
full_dataset,
dataset_size,
task_keys
task_keys,
args.dataset
)
else:
raise NotImplementedError
Expand Down Expand Up @@ -283,33 +292,41 @@ def cross_entropy_selection(
dataset: DatasetDict,
num_total_samples: int,
task_keys: dict,
dataset_name: str,
) -> DatasetDict:
label_column = task_keys["label_column"]
id2label = dict(enumerate(dataset["train"].features[label_column].names))

cache_file_name = f"{dataset}-cross-entropy.json"
cache_file_name = f"{dataset_name}-cross-entropy.json"

if cache_file_name in os.listdir(CACHE_DIR):
with open(os.path.join(CACHE_DIR, cache_file_name), "r") as f:
entropy_tuples = json.load(f)
else:
model, tokenizer = get_classification_model_and_tokenizer(model_name_or_path)
model, tokenizer = get_classification_model_and_tokenizer(model_name_or_path, id2label=id2label)
train_loader = get_trainloader(dataset, tokenizer, task_keys)

entropy = []
criterion = torch.nn.CrossEntropyLoss(reduction="none")

with torch.no_grad():
for batch, targets in tqdm(train_loader):
outputs = model(**batch)
targets = batch.pop("labels")
entropy.extend(criterion(outputs.logits, targets).cpu().numpy().tolist())
for batch in tqdm(train_loader):
outputs = model(**{k: v.to(model.device) for k, v in batch.items()})
targets = batch.pop("labels").to(model.device)
entropy.extend(
[(l, e)
for l, e
in zip(
targets.detach().cpu().numpy().tolist(),
criterion(outputs.logits, targets).cpu().numpy().tolist()
)]
)

assert len(entropy) == len(dataset["train"])

entropy_tuples = [(i, l, e) for i, (l, e) in enumerate(zip(dataset["train"][label_column], entropy))]
entropy_tuples = [(i, l, e) for i, (l, e) in enumerate(entropy)]

with open(os.path.join(CACHE_DIR, cache_file_name, "w")) as f:
with open(os.path.join(CACHE_DIR, cache_file_name), "w") as f:
json.dump(entropy_tuples, f)

num_examples_per_class = num_total_samples // len(id2label.keys())
Expand All @@ -323,37 +340,40 @@ def cross_entropy_selection(
dataset["train"] = dataset["train"].select(selected_examples)
return dataset


def entropy_selection(
model_name_or_path: str,
dataset: DatasetDict,
num_total_samples: int,
task_keys: dict,
dataset_name: str,
) -> DatasetDict:
label_column = task_keys["label_column"]
id2label = dict(enumerate(dataset["train"].features[label_column].names))

cache_file_name = f"{dataset}-entropy.json"
cache_file_name = f"{dataset_name}-entropy.json"

if cache_file_name in os.listdir(CACHE_DIR):
with open(os.path.join(CACHE_DIR, cache_file_name), "r") as f:
entropy_tuples = json.load(f)
else:
model, tokenizer = get_classification_model_and_tokenizer(model_name_or_path)
model, tokenizer = get_classification_model_and_tokenizer(model_name_or_path, id2label=id2label)
train_loader = get_trainloader(dataset, tokenizer, task_keys)

entropy = []

with torch.no_grad():
for batch, targets in tqdm(train_loader):
outputs = model(**batch)
for batch in tqdm(train_loader):
labels = batch.pop("labels").detach().cpu().numpy().tolist()
outputs = model(**{k: v.to(model.device) for k, v in batch.items()})
dist = torch.nn.functional.softmax(outputs.logits, dim=-1)
entropy.extend(torch.distributions.Categorical(dist).entropy().cpu().numpy().tolist())
entropy.extend([(l, e) for l, e in zip(labels, torch.distributions.Categorical(dist).entropy().cpu().numpy().tolist())])

assert len(entropy) == len(dataset["train"])

entropy_tuples = [(i, l, e) for i, (l, e) in enumerate(zip(dataset["train"][label_column], entropy))]
entropy_tuples = [(i, l, e) for i, (l, e) in enumerate(entropy)]

with open(os.path.join(CACHE_DIR, cache_file_name, "w")) as f:
with open(os.path.join(CACHE_DIR, cache_file_name), "w") as f:
json.dump(entropy_tuples, f)

num_examples_per_class = num_total_samples // len(id2label.keys())
Expand Down

0 comments on commit e132131

Please sign in to comment.