Skip to content

Commit

Permalink
Merge branch 'main' into pos-data-verification
Browse files Browse the repository at this point in the history
  • Loading branch information
poszu authored Jul 19, 2023
2 parents 34c0ea5 + f6b4ba5 commit 44c7379
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 28 deletions.
10 changes: 5 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ members = [".", "ffi", "scrypt-ocl", "initializer", "profiler"]

[package]
name = "post-rs"
version = "0.4.0"
version = "0.4.1"
edition = "2021"

[lib]
Expand Down
2 changes: 1 addition & 1 deletion ffi/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "post-cbindings"
version = "0.4.0"
version = "0.4.1"
edition = "2021"


Expand Down
26 changes: 16 additions & 10 deletions ffi/src/initialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,16 @@ pub extern "C" fn get_providers_count() -> usize {
#[no_mangle]
pub extern "C" fn get_providers(out: *mut Provider, out_len: usize) -> InitializeResult {
if out.is_null() {
log::error!("out is null");
return InitializeResult::InitializeInvalidArgument;
}

let providers = if let Ok(p) = scrypt_ocl::get_providers(Some(DeviceType::GPU)) {
p
} else {
return InitializeResult::InitializeFailedToGetProviders;
let providers = match scrypt_ocl::get_providers(Some(DeviceType::GPU)) {
Ok(providers) => providers,
Err(e) => {
log::error!("failed to get providers: {e}");
return InitializeResult::InitializeFailedToGetProviders;
}
};

let out = unsafe { std::slice::from_raw_parts_mut(out, out_len) };
Expand Down Expand Up @@ -109,15 +112,18 @@ pub extern "C" fn initialize(
) -> InitializeResult {
// Convert end to exclusive
if end == u64::MAX {
log::error!("end must be < u64::MAX");
return InitializeResult::InitializeInvalidLabelsRange;
}
let end = end + 1;

let initializer = unsafe { &mut *(initializer as *mut InitializerWrapper) };
let len = if let Ok(len) = usize::try_from(end - start) {
len * 16
} else {
return InitializeResult::InitializeInvalidLabelsRange;
let len = match usize::try_from(end - start) {
Ok(len) => len * 16,
Err(e) => {
log::error!("failed to calculate number of labels to initialize: {e}");
return InitializeResult::InitializeInvalidLabelsRange;
}
};

let mut labels = unsafe { std::slice::from_raw_parts_mut(out_buffer, len) };
Expand All @@ -129,7 +135,7 @@ pub extern "C" fn initialize(
) {
Ok(nonce) => nonce,
Err(e) => {
log::error!("Error initializing labels: {e:?}");
log::error!("error initializing labels: {e:?}");
return InitializeResult::InitializeError;
}
};
Expand All @@ -154,7 +160,7 @@ pub extern "C" fn new_initializer(
match _new_initializer(provider_id, n, commitment, vrf_difficulty) {
Ok(initializer) => Box::into_raw(initializer) as _,
Err(e) => {
log::error!("Error creating initializer: {e:?}");
log::error!("error creating initializer: {e:?}");
std::ptr::null_mut()
}
}
Expand Down
2 changes: 1 addition & 1 deletion initializer/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "initializer"
version = "0.4.0"
version = "0.4.1"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down
2 changes: 1 addition & 1 deletion profiler/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "profiler"
version = "0.4.0"
version = "0.4.1"
edition = "2021"

[dependencies]
Expand Down
2 changes: 1 addition & 1 deletion scrypt-ocl/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "scrypt-ocl"
version = "0.4.0"
version = "0.4.1"
edition = "2021"

[dependencies]
Expand Down
27 changes: 19 additions & 8 deletions scrypt-ocl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,13 @@ impl Display for Provider {
}

pub fn get_providers_count(device_types: Option<DeviceType>) -> usize {
get_providers(device_types).map_or(0, |p| p.len())
get_providers(device_types).map_or_else(
|e| {
log::error!("failed to get providers: {e}");
0
},
|p| p.len(),
)
}

pub fn get_providers(device_types: Option<DeviceType>) -> Result<Vec<Provider>, ScryptError> {
Expand Down Expand Up @@ -143,7 +149,7 @@ impl Scrypter {
DeviceInfoResult::MaxComputeUnits
);
let max_wg_size = device.max_wg_size()?;
log::debug!(
log::info!(
"device memory: {} MB, max_mem_alloc_size: {} MB, max_compute_units: {max_compute_units}, max_wg_size: {max_wg_size}",
device_memory / 1024 / 1024,
max_mem_alloc_size / 1024 / 1024,
Expand Down Expand Up @@ -177,7 +183,7 @@ impl Scrypter {
);
let kernel_wg_size = kernel.wg_info(device, KernelWorkGroupInfo::WorkGroupSize)?;

log::debug!("preferred_wg_size_multiple: {preferred_wg_size_mult}, kernel_wg_size: {kernel_wg_size}");
log::info!("preferred_wg_size_multiple: {preferred_wg_size_mult}, kernel_wg_size: {kernel_wg_size}");

let max_global_work_size_based_on_total_mem =
((device_memory - INPUT_SIZE as u64) / kernel_memory as u64) as usize;
Expand All @@ -190,27 +196,27 @@ impl Scrypter {
let local_work_size = preferred_wg_size_mult;
// Round down to nearest multiple of local_work_size
let global_work_size = (max_global_work_size / local_work_size) * local_work_size;
log::debug!(
log::info!(
"Using: global_work_size: {global_work_size}, local_work_size: {local_work_size}"
);

log::trace!("Allocating buffer for input: {INPUT_SIZE} bytes");
log::info!("Allocating buffer for input: {INPUT_SIZE} bytes");
let input = Buffer::<u32>::builder()
.len(INPUT_SIZE / 4)
.flags(MemFlags::new().read_only())
.queue(pro_que.queue().clone())
.build()?;

let output_size = global_work_size * ENTIRE_LABEL_SIZE;
log::trace!("Allocating buffer for output: {output_size} bytes");
log::info!("Allocating buffer for output: {output_size} bytes");
let output = Buffer::<u8>::builder()
.len(output_size)
.flags(MemFlags::new().write_only())
.queue(pro_que.queue().clone())
.build()?;

let lookup_size = global_work_size * kernel_lookup_mem_size;
log::trace!("Allocating buffer for lookup: {lookup_size} bytes");
log::info!("Allocating buffer for lookup: {lookup_size} bytes");
let lookup_memory = Buffer::<u32>::builder()
.len(lookup_size / 4)
.flags(MemFlags::new().host_no_access())
Expand Down Expand Up @@ -312,6 +318,11 @@ impl OpenClInitializer {
) -> Result<Self, ScryptError> {
let providers = get_providers(device_types)?;
let provider = if let Some(id) = provider_id {
log::info!(
"selecting {} provider from {} available",
id.0,
providers.len()
);
providers
.get(id.0 as usize)
.ok_or(ScryptError::InvalidProviderId(id))?
Expand All @@ -320,7 +331,7 @@ impl OpenClInitializer {
};
let platform = provider.platform;
let device = provider.device;
log::trace!("Using provider: {provider}");
log::info!("Using provider: {provider}");

let scrypter = Scrypter::new(platform, device, n)?;

Expand Down

0 comments on commit 44c7379

Please sign in to comment.