Spaces:
Sleeping
Sleeping
| import torch | |
| from torch.utils.data import Subset | |
| from .pl import PocketLigandPairDataset | |
| def get_dataset(config, *args, **kwargs): | |
| name = config.name | |
| root = config.path | |
| if name == 'pl': | |
| dataset = PocketLigandPairDataset(root, *args, **kwargs) | |
| else: | |
| raise NotImplementedError('Unknown dataset: %s' % name) | |
| if 'split' in config: | |
| split = torch.load(config.split) | |
| subsets = {k: Subset(dataset, indices=v) for k, v in split.items()} | |
| return dataset, subsets | |
| else: | |
| return dataset | |