File size: 9,119 Bytes
42bfcc1
 
 
37b478d
6710f72
e4b324d
42bfcc1
6710f72
 
 
 
 
 
 
37b478d
 
 
6710f72
 
 
d3e8921
 
 
6710f72
d3e8921
 
 
6710f72
 
d3e8921
6710f72
37b478d
6710f72
37b478d
6710f72
37b478d
6710f72
37b478d
 
cfb1903
6710f72
 
 
 
 
 
 
42bfcc1
6710f72
42bfcc1
 
 
 
 
6710f72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3e8921
 
6710f72
d3e8921
6710f72
d3e8921
 
6710f72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7405298
d3e8921
42bfcc1
6710f72
 
 
 
7405298
6710f72
 
 
 
 
 
 
7405298
6710f72
 
 
 
7405298
42bfcc1
6710f72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7405298
6710f72
 
7405298
 
 
 
 
 
 
 
 
6710f72
 
 
7405298
 
 
 
 
6710f72
7405298
 
 
 
6710f72
7405298
c3d0b5d
6710f72
c3d0b5d
6710f72
c3d0b5d
6710f72
 
42bfcc1
7405298
 
42bfcc1
 
 
 
 
 
 
 
 
 
7405298
42bfcc1
 
7405298
6710f72
42bfcc1
37b478d
6710f72
d9fe210
7405298
c3d0b5d
f4f7cd3
 
 
7405298
 
 
6710f72
 
37b478d
6710f72
 
 
f4f7cd3
6710f72
 
 
 
 
d3e8921
6710f72
 
 
 
d3e8921
f4f7cd3
6710f72
 
7405298
 
d3e8921
6710f72
 
 
 
 
 
 
 
 
 
7405298
 
 
 
 
 
 
 
 
 
 
 
 
 
f4f7cd3
7405298
6710f72
7405298
f4f7cd3
42bfcc1
f4f7cd3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
import gradio as gr
import tensorflow as tf
import numpy as np
import requests
import math
from huggingface_hub import hf_hub_download

# --- Global Data Storage ---
ROUTE_DATA = {}
STOP_DATA = {}
ALL_ROUTE_KEYS = []

# --- 1. Fetch Data (Routes & Stops) ---
print("Fetching route and stop data...")
try:
    resp = requests.get("https://hkbus.github.io/hk-bus-crawling/routeFareList.min.json")
    if resp.status_code == 200:
        json_db = resp.json()
        raw_routes = json_db['routeList']
        STOP_DATA = json_db['stopList']
        
        valid_companies = ['kmb', 'ctb']
        
        for key, info in raw_routes.items():
            if 'co' in info and len(info['co']) > 0:
                company = info['co'][0]
                if company in valid_companies:
                    ROUTE_DATA[key] = info
                    ALL_ROUTE_KEYS.append(key)
                    
        ALL_ROUTE_KEYS.sort()
    else:
        print("Failed to download route data")
except Exception as e:
    print(f"Error fetching data: {e}")

print(f"Loaded {len(ALL_ROUTE_KEYS)} valid KMB/CTB routes.")

# --- 2. Download and Load Model ---
print("Downloading model...")
try:
    model_path = hf_hub_download(repo_id="WheelsTransit/HK-TransitFlow-Net", filename="hk_transit_flow_net.keras")
    print("Loading Keras model...")
    model = tf.keras.models.load_model(model_path)
except Exception as e:
    print(f"Model load failed: {e}")
    model = None

# --- Helpers ---
DAY_MAP = {
    "Sunday": 0, "Monday": 1, "Tuesday": 2, "Wednesday": 3,
    "Thursday": 4, "Friday": 5, "Saturday": 6
}

def haversine_distance(coords):
    R = 6371000 
    total_dist = 0
    for i in range(len(coords) - 1):
        lon1, lat1 = coords[i]
        lon2, lat2 = coords[i+1]
        dlon = math.radians(lon2 - lon1)
        dlat = math.radians(lat2 - lat1)
        a = math.sin(dlat/2)**2 + math.cos(math.radians(lat1)) * math.cos(math.radians(lat2)) * math.sin(dlon/2)**2
        c = 2 * math.asin(math.sqrt(a))
        total_dist += R * c
    return total_dist

# --- Dynamic UI Logic ---

def filter_routes(search_text):
    if not search_text:
        return gr.Dropdown(choices=["UNKNOWN"] + ALL_ROUTE_KEYS[:20])
    search_text = search_text.lower()
    filtered = [r for r in ALL_ROUTE_KEYS if search_text in r.lower()]
    return gr.Dropdown(choices=["UNKNOWN"] + filtered[:100], value="UNKNOWN")

def update_stop_dropdowns(route_key):
    if not route_key or route_key == "UNKNOWN" or route_key not in ROUTE_DATA:
        return gr.Dropdown(choices=[], value=None), gr.Dropdown(choices=[], value=None)
    
    route_info = ROUTE_DATA[route_key]
    company = route_info['co'][0]
    stop_ids = route_info['stops'].get(company, [])
    
    stop_options = []
    for idx, sid in enumerate(stop_ids):
        name_en = "Unknown"
        if sid in STOP_DATA:
            name_en = STOP_DATA[sid]['name']['en']
        label = f"{idx+1}. {name_en} ({sid})"
        stop_options.append(label)
        
    return gr.Dropdown(choices=stop_options, value=None), gr.Dropdown(choices=stop_options, value=None)

def calculate_real_metrics(route_key, start_str, end_str):
    if route_key == "UNKNOWN" or not start_str or not end_str:
        return None, None, "Wait"

    try:
        start_idx = int(start_str.split(".")[0]) - 1
        end_idx = int(end_str.split(".")[0]) - 1
        
        if start_idx >= end_idx:
            return None, None, "Start must be before End"

        route_info = ROUTE_DATA[route_key]
        gtfs_id = route_info.get('gtfsId')
        company = route_info['co'][0]
        bound = route_info['bound'].get(company)
        
        if not gtfs_id or not bound:
            return None, None, "No Map Data"

        url = f"https://hkbus.github.io/route-waypoints/{gtfs_id}-{bound}.json"
        resp = requests.get(url)
        if resp.status_code != 200:
            return None, None, "Map Download Fail"
            
        geojson = resp.json()
        features = geojson.get('features', [])
        
        segments = []
        if features and features[0]['geometry']['type'] == 'MultiLineString':
             segments = features[0]['geometry']['coordinates']
        elif features:
             segments = [f['geometry']['coordinates'] for f in features]
             
        total_dist = 0
        for i in range(start_idx, end_idx):
            if i < len(segments):
                total_dist += haversine_distance(segments[i])
                
        num_stops = end_idx - start_idx
        return total_dist, num_stops, None

    except Exception as e:
        return None, None, str(e)

def auto_fill_metrics(route_key, start_str, end_str, current_dist, current_stops):
    """Updates boxes when stops change."""
    dist, stops, error = calculate_real_metrics(route_key, start_str, end_str)
    if dist is not None and stops is not None:
        return round(dist, 1), int(stops)
    else:
        return current_dist, current_stops

# --- Prediction Logic ---

def predict_fn(manual_dist, manual_stops, hour, day_name, route_id, start_str, end_str):
    status_tag = ""
    
    # Check if inputs match the map data (Validation Check)
    map_dist, map_stops, _ = calculate_real_metrics(route_id, start_str, end_str)
    
    if map_dist is not None:
        # We have map data. Check if manual input differs significantly.
        diff_dist = abs(map_dist - float(manual_dist))
        diff_stops = abs(map_stops - float(manual_stops))
        
        if diff_dist > 50 or diff_stops > 0: # 50m tolerance
            status_tag = "Manual Inputs"
        else:
            status_tag = "Computed Data"
    else:
        status_tag = "Manual Inputs"

    try:
        inputs = {
            'distance': np.array([[float(manual_dist)]]),
            'num_stops': np.array([[float(manual_stops)]]),
            'hour': np.array([[int(hour)]]),
            'day_of_week': np.array([[int(DAY_MAP[day_name])]]),
            'route_id': tf.constant([[str(route_id)]], dtype=tf.string)
        }
        
        prediction = model.predict(inputs, verbose=0)
        seconds = float(prediction[0][0])
        minutes = int(seconds // 60)
        rem_seconds = int(seconds % 60)
        
        return f"{status_tag}\n\n⏱️ ETA: {minutes} min {rem_seconds} sec"
    
    except Exception as e:
        return f"Model Error: {str(e)}"


# --- 3. Build the UI ---
with gr.Blocks(title="HK-TransitFlow-Net") as demo:
    gr.Markdown("# HK-TransitFlow-Net Demo")
    gr.Markdown("Predicts KMB/CTB bus travel time.")
    gr.Markdown("Model: https://huggingface.co/WheelsTransit/HK-TransitFlow-Net")
    
    with gr.Row():
        with gr.Column():
            gr.Markdown("### 1. Route Selection (Optional)")
            gr.Markdown("Select a route to auto-fill distance, or skip to type manually.")
            
            route_search = gr.Textbox(label="Search Route", placeholder="Type e.g. '968'")
            route_dropdown = gr.Dropdown(label="Select Route ID", choices=["UNKNOWN"], value="UNKNOWN", interactive=True)
            
            with gr.Row():
                start_dropdown = gr.Dropdown(label="Start Stop", choices=[], interactive=True)
                end_dropdown = gr.Dropdown(label="End Stop", choices=[], interactive=True)
            
            gr.Markdown("---")
            gr.Markdown("### 2. Time & Details")
            with gr.Row():
                hour_input = gr.Slider(minimum=0, maximum=23, step=1, label="Hour (0-23)", value=9)
                day_input = gr.Dropdown(choices=list(DAY_MAP.keys()), label="Day", value="Monday")
            
            with gr.Row():
                dist_input = gr.Number(label="Distance (m)", value=5000)
                stops_input = gr.Number(label="Stops Count", value=10)

            predict_btn = gr.Button("Predict ETA", variant="primary")

        with gr.Column():
            gr.Markdown("### Result")
            output_text = gr.Textbox(label="Prediction", lines=3)
            gr.Markdown("*Tip: If you modify the Distance/Stops boxes manually, the model will use your typed values.*")

    # --- Event Wiring ---
    
    route_search.change(fn=filter_routes, inputs=route_search, outputs=route_dropdown)
    
    route_dropdown.change(
        fn=update_stop_dropdowns, 
        inputs=route_dropdown, 
        outputs=[start_dropdown, end_dropdown]
    )
    
    # Auto-fill triggers
    start_dropdown.change(
        fn=auto_fill_metrics,
        inputs=[route_dropdown, start_dropdown, end_dropdown, dist_input, stops_input],
        outputs=[dist_input, stops_input]
    )
    
    end_dropdown.change(
        fn=auto_fill_metrics,
        inputs=[route_dropdown, start_dropdown, end_dropdown, dist_input, stops_input],
        outputs=[dist_input, stops_input]
    )
    
    # Predict trigger (Passes dropdowns just for validation check)
    predict_btn.click(
        fn=predict_fn,
        inputs=[dist_input, stops_input, hour_input, day_input, route_dropdown, start_dropdown, end_dropdown],
        outputs=output_text
    )

if __name__ == "__main__":
    demo.launch()