diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..0c40dee --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,40 @@ +name: pip - Build Lint Test and Coverage + +on: + push: + branches: [ main, "v*"] + pull_request: + branches: [ main, "v*"] + +jobs: + build-lint-test-coverage: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.11"] + defaults: + run: + shell: bash -l {0} + + steps: + - name: Checkout git repo + uses: actions/checkout@v3 + - name: Get git tags + run: git fetch --prune --unshallow --tags + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + cache: "pip" + # # Use pip to install dependencies + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install .[dev] + # Run tests + - name: Run tests + run: | + pytest tests/ --cov=src --cov-report xml + - name: Coveralls + uses: coverallsapp/github-action@v2 + with: + path-to-lcov: coverage.xml diff --git a/.gitignore b/.gitignore index 136aed2..5896a78 100644 --- a/.gitignore +++ b/.gitignore @@ -101,3 +101,6 @@ dmypy.json # Convenience file for docker development .docker_bash_history.txt + +# As a library, we don't include poetry lock files +poetry.lock diff --git a/pyproject.toml b/pyproject.toml index 450beea..d0437ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,31 +1,50 @@ -[build-system] -requires = ["setuptools"] -build-backend = "setuptools.build_meta" - [project] name = "adam-assist" -dynamic = ["version"] +version = "0.1.0" description = 'ADAM Core Propagator class using ASSIST' +authors = [ + { name = "Alec Koumjian", email = "akoumjian@gmail.com" }, + { name = "Kathleen Kiker" } +] readme = "README.md" -requires-python = ">=3.11" keywords = [] -authors = [{ name = "Alec Koumjian", email = "akoumjian@gmail.com" }] -classifiers = [ - "Development Status :: 4 - Beta", - "Programming Language :: Python", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: Implementation :: CPython", -] +requires-python = ">=3.11,<4.0" dependencies = [ - "assist @ git+https://github.com/B612-Asteroid-Institute/assist.git@ak/wip", - "adam-core @ git+https://github.com/B612-Asteroid-Institute/adam_core@main", - "naif-de440", - "numpy", - "rebound", - "ray", + "adam-core @ git+https://github.com/B612-Asteroid-Institute/adam_core.git@ak/impacts", + "assist", + "naif-de440", + "numpy", + "ray", + "spiceypy>=6.0.0" ] +[build-system] +requires = ["pdm-backend"] +build-backend = "pdm.backend" +[tool.pdm.build] +includes = ["src/adam_core/"] +# package-dir = "src" + [project.urls] -Documentation = "https://github.com/unknown/adam-assist#readme" -Issues = "https://github.com/unknown/adam-assist/issues" -Source = "https://github.com/unknown/adam-assist" +"Documentation" = "https://github.com/unknown/adam-assist#readme" +"Issues" = "https://github.com/unknown/adam-assist/issues" +"Source" = "https://github.com/unknown/adam-assist" + + +[project.optional-dependencies] +dev = [ + "pytest", + "pytest-cov", + "pytest-benchmark", + "black", + "isort", + "ipython" +] + +[tool.black] +line-length = 88 + +[tool.isort] +profile = "black" + + diff --git a/src/adam_assist/__about__.py b/src/adam_assist/__about__.py deleted file mode 100644 index f102a9c..0000000 --- a/src/adam_assist/__about__.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = "0.0.1" diff --git a/src/adam_assist/__init__.py b/src/adam_assist/__init__.py deleted file mode 100644 index 906c0d0..0000000 --- a/src/adam_assist/__init__.py +++ /dev/null @@ -1,549 +0,0 @@ -import hashlib -import os -import pathlib -from ctypes import c_uint32 -from importlib.resources import files -from typing import Dict, List, Optional, Tuple -import ray -import concurrent.futures - -import assist -import numpy as np -import pyarrow as pa -import pyarrow.compute as pc -import quivr as qv -import rebound -import urllib3 -from adam_core.orbits.variants import VariantEphemeris, VariantOrbits -from adam_core.orbits import Orbits -from adam_core.coordinates import (CartesianCoordinates, Origin, - transform_coordinates) -from adam_core.coordinates.origin import OriginCodes -from adam_core.propagator.propagator import (EphemerisType, ObserverType, - OrbitType, Propagator, - TimestampType, propagation_worker_ray) -from adam_core.time import Timestamp -from quivr.concat import concatenate -from typing import Literal -import concurrent.futures -import logging -from abc import ABC, abstractmethod -from typing import List, Literal, Optional, Type, Union - -import numpy as np -import numpy.typing as npt -import quivr as qv - -from adam_core.ray_cluster import initialize_use_ray - -DATA_DIR = os.getenv("ASSIST_DATA_DIR", "~/.adam_assist_data") - - -class EarthImpacts(qv.Table): - orbit_id = qv.StringColumn() - # Distance from earth center in km - distance = qv.Float64Column() - coordinates = CartesianCoordinates.as_column() - - -def download_jpl_ephemeris_files(data_dir: str = DATA_DIR): - ephemeris_urls = ( - "https://ssd.jpl.nasa.gov/ftp/eph/small_bodies/asteroids_de441/sb441-n16.bsp", - "https://ssd.jpl.nasa.gov/ftp/eph/planets/Linux/de440/linux_p1550p2650.440", - ) - data_dir = pathlib.Path(data_dir).expanduser() - data_dir.mkdir(parents=True, exist_ok=True) - for url in ephemeris_urls: - file_name = pathlib.Path(url).name - file_path = data_dir.joinpath(file_name) - if not file_path.exists(): - # use urllib3 - http = urllib3.PoolManager() - with http.request("GET", url, preload_content=False) as r, open( - file_path, "wb" - ) as out_file: - if r.status != 200: - raise RuntimeError(f"Failed to download {url}") - while True: - data = r.read(1024) - if not data: - break - out_file.write(data) - r.release_conn() - - -def initialize_assist(data_dir: str = DATA_DIR) -> assist.Extras: - root_dir = pathlib.Path(data_dir).expanduser() - ephem = assist.Ephem( - root_dir.joinpath("linux_p1550p2650.440"), - root_dir.joinpath("sb441-n16.bsp"), - ) - sim = rebound.Simulation() - ax = assist.Extras(sim, ephem) - return sim, ephem - - -def uint32_hash(s) -> c_uint32: - sha256_result = hashlib.sha256(s.encode()).digest() - # Get the first 4 bytes of the SHA256 hash to obtain a uint32 value. - return c_uint32(int.from_bytes(sha256_result[:4], byteorder="big")) - - -def hash_orbit_ids_to_uint32( - orbit_ids: np.ndarray[str], -) -> Tuple[Dict[int, str], np.ndarray[np.uint32]]: - """ - Derive uint32 hashes from orbit id strigns - - Rebound uses uint32 to track individual particles, but we use orbit id strings. - Here we attempt to generate uint32 hashes for each and return the mapping as well. - """ - hashes = [uint32_hash(o) for o in orbit_ids] - # Because uint32 is an unhashable type, - # we use a dict mapping from uint32 to orbit id string - mapping = {hashes[i].value: orbit_ids[i] for i in range(len(orbit_ids))} - - return mapping, hashes - -@ray.remote -def assist_propagation_worker_ray( - orbits: OrbitType, - times: OrbitType, - propagator: Type["Propagator"], - adaptive_mode: Optional[int] = None, - min_dt: Optional[float] = None, - **kwargs, -) -> OrbitType: - prop = propagator(**kwargs) - propagated = prop._propagate_orbits(orbits, times, adaptive_mode, min_dt) - return propagated - - -class ASSISTPropagator(Propagator): - def _propagate_orbits( - self, orbits: OrbitType, times: TimestampType, adaptive_mode: Optional[int] = None, min_dt: Optional[float] = None - ) -> OrbitType: - orbits, impacts = self._propagate_orbits_inner(orbits, times, False, adaptive_mode, min_dt) - return orbits - - def propagate_orbits( - self, - orbits: OrbitType, - times: TimestampType, - covariance: bool = False, - adaptive_mode: Optional[int] = None, - min_dt: Optional[float] = None, - covariance_method: Literal[ - "auto", "sigma-point", "monte-carlo" - ] = "monte-carlo", - num_samples: int = 1000, - chunk_size: int = 100, - max_processes: Optional[int] = 1, - parallel_backend: Literal["cf", "ray"] = "ray", - ) -> Orbits: - """ - Propagate each orbit in orbits to each time in times. - - Parameters - ---------- - orbits : `~adam_core.orbits.orbits.Orbits` (N) - Orbits to propagate. - times : Timestamp (M) - Times to which to propagate orbits. - covariance : bool, optional - Propagate the covariance matrices of the orbits. This is done by sampling the - orbits from their covariance matrices and propagating each sample. The covariance - of the propagated orbits is then the covariance of the samples. - covariance_method : {'sigma-point', 'monte-carlo', 'auto'}, optional - The method to use for sampling the covariance matrix. If 'auto' is selected then the method - will be automatically selected based on the covariance matrix. The default is 'monte-carlo'. - num_samples : int, optional - The number of samples to draw when sampling with monte-carlo. - chunk_size : int, optional - Number of orbits to send to each job. - max_processes : int or None, optional - Maximum number of processes to launch. If None then the number of - processes will be equal to the number of cores on the machine. If 1 - then no multiprocessing will be used. If "ray" is the parallel_backend and a ray instance - is initialized already then this argument is ignored. - parallel_backend : {'cf', 'ray'}, optional - The parallel backend to use. 'cf' uses concurrent.futures and 'ray' uses ray. The default is 'cf'. - To use ray, ray must be installed. - - Returns - ------- - propagated : `~adam_core.orbits.orbits.Orbits` - Propagated orbits. - """ - - if max_processes is None or max_processes > 1: - propagated_list: List[Orbits] = [] - variants_list: List[VariantOrbits] = [] - - if parallel_backend == "cf": - with concurrent.futures.ProcessPoolExecutor( - max_workers=max_processes - ) as executor: - # Add orbits to propagate to futures - futures = [] - for orbit_chunk in _iterate_chunks(orbits, chunk_size): - futures.append( - executor.submit( - propagation_worker, - orbit_chunk, - times, - self.__class__, - **self.__dict__, - ) - ) - - # Add variants to propagate to futures - if ( - covariance is True - and not orbits.coordinates.covariance.is_all_nan() - ): - variants = VariantOrbits.create( - orbits, method=covariance_method, num_samples=num_samples - ) - for variant_chunk in _iterate_chunks(variants, chunk_size): - futures.append( - executor.submit( - propagation_worker, - variant_chunk, - times, - self.__class__, - **self.__dict__, - ) - ) - - for future in concurrent.futures.as_completed(futures): - result = future.result() - if isinstance(result, Orbits): - propagated_list.append(result) - elif isinstance(result, VariantOrbits): - variants_list.append(result) - else: - raise ValueError( - f"Unexpected result type from propagation worker: {type(result)}" - ) - - elif parallel_backend == "ray": - if RAY_INSTALLED is False: - raise ImportError( - "Ray must be installed to use the ray parallel backend" - ) - - initialize_use_ray(num_cpus=max_processes) - - # Add orbits and times to object store if - # they haven't already been added - if not isinstance(times, ObjectRef): - times_ref = ray.put(times) - else: - times_ref = times - - if not isinstance(orbits, ObjectRef): - orbits_ref = ray.put(orbits) - else: - orbits_ref = orbits - # We need to dereference the orbits ObjectRef so we can - # check its length for chunking and determine - # if we need to propagate variants - orbits = ray.get(orbits_ref) - - # Create futures - futures = [] - idx = np.arange(0, len(orbits)) - - for orbit in orbits: - futures.append( - assist_propagation_worker_ray.remote( - orbit, - times_ref, - adaptive_mode, - min_dt, - self.__class__, - **self.__dict__, - ) - ) - if ( - covariance is True - and not orbit.coordinates.covariance.is_all_nan() - ): - variants = VariantOrbits.create( - orbit, - method=covariance_method, - num_samples=num_samples, - ) - futures.append( - assist_propagation_worker_ray.remote( - variants, - times_ref, - adaptive_mode, - min_dt, - self.__class__, - **self.__dict__, - ) - ) - - # Get results as they finish (we sort later) - unfinished = futures - while unfinished: - finished, unfinished = ray.wait(unfinished, num_returns=1) - result = ray.get(finished[0]) - if isinstance(result, Orbits): - propagated_list.append(result) - elif isinstance(result, VariantOrbits): - variants_list.append(result) - else: - raise ValueError( - f"Unexpected result type from propagation worker: {type(result)}" - ) - - else: - raise ValueError(f"Unknown parallel backend: {parallel_backend}") - - # Concatenate propagated orbits - propagated = qv.concatenate(propagated_list) - if len(variants_list) > 0: - propagated_variants = qv.concatenate(variants_list) - else: - propagated_variants = None - - else: - propagated = self._propagate_orbits(orbits, times, adaptive_mode, min_dt) - - if covariance is True and not orbits.coordinates.covariance.is_all_nan(): - variants = VariantOrbits.create( - orbits, method=covariance_method, num_samples=num_samples - ) - propagated_variants = self._propagate_orbits(variants, times, adaptive_mode, min_dt) - else: - propagated_variants = None - - if propagated_variants is not None: - propagated = propagated_variants.collapse(propagated) - - return propagated.sort_by( - ["orbit_id", "coordinates.time.days", "coordinates.time.nanos"] - ) - - def _propagate_orbits_inner(self, orbits: OrbitType, times: TimestampType, detect_impacts: bool, adaptive_mode: Optional[int] = None, min_dt: Optional[float] = None) -> Tuple[OrbitType, EarthImpacts]: - # Assert that the time for each orbit definition is the same for the simulator to work - assert len(pc.unique(orbits.coordinates.time.mjd())) == 1 - - # sim, ephem = initialize_assist() - - # The coordinate frame is the equatorial International Celestial Reference Frame (ICRF). - # This is also the native coordinate system for the JPL binary files. - # For units we use solar masses, astronomical units, and days. - # The time coordinate is Barycentric Dynamical Time (TDB) in Julian days. - - # Convert coordinates to ICRF using TDB time - coords = transform_coordinates( - orbits.coordinates, - origin_out=OriginCodes.SOLAR_SYSTEM_BARYCENTER, - frame_out="equatorial", - ) - input_orbit_times = coords.time.rescale("tdb") - coords = coords.set_column("time", input_orbit_times) - orbits = orbits.set_column("coordinates", coords) - - root_dir = pathlib.Path(DATA_DIR).expanduser() - ephem = assist.Ephem( - root_dir.joinpath("linux_p1550p2650.440"), - root_dir.joinpath("sb441-n16.bsp"), - ) - sim = rebound.Simulation() - - if min_dt is not None: - sim.ri_ias15.min_dt = min_dt - - if adaptive_mode is not None: - sim.ri_ias15.adaptive_mode = adaptive_mode - - # Set the simulation time, relative to the jd_ref - start_tdb_time = orbits.coordinates.time.jd().to_numpy()[0] - ephem.jd_ref - sim.t = start_tdb_time - - output_type = type(orbits) - - orbit_id_mapping, uint_orbit_ids = hash_orbit_ids_to_uint32( - orbits.orbit_id.to_numpy(zero_copy_only=False) - ) - - if isinstance(orbits, VariantOrbits): - variantattributes = {} - for idx, orbit_id in enumerate(orbits.orbit_id.to_numpy(zero_copy_only=False)): - variantattributes[orbit_id] = { - 'weight': orbits.weights[idx], - 'weight_cov': orbits.weights_cov[idx], - 'object_id': orbits.object_id[idx] - } - - # Add the orbits as particles to the simulation - coords_df = orbits.coordinates.to_dataframe() - - for i in range(len(coords_df)): - sim.add( - x=coords_df.x[i], - y=coords_df.y[i], - z=coords_df.z[i], - vx=coords_df.vx[i], - vy=coords_df.vy[i], - vz=coords_df.vz[i], - hash=uint_orbit_ids[i], - ) - - ax = assist.Extras(sim, ephem) - - # Prepare the times as jd - jd_ref - integrator_times = times.rescale("tdb").jd() - integrator_times = pc.subtract( - integrator_times, ephem.jd_ref - ) - integrator_times = integrator_times.to_numpy() - - results = None - - # Step through each time, move the simulation forward and - # collect the results. - for i in range(len(integrator_times)): - sim.integrate(integrator_times[i]) - - # Get serialized particle data as numpy arrays - orbit_id_hashes = np.zeros(sim.N, dtype="uint32") - step_xyzvxvyvz = np.zeros((sim.N, 6), dtype="float64") - - sim.serialize_particle_data(xyzvxvyvz=step_xyzvxvyvz, hash=orbit_id_hashes) - - if isinstance(orbits, Orbits): - # Retrieve original orbit id from hash - orbit_ids = [orbit_id_mapping[h] for h in orbit_id_hashes] - time_step_results = Orbits.from_kwargs( - coordinates=CartesianCoordinates.from_kwargs( - x=step_xyzvxvyvz[:, 0], - y=step_xyzvxvyvz[:, 1], - z=step_xyzvxvyvz[:, 2], - vx=step_xyzvxvyvz[:, 3], - vy=step_xyzvxvyvz[:, 4], - vz=step_xyzvxvyvz[:, 5], - time=Timestamp.from_jd(pa.repeat(sim.t + ephem.jd_ref, sim.N), scale="tdb"), - origin=Origin.from_kwargs( - code=pa.repeat( - "SOLAR_SYSTEM_BARYCENTER", - sim.N, - ) - ), - frame="equatorial", - ), - orbit_id=orbit_ids, - ) - elif isinstance(orbits, VariantOrbits): - # Retrieve the orbit id and weights from hash - orbit_ids = [orbit_id_mapping[h] for h in orbit_id_hashes] - object_ids = [variantattributes[orbit_id]["object_id"] for orbit_id in orbit_ids] - weight = [variantattributes[orbit_id]["weight"] for orbit_id in orbit_ids] - weights_covs = [variantattributes[orbit_id]["weight_cov"] for orbit_id in orbit_ids] - time_step_results = VariantOrbits.from_kwargs( - orbit_id=orbit_ids, - object_id=object_ids, - weights=weight, - weights_cov=weights_covs, - coordinates=CartesianCoordinates.from_kwargs( - x=step_xyzvxvyvz[:, 0], - y=step_xyzvxvyvz[:, 1], - z=step_xyzvxvyvz[:, 2], - vx=step_xyzvxvyvz[:, 3], - vy=step_xyzvxvyvz[:, 4], - vz=step_xyzvxvyvz[:, 5], - time=Timestamp.from_jd(pa.repeat(sim.t + ephem.jd_ref, sim.N), scale="tdb"), - origin=Origin.from_kwargs( - code=pa.repeat( - "SOLAR_SYSTEM_BARYCENTER", - sim.N, - ) - ), - frame="equatorial", - ), - ) - - if results is None: - results = time_step_results - else: - results = concatenate([results, time_step_results]) - - results = results.set_column( - "coordinates", - transform_coordinates( - results.coordinates, - origin_out=OriginCodes.SUN, - frame_out="ecliptic", - ), - ) - - if detect_impacts: - impacts = sim._extras_ref.get_impacts() - earth_impacts = None - for i, impact in enumerate(impacts): - orbit_id = orbit_id_mapping.get(impact["hash"], f"unknown-{i}") - time = Timestamp.from_jd([impact["time"]], scale="tdb") - if earth_impacts is None: - coordinates=CartesianCoordinates.from_kwargs( - x=[impact["x"]], - y=[impact["y"]], - z=[impact["z"]], - vx=[impact["vx"]], - vy=[impact["vy"]], - vz=[impact["vz"]], - time=time, - origin=Origin.from_kwargs( - code=["SOLAR_SYSTEM_BARYCENTER"], - ), - frame="equatorial", - ) - coordinates = transform_coordinates( - coordinates, - origin_out=OriginCodes.SUN, - frame_out="ecliptic", - ) - earth_impacts = EarthImpacts.from_kwargs( - orbit_id=[orbit_id], - distance=[impact["distance"]], - coordinates=coordinates, - ) - else: - coordinates=CartesianCoordinates.from_kwargs( - x=[impact["x"]], - y=[impact["y"]], - z=[impact["z"]], - vx=[impact["vx"]], - vy=[impact["vy"]], - vz=[impact["vz"]], - time=time, - origin=Origin.from_kwargs( - code=["SOLAR_SYSTEM_BARYCENTER"], - ), - frame="equatorial", - ) - coordinates = transform_coordinates( - coordinates, - origin_out=OriginCodes.SUN, - frame_out="ecliptic", - ) - - earth_impacts = qv.concatenate(earth_impacts, EarthImpacts.from_kwargs( - orbit_id=[orbit_id], - distance=[impact["distance"]], - coordinates=coordinates, - )) - - return results, earth_impacts - else: - return results, None - - - def _generate_ephemeris( - self, orbits: OrbitType, observers: ObserverType - ) -> EphemerisType: - raise NotImplementedError("Ephemeris generation is not implemented for ASSIST.") diff --git a/src/adam_core/propagator/adam_assist.py b/src/adam_core/propagator/adam_assist.py new file mode 100644 index 0000000..6a36f4a --- /dev/null +++ b/src/adam_core/propagator/adam_assist.py @@ -0,0 +1,506 @@ +import gc +import hashlib +import os +import pathlib +from ctypes import c_uint32 +from typing import Dict, Tuple + +import assist +import numpy as np +import numpy.typing as npt +import pyarrow as pa +import pyarrow.compute as pc +import quivr as qv +import rebound +import urllib3 +from adam_core.coordinates import CartesianCoordinates, Origin, transform_coordinates +from adam_core.coordinates.origin import OriginCodes +from adam_core.dynamics.impacts import EarthImpacts, ImpactMixin +from adam_core.orbits import Orbits +from adam_core.orbits.variants import VariantOrbits +from adam_core.time import Timestamp +from adam_core.utils import get_perturber_state +from quivr.concat import concatenate + +from adam_core.propagator.propagator import ( + EphemerisType, + ObserverType, + OrbitType, + Propagator, + TimestampType, +) + +DATA_DIR = os.getenv("ASSIST_DATA_DIR", "~/.adam_assist_data") + +EARTH_RADIUS_KM = 6371.0 + + +def download_jpl_ephemeris_files(data_dir: str = DATA_DIR): + ephemeris_urls = ( + "https://ssd.jpl.nasa.gov/ftp/eph/small_bodies/asteroids_de441/sb441-n16.bsp", + "https://ssd.jpl.nasa.gov/ftp/eph/planets/Linux/de440/linux_p1550p2650.440", + ) + data_dir = pathlib.Path(data_dir).expanduser() + data_dir.mkdir(parents=True, exist_ok=True) + for url in ephemeris_urls: + file_name = pathlib.Path(url).name + file_path = data_dir.joinpath(file_name) + if not file_path.exists(): + # use urllib3 + http = urllib3.PoolManager() + with http.request("GET", url, preload_content=False) as r, open( + file_path, "wb" + ) as out_file: + if r.status != 200: + raise RuntimeError(f"Failed to download {url}") + while True: + data = r.read(1024) + if not data: + break + out_file.write(data) + r.release_conn() + + +def uint32_hash(s) -> c_uint32: + sha256_result = hashlib.sha256(s.encode()).digest() + # Get the first 4 bytes of the SHA256 hash to obtain a uint32 value. + return c_uint32(int.from_bytes(sha256_result[:4], byteorder="big")) + + +def hash_orbit_ids_to_uint32( + orbit_ids: np.ndarray[str], +) -> Tuple[Dict[int, str], np.ndarray[np.uint32]]: + """ + Derive uint32 hashes from orbit id strigns + + Rebound uses uint32 to track individual particles, but we use orbit id strings. + Here we attempt to generate uint32 hashes for each and return the mapping as well. + """ + hashes = [uint32_hash(o) for o in orbit_ids] + # Because uint32 is an unhashable type, + # we use a dict mapping from uint32 to orbit id string + mapping = {hashes[i].value: orbit_ids[i] for i in range(len(orbit_ids))} + + return mapping, hashes + + +class ASSISTPropagator(Propagator, ImpactMixin): + + def _propagate_orbits(self, orbits: OrbitType, times: TimestampType) -> OrbitType: + # Assert that the time for each orbit definition is the same for the simulator to work + assert len(pc.unique(orbits.coordinates.time.mjd())) == 1 + + # The coordinate frame is the equatorial International Celestial Reference Frame (ICRF). + # This is also the native coordinate system for the JPL binary files. + # For units we use solar masses, astronomical units, and days. + # The time coordinate is Barycentric Dynamical Time (TDB) in Julian days. + + # Convert coordinates to ICRF using TDB time + coords = transform_coordinates( + orbits.coordinates, + origin_out=OriginCodes.SOLAR_SYSTEM_BARYCENTER, + frame_out="equatorial", + ) + input_orbit_times = coords.time.rescale("tdb") + coords = coords.set_column("time", input_orbit_times) + orbits = orbits.set_column("coordinates", coords) + + root_dir = pathlib.Path(DATA_DIR).expanduser() + ephem = assist.Ephem( + root_dir.joinpath("linux_p1550p2650.440"), + root_dir.joinpath("sb441-n16.bsp"), + ) + sim = rebound.Simulation() + sim.ri_ias15.min_dt = 1e-15 + sim.ri_ias15.adaptive_mode = 2 + + # Set the simulation time, relative to the jd_ref + start_tdb_time = orbits.coordinates.time.jd().to_numpy()[0] - ephem.jd_ref + sim.t = start_tdb_time + + output_type = type(orbits) + + orbit_id_mapping, uint_orbit_ids = hash_orbit_ids_to_uint32( + orbits.orbit_id.to_numpy(zero_copy_only=False) + ) + + if isinstance(orbits, VariantOrbits): + variantattributes = {} + for idx, orbit_id in enumerate( + orbits.orbit_id.to_numpy(zero_copy_only=False) + ): + variantattributes[orbit_id] = { + "weight": orbits.weights[idx], + "weight_cov": orbits.weights_cov[idx], + "object_id": orbits.object_id[idx], + } + + # Add the orbits as particles to the simulation + coords_df = orbits.coordinates.to_dataframe() + + for i in range(len(coords_df)): + sim.add( + x=coords_df.x[i], + y=coords_df.y[i], + z=coords_df.z[i], + vx=coords_df.vx[i], + vy=coords_df.vy[i], + vz=coords_df.vz[i], + hash=uint_orbit_ids[i], + ) + + ax = assist.Extras(sim, ephem) + + # Prepare the times as jd - jd_ref + integrator_times = times.rescale("tdb").jd() + integrator_times = pc.subtract(integrator_times, ephem.jd_ref) + integrator_times = integrator_times.to_numpy() + + results = None + + # Step through each time, move the simulation forward and + # collect the results. + for i in range(len(integrator_times)): + sim.integrate(integrator_times[i]) + + # Get serialized particle data as numpy arrays + orbit_id_hashes = np.zeros(sim.N, dtype="uint32") + step_xyzvxvyvz = np.zeros((sim.N, 6), dtype="float64") + + sim.serialize_particle_data(xyzvxvyvz=step_xyzvxvyvz, hash=orbit_id_hashes) + + if isinstance(orbits, Orbits): + # Retrieve original orbit id from hash + orbit_ids = [orbit_id_mapping[h] for h in orbit_id_hashes] + time_step_results = Orbits.from_kwargs( + coordinates=CartesianCoordinates.from_kwargs( + x=step_xyzvxvyvz[:, 0], + y=step_xyzvxvyvz[:, 1], + z=step_xyzvxvyvz[:, 2], + vx=step_xyzvxvyvz[:, 3], + vy=step_xyzvxvyvz[:, 4], + vz=step_xyzvxvyvz[:, 5], + time=Timestamp.from_jd( + pa.repeat(sim.t + ephem.jd_ref, sim.N), scale="tdb" + ), + origin=Origin.from_kwargs( + code=pa.repeat( + "SOLAR_SYSTEM_BARYCENTER", + sim.N, + ) + ), + frame="equatorial", + ), + orbit_id=orbit_ids, + ) + elif isinstance(orbits, VariantOrbits): + # Retrieve the orbit id and weights from hash + orbit_ids = [orbit_id_mapping[h] for h in orbit_id_hashes] + + time_step_results = VariantOrbits.from_kwargs( + orbit_id=orbit_ids, + object_id=orbits.object_ids, + weights=orbits.weights, + weights_cov=orbits.weights_covs, + coordinates=CartesianCoordinates.from_kwargs( + x=step_xyzvxvyvz[:, 0], + y=step_xyzvxvyvz[:, 1], + z=step_xyzvxvyvz[:, 2], + vx=step_xyzvxvyvz[:, 3], + vy=step_xyzvxvyvz[:, 4], + vz=step_xyzvxvyvz[:, 5], + time=Timestamp.from_jd( + pa.repeat(sim.t + ephem.jd_ref, sim.N), scale="tdb" + ), + origin=Origin.from_kwargs( + code=pa.repeat( + "SOLAR_SYSTEM_BARYCENTER", + sim.N, + ) + ), + frame="equatorial", + ), + ) + + if results is None: + results = time_step_results + else: + results = concatenate([results, time_step_results]) + + results = results.set_column( + "coordinates", + transform_coordinates( + results.coordinates, + origin_out=OriginCodes.SUN, + frame_out="ecliptic", + ), + ) + + return results + + def _detect_impacts( + self, orbits: OrbitType, num_days: int + ) -> Tuple[VariantOrbits, EarthImpacts]: + # Assert that the time for each orbit definition is the same for the simulator to work + assert len(pc.unique(orbits.coordinates.time.mjd())) == 1 + + # The coordinate frame is the equatorial International Celestial Reference Frame (ICRF). + # This is also the native coordinate system for the JPL binary files. + # For units we use solar masses, astronomical units, and days. + # The time coordinate is Barycentric Dynamical Time (TDB) in Julian days. + + # Convert coordinates to ICRF using TDB time + coords = transform_coordinates( + orbits.coordinates, + origin_out=OriginCodes.SOLAR_SYSTEM_BARYCENTER, + frame_out="equatorial", + ) + input_orbit_times = coords.time.rescale("tdb") + coords = coords.set_column("time", input_orbit_times) + orbits = orbits.set_column("coordinates", coords) + + root_dir = pathlib.Path(DATA_DIR).expanduser() + ephem_paths = [ + root_dir.joinpath("linux_p1550p2650.440"), + root_dir.joinpath("sb441-n16.bsp"), + ] + ephem = assist.Ephem(*ephem_paths) + sim = None + gc.collect() + sim = rebound.Simulation() + + # Set the simulation time, relative to the jd_ref + start_tdb_time = orbits.coordinates.time.jd().to_numpy()[0] + start_tdb_time = start_tdb_time - ephem.jd_ref + sim.t = start_tdb_time + + particle_ids = orbits.orbit_id.to_numpy(zero_copy_only=False) + + # Serialize the variantorbit + if isinstance(orbits, VariantOrbits): + orbit_ids = orbits.orbit_id.to_numpy(zero_copy_only=False).astype(str) + variant_ids = orbits.variant_id.to_numpy(zero_copy_only=False).astype(str) + # Use numpy string operations to concatenate the orbit_id and variant_id + particle_ids = np.char.add( + np.char.add(orbit_ids, np.repeat("-", len(orbit_ids))), variant_ids + ) + particle_ids = np.array(particle_ids, dtype="object") + + orbit_id_mapping, uint_orbit_ids = hash_orbit_ids_to_uint32(particle_ids) + + # Add the orbits as particles to the simulation + coords_df = orbits.coordinates.to_dataframe() + + # ASSIST _must_ be initialized before adding particles + ax = assist.Extras(sim, ephem) + + for i in range(len(coords_df)): + sim.add( + x=coords_df.x[i], + y=coords_df.y[i], + z=coords_df.z[i], + vx=coords_df.vx[i], + vy=coords_df.vy[i], + vz=coords_df.vz[i], + hash=uint_orbit_ids[i], + ) + + + # sim.integrator = "ias15" + sim.ri_ias15.min_dt = 1e-15 + # sim.dt = 1e-9 + # sim.force_is_velocity_dependent = 0 + sim.ri_ias15.adaptive_mode = 2 + + # Prepare the times as jd - jd_ref + final_integrator_time = ( + orbits.coordinates.time.add_days(num_days).jd().to_numpy()[0] + ) + final_integrator_time = final_integrator_time - ephem.jd_ref + + # Results stores the final positions of the objects + # If an object is an impactor, this represents its position at impact time + results = None + earth_impacts = None + past_integrator_time = False + time_step_results = None + + # Step through each time, move the simulation forward and + # collect the results. + while past_integrator_time is False: + sim.steps(1) + # print(sim.dt_last_done) + if sim.t >= final_integrator_time: + past_integrator_time = True + + # Get serialized particle data as numpy arrays + orbit_id_hashes = np.zeros(sim.N, dtype="uint32") + step_xyzvxvyvz = np.zeros((sim.N, 6), dtype="float64") + + sim.serialize_particle_data(xyzvxvyvz=step_xyzvxvyvz, hash=orbit_id_hashes) + + if isinstance(orbits, Orbits): + # Retrieve original orbit id from hash + orbit_ids = [orbit_id_mapping[h] for h in orbit_id_hashes] + time_step_results = Orbits.from_kwargs( + coordinates=CartesianCoordinates.from_kwargs( + x=step_xyzvxvyvz[:, 0], + y=step_xyzvxvyvz[:, 1], + z=step_xyzvxvyvz[:, 2], + vx=step_xyzvxvyvz[:, 3], + vy=step_xyzvxvyvz[:, 4], + vz=step_xyzvxvyvz[:, 5], + time=Timestamp.from_jd( + pa.repeat(sim.t + ephem.jd_ref, sim.N), scale="tdb" + ), + origin=Origin.from_kwargs( + code=pa.repeat( + "SOLAR_SYSTEM_BARYCENTER", + sim.N, + ) + ), + frame="equatorial", + ), + orbit_id=orbit_ids, + ) + elif isinstance(orbits, VariantOrbits): + # Retrieve the orbit id and weights from hash + particle_ids = [orbit_id_mapping[h] for h in orbit_id_hashes] + orbit_ids, variant_ids = zip( + *[particle_id.split("-") for particle_id in particle_ids] + ) + + # Historically we've done a check here to make sure the orbit of the orbits + # and serialized particles is consistent + # np.testing.assert_array_equal(orbits.orbit_id.to_numpy(zero_copy_only=False).astype(str), orbit_ids) + # np.testing.assert_array_equal(orbits.variant_id.to_numpy(zero_copy_only=False).astype(str), variant_ids) + + time_step_results = VariantOrbits.from_kwargs( + orbit_id=orbit_ids, + variant_id=variant_ids, + object_id=orbits.object_id, + weights=orbits.weights, + weights_cov=orbits.weights_cov, + coordinates=CartesianCoordinates.from_kwargs( + x=step_xyzvxvyvz[:, 0], + y=step_xyzvxvyvz[:, 1], + z=step_xyzvxvyvz[:, 2], + vx=step_xyzvxvyvz[:, 3], + vy=step_xyzvxvyvz[:, 4], + vz=step_xyzvxvyvz[:, 5], + time=Timestamp.from_jd( + pa.repeat(sim.t + ephem.jd_ref, sim.N), scale="tdb" + ), + origin=Origin.from_kwargs( + code=pa.repeat( + "SOLAR_SYSTEM_BARYCENTER", + sim.N, + ) + ), + frame="equatorial", + ), + ) + + time_step_results = time_step_results.set_column( + "coordinates", + transform_coordinates( + time_step_results.coordinates, + origin_out=OriginCodes.SUN, + frame_out="ecliptic", + ), + ) + + # Get the Earth's position at the current time + # earth_geo = get_perturber_state(OriginCodes.EARTH, results.coordinates.time[0], origin=OriginCodes.SUN) + # diff = time_step_results.coordinates.values - earth_geo.coordinates.values + earth_geo = ephem.get_particle("Earth", sim.t) + earth_geo = CartesianCoordinates.from_kwargs( + x=[earth_geo.x], + y=[earth_geo.y], + z=[earth_geo.z], + vx=[earth_geo.vx], + vy=[earth_geo.vy], + vz=[earth_geo.vz], + time=Timestamp.from_jd([sim.t + ephem.jd_ref], scale="tdb"), + origin=Origin.from_kwargs( + code=["SOLAR_SYSTEM_BARYCENTER"], + ), + frame="equatorial", + ) + earth_geo = transform_coordinates( + earth_geo, + origin_out=OriginCodes.SUN, + frame_out="ecliptic", + ) + diff = time_step_results.coordinates.values - earth_geo.values + + # Calculate the distance in KM + normalized_distance = np.linalg.norm(diff[:, :3], axis=1) * 149597870.691 + + # Calculate which particles are within an Earth radius + within_radius = normalized_distance < EARTH_RADIUS_KM + + # If any are within our earth radius, we record the impact + # and do bookkeeping to remove the particle from the simulation + if np.any(within_radius): + distances = normalized_distance[within_radius] + impacting_orbits = time_step_results.apply_mask(within_radius) + + new_impacts = EarthImpacts.from_kwargs( + orbit_id=impacting_orbits.orbit_id, + distance=distances, + coordinates=impacting_orbits.coordinates, + variant_id=impacting_orbits.variant_id, + ) + if earth_impacts is None: + earth_impacts = new_impacts + else: + earth_impacts = qv.concatenate([earth_impacts, new_impacts]) + + # Remove the particle from the simulation, orbits, and store in results + for hash_id in orbit_id_hashes[within_radius]: + sim.remove(hash=c_uint32(hash_id)) + # For some reason, it fails if we let rebound convert the hash to c_uint32 + + # Remove the particle from the input / running orbits + # This allows us to carry through object_id, weights, and weights_cov + orbits = orbits.apply_mask(~within_radius) + # Put the orbits / variants of the impactors into the results set + if results is None: + results = impacting_orbits + else: + results = qv.concatenate([results, impacting_orbits]) + + # Add the final positions of the particles to the results + if results is None: + results = time_step_results + else: + results = qv.concatenate([results, time_step_results]) + + if earth_impacts is None: + earth_impacts = EarthImpacts.from_kwargs( + orbit_id=[], + distance=[], + coordinates=CartesianCoordinates.from_kwargs( + x=[], + y=[], + z=[], + vx=[], + vy=[], + vz=[], + time=Timestamp.from_jd([], scale="tdb"), + origin=Origin.from_kwargs( + code=[], + ), + frame="ecliptic", + ), + variant_id=[], + ) + return results, earth_impacts + + def _generate_ephemeris( + self, orbits: OrbitType, observers: ObserverType + ) -> EphemerisType: + raise NotImplementedError( + "ASSISTPropagator does not yet support ephemeris generation." + ) diff --git a/tests/data/I00007_orbit.parquet b/tests/data/I00007_orbit.parquet new file mode 100644 index 0000000..45127e4 Binary files /dev/null and b/tests/data/I00007_orbit.parquet differ diff --git a/tests/test_integration.py b/tests/test_integration.py index 3e59046..2cc8889 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,19 +1,30 @@ -import numpy as np -from adam_assist import ASSISTPropagator, download_jpl_ephemeris_files -from adam_core.orbits.query import query_sbdb -from adam_core.propagator.pyoorb import PYOORB -from adam_core.time import Timestamp +import pytest +from adam_core.dynamics.impacts import calculate_impacts +from adam_core.orbits import Orbits +from src.adam_core.propagator.adam_assist import ( + ASSISTPropagator, + download_jpl_ephemeris_files, +) -def test_assist_propagator(): - download_jpl_ephemeris_files() - initial_time = Timestamp.from_mjd([60000.0], scale="tdb") - times = initial_time.from_mjd(initial_time.mjd() + np.arange(0, 100)) - orbits = query_sbdb(["EDLU"]) - propagator = ASSISTPropagator() - assist_propagated_orbits = propagator.propagate_orbits(orbits, times) +# Contains a likely impactor with ~60% chance of impact in 30 days +IMPACTOR_FILE_PATH = "tests/data/I00007_orbit.parquet" - pyoorb = PYOORB() - pyoorb_propagated_orbits = pyoorb.propagate_orbits(orbits, times) - return assist_propagated_orbits, pyoorb_propagated_orbits +@pytest.mark.benchmark +@pytest.mark.parametrize("processes", [1, 2]) +def test_calculate_impacts_benchmark(benchmark, processes): + download_jpl_ephemeris_files() + impactor = Orbits.from_parquet(IMPACTOR_FILE_PATH)[0] + propagator = ASSISTPropagator() + variants, impacts = benchmark( + calculate_impacts, + impactor, + 60, + propagator, + num_samples=200, + processes=processes, + seed=42 # This allows us to predict exact number of impactors empirically + ) + assert len(variants) == 200, "Should have 200 variants" + assert len(impacts) == 138, "Should have exactly 138 impactors"