mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
467 lines
16 KiB
Python
467 lines
16 KiB
Python
import os, yaml, pickle, shutil, tarfile, glob
|
|
import cv2
|
|
import albumentations
|
|
import PIL
|
|
import numpy as np
|
|
import torchvision.transforms.functional as TF
|
|
from omegaconf import OmegaConf
|
|
from functools import partial
|
|
from PIL import Image
|
|
from tqdm import tqdm
|
|
from torch.utils.data import Dataset, Subset
|
|
|
|
import taming.data.utils as tdu
|
|
from taming.data.imagenet import (
|
|
str_to_indices,
|
|
give_synsets_from_indices,
|
|
download,
|
|
retrieve,
|
|
)
|
|
from taming.data.imagenet import ImagePaths
|
|
|
|
from ldm.modules.image_degradation import (
|
|
degradation_fn_bsr,
|
|
degradation_fn_bsr_light,
|
|
)
|
|
|
|
|
|
def synset2idx(path_to_yaml='data/index_synset.yaml'):
|
|
with open(path_to_yaml) as f:
|
|
di2s = yaml.load(f)
|
|
return dict((v, k) for k, v in di2s.items())
|
|
|
|
|
|
class ImageNetBase(Dataset):
|
|
def __init__(self, config=None):
|
|
self.config = config or OmegaConf.create()
|
|
if not type(self.config) == dict:
|
|
self.config = OmegaConf.to_container(self.config)
|
|
self.keep_orig_class_label = self.config.get(
|
|
'keep_orig_class_label', False
|
|
)
|
|
self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
|
|
self._prepare()
|
|
self._prepare_synset_to_human()
|
|
self._prepare_idx_to_synset()
|
|
self._prepare_human_to_integer_label()
|
|
self._load()
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, i):
|
|
return self.data[i]
|
|
|
|
def _prepare(self):
|
|
raise NotImplementedError()
|
|
|
|
def _filter_relpaths(self, relpaths):
|
|
ignore = set(
|
|
[
|
|
'n06596364_9591.JPEG',
|
|
]
|
|
)
|
|
relpaths = [
|
|
rpath for rpath in relpaths if not rpath.split('/')[-1] in ignore
|
|
]
|
|
if 'sub_indices' in self.config:
|
|
indices = str_to_indices(self.config['sub_indices'])
|
|
synsets = give_synsets_from_indices(
|
|
indices, path_to_yaml=self.idx2syn
|
|
) # returns a list of strings
|
|
self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
|
|
files = []
|
|
for rpath in relpaths:
|
|
syn = rpath.split('/')[0]
|
|
if syn in synsets:
|
|
files.append(rpath)
|
|
return files
|
|
else:
|
|
return relpaths
|
|
|
|
def _prepare_synset_to_human(self):
|
|
SIZE = 2655750
|
|
URL = 'https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1'
|
|
self.human_dict = os.path.join(self.root, 'synset_human.txt')
|
|
if (
|
|
not os.path.exists(self.human_dict)
|
|
or not os.path.getsize(self.human_dict) == SIZE
|
|
):
|
|
download(URL, self.human_dict)
|
|
|
|
def _prepare_idx_to_synset(self):
|
|
URL = 'https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1'
|
|
self.idx2syn = os.path.join(self.root, 'index_synset.yaml')
|
|
if not os.path.exists(self.idx2syn):
|
|
download(URL, self.idx2syn)
|
|
|
|
def _prepare_human_to_integer_label(self):
|
|
URL = 'https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1'
|
|
self.human2integer = os.path.join(
|
|
self.root, 'imagenet1000_clsidx_to_labels.txt'
|
|
)
|
|
if not os.path.exists(self.human2integer):
|
|
download(URL, self.human2integer)
|
|
with open(self.human2integer, 'r') as f:
|
|
lines = f.read().splitlines()
|
|
assert len(lines) == 1000
|
|
self.human2integer_dict = dict()
|
|
for line in lines:
|
|
value, key = line.split(':')
|
|
self.human2integer_dict[key] = int(value)
|
|
|
|
def _load(self):
|
|
with open(self.txt_filelist, 'r') as f:
|
|
self.relpaths = f.read().splitlines()
|
|
l1 = len(self.relpaths)
|
|
self.relpaths = self._filter_relpaths(self.relpaths)
|
|
print(
|
|
'Removed {} files from filelist during filtering.'.format(
|
|
l1 - len(self.relpaths)
|
|
)
|
|
)
|
|
|
|
self.synsets = [p.split('/')[0] for p in self.relpaths]
|
|
self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
|
|
|
|
unique_synsets = np.unique(self.synsets)
|
|
class_dict = dict(
|
|
(synset, i) for i, synset in enumerate(unique_synsets)
|
|
)
|
|
if not self.keep_orig_class_label:
|
|
self.class_labels = [class_dict[s] for s in self.synsets]
|
|
else:
|
|
self.class_labels = [self.synset2idx[s] for s in self.synsets]
|
|
|
|
with open(self.human_dict, 'r') as f:
|
|
human_dict = f.read().splitlines()
|
|
human_dict = dict(line.split(maxsplit=1) for line in human_dict)
|
|
|
|
self.human_labels = [human_dict[s] for s in self.synsets]
|
|
|
|
labels = {
|
|
'relpath': np.array(self.relpaths),
|
|
'synsets': np.array(self.synsets),
|
|
'class_label': np.array(self.class_labels),
|
|
'human_label': np.array(self.human_labels),
|
|
}
|
|
|
|
if self.process_images:
|
|
self.size = retrieve(self.config, 'size', default=256)
|
|
self.data = ImagePaths(
|
|
self.abspaths,
|
|
labels=labels,
|
|
size=self.size,
|
|
random_crop=self.random_crop,
|
|
)
|
|
else:
|
|
self.data = self.abspaths
|
|
|
|
|
|
class ImageNetTrain(ImageNetBase):
|
|
NAME = 'ILSVRC2012_train'
|
|
URL = 'http://www.image-net.org/challenges/LSVRC/2012/'
|
|
AT_HASH = 'a306397ccf9c2ead27155983c254227c0fd938e2'
|
|
FILES = [
|
|
'ILSVRC2012_img_train.tar',
|
|
]
|
|
SIZES = [
|
|
147897477120,
|
|
]
|
|
|
|
def __init__(self, process_images=True, data_root=None, **kwargs):
|
|
self.process_images = process_images
|
|
self.data_root = data_root
|
|
super().__init__(**kwargs)
|
|
|
|
def _prepare(self):
|
|
if self.data_root:
|
|
self.root = os.path.join(self.data_root, self.NAME)
|
|
else:
|
|
cachedir = os.environ.get(
|
|
'XDG_CACHE_HOME', os.path.expanduser('~/.cache')
|
|
)
|
|
self.root = os.path.join(cachedir, 'autoencoders/data', self.NAME)
|
|
|
|
self.datadir = os.path.join(self.root, 'data')
|
|
self.txt_filelist = os.path.join(self.root, 'filelist.txt')
|
|
self.expected_length = 1281167
|
|
self.random_crop = retrieve(
|
|
self.config, 'ImageNetTrain/random_crop', default=True
|
|
)
|
|
if not tdu.is_prepared(self.root):
|
|
# prep
|
|
print('Preparing dataset {} in {}'.format(self.NAME, self.root))
|
|
|
|
datadir = self.datadir
|
|
if not os.path.exists(datadir):
|
|
path = os.path.join(self.root, self.FILES[0])
|
|
if (
|
|
not os.path.exists(path)
|
|
or not os.path.getsize(path) == self.SIZES[0]
|
|
):
|
|
import academictorrents as at
|
|
|
|
atpath = at.get(self.AT_HASH, datastore=self.root)
|
|
assert atpath == path
|
|
|
|
print('Extracting {} to {}'.format(path, datadir))
|
|
os.makedirs(datadir, exist_ok=True)
|
|
with tarfile.open(path, 'r:') as tar:
|
|
tar.extractall(path=datadir)
|
|
|
|
print('Extracting sub-tars.')
|
|
subpaths = sorted(glob.glob(os.path.join(datadir, '*.tar')))
|
|
for subpath in tqdm(subpaths):
|
|
subdir = subpath[: -len('.tar')]
|
|
os.makedirs(subdir, exist_ok=True)
|
|
with tarfile.open(subpath, 'r:') as tar:
|
|
tar.extractall(path=subdir)
|
|
|
|
filelist = glob.glob(os.path.join(datadir, '**', '*.JPEG'))
|
|
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
|
|
filelist = sorted(filelist)
|
|
filelist = '\n'.join(filelist) + '\n'
|
|
with open(self.txt_filelist, 'w') as f:
|
|
f.write(filelist)
|
|
|
|
tdu.mark_prepared(self.root)
|
|
|
|
|
|
class ImageNetValidation(ImageNetBase):
|
|
NAME = 'ILSVRC2012_validation'
|
|
URL = 'http://www.image-net.org/challenges/LSVRC/2012/'
|
|
AT_HASH = '5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5'
|
|
VS_URL = 'https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1'
|
|
FILES = [
|
|
'ILSVRC2012_img_val.tar',
|
|
'validation_synset.txt',
|
|
]
|
|
SIZES = [
|
|
6744924160,
|
|
1950000,
|
|
]
|
|
|
|
def __init__(self, process_images=True, data_root=None, **kwargs):
|
|
self.data_root = data_root
|
|
self.process_images = process_images
|
|
super().__init__(**kwargs)
|
|
|
|
def _prepare(self):
|
|
if self.data_root:
|
|
self.root = os.path.join(self.data_root, self.NAME)
|
|
else:
|
|
cachedir = os.environ.get(
|
|
'XDG_CACHE_HOME', os.path.expanduser('~/.cache')
|
|
)
|
|
self.root = os.path.join(cachedir, 'autoencoders/data', self.NAME)
|
|
self.datadir = os.path.join(self.root, 'data')
|
|
self.txt_filelist = os.path.join(self.root, 'filelist.txt')
|
|
self.expected_length = 50000
|
|
self.random_crop = retrieve(
|
|
self.config, 'ImageNetValidation/random_crop', default=False
|
|
)
|
|
if not tdu.is_prepared(self.root):
|
|
# prep
|
|
print('Preparing dataset {} in {}'.format(self.NAME, self.root))
|
|
|
|
datadir = self.datadir
|
|
if not os.path.exists(datadir):
|
|
path = os.path.join(self.root, self.FILES[0])
|
|
if (
|
|
not os.path.exists(path)
|
|
or not os.path.getsize(path) == self.SIZES[0]
|
|
):
|
|
import academictorrents as at
|
|
|
|
atpath = at.get(self.AT_HASH, datastore=self.root)
|
|
assert atpath == path
|
|
|
|
print('Extracting {} to {}'.format(path, datadir))
|
|
os.makedirs(datadir, exist_ok=True)
|
|
with tarfile.open(path, 'r:') as tar:
|
|
tar.extractall(path=datadir)
|
|
|
|
vspath = os.path.join(self.root, self.FILES[1])
|
|
if (
|
|
not os.path.exists(vspath)
|
|
or not os.path.getsize(vspath) == self.SIZES[1]
|
|
):
|
|
download(self.VS_URL, vspath)
|
|
|
|
with open(vspath, 'r') as f:
|
|
synset_dict = f.read().splitlines()
|
|
synset_dict = dict(line.split() for line in synset_dict)
|
|
|
|
print('Reorganizing into synset folders')
|
|
synsets = np.unique(list(synset_dict.values()))
|
|
for s in synsets:
|
|
os.makedirs(os.path.join(datadir, s), exist_ok=True)
|
|
for k, v in synset_dict.items():
|
|
src = os.path.join(datadir, k)
|
|
dst = os.path.join(datadir, v)
|
|
shutil.move(src, dst)
|
|
|
|
filelist = glob.glob(os.path.join(datadir, '**', '*.JPEG'))
|
|
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
|
|
filelist = sorted(filelist)
|
|
filelist = '\n'.join(filelist) + '\n'
|
|
with open(self.txt_filelist, 'w') as f:
|
|
f.write(filelist)
|
|
|
|
tdu.mark_prepared(self.root)
|
|
|
|
|
|
class ImageNetSR(Dataset):
|
|
def __init__(
|
|
self,
|
|
size=None,
|
|
degradation=None,
|
|
downscale_f=4,
|
|
min_crop_f=0.5,
|
|
max_crop_f=1.0,
|
|
random_crop=True,
|
|
):
|
|
"""
|
|
Imagenet Superresolution Dataloader
|
|
Performs following ops in order:
|
|
1. crops a crop of size s from image either as random or center crop
|
|
2. resizes crop to size with cv2.area_interpolation
|
|
3. degrades resized crop with degradation_fn
|
|
|
|
:param size: resizing to size after cropping
|
|
:param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
|
|
:param downscale_f: Low Resolution Downsample factor
|
|
:param min_crop_f: determines crop size s,
|
|
where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
|
|
:param max_crop_f: ""
|
|
:param data_root:
|
|
:param random_crop:
|
|
"""
|
|
self.base = self.get_base()
|
|
assert size
|
|
assert (size / downscale_f).is_integer()
|
|
self.size = size
|
|
self.LR_size = int(size / downscale_f)
|
|
self.min_crop_f = min_crop_f
|
|
self.max_crop_f = max_crop_f
|
|
assert max_crop_f <= 1.0
|
|
self.center_crop = not random_crop
|
|
|
|
self.image_rescaler = albumentations.SmallestMaxSize(
|
|
max_size=size, interpolation=cv2.INTER_AREA
|
|
)
|
|
|
|
self.pil_interpolation = (
|
|
False # gets reset later if incase interp_op is from pillow
|
|
)
|
|
|
|
if degradation == 'bsrgan':
|
|
self.degradation_process = partial(
|
|
degradation_fn_bsr, sf=downscale_f
|
|
)
|
|
|
|
elif degradation == 'bsrgan_light':
|
|
self.degradation_process = partial(
|
|
degradation_fn_bsr_light, sf=downscale_f
|
|
)
|
|
|
|
else:
|
|
interpolation_fn = {
|
|
'cv_nearest': cv2.INTER_NEAREST,
|
|
'cv_bilinear': cv2.INTER_LINEAR,
|
|
'cv_bicubic': cv2.INTER_CUBIC,
|
|
'cv_area': cv2.INTER_AREA,
|
|
'cv_lanczos': cv2.INTER_LANCZOS4,
|
|
'pil_nearest': PIL.Image.NEAREST,
|
|
'pil_bilinear': PIL.Image.BILINEAR,
|
|
'pil_bicubic': PIL.Image.BICUBIC,
|
|
'pil_box': PIL.Image.BOX,
|
|
'pil_hamming': PIL.Image.HAMMING,
|
|
'pil_lanczos': PIL.Image.LANCZOS,
|
|
}[degradation]
|
|
|
|
self.pil_interpolation = degradation.startswith('pil_')
|
|
|
|
if self.pil_interpolation:
|
|
self.degradation_process = partial(
|
|
TF.resize,
|
|
size=self.LR_size,
|
|
interpolation=interpolation_fn,
|
|
)
|
|
|
|
else:
|
|
self.degradation_process = albumentations.SmallestMaxSize(
|
|
max_size=self.LR_size, interpolation=interpolation_fn
|
|
)
|
|
|
|
def __len__(self):
|
|
return len(self.base)
|
|
|
|
def __getitem__(self, i):
|
|
example = self.base[i]
|
|
image = Image.open(example['file_path_'])
|
|
|
|
if not image.mode == 'RGB':
|
|
image = image.convert('RGB')
|
|
|
|
image = np.array(image).astype(np.uint8)
|
|
|
|
min_side_len = min(image.shape[:2])
|
|
crop_side_len = min_side_len * np.random.uniform(
|
|
self.min_crop_f, self.max_crop_f, size=None
|
|
)
|
|
crop_side_len = int(crop_side_len)
|
|
|
|
if self.center_crop:
|
|
self.cropper = albumentations.CenterCrop(
|
|
height=crop_side_len, width=crop_side_len
|
|
)
|
|
|
|
else:
|
|
self.cropper = albumentations.RandomCrop(
|
|
height=crop_side_len, width=crop_side_len
|
|
)
|
|
|
|
image = self.cropper(image=image)['image']
|
|
image = self.image_rescaler(image=image)['image']
|
|
|
|
if self.pil_interpolation:
|
|
image_pil = PIL.Image.fromarray(image)
|
|
LR_image = self.degradation_process(image_pil)
|
|
LR_image = np.array(LR_image).astype(np.uint8)
|
|
|
|
else:
|
|
LR_image = self.degradation_process(image=image)['image']
|
|
|
|
example['image'] = (image / 127.5 - 1.0).astype(np.float32)
|
|
example['LR_image'] = (LR_image / 127.5 - 1.0).astype(np.float32)
|
|
|
|
return example
|
|
|
|
|
|
class ImageNetSRTrain(ImageNetSR):
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
def get_base(self):
|
|
with open('data/imagenet_train_hr_indices.p', 'rb') as f:
|
|
indices = pickle.load(f)
|
|
dset = ImageNetTrain(
|
|
process_images=False,
|
|
)
|
|
return Subset(dset, indices)
|
|
|
|
|
|
class ImageNetSRValidation(ImageNetSR):
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
def get_base(self):
|
|
with open('data/imagenet_val_hr_indices.p', 'rb') as f:
|
|
indices = pickle.load(f)
|
|
dset = ImageNetValidation(
|
|
process_images=False,
|
|
)
|
|
return Subset(dset, indices)
|