Skip to content

Commit

Permalink
Move common imports to top-level; fix model saving bugs.
Browse files Browse the repository at this point in the history
  • Loading branch information
tymorrow committed Aug 15, 2024
1 parent f8f5088 commit e9ab4c1
Show file tree
Hide file tree
Showing 46 changed files with 1,442 additions and 1,413 deletions.
68 changes: 34 additions & 34 deletions examples/courses/Primer 1/Primer1.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -155,15 +155,15 @@
"from riid.gadras.api import GADRAS_API_SEEMINGLY_AVAILABLE\n",
"\n",
"if GADRAS_API_SEEMINGLY_AVAILABLE:\n",
" from riid.data.synthetic.seed import SeedSynthesizer\n",
" from riid import SeedSynthesizer\n",
" seed_syn = SeedSynthesizer()\n",
" # The YAML file defining the seed synthesis specification is ultimately parsed into a dictionary.\n",
" # You can also load it yourself and pass in the dictionary instead - this is useful for varying detector parameters!\n",
" seeds_ss = seed_syn.generate(\"./spec_nai_few_sources.yaml\")\n",
"else:\n",
" # If you don't have Windows with GADRAS installed, this will use the dummy seeds below which are not actual gamma spectra.\n",
" # Another option would be to load a seeds file obtained elsewhere.\n",
" from riid.data.synthetic import get_dummy_seeds\n",
" from riid import get_dummy_seeds\n",
" seeds_ss = get_dummy_seeds()"
]
},
Expand Down Expand Up @@ -251,7 +251,7 @@
"outputs": [],
"source": [
"\"\"\"Seed mixing\"\"\"\n",
"from riid.data.synthetic.seed import SeedMixer\n",
"from riid import SeedMixer\n",
"\n",
"mixed_bg_seeds_ss = SeedMixer(\n",
" bg_seeds_ss,\n",
Expand All @@ -278,7 +278,7 @@
"outputs": [],
"source": [
"\"\"\"Combining SampleSets\"\"\"\n",
"from riid.data.sampleset import SampleSet\n",
"from riid import SampleSet\n",
"\n",
"combined_ss = SampleSet()\n",
"combined_ss.concat([fg_seeds_ss, mixed_bg_seeds_ss])\n",
Expand Down Expand Up @@ -318,14 +318,14 @@
"outputs": [],
"source": [
"\"\"\"Static Synthesis\"\"\"\n",
"from riid.data.synthetic.static import StaticSynthesizer\n",
"from riid import StaticSynthesizer\n",
"\n",
"static_syn = StaticSynthesizer(\n",
" samples_per_seed=100,\n",
" bg_cps=300,\n",
" live_time_function=\"uniform\",\n",
" live_time_function_args=(0.25, 8),\n",
" snr_function=\"uniform\",\n",
" snr_function=\"log10\",\n",
" snr_function_args=(0.1, 100),\n",
" apply_poisson_noise=True,\n",
" return_fg=True,\n",
Expand All @@ -348,8 +348,7 @@
"outputs": [],
"source": [
"\"\"\"Normalization\"\"\"\n",
"gross_ss.normalize()\n",
"bg_ss.normalize()"
"fg_ss.normalize()"
]
},
{
Expand All @@ -371,16 +370,10 @@
"outputs": [],
"source": [
"\"\"\"Model fitting\"\"\"\n",
"from riid.models.neural_nets import MLPClassifier\n",
"from riid.metrics import single_f1\n",
"from riid.models import MLPClassifier\n",
"\n",
"model = MLPClassifier(\n",
" hidden_layers=(256,),\n",
" learning_rate=4e-3,\n",
" metrics=[single_f1]\n",
")\n",
"\n",
"history = model.fit(gross_ss, bg_ss, epochs=25, patience=5, verbose=True)"
"model = MLPClassifier()\n",
"history = model.fit(fg_ss, epochs=10, verbose=True)"
]
},
{
Expand All @@ -402,12 +395,9 @@
"outputs": [],
"source": [
"\"\"\"Generate some in-distribution data the model has not seen.\"\"\"\n",
"test_bg_ss, test_gross_ss = static_syn.generate(fg_seeds_ss, bg_seeds_ss)\n",
"test_bg_ss.normalize()\n",
"test_gross_ss.normalize()\n",
"# Adjust ground truth\n",
"#test_gross_ss.sources.drop(test_bg_ss.sources.columns, axis=1, inplace=True)\n",
"#test_gross_ss.normalize_sources()"
"test_fg_ss, test_gross_ss = static_syn.generate(fg_seeds_ss, bg_seeds_ss)\n",
"test_fg_ss.normalize()\n",
"test_gross_ss.normalize()"
]
},
{
Expand All @@ -417,7 +407,7 @@
"outputs": [],
"source": [
"\"\"\"Use the model!\"\"\"\n",
"model.predict(test_gross_ss, test_bg_ss) # Saved in your SampleSet containing non-background sources (the gross spectra)"
"model.predict(test_fg_ss) # Results are saved in the SampleSet's prediction_probas DataFrame"
]
},
{
Expand All @@ -429,8 +419,8 @@
"\"\"\"Calculate performance metric\"\"\"\n",
"from sklearn.metrics import f1_score\n",
"\n",
"labels = test_gross_ss.get_labels()\n",
"predictions = test_gross_ss.get_predictions()\n",
"labels = test_fg_ss.get_labels()\n",
"predictions = test_fg_ss.get_predictions()\n",
"f1_score(labels, predictions, average=\"micro\")"
]
},
Expand All @@ -443,7 +433,7 @@
"\"\"\"Confusion Matrix\"\"\"\n",
"from riid.visualize import confusion_matrix\n",
"\n",
"_ = confusion_matrix(test_gross_ss)"
"_ = confusion_matrix(test_fg_ss)"
]
},
{
Expand All @@ -455,7 +445,7 @@
"\"\"\"SNR vs. Model Score\"\"\"\n",
"from riid.visualize import plot_snr_vs_score\n",
"\n",
"_ = plot_snr_vs_score(test_gross_ss, xscale=\"log\")"
"_ = plot_snr_vs_score(test_fg_ss, xscale=\"log\")"
]
},
{
Expand All @@ -465,13 +455,23 @@
"outputs": [],
"source": [
"\"\"\"Save model\"\"\"\n",
"import os\n",
"from pathlib import Path\n",
"\n",
"\n",
"def _delete_if_exists(path: Path):\n",
" if path.exists():\n",
" path.unlink()\n",
"\n",
"model_path = \"./model.h5\"\n",
"if os.path.exists(model_path):\n",
" os.remove(model_path)\n",
"model_path_json = Path(\"./model.json\")\n",
"model_path_tflite = model_path_json.with_suffix(\".tflite\")\n",
"model_path_onnx = model_path_json.with_suffix(\".onnx\")\n",
"_delete_if_exists(model_path_json)\n",
"_delete_if_exists(model_path_tflite)\n",
"_delete_if_exists(model_path_onnx)\n",
"\n",
"model.save(model_path)"
"model.save(str(model_path_json))\n",
"model.to_tflite(str(model_path_tflite))\n",
"model.to_onnx(str(model_path_onnx))"
]
},
{
Expand Down Expand Up @@ -571,7 +571,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.12.4"
},
"orig_nbformat": 4,
"vscode": {
Expand Down
3 changes: 1 addition & 2 deletions examples/data/conversion/pcf_to_ss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
import os
from pathlib import Path

from riid import SAMPLESET_HDF_FILE_EXTENSION
from riid import SAMPLESET_HDF_FILE_EXTENSION, read_pcf
from riid.data.converters import (_validate_and_create_output_dir,
convert_directory)
from riid.data.sampleset import read_pcf


def convert_and_save(input_file_path: str, output_dir: str = None,
Expand Down
4 changes: 1 addition & 3 deletions examples/data/difficulty_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
# Under the terms of Contract DE-NA0003525 with NTESS,
# the U.S. Government retains certain rights in this software.
"""This example demonstrates how to compute the difficulty of a given SampleSet."""
from riid.data.synthetic import get_dummy_seeds
from riid.data.synthetic.seed import SeedMixer
from riid.data.synthetic.static import StaticSynthesizer
from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds

fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg()
mixed_bg_seed_ss = SeedMixer(bg_seeds_ss, mixture_size=3)\
Expand Down
4 changes: 1 addition & 3 deletions examples/data/preprocessing/energy_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
import matplotlib.pyplot as plt
import numpy as np

from riid.data.synthetic import get_dummy_seeds
from riid.data.synthetic.seed import SeedMixer
from riid.data.synthetic.static import StaticSynthesizer
from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds

SYNTHETIC_DATA_CONFIG = {
"samples_per_seed": 10,
Expand Down
4 changes: 2 additions & 2 deletions examples/data/synthesis/mix_seeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
# the U.S. Government retains certain rights in this software.
"""This example demonstrates how to generate synthetic gamma spectra from seeds."""
import numpy as np
from riid.data.synthetic import get_dummy_seeds
from riid.data.synthetic.seed import SeedMixer

from riid import SeedMixer, get_dummy_seeds

fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg()

Expand Down
3 changes: 1 addition & 2 deletions examples/data/synthesis/synthesize_passbys.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import matplotlib.pyplot as plt
import numpy as np

from riid.data.synthetic import get_dummy_seeds
from riid.data.synthetic.passby import PassbySynthesizer
from riid import PassbySynthesizer, get_dummy_seeds

if len(sys.argv) == 2:
import matplotlib
Expand Down
2 changes: 1 addition & 1 deletion examples/data/synthesis/synthesize_seeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""This example demonstrates how to generate synthetic seeds from GADRAS."""
import yaml

from riid.data.synthetic.seed import SeedSynthesizer
from riid import SeedSynthesizer

seed_synth_config = """
---
Expand Down
2 changes: 1 addition & 1 deletion examples/data/synthesis/synthesize_seeds_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
configuration expansion features."""
import yaml

from riid.data.synthetic.seed import SeedSynthesizer
from riid import SeedSynthesizer

seed_synth_config = """
---
Expand Down
4 changes: 1 addition & 3 deletions examples/data/synthesis/synthesize_spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
# Under the terms of Contract DE-NA0003525 with NTESS,
# the U.S. Government retains certain rights in this software.
"""This example demonstrates how to generate synthetic gamma spectra from seeds."""
from riid.data.synthetic import get_dummy_seeds
from riid.data.synthetic.seed import SeedMixer
from riid.data.synthetic.static import StaticSynthesizer
from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds

SYNTHETIC_DATA_CONFIG = {
"samples_per_seed": 10000,
Expand Down
4 changes: 1 addition & 3 deletions examples/modeling/anomaly_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@
import numpy as np
from matplotlib import cm

from riid import PassbySynthesizer, SeedMixer, get_dummy_seeds
from riid.anomaly import PoissonNChannelEventDetector
from riid.data.synthetic import get_dummy_seeds
from riid.data.synthetic.passby import PassbySynthesizer
from riid.data.synthetic.seed import SeedMixer

if len(sys.argv) == 2:
import matplotlib
Expand Down
6 changes: 2 additions & 4 deletions examples/modeling/arad.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
import numpy as np
import pandas as pd

from riid.data.synthetic import get_dummy_seeds
from riid.data.synthetic.seed import SeedMixer
from riid.data.synthetic.static import StaticSynthesizer
from riid.models.neural_nets.arad import ARADv1, ARADv2
from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds
from riid.models import ARADv1, ARADv2

# Config
rng = np.random.default_rng(42)
Expand Down
6 changes: 2 additions & 4 deletions examples/modeling/arad_latent_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@
from keras.api.metrics import Accuracy, CategoricalCrossentropy
from sklearn.metrics import f1_score, mean_squared_error

from riid.data.synthetic import get_dummy_seeds
from riid.data.synthetic.seed import SeedMixer
from riid.data.synthetic.static import StaticSynthesizer
from riid.models.neural_nets.arad import ARADLatentPredictor, ARADv2
from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds
from riid.models import ARADLatentPredictor, ARADv2

# Config
rng = np.random.default_rng(42)
Expand Down
7 changes: 2 additions & 5 deletions examples/modeling/classifier_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,9 @@
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score

from riid.data.synthetic import get_dummy_seeds
from riid.data.synthetic.seed import SeedMixer
from riid.data.synthetic.static import StaticSynthesizer
from riid.models.bayes import PoissonBayesClassifier
from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds
from riid.metrics import precision_recall_curve
from riid.models.neural_nets import MLPClassifier
from riid.models import MLPClassifier, PoissonBayesClassifier
from riid.visualize import plot_precision_recall

if len(sys.argv) == 2:
Expand Down
6 changes: 2 additions & 4 deletions examples/modeling/label_proportion_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@

from sklearn.metrics import mean_absolute_error

from riid.data.synthetic import get_dummy_seeds
from riid.data.synthetic.seed import SeedMixer
from riid.data.synthetic.static import StaticSynthesizer
from riid.models.neural_nets import LabelProportionEstimator
from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds
from riid.models import LabelProportionEstimator

# Generate some mixture training data.
fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg()
Expand Down
6 changes: 2 additions & 4 deletions examples/modeling/neural_network_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
import numpy as np
from sklearn.metrics import f1_score

from riid.data.synthetic import get_dummy_seeds
from riid.data.synthetic.seed import SeedMixer
from riid.data.synthetic.static import StaticSynthesizer
from riid.models.neural_nets import MLPClassifier
from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds
from riid.models import MLPClassifier

# Generate some training data
fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg()
Expand Down
1 change: 1 addition & 0 deletions examples/run_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import subprocess
import sys
from pathlib import Path

import pandas as pd
from tabulate import tabulate

Expand Down
6 changes: 2 additions & 4 deletions examples/visualization/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
"""This example demonstrates how to obtain confusion matrices."""
import sys

from riid.data.synthetic import get_dummy_seeds
from riid.data.synthetic.seed import SeedMixer
from riid.data.synthetic.static import StaticSynthesizer
from riid.models.neural_nets import MLPClassifier
from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds
from riid.models import MLPClassifier
from riid.visualize import confusion_matrix

if len(sys.argv) == 2:
Expand Down
2 changes: 1 addition & 1 deletion examples/visualization/distance_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import matplotlib.pyplot as plt
import seaborn as sns

from riid.data.synthetic import get_dummy_seeds
from riid import get_dummy_seeds

if len(sys.argv) == 2:
import matplotlib
Expand Down
4 changes: 1 addition & 3 deletions examples/visualization/plot_sampleset_compare_to.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
"""This example demonstrates how to compare sample sets."""
import sys

from riid.data.synthetic import get_dummy_seeds
from riid.data.synthetic.seed import SeedMixer
from riid.data.synthetic.static import StaticSynthesizer
from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds
from riid.visualize import plot_ss_comparison

if len(sys.argv) == 2:
Expand Down
2 changes: 1 addition & 1 deletion examples/visualization/plot_spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""This example demonstrates how to plot gamma spectra."""
import sys

from riid.data.synthetic import get_dummy_seeds
from riid import get_dummy_seeds
from riid.visualize import plot_spectra

if len(sys.argv) == 2:
Expand Down
Loading

0 comments on commit e9ab4c1

Please sign in to comment.