import numpy as np import tensorflow as tf import cv2 import base64 from tensorflow.keras.utils import img_to_array from tensorflow.keras import layers, regularizers from tensorflow.keras.applications.efficientnet_v2 import preprocess_input from PIL import Image from io import BytesIO import gradio as gr # Create base model input_shape = (224, 224, 3) base_model = tf.keras.applications.efficientnet_v2.EfficientNetV2B0(include_top=False) base_model.trainable = False # Create Functional model inputs = layers.Input(shape=input_shape, name="input_layer") x = base_model(inputs, training=False) x = layers.GlobalAveragePooling2D(name="pooling_layer")(x) x = layers.Dense(12, kernel_regularizer=regularizers.l2(0.001))(x) outputs = layers.Activation("softmax", dtype=tf.float32, name="softmax_float32")(x) model = tf.keras.Model(inputs, outputs) # Compile the model model.compile( loss="categorical_crossentropy", optimizer=tf.keras.optimizers.Adam(), metrics=["accuracy"] ) # Load weights model.load_weights('pest_classif.weights.h5') # Class labels CLASS_LABELS = ['Fourmis', 'Abeilles', 'Scarabe', 'Chenille', 'Verre de terre', 'Perce-oreille', 'Criquet', 'Papillon de nuit', 'Limace', 'Escargot', 'Guêpes', 'Charançon'] def preprocess_image(frame): if isinstance(frame, Image.Image): frame = np.array(frame) if len(frame.shape) == 3 and frame.shape[2] == 3: if not isinstance(frame, np.ndarray) or frame.dtype != np.uint8: frame = frame.astype(np.uint8) frame_rgb = frame else: frame_rgb = frame frame_resized = cv2.resize(frame_rgb, (224, 224)) img_array = img_to_array(frame_resized) img_array = preprocess_input(img_array) return np.expand_dims(img_array, axis=0) def classify_image(frame): annotated_frame = frame.copy() if isinstance(frame, np.ndarray) else np.array(frame).copy() processed_frame = preprocess_image(frame) predictions = model.predict(processed_frame, verbose=0) predicted_class_idx = np.argmax(predictions, axis=1)[0] confidence = predictions[0][predicted_class_idx] predicted_label = CLASS_LABELS[predicted_class_idx] label_text = f"Classe: {predicted_label} ({confidence*100:.1f}%)" if not isinstance(annotated_frame, np.ndarray): annotated_frame = np.array(annotated_frame) cv2.putText(annotated_frame, label_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2) return annotated_frame, predicted_label, f"{confidence*100:.2f}%" def predict_image(image): if image is None: return None, "No image provided", "N/A" annotated_img, label, conf = classify_image(image) return annotated_img, label, conf # API Function for external clients def predict_api(image): """API endpoint for predictions - returns JSON""" if image is None: return {"error": "No image provided"} try: _, label, conf = classify_image(image) return { "success": True, "class": label, "confidence": conf, "threshold_exceeded": float(conf.replace('%', '')) > 85.0 } except Exception as e: return {"success": False, "error": str(e)} # Create Gradio Interface with API with gr.Blocks(title="SerraSafe - Détection de Pestes") as demo: gr.Markdown( """ # 🌱 SerraSafe Guardian - Système de Détection de Pestes Ce système utilise l'intelligence artificielle pour détecter et classifier les pestes dans votre serre. **Classes détectées:** Fourmis, Abeilles, Scarabe, Chenille, Verre de terre, Perce-oreille, Criquet, Papillon de nuit, Limace, Escargot, Guêpes, Charançon --- ### 🔌 API REST disponible **Endpoint:** `/api/predict` **Méthode:** POST **Format:** Envoyez une image et recevez les prédictions en JSON """ ) with gr.Tab("📷 Télécharger Image"): with gr.Row(): with gr.Column(): image_input = gr.Image(type="numpy", label="Télécharger une image") image_button = gr.Button("Analyser l'image", variant="primary") with gr.Column(): image_output = gr.Image(label="Résultat") image_label = gr.Textbox(label="Classe détectée") image_confidence = gr.Textbox(label="Confiance") image_button.click( fn=predict_image, inputs=image_input, outputs=[image_output, image_label, image_confidence] ) with gr.Tab("🔌 API Info"): gr.Markdown( """ ### Comment utiliser l'API **Endpoint:** `https://votre-space.hf.space/api/predict` **Exemple avec Python:** ```python import requests import cv2 # Lire une image image = cv2.imread('image.jpg') # Envoyer à l'API response = requests.post( "https://votre-space.hf.space/api/predict", files={"image": image} ) result = response.json() print(f"Classe: {result['class']}") print(f"Confiance: {result['confidence']}") ``` **Réponse JSON:** ```json { "success": true, "class": "Fourmis", "confidence": "95.50%", "threshold_exceeded": true } ``` """ ) # Export API endpoint demo.launch( share=False, server_name="0.0.0.0", server_port=7860, show_error=True ) # Create API route api_demo = gr.Interface( fn=predict_api, inputs=gr.Image(type="numpy"), outputs=gr.JSON(), api_name="predict" )