|
|
import torch |
|
|
import base64 |
|
|
import io |
|
|
import os |
|
|
from typing import Optional |
|
|
from PIL import Image |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
pipe = None |
|
|
export_to_video = None |
|
|
|
|
|
class InferenceRequest(BaseModel): |
|
|
image: str |
|
|
prompt: str |
|
|
negative_prompt: str = "ugly, static, blurry, low quality" |
|
|
num_frames: int = 93 |
|
|
num_inference_steps: int = 35 |
|
|
guidance_scale: float = 7.0 |
|
|
seed: Optional[int] = None |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def load_model(): |
|
|
global pipe, export_to_video |
|
|
from diffusers import Cosmos2VideoToWorldPipeline |
|
|
from diffusers.utils import export_to_video as etv |
|
|
|
|
|
export_to_video = etv |
|
|
model_id = "nvidia/Cosmos-Predict2-2B-Video2World" |
|
|
|
|
|
print("Loading model...") |
|
|
pipe = Cosmos2VideoToWorldPipeline.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.bfloat16, |
|
|
token=os.environ.get("HF_TOKEN"), |
|
|
) |
|
|
pipe.to("cuda") |
|
|
print("Model loaded successfully!") |
|
|
|
|
|
@app.post("/predict") |
|
|
@app.post("/") |
|
|
async def predict(request: dict): |
|
|
global pipe, export_to_video |
|
|
|
|
|
|
|
|
inputs = request.get("inputs", request) |
|
|
|
|
|
image_data = inputs.get("image") |
|
|
if not image_data: |
|
|
raise HTTPException(status_code=400, detail="No image provided") |
|
|
|
|
|
prompt = inputs.get("prompt", "") |
|
|
if not prompt: |
|
|
raise HTTPException(status_code=400, detail="No prompt provided") |
|
|
|
|
|
|
|
|
try: |
|
|
if image_data.startswith("http"): |
|
|
from diffusers.utils import load_image |
|
|
image = load_image(image_data) |
|
|
else: |
|
|
image_bytes = base64.b64decode(image_data) |
|
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
|
|
|
|
|
|
image = image.resize((1280, 704)) |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=400, detail=f"Failed to load image: {str(e)}") |
|
|
|
|
|
negative_prompt = inputs.get("negative_prompt", "ugly, static, blurry, low quality") |
|
|
num_frames = inputs.get("num_frames", 93) |
|
|
num_inference_steps = inputs.get("num_inference_steps", 35) |
|
|
guidance_scale = inputs.get("guidance_scale", 7.0) |
|
|
seed = inputs.get("seed") |
|
|
|
|
|
|
|
|
generator = None |
|
|
if seed is not None: |
|
|
generator = torch.Generator(device="cuda").manual_seed(int(seed)) |
|
|
|
|
|
try: |
|
|
with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
|
|
output = pipe( |
|
|
image=image, |
|
|
prompt=prompt, |
|
|
negative_prompt=negative_prompt, |
|
|
num_frames=num_frames, |
|
|
num_inference_steps=num_inference_steps, |
|
|
guidance_scale=guidance_scale, |
|
|
generator=generator, |
|
|
) |
|
|
|
|
|
video_path = "/tmp/output.mp4" |
|
|
export_to_video(output.frames[0], video_path, fps=16) |
|
|
|
|
|
with open(video_path, "rb") as f: |
|
|
video_b64 = base64.b64encode(f.read()).decode("utf-8") |
|
|
|
|
|
return {"video": video_b64, "content_type": "video/mp4"} |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}") |
|
|
|
|
|
@app.get("/health") |
|
|
@app.get("/") |
|
|
async def health(): |
|
|
return {"status": "healthy", "message": "Cosmos-Predict2 Video2World API"} |
|
|
|