From 4f02b72c9ca464f53c0f7f23ee483f7a24b631ad Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 26 Aug 2022 03:15:42 -0400 Subject: [PATCH] prettified all the code using "blue" at the urging of @tildebyte --- ldm/data/base.py | 18 +- ldm/data/imagenet.py | 322 ++-- ldm/data/lsun.py | 104 +- ldm/data/personalized.py | 116 +- ldm/data/personalized_style.py | 108 +- ldm/dream/pngwriter.py | 98 +- ldm/dream/readline.py | 95 +- ldm/lr_scheduler.py | 89 +- ldm/models/autoencoder.py | 425 +++-- ldm/models/diffusion/classifier.py | 226 ++- ldm/models/diffusion/ddim.py | 389 ++-- ldm/models/diffusion/ddpm.py | 1581 ++++++++++++----- ldm/models/diffusion/ksampler.py | 79 +- ldm/models/diffusion/plms.py | 373 ++-- ldm/modules/attention.py | 167 +- ldm/modules/diffusionmodules/model.py | 808 +++++---- ldm/modules/diffusionmodules/openaimodel.py | 203 ++- ldm/modules/diffusionmodules/util.py | 105 +- ldm/modules/distributions/distributions.py | 38 +- ldm/modules/ema.py | 34 +- ldm/modules/embedding_manager.py | 191 +- ldm/modules/encoders/modules.py | 386 ++-- ldm/modules/image_degradation/__init__.py | 8 +- ldm/modules/image_degradation/bsrgan.py | 306 +++- ldm/modules/image_degradation/bsrgan_light.py | 257 ++- ldm/modules/image_degradation/utils_image.py | 346 ++-- ldm/modules/losses/__init__.py | 2 +- ldm/modules/losses/contperceptual.py | 146 +- ldm/modules/losses/vqperceptual.py | 197 +- ldm/modules/x_transformer.py | 382 ++-- ldm/simplet2i.py | 501 +++--- ldm/util.py | 56 +- main.py | 685 ++++--- scripts/dream.py | 482 +++-- scripts/preload_models.py | 48 +- 35 files changed, 6252 insertions(+), 3119 deletions(-) diff --git a/ldm/data/base.py b/ldm/data/base.py index b196c2f7aa..de9493fc1e 100644 --- a/ldm/data/base.py +++ b/ldm/data/base.py @@ -1,11 +1,17 @@ from abc import abstractmethod -from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset +from torch.utils.data import ( + Dataset, + ConcatDataset, + ChainDataset, + IterableDataset, +) class Txt2ImgIterableBaseDataset(IterableDataset): - ''' + """ Define an interface to make the IterableDatasets for text2img data chainable - ''' + """ + def __init__(self, num_records=0, valid_ids=None, size=256): super().__init__() self.num_records = num_records @@ -13,11 +19,13 @@ class Txt2ImgIterableBaseDataset(IterableDataset): self.sample_ids = valid_ids self.size = size - print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') + print( + f'{self.__class__.__name__} dataset contains {self.__len__()} examples.' + ) def __len__(self): return self.num_records @abstractmethod def __iter__(self): - pass \ No newline at end of file + pass diff --git a/ldm/data/imagenet.py b/ldm/data/imagenet.py index 1c473f9c69..d155f6d6ae 100644 --- a/ldm/data/imagenet.py +++ b/ldm/data/imagenet.py @@ -11,24 +11,34 @@ from tqdm import tqdm from torch.utils.data import Dataset, Subset import taming.data.utils as tdu -from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve +from taming.data.imagenet import ( + str_to_indices, + give_synsets_from_indices, + download, + retrieve, +) from taming.data.imagenet import ImagePaths -from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light +from ldm.modules.image_degradation import ( + degradation_fn_bsr, + degradation_fn_bsr_light, +) -def synset2idx(path_to_yaml="data/index_synset.yaml"): +def synset2idx(path_to_yaml='data/index_synset.yaml'): with open(path_to_yaml) as f: di2s = yaml.load(f) - return dict((v,k) for k,v in di2s.items()) + return dict((v, k) for k, v in di2s.items()) class ImageNetBase(Dataset): def __init__(self, config=None): self.config = config or OmegaConf.create() - if not type(self.config)==dict: + if not type(self.config) == dict: self.config = OmegaConf.to_container(self.config) - self.keep_orig_class_label = self.config.get("keep_orig_class_label", False) + self.keep_orig_class_label = self.config.get( + 'keep_orig_class_label', False + ) self.process_images = True # if False we skip loading & processing images and self.data contains filepaths self._prepare() self._prepare_synset_to_human() @@ -46,17 +56,23 @@ class ImageNetBase(Dataset): raise NotImplementedError() def _filter_relpaths(self, relpaths): - ignore = set([ - "n06596364_9591.JPEG", - ]) - relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] - if "sub_indices" in self.config: - indices = str_to_indices(self.config["sub_indices"]) - synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings + ignore = set( + [ + 'n06596364_9591.JPEG', + ] + ) + relpaths = [ + rpath for rpath in relpaths if not rpath.split('/')[-1] in ignore + ] + if 'sub_indices' in self.config: + indices = str_to_indices(self.config['sub_indices']) + synsets = give_synsets_from_indices( + indices, path_to_yaml=self.idx2syn + ) # returns a list of strings self.synset2idx = synset2idx(path_to_yaml=self.idx2syn) files = [] for rpath in relpaths: - syn = rpath.split("/")[0] + syn = rpath.split('/')[0] if syn in synsets: files.append(rpath) return files @@ -65,78 +81,89 @@ class ImageNetBase(Dataset): def _prepare_synset_to_human(self): SIZE = 2655750 - URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" - self.human_dict = os.path.join(self.root, "synset_human.txt") - if (not os.path.exists(self.human_dict) or - not os.path.getsize(self.human_dict)==SIZE): + URL = 'https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1' + self.human_dict = os.path.join(self.root, 'synset_human.txt') + if ( + not os.path.exists(self.human_dict) + or not os.path.getsize(self.human_dict) == SIZE + ): download(URL, self.human_dict) def _prepare_idx_to_synset(self): - URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" - self.idx2syn = os.path.join(self.root, "index_synset.yaml") - if (not os.path.exists(self.idx2syn)): + URL = 'https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1' + self.idx2syn = os.path.join(self.root, 'index_synset.yaml') + if not os.path.exists(self.idx2syn): download(URL, self.idx2syn) def _prepare_human_to_integer_label(self): - URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1" - self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt") - if (not os.path.exists(self.human2integer)): + URL = 'https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1' + self.human2integer = os.path.join( + self.root, 'imagenet1000_clsidx_to_labels.txt' + ) + if not os.path.exists(self.human2integer): download(URL, self.human2integer) - with open(self.human2integer, "r") as f: + with open(self.human2integer, 'r') as f: lines = f.read().splitlines() assert len(lines) == 1000 self.human2integer_dict = dict() for line in lines: - value, key = line.split(":") + value, key = line.split(':') self.human2integer_dict[key] = int(value) def _load(self): - with open(self.txt_filelist, "r") as f: + with open(self.txt_filelist, 'r') as f: self.relpaths = f.read().splitlines() l1 = len(self.relpaths) self.relpaths = self._filter_relpaths(self.relpaths) - print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths))) + print( + 'Removed {} files from filelist during filtering.'.format( + l1 - len(self.relpaths) + ) + ) - self.synsets = [p.split("/")[0] for p in self.relpaths] + self.synsets = [p.split('/')[0] for p in self.relpaths] self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] unique_synsets = np.unique(self.synsets) - class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets)) + class_dict = dict( + (synset, i) for i, synset in enumerate(unique_synsets) + ) if not self.keep_orig_class_label: self.class_labels = [class_dict[s] for s in self.synsets] else: self.class_labels = [self.synset2idx[s] for s in self.synsets] - with open(self.human_dict, "r") as f: + with open(self.human_dict, 'r') as f: human_dict = f.read().splitlines() human_dict = dict(line.split(maxsplit=1) for line in human_dict) self.human_labels = [human_dict[s] for s in self.synsets] labels = { - "relpath": np.array(self.relpaths), - "synsets": np.array(self.synsets), - "class_label": np.array(self.class_labels), - "human_label": np.array(self.human_labels), + 'relpath': np.array(self.relpaths), + 'synsets': np.array(self.synsets), + 'class_label': np.array(self.class_labels), + 'human_label': np.array(self.human_labels), } if self.process_images: - self.size = retrieve(self.config, "size", default=256) - self.data = ImagePaths(self.abspaths, - labels=labels, - size=self.size, - random_crop=self.random_crop, - ) + self.size = retrieve(self.config, 'size', default=256) + self.data = ImagePaths( + self.abspaths, + labels=labels, + size=self.size, + random_crop=self.random_crop, + ) else: self.data = self.abspaths class ImageNetTrain(ImageNetBase): - NAME = "ILSVRC2012_train" - URL = "http://www.image-net.org/challenges/LSVRC/2012/" - AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2" + NAME = 'ILSVRC2012_train' + URL = 'http://www.image-net.org/challenges/LSVRC/2012/' + AT_HASH = 'a306397ccf9c2ead27155983c254227c0fd938e2' FILES = [ - "ILSVRC2012_img_train.tar", + 'ILSVRC2012_img_train.tar', ] SIZES = [ 147897477120, @@ -151,57 +178,64 @@ class ImageNetTrain(ImageNetBase): if self.data_root: self.root = os.path.join(self.data_root, self.NAME) else: - cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) - self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + cachedir = os.environ.get( + 'XDG_CACHE_HOME', os.path.expanduser('~/.cache') + ) + self.root = os.path.join(cachedir, 'autoencoders/data', self.NAME) - self.datadir = os.path.join(self.root, "data") - self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.datadir = os.path.join(self.root, 'data') + self.txt_filelist = os.path.join(self.root, 'filelist.txt') self.expected_length = 1281167 - self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", - default=True) + self.random_crop = retrieve( + self.config, 'ImageNetTrain/random_crop', default=True + ) if not tdu.is_prepared(self.root): # prep - print("Preparing dataset {} in {}".format(self.NAME, self.root)) + print('Preparing dataset {} in {}'.format(self.NAME, self.root)) datadir = self.datadir if not os.path.exists(datadir): path = os.path.join(self.root, self.FILES[0]) - if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + if ( + not os.path.exists(path) + or not os.path.getsize(path) == self.SIZES[0] + ): import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) assert atpath == path - print("Extracting {} to {}".format(path, datadir)) + print('Extracting {} to {}'.format(path, datadir)) os.makedirs(datadir, exist_ok=True) - with tarfile.open(path, "r:") as tar: + with tarfile.open(path, 'r:') as tar: tar.extractall(path=datadir) - print("Extracting sub-tars.") - subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) + print('Extracting sub-tars.') + subpaths = sorted(glob.glob(os.path.join(datadir, '*.tar'))) for subpath in tqdm(subpaths): - subdir = subpath[:-len(".tar")] + subdir = subpath[: -len('.tar')] os.makedirs(subdir, exist_ok=True) - with tarfile.open(subpath, "r:") as tar: + with tarfile.open(subpath, 'r:') as tar: tar.extractall(path=subdir) - filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = glob.glob(os.path.join(datadir, '**', '*.JPEG')) filelist = [os.path.relpath(p, start=datadir) for p in filelist] filelist = sorted(filelist) - filelist = "\n".join(filelist)+"\n" - with open(self.txt_filelist, "w") as f: + filelist = '\n'.join(filelist) + '\n' + with open(self.txt_filelist, 'w') as f: f.write(filelist) tdu.mark_prepared(self.root) class ImageNetValidation(ImageNetBase): - NAME = "ILSVRC2012_validation" - URL = "http://www.image-net.org/challenges/LSVRC/2012/" - AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5" - VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1" + NAME = 'ILSVRC2012_validation' + URL = 'http://www.image-net.org/challenges/LSVRC/2012/' + AT_HASH = '5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5' + VS_URL = 'https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1' FILES = [ - "ILSVRC2012_img_val.tar", - "validation_synset.txt", + 'ILSVRC2012_img_val.tar', + 'validation_synset.txt', ] SIZES = [ 6744924160, @@ -217,39 +251,49 @@ class ImageNetValidation(ImageNetBase): if self.data_root: self.root = os.path.join(self.data_root, self.NAME) else: - cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) - self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) - self.datadir = os.path.join(self.root, "data") - self.txt_filelist = os.path.join(self.root, "filelist.txt") + cachedir = os.environ.get( + 'XDG_CACHE_HOME', os.path.expanduser('~/.cache') + ) + self.root = os.path.join(cachedir, 'autoencoders/data', self.NAME) + self.datadir = os.path.join(self.root, 'data') + self.txt_filelist = os.path.join(self.root, 'filelist.txt') self.expected_length = 50000 - self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", - default=False) + self.random_crop = retrieve( + self.config, 'ImageNetValidation/random_crop', default=False + ) if not tdu.is_prepared(self.root): # prep - print("Preparing dataset {} in {}".format(self.NAME, self.root)) + print('Preparing dataset {} in {}'.format(self.NAME, self.root)) datadir = self.datadir if not os.path.exists(datadir): path = os.path.join(self.root, self.FILES[0]) - if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + if ( + not os.path.exists(path) + or not os.path.getsize(path) == self.SIZES[0] + ): import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) assert atpath == path - print("Extracting {} to {}".format(path, datadir)) + print('Extracting {} to {}'.format(path, datadir)) os.makedirs(datadir, exist_ok=True) - with tarfile.open(path, "r:") as tar: + with tarfile.open(path, 'r:') as tar: tar.extractall(path=datadir) vspath = os.path.join(self.root, self.FILES[1]) - if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: + if ( + not os.path.exists(vspath) + or not os.path.getsize(vspath) == self.SIZES[1] + ): download(self.VS_URL, vspath) - with open(vspath, "r") as f: + with open(vspath, 'r') as f: synset_dict = f.read().splitlines() synset_dict = dict(line.split() for line in synset_dict) - print("Reorganizing into synset folders") + print('Reorganizing into synset folders') synsets = np.unique(list(synset_dict.values())) for s in synsets: os.makedirs(os.path.join(datadir, s), exist_ok=True) @@ -258,21 +302,26 @@ class ImageNetValidation(ImageNetBase): dst = os.path.join(datadir, v) shutil.move(src, dst) - filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = glob.glob(os.path.join(datadir, '**', '*.JPEG')) filelist = [os.path.relpath(p, start=datadir) for p in filelist] filelist = sorted(filelist) - filelist = "\n".join(filelist)+"\n" - with open(self.txt_filelist, "w") as f: + filelist = '\n'.join(filelist) + '\n' + with open(self.txt_filelist, 'w') as f: f.write(filelist) tdu.mark_prepared(self.root) - class ImageNetSR(Dataset): - def __init__(self, size=None, - degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1., - random_crop=True): + def __init__( + self, + size=None, + degradation=None, + downscale_f=4, + min_crop_f=0.5, + max_crop_f=1.0, + random_crop=True, + ): """ Imagenet Superresolution Dataloader Performs following ops in order: @@ -296,67 +345,86 @@ class ImageNetSR(Dataset): self.LR_size = int(size / downscale_f) self.min_crop_f = min_crop_f self.max_crop_f = max_crop_f - assert(max_crop_f <= 1.) + assert max_crop_f <= 1.0 self.center_crop = not random_crop - self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) + self.image_rescaler = albumentations.SmallestMaxSize( + max_size=size, interpolation=cv2.INTER_AREA + ) - self.pil_interpolation = False # gets reset later if incase interp_op is from pillow + self.pil_interpolation = ( + False # gets reset later if incase interp_op is from pillow + ) - if degradation == "bsrgan": - self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f) + if degradation == 'bsrgan': + self.degradation_process = partial( + degradation_fn_bsr, sf=downscale_f + ) - elif degradation == "bsrgan_light": - self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f) + elif degradation == 'bsrgan_light': + self.degradation_process = partial( + degradation_fn_bsr_light, sf=downscale_f + ) else: interpolation_fn = { - "cv_nearest": cv2.INTER_NEAREST, - "cv_bilinear": cv2.INTER_LINEAR, - "cv_bicubic": cv2.INTER_CUBIC, - "cv_area": cv2.INTER_AREA, - "cv_lanczos": cv2.INTER_LANCZOS4, - "pil_nearest": PIL.Image.NEAREST, - "pil_bilinear": PIL.Image.BILINEAR, - "pil_bicubic": PIL.Image.BICUBIC, - "pil_box": PIL.Image.BOX, - "pil_hamming": PIL.Image.HAMMING, - "pil_lanczos": PIL.Image.LANCZOS, + 'cv_nearest': cv2.INTER_NEAREST, + 'cv_bilinear': cv2.INTER_LINEAR, + 'cv_bicubic': cv2.INTER_CUBIC, + 'cv_area': cv2.INTER_AREA, + 'cv_lanczos': cv2.INTER_LANCZOS4, + 'pil_nearest': PIL.Image.NEAREST, + 'pil_bilinear': PIL.Image.BILINEAR, + 'pil_bicubic': PIL.Image.BICUBIC, + 'pil_box': PIL.Image.BOX, + 'pil_hamming': PIL.Image.HAMMING, + 'pil_lanczos': PIL.Image.LANCZOS, }[degradation] - self.pil_interpolation = degradation.startswith("pil_") + self.pil_interpolation = degradation.startswith('pil_') if self.pil_interpolation: - self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn) + self.degradation_process = partial( + TF.resize, + size=self.LR_size, + interpolation=interpolation_fn, + ) else: - self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size, - interpolation=interpolation_fn) + self.degradation_process = albumentations.SmallestMaxSize( + max_size=self.LR_size, interpolation=interpolation_fn + ) def __len__(self): return len(self.base) def __getitem__(self, i): example = self.base[i] - image = Image.open(example["file_path_"]) + image = Image.open(example['file_path_']) - if not image.mode == "RGB": - image = image.convert("RGB") + if not image.mode == 'RGB': + image = image.convert('RGB') image = np.array(image).astype(np.uint8) min_side_len = min(image.shape[:2]) - crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None) + crop_side_len = min_side_len * np.random.uniform( + self.min_crop_f, self.max_crop_f, size=None + ) crop_side_len = int(crop_side_len) if self.center_crop: - self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len) + self.cropper = albumentations.CenterCrop( + height=crop_side_len, width=crop_side_len + ) else: - self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len) + self.cropper = albumentations.RandomCrop( + height=crop_side_len, width=crop_side_len + ) - image = self.cropper(image=image)["image"] - image = self.image_rescaler(image=image)["image"] + image = self.cropper(image=image)['image'] + image = self.image_rescaler(image=image)['image'] if self.pil_interpolation: image_pil = PIL.Image.fromarray(image) @@ -364,10 +432,10 @@ class ImageNetSR(Dataset): LR_image = np.array(LR_image).astype(np.uint8) else: - LR_image = self.degradation_process(image=image)["image"] + LR_image = self.degradation_process(image=image)['image'] - example["image"] = (image/127.5 - 1.0).astype(np.float32) - example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) + example['image'] = (image / 127.5 - 1.0).astype(np.float32) + example['LR_image'] = (LR_image / 127.5 - 1.0).astype(np.float32) return example @@ -377,9 +445,11 @@ class ImageNetSRTrain(ImageNetSR): super().__init__(**kwargs) def get_base(self): - with open("data/imagenet_train_hr_indices.p", "rb") as f: + with open('data/imagenet_train_hr_indices.p', 'rb') as f: indices = pickle.load(f) - dset = ImageNetTrain(process_images=False,) + dset = ImageNetTrain( + process_images=False, + ) return Subset(dset, indices) @@ -388,7 +458,9 @@ class ImageNetSRValidation(ImageNetSR): super().__init__(**kwargs) def get_base(self): - with open("data/imagenet_val_hr_indices.p", "rb") as f: + with open('data/imagenet_val_hr_indices.p', 'rb') as f: indices = pickle.load(f) - dset = ImageNetValidation(process_images=False,) + dset = ImageNetValidation( + process_images=False, + ) return Subset(dset, indices) diff --git a/ldm/data/lsun.py b/ldm/data/lsun.py index 6256e45715..4a7ecb147e 100644 --- a/ldm/data/lsun.py +++ b/ldm/data/lsun.py @@ -7,30 +7,33 @@ from torchvision import transforms class LSUNBase(Dataset): - def __init__(self, - txt_file, - data_root, - size=None, - interpolation="bicubic", - flip_p=0.5 - ): + def __init__( + self, + txt_file, + data_root, + size=None, + interpolation='bicubic', + flip_p=0.5, + ): self.data_paths = txt_file self.data_root = data_root - with open(self.data_paths, "r") as f: + with open(self.data_paths, 'r') as f: self.image_paths = f.read().splitlines() self._length = len(self.image_paths) self.labels = { - "relative_file_path_": [l for l in self.image_paths], - "file_path_": [os.path.join(self.data_root, l) - for l in self.image_paths], + 'relative_file_path_': [l for l in self.image_paths], + 'file_path_': [ + os.path.join(self.data_root, l) for l in self.image_paths + ], } self.size = size - self.interpolation = {"linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, - }[interpolation] + self.interpolation = { + 'linear': PIL.Image.LINEAR, + 'bilinear': PIL.Image.BILINEAR, + 'bicubic': PIL.Image.BICUBIC, + 'lanczos': PIL.Image.LANCZOS, + }[interpolation] self.flip = transforms.RandomHorizontalFlip(p=flip_p) def __len__(self): @@ -38,55 +41,86 @@ class LSUNBase(Dataset): def __getitem__(self, i): example = dict((k, self.labels[k][i]) for k in self.labels) - image = Image.open(example["file_path_"]) - if not image.mode == "RGB": - image = image.convert("RGB") + image = Image.open(example['file_path_']) + if not image.mode == 'RGB': + image = image.convert('RGB') # default to score-sde preprocessing img = np.array(image).astype(np.uint8) crop = min(img.shape[0], img.shape[1]) - h, w, = img.shape[0], img.shape[1] - img = img[(h - crop) // 2:(h + crop) // 2, - (w - crop) // 2:(w + crop) // 2] + h, w, = ( + img.shape[0], + img.shape[1], + ) + img = img[ + (h - crop) // 2 : (h + crop) // 2, + (w - crop) // 2 : (w + crop) // 2, + ] image = Image.fromarray(img) if self.size is not None: - image = image.resize((self.size, self.size), resample=self.interpolation) + image = image.resize( + (self.size, self.size), resample=self.interpolation + ) image = self.flip(image) image = np.array(image).astype(np.uint8) - example["image"] = (image / 127.5 - 1.0).astype(np.float32) + example['image'] = (image / 127.5 - 1.0).astype(np.float32) return example class LSUNChurchesTrain(LSUNBase): def __init__(self, **kwargs): - super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) + super().__init__( + txt_file='data/lsun/church_outdoor_train.txt', + data_root='data/lsun/churches', + **kwargs + ) class LSUNChurchesValidation(LSUNBase): - def __init__(self, flip_p=0., **kwargs): - super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", - flip_p=flip_p, **kwargs) + def __init__(self, flip_p=0.0, **kwargs): + super().__init__( + txt_file='data/lsun/church_outdoor_val.txt', + data_root='data/lsun/churches', + flip_p=flip_p, + **kwargs + ) class LSUNBedroomsTrain(LSUNBase): def __init__(self, **kwargs): - super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) + super().__init__( + txt_file='data/lsun/bedrooms_train.txt', + data_root='data/lsun/bedrooms', + **kwargs + ) class LSUNBedroomsValidation(LSUNBase): def __init__(self, flip_p=0.0, **kwargs): - super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", - flip_p=flip_p, **kwargs) + super().__init__( + txt_file='data/lsun/bedrooms_val.txt', + data_root='data/lsun/bedrooms', + flip_p=flip_p, + **kwargs + ) class LSUNCatsTrain(LSUNBase): def __init__(self, **kwargs): - super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) + super().__init__( + txt_file='data/lsun/cat_train.txt', + data_root='data/lsun/cats', + **kwargs + ) class LSUNCatsValidation(LSUNBase): - def __init__(self, flip_p=0., **kwargs): - super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", - flip_p=flip_p, **kwargs) + def __init__(self, flip_p=0.0, **kwargs): + super().__init__( + txt_file='data/lsun/cat_val.txt', + data_root='data/lsun/cats', + flip_p=flip_p, + **kwargs + ) diff --git a/ldm/data/personalized.py b/ldm/data/personalized.py index c8a57d09fa..15fc8a8d2d 100644 --- a/ldm/data/personalized.py +++ b/ldm/data/personalized.py @@ -72,31 +72,57 @@ imagenet_dual_templates_small = [ ] per_img_token_list = [ - 'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת', + 'א', + 'ב', + 'ג', + 'ד', + 'ה', + 'ו', + 'ז', + 'ח', + 'ט', + 'י', + 'כ', + 'ל', + 'מ', + 'נ', + 'ס', + 'ע', + 'פ', + 'צ', + 'ק', + 'ר', + 'ש', + 'ת', ] + class PersonalizedBase(Dataset): - def __init__(self, - data_root, - size=None, - repeats=100, - interpolation="bicubic", - flip_p=0.5, - set="train", - placeholder_token="*", - per_image_tokens=False, - center_crop=False, - mixing_prob=0.25, - coarse_class_text=None, - ): + def __init__( + self, + data_root, + size=None, + repeats=100, + interpolation='bicubic', + flip_p=0.5, + set='train', + placeholder_token='*', + per_image_tokens=False, + center_crop=False, + mixing_prob=0.25, + coarse_class_text=None, + ): self.data_root = data_root - self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] + self.image_paths = [ + os.path.join(self.data_root, file_path) + for file_path in os.listdir(self.data_root) + ] # self._length = len(self.image_paths) self.num_images = len(self.image_paths) - self._length = self.num_images + self._length = self.num_images self.placeholder_token = placeholder_token @@ -107,17 +133,20 @@ class PersonalizedBase(Dataset): self.coarse_class_text = coarse_class_text if per_image_tokens: - assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'." + assert self.num_images < len( + per_img_token_list + ), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'." - if set == "train": + if set == 'train': self._length = self.num_images * repeats self.size = size - self.interpolation = {"linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, - }[interpolation] + self.interpolation = { + 'linear': PIL.Image.LINEAR, + 'bilinear': PIL.Image.BILINEAR, + 'bicubic': PIL.Image.BICUBIC, + 'lanczos': PIL.Image.LANCZOS, + }[interpolation] self.flip = transforms.RandomHorizontalFlip(p=flip_p) def __len__(self): @@ -127,34 +156,47 @@ class PersonalizedBase(Dataset): example = {} image = Image.open(self.image_paths[i % self.num_images]) - if not image.mode == "RGB": - image = image.convert("RGB") + if not image.mode == 'RGB': + image = image.convert('RGB') placeholder_string = self.placeholder_token if self.coarse_class_text: - placeholder_string = f"{self.coarse_class_text} {placeholder_string}" + placeholder_string = ( + f'{self.coarse_class_text} {placeholder_string}' + ) if self.per_image_tokens and np.random.uniform() < self.mixing_prob: - text = random.choice(imagenet_dual_templates_small).format(placeholder_string, per_img_token_list[i % self.num_images]) + text = random.choice(imagenet_dual_templates_small).format( + placeholder_string, per_img_token_list[i % self.num_images] + ) else: - text = random.choice(imagenet_templates_small).format(placeholder_string) - - example["caption"] = text + text = random.choice(imagenet_templates_small).format( + placeholder_string + ) + + example['caption'] = text # default to score-sde preprocessing img = np.array(image).astype(np.uint8) - + if self.center_crop: crop = min(img.shape[0], img.shape[1]) - h, w, = img.shape[0], img.shape[1] - img = img[(h - crop) // 2:(h + crop) // 2, - (w - crop) // 2:(w + crop) // 2] + h, w, = ( + img.shape[0], + img.shape[1], + ) + img = img[ + (h - crop) // 2 : (h + crop) // 2, + (w - crop) // 2 : (w + crop) // 2, + ] image = Image.fromarray(img) if self.size is not None: - image = image.resize((self.size, self.size), resample=self.interpolation) + image = image.resize( + (self.size, self.size), resample=self.interpolation + ) image = self.flip(image) image = np.array(image).astype(np.uint8) - example["image"] = (image / 127.5 - 1.0).astype(np.float32) - return example \ No newline at end of file + example['image'] = (image / 127.5 - 1.0).astype(np.float32) + return example diff --git a/ldm/data/personalized_style.py b/ldm/data/personalized_style.py index b6be7b15c4..56d77d7e81 100644 --- a/ldm/data/personalized_style.py +++ b/ldm/data/personalized_style.py @@ -50,29 +50,55 @@ imagenet_dual_templates_small = [ ] per_img_token_list = [ - 'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת', + 'א', + 'ב', + 'ג', + 'ד', + 'ה', + 'ו', + 'ז', + 'ח', + 'ט', + 'י', + 'כ', + 'ל', + 'מ', + 'נ', + 'ס', + 'ע', + 'פ', + 'צ', + 'ק', + 'ר', + 'ש', + 'ת', ] + class PersonalizedBase(Dataset): - def __init__(self, - data_root, - size=None, - repeats=100, - interpolation="bicubic", - flip_p=0.5, - set="train", - placeholder_token="*", - per_image_tokens=False, - center_crop=False, - ): + def __init__( + self, + data_root, + size=None, + repeats=100, + interpolation='bicubic', + flip_p=0.5, + set='train', + placeholder_token='*', + per_image_tokens=False, + center_crop=False, + ): self.data_root = data_root - self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] + self.image_paths = [ + os.path.join(self.data_root, file_path) + for file_path in os.listdir(self.data_root) + ] # self._length = len(self.image_paths) self.num_images = len(self.image_paths) - self._length = self.num_images + self._length = self.num_images self.placeholder_token = placeholder_token @@ -80,17 +106,20 @@ class PersonalizedBase(Dataset): self.center_crop = center_crop if per_image_tokens: - assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'." + assert self.num_images < len( + per_img_token_list + ), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'." - if set == "train": + if set == 'train': self._length = self.num_images * repeats self.size = size - self.interpolation = {"linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, - }[interpolation] + self.interpolation = { + 'linear': PIL.Image.LINEAR, + 'bilinear': PIL.Image.BILINEAR, + 'bicubic': PIL.Image.BICUBIC, + 'lanczos': PIL.Image.LANCZOS, + }[interpolation] self.flip = transforms.RandomHorizontalFlip(p=flip_p) def __len__(self): @@ -100,30 +129,41 @@ class PersonalizedBase(Dataset): example = {} image = Image.open(self.image_paths[i % self.num_images]) - if not image.mode == "RGB": - image = image.convert("RGB") + if not image.mode == 'RGB': + image = image.convert('RGB') if self.per_image_tokens and np.random.uniform() < 0.25: - text = random.choice(imagenet_dual_templates_small).format(self.placeholder_token, per_img_token_list[i % self.num_images]) + text = random.choice(imagenet_dual_templates_small).format( + self.placeholder_token, per_img_token_list[i % self.num_images] + ) else: - text = random.choice(imagenet_templates_small).format(self.placeholder_token) - - example["caption"] = text + text = random.choice(imagenet_templates_small).format( + self.placeholder_token + ) + + example['caption'] = text # default to score-sde preprocessing img = np.array(image).astype(np.uint8) - + if self.center_crop: crop = min(img.shape[0], img.shape[1]) - h, w, = img.shape[0], img.shape[1] - img = img[(h - crop) // 2:(h + crop) // 2, - (w - crop) // 2:(w + crop) // 2] + h, w, = ( + img.shape[0], + img.shape[1], + ) + img = img[ + (h - crop) // 2 : (h + crop) // 2, + (w - crop) // 2 : (w + crop) // 2, + ] image = Image.fromarray(img) if self.size is not None: - image = image.resize((self.size, self.size), resample=self.interpolation) + image = image.resize( + (self.size, self.size), resample=self.interpolation + ) image = self.flip(image) image = np.array(image).astype(np.uint8) - example["image"] = (image / 127.5 - 1.0).astype(np.float32) - return example \ No newline at end of file + example['image'] = (image / 127.5 - 1.0).astype(np.float32) + return example diff --git a/ldm/dream/pngwriter.py b/ldm/dream/pngwriter.py index ecbbbd4ff7..3a3f205512 100644 --- a/ldm/dream/pngwriter.py +++ b/ldm/dream/pngwriter.py @@ -1,4 +1,4 @@ -''' +""" Two helper classes for dealing with PNG images and their path names. PngWriter -- Converts Images generated by T2I into PNGs, finds appropriate names for them, and writes prompt metadata @@ -7,95 +7,104 @@ PngWriter -- Converts Images generated by T2I into PNGs, finds prompt for file/directory names. PromptFormatter -- Utility for converting a Namespace of prompt parameters back into a formatted prompt string with command-line switches. -''' +""" import os import re -from math import sqrt,floor,ceil -from PIL import Image,PngImagePlugin +from math import sqrt, floor, ceil +from PIL import Image, PngImagePlugin # -------------------image generation utils----- class PngWriter: - - def __init__(self,outdir,prompt=None,batch_size=1): - self.outdir = outdir - self.batch_size = batch_size - self.prompt = prompt - self.filepath = None - self.files_written = [] + def __init__(self, outdir, prompt=None, batch_size=1): + self.outdir = outdir + self.batch_size = batch_size + self.prompt = prompt + self.filepath = None + self.files_written = [] os.makedirs(outdir, exist_ok=True) - def write_image(self,image,seed): - self.filepath = self.unique_filename(seed,self.filepath) # will increment name in some sensible way + def write_image(self, image, seed): + self.filepath = self.unique_filename( + seed, self.filepath + ) # will increment name in some sensible way try: prompt = f'{self.prompt} -S{seed}' - self.save_image_and_prompt_to_png(image,prompt,self.filepath) + self.save_image_and_prompt_to_png(image, prompt, self.filepath) except IOError as e: print(e) - self.files_written.append([self.filepath,seed]) + self.files_written.append([self.filepath, seed]) - def unique_filename(self,seed,previouspath=None): + def unique_filename(self, seed, previouspath=None): revision = 1 if previouspath is None: # sort reverse alphabetically until we find max+1 - dirlist = sorted(os.listdir(self.outdir),reverse=True) + dirlist = sorted(os.listdir(self.outdir), reverse=True) # find the first filename that matches our pattern or return 000000.0.png - filename = next((f for f in dirlist if re.match('^(\d+)\..*\.png',f)),'0000000.0.png') - basecount = int(filename.split('.',1)[0]) + filename = next( + (f for f in dirlist if re.match('^(\d+)\..*\.png', f)), + '0000000.0.png', + ) + basecount = int(filename.split('.', 1)[0]) basecount += 1 if self.batch_size > 1: filename = f'{basecount:06}.{seed}.01.png' else: filename = f'{basecount:06}.{seed}.png' - return os.path.join(self.outdir,filename) + return os.path.join(self.outdir, filename) else: basename = os.path.basename(previouspath) - x = re.match('^(\d+)\..*\.png',basename) + x = re.match('^(\d+)\..*\.png', basename) if not x: - return self.unique_filename(seed,previouspath) + return self.unique_filename(seed, previouspath) basecount = int(x.groups()[0]) - series = 0 + series = 0 finished = False while not finished: series += 1 filename = f'{basecount:06}.{seed}.png' - if self.batch_size>1 or os.path.exists(os.path.join(self.outdir,filename)): + if self.batch_size > 1 or os.path.exists( + os.path.join(self.outdir, filename) + ): filename = f'{basecount:06}.{seed}.{series:02}.png' - finished = not os.path.exists(os.path.join(self.outdir,filename)) - return os.path.join(self.outdir,filename) + finished = not os.path.exists( + os.path.join(self.outdir, filename) + ) + return os.path.join(self.outdir, filename) - def save_image_and_prompt_to_png(self,image,prompt,path): + def save_image_and_prompt_to_png(self, image, prompt, path): info = PngImagePlugin.PngInfo() - info.add_text("Dream",prompt) - image.save(path,"PNG",pnginfo=info) + info.add_text('Dream', prompt) + image.save(path, 'PNG', pnginfo=info) - def make_grid(self,image_list,rows=None,cols=None): + def make_grid(self, image_list, rows=None, cols=None): image_cnt = len(image_list) - if None in (rows,cols): + if None in (rows, cols): rows = floor(sqrt(image_cnt)) # try to make it square - cols = ceil(image_cnt/rows) - width = image_list[0].width + cols = ceil(image_cnt / rows) + width = image_list[0].width height = image_list[0].height - grid_img = Image.new('RGB',(width*cols,height*rows)) - for r in range(0,rows): - for c in range (0,cols): - i = r*rows + c - grid_img.paste(image_list[i],(c*width,r*height)) + grid_img = Image.new('RGB', (width * cols, height * rows)) + for r in range(0, rows): + for c in range(0, cols): + i = r * rows + c + grid_img.paste(image_list[i], (c * width, r * height)) return grid_img - -class PromptFormatter(): - def __init__(self,t2i,opt): + + +class PromptFormatter: + def __init__(self, t2i, opt): self.t2i = t2i self.opt = opt def normalize_prompt(self): - '''Normalize the prompt and switches''' - t2i = self.t2i - opt = self.opt + """Normalize the prompt and switches""" + t2i = self.t2i + opt = self.opt switches = list() switches.append(f'"{opt.prompt}"') @@ -114,4 +123,3 @@ class PromptFormatter(): if t2i.full_precision: switches.append('-F') return ' '.join(switches) - diff --git a/ldm/dream/readline.py b/ldm/dream/readline.py index f46ac6e23a..f40fe83316 100644 --- a/ldm/dream/readline.py +++ b/ldm/dream/readline.py @@ -1,37 +1,40 @@ -''' +""" Readline helper functions for dream.py (linux and mac only). -''' +""" import os import re import atexit + # ---------------readline utilities--------------------- try: import readline + readline_available = True except: readline_available = False -class Completer(): - def __init__(self,options): + +class Completer: + def __init__(self, options): self.options = sorted(options) return - def complete(self,text,state): + def complete(self, text, state): buffer = readline.get_line_buffer() - if text.startswith(('-I','--init_img')): - return self._path_completions(text,state,('.png')) + if text.startswith(('-I', '--init_img')): + return self._path_completions(text, state, ('.png')) - if buffer.strip().endswith('cd') or text.startswith(('.','/')): - return self._path_completions(text,state,()) + if buffer.strip().endswith('cd') or text.startswith(('.', '/')): + return self._path_completions(text, state, ()) response = None if state == 0: # This is the first time for this text, so build a match list. if text: - self.matches = [s - for s in self.options - if s and s.startswith(text)] + self.matches = [ + s for s in self.options if s and s.startswith(text) + ] else: self.matches = self.options[:] @@ -43,32 +46,34 @@ class Completer(): response = None return response - def _path_completions(self,text,state,extensions): + def _path_completions(self, text, state, extensions): # get the path so far if text.startswith('-I'): - path = text.replace('-I','',1).lstrip() + path = text.replace('-I', '', 1).lstrip() elif text.startswith('--init_img='): - path = text.replace('--init_img=','',1).lstrip() + path = text.replace('--init_img=', '', 1).lstrip() else: path = text - matches = list() + matches = list() path = os.path.expanduser(path) - if len(path)==0: - matches.append(text+'./') + if len(path) == 0: + matches.append(text + './') else: - dir = os.path.dirname(path) + dir = os.path.dirname(path) dir_list = os.listdir(dir) for n in dir_list: - if n.startswith('.') and len(n)>1: + if n.startswith('.') and len(n) > 1: continue - full_path = os.path.join(dir,n) + full_path = os.path.join(dir, n) if full_path.startswith(path): if os.path.isdir(full_path): - matches.append(os.path.join(os.path.dirname(text),n)+'/') + matches.append( + os.path.join(os.path.dirname(text), n) + '/' + ) elif n.endswith(extensions): - matches.append(os.path.join(os.path.dirname(text),n)) + matches.append(os.path.join(os.path.dirname(text), n)) try: response = matches[state] @@ -76,19 +81,47 @@ class Completer(): response = None return response + if readline_available: - readline.set_completer(Completer(['cd','pwd', - '--steps','-s','--seed','-S','--iterations','-n','--batch_size','-b', - '--width','-W','--height','-H','--cfg_scale','-C','--grid','-g', - '--individual','-i','--init_img','-I','--strength','-f','-v','--variants']).complete) - readline.set_completer_delims(" ") + readline.set_completer( + Completer( + [ + 'cd', + 'pwd', + '--steps', + '-s', + '--seed', + '-S', + '--iterations', + '-n', + '--batch_size', + '-b', + '--width', + '-W', + '--height', + '-H', + '--cfg_scale', + '-C', + '--grid', + '-g', + '--individual', + '-i', + '--init_img', + '-I', + '--strength', + '-f', + '-v', + '--variants', + ] + ).complete + ) + readline.set_completer_delims(' ') readline.parse_and_bind('tab: complete') - histfile = os.path.join(os.path.expanduser('~'),".dream_history") + histfile = os.path.join(os.path.expanduser('~'), '.dream_history') try: readline.read_history_file(histfile) readline.set_history_length(1000) except FileNotFoundError: pass - atexit.register(readline.write_history_file,histfile) - + atexit.register(readline.write_history_file, histfile) diff --git a/ldm/lr_scheduler.py b/ldm/lr_scheduler.py index be39da9ca6..79c1d1978e 100644 --- a/ldm/lr_scheduler.py +++ b/ldm/lr_scheduler.py @@ -5,32 +5,49 @@ class LambdaWarmUpCosineScheduler: """ note: use with a base_lr of 1.0 """ - def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): + + def __init__( + self, + warm_up_steps, + lr_min, + lr_max, + lr_start, + max_decay_steps, + verbosity_interval=0, + ): self.lr_warm_up_steps = warm_up_steps self.lr_start = lr_start self.lr_min = lr_min self.lr_max = lr_max self.lr_max_decay_steps = max_decay_steps - self.last_lr = 0. + self.last_lr = 0.0 self.verbosity_interval = verbosity_interval def schedule(self, n, **kwargs): if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n % self.verbosity_interval == 0: + print( + f'current step: {n}, recent lr-multiplier: {self.last_lr}' + ) if n < self.lr_warm_up_steps: - lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start + lr = ( + self.lr_max - self.lr_start + ) / self.lr_warm_up_steps * n + self.lr_start self.last_lr = lr return lr else: - t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) + t = (n - self.lr_warm_up_steps) / ( + self.lr_max_decay_steps - self.lr_warm_up_steps + ) t = min(t, 1.0) lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( - 1 + np.cos(t * np.pi)) + 1 + np.cos(t * np.pi) + ) self.last_lr = lr return lr def __call__(self, n, **kwargs): - return self.schedule(n,**kwargs) + return self.schedule(n, **kwargs) class LambdaWarmUpCosineScheduler2: @@ -38,15 +55,30 @@ class LambdaWarmUpCosineScheduler2: supports repeated iterations, configurable via lists note: use with a base_lr of 1.0. """ - def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): - assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) + + def __init__( + self, + warm_up_steps, + f_min, + f_max, + f_start, + cycle_lengths, + verbosity_interval=0, + ): + assert ( + len(warm_up_steps) + == len(f_min) + == len(f_max) + == len(f_start) + == len(cycle_lengths) + ) self.lr_warm_up_steps = warm_up_steps self.f_start = f_start self.f_min = f_min self.f_max = f_max self.cycle_lengths = cycle_lengths self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) - self.last_f = 0. + self.last_f = 0.0 self.verbosity_interval = verbosity_interval def find_in_interval(self, n): @@ -60,17 +92,25 @@ class LambdaWarmUpCosineScheduler2: cycle = self.find_in_interval(n) n = n - self.cum_cycles[cycle] if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " - f"current cycle {cycle}") + if n % self.verbosity_interval == 0: + print( + f'current step: {n}, recent lr-multiplier: {self.last_f}, ' + f'current cycle {cycle}' + ) if n < self.lr_warm_up_steps[cycle]: - f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + f = ( + self.f_max[cycle] - self.f_start[cycle] + ) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] self.last_f = f return f else: - t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) + t = (n - self.lr_warm_up_steps[cycle]) / ( + self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle] + ) t = min(t, 1.0) - f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( - 1 + np.cos(t * np.pi)) + f = self.f_min[cycle] + 0.5 * ( + self.f_max[cycle] - self.f_min[cycle] + ) * (1 + np.cos(t * np.pi)) self.last_f = f return f @@ -79,20 +119,25 @@ class LambdaWarmUpCosineScheduler2: class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): - def schedule(self, n, **kwargs): cycle = self.find_in_interval(n) n = n - self.cum_cycles[cycle] if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " - f"current cycle {cycle}") + if n % self.verbosity_interval == 0: + print( + f'current step: {n}, recent lr-multiplier: {self.last_f}, ' + f'current cycle {cycle}' + ) if n < self.lr_warm_up_steps[cycle]: - f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + f = ( + self.f_max[cycle] - self.f_start[cycle] + ) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] self.last_f = f return f else: - f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * ( + self.cycle_lengths[cycle] - n + ) / (self.cycle_lengths[cycle]) self.last_f = f return f - diff --git a/ldm/models/autoencoder.py b/ldm/models/autoencoder.py index 6a9c4f4549..359f5688d1 100644 --- a/ldm/models/autoencoder.py +++ b/ldm/models/autoencoder.py @@ -6,29 +6,32 @@ from contextlib import contextmanager from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer from ldm.modules.diffusionmodules.model import Encoder, Decoder -from ldm.modules.distributions.distributions import DiagonalGaussianDistribution +from ldm.modules.distributions.distributions import ( + DiagonalGaussianDistribution, +) from ldm.util import instantiate_from_config class VQModel(pl.LightningModule): - def __init__(self, - ddconfig, - lossconfig, - n_embed, - embed_dim, - ckpt_path=None, - ignore_keys=[], - image_key="image", - colorize_nlabels=None, - monitor=None, - batch_resize_range=None, - scheduler_config=None, - lr_g_factor=1.0, - remap=None, - sane_index_shape=False, # tell vector quantizer to return indices as bhw - use_ema=False - ): + def __init__( + self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key='image', + colorize_nlabels=None, + monitor=None, + batch_resize_range=None, + scheduler_config=None, + lr_g_factor=1.0, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + use_ema=False, + ): super().__init__() self.embed_dim = embed_dim self.n_embed = n_embed @@ -36,24 +39,34 @@ class VQModel(pl.LightningModule): self.encoder = Encoder(**ddconfig) self.decoder = Decoder(**ddconfig) self.loss = instantiate_from_config(lossconfig) - self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, - remap=remap, - sane_index_shape=sane_index_shape) - self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.quantize = VectorQuantizer( + n_embed, + embed_dim, + beta=0.25, + remap=remap, + sane_index_shape=sane_index_shape, + ) + self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d( + embed_dim, ddconfig['z_channels'], 1 + ) if colorize_nlabels is not None: - assert type(colorize_nlabels)==int - self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + assert type(colorize_nlabels) == int + self.register_buffer( + 'colorize', torch.randn(3, colorize_nlabels, 1, 1) + ) if monitor is not None: self.monitor = monitor self.batch_resize_range = batch_resize_range if self.batch_resize_range is not None: - print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") + print( + f'{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.' + ) self.use_ema = use_ema if self.use_ema: self.model_ema = LitEma(self) - print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.') if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) @@ -66,28 +79,30 @@ class VQModel(pl.LightningModule): self.model_ema.store(self.parameters()) self.model_ema.copy_to(self) if context is not None: - print(f"{context}: Switched to EMA weights") + print(f'{context}: Switched to EMA weights') try: yield None finally: if self.use_ema: self.model_ema.restore(self.parameters()) if context is not None: - print(f"{context}: Restored training weights") + print(f'{context}: Restored training weights') def init_from_ckpt(self, path, ignore_keys=list()): - sd = torch.load(path, map_location="cpu")["state_dict"] + sd = torch.load(path, map_location='cpu')['state_dict'] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) + print('Deleting key {} from state_dict.'.format(k)) del sd[k] missing, unexpected = self.load_state_dict(sd, strict=False) - print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + print( + f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys' + ) if len(missing) > 0: - print(f"Missing Keys: {missing}") - print(f"Unexpected Keys: {unexpected}") + print(f'Missing Keys: {missing}') + print(f'Unexpected Keys: {unexpected}') def on_train_batch_end(self, *args, **kwargs): if self.use_ema: @@ -115,7 +130,7 @@ class VQModel(pl.LightningModule): return dec def forward(self, input, return_pred_indices=False): - quant, diff, (_,_,ind) = self.encode(input) + quant, diff, (_, _, ind) = self.encode(input) dec = self.decode(quant) if return_pred_indices: return dec, diff, ind @@ -125,7 +140,11 @@ class VQModel(pl.LightningModule): x = batch[k] if len(x.shape) == 3: x = x[..., None] - x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + x = ( + x.permute(0, 3, 1, 2) + .to(memory_format=torch.contiguous_format) + .float() + ) if self.batch_resize_range is not None: lower_size = self.batch_resize_range[0] upper_size = self.batch_resize_range[1] @@ -133,9 +152,11 @@ class VQModel(pl.LightningModule): # do the first few batches with max size to avoid later oom new_resize = upper_size else: - new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) + new_resize = np.random.choice( + np.arange(lower_size, upper_size + 16, 16) + ) if new_resize != x.shape[2]: - x = F.interpolate(x, size=new_resize, mode="bicubic") + x = F.interpolate(x, size=new_resize, mode='bicubic') x = x.detach() return x @@ -147,81 +168,139 @@ class VQModel(pl.LightningModule): if optimizer_idx == 0: # autoencode - aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train", - predicted_indices=ind) + aeloss, log_dict_ae = self.loss( + qloss, + x, + xrec, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split='train', + predicted_indices=ind, + ) - self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + self.log_dict( + log_dict_ae, + prog_bar=False, + logger=True, + on_step=True, + on_epoch=True, + ) return aeloss if optimizer_idx == 1: # discriminator - discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") - self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + discloss, log_dict_disc = self.loss( + qloss, + x, + xrec, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split='train', + ) + self.log_dict( + log_dict_disc, + prog_bar=False, + logger=True, + on_step=True, + on_epoch=True, + ) return discloss def validation_step(self, batch, batch_idx): log_dict = self._validation_step(batch, batch_idx) with self.ema_scope(): - log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") + log_dict_ema = self._validation_step( + batch, batch_idx, suffix='_ema' + ) return log_dict - def _validation_step(self, batch, batch_idx, suffix=""): + def _validation_step(self, batch, batch_idx, suffix=''): x = self.get_input(batch, self.image_key) xrec, qloss, ind = self(x, return_pred_indices=True) - aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, - self.global_step, - last_layer=self.get_last_layer(), - split="val"+suffix, - predicted_indices=ind - ) + aeloss, log_dict_ae = self.loss( + qloss, + x, + xrec, + 0, + self.global_step, + last_layer=self.get_last_layer(), + split='val' + suffix, + predicted_indices=ind, + ) - discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, - self.global_step, - last_layer=self.get_last_layer(), - split="val"+suffix, - predicted_indices=ind - ) - rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] - self.log(f"val{suffix}/rec_loss", rec_loss, - prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) - self.log(f"val{suffix}/aeloss", aeloss, - prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + discloss, log_dict_disc = self.loss( + qloss, + x, + xrec, + 1, + self.global_step, + last_layer=self.get_last_layer(), + split='val' + suffix, + predicted_indices=ind, + ) + rec_loss = log_dict_ae[f'val{suffix}/rec_loss'] + self.log( + f'val{suffix}/rec_loss', + rec_loss, + prog_bar=True, + logger=True, + on_step=False, + on_epoch=True, + sync_dist=True, + ) + self.log( + f'val{suffix}/aeloss', + aeloss, + prog_bar=True, + logger=True, + on_step=False, + on_epoch=True, + sync_dist=True, + ) if version.parse(pl.__version__) >= version.parse('1.4.0'): - del log_dict_ae[f"val{suffix}/rec_loss"] + del log_dict_ae[f'val{suffix}/rec_loss'] self.log_dict(log_dict_ae) self.log_dict(log_dict_disc) return self.log_dict def configure_optimizers(self): lr_d = self.learning_rate - lr_g = self.lr_g_factor*self.learning_rate - print("lr_d", lr_d) - print("lr_g", lr_g) - opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ - list(self.decoder.parameters())+ - list(self.quantize.parameters())+ - list(self.quant_conv.parameters())+ - list(self.post_quant_conv.parameters()), - lr=lr_g, betas=(0.5, 0.9)) - opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), - lr=lr_d, betas=(0.5, 0.9)) + lr_g = self.lr_g_factor * self.learning_rate + print('lr_d', lr_d) + print('lr_g', lr_g) + opt_ae = torch.optim.Adam( + list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.quantize.parameters()) + + list(self.quant_conv.parameters()) + + list(self.post_quant_conv.parameters()), + lr=lr_g, + betas=(0.5, 0.9), + ) + opt_disc = torch.optim.Adam( + self.loss.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.9) + ) if self.scheduler_config is not None: scheduler = instantiate_from_config(self.scheduler_config) - print("Setting up LambdaLR scheduler...") + print('Setting up LambdaLR scheduler...') scheduler = [ { - 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), + 'scheduler': LambdaLR( + opt_ae, lr_lambda=scheduler.schedule + ), 'interval': 'step', - 'frequency': 1 + 'frequency': 1, }, { - 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), + 'scheduler': LambdaLR( + opt_disc, lr_lambda=scheduler.schedule + ), 'interval': 'step', - 'frequency': 1 + 'frequency': 1, }, ] return [opt_ae, opt_disc], scheduler @@ -235,7 +314,7 @@ class VQModel(pl.LightningModule): x = self.get_input(batch, self.image_key) x = x.to(self.device) if only_inputs: - log["inputs"] = x + log['inputs'] = x return log xrec, _ = self(x) if x.shape[1] > 3: @@ -243,21 +322,24 @@ class VQModel(pl.LightningModule): assert xrec.shape[1] > 3 x = self.to_rgb(x) xrec = self.to_rgb(xrec) - log["inputs"] = x - log["reconstructions"] = xrec + log['inputs'] = x + log['reconstructions'] = xrec if plot_ema: with self.ema_scope(): xrec_ema, _ = self(x) - if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) - log["reconstructions_ema"] = xrec_ema + if x.shape[1] > 3: + xrec_ema = self.to_rgb(xrec_ema) + log['reconstructions_ema'] = xrec_ema return log def to_rgb(self, x): - assert self.image_key == "segmentation" - if not hasattr(self, "colorize"): - self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + assert self.image_key == 'segmentation' + if not hasattr(self, 'colorize'): + self.register_buffer( + 'colorize', torch.randn(3, x.shape[1], 1, 1).to(x) + ) x = F.conv2d(x, weight=self.colorize) - x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 return x @@ -283,43 +365,50 @@ class VQModelInterface(VQModel): class AutoencoderKL(pl.LightningModule): - def __init__(self, - ddconfig, - lossconfig, - embed_dim, - ckpt_path=None, - ignore_keys=[], - image_key="image", - colorize_nlabels=None, - monitor=None, - ): + def __init__( + self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key='image', + colorize_nlabels=None, + monitor=None, + ): super().__init__() self.image_key = image_key self.encoder = Encoder(**ddconfig) self.decoder = Decoder(**ddconfig) self.loss = instantiate_from_config(lossconfig) - assert ddconfig["double_z"] - self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + assert ddconfig['double_z'] + self.quant_conv = torch.nn.Conv2d( + 2 * ddconfig['z_channels'], 2 * embed_dim, 1 + ) + self.post_quant_conv = torch.nn.Conv2d( + embed_dim, ddconfig['z_channels'], 1 + ) self.embed_dim = embed_dim if colorize_nlabels is not None: - assert type(colorize_nlabels)==int - self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + assert type(colorize_nlabels) == int + self.register_buffer( + 'colorize', torch.randn(3, colorize_nlabels, 1, 1) + ) if monitor is not None: self.monitor = monitor if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) def init_from_ckpt(self, path, ignore_keys=list()): - sd = torch.load(path, map_location="cpu")["state_dict"] + sd = torch.load(path, map_location='cpu')['state_dict'] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) + print('Deleting key {} from state_dict.'.format(k)) del sd[k] self.load_state_dict(sd, strict=False) - print(f"Restored from {path}") + print(f'Restored from {path}') def encode(self, x): h = self.encoder(x) @@ -345,7 +434,11 @@ class AutoencoderKL(pl.LightningModule): x = batch[k] if len(x.shape) == 3: x = x[..., None] - x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + x = ( + x.permute(0, 3, 1, 2) + .to(memory_format=torch.contiguous_format) + .float() + ) return x def training_step(self, batch, batch_idx, optimizer_idx): @@ -354,44 +447,102 @@ class AutoencoderKL(pl.LightningModule): if optimizer_idx == 0: # train encoder+decoder+logvar - aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") - self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) - self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + aeloss, log_dict_ae = self.loss( + inputs, + reconstructions, + posterior, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split='train', + ) + self.log( + 'aeloss', + aeloss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + self.log_dict( + log_dict_ae, + prog_bar=False, + logger=True, + on_step=True, + on_epoch=False, + ) return aeloss if optimizer_idx == 1: # train the discriminator - discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") + discloss, log_dict_disc = self.loss( + inputs, + reconstructions, + posterior, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split='train', + ) - self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) - self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + self.log( + 'discloss', + discloss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + self.log_dict( + log_dict_disc, + prog_bar=False, + logger=True, + on_step=True, + on_epoch=False, + ) return discloss def validation_step(self, batch, batch_idx): inputs = self.get_input(batch, self.image_key) reconstructions, posterior = self(inputs) - aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, - last_layer=self.get_last_layer(), split="val") + aeloss, log_dict_ae = self.loss( + inputs, + reconstructions, + posterior, + 0, + self.global_step, + last_layer=self.get_last_layer(), + split='val', + ) - discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, - last_layer=self.get_last_layer(), split="val") + discloss, log_dict_disc = self.loss( + inputs, + reconstructions, + posterior, + 1, + self.global_step, + last_layer=self.get_last_layer(), + split='val', + ) - self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + self.log('val/rec_loss', log_dict_ae['val/rec_loss']) self.log_dict(log_dict_ae) self.log_dict(log_dict_disc) return self.log_dict def configure_optimizers(self): lr = self.learning_rate - opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ - list(self.decoder.parameters())+ - list(self.quant_conv.parameters())+ - list(self.post_quant_conv.parameters()), - lr=lr, betas=(0.5, 0.9)) - opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), - lr=lr, betas=(0.5, 0.9)) + opt_ae = torch.optim.Adam( + list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.quant_conv.parameters()) + + list(self.post_quant_conv.parameters()), + lr=lr, + betas=(0.5, 0.9), + ) + opt_disc = torch.optim.Adam( + self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9) + ) return [opt_ae, opt_disc], [] def get_last_layer(self): @@ -409,17 +560,19 @@ class AutoencoderKL(pl.LightningModule): assert xrec.shape[1] > 3 x = self.to_rgb(x) xrec = self.to_rgb(xrec) - log["samples"] = self.decode(torch.randn_like(posterior.sample())) - log["reconstructions"] = xrec - log["inputs"] = x + log['samples'] = self.decode(torch.randn_like(posterior.sample())) + log['reconstructions'] = xrec + log['inputs'] = x return log def to_rgb(self, x): - assert self.image_key == "segmentation" - if not hasattr(self, "colorize"): - self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + assert self.image_key == 'segmentation' + if not hasattr(self, 'colorize'): + self.register_buffer( + 'colorize', torch.randn(3, x.shape[1], 1, 1).to(x) + ) x = F.conv2d(x, weight=self.colorize) - x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 return x diff --git a/ldm/models/diffusion/classifier.py b/ldm/models/diffusion/classifier.py index 67e98b9d8f..be0d8c1919 100644 --- a/ldm/models/diffusion/classifier.py +++ b/ldm/models/diffusion/classifier.py @@ -10,13 +10,13 @@ from einops import rearrange from glob import glob from natsort import natsorted -from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel +from ldm.modules.diffusionmodules.openaimodel import ( + EncoderUNetModel, + UNetModel, +) from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config -__models__ = { - 'class_label': EncoderUNetModel, - 'segmentation': UNetModel -} +__models__ = {'class_label': EncoderUNetModel, 'segmentation': UNetModel} def disabled_train(self, mode=True): @@ -26,37 +26,49 @@ def disabled_train(self, mode=True): class NoisyLatentImageClassifier(pl.LightningModule): - - def __init__(self, - diffusion_path, - num_classes, - ckpt_path=None, - pool='attention', - label_key=None, - diffusion_ckpt_path=None, - scheduler_config=None, - weight_decay=1.e-2, - log_steps=10, - monitor='val/loss', - *args, - **kwargs): + def __init__( + self, + diffusion_path, + num_classes, + ckpt_path=None, + pool='attention', + label_key=None, + diffusion_ckpt_path=None, + scheduler_config=None, + weight_decay=1.0e-2, + log_steps=10, + monitor='val/loss', + *args, + **kwargs, + ): super().__init__(*args, **kwargs) self.num_classes = num_classes # get latest config of diffusion model - diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] + diffusion_config = natsorted( + glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')) + )[-1] self.diffusion_config = OmegaConf.load(diffusion_config).model self.diffusion_config.params.ckpt_path = diffusion_ckpt_path self.load_diffusion() self.monitor = monitor - self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 - self.log_time_interval = self.diffusion_model.num_timesteps // log_steps + self.numd = ( + self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 + ) + self.log_time_interval = ( + self.diffusion_model.num_timesteps // log_steps + ) self.log_steps = log_steps - self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ + self.label_key = ( + label_key + if not hasattr(self.diffusion_model, 'cond_stage_key') else self.diffusion_model.cond_stage_key + ) - assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' + assert ( + self.label_key is not None + ), 'label_key neither in diffusion model nor in model.params' if self.label_key not in __models__: raise NotImplementedError() @@ -68,22 +80,27 @@ class NoisyLatentImageClassifier(pl.LightningModule): self.weight_decay = weight_decay def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): - sd = torch.load(path, map_location="cpu") - if "state_dict" in list(sd.keys()): - sd = sd["state_dict"] + sd = torch.load(path, map_location='cpu') + if 'state_dict' in list(sd.keys()): + sd = sd['state_dict'] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) + print('Deleting key {} from state_dict.'.format(k)) del sd[k] - missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( - sd, strict=False) - print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + missing, unexpected = ( + self.load_state_dict(sd, strict=False) + if not only_model + else self.model.load_state_dict(sd, strict=False) + ) + print( + f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys' + ) if len(missing) > 0: - print(f"Missing Keys: {missing}") + print(f'Missing Keys: {missing}') if len(unexpected) > 0: - print(f"Unexpected Keys: {unexpected}") + print(f'Unexpected Keys: {unexpected}') def load_diffusion(self): model = instantiate_from_config(self.diffusion_config) @@ -93,17 +110,25 @@ class NoisyLatentImageClassifier(pl.LightningModule): param.requires_grad = False def load_classifier(self, ckpt_path, pool): - model_config = deepcopy(self.diffusion_config.params.unet_config.params) - model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels + model_config = deepcopy( + self.diffusion_config.params.unet_config.params + ) + model_config.in_channels = ( + self.diffusion_config.params.unet_config.params.out_channels + ) model_config.out_channels = self.num_classes if self.label_key == 'class_label': model_config.pool = pool self.model = __models__[self.label_key](**model_config) if ckpt_path is not None: - print('#####################################################################') + print( + '#####################################################################' + ) print(f'load from ckpt "{ckpt_path}"') - print('#####################################################################') + print( + '#####################################################################' + ) self.init_from_ckpt(ckpt_path) @torch.no_grad() @@ -111,11 +136,19 @@ class NoisyLatentImageClassifier(pl.LightningModule): noise = default(noise, lambda: torch.randn_like(x)) continuous_sqrt_alpha_cumprod = None if self.diffusion_model.use_continuous_noise: - continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) + continuous_sqrt_alpha_cumprod = ( + self.diffusion_model.sample_continuous_noise_level( + x.shape[0], t + 1 + ) + ) # todo: make sure t+1 is correct here - return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, - continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) + return self.diffusion_model.q_sample( + x_start=x, + t=t, + noise=noise, + continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod, + ) def forward(self, x_noisy, t, *args, **kwargs): return self.model(x_noisy, t) @@ -141,17 +174,21 @@ class NoisyLatentImageClassifier(pl.LightningModule): targets = rearrange(targets, 'b h w c -> b c h w') for down in range(self.numd): h, w = targets.shape[-2:] - targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') + targets = F.interpolate( + targets, size=(h // 2, w // 2), mode='nearest' + ) # targets = rearrange(targets,'b c h w -> b h w c') return targets - def compute_top_k(self, logits, labels, k, reduction="mean"): + def compute_top_k(self, logits, labels, k, reduction='mean'): _, top_ks = torch.topk(logits, k, dim=1) - if reduction == "mean": - return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() - elif reduction == "none": + if reduction == 'mean': + return ( + (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() + ) + elif reduction == 'none': return (top_ks == labels[:, None]).float().sum(dim=-1) def on_train_epoch_start(self): @@ -162,29 +199,59 @@ class NoisyLatentImageClassifier(pl.LightningModule): def write_logs(self, loss, logits, targets): log_prefix = 'train' if self.training else 'val' log = {} - log[f"{log_prefix}/loss"] = loss.mean() - log[f"{log_prefix}/acc@1"] = self.compute_top_k( - logits, targets, k=1, reduction="mean" + log[f'{log_prefix}/loss'] = loss.mean() + log[f'{log_prefix}/acc@1'] = self.compute_top_k( + logits, targets, k=1, reduction='mean' ) - log[f"{log_prefix}/acc@5"] = self.compute_top_k( - logits, targets, k=5, reduction="mean" + log[f'{log_prefix}/acc@5'] = self.compute_top_k( + logits, targets, k=5, reduction='mean' ) - self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) - self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) - self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) + self.log_dict( + log, + prog_bar=False, + logger=True, + on_step=self.training, + on_epoch=True, + ) + self.log( + 'loss', log[f'{log_prefix}/loss'], prog_bar=True, logger=False + ) + self.log( + 'global_step', + self.global_step, + logger=False, + on_epoch=False, + prog_bar=True, + ) lr = self.optimizers().param_groups[0]['lr'] - self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) + self.log( + 'lr_abs', + lr, + on_step=True, + logger=True, + on_epoch=False, + prog_bar=True, + ) def shared_step(self, batch, t=None): - x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) + x, *_ = self.diffusion_model.get_input( + batch, k=self.diffusion_model.first_stage_key + ) targets = self.get_conditioning(batch) if targets.dim() == 4: targets = targets.argmax(dim=1) if t is None: - t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() + t = torch.randint( + 0, + self.diffusion_model.num_timesteps, + (x.shape[0],), + device=self.device, + ).long() else: - t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() + t = torch.full( + size=(x.shape[0],), fill_value=t, device=self.device + ).long() x_noisy = self.get_x_noisy(x, t) logits = self(x_noisy, t) @@ -200,8 +267,14 @@ class NoisyLatentImageClassifier(pl.LightningModule): return loss def reset_noise_accs(self): - self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in - range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} + self.noisy_acc = { + t: {'acc@1': [], 'acc@5': []} + for t in range( + 0, + self.diffusion_model.num_timesteps, + self.diffusion_model.log_every_t, + ) + } def on_validation_start(self): self.reset_noise_accs() @@ -212,24 +285,35 @@ class NoisyLatentImageClassifier(pl.LightningModule): for t in self.noisy_acc: _, logits, _, targets = self.shared_step(batch, t) - self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) - self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) + self.noisy_acc[t]['acc@1'].append( + self.compute_top_k(logits, targets, k=1, reduction='mean') + ) + self.noisy_acc[t]['acc@5'].append( + self.compute_top_k(logits, targets, k=5, reduction='mean') + ) return loss def configure_optimizers(self): - optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) + optimizer = AdamW( + self.model.parameters(), + lr=self.learning_rate, + weight_decay=self.weight_decay, + ) if self.use_scheduler: scheduler = instantiate_from_config(self.scheduler_config) - print("Setting up LambdaLR scheduler...") + print('Setting up LambdaLR scheduler...') scheduler = [ { - 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), + 'scheduler': LambdaLR( + optimizer, lr_lambda=scheduler.schedule + ), 'interval': 'step', - 'frequency': 1 - }] + 'frequency': 1, + } + ] return [optimizer], scheduler return optimizer @@ -243,7 +327,7 @@ class NoisyLatentImageClassifier(pl.LightningModule): y = self.get_conditioning(batch) if self.label_key == 'class_label': - y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + y = log_txt_as_img((x.shape[2], x.shape[3]), batch['human_label']) log['labels'] = y if ismap(y): @@ -256,10 +340,14 @@ class NoisyLatentImageClassifier(pl.LightningModule): log[f'inputs@t{current_time}'] = x_noisy - pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) + pred = F.one_hot( + logits.argmax(dim=1), num_classes=self.num_classes + ) pred = rearrange(pred, 'b h w c -> b c h w') - log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) + log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb( + pred + ) for key in log: log[key] = log[key][:N] diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index ddf786b5a8..3d9086eb1d 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -5,12 +5,16 @@ import numpy as np from tqdm import tqdm from functools import partial -from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \ - extract_into_tensor +from ldm.modules.diffusionmodules.util import ( + make_ddim_sampling_parameters, + make_ddim_timesteps, + noise_like, + extract_into_tensor, +) class DDIMSampler(object): - def __init__(self, model, schedule="linear", device="cuda", **kwargs): + def __init__(self, model, schedule='linear', device='cuda', **kwargs): super().__init__() self.model = model self.ddpm_num_timesteps = model.num_timesteps @@ -23,70 +27,122 @@ class DDIMSampler(object): attr = attr.to(torch.device(self.device)) setattr(self, name, attr) - def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): - self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + def make_schedule( + self, + ddim_num_steps, + ddim_discretize='uniform', + ddim_eta=0.0, + verbose=True, + ): + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) alphas_cumprod = self.model.alphas_cumprod - assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' - to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + assert ( + alphas_cumprod.shape[0] == self.ddpm_num_timesteps + ), 'alphas have to be defined for each timestep' + to_torch = ( + lambda x: x.clone() + .detach() + .to(torch.float32) + .to(self.model.device) + ) self.register_buffer('betas', to_torch(self.model.betas)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + self.register_buffer( + 'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev) + ) # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + self.register_buffer( + 'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())) + ) + self.register_buffer( + 'sqrt_one_minus_alphas_cumprod', + to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), + ) + self.register_buffer( + 'log_one_minus_alphas_cumprod', + to_torch(np.log(1.0 - alphas_cumprod.cpu())), + ) + self.register_buffer( + 'sqrt_recip_alphas_cumprod', + to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())), + ) + self.register_buffer( + 'sqrt_recipm1_alphas_cumprod', + to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), + ) # ddim sampling parameters - ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta,verbose=verbose) + ( + ddim_sigmas, + ddim_alphas, + ddim_alphas_prev, + ) = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose, + ) self.register_buffer('ddim_sigmas', ddim_sigmas) self.register_buffer('ddim_alphas', ddim_alphas) self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) - self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + self.register_buffer( + 'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas) + ) sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( - 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) - self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer( + 'ddim_sigmas_for_original_num_steps', + sigmas_for_original_sampling_steps, + ) @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs - ): + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs, + ): if conditioning is not None: if isinstance(conditioning, dict): cbs = conditioning[list(conditioning.keys())[0]].shape[0] if cbs != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + print( + f'Warning: Got {cbs} conditionings but batch-size is {batch_size}' + ) else: if conditioning.shape[0] != batch_size: - print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + print( + f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}' + ) self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) # sampling @@ -94,30 +150,47 @@ class DDIMSampler(object): size = (batch_size, C, H, W) print(f'Data shape for DDIM sampling is {size}, eta {eta}') - samples, intermediates = self.ddim_sampling(conditioning, size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - ) + samples, intermediates = self.ddim_sampling( + conditioning, + size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) return samples, intermediates @torch.no_grad() - def ddim_sampling(self, cond, shape, - x_T=None, ddim_use_original_steps=False, - callback=None, timesteps=None, quantize_denoised=False, - mask=None, x0=None, img_callback=None, log_every_t=100, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None,): + def ddim_sampling( + self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + ): device = self.model.betas.device b = shape[0] if x_T is None: @@ -126,17 +199,38 @@ class DDIMSampler(object): img = x_T if timesteps is None: - timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + timesteps = ( + self.ddpm_num_timesteps + if ddim_use_original_steps + else self.ddim_timesteps + ) elif timesteps is not None and not ddim_use_original_steps: - subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + subset_end = ( + int( + min(timesteps / self.ddim_timesteps.shape[0], 1) + * self.ddim_timesteps.shape[0] + ) + - 1 + ) timesteps = self.ddim_timesteps[:subset_end] intermediates = {'x_inter': [img], 'pred_x0': [img]} - time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) - total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] - print(f"Running DDIM Sampling with {total_steps} timesteps") + time_range = ( + reversed(range(0, timesteps)) + if ddim_use_original_steps + else np.flip(timesteps) + ) + total_steps = ( + timesteps if ddim_use_original_steps else timesteps.shape[0] + ) + print(f'Running DDIM Sampling with {total_steps} timesteps') - iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps, dynamic_ncols=True) + iterator = tqdm( + time_range, + desc='DDIM Sampler', + total=total_steps, + dynamic_ncols=True, + ) for i, step in enumerate(iterator): index = total_steps - i - 1 @@ -144,18 +238,30 @@ class DDIMSampler(object): if mask is not None: assert x0 is not None - img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? - img = img_orig * mask + (1. - mask) * img + img_orig = self.model.q_sample( + x0, ts + ) # TODO: deterministic forward pass? + img = img_orig * mask + (1.0 - mask) * img - outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, temperature=temperature, - noise_dropout=noise_dropout, score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning) + outs = self.p_sample_ddim( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) img, pred_x0 = outs - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) if index % log_every_t == 0 or index == total_steps - 1: intermediates['x_inter'].append(img) @@ -164,42 +270,82 @@ class DDIMSampler(object): return img, intermediates @torch.no_grad() - def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None): + def p_sample_ddim( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + ): b, *_, device = *x.shape, x.device - if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + if ( + unconditional_conditioning is None + or unconditional_guidance_scale == 1.0 + ): e_t = self.model.apply_model(x, t, c) else: x_in = torch.cat([x] * 2) t_in = torch.cat([t] * 2) c_in = torch.cat([unconditional_conditioning, c]) e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) - e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + e_t = e_t_uncond + unconditional_guidance_scale * ( + e_t - e_t_uncond + ) if score_corrector is not None: - assert self.model.parameterization == "eps" - e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + assert self.model.parameterization == 'eps' + e_t = score_corrector.modify_score( + self.model, e_t, x, t, c, **corrector_kwargs + ) - alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas - alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev - sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas - sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + alphas = ( + self.model.alphas_cumprod + if use_original_steps + else self.ddim_alphas + ) + alphas_prev = ( + self.model.alphas_cumprod_prev + if use_original_steps + else self.ddim_alphas_prev + ) + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod + if use_original_steps + else self.ddim_sqrt_one_minus_alphas + ) + sigmas = ( + self.model.ddim_sigmas_for_original_num_steps + if use_original_steps + else self.ddim_sigmas + ) # select parameters corresponding to the currently considered timestep a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + sqrt_one_minus_at = torch.full( + (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device + ) # current prediction for x_0 pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() if quantize_denoised: pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) # direction pointing to x_t - dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t - noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t + noise = ( + sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + ) + if noise_dropout > 0.0: noise = torch.nn.functional.dropout(noise, p=noise_dropout) x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise return x_prev, pred_x0 @@ -217,26 +363,51 @@ class DDIMSampler(object): if noise is None: noise = torch.randn_like(x0) - return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + - extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) + return ( + extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) + * noise + ) @torch.no_grad() - def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, - use_original_steps=False): + def decode( + self, + x_latent, + cond, + t_start, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_original_steps=False, + ): - timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps + timesteps = ( + np.arange(self.ddpm_num_timesteps) + if use_original_steps + else self.ddim_timesteps + ) timesteps = timesteps[:t_start] time_range = np.flip(timesteps) total_steps = timesteps.shape[0] - print(f"Running DDIM Sampling with {total_steps} timesteps") + print(f'Running DDIM Sampling with {total_steps} timesteps') iterator = tqdm(time_range, desc='Decoding image', total=total_steps) x_dec = x_latent for i, step in enumerate(iterator): index = total_steps - i - 1 - ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) - x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning) + ts = torch.full( + (x_latent.shape[0],), + step, + device=x_latent.device, + dtype=torch.long, + ) + x_dec, _ = self.p_sample_ddim( + x_dec, + cond, + ts, + index=index, + use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) return x_dec diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index d5f74a0fbe..ccfffa9b9b 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -21,17 +21,39 @@ from torchvision.utils import make_grid from pytorch_lightning.utilities.distributed import rank_zero_only import urllib -from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config +from ldm.util import ( + log_txt_as_img, + exists, + default, + ismap, + isimage, + mean_flat, + count_params, + instantiate_from_config, +) from ldm.modules.ema import LitEma -from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution -from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL -from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from ldm.modules.distributions.distributions import ( + normal_kl, + DiagonalGaussianDistribution, +) +from ldm.models.autoencoder import ( + VQModelInterface, + IdentityFirstStage, + AutoencoderKL, +) +from ldm.modules.diffusionmodules.util import ( + make_beta_schedule, + extract_into_tensor, + noise_like, +) from ldm.models.diffusion.ddim import DDIMSampler -__conditioning_keys__ = {'concat': 'c_concat', - 'crossattn': 'c_crossattn', - 'adm': 'y'} +__conditioning_keys__ = { + 'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y', +} def disabled_train(self, mode=True): @@ -46,40 +68,46 @@ def uniform_on_device(r1, r2, shape, device): class DDPM(pl.LightningModule): # classic DDPM with Gaussian diffusion, in image space - def __init__(self, - unet_config, - timesteps=1000, - beta_schedule="linear", - loss_type="l2", - ckpt_path=None, - ignore_keys=[], - load_only_unet=False, - monitor="val/loss", - use_ema=True, - first_stage_key="image", - image_size=256, - channels=3, - log_every_t=100, - clip_denoised=True, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3, - given_betas=None, - original_elbo_weight=0., - embedding_reg_weight=0., - v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta - l_simple_weight=1., - conditioning_key=None, - parameterization="eps", # all assuming fixed variance schedules - scheduler_config=None, - use_positional_encodings=False, - learn_logvar=False, - logvar_init=0., - ): + def __init__( + self, + unet_config, + timesteps=1000, + beta_schedule='linear', + loss_type='l2', + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor='val/loss', + use_ema=True, + first_stage_key='image', + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0.0, + embedding_reg_weight=0.0, + v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1.0, + conditioning_key=None, + parameterization='eps', # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0.0, + ): super().__init__() - assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' + assert parameterization in [ + 'eps', + 'x0', + ], 'currently only supporting "eps" and "x0"' self.parameterization = parameterization - print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + print( + f'{self.__class__.__name__}: Running in {self.parameterization}-prediction mode' + ) self.cond_stage_model = None self.clip_denoised = clip_denoised self.log_every_t = log_every_t @@ -92,7 +120,7 @@ class DDPM(pl.LightningModule): self.use_ema = use_ema if self.use_ema: self.model_ema = LitEma(self.model) - print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.') self.use_scheduler = scheduler_config is not None if self.use_scheduler: @@ -106,68 +134,131 @@ class DDPM(pl.LightningModule): if monitor is not None: self.monitor = monitor if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + self.init_from_ckpt( + ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet + ) - self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, - linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + self.register_schedule( + given_betas=given_betas, + beta_schedule=beta_schedule, + timesteps=timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) self.loss_type = loss_type self.learn_logvar = learn_logvar - self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + self.logvar = torch.full( + fill_value=logvar_init, size=(self.num_timesteps,) + ) if self.learn_logvar: self.logvar = nn.Parameter(self.logvar, requires_grad=True) - - def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, - linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + def register_schedule( + self, + given_betas=None, + beta_schedule='linear', + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): if exists(given_betas): betas = given_betas else: - betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, - cosine_s=cosine_s) - alphas = 1. - betas + betas = make_beta_schedule( + beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + alphas = 1.0 - betas alphas_cumprod = np.cumprod(alphas, axis=0) - alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) - timesteps, = betas.shape + (timesteps,) = betas.shape self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end - assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + assert ( + alphas_cumprod.shape[0] == self.num_timesteps + ), 'alphas have to be defined for each timestep' to_torch = partial(torch.tensor, dtype=torch.float32) self.register_buffer('betas', to_torch(betas)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + self.register_buffer( + 'alphas_cumprod_prev', to_torch(alphas_cumprod_prev) + ) # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + self.register_buffer( + 'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)) + ) + self.register_buffer( + 'sqrt_one_minus_alphas_cumprod', + to_torch(np.sqrt(1.0 - alphas_cumprod)), + ) + self.register_buffer( + 'log_one_minus_alphas_cumprod', + to_torch(np.log(1.0 - alphas_cumprod)), + ) + self.register_buffer( + 'sqrt_recip_alphas_cumprod', + to_torch(np.sqrt(1.0 / alphas_cumprod)), + ) + self.register_buffer( + 'sqrt_recipm1_alphas_cumprod', + to_torch(np.sqrt(1.0 / alphas_cumprod - 1)), + ) # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( - 1. - alphas_cumprod) + self.v_posterior * betas + posterior_variance = (1 - self.v_posterior) * betas * ( + 1.0 - alphas_cumprod_prev + ) / (1.0 - alphas_cumprod) + self.v_posterior * betas # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - self.register_buffer('posterior_variance', to_torch(posterior_variance)) + self.register_buffer( + 'posterior_variance', to_torch(posterior_variance) + ) # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) - self.register_buffer('posterior_mean_coef1', to_torch( - betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) - self.register_buffer('posterior_mean_coef2', to_torch( - (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + self.register_buffer( + 'posterior_log_variance_clipped', + to_torch(np.log(np.maximum(posterior_variance, 1e-20))), + ) + self.register_buffer( + 'posterior_mean_coef1', + to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod) + ), + ) + self.register_buffer( + 'posterior_mean_coef2', + to_torch( + (1.0 - alphas_cumprod_prev) + * np.sqrt(alphas) + / (1.0 - alphas_cumprod) + ), + ) - if self.parameterization == "eps": - lvlb_weights = self.betas ** 2 / ( - 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) - elif self.parameterization == "x0": - lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + if self.parameterization == 'eps': + lvlb_weights = self.betas**2 / ( + 2 + * self.posterior_variance + * to_torch(alphas) + * (1 - self.alphas_cumprod) + ) + elif self.parameterization == 'x0': + lvlb_weights = ( + 0.5 + * np.sqrt(torch.Tensor(alphas_cumprod)) + / (2.0 * 1 - torch.Tensor(alphas_cumprod)) + ) else: - raise NotImplementedError("mu not supported") + raise NotImplementedError('mu not supported') # TODO how to choose this term lvlb_weights[0] = lvlb_weights[1] self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) @@ -179,32 +270,37 @@ class DDPM(pl.LightningModule): self.model_ema.store(self.model.parameters()) self.model_ema.copy_to(self.model) if context is not None: - print(f"{context}: Switched to EMA weights") + print(f'{context}: Switched to EMA weights') try: yield None finally: if self.use_ema: self.model_ema.restore(self.model.parameters()) if context is not None: - print(f"{context}: Restored training weights") + print(f'{context}: Restored training weights') def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): - sd = torch.load(path, map_location="cpu") - if "state_dict" in list(sd.keys()): - sd = sd["state_dict"] + sd = torch.load(path, map_location='cpu') + if 'state_dict' in list(sd.keys()): + sd = sd['state_dict'] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) + print('Deleting key {} from state_dict.'.format(k)) del sd[k] - missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( - sd, strict=False) - print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + missing, unexpected = ( + self.load_state_dict(sd, strict=False) + if not only_model + else self.model.load_state_dict(sd, strict=False) + ) + print( + f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys' + ) if len(missing) > 0: - print(f"Missing Keys: {missing}") + print(f'Missing Keys: {missing}') if len(unexpected) > 0: - print(f"Unexpected Keys: {unexpected}") + print(f'Unexpected Keys: {unexpected}') def q_mean_variance(self, x_start, t): """ @@ -213,46 +309,78 @@ class DDPM(pl.LightningModule): :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :return: A tuple (mean, variance, log_variance), all of x_start's shape. """ - mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) - variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) - log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + mean = ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) + * x_start + ) + variance = extract_into_tensor( + 1.0 - self.alphas_cumprod, t, x_start.shape + ) + log_variance = extract_into_tensor( + self.log_one_minus_alphas_cumprod, t, x_start.shape + ) return mean, variance, log_variance def predict_start_from_noise(self, x_t, t, noise): return ( - extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) + * x_t + - extract_into_tensor( + self.sqrt_recipm1_alphas_cumprod, t, x_t.shape + ) + * noise ) def q_posterior(self, x_start, x_t, t): posterior_mean = ( - extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + - extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) + * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) + * x_t + ) + posterior_variance = extract_into_tensor( + self.posterior_variance, t, x_t.shape + ) + posterior_log_variance_clipped = extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + return ( + posterior_mean, + posterior_variance, + posterior_log_variance_clipped, ) - posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) - posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) - return posterior_mean, posterior_variance, posterior_log_variance_clipped def p_mean_variance(self, x, t, clip_denoised: bool): model_out = self.model(x, t) - if self.parameterization == "eps": + if self.parameterization == 'eps': x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) - elif self.parameterization == "x0": + elif self.parameterization == 'x0': x_recon = model_out if clip_denoised: - x_recon.clamp_(-1., 1.) + x_recon.clamp_(-1.0, 1.0) - model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + ( + model_mean, + posterior_variance, + posterior_log_variance, + ) = self.q_posterior(x_start=x_recon, x_t=x, t=t) return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): b, *_, device = *x.shape, x.device - model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + model_mean, _, model_log_variance = self.p_mean_variance( + x=x, t=t, clip_denoised=clip_denoised + ) noise = noise_like(x.shape, device, repeat_noise) # no noise when t == 0 - nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) - return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + nonzero_mask = (1 - (t == 0).float()).reshape( + b, *((1,) * (len(x.shape) - 1)) + ) + return ( + model_mean + + nonzero_mask * (0.5 * model_log_variance).exp() * noise + ) @torch.no_grad() def p_sample_loop(self, shape, return_intermediates=False): @@ -260,9 +388,17 @@ class DDPM(pl.LightningModule): b = shape[0] img = torch.randn(shape, device=device) intermediates = [img] - for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps, dynamic_ncols=True): - img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), - clip_denoised=self.clip_denoised) + for i in tqdm( + reversed(range(0, self.num_timesteps)), + desc='Sampling t', + total=self.num_timesteps, + dynamic_ncols=True, + ): + img = self.p_sample( + img, + torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised, + ) if i % self.log_every_t == 0 or i == self.num_timesteps - 1: intermediates.append(img) if return_intermediates: @@ -273,13 +409,21 @@ class DDPM(pl.LightningModule): def sample(self, batch_size=16, return_intermediates=False): image_size = self.image_size channels = self.channels - return self.p_sample_loop((batch_size, channels, image_size, image_size), - return_intermediates=return_intermediates) + return self.p_sample_loop( + (batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates, + ) def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) - return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) + * x_start + + extract_into_tensor( + self.sqrt_one_minus_alphas_cumprod, t, x_start.shape + ) + * noise + ) def get_loss(self, pred, target, mean=True): if self.loss_type == 'l1': @@ -290,7 +434,9 @@ class DDPM(pl.LightningModule): if mean: loss = torch.nn.functional.mse_loss(target, pred) else: - loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + loss = torch.nn.functional.mse_loss( + target, pred, reduction='none' + ) else: raise NotImplementedError("unknown loss type '{loss_type}'") @@ -302,12 +448,14 @@ class DDPM(pl.LightningModule): model_out = self.model(x_noisy, t) loss_dict = {} - if self.parameterization == "eps": + if self.parameterization == 'eps': target = noise - elif self.parameterization == "x0": + elif self.parameterization == 'x0': target = x_start else: - raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") + raise NotImplementedError( + f'Paramterization {self.parameterization} not yet supported' + ) loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) @@ -328,7 +476,9 @@ class DDPM(pl.LightningModule): def forward(self, x, *args, **kwargs): # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' - t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + t = torch.randint( + 0, self.num_timesteps, (x.shape[0],), device=self.device + ).long() return self.p_losses(x, t, *args, **kwargs) def get_input(self, batch, k): @@ -347,15 +497,29 @@ class DDPM(pl.LightningModule): def training_step(self, batch, batch_idx): loss, loss_dict = self.shared_step(batch) - self.log_dict(loss_dict, prog_bar=True, - logger=True, on_step=True, on_epoch=True) + self.log_dict( + loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True + ) - self.log("global_step", self.global_step, - prog_bar=True, logger=True, on_step=True, on_epoch=False) + self.log( + 'global_step', + self.global_step, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=False, + ) if self.use_scheduler: lr = self.optimizers().param_groups[0]['lr'] - self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + self.log( + 'lr_abs', + lr, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=False, + ) return loss @@ -364,9 +528,23 @@ class DDPM(pl.LightningModule): _, loss_dict_no_ema = self.shared_step(batch) with self.ema_scope(): _, loss_dict_ema = self.shared_step(batch) - loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} - self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) - self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + loss_dict_ema = { + key + '_ema': loss_dict_ema[key] for key in loss_dict_ema + } + self.log_dict( + loss_dict_no_ema, + prog_bar=False, + logger=True, + on_step=False, + on_epoch=True, + ) + self.log_dict( + loss_dict_ema, + prog_bar=False, + logger=True, + on_step=False, + on_epoch=True, + ) def on_train_batch_end(self, *args, **kwargs): if self.use_ema: @@ -380,13 +558,15 @@ class DDPM(pl.LightningModule): return denoise_grid @torch.no_grad() - def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + def log_images( + self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs + ): log = dict() x = self.get_input(batch, self.first_stage_key) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) x = x.to(self.device)[:N] - log["inputs"] = x + log['inputs'] = x # get diffusion row diffusion_row = list() @@ -400,15 +580,17 @@ class DDPM(pl.LightningModule): x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) diffusion_row.append(x_noisy) - log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + log['diffusion_row'] = self._get_rows_from_list(diffusion_row) if sample: # get denoise row - with self.ema_scope("Plotting"): - samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) + with self.ema_scope('Plotting'): + samples, denoise_row = self.sample( + batch_size=N, return_intermediates=True + ) - log["samples"] = samples - log["denoise_row"] = self._get_rows_from_list(denoise_row) + log['samples'] = samples + log['denoise_row'] = self._get_rows_from_list(denoise_row) if return_keys: if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: @@ -428,19 +610,23 @@ class DDPM(pl.LightningModule): class LatentDiffusion(DDPM): """main class""" - def __init__(self, - first_stage_config, - cond_stage_config, - personalization_config, - num_timesteps_cond=None, - cond_stage_key="image", - cond_stage_trainable=False, - concat_mode=True, - cond_stage_forward=None, - conditioning_key=None, - scale_factor=1.0, - scale_by_std=False, - *args, **kwargs): + + def __init__( + self, + first_stage_config, + cond_stage_config, + personalization_config, + num_timesteps_cond=None, + cond_stage_key='image', + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + *args, + **kwargs, + ): self.num_timesteps_cond = default(num_timesteps_cond, 1) self.scale_by_std = scale_by_std @@ -450,15 +636,17 @@ class LatentDiffusion(DDPM): conditioning_key = 'concat' if concat_mode else 'crossattn' if cond_stage_config == '__is_unconditional__': conditioning_key = None - ckpt_path = kwargs.pop("ckpt_path", None) - ignore_keys = kwargs.pop("ignore_keys", []) + ckpt_path = kwargs.pop('ckpt_path', None) + ignore_keys = kwargs.pop('ignore_keys', []) super().__init__(conditioning_key=conditioning_key, *args, **kwargs) self.concat_mode = concat_mode self.cond_stage_trainable = cond_stage_trainable self.cond_stage_key = cond_stage_key try: - self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + self.num_downs = ( + len(first_stage_config.params.ddconfig.ch_mult) - 1 + ) except: self.num_downs = 0 if not scale_by_std: @@ -470,7 +658,7 @@ class LatentDiffusion(DDPM): self.cond_stage_forward = cond_stage_forward self.clip_denoised = False - self.bbox_tokenizer = None + self.bbox_tokenizer = None self.restarted_from_ckpt = False if ckpt_path is not None: @@ -485,8 +673,10 @@ class LatentDiffusion(DDPM): self.model.train = disabled_train for param in self.model.parameters(): param.requires_grad = False - - self.embedding_manager = self.instantiate_embedding_manager(personalization_config, self.cond_stage_model) + + self.embedding_manager = self.instantiate_embedding_manager( + personalization_config, self.cond_stage_model + ) self.emb_ckpt_counter = 0 @@ -496,32 +686,61 @@ class LatentDiffusion(DDPM): for param in self.embedding_manager.embedding_parameters(): param.requires_grad = True - def make_cond_schedule(self, ): - self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) - ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() - self.cond_ids[:self.num_timesteps_cond] = ids + def make_cond_schedule( + self, + ): + self.cond_ids = torch.full( + size=(self.num_timesteps,), + fill_value=self.num_timesteps - 1, + dtype=torch.long, + ) + ids = torch.round( + torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond) + ).long() + self.cond_ids[: self.num_timesteps_cond] = ids @rank_zero_only @torch.no_grad() def on_train_batch_start(self, batch, batch_idx, dataloader_idx): # only for very first batch - if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: - assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + if ( + self.scale_by_std + and self.current_epoch == 0 + and self.global_step == 0 + and batch_idx == 0 + and not self.restarted_from_ckpt + ): + assert ( + self.scale_factor == 1.0 + ), 'rather not use custom rescaling and std-rescaling simultaneously' # set rescale weight to 1./std of encodings - print("### USING STD-RESCALING ###") + print('### USING STD-RESCALING ###') x = super().get_input(batch, self.first_stage_key) x = x.to(self.device) encoder_posterior = self.encode_first_stage(x) z = self.get_first_stage_encoding(encoder_posterior).detach() del self.scale_factor - self.register_buffer('scale_factor', 1. / z.flatten().std()) - print(f"setting self.scale_factor to {self.scale_factor}") - print("### USING STD-RESCALING ###") + self.register_buffer('scale_factor', 1.0 / z.flatten().std()) + print(f'setting self.scale_factor to {self.scale_factor}') + print('### USING STD-RESCALING ###') - def register_schedule(self, - given_betas=None, beta_schedule="linear", timesteps=1000, - linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): - super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) + def register_schedule( + self, + given_betas=None, + beta_schedule='linear', + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + super().register_schedule( + given_betas, + beta_schedule, + timesteps, + linear_start, + linear_end, + cosine_s, + ) self.shorten_cond_schedule = self.num_timesteps_cond > 1 if self.shorten_cond_schedule: @@ -536,11 +755,13 @@ class LatentDiffusion(DDPM): def instantiate_cond_stage(self, config): if not self.cond_stage_trainable: - if config == "__is_first_stage__": - print("Using first stage also as cond stage.") + if config == '__is_first_stage__': + print('Using first stage also as cond stage.') self.cond_stage_model = self.first_stage_model - elif config == "__is_unconditional__": - print(f"Training {self.__class__.__name__} as an unconditional model.") + elif config == '__is_unconditional__': + print( + f'Training {self.__class__.__name__} as an unconditional model.' + ) self.cond_stage_model = None # self.be_unconditional = True else: @@ -555,23 +776,32 @@ class LatentDiffusion(DDPM): try: model = instantiate_from_config(config) except urllib.error.URLError: - raise SystemExit("* Couldn't load a dependency. Try running scripts/preload_models.py from an internet-conected machine.") + raise SystemExit( + "* Couldn't load a dependency. Try running scripts/preload_models.py from an internet-conected machine." + ) self.cond_stage_model = model - - + def instantiate_embedding_manager(self, config, embedder): model = instantiate_from_config(config, embedder=embedder) - if config.params.get("embedding_manager_ckpt", None): # do not load if missing OR empty string + if config.params.get( + 'embedding_manager_ckpt', None + ): # do not load if missing OR empty string model.load(config.params.embedding_manager_ckpt) - + return model - def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + def _get_denoise_row_from_list( + self, samples, desc='', force_no_decoder_quantization=False + ): denoise_row = [] for zd in tqdm(samples, desc=desc): - denoise_row.append(self.decode_first_stage(zd.to(self.device), - force_not_quantize=force_no_decoder_quantization)) + denoise_row.append( + self.decode_first_stage( + zd.to(self.device), + force_not_quantize=force_no_decoder_quantization, + ) + ) n_imgs_per_row = len(denoise_row) denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') @@ -585,13 +815,19 @@ class LatentDiffusion(DDPM): elif isinstance(encoder_posterior, torch.Tensor): z = encoder_posterior else: - raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + raise NotImplementedError( + f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" + ) return self.scale_factor * z def get_learned_conditioning(self, c): if self.cond_stage_forward is None: - if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): - c = self.cond_stage_model.encode(c, embedding_manager=self.embedding_manager) + if hasattr(self.cond_stage_model, 'encode') and callable( + self.cond_stage_model.encode + ): + c = self.cond_stage_model.encode( + c, embedding_manager=self.embedding_manager + ) if isinstance(c, DiagonalGaussianDistribution): c = c.mode() else: @@ -619,26 +855,37 @@ class LatentDiffusion(DDPM): arr = self.meshgrid(h, w) / lower_right_corner dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] - edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] + edge_dist = torch.min( + torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1 + )[0] return edge_dist def get_weighting(self, h, w, Ly, Lx, device): weighting = self.delta_border(h, w) - weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], - self.split_input_params["clip_max_weight"], ) - weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) + weighting = torch.clip( + weighting, + self.split_input_params['clip_min_weight'], + self.split_input_params['clip_max_weight'], + ) + weighting = ( + weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) + ) - if self.split_input_params["tie_braker"]: + if self.split_input_params['tie_braker']: L_weighting = self.delta_border(Ly, Lx) - L_weighting = torch.clip(L_weighting, - self.split_input_params["clip_min_tie_weight"], - self.split_input_params["clip_max_tie_weight"]) + L_weighting = torch.clip( + L_weighting, + self.split_input_params['clip_min_tie_weight'], + self.split_input_params['clip_max_tie_weight'], + ) L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) weighting = weighting * L_weighting return weighting - def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code + def get_fold_unfold( + self, x, kernel_size, stride, uf=1, df=1 + ): # todo load once not every time, shorten code """ :param x: img of size (bs, c, h, w) :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) @@ -650,40 +897,75 @@ class LatentDiffusion(DDPM): Lx = (w - kernel_size[1]) // stride[1] + 1 if uf == 1 and df == 1: - fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + fold_params = dict( + kernel_size=kernel_size, dilation=1, padding=0, stride=stride + ) unfold = torch.nn.Unfold(**fold_params) fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) - weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) - normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap - weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + weighting = self.get_weighting( + kernel_size[0], kernel_size[1], Ly, Lx, x.device + ).to(x.dtype) + normalization = fold(weighting).view( + 1, 1, h, w + ) # normalizes the overlap + weighting = weighting.view( + (1, 1, kernel_size[0], kernel_size[1], Ly * Lx) + ) elif uf > 1 and df == 1: - fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + fold_params = dict( + kernel_size=kernel_size, dilation=1, padding=0, stride=stride + ) unfold = torch.nn.Unfold(**fold_params) - fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), - dilation=1, padding=0, - stride=(stride[0] * uf, stride[1] * uf)) - fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) + fold_params2 = dict( + kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, + padding=0, + stride=(stride[0] * uf, stride[1] * uf), + ) + fold = torch.nn.Fold( + output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2 + ) - weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) - normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap - weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) + weighting = self.get_weighting( + kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device + ).to(x.dtype) + normalization = fold(weighting).view( + 1, 1, h * uf, w * uf + ) # normalizes the overlap + weighting = weighting.view( + (1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx) + ) elif df > 1 and uf == 1: - fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + fold_params = dict( + kernel_size=kernel_size, dilation=1, padding=0, stride=stride + ) unfold = torch.nn.Unfold(**fold_params) - fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), - dilation=1, padding=0, - stride=(stride[0] // df, stride[1] // df)) - fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) + fold_params2 = dict( + kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, + padding=0, + stride=(stride[0] // df, stride[1] // df), + ) + fold = torch.nn.Fold( + output_size=(x.shape[2] // df, x.shape[3] // df), + **fold_params2, + ) - weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) - normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap - weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) + weighting = self.get_weighting( + kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device + ).to(x.dtype) + normalization = fold(weighting).view( + 1, 1, h // df, w // df + ) # normalizes the overlap + weighting = weighting.view( + (1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx) + ) else: raise NotImplementedError @@ -691,8 +973,16 @@ class LatentDiffusion(DDPM): return fold, unfold, normalization, weighting @torch.no_grad() - def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, - cond_key=None, return_original_cond=False, bs=None): + def get_input( + self, + batch, + k, + return_first_stage_outputs=False, + force_c_encode=False, + cond_key=None, + return_original_cond=False, + bs=None, + ): x = super().get_input(batch, k) if bs is not None: x = x[:bs] @@ -743,155 +1033,211 @@ class LatentDiffusion(DDPM): return out @torch.no_grad() - def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + def decode_first_stage( + self, z, predict_cids=False, force_not_quantize=False + ): if predict_cids: if z.dim() == 4: z = torch.argmax(z.exp(), dim=1).long() - z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = self.first_stage_model.quantize.get_codebook_entry( + z, shape=None + ) z = rearrange(z, 'b h w c -> b c h w').contiguous() - z = 1. / self.scale_factor * z + z = 1.0 / self.scale_factor * z - if hasattr(self, "split_input_params"): - if self.split_input_params["patch_distributed_vq"]: - ks = self.split_input_params["ks"] # eg. (128, 128) - stride = self.split_input_params["stride"] # eg. (64, 64) - uf = self.split_input_params["vqf"] + if hasattr(self, 'split_input_params'): + if self.split_input_params['patch_distributed_vq']: + ks = self.split_input_params['ks'] # eg. (128, 128) + stride = self.split_input_params['stride'] # eg. (64, 64) + uf = self.split_input_params['vqf'] bs, nc, h, w = z.shape if ks[0] > h or ks[1] > w: ks = (min(ks[0], h), min(ks[1], w)) - print("reducing Kernel") + print('reducing Kernel') if stride[0] > h or stride[1] > w: stride = (min(stride[0], h), min(stride[1], w)) - print("reducing stride") + print('reducing stride') - fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + fold, unfold, normalization, weighting = self.get_fold_unfold( + z, ks, stride, uf=uf + ) z = unfold(z) # (bn, nc * prod(**ks), L) # 1. Reshape to img shape - z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + z = z.view( + (z.shape[0], -1, ks[0], ks[1], z.shape[-1]) + ) # (bn, nc, ks[0], ks[1], L ) # 2. apply model loop over last dim if isinstance(self.first_stage_model, VQModelInterface): - output_list = [self.first_stage_model.decode(z[:, :, :, :, i], - force_not_quantize=predict_cids or force_not_quantize) - for i in range(z.shape[-1])] + output_list = [ + self.first_stage_model.decode( + z[:, :, :, :, i], + force_not_quantize=predict_cids + or force_not_quantize, + ) + for i in range(z.shape[-1]) + ] else: - output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) - for i in range(z.shape[-1])] + output_list = [ + self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1]) + ] - o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = torch.stack( + output_list, axis=-1 + ) # # (bn, nc, ks[0], ks[1], L) o = o * weighting # Reverse 1. reshape to img shape - o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + o = o.view( + (o.shape[0], -1, o.shape[-1]) + ) # (bn, nc * ks[0] * ks[1], L) # stitch crops together decoded = fold(o) decoded = decoded / normalization # norm is shape (1, 1, h, w) return decoded else: if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + return self.first_stage_model.decode( + z, + force_not_quantize=predict_cids or force_not_quantize, + ) else: return self.first_stage_model.decode(z) else: if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + return self.first_stage_model.decode( + z, force_not_quantize=predict_cids or force_not_quantize + ) else: return self.first_stage_model.decode(z) # same as above but without decorator - def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + def differentiable_decode_first_stage( + self, z, predict_cids=False, force_not_quantize=False + ): if predict_cids: if z.dim() == 4: z = torch.argmax(z.exp(), dim=1).long() - z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = self.first_stage_model.quantize.get_codebook_entry( + z, shape=None + ) z = rearrange(z, 'b h w c -> b c h w').contiguous() - z = 1. / self.scale_factor * z + z = 1.0 / self.scale_factor * z - if hasattr(self, "split_input_params"): - if self.split_input_params["patch_distributed_vq"]: - ks = self.split_input_params["ks"] # eg. (128, 128) - stride = self.split_input_params["stride"] # eg. (64, 64) - uf = self.split_input_params["vqf"] + if hasattr(self, 'split_input_params'): + if self.split_input_params['patch_distributed_vq']: + ks = self.split_input_params['ks'] # eg. (128, 128) + stride = self.split_input_params['stride'] # eg. (64, 64) + uf = self.split_input_params['vqf'] bs, nc, h, w = z.shape if ks[0] > h or ks[1] > w: ks = (min(ks[0], h), min(ks[1], w)) - print("reducing Kernel") + print('reducing Kernel') if stride[0] > h or stride[1] > w: stride = (min(stride[0], h), min(stride[1], w)) - print("reducing stride") + print('reducing stride') - fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + fold, unfold, normalization, weighting = self.get_fold_unfold( + z, ks, stride, uf=uf + ) z = unfold(z) # (bn, nc * prod(**ks), L) # 1. Reshape to img shape - z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + z = z.view( + (z.shape[0], -1, ks[0], ks[1], z.shape[-1]) + ) # (bn, nc, ks[0], ks[1], L ) # 2. apply model loop over last dim - if isinstance(self.first_stage_model, VQModelInterface): - output_list = [self.first_stage_model.decode(z[:, :, :, :, i], - force_not_quantize=predict_cids or force_not_quantize) - for i in range(z.shape[-1])] + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [ + self.first_stage_model.decode( + z[:, :, :, :, i], + force_not_quantize=predict_cids + or force_not_quantize, + ) + for i in range(z.shape[-1]) + ] else: - output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) - for i in range(z.shape[-1])] + output_list = [ + self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1]) + ] - o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = torch.stack( + output_list, axis=-1 + ) # # (bn, nc, ks[0], ks[1], L) o = o * weighting # Reverse 1. reshape to img shape - o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + o = o.view( + (o.shape[0], -1, o.shape[-1]) + ) # (bn, nc * ks[0] * ks[1], L) # stitch crops together decoded = fold(o) decoded = decoded / normalization # norm is shape (1, 1, h, w) return decoded else: if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + return self.first_stage_model.decode( + z, + force_not_quantize=predict_cids or force_not_quantize, + ) else: return self.first_stage_model.decode(z) else: if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + return self.first_stage_model.decode( + z, force_not_quantize=predict_cids or force_not_quantize + ) else: return self.first_stage_model.decode(z) @torch.no_grad() def encode_first_stage(self, x): - if hasattr(self, "split_input_params"): - if self.split_input_params["patch_distributed_vq"]: - ks = self.split_input_params["ks"] # eg. (128, 128) - stride = self.split_input_params["stride"] # eg. (64, 64) - df = self.split_input_params["vqf"] + if hasattr(self, 'split_input_params'): + if self.split_input_params['patch_distributed_vq']: + ks = self.split_input_params['ks'] # eg. (128, 128) + stride = self.split_input_params['stride'] # eg. (64, 64) + df = self.split_input_params['vqf'] self.split_input_params['original_image_size'] = x.shape[-2:] bs, nc, h, w = x.shape if ks[0] > h or ks[1] > w: ks = (min(ks[0], h), min(ks[1], w)) - print("reducing Kernel") + print('reducing Kernel') if stride[0] > h or stride[1] > w: stride = (min(stride[0], h), min(stride[1], w)) - print("reducing stride") + print('reducing stride') - fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) + fold, unfold, normalization, weighting = self.get_fold_unfold( + x, ks, stride, df=df + ) z = unfold(x) # (bn, nc * prod(**ks), L) # Reshape to img shape - z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + z = z.view( + (z.shape[0], -1, ks[0], ks[1], z.shape[-1]) + ) # (bn, nc, ks[0], ks[1], L ) - output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) - for i in range(z.shape[-1])] + output_list = [ + self.first_stage_model.encode(z[:, :, :, :, i]) + for i in range(z.shape[-1]) + ] o = torch.stack(output_list, axis=-1) o = o * weighting # Reverse reshape to img shape - o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + o = o.view( + (o.shape[0], -1, o.shape[-1]) + ) # (bn, nc * ks[0] * ks[1], L) # stitch crops together decoded = fold(o) decoded = decoded / normalization @@ -908,18 +1254,24 @@ class LatentDiffusion(DDPM): return loss def forward(self, x, c, *args, **kwargs): - t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + t = torch.randint( + 0, self.num_timesteps, (x.shape[0],), device=self.device + ).long() if self.model.conditioning_key is not None: assert c is not None if self.cond_stage_trainable: c = self.get_learned_conditioning(c) if self.shorten_cond_schedule: # TODO: drop this option tc = self.cond_ids[t].to(self.device) - c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + c = self.q_sample( + x_start=c, t=tc, noise=torch.randn_like(c.float()) + ) return self.p_losses(x, c, t, *args, **kwargs) - def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset + def _rescale_annotations( + self, bboxes, crop_coordinates + ): # TODO: move to dataset def rescale_bbox(bbox): x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) @@ -937,42 +1289,65 @@ class LatentDiffusion(DDPM): else: if not isinstance(cond, list): cond = [cond] - key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + key = ( + 'c_concat' + if self.model.conditioning_key == 'concat' + else 'c_crossattn' + ) cond = {key: cond} - if hasattr(self, "split_input_params"): - assert len(cond) == 1 # todo can only deal with one conditioning atm - assert not return_ids - ks = self.split_input_params["ks"] # eg. (128, 128) - stride = self.split_input_params["stride"] # eg. (64, 64) + if hasattr(self, 'split_input_params'): + assert ( + len(cond) == 1 + ) # todo can only deal with one conditioning atm + assert not return_ids + ks = self.split_input_params['ks'] # eg. (128, 128) + stride = self.split_input_params['stride'] # eg. (64, 64) h, w = x_noisy.shape[-2:] - fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride) + fold, unfold, normalization, weighting = self.get_fold_unfold( + x_noisy, ks, stride + ) z = unfold(x_noisy) # (bn, nc * prod(**ks), L) # Reshape to img shape - z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + z = z.view( + (z.shape[0], -1, ks[0], ks[1], z.shape[-1]) + ) # (bn, nc, ks[0], ks[1], L ) z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] - if self.cond_stage_key in ["image", "LR_image", "segmentation", - 'bbox_img'] and self.model.conditioning_key: # todo check for completeness + if ( + self.cond_stage_key + in ['image', 'LR_image', 'segmentation', 'bbox_img'] + and self.model.conditioning_key + ): # todo check for completeness c_key = next(iter(cond.keys())) # get key c = next(iter(cond.values())) # get value - assert (len(c) == 1) # todo extend to list with more than one elem + assert ( + len(c) == 1 + ) # todo extend to list with more than one elem c = c[0] # get element c = unfold(c) - c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + c = c.view( + (c.shape[0], -1, ks[0], ks[1], c.shape[-1]) + ) # (bn, nc, ks[0], ks[1], L ) - cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] + cond_list = [ + {c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1]) + ] elif self.cond_stage_key == 'coordinates_bbox': - assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size' + assert ( + 'original_image_size' in self.split_input_params + ), 'BoudingBoxRescaling is missing original_image_size' # assuming padding of unfold is always 0 and its dilation is always 1 n_patches_per_row = int((w - ks[0]) / stride[0] + 1) - full_img_h, full_img_w = self.split_input_params['original_image_size'] + full_img_h, full_img_w = self.split_input_params[ + 'original_image_size' + ] # as we are operating on latents, we need the factor from the original image size to the # spatial latent size to properly rescale the crops for regenerating the bbox annotations num_downs = self.first_stage_model.encoder.num_resolutions - 1 @@ -980,47 +1355,84 @@ class LatentDiffusion(DDPM): # get top left postions of patches as conforming for the bbbox tokenizer, therefore we # need to rescale the tl patch coordinates to be in between (0,1) - tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w, - rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h) - for patch_nr in range(z.shape[-1])] + tl_patch_coordinates = [ + ( + rescale_latent + * stride[0] + * (patch_nr % n_patches_per_row) + / full_img_w, + rescale_latent + * stride[1] + * (patch_nr // n_patches_per_row) + / full_img_h, + ) + for patch_nr in range(z.shape[-1]) + ] # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w) - patch_limits = [(x_tl, y_tl, - rescale_latent * ks[0] / full_img_w, - rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates] + patch_limits = [ + ( + x_tl, + y_tl, + rescale_latent * ks[0] / full_img_w, + rescale_latent * ks[1] / full_img_h, + ) + for x_tl, y_tl in tl_patch_coordinates + ] # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates] # tokenize crop coordinates for the bounding boxes of the respective patches - patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device) - for bbox in patch_limits] # list of length l with tensors of shape (1, 2) + patch_limits_tknzd = [ + torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[ + None + ].to(self.device) + for bbox in patch_limits + ] # list of length l with tensors of shape (1, 2) print(patch_limits_tknzd[0].shape) # cut tknzd crop position from conditioning - assert isinstance(cond, dict), 'cond must be dict to be fed into model' + assert isinstance( + cond, dict + ), 'cond must be dict to be fed into model' cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device) print(cut_cond.shape) - adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd]) + adapted_cond = torch.stack( + [ + torch.cat([cut_cond, p], dim=1) + for p in patch_limits_tknzd + ] + ) adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n') print(adapted_cond.shape) adapted_cond = self.get_learned_conditioning(adapted_cond) print(adapted_cond.shape) - adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1]) + adapted_cond = rearrange( + adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1] + ) print(adapted_cond.shape) cond_list = [{'c_crossattn': [e]} for e in adapted_cond] else: - cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient + cond_list = [ + cond for i in range(z.shape[-1]) + ] # Todo make this more efficient # apply model by loop over crops - output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])] - assert not isinstance(output_list[0], - tuple) # todo cant deal with multiple model outputs check this never happens + output_list = [ + self.model(z_list[i], t, **cond_list[i]) + for i in range(z.shape[-1]) + ] + assert not isinstance( + output_list[0], tuple + ) # todo cant deal with multiple model outputs check this never happens o = torch.stack(output_list, axis=-1) o = o * weighting # Reverse reshape to img shape - o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + o = o.view( + (o.shape[0], -1, o.shape[-1]) + ) # (bn, nc * ks[0] * ks[1], L) # stitch crops together x_recon = fold(o) / normalization @@ -1033,8 +1445,11 @@ class LatentDiffusion(DDPM): return x_recon def _predict_eps_from_xstart(self, x_t, t, pred_xstart): - return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) + * x_t + - pred_xstart + ) / extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def _prior_bpd(self, x_start): """ @@ -1045,9 +1460,13 @@ class LatentDiffusion(DDPM): :return: a batch of [N] KL values (in bits), one per batch element. """ batch_size = x_start.shape[0] - t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + t = torch.tensor( + [self.num_timesteps - 1] * batch_size, device=x_start.device + ) qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) - kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + kl_prior = normal_kl( + mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 + ) return mean_flat(kl_prior) / np.log(2.0) def p_losses(self, x_start, cond, t, noise=None): @@ -1058,14 +1477,16 @@ class LatentDiffusion(DDPM): loss_dict = {} prefix = 'train' if self.training else 'val' - if self.parameterization == "x0": + if self.parameterization == 'x0': target = x_start - elif self.parameterization == "eps": + elif self.parameterization == 'eps': target = noise else: raise NotImplementedError() - loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_simple = self.get_loss(model_output, target, mean=False).mean( + [1, 2, 3] + ) loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) logvar_t = self.logvar[t].to(self.device) @@ -1077,65 +1498,117 @@ class LatentDiffusion(DDPM): loss = self.l_simple_weight * loss.mean() - loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) + loss_vlb = self.get_loss(model_output, target, mean=False).mean( + dim=(1, 2, 3) + ) loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) - loss += (self.original_elbo_weight * loss_vlb) + loss += self.original_elbo_weight * loss_vlb loss_dict.update({f'{prefix}/loss': loss}) if self.embedding_reg_weight > 0: - loss_embedding_reg = self.embedding_manager.embedding_to_coarse_loss().mean() + loss_embedding_reg = ( + self.embedding_manager.embedding_to_coarse_loss().mean() + ) loss_dict.update({f'{prefix}/loss_emb_reg': loss_embedding_reg}) - loss += (self.embedding_reg_weight * loss_embedding_reg) + loss += self.embedding_reg_weight * loss_embedding_reg loss_dict.update({f'{prefix}/loss': loss}) return loss, loss_dict - def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, - return_x0=False, score_corrector=None, corrector_kwargs=None): + def p_mean_variance( + self, + x, + c, + t, + clip_denoised: bool, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + score_corrector=None, + corrector_kwargs=None, + ): t_in = t - model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + model_out = self.apply_model( + x, t_in, c, return_ids=return_codebook_ids + ) if score_corrector is not None: - assert self.parameterization == "eps" - model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) + assert self.parameterization == 'eps' + model_out = score_corrector.modify_score( + self, model_out, x, t, c, **corrector_kwargs + ) if return_codebook_ids: model_out, logits = model_out - if self.parameterization == "eps": + if self.parameterization == 'eps': x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) - elif self.parameterization == "x0": + elif self.parameterization == 'x0': x_recon = model_out else: raise NotImplementedError() if clip_denoised: - x_recon.clamp_(-1., 1.) + x_recon.clamp_(-1.0, 1.0) if quantize_denoised: - x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) - model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + x_recon, _, [_, _, indices] = self.first_stage_model.quantize( + x_recon + ) + ( + model_mean, + posterior_variance, + posterior_log_variance, + ) = self.q_posterior(x_start=x_recon, x_t=x, t=t) if return_codebook_ids: - return model_mean, posterior_variance, posterior_log_variance, logits + return ( + model_mean, + posterior_variance, + posterior_log_variance, + logits, + ) elif return_x0: - return model_mean, posterior_variance, posterior_log_variance, x_recon + return ( + model_mean, + posterior_variance, + posterior_log_variance, + x_recon, + ) else: return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() - def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, - return_codebook_ids=False, quantize_denoised=False, return_x0=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): + def p_sample( + self, + x, + c, + t, + clip_denoised=False, + repeat_noise=False, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + ): b, *_, device = *x.shape, x.device - outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, - return_codebook_ids=return_codebook_ids, - quantize_denoised=quantize_denoised, - return_x0=return_x0, - score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + outputs = self.p_mean_variance( + x=x, + c=c, + t=t, + clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + ) if return_codebook_ids: - raise DeprecationWarning("Support dropped.") + raise DeprecationWarning('Support dropped.') model_mean, _, model_log_variance, logits = outputs elif return_x0: model_mean, _, model_log_variance, x0 = outputs @@ -1143,23 +1616,49 @@ class LatentDiffusion(DDPM): model_mean, _, model_log_variance = outputs noise = noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: + if noise_dropout > 0.0: noise = torch.nn.functional.dropout(noise, p=noise_dropout) # no noise when t == 0 - nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + nonzero_mask = (1 - (t == 0).float()).reshape( + b, *((1,) * (len(x.shape) - 1)) + ) if return_codebook_ids: - return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) + return model_mean + nonzero_mask * ( + 0.5 * model_log_variance + ).exp() * noise, logits.argmax(dim=1) if return_x0: - return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + return ( + model_mean + + nonzero_mask * (0.5 * model_log_variance).exp() * noise, + x0, + ) else: - return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + return ( + model_mean + + nonzero_mask * (0.5 * model_log_variance).exp() * noise + ) @torch.no_grad() - def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, - img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., - score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, - log_every_t=None): + def progressive_denoising( + self, + cond, + shape, + verbose=True, + callback=None, + quantize_denoised=False, + img_callback=None, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + batch_size=None, + x_T=None, + start_T=None, + log_every_t=None, + ): if not log_every_t: log_every_t = self.log_every_t timesteps = self.num_timesteps @@ -1175,16 +1674,30 @@ class LatentDiffusion(DDPM): intermediates = [] if cond is not None: if isinstance(cond, dict): - cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else - list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + cond = { + key: cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + for key in cond + } else: - cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + cond = ( + [c[:batch_size] for c in cond] + if isinstance(cond, list) + else cond[:batch_size] + ) if start_T is not None: timesteps = min(timesteps, start_T) - iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', - total=timesteps) if verbose else reversed( - range(0, timesteps)) + iterator = ( + tqdm( + reversed(range(0, timesteps)), + desc='Progressive Generation', + total=timesteps, + ) + if verbose + else reversed(range(0, timesteps)) + ) if type(temperature) == float: temperature = [temperature] * timesteps @@ -1193,29 +1706,52 @@ class LatentDiffusion(DDPM): if self.shorten_cond_schedule: assert self.model.conditioning_key != 'hybrid' tc = self.cond_ids[ts].to(cond.device) - cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + cond = self.q_sample( + x_start=cond, t=tc, noise=torch.randn_like(cond) + ) - img, x0_partial = self.p_sample(img, cond, ts, - clip_denoised=self.clip_denoised, - quantize_denoised=quantize_denoised, return_x0=True, - temperature=temperature[i], noise_dropout=noise_dropout, - score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + img, x0_partial = self.p_sample( + img, + cond, + ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, + return_x0=True, + temperature=temperature[i], + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + ) if mask is not None: assert x0 is not None img_orig = self.q_sample(x0, ts) - img = img_orig * mask + (1. - mask) * img + img = img_orig * mask + (1.0 - mask) * img if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(x0_partial) - if callback: callback(i) - if img_callback: img_callback(img, i) + if callback: + callback(i) + if img_callback: + img_callback(img, i) return img, intermediates @torch.no_grad() - def p_sample_loop(self, cond, shape, return_intermediates=False, - x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, - mask=None, x0=None, img_callback=None, start_T=None, - log_every_t=None): + def p_sample_loop( + self, + cond, + shape, + return_intermediates=False, + x_T=None, + verbose=True, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + start_T=None, + log_every_t=None, + ): if not log_every_t: log_every_t = self.log_every_t @@ -1232,100 +1768,170 @@ class LatentDiffusion(DDPM): if start_T is not None: timesteps = min(timesteps, start_T) - iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( - range(0, timesteps)) + iterator = ( + tqdm( + reversed(range(0, timesteps)), + desc='Sampling t', + total=timesteps, + ) + if verbose + else reversed(range(0, timesteps)) + ) if mask is not None: assert x0 is not None - assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + assert ( + x0.shape[2:3] == mask.shape[2:3] + ) # spatial size has to match for i in iterator: ts = torch.full((b,), i, device=device, dtype=torch.long) if self.shorten_cond_schedule: assert self.model.conditioning_key != 'hybrid' tc = self.cond_ids[ts].to(cond.device) - cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + cond = self.q_sample( + x_start=cond, t=tc, noise=torch.randn_like(cond) + ) - img = self.p_sample(img, cond, ts, - clip_denoised=self.clip_denoised, - quantize_denoised=quantize_denoised) + img = self.p_sample( + img, + cond, + ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, + ) if mask is not None: img_orig = self.q_sample(x0, ts) - img = img_orig * mask + (1. - mask) * img + img = img_orig * mask + (1.0 - mask) * img if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(img) - if callback: callback(i) - if img_callback: img_callback(img, i) + if callback: + callback(i) + if img_callback: + img_callback(img, i) if return_intermediates: return img, intermediates return img @torch.no_grad() - def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, - verbose=True, timesteps=None, quantize_denoised=False, - mask=None, x0=None, shape=None,**kwargs): + def sample( + self, + cond, + batch_size=16, + return_intermediates=False, + x_T=None, + verbose=True, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + shape=None, + **kwargs, + ): if shape is None: - shape = (batch_size, self.channels, self.image_size, self.image_size) + shape = ( + batch_size, + self.channels, + self.image_size, + self.image_size, + ) if cond is not None: if isinstance(cond, dict): - cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else - list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + cond = { + key: cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + for key in cond + } else: - cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] - return self.p_sample_loop(cond, - shape, - return_intermediates=return_intermediates, x_T=x_T, - verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, - mask=mask, x0=x0) + cond = ( + [c[:batch_size] for c in cond] + if isinstance(cond, list) + else cond[:batch_size] + ) + return self.p_sample_loop( + cond, + shape, + return_intermediates=return_intermediates, + x_T=x_T, + verbose=verbose, + timesteps=timesteps, + quantize_denoised=quantize_denoised, + mask=mask, + x0=x0, + ) @torch.no_grad() - def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs): + def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): if ddim: ddim_sampler = DDIMSampler(self) shape = (self.channels, self.image_size, self.image_size) - samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size, - shape,cond,verbose=False,**kwargs) + samples, intermediates = ddim_sampler.sample( + ddim_steps, batch_size, shape, cond, verbose=False, **kwargs + ) else: - samples, intermediates = self.sample(cond=cond, batch_size=batch_size, - return_intermediates=True,**kwargs) + samples, intermediates = self.sample( + cond=cond, + batch_size=batch_size, + return_intermediates=True, + **kwargs, + ) return samples, intermediates @torch.no_grad() - def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, - quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False, - plot_diffusion_rows=False, **kwargs): + def log_images( + self, + batch, + N=8, + n_row=4, + sample=True, + ddim_steps=200, + ddim_eta=1.0, + return_keys=None, + quantize_denoised=True, + inpaint=False, + plot_denoise_rows=False, + plot_progressive_rows=False, + plot_diffusion_rows=False, + **kwargs, + ): use_ddim = ddim_steps is not None log = dict() - z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, - return_first_stage_outputs=True, - force_c_encode=True, - return_original_cond=True, - bs=N) + z, c, x, xrec, xc = self.get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N, + ) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) - log["inputs"] = x - log["reconstruction"] = xrec + log['inputs'] = x + log['reconstruction'] = xrec if self.model.conditioning_key is not None: - if hasattr(self.cond_stage_model, "decode"): + if hasattr(self.cond_stage_model, 'decode'): xc = self.cond_stage_model.decode(c) - log["conditioning"] = xc - elif self.cond_stage_key in ["caption"]: - xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"]) - log["conditioning"] = xc + log['conditioning'] = xc + elif self.cond_stage_key in ['caption']: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch['caption']) + log['conditioning'] = xc elif self.cond_stage_key == 'class_label': - xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + xc = log_txt_as_img( + (x.shape[2], x.shape[3]), batch['human_label'] + ) log['conditioning'] = xc elif isimage(xc): - log["conditioning"] = xc + log['conditioning'] = xc if ismap(xc): - log["original_conditioning"] = self.to_rgb(xc) + log['original_conditioning'] = self.to_rgb(xc) if plot_diffusion_rows: # get diffusion row @@ -1339,75 +1945,114 @@ class LatentDiffusion(DDPM): z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) diffusion_row.append(self.decode_first_stage(z_noisy)) - diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_row = torch.stack( + diffusion_row + ) # n_log_step, n_row, C, H, W diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') - diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') - diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) - log["diffusion_row"] = diffusion_grid + diffusion_grid = rearrange( + diffusion_grid, 'b n c h w -> (b n) c h w' + ) + diffusion_grid = make_grid( + diffusion_grid, nrow=diffusion_row.shape[0] + ) + log['diffusion_row'] = diffusion_grid if sample: # get denoise row - with self.ema_scope("Plotting"): - samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, - ddim_steps=ddim_steps,eta=ddim_eta) + with self.ema_scope('Plotting'): + samples, z_denoise_row = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + ) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) x_samples = self.decode_first_stage(samples) - log["samples"] = x_samples + log['samples'] = x_samples if plot_denoise_rows: denoise_grid = self._get_denoise_row_from_list(z_denoise_row) - log["denoise_row"] = denoise_grid - - uc = self.get_learned_conditioning(len(c) * [""]) - sample_scaled, _ = self.sample_log(cond=c, - batch_size=N, - ddim=use_ddim, - ddim_steps=ddim_steps, - eta=ddim_eta, - unconditional_guidance_scale=5.0, - unconditional_conditioning=uc) - log["samples_scaled"] = self.decode_first_stage(sample_scaled) + log['denoise_row'] = denoise_grid - if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( - self.first_stage_model, IdentityFirstStage): + uc = self.get_learned_conditioning(len(c) * ['']) + sample_scaled, _ = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=5.0, + unconditional_conditioning=uc, + ) + log['samples_scaled'] = self.decode_first_stage(sample_scaled) + + if ( + quantize_denoised + and not isinstance(self.first_stage_model, AutoencoderKL) + and not isinstance(self.first_stage_model, IdentityFirstStage) + ): # also display when quantizing x0 while sampling - with self.ema_scope("Plotting Quantized Denoised"): - samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, - ddim_steps=ddim_steps,eta=ddim_eta, - quantize_denoised=True) + with self.ema_scope('Plotting Quantized Denoised'): + samples, z_denoise_row = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + quantize_denoised=True, + ) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, # quantize_denoised=True) x_samples = self.decode_first_stage(samples.to(self.device)) - log["samples_x0_quantized"] = x_samples + log['samples_x0_quantized'] = x_samples if inpaint: # make a simple center square b, h, w = z.shape[0], z.shape[2], z.shape[3] mask = torch.ones(N, h, w).to(self.device) # zeros will be filled in - mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask[:, h // 4 : 3 * h // 4, w // 4 : 3 * w // 4] = 0.0 mask = mask[:, None, ...] - with self.ema_scope("Plotting Inpaint"): + with self.ema_scope('Plotting Inpaint'): - samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta, - ddim_steps=ddim_steps, x0=z[:N], mask=mask) + samples, _ = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + eta=ddim_eta, + ddim_steps=ddim_steps, + x0=z[:N], + mask=mask, + ) x_samples = self.decode_first_stage(samples.to(self.device)) - log["samples_inpainting"] = x_samples - log["mask"] = mask + log['samples_inpainting'] = x_samples + log['mask'] = mask # outpaint - with self.ema_scope("Plotting Outpaint"): - samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta, - ddim_steps=ddim_steps, x0=z[:N], mask=mask) + with self.ema_scope('Plotting Outpaint'): + samples, _ = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + eta=ddim_eta, + ddim_steps=ddim_steps, + x0=z[:N], + mask=mask, + ) x_samples = self.decode_first_stage(samples.to(self.device)) - log["samples_outpainting"] = x_samples + log['samples_outpainting'] = x_samples if plot_progressive_rows: - with self.ema_scope("Plotting Progressives"): - img, progressives = self.progressive_denoising(c, - shape=(self.channels, self.image_size, self.image_size), - batch_size=N) - prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") - log["progressive_row"] = prog_row + with self.ema_scope('Plotting Progressives'): + img, progressives = self.progressive_denoising( + c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N, + ) + prog_row = self._get_denoise_row_from_list( + progressives, desc='Progressive Generation' + ) + log['progressive_row'] = prog_row if return_keys: if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: @@ -1425,7 +2070,9 @@ class LatentDiffusion(DDPM): else: params = list(self.model.parameters()) if self.cond_stage_trainable: - print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + print( + f'{self.__class__.__name__}: Also optimizing conditioner params!' + ) params = params + list(self.cond_stage_model.parameters()) if self.learn_logvar: print('Diffusion model optimizing logvar') @@ -1435,34 +2082,44 @@ class LatentDiffusion(DDPM): assert 'target' in self.scheduler_config scheduler = instantiate_from_config(self.scheduler_config) - print("Setting up LambdaLR scheduler...") + print('Setting up LambdaLR scheduler...') scheduler = [ { 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), 'interval': 'step', - 'frequency': 1 - }] + 'frequency': 1, + } + ] return [opt], scheduler return opt @torch.no_grad() def to_rgb(self, x): x = x.float() - if not hasattr(self, "colorize"): + if not hasattr(self, 'colorize'): self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) x = nn.functional.conv2d(x, weight=self.colorize) - x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 return x @rank_zero_only def on_save_checkpoint(self, checkpoint): checkpoint.clear() - + if os.path.isdir(self.trainer.checkpoint_callback.dirpath): - self.embedding_manager.save(os.path.join(self.trainer.checkpoint_callback.dirpath, "embeddings.pt")) + self.embedding_manager.save( + os.path.join( + self.trainer.checkpoint_callback.dirpath, 'embeddings.pt' + ) + ) if (self.global_step - self.emb_ckpt_counter) > 500: - self.embedding_manager.save(os.path.join(self.trainer.checkpoint_callback.dirpath, f"embeddings_gs-{self.global_step}.pt")) + self.embedding_manager.save( + os.path.join( + self.trainer.checkpoint_callback.dirpath, + f'embeddings_gs-{self.global_step}.pt', + ) + ) self.emb_ckpt_counter += 500 @@ -1472,7 +2129,13 @@ class DiffusionWrapper(pl.LightningModule): super().__init__() self.diffusion_model = instantiate_from_config(diff_model_config) self.conditioning_key = conditioning_key - assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm'] + assert self.conditioning_key in [ + None, + 'concat', + 'crossattn', + 'hybrid', + 'adm', + ] def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): if self.conditioning_key is None: @@ -1499,7 +2162,9 @@ class DiffusionWrapper(pl.LightningModule): class Layout2ImgDiffusion(LatentDiffusion): # TODO: move all layout-specific hacks to this class def __init__(self, cond_stage_key, *args, **kwargs): - assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' + assert ( + cond_stage_key == 'coordinates_bbox' + ), 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) def log_images(self, batch, N=8, *args, **kwargs): @@ -1510,9 +2175,13 @@ class Layout2ImgDiffusion(LatentDiffusion): mapper = dset.conditional_builders[self.cond_stage_key] bbox_imgs = [] - map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno)) + map_fn = lambda catno: dset.get_textual_label( + dset.get_category_id(catno) + ) for tknzd_bbox in batch[self.cond_stage_key][:N]: - bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256)) + bboximg = mapper.plot( + tknzd_bbox.detach().cpu(), map_fn, (256, 256) + ) bbox_imgs.append(bboximg) cond_img = torch.stack(bbox_imgs, dim=0) diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index 62912d1a07..1da81eee5a 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -1,8 +1,9 @@ -'''wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers''' +"""wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers""" import k_diffusion as K import torch import torch.nn as nn + class CFGDenoiser(nn.Module): def __init__(self, model): super().__init__() @@ -15,8 +16,9 @@ class CFGDenoiser(nn.Module): uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) return uncond + (cond - uncond) * cond_scale + class KSampler(object): - def __init__(self, model, schedule="lms", device="cuda", **kwargs): + def __init__(self, model, schedule='lms', device='cuda', **kwargs): super().__init__() self.model = K.external.CompVisDenoiser(model) self.schedule = schedule @@ -26,44 +28,57 @@ class KSampler(object): x_in = torch.cat([x] * 2) sigma_in = torch.cat([sigma] * 2) cond_in = torch.cat([uncond, cond]) - uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + uncond, cond = self.inner_model( + x_in, sigma_in, cond=cond_in + ).chunk(2) return uncond + (cond - uncond) * cond_scale - # most of these arguments are ignored and are only present for compatibility with # other samples @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs - ): + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs, + ): sigmas = self.model.get_sigmas(S) if x_T: x = x_T else: - x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] # for GPU draw + x = ( + torch.randn([batch_size, *shape], device=self.device) + * sigmas[0] + ) # for GPU draw model_wrap_cfg = CFGDenoiser(self.model) - extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale} - return (K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args=extra_args), - None) + extra_args = { + 'cond': conditioning, + 'uncond': unconditional_conditioning, + 'cond_scale': unconditional_guidance_scale, + } + return ( + K.sampling.__dict__[f'sample_{self.schedule}']( + model_wrap_cfg, x, sigmas, extra_args=extra_args + ), + None, + ) diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index 5eafe1d7ce..7b9dc4706b 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -5,11 +5,15 @@ import numpy as np from tqdm import tqdm from functools import partial -from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +from ldm.modules.diffusionmodules.util import ( + make_ddim_sampling_parameters, + make_ddim_timesteps, + noise_like, +) class PLMSSampler(object): - def __init__(self, model, schedule="linear", device="cuda", **kwargs): + def __init__(self, model, schedule='linear', device='cuda', **kwargs): super().__init__() self.model = model self.ddpm_num_timesteps = model.num_timesteps @@ -23,103 +27,172 @@ class PLMSSampler(object): setattr(self, name, attr) - def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + def make_schedule( + self, + ddim_num_steps, + ddim_discretize='uniform', + ddim_eta=0.0, + verbose=True, + ): if ddim_eta != 0: raise ValueError('ddim_eta must be 0 for PLMS') - self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) alphas_cumprod = self.model.alphas_cumprod - assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' - to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + assert ( + alphas_cumprod.shape[0] == self.ddpm_num_timesteps + ), 'alphas have to be defined for each timestep' + to_torch = ( + lambda x: x.clone() + .detach() + .to(torch.float32) + .to(self.model.device) + ) self.register_buffer('betas', to_torch(self.model.betas)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + self.register_buffer( + 'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev) + ) # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + self.register_buffer( + 'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())) + ) + self.register_buffer( + 'sqrt_one_minus_alphas_cumprod', + to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), + ) + self.register_buffer( + 'log_one_minus_alphas_cumprod', + to_torch(np.log(1.0 - alphas_cumprod.cpu())), + ) + self.register_buffer( + 'sqrt_recip_alphas_cumprod', + to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())), + ) + self.register_buffer( + 'sqrt_recipm1_alphas_cumprod', + to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), + ) # ddim sampling parameters - ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta,verbose=verbose) + ( + ddim_sigmas, + ddim_alphas, + ddim_alphas_prev, + ) = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose, + ) self.register_buffer('ddim_sigmas', ddim_sigmas) self.register_buffer('ddim_alphas', ddim_alphas) self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) - self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + self.register_buffer( + 'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas) + ) sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( - 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) - self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer( + 'ddim_sigmas_for_original_num_steps', + sigmas_for_original_sampling_steps, + ) @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs - ): + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs, + ): if conditioning is not None: if isinstance(conditioning, dict): cbs = conditioning[list(conditioning.keys())[0]].shape[0] if cbs != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + print( + f'Warning: Got {cbs} conditionings but batch-size is {batch_size}' + ) else: if conditioning.shape[0] != batch_size: - print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + print( + f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}' + ) self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) # sampling C, H, W = shape size = (batch_size, C, H, W) -# print(f'Data shape for PLMS sampling is {size}') + # print(f'Data shape for PLMS sampling is {size}') - samples, intermediates = self.plms_sampling(conditioning, size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - ) + samples, intermediates = self.plms_sampling( + conditioning, + size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) return samples, intermediates @torch.no_grad() - def plms_sampling(self, cond, shape, - x_T=None, ddim_use_original_steps=False, - callback=None, timesteps=None, quantize_denoised=False, - mask=None, x0=None, img_callback=None, log_every_t=100, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None,): + def plms_sampling( + self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + ): device = self.model.betas.device b = shape[0] if x_T is None: @@ -128,42 +201,81 @@ class PLMSSampler(object): img = x_T if timesteps is None: - timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + timesteps = ( + self.ddpm_num_timesteps + if ddim_use_original_steps + else self.ddim_timesteps + ) elif timesteps is not None and not ddim_use_original_steps: - subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + subset_end = ( + int( + min(timesteps / self.ddim_timesteps.shape[0], 1) + * self.ddim_timesteps.shape[0] + ) + - 1 + ) timesteps = self.ddim_timesteps[:subset_end] intermediates = {'x_inter': [img], 'pred_x0': [img]} - time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) - total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] -# print(f"Running PLMS Sampling with {total_steps} timesteps") + time_range = ( + list(reversed(range(0, timesteps))) + if ddim_use_original_steps + else np.flip(timesteps) + ) + total_steps = ( + timesteps if ddim_use_original_steps else timesteps.shape[0] + ) + # print(f"Running PLMS Sampling with {total_steps} timesteps") - iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps, dynamic_ncols=True) + iterator = tqdm( + time_range, + desc='PLMS Sampler', + total=total_steps, + dynamic_ncols=True, + ) old_eps = [] for i, step in enumerate(iterator): index = total_steps - i - 1 ts = torch.full((b,), step, device=device, dtype=torch.long) - ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) + ts_next = torch.full( + (b,), + time_range[min(i + 1, len(time_range) - 1)], + device=device, + dtype=torch.long, + ) if mask is not None: assert x0 is not None - img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? - img = img_orig * mask + (1. - mask) * img + img_orig = self.model.q_sample( + x0, ts + ) # TODO: deterministic forward pass? + img = img_orig * mask + (1.0 - mask) * img - outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, temperature=temperature, - noise_dropout=noise_dropout, score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - old_eps=old_eps, t_next=ts_next) + outs = self.p_sample_plms( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, + t_next=ts_next, + ) img, pred_x0, e_t = outs old_eps.append(e_t) if len(old_eps) >= 4: old_eps.pop(0) - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) if index % log_every_t == 0 or index == total_steps - 1: intermediates['x_inter'].append(img) @@ -172,47 +284,95 @@ class PLMSSampler(object): return img, intermediates @torch.no_grad() - def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): + def p_sample_plms( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + old_eps=None, + t_next=None, + ): b, *_, device = *x.shape, x.device def get_model_output(x, t): - if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + if ( + unconditional_conditioning is None + or unconditional_guidance_scale == 1.0 + ): e_t = self.model.apply_model(x, t, c) else: x_in = torch.cat([x] * 2) t_in = torch.cat([t] * 2) c_in = torch.cat([unconditional_conditioning, c]) - e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) - e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + e_t_uncond, e_t = self.model.apply_model( + x_in, t_in, c_in + ).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * ( + e_t - e_t_uncond + ) if score_corrector is not None: - assert self.model.parameterization == "eps" - e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + assert self.model.parameterization == 'eps' + e_t = score_corrector.modify_score( + self.model, e_t, x, t, c, **corrector_kwargs + ) return e_t - alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas - alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev - sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas - sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + alphas = ( + self.model.alphas_cumprod + if use_original_steps + else self.ddim_alphas + ) + alphas_prev = ( + self.model.alphas_cumprod_prev + if use_original_steps + else self.ddim_alphas_prev + ) + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod + if use_original_steps + else self.ddim_sqrt_one_minus_alphas + ) + sigmas = ( + self.model.ddim_sigmas_for_original_num_steps + if use_original_steps + else self.ddim_sigmas + ) def get_x_prev_and_pred_x0(e_t, index): # select parameters corresponding to the currently considered timestep a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) - a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + a_prev = torch.full( + (b, 1, 1, 1), alphas_prev[index], device=device + ) sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + sqrt_one_minus_at = torch.full( + (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device + ) # current prediction for x_0 pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() if quantize_denoised: pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) # direction pointing to x_t - dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t - noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t + noise = ( + sigma_t + * noise_like(x.shape, device, repeat_noise) + * temperature + ) + if noise_dropout > 0.0: noise = torch.nn.functional.dropout(noise, p=noise_dropout) x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise return x_prev, pred_x0 @@ -231,7 +391,12 @@ class PLMSSampler(object): e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 elif len(old_eps) >= 3: # 4nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + e_t_prime = ( + 55 * e_t + - 59 * old_eps[-1] + + 37 * old_eps[-2] + - 9 * old_eps[-3] + ) / 24 x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index f4eff39ccb..960a112001 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -13,7 +13,7 @@ def exists(val): def uniq(arr): - return{el: True for el in arr}.keys() + return {el: True for el in arr}.keys() def default(val, d): @@ -45,19 +45,18 @@ class GEGLU(nn.Module): class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) - project_in = nn.Sequential( - nn.Linear(dim, inner_dim), - nn.GELU() - ) if not glu else GEGLU(dim, inner_dim) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) self.net = nn.Sequential( - project_in, - nn.Dropout(dropout), - nn.Linear(inner_dim, dim_out) + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) ) def forward(self, x): @@ -74,7 +73,9 @@ def zero_module(module): def Normalize(in_channels): - return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + return torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) class LinearAttention(nn.Module): @@ -82,17 +83,28 @@ class LinearAttention(nn.Module): super().__init__() self.heads = heads hidden_dim = dim_head * heads - self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x) - q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) - k = k.softmax(dim=-1) + q, k, v = rearrange( + qkv, + 'b (qkv heads c) h w -> qkv b heads c (h w)', + heads=self.heads, + qkv=3, + ) + k = k.softmax(dim=-1) context = torch.einsum('bhdn,bhen->bhde', k, v) out = torch.einsum('bhde,bhdn->bhen', context, q) - out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + out = rearrange( + out, + 'b heads c (h w) -> b (heads c) h w', + heads=self.heads, + h=h, + w=w, + ) return self.to_out(out) @@ -102,26 +114,18 @@ class SpatialSelfAttention(nn.Module): self.in_channels = in_channels self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) def forward(self, x): h_ = x @@ -131,12 +135,12 @@ class SpatialSelfAttention(nn.Module): v = self.v(h_) # compute attention - b,c,h,w = q.shape + b, c, h, w = q.shape q = rearrange(q, 'b c h w -> b (h w) c') k = rearrange(k, 'b c h w -> b c (h w)') w_ = torch.einsum('bij,bjk->bik', q, k) - w_ = w_ * (int(c)**(-0.5)) + w_ = w_ * (int(c) ** (-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values @@ -146,16 +150,18 @@ class SpatialSelfAttention(nn.Module): h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) h_ = self.proj_out(h_) - return x+h_ + return x + h_ class CrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + def __init__( + self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0 + ): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) - self.scale = dim_head ** -0.5 + self.scale = dim_head**-0.5 self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) @@ -163,8 +169,7 @@ class CrossAttention(nn.Module): self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), - nn.Dropout(dropout) + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) def forward(self, x, context=None, mask=None): @@ -175,7 +180,9 @@ class CrossAttention(nn.Module): k = self.to_k(context) v = self.to_v(context) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + q, k, v = map( + lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v) + ) sim = einsum('b i d, b j d -> b i j', q, k) * self.scale @@ -194,19 +201,37 @@ class CrossAttention(nn.Module): class BasicTransformerBlock(nn.Module): - def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + ): super().__init__() - self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention + self.attn1 = CrossAttention( + query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, - heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.attn2 = CrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + ) # is self-attn if context is none self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim) self.checkpoint = checkpoint def forward(self, x, context=None): - return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + return checkpoint( + self._forward, (x, context), self.parameters(), self.checkpoint + ) def _forward(self, x, context=None): x = self.attn1(self.norm1(x)) + x @@ -223,29 +248,43 @@ class SpatialTransformer(nn.Module): Then apply standard transformer action. Finally, reshape to image """ - def __init__(self, in_channels, n_heads, d_head, - depth=1, dropout=0., context_dim=None): + + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + ): super().__init__() self.in_channels = in_channels inner_dim = n_heads * d_head self.norm = Normalize(in_channels) - self.proj_in = nn.Conv2d(in_channels, - inner_dim, - kernel_size=1, - stride=1, - padding=0) - - self.transformer_blocks = nn.ModuleList( - [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) - for d in range(depth)] + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 ) - self.proj_out = zero_module(nn.Conv2d(inner_dim, - in_channels, - kernel_size=1, - stride=1, - padding=0)) + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim, + ) + for d in range(depth) + ] + ) + + self.proj_out = zero_module( + nn.Conv2d( + inner_dim, in_channels, kernel_size=1, stride=1, padding=0 + ) + ) def forward(self, x, context=None): # note: if no context is given, cross-attention defaults to self-attention @@ -258,4 +297,4 @@ class SpatialTransformer(nn.Module): x = block(x, context=context) x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) x = self.proj_out(x) - return x + x_in \ No newline at end of file + return x + x_in diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index 533e589a20..cd79e37565 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -26,17 +26,19 @@ def get_timestep_embedding(timesteps, embedding_dim): emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0,1,0,0)) + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb def nonlinearity(x): # swish - return x*torch.sigmoid(x) + return x * torch.sigmoid(x) def Normalize(in_channels, num_groups=32): - return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) class Upsample(nn.Module): @@ -44,14 +46,14 @@ class Upsample(nn.Module): super().__init__() self.with_conv = with_conv if self.with_conv: - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) def forward(self, x): - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = torch.nn.functional.interpolate( + x, scale_factor=2.0, mode='nearest' + ) if self.with_conv: x = self.conv(x) return x @@ -63,16 +65,14 @@ class Downsample(nn.Module): self.with_conv = with_conv if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=3, - stride=2, - padding=0) + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) def forward(self, x): if self.with_conv: - pad = (0,1,0,1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode='constant', value=0) x = self.conv(x) else: x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) @@ -80,8 +80,15 @@ class Downsample(nn.Module): class ResnetBlock(nn.Module): - def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, - dropout, temb_channels=512): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + ): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels @@ -89,34 +96,33 @@ class ResnetBlock(nn.Module): self.use_conv_shortcut = conv_shortcut self.norm1 = Normalize(in_channels) - self.conv1 = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, - out_channels) + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) self.norm2 = Normalize(out_channels) self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d(out_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv_shortcut = torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) else: - self.nin_shortcut = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=1, - stride=1, - padding=0) + self.nin_shortcut = torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) def forward(self, x, temb): h = x @@ -125,7 +131,7 @@ class ResnetBlock(nn.Module): h = self.conv1(h) if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] h = self.norm2(h) h = nonlinearity(h) @@ -138,11 +144,12 @@ class ResnetBlock(nn.Module): else: x = self.nin_shortcut(x) - return x+h + return x + h class LinAttnBlock(LinearAttention): """to match AttnBlock usage""" + def __init__(self, in_channels): super().__init__(dim=in_channels, heads=1, dim_head=in_channels) @@ -153,27 +160,18 @@ class AttnBlock(nn.Module): self.in_channels = in_channels self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) def forward(self, x): h_ = x @@ -183,44 +181,66 @@ class AttnBlock(nn.Module): v = self.v(h_) # compute attention - b,c,h,w = q.shape - q = q.reshape(b,c,h*w) - q = q.permute(0,2,1) # b,hw,c - k = k.reshape(b,c,h*w) # b,c,hw - w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c)**(-0.5)) + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values - v = v.reshape(b,c,h*w) - w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b,c,h,w) + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm( + v, w_ + ) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) h_ = self.proj_out(h_) - return x+h_ + return x + h_ -def make_attn(in_channels, attn_type="vanilla"): - assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' - print(f"making attention of type '{attn_type}' with {in_channels} in_channels") - if attn_type == "vanilla": +def make_attn(in_channels, attn_type='vanilla'): + assert attn_type in [ + 'vanilla', + 'linear', + 'none', + ], f'attn_type {attn_type} unknown' + print( + f"making attention of type '{attn_type}' with {in_channels} in_channels" + ) + if attn_type == 'vanilla': return AttnBlock(in_channels) - elif attn_type == "none": + elif attn_type == 'none': return nn.Identity(in_channels) else: return LinAttnBlock(in_channels) class Model(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type='vanilla', + ): super().__init__() - if use_linear_attn: attn_type = "linear" + if use_linear_attn: + attn_type = 'linear' self.ch = ch - self.temb_ch = self.ch*4 + self.temb_ch = self.ch * 4 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution @@ -230,70 +250,80 @@ class Model(nn.Module): if self.use_timestep: # timestep embedding self.temb = nn.Module() - self.temb.dense = nn.ModuleList([ - torch.nn.Linear(self.ch, - self.temb_ch), - torch.nn.Linear(self.temb_ch, - self.temb_ch), - ]) + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, - self.ch, - kernel_size=3, - stride=1, - padding=1) + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) curr_res = resolution - in_ch_mult = (1,)+tuple(ch_mult) + in_ch_mult = (1,) + tuple(ch_mult) self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() - block_in = ch*in_ch_mult[i_level] - block_out = ch*ch_mult[i_level] + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) down = nn.Module() down.block = block down.attn = attn - if i_level != self.num_resolutions-1: + if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in, resamp_with_conv) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() - block_out = ch*ch_mult[i_level] - skip_in = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks+1): + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): if i_block == self.num_res_blocks: - skip_in = ch*in_ch_mult[i_level] - block.append(ResnetBlock(in_channels=block_in+skip_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) @@ -303,18 +333,16 @@ class Model(nn.Module): if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order + self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_ch, - kernel_size=3, - stride=1, - padding=1) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) def forward(self, x, t=None, context=None): - #assert x.shape[2] == x.shape[3] == self.resolution + # assert x.shape[2] == x.shape[3] == self.resolution if context is not None: # assume aligned context, cat along channel axis x = torch.cat((x, context), dim=1) @@ -336,7 +364,7 @@ class Model(nn.Module): if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) - if i_level != self.num_resolutions-1: + if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # middle @@ -347,9 +375,10 @@ class Model(nn.Module): # upsampling for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks+1): + for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block]( - torch.cat([h, hs.pop()], dim=1), temb) + torch.cat([h, hs.pop()], dim=1), temb + ) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) if i_level != 0: @@ -366,12 +395,27 @@ class Model(nn.Module): class Encoder(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", - **ignore_kwargs): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type='vanilla', + **ignore_kwargs, + ): super().__init__() - if use_linear_attn: attn_type = "linear" + if use_linear_attn: + attn_type = 'linear' self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) @@ -380,56 +424,64 @@ class Encoder(nn.Module): self.in_channels = in_channels # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, - self.ch, - kernel_size=3, - stride=1, - padding=1) + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) curr_res = resolution - in_ch_mult = (1,)+tuple(ch_mult) + in_ch_mult = (1,) + tuple(ch_mult) self.in_ch_mult = in_ch_mult self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() - block_in = ch*in_ch_mult[i_level] - block_out = ch*ch_mult[i_level] + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) down = nn.Module() down.block = block down.attn = attn - if i_level != self.num_resolutions-1: + if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in, resamp_with_conv) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - 2*z_channels if double_z else z_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) def forward(self, x): # timestep embedding @@ -443,7 +495,7 @@ class Encoder(nn.Module): if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) - if i_level != self.num_resolutions-1: + if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # middle @@ -460,12 +512,28 @@ class Encoder(nn.Module): class Decoder(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, - attn_type="vanilla", **ignorekwargs): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type='vanilla', + **ignorekwargs, + ): super().__init__() - if use_linear_attn: attn_type = "linear" + if use_linear_attn: + attn_type = 'linear' self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) @@ -476,43 +544,52 @@ class Decoder(nn.Module): self.tanh_out = tanh_out # compute in_ch_mult, block_in and curr_res at lowest res - in_ch_mult = (1,)+tuple(ch_mult) - block_in = ch*ch_mult[self.num_resolutions-1] - curr_res = resolution // 2**(self.num_resolutions-1) - self.z_shape = (1,z_channels,curr_res,curr_res) - print("Working with z of shape {} = {} dimensions.".format( - self.z_shape, np.prod(self.z_shape))) + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print( + 'Working with z of shape {} = {} dimensions.'.format( + self.z_shape, np.prod(self.z_shape) + ) + ) # z to block_in - self.conv_in = torch.nn.Conv2d(z_channels, - block_in, - kernel_size=3, - stride=1, - padding=1) + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() - block_out = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks+1): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) @@ -522,18 +599,16 @@ class Decoder(nn.Module): if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order + self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_ch, - kernel_size=3, - stride=1, - padding=1) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) def forward(self, z): - #assert z.shape[1:] == self.z_shape[1:] + # assert z.shape[1:] == self.z_shape[1:] self.last_z_shape = z.shape # timestep embedding @@ -549,7 +624,7 @@ class Decoder(nn.Module): # upsampling for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks+1): + for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h, temb) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) @@ -571,29 +646,40 @@ class Decoder(nn.Module): class SimpleDecoder(nn.Module): def __init__(self, in_channels, out_channels, *args, **kwargs): super().__init__() - self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), - ResnetBlock(in_channels=in_channels, - out_channels=2 * in_channels, - temb_channels=0, dropout=0.0), - ResnetBlock(in_channels=2 * in_channels, - out_channels=4 * in_channels, - temb_channels=0, dropout=0.0), - ResnetBlock(in_channels=4 * in_channels, - out_channels=2 * in_channels, - temb_channels=0, dropout=0.0), - nn.Conv2d(2*in_channels, in_channels, 1), - Upsample(in_channels, with_conv=True)]) + self.model = nn.ModuleList( + [ + nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock( + in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True), + ] + ) # end self.norm_out = Normalize(in_channels) - self.conv_out = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv_out = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) def forward(self, x): for i, layer in enumerate(self.model): - if i in [1,2,3]: + if i in [1, 2, 3]: x = layer(x, None) else: x = layer(x) @@ -605,8 +691,16 @@ class SimpleDecoder(nn.Module): class UpsampleDecoder(nn.Module): - def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, - ch_mult=(2,2), dropout=0.0): + def __init__( + self, + in_channels, + out_channels, + ch, + num_res_blocks, + resolution, + ch_mult=(2, 2), + dropout=0.0, + ): super().__init__() # upsampling self.temb_ch = 0 @@ -620,10 +714,14 @@ class UpsampleDecoder(nn.Module): res_block = [] block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): - res_block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + res_block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) block_in = block_out self.res_blocks.append(nn.ModuleList(res_block)) if i_level != self.num_resolutions - 1: @@ -632,11 +730,9 @@ class UpsampleDecoder(nn.Module): # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv_out = torch.nn.Conv2d( + block_in, out_channels, kernel_size=3, stride=1, padding=1 + ) def forward(self, x): # upsampling @@ -653,35 +749,56 @@ class UpsampleDecoder(nn.Module): class LatentRescaler(nn.Module): - def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + def __init__( + self, factor, in_channels, mid_channels, out_channels, depth=2 + ): super().__init__() # residual block, interpolate, residual block self.factor = factor - self.conv_in = nn.Conv2d(in_channels, - mid_channels, - kernel_size=3, - stride=1, - padding=1) - self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, - out_channels=mid_channels, - temb_channels=0, - dropout=0.0) for _ in range(depth)]) + self.conv_in = nn.Conv2d( + in_channels, mid_channels, kernel_size=3, stride=1, padding=1 + ) + self.res_block1 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) self.attn = AttnBlock(mid_channels) - self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, - out_channels=mid_channels, - temb_channels=0, - dropout=0.0) for _ in range(depth)]) + self.res_block2 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) - self.conv_out = nn.Conv2d(mid_channels, - out_channels, - kernel_size=1, - ) + self.conv_out = nn.Conv2d( + mid_channels, + out_channels, + kernel_size=1, + ) def forward(self, x): x = self.conv_in(x) for block in self.res_block1: x = block(x, None) - x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) + x = torch.nn.functional.interpolate( + x, + size=( + int(round(x.shape[2] * self.factor)), + int(round(x.shape[3] * self.factor)), + ), + ) x = self.attn(x) for block in self.res_block2: x = block(x, None) @@ -690,17 +807,42 @@ class LatentRescaler(nn.Module): class MergedRescaleEncoder(nn.Module): - def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, - ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): + def __init__( + self, + in_channels, + ch, + resolution, + out_ch, + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + ch_mult=(1, 2, 4, 8), + rescale_factor=1.0, + rescale_module_depth=1, + ): super().__init__() intermediate_chn = ch * ch_mult[-1] - self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, - z_channels=intermediate_chn, double_z=False, resolution=resolution, - attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, - out_ch=None) - self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, - mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + self.encoder = Encoder( + in_channels=in_channels, + num_res_blocks=num_res_blocks, + ch=ch, + ch_mult=ch_mult, + z_channels=intermediate_chn, + double_z=False, + resolution=resolution, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + out_ch=None, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=intermediate_chn, + mid_channels=intermediate_chn, + out_channels=out_ch, + depth=rescale_module_depth, + ) def forward(self, x): x = self.encoder(x) @@ -709,15 +851,41 @@ class MergedRescaleEncoder(nn.Module): class MergedRescaleDecoder(nn.Module): - def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), - dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + def __init__( + self, + z_channels, + out_ch, + resolution, + num_res_blocks, + attn_resolutions, + ch, + ch_mult=(1, 2, 4, 8), + dropout=0.0, + resamp_with_conv=True, + rescale_factor=1.0, + rescale_module_depth=1, + ): super().__init__() - tmp_chn = z_channels*ch_mult[-1] - self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, - resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, - ch_mult=ch_mult, resolution=resolution, ch=ch) - self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, - out_channels=tmp_chn, depth=rescale_module_depth) + tmp_chn = z_channels * ch_mult[-1] + self.decoder = Decoder( + out_ch=out_ch, + z_channels=tmp_chn, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + in_channels=None, + num_res_blocks=num_res_blocks, + ch_mult=ch_mult, + resolution=resolution, + ch=ch, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=z_channels, + mid_channels=tmp_chn, + out_channels=tmp_chn, + depth=rescale_module_depth, + ) def forward(self, x): x = self.rescaler(x) @@ -726,17 +894,32 @@ class MergedRescaleDecoder(nn.Module): class Upsampler(nn.Module): - def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + def __init__( + self, in_size, out_size, in_channels, out_channels, ch_mult=2 + ): super().__init__() assert out_size >= in_size - num_blocks = int(np.log2(out_size//in_size))+1 - factor_up = 1.+ (out_size % in_size) - print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") - self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, - out_channels=in_channels) - self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, - attn_resolutions=[], in_channels=None, ch=in_channels, - ch_mult=[ch_mult for _ in range(num_blocks)]) + num_blocks = int(np.log2(out_size // in_size)) + 1 + factor_up = 1.0 + (out_size % in_size) + print( + f'Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}' + ) + self.rescaler = LatentRescaler( + factor=factor_up, + in_channels=in_channels, + mid_channels=2 * in_channels, + out_channels=in_channels, + ) + self.decoder = Decoder( + out_ch=out_channels, + resolution=out_size, + z_channels=in_channels, + num_res_blocks=2, + attn_resolutions=[], + in_channels=None, + ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)], + ) def forward(self, x): x = self.rescaler(x) @@ -745,42 +928,55 @@ class Upsampler(nn.Module): class Resize(nn.Module): - def __init__(self, in_channels=None, learned=False, mode="bilinear"): + def __init__(self, in_channels=None, learned=False, mode='bilinear'): super().__init__() self.with_conv = learned self.mode = mode if self.with_conv: - print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + print( + f'Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode' + ) raise NotImplementedError() assert in_channels is not None # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=4, - stride=2, - padding=1) + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=4, stride=2, padding=1 + ) def forward(self, x, scale_factor=1.0): - if scale_factor==1.0: + if scale_factor == 1.0: return x else: - x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + x = torch.nn.functional.interpolate( + x, + mode=self.mode, + align_corners=False, + scale_factor=scale_factor, + ) return x -class FirstStagePostProcessor(nn.Module): - def __init__(self, ch_mult:list, in_channels, - pretrained_model:nn.Module=None, - reshape=False, - n_channels=None, - dropout=0., - pretrained_config=None): +class FirstStagePostProcessor(nn.Module): + def __init__( + self, + ch_mult: list, + in_channels, + pretrained_model: nn.Module = None, + reshape=False, + n_channels=None, + dropout=0.0, + pretrained_config=None, + ): super().__init__() if pretrained_config is None: - assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + assert ( + pretrained_model is not None + ), 'Either "pretrained_model" or "pretrained_config" must not be None' self.pretrained_model = pretrained_model else: - assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + assert ( + pretrained_config is not None + ), 'Either "pretrained_model" or "pretrained_config" must not be None' self.instantiate_pretrained(pretrained_config) self.do_reshape = reshape @@ -788,22 +984,28 @@ class FirstStagePostProcessor(nn.Module): if n_channels is None: n_channels = self.pretrained_model.encoder.ch - self.proj_norm = Normalize(in_channels,num_groups=in_channels//2) - self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3, - stride=1,padding=1) + self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2) + self.proj = nn.Conv2d( + in_channels, n_channels, kernel_size=3, stride=1, padding=1 + ) blocks = [] downs = [] ch_in = n_channels for m in ch_mult: - blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout)) + blocks.append( + ResnetBlock( + in_channels=ch_in, + out_channels=m * n_channels, + dropout=dropout, + ) + ) ch_in = m * n_channels downs.append(Downsample(ch_in, with_conv=False)) self.model = nn.ModuleList(blocks) self.downsampler = nn.ModuleList(downs) - def instantiate_pretrained(self, config): model = instantiate_from_config(config) self.pretrained_model = model.eval() @@ -811,25 +1013,23 @@ class FirstStagePostProcessor(nn.Module): for param in self.pretrained_model.parameters(): param.requires_grad = False - @torch.no_grad() - def encode_with_pretrained(self,x): + def encode_with_pretrained(self, x): c = self.pretrained_model.encode(x) if isinstance(c, DiagonalGaussianDistribution): c = c.mode() - return c + return c - def forward(self,x): + def forward(self, x): z_fs = self.encode_with_pretrained(x) z = self.proj_norm(z_fs) z = self.proj(z) z = nonlinearity(z) - for submodel, downmodel in zip(self.model,self.downsampler): - z = submodel(z,temb=None) + for submodel, downmodel in zip(self.model, self.downsampler): + z = submodel(z, temb=None) z = downmodel(z) if self.do_reshape: - z = rearrange(z,'b c h w -> b (h w) c') + z = rearrange(z, 'b c h w -> b (h w) c') return z - diff --git a/ldm/modules/diffusionmodules/openaimodel.py b/ldm/modules/diffusionmodules/openaimodel.py index fcf95d1ea8..d6baa76a1c 100644 --- a/ldm/modules/diffusionmodules/openaimodel.py +++ b/ldm/modules/diffusionmodules/openaimodel.py @@ -24,6 +24,7 @@ from ldm.modules.attention import SpatialTransformer def convert_module_to_f16(x): pass + def convert_module_to_f32(x): pass @@ -42,7 +43,9 @@ class AttentionPool2d(nn.Module): output_dim: int = None, ): super().__init__() - self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.positional_embedding = nn.Parameter( + th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 + ) self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) self.num_heads = embed_dim // num_heads_channels @@ -97,37 +100,45 @@ class Upsample(nn.Module): upsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + def __init__( + self, channels, use_conv, dims=2, out_channels=None, padding=1 + ): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.dims = dims if use_conv: - self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + self.conv = conv_nd( + dims, self.channels, self.out_channels, 3, padding=padding + ) def forward(self, x): assert x.shape[1] == self.channels if self.dims == 3: x = F.interpolate( - x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode='nearest' ) else: - x = F.interpolate(x, scale_factor=2, mode="nearest") + x = F.interpolate(x, scale_factor=2, mode='nearest') if self.use_conv: x = self.conv(x) return x + class TransposedUpsample(nn.Module): - 'Learned 2x upsampling without padding' + """Learned 2x upsampling without padding""" + def __init__(self, channels, out_channels=None, ks=5): super().__init__() self.channels = channels self.out_channels = out_channels or channels - self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + self.up = nn.ConvTranspose2d( + self.channels, self.out_channels, kernel_size=ks, stride=2 + ) - def forward(self,x): + def forward(self, x): return self.up(x) @@ -140,7 +151,9 @@ class Downsample(nn.Module): downsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + def __init__( + self, channels, use_conv, dims=2, out_channels=None, padding=1 + ): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -149,7 +162,12 @@ class Downsample(nn.Module): stride = 2 if dims != 3 else (1, 2, 2) if use_conv: self.op = conv_nd( - dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, ) else: assert self.channels == self.out_channels @@ -219,7 +237,9 @@ class ResBlock(TimestepBlock): nn.SiLU(), linear( emb_channels, - 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + 2 * self.out_channels + if use_scale_shift_norm + else self.out_channels, ), ) self.out_layers = nn.Sequential( @@ -227,7 +247,9 @@ class ResBlock(TimestepBlock): nn.SiLU(), nn.Dropout(p=dropout), zero_module( - conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + conv_nd( + dims, self.out_channels, self.out_channels, 3, padding=1 + ) ), ) @@ -238,7 +260,9 @@ class ResBlock(TimestepBlock): dims, channels, self.out_channels, 3, padding=1 ) else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 1 + ) def forward(self, x, emb): """ @@ -251,7 +275,6 @@ class ResBlock(TimestepBlock): self._forward, (x, emb), self.parameters(), self.use_checkpoint ) - def _forward(self, x, emb): if self.updown: in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] @@ -297,7 +320,7 @@ class AttentionBlock(nn.Module): else: assert ( channels % num_head_channels == 0 - ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + ), f'q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}' self.num_heads = channels // num_head_channels self.use_checkpoint = use_checkpoint self.norm = normalization(channels) @@ -312,8 +335,10 @@ class AttentionBlock(nn.Module): self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) def forward(self, x): - return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! - #return pt_checkpoint(self._forward, x) # pytorch + return checkpoint( + self._forward, (x,), self.parameters(), True + ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + # return pt_checkpoint(self._forward, x) # pytorch def _forward(self, x): b, c, *spatial = x.shape @@ -340,7 +365,7 @@ def count_flops_attn(model, _x, y): # We perform two matmuls with the same number of ops. # The first computes the weight matrix, the second computes # the combination of the value vectors. - matmul_ops = 2 * b * (num_spatial ** 2) * c + matmul_ops = 2 * b * (num_spatial**2) * c model.total_ops += th.DoubleTensor([matmul_ops]) @@ -362,13 +387,15 @@ class QKVAttentionLegacy(nn.Module): bs, width, length = qkv.shape assert width % (3 * self.n_heads) == 0 ch = width // (3 * self.n_heads) - q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split( + ch, dim=1 + ) scale = 1 / math.sqrt(math.sqrt(ch)) weight = th.einsum( - "bct,bcs->bts", q * scale, k * scale + 'bct,bcs->bts', q * scale, k * scale ) # More stable with f16 than dividing afterwards weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) - a = th.einsum("bts,bcs->bct", weight, v) + a = th.einsum('bts,bcs->bct', weight, v) return a.reshape(bs, -1, length) @staticmethod @@ -397,12 +424,14 @@ class QKVAttention(nn.Module): q, k, v = qkv.chunk(3, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) weight = th.einsum( - "bct,bcs->bts", + 'bct,bcs->bts', (q * scale).view(bs * self.n_heads, ch, length), (k * scale).view(bs * self.n_heads, ch, length), ) # More stable with f16 than dividing afterwards weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) - a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + a = th.einsum( + 'bts,bcs->bct', weight, v.reshape(bs * self.n_heads, ch, length) + ) return a.reshape(bs, -1, length) @staticmethod @@ -461,19 +490,24 @@ class UNetModel(nn.Module): use_scale_shift_norm=False, resblock_updown=False, use_new_attention_order=False, - use_spatial_transformer=False, # custom transformer support - transformer_depth=1, # custom transformer support - context_dim=None, # custom transformer support - n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model legacy=True, ): super().__init__() if use_spatial_transformer: - assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + assert ( + context_dim is not None + ), 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' if context_dim is not None: - assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + assert ( + use_spatial_transformer + ), 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: context_dim = list(context_dim) @@ -481,10 +515,14 @@ class UNetModel(nn.Module): num_heads_upsample = num_heads if num_heads == -1: - assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + assert ( + num_head_channels != -1 + ), 'Either num_heads or num_head_channels has to be set' if num_head_channels == -1: - assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + assert ( + num_heads != -1 + ), 'Either num_heads or num_head_channels has to be set' self.image_size = image_size self.in_channels = in_channels @@ -545,8 +583,12 @@ class UNetModel(nn.Module): num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: - #num_heads = 1 - dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) layers.append( AttentionBlock( ch, @@ -554,8 +596,14 @@ class UNetModel(nn.Module): num_heads=num_heads, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) @@ -592,8 +640,12 @@ class UNetModel(nn.Module): num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: - #num_heads = 1 - dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) self.middle_block = TimestepEmbedSequential( ResBlock( ch, @@ -609,9 +661,15 @@ class UNetModel(nn.Module): num_heads=num_heads, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim - ), + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + ), ResBlock( ch, time_embed_dim, @@ -646,8 +704,12 @@ class UNetModel(nn.Module): num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: - #num_heads = 1 - dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) layers.append( AttentionBlock( ch, @@ -655,8 +717,14 @@ class UNetModel(nn.Module): num_heads=num_heads_upsample, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, ) ) if level and i == num_res_blocks: @@ -673,7 +741,9 @@ class UNetModel(nn.Module): up=True, ) if resblock_updown - else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + else Upsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) ) ds //= 2 self.output_blocks.append(TimestepEmbedSequential(*layers)) @@ -682,14 +752,16 @@ class UNetModel(nn.Module): self.out = nn.Sequential( normalization(ch), nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + zero_module( + conv_nd(dims, model_channels, out_channels, 3, padding=1) + ), ) if self.predict_codebook_ids: self.id_predictor = nn.Sequential( - normalization(ch), - conv_nd(dims, model_channels, n_embed, 1), - #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits - ) + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) def convert_to_fp16(self): """ @@ -707,7 +779,7 @@ class UNetModel(nn.Module): self.middle_block.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32) - def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. @@ -718,9 +790,11 @@ class UNetModel(nn.Module): """ assert (y is not None) == ( self.num_classes is not None - ), "must specify y if and only if the model is class-conditional" + ), 'must specify y if and only if the model is class-conditional' hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + t_emb = timestep_embedding( + timesteps, self.model_channels, repeat_only=False + ) emb = self.time_embed(t_emb) if self.num_classes is not None: @@ -768,9 +842,9 @@ class EncoderUNetModel(nn.Module): use_scale_shift_norm=False, resblock_updown=False, use_new_attention_order=False, - pool="adaptive", + pool='adaptive', *args, - **kwargs + **kwargs, ): super().__init__() @@ -888,7 +962,7 @@ class EncoderUNetModel(nn.Module): ) self._feature_size += ch self.pool = pool - if pool == "adaptive": + if pool == 'adaptive': self.out = nn.Sequential( normalization(ch), nn.SiLU(), @@ -896,7 +970,7 @@ class EncoderUNetModel(nn.Module): zero_module(conv_nd(dims, ch, out_channels, 1)), nn.Flatten(), ) - elif pool == "attention": + elif pool == 'attention': assert num_head_channels != -1 self.out = nn.Sequential( normalization(ch), @@ -905,13 +979,13 @@ class EncoderUNetModel(nn.Module): (image_size // ds), ch, num_head_channels, out_channels ), ) - elif pool == "spatial": + elif pool == 'spatial': self.out = nn.Sequential( nn.Linear(self._feature_size, 2048), nn.ReLU(), nn.Linear(2048, self.out_channels), ) - elif pool == "spatial_v2": + elif pool == 'spatial_v2': self.out = nn.Sequential( nn.Linear(self._feature_size, 2048), normalization(2048), @@ -919,7 +993,7 @@ class EncoderUNetModel(nn.Module): nn.Linear(2048, self.out_channels), ) else: - raise NotImplementedError(f"Unexpected {pool} pooling") + raise NotImplementedError(f'Unexpected {pool} pooling') def convert_to_fp16(self): """ @@ -942,20 +1016,21 @@ class EncoderUNetModel(nn.Module): :param timesteps: a 1-D batch of timesteps. :return: an [N x K] Tensor of outputs. """ - emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + emb = self.time_embed( + timestep_embedding(timesteps, self.model_channels) + ) results = [] h = x.type(self.dtype) for module in self.input_blocks: h = module(h, emb) - if self.pool.startswith("spatial"): + if self.pool.startswith('spatial'): results.append(h.type(x.dtype).mean(dim=(2, 3))) h = self.middle_block(h, emb) - if self.pool.startswith("spatial"): + if self.pool.startswith('spatial'): results.append(h.type(x.dtype).mean(dim=(2, 3))) h = th.cat(results, axis=-1) return self.out(h) else: h = h.type(x.dtype) return self.out(h) - diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py index 6b5b9dc9e2..197b42b2bc 100644 --- a/ldm/modules/diffusionmodules/util.py +++ b/ldm/modules/diffusionmodules/util.py @@ -18,15 +18,24 @@ from einops import repeat from ldm.util import instantiate_from_config -def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): - if schedule == "linear": +def make_beta_schedule( + schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 +): + if schedule == 'linear': betas = ( - torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + torch.linspace( + linear_start**0.5, + linear_end**0.5, + n_timestep, + dtype=torch.float64, + ) + ** 2 ) - elif schedule == "cosine": + elif schedule == 'cosine': timesteps = ( - torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + + cosine_s ) alphas = timesteps / (1 + cosine_s) * np.pi / 2 alphas = torch.cos(alphas).pow(2) @@ -34,23 +43,41 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, betas = 1 - alphas[1:] / alphas[:-1] betas = np.clip(betas, a_min=0, a_max=0.999) - elif schedule == "sqrt_linear": - betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) - elif schedule == "sqrt": - betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + elif schedule == 'sqrt_linear': + betas = torch.linspace( + linear_start, linear_end, n_timestep, dtype=torch.float64 + ) + elif schedule == 'sqrt': + betas = ( + torch.linspace( + linear_start, linear_end, n_timestep, dtype=torch.float64 + ) + ** 0.5 + ) else: raise ValueError(f"schedule '{schedule}' unknown.") return betas.numpy() -def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): +def make_ddim_timesteps( + ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True +): if ddim_discr_method == 'uniform': c = num_ddpm_timesteps // num_ddim_timesteps ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) elif ddim_discr_method == 'quad': - ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + ddim_timesteps = ( + ( + np.linspace( + 0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps + ) + ) + ** 2 + ).astype(int) else: - raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + raise NotImplementedError( + f'There is no ddim discretization method called "{ddim_discr_method}"' + ) # assert ddim_timesteps.shape[0] == num_ddim_timesteps # add one to get the final alpha values right (the ones from first scale to data during sampling) @@ -60,17 +87,27 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep return steps_out -def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): +def make_ddim_sampling_parameters( + alphacums, ddim_timesteps, eta, verbose=True +): # select alphas for computing the variance schedule alphas = alphacums[ddim_timesteps] - alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + alphas_prev = np.asarray( + [alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist() + ) # according the the formula provided in https://arxiv.org/abs/2010.02502 - sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + sigmas = eta * np.sqrt( + (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) + ) if verbose: - print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') - print(f'For the chosen value of eta, which is {eta}, ' - f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + print( + f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}' + ) + print( + f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}' + ) return sigmas, alphas, alphas_prev @@ -109,7 +146,9 @@ def checkpoint(func, inputs, params, flag): explicitly take as arguments. :param flag: if False, disable gradient checkpointing. """ - if False: # disabled checkpointing to allow requires_grad = False for main model + if ( + False + ): # disabled checkpointing to allow requires_grad = False for main model args = tuple(inputs) + tuple(params) return CheckpointFunction.apply(func, len(inputs), *args) else: @@ -129,7 +168,9 @@ class CheckpointFunction(torch.autograd.Function): @staticmethod def backward(ctx, *output_grads): - ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + ctx.input_tensors = [ + x.detach().requires_grad_(True) for x in ctx.input_tensors + ] with torch.enable_grad(): # Fixes a bug where the first op in run_function modifies the # Tensor storage in place, which is not allowed for detach()'d @@ -160,12 +201,16 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): if not repeat_only: half = dim // 2 freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half ).to(device=timesteps.device) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) else: embedding = repeat(timesteps, 'b -> b d', d=dim) return embedding @@ -215,6 +260,7 @@ class GroupNorm32(nn.GroupNorm): def forward(self, x): return super().forward(x.float()).type(x.dtype) + def conv_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D convolution module. @@ -225,7 +271,7 @@ def conv_nd(dims, *args, **kwargs): return nn.Conv2d(*args, **kwargs) elif dims == 3: return nn.Conv3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") + raise ValueError(f'unsupported dimensions: {dims}') def linear(*args, **kwargs): @@ -245,15 +291,16 @@ def avg_pool_nd(dims, *args, **kwargs): return nn.AvgPool2d(*args, **kwargs) elif dims == 3: return nn.AvgPool3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") + raise ValueError(f'unsupported dimensions: {dims}') class HybridConditioner(nn.Module): - def __init__(self, c_concat_config, c_crossattn_config): super().__init__() self.concat_conditioner = instantiate_from_config(c_concat_config) - self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + self.crossattn_conditioner = instantiate_from_config( + c_crossattn_config + ) def forward(self, c_concat, c_crossattn): c_concat = self.concat_conditioner(c_concat) @@ -262,6 +309,8 @@ class HybridConditioner(nn.Module): def noise_like(shape, device, repeat=False): - repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( + shape[0], *((1,) * (len(shape) - 1)) + ) noise = lambda: torch.randn(shape, device=device) - return repeat_noise() if repeat else noise() \ No newline at end of file + return repeat_noise() if repeat else noise() diff --git a/ldm/modules/distributions/distributions.py b/ldm/modules/distributions/distributions.py index f2b8ef9011..67ed535791 100644 --- a/ldm/modules/distributions/distributions.py +++ b/ldm/modules/distributions/distributions.py @@ -30,33 +30,45 @@ class DiagonalGaussianDistribution(object): self.std = torch.exp(0.5 * self.logvar) self.var = torch.exp(self.logvar) if self.deterministic: - self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + self.var = self.std = torch.zeros_like(self.mean).to( + device=self.parameters.device + ) def sample(self): - x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + x = self.mean + self.std * torch.randn(self.mean.shape).to( + device=self.parameters.device + ) return x def kl(self, other=None): if self.deterministic: - return torch.Tensor([0.]) + return torch.Tensor([0.0]) else: if other is None: - return 0.5 * torch.sum(torch.pow(self.mean, 2) - + self.var - 1.0 - self.logvar, - dim=[1, 2, 3]) + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) else: return 0.5 * torch.sum( torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - 1.0 - self.logvar + other.logvar, - dim=[1, 2, 3]) + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) - def nll(self, sample, dims=[1,2,3]): + def nll(self, sample, dims=[1, 2, 3]): if self.deterministic: - return torch.Tensor([0.]) + return torch.Tensor([0.0]) logtwopi = np.log(2.0 * np.pi) return 0.5 * torch.sum( - logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, - dim=dims) + logtwopi + + self.logvar + + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) def mode(self): return self.mean @@ -74,7 +86,7 @@ def normal_kl(mean1, logvar1, mean2, logvar2): if isinstance(obj, torch.Tensor): tensor = obj break - assert tensor is not None, "at least one argument must be a Tensor" + assert tensor is not None, 'at least one argument must be a Tensor' # Force variances to be Tensors. Broadcasting helps convert scalars to # Tensors, but it does not work for torch.exp(). diff --git a/ldm/modules/ema.py b/ldm/modules/ema.py index c8c75af435..2ceec5f0e7 100644 --- a/ldm/modules/ema.py +++ b/ldm/modules/ema.py @@ -10,24 +10,30 @@ class LitEma(nn.Module): self.m_name2s_name = {} self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) - self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates - else torch.tensor(-1,dtype=torch.int)) + self.register_buffer( + 'num_updates', + torch.tensor(0, dtype=torch.int) + if use_num_upates + else torch.tensor(-1, dtype=torch.int), + ) for name, p in model.named_parameters(): if p.requires_grad: - #remove as '.'-character is not allowed in buffers - s_name = name.replace('.','') - self.m_name2s_name.update({name:s_name}) - self.register_buffer(s_name,p.clone().detach().data) + # remove as '.'-character is not allowed in buffers + s_name = name.replace('.', '') + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) self.collected_params = [] - def forward(self,model): + def forward(self, model): decay = self.decay if self.num_updates >= 0: self.num_updates += 1 - decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) + decay = min( + self.decay, (1 + self.num_updates) / (10 + self.num_updates) + ) one_minus_decay = 1.0 - decay @@ -38,8 +44,12 @@ class LitEma(nn.Module): for key in m_param: if m_param[key].requires_grad: sname = self.m_name2s_name[key] - shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) - shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + shadow_params[sname] = shadow_params[sname].type_as( + m_param[key] + ) + shadow_params[sname].sub_( + one_minus_decay * (shadow_params[sname] - m_param[key]) + ) else: assert not key in self.m_name2s_name @@ -48,7 +58,9 @@ class LitEma(nn.Module): shadow_params = dict(self.named_buffers()) for key in m_param: if m_param[key].requires_grad: - m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + m_param[key].data.copy_( + shadow_params[self.m_name2s_name[key]].data + ) else: assert not key in self.m_name2s_name diff --git a/ldm/modules/embedding_manager.py b/ldm/modules/embedding_manager.py index 7020a27b9a..677bc4ad3a 100644 --- a/ldm/modules/embedding_manager.py +++ b/ldm/modules/embedding_manager.py @@ -8,18 +8,29 @@ from ldm.data.personalized import per_img_token_list from transformers import CLIPTokenizer from functools import partial -DEFAULT_PLACEHOLDER_TOKEN = ["*"] +DEFAULT_PLACEHOLDER_TOKEN = ['*'] PROGRESSIVE_SCALE = 2000 + def get_clip_token_for_string(tokenizer, string): - batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True, - return_overflowing_tokens=False, padding="max_length", return_tensors="pt") - tokens = batch_encoding["input_ids"] - assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string" + batch_encoding = tokenizer( + string, + truncation=True, + max_length=77, + return_length=True, + return_overflowing_tokens=False, + padding='max_length', + return_tensors='pt', + ) + tokens = batch_encoding['input_ids'] + assert ( + torch.count_nonzero(tokens - 49407) == 2 + ), f"String '{string}' maps to more than a single token. Please use another string" return tokens[0, 1] + def get_bert_token_for_string(tokenizer, string): token = tokenizer(string) # assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string" @@ -28,42 +39,54 @@ def get_bert_token_for_string(tokenizer, string): return token + def get_embedding_for_clip_token(embedder, token): return embedder(token.unsqueeze(0))[0, 0] class EmbeddingManager(nn.Module): def __init__( - self, - embedder, - placeholder_strings=None, - initializer_words=None, - per_image_tokens=False, - num_vectors_per_token=1, - progressive_words=False, - **kwargs + self, + embedder, + placeholder_strings=None, + initializer_words=None, + per_image_tokens=False, + num_vectors_per_token=1, + progressive_words=False, + **kwargs, ): super().__init__() self.string_to_token_dict = {} - + self.string_to_param_dict = nn.ParameterDict() - self.initial_embeddings = nn.ParameterDict() # These should not be optimized + self.initial_embeddings = ( + nn.ParameterDict() + ) # These should not be optimized self.progressive_words = progressive_words self.progressive_counter = 0 self.max_vectors_per_token = num_vectors_per_token - if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder + if hasattr( + embedder, 'tokenizer' + ): # using Stable Diffusion's CLIP encoder self.is_clip = True - get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer) - get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.transformer.text_model.embeddings) + get_token_for_string = partial( + get_clip_token_for_string, embedder.tokenizer + ) + get_embedding_for_tkn = partial( + get_embedding_for_clip_token, + embedder.transformer.text_model.embeddings, + ) token_dim = 1280 - else: # using LDM's BERT encoder + else: # using LDM's BERT encoder self.is_clip = False - get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn) + get_token_for_string = partial( + get_bert_token_for_string, embedder.tknz_fn + ) get_embedding_for_tkn = embedder.transformer.token_emb token_dim = 1280 @@ -71,79 +94,140 @@ class EmbeddingManager(nn.Module): placeholder_strings.extend(per_img_token_list) for idx, placeholder_string in enumerate(placeholder_strings): - + token = get_token_for_string(placeholder_string) if initializer_words and idx < len(initializer_words): init_word_token = get_token_for_string(initializer_words[idx]) with torch.no_grad(): - init_word_embedding = get_embedding_for_tkn(init_word_token.cpu()) + init_word_embedding = get_embedding_for_tkn( + init_word_token.cpu() + ) - token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True) - self.initial_embeddings[placeholder_string] = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=False) + token_params = torch.nn.Parameter( + init_word_embedding.unsqueeze(0).repeat( + num_vectors_per_token, 1 + ), + requires_grad=True, + ) + self.initial_embeddings[ + placeholder_string + ] = torch.nn.Parameter( + init_word_embedding.unsqueeze(0).repeat( + num_vectors_per_token, 1 + ), + requires_grad=False, + ) else: - token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True)) - + token_params = torch.nn.Parameter( + torch.rand( + size=(num_vectors_per_token, token_dim), + requires_grad=True, + ) + ) + self.string_to_token_dict[placeholder_string] = token self.string_to_param_dict[placeholder_string] = token_params def forward( - self, - tokenized_text, - embedded_text, + self, + tokenized_text, + embedded_text, ): b, n, device = *tokenized_text.shape, tokenized_text.device - for placeholder_string, placeholder_token in self.string_to_token_dict.items(): + for ( + placeholder_string, + placeholder_token, + ) in self.string_to_token_dict.items(): - placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device) + placeholder_embedding = self.string_to_param_dict[ + placeholder_string + ].to(device) - if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement - placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device)) + if ( + self.max_vectors_per_token == 1 + ): # If there's only one vector per token, we can do a simple replacement + placeholder_idx = torch.where( + tokenized_text == placeholder_token.to(device) + ) embedded_text[placeholder_idx] = placeholder_embedding - else: # otherwise, need to insert and keep track of changing indices + else: # otherwise, need to insert and keep track of changing indices if self.progressive_words: self.progressive_counter += 1 - max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE + max_step_tokens = ( + 1 + self.progressive_counter // PROGRESSIVE_SCALE + ) else: max_step_tokens = self.max_vectors_per_token - num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens) + num_vectors_for_token = min( + placeholder_embedding.shape[0], max_step_tokens + ) - placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device)) + placeholder_rows, placeholder_cols = torch.where( + tokenized_text == placeholder_token.to(device) + ) if placeholder_rows.nelement() == 0: continue - sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True) + sorted_cols, sort_idx = torch.sort( + placeholder_cols, descending=True + ) sorted_rows = placeholder_rows[sort_idx] for idx in range(len(sorted_rows)): row = sorted_rows[idx] col = sorted_cols[idx] - new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n] - new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n] + new_token_row = torch.cat( + [ + tokenized_text[row][:col], + placeholder_token.repeat(num_vectors_for_token).to( + device + ), + tokenized_text[row][col + 1 :], + ], + axis=0, + )[:n] + new_embed_row = torch.cat( + [ + embedded_text[row][:col], + placeholder_embedding[:num_vectors_for_token], + embedded_text[row][col + 1 :], + ], + axis=0, + )[:n] - embedded_text[row] = new_embed_row + embedded_text[row] = new_embed_row tokenized_text[row] = new_token_row return embedded_text def save(self, ckpt_path): - torch.save({"string_to_token": self.string_to_token_dict, - "string_to_param": self.string_to_param_dict}, ckpt_path) + torch.save( + { + 'string_to_token': self.string_to_token_dict, + 'string_to_param': self.string_to_param_dict, + }, + ckpt_path, + ) def load(self, ckpt_path): ckpt = torch.load(ckpt_path, map_location='cpu') - self.string_to_token_dict = ckpt["string_to_token"] - self.string_to_param_dict = ckpt["string_to_param"] + self.string_to_token_dict = ckpt['string_to_token'] + self.string_to_param_dict = ckpt['string_to_param'] def get_embedding_norms_squared(self): - all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim - param_norm_squared = (all_params * all_params).sum(axis=-1) # num_placeholders + all_params = torch.cat( + list(self.string_to_param_dict.values()), axis=0 + ) # num_placeholders x embedding_dim + param_norm_squared = (all_params * all_params).sum( + axis=-1 + ) # num_placeholders return param_norm_squared @@ -151,14 +235,19 @@ class EmbeddingManager(nn.Module): return self.string_to_param_dict.parameters() def embedding_to_coarse_loss(self): - - loss = 0. + + loss = 0.0 num_embeddings = len(self.initial_embeddings) for key in self.initial_embeddings: optimized = self.string_to_param_dict[key] coarse = self.initial_embeddings[key].clone().to(optimized.device) - loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings + loss = ( + loss + + (optimized - coarse) + @ (optimized - coarse).T + / num_embeddings + ) - return loss \ No newline at end of file + return loss diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index def6d2136d..2c25948b5c 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -6,29 +6,39 @@ from einops import rearrange, repeat from transformers import CLIPTokenizer, CLIPTextModel import kornia -from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test +from ldm.modules.x_transformer import ( + Encoder, + TransformerWrapper, +) # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test -def _expand_mask(mask, dtype, tgt_len = None): + +def _expand_mask(mask, dtype, tgt_len=None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + expanded_mask = ( + mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + ) inverted_mask = 1.0 - expanded_mask - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ) + def _build_causal_attention_mask(bsz, seq_len, dtype): - # lazily create causal attention mask, with full attention between the vision tokens - # pytorch uses additive attention mask; fill with -inf - mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype) - mask.fill_(torch.tensor(torch.finfo(dtype).min)) - mask.triu_(1) # zero out the lower diagonal - mask = mask.unsqueeze(1) # expand mask - return mask + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype) + mask.fill_(torch.tensor(torch.finfo(dtype).min)) + mask.triu_(1) # zero out the lower diagonal + mask = mask.unsqueeze(1) # expand mask + return mask + class AbstractEncoder(nn.Module): def __init__(self): @@ -38,7 +48,6 @@ class AbstractEncoder(nn.Module): raise NotImplementedError - class ClassEmbedder(nn.Module): def __init__(self, embed_dim, n_classes=1000, key='class'): super().__init__() @@ -56,11 +65,17 @@ class ClassEmbedder(nn.Module): class TransformerEmbedder(AbstractEncoder): """Some transformer encoder layers""" - def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): + + def __init__( + self, n_embed, n_layer, vocab_size, max_seq_len=77, device='cuda' + ): super().__init__() self.device = device - self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, - attn_layers=Encoder(dim=n_embed, depth=n_layer)) + self.transformer = TransformerWrapper( + num_tokens=vocab_size, + max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer), + ) def forward(self, tokens): tokens = tokens.to(self.device) # meh @@ -72,27 +87,42 @@ class TransformerEmbedder(AbstractEncoder): class BERTTokenizer(AbstractEncoder): - """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" - def __init__(self, device="cuda", vq_interface=True, max_length=77): + """Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" + + def __init__(self, device='cuda', vq_interface=True, max_length=77): super().__init__() - from transformers import BertTokenizerFast # TODO: add to reuquirements + from transformers import ( + BertTokenizerFast, + ) # TODO: add to reuquirements + # Modified to allow to run on non-internet connected compute nodes. # Model needs to be loaded into cache from an internet-connected machine # by running: # from transformers import BertTokenizerFast # BertTokenizerFast.from_pretrained("bert-base-uncased") try: - self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased",local_files_only=True) + self.tokenizer = BertTokenizerFast.from_pretrained( + 'bert-base-uncased', local_files_only=True + ) except OSError: - raise SystemExit("* Couldn't load Bert tokenizer files. Try running scripts/preload_models.py from an internet-conected machine.") + raise SystemExit( + "* Couldn't load Bert tokenizer files. Try running scripts/preload_models.py from an internet-conected machine." + ) self.device = device self.vq_interface = vq_interface self.max_length = max_length def forward(self, text): - batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, - return_overflowing_tokens=False, padding="max_length", return_tensors="pt") - tokens = batch_encoding["input_ids"].to(self.device) + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding='max_length', + return_tensors='pt', + ) + tokens = batch_encoding['input_ids'].to(self.device) return tokens @torch.no_grad() @@ -108,53 +138,84 @@ class BERTTokenizer(AbstractEncoder): class BERTEmbedder(AbstractEncoder): """Uses the BERT tokenizr model and add some transformer encoder layers""" - def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, - device="cuda",use_tokenizer=True, embedding_dropout=0.0): + + def __init__( + self, + n_embed, + n_layer, + vocab_size=30522, + max_seq_len=77, + device='cuda', + use_tokenizer=True, + embedding_dropout=0.0, + ): super().__init__() self.use_tknz_fn = use_tokenizer if self.use_tknz_fn: - self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) + self.tknz_fn = BERTTokenizer( + vq_interface=False, max_length=max_seq_len + ) self.device = device - self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, - attn_layers=Encoder(dim=n_embed, depth=n_layer), - emb_dropout=embedding_dropout) + self.transformer = TransformerWrapper( + num_tokens=vocab_size, + max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer), + emb_dropout=embedding_dropout, + ) def forward(self, text, embedding_manager=None): if self.use_tknz_fn: - tokens = self.tknz_fn(text)#.to(self.device) + tokens = self.tknz_fn(text) # .to(self.device) else: tokens = text - z = self.transformer(tokens, return_embeddings=True, embedding_manager=embedding_manager) + z = self.transformer( + tokens, return_embeddings=True, embedding_manager=embedding_manager + ) return z def encode(self, text, **kwargs): # output of length 77 return self(text, **kwargs) + class SpatialRescaler(nn.Module): - def __init__(self, - n_stages=1, - method='bilinear', - multiplier=0.5, - in_channels=3, - out_channels=None, - bias=False): + def __init__( + self, + n_stages=1, + method='bilinear', + multiplier=0.5, + in_channels=3, + out_channels=None, + bias=False, + ): super().__init__() self.n_stages = n_stages assert self.n_stages >= 0 - assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] + assert method in [ + 'nearest', + 'linear', + 'bilinear', + 'trilinear', + 'bicubic', + 'area', + ] self.multiplier = multiplier - self.interpolator = partial(torch.nn.functional.interpolate, mode=method) + self.interpolator = partial( + torch.nn.functional.interpolate, mode=method + ) self.remap_output = out_channels is not None if self.remap_output: - print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') - self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) + print( + f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.' + ) + self.channel_mapper = nn.Conv2d( + in_channels, out_channels, 1, bias=bias + ) - def forward(self,x): + def forward(self, x): for stage in range(self.n_stages): x = self.interpolator(x, scale_factor=self.multiplier) - if self.remap_output: x = self.channel_mapper(x) return x @@ -162,57 +223,83 @@ class SpatialRescaler(nn.Module): def encode(self, x): return self(x) + class FrozenCLIPEmbedder(AbstractEncoder): """Uses the CLIP transformer encoder for text (from Hugging Face)""" - def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): + + def __init__( + self, + version='openai/clip-vit-large-patch14', + device='cuda', + max_length=77, + ): super().__init__() - self.tokenizer = CLIPTokenizer.from_pretrained(version,local_files_only=True) - self.transformer = CLIPTextModel.from_pretrained(version,local_files_only=True) + self.tokenizer = CLIPTokenizer.from_pretrained( + version, local_files_only=True + ) + self.transformer = CLIPTextModel.from_pretrained( + version, local_files_only=True + ) self.device = device self.max_length = max_length self.freeze() def embedding_forward( - self, - input_ids = None, - position_ids = None, - inputs_embeds = None, - embedding_manager = None, - ) -> torch.Tensor: + self, + input_ids=None, + position_ids=None, + inputs_embeds=None, + embedding_manager=None, + ) -> torch.Tensor: - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + seq_length = ( + input_ids.shape[-1] + if input_ids is not None + else inputs_embeds.shape[-2] + ) - if position_ids is None: - position_ids = self.position_ids[:, :seq_length] + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] - if inputs_embeds is None: - inputs_embeds = self.token_embedding(input_ids) + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) - if embedding_manager is not None: - inputs_embeds = embedding_manager(input_ids, inputs_embeds) + if embedding_manager is not None: + inputs_embeds = embedding_manager(input_ids, inputs_embeds) + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings - position_embeddings = self.position_embedding(position_ids) - embeddings = inputs_embeds + position_embeddings - - return embeddings + return embeddings - self.transformer.text_model.embeddings.forward = embedding_forward.__get__(self.transformer.text_model.embeddings) + self.transformer.text_model.embeddings.forward = ( + embedding_forward.__get__(self.transformer.text_model.embeddings) + ) def encoder_forward( self, inputs_embeds, - attention_mask = None, - causal_attention_mask = None, - output_attentions = None, - output_hidden_states = None, - return_dict = None, + attention_mask=None, + causal_attention_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict + if return_dict is not None + else self.config.use_return_dict ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -239,44 +326,61 @@ class FrozenCLIPEmbedder(AbstractEncoder): return hidden_states - self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder) - + self.transformer.text_model.encoder.forward = encoder_forward.__get__( + self.transformer.text_model.encoder + ) def text_encoder_forward( self, - input_ids = None, - attention_mask = None, - position_ids = None, - output_attentions = None, - output_hidden_states = None, - return_dict = None, - embedding_manager = None, + input_ids=None, + attention_mask=None, + position_ids=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + embedding_manager=None, ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict + if return_dict is not None + else self.config.use_return_dict ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is None: - raise ValueError("You have to specify either input_ids") + raise ValueError('You have to specify either input_ids') input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) - hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager) + hidden_states = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + embedding_manager=embedding_manager, + ) bsz, seq_len = input_shape # CLIP's text model uses causal mask, prepare it here. # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 - causal_attention_mask = _build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to( - hidden_states.device - ) + causal_attention_mask = _build_causal_attention_mask( + bsz, seq_len, hidden_states.dtype + ).to(hidden_states.device) # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + attention_mask = _expand_mask( + attention_mask, hidden_states.dtype + ) last_hidden_state = self.encoder( inputs_embeds=hidden_states, @@ -291,17 +395,19 @@ class FrozenCLIPEmbedder(AbstractEncoder): return last_hidden_state - self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model) + self.transformer.text_model.forward = text_encoder_forward.__get__( + self.transformer.text_model + ) def transformer_forward( self, - input_ids = None, - attention_mask = None, - position_ids = None, - output_attentions = None, - output_hidden_states = None, - return_dict = None, - embedding_manager = None, + input_ids=None, + attention_mask=None, + position_ids=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + embedding_manager=None, ): return self.text_model( input_ids=input_ids, @@ -310,11 +416,12 @@ class FrozenCLIPEmbedder(AbstractEncoder): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - embedding_manager = embedding_manager + embedding_manager=embedding_manager, ) - self.transformer.forward = transformer_forward.__get__(self.transformer) - + self.transformer.forward = transformer_forward.__get__( + self.transformer + ) def freeze(self): self.transformer = self.transformer.eval() @@ -322,9 +429,16 @@ class FrozenCLIPEmbedder(AbstractEncoder): param.requires_grad = False def forward(self, text, **kwargs): - batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, - return_overflowing_tokens=False, padding="max_length", return_tensors="pt") - tokens = batch_encoding["input_ids"].to(self.device) + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding='max_length', + return_tensors='pt', + ) + tokens = batch_encoding['input_ids'].to(self.device) z = self.transformer(input_ids=tokens, **kwargs) return z @@ -337,9 +451,17 @@ class FrozenCLIPTextEmbedder(nn.Module): """ Uses the CLIP transformer encoder for text. """ - def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): + + def __init__( + self, + version='ViT-L/14', + device='cuda', + max_length=77, + n_repeat=1, + normalize=True, + ): super().__init__() - self.model, _ = clip.load(version, jit=False, device="cpu") + self.model, _ = clip.load(version, jit=False, device='cpu') self.device = device self.max_length = max_length self.n_repeat = n_repeat @@ -359,7 +481,7 @@ class FrozenCLIPTextEmbedder(nn.Module): def encode(self, text): z = self(text) - if z.ndim==2: + if z.ndim == 2: z = z[:, None, :] z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) return z @@ -367,29 +489,42 @@ class FrozenCLIPTextEmbedder(nn.Module): class FrozenClipImageEmbedder(nn.Module): """ - Uses the CLIP image encoder. - """ + Uses the CLIP image encoder. + """ + def __init__( - self, - model, - jit=False, - device='cuda' if torch.cuda.is_available() else 'cpu', - antialias=False, - ): + self, + model, + jit=False, + device='cuda' if torch.cuda.is_available() else 'cpu', + antialias=False, + ): super().__init__() self.model, _ = clip.load(name=model, device=device, jit=jit) self.antialias = antialias - self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) - self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + self.register_buffer( + 'mean', + torch.Tensor([0.48145466, 0.4578275, 0.40821073]), + persistent=False, + ) + self.register_buffer( + 'std', + torch.Tensor([0.26862954, 0.26130258, 0.27577711]), + persistent=False, + ) def preprocess(self, x): # normalize to [0,1] - x = kornia.geometry.resize(x, (224, 224), - interpolation='bicubic',align_corners=True, - antialias=self.antialias) - x = (x + 1.) / 2. + x = kornia.geometry.resize( + x, + (224, 224), + interpolation='bicubic', + align_corners=True, + antialias=self.antialias, + ) + x = (x + 1.0) / 2.0 # renormalize according to clip x = kornia.enhance.normalize(x, self.mean, self.std) return x @@ -399,7 +534,8 @@ class FrozenClipImageEmbedder(nn.Module): return self.model.encode_image(self.preprocess(x)) -if __name__ == "__main__": +if __name__ == '__main__': from ldm.util import count_params + model = FrozenCLIPEmbedder() count_params(model, verbose=True) diff --git a/ldm/modules/image_degradation/__init__.py b/ldm/modules/image_degradation/__init__.py index 7836cada81..c6b3b62ea8 100644 --- a/ldm/modules/image_degradation/__init__.py +++ b/ldm/modules/image_degradation/__init__.py @@ -1,2 +1,6 @@ -from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr -from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light +from ldm.modules.image_degradation.bsrgan import ( + degradation_bsrgan_variant as degradation_fn_bsr, +) +from ldm.modules.image_degradation.bsrgan_light import ( + degradation_bsrgan_variant as degradation_fn_bsr_light, +) diff --git a/ldm/modules/image_degradation/bsrgan.py b/ldm/modules/image_degradation/bsrgan.py index 32ef561699..b51217bd48 100644 --- a/ldm/modules/image_degradation/bsrgan.py +++ b/ldm/modules/image_degradation/bsrgan.py @@ -27,16 +27,16 @@ import ldm.modules.image_degradation.utils_image as util def modcrop_np(img, sf): - ''' + """ Args: img: numpy image, WxH or WxHxC sf: scale factor Return: cropped image - ''' + """ w, h = img.shape[:2] im = np.copy(img) - return im[:w - w % sf, :h - h % sf, ...] + return im[: w - w % sf, : h - h % sf, ...] """ @@ -54,7 +54,9 @@ def analytic_kernel(k): # Loop over the small kernel to fill the big one for r in range(k_size): for c in range(k_size): - big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += ( + k[r, c] * k + ) # Crop the edges of the big kernel to ignore very small values and increase run time of SR crop = k_size // 2 cropped_big_k = big_k[crop:-crop, crop:-crop] @@ -63,7 +65,7 @@ def analytic_kernel(k): def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): - """ generate an anisotropic Gaussian kernel + """generate an anisotropic Gaussian kernel Args: ksize : e.g., 15, kernel size theta : [0, pi], rotation angle range @@ -74,7 +76,12 @@ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): k : kernel """ - v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + v = np.dot( + np.array( + [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]] + ), + np.array([1.0, 0.0]), + ) V = np.array([[v[0], v[1]], [v[1], -v[0]]]) D = np.array([[l1, 0], [0, l2]]) Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) @@ -126,24 +133,32 @@ def shift_pixel(x, sf, upper_left=True): def blur(x, k): - ''' + """ x: image, NxcxHxW k: kernel, Nx1xhxw - ''' + """ n, c = x.shape[:2] p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') k = k.repeat(1, c, 1, 1) k = k.view(-1, 1, k.shape[2], k.shape[3]) x = x.view(1, -1, x.shape[2], x.shape[3]) - x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = torch.nn.functional.conv2d( + x, k, bias=None, stride=1, padding=0, groups=n * c + ) x = x.view(n, c, x.shape[2], x.shape[3]) return x -def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): - """" +def gen_kernel( + k_size=np.array([15, 15]), + scale_factor=np.array([4, 4]), + min_var=0.6, + max_var=10.0, + noise_level=0, +): + """ " # modified version of https://github.com/assafshocher/BlindSR_dataset_generator # Kai Zhang # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var @@ -157,13 +172,16 @@ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var # Set COV matrix using Lambdas and Theta LAMBDA = np.diag([lambda_1, lambda_2]) - Q = np.array([[np.cos(theta), -np.sin(theta)], - [np.sin(theta), np.cos(theta)]]) + Q = np.array( + [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]] + ) SIGMA = Q @ LAMBDA @ Q.T INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] # Set expectation position (shifting kernel for aligned image) - MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = k_size // 2 - 0.5 * ( + scale_factor - 1 + ) # - 0.5 * (scale_factor - k_size % 2) MU = MU[None, None, :, None] # Create meshgrid for Gaussian @@ -188,7 +206,9 @@ def fspecial_gaussian(hsize, sigma): hsize = [hsize, hsize] siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] std = sigma - [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + [x, y] = np.meshgrid( + np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1) + ) arg = -(x * x + y * y) / (2 * std * std) h = np.exp(arg) h[h < scipy.finfo(float).eps * h.max()] = 0 @@ -208,10 +228,10 @@ def fspecial_laplacian(alpha): def fspecial(filter_type, *args, **kwargs): - ''' + """ python code from: https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py - ''' + """ if filter_type == 'gaussian': return fspecial_gaussian(*args, **kwargs) if filter_type == 'laplacian': @@ -226,19 +246,19 @@ def fspecial(filter_type, *args, **kwargs): def bicubic_degradation(x, sf=3): - ''' + """ Args: x: HxWxC image, [0, 1] sf: down-scale factor Return: bicubicly downsampled LR image - ''' + """ x = util.imresize_np(x, scale=1 / sf) return x def srmd_degradation(x, k, sf=3): - ''' blur + bicubic downsampling + """blur + bicubic downsampling Args: x: HxWxC image, [0, 1] k: hxw, double @@ -253,14 +273,16 @@ def srmd_degradation(x, k, sf=3): pages={3262--3271}, year={2018} } - ''' - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + """ + x = ndimage.filters.convolve( + x, np.expand_dims(k, axis=2), mode='wrap' + ) # 'nearest' | 'mirror' x = bicubic_degradation(x, sf=sf) return x def dpsr_degradation(x, k, sf=3): - ''' bicubic downsampling + blur + """bicubic downsampling + blur Args: x: HxWxC image, [0, 1] k: hxw, double @@ -275,21 +297,21 @@ def dpsr_degradation(x, k, sf=3): pages={1671--1681}, year={2019} } - ''' + """ x = bicubic_degradation(x, sf=sf) x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') return x def classical_degradation(x, k, sf=3): - ''' blur + downsampling + """blur + downsampling Args: x: HxWxC image, [0, 1]/[0, 255] k: hxw, double sf: down-scale factor Return: downsampled LR image - ''' + """ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) st = 0 @@ -328,10 +350,19 @@ def add_blur(img, sf=4): if random.random() < 0.5: l1 = wd2 * random.random() l2 = wd2 * random.random() - k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + k = anisotropic_Gaussian( + ksize=2 * random.randint(2, 11) + 3, + theta=random.random() * np.pi, + l1=l1, + l2=l2, + ) else: - k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()) - img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + k = fspecial( + 'gaussian', 2 * random.randint(2, 11) + 3, wd * random.random() + ) + img = ndimage.filters.convolve( + img, np.expand_dims(k, axis=2), mode='mirror' + ) return img @@ -344,7 +375,11 @@ def add_resize(img, sf=4): sf1 = random.uniform(0.5 / sf, 1) else: sf1 = 1.0 - img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = cv2.resize( + img, + (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) img = np.clip(img, 0.0, 1.0) return img @@ -366,19 +401,26 @@ def add_resize(img, sf=4): # img = np.clip(img, 0.0, 1.0) # return img + def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): noise_level = random.randint(noise_level1, noise_level2) rnum = np.random.rand() if rnum > 0.6: # add color Gaussian noise - img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype( + np.float32 + ) elif rnum < 0.4: # add grayscale Gaussian noise - img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + img = img + np.random.normal( + 0, noise_level / 255.0, (*img.shape[:2], 1) + ).astype(np.float32) else: # add noise - L = noise_level2 / 255. + L = noise_level2 / 255.0 D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) - img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = img + np.random.multivariate_normal( + [0, 0, 0], np.abs(L**2 * conv), img.shape[:2] + ).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img @@ -388,28 +430,37 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25): img = np.clip(img, 0.0, 1.0) rnum = random.random() if rnum > 0.6: - img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + img += img * np.random.normal( + 0, noise_level / 255.0, img.shape + ).astype(np.float32) elif rnum < 0.4: - img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + img += img * np.random.normal( + 0, noise_level / 255.0, (*img.shape[:2], 1) + ).astype(np.float32) else: - L = noise_level2 / 255. + L = noise_level2 / 255.0 D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) - img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img += img * np.random.multivariate_normal( + [0, 0, 0], np.abs(L**2 * conv), img.shape[:2] + ).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img def add_Poisson_noise(img): - img = np.clip((img * 255.0).round(), 0, 255) / 255. + img = np.clip((img * 255.0).round(), 0, 255) / 255.0 vals = 10 ** (2 * random.random() + 2.0) # [2, 4] if random.random() < 0.5: img = np.random.poisson(img * vals).astype(np.float32) / vals else: img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) - img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. - noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0 + noise_gray = ( + np.random.poisson(img_gray * vals).astype(np.float32) / vals + - img_gray + ) img += noise_gray[:, :, np.newaxis] img = np.clip(img, 0.0, 1.0) return img @@ -418,7 +469,9 @@ def add_Poisson_noise(img): def add_JPEG_noise(img): quality_factor = random.randint(30, 95) img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) - result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + result, encimg = cv2.imencode( + '.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor] + ) img = cv2.imdecode(encimg, 1) img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) return img @@ -428,10 +481,14 @@ def random_crop(lq, hq, sf=4, lq_patchsize=64): h, w = lq.shape[:2] rnd_h = random.randint(0, h - lq_patchsize) rnd_w = random.randint(0, w - lq_patchsize) - lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :] rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) - hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + hq = hq[ + rnd_h_H : rnd_h_H + lq_patchsize * sf, + rnd_w_H : rnd_w_H + lq_patchsize * sf, + :, + ] return lq, hq @@ -452,7 +509,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): sf_ori = sf h1, w1 = img.shape[:2] - img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop h, w = img.shape[:2] if h < lq_patchsize * sf or w < lq_patchsize * sf: @@ -462,8 +519,11 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): if sf == 4 and random.random() < scale2_prob: # downsample1 if np.random.rand() < 0.5: - img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), - interpolation=random.choice([1, 2, 3])) + img = cv2.resize( + img, + (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) else: img = util.imresize_np(img, 1 / 2, True) img = np.clip(img, 0.0, 1.0) @@ -472,7 +532,10 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): shuffle_order = random.sample(range(7), 7) idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) if idx1 > idx2: # keep downsample3 last - shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + shuffle_order[idx1], shuffle_order[idx2] = ( + shuffle_order[idx2], + shuffle_order[idx1], + ) for i in shuffle_order: @@ -487,19 +550,30 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): # downsample2 if random.random() < 0.75: sf1 = random.uniform(1, 2 * sf) - img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), - interpolation=random.choice([1, 2, 3])) + img = cv2.resize( + img, + (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) else: k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) - k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + k_shifted = ( + k_shifted / k_shifted.sum() + ) # blur with shifted kernel + img = ndimage.filters.convolve( + img, np.expand_dims(k_shifted, axis=2), mode='mirror' + ) img = img[0::sf, 0::sf, ...] # nearest downsampling img = np.clip(img, 0.0, 1.0) elif i == 3: # downsample3 - img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = cv2.resize( + img, + (int(1 / sf * a), int(1 / sf * b)), + interpolation=random.choice([1, 2, 3]), + ) img = np.clip(img, 0.0, 1.0) elif i == 4: @@ -544,15 +618,18 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): sf_ori = sf h1, w1 = image.shape[:2] - image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop h, w = image.shape[:2] hq = image.copy() if sf == 4 and random.random() < scale2_prob: # downsample1 if np.random.rand() < 0.5: - image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), - interpolation=random.choice([1, 2, 3])) + image = cv2.resize( + image, + (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) else: image = util.imresize_np(image, 1 / 2, True) image = np.clip(image, 0.0, 1.0) @@ -561,7 +638,10 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): shuffle_order = random.sample(range(7), 7) idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) if idx1 > idx2: # keep downsample3 last - shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + shuffle_order[idx1], shuffle_order[idx2] = ( + shuffle_order[idx2], + shuffle_order[idx1], + ) for i in shuffle_order: @@ -576,19 +656,33 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): # downsample2 if random.random() < 0.75: sf1 = random.uniform(1, 2 * sf) - image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), - interpolation=random.choice([1, 2, 3])) + image = cv2.resize( + image, + ( + int(1 / sf1 * image.shape[1]), + int(1 / sf1 * image.shape[0]), + ), + interpolation=random.choice([1, 2, 3]), + ) else: k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) - k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + k_shifted = ( + k_shifted / k_shifted.sum() + ) # blur with shifted kernel + image = ndimage.filters.convolve( + image, np.expand_dims(k_shifted, axis=2), mode='mirror' + ) image = image[0::sf, 0::sf, ...] # nearest downsampling image = np.clip(image, 0.0, 1.0) elif i == 3: # downsample3 - image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = cv2.resize( + image, + (int(1 / sf * a), int(1 / sf * b)), + interpolation=random.choice([1, 2, 3]), + ) image = np.clip(image, 0.0, 1.0) elif i == 4: @@ -609,12 +703,19 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): # add final JPEG compression noise image = add_JPEG_noise(image) image = util.single2uint(image) - example = {"image":image} + example = {'image': image} return example # TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc... -def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None): +def degradation_bsrgan_plus( + img, + sf=4, + shuffle_prob=0.5, + use_sharp=True, + lq_patchsize=64, + isp_model=None, +): """ This is an extended degradation model by combining the degradation models of BSRGAN and Real-ESRGAN @@ -630,7 +731,7 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc """ h1, w1 = img.shape[:2] - img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop h, w = img.shape[:2] if h < lq_patchsize * sf or w < lq_patchsize * sf: @@ -645,8 +746,12 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc else: shuffle_order = list(range(13)) # local shuffle for noise, JPEG is always the last one - shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6))) - shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13))) + shuffle_order[2:6] = random.sample( + shuffle_order[2:6], len(range(2, 6)) + ) + shuffle_order[9:13] = random.sample( + shuffle_order[9:13], len(range(9, 13)) + ) poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1 @@ -689,8 +794,11 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc print('check the shuffle!') # resize to desired size - img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), - interpolation=random.choice([1, 2, 3])) + img = cv2.resize( + img, + (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) # add final JPEG compression noise img = add_JPEG_noise(img) @@ -702,29 +810,37 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc if __name__ == '__main__': - print("hey") - img = util.imread_uint('utils/test.png', 3) - print(img) - img = util.uint2single(img) - print(img) - img = img[:448, :448] - h = img.shape[0] // 4 - print("resizing to", h) - sf = 4 - deg_fn = partial(degradation_bsrgan_variant, sf=sf) - for i in range(20): - print(i) - img_lq = deg_fn(img) - print(img_lq) - img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] - print(img_lq.shape) - print("bicubic", img_lq_bicubic.shape) - print(img_hq.shape) - lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) - lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) - img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) - util.imsave(img_concat, str(i) + '.png') - - + print('hey') + img = util.imread_uint('utils/test.png', 3) + print(img) + img = util.uint2single(img) + print(img) + img = img[:448, :448] + h = img.shape[0] // 4 + print('resizing to', h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_lq = deg_fn(img) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize( + max_size=h, interpolation=cv2.INTER_CUBIC + )(image=img)['image'] + print(img_lq.shape) + print('bicubic', img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize( + util.single2uint(img_lq), + (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0, + ) + lq_bicubic_nearest = cv2.resize( + util.single2uint(img_lq_bicubic), + (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0, + ) + img_concat = np.concatenate( + [lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1 + ) + util.imsave(img_concat, str(i) + '.png') diff --git a/ldm/modules/image_degradation/bsrgan_light.py b/ldm/modules/image_degradation/bsrgan_light.py index 9e1f823996..3500ef7316 100644 --- a/ldm/modules/image_degradation/bsrgan_light.py +++ b/ldm/modules/image_degradation/bsrgan_light.py @@ -27,16 +27,16 @@ import ldm.modules.image_degradation.utils_image as util def modcrop_np(img, sf): - ''' + """ Args: img: numpy image, WxH or WxHxC sf: scale factor Return: cropped image - ''' + """ w, h = img.shape[:2] im = np.copy(img) - return im[:w - w % sf, :h - h % sf, ...] + return im[: w - w % sf, : h - h % sf, ...] """ @@ -54,7 +54,9 @@ def analytic_kernel(k): # Loop over the small kernel to fill the big one for r in range(k_size): for c in range(k_size): - big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += ( + k[r, c] * k + ) # Crop the edges of the big kernel to ignore very small values and increase run time of SR crop = k_size // 2 cropped_big_k = big_k[crop:-crop, crop:-crop] @@ -63,7 +65,7 @@ def analytic_kernel(k): def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): - """ generate an anisotropic Gaussian kernel + """generate an anisotropic Gaussian kernel Args: ksize : e.g., 15, kernel size theta : [0, pi], rotation angle range @@ -74,7 +76,12 @@ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): k : kernel """ - v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + v = np.dot( + np.array( + [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]] + ), + np.array([1.0, 0.0]), + ) V = np.array([[v[0], v[1]], [v[1], -v[0]]]) D = np.array([[l1, 0], [0, l2]]) Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) @@ -126,24 +133,32 @@ def shift_pixel(x, sf, upper_left=True): def blur(x, k): - ''' + """ x: image, NxcxHxW k: kernel, Nx1xhxw - ''' + """ n, c = x.shape[:2] p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') k = k.repeat(1, c, 1, 1) k = k.view(-1, 1, k.shape[2], k.shape[3]) x = x.view(1, -1, x.shape[2], x.shape[3]) - x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = torch.nn.functional.conv2d( + x, k, bias=None, stride=1, padding=0, groups=n * c + ) x = x.view(n, c, x.shape[2], x.shape[3]) return x -def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): - """" +def gen_kernel( + k_size=np.array([15, 15]), + scale_factor=np.array([4, 4]), + min_var=0.6, + max_var=10.0, + noise_level=0, +): + """ " # modified version of https://github.com/assafshocher/BlindSR_dataset_generator # Kai Zhang # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var @@ -157,13 +172,16 @@ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var # Set COV matrix using Lambdas and Theta LAMBDA = np.diag([lambda_1, lambda_2]) - Q = np.array([[np.cos(theta), -np.sin(theta)], - [np.sin(theta), np.cos(theta)]]) + Q = np.array( + [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]] + ) SIGMA = Q @ LAMBDA @ Q.T INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] # Set expectation position (shifting kernel for aligned image) - MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = k_size // 2 - 0.5 * ( + scale_factor - 1 + ) # - 0.5 * (scale_factor - k_size % 2) MU = MU[None, None, :, None] # Create meshgrid for Gaussian @@ -188,7 +206,9 @@ def fspecial_gaussian(hsize, sigma): hsize = [hsize, hsize] siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] std = sigma - [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + [x, y] = np.meshgrid( + np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1) + ) arg = -(x * x + y * y) / (2 * std * std) h = np.exp(arg) h[h < scipy.finfo(float).eps * h.max()] = 0 @@ -208,10 +228,10 @@ def fspecial_laplacian(alpha): def fspecial(filter_type, *args, **kwargs): - ''' + """ python code from: https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py - ''' + """ if filter_type == 'gaussian': return fspecial_gaussian(*args, **kwargs) if filter_type == 'laplacian': @@ -226,19 +246,19 @@ def fspecial(filter_type, *args, **kwargs): def bicubic_degradation(x, sf=3): - ''' + """ Args: x: HxWxC image, [0, 1] sf: down-scale factor Return: bicubicly downsampled LR image - ''' + """ x = util.imresize_np(x, scale=1 / sf) return x def srmd_degradation(x, k, sf=3): - ''' blur + bicubic downsampling + """blur + bicubic downsampling Args: x: HxWxC image, [0, 1] k: hxw, double @@ -253,14 +273,16 @@ def srmd_degradation(x, k, sf=3): pages={3262--3271}, year={2018} } - ''' - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + """ + x = ndimage.filters.convolve( + x, np.expand_dims(k, axis=2), mode='wrap' + ) # 'nearest' | 'mirror' x = bicubic_degradation(x, sf=sf) return x def dpsr_degradation(x, k, sf=3): - ''' bicubic downsampling + blur + """bicubic downsampling + blur Args: x: HxWxC image, [0, 1] k: hxw, double @@ -275,21 +297,21 @@ def dpsr_degradation(x, k, sf=3): pages={1671--1681}, year={2019} } - ''' + """ x = bicubic_degradation(x, sf=sf) x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') return x def classical_degradation(x, k, sf=3): - ''' blur + downsampling + """blur + downsampling Args: x: HxWxC image, [0, 1]/[0, 255] k: hxw, double sf: down-scale factor Return: downsampled LR image - ''' + """ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) st = 0 @@ -326,16 +348,25 @@ def add_blur(img, sf=4): wd2 = 4.0 + sf wd = 2.0 + 0.2 * sf - wd2 = wd2/4 - wd = wd/4 + wd2 = wd2 / 4 + wd = wd / 4 if random.random() < 0.5: l1 = wd2 * random.random() l2 = wd2 * random.random() - k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + k = anisotropic_Gaussian( + ksize=random.randint(2, 11) + 3, + theta=random.random() * np.pi, + l1=l1, + l2=l2, + ) else: - k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random()) - img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + k = fspecial( + 'gaussian', random.randint(2, 4) + 3, wd * random.random() + ) + img = ndimage.filters.convolve( + img, np.expand_dims(k, axis=2), mode='mirror' + ) return img @@ -348,7 +379,11 @@ def add_resize(img, sf=4): sf1 = random.uniform(0.5 / sf, 1) else: sf1 = 1.0 - img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = cv2.resize( + img, + (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) img = np.clip(img, 0.0, 1.0) return img @@ -370,19 +405,26 @@ def add_resize(img, sf=4): # img = np.clip(img, 0.0, 1.0) # return img + def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): noise_level = random.randint(noise_level1, noise_level2) rnum = np.random.rand() if rnum > 0.6: # add color Gaussian noise - img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype( + np.float32 + ) elif rnum < 0.4: # add grayscale Gaussian noise - img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + img = img + np.random.normal( + 0, noise_level / 255.0, (*img.shape[:2], 1) + ).astype(np.float32) else: # add noise - L = noise_level2 / 255. + L = noise_level2 / 255.0 D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) - img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = img + np.random.multivariate_normal( + [0, 0, 0], np.abs(L**2 * conv), img.shape[:2] + ).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img @@ -392,28 +434,37 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25): img = np.clip(img, 0.0, 1.0) rnum = random.random() if rnum > 0.6: - img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + img += img * np.random.normal( + 0, noise_level / 255.0, img.shape + ).astype(np.float32) elif rnum < 0.4: - img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + img += img * np.random.normal( + 0, noise_level / 255.0, (*img.shape[:2], 1) + ).astype(np.float32) else: - L = noise_level2 / 255. + L = noise_level2 / 255.0 D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) - img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img += img * np.random.multivariate_normal( + [0, 0, 0], np.abs(L**2 * conv), img.shape[:2] + ).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img def add_Poisson_noise(img): - img = np.clip((img * 255.0).round(), 0, 255) / 255. + img = np.clip((img * 255.0).round(), 0, 255) / 255.0 vals = 10 ** (2 * random.random() + 2.0) # [2, 4] if random.random() < 0.5: img = np.random.poisson(img * vals).astype(np.float32) / vals else: img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) - img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. - noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0 + noise_gray = ( + np.random.poisson(img_gray * vals).astype(np.float32) / vals + - img_gray + ) img += noise_gray[:, :, np.newaxis] img = np.clip(img, 0.0, 1.0) return img @@ -422,7 +473,9 @@ def add_Poisson_noise(img): def add_JPEG_noise(img): quality_factor = random.randint(80, 95) img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) - result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + result, encimg = cv2.imencode( + '.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor] + ) img = cv2.imdecode(encimg, 1) img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) return img @@ -432,10 +485,14 @@ def random_crop(lq, hq, sf=4, lq_patchsize=64): h, w = lq.shape[:2] rnd_h = random.randint(0, h - lq_patchsize) rnd_w = random.randint(0, w - lq_patchsize) - lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :] rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) - hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + hq = hq[ + rnd_h_H : rnd_h_H + lq_patchsize * sf, + rnd_w_H : rnd_w_H + lq_patchsize * sf, + :, + ] return lq, hq @@ -456,7 +513,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): sf_ori = sf h1, w1 = img.shape[:2] - img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop h, w = img.shape[:2] if h < lq_patchsize * sf or w < lq_patchsize * sf: @@ -466,8 +523,11 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): if sf == 4 and random.random() < scale2_prob: # downsample1 if np.random.rand() < 0.5: - img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), - interpolation=random.choice([1, 2, 3])) + img = cv2.resize( + img, + (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) else: img = util.imresize_np(img, 1 / 2, True) img = np.clip(img, 0.0, 1.0) @@ -476,7 +536,10 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): shuffle_order = random.sample(range(7), 7) idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) if idx1 > idx2: # keep downsample3 last - shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + shuffle_order[idx1], shuffle_order[idx2] = ( + shuffle_order[idx2], + shuffle_order[idx1], + ) for i in shuffle_order: @@ -491,19 +554,30 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): # downsample2 if random.random() < 0.75: sf1 = random.uniform(1, 2 * sf) - img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), - interpolation=random.choice([1, 2, 3])) + img = cv2.resize( + img, + (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) else: k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) - k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + k_shifted = ( + k_shifted / k_shifted.sum() + ) # blur with shifted kernel + img = ndimage.filters.convolve( + img, np.expand_dims(k_shifted, axis=2), mode='mirror' + ) img = img[0::sf, 0::sf, ...] # nearest downsampling img = np.clip(img, 0.0, 1.0) elif i == 3: # downsample3 - img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = cv2.resize( + img, + (int(1 / sf * a), int(1 / sf * b)), + interpolation=random.choice([1, 2, 3]), + ) img = np.clip(img, 0.0, 1.0) elif i == 4: @@ -548,15 +622,18 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): sf_ori = sf h1, w1 = image.shape[:2] - image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop h, w = image.shape[:2] hq = image.copy() if sf == 4 and random.random() < scale2_prob: # downsample1 if np.random.rand() < 0.5: - image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), - interpolation=random.choice([1, 2, 3])) + image = cv2.resize( + image, + (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) else: image = util.imresize_np(image, 1 / 2, True) image = np.clip(image, 0.0, 1.0) @@ -565,7 +642,10 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): shuffle_order = random.sample(range(7), 7) idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) if idx1 > idx2: # keep downsample3 last - shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + shuffle_order[idx1], shuffle_order[idx2] = ( + shuffle_order[idx2], + shuffle_order[idx1], + ) for i in shuffle_order: @@ -583,20 +663,34 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): # downsample2 if random.random() < 0.8: sf1 = random.uniform(1, 2 * sf) - image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), - interpolation=random.choice([1, 2, 3])) + image = cv2.resize( + image, + ( + int(1 / sf1 * image.shape[1]), + int(1 / sf1 * image.shape[0]), + ), + interpolation=random.choice([1, 2, 3]), + ) else: k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) - k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + k_shifted = ( + k_shifted / k_shifted.sum() + ) # blur with shifted kernel + image = ndimage.filters.convolve( + image, np.expand_dims(k_shifted, axis=2), mode='mirror' + ) image = image[0::sf, 0::sf, ...] # nearest downsampling image = np.clip(image, 0.0, 1.0) elif i == 3: # downsample3 - image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = cv2.resize( + image, + (int(1 / sf * a), int(1 / sf * b)), + interpolation=random.choice([1, 2, 3]), + ) image = np.clip(image, 0.0, 1.0) elif i == 4: @@ -617,34 +711,41 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): # add final JPEG compression noise image = add_JPEG_noise(image) image = util.single2uint(image) - example = {"image": image} + example = {'image': image} return example - - if __name__ == '__main__': - print("hey") + print('hey') img = util.imread_uint('utils/test.png', 3) img = img[:448, :448] h = img.shape[0] // 4 - print("resizing to", h) + print('resizing to', h) sf = 4 deg_fn = partial(degradation_bsrgan_variant, sf=sf) for i in range(20): print(i) img_hq = img - img_lq = deg_fn(img)["image"] + img_lq = deg_fn(img)['image'] img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq) print(img_lq) - img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"] + img_lq_bicubic = albumentations.SmallestMaxSize( + max_size=h, interpolation=cv2.INTER_CUBIC + )(image=img_hq)['image'] print(img_lq.shape) - print("bicubic", img_lq_bicubic.shape) + print('bicubic', img_lq_bicubic.shape) print(img_hq.shape) - lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) - lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), - (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) - img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + lq_nearest = cv2.resize( + util.single2uint(img_lq), + (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0, + ) + lq_bicubic_nearest = cv2.resize( + util.single2uint(img_lq_bicubic), + (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0, + ) + img_concat = np.concatenate( + [lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1 + ) util.imsave(img_concat, str(i) + '.png') diff --git a/ldm/modules/image_degradation/utils_image.py b/ldm/modules/image_degradation/utils_image.py index 0175f155ad..4b6e64658a 100644 --- a/ldm/modules/image_degradation/utils_image.py +++ b/ldm/modules/image_degradation/utils_image.py @@ -6,13 +6,14 @@ import torch import cv2 from torchvision.utils import make_grid from datetime import datetime -#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py + +# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py -os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" +os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' -''' +""" # -------------------------------------------- # Kai Zhang (github: https://github.com/cszn) # 03/Mar/2019 @@ -20,10 +21,22 @@ os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" # https://github.com/twhui/SRGAN-pyTorch # https://github.com/xinntao/BasicSR # -------------------------------------------- -''' +""" -IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] +IMG_EXTENSIONS = [ + '.jpg', + '.JPG', + '.jpeg', + '.JPEG', + '.png', + '.PNG', + '.ppm', + '.PPM', + '.bmp', + '.BMP', + '.tif', +] def is_image_file(filename): @@ -49,19 +62,19 @@ def surf(Z, cmap='rainbow', figsize=None): ax3 = plt.axes(projection='3d') w, h = Z.shape[:2] - xx = np.arange(0,w,1) - yy = np.arange(0,h,1) + xx = np.arange(0, w, 1) + yy = np.arange(0, h, 1) X, Y = np.meshgrid(xx, yy) - ax3.plot_surface(X,Y,Z,cmap=cmap) - #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) + ax3.plot_surface(X, Y, Z, cmap=cmap) + # ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) plt.show() -''' +""" # -------------------------------------------- # get image pathes # -------------------------------------------- -''' +""" def get_image_paths(dataroot): @@ -83,26 +96,26 @@ def _get_paths_from_images(path): return images -''' +""" # -------------------------------------------- # split large images into small images # -------------------------------------------- -''' +""" def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): w, h = img.shape[:2] patches = [] if w > p_max and h > p_max: - w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int)) - h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int)) - w1.append(w-p_size) - h1.append(h-p_size) -# print(w1) -# print(h1) + w1 = list(np.arange(0, w - p_size, p_size - p_overlap, dtype=np.int)) + h1 = list(np.arange(0, h - p_size, p_size - p_overlap, dtype=np.int)) + w1.append(w - p_size) + h1.append(h - p_size) + # print(w1) + # print(h1) for i in w1: for j in h1: - patches.append(img[i:i+p_size, j:j+p_size,:]) + patches.append(img[i : i + p_size, j : j + p_size, :]) else: patches.append(img) @@ -118,11 +131,21 @@ def imssave(imgs, img_path): for i, img in enumerate(imgs): if img.ndim == 3: img = img[:, :, [2, 1, 0]] - new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png') + new_path = os.path.join( + os.path.dirname(img_path), + img_name + str('_s{:04d}'.format(i)) + '.png', + ) cv2.imwrite(new_path, img) -def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000): +def split_imageset( + original_dataroot, + taget_dataroot, + n_channels=3, + p_size=800, + p_overlap=96, + p_max=1000, +): """ split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size), and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max) @@ -139,15 +162,18 @@ def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, # img_name, ext = os.path.splitext(os.path.basename(img_path)) img = imread_uint(img_path, n_channels=n_channels) patches = patches_from_image(img, p_size, p_overlap, p_max) - imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path))) - #if original_dataroot == taget_dataroot: - #del img_path + imssave( + patches, os.path.join(taget_dataroot, os.path.basename(img_path)) + ) + # if original_dataroot == taget_dataroot: + # del img_path -''' + +""" # -------------------------------------------- # makedir # -------------------------------------------- -''' +""" def mkdir(path): @@ -171,12 +197,12 @@ def mkdir_and_rename(path): os.makedirs(path) -''' +""" # -------------------------------------------- # read image from path # opencv is fast, but read BGR numpy image # -------------------------------------------- -''' +""" # -------------------------------------------- @@ -206,6 +232,7 @@ def imsave(img, img_path): img = img[:, :, [2, 1, 0]] cv2.imwrite(img_path, img) + def imwrite(img, img_path): img = np.squeeze(img) if img.ndim == 3: @@ -213,7 +240,6 @@ def imwrite(img, img_path): cv2.imwrite(img_path, img) - # -------------------------------------------- # get single image of size HxWxn_channles (BGR) # -------------------------------------------- @@ -221,7 +247,7 @@ def read_img(path): # read image by cv2 # return: Numpy float32, HWC, BGR, [0,1] img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE - img = img.astype(np.float32) / 255. + img = img.astype(np.float32) / 255.0 if img.ndim == 2: img = np.expand_dims(img, axis=2) # some images have 4 channels @@ -230,7 +256,7 @@ def read_img(path): return img -''' +""" # -------------------------------------------- # image format conversion # -------------------------------------------- @@ -238,7 +264,7 @@ def read_img(path): # numpy(single) <---> tensor # numpy(unit) <---> tensor # -------------------------------------------- -''' +""" # -------------------------------------------- @@ -248,22 +274,22 @@ def read_img(path): def uint2single(img): - return np.float32(img/255.) + return np.float32(img / 255.0) def single2uint(img): - return np.uint8((img.clip(0, 1)*255.).round()) + return np.uint8((img.clip(0, 1) * 255.0).round()) def uint162single(img): - return np.float32(img/65535.) + return np.float32(img / 65535.0) def single2uint16(img): - return np.uint16((img.clip(0, 1)*65535.).round()) + return np.uint16((img.clip(0, 1) * 65535.0).round()) # -------------------------------------------- @@ -275,14 +301,25 @@ def single2uint16(img): def uint2tensor4(img): if img.ndim == 2: img = np.expand_dims(img, axis=2) - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) + return ( + torch.from_numpy(np.ascontiguousarray(img)) + .permute(2, 0, 1) + .float() + .div(255.0) + .unsqueeze(0) + ) # convert uint to 3-dimensional torch tensor def uint2tensor3(img): if img.ndim == 2: img = np.expand_dims(img, axis=2) - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) + return ( + torch.from_numpy(np.ascontiguousarray(img)) + .permute(2, 0, 1) + .float() + .div(255.0) + ) # convert 2/3/4-dimensional torch tensor to uint @@ -290,7 +327,7 @@ def tensor2uint(img): img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() if img.ndim == 3: img = np.transpose(img, (1, 2, 0)) - return np.uint8((img*255.0).round()) + return np.uint8((img * 255.0).round()) # -------------------------------------------- @@ -305,7 +342,12 @@ def single2tensor3(img): # convert single (HxWxC) to 4-dimensional torch tensor def single2tensor4(img): - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0) + return ( + torch.from_numpy(np.ascontiguousarray(img)) + .permute(2, 0, 1) + .float() + .unsqueeze(0) + ) # convert torch tensor to single @@ -316,6 +358,7 @@ def tensor2single(img): return img + # convert torch tensor to single def tensor2single3(img): img = img.data.squeeze().float().cpu().numpy() @@ -327,30 +370,48 @@ def tensor2single3(img): def single2tensor5(img): - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0) + return ( + torch.from_numpy(np.ascontiguousarray(img)) + .permute(2, 0, 1, 3) + .float() + .unsqueeze(0) + ) def single32tensor5(img): - return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0) + return ( + torch.from_numpy(np.ascontiguousarray(img)) + .float() + .unsqueeze(0) + .unsqueeze(0) + ) def single42tensor4(img): - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float() + return ( + torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float() + ) # from skimage.io import imread, imsave def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): - ''' + """ Converts a torch Tensor into an image Numpy array of BGR channel order Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) - ''' - tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp - tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] + """ + tensor = ( + tensor.squeeze().float().cpu().clamp_(*min_max) + ) # squeeze first, then clamp + tensor = (tensor - min_max[0]) / ( + min_max[1] - min_max[0] + ) # to range [0,1] n_dim = tensor.dim() if n_dim == 4: n_img = len(tensor) - img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() + img_np = make_grid( + tensor, nrow=int(math.sqrt(n_img)), normalize=False + ).numpy() img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR elif n_dim == 3: img_np = tensor.numpy() @@ -359,14 +420,17 @@ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): img_np = tensor.numpy() else: raise TypeError( - 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) + 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format( + n_dim + ) + ) if out_type == np.uint8: img_np = (img_np * 255.0).round() # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. return img_np.astype(out_type) -''' +""" # -------------------------------------------- # Augmentation, flipe and/or rotate # -------------------------------------------- @@ -374,12 +438,11 @@ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): # (1) augmet_img: numpy image of WxHxC or WxH # (2) augment_img_tensor4: tensor image 1xCxWxH # -------------------------------------------- -''' +""" def augment_img(img, mode=0): - '''Kai Zhang (github: https://github.com/cszn) - ''' + """Kai Zhang (github: https://github.com/cszn)""" if mode == 0: return img elif mode == 1: @@ -399,8 +462,7 @@ def augment_img(img, mode=0): def augment_img_tensor4(img, mode=0): - '''Kai Zhang (github: https://github.com/cszn) - ''' + """Kai Zhang (github: https://github.com/cszn)""" if mode == 0: return img elif mode == 1: @@ -420,8 +482,7 @@ def augment_img_tensor4(img, mode=0): def augment_img_tensor(img, mode=0): - '''Kai Zhang (github: https://github.com/cszn) - ''' + """Kai Zhang (github: https://github.com/cszn)""" img_size = img.size() img_np = img.data.cpu().numpy() if len(img_size) == 3: @@ -484,11 +545,11 @@ def augment_imgs(img_list, hflip=True, rot=True): return [_augment(img) for img in img_list] -''' +""" # -------------------------------------------- # modcrop and shave # -------------------------------------------- -''' +""" def modcrop(img_in, scale): @@ -497,11 +558,11 @@ def modcrop(img_in, scale): if img.ndim == 2: H, W = img.shape H_r, W_r = H % scale, W % scale - img = img[:H - H_r, :W - W_r] + img = img[: H - H_r, : W - W_r] elif img.ndim == 3: H, W, C = img.shape H_r, W_r = H % scale, W % scale - img = img[:H - H_r, :W - W_r, :] + img = img[: H - H_r, : W - W_r, :] else: raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) return img @@ -511,11 +572,11 @@ def shave(img_in, border=0): # img_in: Numpy, HWC or HW img = np.copy(img_in) h, w = img.shape[:2] - img = img[border:h-border, border:w-border] + img = img[border : h - border, border : w - border] return img -''' +""" # -------------------------------------------- # image processing process on numpy image # channel_convert(in_c, tar_type, img_list): @@ -523,74 +584,92 @@ def shave(img_in, border=0): # bgr2ycbcr(img, only_y=True): # ycbcr2rgb(img): # -------------------------------------------- -''' +""" def rgb2ycbcr(img, only_y=True): - '''same as matlab rgb2ycbcr + """same as matlab rgb2ycbcr only_y: only return Y channel Input: uint8, [0, 255] float, [0, 1] - ''' + """ in_img_type = img.dtype img.astype(np.float32) if in_img_type != np.uint8: - img *= 255. + img *= 255.0 # convert if only_y: rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 else: - rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], - [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + rlt = np.matmul( + img, + [ + [65.481, -37.797, 112.0], + [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214], + ], + ) / 255.0 + [16, 128, 128] if in_img_type == np.uint8: rlt = rlt.round() else: - rlt /= 255. + rlt /= 255.0 return rlt.astype(in_img_type) def ycbcr2rgb(img): - '''same as matlab ycbcr2rgb + """same as matlab ycbcr2rgb Input: uint8, [0, 255] float, [0, 1] - ''' + """ in_img_type = img.dtype img.astype(np.float32) if in_img_type != np.uint8: - img *= 255. + img *= 255.0 # convert - rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], - [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] + rlt = np.matmul( + img, + [ + [0.00456621, 0.00456621, 0.00456621], + [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0], + ], + ) * 255.0 + [-222.921, 135.576, -276.836] if in_img_type == np.uint8: rlt = rlt.round() else: - rlt /= 255. + rlt /= 255.0 return rlt.astype(in_img_type) def bgr2ycbcr(img, only_y=True): - '''bgr version of rgb2ycbcr + """bgr version of rgb2ycbcr only_y: only return Y channel Input: uint8, [0, 255] float, [0, 1] - ''' + """ in_img_type = img.dtype img.astype(np.float32) if in_img_type != np.uint8: - img *= 255. + img *= 255.0 # convert if only_y: rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 else: - rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], - [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + rlt = np.matmul( + img, + [ + [24.966, 112.0, -18.214], + [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0], + ], + ) / 255.0 + [16, 128, 128] if in_img_type == np.uint8: rlt = rlt.round() else: - rlt /= 255. + rlt /= 255.0 return rlt.astype(in_img_type) @@ -608,11 +687,11 @@ def channel_convert(in_c, tar_type, img_list): return img_list -''' +""" # -------------------------------------------- # metric, PSNR and SSIM # -------------------------------------------- -''' +""" # -------------------------------------------- @@ -620,17 +699,17 @@ def channel_convert(in_c, tar_type, img_list): # -------------------------------------------- def calculate_psnr(img1, img2, border=0): # img1 and img2 have range [0, 255] - #img1 = img1.squeeze() - #img2 = img2.squeeze() + # img1 = img1.squeeze() + # img2 = img2.squeeze() if not img1.shape == img2.shape: raise ValueError('Input images must have the same dimensions.') h, w = img1.shape[:2] - img1 = img1[border:h-border, border:w-border] - img2 = img2[border:h-border, border:w-border] + img1 = img1[border : h - border, border : w - border] + img2 = img2[border : h - border, border : w - border] img1 = img1.astype(np.float64) img2 = img2.astype(np.float64) - mse = np.mean((img1 - img2)**2) + mse = np.mean((img1 - img2) ** 2) if mse == 0: return float('inf') return 20 * math.log10(255.0 / math.sqrt(mse)) @@ -640,17 +719,17 @@ def calculate_psnr(img1, img2, border=0): # SSIM # -------------------------------------------- def calculate_ssim(img1, img2, border=0): - '''calculate SSIM + """calculate SSIM the same outputs as MATLAB's img1, img2: [0, 255] - ''' - #img1 = img1.squeeze() - #img2 = img2.squeeze() + """ + # img1 = img1.squeeze() + # img2 = img2.squeeze() if not img1.shape == img2.shape: raise ValueError('Input images must have the same dimensions.') h, w = img1.shape[:2] - img1 = img1[border:h-border, border:w-border] - img2 = img2[border:h-border, border:w-border] + img1 = img1[border : h - border, border : w - border] + img2 = img2[border : h - border, border : w - border] if img1.ndim == 2: return ssim(img1, img2) @@ -658,7 +737,7 @@ def calculate_ssim(img1, img2, border=0): if img1.shape[2] == 3: ssims = [] for i in range(3): - ssims.append(ssim(img1[:,:,i], img2[:,:,i])) + ssims.append(ssim(img1[:, :, i], img2[:, :, i])) return np.array(ssims).mean() elif img1.shape[2] == 1: return ssim(np.squeeze(img1), np.squeeze(img2)) @@ -667,8 +746,8 @@ def calculate_ssim(img1, img2, border=0): def ssim(img1, img2): - C1 = (0.01 * 255)**2 - C2 = (0.03 * 255)**2 + C1 = (0.01 * 255) ** 2 + C2 = (0.03 * 255) ** 2 img1 = img1.astype(np.float64) img2 = img2.astype(np.float64) @@ -684,16 +763,17 @@ def ssim(img1, img2): sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 - ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * - (sigma1_sq + sigma2_sq + C2)) + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( + (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) + ) return ssim_map.mean() -''' +""" # -------------------------------------------- # matlab's bicubic imresize (numpy and torch) [0, 1] # -------------------------------------------- -''' +""" # matlab 'imresize' function, now only support 'bicubic' @@ -701,11 +781,14 @@ def cubic(x): absx = torch.abs(x) absx2 = absx**2 absx3 = absx**3 - return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ - (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) + return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + ( + -0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2 + ) * (((absx > 1) * (absx <= 2)).type_as(absx)) -def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): +def calculate_weights_indices( + in_length, out_length, scale, kernel, kernel_width, antialiasing +): if (scale < 1) and (antialiasing): # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width kernel_width = kernel_width / scale @@ -729,8 +812,9 @@ def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width # The indices of the input pixels involved in computing the k-th output # pixel are in row k of the indices matrix. - indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( - 1, P).expand(out_length, P) + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace( + 0, P - 1, P + ).view(1, P).expand(out_length, P) # The weights used to compute the k-th output pixel are in row k of the # weights matrix. @@ -771,7 +855,11 @@ def imresize(img, scale, antialiasing=True): if need_squeeze: img.unsqueeze_(0) in_C, in_H, in_W = img.size() - out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + out_C, out_H, out_W = ( + in_C, + math.ceil(in_H * scale), + math.ceil(in_W * scale), + ) kernel_width = 4 kernel = 'cubic' @@ -782,9 +870,11 @@ def imresize(img, scale, antialiasing=True): # get weights and indices weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( - in_H, out_H, scale, kernel, kernel_width, antialiasing) + in_H, out_H, scale, kernel, kernel_width, antialiasing + ) weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( - in_W, out_W, scale, kernel, kernel_width, antialiasing) + in_W, out_W, scale, kernel, kernel_width, antialiasing + ) # process H dimension # symmetric copying img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) @@ -805,7 +895,11 @@ def imresize(img, scale, antialiasing=True): for i in range(out_H): idx = int(indices_H[i][0]) for j in range(out_C): - out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + out_1[j, i, :] = ( + img_aug[j, idx : idx + kernel_width, :] + .transpose(0, 1) + .mv(weights_H[i]) + ) # process W dimension # symmetric copying @@ -827,7 +921,9 @@ def imresize(img, scale, antialiasing=True): for i in range(out_W): idx = int(indices_W[i][0]) for j in range(out_C): - out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i]) + out_2[j, :, i] = out_1_aug[j, :, idx : idx + kernel_width].mv( + weights_W[i] + ) if need_squeeze: out_2.squeeze_() return out_2 @@ -846,7 +942,11 @@ def imresize_np(img, scale, antialiasing=True): img.unsqueeze_(2) in_H, in_W, in_C = img.size() - out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + out_C, out_H, out_W = ( + in_C, + math.ceil(in_H * scale), + math.ceil(in_W * scale), + ) kernel_width = 4 kernel = 'cubic' @@ -857,9 +957,11 @@ def imresize_np(img, scale, antialiasing=True): # get weights and indices weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( - in_H, out_H, scale, kernel, kernel_width, antialiasing) + in_H, out_H, scale, kernel, kernel_width, antialiasing + ) weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( - in_W, out_W, scale, kernel, kernel_width, antialiasing) + in_W, out_W, scale, kernel, kernel_width, antialiasing + ) # process H dimension # symmetric copying img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) @@ -880,7 +982,11 @@ def imresize_np(img, scale, antialiasing=True): for i in range(out_H): idx = int(indices_H[i][0]) for j in range(out_C): - out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) + out_1[i, :, j] = ( + img_aug[idx : idx + kernel_width, :, j] + .transpose(0, 1) + .mv(weights_H[i]) + ) # process W dimension # symmetric copying @@ -902,7 +1008,9 @@ def imresize_np(img, scale, antialiasing=True): for i in range(out_W): idx = int(indices_W[i][0]) for j in range(out_C): - out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) + out_2[:, i, j] = out_1_aug[:, idx : idx + kernel_width, j].mv( + weights_W[i] + ) if need_squeeze: out_2.squeeze_() @@ -913,4 +1021,4 @@ if __name__ == '__main__': print('---') # img = imread_uint('test.bmp', 3) # img = uint2single(img) -# img_bicubic = imresize_np(img, 1/4) \ No newline at end of file +# img_bicubic = imresize_np(img, 1/4) diff --git a/ldm/modules/losses/__init__.py b/ldm/modules/losses/__init__.py index 876d7c5bd6..d86294210c 100644 --- a/ldm/modules/losses/__init__.py +++ b/ldm/modules/losses/__init__.py @@ -1 +1 @@ -from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator \ No newline at end of file +from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator diff --git a/ldm/modules/losses/contperceptual.py b/ldm/modules/losses/contperceptual.py index 672c1e32a1..7fa4124346 100644 --- a/ldm/modules/losses/contperceptual.py +++ b/ldm/modules/losses/contperceptual.py @@ -5,13 +5,24 @@ from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/ class LPIPSWithDiscriminator(nn.Module): - def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, - disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, - perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, - disc_loss="hinge"): + def __init__( + self, + disc_start, + logvar_init=0.0, + kl_weight=1.0, + pixelloss_weight=1.0, + disc_num_layers=3, + disc_in_channels=3, + disc_factor=1.0, + disc_weight=1.0, + perceptual_weight=1.0, + use_actnorm=False, + disc_conditional=False, + disc_loss='hinge', + ): super().__init__() - assert disc_loss in ["hinge", "vanilla"] + assert disc_loss in ['hinge', 'vanilla'] self.kl_weight = kl_weight self.pixel_weight = pixelloss_weight self.perceptual_loss = LPIPS().eval() @@ -19,42 +30,68 @@ class LPIPSWithDiscriminator(nn.Module): # output log variance self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) - self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, - n_layers=disc_num_layers, - use_actnorm=use_actnorm - ).apply(weights_init) + self.discriminator = NLayerDiscriminator( + input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm, + ).apply(weights_init) self.discriminator_iter_start = disc_start - self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss + self.disc_loss = ( + hinge_d_loss if disc_loss == 'hinge' else vanilla_d_loss + ) self.disc_factor = disc_factor self.discriminator_weight = disc_weight self.disc_conditional = disc_conditional def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): if last_layer is not None: - nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + nll_grads = torch.autograd.grad( + nll_loss, last_layer, retain_graph=True + )[0] + g_grads = torch.autograd.grad( + g_loss, last_layer, retain_graph=True + )[0] else: - nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + nll_grads = torch.autograd.grad( + nll_loss, self.last_layer[0], retain_graph=True + )[0] + g_grads = torch.autograd.grad( + g_loss, self.last_layer[0], retain_graph=True + )[0] d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() d_weight = d_weight * self.discriminator_weight return d_weight - def forward(self, inputs, reconstructions, posteriors, optimizer_idx, - global_step, last_layer=None, cond=None, split="train", - weights=None): - rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + def forward( + self, + inputs, + reconstructions, + posteriors, + optimizer_idx, + global_step, + last_layer=None, + cond=None, + split='train', + weights=None, + ): + rec_loss = torch.abs( + inputs.contiguous() - reconstructions.contiguous() + ) if self.perceptual_weight > 0: - p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + p_loss = self.perceptual_loss( + inputs.contiguous(), reconstructions.contiguous() + ) rec_loss = rec_loss + self.perceptual_weight * p_loss nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar weighted_nll_loss = nll_loss if weights is not None: - weighted_nll_loss = weights*nll_loss - weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + weighted_nll_loss = weights * nll_loss + weighted_nll_loss = ( + torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + ) nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] kl_loss = posteriors.kl() kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] @@ -67,45 +104,72 @@ class LPIPSWithDiscriminator(nn.Module): logits_fake = self.discriminator(reconstructions.contiguous()) else: assert self.disc_conditional - logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + logits_fake = self.discriminator( + torch.cat((reconstructions.contiguous(), cond), dim=1) + ) g_loss = -torch.mean(logits_fake) if self.disc_factor > 0.0: try: - d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + d_weight = self.calculate_adaptive_weight( + nll_loss, g_loss, last_layer=last_layer + ) except RuntimeError: assert not self.training d_weight = torch.tensor(0.0) else: d_weight = torch.tensor(0.0) - disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) - loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss + disc_factor = adopt_weight( + self.disc_factor, + global_step, + threshold=self.discriminator_iter_start, + ) + loss = ( + weighted_nll_loss + + self.kl_weight * kl_loss + + d_weight * disc_factor * g_loss + ) - log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), - "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), - "{}/rec_loss".format(split): rec_loss.detach().mean(), - "{}/d_weight".format(split): d_weight.detach(), - "{}/disc_factor".format(split): torch.tensor(disc_factor), - "{}/g_loss".format(split): g_loss.detach().mean(), - } + log = { + '{}/total_loss'.format(split): loss.clone().detach().mean(), + '{}/logvar'.format(split): self.logvar.detach(), + '{}/kl_loss'.format(split): kl_loss.detach().mean(), + '{}/nll_loss'.format(split): nll_loss.detach().mean(), + '{}/rec_loss'.format(split): rec_loss.detach().mean(), + '{}/d_weight'.format(split): d_weight.detach(), + '{}/disc_factor'.format(split): torch.tensor(disc_factor), + '{}/g_loss'.format(split): g_loss.detach().mean(), + } return loss, log if optimizer_idx == 1: # second pass for discriminator update if cond is None: logits_real = self.discriminator(inputs.contiguous().detach()) - logits_fake = self.discriminator(reconstructions.contiguous().detach()) + logits_fake = self.discriminator( + reconstructions.contiguous().detach() + ) else: - logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) - logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + logits_real = self.discriminator( + torch.cat((inputs.contiguous().detach(), cond), dim=1) + ) + logits_fake = self.discriminator( + torch.cat( + (reconstructions.contiguous().detach(), cond), dim=1 + ) + ) - disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + disc_factor = adopt_weight( + self.disc_factor, + global_step, + threshold=self.discriminator_iter_start, + ) d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) - log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), - "{}/logits_real".format(split): logits_real.detach().mean(), - "{}/logits_fake".format(split): logits_fake.detach().mean() - } + log = { + '{}/disc_loss'.format(split): d_loss.clone().detach().mean(), + '{}/logits_real'.format(split): logits_real.detach().mean(), + '{}/logits_fake'.format(split): logits_fake.detach().mean(), + } return d_loss, log - diff --git a/ldm/modules/losses/vqperceptual.py b/ldm/modules/losses/vqperceptual.py index f69981769e..2f94bf5281 100644 --- a/ldm/modules/losses/vqperceptual.py +++ b/ldm/modules/losses/vqperceptual.py @@ -3,21 +3,25 @@ from torch import nn import torch.nn.functional as F from einops import repeat -from taming.modules.discriminator.model import NLayerDiscriminator, weights_init +from taming.modules.discriminator.model import ( + NLayerDiscriminator, + weights_init, +) from taming.modules.losses.lpips import LPIPS from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] - loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) - loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) + loss_real = torch.mean(F.relu(1.0 - logits_real), dim=[1, 2, 3]) + loss_fake = torch.mean(F.relu(1.0 + logits_fake), dim=[1, 2, 3]) loss_real = (weights * loss_real).sum() / weights.sum() loss_fake = (weights * loss_fake).sum() / weights.sum() d_loss = 0.5 * (loss_real + loss_fake) return d_loss -def adopt_weight(weight, global_step, threshold=0, value=0.): + +def adopt_weight(weight, global_step, threshold=0, value=0.0): if global_step < threshold: weight = value return weight @@ -26,57 +30,76 @@ def adopt_weight(weight, global_step, threshold=0, value=0.): def measure_perplexity(predicted_indices, n_embed): # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally - encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) + encodings = ( + F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) + ) avg_probs = encodings.mean(0) perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() cluster_use = torch.sum(avg_probs > 0) return perplexity, cluster_use + def l1(x, y): - return torch.abs(x-y) + return torch.abs(x - y) def l2(x, y): - return torch.pow((x-y), 2) + return torch.pow((x - y), 2) class VQLPIPSWithDiscriminator(nn.Module): - def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, - disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, - perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, - disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", - pixel_loss="l1"): + def __init__( + self, + disc_start, + codebook_weight=1.0, + pixelloss_weight=1.0, + disc_num_layers=3, + disc_in_channels=3, + disc_factor=1.0, + disc_weight=1.0, + perceptual_weight=1.0, + use_actnorm=False, + disc_conditional=False, + disc_ndf=64, + disc_loss='hinge', + n_classes=None, + perceptual_loss='lpips', + pixel_loss='l1', + ): super().__init__() - assert disc_loss in ["hinge", "vanilla"] - assert perceptual_loss in ["lpips", "clips", "dists"] - assert pixel_loss in ["l1", "l2"] + assert disc_loss in ['hinge', 'vanilla'] + assert perceptual_loss in ['lpips', 'clips', 'dists'] + assert pixel_loss in ['l1', 'l2'] self.codebook_weight = codebook_weight self.pixel_weight = pixelloss_weight - if perceptual_loss == "lpips": - print(f"{self.__class__.__name__}: Running with LPIPS.") + if perceptual_loss == 'lpips': + print(f'{self.__class__.__name__}: Running with LPIPS.') self.perceptual_loss = LPIPS().eval() else: - raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") + raise ValueError( + f'Unknown perceptual loss: >> {perceptual_loss} <<' + ) self.perceptual_weight = perceptual_weight - if pixel_loss == "l1": + if pixel_loss == 'l1': self.pixel_loss = l1 else: self.pixel_loss = l2 - self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, - n_layers=disc_num_layers, - use_actnorm=use_actnorm, - ndf=disc_ndf - ).apply(weights_init) + self.discriminator = NLayerDiscriminator( + input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm, + ndf=disc_ndf, + ).apply(weights_init) self.discriminator_iter_start = disc_start - if disc_loss == "hinge": + if disc_loss == 'hinge': self.disc_loss = hinge_d_loss - elif disc_loss == "vanilla": + elif disc_loss == 'vanilla': self.disc_loss = vanilla_d_loss else: raise ValueError(f"Unknown GAN loss '{disc_loss}'.") - print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") + print(f'VQLPIPSWithDiscriminator running with {disc_loss} loss.') self.disc_factor = disc_factor self.discriminator_weight = disc_weight self.disc_conditional = disc_conditional @@ -84,31 +107,53 @@ class VQLPIPSWithDiscriminator(nn.Module): def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): if last_layer is not None: - nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + nll_grads = torch.autograd.grad( + nll_loss, last_layer, retain_graph=True + )[0] + g_grads = torch.autograd.grad( + g_loss, last_layer, retain_graph=True + )[0] else: - nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + nll_grads = torch.autograd.grad( + nll_loss, self.last_layer[0], retain_graph=True + )[0] + g_grads = torch.autograd.grad( + g_loss, self.last_layer[0], retain_graph=True + )[0] d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() d_weight = d_weight * self.discriminator_weight return d_weight - def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, - global_step, last_layer=None, cond=None, split="train", predicted_indices=None): + def forward( + self, + codebook_loss, + inputs, + reconstructions, + optimizer_idx, + global_step, + last_layer=None, + cond=None, + split='train', + predicted_indices=None, + ): if not exists(codebook_loss): - codebook_loss = torch.tensor([0.]).to(inputs.device) - #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) - rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) + codebook_loss = torch.tensor([0.0]).to(inputs.device) + # rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + rec_loss = self.pixel_loss( + inputs.contiguous(), reconstructions.contiguous() + ) if self.perceptual_weight > 0: - p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + p_loss = self.perceptual_loss( + inputs.contiguous(), reconstructions.contiguous() + ) rec_loss = rec_loss + self.perceptual_weight * p_loss else: p_loss = torch.tensor([0.0]) nll_loss = rec_loss - #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + # nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] nll_loss = torch.mean(nll_loss) # now the GAN part @@ -119,49 +164,77 @@ class VQLPIPSWithDiscriminator(nn.Module): logits_fake = self.discriminator(reconstructions.contiguous()) else: assert self.disc_conditional - logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + logits_fake = self.discriminator( + torch.cat((reconstructions.contiguous(), cond), dim=1) + ) g_loss = -torch.mean(logits_fake) try: - d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + d_weight = self.calculate_adaptive_weight( + nll_loss, g_loss, last_layer=last_layer + ) except RuntimeError: assert not self.training d_weight = torch.tensor(0.0) - disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) - loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() + disc_factor = adopt_weight( + self.disc_factor, + global_step, + threshold=self.discriminator_iter_start, + ) + loss = ( + nll_loss + + d_weight * disc_factor * g_loss + + self.codebook_weight * codebook_loss.mean() + ) - log = {"{}/total_loss".format(split): loss.clone().detach().mean(), - "{}/quant_loss".format(split): codebook_loss.detach().mean(), - "{}/nll_loss".format(split): nll_loss.detach().mean(), - "{}/rec_loss".format(split): rec_loss.detach().mean(), - "{}/p_loss".format(split): p_loss.detach().mean(), - "{}/d_weight".format(split): d_weight.detach(), - "{}/disc_factor".format(split): torch.tensor(disc_factor), - "{}/g_loss".format(split): g_loss.detach().mean(), - } + log = { + '{}/total_loss'.format(split): loss.clone().detach().mean(), + '{}/quant_loss'.format(split): codebook_loss.detach().mean(), + '{}/nll_loss'.format(split): nll_loss.detach().mean(), + '{}/rec_loss'.format(split): rec_loss.detach().mean(), + '{}/p_loss'.format(split): p_loss.detach().mean(), + '{}/d_weight'.format(split): d_weight.detach(), + '{}/disc_factor'.format(split): torch.tensor(disc_factor), + '{}/g_loss'.format(split): g_loss.detach().mean(), + } if predicted_indices is not None: assert self.n_classes is not None with torch.no_grad(): - perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) - log[f"{split}/perplexity"] = perplexity - log[f"{split}/cluster_usage"] = cluster_usage + perplexity, cluster_usage = measure_perplexity( + predicted_indices, self.n_classes + ) + log[f'{split}/perplexity'] = perplexity + log[f'{split}/cluster_usage'] = cluster_usage return loss, log if optimizer_idx == 1: # second pass for discriminator update if cond is None: logits_real = self.discriminator(inputs.contiguous().detach()) - logits_fake = self.discriminator(reconstructions.contiguous().detach()) + logits_fake = self.discriminator( + reconstructions.contiguous().detach() + ) else: - logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) - logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + logits_real = self.discriminator( + torch.cat((inputs.contiguous().detach(), cond), dim=1) + ) + logits_fake = self.discriminator( + torch.cat( + (reconstructions.contiguous().detach(), cond), dim=1 + ) + ) - disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + disc_factor = adopt_weight( + self.disc_factor, + global_step, + threshold=self.discriminator_iter_start, + ) d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) - log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), - "{}/logits_real".format(split): logits_real.detach().mean(), - "{}/logits_fake".format(split): logits_fake.detach().mean() - } + log = { + '{}/disc_loss'.format(split): d_loss.clone().detach().mean(), + '{}/logits_real'.format(split): logits_real.detach().mean(), + '{}/logits_fake'.format(split): logits_fake.detach().mean(), + } return d_loss, log diff --git a/ldm/modules/x_transformer.py b/ldm/modules/x_transformer.py index 1316dbd505..d6c4cc6881 100644 --- a/ldm/modules/x_transformer.py +++ b/ldm/modules/x_transformer.py @@ -11,15 +11,13 @@ from einops import rearrange, repeat, reduce DEFAULT_DIM_HEAD = 64 -Intermediates = namedtuple('Intermediates', [ - 'pre_softmax_attn', - 'post_softmax_attn' -]) +Intermediates = namedtuple( + 'Intermediates', ['pre_softmax_attn', 'post_softmax_attn'] +) -LayerIntermediates = namedtuple('Intermediates', [ - 'hiddens', - 'attn_intermediates' -]) +LayerIntermediates = namedtuple( + 'Intermediates', ['hiddens', 'attn_intermediates'] +) class AbsolutePositionalEmbedding(nn.Module): @@ -39,11 +37,16 @@ class AbsolutePositionalEmbedding(nn.Module): class FixedPositionalEmbedding(nn.Module): def __init__(self, dim): super().__init__() - inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq) def forward(self, x, seq_dim=1, offset=0): - t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset + t = ( + torch.arange(x.shape[seq_dim], device=x.device).type_as( + self.inv_freq + ) + + offset + ) sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) return emb[None, :, :] @@ -51,6 +54,7 @@ class FixedPositionalEmbedding(nn.Module): # helpers + def exists(val): return val is not None @@ -64,18 +68,21 @@ def default(val, d): def always(val): def inner(*args, **kwargs): return val + return inner def not_equals(val): def inner(x): return x != val + return inner def equals(val): def inner(x): return x == val + return inner @@ -85,6 +92,7 @@ def max_neg_value(tensor): # keyword argument helpers + def pick_and_pop(keys, d): values = list(map(lambda key: d.pop(key), keys)) return dict(zip(keys, values)) @@ -108,8 +116,15 @@ def group_by_key_prefix(prefix, d): def groupby_prefix_and_trim(prefix, d): - kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) - kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + kwargs_with_prefix, kwargs = group_dict_by_key( + partial(string_begins_with, prefix), d + ) + kwargs_without_prefix = dict( + map( + lambda x: (x[0][len(prefix) :], x[1]), + tuple(kwargs_with_prefix.items()), + ) + ) return kwargs_without_prefix, kwargs @@ -139,7 +154,7 @@ class Rezero(nn.Module): class ScaleNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() - self.scale = dim ** -0.5 + self.scale = dim**-0.5 self.eps = eps self.g = nn.Parameter(torch.ones(1)) @@ -151,7 +166,7 @@ class ScaleNorm(nn.Module): class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-8): super().__init__() - self.scale = dim ** -0.5 + self.scale = dim**-0.5 self.eps = eps self.g = nn.Parameter(torch.ones(dim)) @@ -173,7 +188,7 @@ class GRUGating(nn.Module): def forward(self, x, residual): gated_output = self.gru( rearrange(x, 'b n d -> (b n) d'), - rearrange(residual, 'b n d -> (b n) d') + rearrange(residual, 'b n d -> (b n) d'), ) return gated_output.reshape_as(x) @@ -181,6 +196,7 @@ class GRUGating(nn.Module): # feedforward + class GEGLU(nn.Module): def __init__(self, dim_in, dim_out): super().__init__() @@ -192,19 +208,18 @@ class GEGLU(nn.Module): class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) - project_in = nn.Sequential( - nn.Linear(dim, inner_dim), - nn.GELU() - ) if not glu else GEGLU(dim, inner_dim) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) self.net = nn.Sequential( - project_in, - nn.Dropout(dropout), - nn.Linear(inner_dim, dim_out) + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) ) def forward(self, x): @@ -214,23 +229,25 @@ class FeedForward(nn.Module): # attention. class Attention(nn.Module): def __init__( - self, - dim, - dim_head=DEFAULT_DIM_HEAD, - heads=8, - causal=False, - mask=None, - talking_heads=False, - sparse_topk=None, - use_entmax15=False, - num_mem_kv=0, - dropout=0., - on_attn=False + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + mask=None, + talking_heads=False, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0.0, + on_attn=False, ): super().__init__() if use_entmax15: - raise NotImplementedError("Check out entmax activation instead of softmax activation!") - self.scale = dim_head ** -0.5 + raise NotImplementedError( + 'Check out entmax activation instead of softmax activation!' + ) + self.scale = dim_head**-0.5 self.heads = heads self.causal = causal self.mask = mask @@ -252,7 +269,7 @@ class Attention(nn.Module): self.sparse_topk = sparse_topk # entmax - #self.attn_fn = entmax15 if use_entmax15 else F.softmax + # self.attn_fn = entmax15 if use_entmax15 else F.softmax self.attn_fn = F.softmax # add memory key / values @@ -263,20 +280,29 @@ class Attention(nn.Module): # attention on attention self.attn_on_attn = on_attn - self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) + self.to_out = ( + nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) + if on_attn + else nn.Linear(inner_dim, dim) + ) def forward( - self, - x, - context=None, - mask=None, - context_mask=None, - rel_pos=None, - sinusoidal_emb=None, - prev_attn=None, - mem=None + self, + x, + context=None, + mask=None, + context_mask=None, + rel_pos=None, + sinusoidal_emb=None, + prev_attn=None, + mem=None, ): - b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device + b, n, _, h, talking_heads, device = ( + *x.shape, + self.heads, + self.talking_heads, + x.device, + ) kv_input = default(context, x) q_input = x @@ -297,23 +323,35 @@ class Attention(nn.Module): k = self.to_k(k_input) v = self.to_v(v_input) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + q, k, v = map( + lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v) + ) input_mask = None if any(map(exists, (mask, context_mask))): - q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) + q_mask = default( + mask, lambda: torch.ones((b, n), device=device).bool() + ) k_mask = q_mask if not exists(context) else context_mask - k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) + k_mask = default( + k_mask, + lambda: torch.ones((b, k.shape[-2]), device=device).bool(), + ) q_mask = rearrange(q_mask, 'b i -> b () i ()') k_mask = rearrange(k_mask, 'b j -> b () () j') input_mask = q_mask * k_mask if self.num_mem_kv > 0: - mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) + mem_k, mem_v = map( + lambda t: repeat(t, 'h n d -> b h n d', b=b), + (self.mem_k, self.mem_v), + ) k = torch.cat((mem_k, k), dim=-2) v = torch.cat((mem_v, v), dim=-2) if exists(input_mask): - input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) + input_mask = F.pad( + input_mask, (self.num_mem_kv, 0), value=True + ) dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale mask_value = max_neg_value(dots) @@ -324,7 +362,9 @@ class Attention(nn.Module): pre_softmax_attn = dots if talking_heads: - dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() + dots = einsum( + 'b h i j, h k -> b k i j', dots, self.pre_softmax_proj + ).contiguous() if exists(rel_pos): dots = rel_pos(dots) @@ -336,7 +376,9 @@ class Attention(nn.Module): if self.causal: i, j = dots.shape[-2:] r = torch.arange(i, device=device) - mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') + mask = rearrange(r, 'i -> () () i ()') < rearrange( + r, 'j -> () () () j' + ) mask = F.pad(mask, (j - i, 0), value=False) dots.masked_fill_(mask, mask_value) del mask @@ -354,14 +396,16 @@ class Attention(nn.Module): attn = self.dropout(attn) if talking_heads: - attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() + attn = einsum( + 'b h i j, h k -> b k i j', attn, self.post_softmax_proj + ).contiguous() out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') intermediates = Intermediates( pre_softmax_attn=pre_softmax_attn, - post_softmax_attn=post_softmax_attn + post_softmax_attn=post_softmax_attn, ) return self.to_out(out), intermediates @@ -369,28 +413,28 @@ class Attention(nn.Module): class AttentionLayers(nn.Module): def __init__( - self, - dim, - depth, - heads=8, - causal=False, - cross_attend=False, - only_cross=False, - use_scalenorm=False, - use_rmsnorm=False, - use_rezero=False, - rel_pos_num_buckets=32, - rel_pos_max_distance=128, - position_infused_attn=False, - custom_layers=None, - sandwich_coef=None, - par_ratio=None, - residual_attn=False, - cross_residual_attn=False, - macaron=False, - pre_norm=True, - gate_residual=False, - **kwargs + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rmsnorm=False, + use_rezero=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + position_infused_attn=False, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + **kwargs, ): super().__init__() ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) @@ -403,10 +447,14 @@ class AttentionLayers(nn.Module): self.layers = nn.ModuleList([]) self.has_pos_emb = position_infused_attn - self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None + self.pia_pos_emb = ( + FixedPositionalEmbedding(dim) if position_infused_attn else None + ) self.rotary_pos_emb = always(None) - assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' + assert ( + rel_pos_num_buckets <= rel_pos_max_distance + ), 'number of relative position buckets must be less than the relative position max distance' self.rel_pos = None self.pre_norm = pre_norm @@ -438,15 +486,27 @@ class AttentionLayers(nn.Module): assert 1 < par_ratio <= par_depth, 'par ratio out of range' default_block = tuple(filter(not_equals('f'), default_block)) par_attn = par_depth // par_ratio - depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper + depth_cut = ( + par_depth * 2 // 3 + ) # 2 / 3 attention layer cutoff suggested by PAR paper par_width = (depth_cut + depth_cut // par_attn) // par_attn - assert len(default_block) <= par_width, 'default block is too large for par_ratio' - par_block = default_block + ('f',) * (par_width - len(default_block)) + assert ( + len(default_block) <= par_width + ), 'default block is too large for par_ratio' + par_block = default_block + ('f',) * ( + par_width - len(default_block) + ) par_head = par_block * par_attn layer_types = par_head + ('f',) * (par_depth - len(par_head)) elif exists(sandwich_coef): - assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' - layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef + assert ( + sandwich_coef > 0 and sandwich_coef <= depth + ), 'sandwich coefficient should be less than the depth' + layer_types = ( + ('a',) * sandwich_coef + + default_block * (depth - sandwich_coef) + + ('f',) * sandwich_coef + ) else: layer_types = default_block * depth @@ -455,7 +515,9 @@ class AttentionLayers(nn.Module): for layer_type in self.layer_types: if layer_type == 'a': - layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + layer = Attention( + dim, heads=heads, causal=causal, **attn_kwargs + ) elif layer_type == 'c': layer = Attention(dim, heads=heads, **attn_kwargs) elif layer_type == 'f': @@ -472,21 +534,17 @@ class AttentionLayers(nn.Module): else: residual_fn = Residual() - self.layers.append(nn.ModuleList([ - norm_fn(), - layer, - residual_fn - ])) + self.layers.append(nn.ModuleList([norm_fn(), layer, residual_fn])) def forward( - self, - x, - context=None, - mask=None, - context_mask=None, - mems=None, - return_hiddens=False, - **kwargs + self, + x, + context=None, + mask=None, + context_mask=None, + mems=None, + return_hiddens=False, + **kwargs, ): hiddens = [] intermediates = [] @@ -495,7 +553,9 @@ class AttentionLayers(nn.Module): mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers - for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + for ind, (layer_type, (norm, block, residual_fn)) in enumerate( + zip(self.layer_types, self.layers) + ): is_last = ind == (len(self.layers) - 1) if layer_type == 'a': @@ -508,10 +568,22 @@ class AttentionLayers(nn.Module): x = norm(x) if layer_type == 'a': - out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, - prev_attn=prev_attn, mem=layer_mem) + out, inter = block( + x, + mask=mask, + sinusoidal_emb=self.pia_pos_emb, + rel_pos=self.rel_pos, + prev_attn=prev_attn, + mem=layer_mem, + ) elif layer_type == 'c': - out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) + out, inter = block( + x, + context=context, + mask=mask, + context_mask=context_mask, + prev_attn=prev_cross_attn, + ) elif layer_type == 'f': out = block(x) @@ -530,8 +602,7 @@ class AttentionLayers(nn.Module): if return_hiddens: intermediates = LayerIntermediates( - hiddens=hiddens, - attn_intermediates=intermediates + hiddens=hiddens, attn_intermediates=intermediates ) return x, intermediates @@ -545,23 +616,24 @@ class Encoder(AttentionLayers): super().__init__(causal=False, **kwargs) - class TransformerWrapper(nn.Module): def __init__( - self, - *, - num_tokens, - max_seq_len, - attn_layers, - emb_dim=None, - max_mem_len=0., - emb_dropout=0., - num_memory_tokens=None, - tie_embedding=False, - use_pos_emb=True + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0.0, + emb_dropout=0.0, + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True, ): super().__init__() - assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + assert isinstance( + attn_layers, AttentionLayers + ), 'attention layers must be one of Encoder or Decoder' dim = attn_layers.dim emb_dim = default(emb_dim, dim) @@ -571,23 +643,34 @@ class TransformerWrapper(nn.Module): self.num_tokens = num_tokens self.token_emb = nn.Embedding(num_tokens, emb_dim) - self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( - use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.pos_emb = ( + AbsolutePositionalEmbedding(emb_dim, max_seq_len) + if (use_pos_emb and not attn_layers.has_pos_emb) + else always(0) + ) self.emb_dropout = nn.Dropout(emb_dropout) - self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.project_emb = ( + nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + ) self.attn_layers = attn_layers self.norm = nn.LayerNorm(dim) self.init_() - self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + self.to_logits = ( + nn.Linear(dim, num_tokens) + if not tie_embedding + else lambda t: t @ self.token_emb.weight.t() + ) # memory tokens (like [cls]) from Memory Transformers paper num_memory_tokens = default(num_memory_tokens, 0) self.num_memory_tokens = num_memory_tokens if num_memory_tokens > 0: - self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + self.memory_tokens = nn.Parameter( + torch.randn(num_memory_tokens, dim) + ) # let funnel encoder know number of memory tokens, if specified if hasattr(attn_layers, 'num_memory_tokens'): @@ -597,20 +680,20 @@ class TransformerWrapper(nn.Module): nn.init.normal_(self.token_emb.weight, std=0.02) def forward( - self, - x, - return_embeddings=False, - mask=None, - return_mems=False, - return_attn=False, - mems=None, - embedding_manager=None, - **kwargs + self, + x, + return_embeddings=False, + mask=None, + return_mems=False, + return_attn=False, + mems=None, + embedding_manager=None, + **kwargs, ): b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens embedded_x = self.token_emb(x) - + if embedding_manager: x = embedding_manager(x, embedded_x) else: @@ -629,7 +712,9 @@ class TransformerWrapper(nn.Module): if exists(mask): mask = F.pad(mask, (num_mem, 0), value=True) - x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x, intermediates = self.attn_layers( + x, mask=mask, mems=mems, return_hiddens=True, **kwargs + ) x = self.norm(x) mem, x = x[:, :num_mem], x[:, num_mem:] @@ -638,13 +723,30 @@ class TransformerWrapper(nn.Module): if return_mems: hiddens = intermediates.hiddens - new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens - new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) + new_mems = ( + list( + map( + lambda pair: torch.cat(pair, dim=-2), + zip(mems, hiddens), + ) + ) + if exists(mems) + else hiddens + ) + new_mems = list( + map( + lambda t: t[..., -self.max_mem_len :, :].detach(), new_mems + ) + ) return out, new_mems if return_attn: - attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + attn_maps = list( + map( + lambda t: t.post_softmax_attn, + intermediates.attn_intermediates, + ) + ) return out, attn_maps return out - diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index f1f88bba5e..8aad3557af 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -24,10 +24,10 @@ import re import traceback from ldm.util import instantiate_from_config -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.plms import PLMSSampler +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.ksampler import KSampler -from ldm.dream.pngwriter import PngWriter +from ldm.dream.pngwriter import PngWriter """Simplified text to image API for stable diffusion/latent diffusion @@ -93,67 +93,69 @@ still work. class T2I: """T2I class - Attributes - ---------- - model - config - iterations - batch_size - steps - seed - sampler_name - width - height - cfg_scale - latent_channels - downsampling_factor - precision - strength - embedding_path + Attributes + ---------- + model + config + iterations + batch_size + steps + seed + sampler_name + width + height + cfg_scale + latent_channels + downsampling_factor + precision + strength + embedding_path -The vast majority of these arguments default to reasonable values. + The vast majority of these arguments default to reasonable values. """ - def __init__(self, - batch_size=1, - iterations = 1, - steps=50, - seed=None, - cfg_scale=7.5, - weights="models/ldm/stable-diffusion-v1/model.ckpt", - config = "configs/stable-diffusion/v1-inference.yaml", - width=512, - height=512, - sampler_name="klms", - latent_channels=4, - downsampling_factor=8, - ddim_eta=0.0, # deterministic - precision='autocast', - full_precision=False, - strength=0.75, # default in scripts/img2img.py - embedding_path=None, - latent_diffusion_weights=False, # just to keep track of this parameter when regenerating prompt - device='cuda', - gfpgan=None, + + def __init__( + self, + batch_size=1, + iterations=1, + steps=50, + seed=None, + cfg_scale=7.5, + weights='models/ldm/stable-diffusion-v1/model.ckpt', + config='configs/stable-diffusion/v1-inference.yaml', + width=512, + height=512, + sampler_name='klms', + latent_channels=4, + downsampling_factor=8, + ddim_eta=0.0, # deterministic + precision='autocast', + full_precision=False, + strength=0.75, # default in scripts/img2img.py + embedding_path=None, + latent_diffusion_weights=False, # just to keep track of this parameter when regenerating prompt + device='cuda', + gfpgan=None, ): - self.batch_size = batch_size + self.batch_size = batch_size self.iterations = iterations - self.width = width - self.height = height - self.steps = steps - self.cfg_scale = cfg_scale - self.weights = weights - self.config = config - self.sampler_name = sampler_name - self.latent_channels = latent_channels + self.width = width + self.height = height + self.steps = steps + self.cfg_scale = cfg_scale + self.weights = weights + self.config = config + self.sampler_name = sampler_name + self.latent_channels = latent_channels self.downsampling_factor = downsampling_factor - self.ddim_eta = ddim_eta - self.precision = precision - self.full_precision = full_precision - self.strength = strength - self.embedding_path = embedding_path - self.model = None # empty for now - self.sampler = None - self.latent_diffusion_weights=latent_diffusion_weights + self.ddim_eta = ddim_eta + self.precision = precision + self.full_precision = full_precision + self.strength = strength + self.embedding_path = embedding_path + self.model = None # empty for now + self.sampler = None + self.latent_diffusion_weights = latent_diffusion_weights self.device = device self.gfpgan = gfpgan if seed is None: @@ -162,49 +164,55 @@ The vast majority of these arguments default to reasonable values. self.seed = seed transformers.logging.set_verbosity_error() - def prompt2png(self,prompt,outdir,**kwargs): - ''' + def prompt2png(self, prompt, outdir, **kwargs): + """ Takes a prompt and an output directory, writes out the requested number of PNG files, and returns an array of [[filename,seed],[filename,seed]...] Optional named arguments are the same as those passed to T2I and prompt2image() - ''' - results = self.prompt2image(prompt,**kwargs) - pngwriter = PngWriter(outdir,prompt,kwargs.get('batch_size',self.batch_size)) + """ + results = self.prompt2image(prompt, **kwargs) + pngwriter = PngWriter( + outdir, prompt, kwargs.get('batch_size', self.batch_size) + ) for r in results: metadata_str = f'prompt2png("{prompt}" {kwargs} seed={r[1]}' # gets written into the PNG - pngwriter.write_image(r[0],r[1]) + pngwriter.write_image(r[0], r[1]) return pngwriter.files_written - def txt2img(self,prompt,**kwargs): - outdir = kwargs.get('outdir','outputs/img-samples') - return self.prompt2png(prompt,outdir,**kwargs) + def txt2img(self, prompt, **kwargs): + outdir = kwargs.get('outdir', 'outputs/img-samples') + return self.prompt2png(prompt, outdir, **kwargs) - def img2img(self,prompt,**kwargs): - outdir = kwargs.get('outdir','outputs/img-samples') - assert 'init_img' in kwargs,'call to img2img() must include the init_img argument' - return self.prompt2png(prompt,outdir,**kwargs) + def img2img(self, prompt, **kwargs): + outdir = kwargs.get('outdir', 'outputs/img-samples') + assert ( + 'init_img' in kwargs + ), 'call to img2img() must include the init_img argument' + return self.prompt2png(prompt, outdir, **kwargs) - def prompt2image(self, - # these are common - prompt, - batch_size=None, - iterations=None, - steps=None, - seed=None, - cfg_scale=None, - ddim_eta=None, - skip_normalize=False, - image_callback=None, - # these are specific to txt2img - width=None, - height=None, - # these are specific to img2img - init_img=None, - strength=None, - gfpgan_strength=None, - variants=None, - **args): # eat up additional cruft - ''' + def prompt2image( + self, + # these are common + prompt, + batch_size=None, + iterations=None, + steps=None, + seed=None, + cfg_scale=None, + ddim_eta=None, + skip_normalize=False, + image_callback=None, + # these are specific to txt2img + width=None, + height=None, + # these are specific to img2img + init_img=None, + strength=None, + gfpgan_strength=None, + variants=None, + **args, + ): # eat up additional cruft + """ ldm.prompt2image() is the common entry point for txt2img() and img2img() It takes the following arguments: prompt // prompt string (no default) @@ -232,118 +240,157 @@ The vast majority of these arguments default to reasonable values. The callback used by the prompt2png() can be found in ldm/dream_util.py. It contains code to create the requested output directory, select a unique informative name for each image, and write the prompt into the PNG metadata. - ''' - steps = steps or self.steps - seed = seed or self.seed - width = width or self.width - height = height or self.height - cfg_scale = cfg_scale or self.cfg_scale - ddim_eta = ddim_eta or self.ddim_eta + """ + steps = steps or self.steps + seed = seed or self.seed + width = width or self.width + height = height or self.height + cfg_scale = cfg_scale or self.cfg_scale + ddim_eta = ddim_eta or self.ddim_eta batch_size = batch_size or self.batch_size iterations = iterations or self.iterations - strength = strength or self.strength + strength = strength or self.strength - model = self.load_model() # will instantiate the model or return it from cache - assert cfg_scale>1.0, "CFG_Scale (-C) must be >1.0" - assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]' - w = int(width/64) * 64 - h = int(height/64) * 64 + model = ( + self.load_model() + ) # will instantiate the model or return it from cache + assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0' + assert ( + 0.0 <= strength <= 1.0 + ), 'can only work with strength in [0.0, 1.0]' + w = int(width / 64) * 64 + h = int(height / 64) * 64 if h != height or w != width: - print(f'Height and width must be multiples of 64. Resizing to {h}x{w}') + print( + f'Height and width must be multiples of 64. Resizing to {h}x{w}' + ) height = h - width = w + width = w - scope = autocast if self.precision=="autocast" else nullcontext + scope = autocast if self.precision == 'autocast' else nullcontext tic = time.time() results = list() try: if init_img: - assert os.path.exists(init_img),f'{init_img}: File not found' - images_iterator = self._img2img(prompt, - precision_scope=scope, - batch_size=batch_size, - steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta, - skip_normalize=skip_normalize, - init_img=init_img,strength=strength) + assert os.path.exists(init_img), f'{init_img}: File not found' + images_iterator = self._img2img( + prompt, + precision_scope=scope, + batch_size=batch_size, + steps=steps, + cfg_scale=cfg_scale, + ddim_eta=ddim_eta, + skip_normalize=skip_normalize, + init_img=init_img, + strength=strength, + ) else: - images_iterator = self._txt2img(prompt, - precision_scope=scope, - batch_size=batch_size, - steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta, - skip_normalize=skip_normalize, - width=width,height=height) + images_iterator = self._txt2img( + prompt, + precision_scope=scope, + batch_size=batch_size, + steps=steps, + cfg_scale=cfg_scale, + ddim_eta=ddim_eta, + skip_normalize=skip_normalize, + width=width, + height=height, + ) with scope(self.device.type), self.model.ema_scope(): - for n in trange(iterations, desc="Sampling"): + for n in trange(iterations, desc='Sampling'): seed_everything(seed) iter_images = next(images_iterator) for image in iter_images: try: if gfpgan_strength > 0: - image = self._run_gfpgan(image, gfpgan_strength) + image = self._run_gfpgan( + image, gfpgan_strength + ) except Exception as e: - print(f"Error running GFPGAN - Your image was not enhanced.\n{e}") + print( + f'Error running GFPGAN - Your image was not enhanced.\n{e}' + ) results.append([image, seed]) if image_callback is not None: - image_callback(image,seed) + image_callback(image, seed) seed = self._new_seed() except KeyboardInterrupt: print('*interrupted*') - print('Partial results will be returned; if --grid was requested, nothing will be returned.') + print( + 'Partial results will be returned; if --grid was requested, nothing will be returned.' + ) except RuntimeError as e: print(str(e)) print('Are you sure your system has an adequate NVIDIA GPU?') - toc = time.time() - print(f'{len(results)} images generated in',"%4.2fs"% (toc-tic)) + toc = time.time() + print(f'{len(results)} images generated in', '%4.2fs' % (toc - tic)) return results @torch.no_grad() - def _txt2img(self, - prompt, - precision_scope, - batch_size, - steps,cfg_scale,ddim_eta, - skip_normalize, - width,height): + def _txt2img( + self, + prompt, + precision_scope, + batch_size, + steps, + cfg_scale, + ddim_eta, + skip_normalize, + width, + height, + ): """ An infinite iterator of images from the prompt. """ - sampler = self.sampler while True: uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize) - shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor] - samples, _ = sampler.sample(S=steps, - conditioning=c, - batch_size=batch_size, - shape=shape, - verbose=False, - unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=uc, - eta=ddim_eta) + shape = [ + self.latent_channels, + height // self.downsampling_factor, + width // self.downsampling_factor, + ] + samples, _ = sampler.sample( + S=steps, + conditioning=c, + batch_size=batch_size, + shape=shape, + verbose=False, + unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=uc, + eta=ddim_eta, + ) yield self._samples_to_images(samples) @torch.no_grad() - def _img2img(self, - prompt, - precision_scope, - batch_size, - steps,cfg_scale,ddim_eta, - skip_normalize, - init_img,strength): + def _img2img( + self, + prompt, + precision_scope, + batch_size, + steps, + cfg_scale, + ddim_eta, + skip_normalize, + init_img, + strength, + ): """ An infinite iterator of images from the prompt and the initial image """ # PLMS sampler not supported yet, so ignore previous sampler - if self.sampler_name!='ddim': - print(f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler") + if self.sampler_name != 'ddim': + print( + f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler" + ) sampler = DDIMSampler(self.model, device=self.device) else: sampler = self.sampler @@ -351,9 +398,13 @@ The vast majority of these arguments default to reasonable values. init_image = self._load_img(init_img).to(self.device) init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) with precision_scope(self.device.type): - init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(init_image)) # move to latent space + init_latent = self.model.get_first_stage_encoding( + self.model.encode_first_stage(init_image) + ) # move to latent space - sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False) + sampler.make_schedule( + ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False + ) t_enc = int(strength * steps) # print(f"target t_enc is {t_enc} steps") @@ -362,31 +413,44 @@ The vast majority of these arguments default to reasonable values. uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize) # encode (scaled latent) - z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device)) + z_enc = sampler.stochastic_encode( + init_latent, torch.tensor([t_enc] * batch_size).to(self.device) + ) # decode it - samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=uc,) + samples = sampler.decode( + z_enc, + c, + t_enc, + unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=uc, + ) yield self._samples_to_images(samples) # TODO: does this actually need to run every loop? does anything in it vary by random seed? def _get_uc_and_c(self, prompt, batch_size, skip_normalize): - uc = self.model.get_learned_conditioning(batch_size * [""]) + uc = self.model.get_learned_conditioning(batch_size * ['']) # weighted sub-prompts - subprompts,weights = T2I._split_weighted_subprompts(prompt) + subprompts, weights = T2I._split_weighted_subprompts(prompt) if len(subprompts) > 1: # i dont know if this is correct.. but it works c = torch.zeros_like(uc) # get total weight for normalizing totalWeight = sum(weights) # normalize each "sub prompt" and add it - for i in range(0,len(subprompts)): + for i in range(0, len(subprompts)): weight = weights[i] if not skip_normalize: weight = weight / totalWeight - c = torch.add(c, self.model.get_learned_conditioning(batch_size * [subprompts[i]]), alpha=weight) - else: # just standard 1 prompt + c = torch.add( + c, + self.model.get_learned_conditioning( + batch_size * [subprompts[i]] + ), + alpha=weight, + ) + else: # just standard 1 prompt c = self.model.get_learned_conditioning(batch_size * [prompt]) return (uc, c) @@ -395,23 +459,29 @@ The vast majority of these arguments default to reasonable values. x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) images = list() for x_sample in x_samples: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + x_sample = 255.0 * rearrange( + x_sample.cpu().numpy(), 'c h w -> h w c' + ) image = Image.fromarray(x_sample.astype(np.uint8)) images.append(image) return images def _new_seed(self): - self.seed = random.randrange(0,np.iinfo(np.uint32).max) + self.seed = random.randrange(0, np.iinfo(np.uint32).max) return self.seed def load_model(self): - """ Load and initialize the model from configuration variables passed at object creation time """ + """Load and initialize the model from configuration variables passed at object creation time""" if self.model is None: seed_everything(self.seed) try: config = OmegaConf.load(self.config) - self.device = torch.device(self.device) if torch.cuda.is_available() else torch.device("cpu") - model = self._load_model_from_config(config,self.weights) + self.device = ( + torch.device(self.device) + if torch.cuda.is_available() + else torch.device('cpu') + ) + model = self._load_model_from_config(config, self.weights) if self.embedding_path is not None: model.embedding_manager.load(self.embedding_path) self.model = model.to(self.device) @@ -421,18 +491,26 @@ The vast majority of these arguments default to reasonable values. raise SystemExit msg = f'setting sampler to {self.sampler_name}' - if self.sampler_name=='plms': + if self.sampler_name == 'plms': self.sampler = PLMSSampler(self.model, device=self.device) elif self.sampler_name == 'ddim': self.sampler = DDIMSampler(self.model, device=self.device) elif self.sampler_name == 'k_dpm_2_a': - self.sampler = KSampler(self.model, 'dpm_2_ancestral', device=self.device) + self.sampler = KSampler( + self.model, 'dpm_2_ancestral', device=self.device + ) elif self.sampler_name == 'k_dpm_2': - self.sampler = KSampler(self.model, 'dpm_2', device=self.device) + self.sampler = KSampler( + self.model, 'dpm_2', device=self.device + ) elif self.sampler_name == 'k_euler_a': - self.sampler = KSampler(self.model, 'euler_ancestral', device=self.device) + self.sampler = KSampler( + self.model, 'euler_ancestral', device=self.device + ) elif self.sampler_name == 'k_euler': - self.sampler = KSampler(self.model, 'euler', device=self.device) + self.sampler = KSampler( + self.model, 'euler', device=self.device + ) elif self.sampler_name == 'k_heun': self.sampler = KSampler(self.model, 'heun', device=self.device) elif self.sampler_name == 'k_lms': @@ -446,32 +524,38 @@ The vast majority of these arguments default to reasonable values. return self.model def _load_model_from_config(self, config, ckpt): - print(f"Loading model from {ckpt}") - pl_sd = torch.load(ckpt, map_location="cpu") -# if "global_step" in pl_sd: -# print(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd["state_dict"] + print(f'Loading model from {ckpt}') + pl_sd = torch.load(ckpt, map_location='cpu') + # if "global_step" in pl_sd: + # print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd['state_dict'] model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) model.to(self.device) model.eval() if self.full_precision: - print('Using slower but more accurate full-precision math (--full_precision)') + print( + 'Using slower but more accurate full-precision math (--full_precision)' + ) else: - print('Using half precision math. Call with --full_precision to use slower but more accurate full precision.') + print( + 'Using half precision math. Call with --full_precision to use slower but more accurate full precision.' + ) model.half() return model - def _load_img(self,path): - image = Image.open(path).convert("RGB") + def _load_img(self, path): + image = Image.open(path).convert('RGB') w, h = image.size - print(f"loaded input image of size ({w}, {h}) from {path}") - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + print(f'loaded input image of size ({w}, {h}) from {path}') + w, h = map( + lambda x: x - x % 32, (w, h) + ) # resize to integer multiple of 32 image = image.resize((w, h), resample=Image.Resampling.LANCZOS) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) - return 2.*image - 1. + return 2.0 * image - 1.0 def _split_weighted_subprompts(text): """ @@ -484,34 +568,36 @@ The vast majority of these arguments default to reasonable values. prompts = [] weights = [] while remaining > 0: - if ":" in text: - idx = text.index(":") # first occurrence from start + if ':' in text: + idx = text.index(':') # first occurrence from start # grab up to index as sub-prompt prompt = text[:idx] remaining -= idx # remove from main text - text = text[idx+1:] + text = text[idx + 1 :] # find value for weight - if " " in text: - idx = text.index(" ") # first occurence - else: # no space, read to end + if ' ' in text: + idx = text.index(' ') # first occurence + else: # no space, read to end idx = len(text) if idx != 0: try: weight = float(text[:idx]) - except: # couldn't treat as float - print(f"Warning: '{text[:idx]}' is not a value, are you missing a space?") + except: # couldn't treat as float + print( + f"Warning: '{text[:idx]}' is not a value, are you missing a space?" + ) weight = 1.0 - else: # no value found + else: # no value found weight = 1.0 # remove from main text remaining -= idx - text = text[idx+1:] + text = text[idx + 1 :] # append the sub-prompt and its weight prompts.append(prompt) weights.append(weight) - else: # no : found - if len(text) > 0: # there is still text though + else: # no : found + if len(text) > 0: # there is still text though # take remainder as weight 1 prompts.append(text) weights.append(1.0) @@ -519,13 +605,20 @@ The vast majority of these arguments default to reasonable values. return prompts, weights def _run_gfpgan(self, image, strength): - if (self.gfpgan is None): - print(f"GFPGAN not initialized, it must be loaded via the --gfpgan argument") + if self.gfpgan is None: + print( + f'GFPGAN not initialized, it must be loaded via the --gfpgan argument' + ) return image - - image = image.convert("RGB") - cropped_faces, restored_faces, restored_img = self.gfpgan.enhance(np.array(image, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True) + image = image.convert('RGB') + + cropped_faces, restored_faces, restored_img = self.gfpgan.enhance( + np.array(image, dtype=np.uint8), + has_aligned=False, + only_center_face=False, + paste_back=True, + ) res = Image.fromarray(restored_img) if strength < 1.0: diff --git a/ldm/util.py b/ldm/util.py index 3affd249de..d1379cae2b 100644 --- a/ldm/util.py +++ b/ldm/util.py @@ -13,22 +13,25 @@ from queue import Queue from inspect import isfunction from PIL import Image, ImageDraw, ImageFont + def log_txt_as_img(wh, xc, size=10): # wh a tuple of (width, height) # xc a list of captions to plot b = len(xc) txts = list() for bi in range(b): - txt = Image.new("RGB", wh, color="white") + txt = Image.new('RGB', wh, color='white') draw = ImageDraw.Draw(txt) font = ImageFont.load_default() nc = int(40 * (wh[0] / 256)) - lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) + lines = '\n'.join( + xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc) + ) try: - draw.text((0, 0), lines, fill="black", font=font) + draw.text((0, 0), lines, fill='black', font=font) except UnicodeEncodeError: - print("Cant encode string for logging. Skipping.") + print('Cant encode string for logging. Skipping.') txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 txts.append(txt) @@ -70,22 +73,26 @@ def mean_flat(tensor): def count_params(model, verbose=False): total_params = sum(p.numel() for p in model.parameters()) if verbose: - print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + print( + f'{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.' + ) return total_params def instantiate_from_config(config, **kwargs): - if not "target" in config: + if not 'target' in config: if config == '__is_first_stage__': return None - elif config == "__is_unconditional__": + elif config == '__is_unconditional__': return None - raise KeyError("Expected key `target` to instantiate.") - return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs) + raise KeyError('Expected key `target` to instantiate.') + return get_obj_from_str(config['target'])( + **config.get('params', dict()), **kwargs + ) def get_obj_from_str(string, reload=False): - module, cls = string.rsplit(".", 1) + module, cls = string.rsplit('.', 1) if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) @@ -101,31 +108,36 @@ def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): else: res = func(data) Q.put([idx, res]) - Q.put("Done") + Q.put('Done') def parallel_data_prefetch( - func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False + func: callable, + data, + n_proc, + target_data_type='ndarray', + cpu_intensive=True, + use_worker_id=False, ): # if target_data_type not in ["ndarray", "list"]: # raise ValueError( # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." # ) - if isinstance(data, np.ndarray) and target_data_type == "list": - raise ValueError("list expected but function got ndarray.") + if isinstance(data, np.ndarray) and target_data_type == 'list': + raise ValueError('list expected but function got ndarray.') elif isinstance(data, abc.Iterable): if isinstance(data, dict): print( f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' ) data = list(data.values()) - if target_data_type == "ndarray": + if target_data_type == 'ndarray': data = np.asarray(data) else: data = list(data) else: raise TypeError( - f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." + f'The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}.' ) if cpu_intensive: @@ -135,7 +147,7 @@ def parallel_data_prefetch( Q = Queue(1000) proc = Thread # spawn processes - if target_data_type == "ndarray": + if target_data_type == 'ndarray': arguments = [ [func, Q, part, i, use_worker_id] for i, part in enumerate(np.array_split(data, n_proc)) @@ -149,7 +161,7 @@ def parallel_data_prefetch( arguments = [ [func, Q, part, i, use_worker_id] for i, part in enumerate( - [data[i: i + step] for i in range(0, len(data), step)] + [data[i : i + step] for i in range(0, len(data), step)] ) ] processes = [] @@ -158,7 +170,7 @@ def parallel_data_prefetch( processes += [p] # start processes - print(f"Start prefetching...") + print(f'Start prefetching...') import time start = time.time() @@ -171,13 +183,13 @@ def parallel_data_prefetch( while k < n_proc: # get result res = Q.get() - if res == "Done": + if res == 'Done': k += 1 else: gather_res[res[0]] = res[1] except Exception as e: - print("Exception: ", e) + print('Exception: ', e) for p in processes: p.terminate() @@ -185,7 +197,7 @@ def parallel_data_prefetch( finally: for p in processes: p.join() - print(f"Prefetching complete. [{time.time() - start} sec.]") + print(f'Prefetching complete. [{time.time() - start} sec.]') if target_data_type == 'ndarray': if not isinstance(gather_res[0], np.ndarray): diff --git a/main.py b/main.py index 5653cf5e06..8c36c270b1 100644 --- a/main.py +++ b/main.py @@ -14,145 +14,171 @@ from PIL import Image from pytorch_lightning import seed_everything from pytorch_lightning.trainer import Trainer -from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor +from pytorch_lightning.callbacks import ( + ModelCheckpoint, + Callback, + LearningRateMonitor, +) from pytorch_lightning.utilities.distributed import rank_zero_only from pytorch_lightning.utilities import rank_zero_info from ldm.data.base import Txt2ImgIterableBaseDataset from ldm.util import instantiate_from_config + def load_model_from_config(config, ckpt, verbose=False): - print(f"Loading model from {ckpt}") - pl_sd = torch.load(ckpt, map_location="cpu") - sd = pl_sd["state_dict"] + print(f'Loading model from {ckpt}') + pl_sd = torch.load(ckpt, map_location='cpu') + sd = pl_sd['state_dict'] config.model.params.ckpt_path = ckpt model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: - print("missing keys:") + print('missing keys:') print(m) if len(u) > 0 and verbose: - print("unexpected keys:") + print('unexpected keys:') print(u) model.cuda() return model + def get_parser(**parser_kwargs): def str2bool(v): if isinstance(v, bool): return v - if v.lower() in ("yes", "true", "t", "y", "1"): + if v.lower() in ('yes', 'true', 't', 'y', '1'): return True - elif v.lower() in ("no", "false", "f", "n", "0"): + elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False else: - raise argparse.ArgumentTypeError("Boolean value expected.") + raise argparse.ArgumentTypeError('Boolean value expected.') parser = argparse.ArgumentParser(**parser_kwargs) parser.add_argument( - "-n", - "--name", + '-n', + '--name', type=str, const=True, - default="", - nargs="?", - help="postfix for logdir", + default='', + nargs='?', + help='postfix for logdir', ) parser.add_argument( - "-r", - "--resume", + '-r', + '--resume', type=str, const=True, - default="", - nargs="?", - help="resume from logdir or checkpoint in logdir", + default='', + nargs='?', + help='resume from logdir or checkpoint in logdir', ) parser.add_argument( - "-b", - "--base", - nargs="*", - metavar="base_config.yaml", - help="paths to base configs. Loaded from left-to-right. " - "Parameters can be overwritten or added with command-line options of the form `--key value`.", + '-b', + '--base', + nargs='*', + metavar='base_config.yaml', + help='paths to base configs. Loaded from left-to-right. ' + 'Parameters can be overwritten or added with command-line options of the form `--key value`.', default=list(), ) parser.add_argument( - "-t", - "--train", + '-t', + '--train', type=str2bool, const=True, default=False, - nargs="?", - help="train", + nargs='?', + help='train', ) parser.add_argument( - "--no-test", + '--no-test', type=str2bool, const=True, default=False, - nargs="?", - help="disable test", + nargs='?', + help='disable test', ) parser.add_argument( - "-p", - "--project", - help="name of new or path to existing project" + '-p', '--project', help='name of new or path to existing project' ) parser.add_argument( - "-d", - "--debug", + '-d', + '--debug', type=str2bool, - nargs="?", + nargs='?', const=True, default=False, - help="enable post-mortem debugging", + help='enable post-mortem debugging', ) parser.add_argument( - "-s", - "--seed", + '-s', + '--seed', type=int, default=23, - help="seed for seed_everything", + help='seed for seed_everything', ) parser.add_argument( - "-f", - "--postfix", + '-f', + '--postfix', type=str, - default="", - help="post-postfix for default name", + default='', + help='post-postfix for default name', ) parser.add_argument( - "-l", - "--logdir", + '-l', + '--logdir', type=str, - default="logs", - help="directory for logging dat shit", + default='logs', + help='directory for logging dat shit', ) parser.add_argument( - "--scale_lr", + '--scale_lr', type=str2bool, - nargs="?", + nargs='?', const=True, default=True, - help="scale base-lr by ngpu * batch_size * n_accumulate", + help='scale base-lr by ngpu * batch_size * n_accumulate', ) parser.add_argument( - "--datadir_in_name", - type=str2bool, - nargs="?", - const=True, - default=True, - help="Prepend the final directory in the data_root to the output directory name") + '--datadir_in_name', + type=str2bool, + nargs='?', + const=True, + default=True, + help='Prepend the final directory in the data_root to the output directory name', + ) - parser.add_argument("--actual_resume", type=str, default="", help="Path to model to actually resume from") - parser.add_argument("--data_root", type=str, required=True, help="Path to directory with training images") + parser.add_argument( + '--actual_resume', + type=str, + default='', + help='Path to model to actually resume from', + ) + parser.add_argument( + '--data_root', + type=str, + required=True, + help='Path to directory with training images', + ) - parser.add_argument("--embedding_manager_ckpt", type=str, default="", help="Initialize embedding manager from a checkpoint") - parser.add_argument("--placeholder_tokens", type=str, nargs="+", default=["*"]) + parser.add_argument( + '--embedding_manager_ckpt', + type=str, + default='', + help='Initialize embedding manager from a checkpoint', + ) + parser.add_argument( + '--placeholder_tokens', type=str, nargs='+', default=['*'] + ) - parser.add_argument("--init_word", type=str, help="Word to use as source for initial token embedding.") + parser.add_argument( + '--init_word', + type=str, + help='Word to use as source for initial token embedding.', + ) return parser @@ -186,7 +212,9 @@ def worker_init_fn(_): if isinstance(dataset, Txt2ImgIterableBaseDataset): split_size = dataset.num_records // worker_info.num_workers # reset num_records to the true number to retain reliable length information - dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] + dataset.sample_ids = dataset.valid_ids[ + worker_id * split_size : (worker_id + 1) * split_size + ] current_id = np.random.choice(len(np.random.get_state()[1]), 1) return np.random.seed(np.random.get_state()[1][current_id] + worker_id) else: @@ -194,25 +222,41 @@ def worker_init_fn(_): class DataModuleFromConfig(pl.LightningDataModule): - def __init__(self, batch_size, train=None, validation=None, test=None, predict=None, - wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False, - shuffle_val_dataloader=False): + def __init__( + self, + batch_size, + train=None, + validation=None, + test=None, + predict=None, + wrap=False, + num_workers=None, + shuffle_test_loader=False, + use_worker_init_fn=False, + shuffle_val_dataloader=False, + ): super().__init__() self.batch_size = batch_size self.dataset_configs = dict() - self.num_workers = num_workers if num_workers is not None else batch_size * 2 + self.num_workers = ( + num_workers if num_workers is not None else batch_size * 2 + ) self.use_worker_init_fn = use_worker_init_fn if train is not None: - self.dataset_configs["train"] = train + self.dataset_configs['train'] = train self.train_dataloader = self._train_dataloader if validation is not None: - self.dataset_configs["validation"] = validation - self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader) + self.dataset_configs['validation'] = validation + self.val_dataloader = partial( + self._val_dataloader, shuffle=shuffle_val_dataloader + ) if test is not None: - self.dataset_configs["test"] = test - self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader) + self.dataset_configs['test'] = test + self.test_dataloader = partial( + self._test_dataloader, shuffle=shuffle_test_loader + ) if predict is not None: - self.dataset_configs["predict"] = predict + self.dataset_configs['predict'] = predict self.predict_dataloader = self._predict_dataloader self.wrap = wrap @@ -223,34 +267,48 @@ class DataModuleFromConfig(pl.LightningDataModule): def setup(self, stage=None): self.datasets = dict( (k, instantiate_from_config(self.dataset_configs[k])) - for k in self.dataset_configs) + for k in self.dataset_configs + ) if self.wrap: for k in self.datasets: self.datasets[k] = WrappedDataset(self.datasets[k]) def _train_dataloader(self): - is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) + is_iterable_dataset = isinstance( + self.datasets['train'], Txt2ImgIterableBaseDataset + ) if is_iterable_dataset or self.use_worker_init_fn: init_fn = worker_init_fn else: init_fn = None - return DataLoader(self.datasets["train"], batch_size=self.batch_size, - num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True, - worker_init_fn=init_fn) + return DataLoader( + self.datasets['train'], + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False if is_iterable_dataset else True, + worker_init_fn=init_fn, + ) def _val_dataloader(self, shuffle=False): - if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: + if ( + isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) + or self.use_worker_init_fn + ): init_fn = worker_init_fn else: init_fn = None - return DataLoader(self.datasets["validation"], - batch_size=self.batch_size, - num_workers=self.num_workers, - worker_init_fn=init_fn, - shuffle=shuffle) + return DataLoader( + self.datasets['validation'], + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn, + shuffle=shuffle, + ) def _test_dataloader(self, shuffle=False): - is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) + is_iterable_dataset = isinstance( + self.datasets['train'], Txt2ImgIterableBaseDataset + ) if is_iterable_dataset or self.use_worker_init_fn: init_fn = worker_init_fn else: @@ -259,20 +317,34 @@ class DataModuleFromConfig(pl.LightningDataModule): # do not shuffle dataloader for iterable dataset shuffle = shuffle and (not is_iterable_dataset) - return DataLoader(self.datasets["test"], batch_size=self.batch_size, - num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle) + return DataLoader( + self.datasets['test'], + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn, + shuffle=shuffle, + ) def _predict_dataloader(self, shuffle=False): - if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: + if ( + isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) + or self.use_worker_init_fn + ): init_fn = worker_init_fn else: init_fn = None - return DataLoader(self.datasets["predict"], batch_size=self.batch_size, - num_workers=self.num_workers, worker_init_fn=init_fn) + return DataLoader( + self.datasets['predict'], + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn, + ) class SetupCallback(Callback): - def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config): + def __init__( + self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config + ): super().__init__() self.resume = resume self.now = now @@ -284,8 +356,8 @@ class SetupCallback(Callback): def on_keyboard_interrupt(self, trainer, pl_module): if trainer.global_rank == 0: - print("Summoning checkpoint.") - ckpt_path = os.path.join(self.ckptdir, "last.ckpt") + print('Summoning checkpoint.') + ckpt_path = os.path.join(self.ckptdir, 'last.ckpt') trainer.save_checkpoint(ckpt_path) def on_pretrain_routine_start(self, trainer, pl_module): @@ -295,24 +367,36 @@ class SetupCallback(Callback): os.makedirs(self.ckptdir, exist_ok=True) os.makedirs(self.cfgdir, exist_ok=True) - if "callbacks" in self.lightning_config: - if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']: - os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True) - print("Project config") + if 'callbacks' in self.lightning_config: + if ( + 'metrics_over_trainsteps_checkpoint' + in self.lightning_config['callbacks'] + ): + os.makedirs( + os.path.join(self.ckptdir, 'trainstep_checkpoints'), + exist_ok=True, + ) + print('Project config') print(OmegaConf.to_yaml(self.config)) - OmegaConf.save(self.config, - os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) + OmegaConf.save( + self.config, + os.path.join(self.cfgdir, '{}-project.yaml'.format(self.now)), + ) - print("Lightning config") + print('Lightning config') print(OmegaConf.to_yaml(self.lightning_config)) - OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}), - os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now))) + OmegaConf.save( + OmegaConf.create({'lightning': self.lightning_config}), + os.path.join( + self.cfgdir, '{}-lightning.yaml'.format(self.now) + ), + ) else: # ModelCheckpoint callback created log directory --- remove it if not self.resume and os.path.exists(self.logdir): dst, name = os.path.split(self.logdir) - dst = os.path.join(dst, "child_runs", name) + dst = os.path.join(dst, 'child_runs', name) os.makedirs(os.path.split(dst)[0], exist_ok=True) try: os.rename(self.logdir, dst) @@ -321,9 +405,18 @@ class SetupCallback(Callback): class ImageLogger(Callback): - def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True, - rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False, - log_images_kwargs=None): + def __init__( + self, + batch_frequency, + max_images, + clamp=True, + increase_log_steps=True, + rescale=True, + disabled=False, + log_on_batch_idx=False, + log_first_step=False, + log_images_kwargs=None, + ): super().__init__() self.rescale = rescale self.batch_freq = batch_frequency @@ -331,7 +424,9 @@ class ImageLogger(Callback): self.logger_log_images = { pl.loggers.TestTubeLogger: self._testtube, } - self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)] + self.log_steps = [ + 2**n for n in range(int(np.log2(self.batch_freq)) + 1) + ] if not increase_log_steps: self.log_steps = [self.batch_freq] self.clamp = clamp @@ -346,15 +441,16 @@ class ImageLogger(Callback): grid = torchvision.utils.make_grid(images[k]) grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w - tag = f"{split}/{k}" + tag = f'{split}/{k}' pl_module.logger.experiment.add_image( - tag, grid, - global_step=pl_module.global_step) + tag, grid, global_step=pl_module.global_step + ) @rank_zero_only - def log_local(self, save_dir, split, images, - global_step, current_epoch, batch_idx): - root = os.path.join(save_dir, "images", split) + def log_local( + self, save_dir, split, images, global_step, current_epoch, batch_idx + ): + root = os.path.join(save_dir, 'images', split) for k in images: grid = torchvision.utils.make_grid(images[k], nrow=4) if self.rescale: @@ -362,21 +458,25 @@ class ImageLogger(Callback): grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) grid = grid.numpy() grid = (grid * 255).astype(np.uint8) - filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format( - k, - global_step, - current_epoch, - batch_idx) + filename = '{}_gs-{:06}_e-{:06}_b-{:06}.png'.format( + k, global_step, current_epoch, batch_idx + ) path = os.path.join(root, filename) os.makedirs(os.path.split(path)[0], exist_ok=True) Image.fromarray(grid).save(path) - def log_img(self, pl_module, batch, batch_idx, split="train"): - check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step - if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 - hasattr(pl_module, "log_images") and - callable(pl_module.log_images) and - self.max_images > 0): + def log_img(self, pl_module, batch, batch_idx, split='train'): + check_idx = ( + batch_idx if self.log_on_batch_idx else pl_module.global_step + ) + if ( + self.check_frequency(check_idx) + and hasattr( # batch_idx % self.batch_freq == 0 + pl_module, 'log_images' + ) + and callable(pl_module.log_images) + and self.max_images > 0 + ): logger = type(pl_module.logger) is_train = pl_module.training @@ -384,7 +484,9 @@ class ImageLogger(Callback): pl_module.eval() with torch.no_grad(): - images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) + images = pl_module.log_images( + batch, split=split, **self.log_images_kwargs + ) for k in images: N = min(images[k].shape[0], self.max_images) @@ -392,20 +494,29 @@ class ImageLogger(Callback): if isinstance(images[k], torch.Tensor): images[k] = images[k].detach().cpu() if self.clamp: - images[k] = torch.clamp(images[k], -1., 1.) + images[k] = torch.clamp(images[k], -1.0, 1.0) - self.log_local(pl_module.logger.save_dir, split, images, - pl_module.global_step, pl_module.current_epoch, batch_idx) + self.log_local( + pl_module.logger.save_dir, + split, + images, + pl_module.global_step, + pl_module.current_epoch, + batch_idx, + ) - logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None) + logger_log_images = self.logger_log_images.get( + logger, lambda *args, **kwargs: None + ) logger_log_images(pl_module, images, pl_module.global_step, split) if is_train: pl_module.train() def check_frequency(self, check_idx): - if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and ( - check_idx > 0 or self.log_first_step): + if ( + (check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps) + ) and (check_idx > 0 or self.log_first_step): try: self.log_steps.pop(0) except IndexError as e: @@ -414,15 +525,23 @@ class ImageLogger(Callback): return True return False - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - if not self.disabled and (pl_module.global_step > 0 or self.log_first_step): - self.log_img(pl_module, batch, batch_idx, split="train") + def on_train_batch_end( + self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx + ): + if not self.disabled and ( + pl_module.global_step > 0 or self.log_first_step + ): + self.log_img(pl_module, batch, batch_idx, split='train') - def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_validation_batch_end( + self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx + ): if not self.disabled and pl_module.global_step > 0: - self.log_img(pl_module, batch, batch_idx, split="val") + self.log_img(pl_module, batch, batch_idx, split='val') if hasattr(pl_module, 'calibrate_grad_norm'): - if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0: + if ( + pl_module.calibrate_grad_norm and batch_idx % 25 == 0 + ) and batch_idx > 0: self.log_gradients(trainer, pl_module, batch_idx=batch_idx) @@ -436,20 +555,22 @@ class CUDACallback(Callback): def on_train_epoch_end(self, trainer, pl_module, outputs): torch.cuda.synchronize(trainer.root_gpu) - max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20 + max_memory = ( + torch.cuda.max_memory_allocated(trainer.root_gpu) / 2**20 + ) epoch_time = time.time() - self.start_time try: max_memory = trainer.training_type_plugin.reduce(max_memory) epoch_time = trainer.training_type_plugin.reduce(epoch_time) - rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") - rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB") + rank_zero_info(f'Average Epoch time: {epoch_time:.2f} seconds') + rank_zero_info(f'Average Peak memory {max_memory:.2f}MiB') except AttributeError: pass -if __name__ == "__main__": +if __name__ == '__main__': # custom parser to specify config files, train, test and debug mode, # postfix, resume. # `--key value` arguments are interpreted as arguments to the trainer. @@ -491,7 +612,7 @@ if __name__ == "__main__": # params: # key: value - now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + now = datetime.datetime.now().strftime('%Y-%m-%dT%H-%M-%S') # add cwd for convenience and to make classes in this file available when # running as `python main.py` @@ -504,47 +625,49 @@ if __name__ == "__main__": opt, unknown = parser.parse_known_args() if opt.name and opt.resume: raise ValueError( - "-n/--name and -r/--resume cannot be specified both." - "If you want to resume training in a new log folder, " - "use -n/--name in combination with --resume_from_checkpoint" + '-n/--name and -r/--resume cannot be specified both.' + 'If you want to resume training in a new log folder, ' + 'use -n/--name in combination with --resume_from_checkpoint' ) if opt.resume: if not os.path.exists(opt.resume): - raise ValueError("Cannot find {}".format(opt.resume)) + raise ValueError('Cannot find {}'.format(opt.resume)) if os.path.isfile(opt.resume): - paths = opt.resume.split("/") + paths = opt.resume.split('/') # idx = len(paths)-paths[::-1].index("logs")+1 # logdir = "/".join(paths[:idx]) - logdir = "/".join(paths[:-2]) + logdir = '/'.join(paths[:-2]) ckpt = opt.resume else: assert os.path.isdir(opt.resume), opt.resume - logdir = opt.resume.rstrip("/") - ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") + logdir = opt.resume.rstrip('/') + ckpt = os.path.join(logdir, 'checkpoints', 'last.ckpt') opt.resume_from_checkpoint = ckpt - base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml"))) + base_configs = sorted( + glob.glob(os.path.join(logdir, 'configs/*.yaml')) + ) opt.base = base_configs + opt.base - _tmp = logdir.split("/") + _tmp = logdir.split('/') nowname = _tmp[-1] else: if opt.name: - name = "_" + opt.name + name = '_' + opt.name elif opt.base: cfg_fname = os.path.split(opt.base[0])[-1] cfg_name = os.path.splitext(cfg_fname)[0] - name = "_" + cfg_name + name = '_' + cfg_name else: - name = "" + name = '' if opt.datadir_in_name: now = os.path.basename(os.path.normpath(opt.data_root)) + now - + nowname = now + name + opt.postfix logdir = os.path.join(opt.logdir, nowname) - ckptdir = os.path.join(logdir, "checkpoints") - cfgdir = os.path.join(logdir, "configs") + ckptdir = os.path.join(logdir, 'checkpoints') + cfgdir = os.path.join(logdir, 'configs') seed_everything(opt.seed) try: @@ -552,19 +675,19 @@ if __name__ == "__main__": configs = [OmegaConf.load(cfg) for cfg in opt.base] cli = OmegaConf.from_dotlist(unknown) config = OmegaConf.merge(*configs, cli) - lightning_config = config.pop("lightning", OmegaConf.create()) + lightning_config = config.pop('lightning', OmegaConf.create()) # merge trainer cli with config - trainer_config = lightning_config.get("trainer", OmegaConf.create()) + trainer_config = lightning_config.get('trainer', OmegaConf.create()) # default to ddp - trainer_config["accelerator"] = "ddp" + trainer_config['accelerator'] = 'ddp' for k in nondefault_trainer_args(opt): trainer_config[k] = getattr(opt, k) - if not "gpus" in trainer_config: - del trainer_config["accelerator"] + if not 'gpus' in trainer_config: + del trainer_config['accelerator'] cpu = True else: - gpuinfo = trainer_config["gpus"] - print(f"Running on GPUs {gpuinfo}") + gpuinfo = trainer_config['gpus'] + print(f'Running on GPUs {gpuinfo}') cpu = False trainer_opt = argparse.Namespace(**trainer_config) lightning_config.trainer = trainer_config @@ -572,11 +695,17 @@ if __name__ == "__main__": # model # config.model.params.personalization_config.params.init_word = opt.init_word - config.model.params.personalization_config.params.embedding_manager_ckpt = opt.embedding_manager_ckpt - config.model.params.personalization_config.params.placeholder_tokens = opt.placeholder_tokens + config.model.params.personalization_config.params.embedding_manager_ckpt = ( + opt.embedding_manager_ckpt + ) + config.model.params.personalization_config.params.placeholder_tokens = ( + opt.placeholder_tokens + ) if opt.init_word: - config.model.params.personalization_config.params.initializer_words[0] = opt.init_word + config.model.params.personalization_config.params.initializer_words[ + 0 + ] = opt.init_word if opt.actual_resume: model = load_model_from_config(config, opt.actual_resume) @@ -588,123 +717,136 @@ if __name__ == "__main__": # default logger configs default_logger_cfgs = { - "wandb": { - "target": "pytorch_lightning.loggers.WandbLogger", - "params": { - "name": nowname, - "save_dir": logdir, - "offline": opt.debug, - "id": nowname, - } + 'wandb': { + 'target': 'pytorch_lightning.loggers.WandbLogger', + 'params': { + 'name': nowname, + 'save_dir': logdir, + 'offline': opt.debug, + 'id': nowname, + }, }, - "testtube": { - "target": "pytorch_lightning.loggers.TestTubeLogger", - "params": { - "name": "testtube", - "save_dir": logdir, - } + 'testtube': { + 'target': 'pytorch_lightning.loggers.TestTubeLogger', + 'params': { + 'name': 'testtube', + 'save_dir': logdir, + }, }, } - default_logger_cfg = default_logger_cfgs["testtube"] - if "logger" in lightning_config: + default_logger_cfg = default_logger_cfgs['testtube'] + if 'logger' in lightning_config: logger_cfg = lightning_config.logger else: logger_cfg = OmegaConf.create() logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) - trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) + trainer_kwargs['logger'] = instantiate_from_config(logger_cfg) # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to # specify which metric is used to determine best models default_modelckpt_cfg = { - "target": "pytorch_lightning.callbacks.ModelCheckpoint", - "params": { - "dirpath": ckptdir, - "filename": "{epoch:06}", - "verbose": True, - "save_last": True, - } + 'target': 'pytorch_lightning.callbacks.ModelCheckpoint', + 'params': { + 'dirpath': ckptdir, + 'filename': '{epoch:06}', + 'verbose': True, + 'save_last': True, + }, } - if hasattr(model, "monitor"): - print(f"Monitoring {model.monitor} as checkpoint metric.") - default_modelckpt_cfg["params"]["monitor"] = model.monitor - default_modelckpt_cfg["params"]["save_top_k"] = 3 + if hasattr(model, 'monitor'): + print(f'Monitoring {model.monitor} as checkpoint metric.') + default_modelckpt_cfg['params']['monitor'] = model.monitor + default_modelckpt_cfg['params']['save_top_k'] = 3 - if "modelcheckpoint" in lightning_config: + if 'modelcheckpoint' in lightning_config: modelckpt_cfg = lightning_config.modelcheckpoint else: - modelckpt_cfg = OmegaConf.create() + modelckpt_cfg = OmegaConf.create() modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) - print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}") + print(f'Merged modelckpt-cfg: \n{modelckpt_cfg}') if version.parse(pl.__version__) < version.parse('1.4.0'): - trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg) + trainer_kwargs['checkpoint_callback'] = instantiate_from_config( + modelckpt_cfg + ) # add callback which sets up log directory default_callbacks_cfg = { - "setup_callback": { - "target": "main.SetupCallback", - "params": { - "resume": opt.resume, - "now": now, - "logdir": logdir, - "ckptdir": ckptdir, - "cfgdir": cfgdir, - "config": config, - "lightning_config": lightning_config, - } + 'setup_callback': { + 'target': 'main.SetupCallback', + 'params': { + 'resume': opt.resume, + 'now': now, + 'logdir': logdir, + 'ckptdir': ckptdir, + 'cfgdir': cfgdir, + 'config': config, + 'lightning_config': lightning_config, + }, }, - "image_logger": { - "target": "main.ImageLogger", - "params": { - "batch_frequency": 750, - "max_images": 4, - "clamp": True - } + 'image_logger': { + 'target': 'main.ImageLogger', + 'params': { + 'batch_frequency': 750, + 'max_images': 4, + 'clamp': True, + }, }, - "learning_rate_logger": { - "target": "main.LearningRateMonitor", - "params": { - "logging_interval": "step", + 'learning_rate_logger': { + 'target': 'main.LearningRateMonitor', + 'params': { + 'logging_interval': 'step', # "log_momentum": True - } - }, - "cuda_callback": { - "target": "main.CUDACallback" + }, }, + 'cuda_callback': {'target': 'main.CUDACallback'}, } if version.parse(pl.__version__) >= version.parse('1.4.0'): - default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg}) + default_callbacks_cfg.update( + {'checkpoint_callback': modelckpt_cfg} + ) - if "callbacks" in lightning_config: + if 'callbacks' in lightning_config: callbacks_cfg = lightning_config.callbacks else: callbacks_cfg = OmegaConf.create() if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg: print( - 'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.') + 'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.' + ) default_metrics_over_trainsteps_ckpt_dict = { - 'metrics_over_trainsteps_checkpoint': - {"target": 'pytorch_lightning.callbacks.ModelCheckpoint', - 'params': { - "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'), - "filename": "{epoch:06}-{step:09}", - "verbose": True, - 'save_top_k': -1, - 'every_n_train_steps': 10000, - 'save_weights_only': True - } - } + 'metrics_over_trainsteps_checkpoint': { + 'target': 'pytorch_lightning.callbacks.ModelCheckpoint', + 'params': { + 'dirpath': os.path.join( + ckptdir, 'trainstep_checkpoints' + ), + 'filename': '{epoch:06}-{step:09}', + 'verbose': True, + 'save_top_k': -1, + 'every_n_train_steps': 10000, + 'save_weights_only': True, + }, + } } - default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict) + default_callbacks_cfg.update( + default_metrics_over_trainsteps_ckpt_dict + ) callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) - if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'): - callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint + if 'ignore_keys_callback' in callbacks_cfg and hasattr( + trainer_opt, 'resume_from_checkpoint' + ): + callbacks_cfg.ignore_keys_callback.params[ + 'ckpt_path' + ] = trainer_opt.resume_from_checkpoint elif 'ignore_keys_callback' in callbacks_cfg: del callbacks_cfg['ignore_keys_callback'] - trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] - trainer_kwargs["max_steps"] = opt.max_steps + trainer_kwargs['callbacks'] = [ + instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg + ] + trainer_kwargs['max_steps'] = opt.max_steps trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) trainer.logdir = logdir ### @@ -720,47 +862,60 @@ if __name__ == "__main__": # lightning still takes care of proper multiprocessing though data.prepare_data() data.setup() - print("#### Data #####") + print('#### Data #####') for k in data.datasets: - print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}") + print( + f'{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}' + ) # configure learning rate - bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate + bs, base_lr = ( + config.data.params.batch_size, + config.model.base_learning_rate, + ) if not cpu: - ngpu = len(lightning_config.trainer.gpus.strip(",").split(',')) + ngpu = len(lightning_config.trainer.gpus.strip(',').split(',')) else: ngpu = 1 if 'accumulate_grad_batches' in lightning_config.trainer: - accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches + accumulate_grad_batches = ( + lightning_config.trainer.accumulate_grad_batches + ) else: accumulate_grad_batches = 1 - print(f"accumulate_grad_batches = {accumulate_grad_batches}") - lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches + print(f'accumulate_grad_batches = {accumulate_grad_batches}') + lightning_config.trainer.accumulate_grad_batches = ( + accumulate_grad_batches + ) if opt.scale_lr: model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr print( - "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format( - model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr)) + 'Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)'.format( + model.learning_rate, + accumulate_grad_batches, + ngpu, + bs, + base_lr, + ) + ) else: model.learning_rate = base_lr - print("++++ NOT USING LR SCALING ++++") - print(f"Setting learning rate to {model.learning_rate:.2e}") - + print('++++ NOT USING LR SCALING ++++') + print(f'Setting learning rate to {model.learning_rate:.2e}') # allow checkpointing via USR1 def melk(*args, **kwargs): # run all checkpoint hooks if trainer.global_rank == 0: - print("Summoning checkpoint.") - ckpt_path = os.path.join(ckptdir, "last.ckpt") + print('Summoning checkpoint.') + ckpt_path = os.path.join(ckptdir, 'last.ckpt') trainer.save_checkpoint(ckpt_path) - def divein(*args, **kwargs): if trainer.global_rank == 0: - import pudb; - pudb.set_trace() + import pudb + pudb.set_trace() import signal @@ -788,7 +943,7 @@ if __name__ == "__main__": # move newly created debug project to debug_runs if opt.debug and not opt.resume and trainer.global_rank == 0: dst, name = os.path.split(logdir) - dst = os.path.join(dst, "debug_runs", name) + dst = os.path.join(dst, 'debug_runs', name) os.makedirs(os.path.split(dst)[0], exist_ok=True) os.rename(logdir, dst) # if trainer.global_rank == 0: diff --git a/scripts/dream.py b/scripts/dream.py index 6a17656593..f6feb10adc 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -8,62 +8,66 @@ import sys import copy import warnings import ldm.dream.readline -from ldm.dream.pngwriter import PngWriter,PromptFormatter +from ldm.dream.pngwriter import PngWriter, PromptFormatter debugging = False + def main(): - ''' Initialize command-line parsers and the diffusion model ''' + """Initialize command-line parsers and the diffusion model""" arg_parser = create_argv_parser() - opt = arg_parser.parse_args() + opt = arg_parser.parse_args() if opt.laion400m: # defaults suitable to the older latent diffusion weights - width = 256 - height = 256 - config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml" - weights = "models/ldm/text2img-large/model.ckpt" + width = 256 + height = 256 + config = 'configs/latent-diffusion/txt2img-1p4B-eval.yaml' + weights = 'models/ldm/text2img-large/model.ckpt' else: # some defaults suitable for stable diffusion weights - width = 512 - height = 512 - config = "configs/stable-diffusion/v1-inference.yaml" - weights = "models/ldm/stable-diffusion-v1/model.ckpt" + width = 512 + height = 512 + config = 'configs/stable-diffusion/v1-inference.yaml' + weights = 'models/ldm/stable-diffusion-v1/model.ckpt' - print("* Initializing, be patient...\n") + print('* Initializing, be patient...\n') sys.path.append('.') from pytorch_lightning import logging from ldm.simplet2i import T2I + # these two lines prevent a horrible warning message from appearing # when the frozen CLIP tokenizer is imported import transformers + transformers.logging.set_verbosity_error() - + # creating a simple text2image object with a handful of # defaults passed on the command line. # additional parameters will be added (or overriden) during # the user input loop - t2i = T2I(width=width, - height=height, - sampler_name=opt.sampler_name, - weights=weights, - full_precision=opt.full_precision, - config=config, - latent_diffusion_weights=opt.laion400m, # this is solely for recreating the prompt - embedding_path=opt.embedding_path, - device=opt.device + t2i = T2I( + width=width, + height=height, + sampler_name=opt.sampler_name, + weights=weights, + full_precision=opt.full_precision, + config=config, + latent_diffusion_weights=opt.laion400m, # this is solely for recreating the prompt + embedding_path=opt.embedding_path, + device=opt.device, ) # make sure the output directory exists if not os.path.exists(opt.outdir): os.makedirs(opt.outdir) - + # gets rid of annoying messages about random seed - logging.getLogger("pytorch_lightning").setLevel(logging.ERROR) + logging.getLogger('pytorch_lightning').setLevel(logging.ERROR) infile = None try: if opt.infile is not None: - infile = open(opt.infile,'r') + infile = open(opt.infile, 'r') except FileNotFoundError as e: print(e) exit(-1) @@ -73,135 +77,156 @@ def main(): # load GFPGAN if requested if opt.use_gfpgan: - print("\n* --gfpgan was specified, loading gfpgan...") + print('\n* --gfpgan was specified, loading gfpgan...') with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=DeprecationWarning) + warnings.filterwarnings('ignore', category=DeprecationWarning) try: - model_path = os.path.join(opt.gfpgan_dir, opt.gfpgan_model_path) + model_path = os.path.join( + opt.gfpgan_dir, opt.gfpgan_model_path + ) if not os.path.isfile(model_path): - raise Exception("GFPGAN model not found at path "+model_path) + raise Exception( + 'GFPGAN model not found at path ' + model_path + ) sys.path.append(os.path.abspath(opt.gfpgan_dir)) from gfpgan import GFPGANer - bg_upsampler = load_gfpgan_bg_upsampler(opt.gfpgan_bg_upsampler, opt.gfpgan_bg_tile) + bg_upsampler = load_gfpgan_bg_upsampler( + opt.gfpgan_bg_upsampler, opt.gfpgan_bg_tile + ) - t2i.gfpgan = GFPGANer(model_path=model_path, upscale=opt.gfpgan_upscale, arch='clean', channel_multiplier=2, bg_upsampler=bg_upsampler) + t2i.gfpgan = GFPGANer( + model_path=model_path, + upscale=opt.gfpgan_upscale, + arch='clean', + channel_multiplier=2, + bg_upsampler=bg_upsampler, + ) except Exception: import traceback - print("Error loading GFPGAN:", file=sys.stderr) + + print('Error loading GFPGAN:', file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) - print("\n* Initialization done! Awaiting your command (-h for help, 'q' to quit, 'cd' to change output dir, 'pwd' to print output dir)...") + print( + "\n* Initialization done! Awaiting your command (-h for help, 'q' to quit, 'cd' to change output dir, 'pwd' to print output dir)..." + ) - log_path = os.path.join(opt.outdir,'dream_log.txt') - with open(log_path,'a') as log: + log_path = os.path.join(opt.outdir, 'dream_log.txt') + with open(log_path, 'a') as log: cmd_parser = create_cmd_parser() - main_loop(t2i,opt.outdir,cmd_parser,log,infile) + main_loop(t2i, opt.outdir, cmd_parser, log, infile) log.close() if infile: infile.close() -def main_loop(t2i,outdir,parser,log,infile): - ''' prompt/read/execute loop ''' - done = False +def main_loop(t2i, outdir, parser, log, infile): + """prompt/read/execute loop""" + done = False last_seeds = [] - + while not done: try: - command = infile.readline() if infile else input("dream> ") + command = infile.readline() if infile else input('dream> ') except EOFError: done = True break - if infile and len(command)==0: + if infile and len(command) == 0: done = True break - if command.startswith(('#','//')): + if command.startswith(('#', '//')): continue # before splitting, escape single quotes so as not to mess # up the parser - command = command.replace("'","\\'") + command = command.replace("'", "\\'") try: elements = shlex.split(command) except ValueError as e: print(str(e)) continue - - if len(elements)==0: + + if len(elements) == 0: continue - if elements[0]=='q': + if elements[0] == 'q': done = True break - if elements[0]=='cd' and len(elements)>1: + if elements[0] == 'cd' and len(elements) > 1: if os.path.exists(elements[1]): - print(f"setting image output directory to {elements[1]}") - outdir=elements[1] + print(f'setting image output directory to {elements[1]}') + outdir = elements[1] else: - print(f"directory {elements[1]} does not exist") + print(f'directory {elements[1]} does not exist') continue - if elements[0]=='pwd': - print(f"current output directory is {outdir}") + if elements[0] == 'pwd': + print(f'current output directory is {outdir}') continue - - if elements[0].startswith('!dream'): # in case a stored prompt still contains the !dream command + + if elements[0].startswith( + '!dream' + ): # in case a stored prompt still contains the !dream command elements.pop(0) - + # rearrange the arguments to mimic how it works in the Dream bot. switches = [''] switches_started = False for el in elements: - if el[0]=='-' and not switches_started: + if el[0] == '-' and not switches_started: switches_started = True if switches_started: switches.append(el) else: switches[0] += el switches[0] += ' ' - switches[0] = switches[0][:len(switches[0])-1] + switches[0] = switches[0][: len(switches[0]) - 1] try: - opt = parser.parse_args(switches) + opt = parser.parse_args(switches) except SystemExit: parser.print_help() continue - if len(opt.prompt)==0: - print("Try again with a prompt!") + if len(opt.prompt) == 0: + print('Try again with a prompt!') continue - if opt.seed is not None and opt.seed<0: # retrieve previous value! + if opt.seed is not None and opt.seed < 0: # retrieve previous value! try: opt.seed = last_seeds[opt.seed] - print(f"reusing previous seed {opt.seed}") + print(f'reusing previous seed {opt.seed}') except IndexError: - print(f"No previous seed at position {opt.seed} found") + print(f'No previous seed at position {opt.seed} found') opt.seed = None - - normalized_prompt = PromptFormatter(t2i,opt).normalize_prompt() - individual_images = not opt.grid + + normalized_prompt = PromptFormatter(t2i, opt).normalize_prompt() + individual_images = not opt.grid try: - file_writer = PngWriter(outdir,normalized_prompt,opt.batch_size) - callback = file_writer.write_image if individual_images else None + file_writer = PngWriter(outdir, normalized_prompt, opt.batch_size) + callback = file_writer.write_image if individual_images else None - image_list = t2i.prompt2image(image_callback=callback,**vars(opt)) - results = file_writer.files_written if individual_images else image_list + image_list = t2i.prompt2image(image_callback=callback, **vars(opt)) + results = ( + file_writer.files_written if individual_images else image_list + ) if opt.grid and len(results) > 0: grid_img = file_writer.make_grid([r[0] for r in results]) filename = file_writer.unique_filename(results[0][1]) - seeds = [a[1] for a in results] - results = [[filename,seeds]] - metadata_prompt = f'{normalized_prompt} -S{results[0][1]}' - file_writer.save_image_and_prompt_to_png(grid_img,metadata_prompt,filename) + seeds = [a[1] for a in results] + results = [[filename, seeds]] + metadata_prompt = f'{normalized_prompt} -S{results[0][1]}' + file_writer.save_image_and_prompt_to_png( + grid_img, metadata_prompt, filename + ) last_seeds = [r[1] for r in results] @@ -213,10 +238,11 @@ def main_loop(t2i,outdir,parser,log,infile): print(e) continue - print("Outputs:") - write_log_message(t2i,normalized_prompt,results,log) + print('Outputs:') + write_log_message(t2i, normalized_prompt, results, log) + + print('goodbye!') - print("goodbye!") def load_gfpgan_bg_upsampler(bg_upsampler, bg_tile=400): import torch @@ -224,13 +250,24 @@ def load_gfpgan_bg_upsampler(bg_upsampler, bg_tile=400): if bg_upsampler == 'realesrgan': if not torch.cuda.is_available(): # CPU import warnings - warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. ' - 'If you really want to use it, please modify the corresponding codes.') + + warnings.warn( + 'The unoptimized RealESRGAN is slow on CPU. We do not use it. ' + 'If you really want to use it, please modify the corresponding codes.' + ) bg_upsampler = None else: from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer - model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) + + model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=2, + ) bg_upsampler = RealESRGANer( scale=2, model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', @@ -238,12 +275,14 @@ def load_gfpgan_bg_upsampler(bg_upsampler, bg_tile=400): tile=bg_tile, tile_pad=10, pre_pad=0, - half=True) # need to set False in CPU mode + half=True, + ) # need to set False in CPU mode else: bg_upsampler = None return bg_upsampler + # variant generation is going to be superseded by a generalized # "prompt-morph" functionality # def generate_variants(t2i,outdir,opt,previous_gens): @@ -268,110 +307,209 @@ def load_gfpgan_bg_upsampler(bg_upsampler, bg_tile=400): # continue # print(f'{opt.variants} variants generated') # return variants - -def write_log_message(t2i,prompt,results,logfile): - ''' logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata''' - last_seed = None - img_num = 1 - seenit = {} + +def write_log_message(t2i, prompt, results, logfile): + """logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata""" + last_seed = None + img_num = 1 + seenit = {} for r in results: seed = r[1] - log_message = (f'{r[0]}: {prompt} -S{seed}') + log_message = f'{r[0]}: {prompt} -S{seed}' print(log_message) - logfile.write(log_message+"\n") + logfile.write(log_message + '\n') logfile.flush() + def create_argv_parser(): - parser = argparse.ArgumentParser(description="Parse script's command line args") - parser.add_argument("--laion400m", - "--latent_diffusion", - "-l", - dest='laion400m', - action='store_true', - help="fallback to the latent diffusion (laion400m) weights and config") - parser.add_argument("--from_file", - dest='infile', - type=str, - help="if specified, load prompts from this file") - parser.add_argument('-n','--iterations', - type=int, - default=1, - help="number of images to generate") - parser.add_argument('-F','--full_precision', - dest='full_precision', - action='store_true', - help="use slower full precision math for calculations") - parser.add_argument('--sampler','-m', - dest="sampler_name", - choices=['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'], - default='k_lms', - help="which sampler to use (k_lms) - can only be set on command line") - parser.add_argument('--outdir', - '-o', - type=str, - default="outputs/img-samples", - help="directory in which to place generated images and a log of prompts and seeds (outputs/img-samples") - parser.add_argument('--embedding_path', - type=str, - help="Path to a pre-trained embedding manager checkpoint - can only be set on command line") - parser.add_argument('--device', - '-d', - type=str, - default="cuda", - help="device to run stable diffusion on. defaults to cuda `torch.cuda.current_device()` if avalible") + parser = argparse.ArgumentParser( + description="Parse script's command line args" + ) + parser.add_argument( + '--laion400m', + '--latent_diffusion', + '-l', + dest='laion400m', + action='store_true', + help='fallback to the latent diffusion (laion400m) weights and config', + ) + parser.add_argument( + '--from_file', + dest='infile', + type=str, + help='if specified, load prompts from this file', + ) + parser.add_argument( + '-n', + '--iterations', + type=int, + default=1, + help='number of images to generate', + ) + parser.add_argument( + '-F', + '--full_precision', + dest='full_precision', + action='store_true', + help='use slower full precision math for calculations', + ) + parser.add_argument( + '--sampler', + '-m', + dest='sampler_name', + choices=[ + 'ddim', + 'k_dpm_2_a', + 'k_dpm_2', + 'k_euler_a', + 'k_euler', + 'k_heun', + 'k_lms', + 'plms', + ], + default='k_lms', + help='which sampler to use (k_lms) - can only be set on command line', + ) + parser.add_argument( + '--outdir', + '-o', + type=str, + default='outputs/img-samples', + help='directory in which to place generated images and a log of prompts and seeds (outputs/img-samples', + ) + parser.add_argument( + '--embedding_path', + type=str, + help='Path to a pre-trained embedding manager checkpoint - can only be set on command line', + ) + parser.add_argument( + '--device', + '-d', + type=str, + default='cuda', + help='device to run stable diffusion on. defaults to cuda `torch.cuda.current_device()` if avalible', + ) # GFPGAN related args - parser.add_argument('--gfpgan', - dest='use_gfpgan', - action='store_true', - help="load gfpgan for use in the dreambot. Note: Enabling GFPGAN will require more GPU memory") - parser.add_argument("--gfpgan_upscale", - type=int, - default=2, - help="The final upsampling scale of the image. Default: 2. Only used if --gfpgan is specified") - parser.add_argument("--gfpgan_bg_upsampler", - type=str, - default='realesrgan', - help="Background upsampler. Default: None. Options: realesrgan, none. Only used if --gfpgan is specified") - parser.add_argument("--gfpgan_bg_tile", - type=int, - default=400, - help="Tile size for background sampler, 0 for no tile during testing. Default: 400. Only used if --gfpgan is specified") - parser.add_argument("--gfpgan_model_path", - type=str, - default='experiments/pretrained_models/GFPGANv1.3.pth', - help="indicates the path to the GFPGAN model, relative to --gfpgan_dir. Only used if --gfpgan is specified") - parser.add_argument("--gfpgan_dir", - type=str, - default='../GFPGAN', - help="indicates the directory containing the GFPGAN code. Only used if --gfpgan is specified") + parser.add_argument( + '--gfpgan', + dest='use_gfpgan', + action='store_true', + help='load gfpgan for use in the dreambot. Note: Enabling GFPGAN will require more GPU memory', + ) + parser.add_argument( + '--gfpgan_upscale', + type=int, + default=2, + help='The final upsampling scale of the image. Default: 2. Only used if --gfpgan is specified', + ) + parser.add_argument( + '--gfpgan_bg_upsampler', + type=str, + default='realesrgan', + help='Background upsampler. Default: None. Options: realesrgan, none. Only used if --gfpgan is specified', + ) + parser.add_argument( + '--gfpgan_bg_tile', + type=int, + default=400, + help='Tile size for background sampler, 0 for no tile during testing. Default: 400. Only used if --gfpgan is specified', + ) + parser.add_argument( + '--gfpgan_model_path', + type=str, + default='experiments/pretrained_models/GFPGANv1.3.pth', + help='indicates the path to the GFPGAN model, relative to --gfpgan_dir. Only used if --gfpgan is specified', + ) + parser.add_argument( + '--gfpgan_dir', + type=str, + default='../GFPGAN', + help='indicates the directory containing the GFPGAN code. Only used if --gfpgan is specified', + ) return parser - - + + def create_cmd_parser(): - parser = argparse.ArgumentParser(description='Example: dream> a fantastic alien landscape -W1024 -H960 -s100 -n12') + parser = argparse.ArgumentParser( + description='Example: dream> a fantastic alien landscape -W1024 -H960 -s100 -n12' + ) parser.add_argument('prompt') - parser.add_argument('-s','--steps',type=int,help="number of steps") - parser.add_argument('-S','--seed',type=int,help="image seed; a +ve integer, or use -1 for the previous seed, -2 for the one before that, etc") - parser.add_argument('-n','--iterations',type=int,default=1,help="number of samplings to perform (slower, but will provide seeds for individual images)") - parser.add_argument('-b','--batch_size',type=int,default=1,help="number of images to produce per sampling (will not provide seeds for individual images!)") - parser.add_argument('-W','--width',type=int,help="image width, multiple of 64") - parser.add_argument('-H','--height',type=int,help="image height, multiple of 64") - parser.add_argument('-C','--cfg_scale',default=7.5,type=float,help="prompt configuration scale") - parser.add_argument('-g','--grid',action='store_true',help="generate a grid") - parser.add_argument('-i','--individual',action='store_true',help="generate individual files (default)") - parser.add_argument('-I','--init_img',type=str,help="path to input image for img2img mode (supersedes width and height)") - parser.add_argument('-f','--strength',default=0.75,type=float,help="strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely") - parser.add_argument('-G','--gfpgan_strength', default=0.5, type=float, help="The strength at which to apply the GFPGAN model to the result, in order to improve faces.") -# variants is going to be superseded by a generalized "prompt-morph" function -# parser.add_argument('-v','--variants',type=int,help="in img2img mode, the first generated image will get passed back to img2img to generate the requested number of variants") - parser.add_argument('-x','--skip_normalize',action='store_true',help="skip subprompt weight normalization") + parser.add_argument('-s', '--steps', type=int, help='number of steps') + parser.add_argument( + '-S', + '--seed', + type=int, + help='image seed; a +ve integer, or use -1 for the previous seed, -2 for the one before that, etc', + ) + parser.add_argument( + '-n', + '--iterations', + type=int, + default=1, + help='number of samplings to perform (slower, but will provide seeds for individual images)', + ) + parser.add_argument( + '-b', + '--batch_size', + type=int, + default=1, + help='number of images to produce per sampling (will not provide seeds for individual images!)', + ) + parser.add_argument( + '-W', '--width', type=int, help='image width, multiple of 64' + ) + parser.add_argument( + '-H', '--height', type=int, help='image height, multiple of 64' + ) + parser.add_argument( + '-C', + '--cfg_scale', + default=7.5, + type=float, + help='prompt configuration scale', + ) + parser.add_argument( + '-g', '--grid', action='store_true', help='generate a grid' + ) + parser.add_argument( + '-i', + '--individual', + action='store_true', + help='generate individual files (default)', + ) + parser.add_argument( + '-I', + '--init_img', + type=str, + help='path to input image for img2img mode (supersedes width and height)', + ) + parser.add_argument( + '-f', + '--strength', + default=0.75, + type=float, + help='strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely', + ) + parser.add_argument( + '-G', + '--gfpgan_strength', + default=0.5, + type=float, + help='The strength at which to apply the GFPGAN model to the result, in order to improve faces.', + ) + # variants is going to be superseded by a generalized "prompt-morph" function + # parser.add_argument('-v','--variants',type=int,help="in img2img mode, the first generated image will get passed back to img2img to generate the requested number of variants") + parser.add_argument( + '-x', + '--skip_normalize', + action='store_true', + help='skip subprompt weight normalization', + ) return parser - -if __name__ == "__main__": +if __name__ == '__main__': main() - diff --git a/scripts/preload_models.py b/scripts/preload_models.py index d7538c82b8..624b61e48e 100755 --- a/scripts/preload_models.py +++ b/scripts/preload_models.py @@ -11,26 +11,28 @@ import warnings transformers.logging.set_verbosity_error() # this will preload the Bert tokenizer fles -print("preloading bert tokenizer...") +print('preloading bert tokenizer...') from transformers import BertTokenizerFast -tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") -print("...success") + +tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') +print('...success') # this will download requirements for Kornia -print("preloading Kornia requirements (ignore the deprecation warnings)...") +print('preloading Kornia requirements (ignore the deprecation warnings)...') with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=DeprecationWarning) + warnings.filterwarnings('ignore', category=DeprecationWarning) import kornia -print("...success") +print('...success') -version='openai/clip-vit-large-patch14' +version = 'openai/clip-vit-large-patch14' print('preloading CLIP model (Ignore the deprecation warnings)...') sys.stdout.flush() import clip from transformers import CLIPTokenizer, CLIPTextModel -tokenizer =CLIPTokenizer.from_pretrained(version) -transformer=CLIPTextModel.from_pretrained(version) + +tokenizer = CLIPTokenizer.from_pretrained(version) +transformer = CLIPTextModel.from_pretrained(version) print('\n\n...success') # In the event that the user has installed GFPGAN and also elected to use @@ -38,23 +40,33 @@ print('\n\n...success') gfpgan = False try: from realesrgan import RealESRGANer + gfpgan = True except ModuleNotFoundError: pass if gfpgan: - print("Loading models from RealESRGAN and facexlib") + print('Loading models from RealESRGAN and facexlib') try: from basicsr.archs.rrdbnet_arch import RRDBNet from facexlib.utils.face_restoration_helper import FaceRestoreHelper - RealESRGANer(scale=2, - model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', - model=RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)) - FaceRestoreHelper(1,det_model='retinaface_resnet50') - print("...success") + + RealESRGANer( + scale=2, + model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', + model=RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=2, + ), + ) + FaceRestoreHelper(1, det_model='retinaface_resnet50') + print('...success') except Exception: import traceback - print("Error loading GFPGAN:") + + print('Error loading GFPGAN:') print(traceback.format_exc()) - -