Pocket-Gen / utils /datasets /__init__.py
Zaixi's picture
1
dcacefd
import torch
from torch.utils.data import Subset
from .pl import PocketLigandPairDataset
def get_dataset(config, *args, **kwargs):
name = config.name
root = config.path
if name == 'pl':
dataset = PocketLigandPairDataset(root, *args, **kwargs)
else:
raise NotImplementedError('Unknown dataset: %s' % name)
if 'split' in config:
split = torch.load(config.split)
subsets = {k: Subset(dataset, indices=v) for k, v in split.items()}
return dataset, subsets
else:
return dataset