Skip to content

Commit

Permalink
Refactoring CroppedBertSampleRowProvider implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Nov 27, 2022
1 parent c1b0c94 commit c539597
Showing 1 changed file with 6 additions and 46 deletions.
52 changes: 6 additions & 46 deletions framework/arekit/serialize_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,53 +57,13 @@ def __in():

return _from, _to

def _fill_row_core(self, row, text_opinion_linkage, index_in_linked, etalon_label,
parsed_news, sentence_ind, s_ind, t_ind):

def __assign_value(column, value):
row[column] = value

super(CroppedBertSampleRowProvider, self)._fill_row_core(row=row,
text_opinion_linkage=text_opinion_linkage,
index_in_linked=index_in_linked,
etalon_label=etalon_label,
parsed_news=parsed_news,
sentence_ind=sentence_ind,
s_ind=s_ind,
t_ind=t_ind)

# вырезаем часть текста.

def _provide_sentence_terms(self, parsed_news, sentence_ind, s_ind, t_ind):
terms_iter, src_ind, tgt_ind = super(CroppedBertSampleRowProvider, self)._provide_sentence_terms(
parsed_news=parsed_news, sentence_ind=sentence_ind, s_ind=s_ind, t_ind=t_ind)
terms = list(terms_iter)
_from, _to = self.__calc_window_bounds(window_size=self.__crop_window_size,
s_ind=s_ind, t_ind=t_ind,
input_length=len(row["text_a"]))

expected_label = text_opinion_linkage.get_linked_label()

sentence_terms = list(self._provide_sentence_terms(parsed_news=parsed_news, sentence_ind=sentence_ind))

cropped_sentence_terms = sentence_terms[_from:_to]
s_ind = s_ind - _from
t_ind = t_ind - _from

self.TextProvider.add_text_in_row(
set_text_func=lambda column, value: __assign_value(column, value),
sentence_terms=cropped_sentence_terms,
s_ind=s_ind,
t_ind=t_ind,
expected_label=expected_label)

# обновляем содержимое.
entities_in_cropped = list(filter(lambda term: isinstance(term, Entity), cropped_sentence_terms))

cropped_entity_ids = [str(i) for i, term in enumerate(cropped_sentence_terms) if isinstance(term, Entity)]

row[const.ENTITY_VALUES] = ",".join([e.Value.replace(',', '') for e in entities_in_cropped])
row[const.ENTITY_TYPES] = ",".join([e.Type.replace(',', '') for e in entities_in_cropped])
row[const.ENTITIES] = ",".join(cropped_entity_ids)

row[const.S_IND] = s_ind
row[const.T_IND] = t_ind
s_ind=s_ind, t_ind=t_ind, input_length=len(terms))
return terms[_from:_to], src_ind - _from, tgt_ind - _from


def serialize_bert(split_filepath, terms_per_context, writer, sample_row_provider, output_dir,
Expand Down

0 comments on commit c539597

Please sign in to comment.