Skip to content

Commit

Permalink
Add tqdm support
Browse files Browse the repository at this point in the history
  • Loading branch information
mjain-jump committed Apr 26, 2024
1 parent 5893352 commit 4048d4c
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 16 deletions.
4 changes: 4 additions & 0 deletions src/test_suite/multiprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
import os


def lazy_starmap(args, function):
return function(*args)


def process_instruction(
library: ctypes.CDLL, serialized_instruction_context: str
) -> pb.InstrEffects | None:
Expand Down
105 changes: 89 additions & 16 deletions src/test_suite/test_suite.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import Counter
import functools
import shutil
from typing import List
import typer
Expand All @@ -16,6 +17,7 @@
decode_single_test_case,
generate_test_case,
initialize_process_output_buffers,
lazy_starmap,
merge_results_over_iterations,
process_instruction,
process_single_test_case,
Expand All @@ -26,6 +28,7 @@
import test_suite.globals as globals
from test_suite.debugger import debug_host
import resource
import tqdm


app = typer.Typer(
Expand Down Expand Up @@ -266,10 +269,17 @@ def minimize_tests(

globals.feature_pool = get_feature_pool(lib)

num_test_cases = len(list(input_dir.iterdir()))

minimize_results = []
with Pool(
processes=num_processes, initializer=initialize_process_output_buffers
) as pool:
minimize_results = pool.map(minimize_single_test_case, input_dir.iterdir())
for result in tqdm.tqdm(
pool.imap(minimize_single_test_case, input_dir.iterdir()),
total=num_test_cases,
):
minimize_results.append(result)

lib.sol_compat_fini()
print("-" * LOG_FILE_SEPARATOR_LENGTH)
Expand Down Expand Up @@ -316,29 +326,57 @@ def create_fixtures(
lib.sol_compat_init()
globals.target_libraries[solana_shared_library] = lib

num_test_cases = len(list(input.dir()))

# Generate the test cases in parallel from files on disk
print("Reading test files...")
execution_contexts = []
with Pool(processes=num_processes) as pool:
execution_contexts = pool.map(generate_test_case, input_dir.iterdir())
for result in tqdm.tqdm(
pool.imap(generate_test_case, input_dir.iterdir()), total=num_test_cases
):
execution_contexts.append(result)

# Process the test cases in parallel through shared libraries
print("Executing tests...")
execution_results = []
with Pool(
processes=num_processes, initializer=initialize_process_output_buffers
) as pool:
execution_results = pool.starmap(process_single_test_case, execution_contexts)
for result in tqdm.tqdm(
pool.imap(
functools.partial(lazy_starmap, function=process_single_test_case),
execution_contexts,
),
total=num_test_cases,
):
execution_results.append(result)

print("Creating fixtures...")
# Prune effects and create fixtures
print("Creating fixtures...")
execution_fixtures = []
with Pool(processes=num_processes) as pool:
execution_fixtures = pool.starmap(
create_fixture, zip(execution_contexts, execution_results)
)
for result in tqdm.tqdm(
pool.imap(
functools.partial(lazy_starmap, function=create_fixture),
zip(execution_contexts, execution_results),
),
total=num_test_cases,
):
execution_fixtures.append(result)

# Write fixtures to disk
print("Writing results to disk...")
write_results = []
with Pool(processes=num_processes) as pool:
write_results = pool.starmap(write_fixture_to_disk, execution_fixtures)
for result in tqdm.tqdm(
pool.imap(
functools.partial(lazy_starmap, function=write_fixture_to_disk),
execution_fixtures,
),
total=num_test_cases,
):
write_results.append(result)

# Clean up
print("Cleaning up...")
Expand Down Expand Up @@ -408,31 +446,59 @@ def run_tests(
log_dir = globals.output_dir / target.stem
log_dir.mkdir(parents=True, exist_ok=True)

num_test_cases = len(list(input_dir.iterdir()))

# Generate the test cases in parallel from files on disk
execution_contexts = []
print("Reading test files...")
with Pool(processes=num_processes) as pool:
execution_contexts = pool.map(generate_test_case, input_dir.iterdir())
for result in tqdm.tqdm(
pool.imap(generate_test_case, input_dir.iterdir()), total=num_test_cases
):
execution_contexts.append(result)

# Process the test cases in parallel through shared libraries
print("Executing tests...")
execution_results = []
with Pool(
processes=num_processes,
initializer=initialize_process_output_buffers,
initargs=(randomize_output_buffer,),
) as pool:
execution_results = pool.starmap(process_single_test_case, execution_contexts)
for result in tqdm.tqdm(
pool.imap(
functools.partial(lazy_starmap, function=process_single_test_case),
execution_contexts,
),
total=num_test_cases,
):
execution_results.append(result)

print("Pruning results...")
# Prune modified accounts that were not actually modified
print("Pruning results...")
pruned_execution_results = []
with Pool(processes=num_processes) as pool:
pruned_execution_results = pool.starmap(
prune_execution_result, zip(execution_contexts, execution_results)
)
for result in tqdm.tqdm(
pool.imap(
functools.partial(lazy_starmap, function=prune_execution_result),
zip(execution_contexts, execution_results),
),
total=num_test_cases,
):
pruned_execution_results.append(result)

# Process the test results in parallel
print("Building test results...")
test_case_results = []
with Pool(processes=num_processes) as pool:
test_case_results = pool.starmap(build_test_results, pruned_execution_results)
for result in tqdm.tqdm(
pool.imap(
functools.partial(lazy_starmap, function=build_test_results),
pruned_execution_results,
),
total=num_test_cases,
):
test_case_results.append(result)

print("Logging results...")
passed = 0
Expand Down Expand Up @@ -504,8 +570,15 @@ def decode_protobuf(
shutil.rmtree(globals.output_dir)
globals.output_dir.mkdir(parents=True, exist_ok=True)

num_test_cases = len(list(input_dir.iterdir()))

write_results = []
with Pool(processes=num_processes) as pool:
write_results = pool.map(decode_single_test_case, input_dir.iterdir())
for result in tqdm.tqdm(
pool.imap(decode_single_test_case, input_dir.iterdir()),
total=num_test_cases,
):
write_results.append(result)

print("-" * LOG_FILE_SEPARATOR_LENGTH)
print(f"{len(write_results)} total files seen")
Expand Down

0 comments on commit 4048d4c

Please sign in to comment.