antcar0929's picture
Update app.py
c3d0b5d verified
import gradio as gr
import tensorflow as tf
import numpy as np
import requests
import math
from huggingface_hub import hf_hub_download
# --- Global Data Storage ---
ROUTE_DATA = {}
STOP_DATA = {}
ALL_ROUTE_KEYS = []
# --- 1. Fetch Data (Routes & Stops) ---
print("Fetching route and stop data...")
try:
resp = requests.get("https://hkbus.github.io/hk-bus-crawling/routeFareList.min.json")
if resp.status_code == 200:
json_db = resp.json()
raw_routes = json_db['routeList']
STOP_DATA = json_db['stopList']
valid_companies = ['kmb', 'ctb']
for key, info in raw_routes.items():
if 'co' in info and len(info['co']) > 0:
company = info['co'][0]
if company in valid_companies:
ROUTE_DATA[key] = info
ALL_ROUTE_KEYS.append(key)
ALL_ROUTE_KEYS.sort()
else:
print("Failed to download route data")
except Exception as e:
print(f"Error fetching data: {e}")
print(f"Loaded {len(ALL_ROUTE_KEYS)} valid KMB/CTB routes.")
# --- 2. Download and Load Model ---
print("Downloading model...")
try:
model_path = hf_hub_download(repo_id="WheelsTransit/HK-TransitFlow-Net", filename="hk_transit_flow_net.keras")
print("Loading Keras model...")
model = tf.keras.models.load_model(model_path)
except Exception as e:
print(f"Model load failed: {e}")
model = None
# --- Helpers ---
DAY_MAP = {
"Sunday": 0, "Monday": 1, "Tuesday": 2, "Wednesday": 3,
"Thursday": 4, "Friday": 5, "Saturday": 6
}
def haversine_distance(coords):
R = 6371000
total_dist = 0
for i in range(len(coords) - 1):
lon1, lat1 = coords[i]
lon2, lat2 = coords[i+1]
dlon = math.radians(lon2 - lon1)
dlat = math.radians(lat2 - lat1)
a = math.sin(dlat/2)**2 + math.cos(math.radians(lat1)) * math.cos(math.radians(lat2)) * math.sin(dlon/2)**2
c = 2 * math.asin(math.sqrt(a))
total_dist += R * c
return total_dist
# --- Dynamic UI Logic ---
def filter_routes(search_text):
if not search_text:
return gr.Dropdown(choices=["UNKNOWN"] + ALL_ROUTE_KEYS[:20])
search_text = search_text.lower()
filtered = [r for r in ALL_ROUTE_KEYS if search_text in r.lower()]
return gr.Dropdown(choices=["UNKNOWN"] + filtered[:100], value="UNKNOWN")
def update_stop_dropdowns(route_key):
if not route_key or route_key == "UNKNOWN" or route_key not in ROUTE_DATA:
return gr.Dropdown(choices=[], value=None), gr.Dropdown(choices=[], value=None)
route_info = ROUTE_DATA[route_key]
company = route_info['co'][0]
stop_ids = route_info['stops'].get(company, [])
stop_options = []
for idx, sid in enumerate(stop_ids):
name_en = "Unknown"
if sid in STOP_DATA:
name_en = STOP_DATA[sid]['name']['en']
label = f"{idx+1}. {name_en} ({sid})"
stop_options.append(label)
return gr.Dropdown(choices=stop_options, value=None), gr.Dropdown(choices=stop_options, value=None)
def calculate_real_metrics(route_key, start_str, end_str):
if route_key == "UNKNOWN" or not start_str or not end_str:
return None, None, "Wait"
try:
start_idx = int(start_str.split(".")[0]) - 1
end_idx = int(end_str.split(".")[0]) - 1
if start_idx >= end_idx:
return None, None, "Start must be before End"
route_info = ROUTE_DATA[route_key]
gtfs_id = route_info.get('gtfsId')
company = route_info['co'][0]
bound = route_info['bound'].get(company)
if not gtfs_id or not bound:
return None, None, "No Map Data"
url = f"https://hkbus.github.io/route-waypoints/{gtfs_id}-{bound}.json"
resp = requests.get(url)
if resp.status_code != 200:
return None, None, "Map Download Fail"
geojson = resp.json()
features = geojson.get('features', [])
segments = []
if features and features[0]['geometry']['type'] == 'MultiLineString':
segments = features[0]['geometry']['coordinates']
elif features:
segments = [f['geometry']['coordinates'] for f in features]
total_dist = 0
for i in range(start_idx, end_idx):
if i < len(segments):
total_dist += haversine_distance(segments[i])
num_stops = end_idx - start_idx
return total_dist, num_stops, None
except Exception as e:
return None, None, str(e)
def auto_fill_metrics(route_key, start_str, end_str, current_dist, current_stops):
"""Updates boxes when stops change."""
dist, stops, error = calculate_real_metrics(route_key, start_str, end_str)
if dist is not None and stops is not None:
return round(dist, 1), int(stops)
else:
return current_dist, current_stops
# --- Prediction Logic ---
def predict_fn(manual_dist, manual_stops, hour, day_name, route_id, start_str, end_str):
status_tag = ""
# Check if inputs match the map data (Validation Check)
map_dist, map_stops, _ = calculate_real_metrics(route_id, start_str, end_str)
if map_dist is not None:
# We have map data. Check if manual input differs significantly.
diff_dist = abs(map_dist - float(manual_dist))
diff_stops = abs(map_stops - float(manual_stops))
if diff_dist > 50 or diff_stops > 0: # 50m tolerance
status_tag = "Manual Inputs"
else:
status_tag = "Computed Data"
else:
status_tag = "Manual Inputs"
try:
inputs = {
'distance': np.array([[float(manual_dist)]]),
'num_stops': np.array([[float(manual_stops)]]),
'hour': np.array([[int(hour)]]),
'day_of_week': np.array([[int(DAY_MAP[day_name])]]),
'route_id': tf.constant([[str(route_id)]], dtype=tf.string)
}
prediction = model.predict(inputs, verbose=0)
seconds = float(prediction[0][0])
minutes = int(seconds // 60)
rem_seconds = int(seconds % 60)
return f"{status_tag}\n\n⏱️ ETA: {minutes} min {rem_seconds} sec"
except Exception as e:
return f"Model Error: {str(e)}"
# --- 3. Build the UI ---
with gr.Blocks(title="HK-TransitFlow-Net") as demo:
gr.Markdown("# HK-TransitFlow-Net Demo")
gr.Markdown("Predicts KMB/CTB bus travel time.")
gr.Markdown("Model: https://huggingface.co/WheelsTransit/HK-TransitFlow-Net")
with gr.Row():
with gr.Column():
gr.Markdown("### 1. Route Selection (Optional)")
gr.Markdown("Select a route to auto-fill distance, or skip to type manually.")
route_search = gr.Textbox(label="Search Route", placeholder="Type e.g. '968'")
route_dropdown = gr.Dropdown(label="Select Route ID", choices=["UNKNOWN"], value="UNKNOWN", interactive=True)
with gr.Row():
start_dropdown = gr.Dropdown(label="Start Stop", choices=[], interactive=True)
end_dropdown = gr.Dropdown(label="End Stop", choices=[], interactive=True)
gr.Markdown("---")
gr.Markdown("### 2. Time & Details")
with gr.Row():
hour_input = gr.Slider(minimum=0, maximum=23, step=1, label="Hour (0-23)", value=9)
day_input = gr.Dropdown(choices=list(DAY_MAP.keys()), label="Day", value="Monday")
with gr.Row():
dist_input = gr.Number(label="Distance (m)", value=5000)
stops_input = gr.Number(label="Stops Count", value=10)
predict_btn = gr.Button("Predict ETA", variant="primary")
with gr.Column():
gr.Markdown("### Result")
output_text = gr.Textbox(label="Prediction", lines=3)
gr.Markdown("*Tip: If you modify the Distance/Stops boxes manually, the model will use your typed values.*")
# --- Event Wiring ---
route_search.change(fn=filter_routes, inputs=route_search, outputs=route_dropdown)
route_dropdown.change(
fn=update_stop_dropdowns,
inputs=route_dropdown,
outputs=[start_dropdown, end_dropdown]
)
# Auto-fill triggers
start_dropdown.change(
fn=auto_fill_metrics,
inputs=[route_dropdown, start_dropdown, end_dropdown, dist_input, stops_input],
outputs=[dist_input, stops_input]
)
end_dropdown.change(
fn=auto_fill_metrics,
inputs=[route_dropdown, start_dropdown, end_dropdown, dist_input, stops_input],
outputs=[dist_input, stops_input]
)
# Predict trigger (Passes dropdowns just for validation check)
predict_btn.click(
fn=predict_fn,
inputs=[dist_input, stops_input, hour_input, day_input, route_dropdown, start_dropdown, end_dropdown],
outputs=output_text
)
if __name__ == "__main__":
demo.launch()