| import numpy as np |
| import os |
| import pandas as pd |
| import torch |
| import matplotlib.pyplot as plt |
| from transformers import XLMRobertaModel, XLMRobertaTokenizer |
| import torch.nn as nn |
| import gradio as gr |
| from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay |
| from sklearn.model_selection import train_test_split |
| from sklearn.metrics import classification_report |
| from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoModelForSequenceClassification |
|
|
| |
| bert = XLMRobertaModel.from_pretrained('castorini/afriberta_large') |
| tokenizer = XLMRobertaTokenizer.from_pretrained('castorini/afriberta_large') |
|
|
| |
| class BERT_Arch(nn.Module): |
| def __init__(self, bert): |
| super(BERT_Arch, self).__init__() |
| self.bert = bert |
| self.dropout = nn.Dropout(0.1) |
| self.relu = nn.ReLU() |
| self.fc1 = nn.Linear(768, 512) |
| self.fc2 = nn.Linear(512, 3) |
| self.softmax = nn.LogSoftmax(dim=1) |
|
|
| def forward(self, sent_id, mask): |
| cls_hs = self.bert(sent_id, attention_mask=mask)['pooler_output'] |
| x = self.fc1(cls_hs) |
| x = self.relu(x) |
| x = self.dropout(x) |
| x = self.fc2(x) |
| x = self.softmax(x) |
| return x |
|
|
| |
| model = BERT_Arch(bert) |
| fake_news_model_path = "Hate_Speech_model.pt" |
| fake_news_model = torch.load(fake_news_model_path, map_location=torch.device('cpu')) |
| fake_news_model.eval() |
|
|
| |
| LABELS = {0: "Free", 1: "Hate", 2: "Offensive"} |
|
|
| |
| def detect_fake_news(text): |
| inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128) |
| with torch.no_grad(): |
| outputs = fake_news_model(inputs['input_ids'], inputs['attention_mask']) |
| label = torch.argmax(outputs, dim=1).item() |
| return LABELS[label] |
|
|
| |
| def post_text(text, fake_news_result): |
| if fake_news_result in ["Hate", "Offensive"]: |
| return f"Your message contains {fake_news_result} Speech and cannot be posted.", "" |
| else: |
| return "The text is safe to post.", text |
|
|
| |
| interface = gr.Blocks() |
| with interface: |
| gr.Markdown("## Hate Speech Detection") |
| with gr.Row(): |
| text_input = gr.Textbox(label="Enter Text", lines=5) |
| with gr.Row(): |
| detect_fake_button = gr.Button("Detect Hate Speech") |
| with gr.Row(): |
| fake_news_result_box = gr.Textbox(label="Hate Speech Detection Result", interactive=False) |
| with gr.Row(): |
| post_button = gr.Button("Post Text") |
| with gr.Row(): |
| post_result_box = gr.Textbox(label="Posting Status", interactive=False) |
| posted_text_box = gr.Textbox(label="Posted Text", interactive=False) |
|
|
| detect_fake_button.click( |
| fn=detect_fake_news, |
| inputs=text_input, |
| outputs=fake_news_result_box, |
| ) |
|
|
| post_button.click( |
| fn=post_text, |
| inputs=[text_input, fake_news_result_box], |
| outputs=[post_result_box, posted_text_box], |
| ) |
|
|
| |
| if __name__ == "__main__": |
| interface.launch() |
|
|