InvokeAI/invokeai/backend/stable_diffusion/data/imagenet.py

454 lines
16 KiB
Python
Raw Normal View History

2023-03-03 06:02:00 +00:00
import glob
import os
import pickle
import shutil
import tarfile
from functools import partial
2021-12-21 02:23:41 +00:00
import albumentations
2023-03-03 06:02:00 +00:00
import cv2
2021-12-21 02:23:41 +00:00
import numpy as np
2023-03-03 06:02:00 +00:00
import PIL
import taming.data.utils as tdu
2021-12-21 02:23:41 +00:00
import torchvision.transforms.functional as TF
2023-03-03 06:02:00 +00:00
import yaml
from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
2021-12-21 02:23:41 +00:00
from omegaconf import OmegaConf
from PIL import Image
from taming.data.imagenet import (
2023-03-03 06:02:00 +00:00
ImagePaths,
download,
2023-03-03 06:02:00 +00:00
give_synsets_from_indices,
retrieve,
2023-03-03 06:02:00 +00:00
str_to_indices,
)
2023-03-03 06:02:00 +00:00
from torch.utils.data import Dataset, Subset
from tqdm import tqdm
2021-12-21 02:23:41 +00:00
2023-03-03 06:02:00 +00:00
def synset2idx(path_to_yaml="data/index_synset.yaml"):
2021-12-21 02:23:41 +00:00
with open(path_to_yaml) as f:
di2s = yaml.load(f)
return dict((v, k) for k, v in di2s.items())
2021-12-21 02:23:41 +00:00
class ImageNetBase(Dataset):
def __init__(self, config=None):
self.config = config or OmegaConf.create()
if not type(self.config) == dict:
2021-12-21 02:23:41 +00:00
self.config = OmegaConf.to_container(self.config)
2023-03-03 06:02:00 +00:00
self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
2021-12-21 02:23:41 +00:00
self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
self._prepare()
self._prepare_synset_to_human()
self._prepare_idx_to_synset()
self._prepare_human_to_integer_label()
self._load()
def __len__(self):
return len(self.data)
def __getitem__(self, i):
return self.data[i]
def _prepare(self):
raise NotImplementedError()
def _filter_relpaths(self, relpaths):
ignore = set(
[
2023-03-03 06:02:00 +00:00
"n06596364_9591.JPEG",
]
)
2023-03-03 06:02:00 +00:00
relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
if "sub_indices" in self.config:
indices = str_to_indices(self.config["sub_indices"])
synsets = give_synsets_from_indices(
indices, path_to_yaml=self.idx2syn
) # returns a list of strings
2021-12-21 02:23:41 +00:00
self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
files = []
for rpath in relpaths:
2023-03-03 06:02:00 +00:00
syn = rpath.split("/")[0]
2021-12-21 02:23:41 +00:00
if syn in synsets:
files.append(rpath)
return files
else:
return relpaths
def _prepare_synset_to_human(self):
SIZE = 2655750
2023-03-03 06:02:00 +00:00
URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
self.human_dict = os.path.join(self.root, "synset_human.txt")
if (
not os.path.exists(self.human_dict)
or not os.path.getsize(self.human_dict) == SIZE
):
2021-12-21 02:23:41 +00:00
download(URL, self.human_dict)
def _prepare_idx_to_synset(self):
2023-03-03 06:02:00 +00:00
URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
self.idx2syn = os.path.join(self.root, "index_synset.yaml")
if not os.path.exists(self.idx2syn):
2021-12-21 02:23:41 +00:00
download(URL, self.idx2syn)
def _prepare_human_to_integer_label(self):
2023-03-03 06:02:00 +00:00
URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
self.human2integer = os.path.join(
2023-03-03 06:02:00 +00:00
self.root, "imagenet1000_clsidx_to_labels.txt"
)
if not os.path.exists(self.human2integer):
2021-12-21 02:23:41 +00:00
download(URL, self.human2integer)
2023-03-03 06:02:00 +00:00
with open(self.human2integer, "r") as f:
2021-12-21 02:23:41 +00:00
lines = f.read().splitlines()
assert len(lines) == 1000
self.human2integer_dict = dict()
for line in lines:
2023-03-03 06:02:00 +00:00
value, key = line.split(":")
2021-12-21 02:23:41 +00:00
self.human2integer_dict[key] = int(value)
def _load(self):
2023-03-03 06:02:00 +00:00
with open(self.txt_filelist, "r") as f:
2021-12-21 02:23:41 +00:00
self.relpaths = f.read().splitlines()
l1 = len(self.relpaths)
self.relpaths = self._filter_relpaths(self.relpaths)
print(
2023-03-03 06:02:00 +00:00
"Removed {} files from filelist during filtering.".format(
l1 - len(self.relpaths)
)
)
2021-12-21 02:23:41 +00:00
2023-03-03 06:02:00 +00:00
self.synsets = [p.split("/")[0] for p in self.relpaths]
2021-12-21 02:23:41 +00:00
self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
unique_synsets = np.unique(self.synsets)
2023-03-03 06:02:00 +00:00
class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
2021-12-21 02:23:41 +00:00
if not self.keep_orig_class_label:
self.class_labels = [class_dict[s] for s in self.synsets]
else:
self.class_labels = [self.synset2idx[s] for s in self.synsets]
2023-03-03 06:02:00 +00:00
with open(self.human_dict, "r") as f:
2021-12-21 02:23:41 +00:00
human_dict = f.read().splitlines()
human_dict = dict(line.split(maxsplit=1) for line in human_dict)
self.human_labels = [human_dict[s] for s in self.synsets]
labels = {
2023-03-03 06:02:00 +00:00
"relpath": np.array(self.relpaths),
"synsets": np.array(self.synsets),
"class_label": np.array(self.class_labels),
"human_label": np.array(self.human_labels),
2021-12-21 02:23:41 +00:00
}
if self.process_images:
2023-03-03 06:02:00 +00:00
self.size = retrieve(self.config, "size", default=256)
self.data = ImagePaths(
self.abspaths,
labels=labels,
size=self.size,
random_crop=self.random_crop,
)
2021-12-21 02:23:41 +00:00
else:
self.data = self.abspaths
class ImageNetTrain(ImageNetBase):
2023-03-03 06:02:00 +00:00
NAME = "ILSVRC2012_train"
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
2021-12-21 02:23:41 +00:00
FILES = [
2023-03-03 06:02:00 +00:00
"ILSVRC2012_img_train.tar",
2021-12-21 02:23:41 +00:00
]
SIZES = [
147897477120,
]
def __init__(self, process_images=True, data_root=None, **kwargs):
self.process_images = process_images
self.data_root = data_root
super().__init__(**kwargs)
def _prepare(self):
if self.data_root:
self.root = os.path.join(self.data_root, self.NAME)
else:
2023-03-03 06:02:00 +00:00
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
2021-12-21 02:23:41 +00:00
2023-03-03 06:02:00 +00:00
self.datadir = os.path.join(self.root, "data")
self.txt_filelist = os.path.join(self.root, "filelist.txt")
2021-12-21 02:23:41 +00:00
self.expected_length = 1281167
self.random_crop = retrieve(
2023-03-03 06:02:00 +00:00
self.config, "ImageNetTrain/random_crop", default=True
)
2021-12-21 02:23:41 +00:00
if not tdu.is_prepared(self.root):
# prep
2023-03-03 06:02:00 +00:00
print("Preparing dataset {} in {}".format(self.NAME, self.root))
2021-12-21 02:23:41 +00:00
datadir = self.datadir
if not os.path.exists(datadir):
path = os.path.join(self.root, self.FILES[0])
if (
not os.path.exists(path)
or not os.path.getsize(path) == self.SIZES[0]
):
2021-12-21 02:23:41 +00:00
import academictorrents as at
2021-12-21 02:23:41 +00:00
atpath = at.get(self.AT_HASH, datastore=self.root)
assert atpath == path
2023-03-03 06:02:00 +00:00
print("Extracting {} to {}".format(path, datadir))
2021-12-21 02:23:41 +00:00
os.makedirs(datadir, exist_ok=True)
2023-03-03 06:02:00 +00:00
with tarfile.open(path, "r:") as tar:
2021-12-21 02:23:41 +00:00
tar.extractall(path=datadir)
2023-03-03 06:02:00 +00:00
print("Extracting sub-tars.")
subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
2021-12-21 02:23:41 +00:00
for subpath in tqdm(subpaths):
2023-03-03 06:02:00 +00:00
subdir = subpath[: -len(".tar")]
2021-12-21 02:23:41 +00:00
os.makedirs(subdir, exist_ok=True)
2023-03-03 06:02:00 +00:00
with tarfile.open(subpath, "r:") as tar:
2021-12-21 02:23:41 +00:00
tar.extractall(path=subdir)
2023-03-03 06:02:00 +00:00
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
2021-12-21 02:23:41 +00:00
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
filelist = sorted(filelist)
2023-03-03 06:02:00 +00:00
filelist = "\n".join(filelist) + "\n"
with open(self.txt_filelist, "w") as f:
2021-12-21 02:23:41 +00:00
f.write(filelist)
tdu.mark_prepared(self.root)
class ImageNetValidation(ImageNetBase):
2023-03-03 06:02:00 +00:00
NAME = "ILSVRC2012_validation"
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
2021-12-21 02:23:41 +00:00
FILES = [
2023-03-03 06:02:00 +00:00
"ILSVRC2012_img_val.tar",
"validation_synset.txt",
2021-12-21 02:23:41 +00:00
]
SIZES = [
6744924160,
1950000,
]
def __init__(self, process_images=True, data_root=None, **kwargs):
self.data_root = data_root
self.process_images = process_images
super().__init__(**kwargs)
def _prepare(self):
if self.data_root:
self.root = os.path.join(self.data_root, self.NAME)
else:
2023-03-03 06:02:00 +00:00
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
self.datadir = os.path.join(self.root, "data")
self.txt_filelist = os.path.join(self.root, "filelist.txt")
2021-12-21 02:23:41 +00:00
self.expected_length = 50000
self.random_crop = retrieve(
2023-03-03 06:02:00 +00:00
self.config, "ImageNetValidation/random_crop", default=False
)
2021-12-21 02:23:41 +00:00
if not tdu.is_prepared(self.root):
# prep
2023-03-03 06:02:00 +00:00
print("Preparing dataset {} in {}".format(self.NAME, self.root))
2021-12-21 02:23:41 +00:00
datadir = self.datadir
if not os.path.exists(datadir):
path = os.path.join(self.root, self.FILES[0])
if (
not os.path.exists(path)
or not os.path.getsize(path) == self.SIZES[0]
):
2021-12-21 02:23:41 +00:00
import academictorrents as at
2021-12-21 02:23:41 +00:00
atpath = at.get(self.AT_HASH, datastore=self.root)
assert atpath == path
2023-03-03 06:02:00 +00:00
print("Extracting {} to {}".format(path, datadir))
2021-12-21 02:23:41 +00:00
os.makedirs(datadir, exist_ok=True)
2023-03-03 06:02:00 +00:00
with tarfile.open(path, "r:") as tar:
2021-12-21 02:23:41 +00:00
tar.extractall(path=datadir)
vspath = os.path.join(self.root, self.FILES[1])
if (
not os.path.exists(vspath)
or not os.path.getsize(vspath) == self.SIZES[1]
):
2021-12-21 02:23:41 +00:00
download(self.VS_URL, vspath)
2023-03-03 06:02:00 +00:00
with open(vspath, "r") as f:
2021-12-21 02:23:41 +00:00
synset_dict = f.read().splitlines()
synset_dict = dict(line.split() for line in synset_dict)
2023-03-03 06:02:00 +00:00
print("Reorganizing into synset folders")
2021-12-21 02:23:41 +00:00
synsets = np.unique(list(synset_dict.values()))
for s in synsets:
os.makedirs(os.path.join(datadir, s), exist_ok=True)
for k, v in synset_dict.items():
src = os.path.join(datadir, k)
dst = os.path.join(datadir, v)
shutil.move(src, dst)
2023-03-03 06:02:00 +00:00
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
2021-12-21 02:23:41 +00:00
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
filelist = sorted(filelist)
2023-03-03 06:02:00 +00:00
filelist = "\n".join(filelist) + "\n"
with open(self.txt_filelist, "w") as f:
2021-12-21 02:23:41 +00:00
f.write(filelist)
tdu.mark_prepared(self.root)
class ImageNetSR(Dataset):
def __init__(
self,
size=None,
degradation=None,
downscale_f=4,
min_crop_f=0.5,
max_crop_f=1.0,
random_crop=True,
):
2021-12-21 02:23:41 +00:00
"""
Imagenet Superresolution Dataloader
Performs following ops in order:
1. crops a crop of size s from image either as random or center crop
2. resizes crop to size with cv2.area_interpolation
3. degrades resized crop with degradation_fn
:param size: resizing to size after cropping
:param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
:param downscale_f: Low Resolution Downsample factor
:param min_crop_f: determines crop size s,
where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
:param max_crop_f: ""
:param data_root:
:param random_crop:
"""
self.base = self.get_base()
assert size
assert (size / downscale_f).is_integer()
self.size = size
self.LR_size = int(size / downscale_f)
self.min_crop_f = min_crop_f
self.max_crop_f = max_crop_f
assert max_crop_f <= 1.0
2021-12-21 02:23:41 +00:00
self.center_crop = not random_crop
self.image_rescaler = albumentations.SmallestMaxSize(
max_size=size, interpolation=cv2.INTER_AREA
)
2021-12-21 02:23:41 +00:00
self.pil_interpolation = (
False # gets reset later if incase interp_op is from pillow
)
2021-12-21 02:23:41 +00:00
2023-03-03 06:02:00 +00:00
if degradation == "bsrgan":
self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
2021-12-21 02:23:41 +00:00
2023-03-03 06:02:00 +00:00
elif degradation == "bsrgan_light":
self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
2021-12-21 02:23:41 +00:00
else:
interpolation_fn = {
2023-03-03 06:02:00 +00:00
"cv_nearest": cv2.INTER_NEAREST,
"cv_bilinear": cv2.INTER_LINEAR,
"cv_bicubic": cv2.INTER_CUBIC,
"cv_area": cv2.INTER_AREA,
"cv_lanczos": cv2.INTER_LANCZOS4,
"pil_nearest": PIL.Image.NEAREST,
"pil_bilinear": PIL.Image.BILINEAR,
"pil_bicubic": PIL.Image.BICUBIC,
"pil_box": PIL.Image.BOX,
"pil_hamming": PIL.Image.HAMMING,
"pil_lanczos": PIL.Image.LANCZOS,
2021-12-21 02:23:41 +00:00
}[degradation]
2023-03-03 06:02:00 +00:00
self.pil_interpolation = degradation.startswith("pil_")
2021-12-21 02:23:41 +00:00
if self.pil_interpolation:
self.degradation_process = partial(
TF.resize,
size=self.LR_size,
interpolation=interpolation_fn,
)
2021-12-21 02:23:41 +00:00
else:
self.degradation_process = albumentations.SmallestMaxSize(
max_size=self.LR_size, interpolation=interpolation_fn
)
2021-12-21 02:23:41 +00:00
def __len__(self):
return len(self.base)
def __getitem__(self, i):
example = self.base[i]
2023-03-03 06:02:00 +00:00
image = Image.open(example["file_path_"])
2021-12-21 02:23:41 +00:00
2023-03-03 06:02:00 +00:00
if not image.mode == "RGB":
image = image.convert("RGB")
2021-12-21 02:23:41 +00:00
image = np.array(image).astype(np.uint8)
min_side_len = min(image.shape[:2])
crop_side_len = min_side_len * np.random.uniform(
self.min_crop_f, self.max_crop_f, size=None
)
2021-12-21 02:23:41 +00:00
crop_side_len = int(crop_side_len)
if self.center_crop:
self.cropper = albumentations.CenterCrop(
height=crop_side_len, width=crop_side_len
)
2021-12-21 02:23:41 +00:00
else:
self.cropper = albumentations.RandomCrop(
height=crop_side_len, width=crop_side_len
)
2021-12-21 02:23:41 +00:00
2023-03-03 06:02:00 +00:00
image = self.cropper(image=image)["image"]
image = self.image_rescaler(image=image)["image"]
2021-12-21 02:23:41 +00:00
if self.pil_interpolation:
image_pil = PIL.Image.fromarray(image)
LR_image = self.degradation_process(image_pil)
LR_image = np.array(LR_image).astype(np.uint8)
else:
2023-03-03 06:02:00 +00:00
LR_image = self.degradation_process(image=image)["image"]
2021-12-21 02:23:41 +00:00
2023-03-03 06:02:00 +00:00
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
example["LR_image"] = (LR_image / 127.5 - 1.0).astype(np.float32)
2021-12-21 02:23:41 +00:00
return example
class ImageNetSRTrain(ImageNetSR):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def get_base(self):
2023-03-03 06:02:00 +00:00
with open("data/imagenet_train_hr_indices.p", "rb") as f:
2021-12-21 02:23:41 +00:00
indices = pickle.load(f)
dset = ImageNetTrain(
process_images=False,
)
2021-12-21 02:23:41 +00:00
return Subset(dset, indices)
class ImageNetSRValidation(ImageNetSR):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def get_base(self):
2023-03-03 06:02:00 +00:00
with open("data/imagenet_val_hr_indices.p", "rb") as f:
2021-12-21 02:23:41 +00:00
indices = pickle.load(f)
dset = ImageNetValidation(
process_images=False,
)
2021-12-21 02:23:41 +00:00
return Subset(dset, indices)