Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| # 1. 设置模型路径 | |
| model_dir = "." | |
| # 2. 加载模型 | |
| # 增加容错机制:如果加载训练好的模型失败,自动加载官方小模型作为兜底,防止网页打不开 | |
| try: | |
| config = AutoConfig.from_pretrained(model_dir) | |
| tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_dir, config=config) | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| print("Loading fallback model (bert-tiny) for debugging...") | |
| model_dir = "prajjwal1/bert-tiny" | |
| config = AutoConfig.from_pretrained(model_dir) | |
| tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_dir, config=config) | |
| # 3. 定义推理函数 | |
| def inference(input_text): | |
| if not input_text: | |
| return "Please input some text." | |
| inputs = tokenizer.batch_encode_plus( | |
| [input_text], | |
| max_length=512, | |
| truncation=True, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| predicted_class_id = logits.argmax().item() | |
| # 获取标签映射 | |
| if hasattr(model.config, "id2label") and model.config.id2label: | |
| output = model.config.id2label[predicted_class_id] | |
| else: | |
| # 手动映射:根据你的训练数据,通常 0 是正常,1 是灾难 | |
| labels = {0: "Normal (正常)", 1: "Disaster (灾难)"} | |
| output = labels.get(predicted_class_id, f"Class {predicted_class_id}") | |
| return output | |
| # 4. 构建界面 (移除了 css 参数) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("### 灾难推文检测器 (Disaster Tweet Detector)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_text = gr.Textbox(placeholder="Insert your prompt here:", label="Input Text") | |
| generate_bt = gr.Button("Generate", variant="primary") | |
| with gr.Column(): | |
| answer = gr.Textbox(label="Prediction Result") | |
| # 绑定点击事件 | |
| generate_bt.click( | |
| fn=inference, | |
| inputs=[input_text], | |
| outputs=[answer] | |
| ) | |
| # 示例数据 | |
| gr.Examples( | |
| examples=[ | |
| ["The sky is blue and I am happy."], | |
| ["Huge fire in the downtown building! We need help!"] | |
| ], | |
| inputs=input_text | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.launch() |