Skip to content

Commit

Permalink
Merge pull request #261 from ta-oliver/enable_allocation_limits_for_api
Browse files Browse the repository at this point in the history
Add limits by default to the API position calculations.
  • Loading branch information
ta-oliver authored Nov 25, 2021
2 parents 8706202 + 0b5486f commit 6f177d6
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 12 deletions.
11 changes: 6 additions & 5 deletions infertrade/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

# InferTrade packages
from infertrade.algos import algorithm_functions, ta_adaptor
from infertrade.utilities.operations import ReturnsFromPositions
from infertrade.utilities.operations import ReturnsFromPositions, restrict_allocation, limit_allocation
from infertrade.PandasEnum import PandasEnum


Expand Down Expand Up @@ -159,13 +159,14 @@ def _get_raw_callable(name_of_strategy_or_signal: str) -> callable:

@staticmethod
def calculate_allocations(
df: pd.DataFrame, name_of_strategy: str, name_of_price_series: str = PandasEnum.MID.value
df: pd.DataFrame, name_of_strategy: str, name_of_price_series: str = PandasEnum.MID.value,
allocation_lower_limit: float = -1.0, allocation_upper_limit: float = 1.0
) -> pd.DataFrame:
"""Calculates the allocations using the supplied strategy."""
if name_of_price_series is not PandasEnum.MID.value:
df[PandasEnum.MID.value] = df[name_of_price_series]
rule_function = Api._get_raw_callable(name_of_strategy)
df_with_positions = rule_function(df)
df_with_positions = limit_allocation(rule_function(df), allocation_lower_limit, allocation_upper_limit)
return df_with_positions

@staticmethod
Expand All @@ -176,10 +177,10 @@ def calculate_returns(df: pd.DataFrame) -> pd.DataFrame:

@staticmethod
def calculate_allocations_and_returns(
df: pd.DataFrame, name_of_strategy: str, name_of_price_series: str = PandasEnum.MID.value
df: pd.DataFrame, name_of_strategy: str, name_of_price_series: str = PandasEnum.MID.value, *args, **kwargs
) -> pd.DataFrame:
"""Calculates the returns using the supplied strategy."""
df_with_positions = Api.calculate_allocations(df, name_of_strategy, name_of_price_series)
df_with_positions = Api.calculate_allocations(df, name_of_strategy, name_of_price_series, *args, **kwargs)
df_with_returns = ReturnsFromPositions().transform(df_with_positions)
return df_with_returns

Expand Down
8 changes: 8 additions & 0 deletions tests/test_allocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from infertrade.algos import algorithm_functions
from infertrade.algos.community import allocations
from infertrade.algos.community.allocations import create_infertrade_export_allocations
from infertrade.api import Api


def test_under_minimum_length_to_calculate():
Expand Down Expand Up @@ -115,3 +116,10 @@ def test_create_infertrade_export_allocations():
"""Checks that a valid dictionary can be created."""
dictionary_algorithms = create_infertrade_export_allocations()
assert isinstance(dictionary_algorithms, dict) # could add checks for contents too


def test_all_allocations_list_required_series():
"""Checks that all allocation rules list required series."""
for ii_rule in Api.available_algorithms(filter_by_category="allocation"):
assert isinstance(Api.required_inputs_for_algorithm(ii_rule), list)

36 changes: 29 additions & 7 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
"""

# External imports
import copy

import pandas as pd
import pytest

Expand Down Expand Up @@ -54,7 +56,7 @@ def test_get_available_algorithms(algorithm):
assert Api.determine_package_of_algorithm(algorithm) in Api.available_packages()
try:
Api.determine_package_of_algorithm("not_available_algo")
except (NameError):
except NameError:
pass

inputs = Api.required_inputs_for_algorithm(algorithm)
Expand Down Expand Up @@ -172,8 +174,8 @@ def test_return_representations(algorithm):
)
for representation in dict_of_properties[algorithm]["available_representation_types"]:
assert (
returned_representations[representation]
== dict_of_properties[algorithm]["available_representation_types"][representation]
returned_representations[representation]
== dict_of_properties[algorithm]["available_representation_types"][representation]
)

# Check if the if the function returns the correct representation when given a string
Expand All @@ -185,8 +187,8 @@ def test_return_representations(algorithm):
type(returned_representations),
)
assert (
returned_representations[representation]
== dict_of_properties[algorithm]["available_representation_types"][representation]
returned_representations[representation]
== dict_of_properties[algorithm]["available_representation_types"][representation]
)

# Check if the function returns the correct representations when given a list
Expand All @@ -198,8 +200,8 @@ def test_return_representations(algorithm):
)
for representation in algorithm_representations:
assert (
returned_representations[representation]
== dict_of_properties[algorithm]["available_representation_types"][representation]
returned_representations[representation]
== dict_of_properties[algorithm]["available_representation_types"][representation]
)


Expand Down Expand Up @@ -357,3 +359,23 @@ def test_export_cross_prediction():
)

assert isinstance(sorted_dict, dict)


@pytest.mark.parametrize("test_df", test_dfs)
def test_allocation_limit(test_df):
"""Test used to see if calculated allocation values are inside of specified limit"""

test_df_copy = copy.deepcopy(test_df)
df_with_allocations = Api.calculate_allocations(
df=test_df_copy, name_of_strategy=available_allocation_algorithms[0], name_of_price_series="close",
allocation_lower_limit=0, allocation_upper_limit=0
)
if not all(df_with_allocations["allocation"] == 0.0):
raise ValueError("Allocation limits breached")

df_with_allocations = Api.calculate_allocations(
df=test_df_copy, name_of_strategy=available_allocation_algorithms[0], name_of_price_series="close",
allocation_lower_limit=-0.1, allocation_upper_limit=0.1
)
if any(-0.1 > df_with_allocations["allocation"]) or any(df_with_allocations["allocation"] > 0.1):
raise ValueError("Allocation limits breached")

0 comments on commit 6f177d6

Please sign in to comment.