from abc import abstractmethod from torch.utils.data import ( Dataset, ConcatDataset, ChainDataset, IterableDataset, ) class Txt2ImgIterableBaseDataset(IterableDataset): """ Define an interface to make the IterableDatasets for text2img data chainable """ def __init__(self, num_records=0, valid_ids=None, size=256): super().__init__() self.num_records = num_records self.valid_ids = valid_ids self.sample_ids = valid_ids self.size = size print( f'{self.__class__.__name__} dataset contains {self.__len__()} examples.' ) def __len__(self): return self.num_records @abstractmethod def __iter__(self): pass