diff --git a/riot/lib/process.ml b/riot/lib/process.ml index 3ac10fd..b41289e 100644 --- a/riot/lib/process.ml +++ b/riot/lib/process.ml @@ -7,12 +7,17 @@ end) type t = P.t type priority = P.priority = High | Normal | Low -type process_flag = P.process_flag = Trap_exit of bool | Priority of priority + +type process_flag = P.process_flag = + | Trap_exit of bool + | Priority of priority + | IsBlockingProc of bool let pp_flag fmt t = match t with | Trap_exit b -> Format.fprintf fmt "trap_exit <- %b" b | Priority p -> Format.fprintf fmt "priority <- %s" (P.priority_to_string p) + | _ -> failwith "TODO" type exit_reason = P.exit_reason = | Normal diff --git a/riot/riot.mli b/riot/riot.mli index 1b1b946..9644a0b 100644 --- a/riot/riot.mli +++ b/riot/riot.mli @@ -113,6 +113,7 @@ module Process : sig (** Processes with a [High] priority will be scheduled before processes with a [Normal] priority which will be scheduled before processes with a [Low] priority. *) + | IsBlockingProc of bool (* An [exit_reason] describes why a process finished. *) type exit_reason = @@ -225,6 +226,9 @@ val spawn_pinned : (unit -> unit) -> Pid.t val spawn_link : (unit -> unit) -> Pid.t (** Spawns a new process and links it to the current process before returning. *) +val spawn_blocking : (unit -> unit) -> Pid.t +(** Spawns a new isolated process that does not yield to the Riot scheduler. *) + exception Name_already_registered of string * Pid.t val register : string -> Pid.t -> unit diff --git a/riot/runtime/core/process.ml b/riot/runtime/core/process.ml index 798879a..e4eeaea 100644 --- a/riot/runtime/core/process.ml +++ b/riot/runtime/core/process.ml @@ -39,12 +39,20 @@ let priority_to_string = function type process_flags = { trap_exits : bool Atomic.t; priority : priority Atomic.t; + is_blocking_proc : bool Atomic.t; } -type process_flag = Trap_exit of bool | Priority of priority +type process_flag = + | Trap_exit of bool + | Priority of priority + | IsBlockingProc of bool let default_flags () = - { trap_exits = Atomic.make false; priority = Atomic.make Normal } + { + trap_exits = Atomic.make false; + priority = Atomic.make Normal; + is_blocking_proc = Atomic.make false; + } type t = { pid : Pid.t; @@ -169,6 +177,7 @@ let is_runnable t = Atomic.get t.state = Runnable let is_running t = Atomic.get t.state = Running let is_finalized t = Atomic.get t.state = Finalized let is_main t = Pid.equal (pid t) Pid.main +let is_blocking_proc t = Atomic.get t.flags.is_blocking_proc let has_empty_mailbox t = Mailbox.is_empty t.save_queue && Mailbox.is_empty t.mailbox @@ -274,6 +283,10 @@ let rec set_flag t flag = let old_flag = Atomic.get t.flags.priority in if Atomic.compare_and_set t.flags.priority old_flag p then () else set_flag t flag + | IsBlockingProc b -> + let old_flag = Atomic.get t.flags.is_blocking_proc in + if Atomic.compare_and_set t.flags.is_blocking_proc old_flag b then () + else set_flag t flag let set_cont t c = t.cont <- Some c let set_sid t sid = Atomic.set t.sid sid diff --git a/riot/runtime/import.ml b/riot/runtime/import.ml index c1c5e40..568bc58 100644 --- a/riot/runtime/import.ml +++ b/riot/runtime/import.ml @@ -139,6 +139,35 @@ let spawn_pinned fn = let spawn_link fn = _spawn ~do_link:true fn +let spawn_blocking fn = + let pool = _get_pool () in + (* Create a scheduler *) + let blocking_scheduler = Scheduler.Pool.spawn_blocking_scheduler pool in + + (* Start the process *) + let proc = + Process.make blocking_scheduler.scheduler.uid (fun () -> + try + fn (); + Normal + with + | Proc_state.Unwind -> Normal + | exn -> + Log.error (fun f -> + f "Process %a died with unhandled exception %s:\n%s" Pid.pp + (self ()) (Printexc.to_string exn) + (Printexc.get_backtrace ())); + + Exception exn) + in + Process.set_flag proc (IsBlockingProc true); + Scheduler.Pool.register_process pool proc; + let _ = + Scheduler.kickstart_blocking_process pool blocking_scheduler.scheduler proc + in + proc.pid +(* _spawn ~do_link:false ~scheduler:blocking_scheduler fn *) + let monitor pid = let pool = _get_pool () in let this = _get_proc (self ()) in diff --git a/riot/runtime/scheduler/scheduler.ml b/riot/runtime/scheduler/scheduler.ml index 2c7384a..cc81b35 100644 --- a/riot/runtime/scheduler/scheduler.ml +++ b/riot/runtime/scheduler/scheduler.ml @@ -12,6 +12,7 @@ type t = { idle_mutex : Mutex.t; idle_condition : Condition.t; currently_stealing : Mutex.t; + mutable stop : bool; } type io = { @@ -26,12 +27,15 @@ type io = { mutable calls_send : int; } +type blocking = { scheduler : t; domain : unit Domain.t } + type pool = { mutable stop : bool; mutable status : int; io_scheduler : io; schedulers : t list; processes : Proc_table.t; + blocking_schedulers : blocking list Atomic.t; mutable proc_count : int; registry : Proc_registry.t; } @@ -60,6 +64,7 @@ module Scheduler = struct idle_mutex = Mutex.create (); idle_condition = Condition.create (); currently_stealing = Mutex.create (); + stop = false; } let get_current_scheduler, (set_current_scheduler : t -> unit) = @@ -95,6 +100,10 @@ module Scheduler = struct add_to_run_queue sch proc) pool.schedulers + let kickstart_blocking_process pool sch (proc : Process.t) = + add_to_run_queue sch proc; + pool.schedulers + let handle_receive k pool sch (proc : Process.t) ~(ref : 'a Ref.t option) ~timeout ~selector = Trace.handle_receive_span @@ fun () -> @@ -359,6 +368,11 @@ module Scheduler = struct awake_process pool linked_proc) linked_pids; + if Process.is_blocking_proc proc then ( + Log.debug (fun f -> f "Set scheduler.stop to true"); + sch.stop <- true) + else (); + Proc_queue.remove sch.run_queue proc; Proc_table.remove pool.processes proc.pid; Proc_registry.remove pool.registry proc.pid; @@ -454,6 +468,7 @@ module Scheduler = struct (try while true do if pool.stop then raise_notrace Exit; + if sch.stop then raise_notrace Exit; Mutex.lock sch.idle_mutex; while @@ -471,6 +486,34 @@ module Scheduler = struct Log.trace (fun f -> f "< exit worker loop") end +module Blocking_scheduler = struct + (* include Scheduler *) + type t = blocking + + let make sch domain = { scheduler = sch; domain } + + let rec add_to_pool pool blocking = + let dom_list = Atomic.get pool.blocking_schedulers in + if + Atomic.compare_and_set pool.blocking_schedulers dom_list + (blocking :: dom_list) + then () + else add_to_pool pool blocking + + let rec remove_from_pool pool blocking = + let cur = Atomic.get pool.blocking_schedulers in + let without_removee = List.filter (fun sch -> sch.domain != blocking.domain) cur in + if Atomic.compare_and_set pool.blocking_schedulers cur without_removee then + () + else remove_from_pool pool blocking + + + (* Override the handle exit function *) + (* let handle_exit_blocking_proc pool sch proc reason = *) + (* Scheduler.handle_exit_proc pool sch.scheduler proc reason; *) + (* remove_from_pool pool sch *) +end + include Scheduler module Io_scheduler = struct @@ -535,6 +578,24 @@ module Pool = struct sockets and handle that as a regular value rather than as a signal. *) Sys.set_signal Sys.sigpipe Sys.Signal_ignore + let spawn_scheduler_on_pool pool (scheduler : t) : unit Domain.t = + Stdlib.Domain.spawn (fun () -> + setup (); + Stdlib.Domain.at_exit (fun () -> Log.warn (fun f -> f "Domain freed")); + set_pool pool; + Scheduler.set_current_scheduler scheduler; + try + Scheduler.run pool scheduler (); + Log.trace (fun f -> + f "<<< shutting down scheduler #%a" Scheduler_uid.pp scheduler.uid) + with exn -> + Log.error (fun f -> + f "Scheduler.run exception: %s due to: %s%!" + (Printexc.to_string exn) + (Printexc.raw_backtrace_to_string + (Printexc.get_raw_backtrace ()))); + shutdown pool 1) + let make ?(rnd = Random.State.make_self_init ()) ~domains ~main () = setup (); @@ -550,28 +611,13 @@ module Pool = struct io_scheduler; schedulers = [ main ] @ schedulers; processes = Proc_table.create (); + blocking_schedulers = Atomic.make []; registry = Proc_registry.create (); } in - let spawn (scheduler : t) = - Stdlib.Domain.spawn (fun () -> - setup (); - set_pool pool; - Scheduler.set_current_scheduler scheduler; - try - Scheduler.run pool scheduler (); - Log.trace (fun f -> - f "<<< shutting down scheduler #%a" Scheduler_uid.pp - scheduler.uid) - with exn -> - Log.error (fun f -> - f "Scheduler.run exception: %s due to: %s%!" - (Printexc.to_string exn) - (Printexc.raw_backtrace_to_string - (Printexc.get_raw_backtrace ()))); - shutdown pool 1) - in - Log.debug (fun f -> f "Created %d schedulers" (List.length schedulers)); + Log.debug (fun f -> + f "Created %d schedulers excluding the main scheduler" + (List.length schedulers)); let io_thread = Stdlib.Domain.spawn (fun () -> @@ -585,6 +631,16 @@ module Pool = struct shutdown pool 2) in - let scheduler_threads = List.map spawn schedulers in + let scheduler_threads = + List.map (spawn_scheduler_on_pool pool) schedulers + in (pool, io_thread :: scheduler_threads) + + (** Creates a new blocking scheduler in the pool *) + let spawn_blocking_scheduler ?(rnd = Random.State.make_self_init ()) pool = + let new_scheduler = Scheduler.make ~rnd () in + let domain = spawn_scheduler_on_pool pool new_scheduler in + let blocking_sch = Blocking_scheduler.make new_scheduler domain in + Blocking_scheduler.add_to_pool pool blocking_sch; + blocking_sch end diff --git a/test/dune b/test/dune index dff1e47..c0a60fc 100644 --- a/test/dune +++ b/test/dune @@ -115,6 +115,12 @@ (modules link_processes_test) (libraries riot)) +(test + (package riot) + (name process_blocking_test) + (modules process_blocking_test) + (libraries riot)) + (test (package riot) (name process_registration_test) diff --git a/test/process_blocking_test.ml b/test/process_blocking_test.ml new file mode 100644 index 0000000..b3fecee --- /dev/null +++ b/test/process_blocking_test.ml @@ -0,0 +1,52 @@ +[@@@warning "-8"] + +open Riot + +type Message.t += AnswerToAllTheWorldsProblems of int + +(* Factorial is too fast so just a little function that eats some more CPU time*) +let rec block_longer n = if n == 0 then () else block_longer (n - 1) + +let factorial n = + let rec aux n acc = + Logger.info (fun f -> f "Factorial %d" n); + block_longer 100000; + match n with 1 -> acc | x -> aux (n - 1) (acc * x) + in + aux n 1 + +let busy_worker recipient_pid () = + let number = factorial 30 in + send recipient_pid (AnswerToAllTheWorldsProblems number) + +let rec countdown_worker n = + Logger.info (fun f -> f "Countdown loop n = %d" n); + + if n = 0 then () + else ( + yield (); + countdown_worker (n - 1)) + +let rec wait_for_answer () = + match receive_any () with + | AnswerToAllTheWorldsProblems n -> + Printf.printf + "Got the answer!\n\ + The answer to all the worlds problems has been calculated to be %d\n" + n + | _ -> wait_for_answer () + +let () = + Runtime.set_log_level (Some Trace); + print_endline "Test spawn_blocking"; + Riot.run ~workers:0 @@ fun () -> + let _ = Logger.start () |> Result.get_ok in + Logger.set_log_level (Some Info); + + let pid_waiting = spawn wait_for_answer in + + let _countdown_pid = spawn (fun () -> countdown_worker 100) in + let _factorial_answer_pid = spawn_blocking (busy_worker pid_waiting) in + wait_pids [ pid_waiting ]; + flush_all (); + shutdown ()