Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance improvements #112

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ sha2 = "0.10"
zip = "0.5"
thiserror = "1"
rgb = "0.8"
rayon = "1.5"
nom = "7.1"
smol_str = { version = "*", features = ["serde"] }

futures = { version = "0.3", optional = true }
reqwest = { version = "0.11", optional = true, features = ["blocking"]}
47 changes: 24 additions & 23 deletions src/gtfs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::collections::{HashMap, HashSet};
use std::convert::TryFrom;
use std::sync::Arc;

type Map<O> = HashMap<smol_str::SmolStr, O>;
/// Data structure with all the GTFS objects
///
/// This structure is easier to use than the [RawGtfs] structure as some relationships are parsed to be easier to use.
Expand All @@ -24,21 +25,21 @@ pub struct Gtfs {
/// Time needed to read and parse the archive in milliseconds
pub read_duration: i64,
/// All Calendar by `service_id`
pub calendar: HashMap<String, Calendar>,
pub calendar: Map<Calendar>,
/// All calendar dates grouped by service_id
pub calendar_dates: HashMap<String, Vec<CalendarDate>>,
pub calendar_dates: Map<Vec<CalendarDate>>,
/// All stop by `stop_id`. Stops are in an [Arc] because they are also referenced by each [StopTime]
pub stops: HashMap<String, Arc<Stop>>,
pub stops: Map<Arc<Stop>>,
/// All routes by `route_id`
pub routes: HashMap<String, Route>,
pub routes: Map<Route>,
/// All trips by `trip_id`
pub trips: HashMap<String, Trip>,
pub trips: Map<Trip>,
/// All agencies. They can not be read by `agency_id`, as it is not a required field
pub agencies: Vec<Agency>,
/// All shapes by shape_id
pub shapes: HashMap<String, Vec<Shape>>,
pub shapes: Map<Vec<Shape>>,
/// All fare attributes by `fare_id`
pub fare_attributes: HashMap<String, FareAttribute>,
pub fare_attributes: Map<FareAttribute>,
/// All feed information. There is no identifier
pub feed_info: Vec<FeedInfo>,
}
Expand Down Expand Up @@ -221,24 +222,24 @@ impl Gtfs {
}
}

fn to_map<O: Id>(elements: impl IntoIterator<Item = O>) -> HashMap<String, O> {
fn to_map<O: Id>(elements: impl IntoIterator<Item = O>) -> Map<O> {
elements
.into_iter()
.map(|e| (e.id().to_owned(), e))
.map(|e| (smol_str::SmolStr::new(e.id()), e))
.collect()
}

fn to_stop_map(stops: Vec<Stop>) -> HashMap<String, Arc<Stop>> {
fn to_stop_map(stops: Vec<Stop>) -> Map<Arc<Stop>> {
stops
.into_iter()
.map(|s| (s.id.clone(), Arc::new(s)))
.collect()
}

fn to_shape_map(shapes: Vec<Shape>) -> HashMap<String, Vec<Shape>> {
let mut res = HashMap::default();
fn to_shape_map(shapes: Vec<Shape>) -> Map<Vec<Shape>> {
let mut res = Map::default();
for s in shapes {
let shape = res.entry(s.id.to_owned()).or_insert_with(Vec::new);
let shape = res.entry(s.id.clone()).or_insert_with(Vec::new);
shape.push(s);
}
// we sort the shape by it's pt_sequence
Expand All @@ -249,10 +250,10 @@ fn to_shape_map(shapes: Vec<Shape>) -> HashMap<String, Vec<Shape>> {
res
}

fn to_calendar_dates(cd: Vec<CalendarDate>) -> HashMap<String, Vec<CalendarDate>> {
let mut res = HashMap::default();
fn to_calendar_dates(cd: Vec<CalendarDate>) -> Map<Vec<CalendarDate>> {
let mut res = Map::default();
for c in cd {
let cal = res.entry(c.service_id.to_owned()).or_insert_with(Vec::new);
let cal = res.entry(c.service_id.clone()).or_insert_with(Vec::new);
cal.push(c);
}
res
Expand All @@ -262,8 +263,8 @@ fn create_trips(
raw_trips: Vec<RawTrip>,
raw_stop_times: Vec<RawStopTime>,
raw_frequencies: Vec<RawFrequency>,
stops: &HashMap<String, Arc<Stop>>,
) -> Result<HashMap<String, Trip>, Error> {
stops: &Map<Arc<Stop>>,
) -> Result<Map<Trip>, Error> {
let mut trips = to_map(raw_trips.into_iter().map(|rt| Trip {
id: rt.id,
service_id: rt.service_id,
Expand All @@ -279,23 +280,23 @@ fn create_trips(
frequencies: vec![],
}));
for s in raw_stop_times {
let trip = &mut trips
.get_mut(&s.trip_id)
let trip = trips
.get_mut(s.trip_id.as_str())
.ok_or_else(|| Error::ReferenceError(s.trip_id.to_string()))?;
let stop = stops
.get(&s.stop_id)
.get(s.stop_id.as_str())
.ok_or_else(|| Error::ReferenceError(s.stop_id.to_string()))?;
trip.stop_times.push(StopTime::from(&s, Arc::clone(stop)));
}

for trip in &mut trips.values_mut() {
for trip in trips.values_mut() {
trip.stop_times
.sort_by(|a, b| a.stop_sequence.cmp(&b.stop_sequence));
}

for f in raw_frequencies {
let trip = &mut trips
.get_mut(&f.trip_id)
.get_mut(f.trip_id.as_str())
.ok_or_else(|| Error::ReferenceError(f.trip_id.to_string()))?;
trip.frequencies.push(Frequency::from(&f));
}
Expand Down
52 changes: 28 additions & 24 deletions src/gtfs_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use serde::Deserialize;
use sha2::{Digest, Sha256};

use crate::{Error, Gtfs, RawGtfs};
use rayon::prelude::*;
use std::collections::HashMap;
use std::convert::TryFrom;
use std::fs::File;
Expand Down Expand Up @@ -268,7 +269,7 @@ impl RawGtfsReader {

fn read_objs<T, O>(mut reader: T, file_name: &str) -> Result<Vec<O>, Error>
where
for<'de> O: Deserialize<'de>,
for<'de> O: Deserialize<'de> + Send,
T: std::io::Read,
{
let mut bom = [0; 3];
Expand Down Expand Up @@ -299,30 +300,33 @@ where
})?
.clone();

let mut res = Vec::new();
for rec in reader.records() {
let r = rec.map_err(|e| Error::CSVError {
file_name: file_name.to_owned(),
source: e,
line_in_error: None,
})?;
let o = r.deserialize(Some(&headers)).map_err(|e| Error::CSVError {
file_name: file_name.to_owned(),
source: e,
line_in_error: Some(crate::error::LineError {
headers: headers.into_iter().map(|s| s.to_owned()).collect(),
values: r.into_iter().map(|s| s.to_owned()).collect(),
}),
})?;
res.push(o);
}

Ok(res)
let v = reader
.records()
.map(|rec| {
rec.map_err(|e| Error::CSVError {
file_name: file_name.to_owned(),
source: e,
line_in_error: None,
})
})
.collect::<Result<Vec<_>, Error>>()?;
v.par_iter()
.map(|r| {
r.deserialize(Some(&headers)).map_err(|e| Error::CSVError {
file_name: file_name.to_owned(),
source: e,
line_in_error: Some(crate::error::LineError {
headers: headers.into_iter().map(|s| s.to_owned()).collect(),
values: r.into_iter().map(|s| s.to_owned()).collect(),
}),
})
})
.collect()
}

fn read_objs_from_path<O>(path: std::path::PathBuf) -> Result<Vec<O>, Error>
where
for<'de> O: Deserialize<'de>,
for<'de> O: Deserialize<'de> + Send,
{
let file_name = path
.file_name()
Expand All @@ -346,7 +350,7 @@ fn read_objs_from_optional_path<O>(
file_name: &str,
) -> Option<Result<Vec<O>, Error>>
where
for<'de> O: Deserialize<'de>,
for<'de> O: Deserialize<'de> + Send,
{
File::open(dir_path.join(file_name))
.ok()
Expand All @@ -359,7 +363,7 @@ fn read_file<O, T>(
file_name: &str,
) -> Result<Vec<O>, Error>
where
for<'de> O: Deserialize<'de>,
for<'de> O: Deserialize<'de> + Send,
T: std::io::Read + std::io::Seek,
{
read_optional_file(file_mapping, archive, file_name)
Expand All @@ -372,7 +376,7 @@ fn read_optional_file<O, T>(
file_name: &str,
) -> Option<Result<Vec<O>, Error>>
where
for<'de> O: Deserialize<'de>,
for<'de> O: Deserialize<'de> + Send,
T: std::io::Read + std::io::Seek,
{
file_mapping.get(&file_name).map(|i| {
Expand Down
35 changes: 18 additions & 17 deletions src/objects.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::serde_helpers::*;
use chrono::{Datelike, NaiveDate, Weekday};
use rgb::RGB8;

use smol_str::SmolStr;
use std::fmt;
use std::sync::Arc;

Expand All @@ -25,7 +26,7 @@ pub trait Type {
pub struct Calendar {
/// Unique technical identifier (not for the traveller) of this calendar
#[serde(rename = "service_id")]
pub id: String,
pub id: SmolStr,
/// Does the service run on mondays
#[serde(
deserialize_with = "deserialize_bool",
Expand Down Expand Up @@ -119,7 +120,7 @@ impl Calendar {
#[derive(Debug, Deserialize, Serialize)]
pub struct CalendarDate {
/// Identifier of the service that is modified at this date
pub service_id: String,
pub service_id: SmolStr,
#[serde(
deserialize_with = "deserialize_date",
serialize_with = "serialize_date"
Expand All @@ -135,7 +136,7 @@ pub struct CalendarDate {
pub struct Stop {
/// Unique technical identifier (not for the traveller) of the stop
#[serde(rename = "stop_id")]
pub id: String,
pub id: SmolStr,
/// Short text or a number that identifies the location for riders
#[serde(rename = "stop_code")]
pub code: Option<String>,
Expand Down Expand Up @@ -197,7 +198,7 @@ impl fmt::Display for Stop {
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct RawStopTime {
/// [Trip] to which this stop time belongs to
pub trip_id: String,
pub trip_id: smol_str::SmolStr,
/// Arrival time of the stop time.
/// It's an option since the intermediate stops can have have no arrival
/// and this arrival needs to be interpolated
Expand All @@ -215,7 +216,7 @@ pub struct RawStopTime {
)]
pub departure_time: Option<u32>,
/// Identifier of the [Stop] where the vehicle stops
pub stop_id: String,
pub stop_id: smol_str::SmolStr,
/// Order of stops for a particular trip. The values must increase along the trip but do not need to be consecutive
pub stop_sequence: u16,
/// Text that appears on signage identifying the trip's destination to riders
Expand Down Expand Up @@ -294,7 +295,7 @@ impl StopTime {
pub struct Route {
/// Unique technical (not for the traveller) identifier for the route
#[serde(rename = "route_id")]
pub id: String,
pub id: SmolStr,
/// Short name of a route. This will often be a short, abstract identifier like "32", "100X", or "Green" that riders use to identify a route, but which doesn't give any indication of what places the route serves
#[serde(rename = "route_short_name")]
pub short_name: String,
Expand Down Expand Up @@ -363,13 +364,13 @@ impl fmt::Display for Route {
pub struct RawTrip {
/// Unique technical (not for the traveller) identifier for the Trip
#[serde(rename = "trip_id")]
pub id: String,
pub id: SmolStr,
/// References the [Calendar] on which this trip runs
pub service_id: String,
pub service_id: SmolStr,
/// References along which [Route] this trip runs
pub route_id: String,
pub route_id: SmolStr,
/// Shape of the trip
pub shape_id: Option<String>,
pub shape_id: Option<SmolStr>,
/// Text that appears on signage identifying the trip's destination to riders
pub trip_headsign: Option<String>,
/// Public facing text used to identify the trip to riders, for instance, to identify train numbers for commuter rail trips
Expand Down Expand Up @@ -412,15 +413,15 @@ impl fmt::Display for RawTrip {
#[derive(Debug, Default)]
pub struct Trip {
/// Unique technical identifier (not for the traveller) for the Trip
pub id: String,
pub id: SmolStr,
/// References the [Calendar] on which this trip runs
pub service_id: String,
pub service_id: SmolStr,
/// References along which [Route] this trip runs
pub route_id: String,
pub route_id: SmolStr,
/// All the [StopTime] that define the trip
pub stop_times: Vec<StopTime>,
/// Text that appears on signage identifying the trip's destination to riders
pub shape_id: Option<String>,
pub shape_id: Option<SmolStr>,
/// Text that appears on signage identifying the trip's destination to riders
pub trip_headsign: Option<String>,
/// Public facing text used to identify the trip to riders, for instance, to identify train numbers for commuter rail trips
Expand Down Expand Up @@ -514,7 +515,7 @@ impl fmt::Display for Agency {
pub struct Shape {
/// Unique technical (not for the traveller) identifier for the Shape
#[serde(rename = "shape_id")]
pub id: String,
pub id: SmolStr,
#[serde(rename = "shape_pt_lat", default)]
/// Latitude of a shape point
pub latitude: f64,
Expand Down Expand Up @@ -546,7 +547,7 @@ impl Id for Shape {
pub struct FareAttribute {
/// Unique technical (not for the traveller) identifier for the FareAttribute
#[serde(rename = "fare_id")]
pub id: String,
pub id: SmolStr,
/// Fare price, in the unit specified by [FareAttribute::currency]
pub price: String,
/// Currency used to pay the fare.
Expand Down Expand Up @@ -578,7 +579,7 @@ impl Type for FareAttribute {
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct RawFrequency {
/// References the [Trip] that uses frequency
pub trip_id: String,
pub trip_id: SmolStr,
/// Time at which the first vehicle departs from the first stop of the trip
#[serde(
deserialize_with = "deserialize_time",
Expand Down
Loading