Spaces:
Running
Running
Alex
commited on
Commit
·
b2702fe
1
Parent(s):
0e3833c
updated to onnx
Browse files- .gitignore +5 -0
- README.md +26 -1
- app.py +102 -147
- hf_onnx_converter.py +202 -0
- requirements.txt +4 -1
- response.json +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.env
|
| 2 |
+
.DS_Store
|
| 3 |
+
models/
|
| 4 |
+
model_cache
|
| 5 |
+
onnx_models
|
README.md
CHANGED
|
@@ -51,4 +51,29 @@ curl -X POST "https://alexgenovese-segmentation.hf.space/segment-url" \
|
|
| 51 |
-d '{
|
| 52 |
"url": "https://plus.unsplash.com/premium_photo-1673210886161-bfcc40f54d1f?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8cGVyc29uJTIwc3RhbmRpbmd8ZW58MHx8MHx8&w=1000&q=80"
|
| 53 |
}' \
|
| 54 |
-
-o response.json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
-d '{
|
| 52 |
"url": "https://plus.unsplash.com/premium_photo-1673210886161-bfcc40f54d1f?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8cGVyc29uJTIwc3RhbmRpbmd8ZW58MHx8MHx8&w=1000&q=80"
|
| 53 |
}' \
|
| 54 |
+
-o response.json
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# Segment-clothes-url
|
| 58 |
+
|
| 59 |
+
curl -X POST "https://alexgenovese-segmentation.hf.space/segment-clothes-url" \
|
| 60 |
+
-H "Content-Type: application/json" \
|
| 61 |
+
-d '{
|
| 62 |
+
"url": "https://plus.unsplash.com/premium_photo-1673210886161-bfcc40f54d1f?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8cGVyc29uJTIwc3RhbmRpbmd8ZW58MHx8MHx8&w=1000&q=80"
|
| 63 |
+
}' \
|
| 64 |
+
-o response.json
|
| 65 |
+
|
| 66 |
+
# Convert to ONNX file
|
| 67 |
+
|
| 68 |
+
# For the fashion segmentation model:
|
| 69 |
+
python convert_to_onnx.py --model "sayeed99/segformer-b3-fashion" --output "models/fashion_segformer.onnx"
|
| 70 |
+
|
| 71 |
+
# For the clothes segmentation model:
|
| 72 |
+
python convert_to_onnx.py --model "mattmdjaga/segformer_b2_clothes" --output "models/clothes_segformer.onnx"
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# Convert To Onnx file
|
| 76 |
+
|
| 77 |
+
python3 hf_onnx_converter.py \
|
| 78 |
+
--source "mattmdjaga/segformer_b2_clothes" \
|
| 79 |
+
--target "alexgenovese/segformer-onnx"
|
app.py
CHANGED
|
@@ -1,193 +1,148 @@
|
|
| 1 |
from fastapi import FastAPI, File, UploadFile, HTTPException
|
| 2 |
-
from transformers import
|
|
|
|
| 3 |
from pydantic import BaseModel
|
| 4 |
from PIL import Image
|
| 5 |
import numpy as np
|
| 6 |
-
import io, base64, logging, requests,
|
| 7 |
-
import
|
|
|
|
| 8 |
|
| 9 |
-
#
|
| 10 |
-
|
| 11 |
|
| 12 |
-
# Add this class for the request body
|
| 13 |
class ImageURL(BaseModel):
|
| 14 |
url: str
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
logger = logging.getLogger(__name__)
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
except Exception as e:
|
| 40 |
-
logger.error(f"Error loading clothes model: {str(e)}")
|
| 41 |
-
raise RuntimeError(f"Error loading clothes model: {str(e)}")
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
logger.info("Preparazione dell'immagine per l'inferenza...")
|
| 47 |
-
inputs = processor(images=image, return_tensors="pt").to("cpu")
|
| 48 |
-
|
| 49 |
-
# Inferenza
|
| 50 |
-
logger.info("Esecuzione dell'inferenza...")
|
| 51 |
-
with torch.no_grad():
|
| 52 |
-
outputs = model(**inputs)
|
| 53 |
-
logits = outputs.logits
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
buffered = io.BytesIO()
|
| 65 |
-
mask_img.save(buffered, format="PNG")
|
| 66 |
-
mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 67 |
-
|
| 68 |
-
# Annotazioni
|
| 69 |
-
annotations = {"mask": mask.tolist(), "label": logits }
|
| 70 |
-
|
| 71 |
-
return mask_base64, annotations
|
| 72 |
-
|
| 73 |
-
# Endpoint API
|
| 74 |
-
@app.post("/segment")
|
| 75 |
-
async def segment_endpoint(file: UploadFile = File(...)):
|
| 76 |
-
try:
|
| 77 |
-
logger.info("Ricezione del file...")
|
| 78 |
-
image_data = await file.read()
|
| 79 |
-
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
| 80 |
|
| 81 |
-
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
return {
|
| 85 |
"mask": f"data:image/png;base64,{mask_base64}",
|
| 86 |
-
"
|
|
|
|
| 87 |
}
|
| 88 |
-
except Exception as e:
|
| 89 |
-
logger.error(f"Errore nell'endpoint: {str(e)}")
|
| 90 |
-
raise HTTPException(status_code=500, detail=f"Errore nell'elaborazione: {str(e)}")
|
| 91 |
-
|
| 92 |
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
-
# Add new endpoint
|
| 95 |
@app.post("/segment-url")
|
| 96 |
async def segment_url_endpoint(image_data: ImageURL):
|
| 97 |
try:
|
| 98 |
-
logger.info("Downloading image from URL...")
|
| 99 |
response = requests.get(image_data.url, stream=True)
|
| 100 |
if response.status_code != 200:
|
| 101 |
raise HTTPException(status_code=400, detail="Could not download image from URL")
|
| 102 |
|
| 103 |
-
# Open image from URL
|
| 104 |
image = Image.open(response.raw).convert("RGB")
|
| 105 |
-
|
| 106 |
-
# Process image with SegFormer
|
| 107 |
-
logger.info("Processing image...")
|
| 108 |
-
inputs = processor(images=image, return_tensors="pt")
|
| 109 |
-
outputs = model(**inputs)
|
| 110 |
-
logits = outputs.logits.cpu()
|
| 111 |
-
|
| 112 |
-
# Upsample logits to match original image size
|
| 113 |
-
upsampled_logits = nn.functional.interpolate(
|
| 114 |
-
logits,
|
| 115 |
-
size=image.size[::-1],
|
| 116 |
-
mode="bilinear",
|
| 117 |
-
align_corners=False,
|
| 118 |
-
)
|
| 119 |
-
|
| 120 |
-
# Get prediction
|
| 121 |
-
pred_seg = upsampled_logits.argmax(dim=1)[0]
|
| 122 |
-
|
| 123 |
-
# Convert to image
|
| 124 |
-
mask_img = Image.fromarray((pred_seg.numpy() * 255).astype(np.uint8))
|
| 125 |
-
|
| 126 |
-
# Convert to base64
|
| 127 |
-
buffered = io.BytesIO()
|
| 128 |
-
mask_img.save(buffered, format="PNG")
|
| 129 |
-
mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 130 |
-
|
| 131 |
-
return {
|
| 132 |
-
"mask": f"data:image/png;base64,{mask_base64}",
|
| 133 |
-
"size": image.size,
|
| 134 |
-
"labels" : pred_seg
|
| 135 |
-
}
|
| 136 |
|
| 137 |
except Exception as e:
|
| 138 |
-
|
| 139 |
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
|
| 140 |
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
# Add new endpoint
|
| 144 |
@app.post("/segment-clothes-url")
|
| 145 |
async def segment_clothes_url_endpoint(image_data: ImageURL):
|
| 146 |
try:
|
| 147 |
-
logger.info("Downloading image from URL...")
|
| 148 |
response = requests.get(image_data.url, stream=True)
|
| 149 |
if response.status_code != 200:
|
| 150 |
raise HTTPException(status_code=400, detail="Could not download image from URL")
|
| 151 |
|
| 152 |
-
# Open image from URL
|
| 153 |
image = Image.open(response.raw).convert("RGB")
|
| 154 |
-
|
| 155 |
-
# Process image with SegFormer
|
| 156 |
-
logger.info("Processing image...")
|
| 157 |
-
inputs = clothes_processor(images=image, return_tensors="pt")
|
| 158 |
-
outputs = clothes_model(**inputs)
|
| 159 |
-
logits = outputs.logits.cpu()
|
| 160 |
-
|
| 161 |
-
# Upsample logits to match original image size
|
| 162 |
-
upsampled_logits = nn.functional.interpolate(
|
| 163 |
-
logits,
|
| 164 |
-
size=image.size[::-1],
|
| 165 |
-
mode="bilinear",
|
| 166 |
-
align_corners=False,
|
| 167 |
-
)
|
| 168 |
-
|
| 169 |
-
# Get prediction
|
| 170 |
-
pred_seg = upsampled_logits.argmax(dim=1)[0]
|
| 171 |
-
|
| 172 |
-
# Convert to image
|
| 173 |
-
mask_img = Image.fromarray((pred_seg.numpy() * 255).astype(np.uint8))
|
| 174 |
-
|
| 175 |
-
# Convert to base64
|
| 176 |
-
buffered = io.BytesIO()
|
| 177 |
-
mask_img.save(buffered, format="PNG")
|
| 178 |
-
mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 179 |
-
|
| 180 |
-
return {
|
| 181 |
-
"mask": f"data:image/png;base64,{mask_base64}",
|
| 182 |
-
"size": image.size,
|
| 183 |
-
"predictions": pred_seg.numpy().tolist()
|
| 184 |
-
}
|
| 185 |
|
| 186 |
except Exception as e:
|
| 187 |
-
|
| 188 |
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
|
| 189 |
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
if __name__ == "__main__":
|
| 192 |
import uvicorn
|
| 193 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
| 1 |
from fastapi import FastAPI, File, UploadFile, HTTPException
|
| 2 |
+
from transformers import SegformerImageProcessor
|
| 3 |
+
from huggingface_hub import hf_hub_download
|
| 4 |
from pydantic import BaseModel
|
| 5 |
from PIL import Image
|
| 6 |
import numpy as np
|
| 7 |
+
import io, base64, logging, requests, os
|
| 8 |
+
import onnxruntime as ort
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
|
| 11 |
+
# Load environment variables
|
| 12 |
+
load_dotenv()
|
| 13 |
|
|
|
|
| 14 |
class ImageURL(BaseModel):
|
| 15 |
url: str
|
| 16 |
|
| 17 |
+
class ModelManager:
|
| 18 |
+
def __init__(self):
|
| 19 |
+
self.logger = logging.getLogger(__name__)
|
| 20 |
+
self.token = os.getenv("HF_TOKEN")
|
| 21 |
+
if not self.token:
|
| 22 |
+
raise ValueError("HF_TOKEN environment variable is required")
|
| 23 |
+
self._initialize_models()
|
| 24 |
+
|
| 25 |
+
def _initialize_models(self):
|
| 26 |
+
try:
|
| 27 |
+
# Initialize ONNX runtime sessions
|
| 28 |
+
self.logger.info("Loading ONNX models...")
|
| 29 |
+
|
| 30 |
+
# Download and load fashion model
|
| 31 |
+
fashion_path = hf_hub_download(
|
| 32 |
+
repo_id="alexgenovese/segformer-onnx",
|
| 33 |
+
filename="segformer-b3-fashion.onnx",
|
| 34 |
+
token=self.token
|
| 35 |
+
)
|
| 36 |
+
self.fashion_model = ort.InferenceSession(fashion_path)
|
| 37 |
+
self.fashion_processor = SegformerImageProcessor.from_pretrained(
|
| 38 |
+
"sayeed99/segformer-b3-fashion",
|
| 39 |
+
token=self.token
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Download and load clothes model
|
| 43 |
+
clothes_path = hf_hub_download(
|
| 44 |
+
repo_id="alexgenovese/segformer-onnx",
|
| 45 |
+
filename="segformer_b2_clothes.onnx",
|
| 46 |
+
token=self.token
|
| 47 |
+
)
|
| 48 |
+
self.clothes_model = ort.InferenceSession(clothes_path)
|
| 49 |
+
self.clothes_processor = SegformerImageProcessor.from_pretrained(
|
| 50 |
+
"mattmdjaga/segformer_b2_clothes",
|
| 51 |
+
token=self.token
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
self.logger.info("All models loaded successfully.")
|
| 55 |
+
except Exception as e:
|
| 56 |
+
self.logger.error(f"Error initializing models: {str(e)}")
|
| 57 |
+
raise RuntimeError(f"Error initializing models: {str(e)}")
|
| 58 |
+
|
| 59 |
+
def process_fashion_image(self, image: Image.Image):
|
| 60 |
+
inputs = self.fashion_processor(images=image, return_tensors="np")
|
| 61 |
+
onnx_inputs = {
|
| 62 |
+
'input': inputs['pixel_values']
|
| 63 |
+
}
|
| 64 |
+
logits = self.fashion_model.run(None, onnx_inputs)[0]
|
| 65 |
+
return self._post_process_outputs(logits, image.size)
|
| 66 |
|
| 67 |
+
def process_clothes_image(self, image: Image.Image):
|
| 68 |
+
inputs = self.clothes_processor(images=image, return_tensors="np")
|
| 69 |
+
onnx_inputs = {
|
| 70 |
+
'input': inputs['pixel_values']
|
| 71 |
+
}
|
| 72 |
+
logits = self.clothes_model.run(None, onnx_inputs)[0]
|
| 73 |
+
return self._post_process_outputs(logits, image.size)
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
+
def _post_process_outputs(self, logits, image_size):
|
| 76 |
+
# Convert logits to proper shape for processing
|
| 77 |
+
logits = np.array(logits)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
+
# Resize prediction to match original image size
|
| 80 |
+
from skimage.transform import resize
|
| 81 |
+
resized_logits = resize(
|
| 82 |
+
logits[0],
|
| 83 |
+
(image_size[1], image_size[0]),
|
| 84 |
+
order=1,
|
| 85 |
+
preserve_range=True,
|
| 86 |
+
mode='reflect'
|
| 87 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
+
# Get prediction
|
| 90 |
+
pred_seg = np.argmax(resized_logits, axis=0)
|
| 91 |
+
mask_img = Image.fromarray((pred_seg * 255).astype(np.uint8))
|
| 92 |
+
|
| 93 |
+
# Convert to base64
|
| 94 |
+
buffered = io.BytesIO()
|
| 95 |
+
mask_img.save(buffered, format="PNG")
|
| 96 |
+
mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 97 |
|
| 98 |
return {
|
| 99 |
"mask": f"data:image/png;base64,{mask_base64}",
|
| 100 |
+
"size": image_size,
|
| 101 |
+
"predictions": pred_seg.tolist()
|
| 102 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
+
# Initialize FastAPI and ModelManager
|
| 105 |
+
app = FastAPI()
|
| 106 |
+
model_manager = ModelManager()
|
| 107 |
|
|
|
|
| 108 |
@app.post("/segment-url")
|
| 109 |
async def segment_url_endpoint(image_data: ImageURL):
|
| 110 |
try:
|
|
|
|
| 111 |
response = requests.get(image_data.url, stream=True)
|
| 112 |
if response.status_code != 200:
|
| 113 |
raise HTTPException(status_code=400, detail="Could not download image from URL")
|
| 114 |
|
|
|
|
| 115 |
image = Image.open(response.raw).convert("RGB")
|
| 116 |
+
return model_manager.process_fashion_image(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
except Exception as e:
|
| 119 |
+
logging.error(f"Error processing URL: {str(e)}")
|
| 120 |
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
|
| 121 |
|
|
|
|
|
|
|
|
|
|
| 122 |
@app.post("/segment-clothes-url")
|
| 123 |
async def segment_clothes_url_endpoint(image_data: ImageURL):
|
| 124 |
try:
|
|
|
|
| 125 |
response = requests.get(image_data.url, stream=True)
|
| 126 |
if response.status_code != 200:
|
| 127 |
raise HTTPException(status_code=400, detail="Could not download image from URL")
|
| 128 |
|
|
|
|
| 129 |
image = Image.open(response.raw).convert("RGB")
|
| 130 |
+
return model_manager.process_clothes_image(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
except Exception as e:
|
| 133 |
+
logging.error(f"Error processing URL: {str(e)}")
|
| 134 |
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
|
| 135 |
|
| 136 |
+
@app.post("/segment")
|
| 137 |
+
async def segment_endpoint(file: UploadFile = File(...)):
|
| 138 |
+
try:
|
| 139 |
+
image_data = await file.read()
|
| 140 |
+
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
| 141 |
+
return model_manager.process_fashion_image(image)
|
| 142 |
+
except Exception as e:
|
| 143 |
+
logging.error(f"Error in endpoint: {str(e)}")
|
| 144 |
+
raise HTTPException(status_code=500, detail=f"Error processing: {str(e)}")
|
| 145 |
+
|
| 146 |
if __name__ == "__main__":
|
| 147 |
import uvicorn
|
| 148 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
hf_onnx_converter.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoModelForSemanticSegmentation, SegformerImageProcessor
|
| 3 |
+
from huggingface_hub import HfApi, create_repo, upload_file, model_info
|
| 4 |
+
import os
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import logging
|
| 8 |
+
import argparse
|
| 9 |
+
import tempfile
|
| 10 |
+
|
| 11 |
+
# Setup logging
|
| 12 |
+
logging.basicConfig(level=logging.INFO)
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
# Load environment variables
|
| 16 |
+
load_dotenv()
|
| 17 |
+
|
| 18 |
+
class ConfigurationError(Exception):
|
| 19 |
+
"""Raised when required environment variables are missing"""
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
class HFOnnxConverter:
|
| 23 |
+
def __init__(self, token=None):
|
| 24 |
+
# Load configuration from environment
|
| 25 |
+
self.token = token or os.getenv("HF_TOKEN")
|
| 26 |
+
self.model_cache_dir = os.getenv("MODEL_CACHE_DIR")
|
| 27 |
+
self.onnx_output_dir = os.getenv("ONNX_OUTPUT_DIR")
|
| 28 |
+
|
| 29 |
+
# Validate configuration
|
| 30 |
+
if not self.token:
|
| 31 |
+
raise ConfigurationError("HF_TOKEN is required in environment variables")
|
| 32 |
+
|
| 33 |
+
# Create directories if they don't exist
|
| 34 |
+
for directory in [self.model_cache_dir, self.onnx_output_dir]:
|
| 35 |
+
if directory:
|
| 36 |
+
Path(directory).mkdir(parents=True, exist_ok=True)
|
| 37 |
+
|
| 38 |
+
self.api = HfApi()
|
| 39 |
+
|
| 40 |
+
# Login to Hugging Face
|
| 41 |
+
try:
|
| 42 |
+
self.api.whoami(token=self.token)
|
| 43 |
+
logger.info("Successfully authenticated with Hugging Face")
|
| 44 |
+
except Exception as e:
|
| 45 |
+
raise ConfigurationError(f"Failed to authenticate with Hugging Face: {str(e)}")
|
| 46 |
+
|
| 47 |
+
def setup_repository(self, repo_name: str) -> str:
|
| 48 |
+
"""Create or get repository on Hugging Face Hub"""
|
| 49 |
+
try:
|
| 50 |
+
create_repo(
|
| 51 |
+
repo_name,
|
| 52 |
+
token=self.token,
|
| 53 |
+
private=False,
|
| 54 |
+
exist_ok=True
|
| 55 |
+
)
|
| 56 |
+
logger.info(f"Repository {repo_name} is ready")
|
| 57 |
+
return repo_name
|
| 58 |
+
except Exception as e:
|
| 59 |
+
logger.error(f"Error setting up repository: {e}")
|
| 60 |
+
raise
|
| 61 |
+
|
| 62 |
+
def verify_model_exists(self, model_name: str) -> bool:
|
| 63 |
+
"""Verify if the model exists and is accessible"""
|
| 64 |
+
try:
|
| 65 |
+
model_info(model_name, token=self.token)
|
| 66 |
+
return True
|
| 67 |
+
except Exception as e:
|
| 68 |
+
logger.error(f"Model verification failed: {str(e)}")
|
| 69 |
+
return False
|
| 70 |
+
|
| 71 |
+
def convert_and_push(self, source_model: str, target_repo: str):
|
| 72 |
+
"""Convert model to ONNX and push to Hugging Face Hub"""
|
| 73 |
+
try:
|
| 74 |
+
# Verify model exists and is accessible
|
| 75 |
+
if not self.verify_model_exists(source_model):
|
| 76 |
+
raise ValueError(f"Model {source_model} is not accessible. Check if the model exists and you have proper permissions.")
|
| 77 |
+
|
| 78 |
+
# Use model cache directory if specified
|
| 79 |
+
model_kwargs = {
|
| 80 |
+
"token": self.token
|
| 81 |
+
}
|
| 82 |
+
if self.model_cache_dir:
|
| 83 |
+
model_kwargs["cache_dir"] = self.model_cache_dir
|
| 84 |
+
|
| 85 |
+
# Create working directory
|
| 86 |
+
working_dir = self.onnx_output_dir or tempfile.mkdtemp()
|
| 87 |
+
tmp_path = Path(working_dir) / f"{target_repo.split('/')[-1]}.onnx"
|
| 88 |
+
|
| 89 |
+
logger.info(f"Loading model {source_model}...")
|
| 90 |
+
model = AutoModelForSemanticSegmentation.from_pretrained(
|
| 91 |
+
source_model,
|
| 92 |
+
**model_kwargs
|
| 93 |
+
)
|
| 94 |
+
processor = SegformerImageProcessor.from_pretrained(
|
| 95 |
+
source_model,
|
| 96 |
+
**model_kwargs
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Set model to evaluation mode
|
| 100 |
+
model.eval()
|
| 101 |
+
|
| 102 |
+
# Create dummy input
|
| 103 |
+
dummy_input = processor(
|
| 104 |
+
images=torch.zeros(1, 3, 224, 224),
|
| 105 |
+
return_tensors="pt"
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Export to ONNX
|
| 109 |
+
logger.info(f"Converting to ONNX format... Output path: {tmp_path}")
|
| 110 |
+
torch.onnx.export(
|
| 111 |
+
model,
|
| 112 |
+
(dummy_input['pixel_values'],),
|
| 113 |
+
tmp_path,
|
| 114 |
+
input_names=['input'],
|
| 115 |
+
output_names=['output'],
|
| 116 |
+
dynamic_axes={
|
| 117 |
+
'input': {0: 'batch_size', 2: 'height', 3: 'width'},
|
| 118 |
+
'output': {0: 'batch_size'}
|
| 119 |
+
},
|
| 120 |
+
opset_version=12,
|
| 121 |
+
do_constant_folding=True
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Create model card with environment info
|
| 125 |
+
model_card = f"""---
|
| 126 |
+
base_model: {source_model}
|
| 127 |
+
tags:
|
| 128 |
+
- onnx
|
| 129 |
+
- semantic-segmentation
|
| 130 |
+
---
|
| 131 |
+
|
| 132 |
+
# ONNX Model converted from {source_model}
|
| 133 |
+
|
| 134 |
+
This is an ONNX version of the model {source_model}, converted automatically.
|
| 135 |
+
|
| 136 |
+
## Model Information
|
| 137 |
+
- Original Model: {source_model}
|
| 138 |
+
- ONNX Opset Version: 12
|
| 139 |
+
- Input Shape: Dynamic (batch_size, 3, height, width)
|
| 140 |
+
|
| 141 |
+
## Usage
|
| 142 |
+
|
| 143 |
+
```python
|
| 144 |
+
import onnxruntime as ort
|
| 145 |
+
import numpy as np
|
| 146 |
+
|
| 147 |
+
# Load ONNX model
|
| 148 |
+
session = ort.InferenceSession("model.onnx")
|
| 149 |
+
|
| 150 |
+
# Prepare input
|
| 151 |
+
input_data = np.zeros((1, 3, 224, 224), dtype=np.float32)
|
| 152 |
+
|
| 153 |
+
# Run inference
|
| 154 |
+
outputs = session.run(None, {{"input": input_data}})
|
| 155 |
+
```
|
| 156 |
+
"""
|
| 157 |
+
# Save model card
|
| 158 |
+
readme_path = Path(working_dir) / "README.md"
|
| 159 |
+
with open(readme_path, "w") as f:
|
| 160 |
+
f.write(model_card)
|
| 161 |
+
|
| 162 |
+
# Push files to hub
|
| 163 |
+
logger.info(f"Pushing files to {target_repo}...")
|
| 164 |
+
self.api.upload_file(
|
| 165 |
+
path_or_fileobj=str(tmp_path),
|
| 166 |
+
path_in_repo="model.onnx",
|
| 167 |
+
repo_id=target_repo,
|
| 168 |
+
token=self.token
|
| 169 |
+
)
|
| 170 |
+
self.api.upload_file(
|
| 171 |
+
path_or_fileobj=str(readme_path),
|
| 172 |
+
path_in_repo="README.md",
|
| 173 |
+
repo_id=target_repo,
|
| 174 |
+
token=self.token
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
logger.info(f"Successfully pushed ONNX model to {target_repo}")
|
| 178 |
+
return True
|
| 179 |
+
|
| 180 |
+
except Exception as e:
|
| 181 |
+
logger.error(f"Error during conversion and upload: {e}")
|
| 182 |
+
return False
|
| 183 |
+
|
| 184 |
+
def main():
|
| 185 |
+
parser = argparse.ArgumentParser(description='Convert and push model to ONNX format on Hugging Face Hub')
|
| 186 |
+
parser.add_argument('--source', type=str, required=True,
|
| 187 |
+
help='Source model name (e.g., "sayeed99/segformer-b3-fashion")')
|
| 188 |
+
parser.add_argument('--target', type=str, required=True,
|
| 189 |
+
help='Target repository name (e.g., "your-username/model-name-onnx")')
|
| 190 |
+
parser.add_argument('--token', type=str, help='Hugging Face token (optional)')
|
| 191 |
+
|
| 192 |
+
args = parser.parse_args()
|
| 193 |
+
|
| 194 |
+
converter = HFOnnxConverter(token=args.token)
|
| 195 |
+
converter.setup_repository(args.target)
|
| 196 |
+
success = converter.convert_and_push(args.source, args.target)
|
| 197 |
+
|
| 198 |
+
if not success:
|
| 199 |
+
exit(1)
|
| 200 |
+
|
| 201 |
+
if __name__ == "__main__":
|
| 202 |
+
main()
|
requirements.txt
CHANGED
|
@@ -4,4 +4,7 @@ torch
|
|
| 4 |
torchvision
|
| 5 |
transformers
|
| 6 |
pillow
|
| 7 |
-
numpy
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
torchvision
|
| 5 |
transformers
|
| 6 |
pillow
|
| 7 |
+
numpy
|
| 8 |
+
torch
|
| 9 |
+
dotenv
|
| 10 |
+
onnx
|
response.json
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|