Spaces:
Sleeping
Sleeping
File size: 18,270 Bytes
6ee4b83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 |
import spaces
import gradio as gr
from gradio_molecule3d import Molecule3D
import os
import numpy as np
import torch
from rdkit import Chem
import argparse
import random
from tqdm import tqdm
from vina import Vina
import esm
from utils.relax import openmm_relax, relax_sdf
from utils.protein_ligand import PDBProtein, parse_sdf_file
from utils.data import torchify_dict
from torch_geometric.transforms import Compose
from utils.datasets import *
from utils.transforms import *
from utils.misc import *
from utils.data import *
from torch.utils.data import DataLoader
from models.PD import Pocket_Design_new
from functools import partial
import pickle
import yaml
from easydict import EasyDict
import uuid
from datetime import datetime
import tempfile
import shutil
from Bio import PDB
from Bio.PDB import MMCIFParser, PDBIO
import logging
import zipfile
# 配置日志
logger = logging.getLogger(__name__)
LOG_FORMAT = "%(asctime)s,%(msecs)-3d %(levelname)-8s [%(filename)s:%(lineno)s %(funcName)s] %(message)s"
logging.basicConfig(
format=LOG_FORMAT,
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
filemode="w",
)
# 确保目录存在
os.makedirs("./generate/upload", exist_ok=True)
os.makedirs("./tmp", exist_ok=True)
# 自定义CSS样式
custom_css = """
.title {
font-size: 32px;
font-weight: bold;
color: #4CAF50;
display: flex;
align-items: center;
}
.subtitle {
font-size: 20px;
color: #666;
margin-bottom: 20px;
}
.footer {
margin-top: 20px;
text-align: center;
color: #666;
}
"""
# 3D显示表示设置 - 默认配置
default_reps = [
{
"model": 0,
"chain": "",
"resname": "",
"style": "cartoon",
"color": "whiteCarbon",
"residue_range": "",
"around": 0,
"byres": False,
"visible": True,
"opacity": 1.0
},
{
"model": 0,
"chain": "",
"resname": "",
"style": "stick",
"color": "greenCarbon",
"around": 5, # 显示配体周围5Å的残基
"byres": True,
"visible": True,
"opacity": 0.8
}
]
def create_zip_file(directory_path, zip_filename):
"""将指定目录压缩为zip文件"""
try:
with zipfile.ZipFile(zip_filename, 'w', zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk(directory_path):
for file in files:
file_path = os.path.join(root, file)
arcname = os.path.relpath(file_path, directory_path)
zipf.write(file_path, arcname)
logger.info(f"成功创建压缩文件: {zip_filename}")
return zip_filename
except Exception as e:
logger.error(f"创建压缩文件时出错: {str(e)}")
return None
def load_config(config_path):
"""加载配置文件"""
with open(config_path, 'r') as f:
config_dict = yaml.load(f, Loader=yaml.FullLoader)
return EasyDict(config_dict)
# 删除了Vina相关的计算函数,因为只需要RMSD结果
def from_protein_ligand_dicts(protein_dict=None, ligand_dict=None, residue_dict=None, seq=None, full_seq_idx=None,
r10_idx=None):
"""从蛋白质和配体字典创建数据实例"""
instance = {}
if protein_dict is not None:
for key, item in protein_dict.items():
instance['protein_' + key] = item
if ligand_dict is not None:
for key, item in ligand_dict.items():
instance['ligand_' + key] = item
if residue_dict is not None:
for key, item in residue_dict.items():
instance[key] = item
if seq is not None:
instance['seq'] = seq
if full_seq_idx is not None:
instance['full_seq_idx'] = full_seq_idx
if r10_idx is not None:
instance['r10_idx'] = r10_idx
return instance
def ith_true_index(tensor, i):
"""找到张量中第i个为真的元素的索引"""
true_indices = torch.nonzero(tensor).squeeze()
return true_indices[i].item()
def name2data(pdb_path, lig_path):
"""从PDB和SDF文件生成数据"""
name = os.path.basename(pdb_path).split('.')[0]
dir_name = os.path.dirname(pdb_path)
pocket_path = os.path.join(dir_name, f"{name}_pocket.pdb")
try:
with open(pdb_path, 'r') as f:
pdb_block = f.read()
protein = PDBProtein(pdb_block)
seq = ''.join(protein.to_dict_residue()['seq'])
ligand = parse_sdf_file(lig_path, feat=False)
if ligand is None:
raise ValueError(f"无法从{lig_path}解析配体")
r10_idx, r10_residues = protein.query_residues_ligand(ligand, radius=10, selected_residue=None, return_mask=False)
full_seq_idx, _ = protein.query_residues_ligand(ligand, radius=3.5, selected_residue=r10_residues, return_mask=False)
if not r10_residues:
raise ValueError("在配体10Å范围内未找到任何残基")
assert len(r10_idx) == len(r10_residues)
pdb_block_pocket = protein.residues_to_pdb_block(r10_residues)
with open(pocket_path, 'w') as f:
f.write(pdb_block_pocket)
with open(pocket_path, 'r') as f:
pdb_block = f.read()
pocket = PDBProtein(pdb_block)
pocket_dict = pocket.to_dict_atom()
residue_dict = pocket.to_dict_residue()
_, residue_dict['protein_edit_residue'] = pocket.query_residues_ligand(ligand)
if residue_dict['protein_edit_residue'].sum() == 0:
raise ValueError("在口袋内未找到可编辑残基")
assert residue_dict['protein_edit_residue'].sum() > 0 and residue_dict['protein_edit_residue'].sum() == len(full_seq_idx)
assert len(residue_dict['protein_edit_residue']) == len(r10_idx)
full_seq_idx.sort()
r10_idx.sort()
data = from_protein_ligand_dicts(
protein_dict=torchify_dict(pocket_dict),
ligand_dict=torchify_dict(ligand),
residue_dict=torchify_dict(residue_dict),
seq=seq,
full_seq_idx=torch.tensor(full_seq_idx),
r10_idx=torch.tensor(r10_idx)
)
data['protein_filename'] = pocket_path
data['ligand_filename'] = lig_path
data['whole_protein_name'] = pdb_path
return transform(data)
except Exception as e:
logger.error(f"name2data中出错: {str(e)}")
raise
def convert_cif_to_pdb(cif_path):
"""将CIF文件转换为PDB文件并保存为临时文件"""
try:
parser = MMCIFParser()
structure = parser.get_structure("protein", cif_path)
with tempfile.NamedTemporaryFile(suffix=".pdb", delete=False) as temp_file:
temp_pdb_path = temp_file.name
io = PDBIO()
io.set_structure(structure)
io.save(temp_pdb_path)
return temp_pdb_path
except Exception as e:
logger.error(f"将CIF转换为PDB时出错: {str(e)}")
raise
def align_pdb_files(pdb_file_1, pdb_file_2):
"""将两个PDB文件对齐,将第二个结构对齐到第一个结构上"""
try:
parser = PDB.PPBuilder()
io = PDB.PDBIO()
structure_1 = PDB.PDBParser(QUIET=True).get_structure('Structure_1', pdb_file_1)
structure_2 = PDB.PDBParser(QUIET=True).get_structure('Structure_2', pdb_file_2)
super_imposer = PDB.Superimposer()
model_1 = structure_1[0]
model_2 = structure_2[0]
atoms_1 = [atom for atom in model_1.get_atoms() if atom.get_name() == "CA"]
atoms_2 = [atom for atom in model_2.get_atoms() if atom.get_name() == "CA"]
if not atoms_1 or not atoms_2:
logger.warning("未找到用于对齐的CA原子")
return
min_length = min(len(atoms_1), len(atoms_2))
if min_length == 0:
logger.warning("没有可用于对齐的原子")
return
super_imposer.set_atoms(atoms_1[:min_length], atoms_2[:min_length])
super_imposer.apply(model_2)
io.set_structure(structure_2)
io.save(pdb_file_2)
except Exception as e:
logger.error(f"对齐PDB文件时出错: {str(e)}")
raise
def create_combined_structure(protein_path, ligand_path, output_path):
"""将蛋白质和配体合并为一个PDB文件以便可视化"""
try:
# 读取蛋白质PDB文件
with open(protein_path, 'r') as f:
protein_content = f.read()
# 读取配体SDF文件并转换为PDB格式的字符串
mol = Chem.MolFromMolFile(ligand_path)
if mol is None:
logger.error(f"无法读取配体文件: {ligand_path}")
return protein_path
# 将配体转换为PDB格式
ligand_pdb_block = Chem.MolToPDBBlock(mol)
# 合并蛋白质和配体
combined_content = protein_content.rstrip() + "\n" + ligand_pdb_block
# 保存合并后的文件
with open(output_path, 'w') as f:
f.write(combined_content)
return output_path
except Exception as e:
logger.error(f"创建合并结构时出错: {str(e)}")
return protein_path # 如果失败,返回原始蛋白质文件
@spaces.GPU(duration=500)
def process_files(pdb_file, sdf_file, config_path):
"""处理上传的PDB和SDF文件"""
try:
unique_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
upload_dir = os.path.join("./generate/upload", unique_id)
os.makedirs(upload_dir, exist_ok=True)
logger.info(f"使用ID处理文件: {unique_id}")
config = load_config(config_path)
pdb_save_path = os.path.join(upload_dir, "protein.pdb")
sdf_save_path = os.path.join(upload_dir, "ligand.sdf")
shutil.copy(pdb_file, pdb_save_path)
shutil.copy(sdf_file, sdf_save_path)
logger.info(f"文件已保存到 {upload_dir}")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
logger.info(f"使用设备: {device}")
protein_featurizer = FeaturizeProteinAtom()
ligand_featurizer = FeaturizeLigandAtom()
global transform
transform = Compose([
protein_featurizer,
ligand_featurizer,
])
logger.info("加载ESM模型...")
name = 'esm2_t33_650M_UR50D'
pretrained_model, alphabet = esm.pretrained.load_model_and_alphabet_hub(name)
batch_converter = alphabet.get_batch_converter()
checkpoint_path = config.model.checkpoint
logger.info(f"从{checkpoint_path}加载检查点")
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
del pretrained_model
logger.info("初始化模型...")
model = Pocket_Design_new(
config.model,
protein_atom_feature_dim=protein_featurizer.feature_dim,
ligand_atom_feature_dim=ligand_featurizer.feature_dim,
device=device
).to(device)
model.load_state_dict(ckpt['model'])
logger.info("处理输入数据...")
data = name2data(pdb_save_path, sdf_save_path)
batch_size = 2
datalist = [data for _ in range(batch_size)]
protein_filename = data['protein_filename']
ligand_filename = data['ligand_filename']
whole_protein_name = data['whole_protein_name']
dir_name = os.path.dirname(protein_filename)
model.generate_id = 0
model.generate_id1 = 0
test_loader = DataLoader(
datalist,
batch_size=batch_size,
shuffle=False,
num_workers=0,
collate_fn=partial(collate_mols_block, batch_converter=batch_converter)
)
logger.info("生成结构...")
with torch.no_grad():
model.eval()
for batch in tqdm(test_loader, desc='Test'):
for key in batch:
if torch.is_tensor(batch[key]):
batch[key] = batch[key].to(device)
aar, rmsd, attend_logits = model.generate(batch, dir_name)
logger.info(f'RMSD: {rmsd}')
# 创建结果文件
result_path = os.path.join(dir_name, "0_whole.pdb")
relaxed_path = os.path.join(dir_name, "0_relaxed.pdb")
if os.path.exists(relaxed_path):
shutil.copy(relaxed_path, result_path)
else:
shutil.copy(pdb_save_path, result_path)
# 创建包含蛋白质和配体的合并文件用于可视化
combined_path = os.path.join(dir_name, "combined_structure.pdb")
visualization_path = create_combined_structure(result_path, sdf_save_path, combined_path)
# 创建压缩文件
zip_filename = os.path.join("./generate/upload", f"{unique_id}_results.zip")
zip_path = create_zip_file(upload_dir, zip_filename)
logger.info(f"结果已保存到 {result_path}")
logger.info(f"压缩文件已创建: {zip_path}")
summary = f"""
处理完成!
结果摘要:
- 均方根偏差 (RMSD): {rmsd}
文件说明:
- 所有结果文件已打包为ZIP文件供下载
- 包含原始输入、处理结果等
- 任务ID: {unique_id}
"""
return visualization_path, zip_path, summary
except Exception as e:
import traceback
error_trace = traceback.format_exc()
logger.error(f"处理过程中出错: {error_trace}")
return None, None, f"处理过程中出错: {str(e)}"
def gradio_interface(pdb_file, sdf_file, config_path):
"""Gradio接口函数"""
if pdb_file is None or sdf_file is None:
return None, None, "请上传PDB和SDF文件。"
logger.info(f"开始处理{pdb_file}和{sdf_file}")
pdb_viewer, zip_path, message = process_files(pdb_file, sdf_file, config_path)
if pdb_viewer and os.path.exists(pdb_viewer):
return pdb_viewer, zip_path, message
else:
return None, None, message if message else "处理失败,未知错误。"
# 创建Gradio接口
with gr.Blocks(title="蛋白质-配体处理", css=custom_css) as demo:
gr.Markdown("# 蛋白质-配体结构处理", elem_classes=["title"])
gr.Markdown("上传PDB和SDF文件进行蛋白质口袋设计和配体对接分析", elem_classes=["subtitle"])
with gr.Row():
with gr.Column(scale=1):
pdb_input = gr.File(label="上传PDB文件", file_types=[".pdb"])
sdf_input = gr.File(label="上传SDF文件", file_types=[".sdf"])
config_input = gr.Textbox(label="配置文件路径", value="./configs/train_model_moad.yml")
submit_btn = gr.Button("处理文件", variant="primary")
with gr.Column(scale=2):
# 使用Molecule3D组件,固定为默认样式
view3d = Molecule3D(
label="3D结构可视化 (蛋白质卡通 + 配体周围残基棒状)",
reps=default_reps
)
output_message = gr.Textbox(label="处理状态和结果摘要", lines=8)
output_file = gr.File(label="下载完整结果包 (ZIP)")
# 处理文件的点击事件
submit_btn.click(
fn=gradio_interface,
inputs=[pdb_input, sdf_input, config_input],
outputs=[view3d, output_file, output_message]
)
gr.Markdown("""
## 使用说明
1. **上传文件**: 上传蛋白质PDB文件和配体SDF文件
2. **配置设置**: 保持默认配置路径或调整为您的配置文件位置
3. **处理文件**: 点击"处理文件"按钮开始处理
4. **结果查看**:
- 在3D查看器中交互式查看优化后的蛋白质-配体复合物结构
- 查看详细的处理结果摘要
- 下载包含所有结果文件的ZIP压缩包
## 3D可视化功能
- **旋转**: 鼠标左键拖拽
- **缩放**: 鼠标滚轮或双指缩放
- **平移**: 鼠标右键拖拽
- **重置视图**: 双击重置到初始视角
可视化样式说明:
- 蛋白质以卡通形式显示(白色碳骨架)
- 配体周围5Å内的残基以棒状形式显示(绿色碳骨架)
## 下载文件说明
ZIP压缩包包含以下文件:
- **protein.pdb**: 原始输入蛋白质文件
- **ligand.sdf**: 原始输入配体文件
- **protein_pocket.pdb**: 提取的蛋白质口袋文件
- **0_whole.pdb**: 优化后的完整蛋白质结构
- **0_relaxed.pdb**: 松弛优化后的蛋白质结构
- **combined_structure.pdb**: 用于可视化的蛋白质-配体复合物
## 技术说明
该应用程序使用深度学习方法优化蛋白质口袋结构,提高与特定配体的结合能力。主要功能包括:
- **蛋白质口袋识别**: 自动识别并提取配体结合口袋
- **结构优化设计**: 使用AI模型优化口袋残基构象
- **分子对接评分**: 使用Vina进行结合能评估
- **交互式3D可视化**: 清晰展示蛋白质-配体相互作用
- **完整结果打包**: 所有中间和最终结果文件统一打包下载
处理可能需要几分钟时间,请耐心等待。
""")
gr.Markdown("© 2025 zaixi", elem_classes=["footer"])
if __name__ == "__main__":
demo.launch(share=True) |