diff --git a/tests/test_algos.py b/tests/test_algos.py index e6ea784..ec7379b 100644 --- a/tests/test_algos.py +++ b/tests/test_algos.py @@ -2302,3 +2302,42 @@ def test_hedge_risk_pseudo_over(): assert c1.position == 100 assert c2.position == -5 assert c3.position == -5 + + +def test_corporate_actions(): + dts = pd.date_range("2010-01-01", periods=3) + + data = pd.DataFrame(index=dts, columns=["c1", "c2"], data=100) + divs = pd.DataFrame(index=dts, columns=["c1", "c2"], data=0.0) + divs.loc[dts[1], "c1"] = 2.0 + splits = pd.DataFrame(index=dts, columns=["c1", "c2"], data=1.0) + splits.loc[dts[2], "c2"] = 10.0 + + algo = algos.CorporateActions(divs, splits) + + s = bt.Strategy("s", children=["c1", "c2"]) + s.setup(data) + s.adjust(20000) + + s.update(dts[0]) + s.allocate(10000, "c1", update=True) + s.allocate(10000, "c2", update=True) + + assert algo(s) + assert s.capital == 0 + assert s["c1"].position == 100 + assert s["c2"].position == 100 + + s.update(dts[1]) + + assert algo(s) + assert s.capital == 100 * 2.0 + assert s["c1"].position == 100 + assert s["c2"].position == 100 + + s.update(dts[2]) + + assert algo(s) + assert s.capital == 100 * 2.0 + assert s["c1"].position == 100 + assert s["c2"].position == 100 * 10.0