antcar0929 commited on
Commit
6710f72
·
verified ·
1 Parent(s): f3746a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +197 -60
app.py CHANGED
@@ -2,75 +2,191 @@ import gradio as gr
2
  import tensorflow as tf
3
  import numpy as np
4
  import requests
 
5
  from huggingface_hub import hf_hub_download
6
 
7
- # --- 1. Fetch Route Options (KMB/CTB Only) ---
8
- print("Fetching and filtering route list...")
 
 
 
 
 
9
  try:
10
  resp = requests.get("https://hkbus.github.io/hk-bus-crawling/routeFareList.min.json")
11
  if resp.status_code == 200:
12
- data = resp.json()
13
- raw_list = data['routeList']
 
14
 
15
  # Filter Logic: Only keep KMB or CTB
16
  valid_companies = ['kmb', 'ctb']
17
- filtered_routes = []
18
 
19
- for key, info in raw_list.items():
20
- # Check if company exists and is in our allowed list
21
  if 'co' in info and len(info['co']) > 0:
22
  company = info['co'][0]
23
  if company in valid_companies:
24
- filtered_routes.append(key)
 
 
25
 
26
- all_routes = sorted(filtered_routes)
27
  else:
28
- all_routes = []
29
  except Exception as e:
30
- print(f"Error fetching routes: {e}")
31
- all_routes = []
32
 
33
- print(f"Loaded {len(all_routes)} valid KMB/CTB routes.")
34
 
35
  # --- 2. Download and Load Model ---
36
  print("Downloading model...")
37
- model_path = hf_hub_download(repo_id="WheelsTransit/HK-TransitFlow-Net", filename="hk_transit_flow_net.keras")
38
-
39
- print("Loading Keras model...")
40
- model = tf.keras.models.load_model(model_path)
 
 
 
41
 
42
- # Helper to map day names to integers
43
  DAY_MAP = {
44
  "Sunday": 0, "Monday": 1, "Tuesday": 2, "Wednesday": 3,
45
  "Thursday": 4, "Friday": 5, "Saturday": 6
46
  }
47
 
48
- # --- Helper: Search Logic ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def filter_routes(search_text):
50
- """
51
- Returns a list of routes matching the search text.
52
- Limits to 100 results to prevent browser crash.
53
- """
54
  if not search_text:
55
- return gr.Dropdown(choices=["UNKNOWN"] + all_routes[:20]) # Default top 20
56
-
57
  search_text = search_text.lower()
58
- # Filter list
59
- filtered = [r for r in all_routes if search_text in r.lower()]
60
-
61
- # Cap at 100 results
62
  return gr.Dropdown(choices=["UNKNOWN"] + filtered[:100], value="UNKNOWN")
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- # --- Prediction Logic ---
66
- def predict_eta(distance_meters, num_stops, hour, day_name, route_id):
67
  try:
68
- if not route_id or route_id.strip() == "":
69
- route_id = "UNKNOWN"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  inputs = {
72
- 'distance': np.array([[float(distance_meters)]]),
73
- 'num_stops': np.array([[float(num_stops)]]),
74
  'hour': np.array([[int(hour)]]),
75
  'day_of_week': np.array([[int(DAY_MAP[day_name])]]),
76
  'route_id': tf.constant([[str(route_id)]], dtype=tf.string)
@@ -82,46 +198,67 @@ def predict_eta(distance_meters, num_stops, hour, day_name, route_id):
82
  minutes = int(seconds // 60)
83
  rem_seconds = int(seconds % 60)
84
 
85
- return f"{minutes} min {rem_seconds} sec ({seconds:.1f}s)"
 
 
 
86
 
87
  except Exception as e:
88
- return f"Error: {str(e)}"
 
89
 
90
  # --- 3. Build the UI ---
91
- with gr.Blocks() as demo:
92
  gr.Markdown("# HK-TransitFlow-Net Demo")
93
- gr.Markdown("Live inference for **KMB & CTB** Bus ETA prediction.")
94
- gr.Markdown("Predicts estimated journey time for a distance.")
95
- gr.Markdown("Model URL: https://huggingface.co/WheelsTransit/HK-TransitFlow-Net")
96
 
97
  with gr.Row():
 
98
  with gr.Column():
99
- gr.Markdown("### 1. Trip Details")
100
- dist_input = gr.Number(label="Distance (meters)", value=5000)
101
- stops_input = gr.Number(label="Number of Stops", value=10)
102
- hour_input = gr.Slider(minimum=0, maximum=23, step=1, label="Hour of Day (0-23)", value=9)
103
- day_input = gr.Dropdown(choices=list(DAY_MAP.keys()), label="Day of Week", value="Monday")
104
 
105
- with gr.Column():
106
- gr.Markdown("### 2. Route Selection (Optional - Select UNKNOWN to predict without route)")
107
- gr.Markdown("*Type in the box below to find your route (e.g. '968')*")
108
-
109
- # Search Box
110
- route_search = gr.Textbox(label="Search Route Number", placeholder="Type route number...")
111
 
112
- # Dropdown
113
- route_input = gr.Dropdown(label="Select Route ID", choices=["UNKNOWN"], value="UNKNOWN", interactive=True)
 
 
 
114
 
 
 
 
 
 
115
  predict_btn = gr.Button("Predict ETA", variant="primary")
116
- output_text = gr.Textbox(label="Estimated Travel Time", lines=1)
117
 
118
- # --- Interaction Logic ---
119
- route_search.change(fn=filter_routes, inputs=route_search, outputs=route_input)
 
 
 
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  predict_btn.click(
122
- fn=predict_eta,
123
- inputs=[dist_input, stops_input, hour_input, day_input, route_input],
124
- outputs=output_text
125
  )
126
 
127
  if __name__ == "__main__":
 
2
  import tensorflow as tf
3
  import numpy as np
4
  import requests
5
+ import math
6
  from huggingface_hub import hf_hub_download
7
 
8
+ # --- Global Data Storage ---
9
+ ROUTE_DATA = {}
10
+ STOP_DATA = {}
11
+ ALL_ROUTE_KEYS = []
12
+
13
+ # --- 1. Fetch Data (Routes & Stops) ---
14
+ print("Fetching route and stop data...")
15
  try:
16
  resp = requests.get("https://hkbus.github.io/hk-bus-crawling/routeFareList.min.json")
17
  if resp.status_code == 200:
18
+ json_db = resp.json()
19
+ raw_routes = json_db['routeList']
20
+ STOP_DATA = json_db['stopList']
21
 
22
  # Filter Logic: Only keep KMB or CTB
23
  valid_companies = ['kmb', 'ctb']
 
24
 
25
+ for key, info in raw_routes.items():
 
26
  if 'co' in info and len(info['co']) > 0:
27
  company = info['co'][0]
28
  if company in valid_companies:
29
+ # Store the whole object so we can look up stops later
30
+ ROUTE_DATA[key] = info
31
+ ALL_ROUTE_KEYS.append(key)
32
 
33
+ ALL_ROUTE_KEYS.sort()
34
  else:
35
+ print("Failed to download route data")
36
  except Exception as e:
37
+ print(f"Error fetching data: {e}")
 
38
 
39
+ print(f"Loaded {len(ALL_ROUTE_KEYS)} valid KMB/CTB routes.")
40
 
41
  # --- 2. Download and Load Model ---
42
  print("Downloading model...")
43
+ try:
44
+ model_path = hf_hub_download(repo_id="WheelsTransit/HK-TransitFlow-Net", filename="hk_transit_flow_net.keras")
45
+ print("Loading Keras model...")
46
+ model = tf.keras.models.load_model(model_path)
47
+ except Exception as e:
48
+ print(f"Model load failed: {e}")
49
+ model = None
50
 
51
+ # --- Helpers ---
52
  DAY_MAP = {
53
  "Sunday": 0, "Monday": 1, "Tuesday": 2, "Wednesday": 3,
54
  "Thursday": 4, "Friday": 5, "Saturday": 6
55
  }
56
 
57
+ def haversine_distance(coords):
58
+ """Calculates length of a coordinate list in meters."""
59
+ R = 6371000
60
+ total_dist = 0
61
+ for i in range(len(coords) - 1):
62
+ lon1, lat1 = coords[i]
63
+ lon2, lat2 = coords[i+1]
64
+ dlon = math.radians(lon2 - lon1)
65
+ dlat = math.radians(lat2 - lat1)
66
+ a = math.sin(dlat/2)**2 + math.cos(math.radians(lat1)) * math.cos(math.radians(lat2)) * math.sin(dlon/2)**2
67
+ c = 2 * math.asin(math.sqrt(a))
68
+ total_dist += R * c
69
+ return total_dist
70
+
71
+ # --- Dynamic UI Logic ---
72
+
73
  def filter_routes(search_text):
74
+ """Filters the route dropdown based on search text."""
 
 
 
75
  if not search_text:
76
+ return gr.Dropdown(choices=["UNKNOWN"] + ALL_ROUTE_KEYS[:20])
 
77
  search_text = search_text.lower()
78
+ filtered = [r for r in ALL_ROUTE_KEYS if search_text in r.lower()]
 
 
 
79
  return gr.Dropdown(choices=["UNKNOWN"] + filtered[:100], value="UNKNOWN")
80
 
81
+ def update_stop_dropdowns(route_key):
82
+ """
83
+ When a route is selected, fetch its stops and populate the Start/End dropdowns.
84
+ Returns: (Start_Dropdown_Update, End_Dropdown_Update)
85
+ """
86
+ if not route_key or route_key == "UNKNOWN" or route_key not in ROUTE_DATA:
87
+ return gr.Dropdown(choices=[], value=None), gr.Dropdown(choices=[], value=None)
88
+
89
+ route_info = ROUTE_DATA[route_key]
90
+ company = route_info['co'][0]
91
+ stop_ids = route_info['stops'].get(company, [])
92
+
93
+ # Create readable names: "1. StopName (ID)"
94
+ stop_options = []
95
+ for idx, sid in enumerate(stop_ids):
96
+ # Fetch name from STOP_DATA
97
+ name_en = "Unknown"
98
+ if sid in STOP_DATA:
99
+ name_en = STOP_DATA[sid]['name']['en']
100
+
101
+ label = f"{idx+1}. {name_en} ({sid})"
102
+ stop_options.append(label)
103
+
104
+ return gr.Dropdown(choices=stop_options, value=None), gr.Dropdown(choices=stop_options, value=None)
105
+
106
+ def calculate_real_metrics(route_key, start_str, end_str):
107
+ """
108
+ Downloads Waypoints and calculates actual distance/stops between two selected stops.
109
+ """
110
+ if route_key == "UNKNOWN" or not start_str or not end_str:
111
+ return None, None, "Please select a Route, Start Stop, and End Stop."
112
 
 
 
113
  try:
114
+ # Extract Index from string "1. Name (ID)"
115
+ start_idx = int(start_str.split(".")[0]) - 1
116
+ end_idx = int(end_str.split(".")[0]) - 1
117
+
118
+ if start_idx >= end_idx:
119
+ return None, None, "Error: Start Stop must be before End Stop."
120
+
121
+ # Fetch Route Info for GTFS ID
122
+ route_info = ROUTE_DATA[route_key]
123
+ gtfs_id = route_info.get('gtfsId')
124
+ company = route_info['co'][0]
125
+ bound = route_info['bound'].get(company)
126
+
127
+ if not gtfs_id or not bound:
128
+ return None, None, "Error: No GTFS data for this route."
129
+
130
+ # Download Waypoints
131
+ url = f"https://hkbus.github.io/route-waypoints/{gtfs_id}-{bound}.json"
132
+ resp = requests.get(url)
133
+ if resp.status_code != 200:
134
+ return None, None, "Error: Could not download route path data."
135
 
136
+ geojson = resp.json()
137
+ features = geojson.get('features', [])
138
+
139
+ # Determine Segments
140
+ # Logic: Feature[i] is path from Stop[i] to Stop[i+1]
141
+ segments = []
142
+ if features and features[0]['geometry']['type'] == 'MultiLineString':
143
+ segments = features[0]['geometry']['coordinates']
144
+ elif features:
145
+ segments = [f['geometry']['coordinates'] for f in features]
146
+
147
+ # Sum distance for the specific range
148
+ total_dist = 0
149
+ # We need to sum segments from start_idx to end_idx - 1
150
+ # E.g. Stop 0 to Stop 2 requires Segment 0 (0->1) and Segment 1 (1->2)
151
+
152
+ for i in range(start_idx, end_idx):
153
+ if i < len(segments):
154
+ total_dist += haversine_distance(segments[i])
155
+
156
+ num_stops = end_idx - start_idx
157
+
158
+ return total_dist, num_stops, None # None = No error
159
+
160
+ except Exception as e:
161
+ return None, None, f"Calculation Error: {str(e)}"
162
+
163
+ # --- Prediction Logic ---
164
+
165
+ def smart_predict(manual_dist, manual_stops, hour, day_name, route_id, start_stop, end_stop):
166
+ status_msg = ""
167
+
168
+ # 1. Automatic Calculation Logic
169
+ if route_id != "UNKNOWN" and start_stop and end_stop:
170
+ calc_dist, calc_stops, error = calculate_real_metrics(route_id, start_stop, end_stop)
171
+
172
+ if error:
173
+ status_msg = f"⚠️ {error} Using manual inputs."
174
+ final_dist = manual_dist
175
+ final_stops = manual_stops
176
+ else:
177
+ final_dist = calc_dist
178
+ final_stops = calc_stops
179
+ status_msg = f"✅ Calculated from map: {final_dist:.0f}m / {final_stops} stops."
180
+ else:
181
+ final_dist = manual_dist
182
+ final_stops = manual_stops
183
+ status_msg = "ℹ️ Using manual inputs."
184
+
185
+ # 2. Model Inference
186
+ try:
187
  inputs = {
188
+ 'distance': np.array([[float(final_dist)]]),
189
+ 'num_stops': np.array([[float(final_stops)]]),
190
  'hour': np.array([[int(hour)]]),
191
  'day_of_week': np.array([[int(DAY_MAP[day_name])]]),
192
  'route_id': tf.constant([[str(route_id)]], dtype=tf.string)
 
198
  minutes = int(seconds // 60)
199
  rem_seconds = int(seconds % 60)
200
 
201
+ result_str = f"⏱️ ETA: {minutes} min {rem_seconds} sec"
202
+
203
+ # Return updated boxes + result
204
+ return final_dist, final_stops, f"{status_msg}\n\n{result_str}"
205
 
206
  except Exception as e:
207
+ return final_dist, final_stops, f"Model Error: {str(e)}"
208
+
209
 
210
  # --- 3. Build the UI ---
211
+ with gr.Blocks(title="HK-TransitFlow-Net") as demo:
212
  gr.Markdown("# HK-TransitFlow-Net Demo")
213
+ gr.Markdown("Predicts KMB/CTB bus travel time. Select a route and stops to auto-calculate distance.")
 
 
214
 
215
  with gr.Row():
216
+ # Left Column: Inputs
217
  with gr.Column():
218
+ gr.Markdown("### 1. Route Selection")
219
+ route_search = gr.Textbox(label="Search Route", placeholder="Type e.g. '968'")
220
+ route_dropdown = gr.Dropdown(label="Select Route ID", choices=["UNKNOWN"], value="UNKNOWN", interactive=True)
 
 
221
 
222
+ with gr.Row():
223
+ start_dropdown = gr.Dropdown(label="Start Stop", choices=[], interactive=True)
224
+ end_dropdown = gr.Dropdown(label="End Stop", choices=[], interactive=True)
 
 
 
225
 
226
+ gr.Markdown("---")
227
+ gr.Markdown("### 2. Time & Details")
228
+ with gr.Row():
229
+ hour_input = gr.Slider(minimum=0, maximum=23, step=1, label="Hour (0-23)", value=9)
230
+ day_input = gr.Dropdown(choices=list(DAY_MAP.keys()), label="Day", value="Monday")
231
 
232
+ with gr.Row():
233
+ # These update automatically if stops are picked
234
+ dist_input = gr.Number(label="Distance (m)", value=5000)
235
+ stops_input = gr.Number(label="Stops Count", value=10)
236
+
237
  predict_btn = gr.Button("Predict ETA", variant="primary")
 
238
 
239
+ # Right Column: Output
240
+ with gr.Column():
241
+ gr.Markdown("### Result")
242
+ output_text = gr.Textbox(label="Prediction", lines=4)
243
+ gr.Markdown("*Note: If you select stops, distance is calculated automatically from the map.*")
244
 
245
+ # --- Event Wiring ---
246
+
247
+ # 1. Search filter
248
+ route_search.change(fn=filter_routes, inputs=route_search, outputs=route_dropdown)
249
+
250
+ # 2. Populate Stops when Route Selected
251
+ route_dropdown.change(
252
+ fn=update_stop_dropdowns,
253
+ inputs=route_dropdown,
254
+ outputs=[start_dropdown, end_dropdown]
255
+ )
256
+
257
+ # 3. Predict Button
258
  predict_btn.click(
259
+ fn=smart_predict,
260
+ inputs=[dist_input, stops_input, hour_input, day_input, route_dropdown, start_dropdown, end_dropdown],
261
+ outputs=[dist_input, stops_input, output_text] # Updates the number boxes too!
262
  )
263
 
264
  if __name__ == "__main__":