antcar0929 commited on
Commit
42bfcc1
·
verified ·
1 Parent(s): 3510c8e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -0
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ import os
5
+
6
+ # 1. Load the Model
7
+ # We expect the .keras file to be in the same directory
8
+ model = tf.keras.models.load_model("hk_transit_flow_net.keras")
9
+
10
+ # Helper to map day names to integers
11
+ DAY_MAP = {
12
+ "Sunday": 0, "Monday": 1, "Tuesday": 2, "Wednesday": 3,
13
+ "Thursday": 4, "Friday": 5, "Saturday": 6
14
+ }
15
+
16
+ def predict_eta(distance_meters, num_stops, hour, day_name, route_id):
17
+ try:
18
+ # 1. Prepare Inputs
19
+ # We must match the exact shape and types used in training
20
+
21
+ # Handle empty route
22
+ if not route_id or route_id.strip() == "":
23
+ route_id = "UNKNOWN"
24
+
25
+ inputs = {
26
+ 'distance': np.array([[float(distance_meters)]]),
27
+ 'num_stops': np.array([[float(num_stops)]]),
28
+ 'hour': np.array([[int(hour)]]),
29
+ 'day_of_week': np.array([[int(DAY_MAP[day_name])]]),
30
+ 'route_id': tf.constant([[str(route_id)]], dtype=tf.string)
31
+ }
32
+
33
+ # 2. Run Prediction
34
+ prediction = model.predict(inputs, verbose=0)
35
+ seconds = float(prediction[0][0])
36
+
37
+ # 3. Format Output
38
+ minutes = int(seconds // 60)
39
+ rem_seconds = int(seconds % 60)
40
+
41
+ return f"{minutes} min {rem_seconds} sec ({seconds:.1f}s)"
42
+
43
+ except Exception as e:
44
+ return f"Error: {str(e)}"
45
+
46
+ # 3. Build the Interface
47
+ iface = gr.Interface(
48
+ fn=predict_eta,
49
+ inputs=[
50
+ gr.Number(label="Distance (meters)", value=5000),
51
+ gr.Number(label="Number of Stops", value=10),
52
+ gr.Slider(minimum=0, maximum=23, step=1, label="Hour of Day (0-23)", value=9),
53
+ gr.Dropdown(choices=list(DAY_MAP.keys()), label="Day of Week", value="Monday"),
54
+ gr.Textbox(label="Route ID (Optional)", placeholder="e.g. 968+1+Yuen Long+Tin Hau", value="UNKNOWN")
55
+ ],
56
+ outputs="text",
57
+ title="HK-TransitFlow-Net 🚌",
58
+ description="""
59
+ **Hong Kong Bus ETA Predictor**
60
+
61
+ This model uses Deep Learning to predict bus travel time based on distance, stops, and time context.
62
+
63
+ * **Distance:** Physical distance of the path in meters.
64
+ * **Route ID:** Internal ID (e.g., `968+1+...`). If unknown, leave as UNKNOWN.
65
+ * **Note:** Trained on KMB & CTB data.
66
+ """,
67
+ theme="soft"
68
+ )
69
+
70
+ # 4. Launch
71
+ if __name__ == "__main__":
72
+ iface.launch()