Spaces:
Sleeping
Sleeping
| import json | |
| import re | |
| import os | |
| import hashlib | |
| import onnxruntime as ort | |
| import numpy as np | |
| from typing import List, Dict, Set, Optional | |
| score_map = {'A': 5, 'B': 4, 'C': 3, 'D': 2, 'E': 1} | |
| class SentenceExtractor: | |
| def __init__( | |
| self, | |
| eval_keywords_path: str, | |
| model_path: str = "distilled_model.onnx", | |
| *, | |
| # 分句与聚合相关的可配置开关 | |
| merge_leading_punct: bool = True, | |
| min_sentence_char_len: int = 6, | |
| aggregation_mode: str = "max", # 可选:"max" | "mean" | |
| # 加减号阈值(>0 / <0 为原逻辑;建议适度提高到 2/-2) | |
| word_score_plus_threshold: int = 1, | |
| word_score_minus_threshold: int = -1, | |
| ): | |
| # 统一以文件所在目录为根,避免工作目录不同导致找不到资源 | |
| self.base_dir = os.path.dirname(os.path.abspath(__file__)) | |
| self.tokenizer_dir = self.base_dir | |
| # 允许传相对路径:自动转绝对 | |
| if not os.path.isabs(model_path): | |
| model_path = os.path.join(self.base_dir, model_path) | |
| if not os.path.isabs(eval_keywords_path): | |
| eval_keywords_path = os.path.join(self.base_dir, eval_keywords_path) | |
| self.eval_keywords = self._load_eval_keywords(eval_keywords_path) | |
| self.all_keywords = self._extract_all_keywords() | |
| self.ort_session = None | |
| self.input_name = None | |
| self.output_name = None | |
| # 配置项 | |
| self.merge_leading_punct = merge_leading_punct | |
| self.min_sentence_char_len = max(0, int(min_sentence_char_len)) | |
| self.aggregation_mode = aggregation_mode.lower().strip() | |
| if self.aggregation_mode not in {"max", "mean"}: | |
| self.aggregation_mode = "max" | |
| self.word_score_plus_threshold = int(word_score_plus_threshold) | |
| self.word_score_minus_threshold = int(word_score_minus_threshold) | |
| self.providers: Optional[List[str]] = None | |
| self.tokenizer_loaded: bool = False | |
| self.last_tokenizer_error: Optional[str] = None | |
| try: | |
| # 强制使用 CPU provider,避免某些环境下选择到不可用的 GPU provider 导致加载失败 | |
| self.ort_session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) | |
| self.input_name = self.ort_session.get_inputs()[0].name | |
| self.output_name = self.ort_session.get_outputs()[0].name | |
| try: | |
| self.providers = self.ort_session.get_providers() | |
| except Exception: | |
| self.providers = None | |
| print("ONNX 模型加载成功") | |
| self.model_loaded: bool = True | |
| except Exception as e: | |
| print(f"ONNX 模型加载失败: {e}") | |
| self.ort_session = None | |
| self.model_loaded = False | |
| # 记录模型文件信息,便于排查“用错模型”问题 | |
| try: | |
| self.model_path_abs: Optional[str] = os.path.abspath(model_path) | |
| self.model_sha256: Optional[str] = None | |
| if os.path.exists(model_path): | |
| sha = hashlib.sha256() | |
| with open(model_path, 'rb') as f: | |
| for chunk in iter(lambda: f.read(8192), b''): | |
| sha.update(chunk) | |
| self.model_sha256 = sha.hexdigest() | |
| except Exception: | |
| self.model_path_abs = None | |
| self.model_sha256 = None | |
| def _preprocess_text(self, text: str) -> np.ndarray: | |
| try: | |
| from transformers import AutoTokenizer | |
| # 1) 优先从与脚本同目录加载本地 tokenizer(部署一起带上 tokenizer.json 等文件) | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_dir, local_files_only=True) | |
| except Exception: | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_dir) | |
| except Exception: | |
| # 2) 兜底:在线模型(需要外网) | |
| tokenizer = AutoTokenizer.from_pretrained("uer/chinese_roberta_L-4_H-256") | |
| inputs = tokenizer( | |
| text, | |
| truncation=True, | |
| padding=True, | |
| max_length=512, | |
| return_tensors='np' | |
| ) | |
| self.tokenizer_loaded = True | |
| self.last_tokenizer_error = None | |
| return inputs | |
| except Exception as e: | |
| self.tokenizer_loaded = False | |
| self.last_tokenizer_error = str(e) | |
| # 继续抛出异常,由上层捕获并回退,同时记录原因 | |
| raise | |
| def _predict_grade_with_model(self, text: str) -> Dict[str, any]: | |
| try: | |
| if not self.ort_session: | |
| word_score = self._calculate_word_scores(text)["total_score"] | |
| grade = "C" | |
| if word_score > 1: | |
| grade = "B" | |
| if word_score < -1: | |
| grade = "D" | |
| return {"grade": grade, "source": "rule", "word_score_total": word_score} | |
| inputs = self._preprocess_text(text) | |
| model_input_names = [i.name for i in self.ort_session.get_inputs()] | |
| input_data = {} | |
| if isinstance(inputs, dict) and 'input_ids' in inputs: | |
| token_type = inputs.get('token_type_ids') | |
| attn = inputs.get('attention_mask') | |
| ids = inputs['input_ids'] | |
| for name in model_input_names: | |
| lowered = name.lower() | |
| if 'mask' in lowered: | |
| input_data[name] = attn if attn is not None else np.ones_like(ids) | |
| elif 'token_type' in lowered or 'segment' in lowered: | |
| if token_type is None: | |
| token_type = np.zeros_like(ids) | |
| input_data[name] = token_type | |
| elif 'input_ids' in lowered or 'input' in lowered or 'ids' in lowered: | |
| input_data[name] = ids | |
| else: | |
| input_data[name] = np.zeros_like(ids) | |
| else: | |
| target_input = self.input_name or (model_input_names[0] if model_input_names else 'input') | |
| input_data = {target_input: inputs} | |
| outputs = self.ort_session.run([self.output_name], input_data) | |
| predictions = outputs[0] | |
| grade_index = int(np.argmax(predictions)) | |
| grades = ['A', 'B', 'C', 'D', 'E'] | |
| probs = self._softmax(predictions)[0].tolist() | |
| return { | |
| "grade": grades[grade_index], | |
| "source": "model", | |
| "prob": float(probs[grade_index]), | |
| "probs": probs, | |
| "logits": predictions[0].tolist(), | |
| } | |
| except Exception as e: | |
| print(f"模型预测出错: {e}") | |
| word_score = self._calculate_word_scores(text)["total_score"] | |
| grade = "C" | |
| if word_score > 1: | |
| grade = "B" | |
| if word_score < -1: | |
| grade = "D" | |
| return { | |
| "grade": grade, | |
| "source": "rule", | |
| "word_score_total": word_score, | |
| "reason": str(e), | |
| "tokenizer_loaded": self.tokenizer_loaded, | |
| "last_tokenizer_error": self.last_tokenizer_error, | |
| } | |
| def _softmax(x: np.ndarray) -> np.ndarray: | |
| x = x - np.max(x, axis=-1, keepdims=True) | |
| exp_x = np.exp(x) | |
| return exp_x / np.sum(exp_x, axis=-1, keepdims=True) | |
| def _load_eval_keywords(self, file_path: str) -> Dict[str, Dict[str, List[str]]]: | |
| try: | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| return json.load(f) | |
| except Exception as e: | |
| print(f"加载评估关键词库失败: {e}") | |
| return {} | |
| def _extract_all_keywords(self) -> Set[str]: | |
| keywords_set = set() | |
| for category, types in self.eval_keywords.items(): | |
| for _, keywords in types.items(): | |
| keywords_set.update(keywords) | |
| return keywords_set | |
| def _split_into_sentences(self, text: str) -> List[str]: | |
| if not text: | |
| return [] | |
| # 先按强终止符切分 | |
| normalized = re.sub(r'([。!?.!?])', r'\1\n', text) | |
| normalized = re.sub(r'[;;]\s*', ';\n', normalized) | |
| candidates = [s.strip() for s in re.split(r'[\r\n]+', normalized) if s.strip()] | |
| # 长句再按逗号细分 | |
| rough_sentences: List[str] = [] | |
| for s in candidates: | |
| if len(s) > 80 and not re.search(r'[。!?.!?;;]', s): | |
| parts = re.split(r'[,,]', s) | |
| rough_sentences.extend([p.strip() for p in parts if p.strip()]) | |
| else: | |
| rough_sentences.append(s) | |
| # 合并以标点开头的碎片,并过滤超短句 | |
| sentences: List[str] = [] | |
| leading_punct_pattern = r'^[,,。;;::、\s]+' | |
| for s in rough_sentences: | |
| if self.merge_leading_punct and re.match(leading_punct_pattern, s): | |
| # 去掉前缀标点后并入上一句 | |
| cleaned = re.sub(leading_punct_pattern, '', s) | |
| if sentences: | |
| sentences[-1] = f"{sentences[-1]}{cleaned}" | |
| else: | |
| if cleaned: | |
| sentences.append(cleaned) | |
| continue | |
| # 过滤极短句(去标点长度) | |
| plain = re.sub(r'[,,。;;::、!!??\s]', '', s) | |
| if self.min_sentence_char_len > 0 and len(plain) < self.min_sentence_char_len: | |
| # 不直接丢弃:若有上一句,合并 | |
| if sentences: | |
| sentences[-1] = f"{sentences[-1]}{s}" | |
| else: | |
| sentences.append(s) | |
| continue | |
| sentences.append(s) | |
| return [s.strip() for s in sentences if s and s.strip()] | |
| def _fuzzy_match_keyword(self, sentence: str, keyword: str) -> bool: | |
| """更严格的中文关键词匹配。 | |
| - 长度 < 2 的关键词(如“好”)仅按分词后的精确词匹配,避免所有句子都命中。 | |
| - 其余关键词采用去标点后的包含匹配。 | |
| """ | |
| if not keyword: | |
| return False | |
| # 统一去空白 | |
| sentence = sentence.strip() | |
| keyword = keyword.strip() | |
| # 对极短关键词走分词精确匹配,避免过拟合 | |
| if len(keyword) < 2: | |
| try: | |
| import jieba # 已在 requirements 中 | |
| tokens = set(jieba.lcut(sentence)) | |
| return keyword in tokens | |
| except Exception: | |
| # 兜底:对极短词不做模糊匹配 | |
| return False | |
| # 一般关键词:去标点后做包含匹配 | |
| import string | |
| trans = str.maketrans('', '', string.punctuation) | |
| sentence_clean = sentence.translate(trans) | |
| keyword_clean = keyword.translate(trans) | |
| return keyword_clean in sentence_clean | |
| def _is_negated_positive(self, text: str, keyword: str) -> bool: | |
| """检测积极关键词是否被否定词修饰,例如: | |
| - 没有/无/不/非/未/并不/毫无 + 关键词 | |
| - 对以“有”开头的积极词(如“有创新性”),也匹配“没有/无/不/未/并不/毫无 + 去掉‘有’后的部分(如“创新性”)” | |
| - 缺乏/不足/欠缺/缺少/不具备 + 关键词 或 关键词去“有”后的部分 | |
| """ | |
| if not keyword: | |
| return False | |
| sentence = text.strip() | |
| neg_prefixes = [ | |
| "没有", "没", "无", "不", "非", "未", "并不", "并没有", "并无", "毫无" | |
| ] | |
| lack_prefixes = [ | |
| "缺乏", "不足", "欠缺", "缺少", "不具备", "不够" | |
| ] | |
| # 构建安全的正则片段 | |
| import re | |
| def any_prefix(prefixes: List[str]) -> str: | |
| return "(?:" + "|".join(re.escape(p) for p in prefixes) + ")" | |
| patterns: List[str] = [] | |
| # 直接:否定前缀 + 关键词 | |
| patterns.append(rf"{any_prefix(neg_prefixes)}\s*{re.escape(keyword)}") | |
| # 直接:缺乏类前缀 + 关键词 | |
| patterns.append(rf"{any_prefix(lack_prefixes)}\s*{re.escape(keyword)}") | |
| # 若积极词以“有”开头,额外匹配去掉“有”的尾部(例如 ‘有创新性’ → ‘创新性’) | |
| if keyword.startswith("有") and len(keyword) > 1: | |
| tail = keyword[1:] | |
| patterns.append(rf"{any_prefix(neg_prefixes)}\s*{re.escape(tail)}") | |
| patterns.append(rf"{any_prefix(lack_prefixes)}\s*{re.escape(tail)}") | |
| for pat in patterns: | |
| if re.search(pat, sentence): | |
| return True | |
| return False | |
| def _extract_relevant_sentences(self, text: str) -> List[str]: | |
| sentences = self._split_into_sentences(text) | |
| relevant_sentences = [] | |
| for sentence in sentences: | |
| for category in ["student_performance", "content_quality", "cross_scene"]: | |
| if category not in self.eval_keywords: | |
| continue | |
| for sentiment in ["positive", "negative", "nature", "suggestion"]: | |
| if sentiment not in self.eval_keywords[category]: | |
| continue | |
| for keyword in self.eval_keywords[category][sentiment]: | |
| if self._fuzzy_match_keyword(sentence, keyword): | |
| if sentence not in relevant_sentences: | |
| relevant_sentences.append(sentence) | |
| break | |
| else: | |
| continue | |
| break | |
| else: | |
| continue | |
| break | |
| return relevant_sentences | |
| def _calculate_word_scores(self, text: str) -> Dict[str, int]: | |
| positive_count = 0 | |
| negative_count = 0 | |
| neutral_count = 0 | |
| total_score = 0 | |
| for category in ["student_performance", "content_quality", "cross_scene"]: | |
| if category not in self.eval_keywords: | |
| continue | |
| for keyword in self.eval_keywords[category].get("positive", []): | |
| if self._fuzzy_match_keyword(text, keyword): | |
| # 遇到被否定的积极词(如“没有创新性”含“有创新性”),按消极计分 | |
| if self._is_negated_positive(text, keyword): | |
| negative_count += 1 | |
| total_score -= 1 | |
| else: | |
| positive_count += 1 | |
| total_score += 1 | |
| for keyword in self.eval_keywords[category].get("negative", []): | |
| if self._fuzzy_match_keyword(text, keyword): | |
| negative_count += 1 | |
| total_score -= 1 | |
| for keyword in self.eval_keywords[category].get("nature", []): | |
| if self._fuzzy_match_keyword(text, keyword): | |
| neutral_count += 1 | |
| return { | |
| "positive_count": positive_count, | |
| "negative_count": negative_count, | |
| "neutral_count": neutral_count, | |
| "total_score": total_score, | |
| } | |
| def extract(self, text: str) -> Dict[str, any]: | |
| if not text: | |
| return { | |
| "comprehensive_grade": "C", | |
| "positive_word_count": 0, | |
| "negative_word_count": 0, | |
| "neutral_word_count": 0, | |
| "scored_sentences": [], | |
| "count": 0, | |
| } | |
| relevant_sentences = self._extract_relevant_sentences(text) | |
| scored_sentences = [] | |
| total_sentence_score = 0 | |
| for sentence in relevant_sentences: | |
| info = self._predict_grade_with_model(sentence) | |
| grade = info.get("grade", "C") | |
| score = score_map.get(grade, 3) | |
| # 附带调试信息 | |
| scored_sentences.append({ | |
| "sentence": sentence, | |
| "grade": grade, | |
| "source": info.get("source", "unknown"), | |
| "prob": info.get("prob"), | |
| "word_score_total": info.get("word_score_total"), | |
| }) | |
| total_sentence_score += score | |
| comprehensive_grade = "C" | |
| if relevant_sentences: | |
| reverse_map = {5: 'A', 4: 'B', 3: 'C', 2: 'D', 1: 'E'} | |
| if self.aggregation_mode == "max": | |
| # 取最高等级(更鲁棒,避免短碎句拉低均值) | |
| max_score = max(score_map.get(item["grade"], 3) for item in scored_sentences) | |
| comprehensive_grade = reverse_map.get(max_score, "C") | |
| else: | |
| avg_score = total_sentence_score / len(relevant_sentences) | |
| rounded_score = int(round(avg_score)) | |
| comprehensive_grade = reverse_map.get(rounded_score, "C") | |
| word_scores = self._calculate_word_scores(text) | |
| final_grade = comprehensive_grade | |
| if word_scores["total_score"] > self.word_score_plus_threshold: | |
| final_grade = comprehensive_grade + "+" | |
| elif word_scores["total_score"] < self.word_score_minus_threshold: | |
| final_grade = comprehensive_grade + "-" | |
| return { | |
| "comprehensive_grade": final_grade, | |
| "positive_word_count": word_scores["positive_count"], | |
| "negative_word_count": word_scores["negative_count"], | |
| "neutral_word_count": word_scores["neutral_count"], | |
| "scored_sentences": scored_sentences, | |
| "count": len(relevant_sentences), | |
| # 调试字段 | |
| "debug": { | |
| "model_loaded": getattr(self, "model_loaded", False), | |
| "model_path_abs": getattr(self, "model_path_abs", None), | |
| "model_sha256": getattr(self, "model_sha256", None), | |
| "providers": self.providers, | |
| "tokenizer_loaded": self.tokenizer_loaded, | |
| "last_tokenizer_error": self.last_tokenizer_error, | |
| "aggregation_mode": self.aggregation_mode, | |
| "min_sentence_char_len": self.min_sentence_char_len, | |
| "merge_leading_punct": self.merge_leading_punct, | |
| "word_score_plus_threshold": self.word_score_plus_threshold, | |
| "word_score_minus_threshold": self.word_score_minus_threshold, | |
| "relevant_sentences": relevant_sentences, | |
| "word_score_total": word_scores["total_score"], | |
| } | |
| } | |