Skip to content

Commit

Permalink
cifar10 dataloader, no suitable network implemented yet
Browse files Browse the repository at this point in the history
  • Loading branch information
b-d-e committed Dec 21, 2024
1 parent 4e6ae65 commit 59a1f76
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 21 deletions.
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/target
.gz
*.gz

data/
!src/data
cargo.lock
cargo.lock
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,6 @@ Epoch 4/5
Epoch 5/5
[00:00:06] ======================================== 390/390 Epoch avg loss: 0.0149 | Train Accuracy: 98.44% | Val Accuracy: 97.24%
```
and log some ascii graphs at the end.
and log some ascii graphs at the end.

CIFAR10 _can_ be used, but currently we are treating colour very naively (flattening) to use as MLP input, so performance is poor. Will explore implementing convlayers as a follow up.
87 changes: 87 additions & 0 deletions src/data/cifar.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
use ndarray::{Array2, s};
use std::path::Path;
use super::dataset::Dataset;

pub struct Cifar10Data {
pub train_images: Array2<f32>,
pub train_labels: Array2<f32>,
pub test_images: Array2<f32>,
pub test_labels: Array2<f32>,
}

impl Dataset for Cifar10Data {
fn new() -> Result<Self, Box<dyn std::error::Error>> {
if !Path::new("data").exists() {
std::fs::create_dir("data")?;
}

let base_url = "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz";

// Download and extract CIFAR-10
if !Path::new("data/cifar-10-batches-bin").exists() {
println!("Downloading CIFAR-10 dataset...");
std::process::Command::new("curl")
.args(&["-O", base_url])
.output()?;

std::process::Command::new("tar")
.args(&["-xzf", "cifar-10-binary.tar.gz"])
.output()?;

std::fs::rename("cifar-10-batches-bin", "data/cifar-10-batches-bin")?;
}

// Load training data (5 batches)
let mut train_images = Vec::new();
let mut train_labels = Vec::new();

for i in 1..=5 {
let path = format!("data/cifar-10-batches-bin/data_batch_{}.bin", i);
let data = std::fs::read(path)?;

for chunk in data.chunks(3073) {
let label = chunk[0] as usize;
let pixels = &chunk[1..];

let mut one_hot = vec![0.0; 10];
one_hot[label] = 1.0;
train_labels.extend(one_hot);
train_images.extend(pixels.iter().map(|&x| x as f32 / 255.0));
}
}

// Load test data
let test_data = std::fs::read("data/cifar-10-batches-bin/test_batch.bin")?;
let mut test_images = Vec::new();
let mut test_labels = Vec::new();

for chunk in test_data.chunks(3073) {
let label = chunk[0] as usize;
let pixels = &chunk[1..];

let mut one_hot = vec![0.0; 10];
one_hot[label] = 1.0;
test_labels.extend(one_hot);
test_images.extend(pixels.iter().map(|&x| x as f32 / 255.0));
}

Ok(Cifar10Data {
train_images: Array2::from_shape_vec((50_000, 3072), train_images)?,
train_labels: Array2::from_shape_vec((50_000, 10), train_labels)?,
test_images: Array2::from_shape_vec((10_000, 3072), test_images)?,
test_labels: Array2::from_shape_vec((10_000, 10), test_labels)?,
})
}

fn get_batch(&self, start: usize, batch_size: usize) -> (Array2<f32>, Array2<f32>) {
let end = start + batch_size;
let batch_images = self.train_images.slice(s![start..end, ..]).to_owned();
let batch_labels = self.train_labels.slice(s![start..end, ..]).to_owned();
(batch_images, batch_labels)
}

fn get_train_size(&self) -> usize { 50_000 }
fn get_test_size(&self) -> usize { 10_000 }
fn get_input_size(&self) -> usize { 3072 }
fn get_num_classes(&self) -> usize { 10 }
}
10 changes: 10 additions & 0 deletions src/data/dataset.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
use ndarray::Array2;

pub trait Dataset {
fn new() -> Result<Self, Box<dyn std::error::Error>> where Self: Sized;
fn get_batch(&self, start: usize, batch_size: usize) -> (Array2<f32>, Array2<f32>);
fn get_train_size(&self) -> usize;
fn get_test_size(&self) -> usize;
fn get_num_classes(&self) -> usize;
fn get_input_size(&self) -> usize;
}
16 changes: 10 additions & 6 deletions src/data/mnist.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use ndarray::{Array2, s};
use mnist::MnistBuilder;
use std::path::Path;
use super::dataset::Dataset;

pub struct MnistData {
pub train_images: Array2<f32>,
Expand All @@ -9,14 +10,13 @@ pub struct MnistData {
pub test_labels: Array2<f32>,
}

impl MnistData {
pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
// Create data directory if it doesn't exist
impl Dataset for MnistData {
fn new() -> Result<Self, Box<dyn std::error::Error>> {
if !Path::new("data").exists() {
std::fs::create_dir("data")?;
}

// Download the dataset - mnist crate has broken LeCun URL
// Download the dataset if needed
if !Path::new("data/train-images-idx3-ubyte").exists() {
println!("Downloading MNIST dataset...");
std::process::Command::new("curl")
Expand Down Expand Up @@ -75,7 +75,6 @@ impl MnistData {
.test_set_length(10_000)
.finalize();

// Rest of the implementation remains the same...
let train_images = Array2::from_shape_vec(
(50_000, 784),
mnist.trn_img.iter().map(|&x| x as f32 / 255.0).collect()
Expand Down Expand Up @@ -104,10 +103,15 @@ impl MnistData {
})
}

pub fn get_batch(&self, start: usize, batch_size: usize) -> (Array2<f32>, Array2<f32>) {
fn get_batch(&self, start: usize, batch_size: usize) -> (Array2<f32>, Array2<f32>) {
let end = start + batch_size;
let batch_images = self.train_images.slice(s![start..end, ..]).to_owned();
let batch_labels = self.train_labels.slice(s![start..end, ..]).to_owned();
(batch_images, batch_labels)
}

fn get_train_size(&self) -> usize { 50_000 }
fn get_test_size(&self) -> usize { 10_000 }
fn get_input_size(&self) -> usize { 784 }
fn get_num_classes(&self) -> usize { 10 }
}
8 changes: 7 additions & 1 deletion src/data/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
pub mod mnist;
pub mod dataset; // The trait definition
pub mod mnist;
pub mod cifar;

pub use dataset::Dataset;
pub use mnist::MnistData;
pub use cifar::Cifar10Data;
22 changes: 11 additions & 11 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,27 @@ use corroded_classifier::{
network::network::Network,
network::layer::Layer,
network::activation::{ReLU, Sigmoid},
data::mnist::MnistData,
data::{Dataset, MnistData, Cifar10Data},
};
use textplots::{Chart, Plot, Shape};

fn main() -> Result<(), Box<dyn std::error::Error>> {
// Training parameters
let batch_size = 128;
let batch_size = 32;
let epochs = 5;
let hidden_size = 128;
let learning_rate = 0.1;

// Load MNIST data
let mnist = MnistData::new()?;
// Load MNIST or CIFAR data
let dataset = MnistData::new()?;
// let dataset = Cifar10Data::new()?;

// Create network
let mut network = Network::new(learning_rate);

// Add layers
network.add_layer(Layer::new(784, hidden_size, Box::new(ReLU)));
network.add_layer(Layer::new(hidden_size, 10, Box::new(Sigmoid)));

network.add_layer(Layer::new(dataset.get_input_size(), hidden_size, Box::new(ReLU)));
network.add_layer(Layer::new(hidden_size, dataset.get_num_classes(), Box::new(Sigmoid)));

// vectors to track metrics for graphs
let mut losses = Vec::new();
Expand All @@ -37,7 +37,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
for epoch in 0..epochs {
println!("Epoch {}/{}", epoch + 1, epochs);
let mut total_loss = 0.0;
let num_batches = mnist.train_images.nrows() / batch_size;
let num_batches = dataset.train_images.nrows() / batch_size;

let progress_bar = ProgressBar::new(num_batches as u64);
progress_bar.set_style(ProgressStyle::default_bar()
Expand All @@ -47,7 +47,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {

for batch in 0..num_batches {
let start = batch * batch_size;
let (batch_images, batch_labels) = mnist.get_batch(start, batch_size);
let (batch_images, batch_labels) = dataset.get_batch(start, batch_size);

let mut batch_loss = 0.0;
for i in 0..batch_size {
Expand All @@ -70,9 +70,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let epoch_loss = total_loss / (num_batches * batch_size) as f32;

// Calculate epoch accuracy on whole training set
let epoch_train_accuracy = network.calculate_accuracy(&mnist.train_images, &mnist.train_labels) * 100.0;
let epoch_train_accuracy = network.calculate_accuracy(&dataset.train_images, &dataset.train_labels) * 100.0;

let epoch_val_accuracy = network.calculate_accuracy(&mnist.test_images, &mnist.test_labels) * 100.0;
let epoch_val_accuracy = network.calculate_accuracy(&dataset.test_images, &dataset.test_labels) * 100.0;

progress_bar.finish_with_message(
format!("Epoch avg loss: {:.4} | Train Accuracy: {:.2}% | Val Accuracy: {:.2}%", epoch_loss, epoch_train_accuracy, epoch_val_accuracy)
Expand Down

0 comments on commit 59a1f76

Please sign in to comment.