From dc30a34a0076695c3a458cf35a6736bed805882b Mon Sep 17 00:00:00 2001 From: antoine-de Date: Thu, 17 Mar 2022 18:07:12 +0100 Subject: [PATCH] Use Index instead of Arc POC to see if the ergonomics are ok or not --- Cargo.toml | 2 ++ examples/gtfs_reading.rs | 18 ++++++++++ src/collection.rs | 75 ++++++++++++++++++++++++++++++++++++++++ src/error.rs | 3 ++ src/gtfs.rs | 59 +++++++++++++++---------------- src/lib.rs | 1 + src/objects.rs | 22 +++--------- src/tests.rs | 2 +- 8 files changed, 132 insertions(+), 50 deletions(-) create mode 100644 src/collection.rs diff --git a/Cargo.toml b/Cargo.toml index 9c12032..fb6aa96 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,8 @@ sha2 = "0.10" zip = "0.5" thiserror = "1" rgb = "0.8" +#typed-generational-arena = "0.2" +typed-generational-arena = { git = "https://gitlab.com/antoine-de/typed-generational-arena", branch = "changes" } futures = { version = "0.3", optional = true } reqwest = { version = "0.11", optional = true, features = ["blocking"]} diff --git a/examples/gtfs_reading.rs b/examples/gtfs_reading.rs index 30fd882..a79741e 100644 --- a/examples/gtfs_reading.rs +++ b/examples/gtfs_reading.rs @@ -12,4 +12,22 @@ fn main() { let route_1 = gtfs.routes.get("1").expect("no route 1"); println!("{}: {:?}", route_1.short_name, route_1); + + let trip = gtfs + .trips + .get("trip1") + .expect("impossible to find trip trip1"); + + let stop_time = trip + .stop_times + .iter() + .next() + .expect("no stop times in trips"); + + let stop = gtfs + .stops + .get(stop_time.stop) + .expect("no stop in stop time"); + + println!("first stop of trip 'trip1': {}", &stop.name); } diff --git a/src/collection.rs b/src/collection.rs new file mode 100644 index 0000000..858213a --- /dev/null +++ b/src/collection.rs @@ -0,0 +1,75 @@ +use std::{collections::hash_map::Entry, collections::HashMap, iter::FromIterator}; +use typed_generational_arena::{Arena, Index, Iter}; + +use crate::{Error, Id}; + +pub struct CollectionWithID { + storage: Arena, + ids: HashMap>, +} + +impl Default for CollectionWithID { + fn default() -> Self { + CollectionWithID { + storage: Arena::default(), + ids: HashMap::default(), + } + } +} + +impl CollectionWithID { + pub fn insert(&mut self, o: T) -> Result, Error> { + let id = o.id().to_owned(); + match self.ids.entry(id) { + Entry::Occupied(_) => Err(Error::DuplicateStop(o.id().to_owned())), + Entry::Vacant(e) => { + let index = self.storage.insert(o); + e.insert(index); + Ok(index) + } + } + } +} + +impl CollectionWithID { + pub fn get(&self, i: Index) -> Option<&T> { + self.storage.get(i) + } + + pub fn get_by_id(&self, id: &str) -> Option<&T> { + self.ids.get(id).and_then(|idx| self.storage.get(*idx)) + } + + pub fn get_mut_by_id(&mut self, id: &str) -> Option<&mut T> { + let idx = self.ids.get(id)?; + self.storage.get_mut(*idx) + } + + pub fn get_index(&self, id: &str) -> Option<&Index> { + self.ids.get(id) + } + + pub fn len(&self) -> usize { + self.storage.len() + } + + /// Iterates over the `(Index, &T)` of the `CollectionWithID`. + pub fn iter(&self) -> Iter<'_, T> { + self.storage.iter() + } +} + +impl FromIterator for CollectionWithID { + fn from_iter>(iter: I) -> Self { + let mut c = Self::default(); + + for i in iter { + // Note FromIterator does not handle the insertion error + let _ = c + .insert(i) + .map_err(|e| println!("impossible to insert elt: {}", e)); + } + + c + } +} diff --git a/src/error.rs b/src/error.rs index 7ec098d..fef4fa8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -28,6 +28,9 @@ pub enum Error { /// The color is not given in the RRGGBB format, without a leading `#` #[error("'{0}' is not a valid color; RRGGBB format is expected, without a leading `#`")] InvalidColor(String), + /// Several stops have the same id + #[error("duplicate stop: '{0}'")] + DuplicateStop(String), /// Generic Input/Output error while reading a file #[error("impossible to read file")] IO(#[from] std::io::Error), diff --git a/src/gtfs.rs b/src/gtfs.rs index 6859ea1..2aeec9c 100644 --- a/src/gtfs.rs +++ b/src/gtfs.rs @@ -1,9 +1,8 @@ -use crate::{objects::*, Error, RawGtfs}; +use crate::{collection::CollectionWithID, objects::*, Error, RawGtfs}; use chrono::prelude::NaiveDate; use chrono::Duration; use std::collections::{HashMap, HashSet}; use std::convert::TryFrom; -use std::sync::Arc; /// Data structure with all the GTFS objects /// @@ -27,8 +26,8 @@ pub struct Gtfs { pub calendar: HashMap, /// All calendar dates grouped by service_id pub calendar_dates: HashMap>, - /// All stop by `stop_id`. Stops are in an [Arc] because they are also referenced by each [StopTime] - pub stops: HashMap>, + /// All stop by `stop_id` + pub stops: CollectionWithID, /// All routes by `route_id` pub routes: HashMap, /// All trips by `trip_id` @@ -49,7 +48,7 @@ impl TryFrom for Gtfs { /// /// It might fail if some mandatory files couldn’t be read or if there are references to other objects that are invalid. fn try_from(raw: RawGtfs) -> Result { - let stops = to_stop_map( + let stops = to_stop_collection( raw.stops?, raw.transfers.unwrap_or_else(|| Ok(Vec::new()))?, raw.pathways.unwrap_or(Ok(Vec::new()))?, @@ -176,10 +175,9 @@ impl Gtfs { /// Gets a [Stop] by its `stop_id` pub fn get_stop<'a>(&'a self, id: &str) -> Result<&'a Stop, Error> { - match self.stops.get(id) { - Some(stop) => Ok(stop), - None => Err(Error::ReferenceError(id.to_owned())), - } + self.stops + .get_by_id(id) + .ok_or_else(|| Error::ReferenceError(id.to_owned())) } /// Gets a [Trip] by its `trip_id` @@ -232,37 +230,34 @@ fn to_map(elements: impl IntoIterator) -> HashMap { .collect() } -fn to_stop_map( +fn to_stop_collection( stops: Vec, raw_transfers: Vec, raw_pathways: Vec, -) -> Result>, Error> { - let mut stop_map: HashMap = - stops.into_iter().map(|s| (s.id.clone(), s)).collect(); +) -> Result, Error> { + let mut stops: CollectionWithID = stops.into_iter().collect(); for transfer in raw_transfers { - stop_map - .get(&transfer.to_stop_id) + stops + .get_by_id(&transfer.to_stop_id) .ok_or_else(|| Error::ReferenceError(transfer.to_stop_id.to_string()))?; - stop_map - .entry(transfer.from_stop_id.clone()) - .and_modify(|stop| stop.transfers.push(StopTransfer::from(transfer))); + let s = stops + .get_mut_by_id(&transfer.from_stop_id) + .ok_or_else(|| Error::ReferenceError(transfer.from_stop_id.to_string()))?; + s.transfers.push(StopTransfer::from(transfer)); } for pathway in raw_pathways { - stop_map - .get(&pathway.to_stop_id) + stops + .get_by_id(&pathway.to_stop_id) .ok_or_else(|| Error::ReferenceError(pathway.to_stop_id.to_string()))?; - stop_map - .entry(pathway.from_stop_id.clone()) - .and_modify(|stop| stop.pathways.push(Pathway::from(pathway))); + let s = stops + .get_mut_by_id(&pathway.from_stop_id) + .ok_or_else(|| Error::ReferenceError(pathway.to_stop_id.to_string()))?; + s.pathways.push(Pathway::from(pathway)); } - let res = stop_map - .into_iter() - .map(|(i, s)| (i, Arc::new(s))) - .collect(); - Ok(res) + Ok(stops) } fn to_shape_map(shapes: Vec) -> HashMap> { @@ -292,7 +287,7 @@ fn create_trips( raw_trips: Vec, raw_stop_times: Vec, raw_frequencies: Vec, - stops: &HashMap>, + stops: &CollectionWithID, ) -> Result, Error> { let mut trips = to_map(raw_trips.into_iter().map(|rt| Trip { id: rt.id, @@ -313,9 +308,9 @@ fn create_trips( .get_mut(&s.trip_id) .ok_or_else(|| Error::ReferenceError(s.trip_id.to_string()))?; let stop = stops - .get(&s.stop_id) - .ok_or_else(|| Error::ReferenceError(s.stop_id.to_string()))?; - trip.stop_times.push(StopTime::from(&s, Arc::clone(stop))); + .get_index(&s.stop_id) + .ok_or_else(||Error::ReferenceError(s.stop_id.to_string()))?; + trip.stop_times.push(StopTime::from(&s, *stop)); } for trip in &mut trips.values_mut() { diff --git a/src/lib.rs b/src/lib.rs index ccbd3ff..af7609a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,6 +43,7 @@ extern crate derivative; #[macro_use] extern crate serde_derive; +mod collection; mod enums; pub mod error; mod gtfs; diff --git a/src/objects.rs b/src/objects.rs index 35d8f82..f6ccece 100644 --- a/src/objects.rs +++ b/src/objects.rs @@ -4,7 +4,7 @@ use chrono::{Datelike, NaiveDate, Weekday}; use rgb::RGB8; use std::fmt; -use std::sync::Arc; +use typed_generational_arena::Index; /// Objects that have an identifier implement this trait /// @@ -14,24 +14,12 @@ pub trait Id { fn id(&self) -> &str; } -impl Id for Arc { - fn id(&self) -> &str { - self.as_ref().id() - } -} - /// Trait to introspect what is the object’s type (stop, route…) pub trait Type { /// What is the type of the object fn object_type(&self) -> ObjectType; } -impl Type for Arc { - fn object_type(&self) -> ObjectType { - self.as_ref().object_type() - } -} - /// A calender describes on which days the vehicle runs. See #[derive(Debug, Deserialize, Serialize)] pub struct Calendar { @@ -258,14 +246,14 @@ pub struct RawStopTime { } /// The moment where a vehicle, running on [Trip] stops at a [Stop]. See -#[derive(Clone, Debug, Default)] +#[derive(Clone, Debug)] pub struct StopTime { /// Arrival time of the stop time. /// It's an option since the intermediate stops can have have no arrival /// and this arrival needs to be interpolated pub arrival_time: Option, - /// [Stop] where the vehicle stops - pub stop: Arc, + /// [Index] of the [Stop] where the vehicle stops + pub stop: Index, /// Departure time of the stop time. /// It's an option since the intermediate stops can have have no departure /// and this departure needs to be interpolated @@ -290,7 +278,7 @@ pub struct StopTime { impl StopTime { /// Creates [StopTime] by linking a [RawStopTime::stop_id] to the actual [Stop] - pub fn from(stop_time_gtfs: &RawStopTime, stop: Arc) -> Self { + pub fn from(stop_time_gtfs: &RawStopTime, stop: Index) -> Self { Self { arrival_time: stop_time_gtfs.arrival_time, departure_time: stop_time_gtfs.departure_time, diff --git a/src/tests.rs b/src/tests.rs index 98460a1..728f7d0 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -357,7 +357,7 @@ fn read_interpolated_stops() { assert_eq!(1, gtfs.feed_info.len()); // the second stop have no departure/arrival, it should not cause any problems assert_eq!( - gtfs.trips["trip1"].stop_times[1].stop.name, + gtfs.stops.get(gtfs.trips["trip1"].stop_times[1].stop).expect("no stop").name, "Stop Point child of 1" ); assert!(gtfs.trips["trip1"].stop_times[1].arrival_time.is_none());