diff --git a/.gitignore b/.gitignore index cdf9c4c06..2e9acfb47 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,7 @@ doc/_build /notebooks /testcases/credentials.py +*.log +*.sqlite +.DS_Store +.vscode/settings.json diff --git a/.travis.yml b/.travis.yml index e8fd24050..8141cc767 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,6 @@ language: python python: - - "2.7" + - "3.6.8" env: global: @@ -18,7 +18,7 @@ services: - docker before_install: - - docker pull gbecedillas/pyalgotrade:0.20-py27 + - docker pull gbecedillas/pyalgotrade:0.20-py37 - cp travis/Dockerfile . - docker build -t pyalgotrade_testcases . - sudo pip install coveralls diff --git a/pyalgotrade/bar.py b/pyalgotrade/bar.py index cd9e13cd9..60b29cb1b 100644 --- a/pyalgotrade/bar.py +++ b/pyalgotrade/bar.py @@ -27,23 +27,28 @@ class Frequency(object): """Enum like class for bar frequencies. Valid values are: + * **Frequency.UNKNOWN**: The bar represents an unknown frequency trade. * **Frequency.TRADE**: The bar represents a single trade. * **Frequency.SECOND**: The bar summarizes the trading activity during 1 second. * **Frequency.MINUTE**: The bar summarizes the trading activity during 1 minute. * **Frequency.HOUR**: The bar summarizes the trading activity during 1 hour. + * **Frequency.HOUR_4**: The bar summarizes the trading activity during 4 hour. * **Frequency.DAY**: The bar summarizes the trading activity during 1 day. * **Frequency.WEEK**: The bar summarizes the trading activity during 1 week. * **Frequency.MONTH**: The bar summarizes the trading activity during 1 month. """ # It is important for frequency values to get bigger for bigger windows. - TRADE = -1 - SECOND = 1 - MINUTE = 60 - HOUR = 60*60 - DAY = 24*60*60 - WEEK = 24*60*60*7 - MONTH = 24*60*60*31 + UNKNOWN = -2 + TRADE = -1 + REALTIME = 0 + SECOND = 1 + MINUTE = 60 + HOUR = 60*60 + HOUR_4 = 60*60*4 + DAY = 24*60*60 + WEEK = 24*60*60*7 + MONTH = 24*60*60*31 @six.add_metaclass(abc.ABCMeta) @@ -260,10 +265,12 @@ def __init__(self, barDict): # Check that bar datetimes are in sync firstDateTime = None firstInstrument = None + firstFreq = None for instrument, currentBar in six.iteritems(barDict): if firstDateTime is None: firstDateTime = currentBar.getDateTime() firstInstrument = instrument + firstFreq = currentBar.getFrequency() elif currentBar.getDateTime() != firstDateTime: raise Exception("Bar data times are not in sync. %s %s != %s %s" % ( instrument, @@ -271,9 +278,11 @@ def __init__(self, barDict): firstInstrument, firstDateTime )) + assert firstFreq == currentBar.getFrequency() self.__barDict = barDict self.__dateTime = firstDateTime + self.__frequency = firstFreq def __getitem__(self, instrument): """Returns the :class:`pyalgotrade.bar.Bar` for the given instrument. @@ -301,3 +310,6 @@ def getDateTime(self): def getBar(self, instrument): """Returns the :class:`pyalgotrade.bar.Bar` for the given instrument or None if the instrument is not found.""" return self.__barDict.get(instrument, None) + + def getBarsFrequency(self): + return self.__frequency \ No newline at end of file diff --git a/pyalgotrade/barfeed/__init__.py b/pyalgotrade/barfeed/__init__.py index 1312620bd..f8123b0cf 100644 --- a/pyalgotrade/barfeed/__init__.py +++ b/pyalgotrade/barfeed/__init__.py @@ -43,16 +43,20 @@ class BaseBarFeed(feed.BaseFeed): This is a base class and should not be used directly. """ - def __init__(self, frequency, maxLen=None): + def __init__(self, frequencies, maxLen=None): super(BaseBarFeed, self).__init__(maxLen) - self.__frequency = frequency + if not isinstance(frequencies, list): + raise Exception('only frequencies list is accepted') + self.__frequencies = frequencies self.__useAdjustedValues = False self.__defaultInstrument = None self.__currentBars = None + self.__currentRealtimeBars = None self.__lastBars = {} def reset(self): self.__currentBars = None + self.__currentRealtimeBars = None self.__lastBars = {} super(BaseBarFeed, self).reset() @@ -92,30 +96,38 @@ def createDataSeries(self, key, maxLen): def getNextValues(self): dateTime = None + freq = None bars = self.getNextBars() if bars is not None: + freq = bars.getBarsFrequency() dateTime = bars.getDateTime() # Check that current bar datetimes are greater than the previous one. - if self.__currentBars is not None and self.__currentBars.getDateTime() >= dateTime: - raise Exception( - "Bar date times are not in order. Previous datetime was %s and current datetime is %s" % ( - self.__currentBars.getDateTime(), - dateTime + if self.__currentBars is not None and self.__currentBars.getDateTime() > dateTime: + if freq == self.__currentBars.getBarsFrequency(): + raise Exception( + "Bar date times are not in order. Previous datetime was %s and current datetime is %s" % ( + self.__currentBars.getDateTime(), + dateTime + ) ) - ) # Update self.__currentBars and self.__lastBars self.__currentBars = bars for instrument in bars.getInstruments(): self.__lastBars[instrument] = bars[instrument] - return (dateTime, bars) + return (dateTime, bars, freq) - def getFrequency(self): - return self.__frequency + def getAllFrequencies(self): + return self.__frequencies def isIntraday(self): - return self.__frequency < bar.Frequency.DAY + for i in self.__frequencies: + if i < bar.Frequency.DAY: + return True + + def getCurrentRealtimeBars(self): + return self.__currentRealtimeBars def getCurrentBars(self): """Returns the current :class:`pyalgotrade.bar.Bars`.""" @@ -133,11 +145,11 @@ def getRegisteredInstruments(self): """Returns a list of registered intstrument names.""" return self.getKeys() - def registerInstrument(self, instrument): + def registerInstrument(self, instrument, freq): self.__defaultInstrument = instrument - self.registerDataSeries(instrument) + self.registerDataSeries(instrument, freq) - def getDataSeries(self, instrument=None): + def getDataSeries(self, instrument=None, freq=None): """Returns the :class:`pyalgotrade.dataseries.bards.BarDataSeries` for a given instrument. :param instrument: Instrument identifier. If None, the default instrument is returned. @@ -146,7 +158,7 @@ def getDataSeries(self, instrument=None): """ if instrument is None: instrument = self.__defaultInstrument - return self[instrument] + return self[instrument, freq] if freq is not None else self[instrument] def getDispatchPriority(self): return dispatchprio.BAR_FEED @@ -158,7 +170,7 @@ class OptimizerBarFeed(BaseBarFeed): def __init__(self, frequency, instruments, bars, maxLen=None): super(OptimizerBarFeed, self).__init__(frequency, maxLen) for instrument in instruments: - self.registerInstrument(instrument) + self.registerInstrument(instrument, frequency) self.__bars = bars self.__nextPos = 0 self.__currDateTime = None diff --git a/pyalgotrade/barfeed/csvfeed.py b/pyalgotrade/barfeed/csvfeed.py index 52e2065ed..f74a5a4fd 100644 --- a/pyalgotrade/barfeed/csvfeed.py +++ b/pyalgotrade/barfeed/csvfeed.py @@ -286,8 +286,9 @@ def addBarsFromCSV(self, instrument, path, timezone=None, skipMalformedBars=Fals if timezone is None: timezone = self.__timezone + assert len(self.getAllFrequencies()) == 1 rowParser = GenericRowParser( - self.__columnNames, self.__dateTimeFormat, self.getDailyBarTime(), self.getFrequency(), + self.__columnNames, self.__dateTimeFormat, self.getDailyBarTime(), self.getAllFrequencies()[0], timezone, self.__barClass ) diff --git a/pyalgotrade/barfeed/dbfeed.py b/pyalgotrade/barfeed/dbfeed.py index bf0bb8ae6..087968be2 100644 --- a/pyalgotrade/barfeed/dbfeed.py +++ b/pyalgotrade/barfeed/dbfeed.py @@ -26,9 +26,10 @@ def addBars(self, bars, frequency): self.addBar(instrument, bar, frequency) def addBarsFromFeed(self, feed): - for dateTime, bars in feed: + assert len(feed.getAllFrequencies()) == 1 + for dateTime, bars, _ in feed: if bars: - self.addBars(bars, feed.getFrequency()) + self.addBars(bars, feed.getAllFrequencies()[0]) def addBar(self, instrument, bar, frequency): raise NotImplementedError() diff --git a/pyalgotrade/barfeed/googlefeed.py b/pyalgotrade/barfeed/googlefeed.py index d886cb814..2a788c3fb 100644 --- a/pyalgotrade/barfeed/googlefeed.py +++ b/pyalgotrade/barfeed/googlefeed.py @@ -145,6 +145,8 @@ def addBarsFromCSV(self, instrument, path, timezone=None, skipMalformedBars=Fals if timezone is None: timezone = self.__timezone + + assert len(self.getAllFrequencies()) == 1 - rowParser = RowParser(self.getDailyBarTime(), self.getFrequency(), timezone, self.__sanitizeBars) + rowParser = RowParser(self.getDailyBarTime(), self.getAllFrequencies()[0], timezone, self.__sanitizeBars) super(Feed, self).addBarsFromCSV(instrument, path, rowParser, skipMalformedBars=skipMalformedBars) diff --git a/pyalgotrade/barfeed/membf.py b/pyalgotrade/barfeed/membf.py index 7b2d5e21a..9f6d2f004 100644 --- a/pyalgotrade/barfeed/membf.py +++ b/pyalgotrade/barfeed/membf.py @@ -72,7 +72,10 @@ def addBarsFromSequence(self, instrument, bars): self.__bars[instrument].extend(bars) self.__bars[instrument].sort(key=lambda b: b.getDateTime()) - self.registerInstrument(instrument) + assert len(self.getAllFrequencies()) == 1 + for i in bars: + assert i.getFrequency() == self.getAllFrequencies()[0] + self.registerInstrument(instrument, self.getAllFrequencies()[0]) def eof(self): ret = True @@ -115,5 +118,5 @@ def getNextBars(self): return bar.Bars(ret) def loadAll(self): - for dateTime, bars in self: + for dateTime, bars, freq in self: pass diff --git a/pyalgotrade/barfeed/ninjatraderfeed.py b/pyalgotrade/barfeed/ninjatraderfeed.py index 4bc6b3dcd..3ce75de55 100644 --- a/pyalgotrade/barfeed/ninjatraderfeed.py +++ b/pyalgotrade/barfeed/ninjatraderfeed.py @@ -145,5 +145,6 @@ def addBarsFromCSV(self, instrument, path, timezone=None): if timezone is None: timezone = self.__timezone - rowParser = RowParser(self.getFrequency(), self.getDailyBarTime(), timezone) + assert len(self.getAllFrequencies()) == 1 + rowParser = RowParser(self.getAllFrequencies()[0], self.getDailyBarTime(), timezone) super(Feed, self).addBarsFromCSV(instrument, path, rowParser) diff --git a/pyalgotrade/barfeed/resampled.py b/pyalgotrade/barfeed/resampled.py index 821fff797..4a80da67e 100644 --- a/pyalgotrade/barfeed/resampled.py +++ b/pyalgotrade/barfeed/resampled.py @@ -62,7 +62,7 @@ def __init__(self, barFeed, frequency, maxLen=None): # Register the same instruments as in the underlying barfeed. for instrument in barFeed.getRegisteredInstruments(): - self.registerInstrument(instrument) + self.registerInstrument(instrument, frequency) self.__values = [] self.__barFeed = barFeed @@ -72,15 +72,16 @@ def __init__(self, barFeed, frequency, maxLen=None): barFeed.getNewValuesEvent().subscribe(self.__onNewValues) def __onNewValues(self, dateTime, value): + assert len(self.getAllFrequencies()) == 1 if self.__range is None: - self.__range = resamplebase.build_range(dateTime, self.getFrequency()) - self.__grouper = BarsGrouper(self.__range.getBeginning(), value, self.getFrequency()) + self.__range = resamplebase.build_range(dateTime, self.getAllFrequencies()[0]) + self.__grouper = BarsGrouper(self.__range.getBeginning(), value, self.getAllFrequencies()[0]) elif self.__range.belongs(dateTime): self.__grouper.addValue(value) else: self.__values.append(self.__grouper.getGrouped()) - self.__range = resamplebase.build_range(dateTime, self.getFrequency()) - self.__grouper = BarsGrouper(self.__range.getBeginning(), value, self.getFrequency()) + self.__range = resamplebase.build_range(dateTime, self.getAllFrequencies()[0]) + self.__grouper = BarsGrouper(self.__range.getBeginning(), value, self.getAllFrequencies()[0]) def getCurrentDateTime(self): return self.__barFeed.getCurrentDateTime() diff --git a/pyalgotrade/barfeed/sqlitefeed.py b/pyalgotrade/barfeed/sqlitefeed.py index e67ecf902..06e72e33d 100644 --- a/pyalgotrade/barfeed/sqlitefeed.py +++ b/pyalgotrade/barfeed/sqlitefeed.py @@ -152,5 +152,6 @@ def getDatabase(self): return self.__db def loadBars(self, instrument, timezone=None, fromDateTime=None, toDateTime=None): - bars = self.__db.getBars(instrument, self.getFrequency(), timezone, fromDateTime, toDateTime) + assert len(self.getAllFrequencies()) == 1 + bars = self.__db.getBars(instrument, self.getAllFrequencies()[0], timezone, fromDateTime, toDateTime) self.addBarsFromSequence(instrument, bars) diff --git a/pyalgotrade/barfeed/yahoofeed.py b/pyalgotrade/barfeed/yahoofeed.py index 613ff6a7c..d97a43e7a 100644 --- a/pyalgotrade/barfeed/yahoofeed.py +++ b/pyalgotrade/barfeed/yahoofeed.py @@ -146,7 +146,8 @@ def addBarsFromCSV(self, instrument, path, timezone=None): if timezone is None: timezone = self.__timezone + assert len(self.getAllFrequencies()) == 1 rowParser = RowParser( - self.getDailyBarTime(), self.getFrequency(), timezone, self.__sanitizeBars, self.__barClass + self.getDailyBarTime(), self.getAllFrequencies()[0], timezone, self.__sanitizeBars, self.__barClass ) super(Feed, self).addBarsFromCSV(instrument, path, rowParser) diff --git a/pyalgotrade/bitstamp/livebroker.py b/pyalgotrade/bitstamp/livebroker.py index 6bb8edf7e..713642f0c 100644 --- a/pyalgotrade/bitstamp/livebroker.py +++ b/pyalgotrade/bitstamp/livebroker.py @@ -218,7 +218,7 @@ def stop(self): self.__tradeMonitor.stop() def join(self): - if self.__tradeMonitor.isAlive(): + if self.__tradeMonitor.is_alive(): self.__tradeMonitor.join() def eof(self): diff --git a/pyalgotrade/bitstamp/livefeed.py b/pyalgotrade/bitstamp/livefeed.py index f874f06b6..6d9758dd6 100644 --- a/pyalgotrade/bitstamp/livefeed.py +++ b/pyalgotrade/bitstamp/livefeed.py @@ -112,7 +112,7 @@ class LiveTradeFeed(barfeed.BaseBarFeed): def __init__(self, maxLen=None): super(LiveTradeFeed, self).__init__(bar.Frequency.TRADE, maxLen) self.__barDicts = [] - self.registerInstrument(common.btc_symbol) + self.registerInstrument(common.btc_symbol, bar.Frequency.TRADE) self.__prevTradeDateTime = None self.__thread = None self.__wsClientConnected = False diff --git a/pyalgotrade/broker/backtesting.py b/pyalgotrade/broker/backtesting.py index d79889005..76868219a 100644 --- a/pyalgotrade/broker/backtesting.py +++ b/pyalgotrade/broker/backtesting.py @@ -423,9 +423,10 @@ def __preProcessOrder(self, order, bar_): def __postProcessOrder(self, order, bar_): # For non-GTC orders and daily (or greater) bars we need to check if orders should expire right now # before waiting for the next bar. + assert len(self.__barFeed.getAllFrequencies()) == 1 if not order.getGoodTillCanceled(): expired = False - if self.__barFeed.getFrequency() >= pyalgotrade.bar.Frequency.DAY: + if self.__barFeed.getAllFrequencies()[0] >= pyalgotrade.bar.Frequency.DAY: expired = bar_.getDateTime().date() >= order.getAcceptedDateTime().date() # Cancel the order if it will expire in the next bar. diff --git a/pyalgotrade/dataseries/__init__.py b/pyalgotrade/dataseries/__init__.py index c86419560..2ec777736 100644 --- a/pyalgotrade/dataseries/__init__.py +++ b/pyalgotrade/dataseries/__init__.py @@ -135,7 +135,8 @@ def appendWithDateTime(self, dateTime, value): """ if dateTime is not None and len(self.__dateTimes) != 0 and self.__dateTimes[-1] >= dateTime: - raise Exception("Invalid datetime. It must be bigger than that last one") + raise Exception("Invalid datetime. " + "It must be bigger than that last one {0} {1}".format(self.__dateTimes[-1], dateTime)) assert(len(self.__values) == len(self.__dateTimes)) self.__dateTimes.append(dateTime) diff --git a/pyalgotrade/eventprofiler.py b/pyalgotrade/eventprofiler.py index 8b3a26520..34151e91a 100644 --- a/pyalgotrade/eventprofiler.py +++ b/pyalgotrade/eventprofiler.py @@ -218,7 +218,7 @@ def run(self, feed, useAdjustedCloseForReturns=True): feed.getNewValuesEvent().unsubscribe(self.__onBars) -def build_plot(profilerResults): +def build_plot(profilerResults, alpha): # Calculate each value. x = [] mean = [] @@ -231,6 +231,7 @@ def build_plot(profilerResults): # Cleanup plt.clf() + # Plot a line with the mean cumulative returns. plt.plot(x, mean, color='#0000FF') @@ -255,12 +256,12 @@ def build_plot(profilerResults): plt.ylabel('Cumulative returns') -def plot(profilerResults): +def plot(profilerResults, alpha=0): """Plots the result of the analysis. :param profilerResults: The result of the analysis :type profilerResults: :class:`Results`. """ - build_plot(profilerResults) + build_plot(profilerResults, alpha) plt.show() diff --git a/pyalgotrade/feed/__init__.py b/pyalgotrade/feed/__init__.py index 2a56d6ae2..7bccf1c2d 100644 --- a/pyalgotrade/feed/__init__.py +++ b/pyalgotrade/feed/__init__.py @@ -22,7 +22,7 @@ from pyalgotrade import observer from pyalgotrade import dataseries - +from pyalgotrade import bar def feed_iterator(feed): feed.start() @@ -50,16 +50,15 @@ def __init__(self, maxLen): super(BaseFeed, self).__init__() maxLen = dataseries.get_checked_max_len(maxLen) - + self.__registered_ds = [] self.__ds = {} self.__event = observer.Event() self.__maxLen = maxLen def reset(self): - keys = list(self.__ds.keys()) self.__ds = {} - for key in keys: - self.registerDataSeries(key) + for key, freq in self.__registered_ds: + self.registerDataSeries(key, freq) # Subclasses should implement this and return the appropriate dataseries for the given key. @abc.abstractmethod @@ -73,28 +72,36 @@ def createDataSeries(self, key, maxLen): def getNextValues(self): raise NotImplementedError() - def registerDataSeries(self, key): + def registerDataSeries(self, key, freq = bar.Frequency.UNKNOWN): if key not in self.__ds: - self.__ds[key] = self.createDataSeries(key, self.__maxLen) + self.__ds[key] = {} + for i in self.__registered_ds: + k, v = i + if k == key and v == freq: + return + self.__ds[key][freq] = self.createDataSeries(key, self.__maxLen) + self.__registered_ds.append((key, freq)) def getNextValuesAndUpdateDS(self): - dateTime, values = self.getNextValues() + dateTime, values, freq = self.getNextValues() if dateTime is not None: for key, value in values.items(): # Get or create the datseries for each key. try: - ds = self.__ds[key] + ds = self.__ds[key][freq] except KeyError: ds = self.createDataSeries(key, self.__maxLen) - self.__ds[key] = ds + if key not in self.__ds.keys(): + self.__ds[key] = {} + self.__ds[key][freq] = ds ds.appendWithDateTime(dateTime, value) - return (dateTime, values) + return (dateTime, values, freq) def __iter__(self): return feed_iterator(self) def getNewValuesEvent(self): - """Returns the event that will be emitted when new values are available. + """Returns a event that will be emitted when new values are available. To subscribe you need to pass in a callable object that receives two parameters: 1. A :class:`datetime.datetime` instance. @@ -103,7 +110,7 @@ def getNewValuesEvent(self): return self.__event def dispatch(self): - dateTime, values = self.getNextValuesAndUpdateDS() + dateTime, values, _ = self.getNextValuesAndUpdateDS() if dateTime is not None: self.__event.emit(dateTime, values) return dateTime is not None @@ -111,10 +118,22 @@ def dispatch(self): def getKeys(self): return list(self.__ds.keys()) - def __getitem__(self, key): + def __getitem__(self, val): """Returns the :class:`pyalgotrade.dataseries.DataSeries` for a given key.""" - return self.__ds[key] - - def __contains__(self, key): + if isinstance(val, tuple): + key, freq = val + return self.__ds[key][freq] + else: + assert len(self.__ds[val]) == 1 + return list(self.__ds[val].values())[0] + + def __contains__(self, val): """Returns True if a :class:`pyalgotrade.dataseries.DataSeries` for the given key is available.""" - return key in self.__ds + if isinstance(val, tuple): + key, freq = val + else: + key, freq = val, None + if freq is None: + return key in self.__ds + else: + return (key in self.__ds and freq in self.__ds[key]) \ No newline at end of file diff --git a/pyalgotrade/feed/memfeed.py b/pyalgotrade/feed/memfeed.py index 2f6b533b5..e9a82b55e 100644 --- a/pyalgotrade/feed/memfeed.py +++ b/pyalgotrade/feed/memfeed.py @@ -18,13 +18,12 @@ .. moduleauthor:: Gabriel Martin Becedillas Ruiz """ -from pyalgotrade import feed -from pyalgotrade import dataseries +from pyalgotrade import bar, dataseries, feed class MemFeed(feed.BaseFeed): def __init__(self, maxLen=None): - super(MemFeed, self).__init__(maxLen) + super(MemFeed, self).__init__(maxLen, False) self.__values = [] self.__nextIdx = 0 @@ -60,11 +59,12 @@ def createDataSeries(self, key, maxLen): return dataseries.SequenceDataSeries(maxLen) def getNextValues(self): - ret = (None, None) + ret = (None, None, None) if self.__nextIdx < len(self.__values): ret = self.__values[self.__nextIdx] self.__nextIdx += 1 - return ret + assert isinstance(ret, tuple) + return ret[0], ret[1], bar.Frequency.UNKNOWN # Add values to the feed. values should be a sequence of tuples. The tuples should have two elements: # 1: datetime.datetime. diff --git a/pyalgotrade/fsm.py b/pyalgotrade/fsm.py new file mode 100644 index 000000000..28e243aac --- /dev/null +++ b/pyalgotrade/fsm.py @@ -0,0 +1,107 @@ +#state machine +import enum +import inspect +import sys + +import pyalgotrade.logger + + +logger = pyalgotrade.logger.getLogger(__name__) + + +def state(state_enum, is_initial_state=False): + def wrapper(func): + ''' decorator for state machine + ''' + assert callable(func) + assert isinstance(state_enum, enum.Enum) + func.__state__ = state_enum + if is_initial_state: + func.__initial_state__ = True + return func + return wrapper + + +class StateMachine(object): + ''' new state machine + ''' + + def __init__(self): + self.__states = {} + self.__current_state = None + self.__last_state = None + initial_set = False + methods = inspect.getmembers(self.__class__, + predicate=lambda x: (inspect.isfunction(x) or + inspect.ismethod(x))) + for i in methods: + if hasattr(i[1], '__state__'): + self.__register_state(i[1].__state__, getattr(self, i[0])) + if hasattr(i[1], '__initial_state__'): + if initial_set: + raise Exception('you can only have one initial state') + initial_set = True + self.__set_initial_state(i[1].__state__) + if not initial_set: + raise Exception('no initial state defined') + + def __register_state(self, name, function): + logger.debug('Registering state [%s]' % name) + if name in self.__states: + raise Exception("Duplicate state %s" % name) + self.__states[name] = function + + def __set_initial_state(self, name): + assert name in self.__states + logger.debug('Initial state [%s]' % name) + self.__current_state = name + + @property + def current_state(self): + return self.__current_state + + @current_state.setter + def current_state(self, new_state): + assert new_state in self.__states + logger.info('Setting state from ' + '[{}] to [{}]'.format(self.__current_state, new_state)) + self.__current_state = new_state + + @property + def last_state(self): + return self.__last_state + + def run(self, *args, **kwargs): + assert self.__current_state is not None + new_state = self.__states[self.__current_state](*args, **kwargs) + self.__last_state = self.__current_state + if new_state != self.__current_state: + logger.debug('Switch state [%s] -> [%s]' % (self.__current_state, + new_state)) + assert new_state in self.__states + self.__current_state = new_state + + def run_forever(self, *args, **kwargs): + while True: + self.run(*args, **kwargs) + + +class StrategyFSM(StateMachine): + ''' state machine used by strategy runner + each state should have 2 arguments. + The first one is "bars" and the second + one is "states" + ''' + + def __init__(self, barfeed, states): + super(StrategyFSM, self).__init__() + self.__barfeed = barfeed + self.__states = states + + @property + def barfeed(self): + return self.__barfeed + + @property + def state(self): + return self.__states diff --git a/pyalgotrade/logger.py b/pyalgotrade/logger.py index 0992fda7a..496cead22 100644 --- a/pyalgotrade/logger.py +++ b/pyalgotrade/logger.py @@ -20,11 +20,23 @@ import logging import threading +import sys +from logging.handlers import RotatingFileHandler, SysLogHandler +import os +import coloredlogs initLock = threading.Lock() rootLoggerInitialized = False -log_format = "%(asctime)s %(name)s [%(levelname)s] %(message)s" +if 'TESTING' in os.environ and os.environ['TESTING'] != '0': + log_format = "%(asctime)s %(name)s [%(levelname)s] %(message)s" + sys_log = False +else: + log_format = ("%(asctime)s %(name)s %(process)d [%(levelname)s] " + "%(module)s - %(funcName)s: %(message)s") + sys_log = True + coloredlogs.install(level='INFO') + level = logging.INFO file_log = None # File name console_log = True @@ -38,7 +50,8 @@ def init_logger(logger): logger.setLevel(level) if file_log is not None: - fileHandler = logging.FileHandler(file_log) + fileHandler = RotatingFileHandler(file_log, maxBytes=20*1024*1024, + backupCount=100) init_handler(fileHandler) logger.addHandler(fileHandler) @@ -47,6 +60,18 @@ def init_logger(logger): init_handler(consoleHandler) logger.addHandler(consoleHandler) + if sys_log: + if sys.platform == "darwin": + # Apple made 10.5 more secure by disabling network syslog: + address = "/var/run/syslog" + elif sys.platform == "linux": + address = "/dev/log" + else: + address = ('localhost', 514) + sysHandler = SysLogHandler(address) + init_handler(sysHandler) + logger.addHandler(sysHandler) + def initialize(): global rootLoggerInitialized diff --git a/pyalgotrade/optimizer/xmlrpcserver.py b/pyalgotrade/optimizer/xmlrpcserver.py index 699dc5c51..e6a4d1c2b 100644 --- a/pyalgotrade/optimizer/xmlrpcserver.py +++ b/pyalgotrade/optimizer/xmlrpcserver.py @@ -153,15 +153,16 @@ def stop(self): self.shutdown() def serve(self): + assert len(self.__barFeed.getAllFrequencies()) == 1 try: # Initialize instruments, bars and parameters. logger.info("Loading bars") loadedBars = [] - for dateTime, bars in self.__barFeed: + for dateTime, bars, freq in self.__barFeed: loadedBars.append(bars) instruments = self.__barFeed.getRegisteredInstruments() self.__instrumentsAndBars = serialization.dumps((instruments, loadedBars)) - self.__barsFreq = self.__barFeed.getFrequency() + self.__barsFreq = self.__barFeed.getAllFrequencies()[0] if self.__autoStopThread: self.__autoStopThread.start() diff --git a/pyalgotrade/strategy/__init__.py b/pyalgotrade/strategy/__init__.py index 765e90a01..ba7770341 100644 --- a/pyalgotrade/strategy/__init__.py +++ b/pyalgotrade/strategy/__init__.py @@ -24,12 +24,16 @@ import six import pyalgotrade.broker -from pyalgotrade.broker import backtesting -from pyalgotrade import observer -from pyalgotrade import dispatcher +import pyalgotrade.fsm as fsm +import pyalgotrade.logger import pyalgotrade.strategy.position -from pyalgotrade import logger +from pyalgotrade import dispatcher, logger, observer from pyalgotrade.barfeed import resampled +from pyalgotrade.broker import backtesting +from pyalgotrade.strategy.state import StrategyState + + +log = pyalgotrade.logger.getLogger('strategy') @six.add_metaclass(abc.ABCMeta) @@ -604,3 +608,26 @@ def setDebugMode(self, debugOn): level = logging.DEBUG if debugOn else logging.INFO self.getLogger().setLevel(level) self.getBroker().getLogger().setLevel(level) + + +class LiveFSMStrategy(BacktestingStrategy): + def __init__(self, barFeed, fsmclass, cash_or_brk=1000000): + BacktestingStrategy.__init__(self, barFeed, cash_or_brk=cash_or_brk) + assert(issubclass(fsmclass, fsm.StrategyFSM)) + self.__states = StrategyState() + self.__fsmclass = fsmclass + self.__barfeed = barFeed + + @property + def states(self): + return self.__states + + def onStart(self): + log.info('initializing StrategyFSM...') + self.__fsminst = self.__fsmclass(barfeed=self.__barfeed, states=self.__states) + + def onBars(self, bars): + try: + self.__fsminst.run(bars=bars, states=self.__states) + except Exception as e: + log.error('Exception while running sate machine. %s' % str(e)) diff --git a/pyalgotrade/strategy/state.py b/pyalgotrade/strategy/state.py new file mode 100644 index 000000000..2ea3a772c --- /dev/null +++ b/pyalgotrade/strategy/state.py @@ -0,0 +1,70 @@ +import threading +import json +import coloredlogs +import pyalgotrade.logger + + +logger = pyalgotrade.logger.getLogger('strategystate') + + +class StrategyState: + + def __init__(self): + super(StrategyState, self).__setattr__('__states', {}) + super(StrategyState, self).__setattr__('__state_lock', threading.Lock()) + + def __getattr__(self, key): + return self.__getitem__(key) + + def __getitem__(self, key): + state_lock = getattr(self, '__state_lock') + states = getattr(self, '__states') + state_lock.acquire() + if key in states: + rtn = states[key] + else: + rtn = None + state_lock.release() + return rtn + + def __setattr__(self, key, value): + self.__setitem__(key, value) + + def __setitem__(self, key, value): + state_lock = getattr(self, '__state_lock') + states = getattr(self, '__states') + state_lock.acquire() + states[key] = value + state_lock.release() + + def dumps(self): + ''' dump object to string + ''' + self.__state_lock.acquire() + obj = self.__states.copy() + self.__state_lock.release() + return json.dumps(obj) + + def loads(self, data): + ''' load state object from string + ''' + obj = json.loads(data) + self.__state_lock.acquire() + self.__states = obj + self.__state_lock.release() + + def __str__(self): + state_lock = getattr(self, '__state_lock') + states = getattr(self, '__states') + state_lock.acquire() + tmp = states.copy() + state_lock.release() + return json.dumps(tmp) + + +if __name__ == '__main__': + a = StrategyState() + a.test = 1 + logger.info(a['test']) + a['test2'] = 2 + logger.info(a.test2) \ No newline at end of file diff --git a/pyalgotrade/utils/misc.py b/pyalgotrade/utils/misc.py new file mode 100644 index 000000000..4b4c4a9d5 --- /dev/null +++ b/pyalgotrade/utils/misc.py @@ -0,0 +1,51 @@ +import base64 +import sys +import threading +import traceback +import zlib + +import pyalgotrade.logger + +logger = pyalgotrade.logger.getLogger(__name__) + + +bugart = ''' + / .' + .---. \/ +(._.' \() + ^"""^" +BUG!!! +''' + + +def pyGo(func, *args): + t = threading.Thread(target=func, daemon=True, args=args) + t.start() + return t + + +def protected_function(exception_rtn=None): + def wrapper_outter(func): + def wrapper(*args, **kwargs): + try: + rtn = func(*args, **kwargs) + return rtn + except Exception: + info = traceback.format_exc() + logger.error('-' * 60) + logger.error(bugart + '\n' + info) + logger.error('-' * 60) + return exception_rtn + except KeyboardInterrupt: + logger.info('KeyboardInterrupt received, terminating...') + sys.exit(0) + return wrapper + return wrapper_outter + + +if __name__ == '__main__': + @protected_function(1) + def test(num): + raise Exception(str(num)) + + print('we got %s' % str(test(3))) diff --git a/samples/strategy/__init__.py b/samples/strategy/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/samples/strategy/strategyfsm.py b/samples/strategy/strategyfsm.py new file mode 100644 index 000000000..a008a667e --- /dev/null +++ b/samples/strategy/strategyfsm.py @@ -0,0 +1,56 @@ +import enum +import sys + +import pyalgotrade.fsm as fsm +import pyalgotrade.logger + +logger = pyalgotrade.logger.getLogger('strategyfsm') + + +class SampleStrategyFSMState(enum.Enum): + + INIT = 1 + STATE1 = 2 + STATE2 = 3 + ERROR = 99 + + +class SampleStrategyFSM(fsm.StrategyFSM): + + def __init__(self, barfeed, states): + super(SampleStrategyFSM, self).__init__(barfeed, states) + + def print_bars(self, bars): + for i in bars.getInstruments(): + logger.info('{} {} {}'.format(i, bars[i].getDateTime(), bars[i].getClose())) + + @fsm.state(SampleStrategyFSMState.INIT, True) + def state_init(self, bars, states): + # You are only supposed to save states in states variable + # DO NOT save your local variable and it is not guaranteed to be supported later + logger.info('INIT') + print(states.prev) + states.prev = 'INIT' + self.print_bars(bars) + return SampleStrategyFSMState.STATE1 + + @fsm.state(SampleStrategyFSMState.STATE1, False) + def state_state1(self, bars, states): + logger.info('STATE1') + print(states.prev) + states.prev = 'STATE1' + self.print_bars(bars) + return SampleStrategyFSMState.STATE2 + + @fsm.state(SampleStrategyFSMState.STATE2, False) + def state_state2(self, bars, states): + logger.info('STATE2') + print(states.prev) + states.prev = 'STATE2' + self.print_bars(bars) + return SampleStrategyFSMState.ERROR + + @fsm.state(SampleStrategyFSMState.ERROR, False) + def state_error(self, bars, states): + logger.info('ERROR') + sys.exit(0) diff --git a/setup.py b/setup.py index cda428e80..71bd0113d 100644 --- a/setup.py +++ b/setup.py @@ -40,6 +40,7 @@ packages=[ 'pyalgotrade', 'pyalgotrade.barfeed', + 'pyalgotrade.barfeed.driver', 'pyalgotrade.bitcoincharts', 'pyalgotrade.bitstamp', 'pyalgotrade.broker', @@ -67,6 +68,9 @@ "tornado", "tweepy", "ws4py>=0.3.4", + "coloredlogs", + "pika", + "flask", ], extras_require={ "TALib": ["Cython", "TA-Lib"], diff --git a/testcases/barfeed_test.py b/testcases/barfeed_test.py index 27ba4c930..73e377afd 100644 --- a/testcases/barfeed_test.py +++ b/testcases/barfeed_test.py @@ -57,7 +57,7 @@ def testDateTimesNotInOrder(self): ] f = barfeed.OptimizerBarFeed(bar.Frequency.DAY, ["orcl"], bars) with self.assertRaisesRegexp(Exception, "Bar date times are not in order.*"): - for dt, b in f: + for dt, b, freq in f: pass def testBaseBarFeed(self): diff --git a/testcases/btcharts_test.py b/testcases/btcharts_test.py index b50399a93..524203b74 100644 --- a/testcases/btcharts_test.py +++ b/testcases/btcharts_test.py @@ -30,7 +30,7 @@ class TestCase(common.TestCase): def testLoadNoFilter(self): feed = barfeed.CSVTradeFeed() feed.addBarsFromCSV(common.get_data_file_path("bitstampUSD.csv")) - loaded = [(dateTime, bars) for dateTime, bars in feed] + loaded = [(dateTime, bars) for dateTime, bars, _ in feed] self.assertEquals(len(loaded), 9999) @@ -49,7 +49,7 @@ def testLoadNoFilter(self): def testLoadFilterFrom(self): feed = barfeed.CSVTradeFeed() feed.addBarsFromCSV(common.get_data_file_path("bitstampUSD.csv"), "bitstampUSD", fromDateTime=dt.as_utc(datetime.datetime(2012, 5, 29))) - loaded = [(dateTime, bars) for dateTime, bars in feed] + loaded = [(dateTime, bars) for dateTime, bars, _ in feed] self.assertEquals(len(loaded), 646) @@ -68,7 +68,7 @@ def testLoadFilterFrom(self): def testLoadFilterFromAndTo(self): feed = barfeed.CSVTradeFeed() feed.addBarsFromCSV(common.get_data_file_path("bitstampUSD.csv"), "bitstampUSD", fromDateTime=dt.as_utc(datetime.datetime(2012, 5, 29)), toDateTime=datetime.datetime(2012, 5, 31)) - loaded = [(dateTime, bars) for dateTime, bars in feed] + loaded = [(dateTime, bars) for dateTime, bars, _ in feed] self.assertEquals(len(loaded), 579) diff --git a/testcases/common.py b/testcases/common.py index 621f9b0cf..997c6e239 100644 --- a/testcases/common.py +++ b/testcases/common.py @@ -22,6 +22,7 @@ import os import shutil import subprocess +import sys import tempfile import unittest @@ -60,7 +61,7 @@ def run_cmd(cmd): def run_python_code(code): - cmd = ["python"] + cmd = [sys.executable] cmd.append("-u") cmd.append("-c") cmd.append(code) @@ -68,7 +69,7 @@ def run_python_code(code): def run_sample_module(module, params=[]): - cmd = ["python"] + cmd = [sys.executable] cmd.append("-u") cmd.append("-m") cmd.append("samples.%s" % module) diff --git a/testcases/ninjatraderfeed_test.py b/testcases/ninjatraderfeed_test.py index 85a5c7bcb..692d40935 100644 --- a/testcases/ninjatraderfeed_test.py +++ b/testcases/ninjatraderfeed_test.py @@ -87,7 +87,7 @@ def testLocalizeAndFilter(self): } barFeed = ninjatraderfeed.Feed(ninjatraderfeed.Frequency.MINUTE, timezone) barFeed.addBarsFromCSV("spy", common.get_data_file_path("nt-spy-minute-2011-03.csv")) - for dateTime, bars in barFeed: + for dateTime, bars, _ in barFeed: price = prices.get(bars.getDateTime(), None) if price is not None: self.assertTrue(price == bars.getBar("spy").getClose()) diff --git a/testcases/yahoofeed_test.py b/testcases/yahoofeed_test.py index 9ef2fa10f..48d63560d 100644 --- a/testcases/yahoofeed_test.py +++ b/testcases/yahoofeed_test.py @@ -181,7 +181,7 @@ def testWithoutTimezone(self): barFeed = yahoofeed.Feed() barFeed.addBarsFromCSV(FeedTestCase.TestInstrument, common.get_data_file_path("orcl-2000-yahoofinance.csv")) barFeed.addBarsFromCSV(FeedTestCase.TestInstrument, common.get_data_file_path("orcl-2001-yahoofinance.csv")) - for dateTime, bars in barFeed: + for dateTime, bars, _ in barFeed: bar = bars.getBar(FeedTestCase.TestInstrument) self.assertTrue(dt.datetime_is_naive(bar.getDateTime())) @@ -189,7 +189,7 @@ def testWithDefaultTimezone(self): barFeed = yahoofeed.Feed(timezone=marketsession.USEquities.getTimezone()) barFeed.addBarsFromCSV(FeedTestCase.TestInstrument, common.get_data_file_path("orcl-2000-yahoofinance.csv")) barFeed.addBarsFromCSV(FeedTestCase.TestInstrument, common.get_data_file_path("orcl-2001-yahoofinance.csv")) - for dateTime, bars in barFeed: + for dateTime, bars, _ in barFeed: bar = bars.getBar(FeedTestCase.TestInstrument) self.assertFalse(dt.datetime_is_naive(bar.getDateTime())) @@ -197,7 +197,7 @@ def testWithPerFileTimezone(self): barFeed = yahoofeed.Feed() barFeed.addBarsFromCSV(FeedTestCase.TestInstrument, common.get_data_file_path("orcl-2000-yahoofinance.csv"), marketsession.USEquities.getTimezone()) barFeed.addBarsFromCSV(FeedTestCase.TestInstrument, common.get_data_file_path("orcl-2001-yahoofinance.csv"), marketsession.USEquities.getTimezone()) - for dateTime, bars in barFeed: + for dateTime, bars, _ in barFeed: bar = bars.getBar(FeedTestCase.TestInstrument) self.assertFalse(dt.datetime_is_naive(bar.getDateTime())) @@ -218,7 +218,7 @@ def testWithIntegerTimezone(self): def testMapTypeOperations(self): barFeed = yahoofeed.Feed() barFeed.addBarsFromCSV(FeedTestCase.TestInstrument, common.get_data_file_path("orcl-2000-yahoofinance.csv"), marketsession.USEquities.getTimezone()) - for dateTime, bars in barFeed: + for dateTime, bars, _ in barFeed: self.assertTrue(FeedTestCase.TestInstrument in bars) self.assertFalse(FeedTestCase.TestInstrument not in bars) bars[FeedTestCase.TestInstrument] @@ -228,7 +228,7 @@ def testMapTypeOperations(self): def testBounded(self): barFeed = yahoofeed.Feed(maxLen=2) barFeed.addBarsFromCSV(FeedTestCase.TestInstrument, common.get_data_file_path("orcl-2000-yahoofinance.csv"), marketsession.USEquities.getTimezone()) - for dateTime, bars in barFeed: + for dateTime, bars, _ in barFeed: pass barDS = barFeed[FeedTestCase.TestInstrument] diff --git a/tools/yahoodbfeed/analyze_gaps.py b/tools/yahoodbfeed/analyze_gaps.py index 43bfb1283..7b90c69f1 100644 --- a/tools/yahoodbfeed/analyze_gaps.py +++ b/tools/yahoodbfeed/analyze_gaps.py @@ -92,7 +92,7 @@ def process_symbol(symbol, fromYear, toYear, missingDataVerifierClass): if filesFound > 0: # Process all items. - for dateTime, bars in feed: + for dateTime, bars, _ in feed: pass missingDataVerifier = missingDataVerifierClass(feed[symbol]) diff --git a/tox.ini b/tox.ini index 8ae3c3349..8b5e2df47 100644 --- a/tox.ini +++ b/tox.ini @@ -3,7 +3,9 @@ envlist = py27,py37 [testenv] # Disabling hash randomization to get deterministic dict prints -setenv = PYTHONHASHSEED=0 +setenv = + PYTHONHASHSEED=0 + TESTING=1 passenv = TWITTER_CONSUMER_KEY TWITTER_CONSUMER_SECRET TWITTER_ACCESS_TOKEN TWITTER_ACCESS_TOKEN_SECRET QUANDL_API_KEY extras = TALib diff --git a/travis/Dockerfile b/travis/Dockerfile index a5cd01ec4..d996b9679 100644 --- a/travis/Dockerfile +++ b/travis/Dockerfile @@ -1,6 +1,6 @@ # ARG PYALGOTRADE_TAG # FROM gbecedillas/pyalgotrade:${PYALGOTRADE_TAG} -FROM gbecedillas/pyalgotrade:0.20-py27 +FROM gbecedillas/pyalgotrade:0.20-py37 MAINTAINER Gabriel Martin Becedillas Ruiz diff --git a/travis/run_tests.sh b/travis/run_tests.sh index 0b43a3373..45078f2e1 100755 --- a/travis/run_tests.sh +++ b/travis/run_tests.sh @@ -4,4 +4,4 @@ export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH # This is needed to avoid "Coverage.py warning: No data was collected" from cov plugin. export PYTHONPATH=. -tox -v -e py27 +tox -v -e py37