Spaces:
Sleeping
Sleeping
File size: 3,884 Bytes
28bf80d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import warnings
from itertools import product
from typing import Dict
def gauss_pt_eval(tensor: torch.Tensor, kernels: nn.ParameterList, stride: int = 1) -> torch.Tensor:
if not kernels:
raise ValueError("No Gauss kernels provided.")
conv = F.conv2d
B, C = tensor.shape[0], tensor.shape[1]
device = tensor.device
# determine output spatial shape
with torch.no_grad():
sample_out = conv(tensor[:, :1], kernels[0].to(device), stride=stride)
out_spatial = sample_out.shape[2:]
results = []
for k in kernels:
k = k.to(device)
# apply convolution per channel
out_ch = [conv(tensor[:, i:i+1], k, stride=stride) for i in range(C)]
results.append(torch.cat(out_ch, dim=1).unsqueeze(1))
out = torch.cat(results, dim=1)
expected = (B, len(kernels), C) + out_spatial
if out.shape != expected:
warnings.warn(f"Shape mismatch in gauss_pt_eval: {out.shape} != {expected}")
return out
class FEM2D(nn.Module):
"""
Builds 2D FEM convolution kernels and evaluates derivatives.
"""
def __init__(
self,
height: int,
width: int,
domain_length_x: float,
domain_length_y: float,
device: torch.device
):
super().__init__()
self.height, self.width = height, width
self.device = device
# 2-point Gauss quadrature
self.gpx = [-0.57735, 0.57735]
self.kernels_dx = nn.ParameterList()
self.kernels_dy = nn.ParameterList()
self._build_kernels(domain_length_x, domain_length_y)
def _build_kernels(self, Lx: float, Ly: float):
hx = Lx / (self.width - 1)
hy = Ly / (self.height - 1)
# linear basis functions on [-1,1]
bf = lambda x: [0.5 * (1 - x), 0.5 * (1 + x)]
dbf = lambda x: [-0.5, 0.5]
for gx, gy in product(self.gpx, repeat=2):
dx = torch.zeros(2, 2, device=self.device)
dy = torch.zeros(2, 2, device=self.device)
for i, bf_x in enumerate(bf(gx)):
for j, bf_y in enumerate(bf(gy)):
dx[j, i] = dbf(gx)[i] * (2 / hx) * bf_y
dy[j, i] = bf_x * (dbf(gy)[j] * (2 / hy))
# store kernels with shape [1,1,2,2]
self.kernels_dx.append(nn.Parameter(dx.unsqueeze(0).unsqueeze(0), requires_grad=False))
self.kernels_dy.append(nn.Parameter(dy.unsqueeze(0).unsqueeze(0), requires_grad=False))
def eval_derivative_x(self, tensor: torch.Tensor) -> torch.Tensor:
return gauss_pt_eval(tensor, self.kernels_dx)
def eval_derivative_y(self, tensor: torch.Tensor) -> torch.Tensor:
return gauss_pt_eval(tensor, self.kernels_dy)
class DerivativeCalculator(nn.Module):
"""
Computes first spatial derivatives for 'u' and 'v' channels.
"""
def __init__(
self,
height: int,
width: int,
domain_length_x: float,
domain_length_y: float,
device: torch.device,
channels: int = 2 # number of channels: 2 for (u,v)
):
super().__init__()
self.channels = channels
self.fem = FEM2D(height, width, domain_length_x, domain_length_y, device)
def calculate_first_derivatives(self, y_spatial: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
y_spatial: [B, C, H, W] tensor where C == channels
Returns a dict with keys 'u_x','u_y','v_x','v_y'.
"""
deriv = {}
names = ['u', 'v'][:self.channels]
for idx, name in enumerate(names):
field = y_spatial[:, idx:idx+1]
deriv[f'{name}_x'] = self.fem.eval_derivative_x(field)
deriv[f'{name}_y'] = self.fem.eval_derivative_y(field)
return deriv
forward = calculate_first_derivatives
|