ford442 commited on
Commit
59f891f
·
verified ·
1 Parent(s): 54992f8

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +16 -11
inference.py CHANGED
@@ -66,14 +66,7 @@ def load_image_to_tensor_with_resize_and_crop(
66
  target_width: int = 768,
67
  just_crop: bool = False,
68
  ) -> torch.Tensor:
69
- """Load and process an image into a tensor.
70
-
71
- Args:
72
- image_input: Either a file path (str) or a PIL Image object
73
- target_height: Desired height of output tensor
74
- target_width: Desired width of output tensor
75
- just_crop: If True, only crop the image to the target size without resizing
76
- """
77
  if isinstance(image_input, str):
78
  image = Image.open(image_input).convert("RGB")
79
  elif isinstance(image_input, Image.Image):
@@ -84,6 +77,7 @@ def load_image_to_tensor_with_resize_and_crop(
84
  input_width, input_height = image.size
85
  aspect_ratio_target = target_width / target_height
86
  aspect_ratio_frame = input_width / input_height
 
87
  if aspect_ratio_frame > aspect_ratio_target:
88
  new_width = int(input_height * aspect_ratio_target)
89
  new_height = input_height
@@ -95,16 +89,27 @@ def load_image_to_tensor_with_resize_and_crop(
95
  x_start = 0
96
  y_start = (input_height - new_height) // 2
97
 
 
98
  image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
 
99
  if not just_crop:
100
- image = image.resize((target_width, target_height))
 
101
 
 
102
  image = np.array(image)
103
- image = cv2.GaussianBlur(image, (3, 3), 0)
 
 
104
  frame_tensor = torch.from_numpy(image).float()
105
- frame_tensor = crf_compressor.compress(frame_tensor / 255.0) * 255.0
 
 
 
 
106
  frame_tensor = frame_tensor.permute(2, 0, 1)
107
  frame_tensor = (frame_tensor / 127.5) - 1.0
 
108
  # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
109
  return frame_tensor.unsqueeze(0).unsqueeze(2)
110
 
 
66
  target_width: int = 768,
67
  just_crop: bool = False,
68
  ) -> torch.Tensor:
69
+ """Load and process an image into a tensor with high-quality scaling."""
 
 
 
 
 
 
 
70
  if isinstance(image_input, str):
71
  image = Image.open(image_input).convert("RGB")
72
  elif isinstance(image_input, Image.Image):
 
77
  input_width, input_height = image.size
78
  aspect_ratio_target = target_width / target_height
79
  aspect_ratio_frame = input_width / input_height
80
+
81
  if aspect_ratio_frame > aspect_ratio_target:
82
  new_width = int(input_height * aspect_ratio_target)
83
  new_height = input_height
 
89
  x_start = 0
90
  y_start = (input_height - new_height) // 2
91
 
92
+ # Crop the center of the image
93
  image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
94
+
95
  if not just_crop:
96
+ # Use LANCZOS for high-quality downscaling/upscaling
97
+ image = image.resize((target_width, target_height), Image.LANCZOS)
98
 
99
+ # Convert to numpy and standard processing WITHOUT blur or crf_compression
100
  image = np.array(image)
101
+
102
+ # REMOVED: cv2.GaussianBlur(image, (3, 3), 0)
103
+
104
  frame_tensor = torch.from_numpy(image).float()
105
+
106
+ # REMOVED: crf_compressor.compress(...)
107
+
108
+ # Normalize to [-1, 1] range expected by the VAE
109
+ # Note: The tensor is in (H, W, C) from numpy, we need (C, H, W)
110
  frame_tensor = frame_tensor.permute(2, 0, 1)
111
  frame_tensor = (frame_tensor / 127.5) - 1.0
112
+
113
  # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
114
  return frame_tensor.unsqueeze(0).unsqueeze(2)
115