diff --git a/example/xtest_constituents.py b/example/xtest_constituents.py index 8747d1f..2e67936 100644 --- a/example/xtest_constituents.py +++ b/example/xtest_constituents.py @@ -15,9 +15,6 @@ from bktest.export import DumpExporter from bktest.export import QuantStatsExporter - - - start = datetime.date(1998,11,1) end = datetime.date(2000,11,30) initial_cash = 1_000_000 @@ -87,11 +84,11 @@ def compare_two_files_containing_dataframes(filename1, filename2) -> bool: assert filename1.split('.')[1] == filename2.split('.')[1], "mismatch in file types" if filename1.split('.')[1] == 'parquet': - df0 = pandas.read_parquet('dump.parquet') - df1 = pandas.read_parquet('example/test_constituents_dump.parquet') + df0 = pandas.read_parquet(filename1) + df1 = pandas.read_parquet(filename2) elif filename1.split('.')[1] == 'csv': - df0 = pandas.read_csv('dump.parquet') - df1 = pandas.read_csv('example/test_constituents_dump.parquet') + df0 = pandas.read_csv(filename1) + df1 = pandas.read_csv(filename2) else: assert False, "invalid file format specified" @@ -104,8 +101,10 @@ def compare_two_files_containing_dataframes(filename1, filename2) -> bool: assert (df0[~df0[column].isna()][column] == df1[~df1[column].isna()][column]).all(), "not all non-NaN are equal in comparing int" elif pandas.api.types.is_float_dtype(df0[column]): print(f"Column {column} is of float type.") + print(max(abs(df0[~df0[column].isna()][column]/df1[~df1[column].isna()][column]-1))) + print(max(abs(df0[~df0[column].isna()][column]-df1[~df1[column].isna()][column]))) assert (df0[column].isna()==df1[column].isna()).all(), "not all NaN are equal in comparing float" - assert np.allclose(df0[~df0[column].isna()][column],df1[~df1[column].isna()][column]), "not all non-NaN are equal in comparing float" + assert np.allclose(df0[~df0[column].isna()][column],df1[~df1[column].isna()][column],atol=1.e-7), "not all non-NaN are equal in comparing float" elif pandas.api.types.is_string_dtype(df0[column]): print(f"Column {column} is of string type.") assert (df0[column] == df1[column]).all(), "not all string elements are equal in comparing str" @@ -118,4 +117,4 @@ def compare_two_files_containing_dataframes(filename1, filename2) -> bool: return True -compare_two_files_containing_dataframes('dump.parquet', 'example/test_constituents_dump.parquet') +compare_two_files_containing_dataframes('dump_last.csv', 'tests/fixtures/integration/yahoo/prices/dump.csv') \ No newline at end of file diff --git a/tests/helper.py b/tests/helper.py index 2654a20..e403311 100644 --- a/tests/helper.py +++ b/tests/helper.py @@ -21,6 +21,7 @@ def assertDataFramesEqual(self: unittest.TestCase, first: pandas.DataFrame, seco self.assertTrue(numpy.allclose( first[~first[column].isna()][column], second[~second[column].isna()][column], + atol=1.e-7, equal_nan=True ), f"{column}: not all non-NaN are equal in comparing float")