Antcar commited on
Commit
cfb1903
·
verified ·
1 Parent(s): 4fb4fbc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -24
app.py CHANGED
@@ -1,11 +1,12 @@
1
  import gradio as gr
2
  import tensorflow as tf
3
  import numpy as np
4
- import os
5
 
6
- # 1. Load the Model
7
- # We expect the .keras file to be in the same directory
8
- model = tf.keras.models.load_model("hk_transit_flow_net.keras")
 
9
 
10
  # Helper to map day names to integers
11
  DAY_MAP = {
@@ -15,13 +16,11 @@ DAY_MAP = {
15
 
16
  def predict_eta(distance_meters, num_stops, hour, day_name, route_id):
17
  try:
18
- # 1. Prepare Inputs
19
- # We must match the exact shape and types used in training
20
-
21
  # Handle empty route
22
  if not route_id or route_id.strip() == "":
23
  route_id = "UNKNOWN"
24
 
 
25
  inputs = {
26
  'distance': np.array([[float(distance_meters)]]),
27
  'num_stops': np.array([[float(num_stops)]]),
@@ -30,11 +29,10 @@ def predict_eta(distance_meters, num_stops, hour, day_name, route_id):
30
  'route_id': tf.constant([[str(route_id)]], dtype=tf.string)
31
  }
32
 
33
- # 2. Run Prediction
34
  prediction = model.predict(inputs, verbose=0)
35
  seconds = float(prediction[0][0])
36
 
37
- # 3. Format Output
38
  minutes = int(seconds // 60)
39
  rem_seconds = int(seconds % 60)
40
 
@@ -43,7 +41,7 @@ def predict_eta(distance_meters, num_stops, hour, day_name, route_id):
43
  except Exception as e:
44
  return f"Error: {str(e)}"
45
 
46
- # 3. Build the Interface
47
  iface = gr.Interface(
48
  fn=predict_eta,
49
  inputs=[
@@ -51,22 +49,12 @@ iface = gr.Interface(
51
  gr.Number(label="Number of Stops", value=10),
52
  gr.Slider(minimum=0, maximum=23, step=1, label="Hour of Day (0-23)", value=9),
53
  gr.Dropdown(choices=list(DAY_MAP.keys()), label="Day of Week", value="Monday"),
54
- gr.Textbox(label="Route ID (Optional)", placeholder="e.g. 968+1+Yuen Long+Tin Hau", value="UNKNOWN")
55
  ],
56
  outputs="text",
57
- title="HK-TransitFlow-Net 🚌",
58
- description="""
59
- **Hong Kong Bus ETA Predictor**
60
-
61
- This model uses Deep Learning to predict bus travel time based on distance, stops, and time context.
62
-
63
- * **Distance:** Physical distance of the path in meters.
64
- * **Route ID:** Internal ID (e.g., `968+1+...`). If unknown, leave as UNKNOWN.
65
- * **Note:** Trained on KMB & CTB data.
66
- """,
67
  theme="soft"
68
  )
69
 
70
- # 4. Launch
71
- if __name__ == "__main__":
72
- iface.launch()
 
1
  import gradio as gr
2
  import tensorflow as tf
3
  import numpy as np
4
+ from huggingface_hub import from_pretrained_keras
5
 
6
+ # 1. Download the Model from your Repository
7
+ # This connects this Space to your uploaded model
8
+ print("Downloading model...")
9
+ model = from_pretrained_keras("WheelsTransit/HK-TransitFlow-Net")
10
 
11
  # Helper to map day names to integers
12
  DAY_MAP = {
 
16
 
17
  def predict_eta(distance_meters, num_stops, hour, day_name, route_id):
18
  try:
 
 
 
19
  # Handle empty route
20
  if not route_id or route_id.strip() == "":
21
  route_id = "UNKNOWN"
22
 
23
+ # Prepare inputs exactly as the model expects
24
  inputs = {
25
  'distance': np.array([[float(distance_meters)]]),
26
  'num_stops': np.array([[float(num_stops)]]),
 
29
  'route_id': tf.constant([[str(route_id)]], dtype=tf.string)
30
  }
31
 
32
+ # Run Prediction
33
  prediction = model.predict(inputs, verbose=0)
34
  seconds = float(prediction[0][0])
35
 
 
36
  minutes = int(seconds // 60)
37
  rem_seconds = int(seconds % 60)
38
 
 
41
  except Exception as e:
42
  return f"Error: {str(e)}"
43
 
44
+ # Build the Interface
45
  iface = gr.Interface(
46
  fn=predict_eta,
47
  inputs=[
 
49
  gr.Number(label="Number of Stops", value=10),
50
  gr.Slider(minimum=0, maximum=23, step=1, label="Hour of Day (0-23)", value=9),
51
  gr.Dropdown(choices=list(DAY_MAP.keys()), label="Day of Week", value="Monday"),
52
+ gr.Textbox(label="Route ID", placeholder="968+1+...", value="UNKNOWN")
53
  ],
54
  outputs="text",
55
+ title="HK-TransitFlow-Net Demo",
56
+ description="Live inference for HK Bus ETA prediction.",
 
 
 
 
 
 
 
 
57
  theme="soft"
58
  )
59
 
60
+ iface.launch()