Julian Bilcke
commited on
Commit
·
0ffe757
1
Parent(s):
5393526
let's try again
Browse files- api_engine.py +30 -1
- api_server.py +2 -2
- api_utils.py +25 -5
api_engine.py
CHANGED
|
@@ -113,8 +113,13 @@ class MatrixGameEngine:
|
|
| 113 |
def _init_models(self):
|
| 114 |
"""Initialize Matrix-Game V2 models"""
|
| 115 |
try:
|
|
|
|
| 116 |
# Load configuration
|
|
|
|
|
|
|
|
|
|
| 117 |
self.config = OmegaConf.load(self.config_path)
|
|
|
|
| 118 |
|
| 119 |
# Initialize generator
|
| 120 |
generator = WanDiffusionWrapper(
|
|
@@ -175,10 +180,14 @@ class MatrixGameEngine:
|
|
| 175 |
logger.info("Models loaded successfully")
|
| 176 |
|
| 177 |
# Preprocess initial images for all scenes
|
|
|
|
| 178 |
for scene_name, frames in self.scenes.items():
|
| 179 |
if frames and len(frames) > 0:
|
|
|
|
| 180 |
# Prepare the first frame as initial latent
|
| 181 |
self._prepare_scene_latent(scene_name, frames[0])
|
|
|
|
|
|
|
| 182 |
|
| 183 |
except Exception as e:
|
| 184 |
logger.error(f"Error loading models: {str(e)}")
|
|
@@ -264,6 +273,8 @@ class MatrixGameEngine:
|
|
| 264 |
raise RuntimeError(error_msg)
|
| 265 |
|
| 266 |
try:
|
|
|
|
|
|
|
| 267 |
# Map scene name to mode
|
| 268 |
mode_map = {
|
| 269 |
'universal': 'universal',
|
|
@@ -272,15 +283,19 @@ class MatrixGameEngine:
|
|
| 272 |
'templerun': 'templerun'
|
| 273 |
}
|
| 274 |
mode = mode_map.get(scene_name, 'universal')
|
|
|
|
| 275 |
|
| 276 |
# Get cached latent or prepare new one
|
| 277 |
if scene_name not in self.scene_latents:
|
|
|
|
| 278 |
scene_frames = self.scenes.get(scene_name, self.scenes.get('universal', []))
|
| 279 |
if scene_frames:
|
|
|
|
| 280 |
self._prepare_scene_latent(scene_name, scene_frames[0])
|
| 281 |
else:
|
| 282 |
error_msg = f"No initial frames available for scene: {scene_name}"
|
| 283 |
logger.error(error_msg)
|
|
|
|
| 284 |
raise ValueError(error_msg)
|
| 285 |
|
| 286 |
scene_data = self.scene_latents.get(scene_name)
|
|
@@ -289,6 +304,8 @@ class MatrixGameEngine:
|
|
| 289 |
logger.error(error_msg)
|
| 290 |
raise ValueError(error_msg)
|
| 291 |
|
|
|
|
|
|
|
| 292 |
# Prepare conditions
|
| 293 |
if keyboard_condition is None:
|
| 294 |
keyboard_condition = [[0, 0, 0, 0, 0, 0]]
|
|
@@ -321,6 +338,10 @@ class MatrixGameEngine:
|
|
| 321 |
|
| 322 |
# Generate frames with streaming pipeline
|
| 323 |
with torch.no_grad():
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
# Set seed for reproducibility
|
| 325 |
set_seed(self.seed + self.frame_count)
|
| 326 |
|
|
@@ -334,6 +355,8 @@ class MatrixGameEngine:
|
|
| 334 |
mode=mode
|
| 335 |
)
|
| 336 |
|
|
|
|
|
|
|
| 337 |
# Decode first frame from latent
|
| 338 |
if outputs is not None and len(outputs) > 0:
|
| 339 |
# Extract first frame
|
|
@@ -353,7 +376,13 @@ class MatrixGameEngine:
|
|
| 353 |
|
| 354 |
except Exception as e:
|
| 355 |
error_msg = f"Error generating frame with Matrix-Game V2 model: {str(e)}"
|
| 356 |
-
logger.error(error_msg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
raise RuntimeError(error_msg)
|
| 358 |
|
| 359 |
# Add visualization of input controls
|
|
|
|
| 113 |
def _init_models(self):
|
| 114 |
"""Initialize Matrix-Game V2 models"""
|
| 115 |
try:
|
| 116 |
+
logger.info(f"Loading configuration from: {self.config_path}")
|
| 117 |
# Load configuration
|
| 118 |
+
if not os.path.exists(self.config_path):
|
| 119 |
+
logger.error(f"Config file not found: {self.config_path}")
|
| 120 |
+
raise FileNotFoundError(f"Config file not found: {self.config_path}")
|
| 121 |
self.config = OmegaConf.load(self.config_path)
|
| 122 |
+
logger.debug(f"Configuration loaded: {self.config}")
|
| 123 |
|
| 124 |
# Initialize generator
|
| 125 |
generator = WanDiffusionWrapper(
|
|
|
|
| 180 |
logger.info("Models loaded successfully")
|
| 181 |
|
| 182 |
# Preprocess initial images for all scenes
|
| 183 |
+
logger.info("Preprocessing initial images for scenes...")
|
| 184 |
for scene_name, frames in self.scenes.items():
|
| 185 |
if frames and len(frames) > 0:
|
| 186 |
+
logger.debug(f"Preparing latent for scene: {scene_name} ({len(frames)} frames)")
|
| 187 |
# Prepare the first frame as initial latent
|
| 188 |
self._prepare_scene_latent(scene_name, frames[0])
|
| 189 |
+
else:
|
| 190 |
+
logger.warning(f"No frames found for scene: {scene_name}")
|
| 191 |
|
| 192 |
except Exception as e:
|
| 193 |
logger.error(f"Error loading models: {str(e)}")
|
|
|
|
| 273 |
raise RuntimeError(error_msg)
|
| 274 |
|
| 275 |
try:
|
| 276 |
+
logger.debug(f"Starting frame generation for scene: {scene_name}")
|
| 277 |
+
|
| 278 |
# Map scene name to mode
|
| 279 |
mode_map = {
|
| 280 |
'universal': 'universal',
|
|
|
|
| 283 |
'templerun': 'templerun'
|
| 284 |
}
|
| 285 |
mode = mode_map.get(scene_name, 'universal')
|
| 286 |
+
logger.debug(f"Using mode: {mode} for scene: {scene_name}")
|
| 287 |
|
| 288 |
# Get cached latent or prepare new one
|
| 289 |
if scene_name not in self.scene_latents:
|
| 290 |
+
logger.debug(f"Scene latents not cached for {scene_name}, preparing...")
|
| 291 |
scene_frames = self.scenes.get(scene_name, self.scenes.get('universal', []))
|
| 292 |
if scene_frames:
|
| 293 |
+
logger.debug(f"Found {len(scene_frames)} frames for scene {scene_name}")
|
| 294 |
self._prepare_scene_latent(scene_name, scene_frames[0])
|
| 295 |
else:
|
| 296 |
error_msg = f"No initial frames available for scene: {scene_name}"
|
| 297 |
logger.error(error_msg)
|
| 298 |
+
logger.error(f"Available scenes: {list(self.scenes.keys())}")
|
| 299 |
raise ValueError(error_msg)
|
| 300 |
|
| 301 |
scene_data = self.scene_latents.get(scene_name)
|
|
|
|
| 304 |
logger.error(error_msg)
|
| 305 |
raise ValueError(error_msg)
|
| 306 |
|
| 307 |
+
logger.debug(f"Scene data prepared for {scene_name}")
|
| 308 |
+
|
| 309 |
# Prepare conditions
|
| 310 |
if keyboard_condition is None:
|
| 311 |
keyboard_condition = [[0, 0, 0, 0, 0, 0]]
|
|
|
|
| 338 |
|
| 339 |
# Generate frames with streaming pipeline
|
| 340 |
with torch.no_grad():
|
| 341 |
+
logger.debug(f"Starting inference with mode: {mode}")
|
| 342 |
+
logger.debug(f"Conditional dict keys: {list(conditional_dict.keys())}")
|
| 343 |
+
logger.debug(f"Noise shape: {sampled_noise.shape}")
|
| 344 |
+
|
| 345 |
# Set seed for reproducibility
|
| 346 |
set_seed(self.seed + self.frame_count)
|
| 347 |
|
|
|
|
| 355 |
mode=mode
|
| 356 |
)
|
| 357 |
|
| 358 |
+
logger.debug(f"Inference completed, outputs type: {type(outputs)}")
|
| 359 |
+
|
| 360 |
# Decode first frame from latent
|
| 361 |
if outputs is not None and len(outputs) > 0:
|
| 362 |
# Extract first frame
|
|
|
|
| 376 |
|
| 377 |
except Exception as e:
|
| 378 |
error_msg = f"Error generating frame with Matrix-Game V2 model: {str(e)}"
|
| 379 |
+
logger.error(error_msg, exc_info=True)
|
| 380 |
+
logger.error(f"Scene: {scene_name}, Mode: {mode if 'mode' in locals() else 'unknown'}")
|
| 381 |
+
logger.error(f"Keyboard condition: {keyboard_condition}")
|
| 382 |
+
logger.error(f"Mouse condition: {mouse_condition}")
|
| 383 |
+
logger.error(f"Frame count: {self.frame_count}")
|
| 384 |
+
logger.error(f"Device: {self.device}")
|
| 385 |
+
logger.error(f"Weight dtype: {self.weight_dtype}")
|
| 386 |
raise RuntimeError(error_msg)
|
| 387 |
|
| 388 |
# Add visualization of input controls
|
api_server.py
CHANGED
|
@@ -41,8 +41,8 @@ class GameSession:
|
|
| 41 |
self.created_at = time.time()
|
| 42 |
self.last_activity = time.time()
|
| 43 |
|
| 44 |
-
# Game state
|
| 45 |
-
self.current_scene = "
|
| 46 |
self.is_streaming = False
|
| 47 |
self.stream_task = None
|
| 48 |
|
|
|
|
| 41 |
self.created_at = time.time()
|
| 42 |
self.last_activity = time.time()
|
| 43 |
|
| 44 |
+
# Game state
|
| 45 |
+
self.current_scene = "universal" # Default scene
|
| 46 |
self.is_streaming = False
|
| 47 |
self.stream_task = None
|
| 48 |
|
api_utils.py
CHANGED
|
@@ -18,7 +18,7 @@ from typing import Dict, List, Tuple, Any, Optional, Union
|
|
| 18 |
|
| 19 |
# Configure logging
|
| 20 |
logging.basicConfig(
|
| 21 |
-
level=logging.
|
| 22 |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 23 |
)
|
| 24 |
logger = logging.getLogger(__name__)
|
|
@@ -178,22 +178,42 @@ def load_scene_frames(scene_name: str, frame_width: int, frame_height: int) -> L
|
|
| 178 |
List[np.ndarray]: List of frames as numpy arrays
|
| 179 |
"""
|
| 180 |
frames = []
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
-
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
for img_file in image_files:
|
| 186 |
try:
|
| 187 |
img_path = os.path.join(scene_dir, img_file)
|
| 188 |
img = Image.open(img_path).convert("RGB")
|
| 189 |
img = img.resize((frame_width, frame_height))
|
| 190 |
frames.append(np.array(img))
|
|
|
|
| 191 |
except Exception as e:
|
| 192 |
logger.error(f"Error loading image {img_file}: {str(e)}")
|
| 193 |
|
| 194 |
# If no frames were loaded, create a default colored frame with text
|
| 195 |
if not frames:
|
| 196 |
-
|
|
|
|
| 197 |
# Add scene name as text
|
| 198 |
cv2.putText(frame, f"Scene: {scene_name}", (50, 180),
|
| 199 |
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
|
|
|
|
| 18 |
|
| 19 |
# Configure logging
|
| 20 |
logging.basicConfig(
|
| 21 |
+
level=logging.DEBUG,
|
| 22 |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 23 |
)
|
| 24 |
logger = logging.getLogger(__name__)
|
|
|
|
| 178 |
List[np.ndarray]: List of frames as numpy arrays
|
| 179 |
"""
|
| 180 |
frames = []
|
| 181 |
+
# Try multiple possible directories for scene images
|
| 182 |
+
scene_dirs = [
|
| 183 |
+
f"./GameWorldScore/asset/init_image/{scene_name}",
|
| 184 |
+
f"./demo_images/{scene_name}",
|
| 185 |
+
f"./demo_images/{scene_name.replace('_', '-')}", # Handle gta_drive -> gta-drive
|
| 186 |
+
]
|
| 187 |
|
| 188 |
+
scene_dir = None
|
| 189 |
+
logger.debug(f"Looking for scene images for '{scene_name}' in: {scene_dirs}")
|
| 190 |
+
for potential_dir in scene_dirs:
|
| 191 |
+
logger.debug(f"Checking directory: {potential_dir}")
|
| 192 |
+
if os.path.exists(potential_dir):
|
| 193 |
+
scene_dir = potential_dir
|
| 194 |
+
logger.debug(f"Found scene directory: {scene_dir}")
|
| 195 |
+
break
|
| 196 |
+
|
| 197 |
+
if not scene_dir:
|
| 198 |
+
logger.warning(f"No scene directory found for '{scene_name}'")
|
| 199 |
+
|
| 200 |
+
if scene_dir and os.path.exists(scene_dir):
|
| 201 |
+
image_files = sorted([f for f in os.listdir(scene_dir) if f.endswith('.png') or f.endswith('.jpg') or f.endswith('.webp')])
|
| 202 |
+
logger.debug(f"Found {len(image_files)} image files in {scene_dir}: {image_files}")
|
| 203 |
for img_file in image_files:
|
| 204 |
try:
|
| 205 |
img_path = os.path.join(scene_dir, img_file)
|
| 206 |
img = Image.open(img_path).convert("RGB")
|
| 207 |
img = img.resize((frame_width, frame_height))
|
| 208 |
frames.append(np.array(img))
|
| 209 |
+
logger.debug(f"Successfully loaded image: {img_file}")
|
| 210 |
except Exception as e:
|
| 211 |
logger.error(f"Error loading image {img_file}: {str(e)}")
|
| 212 |
|
| 213 |
# If no frames were loaded, create a default colored frame with text
|
| 214 |
if not frames:
|
| 215 |
+
logger.warning(f"No frames loaded for scene '{scene_name}', creating default frame")
|
| 216 |
+
frame = np.ones((frame_height, frame_width, 3), dtype=np.uint8) * 100
|
| 217 |
# Add scene name as text
|
| 218 |
cv2.putText(frame, f"Scene: {scene_name}", (50, 180),
|
| 219 |
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
|