import torch import torch.nn as nn import torch.nn.functional as F import warnings from itertools import product from typing import Dict def gauss_pt_eval(tensor: torch.Tensor, kernels: nn.ParameterList, stride: int = 1) -> torch.Tensor: if not kernels: raise ValueError("No Gauss kernels provided.") conv = F.conv2d B, C = tensor.shape[0], tensor.shape[1] device = tensor.device # determine output spatial shape with torch.no_grad(): sample_out = conv(tensor[:, :1], kernels[0].to(device), stride=stride) out_spatial = sample_out.shape[2:] results = [] for k in kernels: k = k.to(device) # apply convolution per channel out_ch = [conv(tensor[:, i:i+1], k, stride=stride) for i in range(C)] results.append(torch.cat(out_ch, dim=1).unsqueeze(1)) out = torch.cat(results, dim=1) expected = (B, len(kernels), C) + out_spatial if out.shape != expected: warnings.warn(f"Shape mismatch in gauss_pt_eval: {out.shape} != {expected}") return out class FEM2D(nn.Module): """ Builds 2D FEM convolution kernels and evaluates derivatives. """ def __init__( self, height: int, width: int, domain_length_x: float, domain_length_y: float, device: torch.device ): super().__init__() self.height, self.width = height, width self.device = device # 2-point Gauss quadrature self.gpx = [-0.57735, 0.57735] self.kernels_dx = nn.ParameterList() self.kernels_dy = nn.ParameterList() self._build_kernels(domain_length_x, domain_length_y) def _build_kernels(self, Lx: float, Ly: float): hx = Lx / (self.width - 1) hy = Ly / (self.height - 1) # linear basis functions on [-1,1] bf = lambda x: [0.5 * (1 - x), 0.5 * (1 + x)] dbf = lambda x: [-0.5, 0.5] for gx, gy in product(self.gpx, repeat=2): dx = torch.zeros(2, 2, device=self.device) dy = torch.zeros(2, 2, device=self.device) for i, bf_x in enumerate(bf(gx)): for j, bf_y in enumerate(bf(gy)): dx[j, i] = dbf(gx)[i] * (2 / hx) * bf_y dy[j, i] = bf_x * (dbf(gy)[j] * (2 / hy)) # store kernels with shape [1,1,2,2] self.kernels_dx.append(nn.Parameter(dx.unsqueeze(0).unsqueeze(0), requires_grad=False)) self.kernels_dy.append(nn.Parameter(dy.unsqueeze(0).unsqueeze(0), requires_grad=False)) def eval_derivative_x(self, tensor: torch.Tensor) -> torch.Tensor: return gauss_pt_eval(tensor, self.kernels_dx) def eval_derivative_y(self, tensor: torch.Tensor) -> torch.Tensor: return gauss_pt_eval(tensor, self.kernels_dy) class DerivativeCalculator(nn.Module): """ Computes first spatial derivatives for 'u' and 'v' channels. """ def __init__( self, height: int, width: int, domain_length_x: float, domain_length_y: float, device: torch.device, channels: int = 2 # number of channels: 2 for (u,v) ): super().__init__() self.channels = channels self.fem = FEM2D(height, width, domain_length_x, domain_length_y, device) def calculate_first_derivatives(self, y_spatial: torch.Tensor) -> Dict[str, torch.Tensor]: """ y_spatial: [B, C, H, W] tensor where C == channels Returns a dict with keys 'u_x','u_y','v_x','v_y'. """ deriv = {} names = ['u', 'v'][:self.channels] for idx, name in enumerate(names): field = y_spatial[:, idx:idx+1] deriv[f'{name}_x'] = self.fem.eval_derivative_x(field) deriv[f'{name}_y'] = self.fem.eval_derivative_y(field) return deriv forward = calculate_first_derivatives