diff --git a/src/widgetastic/widget/table.py b/src/widgetastic/widget/table.py
index 6de4b3bc..85719fed 100644
--- a/src/widgetastic/widget/table.py
+++ b/src/widgetastic/widget/table.py
@@ -10,7 +10,9 @@
from collections import defaultdict, deque
from copy import copy
from jsmin import jsmin
+from operator import attrgetter
+from widgetastic.browser import Browser
from widgetastic.exceptions import RowNotFound
from widgetastic.log import create_child_logger, create_item_logger
from widgetastic.utils import (ParametrizedLocator, ConstructorResolvable, attributize_string)
@@ -24,6 +26,40 @@
Pattern = re._pattern_type
+def resolve_table_widget(parent, wcls):
+ """
+ Used for applying a parent to a WidgetDescriptor passed in at Table init time.
+
+ This turns the WidgetDescriptor into a completed Widget, tied to the
+ appropriate parent.
+
+ For example, if you pass in 'column_widgets' to a table like so:
+ {'columnA': MyWidget(somearg1, somearg2)}
+
+ Then upon calling the TableColumn '.widget' function, this function will resolve
+ the "MyWidget" WidgetDescriptor into a "MyWidget" Widget with parent set to the TableColumn.
+ """
+ if not isinstance(parent, (Browser, Widget)):
+ raise TypeError("'parent' must be an instance of widgetastic.Widget or widgetastic.Browser")
+
+ # Verpick, ...
+ if isinstance(wcls, ConstructorResolvable):
+ return wcls.resolve(parent)
+
+ # We cannot use WidgetDescriptor's facility for instantiation as it does caching and all
+ # that stuff
+ args = ()
+ kwargs = {}
+
+ if isinstance(wcls, WidgetDescriptor):
+ args = wcls.args
+ kwargs.update(wcls.kwargs)
+ wcls = wcls.klass
+ if 'logger' not in kwargs:
+ kwargs['logger'] = create_child_logger(parent.logger, wcls.__name__)
+ return wcls(parent, *args, **kwargs)
+
+
class TableColumn(Widget, ClickableMixin):
"""Represents a cell in the row."""
def __init__(self, parent, position, absolute_position=None, logger=None):
@@ -32,7 +68,7 @@ def __init__(self, parent, position, absolute_position=None, logger=None):
self.absolute_position = absolute_position # absolute position according to row/colspan
def __locator__(self):
- return self.browser.element('./td[{}]'.format(self.position + 1), parent=self.parent)
+ return self.parent.table.COLUMN_AT_POSITION.format(self.position + 1)
def __repr__(self):
return '{}({!r}, {!r})'.format(type(self).__name__, self.parent, self.position)
@@ -52,9 +88,6 @@ def column_name(self):
@cached_property
def widget(self):
"""Returns the associated widget if defined. If there is none defined, returns None."""
- args = ()
- kwargs = {}
-
if self.absolute_position and self.absolute_position != self.position:
position = self.absolute_position
else:
@@ -68,21 +101,7 @@ def widget(self):
if self.column_name not in self.table.column_widgets:
return None
wcls = self.table.column_widgets[self.column_name]
-
- # Verpick, ...
- if isinstance(wcls, ConstructorResolvable):
- return wcls.resolve(self)
-
- # We cannot use WidgetDescriptor's facility for instantiation as it does caching and all
- # that stuff
- if isinstance(wcls, WidgetDescriptor):
- args = wcls.args
- kwargs = wcls.kwargs
- wcls = wcls.klass
- kwargs = copy(kwargs)
- if 'logger' not in kwargs:
- kwargs['logger'] = create_child_logger(self.logger, wcls.__name__)
- return wcls(self, *args, **kwargs)
+ return resolve_table_widget(self, wcls)
@property
def text(self):
@@ -165,8 +184,7 @@ def __repr__(self):
return '{}({!r}, {!r})'.format(type(self).__name__, self.parent, self.index)
def __locator__(self):
- loc = self.parent.ROW_AT_INDEX.format(self.index + 1)
- return self.browser.element(loc, parent=self.parent)
+ return self.parent.ROW_AT_INDEX.format(self.index + 1)
def position_to_column_name(self, position):
"""Maps the position index into the column name (pretty)"""
@@ -181,13 +199,29 @@ def __getitem__(self, item):
raise TypeError('row[] accepts only integers and strings')
if self.table.table_tree:
- # todo: add support of xpath and/or iteration to anytree lib
- return self.table.resolver.glob(self.table.table_tree,
- '/table/tbody/tr[{}]/*[{}]'.format(self.index,
- index))[0].obj
+ # We could find either a TableColumn or a TableReference node at this position...
+ cols = self.table.resolver.glob(
+ self.table.table_tree,
+ '{}[{}]{}[{}]'.format(
+ self.table.ROW_RESOLVER_PATH,
+ self.index,
+ self.table.COLUMN_RESOLVER_PATH,
+ index
+ ),
+ handle_resolver_error=True
+ )
+ if not cols:
+ raise IndexError(
+ "Row {} has no TableColumn or TableReference node at position {}".format(
+ repr(self), index
+ )
+ )
+ return cols[0].obj
else:
- return self.Column(self, index, logger=create_item_logger(self.logger, item))
+ return self.table._create_column(
+ self, index, logger=create_item_logger(self.logger, item)
+ )
def __getattr__(self, attr):
try:
@@ -311,7 +345,21 @@ class Table(Widget):
If you subclass :py:class:`Table`, :py:class:`TableRow`, or :py:class:`TableColumn`, do not
forget to update the :py:attr:`Table.Row` and :py:attr:`TableRow.Column` in order for the
- classes to use the correct class.
+ the Table to use the correct class. You can also adjust the class variable constants to change
+ the way :py:class:`Table` looks for rows. For example, you could use the following to create
+ a table class that builds rows based on each 'tbody' tag within the table, with each
+ row being a custom defined class.
+
+ .. code-block:: python
+
+ class MyCustomTable(Table):
+ ROWS = './tbody'
+ ROW_RESOLVER_PATH = '/table/tbody'
+ ROW_AT_INDEX = './tbody[{0}]'
+ COLUMN_RESOLVER_PATH = '/tr[0]/td'
+ COLUMN_AT_POSITION = '/tr[1]/td[{0}]'
+ ROW_TAG = 'tbody'
+ Row = MyCustomTableRowClass
Args:
locator: A locator to the table ``
`` tag.
@@ -324,13 +372,26 @@ class Table(Widget):
bottom_ignore_fill: Whether to also strip these top rows for fill.
"""
ROWS = './tbody/tr[./td]|./tr[not(./th) and ./td]'
+
+ # Resolve path is used for self.resolver for anytree node lookups
+ # where position starts at '0' for elements in the node tree
+ ROW_RESOLVER_PATH = '/table/tbody/tr'
+ COLUMN_RESOLVER_PATH = '/td'
+
+ # These path vars are used for selenium browser.element lookups,
+ # where position starts at '1' for elements
+ COLUMN_AT_POSITION = './td[{0}]'
+ ROW_AT_INDEX = './tbody/tr[{0}]|./tr[not(./th)][{0}]'
+
+ ROW_TAG = 'tr'
+ COLUMN_TAG = 'td'
HEADER_IN_ROWS = './tbody/tr[1]/th'
HEADERS = './thead/tr/th|./tr/th|./thead/tr/td' + '|' + HEADER_IN_ROWS
- ROW_AT_INDEX = './tbody/tr[{0}]|./tr[not(./th)][{0}]'
ROOT = ParametrizedLocator('{@locator}')
Row = TableRow
+ Column = TableColumn
def __init__(
self, parent, locator, column_widgets=None, assoc_column=None,
@@ -446,6 +507,14 @@ def assoc_column_position(self):
raise TypeError(
'Wrong type passed for assoc_column= : {}'.format(type(self.assoc_column).__name__))
+ def _create_row(self, parent, index, logger=None):
+ """Override these if you wish to change row behavior in a child class."""
+ return self.Row(parent, index, logger)
+
+ def _create_column(self, parent, position, absolute_position=None, logger=None):
+ """Override this if you wish to change column behavior in a child class."""
+ return self.Column(parent, position, absolute_position, logger)
+
def __getitem__(self, item):
if isinstance(item, six.string_types):
if self.assoc_column is None:
@@ -469,14 +538,14 @@ def __getitem__(self, item):
at_index = self._process_negative_index(at_index)
if self.table_tree:
- nodes = self.resolver.glob(self.table_tree, '/table/tbody/tr*')
+ nodes = self.resolver.glob(self.table_tree, self.ROW_RESOLVER_PATH)
at_index = at_index + 1 if self._is_header_in_body else at_index
try:
return six.next(n.obj for n in nodes if n.position == at_index)
except StopIteration:
raise RowNotFound('Row not found by index {} via {}'.format(at_index, item))
else:
- return self.Row(self, at_index, logger=create_item_logger(self.logger, item))
+ return self._create_row(self, at_index, logger=create_item_logger(self.logger, item))
def row(self, *extra_filters, **filters):
try:
@@ -534,11 +603,13 @@ def _all_rows(self):
# passing index to TableRow, should not be <1
# +1 offset on end because xpath index vs 0-based range()
if self.table_tree:
- for node in self.resolver.glob(self.table_tree, '/table/tbody/tr*'):
+ for node in self.resolver.glob(self.table_tree, self.ROW_RESOLVER_PATH):
yield node.obj
else:
for row_pos in range(self.row_count):
- yield self.Row(self, row_pos, logger=create_item_logger(self.logger, row_pos))
+ yield self._create_row(
+ self, row_pos, logger=create_item_logger(self.logger, row_pos)
+ )
def _process_filters(self, *extra_filters, **filters):
# Pre-process the filters
@@ -610,8 +681,12 @@ def _build_query(self, processed_filters, row_filters):
raise ValueError('Unknown method {}'.format(method))
col_query_parts.append(q)
- query_parts.append('./td[{}][{}]'.format(column_index + 1,
- ' and '.join(col_query_parts)))
+ query_parts.append(
+ '{}[{}]'.format(
+ self.COLUMN_AT_POSITION.format(column_index + 1),
+ ' and '.join(col_query_parts)
+ )
+ )
# Row query
row_parts = []
@@ -641,14 +716,16 @@ def _build_query(self, processed_filters, row_filters):
raise ValueError('Unsupported action {}'.format(row_action))
if query_parts and row_parts:
- query = './/tr[{}][{}]'.format(' and '.join(row_parts), ' and '.join(query_parts))
+ query = '({})[{}][{}]'.format(
+ self.ROW_AT_INDEX.format('*'), ' and '.join(row_parts), ' and '.join(query_parts)
+ )
elif query_parts:
- query = './/tr[{}]'.format(' and '.join(query_parts))
+ query = '({})[{}]'.format(self.ROW_AT_INDEX.format('*'), ' and '.join(query_parts))
elif row_parts:
- query = './/tr[{}]'.format(' and '.join(row_parts))
+ query = '({})[{}]'.format(self.ROW_AT_INDEX.format('*'), ' and '.join(row_parts))
else:
# When using ONLY regexps, we might see no query_parts, therefore default query
- query = self.ROWS
+ query = self.ROW_AT_INDEX.format('*')
return query
@@ -662,8 +739,13 @@ def _filter_rows_by_query(self, query):
# incorrect and has to be decreased
# If the header is not in the body of the table, number of preceeding rows is 0-based
# what is correct
- rows.append(self.Row(self, row_pos - 1 if self._is_header_in_body else row_pos,
- logger=create_item_logger(self.logger, row_pos)))
+ rows.append(
+ self._create_row(
+ self,
+ row_pos - 1 if self._is_header_in_body else row_pos,
+ logger=create_item_logger(self.logger, row_pos)
+ )
+ )
return rows
def _apply_row_filter(self, rows, row_filters):
@@ -898,7 +980,7 @@ def has_rowcolspan(self):
def _process_table(self):
queue = deque()
- tree = Node(name=self.browser.tag(self), obj=self, position=None)
+ tree = Node(name=self.browser.tag(self), obj=self, position=0)
queue.append(tree)
while len(queue) > 0:
@@ -907,15 +989,21 @@ def _process_table(self):
children = self.browser.elements('./*[descendant-or-self::node()]', parent=node.obj)
for position, child in enumerate(children):
cur_tag = self.browser.tag(child)
- if cur_tag == 'tr':
+ if cur_tag == self.ROW_TAG:
# todo: add logger
- cur_obj = TableRow(parent=self._get_ancestor_node_obj(node), index=position)
+ cur_obj = self._create_row(
+ parent=self._get_ancestor_node_obj(node),
+ index=position
+ )
cur_node = Node(name=cur_tag, parent=node, obj=cur_obj, position=position)
queue.append(cur_node)
- elif cur_tag == 'td':
+ elif cur_tag == self.COLUMN_TAG:
cur_position = self._get_position_respecting_spans(node)
- cur_obj = TableColumn(parent=node.obj, position=cur_position,
- absolute_position=cur_position)
+ cur_obj = self._create_column(
+ parent=self._get_ancestor_node_obj(node),
+ position=cur_position,
+ absolute_position=cur_position
+ )
Node(name=cur_tag, parent=node, obj=cur_obj, position=cur_position)
rowsteps = range(1, int(child.get_attribute('rowspan') or 0))
@@ -941,24 +1029,20 @@ def _process_table(self):
ref_parent = node
ref_obj = TableReference(parent=ref_parent, reference=cur_obj)
ref_position = cur_position if col_step == 0 else cur_position + col_step
- Node(name='ref', parent=ref_parent, obj=ref_obj,
+ Node(name=cur_tag, parent=ref_parent, obj=ref_obj,
position=ref_position)
else:
- if cur_tag == 'thead':
- # not necessary now since current Table implementation
- # analyzes headers itself
- # todo: move headers to tree later
- continue
- cur_node = Node(name=cur_tag, parent=node, obj=child, position=None)
+ cur_node = Node(name=cur_tag, parent=node, obj=child, position=position)
queue.append(cur_node)
return tree
def _recalc_column_positions(self, tree):
- for row in self.resolver.glob(tree, '/table/tbody/tr'):
+ for row in self.resolver.glob(tree, self.ROW_RESOLVER_PATH):
modifier = 0
- cols = self.resolver.glob(row, './*')
- for col in cols:
+ # Look for column nodes
+ cols = self.resolver.glob(row, './{}'.format(self.COLUMN_RESOLVER_PATH))
+ for col in sorted(cols, key=attrgetter('position')):
if getattr(col.obj, 'refers_to', None):
modifier -= 1
continue
@@ -1018,9 +1102,15 @@ def get(self, node, path):
node = self._Resolver__get(node, part)
return node
- def glob(self, node, path):
+ def glob(self, node, path, handle_resolver_error=False):
node, parts = self._Resolver__start(node, path)
- return self.__glob(node, parts)
+ try:
+ return self.__glob(node, parts)
+ except ResolverError:
+ if handle_resolver_error:
+ return []
+ else:
+ raise
def __glob(self, node, parts):
nodes = []
@@ -1044,12 +1134,22 @@ def __glob(self, node, parts):
def _get_node_by_index(self, node, part):
part, position = self.index_regexp.match(part).groups()
- if self.is_wildcard(part):
- cur_node = self.__glob(node, part)[0]
- else:
- cur_node = self._Resolver__get(node, part)
- return cur_node.parent.children[int(position)]
+ matching_nodes = self._Resolver__find(node, part, None)
+ for node_at_pos in filter(lambda n: n.position == int(position), matching_nodes):
+ return node_at_pos
+ else:
+ names = [
+ "{}[{}]".format(repr(getattr(c, self.pathattr, None)), c.position)
+ for c in node.children
+ ]
+ raise ResolverError(
+ node,
+ part,
+ "{} has no child '{}' with position={}. Children are: {}".format(
+ repr(node), part, position, ", ".join(names)
+ )
+ )
def __find(self, node, pat, remainder):
matches = []
diff --git a/testing/test_basic_widgets.py b/testing/test_basic_widgets.py
index 7a25d284..558959b4 100644
--- a/testing/test_basic_widgets.py
+++ b/testing/test_basic_widgets.py
@@ -5,7 +5,8 @@
from widgetastic.exceptions import DoNotReadThisWidget
from widgetastic.widget import (
- View, Table, Text, TextInput, FileInput, Checkbox, Select, ColourInput)
+ View, Table, Text, TextInput, FileInput, Checkbox, Select, ColourInput, Widget)
+from widgetastic.widget.table import TableRow
from widgetastic.utils import Fillable, ParametrizedString, VersionPick, Version
@@ -337,6 +338,144 @@ class TestForm(View):
u'Widget': u'widget6'}]
+def test_table_multiple_tbody(browser):
+ class TBodyRow(TableRow):
+ ROW = "./tr[1]"
+ HIDDEN_CONTENT = "./tr[2]/td[1]"
+
+ def __init__(self, parent, index, logger=None):
+ Widget.__init__(self, parent, logger=logger)
+ # We don't need to adjust index by +1 because anytree Node position will
+ # already be '+1' due to presence of 'thead' among the 'tbody' rows
+ self.index = index
+ self.hidden_content = Text(parent=self, locator=self.HIDDEN_CONTENT)
+
+ def __locator__(self):
+ # We don't need to adjust index by +1 because anytree Node position will
+ # already be '+1' due to presence of 'thead' among the 'tbody' rows
+ return self.parent.ROW_AT_INDEX.format(self.index)
+
+ @property
+ def is_displayed(self):
+ return self.browser.is_displayed(self.ROW, parent=self)
+
+ class TBodyTable(Table):
+ ROWS = "./tbody"
+ ROW_RESOLVER_PATH = "/table/tbody"
+ ROW_AT_INDEX = "./tbody[{0}]"
+ COLUMN_RESOLVER_PATH = "/tr[0]/td"
+ COLUMN_AT_POSITION = "./tr[1]/td[{0}]"
+ ROW_TAG = "tbody"
+ Row = TBodyRow
+
+ @property
+ def _is_header_in_body(self):
+ """Override this to always return true.
+
+ Since we are resolving rows by the 'tbody' tag, widgetastic.Table._process_table
+ creates the rows with a position starting at 1 (because a tag is present
+ when enumerating through the tag's children)
+ """
+ return True
+
+ class TestForm(View):
+ table1 = TBodyTable(
+ '#multiple_tbody_table',
+ column_widgets={
+ 'First Name': TextInput(locator='./input'),
+ 'Last Name': TextInput(locator='./input'),
+ 'Widget': TextInput(locator='./input'),
+ }
+ )
+
+ view = TestForm(browser)
+
+ assert view.table1.headers == ('#', 'First Name', 'Last Name', 'Username', 'Widget')
+ assert len(list(view.table1.rows())) == 3
+
+ assert len(list(view.table1.rows(first_name='Mark'))) == 1
+ assert len(list(view.table1.rows(username__startswith='@slacker'))) == 1
+ assert len(list(view.table1.rows(first_name__startswith='Larry',
+ first_name__endswith='Bird'))) == 1
+ assert len(list(view.table1.rows(_row__attr=('data-test', 'def-345')))) == 1
+ assert len(list(view.table1.rows(_row__attr_startswith=('data-test', 'abc')))) == 2
+ assert len(list(view.table1.rows(_row__attr_endswith=('data-test', '345')))) == 2
+ assert len(list(view.table1.rows(_row__attr_contains=('data-test', '3')))) == 3
+ assert len(list(view.table1.rows(
+ _row__attr_contains=('data-test', '3'), _row__attr_startswith=('data-test', 'abc')))) == 2
+ assert len(list(view.table1.rows(_row__attr=('data-test', 'abc-345'), first_name='qwer'))) == 0
+
+ with pytest.raises(ValueError):
+ list(view.table1.rows(_row__papalala=('foo', 'bar')))
+
+ with pytest.raises(ValueError):
+ list(view.table1.rows(_row__attr_papalala=('foo', 'bar')))
+
+ with pytest.raises(ValueError):
+ list(view.table1.rows(_row__attr='foobar'))
+
+ assert len(list(view.table1.rows((0, '1')))) == 1
+ assert len(list(view.table1.rows((1, 'startswith', 'Jacob')))) == 1
+ assert len(list(view.table1.rows((1, 'startswith', 'Jacob'), username__endswith='at'))) == 1
+
+ assert len(list(view.table1.rows((1, re.compile(r'Mark$'))))) == 1
+ assert len(list(view.table1.rows((1, re.compile(r'^Jacob'))))) == 1
+ assert len(list(view.table1.rows(('Last Name', re.compile(r'^Otto'))))) == 1
+ assert len(list(view.table1.rows((0, re.compile(r'^2')), (3, re.compile(r'fat$'))))) == 1
+
+ row = view.table1.row(username='@slacker')
+ assert row[0].text == '3'
+ assert row['First Name'].text == 'Larry the Bird'
+ assert row.first_name.text == 'Larry the Bird'
+
+ assert row.read() == {u'#': u'3',
+ u'First Name': u'Larry the Bird',
+ u'Last Name': u'Larry the Bird',
+ u'Username': u'@slacker',
+ u'Widget': u'widget3'}
+
+ unpacking_fake_read = [(header, column.text) for header, column in row]
+ assert unpacking_fake_read == [(u'#', u'3'),
+ (u'First Name', u'Larry the Bird'),
+ (u'Last Name', u'Larry the Bird'),
+ (u'Username', u'@slacker'),
+ (u'Widget', u'')]
+
+ assert view.table1[1].last_name.text == 'Thornton'
+
+ with pytest.raises(AttributeError):
+ row.papalala
+
+ with pytest.raises(TypeError):
+ view.table1['boom!']
+
+ with pytest.raises(IndexError):
+ view.table1[1000]
+
+ row = next(view.table1.rows())
+ assert row.first_name.text == 'Mark'
+
+ assert view.table1.read() == [{u'#': u'1',
+ u'First Name': u'Mark',
+ u'Last Name': u'Otto',
+ u'Username': u'@mdo',
+ u'Widget': u'widget1'},
+ {u'#': u'2',
+ u'First Name': u'Jacob',
+ u'Last Name': u'Thornton',
+ u'Username': u'@fat',
+ u'Widget': u'widget2'},
+ {u'#': u'3',
+ u'First Name': u'Larry the Bird',
+ u'Last Name': u'Larry the Bird',
+ u'Username': u'@slacker',
+ u'Widget': u'widget3'}]
+
+ for row in view.table1:
+ assert row.is_displayed
+ assert not row.hidden_content.is_displayed
+
+
def test_table_no_header(browser):
class TestForm(View):
nohead_table = Table('#without_thead')
diff --git a/testing/testing_page.html b/testing/testing_page.html
index fe7374a9..c1613701 100644
--- a/testing/testing_page.html
+++ b/testing/testing_page.html
@@ -237,5 +237,53 @@ bartest
+
+
+