From 615f61316c750a4471c187b919fec74342c75fde Mon Sep 17 00:00:00 2001 From: Caceresenzo Date: Mon, 15 Jan 2024 19:04:21 +0400 Subject: [PATCH] refacto!: flatten structure --- backtest/backtest.py | 3 +- backtest/cli.py | 8 +-- backtest/constants.py | 1 + backtest/data/holidays.py | 2 - backtest/data/source/__init__.py | 6 +- backtest/data/source/base.py | 19 +++--- backtest/data/source/coinmarketcap.py | 6 +- backtest/data/source/dataframe.py | 64 +++++++++++++++++++ backtest/data/source/delegate.py | 9 +-- backtest/data/source/factset.py | 17 +++-- backtest/data/source/file/__init__.py | 1 - backtest/data/source/file/parquet.py | 80 ------------------------ backtest/data/source/yahoo.py | 11 +--- backtest/{order/provider.py => order.py} | 63 +++++++++++++++++-- backtest/order/__init__.py | 2 - backtest/order/_model.py | 55 ---------------- backtest/price_provider.py | 17 ++--- 17 files changed, 167 insertions(+), 197 deletions(-) create mode 100644 backtest/data/source/dataframe.py delete mode 100644 backtest/data/source/file/__init__.py delete mode 100644 backtest/data/source/file/parquet.py rename backtest/{order/provider.py => order.py} (68%) delete mode 100644 backtest/order/__init__.py delete mode 100644 backtest/order/_model.py diff --git a/backtest/backtest.py b/backtest/backtest.py index 0389e41..bf5dd0f 100755 --- a/backtest/backtest.py +++ b/backtest/backtest.py @@ -8,8 +8,7 @@ from .data.source.base import DataSource from .export import BaseExporter, Snapshot from .fee import ConstantFeeModel, FeeModel -from .order import Order, OrderResult -from .order.provider import OrderProvider +from .order import Order, OrderResult, OrderProvider from .price_provider import PriceProvider, SymbolMapper diff --git a/backtest/cli.py b/backtest/cli.py index cb7b620..6280987 100755 --- a/backtest/cli.py +++ b/backtest/cli.py @@ -120,7 +120,7 @@ def main( quantity_in_decimal = quantity_mode == "percent" - from .order.provider import DataFrameOrderProvider + from .order import DataFrameOrderProvider order_provider = DataFrameOrderProvider( readwrite.read(order_file), offset_before_trading, @@ -168,9 +168,9 @@ def main( ) if file_parquet: - from .data.source.file import RowParquetFileDataSource - file_data_source = RowParquetFileDataSource( - path=file_parquet, + from .data.source import DataFrameDataSource + file_data_source = DataFrameDataSource( + path=readwrite.read(file_parquet), date_column=file_parquet_column_date, symbol_column=file_parquet_column_symbol, price_column=file_parquet_column_price diff --git a/backtest/constants.py b/backtest/constants.py index 38dfc00..9e57e96 100644 --- a/backtest/constants.py +++ b/backtest/constants.py @@ -1,3 +1,4 @@ DEFAULT_DATE_COLUMN = "date" DEFAULT_SYMBOL_COLUMN = "symbol" DEFAULT_QUANTITY_COLUMN = "quantity" +DEFAULT_PRICE_COLUMN = "price" diff --git a/backtest/data/holidays.py b/backtest/data/holidays.py index ead8140..bcdf035 100644 --- a/backtest/data/holidays.py +++ b/backtest/data/holidays.py @@ -1,5 +1,3 @@ -import datetime - import dateutil.parser holidays = list(map(lambda x: dateutil.parser.parse(x).date(), [ diff --git a/backtest/data/source/__init__.py b/backtest/data/source/__init__.py index 248b52a..17dc735 100644 --- a/backtest/data/source/__init__.py +++ b/backtest/data/source/__init__.py @@ -1,6 +1,6 @@ from .base import DataSource from .coinmarketcap import CoinMarketCapDataSource -from .file import RowParquetFileDataSource -from .yahoo import YahooDataSource +from .dataframe import DataFrameDataSource from .delegate import DelegateDataSource -from .factset import FactsetDataSource \ No newline at end of file +from .factset import FactsetDataSource +from .yahoo import YahooDataSource diff --git a/backtest/data/source/base.py b/backtest/data/source/base.py index 3037286..6b274d0 100644 --- a/backtest/data/source/base.py +++ b/backtest/data/source/base.py @@ -5,22 +5,25 @@ import pandas -class DataSource: +class DataSource(metaclass=abc.ABCMeta): @abc.abstractmethod - def fetch_prices(self, symbols: typing.Set[str], start: datetime.date, end: datetime.date) -> pandas.DataFrame: - return None + def fetch_prices( + self, + symbols: typing.Set[str], + start: datetime.date, + end: datetime.date + ) -> pandas.DataFrame: + raise NotImplementedError() - @abc.abstractmethod def is_closeable(self) -> bool: """ Return whether or not the markat has closing hours. Cryptocurrencies for examples does not. """ - + return True - @abc.abstractmethod def get_name(self) -> str: base_name = DataSource.__name__ @@ -30,5 +33,5 @@ def get_name(self) -> str: class_name = self.__class__.__name__ if base_name in class_name: return class_name.replace(base_name, "") - - return class_name \ No newline at end of file + + return class_name diff --git a/backtest/data/source/coinmarketcap.py b/backtest/data/source/coinmarketcap.py index 3659202..ff8fa87 100644 --- a/backtest/data/source/coinmarketcap.py +++ b/backtest/data/source/coinmarketcap.py @@ -75,8 +75,7 @@ def _fetch_mapping(self, page_size) -> None: print(f"[info] [datasource] [coinmarketcap] mapping size is {len(self.symbol_to_id_mapping)}", file=sys.stderr) - @abc.abstractmethod - def fetch_prices(self, symbols: typing.Set[str], start: datetime.date, end: datetime.date) -> pandas.DataFrame: + def fetch_prices(self, symbols, start, end): today = pandas.to_datetime(datetime.date.today()) prices: pandas.DataFrame = None @@ -121,6 +120,5 @@ def fetch_prices(self, symbols: typing.Set[str], start: datetime.date, end: date return prices - @abc.abstractmethod - def is_closeable(self) -> bool: + def is_closeable(self): return False diff --git a/backtest/data/source/dataframe.py b/backtest/data/source/dataframe.py new file mode 100644 index 0000000..1d84ca8 --- /dev/null +++ b/backtest/data/source/dataframe.py @@ -0,0 +1,64 @@ +import numpy +import pandas + +from .base import DataSource +from ... import constants + + +class DataFrameDataSource(DataSource): + + def __init__( + self, + dataframe: pandas.DataFrame, + date_column=constants.DEFAULT_DATE_COLUMN, + symbol_column=constants.DEFAULT_SYMBOL_COLUMN, + price_column=constants.DEFAULT_PRICE_COLUMN, + closeable=True + ) -> None: + super().__init__() + + dataframe = dataframe.drop_duplicates( + subset=[symbol_column, date_column], + keep="first" + ) + + dataframe = dataframe.pivot( + index=date_column, + columns=symbol_column, + values=price_column + ) + + dataframe.index = pandas.to_datetime(dataframe.index) + dataframe.index.name = constants.DEFAULT_DATE_COLUMN + + self.dataframe = dataframe + self.closeable = closeable + + def fetch_prices(self, symbols, start, end): + symbols = set(symbols) + + missings = symbols - set(self.dataframe.columns) + founds = symbols - missings + + prices = None + if len(founds): + start = pandas.to_datetime(start) + end = pandas.to_datetime(end) + + prices = self.dataframe[ + (self.dataframe.index >= start) & + (self.dataframe.index <= end) + ][list(founds)].copy() + else: + prices = pandas.DataFrame( + index=pandas.DatetimeIndex( + data=pandas.date_range(start=start, end=end) + ) + ) + + prices[list(missings)] = numpy.nan + + return prices + + def is_closeable(self): + return self.closeable diff --git a/backtest/data/source/delegate.py b/backtest/data/source/delegate.py index fae572a..83bd426 100644 --- a/backtest/data/source/delegate.py +++ b/backtest/data/source/delegate.py @@ -1,8 +1,7 @@ -import abc import datetime import typing -import numpy +import numpy import pandas from .base import DataSource @@ -13,8 +12,7 @@ class DelegateDataSource(DataSource): def __init__(self, delegates: typing.List[DataSource]): self.delegates = delegates - @abc.abstractmethod - def fetch_prices(self, symbols: typing.Set[str], start: datetime.date, end: datetime.date) -> pandas.DataFrame: + def fetch_prices(self, symbols, start, end): prices = None for delegate in self.delegates: @@ -43,8 +41,7 @@ def fetch_prices(self, symbols: typing.Set[str], start: datetime.date, end: date return prices - @abc.abstractmethod - def is_closeable(self) -> bool: + def is_closeable(self): return True @staticmethod diff --git a/backtest/data/source/factset.py b/backtest/data/source/factset.py index 0562c59..da928c0 100644 --- a/backtest/data/source/factset.py +++ b/backtest/data/source/factset.py @@ -1,20 +1,20 @@ import datetime - -import abc -import os import sys +import typing + import pandas import requests -import typing import tqdm from ...utils import ensure_not_blank from .base import DataSource + def chunks(l, n): n = max(1, n) return [l[i: i + n] for i in range(0, len(l), n)] + class FactsetDataSource(DataSource): def __init__(self, username_serial: str, api_key: str, chunk_size=100): @@ -30,8 +30,7 @@ def __init__(self, username_serial: str, api_key: str, chunk_size=100): self.chunk_size = chunk_size - @abc.abstractmethod - def fetch_prices(self, symbols: typing.Set[str], start: datetime.date, end: datetime.date) -> pandas.DataFrame: + def fetch_prices(self, symbols, start, end) -> pandas.DataFrame: prices = None for chunk in tqdm.tqdm(chunks(list(symbols), self.chunk_size)): @@ -50,7 +49,8 @@ def fetch_prices(self, symbols: typing.Set[str], start: datetime.date, end: date status_code = response.status_code if status_code != 200: - print(f"got status {status_code}: {response.content}", file=sys.stderr) + print( + f"got status {status_code}: {response.content}", file=sys.stderr) continue dataframe = FactsetDataSource._to_dataframe(response.json()) @@ -64,10 +64,9 @@ def fetch_prices(self, symbols: typing.Set[str], start: datetime.date, end: date on="Date", how="outer" ) - + return prices - @abc.abstractmethod def is_closeable(self) -> bool: return True diff --git a/backtest/data/source/file/__init__.py b/backtest/data/source/file/__init__.py deleted file mode 100644 index 9c6eae3..0000000 --- a/backtest/data/source/file/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .parquet import RowParquetFileDataSource diff --git a/backtest/data/source/file/parquet.py b/backtest/data/source/file/parquet.py deleted file mode 100644 index 8a9c38c..0000000 --- a/backtest/data/source/file/parquet.py +++ /dev/null @@ -1,80 +0,0 @@ -import abc -import datetime -import json -import os -import sys -import typing - -import numpy -import pandas -import pyarrow -import pyarrow.parquet -import requests -import tqdm - -from ..base import DataSource - - -def _expect_column(table: pyarrow.lib.Table, name: str, type: str): - for field in table.schema: - if field.name == name: - if field.type != type: - raise ValueError(f"field {name} expected type {type} but got: {field.type}") - else: - return True - - raise ValueError(f"field {name} not found") - - -class RowParquetFileDataSource(DataSource): - - def __init__(self, path: str, date_column="date", symbol_column="symbol", price_column="price") -> None: - super().__init__() - - self.date_column = date_column - self.symbol_column = symbol_column - self.price_column = price_column - - table = pyarrow.parquet.read_table(path, memory_map=True) - - _expect_column(table, date_column, "date32[day]") - _expect_column(table, symbol_column, "string") - _expect_column(table, price_column, "double") - - dataframe = table.to_pandas() - dataframe = dataframe.drop_duplicates(subset=[symbol_column, date_column], keep="first") - dataframe = dataframe.pivot(index=date_column, columns=symbol_column, values=price_column) - dataframe.index = pandas.to_datetime(dataframe.index) - - dataframe.index.name = "Date" - - self.storage = dataframe - - @abc.abstractmethod - def fetch_prices(self, symbols: typing.Set[str], start: datetime.date, end: datetime.date) -> pandas.DataFrame: - symbols = set(symbols) - - missings = symbols - set(self.storage.columns) - founds = symbols - missings - - prices = None - if len(founds): - start = pandas.to_datetime(start) - end = pandas.to_datetime(end) - - prices = self.storage[(self.storage.index >= start) & (self.storage.index <= end)][list(founds)].copy() - else: - prices = pandas.DataFrame( - index=pandas.DatetimeIndex( - data=pandas.date_range(start=start, end=end), - name="Date" - ) - ) - - prices[list(missings)] = numpy.nan - - return prices - - @abc.abstractmethod - def is_closeable(self) -> bool: - return True diff --git a/backtest/data/source/yahoo.py b/backtest/data/source/yahoo.py index cf09682..abd33a9 100644 --- a/backtest/data/source/yahoo.py +++ b/backtest/data/source/yahoo.py @@ -1,8 +1,3 @@ -import abc -import datetime -import typing - -import pandas import yfinance from .base import DataSource @@ -10,8 +5,7 @@ class YahooDataSource(DataSource): - @abc.abstractmethod - def fetch_prices(self, symbols: typing.Set[str], start: datetime.date, end: datetime.date) -> pandas.DataFrame: + def fetch_prices(self, symbols, start, end): return yfinance.download( tickers=symbols, start=start, @@ -19,6 +13,5 @@ def fetch_prices(self, symbols: typing.Set[str], start: datetime.date, end: date show_errors=False )["Adj Close"] - @abc.abstractmethod - def is_closeable(self) -> bool: + def is_closeable(self): return True diff --git a/backtest/order/provider.py b/backtest/order.py similarity index 68% rename from backtest/order/provider.py rename to backtest/order.py index af65d24..77458a1 100644 --- a/backtest/order/provider.py +++ b/backtest/order.py @@ -1,14 +1,66 @@ import abc +import dataclasses import datetime +import enum import functools import typing import numpy import pandas -import typing -from ._model import Order -from .. import constants +from . import constants, utils + + +class OrderDirection(enum.IntEnum): + + SELL = -1 + HOLD = 0 + BUY = 1 + + +@dataclasses.dataclass() +class Order: + + symbol: str + quantity: int + price: float + + @property + def value(self) -> float: + return self.quantity * self.price + + @property + def direction(self) -> OrderDirection: + if self.quantity > 0: + return OrderDirection.BUY + + if self.quantity < 0: + return OrderDirection.SELL + + return OrderDirection.HOLD + + @property + def valid(self): + return not utils.is_blank(self.symbol) \ + and self.price > 0 + + +@dataclasses.dataclass() +class OrderResult: + + order: Order + success: bool = False + fee: float = 0.0 + + +@dataclasses.dataclass() +class CloseResult: + + order: Order + success: bool = False + missing: bool = False + fee: float = 0.0 + class OrderProvider(metaclass=abc.ABCMeta): @@ -31,6 +83,9 @@ def __init__( symbol_column=constants.DEFAULT_SYMBOL_COLUMN, quantity_column=constants.DEFAULT_QUANTITY_COLUMN ) -> None: + if not isinstance(dataframe, pandas.DataFrame): + dataframe = pandas.DataFrame(dataframe) + dataframe = dataframe[[ date_column, symbol_column, @@ -70,7 +125,7 @@ def get_orders(self, date, account): self.symbol_column, self.quantity_column ) - + @staticmethod def convert( dataframe: pandas.DataFrame, diff --git a/backtest/order/__init__.py b/backtest/order/__init__.py deleted file mode 100644 index 0dc118c..0000000 --- a/backtest/order/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from ._model import * -from . import provider diff --git a/backtest/order/_model.py b/backtest/order/_model.py deleted file mode 100644 index dfd088a..0000000 --- a/backtest/order/_model.py +++ /dev/null @@ -1,55 +0,0 @@ -import dataclasses -import enum - -from ..utils import is_blank - - -class OrderDirection(enum.IntEnum): - - SELL = -1 - HOLD = 0 - BUY = 1 - - -@dataclasses.dataclass() -class Order: - - symbol: str - quantity: int - price: float - - @property - def value(self) -> float: - return self.quantity * self.price - - @property - def direction(self) -> OrderDirection: - if self.quantity > 0: - return OrderDirection.BUY - - if self.quantity < 0: - return OrderDirection.SELL - - return OrderDirection.HOLD - - @property - def valid(self): - return not is_blank(self.symbol) \ - and self.price > 0 - - -@dataclasses.dataclass() -class OrderResult: - - order: Order - success: bool = False - fee: float = 0.0 - - -@dataclasses.dataclass() -class CloseResult: - - order: Order - success: bool = False - missing: bool = False - fee: float = 0.0 diff --git a/backtest/price_provider.py b/backtest/price_provider.py index 3457836..3884a47 100644 --- a/backtest/price_provider.py +++ b/backtest/price_provider.py @@ -7,7 +7,8 @@ import numpy import pandas -from backtest.data.source.base import DataSource +from .data.source.base import DataSource +from . import constants class SymbolMapper: @@ -98,7 +99,7 @@ def download_missing(self, symbols: typing.Set[str]): if prices is None: prices = pandas.DataFrame( - index=pandas.Index([], name="Date"), + index=pandas.Index([], name=constants.DEFAULT_DATE_COLUMN), columns=list(missing_symbols) ) @@ -111,14 +112,14 @@ def download_missing(self, symbols: typing.Set[str]): first: prices.values }, index=pandas.Index( prices.index, - name="Date" + name=constants.DEFAULT_DATE_COLUMN )) else: prices = pandas.DataFrame({ first: numpy.nan }, index=pandas.Index( self.storage.index, - name="Date" + name=constants.DEFAULT_DATE_COLUMN )) prices.columns = self.mapper.unmaps(prices.columns) @@ -130,7 +131,7 @@ def download_missing(self, symbols: typing.Set[str]): self.storage = pandas.merge( self.storage, prices, - on='Date', + on=constants.DEFAULT_DATE_COLUMN, how="left" ) else: @@ -171,7 +172,7 @@ def _create_storage(start: datetime.date, end: datetime.date, caching=True): path = PriceProvider._get_cache_path(start, end) if os.path.exists(path): - dataframe = pandas.read_csv(path, index_col="Date") + dataframe = pandas.read_csv(path, index_col=constants.DEFAULT_DATE_COLUMN) dataframe.index = dataframe.index.astype( 'datetime64[ns]', copy=False @@ -186,8 +187,8 @@ def _create_storage(start: datetime.date, end: datetime.date, caching=True): dates.append(numpy.datetime64(date)) date += datetime.timedelta(days=1) - dataframe = pandas.DataFrame({"Date": dates, "_": numpy.nan}) - dataframe.set_index("Date", inplace=True) + dataframe = pandas.DataFrame({constants.DEFAULT_DATE_COLUMN: dates, "_": numpy.nan}) + dataframe.set_index(constants.DEFAULT_DATE_COLUMN, inplace=True) return dataframe