diff --git a/requirements-dev.txt b/requirements-dev.txt index b451d198..15fe5aa7 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -5,8 +5,8 @@ flake8==3.8.4 flake8-bugbear==21.4.3 isort==5.7.0 mccabe==0.6.1 -mypy-extensions==0.4.3 mypy==0.812 +mypy-extensions==0.4.3 pathspec==0.8.1 pycodestyle==2.6.0 pyflakes==2.2.0 diff --git a/requirements.txt b/requirements.txt index 897756d5..30ba4a41 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,18 @@ +aiodns==2.0.0 +aiohttp==3.7.4.post0 +async-timeout==3.0.1 +attrs==20.3.0 +ccxt>=1.42.7 certifi==2020.12.5 +cffi==1.14.5 chardet==4.0.0 +cryptography==3.4.7 idna==2.10 +multidict==5.1.0 +pycares==3.1.1 +pycparser==2.20 python-dateutil==2.8.1 requests==2.25.1 six==1.15.0 +yarl==1.1.0 urllib3==1.26.5 diff --git a/setup.cfg b/setup.cfg index 10a17170..3ae0206f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,6 +9,9 @@ warn_return_any = True show_error_codes = True warn_unused_configs = True +[mypy-ccxt.*] +ignore_missing_imports = True + [flake8] exclude = *py*env*/ max_line_length = 88 diff --git a/src/book.py b/src/book.py index 26599aec..7880f4eb 100644 --- a/src/book.py +++ b/src/book.py @@ -1131,6 +1131,31 @@ def read_file(self, file_path: Path) -> None: log.info("Reading file from exchange %s at %s", exchange, file_path) read_file(file_path) + + # Check whether the given exchange is "supported" by our ccxt + # implementation, by comparing the platform with the listed + # ccxt exchanges in our config. + ccxt_mapping = { + "binance": "binance", + "binance_v2": "binance", + "coinbase": "coinbasepro", + "coinbase_pro": "coinbasepro", + "kraken_ledgers_old": "kraken", + "kraken_ledgers": "kraken", + "kraken_trades": "kraken", + "bitpanda_pro_trades": "bitpanda", + } + api = ccxt_mapping.get(exchange) + if api is None: + log.warning( + f"The exchange {exchange} is not mapped to a ccxt exchange. " + "Please add the exchange to the ccxt_mapping dictionary." + ) + elif api not in config.EXCHANGES: + log.warning( + f"Exchange `{api}` not found in EXCHANGES API list in config.ini. " + "Consider adding it to obtain more accurate price data." + ) else: log.warning( f"Unable to detect the exchange of file `{file_path}`. " diff --git a/src/config.py b/src/config.py index 5677f0e3..4883b514 100644 --- a/src/config.py +++ b/src/config.py @@ -60,3 +60,4 @@ def IS_LONG_TERM(buy: datetime, sell: datetime) -> bool: DATA_PATH = Path(BASE_PATH, "data") EXPORT_PATH = Path(BASE_PATH, "export") FIAT = FIAT_CLASS.name # Convert to string. +EXCHANGES = ["binance", "coinbasepro"] diff --git a/src/graph.py b/src/graph.py new file mode 100644 index 00000000..fb41eb81 --- /dev/null +++ b/src/graph.py @@ -0,0 +1,318 @@ +import collections +import time +from typing import Optional + +import ccxt + +import config +import log_config + +log = log_config.getLogger(__name__) + + +class RateLimit: + exchangedict: dict[str, int] = {} + + def limit(self, exchange): + if lastcall := self.exchangedict.get(exchange.id): + now = time.time() + delay = exchange.rateLimit / 1000 + if exchange.name == "Kraken": + delay += 2 # the reported ratelimit gets exceeded sometimes + timepassed = now - lastcall + if (waitfor := delay - timepassed) > 0: + time.sleep(waitfor + 0.5) + self.exchangedict[exchange.id] = time.time() + else: + self.exchangedict[exchange.id] = time.time() + + +class PricePath: + def __init__( + self, + exchanges: Optional[list[str]] = None, + gdict: Optional[dict] = None, + cache: Optional[dict] = None, + ): + if exchanges is None: + exchanges = list(config.EXCHANGES) + if gdict is None: + gdict = {} + if cache is None: + cache = {} + + self.gdict = gdict + self.cache = cache + self.RateLimit = RateLimit() + + # Saves the priority for a certain path so that bad paths can be skipped. + self.priority: collections.defaultdict[str, int] = collections.defaultdict(int) + allpairs: set[tuple[str, str, str, str]] = set() + + for exchange_id in exchanges: + exchange_class = getattr(ccxt, exchange_id) + exchange = exchange_class() + markets = exchange.fetch_markets() + assert isinstance(markets, list) + if exchange_id == "kraken": + log.warning( + """Kraken is currently not supported due to only supporting + the last 720 candles of historic data""" + ) + continue + if exchange.has["fetchOHLCV"]: + toadd = [ + (i["base"], i["quote"], exchange_id, i["symbol"]) for i in markets + ] + for pair in toadd: + allpairs.add(pair) + else: + log.warning( + f"{exchange.name} does not support fetch ohlcv. " + f"Ignoring exchange and {len(markets)} pairs." + ) + + # Remove duplicate pairs. + # TODO It might be faster to create it directly as set. + # Is it even necessary to convert it to a list? + # allpairs = list(set(allpairs)) + allpairslist: list[tuple[str, str, str, str]] = list(allpairs) + del allpairs + # print("Total Pairs to check:", len(allpairs)) + + # Sorting by `symbol` to have the same result on every run due to the set. + allpairslist.sort(key=lambda x: x[3]) + + for base, quote, exchange, symbol in allpairslist: + self.add_Vertex(base) + self.add_Vertex(quote) + self.add_Edge( + base, quote, {"exchange": exchange, "symbol": symbol, "inverted": False} + ) + self.add_Edge( + quote, base, {"exchange": exchange, "symbol": symbol, "inverted": True} + ) + + def edges(self): + return self.find_edges() + + # Find the distinct list of edges + + def find_edges(self): + edgename = [] + for vrtx in self.gdict: + for nxtvrtx in self.gdict[vrtx]: + if {nxtvrtx, vrtx} not in edgename: + edgename.append({vrtx, nxtvrtx}) + return edgename + + def get_Vertices(self): + return list(self.gdict.keys()) + + # Add the vertex as a key + def add_Vertex(self, vrtx): + if vrtx not in self.gdict: + self.gdict[vrtx] = [] + + def add_Edge(self, vrtx1, vrtx2, data): + if vrtx1 in self.gdict: + self.gdict[vrtx1].append((vrtx2, data)) + else: + self.gdict[vrtx1] = [ + (vrtx2, data), + ] + + def _get_path(self, start, stop, maxdepth, depth=0): + """ + a recursive function for finding all possible paths between to vertices + """ + paths = [] + if (edges := self.gdict.get(start)) and maxdepth > depth: + for edge in edges: # list of edges starting from the start vertice + if depth == 0 and edge[0] == stop: + paths.append([edge]) + elif edge[0] == stop: + paths.append(edge) + else: + path = self._get_path(edge[0], stop, maxdepth, depth=depth + 1) + if len(path) and path is not None: + for p in path: + if p[0] == stop: + newpath = [edge] + newpath.append(p) + paths.append(newpath) + return paths + + def change_prio(self, key, value): + ke = "-".join(key) + self.priority[ke] += value + + def get_path( + self, start, stop, starttime=0, stoptime=0, preferredexchange=None, maxdepth=3 + ): + def comb_sort_key(path): + """ + Sorting function which is used to prioritize paths by: + (in order of magnitude) + - smallest length -> +1 per element + - preferred exchange -> +1 per exchange which is not preferred + - priority -> +0.5 per unfinished execution of path + - volume (if known) -> 1/sum(avg_vol per pair) + - volume (if not known) -> 1 -> always smaller if volume is known + """ + # prioritze pairs with the preferred exchange + volume = 1 + volumenew = 1 + priority = self.priority.get("-".join([a[1]["symbol"] for a in path]), 0) + pathlis = (a if (a := check_cache(pair)) else None for pair in path) + for possiblepath in pathlis: + if possiblepath and possiblepath[0]: + if possiblepath[1][1]["stoptime"] == 0: + break + elif possiblepath[1][1]["avg_vol"] != 0: + # is very much off because volume is not in the same + # currency something for later + # volumenew*= volume of next thing in path + # (needs to be fixed for inverted paths) + volumenew *= possiblepath[1][1]["avg_vol"] + + else: + break + else: + volume = 1 / volumenew + temppriority = volume + priority + + if preferredexchange: + + return ( + len(path) + + sum( + [ + 0 if pair[1]["exchange"] == preferredexchange else 1 + for pair in path + ] + ) + + temppriority + ) + else: + return len(path) + temppriority + + def check_cache(pair): + """ + checking if the start and stoptime of a pair is already known + or if it needs to be downloaded + """ + if pair[1].get("starttime") or pair[1].get("stoptime"): + return True, pair + if cacheres := self.cache.get(pair[1]["exchange"] + pair[1]["symbol"]): + pair[1]["starttime"] = cacheres[0] + pair[1]["stoptime"] = cacheres[1] + pair[1]["avg_vol"] = cacheres[2] + return True, pair + return False, pair + + def get_active_timeframe(path, starttimestamp=0, stoptimestamp=-1): + rangeinms = 0 + timeframe = int(6.048e8) # week in ms + if starttimestamp == 0: + starttimestamp = 1325372400 * 1000 + if stoptimestamp == -1: + stoptimestamp = time.time_ns() // 1_000_000 # get cur time in ms + starttimestamp -= timeframe # to handle edge cases + if stoptimestamp > starttimestamp: + rangeinms = stoptimestamp - starttimestamp + else: + rangeinms = 0 # maybe throw error + + # add one candle to the end to ensure the needed + # timeslot is in the requested candles + rangeincandles = int(rangeinms / timeframe) + 1 + + # todo: cache already used pairs + globalstarttime = 0 + globalstoptime = 0 + for i in range(len(path)): + cached, path[i] = check_cache(path[i]) + if not cached: + exchange_class = getattr(ccxt, path[i][1]["exchange"]) + exchange = exchange_class() + + self.RateLimit.limit(exchange) + timeframeexchange = exchange.timeframes.get("1w") + if ( + timeframeexchange + ): # this must be handled better maybe choose timeframe dynamically + # maybe cache this per pair + ohlcv = exchange.fetch_ohlcv( + path[i][1]["symbol"], "1w", starttimestamp, rangeincandles + ) + else: + ohlcv = [] # do not check fail later + if len(ohlcv) > 1: + # (candle ends after the date + timeframe) + path[i][1]["stoptime"] = ohlcv[-1][0] + timeframe + path[i][1]["avg_vol"] = sum([vol[-1] for vol in ohlcv]) / len( + ohlcv + ) # avg vol in curr + path[i][1]["starttime"] = ohlcv[0][0] + if ( + path[i][1]["stoptime"] < globalstoptime + or globalstoptime == 0 + ): + globalstoptime = path[i][1]["stoptime"] + if path[i][1]["starttime"] > globalstarttime: + globalstarttime = path[i][1]["starttime"] + else: + path[i][1]["stoptime"] = 0 + path[i][1]["starttime"] = 0 + path[i][1]["avg_vol"] = 0 + self.cache[path[i][1]["exchange"] + path[i][1]["symbol"]] = ( + path[i][1]["starttime"], + path[i][1]["stoptime"], + path[i][1]["avg_vol"], + ) + else: + + if ( + path[i][1]["stoptime"] < globalstoptime or globalstoptime == 0 + ) and path[i][1]["stoptime"] != 0: + globalstoptime = path[i][1]["stoptime"] + if path[i][1]["starttime"] > globalstarttime: + globalstarttime = path[i][1]["starttime"] + ohlcv = [] + return (globalstarttime, globalstoptime), path + + # get all possible paths which are no longer than 4 pairs long + paths = self._get_path(start, stop, maxdepth) + # sort by path length to get minimal conversion chain to reduce error + paths = sorted(paths, key=comb_sort_key) + # get timeframe in which a path is viable + for path in paths: + timest, newpath = get_active_timeframe(path) + # this is implemented as a generator (hence the yield) to reduce + # the amount of computing needed. if the first path fails the next is used + if starttime == 0 and stoptime == 0: + yield timest, newpath + elif starttime == 0: + if stoptime < timest[1]: + yield timest, newpath + elif stoptime == 0: + if starttime > timest[0]: + yield timest, newpath + # The most ideal situation is if the timerange of the path is known + # and larger than the needed timerange + else: + if stoptime < timest[1] and starttime > timest[0]: + yield timest, newpath + + +if __name__ == "__main__": + g = PricePath(exchanges=["binance", "coinbasepro"]) + start = "IOTA" + to = "EUR" + preferredexchange = "binance" + path = g.get_path(start, to, maxdepth=2, preferredexchange=preferredexchange) + # debug only in actual use we would iterate over + # the path object fetching new paths as needed + path = list(path) + print(len(path)) diff --git a/src/log_config.py b/src/log_config.py index db7e8bf7..753b9c02 100644 --- a/src/log_config.py +++ b/src/log_config.py @@ -36,3 +36,4 @@ # Disable urllib debug messages getLogger("urllib3").propagate = False +getLogger("ccxt").propagate = False diff --git a/src/misc.py b/src/misc.py index 32f80cdb..13a07c3c 100644 --- a/src/misc.py +++ b/src/misc.py @@ -28,7 +28,6 @@ Optional, SupportsFloat, SupportsInt, - Tuple, TypeVar, Union, cast, @@ -122,7 +121,7 @@ def to_decimal_timestamp(d: datetime.datetime) -> decimal.Decimal: def get_offset_timestamps( utc_time: datetime.datetime, offset: datetime.timedelta, -) -> Tuple[int, int]: +) -> tuple[int, int]: """Return timestamps in milliseconds `offset/2` before/after `utc_time`. Args: @@ -130,7 +129,7 @@ def get_offset_timestamps( offset (datetime.timedelta) Returns: - Tuple[int, int]: Timestamps in milliseconds. + tuple[int, int]: Timestamps in milliseconds. """ start = utc_time - offset / 2 end = utc_time + offset / 2 diff --git a/src/price_data.py b/src/price_data.py index c1aa754b..5b41b5da 100644 --- a/src/price_data.py +++ b/src/price_data.py @@ -18,14 +18,17 @@ import datetime import decimal import json +import math import sqlite3 import time from pathlib import Path -from typing import Any, Union +from typing import Any, Optional, Union +import ccxt import requests import config +import graph import log_config import misc import transaction @@ -48,6 +51,9 @@ class PriceData: + def __init__(self): + self.path = graph.PricePath() + def get_db_path(self, platform: str) -> Path: return Path(config.DATA_PATH, f"{platform}.db") @@ -516,6 +522,70 @@ def get_price( return price + def get_missing_price_operations( + self, + operations: list[transaction.Operation], + coin: str, + platform: str, + reference_coin: str = config.FIAT, + ) -> list[transaction.Operation]: + """Return operations for which no price was found in the database. + + Requires the `operations` to have the same `coin` and `platform`. + + Args: + operations (list[transaction.Operation]) + coin (str) + platform (str) + reference_coin (str): Defaults to `config.FIAT`. + + Returns: + list[transaction.Operation] + """ + assert all(op.coin == coin for op in operations) + assert all(op.platform == platform for op in operations) + + # We do not have to calculate the price, if there are no operations or the + # coin is the same as the reference coin. + if not operations or coin == reference_coin: + return [] + + db_path = self.get_db_path(platform) + # If the price database does not exist, we need to query all prices. + if not db_path.is_file(): + return operations + + tablename = self.get_tablename(coin, reference_coin) + utc_time_values = ",".join(f"('{op.utc_time}')" for op in operations) + + with sqlite3.connect(db_path) as conn: + cur = conn.cursor() + # The query returns a list with 0 and 1's. + # - 0: a price exists. + # - 1: the price is missing. + query = ( + "SELECT t.utc_time IS NULL " + f"FROM (VALUES {utc_time_values}) " + f"LEFT JOIN `{tablename}` t ON t.utc_time = COLUMN1;" + ) + + # Execute the query. + try: + cur.execute(query) + except sqlite3.OperationalError as e: + if str(e) == f"no such table: {tablename}": + # The corresponding price table does not exist yet. + # We need to query all prices. + return operations + raise e + + # Evaluate the result. + result = (bool(is_missing) for is_missing, in cur.fetchall()) + missing_prices_operations = [ + op for op, is_missing in zip(operations, result) if is_missing + ] + return missing_prices_operations + def get_cost( self, tr: Union[transaction.Operation, transaction.SoldCoin], @@ -529,6 +599,241 @@ def get_cost( return price * tr.sold raise NotImplementedError + def get_candles( + self, start: int, stop: int, symbol: str, exchange_id: str + ) -> list[tuple[int, float, float, float, float, float]]: + """Return list with candles starting 2 minutes before start. + + Args: + start (int): Start time in milliseconds since epoch. + stop (int): End time in milliseconds since epoch. + symbol (str) + exchange_id (str) + + Returns: + list: List of OHLCV candles gathered from ccxt containing: + + timestamp (int): Timestamp of candle in milliseconds since epoch. + open_price (float) + lowest_price (float) + highest_price (float) + close_price (float) + volume (float) + """ + assert stop >= start, f"`stop` must be after `start` {stop} !>= {start}." + + exchange_class = getattr(ccxt, exchange_id) + exchange = exchange_class() + assert isinstance(exchange, ccxt.Exchange) + + # Technically impossible. Unsupported exchanges should be detected earlier. + assert exchange.has["fetchOHLCV"] + + # time.sleep wants seconds + self.path.RateLimit.limit(exchange) + + # Get candles 2 min before and after start/stop. + since = start - 2 * 60 * 1000 + # `fetch_ohlcv` has no stop value but only a limit (amount of candles fetched). + # Calculate the amount of candles in the 1 min timeframe, + # so that we get enough candles. + # Most exchange have an upper limit (e.g. binance 1000, coinbasepro 300). + # `ccxt` throws an error if we exceed this limit. + limit = math.ceil((stop - start) / (1000 * 60)) + 2 + try: + candles = exchange.fetch_ohlcv(symbol, "1m", since, limit) + except ccxt.RateLimitExceeded: + # sometimes the ratelimit gets exceeded for kraken dunno why + log.warning("Ratelimit exceeded sleeping 10 seconds and retrying") + time.sleep(10) + self.path.RateLimit.limit(exchange) + candles = exchange.fetch_ohlcv(symbol, "1m", since, limit) + + assert isinstance(candles, list) + return candles + + def get_avg_candle_prices( + self, start: int, stop: int, symbol: str, exchange_id: str, invert: bool = False + ) -> list[tuple[int, decimal.Decimal]]: + """Return average price from ohlcv candles. + + The average price of the candle is calculated as the average from the + open and close price. + + Further information about candle-function can be found in `get_candles`. + + Args: + start (int) + stop (int) + symbol (str) + exchange_id (str) + invert (bool, optional): Defaults to False. + + Returns: + list: Timestamp and average prices of candles containing: + + timestamp (int): Timestamp of candle in milliseconds since epoch. + avg_price (decimal.Decimal): Average price of candle. + """ + avg_candle_prices = [] + candle_prices = self.get_candles(start, stop, symbol, exchange_id) + + for timestamp_ms, _open, _high, _low, _close, _volume in candle_prices: + open = misc.force_decimal(_open) + close = misc.force_decimal(_close) + + avg_price = (open + close) / 2 + + if invert and avg_price != 0: + avg_price = 1 / avg_price + + avg_candle_prices.append((timestamp_ms, avg_price)) + return avg_candle_prices + + # TODO preferredexchange default is only for debug purposes and should be + # removed later on. + def _get_bulk_pair_data_path( + self, + operations: list, + coin: str, + reference_coin: str, + preferredexchange: str = "binance", + ) -> list: + def merge_prices(a: list, b: Optional[list] = None) -> list: + if not b: + return a + + prices = [] + for i in a: + factor = next(j[1] for j in b if i[0] == j[0]) + prices.append((i[0], i[1] * factor)) + + return prices + + # TODO Set `max_difference` to the platform specific ohlcv-limit. + max_difference = 300 # coinbasepro + # TODO Set `max_size` to the platform specific ohlcv-limit. + max_size = 300 # coinbasepro + time_batches = transaction.time_batches( + operations, max_difference=max_difference, max_size=max_size + ) + + datacomb = [] + + for batch in time_batches: + # ccxt works with timestamps in milliseconds + first = misc.to_ms_timestamp(batch[0]) + last = misc.to_ms_timestamp(batch[-1]) + firststr = batch[0].strftime("%d-%b-%Y (%H:%M)") + laststr = batch[-1].strftime("%d-%b-%Y (%H:%M)") + log.info( + f"getting data from {str(firststr)} to {str(laststr)} for {str(coin)}" + ) + path = self.path.get_path( + coin, reference_coin, first, last, preferredexchange=preferredexchange + ) + # Todo Move the path calculation out of the for loop + # and only filter after time + for p in path: + tempdatalis: list = [] + printstr = [f"{a[1]['symbol']} ({a[1]['exchange']})" for a in p[1]] + log.debug(f"found path over {' -> '.join(printstr)}") + for i in range(len(p[1])): + tempdatalis.append([]) + symbol = p[1][i][1]["symbol"] + exchange = p[1][i][1]["exchange"] + invert = p[1][i][1]["inverted"] + + if tempdata := self.get_avg_candle_prices( + first, last, symbol, exchange, invert + ): + for operation in batch: + # TODO discuss which candle is picked + # current is closest to original date + # (often off by about 1-20s, but can be after the Trade) + # times do not always line up perfectly so take one nearest + ts = list( + map( + lambda x: ( + abs(misc.to_ms_timestamp(operation) - x[0]), + x, + ), + tempdata, + ) + ) + tempdatalis[i].append( + (operation, min(ts, key=lambda x: x[0])[1][1]) + ) + else: + tempdatalis = [] + # do not try already failed again + self.path.change_prio(printstr, 0.2) + break + if tempdatalis: + wantedlen = len(tempdatalis[0]) + for li in tempdatalis: + if not len(li) == wantedlen: + self.path.change_prio(printstr, 0.2) + break + else: + prices: list = [] + for d in tempdatalis: + prices = merge_prices(d, prices) + datacomb.extend(prices) + break + log.debug("path failed trying new path") + + return datacomb + + def preload_prices( + self, + operations: list[transaction.Operation], + coin: str, + platform: str, + reference_coin: str = config.FIAT, + ) -> None: + """Preload price data. + + Requires the operations to have the same `coin` and `exchange`. + + Args: + operations (list[transaction.Operation]) + coin (str) + platform (str) + reference_coin (str): Defaults to `config.FIAT`. + """ + assert all(op.coin == coin for op in operations) + assert all(op.platform == platform for op in operations) + + # We do not have to preload prices, if there are no operations or the coin is + # the same as the reference coin. + if not operations or coin == reference_coin: + return + + if platform == "kraken": + log.warning( + f"Will not preload prices for {platform}, reverting to default API." + ) + return + + # Only consider the operations for which we have no prices in the database. + missing_prices_operations = self.get_missing_price_operations( + operations, coin, platform, reference_coin + ) + + # Preload the prices. + if missing_prices_operations: + data = self._get_bulk_pair_data_path( + missing_prices_operations, + coin, + reference_coin, + preferredexchange=platform, + ) + + # TODO Use bulk insert to write all prices at once into the database. + for p in data: + set_price_db(platform, coin, reference_coin, p[0], p[1]) + def check_database(self): stats = {} diff --git a/src/taxman.py b/src/taxman.py index 7527991d..28e56741 100644 --- a/src/taxman.py +++ b/src/taxman.py @@ -251,6 +251,20 @@ def _evaluate_taxation_per_coin( def evaluate_taxation(self) -> None: """Evaluate the taxation using country specific function.""" log.debug("Starting evaluation...") + counter = 0 + total_operations = len(self.book.operations) + for plat, _ops in misc.group_by(self.book.operations, "platform").items(): + for coin, coin_operations in misc.group_by(_ops, "coin").items(): + s_operations = transaction.sort_operations( + coin_operations, ["utc_time"] + ) + self.price_data.preload_prices(s_operations, coin, plat) + counter += len(coin_operations) + log.info( + "{:6.2f} % done, {:6d} out of {:d} operations processed".format( + counter / total_operations * 100, counter, total_operations + ) + ) if config.MULTI_DEPOT: # Evaluate taxation separated by platforms and coins. diff --git a/src/transaction.py b/src/transaction.py index 8e0140f7..3cb8de51 100644 --- a/src/transaction.py +++ b/src/transaction.py @@ -142,6 +142,66 @@ class TaxEvent: remark: str = "" +# Functions + + +def time_batches( + operations: list[Operation], + max_difference: typing.Optional[int], + max_size: typing.Optional[int] = None, +) -> typing.Iterable[list[datetime.datetime]]: + """Return timestamps of operations in batches. + + The batches are clustered such that the batches time difference + from first to last operation is lesser than `max_difference` minutes and the + batches have a maximum size of `max_size`. + + TODO Solve the clustering optimally. (It's already optimal, if max_size is None.) + + Args: + operations (list[Operation]): List of operations. + max_difference (Optional[int], optional): + Maximal time difference in batch (in minutes). + Defaults to None (unlimited time difference). + limax_sizemit (Optional[int], optional): + Maximum size of batch. + Defaults to None (unlimited size). + + Yields: + Generator[None, list[datetime.datetime], None]: Yield the timestamp clusters. + """ + assert max_difference is None or max_difference >= 0 + assert max_size is None or max_size > 0 + + batch: list[datetime.datetime] = [] + + if not operations: + # Nothing to cluster, return empty list. + return batch + + # Calculate the latest time which is allowed to be in this cluster. + if max_difference: + max_time = operations[0].utc_time + datetime.timedelta(minutes=max_difference) + else: + max_time = datetime.datetime.max + + for op in operations: + timestamp = op.utc_time + + # Check if timestamp is before max_time and + # that our cluster isn't to large already. + if timestamp < max_time and (not max_size or len(batch) < max_size): + batch.append(timestamp) + else: + yield batch + + batch = [timestamp] + + if max_difference: + max_time = timestamp + datetime.timedelta(minutes=max_difference) + yield batch # fixes bug where last batch ist not yielded + + gain_operations = [ CoinLendEnd, StakingEnd,