altpuppet commited on
Commit
6cc66f0
·
1 Parent(s): 055af68

Fix ZeroGPU timeout issue - extend duration and optimize model loading

Browse files
Files changed (1) hide show
  1. app.py +16 -7
app.py CHANGED
@@ -457,26 +457,29 @@ def create_gradio_app():
457
 
458
  return metrics
459
 
460
- @spaces.GPU
461
  def run_gpu_inference(history_values_tensor, future_values_tensor, start, freq_object):
462
  """
463
  GPU-only inference function for ZeroGPU Spaces.
464
  ALL CUDA operations must happen inside this decorated function.
 
465
  """
466
  global model
467
 
468
- # Load model once on first call
469
  if model is None:
470
  print("--- Loading TempoPFN model for the first time ---")
471
- device = torch.device("cuda:0")
472
  print(f"Downloading model...")
473
  model_path = hf_hub_download(repo_id="AutoML-org/TempoPFN", filename="models/checkpoint_38M.pth")
474
- print(f"Loading model from {model_path} to {device}...")
475
- model = load_model(config_path="configs/example.yaml", model_path=model_path, device=device)
476
- print("--- Model loaded successfully ---")
 
477
 
478
- # Move tensors to GPU inside the decorated function
479
  device = torch.device("cuda:0")
 
 
480
 
481
  # Prepare container with GPU tensors
482
  container = BatchTimeSeriesContainer(
@@ -487,9 +490,14 @@ def create_gradio_app():
487
  )
488
 
489
  # Run inference with bfloat16 autocast
 
490
  with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
491
  model_output = model(container)
492
 
 
 
 
 
493
  return model_output
494
 
495
  def forecast_time_series(data_source, stock_ticker, uploaded_file, forecast_horizon, history_length, seed, synth_generator="Sine Waves", synth_complexity=5):
@@ -760,6 +768,7 @@ def create_gradio_app():
760
  with gr.Blocks(title="TempoPFN") as app:
761
 
762
  gr.Markdown("# TempoPFN\n### Zero-Shot Forecasting & Analysis Terminal\n*Powered by synthetic pre-training • Forecast anything, anywhere*")
 
763
 
764
  with gr.Tabs() as tabs:
765
 
 
457
 
458
  return metrics
459
 
460
+ @spaces.GPU(duration=120) # Extend timeout to 120 seconds for first-run compilation
461
  def run_gpu_inference(history_values_tensor, future_values_tensor, start, freq_object):
462
  """
463
  GPU-only inference function for ZeroGPU Spaces.
464
  ALL CUDA operations must happen inside this decorated function.
465
+ Extended timeout for Triton kernel compilation on first run.
466
  """
467
  global model
468
 
469
+ # Load model once on first call (on CPU first to save GPU time)
470
  if model is None:
471
  print("--- Loading TempoPFN model for the first time ---")
 
472
  print(f"Downloading model...")
473
  model_path = hf_hub_download(repo_id="AutoML-org/TempoPFN", filename="models/checkpoint_38M.pth")
474
+ # Load on CPU first to save GPU allocation time
475
+ print(f"Loading model from {model_path} to CPU first...")
476
+ model = load_model(config_path="configs/example.yaml", model_path=model_path, device=torch.device("cpu"))
477
+ print("--- Model loaded successfully on CPU ---")
478
 
479
+ # Move model to GPU inside the decorated function
480
  device = torch.device("cuda:0")
481
+ print(f"Moving model to {device}...")
482
+ model.to(device)
483
 
484
  # Prepare container with GPU tensors
485
  container = BatchTimeSeriesContainer(
 
490
  )
491
 
492
  # Run inference with bfloat16 autocast
493
+ print("Running inference...")
494
  with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
495
  model_output = model(container)
496
 
497
+ # Move model back to CPU to free GPU memory
498
+ model.to(torch.device("cpu"))
499
+ print("Inference complete, model moved back to CPU")
500
+
501
  return model_output
502
 
503
  def forecast_time_series(data_source, stock_ticker, uploaded_file, forecast_horizon, history_length, seed, synth_generator="Sine Waves", synth_complexity=5):
 
768
  with gr.Blocks(title="TempoPFN") as app:
769
 
770
  gr.Markdown("# TempoPFN\n### Zero-Shot Forecasting & Analysis Terminal\n*Powered by synthetic pre-training • Forecast anything, anywhere*")
771
+ gr.Markdown("⚠️ **First Run Note**: Initial inference may take 60-90 seconds due to Triton kernel compilation. Subsequent runs will be much faster!")
772
 
773
  with gr.Tabs() as tabs:
774