Skip to content

Commit

Permalink
feat: move count_classes to opencv
Browse files Browse the repository at this point in the history
  • Loading branch information
4o3F committed Oct 10, 2024
1 parent 1128e0e commit bc26d58
Showing 1 changed file with 43 additions and 25 deletions.
68 changes: 43 additions & 25 deletions src/common/dataset.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
use opencv::{
core::{self, MatTraitConst},
imgcodecs::{self, imread, IMREAD_GRAYSCALE},
};
use parking_lot::RwLock;
use rayon::prelude::*;
use rayon_progress::ProgressAdaptor;
use std::{
collections::HashMap,
fs,
path::PathBuf,
sync::{Arc, Mutex},
};

use opencv::{
core::{self, MatTraitConst},
imgcodecs,
};
use tokio::{sync::Semaphore, task::JoinSet};
use tracing_unwrap::{OptionExt, ResultExt};

Expand Down Expand Up @@ -60,39 +63,54 @@ pub async fn split_dataset(dataset_path: &String, train_ratio: &f32) {
}

pub async fn count_classes(dataset_path: &String) {
let entries = fs::read_dir(dataset_path).unwrap();
let mut entries: Vec<PathBuf> = Vec::new();
let dataset_path = PathBuf::from(dataset_path);
if dataset_path.is_file() {
entries.push(dataset_path.clone());
} else {
entries = fs::read_dir(dataset_path.clone())
.unwrap()
.map(|x| x.unwrap().path())
.collect();
}

let type_map = Arc::new(Mutex::new(HashMap::<u8, i32>::new()));
let sem = Arc::new(Semaphore::new(
(*THREAD_POOL.read().expect_or_log("Get pool error")).into(),
));
let mut threads = JoinSet::new();
for entry in entries {
let entry = entry.unwrap();
if entry.is_dir() {
continue;
}
let type_map = Arc::clone(&type_map);
let sem = Arc::clone(&sem);
threads.spawn(async move {
let _ = sem.acquire().await.unwrap();
let image = image::open(entry.path()).unwrap();
let image = image.as_luma8().unwrap();
tracing::info!("Loaded image: {}", entry.path().display());
let mut current_img_type_map = HashMap::<u8, i32>::new();
for (_, _, pixel) in image.enumerate_pixels() {
if !current_img_type_map.contains_key(&pixel[0]) {
current_img_type_map.insert(pixel[0], 1);
} else {
current_img_type_map
.insert(pixel[0], current_img_type_map.get(&pixel[0]).unwrap() + 1);
let img = imread(entry.to_str().unwrap(), IMREAD_GRAYSCALE).unwrap();
let row = img.rows();
let cols = img.cols();
let img = Arc::new(RwLock::new(img));

let row_iter = ProgressAdaptor::new(0..row);
row_iter.for_each(|row_index| {
let mut row_type_map = HashMap::<u8, i32>::new();
let row = img.read().row(row_index).unwrap().clone_pointee();
for col_index in 0..cols {
let pixel = row.at_2d::<u8>(0, col_index).unwrap();
row_type_map
.entry(*pixel)
.and_modify(|x| *x += 1)
.or_insert(1);
}
}

let mut type_map = type_map.lock().unwrap();
for (class_id, count) in current_img_type_map.iter() {
let total_count = type_map.entry(*class_id).or_insert(0);
*total_count += count;
}

tracing::info!("Image {} done", entry.path().display());
let mut entry = type_map.lock().unwrap();
for (class_id, count) in row_type_map.iter() {
let total_count = entry.entry(*class_id).or_insert(0);
*total_count += count;
}
});
tracing::info!("Image {} done", entry.to_str().unwrap());
});
}

Expand Down

0 comments on commit bc26d58

Please sign in to comment.