From 4e44aae44c78ea643cd97781cb823303828704e9 Mon Sep 17 00:00:00 2001 From: Nathan Fargo <32229490+ntfargo@users.noreply.github.com> Date: Thu, 11 Jul 2024 12:44:16 +0200 Subject: [PATCH] src --- main.py | 129 +++++++++++++++++++++++++++++++++++++++ requirements.txt | 16 ++--- src/__init__.py | 0 src/classifier.py | 32 ++++++++++ src/denoising_network.py | 47 ++++++++++++++ src/feature_extractor.py | 32 ++++++++++ src/lnp_extractor.py | 10 +++ 7 files changed, 258 insertions(+), 8 deletions(-) create mode 100644 main.py create mode 100644 src/__init__.py create mode 100644 src/classifier.py create mode 100644 src/denoising_network.py create mode 100644 src/feature_extractor.py create mode 100644 src/lnp_extractor.py diff --git a/main.py b/main.py new file mode 100644 index 0000000..ba38611 --- /dev/null +++ b/main.py @@ -0,0 +1,129 @@ +import torch +from torchvision import transforms +from PIL import Image +import os +import numpy as np +import matplotlib.pyplot as plt +from src.denoising_network import DenoisingNetwork +from src.lnp_extractor import LNPExtractor +from src.feature_extractor import FeatureExtractor +from src.classifier import Classifier + +IMAGE_SIZE = (256, 256) +PLOT_FEATURES = 5 +REAL_IMAGES_PATH = 'data/real_images/' +TEST_IMAGES_PATH = 'data/test_images/' + +def load_image(path): + """Load and preprocess a single image.""" + try: + transform = transforms.Compose([ + transforms.Resize(IMAGE_SIZE), + transforms.ToTensor(), + ]) + with Image.open(path) as img: + return transform(img.convert('RGB')) + except Exception as e: + print(f"Error processing {path}: {str(e)}") + return None + +def load_images_from_directory(directory): + """Load all images from a directory.""" + images = [] + for filename in os.listdir(directory): + if filename.lower().endswith(('.png', '.jpg', '.jpeg')): + img_path = os.path.join(directory, filename) + img = load_image(img_path) + if img is not None: + images.append(img) + return images + +def plot_feature_distributions(real_features, test_features, filename): + """Plot feature distributions for real and test images.""" + real_features = np.array(real_features) + test_features = np.array(test_features) + + if test_features.ndim == 1: + test_features = test_features.reshape(1, -1) + + plt.figure(figsize=(15, 5)) + for i in range(min(PLOT_FEATURES, real_features.shape[1])): + plt.subplot(1, PLOT_FEATURES, i+1) + plt.hist(real_features[:, i], bins=20, alpha=0.5, label='Real') + plt.axvline(test_features[0, i], color='r', linestyle='dashed', linewidth=2, label='Test') + plt.title(f'Feature {i+1}') + if i == 0: + plt.legend() + plt.tight_layout() + plt.savefig(filename) + plt.close() + +def process_images(images, lnp_extractor, feature_extractor, device, image_type="real"): + """Process a batch of images and extract features.""" + features = [] + for i, img in enumerate(images): + img_tensor = img.unsqueeze(0).to(device) + lnp = lnp_extractor.extract_lnp(img_tensor) + feature = feature_extractor.extract_features(lnp.squeeze(0)) + features.append(feature.cpu().numpy()) + + print(f"Processed {image_type} image {i+1}/{len(images)}") + print(f" LNP shape: {lnp.shape}") + print(f" Features shape: {feature.shape}") + print(f" Features mean: {feature.mean().item():.4f}") + print(f" Features std: {feature.std().item():.4f}") + + return features + +def compare_features(real_features, test_features): + """Compare test features with real features.""" + real_features_array = np.array(real_features) + real_mean = real_features_array.mean(axis=0) + real_std = real_features_array.std(axis=0) + test_features_np = test_features.cpu().numpy() + + print("\nFeature comparison:") + print(f" Mean difference: {np.mean(np.abs(real_mean - test_features_np)):.4f}") + print(f" Std difference: {np.mean(np.abs(real_std - np.std(test_features_np))):.4f}") + +def main(): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + denoising_network = DenoisingNetwork().to(device) + lnp_extractor = LNPExtractor(denoising_network) + feature_extractor = FeatureExtractor() + classifier = Classifier() + + real_images = load_images_from_directory(REAL_IMAGES_PATH) + print(f"Number of real images loaded: {len(real_images)}") + if not real_images: + print("No valid images found for training.") + return + + real_features = process_images(real_images, lnp_extractor, feature_extractor, device) + + classifier.train(real_features) + print("Classifier trained") + + test_images = load_images_from_directory(TEST_IMAGES_PATH) + + for i, test_img in enumerate(test_images): + test_img_tensor = test_img.unsqueeze(0).to(device) + test_lnp = lnp_extractor.extract_lnp(test_img_tensor) + test_features = feature_extractor.extract_features(test_lnp.squeeze(0)) + + print(f"\nTest image {i+1}:") + print(f" LNP shape: {test_lnp.shape}") + print(f" Features shape: {test_features.shape}") + print(f" Features mean: {test_features.mean().item():.4f}") + print(f" Features std: {test_features.std().item():.4f}") + + result = classifier.predict([test_features.cpu().numpy()]) + print("Image is real" if result[0] == 1 else "Image is generated") + + plot_feature_distributions(real_features, test_features.cpu().numpy(), f'feature_distributions_test_{i+1}.png') + compare_features(real_features, test_features) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 50ed95f..444b281 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ -numpy==2.0.0 -torch==2.3.1 -torchvision==0.18.1 -scikit-learn==1.5.1 -pillow==10.4.0 -matplotlib==3.9.1 -scipy==1.14.0 -tqdm==4.66.4 \ No newline at end of file +numpy==1.21.0 +torch==1.9.0 +torchvision==0.10.0 +scikit-learn==0.24.2 +pillow==8.2.0 +matplotlib==3.4.2 +scipy==1.7.0 +tqdm==4.61.1 \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/classifier.py b/src/classifier.py new file mode 100644 index 0000000..264d973 --- /dev/null +++ b/src/classifier.py @@ -0,0 +1,32 @@ +from sklearn.ensemble import IsolationForest +from sklearn.svm import OneClassSVM +import numpy as np + +class Classifier: + def __init__(self, contamination=0.1, random_state=42): + self.isolation_forest = IsolationForest(contamination=contamination, random_state=random_state) + self.one_class_svm = OneClassSVM(kernel='rbf', nu=0.1) + + def train(self, features): + features_array = np.array(features) + print(f"Training classifiers with {len(features)} samples") + print(f"Feature array shape: {features_array.shape}") + print(f"Feature mean: {np.mean(features_array):.4f}") + print(f"Feature std: {np.std(features_array):.4f}") + self.isolation_forest.fit(features_array) + self.one_class_svm.fit(features_array) + + def predict(self, features): + features_array = np.array(features) + if_prediction = self.isolation_forest.predict(features_array) + if_decision = self.isolation_forest.decision_function(features_array) + svm_prediction = self.one_class_svm.predict(features_array) + svm_decision = self.one_class_svm.decision_function(features_array) + + print(f"Isolation Forest decision: {if_decision[0]:.4f}") + print(f"One-Class SVM decision: {svm_decision[0]:.4f}") + + # Combine predictions (1 if both predict 1, -1 otherwise) + combined_prediction = np.where((if_prediction == 1) & (svm_prediction == 1), 1, -1) + + return combined_prediction \ No newline at end of file diff --git a/src/denoising_network.py b/src/denoising_network.py new file mode 100644 index 0000000..e00c87c --- /dev/null +++ b/src/denoising_network.py @@ -0,0 +1,47 @@ +import torch +import torch.nn as nn + +class DualAttentionBlock(nn.Module): + def __init__(self, channels): + super(DualAttentionBlock, self).__init__() + self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) + self.conv3 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + residual = x + out = self.relu(self.conv1(x)) + out = self.relu(self.conv2(out)) + out = self.conv3(out) + return out + residual + +class RecursiveResidualGroup(nn.Module): + def __init__(self, channels): + super(RecursiveResidualGroup, self).__init__() + self.dab1 = DualAttentionBlock(channels) + self.dab2 = DualAttentionBlock(channels) + + def forward(self, x): + out = self.dab1(x) + out = self.dab2(out) + return out + x + +class DenoisingNetwork(nn.Module): + def __init__(self, channels=64): + super(DenoisingNetwork, self).__init__() + self.conv_in = nn.Conv2d(3, channels, kernel_size=3, padding=1) + self.rrg1 = RecursiveResidualGroup(channels) + self.rrg2 = RecursiveResidualGroup(channels) + self.rrg3 = RecursiveResidualGroup(channels) + self.rrg4 = RecursiveResidualGroup(channels) + self.conv_out = nn.Conv2d(channels, 3, kernel_size=3, padding=1) + + def forward(self, x): + out = self.conv_in(x) + out = self.rrg1(out) + out = self.rrg2(out) + out = self.rrg3(out) + out = self.rrg4(out) + out = self.conv_out(out) + return out \ No newline at end of file diff --git a/src/feature_extractor.py b/src/feature_extractor.py new file mode 100644 index 0000000..1216e21 --- /dev/null +++ b/src/feature_extractor.py @@ -0,0 +1,32 @@ +import torch + +class FeatureExtractor: + def __init__(self, k=32): + self.k = k + + def extract_features(self, lnp): + # Ensure lnp is 3D: (channels, height, width) + if lnp.dim() == 4: # (batch, channels, height, width) + lnp = lnp.squeeze(0) + elif lnp.dim() == 2: # (height, width) + lnp = lnp.unsqueeze(0) + + fft_lnp = torch.fft.fft2(lnp) + amplitude_spectrum = torch.abs(fft_lnp) + enhanced_spectrum = self._enhance_spectrum(amplitude_spectrum) + features = self._sample_features(enhanced_spectrum) + features = (features - features.mean()) / (features.std() + 1e-8) + + return features + + def _enhance_spectrum(self, spectrum): + A_u = torch.mean(spectrum, dim=1, keepdim=True) + enhanced = torch.where(spectrum < A_u, torch.zeros_like(spectrum), spectrum**2) + return enhanced + + def _sample_features(self, enhanced_spectrum): + C, H, W = enhanced_spectrum.shape + m_indices = torch.arange(0, H, self.k) + n_indices = torch.arange(0, W, self.k) + features = enhanced_spectrum[:, m_indices][:, :, n_indices].flatten() + return features \ No newline at end of file diff --git a/src/lnp_extractor.py b/src/lnp_extractor.py new file mode 100644 index 0000000..9fbf211 --- /dev/null +++ b/src/lnp_extractor.py @@ -0,0 +1,10 @@ +import torch + +class LNPExtractor: + def __init__(self, denoising_network): + self.denoising_network = denoising_network + + def extract_lnp(self, image): + with torch.no_grad(): + denoised = self.denoising_network(image) + return image - denoised \ No newline at end of file