|
| 1 | +from udapi.core.block import Block |
| 2 | +import udapi.core.coref |
| 3 | +import itertools |
| 4 | + |
| 5 | +class MarkNested(Block): |
| 6 | + """Find nested mentions.""" |
| 7 | + |
| 8 | + def __init__(self, same_cluster_only=True, both_discontinuous=False, multiword_only=False, |
| 9 | + print_form=False, log=True, mark=True, **kwargs): |
| 10 | + super().__init__(**kwargs) |
| 11 | + self.same_cluster_only = same_cluster_only |
| 12 | + self.both_discontinuous = both_discontinuous |
| 13 | + self.multiword_only = multiword_only |
| 14 | + self.print_form = print_form |
| 15 | + self.log = log |
| 16 | + self.mark = mark |
| 17 | + |
| 18 | + def _print(self, mention): |
| 19 | + if self.print_form: |
| 20 | + return mention.cluster.cluster_id + ':' + ' '.join([w.form for w in mention.words]) |
| 21 | + else: |
| 22 | + return mention.cluster.cluster_id + ':' + mention.span |
| 23 | + |
| 24 | + def process_tree(self, tree): |
| 25 | + mentions = set() |
| 26 | + for node in tree.descendants_and_empty: |
| 27 | + for m in node.coref_mentions: |
| 28 | + mentions.add(m) |
| 29 | + for mA, mB in itertools.combinations(mentions, 2): |
| 30 | + if self.same_cluster_only and mA.cluster != mB.cluster: |
| 31 | + continue |
| 32 | + if self.both_discontinuous and (',' not in mA.span or ',' not in mB.span): |
| 33 | + continue |
| 34 | + sA, sB = set(mA.words), set(mB.words) |
| 35 | + if not (sA <= sB) and not (sB <= sA): |
| 36 | + continue |
| 37 | + if self.multiword_only and (len(sA) == 1 or len(sB) == 1): |
| 38 | + continue |
| 39 | + if self.mark: |
| 40 | + for w in mA.words + mB.words: |
| 41 | + w.misc['Mark'] = 1 |
| 42 | + mA.words[0].misc['Mark'] = f"{self._print(mA)}+{self._print(mB)}" |
| 43 | + if self.log: |
| 44 | + print(f"nested mentions at {tree.sent_id}: {self._print(mA)} + {self._print(mB)}") |
0 commit comments