Skip to content

Commit

Permalink
LS: Implement multithreaded diagnostics refreshing
Browse files Browse the repository at this point in the history
commit-id:2a95e833
  • Loading branch information
mkaput committed Nov 18, 2024
1 parent 9e068e5 commit d62ea56
Show file tree
Hide file tree
Showing 6 changed files with 253 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ use crate::lang::lsp::LsProtoGroup;
use crate::server::panic::is_cancelled;

/// Result of refreshing diagnostics for a single file.
///
/// ## Comparisons
///
/// Diagnostics in this structure are stored as Arcs that directly come from Salsa caches.
/// This means that equality comparisons of `FileDiagnostics` are efficient.
///
/// ## Virtual files
///
/// When collecting diagnostics using [`FileDiagnostics::collect`], all virtual files related
/// to the given `file` will also be visited and their diagnostics collected.
#[derive(Clone, PartialEq, Eq)]
pub struct FileDiagnostics {
/// The file ID these diagnostics are associated with.
Expand Down Expand Up @@ -92,9 +102,11 @@ impl FileDiagnostics {
})
}

/// Returns `true` if this `FileDiagnostics` contains no diagnostics.
pub fn is_empty(&self) -> bool {
self.semantic.is_empty() && self.lowering.is_empty() && self.parser.is_empty()
/// Clears all diagnostics from this `FileDiagnostics`.
pub fn clear(&mut self) {
self.parser = Diagnostics::default();
self.semantic = Diagnostics::default();
self.lowering = Diagnostics::default();
}

/// Constructs a new [`lsp_types::PublishDiagnosticsParams`] from this `FileDiagnostics`.
Expand Down
157 changes: 124 additions & 33 deletions crates/cairo-lang-language-server/src/lang/diagnostics/mod.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
use std::collections::HashMap;
use std::collections::HashSet;
use std::iter;
use std::iter::zip;
use std::num::NonZero;
use std::panic::{AssertUnwindSafe, catch_unwind};

use cairo_lang_filesystem::ids::FileId;
use lsp_types::Url;
use tracing::{error, trace};

use self::file_diagnostics::FileDiagnostics;
use self::refresh::refresh_diagnostics;
use self::project_diagnostics::ProjectDiagnostics;
use self::refresh::{clear_old_diagnostics, refresh_diagnostics};
use self::trigger::trigger;
use crate::lang::diagnostics::file_batches::{batches, find_primary_files, find_secondary_files};
use crate::lang::lsp::LsProtoGroup;
use crate::server::client::Notifier;
use crate::server::panic::cancelled_anyhow;
use crate::server::schedule::thread::{self, JoinHandle, ThreadPriority};
Expand All @@ -16,6 +21,7 @@ use crate::state::{State, StateSnapshot};
mod file_batches;
mod file_diagnostics;
mod lsp;
mod project_diagnostics;
mod refresh;
mod trigger;

Expand All @@ -28,50 +34,67 @@ pub struct DiagnosticsController {
// The trigger MUST be dropped before worker's join handle.
// Otherwise, the controller thread will never be requested to stop, and the controller's
// JoinHandle will never terminate.
trigger: trigger::Sender<StateSnapshot>,
trigger: trigger::Sender<StateSnapshots>,
_thread: JoinHandle,
state_snapshots_factory: StateSnapshotsFactory,
}

impl DiagnosticsController {
/// Creates a new diagnostics controller.
pub fn new(notifier: Notifier) -> Self {
let (trigger, receiver) = trigger();
let thread = DiagnosticsControllerThread::spawn(receiver, notifier);
Self { trigger, _thread: thread }
let (thread, parallelism) = DiagnosticsControllerThread::spawn(receiver, notifier);
Self {
trigger,
_thread: thread,
state_snapshots_factory: StateSnapshotsFactory { parallelism },
}
}

/// Schedules diagnostics refreshing on snapshot(s) of the current state.
pub fn refresh(&self, state: &State) {
self.trigger.activate(state.snapshot());
let state_snapshots = self.state_snapshots_factory.create(state);
self.trigger.activate(state_snapshots);
}
}

/// Stores entire state of diagnostics controller's worker thread.
struct DiagnosticsControllerThread {
receiver: trigger::Receiver<StateSnapshot>,
receiver: trigger::Receiver<StateSnapshots>,
notifier: Notifier,
// NOTE: Globally, we have to always identify files by URL instead of FileId,
// as the diagnostics state is independent of analysis database swaps,
// which invalidate FileIds.
file_diagnostics: HashMap<Url, FileDiagnostics>,
pool: thread::Pool,
project_diagnostics: ProjectDiagnostics,
}

impl DiagnosticsControllerThread {
/// Spawns a new diagnostics controller worker thread.
fn spawn(receiver: trigger::Receiver<StateSnapshot>, notifier: Notifier) -> JoinHandle {
let mut this = Self { receiver, notifier, file_diagnostics: Default::default() };
/// Spawns a new diagnostics controller worker thread
/// and returns a handle to it and the amount of parallelism it provides.
fn spawn(
receiver: trigger::Receiver<StateSnapshots>,
notifier: Notifier,
) -> (JoinHandle, NonZero<usize>) {
let this = Self {
receiver,
notifier,
pool: thread::Pool::new(),
project_diagnostics: ProjectDiagnostics::new(),
};

let parallelism = this.pool.parallelism();

thread::Builder::new(ThreadPriority::Worker)
let thread = thread::Builder::new(ThreadPriority::Worker)
.name("cairo-ls:diagnostics-controller".into())
.spawn(move || this.event_loop())
.expect("failed to spawn diagnostics controller thread")
.expect("failed to spawn diagnostics controller thread");

(thread, parallelism)
}

/// Runs diagnostics controller's event loop.
fn event_loop(&mut self) {
while let Some(state) = self.receiver.wait() {
fn event_loop(&self) {
while let Some(state_snapshots) = self.receiver.wait() {
if let Err(err) = catch_unwind(AssertUnwindSafe(|| {
self.diagnostics_controller_tick(state);
self.diagnostics_controller_tick(state_snapshots);
})) {
if let Ok(err) = cancelled_anyhow(err, "diagnostics refreshing has been cancelled")
{
Expand All @@ -85,19 +108,87 @@ impl DiagnosticsControllerThread {

/// Runs a single tick of the diagnostics controller's event loop.
#[tracing::instrument(skip_all)]
fn diagnostics_controller_tick(&mut self, state: StateSnapshot) {
// TODO(mkaput): Make multiple batches and run them in parallel.
let primary_files = find_primary_files(&state.db, &state.open_files);
let secondary_files = find_secondary_files(&state.db, &primary_files);
let files = primary_files.into_iter().chain(secondary_files).collect::<Vec<_>>();
for batch in batches(&files, 1.try_into().unwrap()) {
refresh_diagnostics(
&state.db,
batch,
state.config.trace_macro_diagnostics,
&mut self.file_diagnostics,
self.notifier.clone(),
);
fn diagnostics_controller_tick(&self, state_snapshots: StateSnapshots) {
let (state, primary_snapshots, secondary_snapshots) = state_snapshots.split();

let primary = find_primary_files(&state.db, &state.open_files);
self.spawn_refresh_worker(&primary, primary_snapshots);

let secondary = find_secondary_files(&state.db, &primary);
self.spawn_refresh_worker(&secondary, secondary_snapshots);

let files_to_preserve: HashSet<Url> = primary
.into_iter()
.chain(secondary)
.flat_map(|file| state.db.url_for_file(file))
.collect();

self.spawn_worker(move |project_diagnostics, notifier| {
clear_old_diagnostics(&state.db, files_to_preserve, project_diagnostics, notifier);
});
}

/// Shortcut for spawning a worker task which does the boilerplate around cloning state parts
/// and catching panics.
// FIXME(mkaput): Spawning tasks seems to hang on initial loads, this is probably due to
// task queue being overloaded.
fn spawn_worker(&self, f: impl FnOnce(ProjectDiagnostics, Notifier) + Send + 'static) {
let project_diagnostics = self.project_diagnostics.clone();
let notifier = self.notifier.clone();
let worker_fn = move || f(project_diagnostics, notifier);
self.pool.spawn(ThreadPriority::Worker, move || {
if let Err(err) = catch_unwind(AssertUnwindSafe(worker_fn)) {
if let Ok(err) = cancelled_anyhow(err, "diagnostics worker has been cancelled") {
trace!("{err:?}");
} else {
error!("caught panic in diagnostics worker");
}
}
});
}

/// Makes batches out of `files` and spawns workers to run [`refresh_diagnostics`] on them.
fn spawn_refresh_worker(&self, files: &HashSet<FileId>, state_snapshots: Vec<StateSnapshot>) {
let files_batches = batches(files, self.pool.parallelism());
assert_eq!(files_batches.len(), state_snapshots.len());
for (batch, state) in zip(files_batches, state_snapshots) {
self.spawn_worker(move |project_diagnostics, notifier| {
refresh_diagnostics(
&state.db,
batch,
state.config.trace_macro_diagnostics,
project_diagnostics,
notifier,
);
});
}
}
}

/// Holds multiple snapshots of the state.
///
/// It is not possible to clone Salsa snapshots nor share one between threads,
/// thus we explicitly create separate snapshots for all threads involved in advance.
struct StateSnapshots(Vec<StateSnapshot>);

impl StateSnapshots {
fn split(self) -> (StateSnapshot, Vec<StateSnapshot>, Vec<StateSnapshot>) {
let Self(mut snapshots) = self;
let control = snapshots.pop().unwrap();
assert_eq!(snapshots.len() % 2, 0);
let secondary = snapshots.split_off(snapshots.len() / 2);
(control, snapshots, secondary)
}
}

struct StateSnapshotsFactory {
parallelism: NonZero<usize>,
}

impl StateSnapshotsFactory {
fn create(&self, state: &State) -> StateSnapshots {
StateSnapshots(
iter::from_fn(|| Some(state.snapshot())).take(self.parallelism.get() * 2 + 1).collect(),
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use std::collections::{HashMap, HashSet};
use std::mem;
use std::sync::{Arc, RwLock};

use itertools::{Either, Itertools};
use lsp_types::Url;

use crate::lang::diagnostics::file_diagnostics::FileDiagnostics;

/// Global storage of diagnostics for the entire analysed codebase(s).
///
/// This object can be shared between threads and accessed concurrently.
///
/// ## Identifying files
///
/// Globally, we have to always identify files by [`Url`] instead of [`FileId`],
/// as the diagnostics state is independent of analysis database swaps that invalidate interned IDs.
///
/// [`FileId`]: cairo_lang_filesystem::ids::FileId
#[derive(Clone)]
pub struct ProjectDiagnostics {
file_diagnostics: Arc<RwLock<HashMap<Url, FileDiagnostics>>>,
}

impl ProjectDiagnostics {
/// Creates new project diagnostics instance.
pub fn new() -> Self {
Self { file_diagnostics: Default::default() }
}

/// Inserts new diagnostics for a file if they update the existing diagnostics.
///
/// Returns `true` if stored diagnostics were updated; otherwise, returns `false`.
pub fn insert(&self, file_url: &Url, new_file_diagnostics: FileDiagnostics) -> bool {
if let Some(old_file_diagnostics) = self
.file_diagnostics
.read()
.expect("file diagnostics are poisoned, bailing out")
.get(file_url)
{
if *old_file_diagnostics == new_file_diagnostics {
return false;
}
};

self.file_diagnostics
.write()
.expect("file diagnostics are poisoned, bailing out")
.insert(file_url.clone(), new_file_diagnostics);
true
}

/// Removes diagnostics for files not present in the given set and returns a list of actually
/// removed entries.
pub fn clear_old(&self, files_to_retain: &HashSet<Url>) -> Vec<FileDiagnostics> {
let mut file_diagnostics =
self.file_diagnostics.write().expect("file diagnostics are poisoned, bailing out");

let (clean, removed) =
mem::take(&mut *file_diagnostics).into_iter().partition_map(|(file_url, diags)| {
if files_to_retain.contains(&file_url) {
Either::Left((file_url, diags))
} else {
Either::Right(diags)
}
});

*file_diagnostics = clean;
removed
}
}
Loading

0 comments on commit d62ea56

Please sign in to comment.