2021-12-21 02:23:41 +00:00
|
|
|
from abc import abstractmethod
|
2022-08-26 07:15:42 +00:00
|
|
|
from torch.utils.data import (
|
|
|
|
Dataset,
|
|
|
|
ConcatDataset,
|
|
|
|
ChainDataset,
|
|
|
|
IterableDataset,
|
|
|
|
)
|
2021-12-21 02:23:41 +00:00
|
|
|
|
|
|
|
|
|
|
|
class Txt2ImgIterableBaseDataset(IterableDataset):
|
2022-08-26 07:15:42 +00:00
|
|
|
"""
|
2021-12-21 02:23:41 +00:00
|
|
|
Define an interface to make the IterableDatasets for text2img data chainable
|
2022-08-26 07:15:42 +00:00
|
|
|
"""
|
|
|
|
|
2021-12-21 02:23:41 +00:00
|
|
|
def __init__(self, num_records=0, valid_ids=None, size=256):
|
|
|
|
super().__init__()
|
|
|
|
self.num_records = num_records
|
|
|
|
self.valid_ids = valid_ids
|
|
|
|
self.sample_ids = valid_ids
|
|
|
|
self.size = size
|
|
|
|
|
2022-08-26 07:15:42 +00:00
|
|
|
print(
|
|
|
|
f'{self.__class__.__name__} dataset contains {self.__len__()} examples.'
|
|
|
|
)
|
2021-12-21 02:23:41 +00:00
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return self.num_records
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def __iter__(self):
|
2022-08-26 07:15:42 +00:00
|
|
|
pass
|