-
Notifications
You must be signed in to change notification settings - Fork 467
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add initial fusion regression script and initial yaml config (#2611)
- Loading branch information
Showing
3 changed files
with
267 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Fusion Regression Test Setup | ||
|
||
This document provides instructions for setting up and downloading the necessary run files to perform fusion regression tests. | ||
|
||
|
||
## Perform two regression runs for test fusion-regression-bge-flat-int8-robust04-2 | ||
|
||
One could generate the runs necessary for test fusion-regression-bge-flat-int8-robust04-2 following | ||
- https://github.com/castorini/anserini/blob/master/docs/regressions/regressions-beir-v1.0.0-robust04.bge-base-en-v1.5.flat-int8.cached.md | ||
- https://github.com/castorini/anserini/blob/master/docs/regressions/regressions-beir-v1.0.0-robust04.bge-base-en-v1.5.flat.cached.md | ||
|
||
## Run fuse-regression script with two yaml tests | ||
```bash | ||
python src/main/python/run_fusion_regression.py --regression fusion-regression-bge-flat-int8-robust04-2 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
# | ||
# Anserini: A Lucene toolkit for reproducible information retrieval research | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
import os | ||
import argparse | ||
import logging | ||
import time | ||
import yaml | ||
from subprocess import call, Popen, PIPE | ||
|
||
# Constants | ||
FUSE_COMMAND = 'bin/run.sh io.anserini.fusion.FuseTrecRuns' | ||
|
||
# Set up logging | ||
logger = logging.getLogger('fusion_regression_test') | ||
logger.setLevel(logging.INFO) | ||
ch = logging.StreamHandler() | ||
ch.setLevel(logging.INFO) | ||
formatter = logging.Formatter('%(asctime)s %(levelname)s [python] %(message)s') | ||
ch.setFormatter(formatter) | ||
logger.addHandler(ch) | ||
|
||
def is_close(a: float, b: float, rel_tol: float = 1e-9, abs_tol: float = 0.0) -> bool: | ||
"""Check if two numbers are close within a given tolerance.""" | ||
return abs(a - b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol) | ||
|
||
def check_output(command: str) -> str: | ||
"""Run a shell command and return its output. Raise an error if the command fails.""" | ||
process = Popen(command, shell=True, stdout=PIPE) | ||
output, err = process.communicate() | ||
if process.returncode == 0: | ||
return output | ||
else: | ||
raise RuntimeError(f"Command {command} failed with error: {err}") | ||
|
||
def construct_fusion_commands(yaml_data: dict) -> list: | ||
""" | ||
Constructs the fusion commands from the YAML configuration. | ||
Args: | ||
yaml_data (dict): The loaded YAML configuration. | ||
Returns: | ||
list: A list of commands to be executed. | ||
""" | ||
return [ | ||
[ | ||
FUSE_COMMAND, | ||
'-runs', ' '.join([run for run in yaml_data['runs']]), | ||
'-output', method.get('output'), | ||
'-method', method.get('name', 'average'), | ||
'-k', str(method.get('k', 1000)), | ||
'-depth', str(method.get('depth', 1000)), | ||
'-rrf_k', str(method.get('rrf_k', 60)), | ||
'-alpha', str(method.get('alpha', 0.5)) | ||
] | ||
for method in yaml_data['methods'] | ||
] | ||
|
||
def run_fusion_commands(cmds: list): | ||
""" | ||
Run the fusion commands and log the results. | ||
Args: | ||
cmds (list): List of fusion commands to run. | ||
""" | ||
for cmd_list in cmds: | ||
cmd = ' '.join(cmd_list) | ||
logger.info(f'Running command: {cmd}') | ||
try: | ||
return_code = call(cmd, shell=True) | ||
if return_code != 0: | ||
logger.error(f"Command failed with return code {return_code}: {cmd}") | ||
except Exception as e: | ||
logger.error(f"Error executing command {cmd}: {str(e)}") | ||
|
||
def evaluate_and_verify(yaml_data: dict, dry_run: bool): | ||
""" | ||
Runs the evaluation and verification of the fusion results. | ||
Args: | ||
yaml_data (dict): The loaded YAML configuration. | ||
dry_run (bool): If True, output commands without executing them. | ||
""" | ||
fail_str = '\033[91m[FAIL]\033[0m ' | ||
ok_str = ' [OK] ' | ||
failures = False | ||
|
||
logger.info('=' * 10 + ' Verifying Fusion Results ' + '=' * 10) | ||
|
||
for method in yaml_data['methods']: | ||
for i, topic_set in enumerate(yaml_data['topics']): | ||
for metric in yaml_data['metrics']: | ||
output_runfile = str(method.get('output')) | ||
|
||
# Build evaluation command | ||
eval_cmd = [ | ||
os.path.join(metric['command']), | ||
metric['params'] if 'params' in metric and metric['params'] else '', | ||
os.path.join('tools/topics-and-qrels', topic_set['qrel']) if 'qrel' in topic_set and topic_set['qrel'] else '', | ||
output_runfile | ||
] | ||
|
||
if dry_run: | ||
logger.info(' '.join(eval_cmd)) | ||
continue | ||
|
||
try: | ||
out = [line for line in | ||
check_output(' '.join(eval_cmd)).decode('utf-8').split('\n') if line.strip()][-1] | ||
if not out.strip(): | ||
continue | ||
except Exception as e: | ||
logger.error(f"Failed to execute evaluation command: {str(e)}") | ||
continue | ||
|
||
eval_out = out.strip().split(metric['separator'])[metric['parse_index']] | ||
expected = round(method['results'][metric['metric']][i], metric['metric_precision']) | ||
actual = round(float(eval_out), metric['metric_precision']) | ||
result_str = ( | ||
f'expected: {expected:.4f} actual: {actual:.4f} (delta={abs(expected-actual):.4f}) - ' | ||
f'metric: {metric["metric"]:<8} method: {method["name"]} topics: {topic_set["id"]}' | ||
) | ||
|
||
if is_close(expected, actual) or actual > expected: | ||
logger.info(ok_str + result_str) | ||
else: | ||
logger.error(fail_str + result_str) | ||
failures = True | ||
|
||
end_time = time.time() | ||
logger.info(f"Total execution time: {end_time - start_time:.2f} seconds") | ||
if failures: | ||
logger.error(f'{fail_str}Some tests failed.') | ||
else: | ||
logger.info(f'All tests passed successfully!') | ||
|
||
if __name__ == '__main__': | ||
start_time = time.time() | ||
|
||
# Command-line argument parsing | ||
parser = argparse.ArgumentParser(description='Run Fusion regression tests.') | ||
parser.add_argument('--regression', required=True, help='Name of the regression test configuration.') | ||
parser.add_argument('--dry-run', dest='dry_run', action='store_true', | ||
help='Output commands without actual execution.') | ||
args = parser.parse_args() | ||
|
||
# Load YAML configuration | ||
try: | ||
with open(f'src/main/resources/fuse_regression/{args.regression}.yaml') as f: | ||
yaml_data = yaml.safe_load(f) | ||
except FileNotFoundError as e: | ||
logger.error(f"Failed to load configuration file: {e}") | ||
exit(1) | ||
|
||
# Construct the fusion command | ||
fusion_commands = construct_fusion_commands(yaml_data) | ||
|
||
# Run the fusion process | ||
if args.dry_run: | ||
logger.info(' '.join([cmd for cmd_list in fusion_commands for cmd in cmd_list])) | ||
else: | ||
run_fusion_commands(fusion_commands) | ||
|
||
# Evaluate and verify results | ||
evaluate_and_verify(yaml_data, args.dry_run) | ||
|
||
logger.info(f"Total execution time: {time.time() - start_time:.2f} seconds") |
71 changes: 71 additions & 0 deletions
71
...s/fuse_regression/beir-v1.0.0-robust04.flat.bm25.fuse.bge-base-en-v1.5.bge-flat-onnx.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
--- | ||
corpus: beir-v1.0.0-robust04 | ||
corpus_path: collections/beir-v1.0.0/corpus/robust04/ | ||
|
||
metrics: | ||
- metric: nDCG@10 | ||
command: bin/trec_eval | ||
params: -c -m ndcg_cut.10 | ||
separator: "\t" | ||
parse_index: 2 | ||
metric_precision: 4 | ||
can_combine: false | ||
- metric: R@100 | ||
command: bin/trec_eval | ||
params: -c -m recall.100 | ||
separator: "\t" | ||
parse_index: 2 | ||
metric_precision: 4 | ||
can_combine: false | ||
- metric: R@1000 | ||
command: bin/trec_eval | ||
params: -c -m recall.1000 | ||
separator: "\t" | ||
parse_index: 2 | ||
metric_precision: 4 | ||
can_combine: false | ||
|
||
topic_reader: TsvString | ||
topics: | ||
- name: "BEIR (v1.0.0): Robust04" | ||
id: test | ||
path: topics.beir-v1.0.0-robust04.test.tsv.gz | ||
qrel: qrels.beir-v1.0.0-robust04.test.txt | ||
|
||
# Fusion Regression Test Configuration | ||
runs: | ||
- runs/run.beir-v1.0.0-robust04.flat.bm25.topics.beir-v1.0.0-robust04.test.txt | ||
- runs/run.beir-v1.0.0-robust04.bge-base-en-v1.5.bge-flat-onnx.topics.beir-v1.0.0-robust04.test.txt | ||
|
||
methods: | ||
- name: rrf | ||
k: 1000 | ||
depth: 1000 | ||
rrf_k: 60 | ||
output: runs/runs.fuse.rrf.beir-v1.0.0-robust04.flat.bm25.bge-base-en-v1.5.bge-flat-onnx.topics.beir-v1.0.0-robust04.test.txt | ||
results: | ||
nDCG@10: | ||
- 0.5070 | ||
R@100: | ||
- 0.4465 | ||
R@1000: | ||
- 0.7219 | ||
- name: average | ||
output: runs/runs.fuse.avg.beir-v1.0.0-robust04.flat.bm25.bge-base-en-v1.5.bge-flat-onnx.topics.beir-v1.0.0-robust04.test.txt | ||
results: | ||
nDCG@10: | ||
- 0.4324 | ||
R@100: | ||
- 0.3963 | ||
R@1000: | ||
- 0.6345 | ||
- name: interpolation | ||
alpha: 0.5 | ||
output: runs/runs.fuse.interp.beir-v1.0.0-robust04.flat.bm25.bge-base-en-v1.5.bge-flat-onnx.topics.beir-v1.0.0-robust04.test.txt | ||
results: | ||
nDCG@10: | ||
- 0.4324 | ||
R@100: | ||
- 0.3963 | ||
R@1000: | ||
- 0.6345 |