luquiT4
/

DolphinInference / handler.py
luquiT4's picture
Create handler.py
f21c632 verified
Raw
History Blame Contribute Delete
2.86 kB
import base64
import io
from typing import Dict, Any
import torch
from PIL import Image
from transformers import AutoProcessor, VisionEncoderDecoderModel
class EndpointHandler:
def __init__(self, path=""):
# Load processor and model from the provided path or model ID
self.processor = AutoProcessor.from_pretrained(path or "bytedance/Dolphin")
self.model = VisionEncoderDecoderModel.from_pretrained(path or "bytedance/Dolphin")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.model.eval()
self.model = self.model.half() # Half precision for speed
self.tokenizer = self.processor.tokenizer
def decode_base64_image(self, image_base64: str) -> Image.Image:
image_bytes = base64.b64decode(image_base64)
return Image.open(io.BytesIO(image_bytes)).convert("RGB")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
# Check for image input
if "inputs" not in data:
return {"error": "No inputs provided"}
image_input = data["inputs"]
# Support both base64 image strings and raw images (Hugging Face supports both)
if isinstance(image_input, str):
try:
image = self.decode_base64_image(image_input)
except Exception as e:
return {"error": f"Invalid base64 image: {str(e)}"}
else:
image = image_input # Assume PIL-compatible image
# Optional: Custom prompt (default: text reading)
prompt = data.get("prompt", "Read text in the image.")
full_prompt = f"<s>{prompt} <Answer/>"
# Preprocess inputs
inputs = self.processor(image, return_tensors="pt")
pixel_values = inputs.pixel_values.half().to(self.device)
prompt_ids = self.tokenizer(full_prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(self.device)
decoder_attention_mask = torch.ones_like(prompt_ids).to(self.device)
# Inference
outputs = self.model.generate(
pixel_values=pixel_values,
decoder_input_ids=prompt_ids,
decoder_attention_mask=decoder_attention_mask,
min_length=1,
max_length=4096,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
use_cache=True,
bad_words_ids=[[self.tokenizer.unk_token_id]],
return_dict_in_generate=True,
do_sample=False,
num_beams=1,
)
sequence = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)[0]
# Clean up
generated_text = sequence.replace(full_prompt, "").replace("<pad>", "").replace("</s>", "").strip()
return {"text": generated_text}