Skip to content

Commit

Permalink
fix memory
Browse files Browse the repository at this point in the history
  • Loading branch information
qiweiii committed Aug 12, 2024
1 parent 245334b commit 22dcbce
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 14 deletions.
41 changes: 30 additions & 11 deletions Utils/Sources/Utils/ErasureCoding.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,32 @@ public enum ErasureCodeError: Error {
case reconstructFailed
}

public class Segment {
public var csegment: CSegment

public let data: Data
public let index: Int

public init?(data: Data, index: UInt32) {
guard data.count == SEGMENT_SIZE else {
return nil
}
csegment = CSegment(
data: UnsafeMutablePointer(mutating: data.withUnsafeBytes { $0.baseAddress!.assumingMemoryBound(to: UInt8.self) }),
index: index
)
self.data = Data(bytes: csegment.data, count: Int(SEGMENT_SIZE))
self.index = Int(csegment.index)
}

deinit {
csegment_data_free(&csegment)
}
}

/// Split original data into segments
public func split(data: Data) -> [CSegment] {
var segments: [CSegment] = []
public func split(data: Data) -> [Segment] {
var segments: [Segment] = []
let segmentSize = Int(SEGMENT_SIZE)

// Create a new data with padding
Expand All @@ -25,24 +48,20 @@ public func split(data: Data) -> [CSegment] {
let segmentData = paddedData[i ..< end]
let index = UInt32(i / segmentSize)

let segment = CSegment(
data: UnsafeMutablePointer(mutating: segmentData.withUnsafeBytes { $0.baseAddress!.assumingMemoryBound(to: UInt8.self) }),
index: index
)
let segment = Segment(data: segmentData, index: index)!
segments.append(segment)
}

return segments
}

/// Join segments into original data (padding not removed)
private func join(segments: [CSegment]) -> Data {
public func join(segments: [Segment]) -> Data {
var data = Data()
let sortedSegments = segments.sorted { $0.index < $1.index }

for segment in sortedSegments {
let segmentData = UnsafeBufferPointer(start: segment.data, count: Int(SEGMENT_SIZE))
data.append(segmentData)
data.append(segment.data)
}

return data
Expand All @@ -60,14 +79,14 @@ public class SubShardEncoder {
}

/// Construct erasure-coded chunks from segments
public func construct(segments: [CSegment]) -> Result<[UInt8], ErasureCodeError> {
public func construct(segments: [Segment]) -> Result<[UInt8], ErasureCodeError> {
var success = false
var out_len: UInt = 0

let expectedOutLen = Int(SUBSHARD_SIZE) * Int(TOTAL_SHARDS) * segments.count
var out_chunks = [UInt8](repeating: 0, count: expectedOutLen)

segments.withUnsafeBufferPointer { segmentsPtr in
segments.map(\.csegment).withUnsafeBufferPointer { segmentsPtr in
subshard_encoder_construct(encoder, segmentsPtr.baseAddress, UInt(segments.count), &success, &out_chunks, &out_len)
}

Expand Down
5 changes: 5 additions & 0 deletions Utils/Sources/erasure-coding/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ typedef struct SubShardTuple {
*/
typedef uint8_t SubShard[SUBSHARD_SIZE];

/**
* Frees CSegment's data.
*/
void csegment_data_free(struct CSegment *c_segment);

/**
* Initializes a new SubShardEncoder.
*/
Expand Down
20 changes: 17 additions & 3 deletions Utils/Sources/erasure-coding/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ impl From<Segment> for CSegment {
index: segment.index,
};

// prevent Rust from freeing the Vec while CSegment is in use
std::mem::forget(vec_data);

c_segment
Expand All @@ -42,6 +41,21 @@ impl From<CSegment> for Segment {
}
}

/// Frees CSegment's data.
#[no_mangle]
pub extern "C" fn csegment_data_free(c_segment: *mut CSegment) {
if !c_segment.is_null() {
let csegment = unsafe { &*c_segment };

if !csegment.data.is_null() {
unsafe {
let vec_data = Vec::from_raw_parts(csegment.data, SEGMENT_SIZE, SEGMENT_SIZE);
drop(vec_data);
}
}
}
}

/// Initializes a new SubShardEncoder.
#[no_mangle]
pub extern "C" fn subshard_encoder_new() -> *mut SubShardEncoder {
Expand Down Expand Up @@ -101,7 +115,6 @@ pub extern "C" fn subshard_encoder_construct(
*out_len = total_subshards;
}

std::mem::forget(data);
unsafe { *success = true };
}
Err(_) => {
Expand Down Expand Up @@ -202,7 +215,8 @@ pub extern "C" fn subshard_decoder_reconstruct(
let segments_len = segments_vec.len();
let segments_ptr = segments_vec.as_mut_ptr();

std::mem::forget(segments_vec); // prevent the Vec from being deallocated
// prevent the Vec from being deallocated, will be freed in reconstruct_result_free
std::mem::forget(segments_vec);

let result = ReconstructResult {
segments: segments_ptr,
Expand Down

0 comments on commit 22dcbce

Please sign in to comment.