SadraCoding commited on
Commit
2ded6cd
·
verified ·
1 Parent(s): 2e1dac3

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +25 -17
README.md CHANGED
@@ -43,32 +43,40 @@ pip install transformers torch pillow
43
  ```
44
 
45
  ```python
 
46
  from transformers import AutoModelForImageClassification, AutoFeatureExtractor
47
  from PIL import Image
48
  import torch
49
 
50
- image_path = 'path/to/your/image.jpg'
51
- model_id = "SADRACODING/SDXL-Deepfake-Detector"
 
 
52
 
53
- # Load Model and Feature Extractor
54
- model = AutoModelForImageClassification.from_pretrained(model_id)
55
- feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
 
56
 
57
- # Preprocessing
58
- image = Image.open(image_path).convert("RGB")
59
- inputs = feature_extractor(images=image, return_tensors="pt")
60
 
61
- # Inference
62
- with torch.no_grad():
63
- logits = model(**inputs).logits
64
 
65
- # Post-processing and Prediction
66
- predicted_class_id = logits.argmax().item()
67
- labels = ["REAL", "FAKE"]
68
- prediction = labels[predicted_class_id]
69
- confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item()
70
 
71
- print(f"Prediction: {prediction} (Confidence: {confidence:.4f})")
 
 
 
 
 
 
 
 
72
  ```
73
  ---
74
  license: mit
 
43
  ```
44
 
45
  ```python
46
+ import argparse
47
  from transformers import AutoModelForImageClassification, AutoFeatureExtractor
48
  from PIL import Image
49
  import torch
50
 
51
+ def main():
52
+ parser = argparse.ArgumentParser(description="Predict image class using fine-tuned model from Hugging Face Hub")
53
+ parser.add_argument("--image", type=str, required=True, help="Path to the input image")
54
+ args = parser.parse_args()
55
 
56
+ # Load model and feature extractor directly from Hugging Face Hub
57
+ model_name = "SADRACODING/SDXL-Deepfake-Detector"
58
+ model = AutoModelForImageClassification.from_pretrained(model_name)
59
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
60
 
61
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
+ model.to(device)
63
+ model.eval()
64
 
65
+ image = Image.open(args.image).convert("RGB")
66
+ inputs = feature_extractor(images=image, return_tensors="pt").to(device)
 
67
 
68
+ with torch.no_grad():
69
+ outputs = model(**inputs)
 
 
 
70
 
71
+ logits = outputs.logits
72
+ predicted_class_idx = logits.argmax(-1).item()
73
+ predicted_label = model.config.id2label[predicted_class_idx]
74
+
75
+ print(f"Predicted class index: {predicted_class_idx}")
76
+ print(f"Predicted label: {predicted_label}")
77
+
78
+ if __name__ == "__main__":
79
+ main()
80
  ```
81
  ---
82
  license: mit