diff --git a/src/backlight/__init__.py b/src/backlight/__init__.py index 24b3a60..a32000a 100644 --- a/src/backlight/__init__.py +++ b/src/backlight/__init__.py @@ -1,4 +1,4 @@ __author__ = "AlpacaJapan Co., Ltd." -__version__ = "0.1.5" -__release__ = "0.1.5" +__version__ = "0.2.0" +__release__ = "0.2.0" __license__ = "MIT" diff --git a/src/backlight/metrics/position_metrics.py b/src/backlight/metrics/position_metrics.py index d88dd61..89f6b4b 100644 --- a/src/backlight/metrics/position_metrics.py +++ b/src/backlight/metrics/position_metrics.py @@ -6,7 +6,6 @@ from backlight.datasource.marketdata import MarketData from backlight.positions import calc_positions from backlight.positions.positions import Positions -from backlight.trades.trades import Trade, Trades def _sum(a: pd.Series) -> float: diff --git a/src/backlight/metrics/trade_metrics.py b/src/backlight/metrics/trade_metrics.py index c4c269f..309a3b4 100644 --- a/src/backlight/metrics/trade_metrics.py +++ b/src/backlight/metrics/trade_metrics.py @@ -5,7 +5,7 @@ import backlight.positions from backlight.datasource.marketdata import MarketData -from backlight.trades.trades import Trade, Trades +from backlight.trades.trades import Trades, make_trades from backlight.metrics.position_metrics import calc_pl, calc_position_performance @@ -17,9 +17,10 @@ def _sum(a: pd.Series) -> float: return a.sum() if len(a) != 0 else 0.0 -def _calc_pl(trade: Trade, mkt: MarketData) -> float: +def _calc_pl(trade: pd.Series, mkt: MarketData) -> float: mkt = mkt.loc[trade.index, :] - positions = backlight.positions.calc_positions((trade,), mkt) + trades = make_trades(mkt.symbol, [trade]) + positions = backlight.positions.calc_positions(trades, mkt) pl = calc_pl(positions) return _sum(pl) @@ -36,8 +37,12 @@ def count_trades(trades: Trades, mkt: MarketData) -> Tuple[int, int, int]: Returns: total count, win count, lose count """ - pls = [_calc_pl(t, mkt) for t in trades if len(t.index) > 1] - total = len(trades) + pls = [ + _calc_pl(trades.get_trade(i), mkt) + for i in trades.ids + if len(trades.get_trade(i).index) > 1 + ] + total = len(trades.ids) win = sum([pl > 0.0 for pl in pls]) lose = sum([pl < 0.0 for pl in pls]) return total, win, lose diff --git a/src/backlight/positions/positions.py b/src/backlight/positions/positions.py index 7f8f7f7..bb7a006 100644 --- a/src/backlight/positions/positions.py +++ b/src/backlight/positions/positions.py @@ -3,8 +3,7 @@ from typing import Type, Callable from backlight.datasource.marketdata import MarketData, MidMarketData, AskBidMarketData -from backlight.trades import flatten -from backlight.trades.trades import Trade, Trades +from backlight.trades.trades import Trades def _freq(idx: pd.Index) -> pd.Timedelta: @@ -44,14 +43,15 @@ def _constructor(self) -> Type["Positions"]: return Positions -def _pricer(trade: Trade, mkt: MarketData, principal: float) -> Positions: +def _pricer(trades: Trades, mkt: MarketData, principal: float) -> pd.DataFrame: + trade = trades.amount # historical data idx = mkt.index[trade.index[0] <= mkt.index] # only after first trades positions = pd.DataFrame(index=idx) - positions.loc[:, "amount"] = trade.amount.cumsum() + positions.loc[:, "amount"] = trade.cumsum() positions.loc[:, "price"] = mkt.mid.loc[idx] - fee = mkt.fee(trade.amount) + fee = mkt.fee(trade) positions.loc[:, "principal"] = -fee.cumsum() + principal positions = positions.ffill() @@ -61,10 +61,7 @@ def _pricer(trade: Trade, mkt: MarketData, principal: float) -> Positions: positions.loc[initial_idx, "price"] = 0.0 positions.loc[initial_idx, "principal"] = principal - pos = Positions(positions.sort_index()) - pos.reset_cols() - pos.symbol = trade.symbol - return pos + return positions.sort_index() def calc_positions( @@ -78,11 +75,10 @@ def calc_positions( mkt: Market data. principal: The initial principal value. """ - trade = flatten(trades) - - assert trade.symbol == mkt.symbol - assert (trade.index.isin(mkt.index)).all() + assert trades.symbol == mkt.symbol + assert trades.index.isin(mkt.index).all() - positions = _pricer(trade, mkt, principal) - positions.symbol = trade.symbol - return positions + pos = Positions(_pricer(trades, mkt, principal)) + pos.reset_cols() + pos.symbol = trades.symbol + return pos diff --git a/src/backlight/strategies/amount_based.py b/src/backlight/strategies/amount_based.py index fc81dc6..7eaeea2 100644 --- a/src/backlight/strategies/amount_based.py +++ b/src/backlight/strategies/amount_based.py @@ -6,7 +6,7 @@ from backlight.datasource.marketdata import MarketData from backlight.signal.signal import Signal from backlight.trades import make_trade -from backlight.trades.trades import Trade, Trades, Transaction, from_series +from backlight.trades.trades import Trades, make_trades from backlight.labelizer.common import TernaryDirection from backlight.strategies.common import Action from backlight.strategies.entry import direction_based_entry @@ -35,8 +35,8 @@ def direction_based_trades( amount = pd.Series(index=df.index, name="amount").astype(np.float64) for direction, action in direction_action_dict.items(): amount.loc[df["pred"] == direction.value] = action.act_on_amount() - trade = from_series(amount, df.symbol) - return (trade,) + trade = amount + return make_trades(df.symbol, [trade]) def only_take_long(mkt: MarketData, sig: Signal) -> Trades: @@ -79,8 +79,8 @@ def _entry_and_exit_at_max_holding_time( sig: Signal data direction_action_dict: Dictionary from signals to actions max_holding_time: maximum holding time - exit_condition: The entry is closed most closest time which - condition is `True`. + exit_condition: The entry is closed most closest time which condition is `True`. + Result: Trades """ diff --git a/src/backlight/strategies/entry.py b/src/backlight/strategies/entry.py index fc36028..6895209 100644 --- a/src/backlight/strategies/entry.py +++ b/src/backlight/strategies/entry.py @@ -1,18 +1,14 @@ import pandas as pd +from typing import List + from backlight.datasource.marketdata import MarketData from backlight.signal.signal import Signal from backlight.trades import make_trade -from backlight.trades.trades import Transaction, Trade, Trades +from backlight.trades.trades import Trades, from_dataframe from backlight.strategies.common import Action -def _entry(amount: float, idx: pd.Timestamp, symbol: str) -> Trade: - trade = make_trade(symbol) - trade.add(Transaction(timestamp=idx, amount=amount)) - return trade - - def direction_based_entry( mkt: MarketData, sig: Signal, direction_action_dict: dict ) -> Trades: @@ -28,12 +24,23 @@ def direction_based_entry( assert all([idx in mkt.index for idx in sig.index]) df = sig - trades = () # type: Trades + trades = [] # type: List[pd.Dataframe] for direction, action in direction_action_dict.items(): + amount = action.act_on_amount() if amount == 0.0: continue - target_index = df[df["pred"] == direction.value].index - trades += tuple(_entry(amount, idx, df.symbol) for idx in target_index) - return trades + trades.append( + pd.DataFrame( + index=df[df["pred"] == direction.value].index, + data=direction.value, + columns=["amount"], + ) + ) + + df_trades = pd.concat(trades, axis=0).sort_index() + df_trades.loc[:, "_id"] = range(len(df_trades.index)) + + t = from_dataframe(df_trades, df.symbol) + return t diff --git a/src/backlight/strategies/exit.py b/src/backlight/strategies/exit.py index 4c2af87..a554be3 100644 --- a/src/backlight/strategies/exit.py +++ b/src/backlight/strategies/exit.py @@ -1,13 +1,13 @@ import numpy as np import pandas as pd -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Tuple from backlight.datasource.marketdata import MarketData from backlight.labelizer.common import TernaryDirection from backlight.signal.signal import Signal from backlight.trades import make_trade -from backlight.trades.trades import Transaction, Trade, Trades +from backlight.trades.trades import Transaction, Trades, concat, from_dataframe from backlight.strategies.common import Action @@ -25,18 +25,18 @@ def _concat(mkt: MarketData, sig: Optional[Signal]) -> pd.DataFrame: def _exit_transaction( df: pd.DataFrame, - trade: Trade, - exit_condition: Callable[[pd.DataFrame, Trade], pd.Series], + trade: pd.Series, + exit_condition: Callable[[pd.DataFrame, pd.Series], pd.Series], ) -> Transaction: exit_indices = df[exit_condition(df, trade)].index if exit_indices.empty: exit_index = df.index[-1] else: exit_index = exit_indices[0] - return Transaction(timestamp=exit_index, amount=-trade.amount.sum()) + return Transaction(timestamp=exit_index, amount=-trade.sum()) -def _no_exit_condition(df: pd.DataFrame, trade: Trade) -> pd.Series: +def _no_exit_condition(df: pd.DataFrame, trade: pd.Series) -> pd.Series: return pd.Series(index=df.index, data=False) @@ -44,9 +44,9 @@ def exit( mkt: MarketData, sig: Optional[Signal], entries: Trades, - exit_condition: Callable[[pd.DataFrame, Trade], pd.Series], + exit_condition: Callable[[pd.DataFrame, pd.Series], pd.Series], ) -> Trades: - """Exit trade at max holding time or satisfying condition. + """Exit trade when satisfying condition. Args: mkt: Market data @@ -61,21 +61,32 @@ def exit( df = _concat(mkt, sig) def _exit( - trade: Trade, + trades: Trades, df: pd.DataFrame, - exit_condition: Callable[[pd.DataFrame, Trade], pd.Series], - ) -> Trade: - if trade.amount.sum() == 0: - return trade + exit_condition: Callable[[pd.DataFrame, pd.Series], pd.Series], + ) -> pd.Series: + + indices = [] # type: List[pd.Timestamp] + exits = [] # type: List[Tuple[float, int]] + for i in trades.ids: + trade = trades.get_trade(i) + if trade.sum() == 0: + continue + + idx = trade.index[0] + df_exit = df[idx <= df.index] + transaction = _exit_transaction(df_exit, trade, exit_condition) + + indices.append(transaction.timestamp) + exits.append((transaction.amount, i)) + + df = pd.DataFrame(index=indices, data=exits, columns=["amount", "_id"]) - idx = trade.index[0] - df_exit = df[idx <= df.index] - transaction = _exit_transaction(df_exit, trade, exit_condition) - trade.add(transaction) - return trade + return from_dataframe(df, symbol) - trades = tuple(_exit(trade, df, exit_condition) for trade in entries) - return trades + symbol = entries.symbol + exits = _exit(entries, df, exit_condition) + return concat([entries, exits]) def exit_by_max_holding_time( @@ -83,7 +94,7 @@ def exit_by_max_holding_time( sig: Optional[Signal], entries: Trades, max_holding_time: pd.Timedelta, - exit_condition: Callable[[pd.DataFrame, Trade], pd.Series], + exit_condition: Callable[[pd.DataFrame, pd.Series], pd.Series], ) -> Trades: """Exit trade at max holding time or satisfying condition. @@ -100,22 +111,32 @@ def exit_by_max_holding_time( df = _concat(mkt, sig) def _exit_by_max_holding_time( - trade: Trade, + trades: Trades, df: pd.DataFrame, max_holding_time: pd.Timedelta, - exit_condition: Callable[[pd.DataFrame, Trade], pd.Series], - ) -> Trade: - idx = trade.index[0] - df_exit = df[(idx <= df.index) & (df.index <= idx + max_holding_time)] - transaction = _exit_transaction(df_exit, trade, exit_condition) - trade.add(transaction) - return trade - - trades = tuple( - _exit_by_max_holding_time(trade, df, max_holding_time, exit_condition) - for trade in entries - ) - return trades + exit_condition: Callable[[pd.DataFrame, pd.Series], pd.Series], + ) -> Trades: + + indices = [] # type: List[pd.Timestamp] + exits = [] # type: List[Tuple[float, int]] + for i in trades.ids: + trade = trades.get_trade(i) + if trade.sum() == 0: + continue + + idx = trade.index[0] + df_exit = df[(idx <= df.index) & (df.index <= idx + max_holding_time)] + transaction = _exit_transaction(df_exit, trade, exit_condition) + + indices.append(transaction.timestamp) + exits.append((transaction.amount, i)) + + df = pd.DataFrame(index=indices, data=exits, columns=["amount", "_id"]) + return from_dataframe(df, symbol) + + symbol = entries.symbol + exits = _exit_by_max_holding_time(entries, df, max_holding_time, exit_condition) + return concat([entries, exits]) def exit_at_max_holding_time( @@ -162,7 +183,7 @@ def _exit_at_opposite_signals_condition( opposite_signals = opposite_signals_dict[current_signal] return df["pred"].isin(opposite_signals) - def _exit_condition(df: pd.DataFrame, trade: Trade) -> pd.Series: + def _exit_condition(df: pd.DataFrame, trade: pd.Series) -> pd.Series: return _exit_at_opposite_signals_condition(df, opposite_signals_dict) return exit_by_max_holding_time( @@ -184,7 +205,7 @@ def exit_by_expectation( Trades """ - def _exit_by_expectation_condition(df: pd.DataFrame, trade: Trade) -> pd.Series: + def _exit_by_expectation_condition(df: pd.DataFrame, trade: pd.Series) -> pd.Series: current_signal = TernaryDirection(df["pred"][0]) v = np.array([1.0, 0.0, -1.0]) expectation = np.dot(df[["up", "neutral", "down"]].values, v) @@ -219,10 +240,10 @@ def exit_by_trailing_stop( assert initial_stop >= 0.0 assert trailing_stop >= 0.0 - def _exit_by_trailing_stop(df: pd.DataFrame, trade: Trade) -> pd.Series: + def _exit_by_trailing_stop(df: pd.DataFrame, trade: pd.Series) -> pd.Series: prices = df.mid - amount = trade.amount.sum() + amount = trade.sum() entry_price = prices.iloc[0] pl_per_amount = np.sign(amount) * (prices - entry_price) is_initial_stop = pl_per_amount <= -initial_stop diff --git a/src/backlight/trades/__init__.py b/src/backlight/trades/__init__.py index d82888a..d281ed1 100644 --- a/src/backlight/trades/__init__.py +++ b/src/backlight/trades/__init__.py @@ -1 +1 @@ -from backlight.trades.trades import flatten, make_trade # noqa +from backlight.trades.trades import make_trade, make_trades # noqa diff --git a/src/backlight/trades/trades.py b/src/backlight/trades/trades.py index cd2fe7e..7fab73d 100644 --- a/src/backlight/trades/trades.py +++ b/src/backlight/trades/trades.py @@ -1,7 +1,7 @@ import pandas as pd from collections import namedtuple from functools import lru_cache -from typing import Any, Type, Tuple # noqa +from typing import Any, Type, List, Iterable, Optional # noqa from backlight.datasource.marketdata import MarketData @@ -9,88 +9,153 @@ Transaction = namedtuple("Transaction", ["timestamp", "amount"]) -class Trade: - """Series object like instance for Trade. The purpose of the class is - to improve computation speed. - """ +def _max(s: pd.Series) -> int: + if len(s) == 0: + return 0 + return max(s) - def __init__(self, symbol: str) -> None: - self._symbol = symbol - self._index = () # type: tuple - self._amount = () # type: tuple - def __eq__(self, other: Any) -> bool: - return self.__class__ == other.__class__ and self.__hash__() == other.__hash__() +class Trades(pd.DataFrame): + """A collection of trades. - def __repr__(self) -> str: - return str(self.amount) + This is designed to achieve following purposes + 1. Compute metrics which need individual trade perfomance + s.t. win_rate and lose_rate. + 2. Filter the trades. + """ - def __hash__(self) -> int: - return hash((self._index, self._amount, self.symbol)) + _metadata = ["symbol"] - def add(self, t: Transaction) -> None: - """Add transaction""" - self._index += (t.timestamp,) - self._amount += (t.amount,) + _target_columns = ["amount", "_id"] @property - def amount(self) -> pd.Series: - """Amount of transactions at that moment""" - amount = pd.Series(data=self._amount, index=self._index) - return amount.groupby(amount.index).sum().sort_index() + def ids(self) -> List[int]: + """Return all unique ids""" + if "_id" not in self.columns: + return [] + return self._id.unique().tolist() @property - def index(self) -> pd.Index: - """Index of transactions""" - return pd.Index(self._index).drop_duplicates().sort_values() + def amount(self) -> pd.Series: + """Flattend as one Trade""" + a = self["amount"] + return a.groupby(a.index).sum().sort_index() + + def get_trade(self, trade_id: int) -> pd.Series: + """Get trade. + + Args: + trade_id: Id for the trade. + Trades of the same id are recognized as one individual trade. + Returns: + Trade of pd.Series. + """ + return self.loc[self._id == trade_id, "amount"] + + def get_any(self, key: Any) -> Type["Trades"]: + """Filter trade which match conditions at least one element. + + Args: + key: Same arguments with pd.DataFrame.__getitem__. + + Returns: + Trades. + """ + filterd_ids = self[key].ids + trades = [self.get_trade(i) for i in filterd_ids] + return make_trades(self.symbol, trades, filterd_ids) + + def get_all(self, key: Any) -> Type["Trades"]: + """Filter trade which match conditions for all elements. + + Args: + key: Same arguments with pd.DataFrame.__getitem__. + + Returns: + Trades. + """ + filterd = self[key] + ids = [] + trades = [] + for i in filterd.ids: + t = self.get_trade(i) + if t.equals(filterd.get_trade(i)): + ids.append(i) + trades.append(t) + return make_trades(self.symbol, trades, ids) + + def reset_cols(self) -> None: + """Keep only _target_columns""" + for col in self.columns: + if col not in self._target_columns: + self.drop(col, axis=1, inplace=True) @property - def symbol(self) -> str: - """Asset symbol""" - return self._symbol + def _constructor(self) -> Type["Trades"]: + return Trades -Trades = Tuple[Trade, ...] -"""A collection of trades. +def _sum(a: pd.Series) -> float: + return a.sum() if len(a) != 0 else 0 -This is designed to achieve following purposes -1. Compute metrics which need individual trade perfomance s.t. win_rate and lose_rate. -2. Filter the trades s.t. `[t for t in trades if trades.index[0].hour in [0, 1, 2])]`. -""" +def _sort(t: Trades) -> Trades: + t["ind"] = t.index + t = t.sort_values(by=["ind", "_id"]) + t.reset_cols() + return t -def _sum(a: pd.Series) -> float: - return a.sum() if len(a) != 0 else 0 + +def make_trade(transactions: Iterable[Transaction]) -> pd.Series: + """Create Trade instance from transacsions""" + index = [t.timestamp for t in transactions] + data = [t.amount for t in transactions] + sr = pd.Series(index=index, data=data, name="amount") + return sr.groupby(sr.index).sum().sort_index() -def from_series(sr: pd.Series, symbol: str) -> Trade: - """Create a Trade instance from pd.Series. +def from_dataframe(df: pd.DataFrame, symbol: str) -> Trades: + """Create a Trades instance out of a DataFrame object Args: - sr : Series - symbol : A symbol + df: DataFrame + symbol: symbol to query Returns: - Trade + Trades """ - t = Trade(symbol) - t._index = tuple([i for i in sr.index]) - t._amount = tuple(sr.values.tolist()) - return t + trades = Trades(df.copy()) + trades.symbol = symbol + trades.reset_cols() -def make_trade(symbol: str) -> Trade: - """Initialize Trade instance""" - t = Trade(symbol) - return t + return _sort(trades) + + +def concat(trades: List[Trades]) -> Trades: + """Concatenate some fo Trades""" + t = Trades(pd.concat(trades, axis=0)) + t.symbol = trades[0].symbol + return _sort(t) + + +def make_trades( + symbol: str, trades: List[pd.Series], ids: Optional[List[int]] = None +) -> Trades: + """Create Trades from some of trades""" + if ids is None: + _ids = list(range(len(trades))) + else: + _ids = ids + assert len(_ids) == len(trades) -@lru_cache() -def flatten(trades: Trades) -> Trade: - """Flatten tuple of trade to a trade.""" - symbol = trades[0].symbol - assert all([t.symbol == symbol for t in trades]) + df = pd.concat(trades, axis=0).to_frame(name="amount") + df.loc[:, "_id"] = 0 + current = 0 + for i, t in zip(_ids, trades): + l = len(t.index) + df.iloc[current : current + l, 1] = i + current += l - amounts = pd.concat([t.amount for t in trades], axis=0) - amount = amounts.groupby(amounts.index).sum().sort_index() - return from_series(amount, symbol) + return from_dataframe(df, symbol) diff --git a/tests/metrics/test_position_metrics.py b/tests/metrics/test_position_metrics.py index 441ae12..ab7f575 100644 --- a/tests/metrics/test_position_metrics.py +++ b/tests/metrics/test_position_metrics.py @@ -30,12 +30,10 @@ def trades(symbol): index = pd.date_range(start="2018-06-06", freq="1D", periods=len(data)) trades = [] for i in range(0, len(data), 2): - trade = tr.from_series( - pd.Series(index=index[i : i + 2], data=data[i : i + 2], name="amount"), - symbol, - ) + trade = pd.Series(index=index[i : i + 2], data=data[i : i + 2], name="amount") trades.append(trade) - return tuple(trades) + trades = tr.make_trades(symbol, trades) + return trades @pytest.fixture diff --git a/tests/metrics/test_trade_metrics.py b/tests/metrics/test_trade_metrics.py index 7485bf9..68d4096 100644 --- a/tests/metrics/test_trade_metrics.py +++ b/tests/metrics/test_trade_metrics.py @@ -6,13 +6,6 @@ from backlight.trades import trades as tr -def _make_trade(transactions, symbol="hoge"): - trade = module.Trade(symbol) - for t in transactions: - trade.add(t) - return trade - - @pytest.fixture def symbol(): return "usdjpy" @@ -36,12 +29,10 @@ def trades(symbol): index = pd.date_range(start="2018-06-06", freq="1D", periods=len(data)) trades = [] for i in range(0, len(data), 2): - trade = tr.from_series( - pd.Series(index=index[i : i + 2], data=data[i : i + 2], name="amount"), - symbol, - ) + trade = pd.Series(index=index[i : i + 2], data=data[i : i + 2], name="amount") trades.append(trade) - return tuple(trades) + trades = tr.make_trades(symbol, trades) + return trades def test__calc_pl(): @@ -60,16 +51,16 @@ def test__calc_pl(): pd.DataFrame(index=dates, data=[[0], [1], [2]], columns=["mid"]), symbol ) - trade = _make_trade([t00, t11], symbol) + trade = tr.make_trade([t00, t11]) assert module._calc_pl(trade, mkt) == 1.0 - trade = _make_trade([t00, t01], symbol) + trade = tr.make_trade([t00, t01]) assert module._calc_pl(trade, mkt) == 0.0 - trade = _make_trade([t11, t20], symbol) + trade = tr.make_trade([t11, t20]) assert module._calc_pl(trade, mkt) == -1.0 - trade = _make_trade([t00, t10, t20], symbol) + trade = tr.make_trade([t00, t10, t20]) assert module._calc_pl(trade, mkt) == 3.0 diff --git a/tests/plot/test_plot_pl.py b/tests/plot/test_plot_pl.py index 8cb44ba..73f4d4b 100644 --- a/tests/plot/test_plot_pl.py +++ b/tests/plot/test_plot_pl.py @@ -6,7 +6,7 @@ import backlight.datasource import backlight.positions -from backlight.trades.trades import from_series +from backlight.trades.trades import make_trades @pytest.fixture @@ -25,12 +25,9 @@ def positions(): index = pd.date_range(start="2018-06-06", freq="1min", periods=len(data)) trades = [] for i in range(0, len(data), 2): - trade = from_series( - pd.Series(index=index[i : i + 2], data=data[i : i + 2], name="amount"), - symbol, - ) + trade = pd.Series(index=index[i : i + 2], data=data[i : i + 2], name="amount") trades.append(trade) - trades = tuple(trades) + trades = make_trades(symbol, trades) return backlight.positions.calc_positions(trades, market) diff --git a/tests/posiitons/test_positions.py b/tests/posiitons/test_positions.py index 0af20b5..005c2e9 100644 --- a/tests/posiitons/test_positions.py +++ b/tests/posiitons/test_positions.py @@ -4,7 +4,7 @@ import backlight.datasource import backlight.positions -from backlight.trades.trades import from_series +from backlight.trades.trades import make_trades @pytest.fixture @@ -53,8 +53,9 @@ def trades(symbol): data=data, name="amount", ) - trade = from_series(sr, symbol) - return (trade,) + trade = sr + trades = make_trades(symbol, [trade]) + return trades @pytest.fixture diff --git a/tests/strategies/test_amount_based.py b/tests/strategies/test_amount_based.py index 03d9706..f1184ea 100644 --- a/tests/strategies/test_amount_based.py +++ b/tests/strategies/test_amount_based.py @@ -92,8 +92,7 @@ def test_direction_based_trades(market, signal): ], name="amount", ) - trade = backlight.trades.flatten(trades) - assert (trade.amount == expected).all() + assert (trades.amount == expected).all() def test_entry_exit_trades(market, signal): @@ -134,8 +133,7 @@ def test_entry_exit_trades(market, signal): ], columns=["exist", "amount"], ) - trade = backlight.trades.flatten(trades) - assert (trade.amount == expected.amount[expected.exist]).all() + assert (trades.amount == expected.amount[expected.exist]).all() def test_simple_entry_and_exit(market, signal): @@ -169,8 +167,7 @@ def test_simple_entry_and_exit(market, signal): ], columns=["exist", "amount"], ) - trade = backlight.trades.flatten(trades) - assert (trade.amount == expected.amount[expected.exist]).all() + assert (trades.amount == expected.amount[expected.exist]).all() def test_only_entry_short_and_exit(market, signal): @@ -204,8 +201,7 @@ def test_only_entry_short_and_exit(market, signal): ], columns=["exist", "amount"], ) - trade = backlight.trades.flatten(trades) - assert (trade.amount == expected.amount[expected.exist]).all() + assert (trades.amount == expected.amount[expected.exist]).all() def test_only_entry_long_and_exit(market, signal): @@ -239,8 +235,7 @@ def test_only_entry_long_and_exit(market, signal): ], columns=["exist", "amount"], ) - trade = backlight.trades.flatten(trades) - assert (trade.amount == expected.amount[expected.exist]).all() + assert (trades.amount == expected.amount[expected.exist]).all() def test_entry_and_exit_opposite_signal(market, signal): @@ -274,8 +269,7 @@ def test_entry_and_exit_opposite_signal(market, signal): ], columns=["exist", "amount"], ) - trade = backlight.trades.flatten(trades) - assert (trade.amount == expected.amount[expected.exist]).all() + assert (trades.amount == expected.amount[expected.exist]).all() def test_entry_and_exit_other_signal(market, signal): @@ -309,8 +303,7 @@ def test_entry_and_exit_other_signal(market, signal): ], columns=["exist", "amount"], ) - trade = backlight.trades.flatten(trades) - assert (trade.amount == expected.amount[expected.exist]).all() + assert (trades.amount == expected.amount[expected.exist]).all() def test_entry_and_exit_by_expectation(market): @@ -351,5 +344,4 @@ def test_entry_and_exit_by_expectation(market): ], columns=["exist", "amount"], ) - trade = backlight.trades.flatten(trades) - assert (trade.amount == expected.amount[expected.exist]).all() + assert (trades.amount == expected.amount[expected.exist]).all() diff --git a/tests/strategies/test_exit.py b/tests/strategies/test_exit.py index 2444f25..bc37205 100644 --- a/tests/strategies/test_exit.py +++ b/tests/strategies/test_exit.py @@ -7,14 +7,7 @@ from backlight.labelizer.common import TernaryDirection from backlight.strategies.common import Action from backlight.strategies.entry import direction_based_entry -from backlight.trades.trades import Transaction, Trade - - -def _make_trade(transactions, symbol="hoge"): - trade = Trade(symbol) - for t in transactions: - trade.add(t) - return trade +from backlight.trades.trades import Transaction, make_trades, make_trade @pytest.fixture @@ -108,8 +101,7 @@ def test_exit_at_max_holding_time(market, signal, entries): ], columns=["exist", "amount"], ) - trade = backlight.trades.flatten(trades) - assert (trade.amount == expected.amount[expected.exist]).all() + assert (trades.amount == expected.amount[expected.exist]).all() def test_exit_by_trailing_stop(market, signal, entries): @@ -136,52 +128,60 @@ def test_exit_by_trailing_stop(market, signal, entries): ), symbol, ) - entries = ( - _make_trade([Transaction(pd.Timestamp("2018-06-06 00:00:00"), 1.0)]), - _make_trade([Transaction(pd.Timestamp("2018-06-06 00:00:00"), -1.0)]), - _make_trade([Transaction(pd.Timestamp("2018-06-06 00:00:00"), 0.0)]), - _make_trade([Transaction(pd.Timestamp("2018-06-06 00:03:00"), 1.0)]), - _make_trade([Transaction(pd.Timestamp("2018-06-06 00:03:00"), 0.5)]), - _make_trade([Transaction(pd.Timestamp("2018-06-06 00:03:00"), -1.0)]), + entries = make_trades( + symbol, + ( + make_trade([Transaction(pd.Timestamp("2018-06-06 00:00:00"), 1.0)]), + make_trade([Transaction(pd.Timestamp("2018-06-06 00:00:00"), -1.0)]), + make_trade([Transaction(pd.Timestamp("2018-06-06 00:00:00"), 0.0)]), + make_trade([Transaction(pd.Timestamp("2018-06-06 00:03:00"), 1.0)]), + make_trade([Transaction(pd.Timestamp("2018-06-06 00:03:00"), 0.5)]), + make_trade([Transaction(pd.Timestamp("2018-06-06 00:03:00"), -1.0)]), + ), ) initial_stop = 2.0 trailing_stop = 1.0 trades = module.exit_by_trailing_stop(market, entries, initial_stop, trailing_stop) - expected = ( - _make_trade( - [ - Transaction(pd.Timestamp("2018-06-06 00:00:00"), 1.0), - Transaction(pd.Timestamp("2018-06-06 00:05:00"), -1.0), # trail stop - ] - ), - _make_trade( - [ - Transaction(pd.Timestamp("2018-06-06 00:00:00"), -1.0), - Transaction(pd.Timestamp("2018-06-06 00:02:00"), 1.0), # loss cut - ] - ), - _make_trade([Transaction(pd.Timestamp("2018-06-06 00:00:00"), 0.0)]), - _make_trade( - [ - Transaction(pd.Timestamp("2018-06-06 00:03:00"), 1.0), - Transaction(pd.Timestamp("2018-06-06 00:05:00"), -1.0), # trail stop - ] - ), - _make_trade( - [ - Transaction(pd.Timestamp("2018-06-06 00:03:00"), 0.5), - Transaction(pd.Timestamp("2018-06-06 00:05:00"), -0.5), # loss cut - ] - ), - _make_trade( - [ - Transaction(pd.Timestamp("2018-06-06 00:03:00"), -1.0), - Transaction(pd.Timestamp("2018-06-06 00:10:00"), 1.0), # trail stop - ] + expected = make_trades( + symbol, + ( + make_trade( + [ + Transaction(pd.Timestamp("2018-06-06 00:00:00"), 1.0), + Transaction( + pd.Timestamp("2018-06-06 00:05:00"), -1.0 + ), # trail stop + ] + ), + make_trade( + [ + Transaction(pd.Timestamp("2018-06-06 00:00:00"), -1.0), + Transaction(pd.Timestamp("2018-06-06 00:02:00"), 1.0), # loss cut + ] + ), + make_trade([Transaction(pd.Timestamp("2018-06-06 00:00:00"), 0.0)]), + make_trade( + [ + Transaction(pd.Timestamp("2018-06-06 00:03:00"), 1.0), + Transaction( + pd.Timestamp("2018-06-06 00:05:00"), -1.0 + ), # trail stop + ] + ), + make_trade( + [ + Transaction(pd.Timestamp("2018-06-06 00:03:00"), 0.5), + Transaction(pd.Timestamp("2018-06-06 00:05:00"), -0.5), # loss cut + ] + ), + make_trade( + [ + Transaction(pd.Timestamp("2018-06-06 00:03:00"), -1.0), + Transaction(pd.Timestamp("2018-06-06 00:10:00"), 1.0), # trail stop + ] + ), ), ) - print(entries) - print(expected) - assert trades == expected + pd.testing.assert_frame_equal(trades, expected) diff --git a/tests/trades/test_trades.py b/tests/trades/test_trades.py index 4397eae..4655752 100644 --- a/tests/trades/test_trades.py +++ b/tests/trades/test_trades.py @@ -3,14 +3,6 @@ import pytest import pandas as pd -import backlight.datasource - - -def _make_trade(transactions, symbol="hoge"): - trade = module.Trade(symbol) - for t in transactions: - trade.add(t) - return trade @pytest.fixture @@ -24,26 +16,53 @@ def trades(symbol): index = pd.date_range(start="2018-06-06", freq="1min", periods=len(data)) trades = [] for i in range(0, len(data), 2): - trade = module.from_series( - pd.Series(index=index[i : i + 2], data=data[i : i + 2], name="amount"), - symbol, - ) + trade = pd.Series(index=index[i : i + 2], data=data[i : i + 2], name="amount") trades.append(trade) - return tuple(trades) + trades = module.make_trades(symbol, trades) + return trades -@pytest.fixture -def market(symbol): - data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0], [9.0], [9.0]] - df = pd.DataFrame( - index=pd.date_range(start="2018-06-06", freq="1min", periods=len(data)), - data=data, - columns=["mid"], - ) - return backlight.datasource.from_dataframe(df, symbol) +def test_trades_ids(trades): + expected = [0, 1, 2, 3, 4] + assert trades.ids == expected + + +def test_trades_amount(trades): + data = [1.0, -2.0, 1.0, 2.0, -4.0, 2.0, 1.0, 0.0, 1.0, 0.0] + index = pd.date_range(start="2018-06-06", freq="1min", periods=len(data)) + expected = pd.Series(data=data, index=index, name="amount") + pd.testing.assert_series_equal(trades.amount, expected) + + +def test_trades_get_any(trades): + data = [1.0, -2.0, -4.0, 2.0] + index = [ + pd.Timestamp("2018-06-06 00:00:00"), + pd.Timestamp("2018-06-06 00:01:00"), + pd.Timestamp("2018-06-06 00:04:00"), + pd.Timestamp("2018-06-06 00:05:00"), + ] + expected = pd.Series(data=data, index=index, name="amount") + result = trades.get_any(trades.index.minute.isin([0, 4, 5])) + pd.testing.assert_series_equal(result.amount, expected) -def test_Trade(): +def test_trades_get_all(trades): + data = [-4.0, 2.0] + index = [pd.Timestamp("2018-06-06 00:04:00"), pd.Timestamp("2018-06-06 00:05:00")] + expected = pd.Series(data=data, index=index, name="amount") + result = trades.get_all(trades.index.minute.isin([0, 4, 5])) + pd.testing.assert_series_equal(result.amount, expected) + + +def test_trades_get_trade(trades): + data = [1.0, -2.0] + index = pd.date_range(start="2018-06-06", freq="1min", periods=len(data)) + expected = pd.Series(data=data, index=index, name="amount") + pd.testing.assert_series_equal(trades.get_trade(0), expected) + + +def test_make_trade(): periods = 2 dates = pd.date_range(start="2018-12-01", periods=periods) amounts = range(periods) @@ -52,28 +71,18 @@ def test_Trade(): t11 = module.Transaction(timestamp=dates[1], amount=amounts[1]) t01 = module.Transaction(timestamp=dates[0], amount=amounts[1]) - trade = _make_trade([t00, t11]) + trade = module.make_trade([t00, t11]) expected = pd.Series(index=dates, data=amounts[:2], name="amount") - assert (trade.amount == expected).all() + pd.testing.assert_series_equal(trade, expected) - trade = _make_trade([t00, t01]) + trade = module.make_trade([t00, t01]) expected = pd.Series( index=[dates[0]], data=[amounts[0] + amounts[1]], name="amount" ) - assert (trade.amount == expected).all() + pd.testing.assert_series_equal(trade, expected) - trade = _make_trade([t11, t01, t00]) + trade = module.make_trade([t11, t01, t00]) expected = pd.Series( index=dates, data=[amounts[0] + amounts[1], amounts[1]], name="amount" ) - assert (trade.amount == expected).all() - - -def test_flatten(symbol, trades): - data = [1.0, -2.0, 1.0, 2.0, -4.0, 2.0, 1.0, 0.0, 1.0, 0.0] - index = pd.date_range(start="2018-06-06", freq="1min", periods=len(data)) - expected = module.from_series( - pd.Series(index=index, data=data, name="amount"), symbol - ) - trade = module.flatten(trades) - assert trade == expected + pd.testing.assert_series_equal(trade, expected)