diff --git a/src/common/dataset.rs b/src/common/dataset.rs index 48fe982..ebab05a 100644 --- a/src/common/dataset.rs +++ b/src/common/dataset.rs @@ -1,13 +1,16 @@ +use opencv::{ + core::{self, MatTraitConst}, + imgcodecs::{self, imread, IMREAD_GRAYSCALE}, +}; +use parking_lot::RwLock; +use rayon::prelude::*; +use rayon_progress::ProgressAdaptor; use std::{ collections::HashMap, fs, + path::PathBuf, sync::{Arc, Mutex}, }; - -use opencv::{ - core::{self, MatTraitConst}, - imgcodecs, -}; use tokio::{sync::Semaphore, task::JoinSet}; use tracing_unwrap::{OptionExt, ResultExt}; @@ -60,7 +63,16 @@ pub async fn split_dataset(dataset_path: &String, train_ratio: &f32) { } pub async fn count_classes(dataset_path: &String) { - let entries = fs::read_dir(dataset_path).unwrap(); + let mut entries: Vec = Vec::new(); + let dataset_path = PathBuf::from(dataset_path); + if dataset_path.is_file() { + entries.push(dataset_path.clone()); + } else { + entries = fs::read_dir(dataset_path.clone()) + .unwrap() + .map(|x| x.unwrap().path()) + .collect(); + } let type_map = Arc::new(Mutex::new(HashMap::::new())); let sem = Arc::new(Semaphore::new( @@ -68,31 +80,37 @@ pub async fn count_classes(dataset_path: &String) { )); let mut threads = JoinSet::new(); for entry in entries { - let entry = entry.unwrap(); + if entry.is_dir() { + continue; + } let type_map = Arc::clone(&type_map); let sem = Arc::clone(&sem); threads.spawn(async move { let _ = sem.acquire().await.unwrap(); - let image = image::open(entry.path()).unwrap(); - let image = image.as_luma8().unwrap(); - tracing::info!("Loaded image: {}", entry.path().display()); - let mut current_img_type_map = HashMap::::new(); - for (_, _, pixel) in image.enumerate_pixels() { - if !current_img_type_map.contains_key(&pixel[0]) { - current_img_type_map.insert(pixel[0], 1); - } else { - current_img_type_map - .insert(pixel[0], current_img_type_map.get(&pixel[0]).unwrap() + 1); + let img = imread(entry.to_str().unwrap(), IMREAD_GRAYSCALE).unwrap(); + let row = img.rows(); + let cols = img.cols(); + let img = Arc::new(RwLock::new(img)); + + let row_iter = ProgressAdaptor::new(0..row); + row_iter.for_each(|row_index| { + let mut row_type_map = HashMap::::new(); + let row = img.read().row(row_index).unwrap().clone_pointee(); + for col_index in 0..cols { + let pixel = row.at_2d::(0, col_index).unwrap(); + row_type_map + .entry(*pixel) + .and_modify(|x| *x += 1) + .or_insert(1); } - } - let mut type_map = type_map.lock().unwrap(); - for (class_id, count) in current_img_type_map.iter() { - let total_count = type_map.entry(*class_id).or_insert(0); - *total_count += count; - } - - tracing::info!("Image {} done", entry.path().display()); + let mut entry = type_map.lock().unwrap(); + for (class_id, count) in row_type_map.iter() { + let total_count = entry.entry(*class_id).or_insert(0); + *total_count += count; + } + }); + tracing::info!("Image {} done", entry.to_str().unwrap()); }); }