-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
258 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import torch | ||
from torchvision import transforms | ||
from PIL import Image | ||
import os | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
from src.denoising_network import DenoisingNetwork | ||
from src.lnp_extractor import LNPExtractor | ||
from src.feature_extractor import FeatureExtractor | ||
from src.classifier import Classifier | ||
|
||
IMAGE_SIZE = (256, 256) | ||
PLOT_FEATURES = 5 | ||
REAL_IMAGES_PATH = 'data/real_images/' | ||
TEST_IMAGES_PATH = 'data/test_images/' | ||
|
||
def load_image(path): | ||
"""Load and preprocess a single image.""" | ||
try: | ||
transform = transforms.Compose([ | ||
transforms.Resize(IMAGE_SIZE), | ||
transforms.ToTensor(), | ||
]) | ||
with Image.open(path) as img: | ||
return transform(img.convert('RGB')) | ||
except Exception as e: | ||
print(f"Error processing {path}: {str(e)}") | ||
return None | ||
|
||
def load_images_from_directory(directory): | ||
"""Load all images from a directory.""" | ||
images = [] | ||
for filename in os.listdir(directory): | ||
if filename.lower().endswith(('.png', '.jpg', '.jpeg')): | ||
img_path = os.path.join(directory, filename) | ||
img = load_image(img_path) | ||
if img is not None: | ||
images.append(img) | ||
return images | ||
|
||
def plot_feature_distributions(real_features, test_features, filename): | ||
"""Plot feature distributions for real and test images.""" | ||
real_features = np.array(real_features) | ||
test_features = np.array(test_features) | ||
|
||
if test_features.ndim == 1: | ||
test_features = test_features.reshape(1, -1) | ||
|
||
plt.figure(figsize=(15, 5)) | ||
for i in range(min(PLOT_FEATURES, real_features.shape[1])): | ||
plt.subplot(1, PLOT_FEATURES, i+1) | ||
plt.hist(real_features[:, i], bins=20, alpha=0.5, label='Real') | ||
plt.axvline(test_features[0, i], color='r', linestyle='dashed', linewidth=2, label='Test') | ||
plt.title(f'Feature {i+1}') | ||
if i == 0: | ||
plt.legend() | ||
plt.tight_layout() | ||
plt.savefig(filename) | ||
plt.close() | ||
|
||
def process_images(images, lnp_extractor, feature_extractor, device, image_type="real"): | ||
"""Process a batch of images and extract features.""" | ||
features = [] | ||
for i, img in enumerate(images): | ||
img_tensor = img.unsqueeze(0).to(device) | ||
lnp = lnp_extractor.extract_lnp(img_tensor) | ||
feature = feature_extractor.extract_features(lnp.squeeze(0)) | ||
features.append(feature.cpu().numpy()) | ||
|
||
print(f"Processed {image_type} image {i+1}/{len(images)}") | ||
print(f" LNP shape: {lnp.shape}") | ||
print(f" Features shape: {feature.shape}") | ||
print(f" Features mean: {feature.mean().item():.4f}") | ||
print(f" Features std: {feature.std().item():.4f}") | ||
|
||
return features | ||
|
||
def compare_features(real_features, test_features): | ||
"""Compare test features with real features.""" | ||
real_features_array = np.array(real_features) | ||
real_mean = real_features_array.mean(axis=0) | ||
real_std = real_features_array.std(axis=0) | ||
test_features_np = test_features.cpu().numpy() | ||
|
||
print("\nFeature comparison:") | ||
print(f" Mean difference: {np.mean(np.abs(real_mean - test_features_np)):.4f}") | ||
print(f" Std difference: {np.mean(np.abs(real_std - np.std(test_features_np))):.4f}") | ||
|
||
def main(): | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
print(f"Using device: {device}") | ||
|
||
denoising_network = DenoisingNetwork().to(device) | ||
lnp_extractor = LNPExtractor(denoising_network) | ||
feature_extractor = FeatureExtractor() | ||
classifier = Classifier() | ||
|
||
real_images = load_images_from_directory(REAL_IMAGES_PATH) | ||
print(f"Number of real images loaded: {len(real_images)}") | ||
if not real_images: | ||
print("No valid images found for training.") | ||
return | ||
|
||
real_features = process_images(real_images, lnp_extractor, feature_extractor, device) | ||
|
||
classifier.train(real_features) | ||
print("Classifier trained") | ||
|
||
test_images = load_images_from_directory(TEST_IMAGES_PATH) | ||
|
||
for i, test_img in enumerate(test_images): | ||
test_img_tensor = test_img.unsqueeze(0).to(device) | ||
test_lnp = lnp_extractor.extract_lnp(test_img_tensor) | ||
test_features = feature_extractor.extract_features(test_lnp.squeeze(0)) | ||
|
||
print(f"\nTest image {i+1}:") | ||
print(f" LNP shape: {test_lnp.shape}") | ||
print(f" Features shape: {test_features.shape}") | ||
print(f" Features mean: {test_features.mean().item():.4f}") | ||
print(f" Features std: {test_features.std().item():.4f}") | ||
|
||
result = classifier.predict([test_features.cpu().numpy()]) | ||
print("Image is real" if result[0] == 1 else "Image is generated") | ||
|
||
plot_feature_distributions(real_features, test_features.cpu().numpy(), f'feature_distributions_test_{i+1}.png') | ||
compare_features(real_features, test_features) | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
numpy==2.0.0 | ||
torch==2.3.1 | ||
torchvision==0.18.1 | ||
scikit-learn==1.5.1 | ||
pillow==10.4.0 | ||
matplotlib==3.9.1 | ||
scipy==1.14.0 | ||
tqdm==4.66.4 | ||
numpy==1.21.0 | ||
torch==1.9.0 | ||
torchvision==0.10.0 | ||
scikit-learn==0.24.2 | ||
pillow==8.2.0 | ||
matplotlib==3.4.2 | ||
scipy==1.7.0 | ||
tqdm==4.61.1 |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from sklearn.ensemble import IsolationForest | ||
from sklearn.svm import OneClassSVM | ||
import numpy as np | ||
|
||
class Classifier: | ||
def __init__(self, contamination=0.1, random_state=42): | ||
self.isolation_forest = IsolationForest(contamination=contamination, random_state=random_state) | ||
self.one_class_svm = OneClassSVM(kernel='rbf', nu=0.1) | ||
|
||
def train(self, features): | ||
features_array = np.array(features) | ||
print(f"Training classifiers with {len(features)} samples") | ||
print(f"Feature array shape: {features_array.shape}") | ||
print(f"Feature mean: {np.mean(features_array):.4f}") | ||
print(f"Feature std: {np.std(features_array):.4f}") | ||
self.isolation_forest.fit(features_array) | ||
self.one_class_svm.fit(features_array) | ||
|
||
def predict(self, features): | ||
features_array = np.array(features) | ||
if_prediction = self.isolation_forest.predict(features_array) | ||
if_decision = self.isolation_forest.decision_function(features_array) | ||
svm_prediction = self.one_class_svm.predict(features_array) | ||
svm_decision = self.one_class_svm.decision_function(features_array) | ||
|
||
print(f"Isolation Forest decision: {if_decision[0]:.4f}") | ||
print(f"One-Class SVM decision: {svm_decision[0]:.4f}") | ||
|
||
# Combine predictions (1 if both predict 1, -1 otherwise) | ||
combined_prediction = np.where((if_prediction == 1) & (svm_prediction == 1), 1, -1) | ||
|
||
return combined_prediction |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
class DualAttentionBlock(nn.Module): | ||
def __init__(self, channels): | ||
super(DualAttentionBlock, self).__init__() | ||
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) | ||
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) | ||
self.conv3 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) | ||
self.relu = nn.ReLU(inplace=True) | ||
|
||
def forward(self, x): | ||
residual = x | ||
out = self.relu(self.conv1(x)) | ||
out = self.relu(self.conv2(out)) | ||
out = self.conv3(out) | ||
return out + residual | ||
|
||
class RecursiveResidualGroup(nn.Module): | ||
def __init__(self, channels): | ||
super(RecursiveResidualGroup, self).__init__() | ||
self.dab1 = DualAttentionBlock(channels) | ||
self.dab2 = DualAttentionBlock(channels) | ||
|
||
def forward(self, x): | ||
out = self.dab1(x) | ||
out = self.dab2(out) | ||
return out + x | ||
|
||
class DenoisingNetwork(nn.Module): | ||
def __init__(self, channels=64): | ||
super(DenoisingNetwork, self).__init__() | ||
self.conv_in = nn.Conv2d(3, channels, kernel_size=3, padding=1) | ||
self.rrg1 = RecursiveResidualGroup(channels) | ||
self.rrg2 = RecursiveResidualGroup(channels) | ||
self.rrg3 = RecursiveResidualGroup(channels) | ||
self.rrg4 = RecursiveResidualGroup(channels) | ||
self.conv_out = nn.Conv2d(channels, 3, kernel_size=3, padding=1) | ||
|
||
def forward(self, x): | ||
out = self.conv_in(x) | ||
out = self.rrg1(out) | ||
out = self.rrg2(out) | ||
out = self.rrg3(out) | ||
out = self.rrg4(out) | ||
out = self.conv_out(out) | ||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import torch | ||
|
||
class FeatureExtractor: | ||
def __init__(self, k=32): | ||
self.k = k | ||
|
||
def extract_features(self, lnp): | ||
# Ensure lnp is 3D: (channels, height, width) | ||
if lnp.dim() == 4: # (batch, channels, height, width) | ||
lnp = lnp.squeeze(0) | ||
elif lnp.dim() == 2: # (height, width) | ||
lnp = lnp.unsqueeze(0) | ||
|
||
fft_lnp = torch.fft.fft2(lnp) | ||
amplitude_spectrum = torch.abs(fft_lnp) | ||
enhanced_spectrum = self._enhance_spectrum(amplitude_spectrum) | ||
features = self._sample_features(enhanced_spectrum) | ||
features = (features - features.mean()) / (features.std() + 1e-8) | ||
|
||
return features | ||
|
||
def _enhance_spectrum(self, spectrum): | ||
A_u = torch.mean(spectrum, dim=1, keepdim=True) | ||
enhanced = torch.where(spectrum < A_u, torch.zeros_like(spectrum), spectrum**2) | ||
return enhanced | ||
|
||
def _sample_features(self, enhanced_spectrum): | ||
C, H, W = enhanced_spectrum.shape | ||
m_indices = torch.arange(0, H, self.k) | ||
n_indices = torch.arange(0, W, self.k) | ||
features = enhanced_spectrum[:, m_indices][:, :, n_indices].flatten() | ||
return features |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import torch | ||
|
||
class LNPExtractor: | ||
def __init__(self, denoising_network): | ||
self.denoising_network = denoising_network | ||
|
||
def extract_lnp(self, image): | ||
with torch.no_grad(): | ||
denoised = self.denoising_network(image) | ||
return image - denoised |