Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import tensorflow as tf | |
| import numpy as np | |
| import requests | |
| import math | |
| from huggingface_hub import hf_hub_download | |
| # --- Global Data Storage --- | |
| ROUTE_DATA = {} | |
| STOP_DATA = {} | |
| ALL_ROUTE_KEYS = [] | |
| # --- 1. Fetch Data (Routes & Stops) --- | |
| print("Fetching route and stop data...") | |
| try: | |
| resp = requests.get("https://hkbus.github.io/hk-bus-crawling/routeFareList.min.json") | |
| if resp.status_code == 200: | |
| json_db = resp.json() | |
| raw_routes = json_db['routeList'] | |
| STOP_DATA = json_db['stopList'] | |
| valid_companies = ['kmb', 'ctb'] | |
| for key, info in raw_routes.items(): | |
| if 'co' in info and len(info['co']) > 0: | |
| company = info['co'][0] | |
| if company in valid_companies: | |
| ROUTE_DATA[key] = info | |
| ALL_ROUTE_KEYS.append(key) | |
| ALL_ROUTE_KEYS.sort() | |
| else: | |
| print("Failed to download route data") | |
| except Exception as e: | |
| print(f"Error fetching data: {e}") | |
| print(f"Loaded {len(ALL_ROUTE_KEYS)} valid KMB/CTB routes.") | |
| # --- 2. Download and Load Model --- | |
| print("Downloading model...") | |
| try: | |
| model_path = hf_hub_download(repo_id="WheelsTransit/HK-TransitFlow-Net", filename="hk_transit_flow_net.keras") | |
| print("Loading Keras model...") | |
| model = tf.keras.models.load_model(model_path) | |
| except Exception as e: | |
| print(f"Model load failed: {e}") | |
| model = None | |
| # --- Helpers --- | |
| DAY_MAP = { | |
| "Sunday": 0, "Monday": 1, "Tuesday": 2, "Wednesday": 3, | |
| "Thursday": 4, "Friday": 5, "Saturday": 6 | |
| } | |
| def haversine_distance(coords): | |
| R = 6371000 | |
| total_dist = 0 | |
| for i in range(len(coords) - 1): | |
| lon1, lat1 = coords[i] | |
| lon2, lat2 = coords[i+1] | |
| dlon = math.radians(lon2 - lon1) | |
| dlat = math.radians(lat2 - lat1) | |
| a = math.sin(dlat/2)**2 + math.cos(math.radians(lat1)) * math.cos(math.radians(lat2)) * math.sin(dlon/2)**2 | |
| c = 2 * math.asin(math.sqrt(a)) | |
| total_dist += R * c | |
| return total_dist | |
| # --- Dynamic UI Logic --- | |
| def filter_routes(search_text): | |
| if not search_text: | |
| return gr.Dropdown(choices=["UNKNOWN"] + ALL_ROUTE_KEYS[:20]) | |
| search_text = search_text.lower() | |
| filtered = [r for r in ALL_ROUTE_KEYS if search_text in r.lower()] | |
| return gr.Dropdown(choices=["UNKNOWN"] + filtered[:100], value="UNKNOWN") | |
| def update_stop_dropdowns(route_key): | |
| if not route_key or route_key == "UNKNOWN" or route_key not in ROUTE_DATA: | |
| return gr.Dropdown(choices=[], value=None), gr.Dropdown(choices=[], value=None) | |
| route_info = ROUTE_DATA[route_key] | |
| company = route_info['co'][0] | |
| stop_ids = route_info['stops'].get(company, []) | |
| stop_options = [] | |
| for idx, sid in enumerate(stop_ids): | |
| name_en = "Unknown" | |
| if sid in STOP_DATA: | |
| name_en = STOP_DATA[sid]['name']['en'] | |
| label = f"{idx+1}. {name_en} ({sid})" | |
| stop_options.append(label) | |
| return gr.Dropdown(choices=stop_options, value=None), gr.Dropdown(choices=stop_options, value=None) | |
| def calculate_real_metrics(route_key, start_str, end_str): | |
| if route_key == "UNKNOWN" or not start_str or not end_str: | |
| return None, None, "Wait" | |
| try: | |
| start_idx = int(start_str.split(".")[0]) - 1 | |
| end_idx = int(end_str.split(".")[0]) - 1 | |
| if start_idx >= end_idx: | |
| return None, None, "Start must be before End" | |
| route_info = ROUTE_DATA[route_key] | |
| gtfs_id = route_info.get('gtfsId') | |
| company = route_info['co'][0] | |
| bound = route_info['bound'].get(company) | |
| if not gtfs_id or not bound: | |
| return None, None, "No Map Data" | |
| url = f"https://hkbus.github.io/route-waypoints/{gtfs_id}-{bound}.json" | |
| resp = requests.get(url) | |
| if resp.status_code != 200: | |
| return None, None, "Map Download Fail" | |
| geojson = resp.json() | |
| features = geojson.get('features', []) | |
| segments = [] | |
| if features and features[0]['geometry']['type'] == 'MultiLineString': | |
| segments = features[0]['geometry']['coordinates'] | |
| elif features: | |
| segments = [f['geometry']['coordinates'] for f in features] | |
| total_dist = 0 | |
| for i in range(start_idx, end_idx): | |
| if i < len(segments): | |
| total_dist += haversine_distance(segments[i]) | |
| num_stops = end_idx - start_idx | |
| return total_dist, num_stops, None | |
| except Exception as e: | |
| return None, None, str(e) | |
| def auto_fill_metrics(route_key, start_str, end_str, current_dist, current_stops): | |
| """Updates boxes when stops change.""" | |
| dist, stops, error = calculate_real_metrics(route_key, start_str, end_str) | |
| if dist is not None and stops is not None: | |
| return round(dist, 1), int(stops) | |
| else: | |
| return current_dist, current_stops | |
| # --- Prediction Logic --- | |
| def predict_fn(manual_dist, manual_stops, hour, day_name, route_id, start_str, end_str): | |
| status_tag = "" | |
| # Check if inputs match the map data (Validation Check) | |
| map_dist, map_stops, _ = calculate_real_metrics(route_id, start_str, end_str) | |
| if map_dist is not None: | |
| # We have map data. Check if manual input differs significantly. | |
| diff_dist = abs(map_dist - float(manual_dist)) | |
| diff_stops = abs(map_stops - float(manual_stops)) | |
| if diff_dist > 50 or diff_stops > 0: # 50m tolerance | |
| status_tag = "Manual Inputs" | |
| else: | |
| status_tag = "Computed Data" | |
| else: | |
| status_tag = "Manual Inputs" | |
| try: | |
| inputs = { | |
| 'distance': np.array([[float(manual_dist)]]), | |
| 'num_stops': np.array([[float(manual_stops)]]), | |
| 'hour': np.array([[int(hour)]]), | |
| 'day_of_week': np.array([[int(DAY_MAP[day_name])]]), | |
| 'route_id': tf.constant([[str(route_id)]], dtype=tf.string) | |
| } | |
| prediction = model.predict(inputs, verbose=0) | |
| seconds = float(prediction[0][0]) | |
| minutes = int(seconds // 60) | |
| rem_seconds = int(seconds % 60) | |
| return f"{status_tag}\n\n⏱️ ETA: {minutes} min {rem_seconds} sec" | |
| except Exception as e: | |
| return f"Model Error: {str(e)}" | |
| # --- 3. Build the UI --- | |
| with gr.Blocks(title="HK-TransitFlow-Net") as demo: | |
| gr.Markdown("# HK-TransitFlow-Net Demo") | |
| gr.Markdown("Predicts KMB/CTB bus travel time.") | |
| gr.Markdown("Model: https://huggingface.co/WheelsTransit/HK-TransitFlow-Net") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### 1. Route Selection (Optional)") | |
| gr.Markdown("Select a route to auto-fill distance, or skip to type manually.") | |
| route_search = gr.Textbox(label="Search Route", placeholder="Type e.g. '968'") | |
| route_dropdown = gr.Dropdown(label="Select Route ID", choices=["UNKNOWN"], value="UNKNOWN", interactive=True) | |
| with gr.Row(): | |
| start_dropdown = gr.Dropdown(label="Start Stop", choices=[], interactive=True) | |
| end_dropdown = gr.Dropdown(label="End Stop", choices=[], interactive=True) | |
| gr.Markdown("---") | |
| gr.Markdown("### 2. Time & Details") | |
| with gr.Row(): | |
| hour_input = gr.Slider(minimum=0, maximum=23, step=1, label="Hour (0-23)", value=9) | |
| day_input = gr.Dropdown(choices=list(DAY_MAP.keys()), label="Day", value="Monday") | |
| with gr.Row(): | |
| dist_input = gr.Number(label="Distance (m)", value=5000) | |
| stops_input = gr.Number(label="Stops Count", value=10) | |
| predict_btn = gr.Button("Predict ETA", variant="primary") | |
| with gr.Column(): | |
| gr.Markdown("### Result") | |
| output_text = gr.Textbox(label="Prediction", lines=3) | |
| gr.Markdown("*Tip: If you modify the Distance/Stops boxes manually, the model will use your typed values.*") | |
| # --- Event Wiring --- | |
| route_search.change(fn=filter_routes, inputs=route_search, outputs=route_dropdown) | |
| route_dropdown.change( | |
| fn=update_stop_dropdowns, | |
| inputs=route_dropdown, | |
| outputs=[start_dropdown, end_dropdown] | |
| ) | |
| # Auto-fill triggers | |
| start_dropdown.change( | |
| fn=auto_fill_metrics, | |
| inputs=[route_dropdown, start_dropdown, end_dropdown, dist_input, stops_input], | |
| outputs=[dist_input, stops_input] | |
| ) | |
| end_dropdown.change( | |
| fn=auto_fill_metrics, | |
| inputs=[route_dropdown, start_dropdown, end_dropdown, dist_input, stops_input], | |
| outputs=[dist_input, stops_input] | |
| ) | |
| # Predict trigger (Passes dropdowns just for validation check) | |
| predict_btn.click( | |
| fn=predict_fn, | |
| inputs=[dist_input, stops_input, hour_input, day_input, route_dropdown, start_dropdown, end_dropdown], | |
| outputs=output_text | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |