diff --git a/myql/myql.py b/myql/myql.py index efceccb..177fb8e 100755 --- a/myql/myql.py +++ b/myql/myql.py @@ -78,6 +78,10 @@ def _payload_builder(self, query, format=None): 'debug': self.debug, 'jsonCompact': 'new' if self.jsonCompact else '' } + + if vars(self).get('_vars'): + payload.update(self._vars) + if self.crossProduct: payload['crossProduct'] = 'optimized' @@ -136,12 +140,13 @@ def _clause_formatter(self, cond): cond = " ".join(cond) else: - if isinstance(cond[2], str): - var = re.match('^@(\w+)$', cond[2]) - else: - var = None - if var : - cond[2] = "{0}".format(var.group(1)) + #if isinstance(cond[2], str): + # var = re.match('^@(\w+)$', cond[2]) + #else: + # var = None + #if var : + if isinstance(cond[2], str) and cond[2].startswith('@'): + cond[2] = "{0}".format(cond[2]) else : cond[2] = "'{0}'".format(cond[2]) cond = ' '.join(cond) @@ -212,6 +217,13 @@ def use(self, url, name='mytable'): self.yql_table_name = name return {'table url': url, 'table name': name} + ##SET + def set(self, myvars): + ''' + ''' + self._vars = myvars + return True + ##DESC def desc(self, table): '''Returns table description diff --git a/tests/tests.py b/tests/tests.py index 01692da..43a11d8 100755 --- a/tests/tests.py +++ b/tests/tests.py @@ -137,6 +137,15 @@ def test_cross_product(self): logging.debug("{0} {1}".format(response.status_code, response.reason)) self.assertEqual(response.status_code, 200) + def test_variable_substitution(self,): + yql = YQL() + var = {'home': 'Congo'} + yql.set(var) + + response = yql.select('geo.states', remote_filter=(5,)).where(['place', '=', '@home']) + logging.debug(pretty_json(response.content)) + self.assertEqual(response.status_code, 200) + def test_raise_exception_no_table_selected(self): with self.assertRaises(NoTableSelectedError): response = self.yql.select(None).where([]) @@ -478,10 +487,10 @@ def setUp(self,): self.oauth = OAuth1(None, None, from_file='credentials.json') self.yql = YQL(oauth=self.oauth) - def test_get_contacts(self,): - data = self.yql.select('social.contacts').where(['guid','=', '@me']) - logging.debug(pretty_json(data.content)) - self.assertEqual(data.status_code, 200) + #def test_get_contacts(self,): + # data = self.yql.select('social.contacts').where(['guid','=', '@me']) + # logging.debug(pretty_json(data.content)) + # self.assertEqual(data.status_code, 200) class TestTable(unittest.TestCase):