-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.yaml
52 lines (45 loc) · 2.72 KB
/
config.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
name: "Sample Config File" # Name of your experiment
random_state: 42 # Random Seed
wandb_logging: False # If true, the experiment is logged in W&B
data:
dataset: "reviews" # Dataset name: either "reviews", "AG_news", "movies" (this is the rationale one)
val_size: 0.2 # Validation split percentage
data_name: "rt-polarity" # Name of the saved precomputed embeddings folder: "rt-polarity", "AG_NEWS", "movie_rationals"
data_dir: "./src/data" # Source to the embeddings - don't change
compute_emb: False # If true, the embeddings are recomputed and saved - computationally expensive
train:
run_training: True # If to train or only run inference
batch_size: 128 # size of the trainng and test batches
epochs: 100 # Number of going through the data
verbose: False # If true, prints the training progress
model:
name: "proto" # Name of the model: "proto" is for protopnet, "MLP", "gpt-2", "bert"
submodel: "bert" # if not a protonet, called it same as "name"
embed_dim: 1024 # 1024 for sentence + bert, 768 for sentence + roberta/mpnet
n_classes: 2 # number of classes to predict -> 4 for AG_NEWS
freeze_layers: True # If the encoder weights are frozen
n_prototypes: 20 # Make sure that n_prototypes/n_classes gives integer
similaritymeasure: "weighted_cosine" # L1, L2, weighted_L2, cosine, weighted_cosine, dot_product, learned]
embedding: "sentence" # sentence or word
project: True # If we project the prototypes onto the training inputs
optimizer:
name: "Adam" # SGD or Adam
lr: 0.005 # Learning Rate
momentum: 0.9 # Momentum
weight_decay: 0.0005 # Weight Decay
betas: [0.9, 0.999] # ADAM betas
T_0: 10 # Warm up epochs for cosine annealing
scheduler:
name: "step" # Scheduler, either "poly", "step", "cosineannealing"
lr_reduce_factor: 0.5 # Learning rate reduction for "step"
patience_lr_reduce: 30 # Patience for "step"
poly_reduce: 0.9 # polynomial for "poly" reduction
loss:
lambda1: 0.2 # Cluster Loss
lambda2: 0.2 # Separation Loss
lambda3: 0.3 # Distribution Loss
lambda4: 0.1 # Diversity Loss
lambda5: 0.001 # L1 regularization
explain:
max_numbers: 3 # Maximal numbers of prototypes showed in the explanation
manual_input: False # Enables to input a sentence and receives the interpretability of it