Skip to content

Commit

Permalink
Better Threading 🔥
Browse files Browse the repository at this point in the history
	modified:   Cargo.toml
	modified:   src/capture.rs
	modified:   src/d3d11.rs
	modified:   src/graphics_capture_api.rs
	modified:   windows-capture-python/README.md
	modified:   windows-capture-python/src/lib.rs
	modified:   windows-capture-python/windows_capture/__init__.py
	modified:   windows-capture-python/windows_capture/windows_capture.cp311-win_amd64.pyd
  • Loading branch information
NiiightmareXD committed Nov 16, 2023
1 parent 93826a2 commit 6997866
Show file tree
Hide file tree
Showing 8 changed files with 240 additions and 67 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ windows = { version = "0.52.0", features = [
"System",
"Graphics_DirectX_Direct3D11",
"Foundation_Metadata",
"Win32_System_Com",
] }

[package.metadata.docs.rs]
Expand Down
152 changes: 126 additions & 26 deletions src/capture.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,23 @@ use std::{
error::Error,
mem,
os::windows::prelude::AsRawHandle,
sync::{
atomic::{self, AtomicBool},
mpsc, Arc,
},
thread::{self, JoinHandle},
};

use log::{info, trace, warn};
use log::{debug, info, trace, warn};
use windows::{
Foundation::AsyncActionCompletedHandler,
Win32::{
Foundation::{HANDLE, LPARAM, WPARAM},
System::{
Com::{CoInitializeEx, CoUninitialize, COINIT_MULTITHREADED, COINIT_SPEED_OVER_MEMORY},
Threading::GetThreadId,
Threading::{GetCurrentThreadId, GetThreadId},
WinRT::{
CreateDispatcherQueueController, DispatcherQueueOptions, DQTAT_COM_NONE,
DQTYPE_THREAD_CURRENT,
CreateDispatcherQueueController, DispatcherQueueOptions, RoInitialize,
RoUninitialize, DQTAT_COM_NONE, DQTYPE_THREAD_CURRENT, RO_INIT_MULTITHREADED,
},
},
UI::WindowsAndMessaging::{
Expand All @@ -35,38 +38,43 @@ use crate::{
#[derive(thiserror::Error, Eq, PartialEq, Clone, Copy, Debug)]
pub enum CaptureControlError {
#[error("Failed To Join Thread")]
FailedToJoin,
FailedToJoinThread,
}

/// Struct Used To Control Capture Thread
pub struct CaptureControl {
thread_handle: Option<JoinHandle<Result<(), Box<dyn Error + Send + Sync>>>>,
halt_handle: Arc<AtomicBool>,
}

impl CaptureControl {
/// Create A New Capture Control Struct
#[must_use]
pub fn new(thread_handle: JoinHandle<Result<(), Box<dyn Error + Send + Sync>>>) -> Self {
pub fn new(
thread_handle: JoinHandle<Result<(), Box<dyn Error + Send + Sync>>>,
halt_handle: Arc<AtomicBool>,
) -> Self {
Self {
thread_handle: Some(thread_handle),
halt_handle,
}
}

/// Check If Capture Thread Is Finished
/// Check To See If Capture Thread Is Finished
#[must_use]
pub fn is_finished(&self) -> bool {
self.thread_handle
.as_ref()
.map_or(true, |thread_handle| thread_handle.is_finished())
}

/// Wait Until The Thread Stops
/// Wait Until The Capturing Thread Stops
pub fn wait(mut self) -> Result<(), Box<dyn Error + Send + Sync>> {
if let Some(thread_handle) = self.thread_handle.take() {
match thread_handle.join() {
Ok(result) => result?,
Err(_) => {
return Err(Box::new(CaptureControlError::FailedToJoin));
return Err(Box::new(CaptureControlError::FailedToJoinThread));
}
}
}
Expand All @@ -76,6 +84,8 @@ impl CaptureControl {

/// Gracefully Stop The Capture Thread
pub fn stop(mut self) -> Result<(), Box<dyn Error + Send + Sync>> {
self.halt_handle.store(true, atomic::Ordering::Relaxed);

if let Some(thread_handle) = self.thread_handle.take() {
let handle = thread_handle.as_raw_handle();
let handle = HANDLE(handle as isize);
Expand Down Expand Up @@ -103,7 +113,7 @@ impl CaptureControl {
match thread_handle.join() {
Ok(result) => result?,
Err(_) => {
return Err(Box::new(CaptureControlError::FailedToJoin));
return Err(Box::new(CaptureControlError::FailedToJoinThread));
}
}
}
Expand All @@ -125,9 +135,9 @@ pub trait WindowsCaptureHandler: Sized {
Self: Send + 'static,
<Self as WindowsCaptureHandler>::Flags: Send,
{
// Initialize COM
trace!("Initializing COM");
unsafe { CoInitializeEx(None, COINIT_MULTITHREADED | COINIT_SPEED_OVER_MEMORY)? };
// Initialize WinRT
trace!("Initializing WinRT");
unsafe { RoInitialize(RO_INIT_MULTITHREADED)? };

// Create A Dispatcher Queue For Current Thread
trace!("Creating A Dispatcher Queue For Capture Thread");
Expand All @@ -150,6 +160,9 @@ pub trait WindowsCaptureHandler: Sized {
)?;
capture.start_capture()?;

// Debug Thread ID
debug!("Thread ID: {}", unsafe { GetCurrentThreadId() });

// Message Loop
trace!("Entering Message Loop");
let mut message = MSG::default();
Expand Down Expand Up @@ -184,9 +197,9 @@ pub trait WindowsCaptureHandler: Sized {
info!("Stopping Capture Thread");
capture.stop_capture();

// Uninitialize COM
trace!("Uninitializing COM");
unsafe { CoUninitialize() };
// Uninitialize WinRT
trace!("Uninitializing WinRT");
unsafe { RoUninitialize() };

// Check RESULT
trace!("Checking RESULT");
Expand All @@ -198,14 +211,107 @@ pub trait WindowsCaptureHandler: Sized {
}

/// Starts The Capture Without Taking Control Of The Current Thread
fn start_free_threaded(settings: WindowsCaptureSettings<Self::Flags>) -> CaptureControl
fn start_free_threaded(
settings: WindowsCaptureSettings<Self::Flags>,
) -> Result<CaptureControl, Box<dyn Error + Send + Sync>>
where
Self: Send + 'static,
<Self as WindowsCaptureHandler>::Flags: Send,
{
let thread_handle = thread::spawn(move || Self::start(settings));
let (sender, receiver) = mpsc::channel::<Arc<AtomicBool>>();

let thread_handle = thread::spawn(move || -> Result<(), Box<dyn Error + Send + Sync>> {
// Initialize WinRT
trace!("Initializing WinRT");
unsafe { RoInitialize(RO_INIT_MULTITHREADED)? };

// Create A Dispatcher Queue For Current Thread
trace!("Creating A Dispatcher Queue For Capture Thread");
let options = DispatcherQueueOptions {
dwSize: mem::size_of::<DispatcherQueueOptions>() as u32,
threadType: DQTYPE_THREAD_CURRENT,
apartmentType: DQTAT_COM_NONE,
};
let controller = unsafe { CreateDispatcherQueueController(options)? };

// Start Capture
info!("Starting Capture Thread");
let trigger = Self::new(settings.flags)?;
let mut capture = GraphicsCaptureApi::new(
settings.item,
trigger,
settings.capture_cursor,
settings.draw_border,
settings.color_format,
)?;
capture.start_capture()?;

// Send Halt Handle
trace!("Sending Halt Handle");
let halt_handle = capture.halt_handle();
sender.send(halt_handle)?;

// Debug Thread ID
debug!("Thread ID: {}", unsafe { GetCurrentThreadId() });

// Message Loop
trace!("Entering Message Loop");
let mut message = MSG::default();
unsafe {
while GetMessageW(&mut message, None, 0, 0).as_bool() {
TranslateMessage(&message);
DispatchMessageW(&message);
}
}

// Shutdown Dispatcher Queue
trace!("Shutting Down Dispatcher Queue");
let async_action = controller.ShutdownQueueAsync()?;
async_action.SetCompleted(&AsyncActionCompletedHandler::new(
move |_, _| -> Result<(), windows::core::Error> {
unsafe { PostQuitMessage(0) };
Ok(())
},
))?;

// Final Message Loop
trace!("Entering Final Message Loop");
let mut message = MSG::default();
unsafe {
while GetMessageW(&mut message, None, 0, 0).as_bool() {
TranslateMessage(&message);
DispatchMessageW(&message);
}
}

// Stop Capturing
info!("Stopping Capture Thread");
capture.stop_capture();

// Uninitialize WinRT
trace!("Uninitializing WinRT");
unsafe { RoUninitialize() };

// Check RESULT
trace!("Checking RESULT");
let result = RESULT.take().expect("Failed To Take RESULT");

result?;

Ok(())
});

CaptureControl::new(thread_handle)
let halt_handle = match receiver.recv() {
Ok(halt_handle) => halt_handle,
Err(_) => match thread_handle.join() {
Ok(result) => return Err(result.err().unwrap()),
Err(_) => {
return Err(Box::new(CaptureControlError::FailedToJoinThread));
}
},
};

Ok(CaptureControl::new(thread_handle, halt_handle))
}

/// Function That Will Be Called To Create The Struct The Flags Can Be
Expand All @@ -222,10 +328,4 @@ pub trait WindowsCaptureHandler: Sized {
/// Called When The Capture Item Closes Usually When The Window Closes,
/// Capture Session Will End After This Function Ends
fn on_closed(&mut self) -> Result<(), Box<dyn Error + Send + Sync>>;

/// Call To Stop The Capture Thread, You Might Receive A Few More Frames
/// Before It Stops
fn stop(&self) {
unsafe { PostQuitMessage(0) };
}
}
8 changes: 4 additions & 4 deletions src/d3d11.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use windows::{
},
Direct3D11::{
D3D11CreateDevice, ID3D11Device, ID3D11DeviceContext,
D3D11_CREATE_DEVICE_DISABLE_GPU_TIMEOUT, D3D11_SDK_VERSION,
D3D11_CREATE_DEVICE_BGRA_SUPPORT, D3D11_SDK_VERSION,
},
Dxgi::IDXGIDevice,
},
Expand Down Expand Up @@ -40,8 +40,8 @@ pub enum DirectXErrors {
}

/// Create ID3D11Device And ID3D11DeviceContext
pub fn create_d3d_device()
-> Result<(ID3D11Device, ID3D11DeviceContext), Box<dyn Error + Send + Sync>> {
pub fn create_d3d_device(
) -> Result<(ID3D11Device, ID3D11DeviceContext), Box<dyn Error + Send + Sync>> {
// Set Feature Flags
let feature_flags = [
D3D_FEATURE_LEVEL_11_1,
Expand All @@ -62,7 +62,7 @@ pub fn create_d3d_device()
None,
D3D_DRIVER_TYPE_HARDWARE,
None,
D3D11_CREATE_DEVICE_DISABLE_GPU_TIMEOUT,
D3D11_CREATE_DEVICE_BGRA_SUPPORT,
Some(&feature_flags),
D3D11_SDK_VERSION,
Some(&mut d3d_device),
Expand Down
10 changes: 9 additions & 1 deletion src/graphics_capture_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ pub struct GraphicsCaptureApi {
_d3d_device_context: ID3D11DeviceContext,
frame_pool: Option<Arc<Direct3D11CaptureFramePool>>,
session: Option<GraphicsCaptureSession>,
closed: Arc<AtomicBool>,
active: bool,
capture_cursor: Option<bool>,
draw_border: Option<bool>,
Expand Down Expand Up @@ -154,7 +155,7 @@ impl GraphicsCaptureApi {
&TypedEventHandler::<Direct3D11CaptureFramePool, IInspectable>::new({
// Init
let frame_pool_recreate = frame_pool.clone();
let closed_frame_pool = closed;
let closed_frame_pool = closed.clone();
let d3d_device_frame_pool = d3d_device.clone();
let context = d3d_device_context.clone();

Expand Down Expand Up @@ -257,6 +258,7 @@ impl GraphicsCaptureApi {
_d3d_device_context: d3d_device_context,
frame_pool: Some(frame_pool),
session: Some(session),
closed,
active: false,
capture_cursor,
draw_border,
Expand Down Expand Up @@ -318,6 +320,12 @@ impl GraphicsCaptureApi {
}
}

/// Get Halt Handle
#[must_use]
pub fn halt_handle(&self) -> Arc<AtomicBool> {
self.closed.clone()
}

/// Check If Windows Graphics Capture Api Is Supported
pub fn is_supported() -> Result<bool, Box<dyn Error + Send + Sync>> {
Ok(ApiInformation::IsApiContractPresentByMajor(
Expand Down
6 changes: 3 additions & 3 deletions windows-capture-python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pip install windows-capture
## Usage

```python
from windows_capture import WindowsCapture, Frame, CaptureControl
from windows_capture import WindowsCapture, Frame, InternalCaptureControl

# Every Error From on_closed and on_frame_arrived Will End Up Here
capture = WindowsCapture(
Expand All @@ -36,7 +36,7 @@ capture = WindowsCapture(

# Called Every Time A New Frame Is Available
@capture.event
def on_frame_arrived(frame: Frame, capture_control: CaptureControl):
def on_frame_arrived(frame: Frame, capture_control: InternalCaptureControl):
print("New Frame Arrived")

# Save The Frame As An Image To The Specified Path
Expand All @@ -48,7 +48,7 @@ def on_frame_arrived(frame: Frame, capture_control: CaptureControl):

# Called When The Capture Item Closes Usually When The Window Closes, Capture
# Session Will End After This Function Ends
@capture.on_closed
@capture.event
def on_closed():
print("Capture Session Closed")

Expand Down
Loading

0 comments on commit 6997866

Please sign in to comment.