from abc import abstractmethod
from torch.utils.data import (
    Dataset,
    ConcatDataset,
    ChainDataset,
    IterableDataset,
)


class Txt2ImgIterableBaseDataset(IterableDataset):
    """
    Define an interface to make the IterableDatasets for text2img data chainable
    """

    def __init__(self, num_records=0, valid_ids=None, size=256):
        super().__init__()
        self.num_records = num_records
        self.valid_ids = valid_ids
        self.sample_ids = valid_ids
        self.size = size

        print(
            f'{self.__class__.__name__} dataset contains {self.__len__()} examples.'
        )

    def __len__(self):
        return self.num_records

    @abstractmethod
    def __iter__(self):
        pass