Pocket-Gen / utils /rmsd.py
Zaixi's picture
1
dcacefd
#!/usr/bin/python
# -*- coding:utf-8 -*-
import torch
import numpy as np
from scipy.spatial.transform import Rotation
# from https://github.com/charnley/rmsd/blob/master/rmsd/calculate_rmsd.py
def kabsch_rotation(P, Q):
"""
Using the Kabsch algorithm with two sets of paired point P and Q, centered
around the centroid. Each vector set is represented as an NxD
matrix, where D is the the dimension of the space.
The algorithm works in three steps:
- a centroid translation of P and Q (assumed done before this function
call)
- the computation of a covariance matrix C
- computation of the optimal rotation matrix U
For more info see http://en.wikipedia.org/wiki/Kabsch_algorithm
Parameters
----------
P : array
(N,D) matrix, where N is points and D is dimension.
Q : array
(N,D) matrix, where N is points and D is dimension.
Returns
-------
U : matrix
Rotation matrix (D,D)
"""
# Computation of the covariance matrix
C = np.dot(np.transpose(P), Q)
# Computation of the optimal rotation matrix
# This can be done using singular value decomposition (SVD)
# Getting the sign of the det(V)*(W) to decide
# whether we need to correct our rotation matrix to ensure a
# right-handed coordinate system.
# And finally calculating the optimal rotation matrix U
# see http://en.wikipedia.org/wiki/Kabsch_algorithm
V, S, W = np.linalg.svd(C)
d = (np.linalg.det(V) * np.linalg.det(W)) < 0.0
if d:
S[-1] = -S[-1]
V[:, -1] = -V[:, -1]
# Create Rotation matrix U
U = np.dot(V, W)
return U
# have been validated with kabsch from RefineGNN
def kabsch(a, b):
# find optimal rotation matrix to transform a into b
# a, b are both [N, 3]
# a_aligned = aR + t
a, b = np.array(a), np.array(b)
a_mean = np.mean(a, axis=0)
b_mean = np.mean(b, axis=0)
a_c = a - a_mean
b_c = b - b_mean
rotation = kabsch_rotation(a_c, b_c)
# a_aligned = np.dot(a_c, rotation)
# t = b_mean - np.mean(a_aligned, axis=0)
# a_aligned += t
t = b_mean - np.dot(a_mean, rotation)
a_aligned = np.dot(a, rotation) + t
return a_aligned, rotation, t
# a: [N, 3], b: [N, 3]
def compute_rmsd(a, b, aligned=False): # amino acids level rmsd
if aligned:
a_aligned = a
else:
a_aligned, _, _ = kabsch(a, b)
dist = np.sum((a_aligned - b) ** 2, axis=-1)
rmsd = np.sqrt(dist.sum() / a.shape[0])
return float(rmsd)
def kabsch_torch(A, B, requires_grad=False):
"""
See: https://en.wikipedia.org/wiki/Kabsch_algorithm
2-D or 3-D registration with known correspondences.
Registration occurs in the zero centered coordinate system, and then
must be transported back.
Args:
- A: Torch tensor of shape (N,D) -- Point Cloud to Align (source)
- B: Torch tensor of shape (N,D) -- Reference Point Cloud (target)
Returns:
- R: optimal rotation
- t: optimal translation
Test on rotation + translation and on rotation + translation + reflection
>>> A = torch.tensor([[1., 1.], [2., 2.], [1.5, 3.]], dtype=torch.float)
>>> R0 = torch.tensor([[np.cos(60), -np.sin(60)], [np.sin(60), np.cos(60)]], dtype=torch.float)
>>> B = (R0.mm(A.T)).T
>>> t0 = torch.tensor([3., 3.])
>>> B += t0
>>> R, t = find_rigid_alignment(A, B)
>>> A_aligned = (R.mm(A.T)).T + t
>>> rmsd = torch.sqrt(((A_aligned - B)**2).sum(axis=1).mean())
>>> rmsd
tensor(3.7064e-07)
>>> B *= torch.tensor([-1., 1.])
>>> R, t = find_rigid_alignment(A, B)
>>> A_aligned = (R.mm(A.T)).T + t
>>> rmsd = torch.sqrt(((A_aligned - B)**2).sum(axis=1).mean())
>>> rmsd
tensor(3.7064e-07)
"""
a_mean = A.mean(axis=0)
b_mean = B.mean(axis=0)
A_c = A - a_mean
B_c = B - b_mean
# Covariance matrix
H = A_c.T.mm(B_c)
# U, S, V = torch.svd(H)
if requires_grad: # try more times to find a stable solution
assert not torch.isnan(H).any()
U, S, Vt = torch.linalg.svd(H)
num_it = 0
while torch.min(S) < 1e-3 or torch.min(
torch.abs((S ** 2).view(1, 3) - (S ** 2).view(3, 1) + torch.eye(3).to(S.device))) < 1e-2:
H = H + torch.rand(3, 3).to(H.device) * torch.eye(3).to(H.device)
U, S, Vt = torch.linalg.svd(H)
num_it += 1
if num_it > 10:
raise RuntimeError('SVD consistently numerically unstable! Exitting ... ')
else:
U, S, Vt = torch.linalg.svd(H)
V = Vt.T
# rms
d = (torch.linalg.det(U) * torch.linalg.det(V)) < 0.0
if d:
SS = torch.diag(torch.tensor([1. for _ in range(len(U) - 1)] + [-1.], device=U.device, dtype=U.dtype))
U = U @ SS
# U[:, -1] = -U[:, -1]
# Rotation matrix
R = V.mm(U.T)
# Translation vector
t = b_mean[None, :] - R.mm(a_mean[None, :].T).T
t = (t.T).squeeze()
return R.mm(A.T).T + t, R, t
def batch_kabsch_torch(A, B):
'''
A: [B, N, 3]
B: [B, N, 3]
'''
a_mean = A.mean(dim=1, keepdims=True)
b_mean = B.mean(dim=1, keepdims=True)
A_c = A - a_mean
B_c = B - b_mean
# Covariance matrix
H = torch.bmm(A_c.transpose(1, 2), B_c) # [B, 3, 3]
U, S, Vt = torch.linalg.svd(H) # [B, 3, 3]
V = Vt.transpose(1, 2)
# rms
d = ((torch.linalg.det(U) * torch.linalg.det(V)) < 0.0).long() # [B]
nSS = torch.diag(torch.tensor([1. for _ in range(len(U))], device=U.device, dtype=U.dtype))
SS = torch.diag(torch.tensor([1. for _ in range(len(U) - 1)] + [-1.], device=U.device, dtype=U.dtype))
bSS = torch.stack([nSS, SS], dim=0)[d] # [B, 3, 3]
U = torch.bmm(U, bSS)
# Rotation matrix
R = torch.bmm(V, U.transpose(1, 2)) # [B, 3, 3]
# Translation vector
t = b_mean - torch.bmm(R, a_mean.transpose(1, 2)).transpose(1, 2)
A_aligned = torch.bmm(R, A.transpose(1, 2)).transpose(1, 2) + t
return A_aligned, R, t