Spaces:
Sleeping
Sleeping
File size: 9,119 Bytes
42bfcc1 37b478d 6710f72 e4b324d 42bfcc1 6710f72 37b478d 6710f72 d3e8921 6710f72 d3e8921 6710f72 d3e8921 6710f72 37b478d 6710f72 37b478d 6710f72 37b478d 6710f72 37b478d cfb1903 6710f72 42bfcc1 6710f72 42bfcc1 6710f72 d3e8921 6710f72 d3e8921 6710f72 d3e8921 6710f72 7405298 d3e8921 42bfcc1 6710f72 7405298 6710f72 7405298 6710f72 7405298 42bfcc1 6710f72 7405298 6710f72 7405298 6710f72 7405298 6710f72 7405298 6710f72 7405298 c3d0b5d 6710f72 c3d0b5d 6710f72 c3d0b5d 6710f72 42bfcc1 7405298 42bfcc1 7405298 42bfcc1 7405298 6710f72 42bfcc1 37b478d 6710f72 d9fe210 7405298 c3d0b5d f4f7cd3 7405298 6710f72 37b478d 6710f72 f4f7cd3 6710f72 d3e8921 6710f72 d3e8921 f4f7cd3 6710f72 7405298 d3e8921 6710f72 7405298 f4f7cd3 7405298 6710f72 7405298 f4f7cd3 42bfcc1 f4f7cd3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 |
import gradio as gr
import tensorflow as tf
import numpy as np
import requests
import math
from huggingface_hub import hf_hub_download
# --- Global Data Storage ---
ROUTE_DATA = {}
STOP_DATA = {}
ALL_ROUTE_KEYS = []
# --- 1. Fetch Data (Routes & Stops) ---
print("Fetching route and stop data...")
try:
resp = requests.get("https://hkbus.github.io/hk-bus-crawling/routeFareList.min.json")
if resp.status_code == 200:
json_db = resp.json()
raw_routes = json_db['routeList']
STOP_DATA = json_db['stopList']
valid_companies = ['kmb', 'ctb']
for key, info in raw_routes.items():
if 'co' in info and len(info['co']) > 0:
company = info['co'][0]
if company in valid_companies:
ROUTE_DATA[key] = info
ALL_ROUTE_KEYS.append(key)
ALL_ROUTE_KEYS.sort()
else:
print("Failed to download route data")
except Exception as e:
print(f"Error fetching data: {e}")
print(f"Loaded {len(ALL_ROUTE_KEYS)} valid KMB/CTB routes.")
# --- 2. Download and Load Model ---
print("Downloading model...")
try:
model_path = hf_hub_download(repo_id="WheelsTransit/HK-TransitFlow-Net", filename="hk_transit_flow_net.keras")
print("Loading Keras model...")
model = tf.keras.models.load_model(model_path)
except Exception as e:
print(f"Model load failed: {e}")
model = None
# --- Helpers ---
DAY_MAP = {
"Sunday": 0, "Monday": 1, "Tuesday": 2, "Wednesday": 3,
"Thursday": 4, "Friday": 5, "Saturday": 6
}
def haversine_distance(coords):
R = 6371000
total_dist = 0
for i in range(len(coords) - 1):
lon1, lat1 = coords[i]
lon2, lat2 = coords[i+1]
dlon = math.radians(lon2 - lon1)
dlat = math.radians(lat2 - lat1)
a = math.sin(dlat/2)**2 + math.cos(math.radians(lat1)) * math.cos(math.radians(lat2)) * math.sin(dlon/2)**2
c = 2 * math.asin(math.sqrt(a))
total_dist += R * c
return total_dist
# --- Dynamic UI Logic ---
def filter_routes(search_text):
if not search_text:
return gr.Dropdown(choices=["UNKNOWN"] + ALL_ROUTE_KEYS[:20])
search_text = search_text.lower()
filtered = [r for r in ALL_ROUTE_KEYS if search_text in r.lower()]
return gr.Dropdown(choices=["UNKNOWN"] + filtered[:100], value="UNKNOWN")
def update_stop_dropdowns(route_key):
if not route_key or route_key == "UNKNOWN" or route_key not in ROUTE_DATA:
return gr.Dropdown(choices=[], value=None), gr.Dropdown(choices=[], value=None)
route_info = ROUTE_DATA[route_key]
company = route_info['co'][0]
stop_ids = route_info['stops'].get(company, [])
stop_options = []
for idx, sid in enumerate(stop_ids):
name_en = "Unknown"
if sid in STOP_DATA:
name_en = STOP_DATA[sid]['name']['en']
label = f"{idx+1}. {name_en} ({sid})"
stop_options.append(label)
return gr.Dropdown(choices=stop_options, value=None), gr.Dropdown(choices=stop_options, value=None)
def calculate_real_metrics(route_key, start_str, end_str):
if route_key == "UNKNOWN" or not start_str or not end_str:
return None, None, "Wait"
try:
start_idx = int(start_str.split(".")[0]) - 1
end_idx = int(end_str.split(".")[0]) - 1
if start_idx >= end_idx:
return None, None, "Start must be before End"
route_info = ROUTE_DATA[route_key]
gtfs_id = route_info.get('gtfsId')
company = route_info['co'][0]
bound = route_info['bound'].get(company)
if not gtfs_id or not bound:
return None, None, "No Map Data"
url = f"https://hkbus.github.io/route-waypoints/{gtfs_id}-{bound}.json"
resp = requests.get(url)
if resp.status_code != 200:
return None, None, "Map Download Fail"
geojson = resp.json()
features = geojson.get('features', [])
segments = []
if features and features[0]['geometry']['type'] == 'MultiLineString':
segments = features[0]['geometry']['coordinates']
elif features:
segments = [f['geometry']['coordinates'] for f in features]
total_dist = 0
for i in range(start_idx, end_idx):
if i < len(segments):
total_dist += haversine_distance(segments[i])
num_stops = end_idx - start_idx
return total_dist, num_stops, None
except Exception as e:
return None, None, str(e)
def auto_fill_metrics(route_key, start_str, end_str, current_dist, current_stops):
"""Updates boxes when stops change."""
dist, stops, error = calculate_real_metrics(route_key, start_str, end_str)
if dist is not None and stops is not None:
return round(dist, 1), int(stops)
else:
return current_dist, current_stops
# --- Prediction Logic ---
def predict_fn(manual_dist, manual_stops, hour, day_name, route_id, start_str, end_str):
status_tag = ""
# Check if inputs match the map data (Validation Check)
map_dist, map_stops, _ = calculate_real_metrics(route_id, start_str, end_str)
if map_dist is not None:
# We have map data. Check if manual input differs significantly.
diff_dist = abs(map_dist - float(manual_dist))
diff_stops = abs(map_stops - float(manual_stops))
if diff_dist > 50 or diff_stops > 0: # 50m tolerance
status_tag = "Manual Inputs"
else:
status_tag = "Computed Data"
else:
status_tag = "Manual Inputs"
try:
inputs = {
'distance': np.array([[float(manual_dist)]]),
'num_stops': np.array([[float(manual_stops)]]),
'hour': np.array([[int(hour)]]),
'day_of_week': np.array([[int(DAY_MAP[day_name])]]),
'route_id': tf.constant([[str(route_id)]], dtype=tf.string)
}
prediction = model.predict(inputs, verbose=0)
seconds = float(prediction[0][0])
minutes = int(seconds // 60)
rem_seconds = int(seconds % 60)
return f"{status_tag}\n\n⏱️ ETA: {minutes} min {rem_seconds} sec"
except Exception as e:
return f"Model Error: {str(e)}"
# --- 3. Build the UI ---
with gr.Blocks(title="HK-TransitFlow-Net") as demo:
gr.Markdown("# HK-TransitFlow-Net Demo")
gr.Markdown("Predicts KMB/CTB bus travel time.")
gr.Markdown("Model: https://huggingface.co/WheelsTransit/HK-TransitFlow-Net")
with gr.Row():
with gr.Column():
gr.Markdown("### 1. Route Selection (Optional)")
gr.Markdown("Select a route to auto-fill distance, or skip to type manually.")
route_search = gr.Textbox(label="Search Route", placeholder="Type e.g. '968'")
route_dropdown = gr.Dropdown(label="Select Route ID", choices=["UNKNOWN"], value="UNKNOWN", interactive=True)
with gr.Row():
start_dropdown = gr.Dropdown(label="Start Stop", choices=[], interactive=True)
end_dropdown = gr.Dropdown(label="End Stop", choices=[], interactive=True)
gr.Markdown("---")
gr.Markdown("### 2. Time & Details")
with gr.Row():
hour_input = gr.Slider(minimum=0, maximum=23, step=1, label="Hour (0-23)", value=9)
day_input = gr.Dropdown(choices=list(DAY_MAP.keys()), label="Day", value="Monday")
with gr.Row():
dist_input = gr.Number(label="Distance (m)", value=5000)
stops_input = gr.Number(label="Stops Count", value=10)
predict_btn = gr.Button("Predict ETA", variant="primary")
with gr.Column():
gr.Markdown("### Result")
output_text = gr.Textbox(label="Prediction", lines=3)
gr.Markdown("*Tip: If you modify the Distance/Stops boxes manually, the model will use your typed values.*")
# --- Event Wiring ---
route_search.change(fn=filter_routes, inputs=route_search, outputs=route_dropdown)
route_dropdown.change(
fn=update_stop_dropdowns,
inputs=route_dropdown,
outputs=[start_dropdown, end_dropdown]
)
# Auto-fill triggers
start_dropdown.change(
fn=auto_fill_metrics,
inputs=[route_dropdown, start_dropdown, end_dropdown, dist_input, stops_input],
outputs=[dist_input, stops_input]
)
end_dropdown.change(
fn=auto_fill_metrics,
inputs=[route_dropdown, start_dropdown, end_dropdown, dist_input, stops_input],
outputs=[dist_input, stops_input]
)
# Predict trigger (Passes dropdowns just for validation check)
predict_btn.click(
fn=predict_fn,
inputs=[dist_input, stops_input, hour_input, day_input, route_dropdown, start_dropdown, end_dropdown],
outputs=output_text
)
if __name__ == "__main__":
demo.launch() |