Spaces:
Sleeping
Sleeping
| #!/usr/bin/python | |
| # -*- coding:utf-8 -*- | |
| import torch | |
| import numpy as np | |
| from scipy.spatial.transform import Rotation | |
| # from https://github.com/charnley/rmsd/blob/master/rmsd/calculate_rmsd.py | |
| def kabsch_rotation(P, Q): | |
| """ | |
| Using the Kabsch algorithm with two sets of paired point P and Q, centered | |
| around the centroid. Each vector set is represented as an NxD | |
| matrix, where D is the the dimension of the space. | |
| The algorithm works in three steps: | |
| - a centroid translation of P and Q (assumed done before this function | |
| call) | |
| - the computation of a covariance matrix C | |
| - computation of the optimal rotation matrix U | |
| For more info see http://en.wikipedia.org/wiki/Kabsch_algorithm | |
| Parameters | |
| ---------- | |
| P : array | |
| (N,D) matrix, where N is points and D is dimension. | |
| Q : array | |
| (N,D) matrix, where N is points and D is dimension. | |
| Returns | |
| ------- | |
| U : matrix | |
| Rotation matrix (D,D) | |
| """ | |
| # Computation of the covariance matrix | |
| C = np.dot(np.transpose(P), Q) | |
| # Computation of the optimal rotation matrix | |
| # This can be done using singular value decomposition (SVD) | |
| # Getting the sign of the det(V)*(W) to decide | |
| # whether we need to correct our rotation matrix to ensure a | |
| # right-handed coordinate system. | |
| # And finally calculating the optimal rotation matrix U | |
| # see http://en.wikipedia.org/wiki/Kabsch_algorithm | |
| V, S, W = np.linalg.svd(C) | |
| d = (np.linalg.det(V) * np.linalg.det(W)) < 0.0 | |
| if d: | |
| S[-1] = -S[-1] | |
| V[:, -1] = -V[:, -1] | |
| # Create Rotation matrix U | |
| U = np.dot(V, W) | |
| return U | |
| # have been validated with kabsch from RefineGNN | |
| def kabsch(a, b): | |
| # find optimal rotation matrix to transform a into b | |
| # a, b are both [N, 3] | |
| # a_aligned = aR + t | |
| a, b = np.array(a), np.array(b) | |
| a_mean = np.mean(a, axis=0) | |
| b_mean = np.mean(b, axis=0) | |
| a_c = a - a_mean | |
| b_c = b - b_mean | |
| rotation = kabsch_rotation(a_c, b_c) | |
| # a_aligned = np.dot(a_c, rotation) | |
| # t = b_mean - np.mean(a_aligned, axis=0) | |
| # a_aligned += t | |
| t = b_mean - np.dot(a_mean, rotation) | |
| a_aligned = np.dot(a, rotation) + t | |
| return a_aligned, rotation, t | |
| # a: [N, 3], b: [N, 3] | |
| def compute_rmsd(a, b, aligned=False): # amino acids level rmsd | |
| if aligned: | |
| a_aligned = a | |
| else: | |
| a_aligned, _, _ = kabsch(a, b) | |
| dist = np.sum((a_aligned - b) ** 2, axis=-1) | |
| rmsd = np.sqrt(dist.sum() / a.shape[0]) | |
| return float(rmsd) | |
| def kabsch_torch(A, B, requires_grad=False): | |
| """ | |
| See: https://en.wikipedia.org/wiki/Kabsch_algorithm | |
| 2-D or 3-D registration with known correspondences. | |
| Registration occurs in the zero centered coordinate system, and then | |
| must be transported back. | |
| Args: | |
| - A: Torch tensor of shape (N,D) -- Point Cloud to Align (source) | |
| - B: Torch tensor of shape (N,D) -- Reference Point Cloud (target) | |
| Returns: | |
| - R: optimal rotation | |
| - t: optimal translation | |
| Test on rotation + translation and on rotation + translation + reflection | |
| >>> A = torch.tensor([[1., 1.], [2., 2.], [1.5, 3.]], dtype=torch.float) | |
| >>> R0 = torch.tensor([[np.cos(60), -np.sin(60)], [np.sin(60), np.cos(60)]], dtype=torch.float) | |
| >>> B = (R0.mm(A.T)).T | |
| >>> t0 = torch.tensor([3., 3.]) | |
| >>> B += t0 | |
| >>> R, t = find_rigid_alignment(A, B) | |
| >>> A_aligned = (R.mm(A.T)).T + t | |
| >>> rmsd = torch.sqrt(((A_aligned - B)**2).sum(axis=1).mean()) | |
| >>> rmsd | |
| tensor(3.7064e-07) | |
| >>> B *= torch.tensor([-1., 1.]) | |
| >>> R, t = find_rigid_alignment(A, B) | |
| >>> A_aligned = (R.mm(A.T)).T + t | |
| >>> rmsd = torch.sqrt(((A_aligned - B)**2).sum(axis=1).mean()) | |
| >>> rmsd | |
| tensor(3.7064e-07) | |
| """ | |
| a_mean = A.mean(axis=0) | |
| b_mean = B.mean(axis=0) | |
| A_c = A - a_mean | |
| B_c = B - b_mean | |
| # Covariance matrix | |
| H = A_c.T.mm(B_c) | |
| # U, S, V = torch.svd(H) | |
| if requires_grad: # try more times to find a stable solution | |
| assert not torch.isnan(H).any() | |
| U, S, Vt = torch.linalg.svd(H) | |
| num_it = 0 | |
| while torch.min(S) < 1e-3 or torch.min( | |
| torch.abs((S ** 2).view(1, 3) - (S ** 2).view(3, 1) + torch.eye(3).to(S.device))) < 1e-2: | |
| H = H + torch.rand(3, 3).to(H.device) * torch.eye(3).to(H.device) | |
| U, S, Vt = torch.linalg.svd(H) | |
| num_it += 1 | |
| if num_it > 10: | |
| raise RuntimeError('SVD consistently numerically unstable! Exitting ... ') | |
| else: | |
| U, S, Vt = torch.linalg.svd(H) | |
| V = Vt.T | |
| # rms | |
| d = (torch.linalg.det(U) * torch.linalg.det(V)) < 0.0 | |
| if d: | |
| SS = torch.diag(torch.tensor([1. for _ in range(len(U) - 1)] + [-1.], device=U.device, dtype=U.dtype)) | |
| U = U @ SS | |
| # U[:, -1] = -U[:, -1] | |
| # Rotation matrix | |
| R = V.mm(U.T) | |
| # Translation vector | |
| t = b_mean[None, :] - R.mm(a_mean[None, :].T).T | |
| t = (t.T).squeeze() | |
| return R.mm(A.T).T + t, R, t | |
| def batch_kabsch_torch(A, B): | |
| ''' | |
| A: [B, N, 3] | |
| B: [B, N, 3] | |
| ''' | |
| a_mean = A.mean(dim=1, keepdims=True) | |
| b_mean = B.mean(dim=1, keepdims=True) | |
| A_c = A - a_mean | |
| B_c = B - b_mean | |
| # Covariance matrix | |
| H = torch.bmm(A_c.transpose(1, 2), B_c) # [B, 3, 3] | |
| U, S, Vt = torch.linalg.svd(H) # [B, 3, 3] | |
| V = Vt.transpose(1, 2) | |
| # rms | |
| d = ((torch.linalg.det(U) * torch.linalg.det(V)) < 0.0).long() # [B] | |
| nSS = torch.diag(torch.tensor([1. for _ in range(len(U))], device=U.device, dtype=U.dtype)) | |
| SS = torch.diag(torch.tensor([1. for _ in range(len(U) - 1)] + [-1.], device=U.device, dtype=U.dtype)) | |
| bSS = torch.stack([nSS, SS], dim=0)[d] # [B, 3, 3] | |
| U = torch.bmm(U, bSS) | |
| # Rotation matrix | |
| R = torch.bmm(V, U.transpose(1, 2)) # [B, 3, 3] | |
| # Translation vector | |
| t = b_mean - torch.bmm(R, a_mean.transpose(1, 2)).transpose(1, 2) | |
| A_aligned = torch.bmm(R, A.transpose(1, 2)).transpose(1, 2) + t | |
| return A_aligned, R, t |