MrUtakata commited on
Commit
c625993
·
verified ·
1 Parent(s): 8964dee

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import os
5
+ import gdown
6
+ import pickle
7
+ import joblib
8
+
9
+ # ---------------------- Download and Load Ensemble Model ----------------------
10
+ # Path to save the downloaded model
11
+ ensemble_model_path = "ensemble_model.pkl"
12
+
13
+ # Download the model if it does not exist yet
14
+ if not os.path.exists(ensemble_model_path):
15
+ url = "https://drive.google.com/uc?export=download&id=1jHtHOzfhtWMyYqX_pbQJ5akYAe6ZhfhU"
16
+ gdown.download(url, ensemble_model_path, quiet=False)
17
+
18
+ # Load the ensemble model using pickle
19
+ with open(ensemble_model_path, "rb") as f:
20
+ ensemble = pickle.load(f)
21
+
22
+ # ---------------------- Load Preprocessing Objects ----------------------
23
+ # Ensure that the files onehotencoder.pkl and scaler.pkl are in the same directory.
24
+ encoder = joblib.load("onehotencoder.pkl")
25
+ scaler = joblib.load("scaler.pkl")
26
+
27
+ # ---------------------- Set up the Streamlit App ----------------------
28
+ st.title("Customer Churn Predictor")
29
+ st.write("""
30
+ This app uses a trained machine learning model to predict whether a customer is likely to churn.
31
+ Please enter the customer details below.
32
+ """)
33
+
34
+ # ---------------------- User Inputs ----------------------
35
+ st.header("Customer Details")
36
+ age = st.number_input("Age", min_value=18, max_value=100, value=30)
37
+ tenure = st.number_input("Tenure (in months)", min_value=0, max_value=120, value=12)
38
+ usage_frequency = st.number_input("Usage Frequency", min_value=0, max_value=100, value=5)
39
+ support_calls = st.number_input("Support Calls", min_value=0, max_value=50, value=2)
40
+ total_spend = st.number_input("Total Spend", min_value=0.0, max_value=10000.0, value=100.0, step=10.0)
41
+
42
+ # Input for categorical fields (modify options as needed for your training data)
43
+ gender = st.selectbox("Gender", options=["Male", "Female"])
44
+ subscription_type = st.selectbox("Subscription Type", options=["Type A", "Type B", "Type C"])
45
+ contract_length = st.selectbox("Contract Length", options=["Monthly", "Quarterly", "Yearly"])
46
+
47
+ # Create a DataFrame for the input
48
+ input_df = pd.DataFrame({
49
+ "Age": [age],
50
+ "Tenure": [tenure],
51
+ "Usage Frequency": [usage_frequency],
52
+ "Support Calls": [support_calls],
53
+ "Total Spend": [total_spend],
54
+ "Gender": [gender],
55
+ "Subscription Type": [subscription_type],
56
+ "Contract Length": [contract_length]
57
+ })
58
+
59
+ st.write("### Input Data")
60
+ st.write(input_df)
61
+
62
+ # ---------------------- Preprocessing ----------------------
63
+ # 1. Encode categorical features using the loaded OneHotEncoder.
64
+ categorical_cols = ["Gender", "Subscription Type", "Contract Length"]
65
+ encoded_cat = encoder.transform(input_df[categorical_cols])
66
+ encoded_cat_df = pd.DataFrame(encoded_cat, columns=encoder.get_feature_names_out(categorical_cols))
67
+
68
+ # 2. Drop original categorical columns and concatenate encoded features.
69
+ input_df = input_df.drop(categorical_cols, axis=1)
70
+ input_transformed = pd.concat([input_df.reset_index(drop=True), encoded_cat_df.reset_index(drop=True)], axis=1)
71
+
72
+ # 3. Standardize numerical features using the loaded StandardScaler.
73
+ input_scaled = scaler.transform(input_transformed)
74
+
75
+ # ---------------------- Run Prediction ----------------------
76
+ if st.button("Predict Churn"):
77
+ prediction = ensemble.predict(input_scaled)
78
+ prediction_proba = ensemble.predict_proba(input_scaled)[:, 1]
79
+
80
+ # Interpret prediction: assuming 1 = churn, 0 = not churn
81
+ result = "Churned" if prediction[0] == 1 else "Not Churned"
82
+ st.write("### Prediction Results")
83
+ st.write(f"**Prediction:** {result}")
84
+ st.write(f"**Churn Probability:** {prediction_proba[0]:.2f}")