import torch import base64 import io import os from typing import Optional from PIL import Image from fastapi import FastAPI, HTTPException from pydantic import BaseModel app = FastAPI() # Global pipeline pipe = None export_to_video = None class InferenceRequest(BaseModel): image: str # base64 or URL prompt: str negative_prompt: str = "ugly, static, blurry, low quality" num_frames: int = 93 num_inference_steps: int = 35 guidance_scale: float = 7.0 seed: Optional[int] = None @app.on_event("startup") async def load_model(): global pipe, export_to_video from diffusers import Cosmos2VideoToWorldPipeline from diffusers.utils import export_to_video as etv export_to_video = etv model_id = "nvidia/Cosmos-Predict2-2B-Video2World" print("Loading model...") pipe = Cosmos2VideoToWorldPipeline.from_pretrained( model_id, torch_dtype=torch.bfloat16, token=os.environ.get("HF_TOKEN"), ) pipe.to("cuda") print("Model loaded successfully!") @app.post("/predict") @app.post("/") async def predict(request: dict): global pipe, export_to_video # Handle both direct and nested input formats inputs = request.get("inputs", request) image_data = inputs.get("image") if not image_data: raise HTTPException(status_code=400, detail="No image provided") prompt = inputs.get("prompt", "") if not prompt: raise HTTPException(status_code=400, detail="No prompt provided") # Load image try: if image_data.startswith("http"): from diffusers.utils import load_image image = load_image(image_data) else: image_bytes = base64.b64decode(image_data) image = Image.open(io.BytesIO(image_bytes)).convert("RGB") # Resize to expected dimensions for Cosmos Video2World image = image.resize((1280, 704)) except Exception as e: raise HTTPException(status_code=400, detail=f"Failed to load image: {str(e)}") negative_prompt = inputs.get("negative_prompt", "ugly, static, blurry, low quality") num_frames = inputs.get("num_frames", 93) num_inference_steps = inputs.get("num_inference_steps", 35) guidance_scale = inputs.get("guidance_scale", 7.0) seed = inputs.get("seed") # Create generator on correct device generator = None if seed is not None: generator = torch.Generator(device="cuda").manual_seed(int(seed)) try: with torch.cuda.amp.autocast(dtype=torch.bfloat16): output = pipe( image=image, prompt=prompt, negative_prompt=negative_prompt, num_frames=num_frames, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator, ) video_path = "/tmp/output.mp4" export_to_video(output.frames[0], video_path, fps=16) with open(video_path, "rb") as f: video_b64 = base64.b64encode(f.read()).decode("utf-8") return {"video": video_b64, "content_type": "video/mp4"} except Exception as e: import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}") @app.get("/health") @app.get("/") async def health(): return {"status": "healthy", "message": "Cosmos-Predict2 Video2World API"}