diff --git a/requirements.txt b/requirements.txt index 444b281..8769b45 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,7 @@ -numpy==1.21.0 -torch==1.9.0 -torchvision==0.10.0 -scikit-learn==0.24.2 -pillow==8.2.0 -matplotlib==3.4.2 -scipy==1.7.0 -tqdm==4.61.1 \ No newline at end of file +numpy>=1.21.0,<1.24.0 +torch>=1.12.0,<1.13.0 +torchvision>=0.13.0,<0.14.0 +scikit-learn>=1.0.0,<1.6.0 +pillow>=8.0.0,<10.4.0 +matplotlib>=3.4.0,<3.5.0 +scipy>=1.6.0,<1.8.0 diff --git a/src/feature_extractor.py b/src/feature_extractor.py index 0147d23..a2d3b54 100644 --- a/src/feature_extractor.py +++ b/src/feature_extractor.py @@ -6,19 +6,15 @@ def __init__(self, k=32): self.k = k def extract_features(self, lnp): - # Existing feature extraction amplitude_spectrum = torch.abs(torch.fft.fft2(lnp)) + enhanced_spectrum = self._enhance_spectrum(amplitude_spectrum) - sampled_features = self._sample_features(enhanced_spectrum) - - # New features + sampled_features = self._sample_features(enhanced_spectrum) gradient_features = self._gradient_features(lnp) noise_features = self._noise_features(lnp) - # Combine all features all_features = torch.cat([sampled_features, gradient_features, noise_features]) - # Normalize return (all_features - all_features.mean()) / (all_features.std() + 1e-8) def _enhance_spectrum(self, spectrum):