royeis's picture
extract generate_with_updates to global scope
3903517
import torch
import threading
import time
import queue
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
import gradio as gr
import logging
from rnn_model import RNNSeqRegressorHub, D_FEATURES
import spaces
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Global variables for model and tokenizer only
model = None
tokenizer = None
# Global variables for RNN model components
rnn_model = None
rnn_gru = None
rnn_head = None
rnn_device = None
# Global state variables for </think> detection (baseline only)
baseline_think_tag_detected = False
baseline_pre_think_content = ""
baseline_post_think_content = ""
baseline_progress_frozen = False
# Global variables for monotonic progress tracking (prevent progress bars from going down)
baseline_max_progress = 0.0
# Model loading status
model_loaded_successfully = False
model_loading_error = None
def load_model(model_name, torch_dtype="float16"):
"""Load model and tokenizer"""
global model, tokenizer
logging.info(f"Loading tokenizer from: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logging.info("Set pad_token to eos_token.")
logging.info(f"Loading model from: {model_name}")
model_dtype = getattr(torch, torch_dtype)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=model_dtype,
low_cpu_mem_usage=True
)
logging.info(f"Model loaded successfully")
return "Model loaded successfully"
def load_rnn_model(rnn_path="rnn_seq_regressor.pt"):
"""Load RNN model for progress prediction"""
global model, rnn_model, rnn_gru, rnn_head, rnn_device
if model is None:
return "Please load the main model first"
try:
# Initialize RNN model with D_FEATURES hidden dimension
rnn_repo_id = "royeis/DeepSeek-R1-Distill-Qwen-32B-GRUThinkingProgressRegressor"
rnn_model = RNNSeqRegressorHub.from_pretrained(rnn_repo_id).to(model.device)
# Extract components for optimized access
rnn_gru = rnn_model.rnn
rnn_head = rnn_model.head
# Store RNN device for later use
rnn_device = next(rnn_model.parameters()).device
logging.info(f"RNN model loaded successfully from {rnn_path}")
logging.info(f"RNN model device: {rnn_device}")
return f"RNN model loaded successfully from {rnn_path}"
except Exception as e:
logging.error(f"Error loading RNN model: {str(e)}")
return f"Error loading RNN model: {str(e)}"
def reset_progress_tracking():
"""Reset progress tracking variables for monotonic progress enforcement"""
global baseline_max_progress
baseline_max_progress = 0.0
logging.info("Progress tracking variables reset for new generation")
@spaces.GPU(duration=120)
def load_model_and_rnn():
"""Load both model and RNN model with hardcoded defaults"""
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
rnn_path = "rnn_seq_regressor.pt"
# Load model first
model_status = load_model(model_name)
if "successfully" not in model_status:
return f"Model loading failed: {model_status}"
# Then load RNN model
rnn_status = load_rnn_model(rnn_path)
if "successfully" not in rnn_status:
return f"RNN model loading failed: {rnn_status}"
return "Model and RNN model loaded successfully"
def generate_baseline_only(
prompt,
max_new_tokens=1024,
baseline_progress_callback=None,
baseline_tokens_callback=None,
stop_event=None
):
"""Generates tokens with baseline only (no intervention)."""
global model, tokenizer
if model is None or tokenizer is None:
return "Please load a model and tokenizer first", 0
# Always add thinking tag if not already present
if "<think>" not in prompt:
prompt = prompt + "\nPlease reason step by step, and put your final answer within \\boxed{{}}. \n<think>\n"
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask if hasattr(inputs, 'attention_mask') else torch.ones_like(input_ids)
# Progress tracking
baseline_progress_values = []
baseline_smoothed_progress_values = [0]
beta = 0.05
baseline_generated_tokens = []
baseline_token_ids = []
# Track current progress for real-time monitoring
baseline_last_reported_progress = 0
# EOS tracking
baseline_finished = False
hook_handles = []
# RNN state for baseline progress tracking
baseline_rnn_state = {}
def baseline_forward_hook(module, input_args, output):
"""Hook that runs on every forward pass through the targeted layer for baseline only"""
nonlocal baseline_last_reported_progress
original_output = output
hidden_states_to_check = None
if isinstance(output, tuple):
hidden_states_to_check = output[0]
elif isinstance(output, torch.Tensor):
hidden_states_to_check = output
else:
logging.warning(f"Unexpected output type from hooked layer: {type(output)}")
return output
if hidden_states_to_check is None or not isinstance(hidden_states_to_check, torch.Tensor):
logging.warning("Hooked layer output does not contain a tensor of hidden states as expected.")
return output
# Calculate predicted progress for baseline using RNN
baseline_p_value = 0.0
if rnn_gru is not None and rnn_head is not None and rnn_device is not None:
try:
with torch.no_grad():
# Convert to float and move to RNN device for processing
hidden_f = hidden_states_to_check[:, -1, :].clone().float().to(rnn_device)
# Initialize RNN state if not exists
if len(baseline_rnn_state) == 0:
baseline_rnn_state['state'] = hidden_f * 0.0
# RNN forward pass
out_t, baseline_rnn_state['state'] = rnn_gru(hidden_f, baseline_rnn_state['state'])
baseline_p_value = rnn_head(out_t).squeeze(-1)
baseline_p_value = max(0.0, min(1.0, baseline_p_value.item()))
except Exception as e:
logging.warning(f"RNN computation failed: {str(e)}")
baseline_p_value = 0.0
# Store the progress values
baseline_progress_values.append(baseline_p_value)
# Store smoothed progress for real-time updates
baseline_smoothed_p_value = beta * baseline_p_value + (1 - beta) * baseline_smoothed_progress_values[-1]
# Enforce monotonic progress (prevent going down)
global baseline_max_progress
baseline_smoothed_p_value = max(baseline_smoothed_p_value, baseline_max_progress)
baseline_max_progress = baseline_smoothed_p_value
baseline_smoothed_progress_values.append(baseline_smoothed_p_value)
# Update progress callbacks
global baseline_progress_frozen
if baseline_progress_callback and not baseline_progress_frozen and abs(baseline_smoothed_p_value - baseline_last_reported_progress) > 0.001:
baseline_progress_callback(int(baseline_smoothed_p_value * 100))
baseline_last_reported_progress = baseline_smoothed_p_value
elif baseline_progress_callback and baseline_progress_frozen:
baseline_progress_callback(100.0)
return original_output
# Register hook for progress tracking
if rnn_gru is not None and rnn_head is not None:
try:
target_module = model.model.norm
hook_handles.append(target_module.register_forward_hook(baseline_forward_hook))
logging.info(f"Baseline progress hook registered successfully")
except AttributeError:
logging.warning("Could not find model.model.norm. Trying to find another appropriate layer...")
try:
for name, module in model.named_modules():
if 'norm' in name and 'model' in name and isinstance(module, torch.nn.Module):
hook_handles.append(module.register_forward_hook(baseline_forward_hook))
logging.info(f"Baseline progress hook registered on {name}")
break
else:
logging.warning("Could not find appropriate normalization layer for progress hook.")
except Exception as e:
logging.error(f"Error setting up progress hook: {str(e)}")
try:
# Initialize KV cache and input preparation
past_key_values = DynamicCache()
# Initial forward pass
with torch.no_grad():
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
return_dict=True
)
# Update cache with prompt's key-values
past_key_values = outputs.past_key_values
# Get logits and sample first token
logits = outputs.logits[:, -1, :] # Last token's logits
next_token_id = torch.argmax(logits, dim=-1).unsqueeze(-1) # Greedy sampling
# Process token
baseline_token_id = next_token_id[0].item()
baseline_token_ids.append(baseline_token_id)
baseline_new_token_text = tokenizer.decode([baseline_token_id])
baseline_generated_tokens.append(baseline_new_token_text)
# Update callback with the first token
if baseline_tokens_callback:
baseline_current_text = tokenizer.decode(baseline_token_ids, skip_special_tokens=True)
baseline_tokens_callback(baseline_current_text, len(baseline_token_ids))
# Check for EOS
if baseline_token_id == tokenizer.eos_token_id:
baseline_finished = True
baseline_result = tokenizer.decode(baseline_token_ids, skip_special_tokens=True)
baseline_total_tokens = len(baseline_token_ids)
if baseline_progress_callback:
baseline_progress_callback(100.0)
logging.info(f"Baseline generation completed early with EOS. Total tokens: {baseline_total_tokens}")
return baseline_result, baseline_total_tokens
# Continue generation loop with cached states
for step in range(1, max_new_tokens):
# Check if we should stop
if stop_event and stop_event.is_set():
logging.info(f"Generation stopped by user at step {step}")
break
# If finished, break
if baseline_finished:
logging.info(f"Baseline generation completed at step {step}")
break
# Prepare input token for next step
current_input_id = next_token_id # [1, 1]
# Extend attention mask for the new token
if attention_mask is not None:
new_attention = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)
attention_mask = torch.cat([attention_mask, new_attention], dim=1)
# Forward pass with cached key-values
with torch.no_grad():
outputs = model(
input_ids=current_input_id,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
return_dict=True
)
# Update cache
past_key_values = outputs.past_key_values
# Get next token
logits = outputs.logits[:, -1, :]
next_token_id = torch.argmax(logits, dim=-1).unsqueeze(-1)
# Process token
baseline_token_id = next_token_id[0].item()
baseline_token_ids.append(baseline_token_id)
baseline_new_token_text = tokenizer.decode([baseline_token_id])
baseline_generated_tokens.append(baseline_new_token_text)
# Update callback
if baseline_tokens_callback:
baseline_current_text = tokenizer.decode(baseline_token_ids, skip_special_tokens=True)
baseline_tokens_callback(baseline_current_text, len(baseline_token_ids))
# Check if EOS token was generated
if baseline_token_id == tokenizer.eos_token_id:
baseline_finished = True
logging.info(f"Baseline generation completed with EOS at step {step}. Total tokens: {len(baseline_token_ids)}")
# Final decoding of all tokens
baseline_result = tokenizer.decode(baseline_token_ids, skip_special_tokens=True)
baseline_total_tokens = len(baseline_token_ids)
# Ensure final progress is set to 100%
if baseline_progress_callback:
baseline_progress_callback(100.0)
logging.info(f"Baseline generation completed. Total tokens: {baseline_total_tokens}")
finally:
for handle in hook_handles:
handle.remove()
if hook_handles:
logging.info("All progress hooks removed")
return baseline_result, baseline_total_tokens
# Automatically load model and RNN on startup
def initialize_app():
"""Initialize the app by loading model and RNN on startup"""
global model_loaded_successfully, model_loading_error
logging.info("Starting automatic model loading...")
try:
result = load_model_and_rnn()
if "successfully" in result:
model_loaded_successfully = True
logging.info("Model and RNN model loaded successfully on startup")
else:
model_loaded_successfully = False
model_loading_error = result
logging.error(f"Failed to load model on startup: {result}")
except Exception as e:
model_loaded_successfully = False
model_loading_error = str(e)
logging.error(f"Exception during model loading: {str(e)}")
# Load model automatically when script starts
initialize_app()
# Global function for resetting UI state
def reset_ui():
"""Reset the UI elements for a new generation"""
global baseline_think_tag_detected, baseline_progress_frozen
global baseline_pre_think_content, baseline_post_think_content
# Reset progress tracking for monotonic behavior
reset_progress_tracking()
baseline_think_tag_detected = False
baseline_progress_frozen = False
baseline_pre_think_content = ""
baseline_post_think_content = ""
return {
"status": "**Starting generation...**",
"progress": 0,
"thinking": "",
"answer": "",
"tokens": "",
"generate_btn_text": "Generating...",
"generate_btn_interactive": False,
"stop_btn_interactive": True
}
@spaces.GPU(duration=240)
def generate_with_updates(prompt, baseline_progress_queue, baseline_tokens_queue, stop_generation):
"""Wrapper around generation function that handles real-time updates"""
# Check if model is loaded
if not model_loaded_successfully:
yield {
"status": f"**Cannot generate: {model_loading_error}**"
}
return
# Use default values
max_tokens = 2048
# Reset UI first
yield reset_ui()
# Start generation in a separate thread to allow for UI updates
baseline_result = ""
baseline_token_count = 0
generation_error = None
generation_thread = None
def baseline_progress_updater(prog_value):
"""Update the baseline progress via the queue"""
baseline_progress_queue.put(prog_value)
def baseline_tokens_updater(text, token_count):
"""Update the baseline generated text via the queue"""
global baseline_think_tag_detected, baseline_progress_frozen, baseline_pre_think_content, baseline_post_think_content
# Check if </think> tag appears in the text
if not baseline_think_tag_detected and "</think>" in text:
baseline_think_tag_detected = True
baseline_progress_frozen = True
# Split content at </think>
parts = text.split("</think>", 1)
baseline_pre_think_content = parts[0] + "</think>"
baseline_post_think_content = parts[1] if len(parts) > 1 else ""
# Signal content split with token count
baseline_tokens_queue.put(("THINK_TAG_DETECTED", baseline_pre_think_content, baseline_post_think_content, token_count))
elif baseline_think_tag_detected:
# Update post-think content
if "</think>" in text:
parts = text.split("</think>", 1)
baseline_post_think_content = parts[1] if len(parts) > 1 else ""
baseline_tokens_queue.put(("POST_THINK_UPDATE", baseline_post_think_content))
else:
baseline_tokens_queue.put(("NORMAL_UPDATE", text))
else:
# Normal pre-think streaming with token count
baseline_tokens_queue.put(("NORMAL_UPDATE", text, token_count))
def run_generation():
nonlocal baseline_result, baseline_token_count, generation_error
try:
# Baseline-only generation
baseline_result, baseline_token_count = generate_baseline_only(
prompt=prompt,
max_new_tokens=max_tokens,
baseline_progress_callback=baseline_progress_updater,
baseline_tokens_callback=baseline_tokens_updater,
stop_event=stop_generation
)
except Exception as e:
generation_error = str(e)
# Start the generation thread
generation_thread = threading.Thread(target=run_generation)
generation_thread.start()
# Monitor queues for updates while generation is running
baseline_current_text = ""
baseline_thinking_tokens = 0
baseline_last_progress = 0
while generation_thread.is_alive() or not baseline_tokens_queue.empty() or not baseline_progress_queue.empty():
updates = {}
# Check baseline tokens queue
try:
while not baseline_tokens_queue.empty():
token_update = baseline_tokens_queue.get_nowait()
if isinstance(token_update, tuple):
update_type = token_update[0]
if update_type == "THINK_TAG_DETECTED":
# </think> tag detected - split content
pre_content = token_update[1]
post_content = token_update[2]
thinking_token_count = token_update[3]
updates["thinking"] = pre_content
updates["answer"] = post_content
updates["progress"] = 100.0 # Freeze at 100%
# Use actual token count (before </think>)
baseline_thinking_tokens = thinking_token_count
updates["tokens"] = f"{baseline_thinking_tokens}"
elif update_type == "POST_THINK_UPDATE":
# Update only the final answer
post_content = token_update[1]
updates["answer"] = post_content
# Don't update token count - frozen at thinking tokens
elif update_type == "NORMAL_UPDATE":
# Normal text update
baseline_current_text = token_update[1]
if not baseline_think_tag_detected:
updates["thinking"] = baseline_current_text
# Update thinking token count with actual token count if available
if len(token_update) > 2:
baseline_thinking_tokens = token_update[2]
else:
# Fallback to word count for backward compatibility
baseline_thinking_tokens = len(baseline_current_text.split())
updates["tokens"] = f"{baseline_thinking_tokens}"
else:
# This shouldn't happen, but handle it gracefully
updates["answer"] = baseline_current_text
else:
# Backward compatibility - treat as normal text
baseline_current_text = token_update
updates["thinking"] = baseline_current_text
if not baseline_think_tag_detected:
baseline_thinking_tokens = len(baseline_current_text.split())
updates["tokens"] = f"{baseline_thinking_tokens}"
except queue.Empty:
pass
# Check baseline progress queue
try:
while not baseline_progress_queue.empty():
baseline_last_progress = baseline_progress_queue.get_nowait()
updates["progress"] = baseline_last_progress
except queue.Empty:
pass
# If there are any updates, yield them
if updates:
yield updates
# Sleep briefly to prevent excessive CPU usage
time.sleep(0.05)
# Final update
final_updates = {
"status": "**Generation complete!**" if not generation_error else f"**Error: {generation_error}**",
"progress": 100,
"generate_btn_text": "Generate",
"generate_btn_interactive": True,
"stop_btn_interactive": True
}
if not generation_error:
# Handle baseline final display
if baseline_think_tag_detected:
# Split result for final display
if "</think>" in baseline_result:
parts = baseline_result.split("</think>", 1)
final_updates["thinking"] = parts[0] + "</think>"
final_updates["answer"] = parts[1] if len(parts) > 1 else ""
# Use actual token count from generation
if baseline_thinking_tokens > 0:
final_updates["tokens"] = f"{baseline_thinking_tokens}"
else:
# Fallback: use actual token count for thinking part
thinking_text = parts[0] + "</think>"
thinking_token_count = len(tokenizer.encode(thinking_text, add_special_tokens=False))
final_updates["tokens"] = f"{thinking_token_count}"
else:
final_updates["thinking"] = baseline_result
# Use actual token count
if baseline_thinking_tokens > 0:
final_updates["tokens"] = f"{baseline_thinking_tokens}"
else:
total_token_count = len(tokenizer.encode(baseline_result, add_special_tokens=False))
final_updates["tokens"] = f"{total_token_count}"
else:
final_updates["thinking"] = baseline_result
# Use actual token count
if baseline_thinking_tokens > 0:
final_updates["tokens"] = f"{baseline_thinking_tokens}"
else:
total_token_count = len(tokenizer.encode(baseline_result, add_special_tokens=False))
final_updates["tokens"] = f"{total_token_count}"
yield final_updates
# Create the Gradio interface
def create_interface():
# Create custom theme with light green progress bars
custom_theme = gr.themes.Base().set(
slider_color="#90EE90", # Light green
slider_color_dark="#32CD32", # Lime green for dark mode
# Additional slider styling
color_accent="#90EE90",
color_accent_soft="#E8F5E8"
)
# Custom CSS to fix text box heights, enable scrolling, and enlarge fonts
custom_css = """
/* Base font size increases for all text areas */
.gr-textbox textarea {
font-size: 18px !important;
line-height: 1.5em !important;
overflow-y: auto !important;
resize: none !important;
}
/* Prompt textbox - larger font */
.gr-textbox:has(textarea[placeholder*="math"]) textarea,
.gr-textbox:has(textarea[placeholder*="Enter"]) textarea {
font-size: 20px !important;
line-height: 1.6em !important;
}
/* Thinking Process and Final Answer boxes - enhanced readability */
.fixed-height-textbox textarea {
font-size: 16px !important;
line-height: 1.5em !important;
overflow-y: auto !important;
resize: none !important;
}
/* Thinking Process boxes - adjusted height for larger fonts */
.fixed-height-textbox:has(textarea[data-testid*="textbox"]) textarea {
height: 280px !important;
max-height: 280px !important;
min-height: 280px !important;
}
/* Progress bar labels and sliders */
.gr-slider .gr-label {
font-size: 16px !important;
font-weight: 500 !important;
}
/* All buttons - larger and more readable */
.gr-button {
font-size: 16px !important;
font-weight: 500 !important;
padding: 8px 16px !important;
}
/* Token count displays - readable but not overwhelming */
.gr-textbox[data-testid*="textbox"][readonly] textarea,
.gr-textbox.gr-readonly textarea {
font-size: 14px !important;
line-height: 1.4em !important;
text-align: center !important;
font-weight: 500 !important;
}
/* All labels - consistent sizing */
.gr-label {
font-size: 15px !important;
font-weight: 500 !important;
}
/* Status messages and markdown - larger and clearer */
.gr-markdown {
font-size: 16px !important;
line-height: 1.4em !important;
}
/* Checkbox labels */
.gr-checkbox .gr-label {
font-size: 15px !important;
}
/* Headers and titles */
.gr-markdown h1 {
font-size: 28px !important;
}
.gr-markdown h2 {
font-size: 22px !important;
}
/* Progress percentage display */
.gr-slider .gr-number {
font-size: 14px !important;
font-weight: 500 !important;
}
/* Custom light green gradient for progress bars */
.gr-slider input[type="range"]::-webkit-slider-runnable-track {
background: linear-gradient(90deg, #90EE90 0%, #32CD32 100%) !important;
border: none !important;
height: 8px !important;
border-radius: 4px !important;
}
.gr-slider input[type="range"]::-moz-range-track {
background: linear-gradient(90deg, #90EE90 0%, #32CD32 100%) !important;
border: none !important;
height: 8px !important;
border-radius: 4px !important;
}
/* Webkit slider thumb styling */
.gr-slider input[type="range"]::-webkit-slider-thumb {
background: #228B22 !important;
border: 2px solid #ffffff !important;
border-radius: 50% !important;
cursor: pointer !important;
}
/* Firefox slider thumb styling */
.gr-slider input[type="range"]::-moz-range-thumb {
background: #228B22 !important;
border: 2px solid #ffffff !important;
border-radius: 50% !important;
cursor: pointer !important;
}
/* Progress bar container background */
.gr-slider {
--slider-color: linear-gradient(90deg, #90EE90 0%, #32CD32 100%) !important;
}
/* Additional Gradio slider styling for better compatibility */
.gr-slider .gr-range {
background: linear-gradient(90deg, #90EE90 0%, #32CD32 100%) !important;
}
/* Style the slider track background (unfilled portion) */
.gr-slider input[type="range"] {
background: transparent !important;
}
.gr-slider input[type="range"]::-webkit-slider-runnable-track {
background: linear-gradient(90deg, #E8E8E8 0%, #E8E8E8 var(--slider-percent, 0%), #90EE90 var(--slider-percent, 0%), #32CD32 100%) !important;
}
"""
with gr.Blocks(title="Baseline Generation Demo", css=custom_css, theme=custom_theme) as demo:
# Display model loading status in the header
if model_loaded_successfully:
status_message = "✅ **DeepSeek-R1-Distill-Qwen-32B model loaded and ready for generation**"
else:
status_message = f"❌ **Model loading failed: {model_loading_error}**"
gr.Markdown(f"""
# 🚀 Reasoning Loading Bar Space ⏳
{status_message}
This Space demonstrates real-time progress tracking of a reasoning model.
## How it works:
1. Enter a prompt below - it works best with math problems that require reasoning
2. Click "Generate" to start generation with progress visualization
""")
# Generation section
with gr.Row():
with gr.Column(scale=8):
prompt = gr.Textbox(
label="Prompt",
placeholder="Enter your math problem or reasoning task here...",
lines=3
)
with gr.Column(scale=1):
generate_btn = gr.Button(
"Generate",
variant="primary" if model_loaded_successfully else "secondary",
interactive=model_loaded_successfully
)
stop_btn = gr.Button("Stop", variant="stop", interactive=model_loaded_successfully)
# Generation progress and results
with gr.Row():
generation_status = gr.Markdown("**Ready to generate**" if model_loaded_successfully else "**Model not loaded - cannot generate**")
# Single progress section for baseline only
with gr.Row():
with gr.Column():
with gr.Row():
baseline_progress_bar = gr.Slider(
minimum=0,
maximum=100,
value=0,
label="Generation Progress (%)",
interactive=False,
scale=5
)
baseline_tokens_count = gr.Textbox(label="Thinking tokens", value="", interactive=False, scale=1)
baseline_thinking_output = gr.Textbox(label="🧠 Thinking Process", lines=10, value="", elem_classes=["fixed-height-textbox"])
baseline_answer_output = gr.Textbox(label="✅ Final Answer", lines=4, value="", elem_classes=["fixed-height-textbox"])
# Create queues to store progress and token updates
baseline_progress_queue = queue.Queue()
baseline_tokens_queue = queue.Queue()
stop_generation = threading.Event()
def stop_generation_fn():
"""Stop the generation process"""
stop_generation.set()
return "Generation stopped"
def generate_wrapper(prompt):
"""Wrapper to adapt the global generate_with_updates function for Gradio"""
# Process updates from the global function and map to UI components
for update_dict in generate_with_updates(prompt, baseline_progress_queue, baseline_tokens_queue, stop_generation):
gradio_updates = {}
# Map the string keys to actual Gradio components
if "status" in update_dict:
gradio_updates[generation_status] = update_dict["status"]
if "progress" in update_dict:
gradio_updates[baseline_progress_bar] = update_dict["progress"]
if "thinking" in update_dict:
gradio_updates[baseline_thinking_output] = update_dict["thinking"]
if "answer" in update_dict:
gradio_updates[baseline_answer_output] = update_dict["answer"]
if "tokens" in update_dict:
gradio_updates[baseline_tokens_count] = update_dict["tokens"]
if "generate_btn_text" in update_dict:
gradio_updates[generate_btn] = gr.Button(
update_dict["generate_btn_text"],
variant="secondary" if "Generating" in update_dict["generate_btn_text"] else "primary",
interactive=update_dict.get("generate_btn_interactive", True)
)
if "stop_btn_interactive" in update_dict:
gradio_updates[stop_btn] = gr.Button(
"Stop",
variant="stop",
interactive=update_dict["stop_btn_interactive"]
)
yield gradio_updates
# Connect the buttons to the handlers
if model_loaded_successfully:
generate_btn.click(
generate_wrapper,
inputs=[prompt],
outputs=[
generation_status,
baseline_progress_bar,
baseline_thinking_output,
baseline_answer_output,
baseline_tokens_count,
generate_btn,
stop_btn
]
)
stop_btn.click(
stop_generation_fn,
outputs=[generation_status]
)
return demo
# Launch the app if running directly
if __name__ == "__main__":
demo = create_interface()
demo.launch(share=True)