Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import tensorflow as tf | |
| import cv2 | |
| from tensorflow.keras.utils import img_to_array | |
| from tensorflow.keras import layers, regularizers | |
| from tensorflow.keras.applications.efficientnet_v2 import preprocess_input | |
| from PIL import Image | |
| import gradio as gr | |
| # Create base model | |
| input_shape = (224, 224, 3) | |
| base_model = tf.keras.applications.efficientnet_v2.EfficientNetV2B0(include_top=False) | |
| base_model.trainable = False | |
| # Create Functional model | |
| inputs = layers.Input(shape=input_shape, name="input_layer") | |
| x = base_model(inputs, training=False) | |
| x = layers.GlobalAveragePooling2D(name="pooling_layer")(x) | |
| x = layers.Dense(12, kernel_regularizer=regularizers.l2(0.001))(x) | |
| outputs = layers.Activation("softmax", dtype=tf.float32, name="softmax_float32")(x) | |
| model = tf.keras.Model(inputs, outputs) | |
| # Compile the model | |
| model.compile( | |
| loss="categorical_crossentropy", | |
| optimizer=tf.keras.optimizers.Adam(), | |
| metrics=["accuracy"] | |
| ) | |
| # Load weights | |
| model.load_weights('pest_classif.weights.h5') | |
| # Class labels | |
| CLASS_LABELS = ['Fourmis', 'Abeilles', 'Scarabe', 'Chenille', 'Verre de terre', | |
| 'Perce-oreille', 'Criquet', 'Papillon de nuit', 'Limace', 'Escargot', 'Guêpes', 'Charançon'] | |
| def preprocess_image(frame): | |
| if isinstance(frame, Image.Image): | |
| frame = np.array(frame) | |
| if len(frame.shape) == 3 and frame.shape[2] == 3: | |
| if not isinstance(frame, np.ndarray) or frame.dtype != np.uint8: | |
| frame = frame.astype(np.uint8) | |
| frame_rgb = frame | |
| else: | |
| frame_rgb = frame | |
| frame_resized = cv2.resize(frame_rgb, (224, 224)) | |
| img_array = img_to_array(frame_resized) | |
| img_array = preprocess_input(img_array) | |
| return np.expand_dims(img_array, axis=0) | |
| def classify_image(frame): | |
| annotated_frame = frame.copy() if isinstance(frame, np.ndarray) else np.array(frame).copy() | |
| processed_frame = preprocess_image(frame) | |
| predictions = model.predict(processed_frame, verbose=0) | |
| predicted_class_idx = np.argmax(predictions, axis=1)[0] | |
| confidence = predictions[0][predicted_class_idx] | |
| predicted_label = CLASS_LABELS[predicted_class_idx] | |
| label_text = f"Classe: {predicted_label} ({confidence*100:.1f}%)" | |
| if not isinstance(annotated_frame, np.ndarray): | |
| annotated_frame = np.array(annotated_frame) | |
| cv2.putText(annotated_frame, label_text, (10, 30), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2) | |
| return annotated_frame, predicted_label, f"{confidence*100:.2f}%" | |
| def predict_image(image): | |
| if image is None: | |
| return None, "No image provided", "N/A" | |
| annotated_img, label, conf = classify_image(image) | |
| return annotated_img, label, conf | |
| # API Function - returns JSON for external clients | |
| def predict_api(image): | |
| if image is None: | |
| return {"error": "No image provided", "success": False} | |
| try: | |
| _, label, conf = classify_image(image) | |
| conf_value = float(conf.replace('%', '')) | |
| return { | |
| "success": True, | |
| "class": label, | |
| "confidence": conf, | |
| "confidence_value": conf_value, | |
| "threshold_exceeded": conf_value > 85.0 | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| # Create main Gradio interface | |
| with gr.Blocks(title="SerraSafe - Détection de Pestes") as demo: | |
| gr.Markdown( | |
| """ | |
| # 🌱 SerraSafe Guardian - Système de Détection de Pestes | |
| Ce système utilise l'intelligence artificielle pour détecter et classifier les pestes dans votre serre. | |
| **Classes détectées:** Fourmis, Abeilles, Scarabe, Chenille, Verre de terre, Perce-oreille, Criquet, Papillon de nuit, Limace, Escargot, Guêpes, Charançon | |
| """ | |
| ) | |
| with gr.Tab("📷 Analyse Image"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="numpy", label="Télécharger une image") | |
| image_button = gr.Button("Analyser l'image", variant="primary") | |
| with gr.Column(): | |
| image_output = gr.Image(label="Résultat") | |
| image_label = gr.Textbox(label="Classe détectée") | |
| image_confidence = gr.Textbox(label="Confiance") | |
| image_button.click( | |
| fn=predict_image, | |
| inputs=image_input, | |
| outputs=[image_output, image_label, image_confidence] | |
| ) | |
| with gr.Tab("🔌 API (JSON)"): | |
| gr.Markdown( | |
| """ | |
| ### API pour clients externes | |
| Utilisez cet endpoint pour obtenir des prédictions en JSON. | |
| """ | |
| ) | |
| with gr.Column(): | |
| api_image_input = gr.Image(type="numpy", label="Image") | |
| api_button = gr.Button("Tester l'API", variant="primary") | |
| api_output = gr.JSON(label="Réponse JSON") | |
| api_button.click( | |
| fn=predict_api, | |
| inputs=api_image_input, | |
| outputs=api_output, | |
| api_name="predict" # Ceci expose l'endpoint /api/predict | |
| ) | |
| with gr.Tab("📖 Documentation"): | |
| gr.Markdown( | |
| f""" | |
| ### 🔌 Utilisation de l'API | |
| **Endpoint:** `/api/predict` | |
| **URL complète:** `https://danyanderson-serrasafe.hf.space/api/predict` | |
| Utilisez l'onglet "API (JSON)" ci-dessus pour tester l'API directement. | |
| Pour l'utiliser depuis Python avec le client local, téléchargez `client_local.py` et lancez-le. | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |