diff --git a/compiler/fm-equalize/fm-equalize b/compiler/fm-equalize/fm-equalize index 4e4e5395e18..36b4f99a003 100644 --- a/compiler/fm-equalize/fm-equalize +++ b/compiler/fm-equalize/fm-equalize @@ -62,6 +62,18 @@ def _get_parser(): help="Allow to create duplicate operations when a feature map matches " "with multiple equalization patterns. This can increase the size of " "the model. Default is false.") + parser.add_argument("--fme_detect", + type=str, + help="Path to fme-detect driver.", + required=False) + parser.add_argument("--dalgona", + type=str, + help="Path to dalgona driver.", + required=False) + parser.add_argument("--fme_apply", + type=str, + help="Path to fme-apply driver.", + required=False) parser.add_argument('--verbose', action='store_true', help='Print logs') return parser @@ -78,12 +90,9 @@ def _run_cmd(cmd: str, verbose: bool): raise -def _run_dalgona(model: str, data: Optional[str], analysis: str, save_dir: str, - verbose: bool): - dir_path = os.getenv('ONE_BIN_PATH') - assert dir_path != None - dalgona_path = os.path.join(dir_path, 'dalgona') - cmd = [dalgona_path] +def _run_dalgona(driver_path: str, model: str, data: Optional[str], analysis: str, + save_dir: str, verbose: bool): + cmd = [driver_path] cmd += ['--input_model', model] cmd += ['--analysis', analysis] if data != None: @@ -94,11 +103,9 @@ def _run_dalgona(model: str, data: Optional[str], analysis: str, save_dir: str, _run_cmd(cmd, verbose) -def _run_fme_detect(input_model: str, fme_patterns: str, verbose: bool, +def _run_fme_detect(driver_path: str, input_model: str, fme_patterns: str, verbose: bool, allow_dup_op: bool): - dir_path = Path(__file__).parent.resolve() - fme_detect_path = os.path.join(dir_path, 'fme-detect') - cmd = [fme_detect_path] + cmd = [driver_path] cmd += ['--input', input_model] cmd += ['--output', fme_patterns] if allow_dup_op: @@ -107,10 +114,9 @@ def _run_fme_detect(input_model: str, fme_patterns: str, verbose: bool, _run_cmd(cmd, verbose) -def _run_fme_apply(input_model: str, fme_patterns: str, output_model: str, verbose: bool): - dir_path = Path(__file__).parent.resolve() - fme_apply_path = os.path.join(dir_path, 'fme-apply') - cmd = [fme_apply_path] +def _run_fme_apply(driver_path: str, input_model: str, fme_patterns: str, + output_model: str, verbose: bool): + cmd = [driver_path] cmd += ['--input', input_model] cmd += ['--fme_patterns', fme_patterns] cmd += ['--output', output_model] @@ -128,6 +134,25 @@ def main(): data = args.data verbose = args.verbose allow_dup_op = args.allow_dup_op + fme_detect_path = args.fme_detect + fme_apply_path = args.fme_apply + dalgona_path = args.dalgona + + curr_dir = Path(__file__).parent.resolve() + dump_fme_param_py = curr_dir / 'fmelib' / 'DumpFMEParams.py' + if dump_fme_param_py.exists() == False: + raise FileNotFoundError('Error: DumpFMEParams.py not found') + + if not fme_detect_path: + dir_path = Path(__file__).parent.resolve() + fme_detect_path = os.path.join(dir_path, 'fme-detect') + if not dalgona_path: + dir_path = os.getenv('ONE_BIN_PATH') + assert dir_path != None + dalgona_path = os.path.join(dir_path, 'dalgona') + if not fme_apply_path: + dir_path = Path(__file__).parent.resolve() + fme_apply_path = os.path.join(dir_path, 'fme-apply') with tempfile.TemporaryDirectory() as tmp_dir: fme_patterns = os.path.join( @@ -135,7 +160,8 @@ def main(): Path(output_model).with_suffix('.fme_patterns.json').name) # Step 1. Run fme-detect to find equalization patterns - _run_fme_detect(str(input_model), + _run_fme_detect(fme_detect_path, + str(input_model), str(fme_patterns), verbose=verbose, allow_dup_op=allow_dup_op) @@ -144,8 +170,13 @@ def main(): if args.fme_patterns != None: os.system(f'cp {fme_patterns} {args.fme_patterns}') - # TODO Step 2. Run dalgona - # _run_dalgona + # Step 2. Run dalgona + _run_dalgona(dalgona_path, + str(input_model), + data, + str(dump_fme_param_py), + str(fme_patterns), + verbose=verbose) # Copy fme_patterns to the given path # Why copy twice? To observe the result of fme-detect too @@ -153,7 +184,8 @@ def main(): os.system(f'cp {fme_patterns} {args.fme_patterns}') # Step 3. Run fme-apply - _run_fme_apply(str(input_model), + _run_fme_apply(fme_apply_path, + str(input_model), str(fme_patterns), str(output_model), verbose=verbose)