import cv2 import gradio as gr import numpy as np import os import torch from pytorchvideo.transforms import ( Normalize, UniformTemporalSubsample, ) from torchvision.transforms import Compose, Lambda, Resize from transformers import VideoMAEFeatureExtractor, VideoMAEForVideoClassification # FIXED IMPORT from torchvision.transforms import functional as F # --- Fix pytorchvideo import error for Kaggle/torchvision >= 0.17 --- import sys import types # Create a fake module to satisfy pytorchvideo fake_ft = types.ModuleType("torchvision.transforms.functional_tensor") sys.modules["torchvision.transforms.functional_tensor"] = fake_ft # Load model and processor MODEL_CKPT = "Shawon16/VideoMAE_BdSLW401_20_epochs_p5_SR_10" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") MODEL = VideoMAEForVideoClassification.from_pretrained(MODEL_CKPT).to(DEVICE) PROCESSOR = VideoMAEFeatureExtractor.from_pretrained(MODEL_CKPT) RESIZE_TO = PROCESSOR.size["shortest_edge"] NUM_FRAMES_TO_SAMPLE = MODEL.config.num_frames IMAGE_STATS = {"image_mean": [0.485, 0.456, 0.406], "image_std": [0.229, 0.224, 0.225]} VAL_TRANSFORMS = Compose( [ UniformTemporalSubsample(NUM_FRAMES_TO_SAMPLE), Lambda(lambda x: x / 255.0), Normalize(IMAGE_STATS["image_mean"], IMAGE_STATS["image_std"]), Resize((RESIZE_TO, RESIZE_TO)), ] ) LABELS = list(MODEL.config.label2id.keys()) def parse_video(video_file): """Extract frames from a video file with a sample rate of 10.""" vs = cv2.VideoCapture(video_file) frames = [] frame_id = 0 while True: grabbed, frame = vs.read() if not grabbed: break if frame_id % 10 == 0: # Sample every 10th frame frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(frame) frame_id += 1 vs.release() return frames def preprocess_video(frames): """Preprocess video frames for inference.""" video_tensor = torch.tensor(np.array(frames).astype(frames[0].dtype)) video_tensor = video_tensor.permute(3, 0, 1, 2) # (num_channels, num_frames, height, width) video_tensor_pp = VAL_TRANSFORMS(video_tensor) video_tensor_pp = video_tensor_pp.permute(1, 0, 2, 3) # (num_frames, num_channels, height, width) video_tensor_pp = video_tensor_pp.unsqueeze(0) # Add batch dimension return video_tensor_pp.to(DEVICE) def infer(video_file): frames = parse_video(video_file) video_tensor = preprocess_video(frames) inputs = {"pixel_values": video_tensor} # Forward pass with torch.no_grad(): outputs = MODEL(**inputs) logits = outputs.logits softmax_scores = torch.nn.functional.softmax(logits, dim=-1).squeeze(0) confidences = {LABELS[i]: float(softmax_scores[i]) for i in range(len(LABELS))} return confidences, frames # Remove confidence plot custom_css = """ /* Hide the webcam button */ button[data-testid="webcam-button"] { display: none !important; } /* Reduce padding and margins */ .gradio-container { max-width: 700px !important; /* Set a smaller max width */ margin: auto; padding: 10px !important; } /* Reduce the gallery size */ .gr-gallery { max-height: 200px !important; /* Make the frames smaller */ } /* Center the title */ h1 { text-align: center !important; } """ gr.Interface( fn=infer, inputs=[gr.Video(label="Upload Video")], # Keep Video for preview outputs=[ gr.Label(num_top_classes=5, label="Top 5 Predictions"), gr.Gallery(label="Sampled Frames (Rate: 10)", columns=4, height="200px"), # Smaller gallery ], examples=[ ["W002S08F_03.mp4"], ["W003S08F_11.mp4"], #["W205S08F_02.mp4"], #["W211S04F_03.mp4"], ["W389S08F_02.mp4"], ["W401S04F_06.mp4"], #[r"C:\Users\shawo\Desktop\BdSLW60 Full DataSet\FrameRate Corrected Clips\W2\U8W2F_trial_6_R.mp4"], #[r"C:\Users\shawo\Desktop\BdSLW60 Full DataSet\FrameRate Corrected Clips\W20\U4W20F_trial_9_R.mp4"], ], title="Bangla Word Level (BdSLW401) Sign Language Recognition Interface", description=( "This framework uses a fine-tuned VideoLLM (VideoMAE) to classify Bangla Sign Language words from video inputs." " Upload a video for predictions." ), flagging_mode="never", css=custom_css, # Apply custom CSS for compact design ).launch()