import onnxruntime as ort import numpy as np from PIL import Image import pandas as pd import requests import io import sys # CONFIG SECTION ONNX_MODEL_PATH = "./lsnet_xl_artist-dynamo-opset18_merged.onnx" CSV_PATH = "./class_mapping.csv" IMAGE_URL = "https://cdn.donmai.us/sample/9f/bb/__vampire_s_sister_original_drawn_by_gogalking__sample-9fbb30aa76bdc8242a1c122d3d6b41d9.jpg" IMAGE_SIZE = (224, 224) TOP_K = 5 PREDICTION_THRESHOLD = 0.0 # ------------------------------------------------------------ def preprocess_image_from_url(image_url, size=(224, 224)): """ Downloads an image from a URL, preprocesses it, and prepares it for the model. """ try: response = requests.get(image_url) response.raise_for_status() # Raise an exception for bad status codes image_bytes = io.BytesIO(response.content) image = Image.open(image_bytes).convert("RGB") except requests.exceptions.RequestException as e: print(f"Error: Failed to download image from URL '{image_url}'.\nDetails: {e}") sys.exit(1) except Exception as e: print(f"Error: Could not process the downloaded image. It may not be a valid image file.\nDetails: {e}") sys.exit(1) # Resize the image image = image.resize(size, Image.Resampling.LANCZOS) # Convert image to numpy array and scale to [0, 1] image_np = np.array(image, dtype=np.float32) / 255.0 # Define ImageNet mean and standard deviation for normalization mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) std = np.array([0.229, 0.224, 0.225], dtype=np.float32) # Normalize the image normalized_image = (image_np - mean) / std # Transpose the dimensions from (H, W, C) to (C, H, W) transposed_image = normalized_image.transpose((2, 0, 1)) # Add a batch dimension to create a shape of (1, C, H, W) batched_image = np.expand_dims(transposed_image, axis=0) return batched_image def load_labels(csv_path): """ Loads the class labels from the provided CSV file into a dictionary, handling the header row and stripping quotes from names. """ try: df = pd.read_csv(csv_path) if 'class_id' not in df.columns or 'class_name' not in df.columns: print(f"Error: CSV file must have 'class_id' and 'class_name' columns.") sys.exit(1) df['class_name'] = df['class_name'].str.strip("'") return dict(zip(df['class_id'], df['class_name'])) except FileNotFoundError: print(f"Error: CSV file not found at '{csv_path}'") sys.exit(1) except Exception as e: print(f"Error reading CSV file: {e}") sys.exit(1) def softmax(x): """Compute softmax values for a set of scores.""" e_x = np.exp(x - np.max(x)) return e_x / e_x.sum(axis=0) def main(): """ Main function to run the ONNX model inference. """ print("1. Loading class labels...") labels = load_labels(CSV_PATH) print(f" Loaded {len(labels)} labels.") print("\n2. Downloading and preprocessing image from URL...") input_tensor = preprocess_image_from_url(IMAGE_URL, IMAGE_SIZE) print(f" Image shape: {input_tensor.shape}, Data type: {input_tensor.dtype}") print("\n3. Initializing ONNX runtime session...") try: session = ort.InferenceSession(ONNX_MODEL_PATH, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) input_name = session.get_inputs()[0].name output_name = session.get_outputs()[0].name print(" ONNX session created successfully.") except Exception as e: print(f"Error loading ONNX model: {e}") sys.exit(1) print("\n4. Running inference...") results = session.run([output_name], {input_name: input_tensor}) logits = results[0][0] print(" Inference complete.") print("\n5. Processing results...") probabilities = softmax(logits) top_k_indices = np.argsort(probabilities)[-TOP_K:][::-1] print(f"\n--- Predictions for image URL (Top K: {TOP_K}, Threshold: {PREDICTION_THRESHOLD:.1%}) ---") predictions_found = 0 for i, index in enumerate(top_k_indices): score = probabilities[index] if score >= PREDICTION_THRESHOLD: class_name = labels.get(index, f"Unknown Class #{index}") print(f"Rank {i+1}: {class_name} (Score: {score:.2%})") predictions_found += 1 if predictions_found == 0: print("No predictions met the specified threshold.") if __name__ == "__main__": main()