Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, File, UploadFile | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import numpy as np | |
| import tensorflow as tf | |
| from PIL import Image | |
| from io import BytesIO | |
| app = FastAPI() | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # You can restrict this to specific origins if needed | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Load your pre-trained model | |
| MODEL_PATH = "./models/model_catdog1.h5" | |
| model = tf.keras.models.load_model(MODEL_PATH) | |
| def home(): | |
| return {"message": "FastAPI server is running on Hugging Face Spaces!"} | |
| def home(): | |
| return {"message": "FastAPI server is running on Hugging Face Spaces!"} | |
| # Helper function to read and convert the uploaded image | |
| def read_image(file: UploadFile) -> Image.Image: | |
| image = Image.open(BytesIO(file.file.read())).convert('RGB') | |
| return image | |
| # Helper function to preprocess the image | |
| def preprocess_image(image: Image.Image): | |
| image = image.resize((128, 128)) # Adjust to the size expected by your model | |
| image = np.array(image) / 255.0 # Normalize the image | |
| image = np.expand_dims(image, axis=0) # Add batch dimension | |
| return image | |
| # Route for classifying image | |
| async def predict(file: UploadFile = File(...)): | |
| try: | |
| # Read and preprocess the image | |
| image = read_image(file) | |
| preprocessed_image = preprocess_image(image) | |
| # Perform prediction | |
| prediction = model.predict(preprocessed_image) | |
| predicted_class = "Dog" if np.round(prediction[0][0]) == 1 else "Cat" | |
| # Return the prediction result | |
| return JSONResponse(content={"ok": 1, "prediction": predicted_class}) | |
| except Exception as e: | |
| return JSONResponse(content={"ok": -1, "message": f"Something went wrong! {str(e)}"}, status_code=500) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |