Skip to content

Commit

Permalink
refacto!: flatten structure
Browse files Browse the repository at this point in the history
  • Loading branch information
Caceresenzo committed Jan 15, 2024
1 parent e737534 commit 615f613
Show file tree
Hide file tree
Showing 17 changed files with 167 additions and 197 deletions.
3 changes: 1 addition & 2 deletions backtest/backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from .data.source.base import DataSource
from .export import BaseExporter, Snapshot
from .fee import ConstantFeeModel, FeeModel
from .order import Order, OrderResult
from .order.provider import OrderProvider
from .order import Order, OrderResult, OrderProvider
from .price_provider import PriceProvider, SymbolMapper


Expand Down
8 changes: 4 additions & 4 deletions backtest/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def main(

quantity_in_decimal = quantity_mode == "percent"

from .order.provider import DataFrameOrderProvider
from .order import DataFrameOrderProvider
order_provider = DataFrameOrderProvider(
readwrite.read(order_file),
offset_before_trading,
Expand Down Expand Up @@ -168,9 +168,9 @@ def main(
)

if file_parquet:
from .data.source.file import RowParquetFileDataSource
file_data_source = RowParquetFileDataSource(
path=file_parquet,
from .data.source import DataFrameDataSource
file_data_source = DataFrameDataSource(
path=readwrite.read(file_parquet),
date_column=file_parquet_column_date,
symbol_column=file_parquet_column_symbol,
price_column=file_parquet_column_price
Expand Down
1 change: 1 addition & 0 deletions backtest/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
DEFAULT_DATE_COLUMN = "date"
DEFAULT_SYMBOL_COLUMN = "symbol"
DEFAULT_QUANTITY_COLUMN = "quantity"
DEFAULT_PRICE_COLUMN = "price"
2 changes: 0 additions & 2 deletions backtest/data/holidays.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import datetime

import dateutil.parser

holidays = list(map(lambda x: dateutil.parser.parse(x).date(), [
Expand Down
6 changes: 3 additions & 3 deletions backtest/data/source/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .base import DataSource
from .coinmarketcap import CoinMarketCapDataSource
from .file import RowParquetFileDataSource
from .yahoo import YahooDataSource
from .dataframe import DataFrameDataSource
from .delegate import DelegateDataSource
from .factset import FactsetDataSource
from .factset import FactsetDataSource
from .yahoo import YahooDataSource
19 changes: 11 additions & 8 deletions backtest/data/source/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,25 @@
import pandas


class DataSource:
class DataSource(metaclass=abc.ABCMeta):

@abc.abstractmethod
def fetch_prices(self, symbols: typing.Set[str], start: datetime.date, end: datetime.date) -> pandas.DataFrame:
return None
def fetch_prices(
self,
symbols: typing.Set[str],
start: datetime.date,
end: datetime.date
) -> pandas.DataFrame:
raise NotImplementedError()

@abc.abstractmethod
def is_closeable(self) -> bool:
"""
Return whether or not the markat has closing hours.
Cryptocurrencies for examples does not.
"""

return True

@abc.abstractmethod
def get_name(self) -> str:
base_name = DataSource.__name__

Expand All @@ -30,5 +33,5 @@ def get_name(self) -> str:
class_name = self.__class__.__name__
if base_name in class_name:
return class_name.replace(base_name, "")
return class_name

return class_name
6 changes: 2 additions & 4 deletions backtest/data/source/coinmarketcap.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ def _fetch_mapping(self, page_size) -> None:

print(f"[info] [datasource] [coinmarketcap] mapping size is {len(self.symbol_to_id_mapping)}", file=sys.stderr)

@abc.abstractmethod
def fetch_prices(self, symbols: typing.Set[str], start: datetime.date, end: datetime.date) -> pandas.DataFrame:
def fetch_prices(self, symbols, start, end):
today = pandas.to_datetime(datetime.date.today())

prices: pandas.DataFrame = None
Expand Down Expand Up @@ -121,6 +120,5 @@ def fetch_prices(self, symbols: typing.Set[str], start: datetime.date, end: date

return prices

@abc.abstractmethod
def is_closeable(self) -> bool:
def is_closeable(self):
return False
64 changes: 64 additions & 0 deletions backtest/data/source/dataframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import numpy
import pandas

from .base import DataSource
from ... import constants


class DataFrameDataSource(DataSource):

def __init__(
self,
dataframe: pandas.DataFrame,
date_column=constants.DEFAULT_DATE_COLUMN,
symbol_column=constants.DEFAULT_SYMBOL_COLUMN,
price_column=constants.DEFAULT_PRICE_COLUMN,
closeable=True
) -> None:
super().__init__()

dataframe = dataframe.drop_duplicates(
subset=[symbol_column, date_column],
keep="first"
)

dataframe = dataframe.pivot(
index=date_column,
columns=symbol_column,
values=price_column
)

dataframe.index = pandas.to_datetime(dataframe.index)
dataframe.index.name = constants.DEFAULT_DATE_COLUMN

self.dataframe = dataframe
self.closeable = closeable

def fetch_prices(self, symbols, start, end):
symbols = set(symbols)

missings = symbols - set(self.dataframe.columns)
founds = symbols - missings

prices = None
if len(founds):
start = pandas.to_datetime(start)
end = pandas.to_datetime(end)

prices = self.dataframe[
(self.dataframe.index >= start) &
(self.dataframe.index <= end)
][list(founds)].copy()
else:
prices = pandas.DataFrame(
index=pandas.DatetimeIndex(
data=pandas.date_range(start=start, end=end)
)
)

prices[list(missings)] = numpy.nan

return prices

def is_closeable(self):
return self.closeable
9 changes: 3 additions & 6 deletions backtest/data/source/delegate.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import abc
import datetime
import typing
import numpy

import numpy
import pandas

from .base import DataSource
Expand All @@ -13,8 +12,7 @@ class DelegateDataSource(DataSource):
def __init__(self, delegates: typing.List[DataSource]):
self.delegates = delegates

@abc.abstractmethod
def fetch_prices(self, symbols: typing.Set[str], start: datetime.date, end: datetime.date) -> pandas.DataFrame:
def fetch_prices(self, symbols, start, end):
prices = None

for delegate in self.delegates:
Expand Down Expand Up @@ -43,8 +41,7 @@ def fetch_prices(self, symbols: typing.Set[str], start: datetime.date, end: date

return prices

@abc.abstractmethod
def is_closeable(self) -> bool:
def is_closeable(self):
return True

@staticmethod
Expand Down
17 changes: 8 additions & 9 deletions backtest/data/source/factset.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import datetime

import abc
import os
import sys
import typing

import pandas
import requests
import typing
import tqdm

from ...utils import ensure_not_blank
from .base import DataSource


def chunks(l, n):
n = max(1, n)
return [l[i: i + n] for i in range(0, len(l), n)]


class FactsetDataSource(DataSource):

def __init__(self, username_serial: str, api_key: str, chunk_size=100):
Expand All @@ -30,8 +30,7 @@ def __init__(self, username_serial: str, api_key: str, chunk_size=100):

self.chunk_size = chunk_size

@abc.abstractmethod
def fetch_prices(self, symbols: typing.Set[str], start: datetime.date, end: datetime.date) -> pandas.DataFrame:
def fetch_prices(self, symbols, start, end) -> pandas.DataFrame:
prices = None

for chunk in tqdm.tqdm(chunks(list(symbols), self.chunk_size)):
Expand All @@ -50,7 +49,8 @@ def fetch_prices(self, symbols: typing.Set[str], start: datetime.date, end: date

status_code = response.status_code
if status_code != 200:
print(f"got status {status_code}: {response.content}", file=sys.stderr)
print(
f"got status {status_code}: {response.content}", file=sys.stderr)
continue

dataframe = FactsetDataSource._to_dataframe(response.json())
Expand All @@ -64,10 +64,9 @@ def fetch_prices(self, symbols: typing.Set[str], start: datetime.date, end: date
on="Date",
how="outer"
)

return prices

@abc.abstractmethod
def is_closeable(self) -> bool:
return True

Expand Down
1 change: 0 additions & 1 deletion backtest/data/source/file/__init__.py

This file was deleted.

80 changes: 0 additions & 80 deletions backtest/data/source/file/parquet.py

This file was deleted.

11 changes: 2 additions & 9 deletions backtest/data/source/yahoo.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,17 @@
import abc
import datetime
import typing

import pandas
import yfinance

from .base import DataSource


class YahooDataSource(DataSource):

@abc.abstractmethod
def fetch_prices(self, symbols: typing.Set[str], start: datetime.date, end: datetime.date) -> pandas.DataFrame:
def fetch_prices(self, symbols, start, end):
return yfinance.download(
tickers=symbols,
start=start,
end=end,
show_errors=False
)["Adj Close"]

@abc.abstractmethod
def is_closeable(self) -> bool:
def is_closeable(self):
return True
Loading

0 comments on commit 615f613

Please sign in to comment.