antcar0929 commited on
Commit
d3e8921
Β·
verified Β·
1 Parent(s): 37b478d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -24
app.py CHANGED
@@ -4,23 +4,33 @@ import numpy as np
4
  import requests
5
  from huggingface_hub import hf_hub_download
6
 
7
- # --- 1. Fetch Route Options for Dropdown ---
8
- print("Fetching route list for dropdown...")
9
  try:
10
  resp = requests.get("https://hkbus.github.io/hk-bus-crawling/routeFareList.min.json")
11
  if resp.status_code == 200:
12
  data = resp.json()
13
- # Get all keys (e.g. "968+1+Yuen Long+Tin Hau")
14
- # We sort them so they are easy to find
15
- all_routes = sorted(list(data['routeList'].keys()))
 
 
 
 
 
 
 
 
 
 
 
16
  else:
17
  all_routes = []
18
  except Exception as e:
19
  print(f"Error fetching routes: {e}")
20
  all_routes = []
21
 
22
- # Add "UNKNOWN" as the first/default option
23
- route_choices = ["UNKNOWN"] + all_routes
24
 
25
  # --- 2. Download and Load Model ---
26
  print("Downloading model...")
@@ -35,13 +45,29 @@ DAY_MAP = {
35
  "Thursday": 4, "Friday": 5, "Saturday": 6
36
  }
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def predict_eta(distance_meters, num_stops, hour, day_name, route_id):
39
  try:
40
- # Handle empty route
41
  if not route_id or route_id.strip() == "":
42
  route_id = "UNKNOWN"
43
 
44
- # Prepare inputs exactly as the model expects
45
  inputs = {
46
  'distance': np.array([[float(distance_meters)]]),
47
  'num_stops': np.array([[float(num_stops)]]),
@@ -50,7 +76,6 @@ def predict_eta(distance_meters, num_stops, hour, day_name, route_id):
50
  'route_id': tf.constant([[str(route_id)]], dtype=tf.string)
51
  }
52
 
53
- # Run Prediction
54
  prediction = model.predict(inputs, verbose=0)
55
  seconds = float(prediction[0][0])
56
 
@@ -65,32 +90,32 @@ def predict_eta(distance_meters, num_stops, hour, day_name, route_id):
65
  # --- 3. Build the UI ---
66
  with gr.Blocks() as demo:
67
  gr.Markdown("# HK-TransitFlow-Net Demo 🚌")
68
- gr.Markdown("Live inference for HK Bus ETA prediction.")
69
 
70
  with gr.Row():
71
  with gr.Column():
72
- # Inputs
73
  dist_input = gr.Number(label="Distance (meters)", value=5000)
74
  stops_input = gr.Number(label="Number of Stops", value=10)
75
  hour_input = gr.Slider(minimum=0, maximum=23, step=1, label="Hour of Day (0-23)", value=9)
76
  day_input = gr.Dropdown(choices=list(DAY_MAP.keys()), label="Day of Week", value="Monday")
77
 
78
- # UPDATED: Dropdown with search capabilities
79
- route_input = gr.Dropdown(
80
- choices=route_choices,
81
- label="Route ID",
82
- value="UNKNOWN",
83
- filterable=True, # Allows typing to search
84
- interactive=True
85
- )
86
 
87
- predict_btn = gr.Button("Predict ETA", variant="primary")
 
88
 
89
- with gr.Column():
90
- # Output
 
 
91
  output_text = gr.Textbox(label="Estimated Travel Time", lines=1)
92
 
93
- # Event Listener
 
 
94
  predict_btn.click(
95
  fn=predict_eta,
96
  inputs=[dist_input, stops_input, hour_input, day_input, route_input],
 
4
  import requests
5
  from huggingface_hub import hf_hub_download
6
 
7
+ # --- 1. Fetch Route Options (KMB/CTB Only) ---
8
+ print("Fetching and filtering route list...")
9
  try:
10
  resp = requests.get("https://hkbus.github.io/hk-bus-crawling/routeFareList.min.json")
11
  if resp.status_code == 200:
12
  data = resp.json()
13
+ raw_list = data['routeList']
14
+
15
+ # Filter Logic: Only keep KMB or CTB
16
+ valid_companies = ['kmb', 'ctb']
17
+ filtered_routes = []
18
+
19
+ for key, info in raw_list.items():
20
+ # Check if company exists and is in our allowed list
21
+ if 'co' in info and len(info['co']) > 0:
22
+ company = info['co'][0]
23
+ if company in valid_companies:
24
+ filtered_routes.append(key)
25
+
26
+ all_routes = sorted(filtered_routes)
27
  else:
28
  all_routes = []
29
  except Exception as e:
30
  print(f"Error fetching routes: {e}")
31
  all_routes = []
32
 
33
+ print(f"Loaded {len(all_routes)} valid KMB/CTB routes.")
 
34
 
35
  # --- 2. Download and Load Model ---
36
  print("Downloading model...")
 
45
  "Thursday": 4, "Friday": 5, "Saturday": 6
46
  }
47
 
48
+ # --- Helper: Search Logic ---
49
+ def filter_routes(search_text):
50
+ """
51
+ Returns a list of routes matching the search text.
52
+ Limits to 100 results to prevent browser crash.
53
+ """
54
+ if not search_text:
55
+ return gr.Dropdown(choices=["UNKNOWN"] + all_routes[:20]) # Default top 20
56
+
57
+ search_text = search_text.lower()
58
+ # Filter list
59
+ filtered = [r for r in all_routes if search_text in r.lower()]
60
+
61
+ # Cap at 100 results
62
+ return gr.Dropdown(choices=["UNKNOWN"] + filtered[:100], value="UNKNOWN")
63
+
64
+
65
+ # --- Prediction Logic ---
66
  def predict_eta(distance_meters, num_stops, hour, day_name, route_id):
67
  try:
 
68
  if not route_id or route_id.strip() == "":
69
  route_id = "UNKNOWN"
70
 
 
71
  inputs = {
72
  'distance': np.array([[float(distance_meters)]]),
73
  'num_stops': np.array([[float(num_stops)]]),
 
76
  'route_id': tf.constant([[str(route_id)]], dtype=tf.string)
77
  }
78
 
 
79
  prediction = model.predict(inputs, verbose=0)
80
  seconds = float(prediction[0][0])
81
 
 
90
  # --- 3. Build the UI ---
91
  with gr.Blocks() as demo:
92
  gr.Markdown("# HK-TransitFlow-Net Demo 🚌")
93
+ gr.Markdown("Live inference for **KMB & CTB** Bus ETA prediction.")
94
 
95
  with gr.Row():
96
  with gr.Column():
97
+ gr.Markdown("### 1. Trip Details")
98
  dist_input = gr.Number(label="Distance (meters)", value=5000)
99
  stops_input = gr.Number(label="Number of Stops", value=10)
100
  hour_input = gr.Slider(minimum=0, maximum=23, step=1, label="Hour of Day (0-23)", value=9)
101
  day_input = gr.Dropdown(choices=list(DAY_MAP.keys()), label="Day of Week", value="Monday")
102
 
103
+ with gr.Column():
104
+ gr.Markdown("### 2. Route Selection")
105
+ gr.Markdown("*Type in the box below to find your route (e.g. '968')*")
 
 
 
 
 
106
 
107
+ # Search Box
108
+ route_search = gr.Textbox(label="Search Route Number", placeholder="Type route number...")
109
 
110
+ # Dropdown
111
+ route_input = gr.Dropdown(label="Select Route ID", choices=["UNKNOWN"], value="UNKNOWN", interactive=True)
112
+
113
+ predict_btn = gr.Button("Predict ETA", variant="primary")
114
  output_text = gr.Textbox(label="Estimated Travel Time", lines=1)
115
 
116
+ # --- Interaction Logic ---
117
+ route_search.change(fn=filter_routes, inputs=route_search, outputs=route_input)
118
+
119
  predict_btn.click(
120
  fn=predict_eta,
121
  inputs=[dist_input, stops_input, hour_input, day_input, route_input],