Spaces:
Sleeping
Sleeping
| import torch | |
| import threading | |
| import time | |
| import queue | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache | |
| import gradio as gr | |
| import logging | |
| from rnn_model import RNNSeqRegressorHub, D_FEATURES | |
| import spaces | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| # Global variables for model and tokenizer only | |
| model = None | |
| tokenizer = None | |
| # Global variables for RNN model components | |
| rnn_model = None | |
| rnn_gru = None | |
| rnn_head = None | |
| rnn_device = None | |
| # Global state variables for </think> detection (baseline only) | |
| baseline_think_tag_detected = False | |
| baseline_pre_think_content = "" | |
| baseline_post_think_content = "" | |
| baseline_progress_frozen = False | |
| # Global variables for monotonic progress tracking (prevent progress bars from going down) | |
| baseline_max_progress = 0.0 | |
| # Model loading status | |
| model_loaded_successfully = False | |
| model_loading_error = None | |
| def load_model(model_name, torch_dtype="float16"): | |
| """Load model and tokenizer""" | |
| global model, tokenizer | |
| logging.info(f"Loading tokenizer from: {model_name}") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| logging.info("Set pad_token to eos_token.") | |
| logging.info(f"Loading model from: {model_name}") | |
| model_dtype = getattr(torch, torch_dtype) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map="auto", | |
| torch_dtype=model_dtype, | |
| low_cpu_mem_usage=True | |
| ) | |
| logging.info(f"Model loaded successfully") | |
| return "Model loaded successfully" | |
| def load_rnn_model(rnn_path="rnn_seq_regressor.pt"): | |
| """Load RNN model for progress prediction""" | |
| global model, rnn_model, rnn_gru, rnn_head, rnn_device | |
| if model is None: | |
| return "Please load the main model first" | |
| try: | |
| # Initialize RNN model with D_FEATURES hidden dimension | |
| rnn_repo_id = "royeis/DeepSeek-R1-Distill-Qwen-32B-GRUThinkingProgressRegressor" | |
| rnn_model = RNNSeqRegressorHub.from_pretrained(rnn_repo_id).to(model.device) | |
| # Extract components for optimized access | |
| rnn_gru = rnn_model.rnn | |
| rnn_head = rnn_model.head | |
| # Store RNN device for later use | |
| rnn_device = next(rnn_model.parameters()).device | |
| logging.info(f"RNN model loaded successfully from {rnn_path}") | |
| logging.info(f"RNN model device: {rnn_device}") | |
| return f"RNN model loaded successfully from {rnn_path}" | |
| except Exception as e: | |
| logging.error(f"Error loading RNN model: {str(e)}") | |
| return f"Error loading RNN model: {str(e)}" | |
| def reset_progress_tracking(): | |
| """Reset progress tracking variables for monotonic progress enforcement""" | |
| global baseline_max_progress | |
| baseline_max_progress = 0.0 | |
| logging.info("Progress tracking variables reset for new generation") | |
| def load_model_and_rnn(): | |
| """Load both model and RNN model with hardcoded defaults""" | |
| model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" | |
| rnn_path = "rnn_seq_regressor.pt" | |
| # Load model first | |
| model_status = load_model(model_name) | |
| if "successfully" not in model_status: | |
| return f"Model loading failed: {model_status}" | |
| # Then load RNN model | |
| rnn_status = load_rnn_model(rnn_path) | |
| if "successfully" not in rnn_status: | |
| return f"RNN model loading failed: {rnn_status}" | |
| return "Model and RNN model loaded successfully" | |
| def generate_baseline_only( | |
| prompt, | |
| max_new_tokens=1024, | |
| baseline_progress_callback=None, | |
| baseline_tokens_callback=None, | |
| stop_event=None | |
| ): | |
| """Generates tokens with baseline only (no intervention).""" | |
| global model, tokenizer | |
| if model is None or tokenizer is None: | |
| return "Please load a model and tokenizer first", 0 | |
| # Always add thinking tag if not already present | |
| if "<think>" not in prompt: | |
| prompt = prompt + "\nPlease reason step by step, and put your final answer within \\boxed{{}}. \n<think>\n" | |
| # Tokenize input | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| input_ids = inputs.input_ids | |
| attention_mask = inputs.attention_mask if hasattr(inputs, 'attention_mask') else torch.ones_like(input_ids) | |
| # Progress tracking | |
| baseline_progress_values = [] | |
| baseline_smoothed_progress_values = [0] | |
| beta = 0.05 | |
| baseline_generated_tokens = [] | |
| baseline_token_ids = [] | |
| # Track current progress for real-time monitoring | |
| baseline_last_reported_progress = 0 | |
| # EOS tracking | |
| baseline_finished = False | |
| hook_handles = [] | |
| # RNN state for baseline progress tracking | |
| baseline_rnn_state = {} | |
| def baseline_forward_hook(module, input_args, output): | |
| """Hook that runs on every forward pass through the targeted layer for baseline only""" | |
| nonlocal baseline_last_reported_progress | |
| original_output = output | |
| hidden_states_to_check = None | |
| if isinstance(output, tuple): | |
| hidden_states_to_check = output[0] | |
| elif isinstance(output, torch.Tensor): | |
| hidden_states_to_check = output | |
| else: | |
| logging.warning(f"Unexpected output type from hooked layer: {type(output)}") | |
| return output | |
| if hidden_states_to_check is None or not isinstance(hidden_states_to_check, torch.Tensor): | |
| logging.warning("Hooked layer output does not contain a tensor of hidden states as expected.") | |
| return output | |
| # Calculate predicted progress for baseline using RNN | |
| baseline_p_value = 0.0 | |
| if rnn_gru is not None and rnn_head is not None and rnn_device is not None: | |
| try: | |
| with torch.no_grad(): | |
| # Convert to float and move to RNN device for processing | |
| hidden_f = hidden_states_to_check[:, -1, :].clone().float().to(rnn_device) | |
| # Initialize RNN state if not exists | |
| if len(baseline_rnn_state) == 0: | |
| baseline_rnn_state['state'] = hidden_f * 0.0 | |
| # RNN forward pass | |
| out_t, baseline_rnn_state['state'] = rnn_gru(hidden_f, baseline_rnn_state['state']) | |
| baseline_p_value = rnn_head(out_t).squeeze(-1) | |
| baseline_p_value = max(0.0, min(1.0, baseline_p_value.item())) | |
| except Exception as e: | |
| logging.warning(f"RNN computation failed: {str(e)}") | |
| baseline_p_value = 0.0 | |
| # Store the progress values | |
| baseline_progress_values.append(baseline_p_value) | |
| # Store smoothed progress for real-time updates | |
| baseline_smoothed_p_value = beta * baseline_p_value + (1 - beta) * baseline_smoothed_progress_values[-1] | |
| # Enforce monotonic progress (prevent going down) | |
| global baseline_max_progress | |
| baseline_smoothed_p_value = max(baseline_smoothed_p_value, baseline_max_progress) | |
| baseline_max_progress = baseline_smoothed_p_value | |
| baseline_smoothed_progress_values.append(baseline_smoothed_p_value) | |
| # Update progress callbacks | |
| global baseline_progress_frozen | |
| if baseline_progress_callback and not baseline_progress_frozen and abs(baseline_smoothed_p_value - baseline_last_reported_progress) > 0.001: | |
| baseline_progress_callback(int(baseline_smoothed_p_value * 100)) | |
| baseline_last_reported_progress = baseline_smoothed_p_value | |
| elif baseline_progress_callback and baseline_progress_frozen: | |
| baseline_progress_callback(100.0) | |
| return original_output | |
| # Register hook for progress tracking | |
| if rnn_gru is not None and rnn_head is not None: | |
| try: | |
| target_module = model.model.norm | |
| hook_handles.append(target_module.register_forward_hook(baseline_forward_hook)) | |
| logging.info(f"Baseline progress hook registered successfully") | |
| except AttributeError: | |
| logging.warning("Could not find model.model.norm. Trying to find another appropriate layer...") | |
| try: | |
| for name, module in model.named_modules(): | |
| if 'norm' in name and 'model' in name and isinstance(module, torch.nn.Module): | |
| hook_handles.append(module.register_forward_hook(baseline_forward_hook)) | |
| logging.info(f"Baseline progress hook registered on {name}") | |
| break | |
| else: | |
| logging.warning("Could not find appropriate normalization layer for progress hook.") | |
| except Exception as e: | |
| logging.error(f"Error setting up progress hook: {str(e)}") | |
| try: | |
| # Initialize KV cache and input preparation | |
| past_key_values = DynamicCache() | |
| # Initial forward pass | |
| with torch.no_grad(): | |
| outputs = model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| past_key_values=past_key_values, | |
| use_cache=True, | |
| return_dict=True | |
| ) | |
| # Update cache with prompt's key-values | |
| past_key_values = outputs.past_key_values | |
| # Get logits and sample first token | |
| logits = outputs.logits[:, -1, :] # Last token's logits | |
| next_token_id = torch.argmax(logits, dim=-1).unsqueeze(-1) # Greedy sampling | |
| # Process token | |
| baseline_token_id = next_token_id[0].item() | |
| baseline_token_ids.append(baseline_token_id) | |
| baseline_new_token_text = tokenizer.decode([baseline_token_id]) | |
| baseline_generated_tokens.append(baseline_new_token_text) | |
| # Update callback with the first token | |
| if baseline_tokens_callback: | |
| baseline_current_text = tokenizer.decode(baseline_token_ids, skip_special_tokens=True) | |
| baseline_tokens_callback(baseline_current_text, len(baseline_token_ids)) | |
| # Check for EOS | |
| if baseline_token_id == tokenizer.eos_token_id: | |
| baseline_finished = True | |
| baseline_result = tokenizer.decode(baseline_token_ids, skip_special_tokens=True) | |
| baseline_total_tokens = len(baseline_token_ids) | |
| if baseline_progress_callback: | |
| baseline_progress_callback(100.0) | |
| logging.info(f"Baseline generation completed early with EOS. Total tokens: {baseline_total_tokens}") | |
| return baseline_result, baseline_total_tokens | |
| # Continue generation loop with cached states | |
| for step in range(1, max_new_tokens): | |
| # Check if we should stop | |
| if stop_event and stop_event.is_set(): | |
| logging.info(f"Generation stopped by user at step {step}") | |
| break | |
| # If finished, break | |
| if baseline_finished: | |
| logging.info(f"Baseline generation completed at step {step}") | |
| break | |
| # Prepare input token for next step | |
| current_input_id = next_token_id # [1, 1] | |
| # Extend attention mask for the new token | |
| if attention_mask is not None: | |
| new_attention = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype) | |
| attention_mask = torch.cat([attention_mask, new_attention], dim=1) | |
| # Forward pass with cached key-values | |
| with torch.no_grad(): | |
| outputs = model( | |
| input_ids=current_input_id, | |
| attention_mask=attention_mask, | |
| past_key_values=past_key_values, | |
| use_cache=True, | |
| return_dict=True | |
| ) | |
| # Update cache | |
| past_key_values = outputs.past_key_values | |
| # Get next token | |
| logits = outputs.logits[:, -1, :] | |
| next_token_id = torch.argmax(logits, dim=-1).unsqueeze(-1) | |
| # Process token | |
| baseline_token_id = next_token_id[0].item() | |
| baseline_token_ids.append(baseline_token_id) | |
| baseline_new_token_text = tokenizer.decode([baseline_token_id]) | |
| baseline_generated_tokens.append(baseline_new_token_text) | |
| # Update callback | |
| if baseline_tokens_callback: | |
| baseline_current_text = tokenizer.decode(baseline_token_ids, skip_special_tokens=True) | |
| baseline_tokens_callback(baseline_current_text, len(baseline_token_ids)) | |
| # Check if EOS token was generated | |
| if baseline_token_id == tokenizer.eos_token_id: | |
| baseline_finished = True | |
| logging.info(f"Baseline generation completed with EOS at step {step}. Total tokens: {len(baseline_token_ids)}") | |
| # Final decoding of all tokens | |
| baseline_result = tokenizer.decode(baseline_token_ids, skip_special_tokens=True) | |
| baseline_total_tokens = len(baseline_token_ids) | |
| # Ensure final progress is set to 100% | |
| if baseline_progress_callback: | |
| baseline_progress_callback(100.0) | |
| logging.info(f"Baseline generation completed. Total tokens: {baseline_total_tokens}") | |
| finally: | |
| for handle in hook_handles: | |
| handle.remove() | |
| if hook_handles: | |
| logging.info("All progress hooks removed") | |
| return baseline_result, baseline_total_tokens | |
| # Automatically load model and RNN on startup | |
| def initialize_app(): | |
| """Initialize the app by loading model and RNN on startup""" | |
| global model_loaded_successfully, model_loading_error | |
| logging.info("Starting automatic model loading...") | |
| try: | |
| result = load_model_and_rnn() | |
| if "successfully" in result: | |
| model_loaded_successfully = True | |
| logging.info("Model and RNN model loaded successfully on startup") | |
| else: | |
| model_loaded_successfully = False | |
| model_loading_error = result | |
| logging.error(f"Failed to load model on startup: {result}") | |
| except Exception as e: | |
| model_loaded_successfully = False | |
| model_loading_error = str(e) | |
| logging.error(f"Exception during model loading: {str(e)}") | |
| # Load model automatically when script starts | |
| initialize_app() | |
| # Global function for resetting UI state | |
| def reset_ui(): | |
| """Reset the UI elements for a new generation""" | |
| global baseline_think_tag_detected, baseline_progress_frozen | |
| global baseline_pre_think_content, baseline_post_think_content | |
| # Reset progress tracking for monotonic behavior | |
| reset_progress_tracking() | |
| baseline_think_tag_detected = False | |
| baseline_progress_frozen = False | |
| baseline_pre_think_content = "" | |
| baseline_post_think_content = "" | |
| return { | |
| "status": "**Starting generation...**", | |
| "progress": 0, | |
| "thinking": "", | |
| "answer": "", | |
| "tokens": "", | |
| "generate_btn_text": "Generating...", | |
| "generate_btn_interactive": False, | |
| "stop_btn_interactive": True | |
| } | |
| def generate_with_updates(prompt, baseline_progress_queue, baseline_tokens_queue, stop_generation): | |
| """Wrapper around generation function that handles real-time updates""" | |
| # Check if model is loaded | |
| if not model_loaded_successfully: | |
| yield { | |
| "status": f"**Cannot generate: {model_loading_error}**" | |
| } | |
| return | |
| # Use default values | |
| max_tokens = 2048 | |
| # Reset UI first | |
| yield reset_ui() | |
| # Start generation in a separate thread to allow for UI updates | |
| baseline_result = "" | |
| baseline_token_count = 0 | |
| generation_error = None | |
| generation_thread = None | |
| def baseline_progress_updater(prog_value): | |
| """Update the baseline progress via the queue""" | |
| baseline_progress_queue.put(prog_value) | |
| def baseline_tokens_updater(text, token_count): | |
| """Update the baseline generated text via the queue""" | |
| global baseline_think_tag_detected, baseline_progress_frozen, baseline_pre_think_content, baseline_post_think_content | |
| # Check if </think> tag appears in the text | |
| if not baseline_think_tag_detected and "</think>" in text: | |
| baseline_think_tag_detected = True | |
| baseline_progress_frozen = True | |
| # Split content at </think> | |
| parts = text.split("</think>", 1) | |
| baseline_pre_think_content = parts[0] + "</think>" | |
| baseline_post_think_content = parts[1] if len(parts) > 1 else "" | |
| # Signal content split with token count | |
| baseline_tokens_queue.put(("THINK_TAG_DETECTED", baseline_pre_think_content, baseline_post_think_content, token_count)) | |
| elif baseline_think_tag_detected: | |
| # Update post-think content | |
| if "</think>" in text: | |
| parts = text.split("</think>", 1) | |
| baseline_post_think_content = parts[1] if len(parts) > 1 else "" | |
| baseline_tokens_queue.put(("POST_THINK_UPDATE", baseline_post_think_content)) | |
| else: | |
| baseline_tokens_queue.put(("NORMAL_UPDATE", text)) | |
| else: | |
| # Normal pre-think streaming with token count | |
| baseline_tokens_queue.put(("NORMAL_UPDATE", text, token_count)) | |
| def run_generation(): | |
| nonlocal baseline_result, baseline_token_count, generation_error | |
| try: | |
| # Baseline-only generation | |
| baseline_result, baseline_token_count = generate_baseline_only( | |
| prompt=prompt, | |
| max_new_tokens=max_tokens, | |
| baseline_progress_callback=baseline_progress_updater, | |
| baseline_tokens_callback=baseline_tokens_updater, | |
| stop_event=stop_generation | |
| ) | |
| except Exception as e: | |
| generation_error = str(e) | |
| # Start the generation thread | |
| generation_thread = threading.Thread(target=run_generation) | |
| generation_thread.start() | |
| # Monitor queues for updates while generation is running | |
| baseline_current_text = "" | |
| baseline_thinking_tokens = 0 | |
| baseline_last_progress = 0 | |
| while generation_thread.is_alive() or not baseline_tokens_queue.empty() or not baseline_progress_queue.empty(): | |
| updates = {} | |
| # Check baseline tokens queue | |
| try: | |
| while not baseline_tokens_queue.empty(): | |
| token_update = baseline_tokens_queue.get_nowait() | |
| if isinstance(token_update, tuple): | |
| update_type = token_update[0] | |
| if update_type == "THINK_TAG_DETECTED": | |
| # </think> tag detected - split content | |
| pre_content = token_update[1] | |
| post_content = token_update[2] | |
| thinking_token_count = token_update[3] | |
| updates["thinking"] = pre_content | |
| updates["answer"] = post_content | |
| updates["progress"] = 100.0 # Freeze at 100% | |
| # Use actual token count (before </think>) | |
| baseline_thinking_tokens = thinking_token_count | |
| updates["tokens"] = f"{baseline_thinking_tokens}" | |
| elif update_type == "POST_THINK_UPDATE": | |
| # Update only the final answer | |
| post_content = token_update[1] | |
| updates["answer"] = post_content | |
| # Don't update token count - frozen at thinking tokens | |
| elif update_type == "NORMAL_UPDATE": | |
| # Normal text update | |
| baseline_current_text = token_update[1] | |
| if not baseline_think_tag_detected: | |
| updates["thinking"] = baseline_current_text | |
| # Update thinking token count with actual token count if available | |
| if len(token_update) > 2: | |
| baseline_thinking_tokens = token_update[2] | |
| else: | |
| # Fallback to word count for backward compatibility | |
| baseline_thinking_tokens = len(baseline_current_text.split()) | |
| updates["tokens"] = f"{baseline_thinking_tokens}" | |
| else: | |
| # This shouldn't happen, but handle it gracefully | |
| updates["answer"] = baseline_current_text | |
| else: | |
| # Backward compatibility - treat as normal text | |
| baseline_current_text = token_update | |
| updates["thinking"] = baseline_current_text | |
| if not baseline_think_tag_detected: | |
| baseline_thinking_tokens = len(baseline_current_text.split()) | |
| updates["tokens"] = f"{baseline_thinking_tokens}" | |
| except queue.Empty: | |
| pass | |
| # Check baseline progress queue | |
| try: | |
| while not baseline_progress_queue.empty(): | |
| baseline_last_progress = baseline_progress_queue.get_nowait() | |
| updates["progress"] = baseline_last_progress | |
| except queue.Empty: | |
| pass | |
| # If there are any updates, yield them | |
| if updates: | |
| yield updates | |
| # Sleep briefly to prevent excessive CPU usage | |
| time.sleep(0.05) | |
| # Final update | |
| final_updates = { | |
| "status": "**Generation complete!**" if not generation_error else f"**Error: {generation_error}**", | |
| "progress": 100, | |
| "generate_btn_text": "Generate", | |
| "generate_btn_interactive": True, | |
| "stop_btn_interactive": True | |
| } | |
| if not generation_error: | |
| # Handle baseline final display | |
| if baseline_think_tag_detected: | |
| # Split result for final display | |
| if "</think>" in baseline_result: | |
| parts = baseline_result.split("</think>", 1) | |
| final_updates["thinking"] = parts[0] + "</think>" | |
| final_updates["answer"] = parts[1] if len(parts) > 1 else "" | |
| # Use actual token count from generation | |
| if baseline_thinking_tokens > 0: | |
| final_updates["tokens"] = f"{baseline_thinking_tokens}" | |
| else: | |
| # Fallback: use actual token count for thinking part | |
| thinking_text = parts[0] + "</think>" | |
| thinking_token_count = len(tokenizer.encode(thinking_text, add_special_tokens=False)) | |
| final_updates["tokens"] = f"{thinking_token_count}" | |
| else: | |
| final_updates["thinking"] = baseline_result | |
| # Use actual token count | |
| if baseline_thinking_tokens > 0: | |
| final_updates["tokens"] = f"{baseline_thinking_tokens}" | |
| else: | |
| total_token_count = len(tokenizer.encode(baseline_result, add_special_tokens=False)) | |
| final_updates["tokens"] = f"{total_token_count}" | |
| else: | |
| final_updates["thinking"] = baseline_result | |
| # Use actual token count | |
| if baseline_thinking_tokens > 0: | |
| final_updates["tokens"] = f"{baseline_thinking_tokens}" | |
| else: | |
| total_token_count = len(tokenizer.encode(baseline_result, add_special_tokens=False)) | |
| final_updates["tokens"] = f"{total_token_count}" | |
| yield final_updates | |
| # Create the Gradio interface | |
| def create_interface(): | |
| # Create custom theme with light green progress bars | |
| custom_theme = gr.themes.Base().set( | |
| slider_color="#90EE90", # Light green | |
| slider_color_dark="#32CD32", # Lime green for dark mode | |
| # Additional slider styling | |
| color_accent="#90EE90", | |
| color_accent_soft="#E8F5E8" | |
| ) | |
| # Custom CSS to fix text box heights, enable scrolling, and enlarge fonts | |
| custom_css = """ | |
| /* Base font size increases for all text areas */ | |
| .gr-textbox textarea { | |
| font-size: 18px !important; | |
| line-height: 1.5em !important; | |
| overflow-y: auto !important; | |
| resize: none !important; | |
| } | |
| /* Prompt textbox - larger font */ | |
| .gr-textbox:has(textarea[placeholder*="math"]) textarea, | |
| .gr-textbox:has(textarea[placeholder*="Enter"]) textarea { | |
| font-size: 20px !important; | |
| line-height: 1.6em !important; | |
| } | |
| /* Thinking Process and Final Answer boxes - enhanced readability */ | |
| .fixed-height-textbox textarea { | |
| font-size: 16px !important; | |
| line-height: 1.5em !important; | |
| overflow-y: auto !important; | |
| resize: none !important; | |
| } | |
| /* Thinking Process boxes - adjusted height for larger fonts */ | |
| .fixed-height-textbox:has(textarea[data-testid*="textbox"]) textarea { | |
| height: 280px !important; | |
| max-height: 280px !important; | |
| min-height: 280px !important; | |
| } | |
| /* Progress bar labels and sliders */ | |
| .gr-slider .gr-label { | |
| font-size: 16px !important; | |
| font-weight: 500 !important; | |
| } | |
| /* All buttons - larger and more readable */ | |
| .gr-button { | |
| font-size: 16px !important; | |
| font-weight: 500 !important; | |
| padding: 8px 16px !important; | |
| } | |
| /* Token count displays - readable but not overwhelming */ | |
| .gr-textbox[data-testid*="textbox"][readonly] textarea, | |
| .gr-textbox.gr-readonly textarea { | |
| font-size: 14px !important; | |
| line-height: 1.4em !important; | |
| text-align: center !important; | |
| font-weight: 500 !important; | |
| } | |
| /* All labels - consistent sizing */ | |
| .gr-label { | |
| font-size: 15px !important; | |
| font-weight: 500 !important; | |
| } | |
| /* Status messages and markdown - larger and clearer */ | |
| .gr-markdown { | |
| font-size: 16px !important; | |
| line-height: 1.4em !important; | |
| } | |
| /* Checkbox labels */ | |
| .gr-checkbox .gr-label { | |
| font-size: 15px !important; | |
| } | |
| /* Headers and titles */ | |
| .gr-markdown h1 { | |
| font-size: 28px !important; | |
| } | |
| .gr-markdown h2 { | |
| font-size: 22px !important; | |
| } | |
| /* Progress percentage display */ | |
| .gr-slider .gr-number { | |
| font-size: 14px !important; | |
| font-weight: 500 !important; | |
| } | |
| /* Custom light green gradient for progress bars */ | |
| .gr-slider input[type="range"]::-webkit-slider-runnable-track { | |
| background: linear-gradient(90deg, #90EE90 0%, #32CD32 100%) !important; | |
| border: none !important; | |
| height: 8px !important; | |
| border-radius: 4px !important; | |
| } | |
| .gr-slider input[type="range"]::-moz-range-track { | |
| background: linear-gradient(90deg, #90EE90 0%, #32CD32 100%) !important; | |
| border: none !important; | |
| height: 8px !important; | |
| border-radius: 4px !important; | |
| } | |
| /* Webkit slider thumb styling */ | |
| .gr-slider input[type="range"]::-webkit-slider-thumb { | |
| background: #228B22 !important; | |
| border: 2px solid #ffffff !important; | |
| border-radius: 50% !important; | |
| cursor: pointer !important; | |
| } | |
| /* Firefox slider thumb styling */ | |
| .gr-slider input[type="range"]::-moz-range-thumb { | |
| background: #228B22 !important; | |
| border: 2px solid #ffffff !important; | |
| border-radius: 50% !important; | |
| cursor: pointer !important; | |
| } | |
| /* Progress bar container background */ | |
| .gr-slider { | |
| --slider-color: linear-gradient(90deg, #90EE90 0%, #32CD32 100%) !important; | |
| } | |
| /* Additional Gradio slider styling for better compatibility */ | |
| .gr-slider .gr-range { | |
| background: linear-gradient(90deg, #90EE90 0%, #32CD32 100%) !important; | |
| } | |
| /* Style the slider track background (unfilled portion) */ | |
| .gr-slider input[type="range"] { | |
| background: transparent !important; | |
| } | |
| .gr-slider input[type="range"]::-webkit-slider-runnable-track { | |
| background: linear-gradient(90deg, #E8E8E8 0%, #E8E8E8 var(--slider-percent, 0%), #90EE90 var(--slider-percent, 0%), #32CD32 100%) !important; | |
| } | |
| """ | |
| with gr.Blocks(title="Baseline Generation Demo", css=custom_css, theme=custom_theme) as demo: | |
| # Display model loading status in the header | |
| if model_loaded_successfully: | |
| status_message = "✅ **DeepSeek-R1-Distill-Qwen-32B model loaded and ready for generation**" | |
| else: | |
| status_message = f"❌ **Model loading failed: {model_loading_error}**" | |
| gr.Markdown(f""" | |
| # 🚀 Reasoning Loading Bar Space ⏳ | |
| {status_message} | |
| This Space demonstrates real-time progress tracking of a reasoning model. | |
| ## How it works: | |
| 1. Enter a prompt below - it works best with math problems that require reasoning | |
| 2. Click "Generate" to start generation with progress visualization | |
| """) | |
| # Generation section | |
| with gr.Row(): | |
| with gr.Column(scale=8): | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Enter your math problem or reasoning task here...", | |
| lines=3 | |
| ) | |
| with gr.Column(scale=1): | |
| generate_btn = gr.Button( | |
| "Generate", | |
| variant="primary" if model_loaded_successfully else "secondary", | |
| interactive=model_loaded_successfully | |
| ) | |
| stop_btn = gr.Button("Stop", variant="stop", interactive=model_loaded_successfully) | |
| # Generation progress and results | |
| with gr.Row(): | |
| generation_status = gr.Markdown("**Ready to generate**" if model_loaded_successfully else "**Model not loaded - cannot generate**") | |
| # Single progress section for baseline only | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| baseline_progress_bar = gr.Slider( | |
| minimum=0, | |
| maximum=100, | |
| value=0, | |
| label="Generation Progress (%)", | |
| interactive=False, | |
| scale=5 | |
| ) | |
| baseline_tokens_count = gr.Textbox(label="Thinking tokens", value="", interactive=False, scale=1) | |
| baseline_thinking_output = gr.Textbox(label="🧠 Thinking Process", lines=10, value="", elem_classes=["fixed-height-textbox"]) | |
| baseline_answer_output = gr.Textbox(label="✅ Final Answer", lines=4, value="", elem_classes=["fixed-height-textbox"]) | |
| # Create queues to store progress and token updates | |
| baseline_progress_queue = queue.Queue() | |
| baseline_tokens_queue = queue.Queue() | |
| stop_generation = threading.Event() | |
| def stop_generation_fn(): | |
| """Stop the generation process""" | |
| stop_generation.set() | |
| return "Generation stopped" | |
| def generate_wrapper(prompt): | |
| """Wrapper to adapt the global generate_with_updates function for Gradio""" | |
| # Process updates from the global function and map to UI components | |
| for update_dict in generate_with_updates(prompt, baseline_progress_queue, baseline_tokens_queue, stop_generation): | |
| gradio_updates = {} | |
| # Map the string keys to actual Gradio components | |
| if "status" in update_dict: | |
| gradio_updates[generation_status] = update_dict["status"] | |
| if "progress" in update_dict: | |
| gradio_updates[baseline_progress_bar] = update_dict["progress"] | |
| if "thinking" in update_dict: | |
| gradio_updates[baseline_thinking_output] = update_dict["thinking"] | |
| if "answer" in update_dict: | |
| gradio_updates[baseline_answer_output] = update_dict["answer"] | |
| if "tokens" in update_dict: | |
| gradio_updates[baseline_tokens_count] = update_dict["tokens"] | |
| if "generate_btn_text" in update_dict: | |
| gradio_updates[generate_btn] = gr.Button( | |
| update_dict["generate_btn_text"], | |
| variant="secondary" if "Generating" in update_dict["generate_btn_text"] else "primary", | |
| interactive=update_dict.get("generate_btn_interactive", True) | |
| ) | |
| if "stop_btn_interactive" in update_dict: | |
| gradio_updates[stop_btn] = gr.Button( | |
| "Stop", | |
| variant="stop", | |
| interactive=update_dict["stop_btn_interactive"] | |
| ) | |
| yield gradio_updates | |
| # Connect the buttons to the handlers | |
| if model_loaded_successfully: | |
| generate_btn.click( | |
| generate_wrapper, | |
| inputs=[prompt], | |
| outputs=[ | |
| generation_status, | |
| baseline_progress_bar, | |
| baseline_thinking_output, | |
| baseline_answer_output, | |
| baseline_tokens_count, | |
| generate_btn, | |
| stop_btn | |
| ] | |
| ) | |
| stop_btn.click( | |
| stop_generation_fn, | |
| outputs=[generation_status] | |
| ) | |
| return demo | |
| # Launch the app if running directly | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch(share=True) | |