| | import csv |
| | import json |
| | import torch |
| | from transformers import BertTokenizer |
| |
|
| |
|
| | class CNerTokenizer(BertTokenizer): |
| | def __init__(self, vocab_file, do_lower_case=True): |
| | super().__init__(vocab_file=str(vocab_file), do_lower_case=do_lower_case) |
| | self.vocab_file = str(vocab_file) |
| | self.do_lower_case = do_lower_case |
| |
|
| | def tokenize(self, text): |
| | _tokens = [] |
| | for c in text: |
| | if self.do_lower_case: |
| | c = c.lower() |
| | if c in self.vocab: |
| | _tokens.append(c) |
| | else: |
| | _tokens.append('[UNK]') |
| | return _tokens |
| |
|
| |
|
| | class DataProcessor(object): |
| | """Base class for data converters for sequence classification data sets.""" |
| |
|
| | def get_train_examples(self, data_dir): |
| | """Gets a collection of `InputExample`s for the train set.""" |
| | raise NotImplementedError() |
| |
|
| | def get_dev_examples(self, data_dir): |
| | """Gets a collection of `InputExample`s for the dev set.""" |
| | raise NotImplementedError() |
| |
|
| | def get_labels(self): |
| | """Gets the list of labels for this data set.""" |
| | raise NotImplementedError() |
| |
|
| | @classmethod |
| | def _read_tsv(cls, input_file, quotechar=None): |
| | """Reads a tab separated value file.""" |
| | with open(input_file, "r", encoding="utf-8-sig") as f: |
| | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) |
| | lines = [] |
| | for line in reader: |
| | lines.append(line) |
| | return lines |
| |
|
| | @classmethod |
| | def _read_text(self, input_file): |
| | lines = [] |
| | with open(input_file, 'r') as f: |
| | words = [] |
| | labels = [] |
| | for line in f: |
| | if line.startswith("-DOCSTART-") or line == "" or line == "\n": |
| | if words: |
| | lines.append({"words": words, "labels": labels}) |
| | words = [] |
| | labels = [] |
| | else: |
| | splits = line.split(" ") |
| | words.append(splits[0]) |
| | if len(splits) > 1: |
| | labels.append(splits[-1].replace("\n", "")) |
| | else: |
| | |
| | labels.append("O") |
| | if words: |
| | lines.append({"words": words, "labels": labels}) |
| | return lines |
| |
|
| | @classmethod |
| | def _read_json(self, input_file): |
| | lines = [] |
| | with open(input_file, 'r', encoding='utf8') as f: |
| | for line in f: |
| | line = json.loads(line.strip()) |
| | text = line['text'] |
| | label_entities = line.get('label', None) |
| | words = list(text) |
| | labels = ['O'] * len(words) |
| | if label_entities is not None: |
| | for key, value in label_entities.items(): |
| | for sub_name, sub_index in value.items(): |
| | for start_index, end_index in sub_index: |
| | assert ''.join(words[start_index:end_index+1]) == sub_name |
| | if start_index == end_index: |
| | labels[start_index] = 'S-'+key |
| | else: |
| | if end_index - start_index == 1: |
| | labels[start_index] = 'B-' + key |
| | labels[end_index] = 'E-' + key |
| | else: |
| | labels[start_index] = 'B-' + key |
| | labels[start_index + 1:end_index] = ['I-' + key] * (len(sub_name) - 2) |
| | labels[end_index] = 'E-' + key |
| | lines.append({"words": words, "labels": labels}) |
| | return lines |
| |
|
| |
|
| | def get_entity_bios(seq, id2label, middle_prefix='I-'): |
| | """Gets entities from sequence. |
| | note: BIOS |
| | Args: |
| | seq (list): sequence of labels. |
| | Returns: |
| | list: list of (chunk_type, chunk_start, chunk_end). |
| | Example: |
| | # >>> seq = ['B-PER', 'I-PER', 'O', 'S-LOC'] |
| | # >>> get_entity_bios(seq) |
| | [['PER', 0,1], ['LOC', 3, 3]] |
| | """ |
| | chunks = [] |
| | chunk = [-1, -1, -1] |
| | for indx, tag in enumerate(seq): |
| | if not isinstance(tag, str): |
| | tag = id2label[tag] |
| | if tag.startswith("S-"): |
| | if chunk[2] != -1: |
| | chunks.append(chunk) |
| | chunk = [-1, -1, -1] |
| | chunk[1] = indx |
| | chunk[2] = indx |
| | chunk[0] = tag.split('-')[1] |
| | chunks.append(chunk) |
| | chunk = (-1, -1, -1) |
| | if tag.startswith("B-"): |
| | if chunk[2] != -1: |
| | chunks.append(chunk) |
| | chunk = [-1, -1, -1] |
| | chunk[1] = indx |
| | chunk[0] = tag.split('-')[1] |
| | elif tag.startswith(middle_prefix) and chunk[1] != -1: |
| | _type = tag.split('-')[1] |
| | if _type == chunk[0]: |
| | chunk[2] = indx |
| | if indx == len(seq) - 1: |
| | chunks.append(chunk) |
| | else: |
| | if chunk[2] != -1: |
| | chunks.append(chunk) |
| | chunk = [-1, -1, -1] |
| | return chunks |
| |
|
| |
|
| | def get_entity_bio(seq, id2label, middle_prefix='I-'): |
| | """Gets entities from sequence. |
| | note: BIO |
| | Args: |
| | seq (list): sequence of labels. |
| | Returns: |
| | list: list of (chunk_type, chunk_start, chunk_end). |
| | Example: |
| | seq = ['B-PER', 'I-PER', 'O', 'B-LOC'] |
| | get_entity_bio(seq) |
| | #output |
| | [['PER', 0,1], ['LOC', 3, 3]] |
| | """ |
| | chunks = [] |
| | chunk = [-1, -1, -1] |
| | for indx, tag in enumerate(seq): |
| | if not isinstance(tag, str): |
| | tag = id2label[tag] |
| | if tag.startswith("B-"): |
| | if chunk[2] != -1: |
| | chunks.append(chunk) |
| | chunk = [-1, -1, -1] |
| | chunk[1] = indx |
| | chunk[0] = tag.split('-')[1] |
| | chunk[2] = indx |
| | if indx == len(seq) - 1: |
| | chunks.append(chunk) |
| | elif tag.startswith(middle_prefix) and chunk[1] != -1: |
| | _type = tag.split('-')[1] |
| | if _type == chunk[0]: |
| | chunk[2] = indx |
| |
|
| | if indx == len(seq) - 1: |
| | chunks.append(chunk) |
| | else: |
| | if chunk[2] != -1: |
| | chunks.append(chunk) |
| | chunk = [-1, -1, -1] |
| | return chunks |
| |
|
| |
|
| | def get_entity_bioes(seq, id2label, middle_prefix='I-'): |
| | """Gets entities from sequence. |
| | note: BIOS |
| | Args: |
| | seq (list): sequence of labels. |
| | Returns: |
| | list: list of (chunk_type, chunk_start, chunk_end). |
| | Example: |
| | # >>> seq = ['B-PER', 'I-PER', 'O', 'S-LOC'] |
| | # >>> get_entity_bios(seq) |
| | [['PER', 0,1], ['LOC', 3, 3]] |
| | """ |
| | chunks = [] |
| | chunk = [-1, -1, -1] |
| | for indx, tag in enumerate(seq): |
| | if not isinstance(tag, str): |
| | tag = id2label[tag] |
| | if tag.startswith("S-"): |
| | if chunk[2] != -1: |
| | chunks.append(chunk) |
| | chunk = [-1, -1, -1] |
| | chunk[1] = indx |
| | chunk[2] = indx |
| | chunk[0] = tag.split('-')[1] |
| | chunks.append(chunk) |
| | chunk = (-1, -1, -1) |
| | if tag.startswith("B-"): |
| | if chunk[2] != -1: |
| | chunks.append(chunk) |
| | chunk = [-1, -1, -1] |
| | chunk[1] = indx |
| | chunk[0] = tag.split('-')[1] |
| | elif (tag.startswith(middle_prefix) or tag.startswith("E-")) and chunk[1] != -1: |
| | _type = tag.split('-')[1] |
| | if _type == chunk[0]: |
| | chunk[2] = indx |
| | if indx == len(seq) - 1: |
| | chunks.append(chunk) |
| | else: |
| | if chunk[2] != -1: |
| | chunks.append(chunk) |
| | chunk = [-1, -1, -1] |
| | return chunks |
| |
|
| |
|
| | def get_entities(seq, id2label, markup='bio', middle_prefix='I-'): |
| | ''' |
| | :param seq: |
| | :param id2label: |
| | :param markup: |
| | :return: |
| | ''' |
| | assert markup in ['bio', 'bios', 'bioes'] |
| | if markup == 'bio': |
| | return get_entity_bio(seq, id2label, middle_prefix) |
| | elif markup == 'bios': |
| | return get_entity_bios(seq, id2label, middle_prefix) |
| | else: |
| | return get_entity_bioes(seq, id2label, middle_prefix) |
| |
|
| |
|
| | def bert_extract_item(start_logits, end_logits): |
| | S = [] |
| | start_pred = torch.argmax(start_logits, -1).cpu().numpy()[0][1:-1] |
| | end_pred = torch.argmax(end_logits, -1).cpu().numpy()[0][1:-1] |
| | for i, s_l in enumerate(start_pred): |
| | if s_l == 0: |
| | continue |
| | for j, e_l in enumerate(end_pred[i:]): |
| | if s_l == e_l: |
| | S.append((s_l, i, i + j)) |
| | break |
| | return S |
| |
|