serrasafe / app_legacy.py
danyanderson's picture
Rename app.py to app_legacy.py
cd1025f verified
import numpy as np
import tensorflow as tf
import cv2
import base64
from tensorflow.keras.utils import img_to_array
from tensorflow.keras import layers, regularizers
from tensorflow.keras.applications.efficientnet_v2 import preprocess_input
from PIL import Image
from io import BytesIO
import gradio as gr
# Create base model
input_shape = (224, 224, 3)
base_model = tf.keras.applications.efficientnet_v2.EfficientNetV2B0(include_top=False)
base_model.trainable = False
# Create Functional model
inputs = layers.Input(shape=input_shape, name="input_layer")
x = base_model(inputs, training=False)
x = layers.GlobalAveragePooling2D(name="pooling_layer")(x)
x = layers.Dense(12, kernel_regularizer=regularizers.l2(0.001))(x)
outputs = layers.Activation("softmax", dtype=tf.float32, name="softmax_float32")(x)
model = tf.keras.Model(inputs, outputs)
# Compile the model
model.compile(
loss="categorical_crossentropy",
optimizer=tf.keras.optimizers.Adam(),
metrics=["accuracy"]
)
# Load weights
model.load_weights('pest_classif.weights.h5')
# Class labels
CLASS_LABELS = ['Fourmis', 'Abeilles', 'Scarabe', 'Chenille', 'Verre de terre',
'Perce-oreille', 'Criquet', 'Papillon de nuit', 'Limace', 'Escargot', 'Guêpes', 'Charançon']
def preprocess_image(frame):
if isinstance(frame, Image.Image):
frame = np.array(frame)
if len(frame.shape) == 3 and frame.shape[2] == 3:
if not isinstance(frame, np.ndarray) or frame.dtype != np.uint8:
frame = frame.astype(np.uint8)
frame_rgb = frame
else:
frame_rgb = frame
frame_resized = cv2.resize(frame_rgb, (224, 224))
img_array = img_to_array(frame_resized)
img_array = preprocess_input(img_array)
return np.expand_dims(img_array, axis=0)
def classify_image(frame):
annotated_frame = frame.copy() if isinstance(frame, np.ndarray) else np.array(frame).copy()
processed_frame = preprocess_image(frame)
predictions = model.predict(processed_frame, verbose=0)
predicted_class_idx = np.argmax(predictions, axis=1)[0]
confidence = predictions[0][predicted_class_idx]
predicted_label = CLASS_LABELS[predicted_class_idx]
label_text = f"Classe: {predicted_label} ({confidence*100:.1f}%)"
if not isinstance(annotated_frame, np.ndarray):
annotated_frame = np.array(annotated_frame)
cv2.putText(annotated_frame, label_text, (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
return annotated_frame, predicted_label, f"{confidence*100:.2f}%"
def predict_image(image):
if image is None:
return None, "No image provided", "N/A"
annotated_img, label, conf = classify_image(image)
return annotated_img, label, conf
# API Function for external clients
def predict_api(image):
"""API endpoint for predictions - returns JSON"""
if image is None:
return {"error": "No image provided"}
try:
_, label, conf = classify_image(image)
return {
"success": True,
"class": label,
"confidence": conf,
"threshold_exceeded": float(conf.replace('%', '')) > 85.0
}
except Exception as e:
return {"success": False, "error": str(e)}
# Create Gradio Interface with API
with gr.Blocks(title="SerraSafe - Détection de Pestes") as demo:
gr.Markdown(
"""
# 🌱 SerraSafe Guardian - Système de Détection de Pestes
Ce système utilise l'intelligence artificielle pour détecter et classifier les pestes dans votre serre.
**Classes détectées:** Fourmis, Abeilles, Scarabe, Chenille, Verre de terre, Perce-oreille, Criquet, Papillon de nuit, Limace, Escargot, Guêpes, Charançon
---
### 🔌 API REST disponible
**Endpoint:** `/api/predict`
**Méthode:** POST
**Format:** Envoyez une image et recevez les prédictions en JSON
"""
)
with gr.Tab("📷 Télécharger Image"):
with gr.Row():
with gr.Column():
image_input = gr.Image(type="numpy", label="Télécharger une image")
image_button = gr.Button("Analyser l'image", variant="primary")
with gr.Column():
image_output = gr.Image(label="Résultat")
image_label = gr.Textbox(label="Classe détectée")
image_confidence = gr.Textbox(label="Confiance")
image_button.click(
fn=predict_image,
inputs=image_input,
outputs=[image_output, image_label, image_confidence]
)
with gr.Tab("🔌 API Info"):
gr.Markdown(
"""
### Comment utiliser l'API
**Endpoint:** `https://votre-space.hf.space/api/predict`
**Exemple avec Python:**
```python
import requests
import cv2
# Lire une image
image = cv2.imread('image.jpg')
# Envoyer à l'API
response = requests.post(
"https://votre-space.hf.space/api/predict",
files={"image": image}
)
result = response.json()
print(f"Classe: {result['class']}")
print(f"Confiance: {result['confidence']}")
```
**Réponse JSON:**
```json
{
"success": true,
"class": "Fourmis",
"confidence": "95.50%",
"threshold_exceeded": true
}
```
"""
)
# Export API endpoint
demo.launch(
share=False,
server_name="0.0.0.0",
server_port=7860,
show_error=True
)
# Create API route
api_demo = gr.Interface(
fn=predict_api,
inputs=gr.Image(type="numpy"),
outputs=gr.JSON(),
api_name="predict"
)