From e0d3061897c38a4cb3530e49fe654195786bc960 Mon Sep 17 00:00:00 2001 From: MetzkerLior Date: Tue, 28 May 2024 14:30:41 +0300 Subject: [PATCH] chore: add the ability to load a data source containing returns --- bktest/data/source/base.py | 4 +++- bktest/data/source/dataframe.py | 4 +++- bktest/price_provider.py | 32 ++++++++++++++------------------ 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/bktest/data/source/base.py b/bktest/data/source/base.py index 6b274d0..5842081 100644 --- a/bktest/data/source/base.py +++ b/bktest/data/source/base.py @@ -6,7 +6,9 @@ class DataSource(metaclass=abc.ABCMeta): - + def __init__(self, data_source_contains_prices_not_returns = True) -> None: + self.data_source_contains_prices_not_returns = data_source_contains_prices_not_returns + @abc.abstractmethod def fetch_prices( self, diff --git a/bktest/data/source/dataframe.py b/bktest/data/source/dataframe.py index f04bc4b..e70dac6 100644 --- a/bktest/data/source/dataframe.py +++ b/bktest/data/source/dataframe.py @@ -13,7 +13,8 @@ def __init__( date_column=constants.DEFAULT_DATE_COLUMN, symbol_column=constants.DEFAULT_SYMBOL_COLUMN, price_column=constants.DEFAULT_PRICE_COLUMN, - closeable=True + closeable=True, + data_source_contains_prices_not_returns=True # True for price, False for returns. ) -> None: super().__init__() @@ -33,6 +34,7 @@ def __init__( self.dataframe = dataframe self.closeable = closeable + self.data_source_contains_prices_not_returns = data_source_contains_prices_not_returns def fetch_prices(self, symbols, start, end): symbols = set(symbols) diff --git a/bktest/price_provider.py b/bktest/price_provider.py index 68eaad0..e1824ce 100644 --- a/bktest/price_provider.py +++ b/bktest/price_provider.py @@ -11,6 +11,7 @@ from .data.source.base import DataSource from . import constants +import copy class SymbolMapper: @@ -78,9 +79,6 @@ def __init__(self, start: datetime.date, end: datetime.date, data_source: DataSo self.caching = caching self.storage = PriceProvider._create_storage(start, end, caching) - self.close_price = pandas.DataFrame() - self.adj_close_price = pandas.DataFrame() - #self.total_returns = pandas.DataFrame() self.total_returns = PriceProvider._create_storage(start, end, caching, name='returns') self.symbols = PriceProvider._create_symbols_set(self.storage) @@ -98,17 +96,10 @@ def download_missing(self, symbols: typing.Set[str]): prices = self.data_source.fetch_prices( symbols=self.mapper.maps(missing_symbols), - start=self.start - one_day, # Not enough since day before is not necessarily a trading day... or it's ok... because first day is the base + start=self.start - one_day, # Not enough since day before is not necessarily a trading day... or it's ok... because first day is the base? end=self.end + one_day ) -# self.close_price = self.data_source.fetch_prices( -# symbols=self.mapper.maps(missing_symbols), -# start=self.start - one_day, # Not enough since day before is not necessarily a trading day... or it's ok... because first day is the base -# end=self.end + one_day, -# "Close", -# ) - if prices is None: prices = pandas.DataFrame( index=pandas.Index([], name=constants.DEFAULT_DATE_COLUMN), @@ -135,12 +126,19 @@ def download_missing(self, symbols: typing.Set[str]): )) prices.columns = self.mapper.unmaps(prices.columns) + for column in prices.columns: if prices[column].isna().values.all(): print(f"[warning] {column} does not have a price", file=sys.stderr) - - - + + if self.data_source.data_source_contains_prices_not_returns: + total_returns = prices/prices.shift(1) -1 + else: + total_returns = copy.deepcopy(prices) + # If return for a specific stock exists, there was a price/trading in that day and since we work with returns the price is set to one else it is NaN. + prices = (prices.abs() + 1e-6)/(prices.abs() + 1e-6) + assert (prices.isna()==total_returns.isna()).all().all() + if self.storage is not None: with warnings.catch_warnings(): warnings.simplefilter(action='ignore', category=pandas.errors.PerformanceWarning) @@ -153,9 +151,7 @@ def download_missing(self, symbols: typing.Set[str]): ) else: self.storage = prices - - total_returns = prices/prices.shift(1) -1 - + if self.total_returns is not None: with warnings.catch_warnings(): warnings.simplefilter(action='ignore', category=pandas.errors.PerformanceWarning) @@ -191,7 +187,7 @@ def get_total_return(self, date: datetime.date, symbol: str): symbol = self.mapper.map(symbol) value = self.total_returns[symbol][numpy.datetime64(date)] - if not value or numpy.isnan(value): + if numpy.isnan(value): value = None return value