antcar0929 commited on
Commit
7405298
·
verified ·
1 Parent(s): 6710f72

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -66
app.py CHANGED
@@ -19,14 +19,12 @@ try:
19
  raw_routes = json_db['routeList']
20
  STOP_DATA = json_db['stopList']
21
 
22
- # Filter Logic: Only keep KMB or CTB
23
  valid_companies = ['kmb', 'ctb']
24
 
25
  for key, info in raw_routes.items():
26
  if 'co' in info and len(info['co']) > 0:
27
  company = info['co'][0]
28
  if company in valid_companies:
29
- # Store the whole object so we can look up stops later
30
  ROUTE_DATA[key] = info
31
  ALL_ROUTE_KEYS.append(key)
32
 
@@ -55,7 +53,6 @@ DAY_MAP = {
55
  }
56
 
57
  def haversine_distance(coords):
58
- """Calculates length of a coordinate list in meters."""
59
  R = 6371000
60
  total_dist = 0
61
  for i in range(len(coords) - 1):
@@ -71,7 +68,6 @@ def haversine_distance(coords):
71
  # --- Dynamic UI Logic ---
72
 
73
  def filter_routes(search_text):
74
- """Filters the route dropdown based on search text."""
75
  if not search_text:
76
  return gr.Dropdown(choices=["UNKNOWN"] + ALL_ROUTE_KEYS[:20])
77
  search_text = search_text.lower()
@@ -79,10 +75,6 @@ def filter_routes(search_text):
79
  return gr.Dropdown(choices=["UNKNOWN"] + filtered[:100], value="UNKNOWN")
80
 
81
  def update_stop_dropdowns(route_key):
82
- """
83
- When a route is selected, fetch its stops and populate the Start/End dropdowns.
84
- Returns: (Start_Dropdown_Update, End_Dropdown_Update)
85
- """
86
  if not route_key or route_key == "UNKNOWN" or route_key not in ROUTE_DATA:
87
  return gr.Dropdown(choices=[], value=None), gr.Dropdown(choices=[], value=None)
88
 
@@ -90,103 +82,92 @@ def update_stop_dropdowns(route_key):
90
  company = route_info['co'][0]
91
  stop_ids = route_info['stops'].get(company, [])
92
 
93
- # Create readable names: "1. StopName (ID)"
94
  stop_options = []
95
  for idx, sid in enumerate(stop_ids):
96
- # Fetch name from STOP_DATA
97
  name_en = "Unknown"
98
  if sid in STOP_DATA:
99
  name_en = STOP_DATA[sid]['name']['en']
100
-
101
  label = f"{idx+1}. {name_en} ({sid})"
102
  stop_options.append(label)
103
 
104
  return gr.Dropdown(choices=stop_options, value=None), gr.Dropdown(choices=stop_options, value=None)
105
 
106
  def calculate_real_metrics(route_key, start_str, end_str):
107
- """
108
- Downloads Waypoints and calculates actual distance/stops between two selected stops.
109
- """
110
  if route_key == "UNKNOWN" or not start_str or not end_str:
111
- return None, None, "Please select a Route, Start Stop, and End Stop."
112
 
113
  try:
114
- # Extract Index from string "1. Name (ID)"
115
  start_idx = int(start_str.split(".")[0]) - 1
116
  end_idx = int(end_str.split(".")[0]) - 1
117
 
118
  if start_idx >= end_idx:
119
- return None, None, "Error: Start Stop must be before End Stop."
120
 
121
- # Fetch Route Info for GTFS ID
122
  route_info = ROUTE_DATA[route_key]
123
  gtfs_id = route_info.get('gtfsId')
124
  company = route_info['co'][0]
125
  bound = route_info['bound'].get(company)
126
 
127
  if not gtfs_id or not bound:
128
- return None, None, "Error: No GTFS data for this route."
129
 
130
- # Download Waypoints
131
  url = f"https://hkbus.github.io/route-waypoints/{gtfs_id}-{bound}.json"
132
  resp = requests.get(url)
133
  if resp.status_code != 200:
134
- return None, None, "Error: Could not download route path data."
135
 
136
  geojson = resp.json()
137
  features = geojson.get('features', [])
138
 
139
- # Determine Segments
140
- # Logic: Feature[i] is path from Stop[i] to Stop[i+1]
141
  segments = []
142
  if features and features[0]['geometry']['type'] == 'MultiLineString':
143
  segments = features[0]['geometry']['coordinates']
144
  elif features:
145
  segments = [f['geometry']['coordinates'] for f in features]
146
 
147
- # Sum distance for the specific range
148
  total_dist = 0
149
- # We need to sum segments from start_idx to end_idx - 1
150
- # E.g. Stop 0 to Stop 2 requires Segment 0 (0->1) and Segment 1 (1->2)
151
-
152
  for i in range(start_idx, end_idx):
153
  if i < len(segments):
154
  total_dist += haversine_distance(segments[i])
155
 
156
  num_stops = end_idx - start_idx
157
-
158
- return total_dist, num_stops, None # None = No error
159
 
160
  except Exception as e:
161
- return None, None, f"Calculation Error: {str(e)}"
 
 
 
 
 
 
 
 
162
 
163
  # --- Prediction Logic ---
164
 
165
- def smart_predict(manual_dist, manual_stops, hour, day_name, route_id, start_stop, end_stop):
166
- status_msg = ""
 
 
 
167
 
168
- # 1. Automatic Calculation Logic
169
- if route_id != "UNKNOWN" and start_stop and end_stop:
170
- calc_dist, calc_stops, error = calculate_real_metrics(route_id, start_stop, end_stop)
 
171
 
172
- if error:
173
- status_msg = f"⚠️ {error} Using manual inputs."
174
- final_dist = manual_dist
175
- final_stops = manual_stops
176
  else:
177
- final_dist = calc_dist
178
- final_stops = calc_stops
179
- status_msg = f"✅ Calculated from map: {final_dist:.0f}m / {final_stops} stops."
180
  else:
181
- final_dist = manual_dist
182
- final_stops = manual_stops
183
- status_msg = "ℹ️ Using manual inputs."
184
 
185
- # 2. Model Inference
186
  try:
187
  inputs = {
188
- 'distance': np.array([[float(final_dist)]]),
189
- 'num_stops': np.array([[float(final_stops)]]),
190
  'hour': np.array([[int(hour)]]),
191
  'day_of_week': np.array([[int(DAY_MAP[day_name])]]),
192
  'route_id': tf.constant([[str(route_id)]], dtype=tf.string)
@@ -194,28 +175,25 @@ def smart_predict(manual_dist, manual_stops, hour, day_name, route_id, start_sto
194
 
195
  prediction = model.predict(inputs, verbose=0)
196
  seconds = float(prediction[0][0])
197
-
198
  minutes = int(seconds // 60)
199
  rem_seconds = int(seconds % 60)
200
 
201
- result_str = f"⏱️ ETA: {minutes} min {rem_seconds} sec"
202
-
203
- # Return updated boxes + result
204
- return final_dist, final_stops, f"{status_msg}\n\n{result_str}"
205
 
206
  except Exception as e:
207
- return final_dist, final_stops, f"Model Error: {str(e)}"
208
 
209
 
210
  # --- 3. Build the UI ---
211
  with gr.Blocks(title="HK-TransitFlow-Net") as demo:
212
  gr.Markdown("# HK-TransitFlow-Net Demo")
213
- gr.Markdown("Predicts KMB/CTB bus travel time. Select a route and stops to auto-calculate distance.")
214
 
215
  with gr.Row():
216
- # Left Column: Inputs
217
  with gr.Column():
218
- gr.Markdown("### 1. Route Selection")
 
 
219
  route_search = gr.Textbox(label="Search Route", placeholder="Type e.g. '968'")
220
  route_dropdown = gr.Dropdown(label="Select Route ID", choices=["UNKNOWN"], value="UNKNOWN", interactive=True)
221
 
@@ -230,35 +208,44 @@ with gr.Blocks(title="HK-TransitFlow-Net") as demo:
230
  day_input = gr.Dropdown(choices=list(DAY_MAP.keys()), label="Day", value="Monday")
231
 
232
  with gr.Row():
233
- # These update automatically if stops are picked
234
  dist_input = gr.Number(label="Distance (m)", value=5000)
235
  stops_input = gr.Number(label="Stops Count", value=10)
236
 
237
  predict_btn = gr.Button("Predict ETA", variant="primary")
238
 
239
- # Right Column: Output
240
  with gr.Column():
241
  gr.Markdown("### Result")
242
- output_text = gr.Textbox(label="Prediction", lines=4)
243
- gr.Markdown("*Note: If you select stops, distance is calculated automatically from the map.*")
244
 
245
  # --- Event Wiring ---
246
 
247
- # 1. Search filter
248
  route_search.change(fn=filter_routes, inputs=route_search, outputs=route_dropdown)
249
 
250
- # 2. Populate Stops when Route Selected
251
  route_dropdown.change(
252
  fn=update_stop_dropdowns,
253
  inputs=route_dropdown,
254
  outputs=[start_dropdown, end_dropdown]
255
  )
256
 
257
- # 3. Predict Button
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  predict_btn.click(
259
- fn=smart_predict,
260
  inputs=[dist_input, stops_input, hour_input, day_input, route_dropdown, start_dropdown, end_dropdown],
261
- outputs=[dist_input, stops_input, output_text] # Updates the number boxes too!
262
  )
263
 
264
  if __name__ == "__main__":
 
19
  raw_routes = json_db['routeList']
20
  STOP_DATA = json_db['stopList']
21
 
 
22
  valid_companies = ['kmb', 'ctb']
23
 
24
  for key, info in raw_routes.items():
25
  if 'co' in info and len(info['co']) > 0:
26
  company = info['co'][0]
27
  if company in valid_companies:
 
28
  ROUTE_DATA[key] = info
29
  ALL_ROUTE_KEYS.append(key)
30
 
 
53
  }
54
 
55
  def haversine_distance(coords):
 
56
  R = 6371000
57
  total_dist = 0
58
  for i in range(len(coords) - 1):
 
68
  # --- Dynamic UI Logic ---
69
 
70
  def filter_routes(search_text):
 
71
  if not search_text:
72
  return gr.Dropdown(choices=["UNKNOWN"] + ALL_ROUTE_KEYS[:20])
73
  search_text = search_text.lower()
 
75
  return gr.Dropdown(choices=["UNKNOWN"] + filtered[:100], value="UNKNOWN")
76
 
77
  def update_stop_dropdowns(route_key):
 
 
 
 
78
  if not route_key or route_key == "UNKNOWN" or route_key not in ROUTE_DATA:
79
  return gr.Dropdown(choices=[], value=None), gr.Dropdown(choices=[], value=None)
80
 
 
82
  company = route_info['co'][0]
83
  stop_ids = route_info['stops'].get(company, [])
84
 
 
85
  stop_options = []
86
  for idx, sid in enumerate(stop_ids):
 
87
  name_en = "Unknown"
88
  if sid in STOP_DATA:
89
  name_en = STOP_DATA[sid]['name']['en']
 
90
  label = f"{idx+1}. {name_en} ({sid})"
91
  stop_options.append(label)
92
 
93
  return gr.Dropdown(choices=stop_options, value=None), gr.Dropdown(choices=stop_options, value=None)
94
 
95
  def calculate_real_metrics(route_key, start_str, end_str):
 
 
 
96
  if route_key == "UNKNOWN" or not start_str or not end_str:
97
+ return None, None, "Wait"
98
 
99
  try:
 
100
  start_idx = int(start_str.split(".")[0]) - 1
101
  end_idx = int(end_str.split(".")[0]) - 1
102
 
103
  if start_idx >= end_idx:
104
+ return None, None, "Start must be before End"
105
 
 
106
  route_info = ROUTE_DATA[route_key]
107
  gtfs_id = route_info.get('gtfsId')
108
  company = route_info['co'][0]
109
  bound = route_info['bound'].get(company)
110
 
111
  if not gtfs_id or not bound:
112
+ return None, None, "No Map Data"
113
 
 
114
  url = f"https://hkbus.github.io/route-waypoints/{gtfs_id}-{bound}.json"
115
  resp = requests.get(url)
116
  if resp.status_code != 200:
117
+ return None, None, "Map Download Fail"
118
 
119
  geojson = resp.json()
120
  features = geojson.get('features', [])
121
 
 
 
122
  segments = []
123
  if features and features[0]['geometry']['type'] == 'MultiLineString':
124
  segments = features[0]['geometry']['coordinates']
125
  elif features:
126
  segments = [f['geometry']['coordinates'] for f in features]
127
 
 
128
  total_dist = 0
 
 
 
129
  for i in range(start_idx, end_idx):
130
  if i < len(segments):
131
  total_dist += haversine_distance(segments[i])
132
 
133
  num_stops = end_idx - start_idx
134
+ return total_dist, num_stops, None
 
135
 
136
  except Exception as e:
137
+ return None, None, str(e)
138
+
139
+ def auto_fill_metrics(route_key, start_str, end_str, current_dist, current_stops):
140
+ """Updates boxes when stops change."""
141
+ dist, stops, error = calculate_real_metrics(route_key, start_str, end_str)
142
+ if dist is not None and stops is not None:
143
+ return round(dist, 1), int(stops)
144
+ else:
145
+ return current_dist, current_stops
146
 
147
  # --- Prediction Logic ---
148
 
149
+ def predict_fn(manual_dist, manual_stops, hour, day_name, route_id, start_str, end_str):
150
+ status_tag = ""
151
+
152
+ # Check if inputs match the map data (Validation Check)
153
+ map_dist, map_stops, _ = calculate_real_metrics(route_id, start_str, end_str)
154
 
155
+ if map_dist is not None:
156
+ # We have map data. Check if manual input differs significantly.
157
+ diff_dist = abs(map_dist - float(manual_dist))
158
+ diff_stops = abs(map_stops - float(manual_stops))
159
 
160
+ if diff_dist > 50 or diff_stops > 0: # 50m tolerance
161
+ status_tag = "ℹ️ Based on Manual Inputs (Modified)"
 
 
162
  else:
163
+ status_tag = "✅ Based on Map Data"
 
 
164
  else:
165
+ status_tag = "ℹ️ Based on Manual Inputs"
 
 
166
 
 
167
  try:
168
  inputs = {
169
+ 'distance': np.array([[float(manual_dist)]]),
170
+ 'num_stops': np.array([[float(manual_stops)]]),
171
  'hour': np.array([[int(hour)]]),
172
  'day_of_week': np.array([[int(DAY_MAP[day_name])]]),
173
  'route_id': tf.constant([[str(route_id)]], dtype=tf.string)
 
175
 
176
  prediction = model.predict(inputs, verbose=0)
177
  seconds = float(prediction[0][0])
 
178
  minutes = int(seconds // 60)
179
  rem_seconds = int(seconds % 60)
180
 
181
+ return f"{status_tag}\n\n⏱️ ETA: {minutes} min {rem_seconds} sec"
 
 
 
182
 
183
  except Exception as e:
184
+ return f"Model Error: {str(e)}"
185
 
186
 
187
  # --- 3. Build the UI ---
188
  with gr.Blocks(title="HK-TransitFlow-Net") as demo:
189
  gr.Markdown("# HK-TransitFlow-Net Demo")
190
+ gr.Markdown("Predicts KMB/CTB bus travel time.")
191
 
192
  with gr.Row():
 
193
  with gr.Column():
194
+ gr.Markdown("### 1. Route Selection (Optional)")
195
+ gr.Markdown("Select a route to auto-fill distance, or skip to type manually.")
196
+
197
  route_search = gr.Textbox(label="Search Route", placeholder="Type e.g. '968'")
198
  route_dropdown = gr.Dropdown(label="Select Route ID", choices=["UNKNOWN"], value="UNKNOWN", interactive=True)
199
 
 
208
  day_input = gr.Dropdown(choices=list(DAY_MAP.keys()), label="Day", value="Monday")
209
 
210
  with gr.Row():
 
211
  dist_input = gr.Number(label="Distance (m)", value=5000)
212
  stops_input = gr.Number(label="Stops Count", value=10)
213
 
214
  predict_btn = gr.Button("Predict ETA", variant="primary")
215
 
 
216
  with gr.Column():
217
  gr.Markdown("### Result")
218
+ output_text = gr.Textbox(label="Prediction", lines=3)
219
+ gr.Markdown("*Tip: If you modify the Distance/Stops boxes manually, the model will use your typed values.*")
220
 
221
  # --- Event Wiring ---
222
 
 
223
  route_search.change(fn=filter_routes, inputs=route_search, outputs=route_dropdown)
224
 
 
225
  route_dropdown.change(
226
  fn=update_stop_dropdowns,
227
  inputs=route_dropdown,
228
  outputs=[start_dropdown, end_dropdown]
229
  )
230
 
231
+ # Auto-fill triggers
232
+ start_dropdown.change(
233
+ fn=auto_fill_metrics,
234
+ inputs=[route_dropdown, start_dropdown, end_dropdown, dist_input, stops_input],
235
+ outputs=[dist_input, stops_input]
236
+ )
237
+
238
+ end_dropdown.change(
239
+ fn=auto_fill_metrics,
240
+ inputs=[route_dropdown, start_dropdown, end_dropdown, dist_input, stops_input],
241
+ outputs=[dist_input, stops_input]
242
+ )
243
+
244
+ # Predict trigger (Passes dropdowns just for validation check)
245
  predict_btn.click(
246
+ fn=predict_fn,
247
  inputs=[dist_input, stops_input, hour_input, day_input, route_dropdown, start_dropdown, end_dropdown],
248
+ outputs=output_text
249
  )
250
 
251
  if __name__ == "__main__":