|
|
import torch |
|
|
from huggingface_hub import hf_hub_download |
|
|
from PIL import Image |
|
|
import torchvision.transforms as transforms |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model(repo_id="ash12321/ai-image-detector-deepsvdd"): |
|
|
"""Download and load model from HuggingFace""" |
|
|
|
|
|
model_path = hf_hub_download( |
|
|
repo_id=repo_id, |
|
|
filename="model.ckpt" |
|
|
) |
|
|
|
|
|
|
|
|
from model import AdvancedDeepSVDD |
|
|
model = AdvancedDeepSVDD.load_from_checkpoint(model_path) |
|
|
model.eval() |
|
|
|
|
|
return model |
|
|
|
|
|
def predict_image(image_path, model): |
|
|
"""Predict if image is AI-generated""" |
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((32, 32)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize( |
|
|
mean=[0.4914, 0.4822, 0.4465], |
|
|
std=[0.2470, 0.2435, 0.2616] |
|
|
) |
|
|
]) |
|
|
|
|
|
image = Image.open(image_path).convert('RGB') |
|
|
image_tensor = transform(image).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
|
is_fake, scores, distances = model.predict_anomaly(image_tensor) |
|
|
|
|
|
return { |
|
|
'is_ai_generated': bool(is_fake[0].item()), |
|
|
'confidence': float(scores[0].item()), |
|
|
'anomaly_score': float(scores[0].item()), |
|
|
'distance': float(distances[0].item()) |
|
|
} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
model = load_model() |
|
|
result = predict_image("test_image.jpg", model) |
|
|
print(f"AI-Generated: {result['is_ai_generated']}") |
|
|
print(f"Confidence: {result['confidence']*100:.1f}%") |
|
|
|