File size: 3,884 Bytes
28bf80d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
import torch.nn as nn
import torch.nn.functional as F
import warnings
from itertools import product
from typing import Dict


def gauss_pt_eval(tensor: torch.Tensor, kernels: nn.ParameterList, stride: int = 1) -> torch.Tensor:
    if not kernels:
        raise ValueError("No Gauss kernels provided.")
    conv = F.conv2d
    B, C = tensor.shape[0], tensor.shape[1]
    device = tensor.device
    # determine output spatial shape
    with torch.no_grad():
        sample_out = conv(tensor[:, :1], kernels[0].to(device), stride=stride)
        out_spatial = sample_out.shape[2:]

    results = []
    for k in kernels:
        k = k.to(device)
        # apply convolution per channel
        out_ch = [conv(tensor[:, i:i+1], k, stride=stride) for i in range(C)]
        results.append(torch.cat(out_ch, dim=1).unsqueeze(1))

    out = torch.cat(results, dim=1)
    expected = (B, len(kernels), C) + out_spatial
    if out.shape != expected:
        warnings.warn(f"Shape mismatch in gauss_pt_eval: {out.shape} != {expected}")
    return out


class FEM2D(nn.Module):
    """
    Builds 2D FEM convolution kernels and evaluates derivatives.
    """
    def __init__(
        self,
        height: int,
        width: int,
        domain_length_x: float,
        domain_length_y: float,
        device: torch.device
    ):
        super().__init__()
        self.height, self.width = height, width
        self.device = device
        # 2-point Gauss quadrature
        self.gpx = [-0.57735, 0.57735]
        self.kernels_dx = nn.ParameterList()
        self.kernels_dy = nn.ParameterList()
        self._build_kernels(domain_length_x, domain_length_y)

    def _build_kernels(self, Lx: float, Ly: float):
        hx = Lx / (self.width - 1)
        hy = Ly / (self.height - 1)
        # linear basis functions on [-1,1]
        bf = lambda x: [0.5 * (1 - x), 0.5 * (1 + x)]
        dbf = lambda x: [-0.5, 0.5]

        for gx, gy in product(self.gpx, repeat=2):
            dx = torch.zeros(2, 2, device=self.device)
            dy = torch.zeros(2, 2, device=self.device)
            for i, bf_x in enumerate(bf(gx)):
                for j, bf_y in enumerate(bf(gy)):
                    dx[j, i] = dbf(gx)[i] * (2 / hx) * bf_y
                    dy[j, i] = bf_x * (dbf(gy)[j] * (2 / hy))
            # store kernels with shape [1,1,2,2]
            self.kernels_dx.append(nn.Parameter(dx.unsqueeze(0).unsqueeze(0), requires_grad=False))
            self.kernels_dy.append(nn.Parameter(dy.unsqueeze(0).unsqueeze(0), requires_grad=False))

    def eval_derivative_x(self, tensor: torch.Tensor) -> torch.Tensor:
        return gauss_pt_eval(tensor, self.kernels_dx)

    def eval_derivative_y(self, tensor: torch.Tensor) -> torch.Tensor:
        return gauss_pt_eval(tensor, self.kernels_dy)


class DerivativeCalculator(nn.Module):
    """
    Computes first spatial derivatives for 'u' and 'v' channels.
    """
    def __init__(
        self,
        height: int,
        width: int,
        domain_length_x: float,
        domain_length_y: float,
        device: torch.device,
        channels: int = 2  # number of channels: 2 for (u,v)
    ):
        super().__init__()
        self.channels = channels
        self.fem = FEM2D(height, width, domain_length_x, domain_length_y, device)

    def calculate_first_derivatives(self, y_spatial: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        y_spatial: [B, C, H, W] tensor where C == channels
        Returns a dict with keys 'u_x','u_y','v_x','v_y'.
        """
        deriv = {}
        names = ['u', 'v'][:self.channels]
        for idx, name in enumerate(names):
            field = y_spatial[:, idx:idx+1]
            deriv[f'{name}_x'] = self.fem.eval_derivative_x(field)
            deriv[f'{name}_y'] = self.fem.eval_derivative_y(field)
        return deriv

    forward = calculate_first_derivatives