Skip to content

Commit

Permalink
chore: add the ability to load a data source containing returns
Browse files Browse the repository at this point in the history
  • Loading branch information
MetzkerLior committed May 28, 2024
1 parent 9b92a3d commit e0d3061
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 20 deletions.
4 changes: 3 additions & 1 deletion bktest/data/source/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@


class DataSource(metaclass=abc.ABCMeta):

def __init__(self, data_source_contains_prices_not_returns = True) -> None:
self.data_source_contains_prices_not_returns = data_source_contains_prices_not_returns

@abc.abstractmethod
def fetch_prices(
self,
Expand Down
4 changes: 3 additions & 1 deletion bktest/data/source/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def __init__(
date_column=constants.DEFAULT_DATE_COLUMN,
symbol_column=constants.DEFAULT_SYMBOL_COLUMN,
price_column=constants.DEFAULT_PRICE_COLUMN,
closeable=True
closeable=True,
data_source_contains_prices_not_returns=True # True for price, False for returns.
) -> None:
super().__init__()

Expand All @@ -33,6 +34,7 @@ def __init__(

self.dataframe = dataframe
self.closeable = closeable
self.data_source_contains_prices_not_returns = data_source_contains_prices_not_returns

def fetch_prices(self, symbols, start, end):
symbols = set(symbols)
Expand Down
32 changes: 14 additions & 18 deletions bktest/price_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .data.source.base import DataSource
from . import constants

import copy

class SymbolMapper:

Expand Down Expand Up @@ -78,9 +79,6 @@ def __init__(self, start: datetime.date, end: datetime.date, data_source: DataSo
self.caching = caching

self.storage = PriceProvider._create_storage(start, end, caching)
self.close_price = pandas.DataFrame()
self.adj_close_price = pandas.DataFrame()
#self.total_returns = pandas.DataFrame()
self.total_returns = PriceProvider._create_storage(start, end, caching, name='returns')
self.symbols = PriceProvider._create_symbols_set(self.storage)

Expand All @@ -98,17 +96,10 @@ def download_missing(self, symbols: typing.Set[str]):

prices = self.data_source.fetch_prices(
symbols=self.mapper.maps(missing_symbols),
start=self.start - one_day, # Not enough since day before is not necessarily a trading day... or it's ok... because first day is the base
start=self.start - one_day, # Not enough since day before is not necessarily a trading day... or it's ok... because first day is the base?
end=self.end + one_day
)

# self.close_price = self.data_source.fetch_prices(
# symbols=self.mapper.maps(missing_symbols),
# start=self.start - one_day, # Not enough since day before is not necessarily a trading day... or it's ok... because first day is the base
# end=self.end + one_day,
# "Close",
# )

if prices is None:
prices = pandas.DataFrame(
index=pandas.Index([], name=constants.DEFAULT_DATE_COLUMN),
Expand All @@ -135,12 +126,19 @@ def download_missing(self, symbols: typing.Set[str]):
))

prices.columns = self.mapper.unmaps(prices.columns)

for column in prices.columns:
if prices[column].isna().values.all():
print(f"[warning] {column} does not have a price", file=sys.stderr)




if self.data_source.data_source_contains_prices_not_returns:
total_returns = prices/prices.shift(1) -1
else:
total_returns = copy.deepcopy(prices)
# If return for a specific stock exists, there was a price/trading in that day and since we work with returns the price is set to one else it is NaN.
prices = (prices.abs() + 1e-6)/(prices.abs() + 1e-6)
assert (prices.isna()==total_returns.isna()).all().all()

if self.storage is not None:
with warnings.catch_warnings():
warnings.simplefilter(action='ignore', category=pandas.errors.PerformanceWarning)
Expand All @@ -153,9 +151,7 @@ def download_missing(self, symbols: typing.Set[str]):
)
else:
self.storage = prices

total_returns = prices/prices.shift(1) -1


if self.total_returns is not None:
with warnings.catch_warnings():
warnings.simplefilter(action='ignore', category=pandas.errors.PerformanceWarning)
Expand Down Expand Up @@ -191,7 +187,7 @@ def get_total_return(self, date: datetime.date, symbol: str):
symbol = self.mapper.map(symbol)

value = self.total_returns[symbol][numpy.datetime64(date)]
if not value or numpy.isnan(value):
if numpy.isnan(value):
value = None

return value
Expand Down

0 comments on commit e0d3061

Please sign in to comment.