Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pseudo labeling #138

Open
wants to merge 71 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
13bb6cc
Nutella, yum
mbazzani Aug 4, 2023
9f80403
Add feature extraction, specific to efficientnet
mbazzani Aug 5, 2023
8b075f8
Pseudo labelling initial commit
benjamin-cates Aug 5, 2023
59a1eaf
Fix logger printing twice??? idfk
benjamin-cates Aug 6, 2023
f69cebc
Fix borked wandb logging
benjamin-cates Aug 6, 2023
bf4c816
Move pseudolabel init from run_raw to main
benjamin-cates Aug 6, 2023
982d279
Fix class_list bug
benjamin-cates Aug 7, 2023
8da2ddd
Only read audio files in pseudo directory
benjamin-cates Aug 7, 2023
567f555
Fix get_test_map checkpoint
benjamin-cates Aug 7, 2023
2a125f8
Merge remote-tracking branch 'origin/main' into nutella
mbazzani Aug 7, 2023
74d721d
Merge branch 'nutella' into pseudo
mbazzani Aug 7, 2023
0cc5508
Pylint fixes 🥰
benjamin-cates Aug 7, 2023
97f0204
Change pseudolabeling to finetuning
mbazzani Aug 7, 2023
9b7b1ec
Merge remote-tracking branch 'refs/remotes/origin/pseudo' into pseudo
mbazzani Aug 7, 2023
ae6b73d
Notela project
benjamin-cates Aug 7, 2023
d08c782
Add epsilon to mixup
mbazzani Aug 7, 2023
0c8654f
Bug fixes
mbazzani Aug 7, 2023
c59c68a
Merge remote-tracking branch 'refs/remotes/origin/pseudo' into pseudo
mbazzani Aug 7, 2023
1f6ad35
Add pseudo labeling parameters to config
mbazzani Aug 7, 2023
6041cbf
Linter happy!
mbazzani Aug 7, 2023
e1b87b8
Add wandb logging to pseudo
mbazzani Aug 7, 2023
e87470a
Enforce consistent naming
mbazzani Aug 7, 2023
830a76a
Remove old comments
mbazzani Aug 7, 2023
f10e5d7
Fix pyright
mbazzani Aug 8, 2023
9af622f
Fix pseudo label progress bar
benjamin-cates Aug 8, 2023
beb5f60
Move pseudolabel pipeline to own script
benjamin-cates Aug 8, 2023
b5853d9
Pylint 🥰
benjamin-cates Aug 8, 2023
dc03c90
Fix pseudolabel species saving
benjamin-cates Aug 8, 2023
69dec99
Partial refactor
mbazzani Aug 8, 2023
3fcfdce
Fix merge conflict
mbazzani Aug 8, 2023
58364c6
Formatting tweaks and bug fixes
mbazzani Aug 8, 2023
194714c
Update pseudolabel to use new training process
mbazzani Aug 8, 2023
c9f571a
Make log_metrics ~10000x prettier
mbazzani Aug 8, 2023
96bc175
Yay girliesgit add chunking_methods/csv_cleaner.py default_config.yml…
mbazzani Aug 9, 2023
a4b961a
BCEWL hotfix
benjamin-cates Aug 9, 2023
2c93f30
Fix imbalanced loss function
benjamin-cates Aug 9, 2023
224f415
Loss fn imbalance hotfix
benjamin-cates Aug 9, 2023
04edfef
Dataset now returns dataframe indices
mbazzani Aug 9, 2023
4c98f58
Dataset hotfix
benjamin-cates Aug 9, 2023
e13b11f
Merge remote-tracking branch 'refs/remotes/origin/pseudo' into pseudo
mbazzani Aug 9, 2023
573ca3f
wandb project changed back to "acoustic-species-reu2023"
mbazzani Aug 9, 2023
4390bfa
Fix class_list required warning
benjamin-cates Aug 9, 2023
2c02f45
Add nutella training loop
mbazzani Aug 10, 2023
3fe4042
Fix #144
mbazzani Aug 10, 2023
c253ad1
Pylint 🥰
benjamin-cates Aug 10, 2023
f4b4dd8
Fix test bug
benjamin-cates Aug 10, 2023
7036419
Fix clips/s metric
benjamin-cates Aug 10, 2023
b018e22
Debias pseudo labeling (#145)
benjamin-cates Aug 11, 2023
a772e0f
Fix debias clips/sec bug
benjamin-cates Aug 14, 2023
03d177c
Initial implementation
mbazzani Aug 14, 2023
1bc82e3
project_suffix option to wandb_init
mbazzani Aug 14, 2023
0f3f076
Bug fixes
mbazzani Aug 15, 2023
5ec63e6
Add eca_nfnet_l0 features
mbazzani Aug 15, 2023
68a5f27
Remove superfluous print statements
mbazzani Aug 15, 2023
7c51c0e
Add error message for missing pseudolabels
mbazzani Aug 15, 2023
26742ee
Linting fixes
mbazzani Aug 16, 2023
6787a4c
Fix bug with label existence edge cases
mbazzani Aug 18, 2023
23d2f6d
Fix type issues
mbazzani Aug 18, 2023
1291a1c
Resolve merge conflicts
mbazzani Aug 18, 2023
ef554cc
Move pseudolabel loading into nutella
mbazzani Aug 18, 2023
2115c23
Switch to using scipy distance
mbazzani Aug 18, 2023
03de22e
Set multiprocessing method to spawn
mbazzani Aug 18, 2023
bd3c304
Restore usage of all files for pseudolabeling
mbazzani Aug 18, 2023
86ef35f
Add options to choose between flattening and pooling features
mbazzani Aug 18, 2023
bb31add
Add Notela parameters
mbazzani Aug 18, 2023
9bca3ea
Fix max method call
mbazzani Aug 18, 2023
b705aae
Various bug fixes
mbazzani Aug 18, 2023
73d5c10
Pylint happy 🥰
mbazzani Aug 21, 2023
c34d927
Stop pyright being a silly little goose
mbazzani Aug 21, 2023
9d1ba3f
Remove excess print statement
mbazzani Aug 21, 2023
89d4270
Merge pull request #148 from UCSD-E4E/nutella_for_real
mbazzani Aug 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -414,14 +414,15 @@ confidence=HIGH,
# no Warning level messages displayed, use "--disable=all --enable=classes
# --disable=W".
disable=raw-checker-failed,
logging-fstring-interpolation,
bad-inline-option,
locally-disabled,
file-ignored,
suppressed-message,
useless-suppression,
deprecated-pragma,
use-symbolic-message-instead,
trailing-whitespace
trailing-whitespace
mbazzani marked this conversation as resolved.
Show resolved Hide resolved

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
Expand Down
9 changes: 5 additions & 4 deletions pyha_analyzer/aug_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@
This file contains methods that allow the visualization of
different data augmentations.
"""
from typing import Callable, List, Tuple, Dict, Any
from typing import Any, Callable, Dict, List, Tuple

import numpy as np
import torch
from matplotlib import cm
from matplotlib import pyplot as plt

from pyha_analyzer import config
from pyha_analyzer.utils import get_annotation
from pyha_analyzer.augmentations import (BackgroundNoise, LowpassFilter, Mixup,
HighpassFilter, RandomEQ, SyntheticNoise)
from pyha_analyzer.augmentations import (BackgroundNoise, HighpassFilter,
LowpassFilter, Mixup, RandomEQ,
SyntheticNoise)
from pyha_analyzer.dataset import PyhaDFDataset, get_datasets
from pyha_analyzer.utils import get_annotation

SYNTH_COLORS = ["white","pink","brown","violet","blue"]

Expand Down
2 changes: 1 addition & 1 deletion pyha_analyzer/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import os
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Iterable
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple

import numpy as np
import pandas as pd
Expand Down
6 changes: 3 additions & 3 deletions pyha_analyzer/chunking_methods/csv_cleaner.py
mbazzani marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"Offset": "OFFSET",
"Duration": "DURATION",
},

# OUTPUT
"cols_to_save": [
"FILE NAME",
Expand All @@ -34,14 +34,14 @@
"COMMON"
],
"output_path": "../example_dataset/metadata_cleaned.csv",

}

def main():
""" Main function """
if ARGUMENTS["input_path"] == "":
raise ValueError("Input path not specified")

if ARGUMENTS["has_index_col"]:
df = pd.read_csv(ARGUMENTS["input_path"], index_col=0)
else:
Expand Down
5 changes: 3 additions & 2 deletions pyha_analyzer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@
"""
import argparse
import logging
import os
import shutil
import sys
import os
from pathlib import Path

# "Repo" is not exported from module "git" Import from "git.repo" instead
# https://gitpython.readthedocs.io/en/stable/tutorial.html?highlight=repo#meet-the-repo-type
import git
import importlib_resources as pkg_resources
import yaml
from git import Repo # pyright: ignore [reportPrivateImportUsage]
from torch.cuda import is_available
import importlib_resources as pkg_resources

logger = logging.getLogger("acoustic_multiclass_training")

Expand Down
87 changes: 36 additions & 51 deletions pyha_analyzer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,25 @@
If this module is run directly, it tests that the dataloader works

"""
import ast
import logging
import os
from typing import List, Tuple, Optional
import ast
from typing import List, Optional, Tuple

import numpy as np
import pandas as pd
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torchaudio import transforms as audtr
from torchvision.transforms import RandomApply
from tqdm import tqdm
import wandb

from pyha_analyzer import config
from pyha_analyzer import utils
from pyha_analyzer.augmentations import (BackgroundNoise, LowpassFilter, Mixup, RandomEQ,
HighpassFilter, SyntheticNoise)
import wandb
from pyha_analyzer import config, utils
from pyha_analyzer.augmentations import (BackgroundNoise, HighpassFilter,
LowpassFilter, Mixup, RandomEQ,
SyntheticNoise)
from pyha_analyzer.chunking_methods import sliding_chunks

cfg = config.cfg
Expand Down Expand Up @@ -285,12 +285,12 @@ def get_sample_weights(self) -> pd.Series:
weight_list = self.samples[manual_id].apply(lambda x: sample_weights.loc[x])
return weight_list


def get_datasets() -> Tuple[PyhaDFDataset, PyhaDFDataset, Optional[PyhaDFDataset]]:
""" Returns train and validation datasets
does random sampling for train/valid split
adds transforms to dataset
"""

train_p = cfg.train_test_split
path = cfg.dataframe_csv
# Load the dataset
Expand Down Expand Up @@ -346,8 +346,9 @@ def get_datasets() -> Tuple[PyhaDFDataset, PyhaDFDataset, Optional[PyhaDFDataset
classes.add(species)
classes = list(classes)
classes.sort()
# pylint: disable-next=attribute-defined-outside-init
cfg.config_dict["class_list"] = classes
# pylint: disable-next=attribute-defined-outside-init
cfg.class_list = classes # type: ignore
wandb.config.update({"class_list": classes}, allow_val_change=True)

#for each species, get a random sample of files for train/valid split
Expand All @@ -358,10 +359,8 @@ def get_datasets() -> Tuple[PyhaDFDataset, PyhaDFDataset, Optional[PyhaDFDataset

valid = data[~data.index.isin(train.index)]
train_ds = PyhaDFDataset(train, train=True, species=classes)

valid_ds = PyhaDFDataset(valid, train=False, species=classes)




#Handle inference datasets
if cfg.infer_csv is None:
Expand All @@ -380,59 +379,45 @@ def set_torch_file_sharing(_) -> None:
"""
torch.multiprocessing.set_sharing_strategy("file_system")


def make_dataloaders(train_dataset, val_dataset, infer_dataset
)-> Tuple[DataLoader, DataLoader, Optional[DataLoader]]:
def get_dataloader(train_dataset, val_dataset, infer_dataset):
"""
Loads datasets and dataloaders for train and validation
Convenience wrapper to apply `make_dataloader` to all datasets
"""
train_dataloader = make_dataloader(train_dataset,cfg.train_batch_size,
cfg.does_weighted_sampling)
val_dataloader = make_dataloader(val_dataset,cfg.validation_batch_size)
if infer_dataset is None:
infer_dataloader = None
else:
infer_dataloader = make_dataloader(infer_dataset,cfg.validation_batch_size)
return train_dataloader, val_dataloader, infer_dataloader


# Create our dataloaders
# if sampler function is "specified, shuffle must not be specified."
# https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader

if cfg.does_weighted_sampling:
if train_dataset.samples[cfg.manual_id_col].any(lambda x: isinstance(x,dict)):
def make_dataloader(dataset, batch_size, weighted_sampling=False, shuffle=True):
""" Creates a torch DataLoader from a PyhaDFDataset """
if weighted_sampling:
if dataset.samples[cfg.manual_id_col].any(lambda x: isinstance(x,dict)):
raise NotImplementedError("Weighted sampling not implemented for overlapping targets")
# Code used from:
# https://www.kaggle.com/competitions/birdclef-2023/discussion/412808
# Get Sample Weights
weights_list = train_dataset.get_sample_weights()
weights_list = dataset.get_sample_weights()
sampler = WeightedRandomSampler(weights_list, len(weights_list))
train_dataloader = DataLoader(
train_dataset,
# if sampler function is "specified, shuffle must not be specified."
# https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
return DataLoader(
dataset,
cfg.train_batch_size,
sampler=sampler,
num_workers=cfg.jobs,
worker_init_fn=set_torch_file_sharing
)
else:
train_dataloader = DataLoader(
train_dataset,
cfg.train_batch_size,
shuffle=True,
num_workers=cfg.jobs,
worker_init_fn=set_torch_file_sharing
)

val_dataloader = DataLoader(
val_dataset,
cfg.validation_batch_size,
shuffle=False,
return DataLoader(
dataset,
batch_size,
shuffle=shuffle,
num_workers=cfg.jobs,
worker_init_fn=set_torch_file_sharing
)
if infer_dataset is None:
infer_dataloader = None
else:
infer_dataloader = DataLoader(
infer_dataset,
cfg.validation_batch_size,
shuffle=False,
num_workers=cfg.jobs,
worker_init_fn=set_torch_file_sharing
)
return train_dataloader, val_dataloader, infer_dataloader

def main() -> None:
"""
Expand Down
7 changes: 5 additions & 2 deletions pyha_analyzer/default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ infer_csv:
# Optional, automatically generates class order if not given
class_list:

# Pseudo-labeling if running pyha_analyzer/pseudolabel.py
pseudo_data_augs: true
pseudo_threshold: 0.7

# Dataframe column names
offset_col: "OFFSET"
duration_col: "DURATION"
Expand Down Expand Up @@ -40,7 +44,7 @@ chunk_margin_s: 0.4 # Margin on either side of annotation
logging: true
logging_freq: 20 # Batches per log
wandb_entity: "acoustic-species-identification"
wandb_project: "acoustic-species-reu2023"
wandb_project: "nutella"
mbazzani marked this conversation as resolved.
Show resolved Hide resolved
wandb_run_name: "auto"
debug: false

Expand Down Expand Up @@ -81,7 +85,6 @@ sweep_id:
early_stopping: false
patience: 3
min_valid_map_delta: 0.01

# Data augmentation probabilities
mixup_p: 0.4
time_shift_p: 0
Expand Down
54 changes: 0 additions & 54 deletions pyha_analyzer/get_test_map.py
mbazzani marked this conversation as resolved.
Show resolved Hide resolved

This file was deleted.

Loading
Loading