| | from dataclasses import dataclass |
| | from typing import Optional, List |
| | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig |
| | import regex as re |
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | PROGRAM_SPECIAL_TOKEN="<extra_id_124>" |
| | UTTERANCES_SPECIAL_TOKEN="<extra_id_123>" |
| | GT_PROGRAM_SPECIAL_TOKEN="<extra_id_122>" |
| |
|
| | def consistent(rx, spec): |
| | |
| | for s, label in spec: |
| | if not label in ['+', '-']: |
| | return None |
| | try: |
| | if re.fullmatch(rx, s, timeout=1): |
| | if label == '-': |
| | return False |
| | else: |
| | if label == '+': |
| | return False |
| | except re.error: |
| | return None |
| | except TimeoutError: |
| | return None |
| |
|
| | return True |
| |
|
| | def get_utterance_processing_functions(label_pos, idx, separator=' '): |
| | if label_pos == "suffix": |
| | if idx: |
| | def utterances_to_string(spec): |
| | return ''.join([f"<extra_id_{i}>{s}{label}" for i, (s, label) in enumerate(spec)]) |
| | else: |
| | def utterances_to_string(spec): |
| | return separator.join([f"{s}{label}" for s, label in spec]) |
| | else: |
| | if idx: |
| | def utterances_to_string(spec): |
| | return ''.join([f"<extra_id_{i}>{label}{s}" for i, (s, label) in enumerate(spec)]) |
| | else: |
| | def utterances_to_string(spec): |
| | return separator.join([f"{label}{s}" for s, label in spec]) |
| | |
| | if label_pos == "suffix": |
| | if idx: |
| | def string_to_utterances(string): |
| | string = re.sub(r'<extra_id_\d+>', ' ', string) |
| | return [(s[:-1], s[-1]) for s in string.split(' ') if len(s) > 0] |
| | else: |
| | def string_to_utterances(string): |
| | return [(s[:-1], s[-1]) for s in string.split(separator) if len(s) > 0] |
| | else: |
| | if idx: |
| | def string_to_utterances(string): |
| | string = re.sub(r'<extra_id_\d+>', '', string) |
| | return [(s[1:], s[0]) for s in string.split(separator) if len(s) > 0] |
| | else: |
| | def string_to_utterances(string): |
| | return [(s[1:], s[0]) for s in string.split(separator) if len(s) > 0] |
| | |
| | return utterances_to_string, string_to_utterances |
| |
|
| | def decode(c): |
| | if c < 3: |
| | return f"<{c}>" |
| | elif c < 258: |
| | return chr(c - 3) |
| | else: |
| | return f"<extra_id_{c - 259}>" |
| | |
| | def byt5_decode_batch(outputs, skip_special_tokens=True, skip_position_token=False): |
| | skipped_tokens = outputs |
| | if skip_special_tokens: |
| | skipped_tokens = [ |
| | [[t for t in x if t >= 3] for x in beam] |
| | for beam in skipped_tokens |
| | ] |
| | |
| | if skip_position_token: |
| | skipped_tokens = [ |
| | [[t for t in x if t <= 258] for x in beam] |
| | for beam in skipped_tokens |
| | ] |
| |
|
| | return [ |
| | [''.join([decode(t) for t in x]) for x in beam] |
| | for beam in skipped_tokens |
| | ] |
| |
|
| | class Agent: |
| | def __init__(self, |
| | model_path: str, |
| | gen_config: dict, |
| | device: str = "cuda", |
| | ): |
| | self.device = device |
| | self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(device) |
| | self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
| | self.gen_config = GenerationConfig(**gen_config) |
| |
|
| | @dataclass |
| | class ListenerOutput: |
| | programs: List[List[str]] |
| | idx: Optional[List[List[int]]] = None |
| | decoded: Optional[List[List[str]]] = None |
| | decoded_scores: Optional[List[List[float]]] = None |
| | pruned: Optional[List[List[str]]] = None |
| |
|
| |
|
| | class Listener(Agent): |
| | def __init__(self, |
| | model_path, |
| | gen_config, |
| | device="cuda", |
| | label_pos="suffix", |
| | idx: bool=True, |
| | program_special_token=PROGRAM_SPECIAL_TOKEN, |
| | utterances_special_token=UTTERANCES_SPECIAL_TOKEN |
| | ): |
| | super().__init__( |
| | model_path, |
| | gen_config, |
| | device=device |
| | ) |
| | self.label_pos = label_pos |
| | self.idx = idx |
| | self.program_special_token = program_special_token |
| | self.utterances_special_token = utterances_special_token |
| | self.utterances_to_string, self.string_to_utterances = ( |
| | get_utterance_processing_functions( |
| | label_pos, idx, separator=utterances_special_token |
| | ) |
| | ) |
| | |
| | def synthesize(self, context, return_scores=False, enforce_consistency=True): |
| | |
| | if isinstance(context[0], list): |
| | context_str = list(map(self.utterances_to_string, context)) |
| | else: |
| | context_str = context |
| |
|
| | context_tokens = self.tokenizer( |
| | [f"{self.utterances_special_token}{c}" if not c.startswith(self.utterances_special_token) else c |
| | for c in context_str], |
| | return_tensors="pt", |
| | padding=True |
| | ).to(self.device) |
| | |
| | decoder_inputs = self.tokenizer( |
| | [self.program_special_token for _ in context], return_tensors="pt", |
| | add_special_tokens=False |
| | ).to(self.device) |
| |
|
| | outputs = self.model.generate(**context_tokens, |
| | decoder_input_ids=decoder_inputs.input_ids, |
| | generation_config=self.gen_config, |
| | return_dict_in_generate=True, |
| | output_scores=True |
| | ) |
| |
|
| | decoded_batch = byt5_decode_batch(outputs.sequences.reshape((len(context), -1, outputs.sequences.shape[-1])).tolist(), skip_position_token=True, skip_special_tokens=True) |
| |
|
| | consistent_programs = [] |
| | idxs = [] |
| | for decoded, ctx in zip(decoded_batch, context): |
| | cp = [] |
| | idx = [] |
| | for i, p in enumerate(decoded): |
| | if enforce_consistency: |
| | if consistent(p, ctx): |
| | cp.append(p) |
| | idx.append(i) |
| | else: |
| | cp.append(p) |
| | idx.append(i) |
| | |
| | consistent_programs.append(cp) |
| | idxs.append(idx) |
| | |
| | logprobs = torch.stack(outputs.scores, dim=1).log_softmax(dim=-1) |
| | gen_probs = torch.gather(logprobs, 2, outputs.sequences[:, 1:, None]).squeeze(-1) |
| | gen_probs.masked_fill_(gen_probs.isinf(), 0) |
| | scores = gen_probs.sum(-1) |
| | n_decoded = scores.shape[0] |
| | n_seq = n_decoded // len(context) |
| | scores = scores.reshape((len(context), n_seq)) |
| | scores_list = scores.tolist() |
| |
|
| | if return_scores: |
| | return ListenerOutput( |
| | consistent_programs, |
| | idxs, |
| | decoded_batch, |
| | scores_list |
| | ) |
| | else: |
| | return ListenerOutput(consistent_programs) |
| |
|
| | |
| | def score_program(self, contexts, programs): |
| | if isinstance(contexts[0], list): |
| | context_str = list(map(self.utterances_to_string, contexts)) |
| | else: |
| | context_str = contexts |
| |
|
| | context_tokens = self.tokenizer( |
| | [f"{self.utterances_special_token}{c}" if not c.startswith(self.utterances_special_token) else c |
| | for c in context_str], |
| | return_tensors="pt", |
| | padding=True |
| | ).to(self.device) |
| |
|
| | program_tokens = self.tokenizer([f"{self.program_special_token}{p}" for p in programs], return_tensors="pt").to(self.device) |
| | outputs = self.model(input_ids=context_tokens.input_ids, decoder_input_ids=program_tokens.input_ids, return_dict=True) |
| | |
| | logprobs = torch.gather(F.log_softmax(outputs.logits, dim=-1), 2, program_tokens.input_ids[:, 1:, None]).squeeze(-1) |
| | |
| | logprobs.masked_fill_(program_tokens.input_ids[:, 1:] == 0, 0) |
| |
|
| | scores = logprobs.sum(-1) |
| | |
| | return scores.tolist() |