From d62ea5640004511bfa3017dccc76626584c9348c Mon Sep 17 00:00:00 2001 From: Marek Kaput Date: Thu, 14 Nov 2024 07:37:03 +0100 Subject: [PATCH] LS: Implement multithreaded diagnostics refreshing commit-id:2a95e833 --- .../src/lang/diagnostics/file_diagnostics.rs | 18 +- .../src/lang/diagnostics/mod.rs | 157 ++++++++++++++---- .../lang/diagnostics/project_diagnostics.rs | 71 ++++++++ .../src/lang/diagnostics/refresh.rs | 73 ++++---- .../src/server/schedule/thread/mod.rs | 4 +- .../src/server/schedule/thread/pool.rs | 10 +- 6 files changed, 253 insertions(+), 80 deletions(-) create mode 100644 crates/cairo-lang-language-server/src/lang/diagnostics/project_diagnostics.rs diff --git a/crates/cairo-lang-language-server/src/lang/diagnostics/file_diagnostics.rs b/crates/cairo-lang-language-server/src/lang/diagnostics/file_diagnostics.rs index 1b0a231dde5..1bb75f7d9b4 100644 --- a/crates/cairo-lang-language-server/src/lang/diagnostics/file_diagnostics.rs +++ b/crates/cairo-lang-language-server/src/lang/diagnostics/file_diagnostics.rs @@ -20,6 +20,16 @@ use crate::lang::lsp::LsProtoGroup; use crate::server::panic::is_cancelled; /// Result of refreshing diagnostics for a single file. +/// +/// ## Comparisons +/// +/// Diagnostics in this structure are stored as Arcs that directly come from Salsa caches. +/// This means that equality comparisons of `FileDiagnostics` are efficient. +/// +/// ## Virtual files +/// +/// When collecting diagnostics using [`FileDiagnostics::collect`], all virtual files related +/// to the given `file` will also be visited and their diagnostics collected. #[derive(Clone, PartialEq, Eq)] pub struct FileDiagnostics { /// The file ID these diagnostics are associated with. @@ -92,9 +102,11 @@ impl FileDiagnostics { }) } - /// Returns `true` if this `FileDiagnostics` contains no diagnostics. - pub fn is_empty(&self) -> bool { - self.semantic.is_empty() && self.lowering.is_empty() && self.parser.is_empty() + /// Clears all diagnostics from this `FileDiagnostics`. + pub fn clear(&mut self) { + self.parser = Diagnostics::default(); + self.semantic = Diagnostics::default(); + self.lowering = Diagnostics::default(); } /// Constructs a new [`lsp_types::PublishDiagnosticsParams`] from this `FileDiagnostics`. diff --git a/crates/cairo-lang-language-server/src/lang/diagnostics/mod.rs b/crates/cairo-lang-language-server/src/lang/diagnostics/mod.rs index 520bdf93a8b..50fa9143e19 100644 --- a/crates/cairo-lang-language-server/src/lang/diagnostics/mod.rs +++ b/crates/cairo-lang-language-server/src/lang/diagnostics/mod.rs @@ -1,13 +1,18 @@ -use std::collections::HashMap; +use std::collections::HashSet; +use std::iter; +use std::iter::zip; +use std::num::NonZero; use std::panic::{AssertUnwindSafe, catch_unwind}; +use cairo_lang_filesystem::ids::FileId; use lsp_types::Url; use tracing::{error, trace}; -use self::file_diagnostics::FileDiagnostics; -use self::refresh::refresh_diagnostics; +use self::project_diagnostics::ProjectDiagnostics; +use self::refresh::{clear_old_diagnostics, refresh_diagnostics}; use self::trigger::trigger; use crate::lang::diagnostics::file_batches::{batches, find_primary_files, find_secondary_files}; +use crate::lang::lsp::LsProtoGroup; use crate::server::client::Notifier; use crate::server::panic::cancelled_anyhow; use crate::server::schedule::thread::{self, JoinHandle, ThreadPriority}; @@ -16,6 +21,7 @@ use crate::state::{State, StateSnapshot}; mod file_batches; mod file_diagnostics; mod lsp; +mod project_diagnostics; mod refresh; mod trigger; @@ -28,50 +34,67 @@ pub struct DiagnosticsController { // The trigger MUST be dropped before worker's join handle. // Otherwise, the controller thread will never be requested to stop, and the controller's // JoinHandle will never terminate. - trigger: trigger::Sender, + trigger: trigger::Sender, _thread: JoinHandle, + state_snapshots_factory: StateSnapshotsFactory, } impl DiagnosticsController { /// Creates a new diagnostics controller. pub fn new(notifier: Notifier) -> Self { let (trigger, receiver) = trigger(); - let thread = DiagnosticsControllerThread::spawn(receiver, notifier); - Self { trigger, _thread: thread } + let (thread, parallelism) = DiagnosticsControllerThread::spawn(receiver, notifier); + Self { + trigger, + _thread: thread, + state_snapshots_factory: StateSnapshotsFactory { parallelism }, + } } /// Schedules diagnostics refreshing on snapshot(s) of the current state. pub fn refresh(&self, state: &State) { - self.trigger.activate(state.snapshot()); + let state_snapshots = self.state_snapshots_factory.create(state); + self.trigger.activate(state_snapshots); } } /// Stores entire state of diagnostics controller's worker thread. struct DiagnosticsControllerThread { - receiver: trigger::Receiver, + receiver: trigger::Receiver, notifier: Notifier, - // NOTE: Globally, we have to always identify files by URL instead of FileId, - // as the diagnostics state is independent of analysis database swaps, - // which invalidate FileIds. - file_diagnostics: HashMap, + pool: thread::Pool, + project_diagnostics: ProjectDiagnostics, } impl DiagnosticsControllerThread { - /// Spawns a new diagnostics controller worker thread. - fn spawn(receiver: trigger::Receiver, notifier: Notifier) -> JoinHandle { - let mut this = Self { receiver, notifier, file_diagnostics: Default::default() }; + /// Spawns a new diagnostics controller worker thread + /// and returns a handle to it and the amount of parallelism it provides. + fn spawn( + receiver: trigger::Receiver, + notifier: Notifier, + ) -> (JoinHandle, NonZero) { + let this = Self { + receiver, + notifier, + pool: thread::Pool::new(), + project_diagnostics: ProjectDiagnostics::new(), + }; + + let parallelism = this.pool.parallelism(); - thread::Builder::new(ThreadPriority::Worker) + let thread = thread::Builder::new(ThreadPriority::Worker) .name("cairo-ls:diagnostics-controller".into()) .spawn(move || this.event_loop()) - .expect("failed to spawn diagnostics controller thread") + .expect("failed to spawn diagnostics controller thread"); + + (thread, parallelism) } /// Runs diagnostics controller's event loop. - fn event_loop(&mut self) { - while let Some(state) = self.receiver.wait() { + fn event_loop(&self) { + while let Some(state_snapshots) = self.receiver.wait() { if let Err(err) = catch_unwind(AssertUnwindSafe(|| { - self.diagnostics_controller_tick(state); + self.diagnostics_controller_tick(state_snapshots); })) { if let Ok(err) = cancelled_anyhow(err, "diagnostics refreshing has been cancelled") { @@ -85,19 +108,87 @@ impl DiagnosticsControllerThread { /// Runs a single tick of the diagnostics controller's event loop. #[tracing::instrument(skip_all)] - fn diagnostics_controller_tick(&mut self, state: StateSnapshot) { - // TODO(mkaput): Make multiple batches and run them in parallel. - let primary_files = find_primary_files(&state.db, &state.open_files); - let secondary_files = find_secondary_files(&state.db, &primary_files); - let files = primary_files.into_iter().chain(secondary_files).collect::>(); - for batch in batches(&files, 1.try_into().unwrap()) { - refresh_diagnostics( - &state.db, - batch, - state.config.trace_macro_diagnostics, - &mut self.file_diagnostics, - self.notifier.clone(), - ); + fn diagnostics_controller_tick(&self, state_snapshots: StateSnapshots) { + let (state, primary_snapshots, secondary_snapshots) = state_snapshots.split(); + + let primary = find_primary_files(&state.db, &state.open_files); + self.spawn_refresh_worker(&primary, primary_snapshots); + + let secondary = find_secondary_files(&state.db, &primary); + self.spawn_refresh_worker(&secondary, secondary_snapshots); + + let files_to_preserve: HashSet = primary + .into_iter() + .chain(secondary) + .flat_map(|file| state.db.url_for_file(file)) + .collect(); + + self.spawn_worker(move |project_diagnostics, notifier| { + clear_old_diagnostics(&state.db, files_to_preserve, project_diagnostics, notifier); + }); + } + + /// Shortcut for spawning a worker task which does the boilerplate around cloning state parts + /// and catching panics. + // FIXME(mkaput): Spawning tasks seems to hang on initial loads, this is probably due to + // task queue being overloaded. + fn spawn_worker(&self, f: impl FnOnce(ProjectDiagnostics, Notifier) + Send + 'static) { + let project_diagnostics = self.project_diagnostics.clone(); + let notifier = self.notifier.clone(); + let worker_fn = move || f(project_diagnostics, notifier); + self.pool.spawn(ThreadPriority::Worker, move || { + if let Err(err) = catch_unwind(AssertUnwindSafe(worker_fn)) { + if let Ok(err) = cancelled_anyhow(err, "diagnostics worker has been cancelled") { + trace!("{err:?}"); + } else { + error!("caught panic in diagnostics worker"); + } + } + }); + } + + /// Makes batches out of `files` and spawns workers to run [`refresh_diagnostics`] on them. + fn spawn_refresh_worker(&self, files: &HashSet, state_snapshots: Vec) { + let files_batches = batches(files, self.pool.parallelism()); + assert_eq!(files_batches.len(), state_snapshots.len()); + for (batch, state) in zip(files_batches, state_snapshots) { + self.spawn_worker(move |project_diagnostics, notifier| { + refresh_diagnostics( + &state.db, + batch, + state.config.trace_macro_diagnostics, + project_diagnostics, + notifier, + ); + }); } } } + +/// Holds multiple snapshots of the state. +/// +/// It is not possible to clone Salsa snapshots nor share one between threads, +/// thus we explicitly create separate snapshots for all threads involved in advance. +struct StateSnapshots(Vec); + +impl StateSnapshots { + fn split(self) -> (StateSnapshot, Vec, Vec) { + let Self(mut snapshots) = self; + let control = snapshots.pop().unwrap(); + assert_eq!(snapshots.len() % 2, 0); + let secondary = snapshots.split_off(snapshots.len() / 2); + (control, snapshots, secondary) + } +} + +struct StateSnapshotsFactory { + parallelism: NonZero, +} + +impl StateSnapshotsFactory { + fn create(&self, state: &State) -> StateSnapshots { + StateSnapshots( + iter::from_fn(|| Some(state.snapshot())).take(self.parallelism.get() * 2 + 1).collect(), + ) + } +} diff --git a/crates/cairo-lang-language-server/src/lang/diagnostics/project_diagnostics.rs b/crates/cairo-lang-language-server/src/lang/diagnostics/project_diagnostics.rs new file mode 100644 index 00000000000..fa5c9a2f64e --- /dev/null +++ b/crates/cairo-lang-language-server/src/lang/diagnostics/project_diagnostics.rs @@ -0,0 +1,71 @@ +use std::collections::{HashMap, HashSet}; +use std::mem; +use std::sync::{Arc, RwLock}; + +use itertools::{Either, Itertools}; +use lsp_types::Url; + +use crate::lang::diagnostics::file_diagnostics::FileDiagnostics; + +/// Global storage of diagnostics for the entire analysed codebase(s). +/// +/// This object can be shared between threads and accessed concurrently. +/// +/// ## Identifying files +/// +/// Globally, we have to always identify files by [`Url`] instead of [`FileId`], +/// as the diagnostics state is independent of analysis database swaps that invalidate interned IDs. +/// +/// [`FileId`]: cairo_lang_filesystem::ids::FileId +#[derive(Clone)] +pub struct ProjectDiagnostics { + file_diagnostics: Arc>>, +} + +impl ProjectDiagnostics { + /// Creates new project diagnostics instance. + pub fn new() -> Self { + Self { file_diagnostics: Default::default() } + } + + /// Inserts new diagnostics for a file if they update the existing diagnostics. + /// + /// Returns `true` if stored diagnostics were updated; otherwise, returns `false`. + pub fn insert(&self, file_url: &Url, new_file_diagnostics: FileDiagnostics) -> bool { + if let Some(old_file_diagnostics) = self + .file_diagnostics + .read() + .expect("file diagnostics are poisoned, bailing out") + .get(file_url) + { + if *old_file_diagnostics == new_file_diagnostics { + return false; + } + }; + + self.file_diagnostics + .write() + .expect("file diagnostics are poisoned, bailing out") + .insert(file_url.clone(), new_file_diagnostics); + true + } + + /// Removes diagnostics for files not present in the given set and returns a list of actually + /// removed entries. + pub fn clear_old(&self, files_to_retain: &HashSet) -> Vec { + let mut file_diagnostics = + self.file_diagnostics.write().expect("file diagnostics are poisoned, bailing out"); + + let (clean, removed) = + mem::take(&mut *file_diagnostics).into_iter().partition_map(|(file_url, diags)| { + if files_to_retain.contains(&file_url) { + Either::Left((file_url, diags)) + } else { + Either::Right(diags) + } + }); + + *file_diagnostics = clean; + removed + } +} diff --git a/crates/cairo-lang-language-server/src/lang/diagnostics/refresh.rs b/crates/cairo-lang-language-server/src/lang/diagnostics/refresh.rs index 98afe053c84..f2f5066da39 100644 --- a/crates/cairo-lang-language-server/src/lang/diagnostics/refresh.rs +++ b/crates/cairo-lang-language-server/src/lang/diagnostics/refresh.rs @@ -1,14 +1,15 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use cairo_lang_defs::ids::ModuleId; use cairo_lang_filesystem::ids::FileId; use cairo_lang_utils::LookupIntern; +use lsp_types::Url; use lsp_types::notification::PublishDiagnostics; -use lsp_types::{PublishDiagnosticsParams, Url}; -use tracing::{info_span, trace}; +use tracing::trace; use crate::lang::db::AnalysisDatabase; use crate::lang::diagnostics::file_diagnostics::FileDiagnostics; +use crate::lang::diagnostics::project_diagnostics::ProjectDiagnostics; use crate::lang::lsp::LsProtoGroup; use crate::server::client::Notifier; @@ -18,10 +19,9 @@ pub fn refresh_diagnostics( db: &AnalysisDatabase, batch: Vec, trace_macro_diagnostics: bool, - file_diagnostics: &mut HashMap, + project_diagnostics: ProjectDiagnostics, notifier: Notifier, ) { - let mut files_with_set_diagnostics: HashSet = HashSet::default(); let mut processed_modules: HashSet = HashSet::default(); for file in batch { @@ -30,31 +30,10 @@ pub fn refresh_diagnostics( file, trace_macro_diagnostics, &mut processed_modules, - &mut files_with_set_diagnostics, - file_diagnostics, + &project_diagnostics, ¬ifier, ); } - - info_span!("clear_old_diagnostics").in_scope(|| { - let mut removed_files = Vec::new(); - - file_diagnostics.retain(|uri, _| { - let retain = files_with_set_diagnostics.contains(uri); - if !retain { - removed_files.push(uri.clone()); - } - retain - }); - - for file in removed_files { - notifier.notify::(PublishDiagnosticsParams { - uri: file, - diagnostics: vec![], - version: None, - }); - } - }); } /// Refresh diagnostics for a single file. @@ -63,8 +42,7 @@ fn refresh_file_diagnostics( file: FileId, trace_macro_diagnostics: bool, processed_modules: &mut HashSet, - files_with_set_diagnostics: &mut HashSet, - file_diagnostics: &mut HashMap, + project_diagnostics: &ProjectDiagnostics, notifier: &Notifier, ) { let Some(file_uri) = db.url_for_file(file) else { @@ -76,20 +54,33 @@ fn refresh_file_diagnostics( return; }; - if !new_file_diagnostics.is_empty() { - files_with_set_diagnostics.insert(file_uri.clone()); - } - - // Since we are using Arcs, this comparison should be efficient. - if let Some(old_file_diagnostics) = file_diagnostics.get(&file_uri) { - if old_file_diagnostics == &new_file_diagnostics { - return; + if project_diagnostics.insert(&file_uri, new_file_diagnostics.clone()) { + if let Some(params) = new_file_diagnostics.to_lsp(db, trace_macro_diagnostics) { + notifier.notify::(params); } + } +} - file_diagnostics.insert(file_uri.clone(), new_file_diagnostics.clone()); - }; +/// Wipes diagnostics for any files not present in the preserve set. +#[tracing::instrument(skip_all)] +pub fn clear_old_diagnostics( + db: &AnalysisDatabase, + files_to_preserve: HashSet, + project_diagnostics: ProjectDiagnostics, + notifier: Notifier, +) { + let removed = project_diagnostics.clear_old(&files_to_preserve); + for mut file_diagnostics in removed { + // It might be that we are removing a file that actually had some diagnostics. + // For example, this might happen if a file with a syntax error is deleted. + // We are reusing just removed `FileDiagnostics` instead of constructing a fresh one + // to preserve any extra state it might contain. + file_diagnostics.clear(); - if let Some(params) = new_file_diagnostics.to_lsp(db, trace_macro_diagnostics) { - notifier.notify::(params); + // We can safely assume `trace_macro_diagnostics` = false here, as we are explicitly + // sending a "no diagnostics" message. + if let Some(params) = file_diagnostics.to_lsp(db, false) { + notifier.notify::(params); + } } } diff --git a/crates/cairo-lang-language-server/src/server/schedule/thread/mod.rs b/crates/cairo-lang-language-server/src/server/schedule/thread/mod.rs index c6644821280..4614c6f5b8b 100644 --- a/crates/cairo-lang-language-server/src/server/schedule/thread/mod.rs +++ b/crates/cairo-lang-language-server/src/server/schedule/thread/mod.rs @@ -38,8 +38,8 @@ use std::fmt; mod pool; mod priority; -pub(super) use pool::Pool; -pub use priority::ThreadPriority; +pub use self::pool::Pool; +pub use self::priority::ThreadPriority; pub struct Builder { priority: ThreadPriority, diff --git a/crates/cairo-lang-language-server/src/server/schedule/thread/pool.rs b/crates/cairo-lang-language-server/src/server/schedule/thread/pool.rs index 417b5461e31..69afe119205 100644 --- a/crates/cairo-lang-language-server/src/server/schedule/thread/pool.rs +++ b/crates/cairo-lang-language-server/src/server/schedule/thread/pool.rs @@ -23,6 +23,7 @@ //! the threading utilities in [`crate::server::schedule::thread`]. use std::cmp::min; +use std::num::NonZero; use std::thread::available_parallelism; use crossbeam::channel::{Receiver, Sender, bounded}; @@ -39,6 +40,8 @@ pub struct Pool { // before we join the worker threads! job_sender: Sender, _handles: Vec, + + parallelism: NonZero, } struct Job { @@ -90,7 +93,7 @@ impl Pool { handles.push(handle); } - Pool { _handles: handles, job_sender } + Pool { _handles: handles, job_sender, parallelism: NonZero::new(threads).unwrap() } } pub fn spawn(&self, priority: ThreadPriority, f: F) @@ -107,4 +110,9 @@ impl Pool { let job = Job { requested_priority: priority, f }; self.job_sender.send(job).unwrap(); } + + /// Returns a number of tasks that this pool can run concurrently. + pub fn parallelism(&self) -> NonZero { + self.parallelism + } }