Skip to content

Commit

Permalink
src
Browse files Browse the repository at this point in the history
  • Loading branch information
ntfargo committed Jul 11, 2024
1 parent c9561fd commit 4e44aae
Show file tree
Hide file tree
Showing 7 changed files with 258 additions and 8 deletions.
129 changes: 129 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import torch
from torchvision import transforms
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
from src.denoising_network import DenoisingNetwork
from src.lnp_extractor import LNPExtractor
from src.feature_extractor import FeatureExtractor
from src.classifier import Classifier

IMAGE_SIZE = (256, 256)
PLOT_FEATURES = 5
REAL_IMAGES_PATH = 'data/real_images/'
TEST_IMAGES_PATH = 'data/test_images/'

def load_image(path):
"""Load and preprocess a single image."""
try:
transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor(),
])
with Image.open(path) as img:
return transform(img.convert('RGB'))
except Exception as e:
print(f"Error processing {path}: {str(e)}")
return None

def load_images_from_directory(directory):
"""Load all images from a directory."""
images = []
for filename in os.listdir(directory):
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
img_path = os.path.join(directory, filename)
img = load_image(img_path)
if img is not None:
images.append(img)
return images

def plot_feature_distributions(real_features, test_features, filename):
"""Plot feature distributions for real and test images."""
real_features = np.array(real_features)
test_features = np.array(test_features)

if test_features.ndim == 1:
test_features = test_features.reshape(1, -1)

plt.figure(figsize=(15, 5))
for i in range(min(PLOT_FEATURES, real_features.shape[1])):
plt.subplot(1, PLOT_FEATURES, i+1)
plt.hist(real_features[:, i], bins=20, alpha=0.5, label='Real')
plt.axvline(test_features[0, i], color='r', linestyle='dashed', linewidth=2, label='Test')
plt.title(f'Feature {i+1}')
if i == 0:
plt.legend()
plt.tight_layout()
plt.savefig(filename)
plt.close()

def process_images(images, lnp_extractor, feature_extractor, device, image_type="real"):
"""Process a batch of images and extract features."""
features = []
for i, img in enumerate(images):
img_tensor = img.unsqueeze(0).to(device)
lnp = lnp_extractor.extract_lnp(img_tensor)
feature = feature_extractor.extract_features(lnp.squeeze(0))
features.append(feature.cpu().numpy())

print(f"Processed {image_type} image {i+1}/{len(images)}")
print(f" LNP shape: {lnp.shape}")
print(f" Features shape: {feature.shape}")
print(f" Features mean: {feature.mean().item():.4f}")
print(f" Features std: {feature.std().item():.4f}")

return features

def compare_features(real_features, test_features):
"""Compare test features with real features."""
real_features_array = np.array(real_features)
real_mean = real_features_array.mean(axis=0)
real_std = real_features_array.std(axis=0)
test_features_np = test_features.cpu().numpy()

print("\nFeature comparison:")
print(f" Mean difference: {np.mean(np.abs(real_mean - test_features_np)):.4f}")
print(f" Std difference: {np.mean(np.abs(real_std - np.std(test_features_np))):.4f}")

def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

denoising_network = DenoisingNetwork().to(device)
lnp_extractor = LNPExtractor(denoising_network)
feature_extractor = FeatureExtractor()
classifier = Classifier()

real_images = load_images_from_directory(REAL_IMAGES_PATH)
print(f"Number of real images loaded: {len(real_images)}")
if not real_images:
print("No valid images found for training.")
return

real_features = process_images(real_images, lnp_extractor, feature_extractor, device)

classifier.train(real_features)
print("Classifier trained")

test_images = load_images_from_directory(TEST_IMAGES_PATH)

for i, test_img in enumerate(test_images):
test_img_tensor = test_img.unsqueeze(0).to(device)
test_lnp = lnp_extractor.extract_lnp(test_img_tensor)
test_features = feature_extractor.extract_features(test_lnp.squeeze(0))

print(f"\nTest image {i+1}:")
print(f" LNP shape: {test_lnp.shape}")
print(f" Features shape: {test_features.shape}")
print(f" Features mean: {test_features.mean().item():.4f}")
print(f" Features std: {test_features.std().item():.4f}")

result = classifier.predict([test_features.cpu().numpy()])
print("Image is real" if result[0] == 1 else "Image is generated")

plot_feature_distributions(real_features, test_features.cpu().numpy(), f'feature_distributions_test_{i+1}.png')
compare_features(real_features, test_features)

if __name__ == "__main__":
main()
16 changes: 8 additions & 8 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
numpy==2.0.0
torch==2.3.1
torchvision==0.18.1
scikit-learn==1.5.1
pillow==10.4.0
matplotlib==3.9.1
scipy==1.14.0
tqdm==4.66.4
numpy==1.21.0
torch==1.9.0
torchvision==0.10.0
scikit-learn==0.24.2
pillow==8.2.0
matplotlib==3.4.2
scipy==1.7.0
tqdm==4.61.1
Empty file added src/__init__.py
Empty file.
32 changes: 32 additions & 0 deletions src/classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from sklearn.ensemble import IsolationForest
from sklearn.svm import OneClassSVM
import numpy as np

class Classifier:
def __init__(self, contamination=0.1, random_state=42):
self.isolation_forest = IsolationForest(contamination=contamination, random_state=random_state)
self.one_class_svm = OneClassSVM(kernel='rbf', nu=0.1)

def train(self, features):
features_array = np.array(features)
print(f"Training classifiers with {len(features)} samples")
print(f"Feature array shape: {features_array.shape}")
print(f"Feature mean: {np.mean(features_array):.4f}")
print(f"Feature std: {np.std(features_array):.4f}")
self.isolation_forest.fit(features_array)
self.one_class_svm.fit(features_array)

def predict(self, features):
features_array = np.array(features)
if_prediction = self.isolation_forest.predict(features_array)
if_decision = self.isolation_forest.decision_function(features_array)
svm_prediction = self.one_class_svm.predict(features_array)
svm_decision = self.one_class_svm.decision_function(features_array)

print(f"Isolation Forest decision: {if_decision[0]:.4f}")
print(f"One-Class SVM decision: {svm_decision[0]:.4f}")

# Combine predictions (1 if both predict 1, -1 otherwise)
combined_prediction = np.where((if_prediction == 1) & (svm_prediction == 1), 1, -1)

return combined_prediction
47 changes: 47 additions & 0 deletions src/denoising_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
import torch.nn as nn

class DualAttentionBlock(nn.Module):
def __init__(self, channels):
super(DualAttentionBlock, self).__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)

def forward(self, x):
residual = x
out = self.relu(self.conv1(x))
out = self.relu(self.conv2(out))
out = self.conv3(out)
return out + residual

class RecursiveResidualGroup(nn.Module):
def __init__(self, channels):
super(RecursiveResidualGroup, self).__init__()
self.dab1 = DualAttentionBlock(channels)
self.dab2 = DualAttentionBlock(channels)

def forward(self, x):
out = self.dab1(x)
out = self.dab2(out)
return out + x

class DenoisingNetwork(nn.Module):
def __init__(self, channels=64):
super(DenoisingNetwork, self).__init__()
self.conv_in = nn.Conv2d(3, channels, kernel_size=3, padding=1)
self.rrg1 = RecursiveResidualGroup(channels)
self.rrg2 = RecursiveResidualGroup(channels)
self.rrg3 = RecursiveResidualGroup(channels)
self.rrg4 = RecursiveResidualGroup(channels)
self.conv_out = nn.Conv2d(channels, 3, kernel_size=3, padding=1)

def forward(self, x):
out = self.conv_in(x)
out = self.rrg1(out)
out = self.rrg2(out)
out = self.rrg3(out)
out = self.rrg4(out)
out = self.conv_out(out)
return out
32 changes: 32 additions & 0 deletions src/feature_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch

class FeatureExtractor:
def __init__(self, k=32):
self.k = k

def extract_features(self, lnp):
# Ensure lnp is 3D: (channels, height, width)
if lnp.dim() == 4: # (batch, channels, height, width)
lnp = lnp.squeeze(0)
elif lnp.dim() == 2: # (height, width)
lnp = lnp.unsqueeze(0)

fft_lnp = torch.fft.fft2(lnp)
amplitude_spectrum = torch.abs(fft_lnp)
enhanced_spectrum = self._enhance_spectrum(amplitude_spectrum)
features = self._sample_features(enhanced_spectrum)
features = (features - features.mean()) / (features.std() + 1e-8)

return features

def _enhance_spectrum(self, spectrum):
A_u = torch.mean(spectrum, dim=1, keepdim=True)
enhanced = torch.where(spectrum < A_u, torch.zeros_like(spectrum), spectrum**2)
return enhanced

def _sample_features(self, enhanced_spectrum):
C, H, W = enhanced_spectrum.shape
m_indices = torch.arange(0, H, self.k)
n_indices = torch.arange(0, W, self.k)
features = enhanced_spectrum[:, m_indices][:, :, n_indices].flatten()
return features
10 changes: 10 additions & 0 deletions src/lnp_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch

class LNPExtractor:
def __init__(self, denoising_network):
self.denoising_network = denoising_network

def extract_lnp(self, image):
with torch.no_grad():
denoised = self.denoising_network(image)
return image - denoised

0 comments on commit 4e44aae

Please sign in to comment.