NLP-HOMEWORK / app.py
yujieyyj's picture
Update app.py
fc7dfef verified
import gradio as gr
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification
import torch
# 1. 设置模型路径
model_dir = "."
# 2. 加载模型
# 增加容错机制:如果加载训练好的模型失败,自动加载官方小模型作为兜底,防止网页打不开
try:
config = AutoConfig.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(model_dir, config=config)
except Exception as e:
print(f"Error loading model: {e}")
print("Loading fallback model (bert-tiny) for debugging...")
model_dir = "prajjwal1/bert-tiny"
config = AutoConfig.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(model_dir, config=config)
# 3. 定义推理函数
def inference(input_text):
if not input_text:
return "Please input some text."
inputs = tokenizer.batch_encode_plus(
[input_text],
max_length=512,
truncation=True,
padding="max_length",
return_tensors="pt",
)
with torch.no_grad():
logits = model(**inputs).logits
predicted_class_id = logits.argmax().item()
# 获取标签映射
if hasattr(model.config, "id2label") and model.config.id2label:
output = model.config.id2label[predicted_class_id]
else:
# 手动映射:根据你的训练数据,通常 0 是正常,1 是灾难
labels = {0: "Normal (正常)", 1: "Disaster (灾难)"}
output = labels.get(predicted_class_id, f"Class {predicted_class_id}")
return output
# 4. 构建界面 (移除了 css 参数)
with gr.Blocks() as demo:
gr.Markdown("### 灾难推文检测器 (Disaster Tweet Detector)")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(placeholder="Insert your prompt here:", label="Input Text")
generate_bt = gr.Button("Generate", variant="primary")
with gr.Column():
answer = gr.Textbox(label="Prediction Result")
# 绑定点击事件
generate_bt.click(
fn=inference,
inputs=[input_text],
outputs=[answer]
)
# 示例数据
gr.Examples(
examples=[
["The sky is blue and I am happy."],
["Huge fire in the downtown building! We need help!"]
],
inputs=input_text
)
if __name__ == "__main__":
demo.queue()
demo.launch()