import torch import threading import time import queue from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache import gradio as gr import logging from rnn_model import RNNSeqRegressorHub, D_FEATURES import spaces # Setup logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # Global variables for model and tokenizer only model = None tokenizer = None # Global variables for RNN model components rnn_model = None rnn_gru = None rnn_head = None rnn_device = None # Global state variables for detection (baseline only) baseline_think_tag_detected = False baseline_pre_think_content = "" baseline_post_think_content = "" baseline_progress_frozen = False # Global variables for monotonic progress tracking (prevent progress bars from going down) baseline_max_progress = 0.0 # Model loading status model_loaded_successfully = False model_loading_error = None def load_model(model_name, torch_dtype="float16"): """Load model and tokenizer""" global model, tokenizer logging.info(f"Loading tokenizer from: {model_name}") tokenizer = AutoTokenizer.from_pretrained(model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token logging.info("Set pad_token to eos_token.") logging.info(f"Loading model from: {model_name}") model_dtype = getattr(torch, torch_dtype) model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", torch_dtype=model_dtype, low_cpu_mem_usage=True ) logging.info(f"Model loaded successfully") return "Model loaded successfully" def load_rnn_model(rnn_path="rnn_seq_regressor.pt"): """Load RNN model for progress prediction""" global model, rnn_model, rnn_gru, rnn_head, rnn_device if model is None: return "Please load the main model first" try: # Initialize RNN model with D_FEATURES hidden dimension rnn_repo_id = "royeis/DeepSeek-R1-Distill-Qwen-32B-GRUThinkingProgressRegressor" rnn_model = RNNSeqRegressorHub.from_pretrained(rnn_repo_id).to(model.device) # Extract components for optimized access rnn_gru = rnn_model.rnn rnn_head = rnn_model.head # Store RNN device for later use rnn_device = next(rnn_model.parameters()).device logging.info(f"RNN model loaded successfully from {rnn_path}") logging.info(f"RNN model device: {rnn_device}") return f"RNN model loaded successfully from {rnn_path}" except Exception as e: logging.error(f"Error loading RNN model: {str(e)}") return f"Error loading RNN model: {str(e)}" def reset_progress_tracking(): """Reset progress tracking variables for monotonic progress enforcement""" global baseline_max_progress baseline_max_progress = 0.0 logging.info("Progress tracking variables reset for new generation") @spaces.GPU(duration=120) def load_model_and_rnn(): """Load both model and RNN model with hardcoded defaults""" model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" rnn_path = "rnn_seq_regressor.pt" # Load model first model_status = load_model(model_name) if "successfully" not in model_status: return f"Model loading failed: {model_status}" # Then load RNN model rnn_status = load_rnn_model(rnn_path) if "successfully" not in rnn_status: return f"RNN model loading failed: {rnn_status}" return "Model and RNN model loaded successfully" def generate_baseline_only( prompt, max_new_tokens=1024, baseline_progress_callback=None, baseline_tokens_callback=None, stop_event=None ): """Generates tokens with baseline only (no intervention).""" global model, tokenizer if model is None or tokenizer is None: return "Please load a model and tokenizer first", 0 # Always add thinking tag if not already present if "" not in prompt: prompt = prompt + "\nPlease reason step by step, and put your final answer within \\boxed{{}}. \n\n" # Tokenize input inputs = tokenizer(prompt, return_tensors="pt").to(model.device) input_ids = inputs.input_ids attention_mask = inputs.attention_mask if hasattr(inputs, 'attention_mask') else torch.ones_like(input_ids) # Progress tracking baseline_progress_values = [] baseline_smoothed_progress_values = [0] beta = 0.05 baseline_generated_tokens = [] baseline_token_ids = [] # Track current progress for real-time monitoring baseline_last_reported_progress = 0 # EOS tracking baseline_finished = False hook_handles = [] # RNN state for baseline progress tracking baseline_rnn_state = {} def baseline_forward_hook(module, input_args, output): """Hook that runs on every forward pass through the targeted layer for baseline only""" nonlocal baseline_last_reported_progress original_output = output hidden_states_to_check = None if isinstance(output, tuple): hidden_states_to_check = output[0] elif isinstance(output, torch.Tensor): hidden_states_to_check = output else: logging.warning(f"Unexpected output type from hooked layer: {type(output)}") return output if hidden_states_to_check is None or not isinstance(hidden_states_to_check, torch.Tensor): logging.warning("Hooked layer output does not contain a tensor of hidden states as expected.") return output # Calculate predicted progress for baseline using RNN baseline_p_value = 0.0 if rnn_gru is not None and rnn_head is not None and rnn_device is not None: try: with torch.no_grad(): # Convert to float and move to RNN device for processing hidden_f = hidden_states_to_check[:, -1, :].clone().float().to(rnn_device) # Initialize RNN state if not exists if len(baseline_rnn_state) == 0: baseline_rnn_state['state'] = hidden_f * 0.0 # RNN forward pass out_t, baseline_rnn_state['state'] = rnn_gru(hidden_f, baseline_rnn_state['state']) baseline_p_value = rnn_head(out_t).squeeze(-1) baseline_p_value = max(0.0, min(1.0, baseline_p_value.item())) except Exception as e: logging.warning(f"RNN computation failed: {str(e)}") baseline_p_value = 0.0 # Store the progress values baseline_progress_values.append(baseline_p_value) # Store smoothed progress for real-time updates baseline_smoothed_p_value = beta * baseline_p_value + (1 - beta) * baseline_smoothed_progress_values[-1] # Enforce monotonic progress (prevent going down) global baseline_max_progress baseline_smoothed_p_value = max(baseline_smoothed_p_value, baseline_max_progress) baseline_max_progress = baseline_smoothed_p_value baseline_smoothed_progress_values.append(baseline_smoothed_p_value) # Update progress callbacks global baseline_progress_frozen if baseline_progress_callback and not baseline_progress_frozen and abs(baseline_smoothed_p_value - baseline_last_reported_progress) > 0.001: baseline_progress_callback(int(baseline_smoothed_p_value * 100)) baseline_last_reported_progress = baseline_smoothed_p_value elif baseline_progress_callback and baseline_progress_frozen: baseline_progress_callback(100.0) return original_output # Register hook for progress tracking if rnn_gru is not None and rnn_head is not None: try: target_module = model.model.norm hook_handles.append(target_module.register_forward_hook(baseline_forward_hook)) logging.info(f"Baseline progress hook registered successfully") except AttributeError: logging.warning("Could not find model.model.norm. Trying to find another appropriate layer...") try: for name, module in model.named_modules(): if 'norm' in name and 'model' in name and isinstance(module, torch.nn.Module): hook_handles.append(module.register_forward_hook(baseline_forward_hook)) logging.info(f"Baseline progress hook registered on {name}") break else: logging.warning("Could not find appropriate normalization layer for progress hook.") except Exception as e: logging.error(f"Error setting up progress hook: {str(e)}") try: # Initialize KV cache and input preparation past_key_values = DynamicCache() # Initial forward pass with torch.no_grad(): outputs = model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, use_cache=True, return_dict=True ) # Update cache with prompt's key-values past_key_values = outputs.past_key_values # Get logits and sample first token logits = outputs.logits[:, -1, :] # Last token's logits next_token_id = torch.argmax(logits, dim=-1).unsqueeze(-1) # Greedy sampling # Process token baseline_token_id = next_token_id[0].item() baseline_token_ids.append(baseline_token_id) baseline_new_token_text = tokenizer.decode([baseline_token_id]) baseline_generated_tokens.append(baseline_new_token_text) # Update callback with the first token if baseline_tokens_callback: baseline_current_text = tokenizer.decode(baseline_token_ids, skip_special_tokens=True) baseline_tokens_callback(baseline_current_text, len(baseline_token_ids)) # Check for EOS if baseline_token_id == tokenizer.eos_token_id: baseline_finished = True baseline_result = tokenizer.decode(baseline_token_ids, skip_special_tokens=True) baseline_total_tokens = len(baseline_token_ids) if baseline_progress_callback: baseline_progress_callback(100.0) logging.info(f"Baseline generation completed early with EOS. Total tokens: {baseline_total_tokens}") return baseline_result, baseline_total_tokens # Continue generation loop with cached states for step in range(1, max_new_tokens): # Check if we should stop if stop_event and stop_event.is_set(): logging.info(f"Generation stopped by user at step {step}") break # If finished, break if baseline_finished: logging.info(f"Baseline generation completed at step {step}") break # Prepare input token for next step current_input_id = next_token_id # [1, 1] # Extend attention mask for the new token if attention_mask is not None: new_attention = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype) attention_mask = torch.cat([attention_mask, new_attention], dim=1) # Forward pass with cached key-values with torch.no_grad(): outputs = model( input_ids=current_input_id, attention_mask=attention_mask, past_key_values=past_key_values, use_cache=True, return_dict=True ) # Update cache past_key_values = outputs.past_key_values # Get next token logits = outputs.logits[:, -1, :] next_token_id = torch.argmax(logits, dim=-1).unsqueeze(-1) # Process token baseline_token_id = next_token_id[0].item() baseline_token_ids.append(baseline_token_id) baseline_new_token_text = tokenizer.decode([baseline_token_id]) baseline_generated_tokens.append(baseline_new_token_text) # Update callback if baseline_tokens_callback: baseline_current_text = tokenizer.decode(baseline_token_ids, skip_special_tokens=True) baseline_tokens_callback(baseline_current_text, len(baseline_token_ids)) # Check if EOS token was generated if baseline_token_id == tokenizer.eos_token_id: baseline_finished = True logging.info(f"Baseline generation completed with EOS at step {step}. Total tokens: {len(baseline_token_ids)}") # Final decoding of all tokens baseline_result = tokenizer.decode(baseline_token_ids, skip_special_tokens=True) baseline_total_tokens = len(baseline_token_ids) # Ensure final progress is set to 100% if baseline_progress_callback: baseline_progress_callback(100.0) logging.info(f"Baseline generation completed. Total tokens: {baseline_total_tokens}") finally: for handle in hook_handles: handle.remove() if hook_handles: logging.info("All progress hooks removed") return baseline_result, baseline_total_tokens # Automatically load model and RNN on startup def initialize_app(): """Initialize the app by loading model and RNN on startup""" global model_loaded_successfully, model_loading_error logging.info("Starting automatic model loading...") try: result = load_model_and_rnn() if "successfully" in result: model_loaded_successfully = True logging.info("Model and RNN model loaded successfully on startup") else: model_loaded_successfully = False model_loading_error = result logging.error(f"Failed to load model on startup: {result}") except Exception as e: model_loaded_successfully = False model_loading_error = str(e) logging.error(f"Exception during model loading: {str(e)}") # Load model automatically when script starts initialize_app() # Global function for resetting UI state def reset_ui(): """Reset the UI elements for a new generation""" global baseline_think_tag_detected, baseline_progress_frozen global baseline_pre_think_content, baseline_post_think_content # Reset progress tracking for monotonic behavior reset_progress_tracking() baseline_think_tag_detected = False baseline_progress_frozen = False baseline_pre_think_content = "" baseline_post_think_content = "" return { "status": "**Starting generation...**", "progress": 0, "thinking": "", "answer": "", "tokens": "", "generate_btn_text": "Generating...", "generate_btn_interactive": False, "stop_btn_interactive": True } @spaces.GPU(duration=240) def generate_with_updates(prompt, baseline_progress_queue, baseline_tokens_queue, stop_generation): """Wrapper around generation function that handles real-time updates""" # Check if model is loaded if not model_loaded_successfully: yield { "status": f"**Cannot generate: {model_loading_error}**" } return # Use default values max_tokens = 2048 # Reset UI first yield reset_ui() # Start generation in a separate thread to allow for UI updates baseline_result = "" baseline_token_count = 0 generation_error = None generation_thread = None def baseline_progress_updater(prog_value): """Update the baseline progress via the queue""" baseline_progress_queue.put(prog_value) def baseline_tokens_updater(text, token_count): """Update the baseline generated text via the queue""" global baseline_think_tag_detected, baseline_progress_frozen, baseline_pre_think_content, baseline_post_think_content # Check if tag appears in the text if not baseline_think_tag_detected and "" in text: baseline_think_tag_detected = True baseline_progress_frozen = True # Split content at parts = text.split("", 1) baseline_pre_think_content = parts[0] + "" baseline_post_think_content = parts[1] if len(parts) > 1 else "" # Signal content split with token count baseline_tokens_queue.put(("THINK_TAG_DETECTED", baseline_pre_think_content, baseline_post_think_content, token_count)) elif baseline_think_tag_detected: # Update post-think content if "" in text: parts = text.split("", 1) baseline_post_think_content = parts[1] if len(parts) > 1 else "" baseline_tokens_queue.put(("POST_THINK_UPDATE", baseline_post_think_content)) else: baseline_tokens_queue.put(("NORMAL_UPDATE", text)) else: # Normal pre-think streaming with token count baseline_tokens_queue.put(("NORMAL_UPDATE", text, token_count)) def run_generation(): nonlocal baseline_result, baseline_token_count, generation_error try: # Baseline-only generation baseline_result, baseline_token_count = generate_baseline_only( prompt=prompt, max_new_tokens=max_tokens, baseline_progress_callback=baseline_progress_updater, baseline_tokens_callback=baseline_tokens_updater, stop_event=stop_generation ) except Exception as e: generation_error = str(e) # Start the generation thread generation_thread = threading.Thread(target=run_generation) generation_thread.start() # Monitor queues for updates while generation is running baseline_current_text = "" baseline_thinking_tokens = 0 baseline_last_progress = 0 while generation_thread.is_alive() or not baseline_tokens_queue.empty() or not baseline_progress_queue.empty(): updates = {} # Check baseline tokens queue try: while not baseline_tokens_queue.empty(): token_update = baseline_tokens_queue.get_nowait() if isinstance(token_update, tuple): update_type = token_update[0] if update_type == "THINK_TAG_DETECTED": # tag detected - split content pre_content = token_update[1] post_content = token_update[2] thinking_token_count = token_update[3] updates["thinking"] = pre_content updates["answer"] = post_content updates["progress"] = 100.0 # Freeze at 100% # Use actual token count (before ) baseline_thinking_tokens = thinking_token_count updates["tokens"] = f"{baseline_thinking_tokens}" elif update_type == "POST_THINK_UPDATE": # Update only the final answer post_content = token_update[1] updates["answer"] = post_content # Don't update token count - frozen at thinking tokens elif update_type == "NORMAL_UPDATE": # Normal text update baseline_current_text = token_update[1] if not baseline_think_tag_detected: updates["thinking"] = baseline_current_text # Update thinking token count with actual token count if available if len(token_update) > 2: baseline_thinking_tokens = token_update[2] else: # Fallback to word count for backward compatibility baseline_thinking_tokens = len(baseline_current_text.split()) updates["tokens"] = f"{baseline_thinking_tokens}" else: # This shouldn't happen, but handle it gracefully updates["answer"] = baseline_current_text else: # Backward compatibility - treat as normal text baseline_current_text = token_update updates["thinking"] = baseline_current_text if not baseline_think_tag_detected: baseline_thinking_tokens = len(baseline_current_text.split()) updates["tokens"] = f"{baseline_thinking_tokens}" except queue.Empty: pass # Check baseline progress queue try: while not baseline_progress_queue.empty(): baseline_last_progress = baseline_progress_queue.get_nowait() updates["progress"] = baseline_last_progress except queue.Empty: pass # If there are any updates, yield them if updates: yield updates # Sleep briefly to prevent excessive CPU usage time.sleep(0.05) # Final update final_updates = { "status": "**Generation complete!**" if not generation_error else f"**Error: {generation_error}**", "progress": 100, "generate_btn_text": "Generate", "generate_btn_interactive": True, "stop_btn_interactive": True } if not generation_error: # Handle baseline final display if baseline_think_tag_detected: # Split result for final display if "" in baseline_result: parts = baseline_result.split("", 1) final_updates["thinking"] = parts[0] + "" final_updates["answer"] = parts[1] if len(parts) > 1 else "" # Use actual token count from generation if baseline_thinking_tokens > 0: final_updates["tokens"] = f"{baseline_thinking_tokens}" else: # Fallback: use actual token count for thinking part thinking_text = parts[0] + "" thinking_token_count = len(tokenizer.encode(thinking_text, add_special_tokens=False)) final_updates["tokens"] = f"{thinking_token_count}" else: final_updates["thinking"] = baseline_result # Use actual token count if baseline_thinking_tokens > 0: final_updates["tokens"] = f"{baseline_thinking_tokens}" else: total_token_count = len(tokenizer.encode(baseline_result, add_special_tokens=False)) final_updates["tokens"] = f"{total_token_count}" else: final_updates["thinking"] = baseline_result # Use actual token count if baseline_thinking_tokens > 0: final_updates["tokens"] = f"{baseline_thinking_tokens}" else: total_token_count = len(tokenizer.encode(baseline_result, add_special_tokens=False)) final_updates["tokens"] = f"{total_token_count}" yield final_updates # Create the Gradio interface def create_interface(): # Create custom theme with light green progress bars custom_theme = gr.themes.Base().set( slider_color="#90EE90", # Light green slider_color_dark="#32CD32", # Lime green for dark mode # Additional slider styling color_accent="#90EE90", color_accent_soft="#E8F5E8" ) # Custom CSS to fix text box heights, enable scrolling, and enlarge fonts custom_css = """ /* Base font size increases for all text areas */ .gr-textbox textarea { font-size: 18px !important; line-height: 1.5em !important; overflow-y: auto !important; resize: none !important; } /* Prompt textbox - larger font */ .gr-textbox:has(textarea[placeholder*="math"]) textarea, .gr-textbox:has(textarea[placeholder*="Enter"]) textarea { font-size: 20px !important; line-height: 1.6em !important; } /* Thinking Process and Final Answer boxes - enhanced readability */ .fixed-height-textbox textarea { font-size: 16px !important; line-height: 1.5em !important; overflow-y: auto !important; resize: none !important; } /* Thinking Process boxes - adjusted height for larger fonts */ .fixed-height-textbox:has(textarea[data-testid*="textbox"]) textarea { height: 280px !important; max-height: 280px !important; min-height: 280px !important; } /* Progress bar labels and sliders */ .gr-slider .gr-label { font-size: 16px !important; font-weight: 500 !important; } /* All buttons - larger and more readable */ .gr-button { font-size: 16px !important; font-weight: 500 !important; padding: 8px 16px !important; } /* Token count displays - readable but not overwhelming */ .gr-textbox[data-testid*="textbox"][readonly] textarea, .gr-textbox.gr-readonly textarea { font-size: 14px !important; line-height: 1.4em !important; text-align: center !important; font-weight: 500 !important; } /* All labels - consistent sizing */ .gr-label { font-size: 15px !important; font-weight: 500 !important; } /* Status messages and markdown - larger and clearer */ .gr-markdown { font-size: 16px !important; line-height: 1.4em !important; } /* Checkbox labels */ .gr-checkbox .gr-label { font-size: 15px !important; } /* Headers and titles */ .gr-markdown h1 { font-size: 28px !important; } .gr-markdown h2 { font-size: 22px !important; } /* Progress percentage display */ .gr-slider .gr-number { font-size: 14px !important; font-weight: 500 !important; } /* Custom light green gradient for progress bars */ .gr-slider input[type="range"]::-webkit-slider-runnable-track { background: linear-gradient(90deg, #90EE90 0%, #32CD32 100%) !important; border: none !important; height: 8px !important; border-radius: 4px !important; } .gr-slider input[type="range"]::-moz-range-track { background: linear-gradient(90deg, #90EE90 0%, #32CD32 100%) !important; border: none !important; height: 8px !important; border-radius: 4px !important; } /* Webkit slider thumb styling */ .gr-slider input[type="range"]::-webkit-slider-thumb { background: #228B22 !important; border: 2px solid #ffffff !important; border-radius: 50% !important; cursor: pointer !important; } /* Firefox slider thumb styling */ .gr-slider input[type="range"]::-moz-range-thumb { background: #228B22 !important; border: 2px solid #ffffff !important; border-radius: 50% !important; cursor: pointer !important; } /* Progress bar container background */ .gr-slider { --slider-color: linear-gradient(90deg, #90EE90 0%, #32CD32 100%) !important; } /* Additional Gradio slider styling for better compatibility */ .gr-slider .gr-range { background: linear-gradient(90deg, #90EE90 0%, #32CD32 100%) !important; } /* Style the slider track background (unfilled portion) */ .gr-slider input[type="range"] { background: transparent !important; } .gr-slider input[type="range"]::-webkit-slider-runnable-track { background: linear-gradient(90deg, #E8E8E8 0%, #E8E8E8 var(--slider-percent, 0%), #90EE90 var(--slider-percent, 0%), #32CD32 100%) !important; } """ with gr.Blocks(title="Baseline Generation Demo", css=custom_css, theme=custom_theme) as demo: # Display model loading status in the header if model_loaded_successfully: status_message = "✅ **DeepSeek-R1-Distill-Qwen-32B model loaded and ready for generation**" else: status_message = f"❌ **Model loading failed: {model_loading_error}**" gr.Markdown(f""" # 🚀 Reasoning Loading Bar Space ⏳ {status_message} This Space demonstrates real-time progress tracking of a reasoning model. ## How it works: 1. Enter a prompt below - it works best with math problems that require reasoning 2. Click "Generate" to start generation with progress visualization """) # Generation section with gr.Row(): with gr.Column(scale=8): prompt = gr.Textbox( label="Prompt", placeholder="Enter your math problem or reasoning task here...", lines=3 ) with gr.Column(scale=1): generate_btn = gr.Button( "Generate", variant="primary" if model_loaded_successfully else "secondary", interactive=model_loaded_successfully ) stop_btn = gr.Button("Stop", variant="stop", interactive=model_loaded_successfully) # Generation progress and results with gr.Row(): generation_status = gr.Markdown("**Ready to generate**" if model_loaded_successfully else "**Model not loaded - cannot generate**") # Single progress section for baseline only with gr.Row(): with gr.Column(): with gr.Row(): baseline_progress_bar = gr.Slider( minimum=0, maximum=100, value=0, label="Generation Progress (%)", interactive=False, scale=5 ) baseline_tokens_count = gr.Textbox(label="Thinking tokens", value="", interactive=False, scale=1) baseline_thinking_output = gr.Textbox(label="🧠 Thinking Process", lines=10, value="", elem_classes=["fixed-height-textbox"]) baseline_answer_output = gr.Textbox(label="✅ Final Answer", lines=4, value="", elem_classes=["fixed-height-textbox"]) # Create queues to store progress and token updates baseline_progress_queue = queue.Queue() baseline_tokens_queue = queue.Queue() stop_generation = threading.Event() def stop_generation_fn(): """Stop the generation process""" stop_generation.set() return "Generation stopped" def generate_wrapper(prompt): """Wrapper to adapt the global generate_with_updates function for Gradio""" # Process updates from the global function and map to UI components for update_dict in generate_with_updates(prompt, baseline_progress_queue, baseline_tokens_queue, stop_generation): gradio_updates = {} # Map the string keys to actual Gradio components if "status" in update_dict: gradio_updates[generation_status] = update_dict["status"] if "progress" in update_dict: gradio_updates[baseline_progress_bar] = update_dict["progress"] if "thinking" in update_dict: gradio_updates[baseline_thinking_output] = update_dict["thinking"] if "answer" in update_dict: gradio_updates[baseline_answer_output] = update_dict["answer"] if "tokens" in update_dict: gradio_updates[baseline_tokens_count] = update_dict["tokens"] if "generate_btn_text" in update_dict: gradio_updates[generate_btn] = gr.Button( update_dict["generate_btn_text"], variant="secondary" if "Generating" in update_dict["generate_btn_text"] else "primary", interactive=update_dict.get("generate_btn_interactive", True) ) if "stop_btn_interactive" in update_dict: gradio_updates[stop_btn] = gr.Button( "Stop", variant="stop", interactive=update_dict["stop_btn_interactive"] ) yield gradio_updates # Connect the buttons to the handlers if model_loaded_successfully: generate_btn.click( generate_wrapper, inputs=[prompt], outputs=[ generation_status, baseline_progress_bar, baseline_thinking_output, baseline_answer_output, baseline_tokens_count, generate_btn, stop_btn ] ) stop_btn.click( stop_generation_fn, outputs=[generation_status] ) return demo # Launch the app if running directly if __name__ == "__main__": demo = create_interface() demo.launch(share=True)