InvokeAI/invokeai/backend/stable_diffusion/data/base.py

26 lines
696 B
Python
Raw Normal View History

2021-12-21 02:23:41 +00:00
from abc import abstractmethod
2023-03-03 06:02:00 +00:00
from torch.utils.data import ChainDataset, ConcatDataset, Dataset, IterableDataset
2021-12-21 02:23:41 +00:00
class Txt2ImgIterableBaseDataset(IterableDataset):
"""
2021-12-21 02:23:41 +00:00
Define an interface to make the IterableDatasets for text2img data chainable
"""
2021-12-21 02:23:41 +00:00
def __init__(self, num_records=0, valid_ids=None, size=256):
super().__init__()
self.num_records = num_records
self.valid_ids = valid_ids
self.sample_ids = valid_ids
self.size = size
2023-03-03 06:02:00 +00:00
print(f"{self.__class__.__name__} dataset contains {self.__len__()} examples.")
2021-12-21 02:23:41 +00:00
def __len__(self):
return self.num_records
@abstractmethod
def __iter__(self):
pass