teacher-evaluation-api / predictor.py
Wind-xixi's picture
Update predictor.py
bdb326c verified
import json
import re
import os
import hashlib
import onnxruntime as ort
import numpy as np
from typing import List, Dict, Set, Optional
score_map = {'A': 5, 'B': 4, 'C': 3, 'D': 2, 'E': 1}
class SentenceExtractor:
def __init__(
self,
eval_keywords_path: str,
model_path: str = "distilled_model.onnx",
*,
# 分句与聚合相关的可配置开关
merge_leading_punct: bool = True,
min_sentence_char_len: int = 6,
aggregation_mode: str = "max", # 可选:"max" | "mean"
# 加减号阈值(>0 / <0 为原逻辑;建议适度提高到 2/-2)
word_score_plus_threshold: int = 1,
word_score_minus_threshold: int = -1,
):
# 统一以文件所在目录为根,避免工作目录不同导致找不到资源
self.base_dir = os.path.dirname(os.path.abspath(__file__))
self.tokenizer_dir = self.base_dir
# 允许传相对路径:自动转绝对
if not os.path.isabs(model_path):
model_path = os.path.join(self.base_dir, model_path)
if not os.path.isabs(eval_keywords_path):
eval_keywords_path = os.path.join(self.base_dir, eval_keywords_path)
self.eval_keywords = self._load_eval_keywords(eval_keywords_path)
self.all_keywords = self._extract_all_keywords()
self.ort_session = None
self.input_name = None
self.output_name = None
# 配置项
self.merge_leading_punct = merge_leading_punct
self.min_sentence_char_len = max(0, int(min_sentence_char_len))
self.aggregation_mode = aggregation_mode.lower().strip()
if self.aggregation_mode not in {"max", "mean"}:
self.aggregation_mode = "max"
self.word_score_plus_threshold = int(word_score_plus_threshold)
self.word_score_minus_threshold = int(word_score_minus_threshold)
self.providers: Optional[List[str]] = None
self.tokenizer_loaded: bool = False
self.last_tokenizer_error: Optional[str] = None
try:
# 强制使用 CPU provider,避免某些环境下选择到不可用的 GPU provider 导致加载失败
self.ort_session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
self.input_name = self.ort_session.get_inputs()[0].name
self.output_name = self.ort_session.get_outputs()[0].name
try:
self.providers = self.ort_session.get_providers()
except Exception:
self.providers = None
print("ONNX 模型加载成功")
self.model_loaded: bool = True
except Exception as e:
print(f"ONNX 模型加载失败: {e}")
self.ort_session = None
self.model_loaded = False
# 记录模型文件信息,便于排查“用错模型”问题
try:
self.model_path_abs: Optional[str] = os.path.abspath(model_path)
self.model_sha256: Optional[str] = None
if os.path.exists(model_path):
sha = hashlib.sha256()
with open(model_path, 'rb') as f:
for chunk in iter(lambda: f.read(8192), b''):
sha.update(chunk)
self.model_sha256 = sha.hexdigest()
except Exception:
self.model_path_abs = None
self.model_sha256 = None
def _preprocess_text(self, text: str) -> np.ndarray:
try:
from transformers import AutoTokenizer
# 1) 优先从与脚本同目录加载本地 tokenizer(部署一起带上 tokenizer.json 等文件)
try:
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_dir, local_files_only=True)
except Exception:
try:
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_dir)
except Exception:
# 2) 兜底:在线模型(需要外网)
tokenizer = AutoTokenizer.from_pretrained("uer/chinese_roberta_L-4_H-256")
inputs = tokenizer(
text,
truncation=True,
padding=True,
max_length=512,
return_tensors='np'
)
self.tokenizer_loaded = True
self.last_tokenizer_error = None
return inputs
except Exception as e:
self.tokenizer_loaded = False
self.last_tokenizer_error = str(e)
# 继续抛出异常,由上层捕获并回退,同时记录原因
raise
def _predict_grade_with_model(self, text: str) -> Dict[str, any]:
try:
if not self.ort_session:
word_score = self._calculate_word_scores(text)["total_score"]
grade = "C"
if word_score > 1:
grade = "B"
if word_score < -1:
grade = "D"
return {"grade": grade, "source": "rule", "word_score_total": word_score}
inputs = self._preprocess_text(text)
model_input_names = [i.name for i in self.ort_session.get_inputs()]
input_data = {}
if isinstance(inputs, dict) and 'input_ids' in inputs:
token_type = inputs.get('token_type_ids')
attn = inputs.get('attention_mask')
ids = inputs['input_ids']
for name in model_input_names:
lowered = name.lower()
if 'mask' in lowered:
input_data[name] = attn if attn is not None else np.ones_like(ids)
elif 'token_type' in lowered or 'segment' in lowered:
if token_type is None:
token_type = np.zeros_like(ids)
input_data[name] = token_type
elif 'input_ids' in lowered or 'input' in lowered or 'ids' in lowered:
input_data[name] = ids
else:
input_data[name] = np.zeros_like(ids)
else:
target_input = self.input_name or (model_input_names[0] if model_input_names else 'input')
input_data = {target_input: inputs}
outputs = self.ort_session.run([self.output_name], input_data)
predictions = outputs[0]
grade_index = int(np.argmax(predictions))
grades = ['A', 'B', 'C', 'D', 'E']
probs = self._softmax(predictions)[0].tolist()
return {
"grade": grades[grade_index],
"source": "model",
"prob": float(probs[grade_index]),
"probs": probs,
"logits": predictions[0].tolist(),
}
except Exception as e:
print(f"模型预测出错: {e}")
word_score = self._calculate_word_scores(text)["total_score"]
grade = "C"
if word_score > 1:
grade = "B"
if word_score < -1:
grade = "D"
return {
"grade": grade,
"source": "rule",
"word_score_total": word_score,
"reason": str(e),
"tokenizer_loaded": self.tokenizer_loaded,
"last_tokenizer_error": self.last_tokenizer_error,
}
@staticmethod
def _softmax(x: np.ndarray) -> np.ndarray:
x = x - np.max(x, axis=-1, keepdims=True)
exp_x = np.exp(x)
return exp_x / np.sum(exp_x, axis=-1, keepdims=True)
def _load_eval_keywords(self, file_path: str) -> Dict[str, Dict[str, List[str]]]:
try:
with open(file_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
print(f"加载评估关键词库失败: {e}")
return {}
def _extract_all_keywords(self) -> Set[str]:
keywords_set = set()
for category, types in self.eval_keywords.items():
for _, keywords in types.items():
keywords_set.update(keywords)
return keywords_set
def _split_into_sentences(self, text: str) -> List[str]:
if not text:
return []
# 先按强终止符切分
normalized = re.sub(r'([。!?.!?])', r'\1\n', text)
normalized = re.sub(r'[;;]\s*', ';\n', normalized)
candidates = [s.strip() for s in re.split(r'[\r\n]+', normalized) if s.strip()]
# 长句再按逗号细分
rough_sentences: List[str] = []
for s in candidates:
if len(s) > 80 and not re.search(r'[。!?.!?;;]', s):
parts = re.split(r'[,,]', s)
rough_sentences.extend([p.strip() for p in parts if p.strip()])
else:
rough_sentences.append(s)
# 合并以标点开头的碎片,并过滤超短句
sentences: List[str] = []
leading_punct_pattern = r'^[,,。;;::、\s]+'
for s in rough_sentences:
if self.merge_leading_punct and re.match(leading_punct_pattern, s):
# 去掉前缀标点后并入上一句
cleaned = re.sub(leading_punct_pattern, '', s)
if sentences:
sentences[-1] = f"{sentences[-1]}{cleaned}"
else:
if cleaned:
sentences.append(cleaned)
continue
# 过滤极短句(去标点长度)
plain = re.sub(r'[,,。;;::、!!??\s]', '', s)
if self.min_sentence_char_len > 0 and len(plain) < self.min_sentence_char_len:
# 不直接丢弃:若有上一句,合并
if sentences:
sentences[-1] = f"{sentences[-1]}{s}"
else:
sentences.append(s)
continue
sentences.append(s)
return [s.strip() for s in sentences if s and s.strip()]
def _fuzzy_match_keyword(self, sentence: str, keyword: str) -> bool:
"""更严格的中文关键词匹配。
- 长度 < 2 的关键词(如“好”)仅按分词后的精确词匹配,避免所有句子都命中。
- 其余关键词采用去标点后的包含匹配。
"""
if not keyword:
return False
# 统一去空白
sentence = sentence.strip()
keyword = keyword.strip()
# 对极短关键词走分词精确匹配,避免过拟合
if len(keyword) < 2:
try:
import jieba # 已在 requirements 中
tokens = set(jieba.lcut(sentence))
return keyword in tokens
except Exception:
# 兜底:对极短词不做模糊匹配
return False
# 一般关键词:去标点后做包含匹配
import string
trans = str.maketrans('', '', string.punctuation)
sentence_clean = sentence.translate(trans)
keyword_clean = keyword.translate(trans)
return keyword_clean in sentence_clean
def _is_negated_positive(self, text: str, keyword: str) -> bool:
"""检测积极关键词是否被否定词修饰,例如:
- 没有/无/不/非/未/并不/毫无 + 关键词
- 对以“有”开头的积极词(如“有创新性”),也匹配“没有/无/不/未/并不/毫无 + 去掉‘有’后的部分(如“创新性”)”
- 缺乏/不足/欠缺/缺少/不具备 + 关键词 或 关键词去“有”后的部分
"""
if not keyword:
return False
sentence = text.strip()
neg_prefixes = [
"没有", "没", "无", "不", "非", "未", "并不", "并没有", "并无", "毫无"
]
lack_prefixes = [
"缺乏", "不足", "欠缺", "缺少", "不具备", "不够"
]
# 构建安全的正则片段
import re
def any_prefix(prefixes: List[str]) -> str:
return "(?:" + "|".join(re.escape(p) for p in prefixes) + ")"
patterns: List[str] = []
# 直接:否定前缀 + 关键词
patterns.append(rf"{any_prefix(neg_prefixes)}\s*{re.escape(keyword)}")
# 直接:缺乏类前缀 + 关键词
patterns.append(rf"{any_prefix(lack_prefixes)}\s*{re.escape(keyword)}")
# 若积极词以“有”开头,额外匹配去掉“有”的尾部(例如 ‘有创新性’ → ‘创新性’)
if keyword.startswith("有") and len(keyword) > 1:
tail = keyword[1:]
patterns.append(rf"{any_prefix(neg_prefixes)}\s*{re.escape(tail)}")
patterns.append(rf"{any_prefix(lack_prefixes)}\s*{re.escape(tail)}")
for pat in patterns:
if re.search(pat, sentence):
return True
return False
def _extract_relevant_sentences(self, text: str) -> List[str]:
sentences = self._split_into_sentences(text)
relevant_sentences = []
for sentence in sentences:
for category in ["student_performance", "content_quality", "cross_scene"]:
if category not in self.eval_keywords:
continue
for sentiment in ["positive", "negative", "nature", "suggestion"]:
if sentiment not in self.eval_keywords[category]:
continue
for keyword in self.eval_keywords[category][sentiment]:
if self._fuzzy_match_keyword(sentence, keyword):
if sentence not in relevant_sentences:
relevant_sentences.append(sentence)
break
else:
continue
break
else:
continue
break
return relevant_sentences
def _calculate_word_scores(self, text: str) -> Dict[str, int]:
positive_count = 0
negative_count = 0
neutral_count = 0
total_score = 0
for category in ["student_performance", "content_quality", "cross_scene"]:
if category not in self.eval_keywords:
continue
for keyword in self.eval_keywords[category].get("positive", []):
if self._fuzzy_match_keyword(text, keyword):
# 遇到被否定的积极词(如“没有创新性”含“有创新性”),按消极计分
if self._is_negated_positive(text, keyword):
negative_count += 1
total_score -= 1
else:
positive_count += 1
total_score += 1
for keyword in self.eval_keywords[category].get("negative", []):
if self._fuzzy_match_keyword(text, keyword):
negative_count += 1
total_score -= 1
for keyword in self.eval_keywords[category].get("nature", []):
if self._fuzzy_match_keyword(text, keyword):
neutral_count += 1
return {
"positive_count": positive_count,
"negative_count": negative_count,
"neutral_count": neutral_count,
"total_score": total_score,
}
def extract(self, text: str) -> Dict[str, any]:
if not text:
return {
"comprehensive_grade": "C",
"positive_word_count": 0,
"negative_word_count": 0,
"neutral_word_count": 0,
"scored_sentences": [],
"count": 0,
}
relevant_sentences = self._extract_relevant_sentences(text)
scored_sentences = []
total_sentence_score = 0
for sentence in relevant_sentences:
info = self._predict_grade_with_model(sentence)
grade = info.get("grade", "C")
score = score_map.get(grade, 3)
# 附带调试信息
scored_sentences.append({
"sentence": sentence,
"grade": grade,
"source": info.get("source", "unknown"),
"prob": info.get("prob"),
"word_score_total": info.get("word_score_total"),
})
total_sentence_score += score
comprehensive_grade = "C"
if relevant_sentences:
reverse_map = {5: 'A', 4: 'B', 3: 'C', 2: 'D', 1: 'E'}
if self.aggregation_mode == "max":
# 取最高等级(更鲁棒,避免短碎句拉低均值)
max_score = max(score_map.get(item["grade"], 3) for item in scored_sentences)
comprehensive_grade = reverse_map.get(max_score, "C")
else:
avg_score = total_sentence_score / len(relevant_sentences)
rounded_score = int(round(avg_score))
comprehensive_grade = reverse_map.get(rounded_score, "C")
word_scores = self._calculate_word_scores(text)
final_grade = comprehensive_grade
if word_scores["total_score"] > self.word_score_plus_threshold:
final_grade = comprehensive_grade + "+"
elif word_scores["total_score"] < self.word_score_minus_threshold:
final_grade = comprehensive_grade + "-"
return {
"comprehensive_grade": final_grade,
"positive_word_count": word_scores["positive_count"],
"negative_word_count": word_scores["negative_count"],
"neutral_word_count": word_scores["neutral_count"],
"scored_sentences": scored_sentences,
"count": len(relevant_sentences),
# 调试字段
"debug": {
"model_loaded": getattr(self, "model_loaded", False),
"model_path_abs": getattr(self, "model_path_abs", None),
"model_sha256": getattr(self, "model_sha256", None),
"providers": self.providers,
"tokenizer_loaded": self.tokenizer_loaded,
"last_tokenizer_error": self.last_tokenizer_error,
"aggregation_mode": self.aggregation_mode,
"min_sentence_char_len": self.min_sentence_char_len,
"merge_leading_punct": self.merge_leading_punct,
"word_score_plus_threshold": self.word_score_plus_threshold,
"word_score_minus_threshold": self.word_score_minus_threshold,
"relevant_sentences": relevant_sentences,
"word_score_total": word_scores["total_score"],
}
}