import gradio as gr from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification import torch # 1. 设置模型路径 model_dir = "." # 2. 加载模型 # 增加容错机制:如果加载训练好的模型失败,自动加载官方小模型作为兜底,防止网页打不开 try: config = AutoConfig.from_pretrained(model_dir) tokenizer = AutoTokenizer.from_pretrained(model_dir) model = AutoModelForSequenceClassification.from_pretrained(model_dir, config=config) except Exception as e: print(f"Error loading model: {e}") print("Loading fallback model (bert-tiny) for debugging...") model_dir = "prajjwal1/bert-tiny" config = AutoConfig.from_pretrained(model_dir) tokenizer = AutoTokenizer.from_pretrained(model_dir) model = AutoModelForSequenceClassification.from_pretrained(model_dir, config=config) # 3. 定义推理函数 def inference(input_text): if not input_text: return "Please input some text." inputs = tokenizer.batch_encode_plus( [input_text], max_length=512, truncation=True, padding="max_length", return_tensors="pt", ) with torch.no_grad(): logits = model(**inputs).logits predicted_class_id = logits.argmax().item() # 获取标签映射 if hasattr(model.config, "id2label") and model.config.id2label: output = model.config.id2label[predicted_class_id] else: # 手动映射:根据你的训练数据,通常 0 是正常,1 是灾难 labels = {0: "Normal (正常)", 1: "Disaster (灾难)"} output = labels.get(predicted_class_id, f"Class {predicted_class_id}") return output # 4. 构建界面 (移除了 css 参数) with gr.Blocks() as demo: gr.Markdown("### 灾难推文检测器 (Disaster Tweet Detector)") with gr.Row(): with gr.Column(): input_text = gr.Textbox(placeholder="Insert your prompt here:", label="Input Text") generate_bt = gr.Button("Generate", variant="primary") with gr.Column(): answer = gr.Textbox(label="Prediction Result") # 绑定点击事件 generate_bt.click( fn=inference, inputs=[input_text], outputs=[answer] ) # 示例数据 gr.Examples( examples=[ ["The sky is blue and I am happy."], ["Huge fire in the downtown building! We need help!"] ], inputs=input_text ) if __name__ == "__main__": demo.queue() demo.launch()